|
|
|
|
|
|
|
|
import torch
|
|
|
from model import GPTModel, ScratchTokenizer
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
|
|
|
tokenizer = ScratchTokenizer()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json
|
|
|
with open("vocab.json", "r") as f:
|
|
|
vocab = json.load(f)
|
|
|
|
|
|
tokenizer.word2idx = vocab["word2idx"]
|
|
|
tokenizer.idx2word = {int(k): v for k, v in vocab["idx2word"].items()}
|
|
|
tokenizer.vocab_size = vocab["vocab_size"]
|
|
|
|
|
|
|
|
|
model = GPTModel(vocab_size=tokenizer.vocab_size)
|
|
|
model.load_state_dict(torch.load("gpt_model.pth", map_location=device))
|
|
|
model.to(device)
|
|
|
model.eval()
|
|
|
|
|
|
|
|
|
def generate_response(query, max_length=200):
|
|
|
src = torch.tensor(tokenizer.encode(query)).unsqueeze(0).to(device)
|
|
|
tgt = torch.tensor([[1]]).to(device)
|
|
|
|
|
|
for _ in range(max_length):
|
|
|
output = model(src, tgt)
|
|
|
next_word = output.argmax(-1)[:, -1].unsqueeze(1)
|
|
|
tgt = torch.cat([tgt, next_word], dim=1)
|
|
|
if next_word.item() == 2:
|
|
|
break
|
|
|
|
|
|
return tokenizer.decode(tgt.squeeze(0).tolist())
|
|
|
|