Update modeling.py
Browse files- modeling.py +11 -6
modeling.py
CHANGED
|
@@ -150,20 +150,24 @@ class RecombinationTransformerForCausalLM(PreTrainedModel):
|
|
| 150 |
self.final_rms_norm = RMSNorm(config.embed_dim)
|
| 151 |
self.lm_head = nn.Linear(config.embed_dim, config.vocab_size, bias=False)
|
| 152 |
|
| 153 |
-
def forward(self, input_ids, attention_mask=None, past_key_values=None, return_dict=None):
|
| 154 |
if attention_mask is None:
|
| 155 |
attention_mask = torch.ones(input_ids.shape, device=input_ids.device)
|
| 156 |
|
| 157 |
-
# Create causal mask
|
| 158 |
batch_size, seq_length = input_ids.size()
|
| 159 |
causal_mask = torch.tril(torch.ones((seq_length, seq_length), device=input_ids.device)).unsqueeze(0).expand(batch_size, -1, -1)
|
| 160 |
|
|
|
|
|
|
|
|
|
|
| 161 |
# Embedding
|
| 162 |
x = self.embed_tokens(input_ids)
|
| 163 |
|
| 164 |
-
|
| 165 |
-
for layer in self.layers:
|
|
|
|
| 166 |
x = layer(x, attn_mask=causal_mask)
|
|
|
|
| 167 |
|
| 168 |
# Final normalization
|
| 169 |
x = self.final_rms_norm(x)
|
|
@@ -172,9 +176,10 @@ class RecombinationTransformerForCausalLM(PreTrainedModel):
|
|
| 172 |
logits = self.lm_head(x)
|
| 173 |
|
| 174 |
if not return_dict:
|
| 175 |
-
return (logits,)
|
|
|
|
|
|
|
| 176 |
|
| 177 |
-
return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)
|
| 178 |
|
| 179 |
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **kwargs):
|
| 180 |
if past:
|
|
|
|
| 150 |
self.final_rms_norm = RMSNorm(config.embed_dim)
|
| 151 |
self.lm_head = nn.Linear(config.embed_dim, config.vocab_size, bias=False)
|
| 152 |
|
| 153 |
+
def forward(self, input_ids, attention_mask=None, past_key_values=None, return_dict=None, **kwargs):
|
| 154 |
if attention_mask is None:
|
| 155 |
attention_mask = torch.ones(input_ids.shape, device=input_ids.device)
|
| 156 |
|
|
|
|
| 157 |
batch_size, seq_length = input_ids.size()
|
| 158 |
causal_mask = torch.tril(torch.ones((seq_length, seq_length), device=input_ids.device)).unsqueeze(0).expand(batch_size, -1, -1)
|
| 159 |
|
| 160 |
+
if past_key_values is None:
|
| 161 |
+
past_key_values = [None] * len(self.layers)
|
| 162 |
+
|
| 163 |
# Embedding
|
| 164 |
x = self.embed_tokens(input_ids)
|
| 165 |
|
| 166 |
+
new_past_key_values = []
|
| 167 |
+
for i, layer in enumerate(self.layers):
|
| 168 |
+
past_key_value = past_key_values[i]
|
| 169 |
x = layer(x, attn_mask=causal_mask)
|
| 170 |
+
new_past_key_values.append(x)
|
| 171 |
|
| 172 |
# Final normalization
|
| 173 |
x = self.final_rms_norm(x)
|
|
|
|
| 176 |
logits = self.lm_head(x)
|
| 177 |
|
| 178 |
if not return_dict:
|
| 179 |
+
return (logits, new_past_key_values)
|
| 180 |
+
|
| 181 |
+
return CausalLMOutputWithPast(logits=logits, past_key_values=new_past_key_values)
|
| 182 |
|
|
|
|
| 183 |
|
| 184 |
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **kwargs):
|
| 185 |
if past:
|