rishiiitha commited on
Commit
3c819a7
·
verified ·
1 Parent(s): b1c3deb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -9
app.py CHANGED
@@ -91,17 +91,32 @@ def load_model(model, path="gpt_model.pth"):
91
 
92
  load_model(model)
93
 
94
- # Generate Response
 
 
 
 
 
 
 
 
 
 
 
 
95
  def generate_response(model, query, max_length=200):
96
  model.eval()
97
- src = torch.tensor(tokenizer.encode(query)).unsqueeze(0).to(device)
98
- tgt = torch.tensor([[1]]).to(device) # <SOS>
99
- for _ in range(max_length):
100
- output = model(src, tgt)
101
- next_word = output.argmax(-1)[:, -1].unsqueeze(1)
102
- tgt = torch.cat([tgt, next_word], dim=1)
103
- if next_word.item() == 2: # <EOS>
104
- break
 
 
 
105
  return tokenizer.decode(tgt.squeeze(0).tolist())
106
 
107
  # Flask App
 
91
 
92
  load_model(model)
93
 
94
+ # # Generate Response
95
+ # def generate_response(model, query, max_length=200):
96
+ # model.eval()
97
+ # src = torch.tensor(tokenizer.encode(query)).unsqueeze(0).to(device)
98
+ # tgt = torch.tensor([[1]]).to(device) # <SOS>
99
+ # for _ in range(max_length):
100
+ # output = model(src, tgt)
101
+ # next_word = output.argmax(-1)[:, -1].unsqueeze(1)
102
+ # tgt = torch.cat([tgt, next_word], dim=1)
103
+ # if next_word.item() == 2: # <EOS>
104
+ # break
105
+ # return tokenizer.decode(tgt.squeeze(0).tolist())
106
+
107
  def generate_response(model, query, max_length=200):
108
  model.eval()
109
+ with torch.no_grad(): # Disable gradient tracking
110
+ src = torch.tensor(tokenizer.encode(query)).unsqueeze(0).to(device)
111
+ tgt = torch.tensor([[1]]).to(device) # <SOS>
112
+
113
+ for _ in range(max_length):
114
+ output = model(src, tgt)
115
+ next_token = output[:, -1, :].argmax(dim=-1, keepdim=True)
116
+ tgt = torch.cat([tgt, next_token], dim=1)
117
+ if next_token.item() == 2: # <EOS>
118
+ break
119
+
120
  return tokenizer.decode(tgt.squeeze(0).tolist())
121
 
122
  # Flask App