Update modeling.py
Browse files- modeling.py +4 -1
modeling.py
CHANGED
|
@@ -150,7 +150,7 @@ class RecombinationTransformerForCausalLM(PreTrainedModel):
|
|
| 150 |
self.final_rms_norm = RMSNorm(config.embed_dim)
|
| 151 |
self.lm_head = nn.Linear(config.embed_dim, config.vocab_size, bias=False)
|
| 152 |
|
| 153 |
-
def forward(self, input_ids, attention_mask=None, past_key_values=None):
|
| 154 |
if attention_mask is None:
|
| 155 |
attention_mask = torch.ones(input_ids.shape, device=input_ids.device)
|
| 156 |
|
|
@@ -171,6 +171,9 @@ class RecombinationTransformerForCausalLM(PreTrainedModel):
|
|
| 171 |
# LM head
|
| 172 |
logits = self.lm_head(x)
|
| 173 |
|
|
|
|
|
|
|
|
|
|
| 174 |
return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)
|
| 175 |
|
| 176 |
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **kwargs):
|
|
|
|
| 150 |
self.final_rms_norm = RMSNorm(config.embed_dim)
|
| 151 |
self.lm_head = nn.Linear(config.embed_dim, config.vocab_size, bias=False)
|
| 152 |
|
| 153 |
+
def forward(self, input_ids, attention_mask=None, past_key_values=None, return_dict=None):
|
| 154 |
if attention_mask is None:
|
| 155 |
attention_mask = torch.ones(input_ids.shape, device=input_ids.device)
|
| 156 |
|
|
|
|
| 171 |
# LM head
|
| 172 |
logits = self.lm_head(x)
|
| 173 |
|
| 174 |
+
if not return_dict:
|
| 175 |
+
return (logits,)
|
| 176 |
+
|
| 177 |
return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)
|
| 178 |
|
| 179 |
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **kwargs):
|