|
|
|
|
|
""" |
|
|
NeoLLM Model with FANformer Integration in both Attention and FFN, Dropout Regularization, |
|
|
SeeDNorm (Self-Rescaled Dynamic Normalization), ResFormer Value Residual Learning, |
|
|
Learnable Multipliers for enhanced scale adaptation and information flow through deep layers, |
|
|
and StackMemory for hierarchical pattern modeling. |
|
|
Updated to include: |
|
|
- Fourier Analysis Network (FAN) layer for effective periodicity modeling in attention (relational space) |
|
|
- FAN layer in FFN for featural periodicity modeling (complementary coverage) |
|
|
- SeeDNorm: Dynamic normalization with input-dependent scaling for better adaptability |
|
|
- Dropout regularization at strategic locations |
|
|
- ResFormer: Feature residual connections from first layer (applied before projections) |
|
|
- Learnable Multipliers: Frees weight matrix scale from WD-noise equilibrium for data-adaptive scaling |
|
|
- StackMemory: Differentiable hidden state stack for modeling Chomsky hierarchy grammars |
|
|
- Full Attention only (linear attention removed) |
|
|
""" |
|
|
|
|
|
import math |
|
|
from typing import Any, Callable, Optional, Union, Tuple, List |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from torch import nn |
|
|
from cut_cross_entropy import linear_cross_entropy |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.checkpoint import checkpoint |
|
|
from typing import Optional, Tuple |
|
|
|
|
|
from transformers.activations import ACT2FN |
|
|
from transformers.generation import GenerationMixin |
|
|
from transformers.masking_utils import create_causal_mask |
|
|
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs |
|
|
from transformers.modeling_layers import GradientCheckpointingLayer |
|
|
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
|
|
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update |
|
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel |
|
|
from transformers.processing_utils import Unpack |
|
|
from transformers.utils import TransformersKwargs, logging |
|
|
from transformers.utils.generic import check_model_inputs |
|
|
from configuration_neollm import NeoLLMConfig |
|
|
|
|
|
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ScalarMultiplier(nn.Module): |
|
|
""" |
|
|
Scalar Learnable Multiplier: W̃ = s·W |
|
|
|
|
|
From "Learnable Multipliers: Freeing the Scale of Language Model Matrix Layers": |
|
|
Allows the effective matrix norm ||W̃|| = s·||W|| to adapt to data, escaping the |
|
|
WD-noise equilibrium that constrains ||W|| ∝ √(η/λ). |
|
|
|
|
|
Args: |
|
|
initial_value: Initial multiplier value (default: 1.0 for identity) |
|
|
""" |
|
|
def __init__(self, initial_value: float = 1.0): |
|
|
super().__init__() |
|
|
self.multiplier = nn.Parameter(torch.tensor(initial_value)) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
return self.multiplier * x |
|
|
|
|
|
|
|
|
class VectorMultiplier(nn.Module): |
|
|
""" |
|
|
Vector Learnable Multipliers: W̃ = diag(r)·W·diag(c) |
|
|
|
|
|
From "Learnable Multipliers: Freeing the Scale of Language Model Matrix Layers": |
|
|
Frees not only the overall matrix norm but also individual row/column norms from |
|
|
the WD-noise equilibrium, enabling richer feature scale diversity. |
|
|
|
|
|
Args: |
|
|
dim: Dimension size for the multiplier vector |
|
|
multiplier_type: Either "row" or "column" |
|
|
initial_value: Initial multiplier value (default: 1.0) |
|
|
""" |
|
|
def __init__(self, dim: int, multiplier_type: str = "row", initial_value: float = 1.0): |
|
|
super().__init__() |
|
|
self.multiplier_type = multiplier_type |
|
|
self.multiplier = nn.Parameter(torch.ones(dim) * initial_value) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Apply row or column multiplier. |
|
|
|
|
|
For row multipliers: x shape is (batch, seq, out_features) or (batch, heads, seq, head_dim) |
|
|
For column multipliers: applied before matrix multiplication |
|
|
""" |
|
|
if self.multiplier_type == "row": |
|
|
|
|
|
return x * self.multiplier |
|
|
else: |
|
|
|
|
|
return x * self.multiplier |
|
|
|
|
|
|
|
|
class LinearWithMultipliers(nn.Module): |
|
|
""" |
|
|
Linear layer with optional row and/or column learnable multipliers. |
|
|
|
|
|
Implements: y = (r ⊙ (W @ (c ⊙ x))) + b |
|
|
where r and c are learnable multipliers, W is the base weight matrix. |
|
|
|
|
|
From "Learnable Multipliers: Freeing the Scale of Language Model Matrix Layers": |
|
|
The base matrix W remains subject to WD-noise equilibrium with ||W|| ∝ √(η/λ), |
|
|
while multipliers r,c learn freely to adapt the effective scale to data. |
|
|
|
|
|
Args: |
|
|
in_features: Input feature dimension |
|
|
out_features: Output feature dimension |
|
|
bias: Whether to include bias term |
|
|
use_row_multiplier: Enable row (output) multipliers |
|
|
use_column_multiplier: Enable column (input) multipliers |
|
|
""" |
|
|
def __init__( |
|
|
self, |
|
|
in_features: int, |
|
|
out_features: int, |
|
|
bias: bool = True, |
|
|
use_row_multiplier: bool = False, |
|
|
use_column_multiplier: bool = False |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.linear = nn.Linear(in_features, out_features, bias=bias) |
|
|
|
|
|
|
|
|
self.use_row_multiplier = use_row_multiplier |
|
|
self.use_column_multiplier = use_column_multiplier |
|
|
|
|
|
if use_row_multiplier: |
|
|
self.row_multiplier = VectorMultiplier(out_features, multiplier_type="row") |
|
|
|
|
|
if use_column_multiplier: |
|
|
self.column_multiplier = VectorMultiplier(in_features, multiplier_type="column") |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
if self.use_column_multiplier: |
|
|
x = self.column_multiplier(x) |
|
|
|
|
|
|
|
|
x = self.linear(x) |
|
|
|
|
|
|
|
|
if self.use_row_multiplier: |
|
|
x = self.row_multiplier(x) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FANLayer(nn.Module): |
|
|
""" |
|
|
Fourier Analysis Network (FAN) layer for effective periodicity modeling. |
|
|
|
|
|
From "FANformer: Improving Large Language Models Through Effective Periodicity Modeling": |
|
|
FANLayer'(X) = [cos(WpX)||sin(WpX)||(Wp¯X + Bp¯)] |
|
|
|
|
|
This is the modified version (FANLayer') without activation function that gave |
|
|
the best results in the paper. |
|
|
""" |
|
|
|
|
|
def __init__(self, hidden_size: int, fan_ratio: float = 0.25): |
|
|
super().__init__() |
|
|
self.hidden_size = hidden_size |
|
|
self.fan_ratio = fan_ratio |
|
|
|
|
|
|
|
|
|
|
|
output_dim = hidden_size + int(hidden_size * fan_ratio) |
|
|
self.p_output_dim = int(output_dim * fan_ratio) |
|
|
self.g_output_dim = output_dim - self.p_output_dim * 2 |
|
|
|
|
|
|
|
|
self.input_linear = nn.Linear( |
|
|
hidden_size, |
|
|
self.p_output_dim + self.g_output_dim, |
|
|
bias=True |
|
|
) |
|
|
|
|
|
|
|
|
self._init_weights() |
|
|
|
|
|
def _init_weights(self): |
|
|
"""Initialize weights following the paper's recommendations.""" |
|
|
nn.init.normal_(self.input_linear.weight, mean=0.0, std=0.02) |
|
|
if self.input_linear.bias is not None: |
|
|
nn.init.zeros_(self.input_linear.bias) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Apply Fourier transformation to input. |
|
|
|
|
|
Args: |
|
|
x: Input tensor of shape (batch, seq_len, hidden_size) |
|
|
|
|
|
Returns: |
|
|
Transformed tensor with Fourier components concatenated |
|
|
Shape: (batch, seq_len, hidden_size + periodic_dim) |
|
|
""" |
|
|
|
|
|
pg = self.input_linear(x) |
|
|
p, g = torch.split(pg, [self.p_output_dim, self.g_output_dim], dim=-1) |
|
|
|
|
|
|
|
|
x_fan = torch.cat([torch.cos(p), torch.sin(p), g], dim=-1) |
|
|
|
|
|
return x_fan |
|
|
|
|
|
|
|
|
class LNS(nn.Module): |
|
|
""" |
|
|
LayerNorm Scaling (LNS) - applies scaling factor 1/√ℓ as described in the paper. |
|
|
|
|
|
From "The Curse of Depth in Large Language Models": |
|
|
h^(ℓ) = LayerNorm(h^(ℓ)) × (1/√ℓ) |
|
|
|
|
|
This prevents exponential variance growth in deeper layers. |
|
|
""" |
|
|
def __init__(self, layer_idx: int): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.layer_idx = max(layer_idx + 1, 1) |
|
|
self.scale = 1.0 / math.sqrt(self.layer_idx) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
return x * self.scale |
|
|
|
|
|
|
|
|
class GPAS(nn.Module): |
|
|
""" |
|
|
Gradient-Preserving Activation Scaling (GPAS) |
|
|
Scales activations without penalizing gradients using stop-gradient. |
|
|
Applied in Pre-Norm style: after sub-layer output but before residual sum. |
|
|
""" |
|
|
def __init__(self, d_model: int): |
|
|
super().__init__() |
|
|
|
|
|
self.d_model = d_model |
|
|
self.alpha = nn.Parameter(torch.zeros(1)) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
x_detached = x.detach() |
|
|
scaled_component = F.silu(self.alpha) * x_detached |
|
|
x_scaled = x - scaled_component |
|
|
|
|
|
return x_scaled |
|
|
|
|
|
class SeeDNorm(nn.Module): |
|
|
""" |
|
|
Self-Rescaled Dynamic Normalization (SeeDNorm) with dual dropout regularization. |
|
|
|
|
|
SeeDNorm(x) = [σ(x·β^T)·α + γ] ⊙ x/RMS(x) |
|
|
|
|
|
Args: |
|
|
dim: Hidden dimension size |
|
|
eps: Small constant for numerical stability |
|
|
dropout_input: Dropout on input features for dynamic mechanism (default: 0.0) |
|
|
dropout_hidden: Dropout on normalized hidden states (default: 0.0) |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
eps: float = 1e-6, |
|
|
dropout_input: float = 0.01, |
|
|
dropout_hidden: float = 0.01, |
|
|
): |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
self.eps = eps |
|
|
self.dropout_input = dropout_input |
|
|
self.dropout_hidden = dropout_hidden |
|
|
|
|
|
|
|
|
self.gamma = nn.Parameter(torch.ones(dim)) |
|
|
self.beta = nn.Parameter(torch.zeros(dim)) |
|
|
self.alpha = nn.Parameter(torch.ones(dim)) |
|
|
|
|
|
def _rms_norm(self, x: torch.Tensor) -> torch.Tensor: |
|
|
"""Compute RMS normalization: x / RMS(x)""" |
|
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Apply Self-Rescaled Dynamic Normalization with dual dropout. |
|
|
|
|
|
Args: |
|
|
x: Input tensor of shape (..., dim) |
|
|
|
|
|
Returns: |
|
|
Normalized and dynamically scaled tensor of same shape |
|
|
""" |
|
|
|
|
|
x_for_dynamic = F.dropout(x, p=self.dropout_input, training=self.training) |
|
|
rescale_factor = torch.tanh(torch.sum(x_for_dynamic * self.beta, |
|
|
dim=-1, keepdim=True)) |
|
|
|
|
|
|
|
|
dynamic_scale = rescale_factor * self.alpha + self.gamma |
|
|
|
|
|
|
|
|
x_normalized = self._rms_norm(x.float()) |
|
|
|
|
|
x_normalized = F.dropout(x_normalized, p=self.dropout_hidden, training=self.training) |
|
|
|
|
|
|
|
|
output = x_normalized * dynamic_scale.float() |
|
|
|
|
|
return output.type_as(x) |
|
|
|
|
|
def extra_repr(self) -> str: |
|
|
return (f"dim={self.dim}, eps={self.eps}, " |
|
|
f"dropout_input={self.dropout_input}, dropout_hidden={self.dropout_hidden}") |
|
|
|
|
|
|
|
|
|
|
|
class StackMemory(nn.Module): |
|
|
""" |
|
|
From "Improving Formal Reasoning of Transformer with State Stack": |
|
|
Implements a multi-head differentiable stack with soft push, pop, and no-op operations. |
|
|
Each head maintains its own stack and mask, which are updated based on learned action |
|
|
probabilities. Global reading is performed via query-over-stack attention. |
|
|
|
|
|
This module is inserted between Transformer layers to augment information flow with |
|
|
stack-like memory operations, enabling the model to better capture hierarchical and |
|
|
recursive patterns characteristic of regular expressions and context-free grammars. |
|
|
|
|
|
Note: StackMemory uses standard nn.Linear to maintain architectural |
|
|
independence and avoid introducing additional complexity in the memory operations. |
|
|
|
|
|
Args: |
|
|
config: Model configuration containing stack-related hyperparameters |
|
|
""" |
|
|
|
|
|
def __init__(self, config: NeoLLMConfig): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.num_stack_heads = getattr(config, 'num_stack_heads', 4) |
|
|
self.stack_slots = getattr(config, 'stack_slots', 24) |
|
|
self.stack_d_model = getattr(config, 'stack_d_model', 128) |
|
|
|
|
|
self.head_dim = self.stack_d_model // self.num_stack_heads |
|
|
|
|
|
|
|
|
|
|
|
self.down_proj = nn.Linear(config.hidden_size, self.stack_d_model, bias=True) |
|
|
self.up_proj = nn.Linear(self.stack_d_model, config.hidden_size, bias=True) |
|
|
|
|
|
|
|
|
self.action_head = nn.Linear(self.stack_d_model, 3 * self.num_stack_heads, bias=True) |
|
|
|
|
|
|
|
|
self.gate_proj = nn.Linear(self.head_dim, 1, bias=True) |
|
|
|
|
|
|
|
|
self.res_weight = nn.Parameter(torch.ones(1)) |
|
|
|
|
|
|
|
|
self.cache_size = getattr(config, "cache_size", 2048) |
|
|
|
|
|
|
|
|
forward_bs = getattr(config, 'forward_bs', 1) |
|
|
self.register_buffer("k_cache", torch.zeros(forward_bs, self.cache_size, self.num_stack_heads, self.head_dim)) |
|
|
self.register_buffer("action_cache", torch.zeros(forward_bs, self.cache_size, self.num_stack_heads, 3)) |
|
|
|
|
|
self.cache_position = 0 |
|
|
self.enable_cache = False |
|
|
|
|
|
def reset_cache(self): |
|
|
self.cache_position = 0 |
|
|
|
|
|
def _vectorized_update( |
|
|
self, |
|
|
stack: torch.Tensor, |
|
|
mask: torch.Tensor, |
|
|
actions: torch.Tensor, |
|
|
k_values: torch.Tensor |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Vectorized stack update mechanism applying soft push/pop/no-op operations. |
|
|
|
|
|
Implements the differentiable stack operations from the paper: |
|
|
- Push: shifts all elements down and places k_values at top |
|
|
- Pop: shifts all elements up and removes top |
|
|
- No-op: maintains current stack state |
|
|
|
|
|
Args: |
|
|
stack: Current stack state [batch, seq, num_heads, stack_slots, head_dim] |
|
|
mask: Current stack mask [batch, seq, num_heads, stack_slots] |
|
|
actions: Action probabilities [batch, seq, num_heads, 3] (push/pop/no-op) |
|
|
k_values: New values to push [batch, seq, num_heads, head_dim] |
|
|
|
|
|
Returns: |
|
|
Tuple of (updated_stack, updated_mask) |
|
|
""" |
|
|
batch_size, seq_len = actions.shape[:2] |
|
|
|
|
|
|
|
|
|
|
|
if stack.dim() == 4: |
|
|
stack = stack.unsqueeze(1).expand(-1, seq_len, -1, -1, -1) |
|
|
mask = mask.unsqueeze(1).expand(-1, seq_len, -1, -1) |
|
|
|
|
|
|
|
|
push_stack = torch.cat([ |
|
|
k_values.unsqueeze(3), |
|
|
stack[:, :, :, :-1] |
|
|
], dim=3) |
|
|
push_mask = torch.cat([ |
|
|
torch.ones_like(mask[:, :, :, :1]), |
|
|
mask[:, :, :, :-1] |
|
|
], dim=3) |
|
|
|
|
|
|
|
|
pop_stack = torch.cat([ |
|
|
stack[:, :, :, 1:], |
|
|
torch.zeros_like(stack[:, :, :, :1]) |
|
|
], dim=3) |
|
|
pop_mask = torch.cat([ |
|
|
mask[:, :, :, 1:], |
|
|
torch.zeros_like(mask[:, :, :, :1]) |
|
|
], dim=3) |
|
|
|
|
|
|
|
|
action_weights = actions.unsqueeze(-1).unsqueeze(-1) |
|
|
stacks = torch.stack([push_stack, pop_stack, stack], dim=3) |
|
|
masks = torch.stack([push_mask, pop_mask, mask], dim=3) |
|
|
|
|
|
|
|
|
new_stack = (stacks * action_weights).sum(dim=3) |
|
|
new_mask = (masks * action_weights.squeeze(-1)).sum(dim=3) |
|
|
|
|
|
return new_stack, new_mask |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
stack: Optional[torch.Tensor] = None, |
|
|
mask: Optional[torch.Tensor] = None |
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Apply differentiable stack operations to hidden states. |
|
|
|
|
|
Args: |
|
|
hidden_states: Input hidden states [batch, seq, hidden_size] |
|
|
stack: Previous stack state [batch, num_heads, stack_slots, head_dim] or None |
|
|
mask: Previous stack mask [batch, num_heads, stack_slots] or None |
|
|
|
|
|
Returns: |
|
|
Tuple of (output_hidden_states, updated_stack, updated_mask) |
|
|
""" |
|
|
batch_size, seq_len, _ = hidden_states.shape |
|
|
device = hidden_states.device |
|
|
|
|
|
|
|
|
if stack is None: |
|
|
stack = torch.zeros( |
|
|
batch_size, self.num_stack_heads, self.stack_slots, self.head_dim, |
|
|
device=device, dtype=hidden_states.dtype |
|
|
) |
|
|
if mask is None: |
|
|
mask = torch.zeros( |
|
|
batch_size, self.num_stack_heads, self.stack_slots, |
|
|
device=device, dtype=hidden_states.dtype |
|
|
) |
|
|
|
|
|
|
|
|
new_hidden_states = self.down_proj(hidden_states) |
|
|
|
|
|
|
|
|
action_logits = self.action_head(new_hidden_states) / math.sqrt(self.head_dim) |
|
|
actions = F.softmax( |
|
|
action_logits.view(batch_size, seq_len, self.num_stack_heads, 3), |
|
|
dim=-1 |
|
|
) |
|
|
|
|
|
|
|
|
k_values = new_hidden_states.view(batch_size, seq_len, self.num_stack_heads, self.head_dim) |
|
|
|
|
|
|
|
|
new_stack, new_mask = self._vectorized_update(stack, mask, actions, k_values) |
|
|
|
|
|
|
|
|
gate_scores = self.gate_proj(new_stack).squeeze(-1) |
|
|
|
|
|
gate_weights = F.softmax(gate_scores + (1 - new_mask) * -1e9, dim=-1) |
|
|
|
|
|
|
|
|
memory_output = (new_stack * gate_weights.unsqueeze(-1)).sum(dim=3) |
|
|
memory_output = memory_output.view(batch_size, seq_len, -1) |
|
|
|
|
|
memory_output = self.up_proj(memory_output) |
|
|
|
|
|
|
|
|
output = memory_output * self.res_weight + hidden_states |
|
|
|
|
|
|
|
|
if self.enable_cache: |
|
|
self._update_cache(k_values.detach(), actions.detach()) |
|
|
|
|
|
return output, new_stack[:, -1], new_mask[:, -1] |
|
|
|
|
|
def _update_cache(self, k_values: torch.Tensor, actions: torch.Tensor): |
|
|
seq_len = k_values.shape[1] |
|
|
if self.cache_position + seq_len <= self.cache_size: |
|
|
|
|
|
self.k_cache[:, self.cache_position:self.cache_position+seq_len] = k_values |
|
|
self.action_cache[:, self.cache_position:self.cache_position+seq_len] = actions |
|
|
self.cache_position += seq_len |
|
|
else: |
|
|
self.reset_cache() |
|
|
|
|
|
def step(self, hidden_state: torch.Tensor, stack: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
if not self.enable_cache: |
|
|
return self.forward(hidden_state.unsqueeze(1), stack, mask) |
|
|
|
|
|
batch_size = hidden_state.shape[0] |
|
|
|
|
|
|
|
|
new_hidden_states = self.down_proj(hidden_state) |
|
|
|
|
|
action_logits = self.action_head(new_hidden_states) / math.sqrt(self.head_dim) |
|
|
current_actions = F.softmax( |
|
|
action_logits.view(batch_size, 1, self.num_stack_heads, 3), |
|
|
dim=-1 |
|
|
) |
|
|
|
|
|
current_k = new_hidden_states.view(batch_size, 1, self.num_stack_heads, self.head_dim) |
|
|
|
|
|
|
|
|
if self.cache_position > 0: |
|
|
cached_k = self.k_cache[:, :self.cache_position] |
|
|
cached_actions = self.action_cache[:, :self.cache_position] |
|
|
|
|
|
k_values = torch.cat([cached_k, current_k], dim=1) |
|
|
actions = torch.cat([cached_actions, current_actions], dim=1) |
|
|
else: |
|
|
k_values = current_k |
|
|
actions = current_actions |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
new_stack_seq, new_mask_seq = self._vectorized_update( |
|
|
stack, |
|
|
mask, |
|
|
actions, |
|
|
k_values |
|
|
) |
|
|
|
|
|
|
|
|
current_stack = new_stack_seq[:, -1] |
|
|
current_mask = new_mask_seq[:, -1] |
|
|
|
|
|
gate_scores = self.gate_proj(current_stack).squeeze(-1) |
|
|
gate_weights = F.softmax(gate_scores + (1 - current_mask) * -1e9, dim=-1) |
|
|
|
|
|
memory_output = (current_stack * gate_weights.unsqueeze(-1)).sum(dim=2) |
|
|
memory_output = memory_output.view(batch_size, -1) |
|
|
|
|
|
memory_output_proj = self.up_proj(memory_output) |
|
|
|
|
|
self._update_cache(current_k, current_actions) |
|
|
|
|
|
return ( |
|
|
memory_output_proj * self.res_weight + hidden_state, |
|
|
current_stack, |
|
|
current_mask |
|
|
) |
|
|
|
|
|
class NeoLLMRotaryEmbedding(nn.Module): |
|
|
inv_freq: torch.Tensor |
|
|
|
|
|
def __init__(self, config: NeoLLMConfig, device=None): |
|
|
super().__init__() |
|
|
self.max_seq_len_cached = config.max_position_embeddings |
|
|
self.original_max_seq_len = config.max_position_embeddings |
|
|
self.config = config |
|
|
|
|
|
|
|
|
self.rope_type = "default" |
|
|
if hasattr(config, "rope_scaling") and config.rope_scaling is not None and isinstance(config.rope_scaling, dict): |
|
|
rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) |
|
|
if rope_type and rope_type in ROPE_INIT_FUNCTIONS: |
|
|
self.rope_type = rope_type |
|
|
|
|
|
|
|
|
rope_init_fn = self.compute_default_rope_parameters |
|
|
if self.rope_type != "default": |
|
|
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] |
|
|
|
|
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device) |
|
|
|
|
|
self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) |
|
|
|
|
|
@staticmethod |
|
|
def compute_default_rope_parameters( |
|
|
config: NeoLLMConfig = None, |
|
|
device: Optional["torch.device"] = None, |
|
|
seq_len: int = None, |
|
|
) -> tuple["torch.Tensor", float]: |
|
|
""" |
|
|
Computes the inverse frequencies according to the original RoPE implementation |
|
|
|
|
|
Args: |
|
|
config: The model configuration. |
|
|
device: The device to use for initialization of the inverse frequencies. |
|
|
seq_len: The current sequence length. Unused for this type of RoPE. |
|
|
|
|
|
Returns: |
|
|
Tuple of (torch.Tensor, float), containing the inverse frequencies for the RoPE |
|
|
embeddings and the post-processing scaling factor applied to the computed cos/sin. |
|
|
""" |
|
|
base = config.rope_theta |
|
|
dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads |
|
|
partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0) |
|
|
dim = int(dim * partial_rotary_factor) |
|
|
|
|
|
attention_scaling = 1.0 |
|
|
|
|
|
|
|
|
inv_freq = 1.0 / ( |
|
|
base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) |
|
|
) |
|
|
return inv_freq, attention_scaling |
|
|
|
|
|
@torch.no_grad() |
|
|
@dynamic_rope_update |
|
|
def forward(self, x, position_ids): |
|
|
|
|
|
if position_ids.dim() == 1: |
|
|
position_ids = position_ids.unsqueeze(0) |
|
|
|
|
|
B = x.shape[0] |
|
|
if position_ids.shape[0] != B: |
|
|
|
|
|
position_ids = position_ids.expand(B, -1) |
|
|
|
|
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" |
|
|
|
|
|
|
|
|
inv_freq = self.inv_freq.to(device=x.device, dtype=torch.float32) |
|
|
|
|
|
with torch.autocast(device_type=device_type, enabled=False): |
|
|
|
|
|
freqs = position_ids.to(dtype=torch.float32).unsqueeze(-1) * inv_freq.unsqueeze(0).unsqueeze(0) |
|
|
|
|
|
|
|
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
|
cos = emb.cos() * self.attention_scaling |
|
|
sin = emb.sin() * self.attention_scaling |
|
|
|
|
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) |
|
|
|
|
|
|
|
|
|
|
|
def rotate_half(x): |
|
|
"""Rotates half the hidden dims of the input.""" |
|
|
x1 = x[..., : x.shape[-1] // 2] |
|
|
x2 = x[..., x.shape[-1] // 2 :] |
|
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
|
|
|
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
|
|
"""Applies Rotary Position Embedding to the query and key tensors.""" |
|
|
cos = cos.unsqueeze(unsqueeze_dim) |
|
|
sin = sin.unsqueeze(unsqueeze_dim) |
|
|
|
|
|
|
|
|
rotary_dim = cos.shape[-1] |
|
|
q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] |
|
|
k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] |
|
|
|
|
|
|
|
|
q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) |
|
|
k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) |
|
|
|
|
|
|
|
|
q_embed = torch.cat([q_embed, q_pass], dim=-1) |
|
|
k_embed = torch.cat([k_embed, k_pass], dim=-1) |
|
|
return q_embed, k_embed |
|
|
|
|
|
|
|
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
|
|
""" |
|
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
|
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) |
|
|
""" |
|
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
|
|
if n_rep == 1: |
|
|
return hidden_states |
|
|
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) |
|
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
|
|
|
|
|
def eager_attention_forward( |
|
|
module: nn.Module, |
|
|
query: torch.Tensor, |
|
|
key: torch.Tensor, |
|
|
value: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor], |
|
|
scaling: float, |
|
|
dropout: float = 0.0, |
|
|
**kwargs: Unpack[TransformersKwargs], |
|
|
): |
|
|
key_states = repeat_kv(key, module.num_key_value_groups) |
|
|
value_states = repeat_kv(value, module.num_key_value_groups) |
|
|
|
|
|
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling |
|
|
if attention_mask is not None: |
|
|
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] |
|
|
attn_weights = attn_weights + causal_mask |
|
|
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) |
|
|
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) |
|
|
attn_output = torch.matmul(attn_weights, value_states) |
|
|
attn_output = attn_output.transpose(1, 2).contiguous() |
|
|
|
|
|
return attn_output, attn_weights |
|
|
|
|
|
|
|
|
class NeoLLMAttention(nn.Module): |
|
|
""" |
|
|
Multi-headed attention with FANformer integration, SeeDNorm for Q/K normalization, |
|
|
ResFormer feature residual connections, and Learnable Multipliers for enhanced |
|
|
information flow and scale adaptation. |
|
|
|
|
|
ResFormer enhancement: Applies learnable feature residual connections from first layer |
|
|
BEFORE QKV projections: H'_fan_n = λ_1 * H_fan_1 + λ_2 * H_fan_n |
|
|
|
|
|
Learnable Multipliers placement (from "Learnable Multipliers" paper Appendix C): |
|
|
- Q projection: row multipliers only (enables per-head attention scaling in GQA) |
|
|
- K, V projections: no multipliers (avoids redundancy with Q multipliers) |
|
|
- Output projection: row + column multipliers (maximally expressive without symmetries) |
|
|
""" |
|
|
|
|
|
def __init__(self, config: NeoLLMConfig, layer_idx: int): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.layer_idx = layer_idx |
|
|
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) |
|
|
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads |
|
|
self.scaling = self.head_dim**-0.5 |
|
|
self.attention_dropout = config.attention_dropout |
|
|
self.is_causal = True |
|
|
|
|
|
|
|
|
self.fan_layer = FANLayer( |
|
|
hidden_size=config.hidden_size, |
|
|
fan_ratio=getattr(config, 'fan_ratio', 0.125) |
|
|
) |
|
|
|
|
|
|
|
|
fan_output_dim = config.hidden_size + int(config.hidden_size * getattr(config, 'fan_ratio', 0.125)) |
|
|
|
|
|
|
|
|
self.q_proj = LinearWithMultipliers( |
|
|
fan_output_dim, |
|
|
config.num_attention_heads * self.head_dim * 2, |
|
|
bias=config.attention_bias, |
|
|
use_row_multiplier=True, |
|
|
use_column_multiplier=False |
|
|
) |
|
|
|
|
|
|
|
|
self.k_proj = nn.Linear( |
|
|
fan_output_dim, config.num_key_value_heads * self.head_dim, bias=config.attention_bias |
|
|
) |
|
|
self.v_proj = nn.Linear( |
|
|
fan_output_dim, config.num_key_value_heads * self.head_dim, bias=config.attention_bias |
|
|
) |
|
|
|
|
|
|
|
|
self.o_proj = LinearWithMultipliers( |
|
|
config.num_attention_heads * self.head_dim, |
|
|
config.hidden_size, |
|
|
bias=config.attention_bias, |
|
|
use_row_multiplier=True, |
|
|
use_column_multiplier=True |
|
|
) |
|
|
|
|
|
|
|
|
self.q_norm = SeeDNorm(self.head_dim, eps=config.rms_norm_eps) |
|
|
self.k_norm = SeeDNorm(self.head_dim, eps=config.rms_norm_eps) |
|
|
|
|
|
|
|
|
self.dropout = nn.Dropout(config.dropout_rate) |
|
|
|
|
|
|
|
|
self.lambda_1 = nn.Parameter(torch.tensor(0.5)) |
|
|
self.lambda_2 = nn.Parameter(torch.tensor(0.5)) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
position_embeddings: tuple[torch.Tensor, torch.Tensor], |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
first_layer_fan: Optional[torch.Tensor] = None, |
|
|
**kwargs: Unpack[FlashAttentionKwargs], |
|
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: |
|
|
""" |
|
|
Forward pass with ResFormer feature residual connections. |
|
|
|
|
|
Args: |
|
|
hidden_states: Current layer input [batch, seq, hidden_size] |
|
|
position_embeddings: Tuple of (cos, sin) for RoPE |
|
|
attention_mask: Causal attention mask |
|
|
first_layer_fan: First layer FAN features (for ResFormer) |
|
|
|
|
|
Returns: |
|
|
Tuple of (attn_output, attn_weights, current_layer_fan) |
|
|
""" |
|
|
input_shape = hidden_states.shape[:-1] |
|
|
|
|
|
|
|
|
hidden_states_fan = self.fan_layer(hidden_states) |
|
|
|
|
|
|
|
|
if first_layer_fan is not None: |
|
|
hidden_states_fan = self.lambda_1 * first_layer_fan + self.lambda_2 * hidden_states_fan |
|
|
|
|
|
|
|
|
current_layer_fan = hidden_states_fan.clone() |
|
|
|
|
|
hidden_shape = (*input_shape, -1, self.head_dim) |
|
|
|
|
|
|
|
|
query_states, gate = torch.chunk( |
|
|
self.q_proj(hidden_states_fan).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1 |
|
|
) |
|
|
gate = gate.reshape(*input_shape, -1) |
|
|
|
|
|
|
|
|
query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2) |
|
|
key_states = self.k_norm(self.k_proj(hidden_states_fan).view(hidden_shape)).transpose(1, 2) |
|
|
value_states = self.v_proj(hidden_states_fan).view(hidden_shape).transpose(1, 2) |
|
|
|
|
|
cos, sin = position_embeddings |
|
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
|
|
|
|
|
attention_interface: Callable = eager_attention_forward |
|
|
if self.config._attn_implementation != "eager": |
|
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] |
|
|
|
|
|
attn_output, attn_weights = attention_interface( |
|
|
self, |
|
|
query_states, |
|
|
key_states, |
|
|
value_states, |
|
|
attention_mask, |
|
|
dropout=0.0 if not self.training else self.attention_dropout, |
|
|
scaling=self.scaling, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
attn_output = attn_output.reshape(*input_shape, -1).contiguous() |
|
|
attn_output = attn_output * torch.sigmoid(gate) |
|
|
|
|
|
|
|
|
attn_output = self.o_proj(attn_output) |
|
|
attn_output = self.dropout(attn_output) |
|
|
|
|
|
return attn_output, attn_weights, current_layer_fan |
|
|
|
|
|
|
|
|
class PolyNorm(torch.nn.Module): |
|
|
def __init__(self, eps=1e-6): |
|
|
super(PolyNorm, self).__init__() |
|
|
self.weight = torch.nn.Parameter(torch.ones(3) / 3) |
|
|
self.bias = torch.nn.Parameter(torch.zeros(1)) |
|
|
self.eps = eps |
|
|
|
|
|
def _norm(self, x): |
|
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.weight[0] * self._norm(x**3) + self.weight[1] * self._norm(x**2) + self.weight[2] * self._norm(x) + self.bias |
|
|
|
|
|
|
|
|
class NeoLLMMLP(nn.Module): |
|
|
""" |
|
|
MLP with FANformer integration for featural periodicity modeling and |
|
|
Learnable Multipliers for adaptive scale control. |
|
|
|
|
|
This captures periodicities in the feature space (semantic/embedding dimensions) |
|
|
complementary to the relational periodicities captured by attention mechanisms. |
|
|
Works in conjunction with ResFormer for comprehensive information flow. |
|
|
|
|
|
Learnable Multipliers placement (from "Learnable Multipliers" paper Appendix C): |
|
|
- gate_proj: row multipliers only (controls gating mechanism scale) |
|
|
- up_proj: no multipliers (avoids redundancy with down_proj) |
|
|
- down_proj: row + column multipliers (maximally expressive output scaling) |
|
|
""" |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.hidden_size = config.hidden_size |
|
|
self.intermediate_size = config.intermediate_size |
|
|
|
|
|
|
|
|
self.fan_layer = FANLayer( |
|
|
hidden_size=config.hidden_size, |
|
|
fan_ratio=getattr(config, 'fan_ratio_ffn', 0.0625) |
|
|
) |
|
|
|
|
|
|
|
|
fan_output_dim = config.hidden_size + int(config.hidden_size * getattr(config, 'fan_ratio_ffn', 0.0625)) |
|
|
|
|
|
|
|
|
|
|
|
self.gate_proj = LinearWithMultipliers( |
|
|
fan_output_dim, |
|
|
self.intermediate_size, |
|
|
bias=False, |
|
|
use_row_multiplier=True, |
|
|
use_column_multiplier=False |
|
|
) |
|
|
|
|
|
|
|
|
self.up_proj = nn.Linear(fan_output_dim, self.intermediate_size, bias=False) |
|
|
|
|
|
|
|
|
self.down_proj = LinearWithMultipliers( |
|
|
self.intermediate_size, |
|
|
self.hidden_size, |
|
|
bias=False, |
|
|
use_row_multiplier=True, |
|
|
use_column_multiplier=True |
|
|
) |
|
|
|
|
|
self.act_fn = PolyNorm() |
|
|
|
|
|
|
|
|
self.dropout = nn.Dropout(config.dropout_rate) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
x_fan = self.fan_layer(x) |
|
|
|
|
|
|
|
|
gate_output = self.act_fn(self.gate_proj(x_fan)) |
|
|
up_output = self.up_proj(x_fan) |
|
|
hidden = gate_output * up_output |
|
|
hidden = self.dropout(hidden) |
|
|
return self.down_proj(hidden) |
|
|
|
|
|
class NeoLLMDecoderLayer(GradientCheckpointingLayer): |
|
|
""" |
|
|
Decoder layer with standard residual connections and optional StackMemory. |
|
|
|
|
|
Architecture (Updated Flow): |
|
|
1. Optional: StackMemory module (Pre-processing context injection) |
|
|
2. Pre-norm (SeeDNorm) → LNS scaling → Self-Attention with ResFormer and Learnable Multipliers |
|
|
3. Standard Residual Connection |
|
|
4. GPAS activation scaling |
|
|
5. Pre-norm (SeeDNorm) → LNS scaling → MLP with FANformer and Learnable Multipliers |
|
|
6. Standard Residual Connection |
|
|
7. GPAS activation scaling |
|
|
""" |
|
|
|
|
|
def __init__(self, config: NeoLLMConfig, layer_idx: int): |
|
|
super().__init__() |
|
|
self.hidden_size = config.hidden_size |
|
|
self.layer_idx = layer_idx |
|
|
|
|
|
|
|
|
self.self_attn = NeoLLMAttention(config, layer_idx) |
|
|
|
|
|
|
|
|
self.mlp = NeoLLMMLP(config) |
|
|
|
|
|
|
|
|
self.input_layernorm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
self.post_attention_layernorm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
|
|
|
|
|
|
self.lns_attn = LNS(layer_idx) |
|
|
self.lns_mlp = LNS(layer_idx) |
|
|
|
|
|
|
|
|
self.gpas_attn = GPAS(config.hidden_size) |
|
|
self.gpas_mlp = GPAS(config.hidden_size) |
|
|
|
|
|
|
|
|
self.use_stack = getattr(config, 'use_stack', False) |
|
|
if self.use_stack: |
|
|
self.stack_memory = StackMemory(config) |
|
|
|
|
|
|
|
|
self.current_layer_fan = None |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
position_embeddings: tuple[torch.Tensor, torch.Tensor], |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
first_layer_fan: Optional[torch.Tensor] = None, |
|
|
stack_state: Optional[torch.Tensor] = None, |
|
|
stack_mask: Optional[torch.Tensor] = None, |
|
|
output_attentions: Optional[bool] = False, |
|
|
**kwargs: Unpack[FlashAttentionKwargs], |
|
|
) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[torch.Tensor], Optional[torch.Tensor]]: |
|
|
""" |
|
|
Forward pass with ResFormer and optional StackMemory. |
|
|
|
|
|
Args: |
|
|
hidden_states: Current layer input [batch, seq, hidden_size] |
|
|
position_embeddings: Tuple of (cos, sin) for RoPE |
|
|
attention_mask: Causal attention mask |
|
|
first_layer_fan: First layer FAN features (for ResFormer) |
|
|
stack_state: StackMemory state (optional) |
|
|
stack_mask: StackMemory mask (optional) |
|
|
output_attentions: Whether to return attention weights |
|
|
|
|
|
Returns: |
|
|
Tuple of (hidden_states, attn_weights, stack_state, stack_mask) |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.use_stack: |
|
|
hidden_states, stack_state, stack_mask = self.stack_memory( |
|
|
hidden_states, stack_state, stack_mask |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
residual = hidden_states |
|
|
|
|
|
|
|
|
hidden_states = self.input_layernorm(hidden_states) |
|
|
|
|
|
|
|
|
hidden_states = self.lns_attn(hidden_states) |
|
|
|
|
|
|
|
|
attn_output, attn_weights, self.current_layer_fan = self.self_attn( |
|
|
hidden_states=hidden_states, |
|
|
position_embeddings=position_embeddings, |
|
|
attention_mask=attention_mask, |
|
|
first_layer_fan=first_layer_fan, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
hidden_states = residual + attn_output |
|
|
|
|
|
|
|
|
hidden_states = self.gpas_attn(hidden_states) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
residual = hidden_states |
|
|
hidden_states = self.post_attention_layernorm(hidden_states) |
|
|
|
|
|
|
|
|
hidden_states = self.lns_mlp(hidden_states) |
|
|
|
|
|
|
|
|
mlp_output = self.mlp(hidden_states) |
|
|
|
|
|
|
|
|
hidden_states = residual + mlp_output |
|
|
|
|
|
|
|
|
hidden_states = self.gpas_mlp(hidden_states) |
|
|
|
|
|
|
|
|
if self.use_stack: |
|
|
return (hidden_states, attn_weights, stack_state, stack_mask) |
|
|
else: |
|
|
return (hidden_states, attn_weights, None, None) |
|
|
|
|
|
|
|
|
class NeoLLMPreTrainedModel(PreTrainedModel): |
|
|
""" |
|
|
Base class for NeoLLM models with custom weight initialization. |
|
|
|
|
|
Handles initialization for: |
|
|
- NeoLLMAttention (ResFormer lambda parameters) |
|
|
- GPAS (Gradient-Preserving Activation Scaling) |
|
|
- FANLayer (Fourier Analysis Network) |
|
|
- SeeDNorm (Self-Rescaled Dynamic Normalization) |
|
|
- Learnable Multipliers (ScalarMultiplier, VectorMultiplier) |
|
|
- StackMemory (Differentiable Hidden State Stack) |
|
|
""" |
|
|
config: NeoLLMConfig |
|
|
base_model_prefix = "model" |
|
|
supports_gradient_checkpointing = True |
|
|
_no_split_modules = ["NeoLLMDecoderLayer"] |
|
|
_supports_flash_attn_2 = True |
|
|
_supports_sdpa = True |
|
|
_is_stateful = True |
|
|
|
|
|
def _init_weights(self, module): |
|
|
""" |
|
|
Initialize weights for all custom modules in NeoLLM. |
|
|
""" |
|
|
super()._init_weights(module) |
|
|
|
|
|
if isinstance(module, NeoLLMAttention): |
|
|
if hasattr(module, 'lambda_1'): |
|
|
module.lambda_1.data.fill_(0.5) |
|
|
if hasattr(module, 'lambda_2'): |
|
|
module.lambda_2.data.fill_(0.5) |
|
|
|
|
|
elif isinstance(module, GPAS): |
|
|
module.alpha.data.fill_(0.0) |
|
|
|
|
|
elif isinstance(module, (ScalarMultiplier, VectorMultiplier)): |
|
|
if hasattr(module, 'multiplier'): |
|
|
module.multiplier.data.fill_(1.0) |
|
|
|
|
|
elif isinstance(module, StackMemory): |
|
|
std = self.config.initializer_range if hasattr(self.config, 'initializer_range') else 0.02 |
|
|
if hasattr(module, 'down_proj'): |
|
|
module.down_proj.weight.data.normal_(mean=0.0, std=std) |
|
|
if hasattr(module, 'up_proj'): |
|
|
module.up_proj.weight.data.normal_(mean=0.0, std=std) |
|
|
if hasattr(module, 'action_head'): |
|
|
module.action_head.weight.data.normal_(mean=0.0, std=std) |
|
|
if module.action_head.bias is not None: |
|
|
module.action_head.bias.data.zero_() |
|
|
if hasattr(module, 'gate_proj'): |
|
|
module.gate_proj.weight.data.normal_(mean=0.0, std=std) |
|
|
if hasattr(module, 'res_weight'): |
|
|
module.res_weight.data.fill_(1.0) |
|
|
|
|
|
|
|
|
class NeoLLMModel(NeoLLMPreTrainedModel): |
|
|
""" |
|
|
NeoLLM base model with transformer decoder architecture. |
|
|
|
|
|
Uses ResFormer for first-layer feature propagation with standard residual connections |
|
|
and optional StackMemory for hierarchical pattern modeling. |
|
|
|
|
|
Note on embeddings and weight tying: This model uses weight tying between |
|
|
embed_tokens and lm_head (shared weights). Following "Learnable Multipliers" |
|
|
paper analysis, we do NOT add multipliers to embeddings because: |
|
|
|
|
|
1. Weight tying creates conflicting gradient paths |
|
|
2. The paper explicitly warns against multipliers in lm_head |
|
|
3. Compensating mechanisms provide scale adaptation immediately after embedding |
|
|
""" |
|
|
|
|
|
def __init__(self, config: NeoLLMConfig): |
|
|
super().__init__(config) |
|
|
|
|
|
|
|
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) |
|
|
|
|
|
|
|
|
self.layers = nn.ModuleList( |
|
|
[NeoLLMDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] |
|
|
) |
|
|
|
|
|
|
|
|
self.norm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
self.rotary_emb = NeoLLMRotaryEmbedding(config=config) |
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
|
|
|
self.use_stack = getattr(config, 'use_stack', False) |
|
|
|
|
|
|
|
|
self.first_layer_fan = None |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
**kwargs: Unpack[TransformersKwargs], |
|
|
) -> BaseModelOutputWithPast: |
|
|
output_hidden_states = ( |
|
|
output_hidden_states if output_hidden_states is not None |
|
|
else self.config.output_hidden_states |
|
|
) |
|
|
output_attentions = ( |
|
|
output_attentions if output_attentions is not None |
|
|
else self.config.output_attentions |
|
|
) |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
if (input_ids is None) ^ (inputs_embeds is not None): |
|
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
|
|
|
|
|
if inputs_embeds is None: |
|
|
inputs_embeds = self.embed_tokens(input_ids) |
|
|
|
|
|
if position_ids is None: |
|
|
position_ids = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0) |
|
|
|
|
|
causal_mask = create_causal_mask( |
|
|
config=self.config, |
|
|
input_embeds=inputs_embeds, |
|
|
attention_mask=attention_mask, |
|
|
cache_position=position_ids.squeeze(0), |
|
|
past_key_values=None, |
|
|
position_ids=position_ids, |
|
|
) |
|
|
|
|
|
hidden_states = inputs_embeds |
|
|
next_decoder_cache = None |
|
|
all_hidden_states = () if output_hidden_states else None |
|
|
all_attentions = () if output_attentions else None |
|
|
|
|
|
|
|
|
position_embeddings = self.rotary_emb(hidden_states, position_ids) |
|
|
|
|
|
|
|
|
self.first_layer_fan = None |
|
|
|
|
|
|
|
|
stack_state = None |
|
|
stack_mask = None |
|
|
|
|
|
|
|
|
if self.use_stack: |
|
|
for layer in self.layers: |
|
|
if hasattr(layer, 'stack_memory'): |
|
|
layer.stack_memory.enable_cache = use_cache if use_cache is not None else False |
|
|
if past_key_values is None: |
|
|
layer.stack_memory.reset_cache() |
|
|
|
|
|
for decoder_layer in self.layers: |
|
|
if output_hidden_states: |
|
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
|
|
|
layer_outputs = decoder_layer( |
|
|
hidden_states, |
|
|
position_embeddings=position_embeddings, |
|
|
attention_mask=causal_mask, |
|
|
first_layer_fan=self.first_layer_fan, |
|
|
stack_state=stack_state, |
|
|
stack_mask=stack_mask, |
|
|
output_attentions=output_attentions, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
hidden_states = layer_outputs[0] |
|
|
|
|
|
if output_attentions: |
|
|
all_attentions = all_attentions + (layer_outputs[1],) |
|
|
|
|
|
if self.use_stack: |
|
|
|
|
|
|
|
|
|
|
|
stack_state = layer_outputs[2] |
|
|
stack_mask = layer_outputs[3] |
|
|
|
|
|
|
|
|
|
|
|
if self.first_layer_fan is None and hasattr(decoder_layer, 'current_layer_fan'): |
|
|
self.first_layer_fan = decoder_layer.current_layer_fan |
|
|
|
|
|
|
|
|
hidden_states = self.norm(hidden_states) |
|
|
|
|
|
if output_hidden_states: |
|
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
|
|
|
if not return_dict: |
|
|
return tuple(v for v in [hidden_states, next_decoder_cache, all_hidden_states, all_attentions] if v is not None) |
|
|
|
|
|
return BaseModelOutputWithPast( |
|
|
last_hidden_state=hidden_states, |
|
|
past_key_values=next_decoder_cache, |
|
|
hidden_states=all_hidden_states, |
|
|
attentions=all_attentions, |
|
|
) |
|
|
|
|
|
|
|
|
@torch.compiler.disable |
|
|
def compute_cce_loss(hidden_states, labels, lm_head_weight, lm_head_bias=None, pad_token_id=None): |
|
|
""" |
|
|
CCE loss computation excluded from compilation. |
|
|
Preprocesses labels to eliminate torch.compile warnings. |
|
|
""" |
|
|
|
|
|
processed_labels = labels.to(hidden_states.device) |
|
|
|
|
|
|
|
|
if pad_token_id is not None: |
|
|
processed_labels = torch.where( |
|
|
processed_labels == pad_token_id, |
|
|
torch.tensor(-100, dtype=processed_labels.dtype, device=processed_labels.device), |
|
|
processed_labels |
|
|
) |
|
|
|
|
|
return linear_cross_entropy( |
|
|
hidden_states, |
|
|
lm_head_weight, |
|
|
processed_labels, |
|
|
bias=lm_head_bias, |
|
|
shift=1, |
|
|
impl="cce_kahan_full_c", |
|
|
reduction="mean" |
|
|
) |
|
|
|
|
|
|
|
|
class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin): |
|
|
""" |
|
|
Causal Language Model with NeoLLM architecture. |
|
|
|
|
|
Supports ResFormer with standard residuals and optional StackMemory. |
|
|
|
|
|
Note on LM head: Following "Learnable Multipliers" paper recommendations, |
|
|
the output projection (lm_head) does NOT include learnable multipliers. |
|
|
""" |
|
|
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.model = NeoLLMModel(config) |
|
|
self.vocab_size = config.vocab_size |
|
|
|
|
|
|
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def prepare_inputs_for_generation( |
|
|
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs |
|
|
): |
|
|
if past_key_values: |
|
|
past_length = past_key_values[0][0].shape[2] |
|
|
|
|
|
|
|
|
if input_ids.shape[1] > past_length: |
|
|
remove_prefix_length = past_length |
|
|
else: |
|
|
|
|
|
remove_prefix_length = input_ids.shape[1] - 1 |
|
|
|
|
|
input_ids = input_ids[:, remove_prefix_length:] |
|
|
|
|
|
position_ids = kwargs.get("position_ids", None) |
|
|
if attention_mask is not None and position_ids is None: |
|
|
|
|
|
position_ids = attention_mask.long().cumsum(-1) - 1 |
|
|
position_ids.masked_fill_(attention_mask == 0, 1) |
|
|
if past_key_values: |
|
|
position_ids = position_ids[:, -input_ids.shape[1] :] |
|
|
|
|
|
return { |
|
|
"input_ids": input_ids, |
|
|
"past_key_values": past_key_values, |
|
|
"use_cache": kwargs.get("use_cache"), |
|
|
"position_ids": position_ids, |
|
|
"attention_mask": attention_mask, |
|
|
"inputs_embeds": inputs_embeds, |
|
|
} |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
logits_to_keep: Union[int, torch.Tensor] = 0, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
|
|
|
**kwargs: Unpack[TransformersKwargs], |
|
|
) -> CausalLMOutputWithPast: |
|
|
outputs: BaseModelOutputWithPast = self.model( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
inputs_embeds=inputs_embeds, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
|
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
hidden_states = outputs.last_hidden_state |
|
|
|
|
|
|
|
|
if labels is not None: |
|
|
loss = compute_cce_loss( |
|
|
hidden_states, |
|
|
labels, |
|
|
self.lm_head.weight, |
|
|
getattr(self.lm_head, 'bias', None), |
|
|
self.config.pad_token_id |
|
|
) |
|
|
logits = None |
|
|
else: |
|
|
|
|
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
|
|
logits = self.lm_head(hidden_states[:, slice_indices, :]) |
|
|
loss = None |
|
|
|
|
|
return CausalLMOutputWithPast( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
past_key_values=outputs.past_key_values, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
"NeoLLMForCausalLM", |
|
|
"NeoLLMModel", |
|
|
"NeoLLMPreTrainedModel", |
|
|
"NeoLLMConfig", |
|
|
"FANLayer", |
|
|
"SeeDNorm", |
|
|
"ScalarMultiplier", |
|
|
"VectorMultiplier", |
|
|
"LinearWithMultipliers", |
|
|
"StackMemory", |
|
|
] |
|
|
|
|
|
|
|
|
AutoConfig.register("neollm", NeoLLMConfig) |
|
|
AutoModel.register(NeoLLMConfig, NeoLLMModel) |
|
|
AutoModelForCausalLM.register(NeoLLMConfig, NeoLLMForCausalLM) |