Update modeling_neollm.py
Browse files- modeling_neollm.py +30 -28
modeling_neollm.py
CHANGED
|
@@ -4,7 +4,6 @@ NeoLLM Model with FANformer Integration in both Attention and FFN, Dropout Regul
|
|
| 4 |
SeeDNorm (Self-Rescaled Dynamic Normalization), ResFormer Value Residual Learning,
|
| 5 |
Learnable Multipliers for enhanced scale adaptation and information flow through deep layers,
|
| 6 |
and StackMemory for hierarchical pattern modeling.
|
| 7 |
-
|
| 8 |
Updated to include:
|
| 9 |
- Fourier Analysis Network (FAN) layer for effective periodicity modeling in attention (relational space)
|
| 10 |
- FAN layer in FFN for featural periodicity modeling (complementary coverage)
|
|
@@ -260,7 +259,6 @@ class SeeDNorm(nn.Module):
|
|
| 260 |
Self-Rescaled Dynamic Normalization (SeeDNorm) with dual dropout regularization.
|
| 261 |
|
| 262 |
SeeDNorm(x) = [σ(x·β^T)·α + γ] ⊙ x/RMS(x)
|
| 263 |
-
|
| 264 |
|
| 265 |
Args:
|
| 266 |
dim: Hidden dimension size
|
|
@@ -325,7 +323,6 @@ class SeeDNorm(nn.Module):
|
|
| 325 |
|
| 326 |
|
| 327 |
# ==================== STACK MEMORY MODULE ====================
|
| 328 |
-
|
| 329 |
class StackMemory(nn.Module):
|
| 330 |
"""
|
| 331 |
Differentiable Hidden State Stack for modeling Chomsky hierarchy grammars.
|
|
@@ -364,7 +361,7 @@ class StackMemory(nn.Module):
|
|
| 364 |
self.action_head = nn.Linear(self.stack_d_model, 3 * self.num_stack_heads, bias=True)
|
| 365 |
|
| 366 |
# Query projection for global reading (one per head)
|
| 367 |
-
self.gate_proj = nn.Linear(self.head_dim, 1, bias=
|
| 368 |
|
| 369 |
# Residual weight for gating stack contribution
|
| 370 |
self.res_weight = nn.Parameter(torch.ones(1))
|
|
@@ -479,19 +476,21 @@ class StackMemory(nn.Module):
|
|
| 479 |
new_stack, new_mask = self._vectorized_update(stack, mask, actions, k_values)
|
| 480 |
|
| 481 |
# Global reading via query-over-stack attention
|
| 482 |
-
# Apply mask before attention computation
|
| 483 |
-
masked_stack = new_stack * new_mask.unsqueeze(-1)
|
| 484 |
|
| 485 |
-
#
|
| 486 |
-
|
|
|
|
|
|
|
| 487 |
|
| 488 |
-
#
|
|
|
|
| 489 |
gate_scores = gate_scores + (1 - new_mask) * -1e9
|
| 490 |
|
| 491 |
# Softmax to get attention weights
|
| 492 |
gate_weights = F.softmax(gate_scores, dim=-1)
|
| 493 |
|
| 494 |
# Weighted sum over stack slots
|
|
|
|
| 495 |
memory_output = (new_stack * gate_weights.unsqueeze(-1)).sum(dim=3)
|
| 496 |
memory_output = memory_output.view(batch_size, seq_len, -1)
|
| 497 |
|
|
@@ -882,19 +881,18 @@ class NeoLLMMLP(nn.Module):
|
|
| 882 |
hidden = self.dropout(hidden)
|
| 883 |
return self.down_proj(hidden)
|
| 884 |
|
| 885 |
-
|
| 886 |
class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
| 887 |
"""
|
| 888 |
Decoder layer with standard residual connections and optional StackMemory.
|
| 889 |
|
| 890 |
-
Architecture:
|
| 891 |
-
1.
|
| 892 |
-
2.
|
| 893 |
-
3.
|
| 894 |
-
4.
|
| 895 |
-
5.
|
| 896 |
-
6.
|
| 897 |
-
7.
|
| 898 |
"""
|
| 899 |
|
| 900 |
def __init__(self, config: NeoLLMConfig, layer_idx: int):
|
|
@@ -954,8 +952,19 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
|
| 954 |
Returns:
|
| 955 |
Tuple of (hidden_states, attn_weights, stack_state, stack_mask)
|
| 956 |
"""
|
|
|
|
| 957 |
# ============================================================
|
| 958 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 959 |
# ============================================================
|
| 960 |
residual = hidden_states
|
| 961 |
|
|
@@ -981,7 +990,7 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
|
| 981 |
hidden_states = self.gpas_attn(hidden_states)
|
| 982 |
|
| 983 |
# ============================================================
|
| 984 |
-
# MLP Block with Standard Residual Connection
|
| 985 |
# ============================================================
|
| 986 |
residual = hidden_states
|
| 987 |
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
@@ -998,14 +1007,7 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
|
| 998 |
# Apply GPAS after residual connection
|
| 999 |
hidden_states = self.gpas_mlp(hidden_states)
|
| 1000 |
|
| 1001 |
-
#
|
| 1002 |
-
# Stack Memory Module
|
| 1003 |
-
# ============================================================
|
| 1004 |
-
if self.use_stack:
|
| 1005 |
-
hidden_states, stack_state, stack_mask = self.stack_memory(
|
| 1006 |
-
hidden_states, stack_state, stack_mask
|
| 1007 |
-
)
|
| 1008 |
-
|
| 1009 |
if self.use_stack:
|
| 1010 |
return (hidden_states, attn_weights, stack_state, stack_mask)
|
| 1011 |
else:
|
|
|
|
| 4 |
SeeDNorm (Self-Rescaled Dynamic Normalization), ResFormer Value Residual Learning,
|
| 5 |
Learnable Multipliers for enhanced scale adaptation and information flow through deep layers,
|
| 6 |
and StackMemory for hierarchical pattern modeling.
|
|
|
|
| 7 |
Updated to include:
|
| 8 |
- Fourier Analysis Network (FAN) layer for effective periodicity modeling in attention (relational space)
|
| 9 |
- FAN layer in FFN for featural periodicity modeling (complementary coverage)
|
|
|
|
| 259 |
Self-Rescaled Dynamic Normalization (SeeDNorm) with dual dropout regularization.
|
| 260 |
|
| 261 |
SeeDNorm(x) = [σ(x·β^T)·α + γ] ⊙ x/RMS(x)
|
|
|
|
| 262 |
|
| 263 |
Args:
|
| 264 |
dim: Hidden dimension size
|
|
|
|
| 323 |
|
| 324 |
|
| 325 |
# ==================== STACK MEMORY MODULE ====================
|
|
|
|
| 326 |
class StackMemory(nn.Module):
|
| 327 |
"""
|
| 328 |
Differentiable Hidden State Stack for modeling Chomsky hierarchy grammars.
|
|
|
|
| 361 |
self.action_head = nn.Linear(self.stack_d_model, 3 * self.num_stack_heads, bias=True)
|
| 362 |
|
| 363 |
# Query projection for global reading (one per head)
|
| 364 |
+
self.gate_proj = nn.Linear(self.head_dim, 1, bias=True)
|
| 365 |
|
| 366 |
# Residual weight for gating stack contribution
|
| 367 |
self.res_weight = nn.Parameter(torch.ones(1))
|
|
|
|
| 476 |
new_stack, new_mask = self._vectorized_update(stack, mask, actions, k_values)
|
| 477 |
|
| 478 |
# Global reading via query-over-stack attention
|
|
|
|
|
|
|
| 479 |
|
| 480 |
+
# FIX: Project the raw stack content directly.
|
| 481 |
+
# Previously, masking before projection killed gradients for "empty" slots
|
| 482 |
+
# preventing them from ever becoming "full".
|
| 483 |
+
gate_scores = self.gate_proj(new_stack).squeeze(-1) # [batch, seq, heads, slots]
|
| 484 |
|
| 485 |
+
# Apply mask to the SCORES, not the features.
|
| 486 |
+
# Mask out invalid positions (add large negative value where mask is 0)
|
| 487 |
gate_scores = gate_scores + (1 - new_mask) * -1e9
|
| 488 |
|
| 489 |
# Softmax to get attention weights
|
| 490 |
gate_weights = F.softmax(gate_scores, dim=-1)
|
| 491 |
|
| 492 |
# Weighted sum over stack slots
|
| 493 |
+
# new_stack contains the features, gate_weights contains the validity/relevance
|
| 494 |
memory_output = (new_stack * gate_weights.unsqueeze(-1)).sum(dim=3)
|
| 495 |
memory_output = memory_output.view(batch_size, seq_len, -1)
|
| 496 |
|
|
|
|
| 881 |
hidden = self.dropout(hidden)
|
| 882 |
return self.down_proj(hidden)
|
| 883 |
|
|
|
|
| 884 |
class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
| 885 |
"""
|
| 886 |
Decoder layer with standard residual connections and optional StackMemory.
|
| 887 |
|
| 888 |
+
Architecture (Updated Flow):
|
| 889 |
+
1. Optional: StackMemory module (Pre-processing context injection)
|
| 890 |
+
2. Pre-norm (SeeDNorm) → LNS scaling → Self-Attention with ResFormer and Learnable Multipliers
|
| 891 |
+
3. Standard Residual Connection
|
| 892 |
+
4. GPAS activation scaling
|
| 893 |
+
5. Pre-norm (SeeDNorm) → LNS scaling → MLP with FANformer and Learnable Multipliers
|
| 894 |
+
6. Standard Residual Connection
|
| 895 |
+
7. GPAS activation scaling
|
| 896 |
"""
|
| 897 |
|
| 898 |
def __init__(self, config: NeoLLMConfig, layer_idx: int):
|
|
|
|
| 952 |
Returns:
|
| 953 |
Tuple of (hidden_states, attn_weights, stack_state, stack_mask)
|
| 954 |
"""
|
| 955 |
+
|
| 956 |
# ============================================================
|
| 957 |
+
# 1. Stack Memory Module (MOVED TO START)
|
| 958 |
+
# ============================================================
|
| 959 |
+
# We process memory first so the Attention layer can "see" the
|
| 960 |
+
# retrieved context. This eliminates the 1-layer lag.
|
| 961 |
+
if self.use_stack:
|
| 962 |
+
hidden_states, stack_state, stack_mask = self.stack_memory(
|
| 963 |
+
hidden_states, stack_state, stack_mask
|
| 964 |
+
)
|
| 965 |
+
|
| 966 |
+
# ============================================================
|
| 967 |
+
# 2. Attention Block with Standard Residual Connection
|
| 968 |
# ============================================================
|
| 969 |
residual = hidden_states
|
| 970 |
|
|
|
|
| 990 |
hidden_states = self.gpas_attn(hidden_states)
|
| 991 |
|
| 992 |
# ============================================================
|
| 993 |
+
# 3. MLP Block with Standard Residual Connection
|
| 994 |
# ============================================================
|
| 995 |
residual = hidden_states
|
| 996 |
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
|
|
| 1007 |
# Apply GPAS after residual connection
|
| 1008 |
hidden_states = self.gpas_mlp(hidden_states)
|
| 1009 |
|
| 1010 |
+
# Return tuple matching the expected signature
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1011 |
if self.use_stack:
|
| 1012 |
return (hidden_states, attn_weights, stack_state, stack_mask)
|
| 1013 |
else:
|