dev-das commited on
Commit
46a0c5f
·
verified ·
1 Parent(s): 65dcf66

Update modeling_my_gpt.py

Browse files
Files changed (1) hide show
  1. modeling_my_gpt.py +9 -6
modeling_my_gpt.py CHANGED
@@ -1,6 +1,8 @@
1
  import torch
2
  import torch.nn as nn
3
  from transformers import PreTrainedModel
 
 
4
  from .configuration_my_gpt import MyGPTConfig
5
  from .untrained_model import GPTModel
6
 
@@ -32,9 +34,9 @@ class MyGPTForCausalLM(PreTrainedModel):
32
 
33
  self.post_init()
34
 
35
- def forward(self, input_ids, labels=None):
36
  logits = self.model(input_ids)
37
-
38
  loss = None
39
  if labels is not None:
40
  shift_logits = logits[..., :-1, :].contiguous()
@@ -44,8 +46,9 @@ class MyGPTForCausalLM(PreTrainedModel):
44
  shift_logits.view(-1, shift_logits.size(-1)),
45
  shift_labels.view(-1)
46
  )
 
 
 
 
 
47
 
48
- return {
49
- "loss": loss,
50
- "logits": logits,
51
- }
 
1
  import torch
2
  import torch.nn as nn
3
  from transformers import PreTrainedModel
4
+ from transformers.modeling_outputs import CausalLMOutput
5
+
6
  from .configuration_my_gpt import MyGPTConfig
7
  from .untrained_model import GPTModel
8
 
 
34
 
35
  self.post_init()
36
 
37
+ def forward(self, input_ids, labels=None, **kwargs):
38
  logits = self.model(input_ids)
39
+
40
  loss = None
41
  if labels is not None:
42
  shift_logits = logits[..., :-1, :].contiguous()
 
46
  shift_logits.view(-1, shift_logits.size(-1)),
47
  shift_labels.view(-1)
48
  )
49
+
50
+ return CausalLMOutput(
51
+ loss=loss,
52
+ logits=logits,
53
+ )
54