KitsuVp commited on
Commit
470f299
·
verified ·
1 Parent(s): 3e2ca3a

Update modeling_neollm.py

Browse files
Files changed (1) hide show
  1. modeling_neollm.py +30 -28
modeling_neollm.py CHANGED
@@ -4,7 +4,6 @@ NeoLLM Model with FANformer Integration in both Attention and FFN, Dropout Regul
4
  SeeDNorm (Self-Rescaled Dynamic Normalization), ResFormer Value Residual Learning,
5
  Learnable Multipliers for enhanced scale adaptation and information flow through deep layers,
6
  and StackMemory for hierarchical pattern modeling.
7
-
8
  Updated to include:
9
  - Fourier Analysis Network (FAN) layer for effective periodicity modeling in attention (relational space)
10
  - FAN layer in FFN for featural periodicity modeling (complementary coverage)
@@ -260,7 +259,6 @@ class SeeDNorm(nn.Module):
260
  Self-Rescaled Dynamic Normalization (SeeDNorm) with dual dropout regularization.
261
 
262
  SeeDNorm(x) = [σ(x·β^T)·α + γ] ⊙ x/RMS(x)
263
-
264
 
265
  Args:
266
  dim: Hidden dimension size
@@ -325,7 +323,6 @@ class SeeDNorm(nn.Module):
325
 
326
 
327
  # ==================== STACK MEMORY MODULE ====================
328
-
329
  class StackMemory(nn.Module):
330
  """
331
  Differentiable Hidden State Stack for modeling Chomsky hierarchy grammars.
@@ -364,7 +361,7 @@ class StackMemory(nn.Module):
364
  self.action_head = nn.Linear(self.stack_d_model, 3 * self.num_stack_heads, bias=True)
365
 
366
  # Query projection for global reading (one per head)
367
- self.gate_proj = nn.Linear(self.head_dim, 1, bias=False)
368
 
369
  # Residual weight for gating stack contribution
370
  self.res_weight = nn.Parameter(torch.ones(1))
@@ -479,19 +476,21 @@ class StackMemory(nn.Module):
479
  new_stack, new_mask = self._vectorized_update(stack, mask, actions, k_values)
480
 
481
  # Global reading via query-over-stack attention
482
- # Apply mask before attention computation
483
- masked_stack = new_stack * new_mask.unsqueeze(-1)
484
 
485
- # Compute attention scores for each head
486
- gate_scores = self.gate_proj(masked_stack).squeeze(-1) # [batch, seq, heads, slots]
 
 
487
 
488
- # Mask out invalid positions (add large negative value)
 
489
  gate_scores = gate_scores + (1 - new_mask) * -1e9
490
 
491
  # Softmax to get attention weights
492
  gate_weights = F.softmax(gate_scores, dim=-1)
493
 
494
  # Weighted sum over stack slots
 
495
  memory_output = (new_stack * gate_weights.unsqueeze(-1)).sum(dim=3)
496
  memory_output = memory_output.view(batch_size, seq_len, -1)
497
 
@@ -882,19 +881,18 @@ class NeoLLMMLP(nn.Module):
882
  hidden = self.dropout(hidden)
883
  return self.down_proj(hidden)
884
 
885
-
886
  class NeoLLMDecoderLayer(GradientCheckpointingLayer):
887
  """
888
  Decoder layer with standard residual connections and optional StackMemory.
889
 
890
- Architecture:
891
- 1. Pre-norm (SeeDNorm) LNS scaling → Self-Attention with ResFormer and Learnable Multipliers
892
- 2. Standard Residual Connection
893
- 3. GPAS activation scaling
894
- 4. Pre-norm (SeeDNorm) → LNS scaling → MLP with FANformer and Learnable Multipliers
895
- 5. Standard Residual Connection
896
- 6. GPAS activation scaling
897
- 7. Optional: StackMemory module
898
  """
899
 
900
  def __init__(self, config: NeoLLMConfig, layer_idx: int):
@@ -954,8 +952,19 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
954
  Returns:
955
  Tuple of (hidden_states, attn_weights, stack_state, stack_mask)
956
  """
 
957
  # ============================================================
958
- # Attention Block with Standard Residual Connection
 
 
 
 
 
 
 
 
 
 
959
  # ============================================================
960
  residual = hidden_states
961
 
@@ -981,7 +990,7 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
981
  hidden_states = self.gpas_attn(hidden_states)
982
 
983
  # ============================================================
984
- # MLP Block with Standard Residual Connection
985
  # ============================================================
986
  residual = hidden_states
987
  hidden_states = self.post_attention_layernorm(hidden_states)
@@ -998,14 +1007,7 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
998
  # Apply GPAS after residual connection
999
  hidden_states = self.gpas_mlp(hidden_states)
1000
 
1001
- # ============================================================
1002
- # Stack Memory Module
1003
- # ============================================================
1004
- if self.use_stack:
1005
- hidden_states, stack_state, stack_mask = self.stack_memory(
1006
- hidden_states, stack_state, stack_mask
1007
- )
1008
-
1009
  if self.use_stack:
1010
  return (hidden_states, attn_weights, stack_state, stack_mask)
1011
  else:
 
4
  SeeDNorm (Self-Rescaled Dynamic Normalization), ResFormer Value Residual Learning,
5
  Learnable Multipliers for enhanced scale adaptation and information flow through deep layers,
6
  and StackMemory for hierarchical pattern modeling.
 
7
  Updated to include:
8
  - Fourier Analysis Network (FAN) layer for effective periodicity modeling in attention (relational space)
9
  - FAN layer in FFN for featural periodicity modeling (complementary coverage)
 
259
  Self-Rescaled Dynamic Normalization (SeeDNorm) with dual dropout regularization.
260
 
261
  SeeDNorm(x) = [σ(x·β^T)·α + γ] ⊙ x/RMS(x)
 
262
 
263
  Args:
264
  dim: Hidden dimension size
 
323
 
324
 
325
  # ==================== STACK MEMORY MODULE ====================
 
326
  class StackMemory(nn.Module):
327
  """
328
  Differentiable Hidden State Stack for modeling Chomsky hierarchy grammars.
 
361
  self.action_head = nn.Linear(self.stack_d_model, 3 * self.num_stack_heads, bias=True)
362
 
363
  # Query projection for global reading (one per head)
364
+ self.gate_proj = nn.Linear(self.head_dim, 1, bias=True)
365
 
366
  # Residual weight for gating stack contribution
367
  self.res_weight = nn.Parameter(torch.ones(1))
 
476
  new_stack, new_mask = self._vectorized_update(stack, mask, actions, k_values)
477
 
478
  # Global reading via query-over-stack attention
 
 
479
 
480
+ # FIX: Project the raw stack content directly.
481
+ # Previously, masking before projection killed gradients for "empty" slots
482
+ # preventing them from ever becoming "full".
483
+ gate_scores = self.gate_proj(new_stack).squeeze(-1) # [batch, seq, heads, slots]
484
 
485
+ # Apply mask to the SCORES, not the features.
486
+ # Mask out invalid positions (add large negative value where mask is 0)
487
  gate_scores = gate_scores + (1 - new_mask) * -1e9
488
 
489
  # Softmax to get attention weights
490
  gate_weights = F.softmax(gate_scores, dim=-1)
491
 
492
  # Weighted sum over stack slots
493
+ # new_stack contains the features, gate_weights contains the validity/relevance
494
  memory_output = (new_stack * gate_weights.unsqueeze(-1)).sum(dim=3)
495
  memory_output = memory_output.view(batch_size, seq_len, -1)
496
 
 
881
  hidden = self.dropout(hidden)
882
  return self.down_proj(hidden)
883
 
 
884
  class NeoLLMDecoderLayer(GradientCheckpointingLayer):
885
  """
886
  Decoder layer with standard residual connections and optional StackMemory.
887
 
888
+ Architecture (Updated Flow):
889
+ 1. Optional: StackMemory module (Pre-processing context injection)
890
+ 2. Pre-norm (SeeDNorm) → LNS scaling → Self-Attention with ResFormer and Learnable Multipliers
891
+ 3. Standard Residual Connection
892
+ 4. GPAS activation scaling
893
+ 5. Pre-norm (SeeDNorm) → LNS scaling → MLP with FANformer and Learnable Multipliers
894
+ 6. Standard Residual Connection
895
+ 7. GPAS activation scaling
896
  """
897
 
898
  def __init__(self, config: NeoLLMConfig, layer_idx: int):
 
952
  Returns:
953
  Tuple of (hidden_states, attn_weights, stack_state, stack_mask)
954
  """
955
+
956
  # ============================================================
957
+ # 1. Stack Memory Module (MOVED TO START)
958
+ # ============================================================
959
+ # We process memory first so the Attention layer can "see" the
960
+ # retrieved context. This eliminates the 1-layer lag.
961
+ if self.use_stack:
962
+ hidden_states, stack_state, stack_mask = self.stack_memory(
963
+ hidden_states, stack_state, stack_mask
964
+ )
965
+
966
+ # ============================================================
967
+ # 2. Attention Block with Standard Residual Connection
968
  # ============================================================
969
  residual = hidden_states
970
 
 
990
  hidden_states = self.gpas_attn(hidden_states)
991
 
992
  # ============================================================
993
+ # 3. MLP Block with Standard Residual Connection
994
  # ============================================================
995
  residual = hidden_states
996
  hidden_states = self.post_attention_layernorm(hidden_states)
 
1007
  # Apply GPAS after residual connection
1008
  hidden_states = self.gpas_mlp(hidden_states)
1009
 
1010
+ # Return tuple matching the expected signature
 
 
 
 
 
 
 
1011
  if self.use_stack:
1012
  return (hidden_states, attn_weights, stack_state, stack_mask)
1013
  else: