Akshitha1 commited on
Commit
b14f70c
·
verified ·
1 Parent(s): 9736d99

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -8
app.py CHANGED
@@ -95,14 +95,17 @@ load_model(model)
95
  # Generate Response
96
  def generate_response(model, query, max_length=200):
97
  model.eval()
98
- src = torch.tensor(tokenizer.encode(query)).unsqueeze(0).to(device)
99
- tgt = torch.tensor([[1]]).to(device) # <SOS>
100
- for _ in range(max_length):
101
- output = model(src, tgt)
102
- next_word = output.argmax(-1)[:, -1].unsqueeze(1)
103
- tgt = torch.cat([tgt, next_word], dim=1)
104
- if next_word.item() == 2: # <EOS>
105
- break
 
 
 
106
  return tokenizer.decode(tgt.squeeze(0).tolist())
107
 
108
  # FastAPI app
 
95
  # Generate Response
96
  def generate_response(model, query, max_length=200):
97
  model.eval()
98
+ with torch.no_grad(): # Disable gradient tracking
99
+ src = torch.tensor(tokenizer.encode(query)).unsqueeze(0).to(device)
100
+ tgt = torch.tensor([[1]]).to(device) # <SOS>
101
+
102
+ for _ in range(max_length):
103
+ output = model(src, tgt)
104
+ next_token = output[:, -1, :].argmax(dim=-1, keepdim=True)
105
+ tgt = torch.cat([tgt, next_token], dim=1)
106
+ if next_token.item() == 2: # <EOS>
107
+ break
108
+
109
  return tokenizer.decode(tgt.squeeze(0).tolist())
110
 
111
  # FastAPI app