Update modeling_neollm.py
Browse files- modeling_neollm.py +65 -7
modeling_neollm.py
CHANGED
|
@@ -37,7 +37,7 @@ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
| 37 |
from transformers.processing_utils import Unpack
|
| 38 |
from transformers.utils import TransformersKwargs, logging
|
| 39 |
from transformers.utils.generic import check_model_inputs
|
| 40 |
-
from
|
| 41 |
|
| 42 |
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
| 43 |
|
|
@@ -1119,6 +1119,9 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 1119 |
output_hidden_states: Optional[bool] = None,
|
| 1120 |
output_attentions: Optional[bool] = None,
|
| 1121 |
return_dict: Optional[bool] = None,
|
|
|
|
|
|
|
|
|
|
| 1122 |
**kwargs: Unpack[TransformersKwargs],
|
| 1123 |
) -> BaseModelOutputWithPast:
|
| 1124 |
output_hidden_states = (
|
|
@@ -1157,9 +1160,12 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 1157 |
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 1158 |
|
| 1159 |
# ResFormer with first-layer feature propagation
|
| 1160 |
-
|
| 1161 |
-
|
| 1162 |
-
|
|
|
|
|
|
|
|
|
|
| 1163 |
|
| 1164 |
for decoder_layer in self.layers:
|
| 1165 |
if output_hidden_states:
|
|
@@ -1186,7 +1192,13 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 1186 |
stack_mask = layer_outputs[3]
|
| 1187 |
|
| 1188 |
# ResFormer: capture H_fan_1 from the first layer
|
|
|
|
|
|
|
| 1189 |
if self.first_layer_fan is None and hasattr(decoder_layer, 'current_layer_fan'):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1190 |
self.first_layer_fan = decoder_layer.current_layer_fan
|
| 1191 |
|
| 1192 |
# Apply SeeDNorm for final normalization
|
|
@@ -1194,13 +1206,23 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 1194 |
|
| 1195 |
if output_hidden_states:
|
| 1196 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1197 |
|
| 1198 |
if not return_dict:
|
| 1199 |
-
return tuple(v for v in [hidden_states,
|
| 1200 |
|
| 1201 |
return BaseModelOutputWithPast(
|
| 1202 |
last_hidden_state=hidden_states,
|
| 1203 |
-
past_key_values=
|
| 1204 |
hidden_states=all_hidden_states,
|
| 1205 |
attentions=all_attentions,
|
| 1206 |
)
|
|
@@ -1255,6 +1277,36 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
|
|
| 1255 |
|
| 1256 |
self.post_init()
|
| 1257 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1258 |
def forward(
|
| 1259 |
self,
|
| 1260 |
input_ids: Optional[torch.LongTensor] = None,
|
|
@@ -1265,6 +1317,9 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
|
|
| 1265 |
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 1266 |
output_hidden_states: Optional[bool] = None,
|
| 1267 |
return_dict: Optional[bool] = None,
|
|
|
|
|
|
|
|
|
|
| 1268 |
**kwargs: Unpack[TransformersKwargs],
|
| 1269 |
) -> CausalLMOutputWithPast:
|
| 1270 |
outputs: BaseModelOutputWithPast = self.model(
|
|
@@ -1274,6 +1329,9 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
|
|
| 1274 |
inputs_embeds=inputs_embeds,
|
| 1275 |
output_hidden_states=output_hidden_states,
|
| 1276 |
return_dict=return_dict,
|
|
|
|
|
|
|
|
|
|
| 1277 |
**kwargs,
|
| 1278 |
)
|
| 1279 |
|
|
@@ -1298,7 +1356,7 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
|
|
| 1298 |
return CausalLMOutputWithPast(
|
| 1299 |
loss=loss,
|
| 1300 |
logits=logits,
|
| 1301 |
-
past_key_values=
|
| 1302 |
hidden_states=outputs.hidden_states,
|
| 1303 |
attentions=outputs.attentions,
|
| 1304 |
)
|
|
|
|
| 37 |
from transformers.processing_utils import Unpack
|
| 38 |
from transformers.utils import TransformersKwargs, logging
|
| 39 |
from transformers.utils.generic import check_model_inputs
|
| 40 |
+
from configuration_neollm import NeoLLMConfig
|
| 41 |
|
| 42 |
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
| 43 |
|
|
|
|
| 1119 |
output_hidden_states: Optional[bool] = None,
|
| 1120 |
output_attentions: Optional[bool] = None,
|
| 1121 |
return_dict: Optional[bool] = None,
|
| 1122 |
+
past_stack_state: Optional[torch.Tensor] = None,
|
| 1123 |
+
past_stack_mask: Optional[torch.Tensor] = None,
|
| 1124 |
+
past_first_layer_fan: Optional[torch.Tensor] = None,
|
| 1125 |
**kwargs: Unpack[TransformersKwargs],
|
| 1126 |
) -> BaseModelOutputWithPast:
|
| 1127 |
output_hidden_states = (
|
|
|
|
| 1160 |
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 1161 |
|
| 1162 |
# ResFormer with first-layer feature propagation
|
| 1163 |
+
# Retrieve persistent ResFormer state if provided (for inference)
|
| 1164 |
+
self.first_layer_fan = past_first_layer_fan
|
| 1165 |
+
|
| 1166 |
+
# Initialize Stack states
|
| 1167 |
+
stack_state = past_stack_state
|
| 1168 |
+
stack_mask = past_stack_mask
|
| 1169 |
|
| 1170 |
for decoder_layer in self.layers:
|
| 1171 |
if output_hidden_states:
|
|
|
|
| 1192 |
stack_mask = layer_outputs[3]
|
| 1193 |
|
| 1194 |
# ResFormer: capture H_fan_1 from the first layer
|
| 1195 |
+
# If we didn't have it (and it wasn't passed via past), capture it now.
|
| 1196 |
+
# For inference, if we just computed the prompt/first token, we keep it.
|
| 1197 |
if self.first_layer_fan is None and hasattr(decoder_layer, 'current_layer_fan'):
|
| 1198 |
+
# Crucial: For persistence, we might want to slice this if it's the prompt?
|
| 1199 |
+
# But logic says: reuse same tensor. If seq_len > 1, it's prompt.
|
| 1200 |
+
# If seq_len == 1, it's generation.
|
| 1201 |
+
# If we are starting fresh (None), we capture what we have.
|
| 1202 |
self.first_layer_fan = decoder_layer.current_layer_fan
|
| 1203 |
|
| 1204 |
# Apply SeeDNorm for final normalization
|
|
|
|
| 1206 |
|
| 1207 |
if output_hidden_states:
|
| 1208 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 1209 |
+
|
| 1210 |
+
# Construct the persistence tuple (Stack + Fan)
|
| 1211 |
+
# Note: We do not implement full KV cache yet, but we persist these states.
|
| 1212 |
+
next_cache = None
|
| 1213 |
+
if self.use_stack or self.first_layer_fan is not None:
|
| 1214 |
+
# Capture the first token's FAN for ResFormer persistence
|
| 1215 |
+
# If we have a sequence, we probably want to keep the FIRST token's fan for consistency?
|
| 1216 |
+
# Or just keep the whole thing? The requirement is "reutilizar".
|
| 1217 |
+
# We keep the object self.first_layer_fan.
|
| 1218 |
+
next_cache = (stack_state, stack_mask, self.first_layer_fan)
|
| 1219 |
|
| 1220 |
if not return_dict:
|
| 1221 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attentions] if v is not None)
|
| 1222 |
|
| 1223 |
return BaseModelOutputWithPast(
|
| 1224 |
last_hidden_state=hidden_states,
|
| 1225 |
+
past_key_values=next_cache,
|
| 1226 |
hidden_states=all_hidden_states,
|
| 1227 |
attentions=all_attentions,
|
| 1228 |
)
|
|
|
|
| 1277 |
|
| 1278 |
self.post_init()
|
| 1279 |
|
| 1280 |
+
def prepare_inputs_for_generation(
|
| 1281 |
+
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
| 1282 |
+
):
|
| 1283 |
+
# Extract custom states from past_key_values if present
|
| 1284 |
+
# Structure: (stack_state, stack_mask, first_layer_fan)
|
| 1285 |
+
past_stack_state = None
|
| 1286 |
+
past_stack_mask = None
|
| 1287 |
+
past_first_layer_fan = None
|
| 1288 |
+
|
| 1289 |
+
if past_key_values is not None:
|
| 1290 |
+
# We use the past_key_values as a container for our custom states
|
| 1291 |
+
# Since we don't have standard KV cache yet, it should just be our tuple
|
| 1292 |
+
if len(past_key_values) == 3:
|
| 1293 |
+
past_stack_state, past_stack_mask, past_first_layer_fan = past_key_values
|
| 1294 |
+
|
| 1295 |
+
# Helper for generation loop: input_ids should be just the last token if we have past
|
| 1296 |
+
input_ids = input_ids[:, -1:]
|
| 1297 |
+
|
| 1298 |
+
model_inputs = {
|
| 1299 |
+
"input_ids": input_ids,
|
| 1300 |
+
"past_stack_state": past_stack_state,
|
| 1301 |
+
"past_stack_mask": past_stack_mask,
|
| 1302 |
+
"past_first_layer_fan": past_first_layer_fan,
|
| 1303 |
+
"use_cache": kwargs.get("use_cache"),
|
| 1304 |
+
"position_ids": kwargs.get("position_ids", None),
|
| 1305 |
+
"attention_mask": attention_mask,
|
| 1306 |
+
"inputs_embeds": inputs_embeds,
|
| 1307 |
+
}
|
| 1308 |
+
return model_inputs
|
| 1309 |
+
|
| 1310 |
def forward(
|
| 1311 |
self,
|
| 1312 |
input_ids: Optional[torch.LongTensor] = None,
|
|
|
|
| 1317 |
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 1318 |
output_hidden_states: Optional[bool] = None,
|
| 1319 |
return_dict: Optional[bool] = None,
|
| 1320 |
+
past_stack_state: Optional[torch.Tensor] = None,
|
| 1321 |
+
past_stack_mask: Optional[torch.Tensor] = None,
|
| 1322 |
+
past_first_layer_fan: Optional[torch.Tensor] = None,
|
| 1323 |
**kwargs: Unpack[TransformersKwargs],
|
| 1324 |
) -> CausalLMOutputWithPast:
|
| 1325 |
outputs: BaseModelOutputWithPast = self.model(
|
|
|
|
| 1329 |
inputs_embeds=inputs_embeds,
|
| 1330 |
output_hidden_states=output_hidden_states,
|
| 1331 |
return_dict=return_dict,
|
| 1332 |
+
past_stack_state=past_stack_state,
|
| 1333 |
+
past_stack_mask=past_stack_mask,
|
| 1334 |
+
past_first_layer_fan=past_first_layer_fan,
|
| 1335 |
**kwargs,
|
| 1336 |
)
|
| 1337 |
|
|
|
|
| 1356 |
return CausalLMOutputWithPast(
|
| 1357 |
loss=loss,
|
| 1358 |
logits=logits,
|
| 1359 |
+
past_key_values=outputs.past_key_values,
|
| 1360 |
hidden_states=outputs.hidden_states,
|
| 1361 |
attentions=outputs.attentions,
|
| 1362 |
)
|