KitsuVp commited on
Commit
ac61dcd
·
verified ·
1 Parent(s): 7e57200

Update modeling_neollm.py

Browse files
Files changed (1) hide show
  1. modeling_neollm.py +139 -55
modeling_neollm.py CHANGED
@@ -37,7 +37,7 @@ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
37
  from transformers.processing_utils import Unpack
38
  from transformers.utils import TransformersKwargs, logging
39
  from transformers.utils.generic import check_model_inputs
40
- from .configuration_neollm import NeoLLMConfig
41
 
42
  from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
43
 
@@ -325,8 +325,6 @@ class SeeDNorm(nn.Module):
325
  # ==================== STACK MEMORY MODULE ====================
326
  class StackMemory(nn.Module):
327
  """
328
- Differentiable Hidden State Stack for modeling Chomsky hierarchy grammars.
329
-
330
  From "Improving Formal Reasoning of Transformer with State Stack":
331
  Implements a multi-head differentiable stack with soft push, pop, and no-op operations.
332
  Each head maintains its own stack and mask, which are updated based on learned action
@@ -354,8 +352,8 @@ class StackMemory(nn.Module):
354
 
355
  # Dimension reduction projections for efficiency
356
  # Uses standard nn.Linear
357
- self.down_proj = nn.Linear(config.hidden_size, self.stack_d_model, bias=False)
358
- self.up_proj = nn.Linear(self.stack_d_model, config.hidden_size, bias=False)
359
 
360
  # Action prediction: generates push/pop/no-op probabilities for each head
361
  self.action_head = nn.Linear(self.stack_d_model, 3 * self.num_stack_heads, bias=True)
@@ -365,6 +363,20 @@ class StackMemory(nn.Module):
365
 
366
  # Residual weight for gating stack contribution
367
  self.res_weight = nn.Parameter(torch.ones(1))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368
 
369
  def _vectorized_update(
370
  self,
@@ -393,8 +405,10 @@ class StackMemory(nn.Module):
393
  batch_size, seq_len = actions.shape[:2]
394
 
395
  # Expand stack and mask along sequence dimension for parallel processing
396
- stack = stack.unsqueeze(1).expand(-1, seq_len, -1, -1, -1)
397
- mask = mask.unsqueeze(1).expand(-1, seq_len, -1, -1)
 
 
398
 
399
  # Generate pushed stack: new value at top, shift others down
400
  push_stack = torch.cat([
@@ -476,33 +490,93 @@ class StackMemory(nn.Module):
476
  new_stack, new_mask = self._vectorized_update(stack, mask, actions, k_values)
477
 
478
  # Global reading via query-over-stack attention
479
-
480
- # FIX: Project the raw stack content directly.
481
- # Previously, masking before projection killed gradients for "empty" slots
482
- # preventing them from ever becoming "full".
483
  gate_scores = self.gate_proj(new_stack).squeeze(-1) # [batch, seq, heads, slots]
484
 
485
- # Apply mask to the SCORES, not the features.
486
- # Mask out invalid positions (add large negative value where mask is 0)
487
- gate_scores = gate_scores + (1 - new_mask) * -1e9
488
-
489
- # Softmax to get attention weights
490
- gate_weights = F.softmax(gate_scores, dim=-1)
491
 
492
  # Weighted sum over stack slots
493
- # new_stack contains the features, gate_weights contains the validity/relevance
494
  memory_output = (new_stack * gate_weights.unsqueeze(-1)).sum(dim=3)
495
  memory_output = memory_output.view(batch_size, seq_len, -1)
496
-
497
- # Project back to original dimension
498
  memory_output = self.up_proj(memory_output)
499
 
500
- # Gated residual connection
501
  output = memory_output * self.res_weight + hidden_states
502
 
503
- # Return output and updated stack state (use last timestep's state)
 
 
 
504
  return output, new_stack[:, -1], new_mask[:, -1]
505
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
506
  # ==================== ROTARY EMBEDDING ====================
507
  class NeoLLMRotaryEmbedding(nn.Module):
508
  inv_freq: torch.Tensor # fix linting for `register_buffer`
@@ -1119,8 +1193,8 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
1119
  output_hidden_states: Optional[bool] = None,
1120
  output_attentions: Optional[bool] = None,
1121
  return_dict: Optional[bool] = None,
1122
- past_stack_state: Optional[torch.Tensor] = None,
1123
- past_stack_mask: Optional[torch.Tensor] = None,
1124
  **kwargs: Unpack[TransformersKwargs],
1125
  ) -> BaseModelOutputWithPast:
1126
  output_hidden_states = (
@@ -1152,6 +1226,7 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
1152
  )
1153
 
1154
  hidden_states = inputs_embeds
 
1155
  all_hidden_states = () if output_hidden_states else None
1156
  all_attentions = () if output_attentions else None
1157
 
@@ -1161,9 +1236,17 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
1161
  # ResFormer with first-layer feature propagation
1162
  self.first_layer_fan = None
1163
 
1164
- # Initialize Stack states
1165
- stack_state = past_stack_state
1166
- stack_mask = past_stack_mask
 
 
 
 
 
 
 
 
1167
 
1168
  for decoder_layer in self.layers:
1169
  if output_hidden_states:
@@ -1186,6 +1269,9 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
1186
  all_attentions = all_attentions + (layer_outputs[1],)
1187
 
1188
  if self.use_stack:
 
 
 
1189
  stack_state = layer_outputs[2]
1190
  stack_mask = layer_outputs[3]
1191
 
@@ -1199,18 +1285,13 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
1199
 
1200
  if output_hidden_states:
1201
  all_hidden_states = all_hidden_states + (hidden_states,)
1202
-
1203
- # Construct the persistence tuple (Stack only)
1204
- next_cache = None
1205
- if self.use_stack:
1206
- next_cache = (stack_state, stack_mask)
1207
 
1208
  if not return_dict:
1209
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attentions] if v is not None)
1210
 
1211
  return BaseModelOutputWithPast(
1212
  last_hidden_state=hidden_states,
1213
- past_key_values=next_cache,
1214
  hidden_states=all_hidden_states,
1215
  attentions=all_attentions,
1216
  )
@@ -1268,29 +1349,34 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
1268
  def prepare_inputs_for_generation(
1269
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1270
  ):
1271
- # Extract custom states from past_key_values if present
1272
- # Structure: (stack_state, stack_mask)
1273
- past_stack_state = None
1274
- past_stack_mask = None
1275
-
1276
- if past_key_values is not None:
1277
- # We use the past_key_values as a container for our custom states
1278
- if len(past_key_values) == 2:
1279
- past_stack_state, past_stack_mask = past_key_values
1280
 
1281
- # Helper for generation loop: input_ids should be just the last token if we have past
1282
- input_ids = input_ids[:, -1:]
1283
-
1284
- model_inputs = {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1285
  "input_ids": input_ids,
1286
- "past_stack_state": past_stack_state,
1287
- "past_stack_mask": past_stack_mask,
1288
  "use_cache": kwargs.get("use_cache"),
1289
- "position_ids": kwargs.get("position_ids", None),
1290
  "attention_mask": attention_mask,
1291
  "inputs_embeds": inputs_embeds,
1292
  }
1293
- return model_inputs
1294
 
1295
  def forward(
1296
  self,
@@ -1302,8 +1388,7 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
1302
  logits_to_keep: Union[int, torch.Tensor] = 0,
1303
  output_hidden_states: Optional[bool] = None,
1304
  return_dict: Optional[bool] = None,
1305
- past_stack_state: Optional[torch.Tensor] = None,
1306
- past_stack_mask: Optional[torch.Tensor] = None,
1307
  **kwargs: Unpack[TransformersKwargs],
1308
  ) -> CausalLMOutputWithPast:
1309
  outputs: BaseModelOutputWithPast = self.model(
@@ -1313,8 +1398,7 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
1313
  inputs_embeds=inputs_embeds,
1314
  output_hidden_states=output_hidden_states,
1315
  return_dict=return_dict,
1316
- past_stack_state=past_stack_state,
1317
- past_stack_mask=past_stack_mask,
1318
  **kwargs,
1319
  )
1320
 
 
37
  from transformers.processing_utils import Unpack
38
  from transformers.utils import TransformersKwargs, logging
39
  from transformers.utils.generic import check_model_inputs
40
+ from configuration_neollm import NeoLLMConfig
41
 
42
  from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
43
 
 
325
  # ==================== STACK MEMORY MODULE ====================
326
  class StackMemory(nn.Module):
327
  """
 
 
328
  From "Improving Formal Reasoning of Transformer with State Stack":
329
  Implements a multi-head differentiable stack with soft push, pop, and no-op operations.
330
  Each head maintains its own stack and mask, which are updated based on learned action
 
352
 
353
  # Dimension reduction projections for efficiency
354
  # Uses standard nn.Linear
355
+ self.down_proj = nn.Linear(config.hidden_size, self.stack_d_model, bias=True)
356
+ self.up_proj = nn.Linear(self.stack_d_model, config.hidden_size, bias=True)
357
 
358
  # Action prediction: generates push/pop/no-op probabilities for each head
359
  self.action_head = nn.Linear(self.stack_d_model, 3 * self.num_stack_heads, bias=True)
 
363
 
364
  # Residual weight for gating stack contribution
365
  self.res_weight = nn.Parameter(torch.ones(1))
366
+
367
+ # Cache for autoregressive generation (matches OLMo reference)
368
+ self.cache_size = getattr(config, "cache_size", 2048)
369
+ # Initialization fix: Register buffers for cache
370
+ # Default to batch_size=1 if forward_bs is not in config (standard inference)
371
+ forward_bs = getattr(config, 'forward_bs', 1)
372
+ self.register_buffer("k_cache", torch.zeros(forward_bs, self.cache_size, self.num_stack_heads, self.head_dim))
373
+ self.register_buffer("action_cache", torch.zeros(forward_bs, self.cache_size, self.num_stack_heads, 3))
374
+
375
+ self.cache_position = 0
376
+ self.enable_cache = False
377
+
378
+ def reset_cache(self):
379
+ self.cache_position = 0
380
 
381
  def _vectorized_update(
382
  self,
 
405
  batch_size, seq_len = actions.shape[:2]
406
 
407
  # Expand stack and mask along sequence dimension for parallel processing
408
+ # Only expand if checking against initial state dimensions (4D)
409
+ if stack.dim() == 4:
410
+ stack = stack.unsqueeze(1).expand(-1, seq_len, -1, -1, -1)
411
+ mask = mask.unsqueeze(1).expand(-1, seq_len, -1, -1)
412
 
413
  # Generate pushed stack: new value at top, shift others down
414
  push_stack = torch.cat([
 
490
  new_stack, new_mask = self._vectorized_update(stack, mask, actions, k_values)
491
 
492
  # Global reading via query-over-stack attention
 
 
 
 
493
  gate_scores = self.gate_proj(new_stack).squeeze(-1) # [batch, seq, heads, slots]
494
 
495
+ gate_weights = F.softmax(gate_scores + (1 - new_mask) * -1e9, dim=-1)
 
 
 
 
 
496
 
497
  # Weighted sum over stack slots
 
498
  memory_output = (new_stack * gate_weights.unsqueeze(-1)).sum(dim=3)
499
  memory_output = memory_output.view(batch_size, seq_len, -1)
500
+
 
501
  memory_output = self.up_proj(memory_output)
502
 
503
+ # Residual Connection
504
  output = memory_output * self.res_weight + hidden_states
505
 
506
+ # Update Cache Logic
507
+ if self.enable_cache:
508
+ self._update_cache(k_values.detach(), actions.detach())
509
+
510
  return output, new_stack[:, -1], new_mask[:, -1]
511
 
512
+ def _update_cache(self, k_values: torch.Tensor, actions: torch.Tensor):
513
+ seq_len = k_values.shape[1]
514
+ if self.cache_position + seq_len <= self.cache_size:
515
+ # Assumes standard batch processing for inference (usually batch_size=1)
516
+ self.k_cache[:, self.cache_position:self.cache_position+seq_len] = k_values
517
+ self.action_cache[:, self.cache_position:self.cache_position+seq_len] = actions
518
+ self.cache_position += seq_len
519
+ else:
520
+ self.reset_cache()
521
+
522
+ def step(self, hidden_state: torch.Tensor, stack: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
523
+ if not self.enable_cache:
524
+ return self.forward(hidden_state.unsqueeze(1), stack, mask)
525
+
526
+ batch_size = hidden_state.shape[0]
527
+
528
+ # Compute features for current token
529
+ new_hidden_states = self.down_proj(hidden_state)
530
+
531
+ action_logits = self.action_head(new_hidden_states) / math.sqrt(self.head_dim)
532
+ current_actions = F.softmax(
533
+ action_logits.view(batch_size, 1, self.num_stack_heads, 3),
534
+ dim=-1
535
+ )
536
+
537
+ current_k = new_hidden_states.view(batch_size, 1, self.num_stack_heads, self.head_dim)
538
+
539
+ # Reconstruct History
540
+ if self.cache_position > 0:
541
+ cached_k = self.k_cache[:, :self.cache_position]
542
+ cached_actions = self.action_cache[:, :self.cache_position]
543
+
544
+ k_values = torch.cat([cached_k, current_k], dim=1)
545
+ actions = torch.cat([cached_actions, current_actions], dim=1)
546
+ else:
547
+ k_values = current_k
548
+ actions = current_actions
549
+
550
+ # Dimension Fix: Pass sequences directly without unsqueeze(0)
551
+ # k_values is [batch, seq_len_total, heads, dim]
552
+ # actions is [batch, seq_len_total, heads, 3]
553
+
554
+ new_stack_seq, new_mask_seq = self._vectorized_update(
555
+ stack, # Initial stack [batch, heads, slots, dim]
556
+ mask,
557
+ actions,
558
+ k_values
559
+ )
560
+
561
+ # Extract last step
562
+ current_stack = new_stack_seq[:, -1]
563
+ current_mask = new_mask_seq[:, -1]
564
+
565
+ gate_scores = self.gate_proj(current_stack).squeeze(-1)
566
+ gate_weights = F.softmax(gate_scores + (1 - current_mask) * -1e9, dim=-1)
567
+
568
+ memory_output = (current_stack * gate_weights.unsqueeze(-1)).sum(dim=2)
569
+ memory_output = memory_output.view(batch_size, -1)
570
+
571
+ memory_output_proj = self.up_proj(memory_output)
572
+
573
+ self._update_cache(current_k, current_actions)
574
+
575
+ return (
576
+ memory_output_proj * self.res_weight + hidden_state,
577
+ current_stack,
578
+ current_mask
579
+ )
580
  # ==================== ROTARY EMBEDDING ====================
581
  class NeoLLMRotaryEmbedding(nn.Module):
582
  inv_freq: torch.Tensor # fix linting for `register_buffer`
 
1193
  output_hidden_states: Optional[bool] = None,
1194
  output_attentions: Optional[bool] = None,
1195
  return_dict: Optional[bool] = None,
1196
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1197
+ use_cache: Optional[bool] = None,
1198
  **kwargs: Unpack[TransformersKwargs],
1199
  ) -> BaseModelOutputWithPast:
1200
  output_hidden_states = (
 
1226
  )
1227
 
1228
  hidden_states = inputs_embeds
1229
+ next_decoder_cache = None
1230
  all_hidden_states = () if output_hidden_states else None
1231
  all_attentions = () if output_attentions else None
1232
 
 
1236
  # ResFormer with first-layer feature propagation
1237
  self.first_layer_fan = None
1238
 
1239
+ # Initialize Stack states (always None at start of forward, rebuilt via cache step or vertical flow)
1240
+ stack_state = None
1241
+ stack_mask = None
1242
+
1243
+ # Propagate use_cache and reset if starting a new sequence
1244
+ if self.use_stack:
1245
+ for layer in self.layers:
1246
+ if hasattr(layer, 'stack_memory'):
1247
+ layer.stack_memory.enable_cache = use_cache if use_cache is not None else False
1248
+ if past_key_values is None:
1249
+ layer.stack_memory.reset_cache()
1250
 
1251
  for decoder_layer in self.layers:
1252
  if output_hidden_states:
 
1269
  all_attentions = all_attentions + (layer_outputs[1],)
1270
 
1271
  if self.use_stack:
1272
+ # Vertical memory logic:
1273
+ # The layer returns updated stack for the next layer to use (Vertical passing)
1274
+ # But we do NOT persist it temporally here. The Module's internal cache handles temporal.
1275
  stack_state = layer_outputs[2]
1276
  stack_mask = layer_outputs[3]
1277
 
 
1285
 
1286
  if output_hidden_states:
1287
  all_hidden_states = all_hidden_states + (hidden_states,)
 
 
 
 
 
1288
 
1289
  if not return_dict:
1290
+ return tuple(v for v in [hidden_states, next_decoder_cache, all_hidden_states, all_attentions] if v is not None)
1291
 
1292
  return BaseModelOutputWithPast(
1293
  last_hidden_state=hidden_states,
1294
+ past_key_values=next_decoder_cache,
1295
  hidden_states=all_hidden_states,
1296
  attentions=all_attentions,
1297
  )
 
1349
  def prepare_inputs_for_generation(
1350
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1351
  ):
1352
+ if past_key_values:
1353
+ past_length = past_key_values[0][0].shape[2]
 
 
 
 
 
 
 
1354
 
1355
+ # If past_length > input_ids length, we are likely generating token by token
1356
+ if input_ids.shape[1] > past_length:
1357
+ remove_prefix_length = past_length
1358
+ else:
1359
+ # Default standard HF behavior
1360
+ remove_prefix_length = input_ids.shape[1] - 1
1361
+
1362
+ input_ids = input_ids[:, remove_prefix_length:]
1363
+
1364
+ position_ids = kwargs.get("position_ids", None)
1365
+ if attention_mask is not None and position_ids is None:
1366
+ # create position_ids on the fly for batch generation
1367
+ position_ids = attention_mask.long().cumsum(-1) - 1
1368
+ position_ids.masked_fill_(attention_mask == 0, 1)
1369
+ if past_key_values:
1370
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1371
+
1372
+ return {
1373
  "input_ids": input_ids,
1374
+ "past_key_values": past_key_values,
 
1375
  "use_cache": kwargs.get("use_cache"),
1376
+ "position_ids": position_ids,
1377
  "attention_mask": attention_mask,
1378
  "inputs_embeds": inputs_embeds,
1379
  }
 
1380
 
1381
  def forward(
1382
  self,
 
1388
  logits_to_keep: Union[int, torch.Tensor] = 0,
1389
  output_hidden_states: Optional[bool] = None,
1390
  return_dict: Optional[bool] = None,
1391
+
 
1392
  **kwargs: Unpack[TransformersKwargs],
1393
  ) -> CausalLMOutputWithPast:
1394
  outputs: BaseModelOutputWithPast = self.model(
 
1398
  inputs_embeds=inputs_embeds,
1399
  output_hidden_states=output_hidden_states,
1400
  return_dict=return_dict,
1401
+
 
1402
  **kwargs,
1403
  )
1404