klemenk commited on
Commit
ffc5144
·
verified ·
1 Parent(s): 248af8b

Update modeling_auristream.py

Browse files
Files changed (1) hide show
  1. modeling_auristream.py +4 -0
modeling_auristream.py CHANGED
@@ -369,6 +369,9 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
369
  all_logits = [logits] if output_logits else None
370
 
371
  # Compute future head logits
 
 
 
372
  if self.future_heads is not None:
373
  for i, head in enumerate(self.future_heads):
374
  future_logits = head(x[:, :-(i + 1)])
@@ -378,6 +381,7 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
378
  # Compute loss if labels provided
379
  loss = None
380
  if labels is not None:
 
381
  loss = F.cross_entropy(
382
  logits.reshape(-1, self.config.vocab_size),
383
  labels.reshape(-1),
 
369
  all_logits = [logits] if output_logits else None
370
 
371
  # Compute future head logits
372
+ # lm_head is the first "standard" lm head which predicts token i+1 (as all GPT models have)
373
+ # self.future_heads holds all the other "MTP" future prediction heads, so self.future_heads
374
+ # corresponds to the head that predicts token i+2 - aka the "second head"
375
  if self.future_heads is not None:
376
  for i, head in enumerate(self.future_heads):
377
  future_logits = head(x[:, :-(i + 1)])
 
381
  # Compute loss if labels provided
382
  loss = None
383
  if labels is not None:
384
+ # compute loss from the first "standard" lm head
385
  loss = F.cross_entropy(
386
  logits.reshape(-1, self.config.vocab_size),
387
  labels.reshape(-1),