Update modeling_progen.py
Browse files- modeling_progen.py +0 -4
modeling_progen.py
CHANGED
|
@@ -613,11 +613,7 @@ class ProGenForCausalLM(ProGenPreTrainedModel):
|
|
| 613 |
# compute loss in fp32 to match with mesh-tf version
|
| 614 |
# https://github.com/EleutherAI/gpt-neo/blob/89ce74164da2fb16179106f54e2269b5da8db333/models/gpt2/gpt2.py#L179
|
| 615 |
lm_logits = self.lm_head(hidden_states).to(torch.float32)
|
| 616 |
-
print(f"lm_logits = {lm_logits.shape}")
|
| 617 |
-
print(f"logits_to_keep = {logits_to_keep}")
|
| 618 |
|
| 619 |
-
# Debug shape
|
| 620 |
-
print(f"Final logits shape: {lm_logits.shape}") # Should be [batch, seq, vocab]
|
| 621 |
loss = None
|
| 622 |
if labels is not None:
|
| 623 |
# Shift so that tokens < n predict n
|
|
|
|
| 613 |
# compute loss in fp32 to match with mesh-tf version
|
| 614 |
# https://github.com/EleutherAI/gpt-neo/blob/89ce74164da2fb16179106f54e2269b5da8db333/models/gpt2/gpt2.py#L179
|
| 615 |
lm_logits = self.lm_head(hidden_states).to(torch.float32)
|
|
|
|
|
|
|
| 616 |
|
|
|
|
|
|
|
| 617 |
loss = None
|
| 618 |
if labels is not None:
|
| 619 |
# Shift so that tokens < n predict n
|