Spaces:
Sleeping
Sleeping
Commit ·
4bf12f4
1
Parent(s): 6feea4b
Update app.py
Browse files
app.py
CHANGED
|
@@ -225,7 +225,11 @@ def beam_generate_v2(model,src_tensor, beam=5, max_len=50, alpha=0.7):
|
|
| 225 |
|
| 226 |
# decoder step: input is last token
|
| 227 |
dec_in = torch.tensor([[seq[-1]]], device=DEVICE)
|
| 228 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
probs = F.log_softmax(out, dim=-1).squeeze(0) # [vocab]
|
| 230 |
|
| 231 |
# penalty for repetition
|
|
|
|
| 225 |
|
| 226 |
# decoder step: input is last token
|
| 227 |
dec_in = torch.tensor([[seq[-1]]], device=DEVICE)
|
| 228 |
+
# Call decoder with correct arguments depending on model type
|
| 229 |
+
if isinstance(model.decoder, Decoder_with_attn):
|
| 230 |
+
out, new_h = model.decoder(dec_in, hid, enc_out)
|
| 231 |
+
else:
|
| 232 |
+
out, new_h = model.decoder(dec_in, hid)
|
| 233 |
probs = F.log_softmax(out, dim=-1).squeeze(0) # [vocab]
|
| 234 |
|
| 235 |
# penalty for repetition
|