Update modeling_auristream.py
Browse files- modeling_auristream.py +4 -0
modeling_auristream.py
CHANGED
|
@@ -369,6 +369,9 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
|
|
| 369 |
all_logits = [logits] if output_logits else None
|
| 370 |
|
| 371 |
# Compute future head logits
|
|
|
|
|
|
|
|
|
|
| 372 |
if self.future_heads is not None:
|
| 373 |
for i, head in enumerate(self.future_heads):
|
| 374 |
future_logits = head(x[:, :-(i + 1)])
|
|
@@ -378,6 +381,7 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
|
|
| 378 |
# Compute loss if labels provided
|
| 379 |
loss = None
|
| 380 |
if labels is not None:
|
|
|
|
| 381 |
loss = F.cross_entropy(
|
| 382 |
logits.reshape(-1, self.config.vocab_size),
|
| 383 |
labels.reshape(-1),
|
|
|
|
| 369 |
all_logits = [logits] if output_logits else None
|
| 370 |
|
| 371 |
# Compute future head logits
|
| 372 |
+
# lm_head is the first "standard" lm head which predicts token i+1 (as all GPT models have)
|
| 373 |
+
# self.future_heads holds all the other "MTP" future prediction heads, so self.future_heads
|
| 374 |
+
# corresponds to the head that predicts token i+2 - aka the "second head"
|
| 375 |
if self.future_heads is not None:
|
| 376 |
for i, head in enumerate(self.future_heads):
|
| 377 |
future_logits = head(x[:, :-(i + 1)])
|
|
|
|
| 381 |
# Compute loss if labels provided
|
| 382 |
loss = None
|
| 383 |
if labels is not None:
|
| 384 |
+
# compute loss from the first "standard" lm head
|
| 385 |
loss = F.cross_entropy(
|
| 386 |
logits.reshape(-1, self.config.vocab_size),
|
| 387 |
labels.reshape(-1),
|