Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -91,17 +91,32 @@ def load_model(model, path="gpt_model.pth"):
|
|
| 91 |
|
| 92 |
load_model(model)
|
| 93 |
|
| 94 |
-
# Generate Response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
def generate_response(model, query, max_length=200):
|
| 96 |
model.eval()
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
| 105 |
return tokenizer.decode(tgt.squeeze(0).tolist())
|
| 106 |
|
| 107 |
# Flask App
|
|
|
|
| 91 |
|
| 92 |
load_model(model)
|
| 93 |
|
| 94 |
+
# # Generate Response
|
| 95 |
+
# def generate_response(model, query, max_length=200):
|
| 96 |
+
# model.eval()
|
| 97 |
+
# src = torch.tensor(tokenizer.encode(query)).unsqueeze(0).to(device)
|
| 98 |
+
# tgt = torch.tensor([[1]]).to(device) # <SOS>
|
| 99 |
+
# for _ in range(max_length):
|
| 100 |
+
# output = model(src, tgt)
|
| 101 |
+
# next_word = output.argmax(-1)[:, -1].unsqueeze(1)
|
| 102 |
+
# tgt = torch.cat([tgt, next_word], dim=1)
|
| 103 |
+
# if next_word.item() == 2: # <EOS>
|
| 104 |
+
# break
|
| 105 |
+
# return tokenizer.decode(tgt.squeeze(0).tolist())
|
| 106 |
+
|
| 107 |
def generate_response(model, query, max_length=200):
|
| 108 |
model.eval()
|
| 109 |
+
with torch.no_grad(): # Disable gradient tracking
|
| 110 |
+
src = torch.tensor(tokenizer.encode(query)).unsqueeze(0).to(device)
|
| 111 |
+
tgt = torch.tensor([[1]]).to(device) # <SOS>
|
| 112 |
+
|
| 113 |
+
for _ in range(max_length):
|
| 114 |
+
output = model(src, tgt)
|
| 115 |
+
next_token = output[:, -1, :].argmax(dim=-1, keepdim=True)
|
| 116 |
+
tgt = torch.cat([tgt, next_token], dim=1)
|
| 117 |
+
if next_token.item() == 2: # <EOS>
|
| 118 |
+
break
|
| 119 |
+
|
| 120 |
return tokenizer.decode(tgt.squeeze(0).tolist())
|
| 121 |
|
| 122 |
# Flask App
|