Xuezha commited on
Commit
57c1748
·
verified ·
1 Parent(s): b524ec2

Update modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +11 -6
modeling.py CHANGED
@@ -150,20 +150,24 @@ class RecombinationTransformerForCausalLM(PreTrainedModel):
150
  self.final_rms_norm = RMSNorm(config.embed_dim)
151
  self.lm_head = nn.Linear(config.embed_dim, config.vocab_size, bias=False)
152
 
153
- def forward(self, input_ids, attention_mask=None, past_key_values=None, return_dict=None):
154
  if attention_mask is None:
155
  attention_mask = torch.ones(input_ids.shape, device=input_ids.device)
156
 
157
- # Create causal mask
158
  batch_size, seq_length = input_ids.size()
159
  causal_mask = torch.tril(torch.ones((seq_length, seq_length), device=input_ids.device)).unsqueeze(0).expand(batch_size, -1, -1)
160
 
 
 
 
161
  # Embedding
162
  x = self.embed_tokens(input_ids)
163
 
164
- # Apply layers
165
- for layer in self.layers:
 
166
  x = layer(x, attn_mask=causal_mask)
 
167
 
168
  # Final normalization
169
  x = self.final_rms_norm(x)
@@ -172,9 +176,10 @@ class RecombinationTransformerForCausalLM(PreTrainedModel):
172
  logits = self.lm_head(x)
173
 
174
  if not return_dict:
175
- return (logits,)
 
 
176
 
177
- return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)
178
 
179
  def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **kwargs):
180
  if past:
 
150
  self.final_rms_norm = RMSNorm(config.embed_dim)
151
  self.lm_head = nn.Linear(config.embed_dim, config.vocab_size, bias=False)
152
 
153
+ def forward(self, input_ids, attention_mask=None, past_key_values=None, return_dict=None, **kwargs):
154
  if attention_mask is None:
155
  attention_mask = torch.ones(input_ids.shape, device=input_ids.device)
156
 
 
157
  batch_size, seq_length = input_ids.size()
158
  causal_mask = torch.tril(torch.ones((seq_length, seq_length), device=input_ids.device)).unsqueeze(0).expand(batch_size, -1, -1)
159
 
160
+ if past_key_values is None:
161
+ past_key_values = [None] * len(self.layers)
162
+
163
  # Embedding
164
  x = self.embed_tokens(input_ids)
165
 
166
+ new_past_key_values = []
167
+ for i, layer in enumerate(self.layers):
168
+ past_key_value = past_key_values[i]
169
  x = layer(x, attn_mask=causal_mask)
170
+ new_past_key_values.append(x)
171
 
172
  # Final normalization
173
  x = self.final_rms_norm(x)
 
176
  logits = self.lm_head(x)
177
 
178
  if not return_dict:
179
+ return (logits, new_past_key_values)
180
+
181
+ return CausalLMOutputWithPast(logits=logits, past_key_values=new_past_key_values)
182
 
 
183
 
184
  def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **kwargs):
185
  if past: