NeoLLM / modeling_neollm.py
KitsuVp's picture
Update modeling_neollm.py
ac61dcd verified
#!/usr/bin/env python3
"""
NeoLLM Model with FANformer Integration in both Attention and FFN, Dropout Regularization,
SeeDNorm (Self-Rescaled Dynamic Normalization), ResFormer Value Residual Learning,
Learnable Multipliers for enhanced scale adaptation and information flow through deep layers,
and StackMemory for hierarchical pattern modeling.
Updated to include:
- Fourier Analysis Network (FAN) layer for effective periodicity modeling in attention (relational space)
- FAN layer in FFN for featural periodicity modeling (complementary coverage)
- SeeDNorm: Dynamic normalization with input-dependent scaling for better adaptability
- Dropout regularization at strategic locations
- ResFormer: Feature residual connections from first layer (applied before projections)
- Learnable Multipliers: Frees weight matrix scale from WD-noise equilibrium for data-adaptive scaling
- StackMemory: Differentiable hidden state stack for modeling Chomsky hierarchy grammars
- Full Attention only (linear attention removed)
"""
import math
from typing import Any, Callable, Optional, Union, Tuple, List
import torch
import torch.nn.functional as F
from torch import nn
from cut_cross_entropy import linear_cross_entropy
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from typing import Optional, Tuple
from transformers.activations import ACT2FN
from transformers.generation import GenerationMixin
from transformers.masking_utils import create_causal_mask
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_layers import GradientCheckpointingLayer
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs, logging
from transformers.utils.generic import check_model_inputs
from configuration_neollm import NeoLLMConfig
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
logger = logging.get_logger(__name__)
# ==================== LEARNABLE MULTIPLIERS ====================
class ScalarMultiplier(nn.Module):
"""
Scalar Learnable Multiplier: W̃ = s·W
From "Learnable Multipliers: Freeing the Scale of Language Model Matrix Layers":
Allows the effective matrix norm ||W̃|| = s·||W|| to adapt to data, escaping the
WD-noise equilibrium that constrains ||W|| ∝ √(η/λ).
Args:
initial_value: Initial multiplier value (default: 1.0 for identity)
"""
def __init__(self, initial_value: float = 1.0):
super().__init__()
self.multiplier = nn.Parameter(torch.tensor(initial_value))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.multiplier * x
class VectorMultiplier(nn.Module):
"""
Vector Learnable Multipliers: W̃ = diag(r)·W·diag(c)
From "Learnable Multipliers: Freeing the Scale of Language Model Matrix Layers":
Frees not only the overall matrix norm but also individual row/column norms from
the WD-noise equilibrium, enabling richer feature scale diversity.
Args:
dim: Dimension size for the multiplier vector
multiplier_type: Either "row" or "column"
initial_value: Initial multiplier value (default: 1.0)
"""
def __init__(self, dim: int, multiplier_type: str = "row", initial_value: float = 1.0):
super().__init__()
self.multiplier_type = multiplier_type
self.multiplier = nn.Parameter(torch.ones(dim) * initial_value)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Apply row or column multiplier.
For row multipliers: x shape is (batch, seq, out_features) or (batch, heads, seq, head_dim)
For column multipliers: applied before matrix multiplication
"""
if self.multiplier_type == "row":
# Broadcast along the last dimension (output features)
return x * self.multiplier
else: # column
# For column multipliers, typically applied before linear layer
return x * self.multiplier
class LinearWithMultipliers(nn.Module):
"""
Linear layer with optional row and/or column learnable multipliers.
Implements: y = (r ⊙ (W @ (c ⊙ x))) + b
where r and c are learnable multipliers, W is the base weight matrix.
From "Learnable Multipliers: Freeing the Scale of Language Model Matrix Layers":
The base matrix W remains subject to WD-noise equilibrium with ||W|| ∝ √(η/λ),
while multipliers r,c learn freely to adapt the effective scale to data.
Args:
in_features: Input feature dimension
out_features: Output feature dimension
bias: Whether to include bias term
use_row_multiplier: Enable row (output) multipliers
use_column_multiplier: Enable column (input) multipliers
"""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
use_row_multiplier: bool = False,
use_column_multiplier: bool = False
):
super().__init__()
# Base weight matrix (subject to WD)
self.linear = nn.Linear(in_features, out_features, bias=bias)
# Learnable multipliers (NOT subject to WD)
self.use_row_multiplier = use_row_multiplier
self.use_column_multiplier = use_column_multiplier
if use_row_multiplier:
self.row_multiplier = VectorMultiplier(out_features, multiplier_type="row")
if use_column_multiplier:
self.column_multiplier = VectorMultiplier(in_features, multiplier_type="column")
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Apply column multiplier before linear transformation
if self.use_column_multiplier:
x = self.column_multiplier(x)
# Linear transformation with base weights
x = self.linear(x)
# Apply row multiplier after linear transformation
if self.use_row_multiplier:
x = self.row_multiplier(x)
return x
# ==================== ORIGINAL COMPONENTS ====================
class FANLayer(nn.Module):
"""
Fourier Analysis Network (FAN) layer for effective periodicity modeling.
From "FANformer: Improving Large Language Models Through Effective Periodicity Modeling":
FANLayer'(X) = [cos(WpX)||sin(WpX)||(Wp¯X + Bp¯)]
This is the modified version (FANLayer') without activation function that gave
the best results in the paper.
"""
def __init__(self, hidden_size: int, fan_ratio: float = 0.25):
super().__init__()
self.hidden_size = hidden_size
self.fan_ratio = fan_ratio
# Calculate dimensions following the paper's approach
# Output will be: [cos(p) || sin(p) || g] where total = hidden_size + periodic_dim
output_dim = hidden_size + int(hidden_size * fan_ratio)
self.p_output_dim = int(output_dim * fan_ratio)
self.g_output_dim = output_dim - self.p_output_dim * 2
# Single fused projection (more efficient than two separate projections)
self.input_linear = nn.Linear(
hidden_size,
self.p_output_dim + self.g_output_dim,
bias=True
)
# Initialize parameters
self._init_weights()
def _init_weights(self):
"""Initialize weights following the paper's recommendations."""
nn.init.normal_(self.input_linear.weight, mean=0.0, std=0.02)
if self.input_linear.bias is not None:
nn.init.zeros_(self.input_linear.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Apply Fourier transformation to input.
Args:
x: Input tensor of shape (batch, seq_len, hidden_size)
Returns:
Transformed tensor with Fourier components concatenated
Shape: (batch, seq_len, hidden_size + periodic_dim)
"""
# Single projection followed by split (more efficient)
pg = self.input_linear(x)
p, g = torch.split(pg, [self.p_output_dim, self.g_output_dim], dim=-1)
# Concatenate all components: [cos(WpX) || sin(WpX) || (Wp¯X + Bp¯)]
x_fan = torch.cat([torch.cos(p), torch.sin(p), g], dim=-1)
return x_fan
class LNS(nn.Module):
"""
LayerNorm Scaling (LNS) - applies scaling factor 1/√ℓ as described in the paper.
From "The Curse of Depth in Large Language Models":
h^(ℓ) = LayerNorm(h^(ℓ)) × (1/√ℓ)
This prevents exponential variance growth in deeper layers.
"""
def __init__(self, layer_idx: int):
super().__init__()
# Layer 1 gets index 1, layer 2 gets index 2, etc.
# Avoid division by zero for layer 0
self.layer_idx = max(layer_idx + 1, 1) # +1 because layer_idx starts from 0
self.scale = 1.0 / math.sqrt(self.layer_idx)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x * self.scale
class GPAS(nn.Module):
"""
Gradient-Preserving Activation Scaling (GPAS)
Scales activations without penalizing gradients using stop-gradient.
Applied in Pre-Norm style: after sub-layer output but before residual sum.
"""
def __init__(self, d_model: int):
super().__init__()
self.d_model = d_model
self.alpha = nn.Parameter(torch.zeros(1))
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_detached = x.detach()
scaled_component = F.silu(self.alpha) * x_detached
x_scaled = x - scaled_component
return x_scaled
class SeeDNorm(nn.Module):
"""
Self-Rescaled Dynamic Normalization (SeeDNorm) with dual dropout regularization.
SeeDNorm(x) = [σ(x·β^T)·α + γ] ⊙ x/RMS(x)
Args:
dim: Hidden dimension size
eps: Small constant for numerical stability
dropout_input: Dropout on input features for dynamic mechanism (default: 0.0)
dropout_hidden: Dropout on normalized hidden states (default: 0.0)
"""
def __init__(
self,
dim: int,
eps: float = 1e-6,
dropout_input: float = 0.01,
dropout_hidden: float = 0.01,
):
super().__init__()
self.dim = dim
self.eps = eps
self.dropout_input = dropout_input
self.dropout_hidden = dropout_hidden
# Learnable parameters
self.gamma = nn.Parameter(torch.ones(dim)) # γ: static scaling
self.beta = nn.Parameter(torch.zeros(dim)) # β: self-rescaling
self.alpha = nn.Parameter(torch.ones(dim)) # α: dynamic modulation
def _rms_norm(self, x: torch.Tensor) -> torch.Tensor:
"""Compute RMS normalization: x / RMS(x)"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Apply Self-Rescaled Dynamic Normalization with dual dropout.
Args:
x: Input tensor of shape (..., dim)
Returns:
Normalized and dynamically scaled tensor of same shape
"""
x_for_dynamic = F.dropout(x, p=self.dropout_input, training=self.training)
rescale_factor = torch.tanh(torch.sum(x_for_dynamic * self.beta,
dim=-1, keepdim=True))
# Compute dynamic scaling coefficient: σ(x·β^T)·α + γ
dynamic_scale = rescale_factor * self.alpha + self.gamma
# Apply RMS normalization on ORIGINAL input (not dropped version)
x_normalized = self._rms_norm(x.float())
x_normalized = F.dropout(x_normalized, p=self.dropout_hidden, training=self.training)
# Apply dynamic scaling
output = x_normalized * dynamic_scale.float()
return output.type_as(x)
def extra_repr(self) -> str:
return (f"dim={self.dim}, eps={self.eps}, "
f"dropout_input={self.dropout_input}, dropout_hidden={self.dropout_hidden}")
# ==================== STACK MEMORY MODULE ====================
class StackMemory(nn.Module):
"""
From "Improving Formal Reasoning of Transformer with State Stack":
Implements a multi-head differentiable stack with soft push, pop, and no-op operations.
Each head maintains its own stack and mask, which are updated based on learned action
probabilities. Global reading is performed via query-over-stack attention.
This module is inserted between Transformer layers to augment information flow with
stack-like memory operations, enabling the model to better capture hierarchical and
recursive patterns characteristic of regular expressions and context-free grammars.
Note: StackMemory uses standard nn.Linear to maintain architectural
independence and avoid introducing additional complexity in the memory operations.
Args:
config: Model configuration containing stack-related hyperparameters
"""
def __init__(self, config: NeoLLMConfig):
super().__init__()
self.config = config
self.num_stack_heads = getattr(config, 'num_stack_heads', 4)
self.stack_slots = getattr(config, 'stack_slots', 24)
self.stack_d_model = getattr(config, 'stack_d_model', 128)
self.head_dim = self.stack_d_model // self.num_stack_heads
# Dimension reduction projections for efficiency
# Uses standard nn.Linear
self.down_proj = nn.Linear(config.hidden_size, self.stack_d_model, bias=True)
self.up_proj = nn.Linear(self.stack_d_model, config.hidden_size, bias=True)
# Action prediction: generates push/pop/no-op probabilities for each head
self.action_head = nn.Linear(self.stack_d_model, 3 * self.num_stack_heads, bias=True)
# Query projection for global reading (one per head)
self.gate_proj = nn.Linear(self.head_dim, 1, bias=True)
# Residual weight for gating stack contribution
self.res_weight = nn.Parameter(torch.ones(1))
# Cache for autoregressive generation (matches OLMo reference)
self.cache_size = getattr(config, "cache_size", 2048)
# Initialization fix: Register buffers for cache
# Default to batch_size=1 if forward_bs is not in config (standard inference)
forward_bs = getattr(config, 'forward_bs', 1)
self.register_buffer("k_cache", torch.zeros(forward_bs, self.cache_size, self.num_stack_heads, self.head_dim))
self.register_buffer("action_cache", torch.zeros(forward_bs, self.cache_size, self.num_stack_heads, 3))
self.cache_position = 0
self.enable_cache = False
def reset_cache(self):
self.cache_position = 0
def _vectorized_update(
self,
stack: torch.Tensor,
mask: torch.Tensor,
actions: torch.Tensor,
k_values: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Vectorized stack update mechanism applying soft push/pop/no-op operations.
Implements the differentiable stack operations from the paper:
- Push: shifts all elements down and places k_values at top
- Pop: shifts all elements up and removes top
- No-op: maintains current stack state
Args:
stack: Current stack state [batch, seq, num_heads, stack_slots, head_dim]
mask: Current stack mask [batch, seq, num_heads, stack_slots]
actions: Action probabilities [batch, seq, num_heads, 3] (push/pop/no-op)
k_values: New values to push [batch, seq, num_heads, head_dim]
Returns:
Tuple of (updated_stack, updated_mask)
"""
batch_size, seq_len = actions.shape[:2]
# Expand stack and mask along sequence dimension for parallel processing
# Only expand if checking against initial state dimensions (4D)
if stack.dim() == 4:
stack = stack.unsqueeze(1).expand(-1, seq_len, -1, -1, -1)
mask = mask.unsqueeze(1).expand(-1, seq_len, -1, -1)
# Generate pushed stack: new value at top, shift others down
push_stack = torch.cat([
k_values.unsqueeze(3), # New value at position 0
stack[:, :, :, :-1] # Shift existing elements down
], dim=3)
push_mask = torch.cat([
torch.ones_like(mask[:, :, :, :1]),
mask[:, :, :, :-1]
], dim=3)
# Generate popped stack: shift all up, zero at bottom
pop_stack = torch.cat([
stack[:, :, :, 1:],
torch.zeros_like(stack[:, :, :, :1])
], dim=3)
pop_mask = torch.cat([
mask[:, :, :, 1:],
torch.zeros_like(mask[:, :, :, :1])
], dim=3)
# Combine operations weighted by action probabilities
action_weights = actions.unsqueeze(-1).unsqueeze(-1) # [batch, seq, heads, 3, 1, 1]
stacks = torch.stack([push_stack, pop_stack, stack], dim=3) # [batch, seq, heads, 3, slots, dim]
masks = torch.stack([push_mask, pop_mask, mask], dim=3) # [batch, seq, heads, 3, slots]
# Weighted combination of all operations
new_stack = (stacks * action_weights).sum(dim=3)
new_mask = (masks * action_weights.squeeze(-1)).sum(dim=3)
return new_stack, new_mask
def forward(
self,
hidden_states: torch.Tensor,
stack: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Apply differentiable stack operations to hidden states.
Args:
hidden_states: Input hidden states [batch, seq, hidden_size]
stack: Previous stack state [batch, num_heads, stack_slots, head_dim] or None
mask: Previous stack mask [batch, num_heads, stack_slots] or None
Returns:
Tuple of (output_hidden_states, updated_stack, updated_mask)
"""
batch_size, seq_len, _ = hidden_states.shape
device = hidden_states.device
# Initialize stack and mask if not provided
if stack is None:
stack = torch.zeros(
batch_size, self.num_stack_heads, self.stack_slots, self.head_dim,
device=device, dtype=hidden_states.dtype
)
if mask is None:
mask = torch.zeros(
batch_size, self.num_stack_heads, self.stack_slots,
device=device, dtype=hidden_states.dtype
)
# Project to lower dimension for efficiency
new_hidden_states = self.down_proj(hidden_states)
# Generate action probabilities: [batch, seq, num_heads, 3]
action_logits = self.action_head(new_hidden_states) / math.sqrt(self.head_dim)
actions = F.softmax(
action_logits.view(batch_size, seq_len, self.num_stack_heads, 3),
dim=-1
)
# Prepare values to push (split into heads)
k_values = new_hidden_states.view(batch_size, seq_len, self.num_stack_heads, self.head_dim)
# Update stack and mask using vectorized operations
new_stack, new_mask = self._vectorized_update(stack, mask, actions, k_values)
# Global reading via query-over-stack attention
gate_scores = self.gate_proj(new_stack).squeeze(-1) # [batch, seq, heads, slots]
gate_weights = F.softmax(gate_scores + (1 - new_mask) * -1e9, dim=-1)
# Weighted sum over stack slots
memory_output = (new_stack * gate_weights.unsqueeze(-1)).sum(dim=3)
memory_output = memory_output.view(batch_size, seq_len, -1)
memory_output = self.up_proj(memory_output)
# Residual Connection
output = memory_output * self.res_weight + hidden_states
# Update Cache Logic
if self.enable_cache:
self._update_cache(k_values.detach(), actions.detach())
return output, new_stack[:, -1], new_mask[:, -1]
def _update_cache(self, k_values: torch.Tensor, actions: torch.Tensor):
seq_len = k_values.shape[1]
if self.cache_position + seq_len <= self.cache_size:
# Assumes standard batch processing for inference (usually batch_size=1)
self.k_cache[:, self.cache_position:self.cache_position+seq_len] = k_values
self.action_cache[:, self.cache_position:self.cache_position+seq_len] = actions
self.cache_position += seq_len
else:
self.reset_cache()
def step(self, hidden_state: torch.Tensor, stack: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if not self.enable_cache:
return self.forward(hidden_state.unsqueeze(1), stack, mask)
batch_size = hidden_state.shape[0]
# Compute features for current token
new_hidden_states = self.down_proj(hidden_state)
action_logits = self.action_head(new_hidden_states) / math.sqrt(self.head_dim)
current_actions = F.softmax(
action_logits.view(batch_size, 1, self.num_stack_heads, 3),
dim=-1
)
current_k = new_hidden_states.view(batch_size, 1, self.num_stack_heads, self.head_dim)
# Reconstruct History
if self.cache_position > 0:
cached_k = self.k_cache[:, :self.cache_position]
cached_actions = self.action_cache[:, :self.cache_position]
k_values = torch.cat([cached_k, current_k], dim=1)
actions = torch.cat([cached_actions, current_actions], dim=1)
else:
k_values = current_k
actions = current_actions
# Dimension Fix: Pass sequences directly without unsqueeze(0)
# k_values is [batch, seq_len_total, heads, dim]
# actions is [batch, seq_len_total, heads, 3]
new_stack_seq, new_mask_seq = self._vectorized_update(
stack, # Initial stack [batch, heads, slots, dim]
mask,
actions,
k_values
)
# Extract last step
current_stack = new_stack_seq[:, -1]
current_mask = new_mask_seq[:, -1]
gate_scores = self.gate_proj(current_stack).squeeze(-1)
gate_weights = F.softmax(gate_scores + (1 - current_mask) * -1e9, dim=-1)
memory_output = (current_stack * gate_weights.unsqueeze(-1)).sum(dim=2)
memory_output = memory_output.view(batch_size, -1)
memory_output_proj = self.up_proj(memory_output)
self._update_cache(current_k, current_actions)
return (
memory_output_proj * self.res_weight + hidden_state,
current_stack,
current_mask
)
# ==================== ROTARY EMBEDDING ====================
class NeoLLMRotaryEmbedding(nn.Module):
inv_freq: torch.Tensor # fix linting for `register_buffer`
def __init__(self, config: NeoLLMConfig, device=None):
super().__init__()
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
# Determine rope_type from rope_scaling config
self.rope_type = "default"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None and isinstance(config.rope_scaling, dict):
rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
if rope_type and rope_type in ROPE_INIT_FUNCTIONS:
self.rope_type = rope_type
# Initialize rope parameters
rope_init_fn = self.compute_default_rope_parameters
if self.rope_type != "default":
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
@staticmethod
def compute_default_rope_parameters(
config: NeoLLMConfig = None,
device: Optional["torch.device"] = None,
seq_len: int = None,
) -> tuple["torch.Tensor", float]:
"""
Computes the inverse frequencies according to the original RoPE implementation
Args:
config: The model configuration.
device: The device to use for initialization of the inverse frequencies.
seq_len: The current sequence length. Unused for this type of RoPE.
Returns:
Tuple of (torch.Tensor, float), containing the inverse frequencies for the RoPE
embeddings and the post-processing scaling factor applied to the computed cos/sin.
"""
base = config.rope_theta
dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
dim = int(dim * partial_rotary_factor)
attention_scaling = 1.0 # Unused in default RoPE
# Compute the inverse frequencies
inv_freq = 1.0 / (
base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
)
return inv_freq, attention_scaling
@torch.no_grad()
@dynamic_rope_update
def forward(self, x, position_ids):
# Asegura forma [B, S]
if position_ids.dim() == 1:
position_ids = position_ids.unsqueeze(0) # [1, S]
B = x.shape[0]
if position_ids.shape[0] != B:
# Replica posiciones idénticas por batch (semántica correcta)
position_ids = position_ids.expand(B, -1) # [B, S]
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
# inv_freq en float32 en el device correcto (sin expand con stride 0)
inv_freq = self.inv_freq.to(device=x.device, dtype=torch.float32) # [d/2]
with torch.autocast(device_type=device_type, enabled=False): # fuerza float32
# Θ[b,s,i] = position_ids[b,s] * inv_freq[i]
freqs = position_ids.to(dtype=torch.float32).unsqueeze(-1) * inv_freq.unsqueeze(0).unsqueeze(0)
# freqs: [B, S, d/2]
emb = torch.cat((freqs, freqs), dim=-1) # [B, S, d]
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors."""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
# Keep half or full tensor for later concatenation
rotary_dim = cos.shape[-1]
q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
# Apply rotary embeddings on the first half or full tensor
q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
# Concatenate back to full shape
q_embed = torch.cat([q_embed, q_pass], dim=-1)
k_embed = torch.cat([k_embed, k_pass], dim=-1)
return q_embed, k_embed
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs: Unpack[TransformersKwargs],
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class NeoLLMAttention(nn.Module):
"""
Multi-headed attention with FANformer integration, SeeDNorm for Q/K normalization,
ResFormer feature residual connections, and Learnable Multipliers for enhanced
information flow and scale adaptation.
ResFormer enhancement: Applies learnable feature residual connections from first layer
BEFORE QKV projections: H'_fan_n = λ_1 * H_fan_1 + λ_2 * H_fan_n
Learnable Multipliers placement (from "Learnable Multipliers" paper Appendix C):
- Q projection: row multipliers only (enables per-head attention scaling in GQA)
- K, V projections: no multipliers (avoids redundancy with Q multipliers)
- Output projection: row + column multipliers (maximally expressive without symmetries)
"""
def __init__(self, config: NeoLLMConfig, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True
# FANformer integration: FAN layer before QKV projections
self.fan_layer = FANLayer(
hidden_size=config.hidden_size,
fan_ratio=getattr(config, 'fan_ratio', 0.125)
)
# Calculate the output dimension after FAN transformation
fan_output_dim = config.hidden_size + int(config.hidden_size * getattr(config, 'fan_ratio', 0.125))
# Q projection with row multipliers (per-head scaling capability)
self.q_proj = LinearWithMultipliers(
fan_output_dim,
config.num_attention_heads * self.head_dim * 2,
bias=config.attention_bias,
use_row_multiplier=True,
use_column_multiplier=False
)
# K, V projections without multipliers (avoids Q-K symmetry)
self.k_proj = nn.Linear(
fan_output_dim, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.v_proj = nn.Linear(
fan_output_dim, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
# Output projection with row + column multipliers (maximally expressive)
self.o_proj = LinearWithMultipliers(
config.num_attention_heads * self.head_dim,
config.hidden_size,
bias=config.attention_bias,
use_row_multiplier=True,
use_column_multiplier=True
)
# SeeDNorm for Q/K normalization (replaces RMSNorm)
self.q_norm = SeeDNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = SeeDNorm(self.head_dim, eps=config.rms_norm_eps)
# Dropout for attention output
self.dropout = nn.Dropout(config.dropout_rate)
# ResFormer: learnable feature residual parameters (initialized to 0.5)
self.lambda_1 = nn.Parameter(torch.tensor(0.5)) # Weight for H_fan_1
self.lambda_2 = nn.Parameter(torch.tensor(0.5)) # Weight for H_fan_n
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
first_layer_fan: Optional[torch.Tensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Forward pass with ResFormer feature residual connections.
Args:
hidden_states: Current layer input [batch, seq, hidden_size]
position_embeddings: Tuple of (cos, sin) for RoPE
attention_mask: Causal attention mask
first_layer_fan: First layer FAN features (for ResFormer)
Returns:
Tuple of (attn_output, attn_weights, current_layer_fan)
"""
input_shape = hidden_states.shape[:-1]
# Apply FANformer transformation
hidden_states_fan = self.fan_layer(hidden_states)
# ResFormer: Apply feature residual connection BEFORE projections
if first_layer_fan is not None:
hidden_states_fan = self.lambda_1 * first_layer_fan + self.lambda_2 * hidden_states_fan
# Store current FAN features for ResFormer
current_layer_fan = hidden_states_fan.clone()
hidden_shape = (*input_shape, -1, self.head_dim)
# Q projection with learnable row multipliers
query_states, gate = torch.chunk(
self.q_proj(hidden_states_fan).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1
)
gate = gate.reshape(*input_shape, -1)
# Apply SeeDNorm to Q and K
query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2)
key_states = self.k_norm(self.k_proj(hidden_states_fan).view(hidden_shape)).transpose(1, 2)
value_states = self.v_proj(hidden_states_fan).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = attn_output * torch.sigmoid(gate)
# Output projection with learnable row + column multipliers
attn_output = self.o_proj(attn_output)
attn_output = self.dropout(attn_output)
return attn_output, attn_weights, current_layer_fan
class PolyNorm(torch.nn.Module):
def __init__(self, eps=1e-6):
super(PolyNorm, self).__init__()
self.weight = torch.nn.Parameter(torch.ones(3) / 3)
self.bias = torch.nn.Parameter(torch.zeros(1))
self.eps = eps
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return self.weight[0] * self._norm(x**3) + self.weight[1] * self._norm(x**2) + self.weight[2] * self._norm(x) + self.bias
class NeoLLMMLP(nn.Module):
"""
MLP with FANformer integration for featural periodicity modeling and
Learnable Multipliers for adaptive scale control.
This captures periodicities in the feature space (semantic/embedding dimensions)
complementary to the relational periodicities captured by attention mechanisms.
Works in conjunction with ResFormer for comprehensive information flow.
Learnable Multipliers placement (from "Learnable Multipliers" paper Appendix C):
- gate_proj: row multipliers only (controls gating mechanism scale)
- up_proj: no multipliers (avoids redundancy with down_proj)
- down_proj: row + column multipliers (maximally expressive output scaling)
"""
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
# FANformer integration for featural space periodicity
self.fan_layer = FANLayer(
hidden_size=config.hidden_size,
fan_ratio=getattr(config, 'fan_ratio_ffn', 0.0625) # Half of attention's fan_ratio
)
# Calculate the output dimension after FAN transformation
fan_output_dim = config.hidden_size + int(config.hidden_size * getattr(config, 'fan_ratio_ffn', 0.0625))
# SwiGLU/Gated architecture with learnable multipliers
# gate_proj: row multipliers for gating scale control
self.gate_proj = LinearWithMultipliers(
fan_output_dim,
self.intermediate_size,
bias=False,
use_row_multiplier=True,
use_column_multiplier=False
)
# up_proj: no multipliers (avoids redundancy)
self.up_proj = nn.Linear(fan_output_dim, self.intermediate_size, bias=False)
# down_proj: row + column multipliers (maximally expressive)
self.down_proj = LinearWithMultipliers(
self.intermediate_size,
self.hidden_size,
bias=False,
use_row_multiplier=True,
use_column_multiplier=True
)
self.act_fn = PolyNorm()
# Dropout for MLP hidden layer
self.dropout = nn.Dropout(config.dropout_rate)
def forward(self, x):
# Apply FAN transformation before projections
x_fan = self.fan_layer(x)
# Use FAN-transformed features for gate and up projections
gate_output = self.act_fn(self.gate_proj(x_fan))
up_output = self.up_proj(x_fan)
hidden = gate_output * up_output
hidden = self.dropout(hidden)
return self.down_proj(hidden)
class NeoLLMDecoderLayer(GradientCheckpointingLayer):
"""
Decoder layer with standard residual connections and optional StackMemory.
Architecture (Updated Flow):
1. Optional: StackMemory module (Pre-processing context injection)
2. Pre-norm (SeeDNorm) → LNS scaling → Self-Attention with ResFormer and Learnable Multipliers
3. Standard Residual Connection
4. GPAS activation scaling
5. Pre-norm (SeeDNorm) → LNS scaling → MLP with FANformer and Learnable Multipliers
6. Standard Residual Connection
7. GPAS activation scaling
"""
def __init__(self, config: NeoLLMConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.layer_idx = layer_idx
# Full attention with learnable multipliers
self.self_attn = NeoLLMAttention(config, layer_idx)
# MLP with FANformer integration and learnable multipliers
self.mlp = NeoLLMMLP(config)
# SeeDNorm for input and post-attention normalization
self.input_layernorm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps)
# LNS (LayerNorm Scaling) - applies 1/√ℓ scaling
self.lns_attn = LNS(layer_idx)
self.lns_mlp = LNS(layer_idx)
# GPAS (Gradient-Preserving Activation Scaling)
self.gpas_attn = GPAS(config.hidden_size)
self.gpas_mlp = GPAS(config.hidden_size)
# StackMemory: Differentiable hidden state stack
self.use_stack = getattr(config, 'use_stack', False)
if self.use_stack:
self.stack_memory = StackMemory(config)
# ResFormer: storage for current layer's FAN features
self.current_layer_fan = None
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
first_layer_fan: Optional[torch.Tensor] = None,
stack_state: Optional[torch.Tensor] = None,
stack_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Forward pass with ResFormer and optional StackMemory.
Args:
hidden_states: Current layer input [batch, seq, hidden_size]
position_embeddings: Tuple of (cos, sin) for RoPE
attention_mask: Causal attention mask
first_layer_fan: First layer FAN features (for ResFormer)
stack_state: StackMemory state (optional)
stack_mask: StackMemory mask (optional)
output_attentions: Whether to return attention weights
Returns:
Tuple of (hidden_states, attn_weights, stack_state, stack_mask)
"""
# ============================================================
# 1. Stack Memory Module (MOVED TO START)
# ============================================================
# We process memory first so the Attention layer can "see" the
# retrieved context. This eliminates the 1-layer lag.
if self.use_stack:
hidden_states, stack_state, stack_mask = self.stack_memory(
hidden_states, stack_state, stack_mask
)
# ============================================================
# 2. Attention Block with Standard Residual Connection
# ============================================================
residual = hidden_states
# Apply SeeDNorm normalization
hidden_states = self.input_layernorm(hidden_states)
# Apply LNS scaling after normalization
hidden_states = self.lns_attn(hidden_states)
# Self Attention with ResFormer
attn_output, attn_weights, self.current_layer_fan = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
first_layer_fan=first_layer_fan,
**kwargs,
)
# Standard Residual Connection
hidden_states = residual + attn_output
# Apply GPAS after residual connection
hidden_states = self.gpas_attn(hidden_states)
# ============================================================
# 3. MLP Block with Standard Residual Connection
# ============================================================
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
# Apply LNS scaling after normalization
hidden_states = self.lns_mlp(hidden_states)
# MLP with FANformer
mlp_output = self.mlp(hidden_states)
# Standard Residual Connection
hidden_states = residual + mlp_output
# Apply GPAS after residual connection
hidden_states = self.gpas_mlp(hidden_states)
# Return tuple matching the expected signature
if self.use_stack:
return (hidden_states, attn_weights, stack_state, stack_mask)
else:
return (hidden_states, attn_weights, None, None)
class NeoLLMPreTrainedModel(PreTrainedModel):
"""
Base class for NeoLLM models with custom weight initialization.
Handles initialization for:
- NeoLLMAttention (ResFormer lambda parameters)
- GPAS (Gradient-Preserving Activation Scaling)
- FANLayer (Fourier Analysis Network)
- SeeDNorm (Self-Rescaled Dynamic Normalization)
- Learnable Multipliers (ScalarMultiplier, VectorMultiplier)
- StackMemory (Differentiable Hidden State Stack)
"""
config: NeoLLMConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["NeoLLMDecoderLayer"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_is_stateful = True
def _init_weights(self, module):
"""
Initialize weights for all custom modules in NeoLLM.
"""
super()._init_weights(module)
if isinstance(module, NeoLLMAttention):
if hasattr(module, 'lambda_1'):
module.lambda_1.data.fill_(0.5)
if hasattr(module, 'lambda_2'):
module.lambda_2.data.fill_(0.5)
elif isinstance(module, GPAS):
module.alpha.data.fill_(0.0)
elif isinstance(module, (ScalarMultiplier, VectorMultiplier)):
if hasattr(module, 'multiplier'):
module.multiplier.data.fill_(1.0)
elif isinstance(module, StackMemory):
std = self.config.initializer_range if hasattr(self.config, 'initializer_range') else 0.02
if hasattr(module, 'down_proj'):
module.down_proj.weight.data.normal_(mean=0.0, std=std)
if hasattr(module, 'up_proj'):
module.up_proj.weight.data.normal_(mean=0.0, std=std)
if hasattr(module, 'action_head'):
module.action_head.weight.data.normal_(mean=0.0, std=std)
if module.action_head.bias is not None:
module.action_head.bias.data.zero_()
if hasattr(module, 'gate_proj'):
module.gate_proj.weight.data.normal_(mean=0.0, std=std)
if hasattr(module, 'res_weight'):
module.res_weight.data.fill_(1.0)
class NeoLLMModel(NeoLLMPreTrainedModel):
"""
NeoLLM base model with transformer decoder architecture.
Uses ResFormer for first-layer feature propagation with standard residual connections
and optional StackMemory for hierarchical pattern modeling.
Note on embeddings and weight tying: This model uses weight tying between
embed_tokens and lm_head (shared weights). Following "Learnable Multipliers"
paper analysis, we do NOT add multipliers to embeddings because:
1. Weight tying creates conflicting gradient paths
2. The paper explicitly warns against multipliers in lm_head
3. Compensating mechanisms provide scale adaptation immediately after embedding
"""
def __init__(self, config: NeoLLMConfig):
super().__init__(config)
# Standard embedding without learnable multipliers
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
# Each layer creates its own components (no shared parameters)
self.layers = nn.ModuleList(
[NeoLLMDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
# SeeDNorm for final output normalization (replaces RMSNorm)
self.norm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = NeoLLMRotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Configuration
self.use_stack = getattr(config, 'use_stack', False)
# ResFormer: storage for first layer's FAN features
self.first_layer_fan = None
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
use_cache: Optional[bool] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPast:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None
else self.config.output_hidden_states
)
output_attentions = (
output_attentions if output_attentions is not None
else self.config.output_attentions
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if position_ids is None:
position_ids = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0)
causal_mask = create_causal_mask(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=position_ids.squeeze(0),
past_key_values=None,
position_ids=position_ids,
)
hidden_states = inputs_embeds
next_decoder_cache = None
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
# Create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# ResFormer with first-layer feature propagation
self.first_layer_fan = None
# Initialize Stack states (always None at start of forward, rebuilt via cache step or vertical flow)
stack_state = None
stack_mask = None
# Propagate use_cache and reset if starting a new sequence
if self.use_stack:
for layer in self.layers:
if hasattr(layer, 'stack_memory'):
layer.stack_memory.enable_cache = use_cache if use_cache is not None else False
if past_key_values is None:
layer.stack_memory.reset_cache()
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = decoder_layer(
hidden_states,
position_embeddings=position_embeddings,
attention_mask=causal_mask,
first_layer_fan=self.first_layer_fan,
stack_state=stack_state,
stack_mask=stack_mask,
output_attentions=output_attentions,
**kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if self.use_stack:
# Vertical memory logic:
# The layer returns updated stack for the next layer to use (Vertical passing)
# But we do NOT persist it temporally here. The Module's internal cache handles temporal.
stack_state = layer_outputs[2]
stack_mask = layer_outputs[3]
# ResFormer: capture H_fan_1 from the first layer
# Dynamically capture for the current pass
if self.first_layer_fan is None and hasattr(decoder_layer, 'current_layer_fan'):
self.first_layer_fan = decoder_layer.current_layer_fan
# Apply SeeDNorm for final normalization
hidden_states = self.norm(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, next_decoder_cache, all_hidden_states, all_attentions] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_attentions,
)
@torch.compiler.disable
def compute_cce_loss(hidden_states, labels, lm_head_weight, lm_head_bias=None, pad_token_id=None):
"""
CCE loss computation excluded from compilation.
Preprocesses labels to eliminate torch.compile warnings.
"""
# Ensure labels are on the correct device
processed_labels = labels.to(hidden_states.device)
# Handle pad tokens: convert pad_token_id to -100 for proper masking
if pad_token_id is not None:
processed_labels = torch.where(
processed_labels == pad_token_id,
torch.tensor(-100, dtype=processed_labels.dtype, device=processed_labels.device),
processed_labels
)
return linear_cross_entropy(
hidden_states,
lm_head_weight,
processed_labels,
bias=lm_head_bias,
shift=1,
impl="cce_kahan_full_c",
reduction="mean"
)
class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
"""
Causal Language Model with NeoLLM architecture.
Supports ResFormer with standard residuals and optional StackMemory.
Note on LM head: Following "Learnable Multipliers" paper recommendations,
the output projection (lm_head) does NOT include learnable multipliers.
"""
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
def __init__(self, config):
super().__init__(config)
self.model = NeoLLMModel(config)
self.vocab_size = config.vocab_size
# LM head without learnable multipliers (standard linear layer)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
if past_key_values:
past_length = past_key_values[0][0].shape[2]
# If past_length > input_ids length, we are likely generating token by token
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default standard HF behavior
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"inputs_embeds": inputs_embeds,
}
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs: Unpack[TransformersKwargs],
) -> CausalLMOutputWithPast:
outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)
hidden_states = outputs.last_hidden_state
# CCE Loss computation for training
if labels is not None:
loss = compute_cce_loss(
hidden_states,
labels,
self.lm_head.weight,
getattr(self.lm_head, 'bias', None),
self.config.pad_token_id
)
logits = None
else:
# Inference mode - compute logits normally
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
# ==================== AUTOMODEL REGISTRATION ====================
__all__ = [
"NeoLLMForCausalLM",
"NeoLLMModel",
"NeoLLMPreTrainedModel",
"NeoLLMConfig",
"FANLayer",
"SeeDNorm",
"ScalarMultiplier",
"VectorMultiplier",
"LinearWithMultipliers",
"StackMemory",
]
# Register the configuration and model for AutoClass support
AutoConfig.register("neollm", NeoLLMConfig)
AutoModel.register(NeoLLMConfig, NeoLLMModel)
AutoModelForCausalLM.register(NeoLLMConfig, NeoLLMForCausalLM)