Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -95,14 +95,17 @@ load_model(model)
|
|
| 95 |
# Generate Response
|
| 96 |
def generate_response(model, query, max_length=200):
|
| 97 |
model.eval()
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
|
|
|
|
|
|
|
|
|
| 106 |
return tokenizer.decode(tgt.squeeze(0).tolist())
|
| 107 |
|
| 108 |
# FastAPI app
|
|
|
|
| 95 |
# Generate Response
|
| 96 |
def generate_response(model, query, max_length=200):
|
| 97 |
model.eval()
|
| 98 |
+
with torch.no_grad(): # Disable gradient tracking
|
| 99 |
+
src = torch.tensor(tokenizer.encode(query)).unsqueeze(0).to(device)
|
| 100 |
+
tgt = torch.tensor([[1]]).to(device) # <SOS>
|
| 101 |
+
|
| 102 |
+
for _ in range(max_length):
|
| 103 |
+
output = model(src, tgt)
|
| 104 |
+
next_token = output[:, -1, :].argmax(dim=-1, keepdim=True)
|
| 105 |
+
tgt = torch.cat([tgt, next_token], dim=1)
|
| 106 |
+
if next_token.item() == 2: # <EOS>
|
| 107 |
+
break
|
| 108 |
+
|
| 109 |
return tokenizer.decode(tgt.squeeze(0).tolist())
|
| 110 |
|
| 111 |
# FastAPI app
|