Update modeling_my_gpt.py
Browse files- modeling_my_gpt.py +9 -6
modeling_my_gpt.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
from transformers import PreTrainedModel
|
|
|
|
|
|
|
| 4 |
from .configuration_my_gpt import MyGPTConfig
|
| 5 |
from .untrained_model import GPTModel
|
| 6 |
|
|
@@ -32,9 +34,9 @@ class MyGPTForCausalLM(PreTrainedModel):
|
|
| 32 |
|
| 33 |
self.post_init()
|
| 34 |
|
| 35 |
-
def forward(self, input_ids, labels=None):
|
| 36 |
logits = self.model(input_ids)
|
| 37 |
-
|
| 38 |
loss = None
|
| 39 |
if labels is not None:
|
| 40 |
shift_logits = logits[..., :-1, :].contiguous()
|
|
@@ -44,8 +46,9 @@ class MyGPTForCausalLM(PreTrainedModel):
|
|
| 44 |
shift_logits.view(-1, shift_logits.size(-1)),
|
| 45 |
shift_labels.view(-1)
|
| 46 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
-
return {
|
| 49 |
-
"loss": loss,
|
| 50 |
-
"logits": logits,
|
| 51 |
-
}
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
from transformers import PreTrainedModel
|
| 4 |
+
from transformers.modeling_outputs import CausalLMOutput
|
| 5 |
+
|
| 6 |
from .configuration_my_gpt import MyGPTConfig
|
| 7 |
from .untrained_model import GPTModel
|
| 8 |
|
|
|
|
| 34 |
|
| 35 |
self.post_init()
|
| 36 |
|
| 37 |
+
def forward(self, input_ids, labels=None, **kwargs):
|
| 38 |
logits = self.model(input_ids)
|
| 39 |
+
|
| 40 |
loss = None
|
| 41 |
if labels is not None:
|
| 42 |
shift_logits = logits[..., :-1, :].contiguous()
|
|
|
|
| 46 |
shift_logits.view(-1, shift_logits.size(-1)),
|
| 47 |
shift_labels.view(-1)
|
| 48 |
)
|
| 49 |
+
|
| 50 |
+
return CausalLMOutput(
|
| 51 |
+
loss=loss,
|
| 52 |
+
logits=logits,
|
| 53 |
+
)
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|