KitsuVp commited on
Commit
942b781
·
verified ·
1 Parent(s): 2bbbd04

Update modeling_neollm.py

Browse files
Files changed (1) hide show
  1. modeling_neollm.py +8 -25
modeling_neollm.py CHANGED
@@ -1121,7 +1121,6 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
1121
  return_dict: Optional[bool] = None,
1122
  past_stack_state: Optional[torch.Tensor] = None,
1123
  past_stack_mask: Optional[torch.Tensor] = None,
1124
- past_first_layer_fan: Optional[torch.Tensor] = None,
1125
  **kwargs: Unpack[TransformersKwargs],
1126
  ) -> BaseModelOutputWithPast:
1127
  output_hidden_states = (
@@ -1160,8 +1159,7 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
1160
  position_embeddings = self.rotary_emb(hidden_states, position_ids)
1161
 
1162
  # ResFormer with first-layer feature propagation
1163
- # Retrieve persistent ResFormer state if provided (for inference)
1164
- self.first_layer_fan = past_first_layer_fan
1165
 
1166
  # Initialize Stack states
1167
  stack_state = past_stack_state
@@ -1192,13 +1190,8 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
1192
  stack_mask = layer_outputs[3]
1193
 
1194
  # ResFormer: capture H_fan_1 from the first layer
1195
- # If we didn't have it (and it wasn't passed via past), capture it now.
1196
- # For inference, if we just computed the prompt/first token, we keep it.
1197
  if self.first_layer_fan is None and hasattr(decoder_layer, 'current_layer_fan'):
1198
- # Crucial: For persistence, we might want to slice this if it's the prompt?
1199
- # But logic says: reuse same tensor. If seq_len > 1, it's prompt.
1200
- # If seq_len == 1, it's generation.
1201
- # If we are starting fresh (None), we capture what we have.
1202
  self.first_layer_fan = decoder_layer.current_layer_fan
1203
 
1204
  # Apply SeeDNorm for final normalization
@@ -1207,15 +1200,10 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
1207
  if output_hidden_states:
1208
  all_hidden_states = all_hidden_states + (hidden_states,)
1209
 
1210
- # Construct the persistence tuple (Stack + Fan)
1211
- # Note: We do not implement full KV cache yet, but we persist these states.
1212
  next_cache = None
1213
- if self.use_stack or self.first_layer_fan is not None:
1214
- # Capture the first token's FAN for ResFormer persistence
1215
- # If we have a sequence, we probably want to keep the FIRST token's fan for consistency?
1216
- # Or just keep the whole thing? The requirement is "reutilizar".
1217
- # We keep the object self.first_layer_fan.
1218
- next_cache = (stack_state, stack_mask, self.first_layer_fan)
1219
 
1220
  if not return_dict:
1221
  return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attentions] if v is not None)
@@ -1281,16 +1269,14 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
1281
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1282
  ):
1283
  # Extract custom states from past_key_values if present
1284
- # Structure: (stack_state, stack_mask, first_layer_fan)
1285
  past_stack_state = None
1286
  past_stack_mask = None
1287
- past_first_layer_fan = None
1288
 
1289
  if past_key_values is not None:
1290
  # We use the past_key_values as a container for our custom states
1291
- # Since we don't have standard KV cache yet, it should just be our tuple
1292
- if len(past_key_values) == 3:
1293
- past_stack_state, past_stack_mask, past_first_layer_fan = past_key_values
1294
 
1295
  # Helper for generation loop: input_ids should be just the last token if we have past
1296
  input_ids = input_ids[:, -1:]
@@ -1299,7 +1285,6 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
1299
  "input_ids": input_ids,
1300
  "past_stack_state": past_stack_state,
1301
  "past_stack_mask": past_stack_mask,
1302
- "past_first_layer_fan": past_first_layer_fan,
1303
  "use_cache": kwargs.get("use_cache"),
1304
  "position_ids": kwargs.get("position_ids", None),
1305
  "attention_mask": attention_mask,
@@ -1319,7 +1304,6 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
1319
  return_dict: Optional[bool] = None,
1320
  past_stack_state: Optional[torch.Tensor] = None,
1321
  past_stack_mask: Optional[torch.Tensor] = None,
1322
- past_first_layer_fan: Optional[torch.Tensor] = None,
1323
  **kwargs: Unpack[TransformersKwargs],
1324
  ) -> CausalLMOutputWithPast:
1325
  outputs: BaseModelOutputWithPast = self.model(
@@ -1331,7 +1315,6 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
1331
  return_dict=return_dict,
1332
  past_stack_state=past_stack_state,
1333
  past_stack_mask=past_stack_mask,
1334
- past_first_layer_fan=past_first_layer_fan,
1335
  **kwargs,
1336
  )
1337
 
 
1121
  return_dict: Optional[bool] = None,
1122
  past_stack_state: Optional[torch.Tensor] = None,
1123
  past_stack_mask: Optional[torch.Tensor] = None,
 
1124
  **kwargs: Unpack[TransformersKwargs],
1125
  ) -> BaseModelOutputWithPast:
1126
  output_hidden_states = (
 
1159
  position_embeddings = self.rotary_emb(hidden_states, position_ids)
1160
 
1161
  # ResFormer with first-layer feature propagation
1162
+ self.first_layer_fan = None
 
1163
 
1164
  # Initialize Stack states
1165
  stack_state = past_stack_state
 
1190
  stack_mask = layer_outputs[3]
1191
 
1192
  # ResFormer: capture H_fan_1 from the first layer
1193
+ # Dynamically capture for the current pass
 
1194
  if self.first_layer_fan is None and hasattr(decoder_layer, 'current_layer_fan'):
 
 
 
 
1195
  self.first_layer_fan = decoder_layer.current_layer_fan
1196
 
1197
  # Apply SeeDNorm for final normalization
 
1200
  if output_hidden_states:
1201
  all_hidden_states = all_hidden_states + (hidden_states,)
1202
 
1203
+ # Construct the persistence tuple (Stack only)
 
1204
  next_cache = None
1205
+ if self.use_stack:
1206
+ next_cache = (stack_state, stack_mask)
 
 
 
 
1207
 
1208
  if not return_dict:
1209
  return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attentions] if v is not None)
 
1269
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1270
  ):
1271
  # Extract custom states from past_key_values if present
1272
+ # Structure: (stack_state, stack_mask)
1273
  past_stack_state = None
1274
  past_stack_mask = None
 
1275
 
1276
  if past_key_values is not None:
1277
  # We use the past_key_values as a container for our custom states
1278
+ if len(past_key_values) == 2:
1279
+ past_stack_state, past_stack_mask = past_key_values
 
1280
 
1281
  # Helper for generation loop: input_ids should be just the last token if we have past
1282
  input_ids = input_ids[:, -1:]
 
1285
  "input_ids": input_ids,
1286
  "past_stack_state": past_stack_state,
1287
  "past_stack_mask": past_stack_mask,
 
1288
  "use_cache": kwargs.get("use_cache"),
1289
  "position_ids": kwargs.get("position_ids", None),
1290
  "attention_mask": attention_mask,
 
1304
  return_dict: Optional[bool] = None,
1305
  past_stack_state: Optional[torch.Tensor] = None,
1306
  past_stack_mask: Optional[torch.Tensor] = None,
 
1307
  **kwargs: Unpack[TransformersKwargs],
1308
  ) -> CausalLMOutputWithPast:
1309
  outputs: BaseModelOutputWithPast = self.model(
 
1315
  return_dict=return_dict,
1316
  past_stack_state=past_stack_state,
1317
  past_stack_mask=past_stack_mask,
 
1318
  **kwargs,
1319
  )
1320