Xuezha commited on
Commit
b524ec2
·
verified ·
1 Parent(s): 148f88a

Update modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +4 -1
modeling.py CHANGED
@@ -150,7 +150,7 @@ class RecombinationTransformerForCausalLM(PreTrainedModel):
150
  self.final_rms_norm = RMSNorm(config.embed_dim)
151
  self.lm_head = nn.Linear(config.embed_dim, config.vocab_size, bias=False)
152
 
153
- def forward(self, input_ids, attention_mask=None, past_key_values=None):
154
  if attention_mask is None:
155
  attention_mask = torch.ones(input_ids.shape, device=input_ids.device)
156
 
@@ -171,6 +171,9 @@ class RecombinationTransformerForCausalLM(PreTrainedModel):
171
  # LM head
172
  logits = self.lm_head(x)
173
 
 
 
 
174
  return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)
175
 
176
  def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **kwargs):
 
150
  self.final_rms_norm = RMSNorm(config.embed_dim)
151
  self.lm_head = nn.Linear(config.embed_dim, config.vocab_size, bias=False)
152
 
153
+ def forward(self, input_ids, attention_mask=None, past_key_values=None, return_dict=None):
154
  if attention_mask is None:
155
  attention_mask = torch.ones(input_ids.shape, device=input_ids.device)
156
 
 
171
  # LM head
172
  logits = self.lm_head(x)
173
 
174
+ if not return_dict:
175
+ return (logits,)
176
+
177
  return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)
178
 
179
  def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **kwargs):