#!/usr/bin/env python3 """ NeoLLM Model with FANformer Integration in both Attention and FFN, Dropout Regularization, SeeDNorm (Self-Rescaled Dynamic Normalization), ResFormer Value Residual Learning, Learnable Multipliers for enhanced scale adaptation and information flow through deep layers, and StackMemory for hierarchical pattern modeling. Updated to include: - Fourier Analysis Network (FAN) layer for effective periodicity modeling in attention (relational space) - FAN layer in FFN for featural periodicity modeling (complementary coverage) - SeeDNorm: Dynamic normalization with input-dependent scaling for better adaptability - Dropout regularization at strategic locations - ResFormer: Feature residual connections from first layer (applied before projections) - Learnable Multipliers: Frees weight matrix scale from WD-noise equilibrium for data-adaptive scaling - StackMemory: Differentiable hidden state stack for modeling Chomsky hierarchy grammars - Full Attention only (linear attention removed) """ import math from typing import Any, Callable, Optional, Union, Tuple, List import torch import torch.nn.functional as F from torch import nn from cut_cross_entropy import linear_cross_entropy import torch.nn.functional as F from torch.utils.checkpoint import checkpoint from typing import Optional, Tuple from transformers.activations import ACT2FN from transformers.generation import GenerationMixin from transformers.masking_utils import create_causal_mask from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_layers import GradientCheckpointingLayer from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from transformers.processing_utils import Unpack from transformers.utils import TransformersKwargs, logging from transformers.utils.generic import check_model_inputs from configuration_neollm import NeoLLMConfig from transformers import AutoConfig, AutoModel, AutoModelForCausalLM logger = logging.get_logger(__name__) # ==================== LEARNABLE MULTIPLIERS ==================== class ScalarMultiplier(nn.Module): """ Scalar Learnable Multiplier: W̃ = s·W From "Learnable Multipliers: Freeing the Scale of Language Model Matrix Layers": Allows the effective matrix norm ||W̃|| = s·||W|| to adapt to data, escaping the WD-noise equilibrium that constrains ||W|| ∝ √(η/λ). Args: initial_value: Initial multiplier value (default: 1.0 for identity) """ def __init__(self, initial_value: float = 1.0): super().__init__() self.multiplier = nn.Parameter(torch.tensor(initial_value)) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.multiplier * x class VectorMultiplier(nn.Module): """ Vector Learnable Multipliers: W̃ = diag(r)·W·diag(c) From "Learnable Multipliers: Freeing the Scale of Language Model Matrix Layers": Frees not only the overall matrix norm but also individual row/column norms from the WD-noise equilibrium, enabling richer feature scale diversity. Args: dim: Dimension size for the multiplier vector multiplier_type: Either "row" or "column" initial_value: Initial multiplier value (default: 1.0) """ def __init__(self, dim: int, multiplier_type: str = "row", initial_value: float = 1.0): super().__init__() self.multiplier_type = multiplier_type self.multiplier = nn.Parameter(torch.ones(dim) * initial_value) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Apply row or column multiplier. For row multipliers: x shape is (batch, seq, out_features) or (batch, heads, seq, head_dim) For column multipliers: applied before matrix multiplication """ if self.multiplier_type == "row": # Broadcast along the last dimension (output features) return x * self.multiplier else: # column # For column multipliers, typically applied before linear layer return x * self.multiplier class LinearWithMultipliers(nn.Module): """ Linear layer with optional row and/or column learnable multipliers. Implements: y = (r ⊙ (W @ (c ⊙ x))) + b where r and c are learnable multipliers, W is the base weight matrix. From "Learnable Multipliers: Freeing the Scale of Language Model Matrix Layers": The base matrix W remains subject to WD-noise equilibrium with ||W|| ∝ √(η/λ), while multipliers r,c learn freely to adapt the effective scale to data. Args: in_features: Input feature dimension out_features: Output feature dimension bias: Whether to include bias term use_row_multiplier: Enable row (output) multipliers use_column_multiplier: Enable column (input) multipliers """ def __init__( self, in_features: int, out_features: int, bias: bool = True, use_row_multiplier: bool = False, use_column_multiplier: bool = False ): super().__init__() # Base weight matrix (subject to WD) self.linear = nn.Linear(in_features, out_features, bias=bias) # Learnable multipliers (NOT subject to WD) self.use_row_multiplier = use_row_multiplier self.use_column_multiplier = use_column_multiplier if use_row_multiplier: self.row_multiplier = VectorMultiplier(out_features, multiplier_type="row") if use_column_multiplier: self.column_multiplier = VectorMultiplier(in_features, multiplier_type="column") def forward(self, x: torch.Tensor) -> torch.Tensor: # Apply column multiplier before linear transformation if self.use_column_multiplier: x = self.column_multiplier(x) # Linear transformation with base weights x = self.linear(x) # Apply row multiplier after linear transformation if self.use_row_multiplier: x = self.row_multiplier(x) return x # ==================== ORIGINAL COMPONENTS ==================== class FANLayer(nn.Module): """ Fourier Analysis Network (FAN) layer for effective periodicity modeling. From "FANformer: Improving Large Language Models Through Effective Periodicity Modeling": FANLayer'(X) = [cos(WpX)||sin(WpX)||(Wp¯X + Bp¯)] This is the modified version (FANLayer') without activation function that gave the best results in the paper. """ def __init__(self, hidden_size: int, fan_ratio: float = 0.25): super().__init__() self.hidden_size = hidden_size self.fan_ratio = fan_ratio # Calculate dimensions following the paper's approach # Output will be: [cos(p) || sin(p) || g] where total = hidden_size + periodic_dim output_dim = hidden_size + int(hidden_size * fan_ratio) self.p_output_dim = int(output_dim * fan_ratio) self.g_output_dim = output_dim - self.p_output_dim * 2 # Single fused projection (more efficient than two separate projections) self.input_linear = nn.Linear( hidden_size, self.p_output_dim + self.g_output_dim, bias=True ) # Initialize parameters self._init_weights() def _init_weights(self): """Initialize weights following the paper's recommendations.""" nn.init.normal_(self.input_linear.weight, mean=0.0, std=0.02) if self.input_linear.bias is not None: nn.init.zeros_(self.input_linear.bias) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Apply Fourier transformation to input. Args: x: Input tensor of shape (batch, seq_len, hidden_size) Returns: Transformed tensor with Fourier components concatenated Shape: (batch, seq_len, hidden_size + periodic_dim) """ # Single projection followed by split (more efficient) pg = self.input_linear(x) p, g = torch.split(pg, [self.p_output_dim, self.g_output_dim], dim=-1) # Concatenate all components: [cos(WpX) || sin(WpX) || (Wp¯X + Bp¯)] x_fan = torch.cat([torch.cos(p), torch.sin(p), g], dim=-1) return x_fan class LNS(nn.Module): """ LayerNorm Scaling (LNS) - applies scaling factor 1/√ℓ as described in the paper. From "The Curse of Depth in Large Language Models": h^(ℓ) = LayerNorm(h^(ℓ)) × (1/√ℓ) This prevents exponential variance growth in deeper layers. """ def __init__(self, layer_idx: int): super().__init__() # Layer 1 gets index 1, layer 2 gets index 2, etc. # Avoid division by zero for layer 0 self.layer_idx = max(layer_idx + 1, 1) # +1 because layer_idx starts from 0 self.scale = 1.0 / math.sqrt(self.layer_idx) def forward(self, x: torch.Tensor) -> torch.Tensor: return x * self.scale class GPAS(nn.Module): """ Gradient-Preserving Activation Scaling (GPAS) Scales activations without penalizing gradients using stop-gradient. Applied in Pre-Norm style: after sub-layer output but before residual sum. """ def __init__(self, d_model: int): super().__init__() self.d_model = d_model self.alpha = nn.Parameter(torch.zeros(1)) def forward(self, x: torch.Tensor) -> torch.Tensor: x_detached = x.detach() scaled_component = F.silu(self.alpha) * x_detached x_scaled = x - scaled_component return x_scaled class SeeDNorm(nn.Module): """ Self-Rescaled Dynamic Normalization (SeeDNorm) with dual dropout regularization. SeeDNorm(x) = [σ(x·β^T)·α + γ] ⊙ x/RMS(x) Args: dim: Hidden dimension size eps: Small constant for numerical stability dropout_input: Dropout on input features for dynamic mechanism (default: 0.0) dropout_hidden: Dropout on normalized hidden states (default: 0.0) """ def __init__( self, dim: int, eps: float = 1e-6, dropout_input: float = 0.01, dropout_hidden: float = 0.01, ): super().__init__() self.dim = dim self.eps = eps self.dropout_input = dropout_input self.dropout_hidden = dropout_hidden # Learnable parameters self.gamma = nn.Parameter(torch.ones(dim)) # γ: static scaling self.beta = nn.Parameter(torch.zeros(dim)) # β: self-rescaling self.alpha = nn.Parameter(torch.ones(dim)) # α: dynamic modulation def _rms_norm(self, x: torch.Tensor) -> torch.Tensor: """Compute RMS normalization: x / RMS(x)""" return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Apply Self-Rescaled Dynamic Normalization with dual dropout. Args: x: Input tensor of shape (..., dim) Returns: Normalized and dynamically scaled tensor of same shape """ x_for_dynamic = F.dropout(x, p=self.dropout_input, training=self.training) rescale_factor = torch.tanh(torch.sum(x_for_dynamic * self.beta, dim=-1, keepdim=True)) # Compute dynamic scaling coefficient: σ(x·β^T)·α + γ dynamic_scale = rescale_factor * self.alpha + self.gamma # Apply RMS normalization on ORIGINAL input (not dropped version) x_normalized = self._rms_norm(x.float()) x_normalized = F.dropout(x_normalized, p=self.dropout_hidden, training=self.training) # Apply dynamic scaling output = x_normalized * dynamic_scale.float() return output.type_as(x) def extra_repr(self) -> str: return (f"dim={self.dim}, eps={self.eps}, " f"dropout_input={self.dropout_input}, dropout_hidden={self.dropout_hidden}") # ==================== STACK MEMORY MODULE ==================== class StackMemory(nn.Module): """ From "Improving Formal Reasoning of Transformer with State Stack": Implements a multi-head differentiable stack with soft push, pop, and no-op operations. Each head maintains its own stack and mask, which are updated based on learned action probabilities. Global reading is performed via query-over-stack attention. This module is inserted between Transformer layers to augment information flow with stack-like memory operations, enabling the model to better capture hierarchical and recursive patterns characteristic of regular expressions and context-free grammars. Note: StackMemory uses standard nn.Linear to maintain architectural independence and avoid introducing additional complexity in the memory operations. Args: config: Model configuration containing stack-related hyperparameters """ def __init__(self, config: NeoLLMConfig): super().__init__() self.config = config self.num_stack_heads = getattr(config, 'num_stack_heads', 4) self.stack_slots = getattr(config, 'stack_slots', 24) self.stack_d_model = getattr(config, 'stack_d_model', 128) self.head_dim = self.stack_d_model // self.num_stack_heads # Dimension reduction projections for efficiency # Uses standard nn.Linear self.down_proj = nn.Linear(config.hidden_size, self.stack_d_model, bias=True) self.up_proj = nn.Linear(self.stack_d_model, config.hidden_size, bias=True) # Action prediction: generates push/pop/no-op probabilities for each head self.action_head = nn.Linear(self.stack_d_model, 3 * self.num_stack_heads, bias=True) # Query projection for global reading (one per head) self.gate_proj = nn.Linear(self.head_dim, 1, bias=True) # Residual weight for gating stack contribution self.res_weight = nn.Parameter(torch.ones(1)) # Cache for autoregressive generation (matches OLMo reference) self.cache_size = getattr(config, "cache_size", 2048) # Initialization fix: Register buffers for cache # Default to batch_size=1 if forward_bs is not in config (standard inference) forward_bs = getattr(config, 'forward_bs', 1) self.register_buffer("k_cache", torch.zeros(forward_bs, self.cache_size, self.num_stack_heads, self.head_dim)) self.register_buffer("action_cache", torch.zeros(forward_bs, self.cache_size, self.num_stack_heads, 3)) self.cache_position = 0 self.enable_cache = False def reset_cache(self): self.cache_position = 0 def _vectorized_update( self, stack: torch.Tensor, mask: torch.Tensor, actions: torch.Tensor, k_values: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ Vectorized stack update mechanism applying soft push/pop/no-op operations. Implements the differentiable stack operations from the paper: - Push: shifts all elements down and places k_values at top - Pop: shifts all elements up and removes top - No-op: maintains current stack state Args: stack: Current stack state [batch, seq, num_heads, stack_slots, head_dim] mask: Current stack mask [batch, seq, num_heads, stack_slots] actions: Action probabilities [batch, seq, num_heads, 3] (push/pop/no-op) k_values: New values to push [batch, seq, num_heads, head_dim] Returns: Tuple of (updated_stack, updated_mask) """ batch_size, seq_len = actions.shape[:2] # Expand stack and mask along sequence dimension for parallel processing # Only expand if checking against initial state dimensions (4D) if stack.dim() == 4: stack = stack.unsqueeze(1).expand(-1, seq_len, -1, -1, -1) mask = mask.unsqueeze(1).expand(-1, seq_len, -1, -1) # Generate pushed stack: new value at top, shift others down push_stack = torch.cat([ k_values.unsqueeze(3), # New value at position 0 stack[:, :, :, :-1] # Shift existing elements down ], dim=3) push_mask = torch.cat([ torch.ones_like(mask[:, :, :, :1]), mask[:, :, :, :-1] ], dim=3) # Generate popped stack: shift all up, zero at bottom pop_stack = torch.cat([ stack[:, :, :, 1:], torch.zeros_like(stack[:, :, :, :1]) ], dim=3) pop_mask = torch.cat([ mask[:, :, :, 1:], torch.zeros_like(mask[:, :, :, :1]) ], dim=3) # Combine operations weighted by action probabilities action_weights = actions.unsqueeze(-1).unsqueeze(-1) # [batch, seq, heads, 3, 1, 1] stacks = torch.stack([push_stack, pop_stack, stack], dim=3) # [batch, seq, heads, 3, slots, dim] masks = torch.stack([push_mask, pop_mask, mask], dim=3) # [batch, seq, heads, 3, slots] # Weighted combination of all operations new_stack = (stacks * action_weights).sum(dim=3) new_mask = (masks * action_weights.squeeze(-1)).sum(dim=3) return new_stack, new_mask def forward( self, hidden_states: torch.Tensor, stack: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Apply differentiable stack operations to hidden states. Args: hidden_states: Input hidden states [batch, seq, hidden_size] stack: Previous stack state [batch, num_heads, stack_slots, head_dim] or None mask: Previous stack mask [batch, num_heads, stack_slots] or None Returns: Tuple of (output_hidden_states, updated_stack, updated_mask) """ batch_size, seq_len, _ = hidden_states.shape device = hidden_states.device # Initialize stack and mask if not provided if stack is None: stack = torch.zeros( batch_size, self.num_stack_heads, self.stack_slots, self.head_dim, device=device, dtype=hidden_states.dtype ) if mask is None: mask = torch.zeros( batch_size, self.num_stack_heads, self.stack_slots, device=device, dtype=hidden_states.dtype ) # Project to lower dimension for efficiency new_hidden_states = self.down_proj(hidden_states) # Generate action probabilities: [batch, seq, num_heads, 3] action_logits = self.action_head(new_hidden_states) / math.sqrt(self.head_dim) actions = F.softmax( action_logits.view(batch_size, seq_len, self.num_stack_heads, 3), dim=-1 ) # Prepare values to push (split into heads) k_values = new_hidden_states.view(batch_size, seq_len, self.num_stack_heads, self.head_dim) # Update stack and mask using vectorized operations new_stack, new_mask = self._vectorized_update(stack, mask, actions, k_values) # Global reading via query-over-stack attention gate_scores = self.gate_proj(new_stack).squeeze(-1) # [batch, seq, heads, slots] gate_weights = F.softmax(gate_scores + (1 - new_mask) * -1e9, dim=-1) # Weighted sum over stack slots memory_output = (new_stack * gate_weights.unsqueeze(-1)).sum(dim=3) memory_output = memory_output.view(batch_size, seq_len, -1) memory_output = self.up_proj(memory_output) # Residual Connection output = memory_output * self.res_weight + hidden_states # Update Cache Logic if self.enable_cache: self._update_cache(k_values.detach(), actions.detach()) return output, new_stack[:, -1], new_mask[:, -1] def _update_cache(self, k_values: torch.Tensor, actions: torch.Tensor): seq_len = k_values.shape[1] if self.cache_position + seq_len <= self.cache_size: # Assumes standard batch processing for inference (usually batch_size=1) self.k_cache[:, self.cache_position:self.cache_position+seq_len] = k_values self.action_cache[:, self.cache_position:self.cache_position+seq_len] = actions self.cache_position += seq_len else: self.reset_cache() def step(self, hidden_state: torch.Tensor, stack: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if not self.enable_cache: return self.forward(hidden_state.unsqueeze(1), stack, mask) batch_size = hidden_state.shape[0] # Compute features for current token new_hidden_states = self.down_proj(hidden_state) action_logits = self.action_head(new_hidden_states) / math.sqrt(self.head_dim) current_actions = F.softmax( action_logits.view(batch_size, 1, self.num_stack_heads, 3), dim=-1 ) current_k = new_hidden_states.view(batch_size, 1, self.num_stack_heads, self.head_dim) # Reconstruct History if self.cache_position > 0: cached_k = self.k_cache[:, :self.cache_position] cached_actions = self.action_cache[:, :self.cache_position] k_values = torch.cat([cached_k, current_k], dim=1) actions = torch.cat([cached_actions, current_actions], dim=1) else: k_values = current_k actions = current_actions # Dimension Fix: Pass sequences directly without unsqueeze(0) # k_values is [batch, seq_len_total, heads, dim] # actions is [batch, seq_len_total, heads, 3] new_stack_seq, new_mask_seq = self._vectorized_update( stack, # Initial stack [batch, heads, slots, dim] mask, actions, k_values ) # Extract last step current_stack = new_stack_seq[:, -1] current_mask = new_mask_seq[:, -1] gate_scores = self.gate_proj(current_stack).squeeze(-1) gate_weights = F.softmax(gate_scores + (1 - current_mask) * -1e9, dim=-1) memory_output = (current_stack * gate_weights.unsqueeze(-1)).sum(dim=2) memory_output = memory_output.view(batch_size, -1) memory_output_proj = self.up_proj(memory_output) self._update_cache(current_k, current_actions) return ( memory_output_proj * self.res_weight + hidden_state, current_stack, current_mask ) # ==================== ROTARY EMBEDDING ==================== class NeoLLMRotaryEmbedding(nn.Module): inv_freq: torch.Tensor # fix linting for `register_buffer` def __init__(self, config: NeoLLMConfig, device=None): super().__init__() self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config # Determine rope_type from rope_scaling config self.rope_type = "default" if hasattr(config, "rope_scaling") and config.rope_scaling is not None and isinstance(config.rope_scaling, dict): rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) if rope_type and rope_type in ROPE_INIT_FUNCTIONS: self.rope_type = rope_type # Initialize rope parameters rope_init_fn = self.compute_default_rope_parameters if self.rope_type != "default": rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) @staticmethod def compute_default_rope_parameters( config: NeoLLMConfig = None, device: Optional["torch.device"] = None, seq_len: int = None, ) -> tuple["torch.Tensor", float]: """ Computes the inverse frequencies according to the original RoPE implementation Args: config: The model configuration. device: The device to use for initialization of the inverse frequencies. seq_len: The current sequence length. Unused for this type of RoPE. Returns: Tuple of (torch.Tensor, float), containing the inverse frequencies for the RoPE embeddings and the post-processing scaling factor applied to the computed cos/sin. """ base = config.rope_theta dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0) dim = int(dim * partial_rotary_factor) attention_scaling = 1.0 # Unused in default RoPE # Compute the inverse frequencies inv_freq = 1.0 / ( base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) ) return inv_freq, attention_scaling @torch.no_grad() @dynamic_rope_update def forward(self, x, position_ids): # Asegura forma [B, S] if position_ids.dim() == 1: position_ids = position_ids.unsqueeze(0) # [1, S] B = x.shape[0] if position_ids.shape[0] != B: # Replica posiciones idénticas por batch (semántica correcta) position_ids = position_ids.expand(B, -1) # [B, S] device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" # inv_freq en float32 en el device correcto (sin expand con stride 0) inv_freq = self.inv_freq.to(device=x.device, dtype=torch.float32) # [d/2] with torch.autocast(device_type=device_type, enabled=False): # fuerza float32 # Θ[b,s,i] = position_ids[b,s] * inv_freq[i] freqs = position_ids.to(dtype=torch.float32).unsqueeze(-1) * inv_freq.unsqueeze(0).unsqueeze(0) # freqs: [B, S, d/2] emb = torch.cat((freqs, freqs), dim=-1) # [B, S, d] cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors.""" cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) # Keep half or full tensor for later concatenation rotary_dim = cos.shape[-1] q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] # Apply rotary embeddings on the first half or full tensor q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) # Concatenate back to full shape q_embed = torch.cat([q_embed, q_pass], dim=-1) k_embed = torch.cat([k_embed, k_pass], dim=-1) return q_embed, k_embed def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, dropout: float = 0.0, **kwargs: Unpack[TransformersKwargs], ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights class NeoLLMAttention(nn.Module): """ Multi-headed attention with FANformer integration, SeeDNorm for Q/K normalization, ResFormer feature residual connections, and Learnable Multipliers for enhanced information flow and scale adaptation. ResFormer enhancement: Applies learnable feature residual connections from first layer BEFORE QKV projections: H'_fan_n = λ_1 * H_fan_1 + λ_2 * H_fan_n Learnable Multipliers placement (from "Learnable Multipliers" paper Appendix C): - Q projection: row multipliers only (enables per-head attention scaling in GQA) - K, V projections: no multipliers (avoids redundancy with Q multipliers) - Output projection: row + column multipliers (maximally expressive without symmetries) """ def __init__(self, config: NeoLLMConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout self.is_causal = True # FANformer integration: FAN layer before QKV projections self.fan_layer = FANLayer( hidden_size=config.hidden_size, fan_ratio=getattr(config, 'fan_ratio', 0.125) ) # Calculate the output dimension after FAN transformation fan_output_dim = config.hidden_size + int(config.hidden_size * getattr(config, 'fan_ratio', 0.125)) # Q projection with row multipliers (per-head scaling capability) self.q_proj = LinearWithMultipliers( fan_output_dim, config.num_attention_heads * self.head_dim * 2, bias=config.attention_bias, use_row_multiplier=True, use_column_multiplier=False ) # K, V projections without multipliers (avoids Q-K symmetry) self.k_proj = nn.Linear( fan_output_dim, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) self.v_proj = nn.Linear( fan_output_dim, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) # Output projection with row + column multipliers (maximally expressive) self.o_proj = LinearWithMultipliers( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias, use_row_multiplier=True, use_column_multiplier=True ) # SeeDNorm for Q/K normalization (replaces RMSNorm) self.q_norm = SeeDNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = SeeDNorm(self.head_dim, eps=config.rms_norm_eps) # Dropout for attention output self.dropout = nn.Dropout(config.dropout_rate) # ResFormer: learnable feature residual parameters (initialized to 0.5) self.lambda_1 = nn.Parameter(torch.tensor(0.5)) # Weight for H_fan_1 self.lambda_2 = nn.Parameter(torch.tensor(0.5)) # Weight for H_fan_n def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, first_layer_fan: Optional[torch.Tensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: """ Forward pass with ResFormer feature residual connections. Args: hidden_states: Current layer input [batch, seq, hidden_size] position_embeddings: Tuple of (cos, sin) for RoPE attention_mask: Causal attention mask first_layer_fan: First layer FAN features (for ResFormer) Returns: Tuple of (attn_output, attn_weights, current_layer_fan) """ input_shape = hidden_states.shape[:-1] # Apply FANformer transformation hidden_states_fan = self.fan_layer(hidden_states) # ResFormer: Apply feature residual connection BEFORE projections if first_layer_fan is not None: hidden_states_fan = self.lambda_1 * first_layer_fan + self.lambda_2 * hidden_states_fan # Store current FAN features for ResFormer current_layer_fan = hidden_states_fan.clone() hidden_shape = (*input_shape, -1, self.head_dim) # Q projection with learnable row multipliers query_states, gate = torch.chunk( self.q_proj(hidden_states_fan).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1 ) gate = gate.reshape(*input_shape, -1) # Apply SeeDNorm to Q and K query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2) key_states = self.k_norm(self.k_proj(hidden_states_fan).view(hidden_shape)).transpose(1, 2) value_states = self.v_proj(hidden_states_fan).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = attn_output * torch.sigmoid(gate) # Output projection with learnable row + column multipliers attn_output = self.o_proj(attn_output) attn_output = self.dropout(attn_output) return attn_output, attn_weights, current_layer_fan class PolyNorm(torch.nn.Module): def __init__(self, eps=1e-6): super(PolyNorm, self).__init__() self.weight = torch.nn.Parameter(torch.ones(3) / 3) self.bias = torch.nn.Parameter(torch.zeros(1)) self.eps = eps def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): return self.weight[0] * self._norm(x**3) + self.weight[1] * self._norm(x**2) + self.weight[2] * self._norm(x) + self.bias class NeoLLMMLP(nn.Module): """ MLP with FANformer integration for featural periodicity modeling and Learnable Multipliers for adaptive scale control. This captures periodicities in the feature space (semantic/embedding dimensions) complementary to the relational periodicities captured by attention mechanisms. Works in conjunction with ResFormer for comprehensive information flow. Learnable Multipliers placement (from "Learnable Multipliers" paper Appendix C): - gate_proj: row multipliers only (controls gating mechanism scale) - up_proj: no multipliers (avoids redundancy with down_proj) - down_proj: row + column multipliers (maximally expressive output scaling) """ def __init__(self, config): super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size # FANformer integration for featural space periodicity self.fan_layer = FANLayer( hidden_size=config.hidden_size, fan_ratio=getattr(config, 'fan_ratio_ffn', 0.0625) # Half of attention's fan_ratio ) # Calculate the output dimension after FAN transformation fan_output_dim = config.hidden_size + int(config.hidden_size * getattr(config, 'fan_ratio_ffn', 0.0625)) # SwiGLU/Gated architecture with learnable multipliers # gate_proj: row multipliers for gating scale control self.gate_proj = LinearWithMultipliers( fan_output_dim, self.intermediate_size, bias=False, use_row_multiplier=True, use_column_multiplier=False ) # up_proj: no multipliers (avoids redundancy) self.up_proj = nn.Linear(fan_output_dim, self.intermediate_size, bias=False) # down_proj: row + column multipliers (maximally expressive) self.down_proj = LinearWithMultipliers( self.intermediate_size, self.hidden_size, bias=False, use_row_multiplier=True, use_column_multiplier=True ) self.act_fn = PolyNorm() # Dropout for MLP hidden layer self.dropout = nn.Dropout(config.dropout_rate) def forward(self, x): # Apply FAN transformation before projections x_fan = self.fan_layer(x) # Use FAN-transformed features for gate and up projections gate_output = self.act_fn(self.gate_proj(x_fan)) up_output = self.up_proj(x_fan) hidden = gate_output * up_output hidden = self.dropout(hidden) return self.down_proj(hidden) class NeoLLMDecoderLayer(GradientCheckpointingLayer): """ Decoder layer with standard residual connections and optional StackMemory. Architecture (Updated Flow): 1. Optional: StackMemory module (Pre-processing context injection) 2. Pre-norm (SeeDNorm) → LNS scaling → Self-Attention with ResFormer and Learnable Multipliers 3. Standard Residual Connection 4. GPAS activation scaling 5. Pre-norm (SeeDNorm) → LNS scaling → MLP with FANformer and Learnable Multipliers 6. Standard Residual Connection 7. GPAS activation scaling """ def __init__(self, config: NeoLLMConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.layer_idx = layer_idx # Full attention with learnable multipliers self.self_attn = NeoLLMAttention(config, layer_idx) # MLP with FANformer integration and learnable multipliers self.mlp = NeoLLMMLP(config) # SeeDNorm for input and post-attention normalization self.input_layernorm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps) # LNS (LayerNorm Scaling) - applies 1/√ℓ scaling self.lns_attn = LNS(layer_idx) self.lns_mlp = LNS(layer_idx) # GPAS (Gradient-Preserving Activation Scaling) self.gpas_attn = GPAS(config.hidden_size) self.gpas_mlp = GPAS(config.hidden_size) # StackMemory: Differentiable hidden state stack self.use_stack = getattr(config, 'use_stack', False) if self.use_stack: self.stack_memory = StackMemory(config) # ResFormer: storage for current layer's FAN features self.current_layer_fan = None def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, first_layer_fan: Optional[torch.Tensor] = None, stack_state: Optional[torch.Tensor] = None, stack_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[torch.Tensor], Optional[torch.Tensor]]: """ Forward pass with ResFormer and optional StackMemory. Args: hidden_states: Current layer input [batch, seq, hidden_size] position_embeddings: Tuple of (cos, sin) for RoPE attention_mask: Causal attention mask first_layer_fan: First layer FAN features (for ResFormer) stack_state: StackMemory state (optional) stack_mask: StackMemory mask (optional) output_attentions: Whether to return attention weights Returns: Tuple of (hidden_states, attn_weights, stack_state, stack_mask) """ # ============================================================ # 1. Stack Memory Module (MOVED TO START) # ============================================================ # We process memory first so the Attention layer can "see" the # retrieved context. This eliminates the 1-layer lag. if self.use_stack: hidden_states, stack_state, stack_mask = self.stack_memory( hidden_states, stack_state, stack_mask ) # ============================================================ # 2. Attention Block with Standard Residual Connection # ============================================================ residual = hidden_states # Apply SeeDNorm normalization hidden_states = self.input_layernorm(hidden_states) # Apply LNS scaling after normalization hidden_states = self.lns_attn(hidden_states) # Self Attention with ResFormer attn_output, attn_weights, self.current_layer_fan = self.self_attn( hidden_states=hidden_states, position_embeddings=position_embeddings, attention_mask=attention_mask, first_layer_fan=first_layer_fan, **kwargs, ) # Standard Residual Connection hidden_states = residual + attn_output # Apply GPAS after residual connection hidden_states = self.gpas_attn(hidden_states) # ============================================================ # 3. MLP Block with Standard Residual Connection # ============================================================ residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) # Apply LNS scaling after normalization hidden_states = self.lns_mlp(hidden_states) # MLP with FANformer mlp_output = self.mlp(hidden_states) # Standard Residual Connection hidden_states = residual + mlp_output # Apply GPAS after residual connection hidden_states = self.gpas_mlp(hidden_states) # Return tuple matching the expected signature if self.use_stack: return (hidden_states, attn_weights, stack_state, stack_mask) else: return (hidden_states, attn_weights, None, None) class NeoLLMPreTrainedModel(PreTrainedModel): """ Base class for NeoLLM models with custom weight initialization. Handles initialization for: - NeoLLMAttention (ResFormer lambda parameters) - GPAS (Gradient-Preserving Activation Scaling) - FANLayer (Fourier Analysis Network) - SeeDNorm (Self-Rescaled Dynamic Normalization) - Learnable Multipliers (ScalarMultiplier, VectorMultiplier) - StackMemory (Differentiable Hidden State Stack) """ config: NeoLLMConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["NeoLLMDecoderLayer"] _supports_flash_attn_2 = True _supports_sdpa = True _is_stateful = True def _init_weights(self, module): """ Initialize weights for all custom modules in NeoLLM. """ super()._init_weights(module) if isinstance(module, NeoLLMAttention): if hasattr(module, 'lambda_1'): module.lambda_1.data.fill_(0.5) if hasattr(module, 'lambda_2'): module.lambda_2.data.fill_(0.5) elif isinstance(module, GPAS): module.alpha.data.fill_(0.0) elif isinstance(module, (ScalarMultiplier, VectorMultiplier)): if hasattr(module, 'multiplier'): module.multiplier.data.fill_(1.0) elif isinstance(module, StackMemory): std = self.config.initializer_range if hasattr(self.config, 'initializer_range') else 0.02 if hasattr(module, 'down_proj'): module.down_proj.weight.data.normal_(mean=0.0, std=std) if hasattr(module, 'up_proj'): module.up_proj.weight.data.normal_(mean=0.0, std=std) if hasattr(module, 'action_head'): module.action_head.weight.data.normal_(mean=0.0, std=std) if module.action_head.bias is not None: module.action_head.bias.data.zero_() if hasattr(module, 'gate_proj'): module.gate_proj.weight.data.normal_(mean=0.0, std=std) if hasattr(module, 'res_weight'): module.res_weight.data.fill_(1.0) class NeoLLMModel(NeoLLMPreTrainedModel): """ NeoLLM base model with transformer decoder architecture. Uses ResFormer for first-layer feature propagation with standard residual connections and optional StackMemory for hierarchical pattern modeling. Note on embeddings and weight tying: This model uses weight tying between embed_tokens and lm_head (shared weights). Following "Learnable Multipliers" paper analysis, we do NOT add multipliers to embeddings because: 1. Weight tying creates conflicting gradient paths 2. The paper explicitly warns against multipliers in lm_head 3. Compensating mechanisms provide scale adaptation immediately after embedding """ def __init__(self, config: NeoLLMConfig): super().__init__(config) # Standard embedding without learnable multipliers self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) # Each layer creates its own components (no shared parameters) self.layers = nn.ModuleList( [NeoLLMDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) # SeeDNorm for final output normalization (replaces RMSNorm) self.norm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = NeoLLMRotaryEmbedding(config=config) self.gradient_checkpointing = False # Configuration self.use_stack = getattr(config, 'use_stack', False) # ResFormer: storage for first layer's FAN features self.first_layer_fan = None # Initialize weights and apply final processing self.post_init() def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, output_hidden_states: Optional[bool] = None, output_attentions: Optional[bool] = None, return_dict: Optional[bool] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, use_cache: Optional[bool] = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPast: output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) if position_ids is None: position_ids = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0) causal_mask = create_causal_mask( config=self.config, input_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=position_ids.squeeze(0), past_key_values=None, position_ids=position_ids, ) hidden_states = inputs_embeds next_decoder_cache = None all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None # Create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) # ResFormer with first-layer feature propagation self.first_layer_fan = None # Initialize Stack states (always None at start of forward, rebuilt via cache step or vertical flow) stack_state = None stack_mask = None # Propagate use_cache and reset if starting a new sequence if self.use_stack: for layer in self.layers: if hasattr(layer, 'stack_memory'): layer.stack_memory.enable_cache = use_cache if use_cache is not None else False if past_key_values is None: layer.stack_memory.reset_cache() for decoder_layer in self.layers: if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_outputs = decoder_layer( hidden_states, position_embeddings=position_embeddings, attention_mask=causal_mask, first_layer_fan=self.first_layer_fan, stack_state=stack_state, stack_mask=stack_mask, output_attentions=output_attentions, **kwargs, ) hidden_states = layer_outputs[0] if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) if self.use_stack: # Vertical memory logic: # The layer returns updated stack for the next layer to use (Vertical passing) # But we do NOT persist it temporally here. The Module's internal cache handles temporal. stack_state = layer_outputs[2] stack_mask = layer_outputs[3] # ResFormer: capture H_fan_1 from the first layer # Dynamically capture for the current pass if self.first_layer_fan is None and hasattr(decoder_layer, 'current_layer_fan'): self.first_layer_fan = decoder_layer.current_layer_fan # Apply SeeDNorm for final normalization hidden_states = self.norm(hidden_states) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple(v for v in [hidden_states, next_decoder_cache, all_hidden_states, all_attentions] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_decoder_cache, hidden_states=all_hidden_states, attentions=all_attentions, ) @torch.compiler.disable def compute_cce_loss(hidden_states, labels, lm_head_weight, lm_head_bias=None, pad_token_id=None): """ CCE loss computation excluded from compilation. Preprocesses labels to eliminate torch.compile warnings. """ # Ensure labels are on the correct device processed_labels = labels.to(hidden_states.device) # Handle pad tokens: convert pad_token_id to -100 for proper masking if pad_token_id is not None: processed_labels = torch.where( processed_labels == pad_token_id, torch.tensor(-100, dtype=processed_labels.dtype, device=processed_labels.device), processed_labels ) return linear_cross_entropy( hidden_states, lm_head_weight, processed_labels, bias=lm_head_bias, shift=1, impl="cce_kahan_full_c", reduction="mean" ) class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin): """ Causal Language Model with NeoLLM architecture. Supports ResFormer with standard residuals and optional StackMemory. Note on LM head: Following "Learnable Multipliers" paper recommendations, the output projection (lm_head) does NOT include learnable multipliers. """ _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} def __init__(self, config): super().__init__(config) self.model = NeoLLMModel(config) self.vocab_size = config.vocab_size # LM head without learnable multipliers (standard linear layer) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.post_init() def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): if past_key_values: past_length = past_key_values[0][0].shape[2] # If past_length > input_ids length, we are likely generating token by token if input_ids.shape[1] > past_length: remove_prefix_length = past_length else: # Default standard HF behavior remove_prefix_length = input_ids.shape[1] - 1 input_ids = input_ids[:, remove_prefix_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] return { "input_ids": input_ids, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "position_ids": position_ids, "attention_mask": attention_mask, "inputs_embeds": inputs_embeds, } def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs: Unpack[TransformersKwargs], ) -> CausalLMOutputWithPast: outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, output_hidden_states=output_hidden_states, return_dict=return_dict, **kwargs, ) hidden_states = outputs.last_hidden_state # CCE Loss computation for training if labels is not None: loss = compute_cce_loss( hidden_states, labels, self.lm_head.weight, getattr(self.lm_head, 'bias', None), self.config.pad_token_id ) logits = None else: # Inference mode - compute logits normally slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) # ==================== AUTOMODEL REGISTRATION ==================== __all__ = [ "NeoLLMForCausalLM", "NeoLLMModel", "NeoLLMPreTrainedModel", "NeoLLMConfig", "FANLayer", "SeeDNorm", "ScalarMultiplier", "VectorMultiplier", "LinearWithMultipliers", "StackMemory", ] # Register the configuration and model for AutoClass support AutoConfig.register("neollm", NeoLLMConfig) AutoModel.register(NeoLLMConfig, NeoLLMModel) AutoModelForCausalLM.register(NeoLLMConfig, NeoLLMForCausalLM)