Fix the loss computation
Browse files- t5la_modeling.py +2 -2
t5la_modeling.py
CHANGED
|
@@ -327,14 +327,14 @@ class T5LaForConditionalGeneration(T5ForConditionalGeneration):
|
|
| 327 |
lookahead_targets.view(-1),
|
| 328 |
# vocab_size=self.config.vocab_size,
|
| 329 |
)
|
| 330 |
-
|
| 331 |
# If we simply add, the loss will be larger than a non-LA T5 model because
|
| 332 |
# in a normal T5, the number of tokens are much lower:
|
| 333 |
loss = (loss + lookahead_loss) / (1 + self.config.lookahead_size)
|
| 334 |
else:
|
| 335 |
loss = (loss * lm_logits.shape[1] + lookahead_loss * self.config.lookahead_size) / (
|
| 336 |
lm_logits.shape[1] + self.config.lookahead_size
|
| 337 |
-
)
|
| 338 |
|
| 339 |
if not return_dict:
|
| 340 |
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
|
|
|
|
| 327 |
lookahead_targets.view(-1),
|
| 328 |
# vocab_size=self.config.vocab_size,
|
| 329 |
)
|
| 330 |
+
if self.config.lookahead_type == "la":
|
| 331 |
# If we simply add, the loss will be larger than a non-LA T5 model because
|
| 332 |
# in a normal T5, the number of tokens are much lower:
|
| 333 |
loss = (loss + lookahead_loss) / (1 + self.config.lookahead_size)
|
| 334 |
else:
|
| 335 |
loss = (loss * lm_logits.shape[1] + lookahead_loss * self.config.lookahead_size) / (
|
| 336 |
lm_logits.shape[1] + self.config.lookahead_size
|
| 337 |
+
)
|
| 338 |
|
| 339 |
if not return_dict:
|
| 340 |
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
|