hrezaei commited on
Commit
91aef84
·
verified ·
1 Parent(s): c586e86

Fix the loss computation

Browse files
Files changed (1) hide show
  1. t5la_modeling.py +2 -2
t5la_modeling.py CHANGED
@@ -327,14 +327,14 @@ class T5LaForConditionalGeneration(T5ForConditionalGeneration):
327
  lookahead_targets.view(-1),
328
  # vocab_size=self.config.vocab_size,
329
  )
330
- """if self.config.lookahead_type == "la":
331
  # If we simply add, the loss will be larger than a non-LA T5 model because
332
  # in a normal T5, the number of tokens are much lower:
333
  loss = (loss + lookahead_loss) / (1 + self.config.lookahead_size)
334
  else:
335
  loss = (loss * lm_logits.shape[1] + lookahead_loss * self.config.lookahead_size) / (
336
  lm_logits.shape[1] + self.config.lookahead_size
337
- )"""
338
 
339
  if not return_dict:
340
  output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
 
327
  lookahead_targets.view(-1),
328
  # vocab_size=self.config.vocab_size,
329
  )
330
+ if self.config.lookahead_type == "la":
331
  # If we simply add, the loss will be larger than a non-LA T5 model because
332
  # in a normal T5, the number of tokens are much lower:
333
  loss = (loss + lookahead_loss) / (1 + self.config.lookahead_size)
334
  else:
335
  loss = (loss * lm_logits.shape[1] + lookahead_loss * self.config.lookahead_size) / (
336
  lm_logits.shape[1] + self.config.lookahead_size
337
+ )
338
 
339
  if not return_dict:
340
  output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs