Update modeling_neollm.py
Browse files- modeling_neollm.py +8 -25
modeling_neollm.py
CHANGED
|
@@ -1121,7 +1121,6 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 1121 |
return_dict: Optional[bool] = None,
|
| 1122 |
past_stack_state: Optional[torch.Tensor] = None,
|
| 1123 |
past_stack_mask: Optional[torch.Tensor] = None,
|
| 1124 |
-
past_first_layer_fan: Optional[torch.Tensor] = None,
|
| 1125 |
**kwargs: Unpack[TransformersKwargs],
|
| 1126 |
) -> BaseModelOutputWithPast:
|
| 1127 |
output_hidden_states = (
|
|
@@ -1160,8 +1159,7 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 1160 |
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 1161 |
|
| 1162 |
# ResFormer with first-layer feature propagation
|
| 1163 |
-
|
| 1164 |
-
self.first_layer_fan = past_first_layer_fan
|
| 1165 |
|
| 1166 |
# Initialize Stack states
|
| 1167 |
stack_state = past_stack_state
|
|
@@ -1192,13 +1190,8 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 1192 |
stack_mask = layer_outputs[3]
|
| 1193 |
|
| 1194 |
# ResFormer: capture H_fan_1 from the first layer
|
| 1195 |
-
#
|
| 1196 |
-
# For inference, if we just computed the prompt/first token, we keep it.
|
| 1197 |
if self.first_layer_fan is None and hasattr(decoder_layer, 'current_layer_fan'):
|
| 1198 |
-
# Crucial: For persistence, we might want to slice this if it's the prompt?
|
| 1199 |
-
# But logic says: reuse same tensor. If seq_len > 1, it's prompt.
|
| 1200 |
-
# If seq_len == 1, it's generation.
|
| 1201 |
-
# If we are starting fresh (None), we capture what we have.
|
| 1202 |
self.first_layer_fan = decoder_layer.current_layer_fan
|
| 1203 |
|
| 1204 |
# Apply SeeDNorm for final normalization
|
|
@@ -1207,15 +1200,10 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 1207 |
if output_hidden_states:
|
| 1208 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 1209 |
|
| 1210 |
-
# Construct the persistence tuple (Stack
|
| 1211 |
-
# Note: We do not implement full KV cache yet, but we persist these states.
|
| 1212 |
next_cache = None
|
| 1213 |
-
if self.use_stack
|
| 1214 |
-
|
| 1215 |
-
# If we have a sequence, we probably want to keep the FIRST token's fan for consistency?
|
| 1216 |
-
# Or just keep the whole thing? The requirement is "reutilizar".
|
| 1217 |
-
# We keep the object self.first_layer_fan.
|
| 1218 |
-
next_cache = (stack_state, stack_mask, self.first_layer_fan)
|
| 1219 |
|
| 1220 |
if not return_dict:
|
| 1221 |
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attentions] if v is not None)
|
|
@@ -1281,16 +1269,14 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
|
|
| 1281 |
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
| 1282 |
):
|
| 1283 |
# Extract custom states from past_key_values if present
|
| 1284 |
-
# Structure: (stack_state, stack_mask
|
| 1285 |
past_stack_state = None
|
| 1286 |
past_stack_mask = None
|
| 1287 |
-
past_first_layer_fan = None
|
| 1288 |
|
| 1289 |
if past_key_values is not None:
|
| 1290 |
# We use the past_key_values as a container for our custom states
|
| 1291 |
-
|
| 1292 |
-
|
| 1293 |
-
past_stack_state, past_stack_mask, past_first_layer_fan = past_key_values
|
| 1294 |
|
| 1295 |
# Helper for generation loop: input_ids should be just the last token if we have past
|
| 1296 |
input_ids = input_ids[:, -1:]
|
|
@@ -1299,7 +1285,6 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
|
|
| 1299 |
"input_ids": input_ids,
|
| 1300 |
"past_stack_state": past_stack_state,
|
| 1301 |
"past_stack_mask": past_stack_mask,
|
| 1302 |
-
"past_first_layer_fan": past_first_layer_fan,
|
| 1303 |
"use_cache": kwargs.get("use_cache"),
|
| 1304 |
"position_ids": kwargs.get("position_ids", None),
|
| 1305 |
"attention_mask": attention_mask,
|
|
@@ -1319,7 +1304,6 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
|
|
| 1319 |
return_dict: Optional[bool] = None,
|
| 1320 |
past_stack_state: Optional[torch.Tensor] = None,
|
| 1321 |
past_stack_mask: Optional[torch.Tensor] = None,
|
| 1322 |
-
past_first_layer_fan: Optional[torch.Tensor] = None,
|
| 1323 |
**kwargs: Unpack[TransformersKwargs],
|
| 1324 |
) -> CausalLMOutputWithPast:
|
| 1325 |
outputs: BaseModelOutputWithPast = self.model(
|
|
@@ -1331,7 +1315,6 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
|
|
| 1331 |
return_dict=return_dict,
|
| 1332 |
past_stack_state=past_stack_state,
|
| 1333 |
past_stack_mask=past_stack_mask,
|
| 1334 |
-
past_first_layer_fan=past_first_layer_fan,
|
| 1335 |
**kwargs,
|
| 1336 |
)
|
| 1337 |
|
|
|
|
| 1121 |
return_dict: Optional[bool] = None,
|
| 1122 |
past_stack_state: Optional[torch.Tensor] = None,
|
| 1123 |
past_stack_mask: Optional[torch.Tensor] = None,
|
|
|
|
| 1124 |
**kwargs: Unpack[TransformersKwargs],
|
| 1125 |
) -> BaseModelOutputWithPast:
|
| 1126 |
output_hidden_states = (
|
|
|
|
| 1159 |
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 1160 |
|
| 1161 |
# ResFormer with first-layer feature propagation
|
| 1162 |
+
self.first_layer_fan = None
|
|
|
|
| 1163 |
|
| 1164 |
# Initialize Stack states
|
| 1165 |
stack_state = past_stack_state
|
|
|
|
| 1190 |
stack_mask = layer_outputs[3]
|
| 1191 |
|
| 1192 |
# ResFormer: capture H_fan_1 from the first layer
|
| 1193 |
+
# Dynamically capture for the current pass
|
|
|
|
| 1194 |
if self.first_layer_fan is None and hasattr(decoder_layer, 'current_layer_fan'):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1195 |
self.first_layer_fan = decoder_layer.current_layer_fan
|
| 1196 |
|
| 1197 |
# Apply SeeDNorm for final normalization
|
|
|
|
| 1200 |
if output_hidden_states:
|
| 1201 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 1202 |
|
| 1203 |
+
# Construct the persistence tuple (Stack only)
|
|
|
|
| 1204 |
next_cache = None
|
| 1205 |
+
if self.use_stack:
|
| 1206 |
+
next_cache = (stack_state, stack_mask)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1207 |
|
| 1208 |
if not return_dict:
|
| 1209 |
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attentions] if v is not None)
|
|
|
|
| 1269 |
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
| 1270 |
):
|
| 1271 |
# Extract custom states from past_key_values if present
|
| 1272 |
+
# Structure: (stack_state, stack_mask)
|
| 1273 |
past_stack_state = None
|
| 1274 |
past_stack_mask = None
|
|
|
|
| 1275 |
|
| 1276 |
if past_key_values is not None:
|
| 1277 |
# We use the past_key_values as a container for our custom states
|
| 1278 |
+
if len(past_key_values) == 2:
|
| 1279 |
+
past_stack_state, past_stack_mask = past_key_values
|
|
|
|
| 1280 |
|
| 1281 |
# Helper for generation loop: input_ids should be just the last token if we have past
|
| 1282 |
input_ids = input_ids[:, -1:]
|
|
|
|
| 1285 |
"input_ids": input_ids,
|
| 1286 |
"past_stack_state": past_stack_state,
|
| 1287 |
"past_stack_mask": past_stack_mask,
|
|
|
|
| 1288 |
"use_cache": kwargs.get("use_cache"),
|
| 1289 |
"position_ids": kwargs.get("position_ids", None),
|
| 1290 |
"attention_mask": attention_mask,
|
|
|
|
| 1304 |
return_dict: Optional[bool] = None,
|
| 1305 |
past_stack_state: Optional[torch.Tensor] = None,
|
| 1306 |
past_stack_mask: Optional[torch.Tensor] = None,
|
|
|
|
| 1307 |
**kwargs: Unpack[TransformersKwargs],
|
| 1308 |
) -> CausalLMOutputWithPast:
|
| 1309 |
outputs: BaseModelOutputWithPast = self.model(
|
|
|
|
| 1315 |
return_dict=return_dict,
|
| 1316 |
past_stack_state=past_stack_state,
|
| 1317 |
past_stack_mask=past_stack_mask,
|
|
|
|
| 1318 |
**kwargs,
|
| 1319 |
)
|
| 1320 |
|