amaresh8053 commited on
Commit
4bf12f4
·
1 Parent(s): 6feea4b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -1
app.py CHANGED
@@ -225,7 +225,11 @@ def beam_generate_v2(model,src_tensor, beam=5, max_len=50, alpha=0.7):
225
 
226
  # decoder step: input is last token
227
  dec_in = torch.tensor([[seq[-1]]], device=DEVICE)
228
- out, new_h = model.decoder(dec_in, hid, enc_out)
 
 
 
 
229
  probs = F.log_softmax(out, dim=-1).squeeze(0) # [vocab]
230
 
231
  # penalty for repetition
 
225
 
226
  # decoder step: input is last token
227
  dec_in = torch.tensor([[seq[-1]]], device=DEVICE)
228
+ # Call decoder with correct arguments depending on model type
229
+ if isinstance(model.decoder, Decoder_with_attn):
230
+ out, new_h = model.decoder(dec_in, hid, enc_out)
231
+ else:
232
+ out, new_h = model.decoder(dec_in, hid)
233
  probs = F.log_softmax(out, dim=-1).squeeze(0) # [vocab]
234
 
235
  # penalty for repetition