KitsuVp commited on
Commit
2bbbd04
·
verified ·
1 Parent(s): 470f299

Update modeling_neollm.py

Browse files
Files changed (1) hide show
  1. modeling_neollm.py +65 -7
modeling_neollm.py CHANGED
@@ -37,7 +37,7 @@ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
37
  from transformers.processing_utils import Unpack
38
  from transformers.utils import TransformersKwargs, logging
39
  from transformers.utils.generic import check_model_inputs
40
- from .configuration_neollm import NeoLLMConfig
41
 
42
  from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
43
 
@@ -1119,6 +1119,9 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
1119
  output_hidden_states: Optional[bool] = None,
1120
  output_attentions: Optional[bool] = None,
1121
  return_dict: Optional[bool] = None,
 
 
 
1122
  **kwargs: Unpack[TransformersKwargs],
1123
  ) -> BaseModelOutputWithPast:
1124
  output_hidden_states = (
@@ -1157,9 +1160,12 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
1157
  position_embeddings = self.rotary_emb(hidden_states, position_ids)
1158
 
1159
  # ResFormer with first-layer feature propagation
1160
- self.first_layer_fan = None
1161
- stack_state = None
1162
- stack_mask = None
 
 
 
1163
 
1164
  for decoder_layer in self.layers:
1165
  if output_hidden_states:
@@ -1186,7 +1192,13 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
1186
  stack_mask = layer_outputs[3]
1187
 
1188
  # ResFormer: capture H_fan_1 from the first layer
 
 
1189
  if self.first_layer_fan is None and hasattr(decoder_layer, 'current_layer_fan'):
 
 
 
 
1190
  self.first_layer_fan = decoder_layer.current_layer_fan
1191
 
1192
  # Apply SeeDNorm for final normalization
@@ -1194,13 +1206,23 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
1194
 
1195
  if output_hidden_states:
1196
  all_hidden_states = all_hidden_states + (hidden_states,)
 
 
 
 
 
 
 
 
 
 
1197
 
1198
  if not return_dict:
1199
- return tuple(v for v in [hidden_states, None, all_hidden_states, all_attentions] if v is not None)
1200
 
1201
  return BaseModelOutputWithPast(
1202
  last_hidden_state=hidden_states,
1203
- past_key_values=None,
1204
  hidden_states=all_hidden_states,
1205
  attentions=all_attentions,
1206
  )
@@ -1255,6 +1277,36 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
1255
 
1256
  self.post_init()
1257
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1258
  def forward(
1259
  self,
1260
  input_ids: Optional[torch.LongTensor] = None,
@@ -1265,6 +1317,9 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
1265
  logits_to_keep: Union[int, torch.Tensor] = 0,
1266
  output_hidden_states: Optional[bool] = None,
1267
  return_dict: Optional[bool] = None,
 
 
 
1268
  **kwargs: Unpack[TransformersKwargs],
1269
  ) -> CausalLMOutputWithPast:
1270
  outputs: BaseModelOutputWithPast = self.model(
@@ -1274,6 +1329,9 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
1274
  inputs_embeds=inputs_embeds,
1275
  output_hidden_states=output_hidden_states,
1276
  return_dict=return_dict,
 
 
 
1277
  **kwargs,
1278
  )
1279
 
@@ -1298,7 +1356,7 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
1298
  return CausalLMOutputWithPast(
1299
  loss=loss,
1300
  logits=logits,
1301
- past_key_values=None,
1302
  hidden_states=outputs.hidden_states,
1303
  attentions=outputs.attentions,
1304
  )
 
37
  from transformers.processing_utils import Unpack
38
  from transformers.utils import TransformersKwargs, logging
39
  from transformers.utils.generic import check_model_inputs
40
+ from configuration_neollm import NeoLLMConfig
41
 
42
  from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
43
 
 
1119
  output_hidden_states: Optional[bool] = None,
1120
  output_attentions: Optional[bool] = None,
1121
  return_dict: Optional[bool] = None,
1122
+ past_stack_state: Optional[torch.Tensor] = None,
1123
+ past_stack_mask: Optional[torch.Tensor] = None,
1124
+ past_first_layer_fan: Optional[torch.Tensor] = None,
1125
  **kwargs: Unpack[TransformersKwargs],
1126
  ) -> BaseModelOutputWithPast:
1127
  output_hidden_states = (
 
1160
  position_embeddings = self.rotary_emb(hidden_states, position_ids)
1161
 
1162
  # ResFormer with first-layer feature propagation
1163
+ # Retrieve persistent ResFormer state if provided (for inference)
1164
+ self.first_layer_fan = past_first_layer_fan
1165
+
1166
+ # Initialize Stack states
1167
+ stack_state = past_stack_state
1168
+ stack_mask = past_stack_mask
1169
 
1170
  for decoder_layer in self.layers:
1171
  if output_hidden_states:
 
1192
  stack_mask = layer_outputs[3]
1193
 
1194
  # ResFormer: capture H_fan_1 from the first layer
1195
+ # If we didn't have it (and it wasn't passed via past), capture it now.
1196
+ # For inference, if we just computed the prompt/first token, we keep it.
1197
  if self.first_layer_fan is None and hasattr(decoder_layer, 'current_layer_fan'):
1198
+ # Crucial: For persistence, we might want to slice this if it's the prompt?
1199
+ # But logic says: reuse same tensor. If seq_len > 1, it's prompt.
1200
+ # If seq_len == 1, it's generation.
1201
+ # If we are starting fresh (None), we capture what we have.
1202
  self.first_layer_fan = decoder_layer.current_layer_fan
1203
 
1204
  # Apply SeeDNorm for final normalization
 
1206
 
1207
  if output_hidden_states:
1208
  all_hidden_states = all_hidden_states + (hidden_states,)
1209
+
1210
+ # Construct the persistence tuple (Stack + Fan)
1211
+ # Note: We do not implement full KV cache yet, but we persist these states.
1212
+ next_cache = None
1213
+ if self.use_stack or self.first_layer_fan is not None:
1214
+ # Capture the first token's FAN for ResFormer persistence
1215
+ # If we have a sequence, we probably want to keep the FIRST token's fan for consistency?
1216
+ # Or just keep the whole thing? The requirement is "reutilizar".
1217
+ # We keep the object self.first_layer_fan.
1218
+ next_cache = (stack_state, stack_mask, self.first_layer_fan)
1219
 
1220
  if not return_dict:
1221
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attentions] if v is not None)
1222
 
1223
  return BaseModelOutputWithPast(
1224
  last_hidden_state=hidden_states,
1225
+ past_key_values=next_cache,
1226
  hidden_states=all_hidden_states,
1227
  attentions=all_attentions,
1228
  )
 
1277
 
1278
  self.post_init()
1279
 
1280
+ def prepare_inputs_for_generation(
1281
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1282
+ ):
1283
+ # Extract custom states from past_key_values if present
1284
+ # Structure: (stack_state, stack_mask, first_layer_fan)
1285
+ past_stack_state = None
1286
+ past_stack_mask = None
1287
+ past_first_layer_fan = None
1288
+
1289
+ if past_key_values is not None:
1290
+ # We use the past_key_values as a container for our custom states
1291
+ # Since we don't have standard KV cache yet, it should just be our tuple
1292
+ if len(past_key_values) == 3:
1293
+ past_stack_state, past_stack_mask, past_first_layer_fan = past_key_values
1294
+
1295
+ # Helper for generation loop: input_ids should be just the last token if we have past
1296
+ input_ids = input_ids[:, -1:]
1297
+
1298
+ model_inputs = {
1299
+ "input_ids": input_ids,
1300
+ "past_stack_state": past_stack_state,
1301
+ "past_stack_mask": past_stack_mask,
1302
+ "past_first_layer_fan": past_first_layer_fan,
1303
+ "use_cache": kwargs.get("use_cache"),
1304
+ "position_ids": kwargs.get("position_ids", None),
1305
+ "attention_mask": attention_mask,
1306
+ "inputs_embeds": inputs_embeds,
1307
+ }
1308
+ return model_inputs
1309
+
1310
  def forward(
1311
  self,
1312
  input_ids: Optional[torch.LongTensor] = None,
 
1317
  logits_to_keep: Union[int, torch.Tensor] = 0,
1318
  output_hidden_states: Optional[bool] = None,
1319
  return_dict: Optional[bool] = None,
1320
+ past_stack_state: Optional[torch.Tensor] = None,
1321
+ past_stack_mask: Optional[torch.Tensor] = None,
1322
+ past_first_layer_fan: Optional[torch.Tensor] = None,
1323
  **kwargs: Unpack[TransformersKwargs],
1324
  ) -> CausalLMOutputWithPast:
1325
  outputs: BaseModelOutputWithPast = self.model(
 
1329
  inputs_embeds=inputs_embeds,
1330
  output_hidden_states=output_hidden_states,
1331
  return_dict=return_dict,
1332
+ past_stack_state=past_stack_state,
1333
+ past_stack_mask=past_stack_mask,
1334
+ past_first_layer_fan=past_first_layer_fan,
1335
  **kwargs,
1336
  )
1337
 
 
1356
  return CausalLMOutputWithPast(
1357
  loss=loss,
1358
  logits=logits,
1359
+ past_key_values=outputs.past_key_values,
1360
  hidden_states=outputs.hidden_states,
1361
  attentions=outputs.attentions,
1362
  )