| |
| """ |
| NeoLLM Model with FANformer Integration and Dropout Regularization |
| Updated to include Fourier Analysis Network (FAN) layer for effective periodicity modeling |
| and dropout regularization at strategic locations |
| """ |
|
|
| import math |
| from typing import Any, Callable, Optional, Union |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch import nn |
| from cut_cross_entropy import linear_cross_entropy |
|
|
| from transformers.activations import ACT2FN |
| from transformers.generation import GenerationMixin |
| from transformers.masking_utils import create_causal_mask |
| from transformers.modeling_flash_attention_utils import FlashAttentionKwargs |
| from transformers.modeling_layers import GradientCheckpointingLayer |
| from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
| from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update |
| from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel |
| from transformers.processing_utils import Unpack |
| from transformers.utils import TransformersKwargs, logging |
| from transformers.utils.generic import check_model_inputs |
| from transformers.utils.import_utils import ( |
| is_causal_conv1d_available, |
| is_flash_linear_attention_available, |
| ) |
| from configuration_neollm import NeoLLMConfig |
|
|
| from transformers import AutoConfig, AutoModel, AutoModelForCausalLM |
|
|
| if is_causal_conv1d_available(): |
| from causal_conv1d import causal_conv1d_fn, causal_conv1d_update |
| else: |
| causal_conv1d_update, causal_conv1d_fn = None, None |
|
|
| if is_flash_linear_attention_available(): |
| from fla.modules import FusedRMSNormGated |
| from fla.ops.gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule |
| else: |
| chunk_gated_delta_rule, fused_recurrent_gated_delta_rule = None, None |
| FusedRMSNormGated = None |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class FANLayer(nn.Module): |
| """ |
| Fourier Analysis Network (FAN) layer for effective periodicity modeling. |
| |
| From "FANformer: Improving Large Language Models Through Effective Periodicity Modeling": |
| FANLayer'(X) = [cos(WpX)||sin(WpX)||(Wp¯X + Bp¯)] |
| |
| This is the modified version (FANLayer') without activation function that gave |
| the best results in the paper. |
| """ |
| |
| def __init__(self, hidden_size: int, fan_ratio: float = 0.25): |
| super().__init__() |
| self.hidden_size = hidden_size |
| self.fan_ratio = fan_ratio |
| |
| |
| self.periodic_dim = int(hidden_size * fan_ratio) |
| self.non_periodic_dim = hidden_size - self.periodic_dim |
| |
| |
| self.Wp = nn.Linear(hidden_size, self.periodic_dim, bias=False) |
| self.Wp_bar = nn.Linear(hidden_size, self.non_periodic_dim, bias=True) |
| |
| |
| self._init_weights() |
| |
| def _init_weights(self): |
| """Initialize weights following the paper's recommendations.""" |
| |
| nn.init.normal_(self.Wp.weight, mean=0.0, std=0.02) |
| |
| |
| nn.init.normal_(self.Wp_bar.weight, mean=0.0, std=0.02) |
| if self.Wp_bar.bias is not None: |
| nn.init.zeros_(self.Wp_bar.bias) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Apply Fourier transformation to input. |
| |
| Args: |
| x: Input tensor of shape (batch, seq_len, hidden_size) |
| |
| Returns: |
| Transformed tensor with Fourier components concatenated |
| """ |
| |
| x_periodic = self.Wp(x) |
| cos_component = torch.cos(x_periodic) |
| sin_component = torch.sin(x_periodic) |
| |
| |
| x_non_periodic = self.Wp_bar(x) |
| |
| |
| x_fan = torch.cat([cos_component, sin_component, x_non_periodic], dim=-1) |
| |
| return x_fan |
|
|
|
|
| class LNS(nn.Module): |
| """ |
| LayerNorm Scaling (LNS) - applies scaling factor 1/√ℓ as described in the paper. |
| |
| From "The Curse of Depth in Large Language Models": |
| h^(ℓ) = LayerNorm(h^(ℓ)) × (1/√ℓ) |
| |
| This prevents exponential variance growth in deeper layers. |
| """ |
| def __init__(self, layer_idx: int): |
| super().__init__() |
| |
| |
| self.layer_idx = max(layer_idx + 1, 1) |
| self.scale = 1.0 / math.sqrt(self.layer_idx) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return x * self.scale |
|
|
|
|
| class GPAS(nn.Module): |
| """ |
| Gradient-Preserving Activation Scaling (GPAS) |
| Scales activations without penalizing gradients using stop-gradient. |
| Applied in Pre-Norm style: after sub-layer output but before residual sum. |
| """ |
| def __init__(self, d_model: int): |
| super().__init__() |
| |
| self.d_model = d_model |
| self.alpha = nn.Parameter(torch.zeros(1)) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x_detached = x.detach() |
| scaled_component = F.silu(self.alpha) * x_detached |
| x_scaled = x - scaled_component |
| |
| return x_scaled |
|
|
|
|
| class NeoLLMRMSNormGated(nn.Module): |
| def __init__(self, hidden_size, eps=1e-6, **kwargs): |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(hidden_size)) |
| self.variance_epsilon = eps |
|
|
| def forward(self, hidden_states, gate=None): |
| input_dtype = hidden_states.dtype |
| hidden_states = hidden_states.to(torch.float32) |
| variance = hidden_states.pow(2).mean(-1, keepdim=True) |
| |
| hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
| hidden_states = self.weight * hidden_states.to(input_dtype) |
| hidden_states = hidden_states * F.silu(gate.to(torch.float32)) |
|
|
| return hidden_states.to(input_dtype) |
|
|
|
|
| class NeoLLMRotaryEmbedding(nn.Module): |
| inv_freq: torch.Tensor |
|
|
| def __init__(self, config: NeoLLMConfig, device=None): |
| super().__init__() |
| |
| if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): |
| self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) |
| else: |
| self.rope_type = "default" |
| self.max_seq_len_cached = config.max_position_embeddings |
| self.original_max_seq_len = config.max_position_embeddings |
|
|
| self.config = config |
| self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] |
|
|
| inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
| self.original_inv_freq = self.inv_freq |
|
|
| @torch.no_grad() |
| @dynamic_rope_update |
| def forward(self, x, position_ids): |
| inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) |
| position_ids_expanded = position_ids[:, None, :].float() |
|
|
| device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" |
| with torch.autocast(device_type=device_type, enabled=False): |
| freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| cos = emb.cos() * self.attention_scaling |
| sin = emb.sin() * self.attention_scaling |
|
|
| return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) |
|
|
|
|
| class NeoLLMRMSNorm(nn.Module): |
| def __init__(self, dim: int, eps: float = 1e-6): |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.zeros(dim)) |
|
|
| def _norm(self, x): |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
|
|
| def forward(self, x): |
| output = self._norm(x.float()) |
| |
| output = output * (1.0 + self.weight.float()) |
| return output.type_as(x) |
|
|
| def extra_repr(self): |
| return f"{tuple(self.weight.shape)}, eps={self.eps}" |
|
|
|
|
| def rotate_half(x): |
| """Rotates half the hidden dims of the input.""" |
| x1 = x[..., : x.shape[-1] // 2] |
| x2 = x[..., x.shape[-1] // 2 :] |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors.""" |
| cos = cos.unsqueeze(unsqueeze_dim) |
| sin = sin.unsqueeze(unsqueeze_dim) |
|
|
| |
| rotary_dim = cos.shape[-1] |
| q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] |
| k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] |
|
|
| |
| q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) |
| k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) |
|
|
| |
| q_embed = torch.cat([q_embed, q_pass], dim=-1) |
| k_embed = torch.cat([k_embed, k_pass], dim=-1) |
| return q_embed, k_embed |
|
|
|
|
| def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
| """ |
| This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
| num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) |
| """ |
| batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
| if n_rep == 1: |
| return hidden_states |
| hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) |
| return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
|
| def eager_attention_forward( |
| module: nn.Module, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attention_mask: Optional[torch.Tensor], |
| scaling: float, |
| dropout: float = 0.0, |
| **kwargs: Unpack[TransformersKwargs], |
| ): |
| key_states = repeat_kv(key, module.num_key_value_groups) |
| value_states = repeat_kv(value, module.num_key_value_groups) |
|
|
| attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling |
| if attention_mask is not None: |
| causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] |
| attn_weights = attn_weights + causal_mask |
|
|
| attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) |
| attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) |
| attn_output = torch.matmul(attn_weights, value_states) |
| attn_output = attn_output.transpose(1, 2).contiguous() |
|
|
| return attn_output, attn_weights |
|
|
|
|
| class NeoLLMAttention(nn.Module): |
| """Multi-headed attention with FANformer integration for periodicity modeling""" |
|
|
| def __init__(self, config: NeoLLMConfig, layer_idx: int): |
| super().__init__() |
| self.config = config |
| self.layer_idx = layer_idx |
| self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) |
| self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads |
| self.scaling = self.head_dim**-0.5 |
| self.attention_dropout = config.attention_dropout |
| self.is_causal = True |
| |
| |
| self.fan_layer = FANLayer( |
| hidden_size=config.hidden_size, |
| fan_ratio=getattr(config, 'fan_ratio', 0.25) |
| ) |
| |
| |
| fan_output_dim = config.hidden_size + int(config.hidden_size * getattr(config, 'fan_ratio', 0.25)) |
| |
| |
| self.q_proj = nn.Linear( |
| fan_output_dim, config.num_attention_heads * self.head_dim * 2, bias=config.attention_bias |
| ) |
| self.k_proj = nn.Linear( |
| fan_output_dim, config.num_key_value_heads * self.head_dim, bias=config.attention_bias |
| ) |
| self.v_proj = nn.Linear( |
| fan_output_dim, config.num_key_value_heads * self.head_dim, bias=config.attention_bias |
| ) |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| self.q_norm = NeoLLMRMSNorm(self.head_dim, eps=config.rms_norm_eps) |
| self.k_norm = NeoLLMRMSNorm(self.head_dim, eps=config.rms_norm_eps) |
| |
| |
| self.dropout = nn.Dropout(config.dropout_rate) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| position_embeddings: tuple[torch.Tensor, torch.Tensor], |
| attention_mask: Optional[torch.Tensor], |
| **kwargs: Unpack[FlashAttentionKwargs], |
| ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: |
| input_shape = hidden_states.shape[:-1] |
| |
| |
| hidden_states_fan = self.fan_layer(hidden_states) |
| hidden_shape = (*input_shape, -1, self.head_dim) |
|
|
| query_states, gate = torch.chunk( |
| self.q_proj(hidden_states_fan).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1 |
| ) |
| gate = gate.reshape(*input_shape, -1) |
|
|
| query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2) |
| key_states = self.k_norm(self.k_proj(hidden_states_fan).view(hidden_shape)).transpose(1, 2) |
| value_states = self.v_proj(hidden_states_fan).view(hidden_shape).transpose(1, 2) |
|
|
| cos, sin = position_embeddings |
| query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
|
|
| attention_interface: Callable = eager_attention_forward |
| if self.config._attn_implementation != "eager": |
| attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] |
|
|
| attn_output, attn_weights = attention_interface( |
| self, |
| query_states, |
| key_states, |
| value_states, |
| attention_mask, |
| dropout=0.0 if not self.training else self.attention_dropout, |
| scaling=self.scaling, |
| **kwargs, |
| ) |
|
|
| attn_output = attn_output.reshape(*input_shape, -1).contiguous() |
| attn_output = attn_output * torch.sigmoid(gate) |
|
|
| attn_output = self.o_proj(attn_output) |
| attn_output = self.dropout(attn_output) |
| return attn_output, attn_weights |
|
|
|
|
| def apply_mask_to_padding_states(hidden_states, attention_mask): |
| """ |
| Tunes out the hidden states for padding tokens |
| """ |
| if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: |
| dtype = hidden_states.dtype |
| hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) |
|
|
| return hidden_states |
|
|
|
|
| is_fast_path_available = all( |
| (causal_conv1d_fn, causal_conv1d_update, chunk_gated_delta_rule, fused_recurrent_gated_delta_rule) |
| ) |
|
|
|
|
| def torch_causal_conv1d_update( |
| hidden_states, |
| conv_state, |
| weight, |
| bias=None, |
| activation=None, |
| ): |
| _, hidden_size, seq_len = hidden_states.shape |
| state_len = conv_state.shape[-1] |
|
|
| hidden_states_new = torch.cat([conv_state, hidden_states], dim=-1).to(weight.dtype) |
| conv_state.copy_(hidden_states_new[:, :, -state_len:]) |
| out = F.conv1d(hidden_states_new, weight.unsqueeze(1), bias, padding=0, groups=hidden_size) |
| out = F.silu(out[:, :, -seq_len:]) |
| out = out.to(hidden_states.dtype) |
| return out |
|
|
|
|
| def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6): |
| """This function is intended to align with the l2norm implementation in the FLA library.""" |
| inv_norm = 1 / torch.sqrt((x * x).sum(dim=dim, keepdim=True) + eps) |
| return x * inv_norm |
|
|
|
|
| def torch_chunk_gated_delta_rule( |
| query, |
| key, |
| value, |
| g, |
| beta, |
| chunk_size=64, |
| initial_state=None, |
| output_final_state=False, |
| use_qk_l2norm_in_kernel=False, |
| ): |
| initial_dtype = query.dtype |
| if use_qk_l2norm_in_kernel: |
| query = l2norm(query, dim=-1, eps=1e-6) |
| key = l2norm(key, dim=-1, eps=1e-6) |
| query, key, value, beta, g = [ |
| x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g) |
| ] |
|
|
| batch_size, sequence_length, num_heads, k_head_dim = key.shape |
| v_head_dim = value.shape[-1] |
| pad_size = (chunk_size - num_heads % chunk_size) % chunk_size |
| query = F.pad(query, (0, 0, 0, pad_size)) |
| key = F.pad(key, (0, 0, 0, pad_size)) |
| value = F.pad(value, (0, 0, 0, pad_size)) |
| beta = F.pad(beta, (0, pad_size)) |
| g = F.pad(g, (0, pad_size)) |
| tot_heads = num_heads + pad_size |
| scale = 1 / (query.shape[-1] ** 0.5) |
| query = query * scale |
|
|
| v_beta = value * beta.unsqueeze(-1) |
| k_beta = key * beta.unsqueeze(-1) |
| |
| query, key, value, k_beta, v_beta = [ |
| x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta) |
| ] |
| g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size) |
| mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0) |
|
|
| |
| g = g.cumsum(dim=-1) |
| decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril() |
| attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) |
| for i in range(1, chunk_size): |
| row = attn[..., i, :i].clone() |
| sub = attn[..., :i, :i].clone() |
| attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) |
| attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) |
| value = attn @ v_beta |
| k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) |
| last_recurrent_state = ( |
| torch.zeros(batch_size, sequence_length, k_head_dim, v_head_dim).to(value) |
| if initial_state is None |
| else initial_state.to(value) |
| ) |
| core_attn_out = torch.zeros_like(value) |
| mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1) |
|
|
| |
| for i in range(0, tot_heads // chunk_size): |
| q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i] |
| attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) |
| v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state |
| v_new = v_i - v_prime |
| attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state |
| core_attn_out[:, :, i] = attn_inter + attn @ v_new |
| last_recurrent_state = ( |
| last_recurrent_state * g[:, :, i, -1, None, None].exp() |
| + (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new |
| ) |
|
|
| if not output_final_state: |
| last_recurrent_state = None |
| core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1]) |
| core_attn_out = core_attn_out[:, :, :num_heads] |
| core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) |
| return core_attn_out, last_recurrent_state |
|
|
|
|
| def torch_recurrent_gated_delta_rule( |
| query, key, value, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel=False |
| ): |
| initial_dtype = query.dtype |
| if use_qk_l2norm_in_kernel: |
| query = l2norm(query, dim=-1, eps=1e-6) |
| key = l2norm(key, dim=-1, eps=1e-6) |
| query, key, value, beta, g = [ |
| x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g) |
| ] |
|
|
| batch_size, sequence_length, num_heads, k_head_dim = key.shape |
| v_head_dim = value.shape[-1] |
| scale = 1 / (query.shape[-1] ** 0.5) |
| query = query * scale |
|
|
| core_attn_out = torch.zeros(batch_size, sequence_length, num_heads, v_head_dim).to(value) |
| last_recurrent_state = ( |
| torch.zeros(batch_size, sequence_length, k_head_dim, v_head_dim).to(value) |
| if initial_state is None |
| else initial_state.to(value) |
| ) |
|
|
| for i in range(num_heads): |
| q_t = query[:, :, i] |
| k_t = key[:, :, i] |
| v_t = value[:, :, i] |
| g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1) |
| beta_t = beta[:, :, i].unsqueeze(-1) |
|
|
| last_recurrent_state = last_recurrent_state * g_t |
| kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) |
| delta = (v_t - kv_mem) * beta_t |
| last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) |
| core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) |
|
|
| if not output_final_state: |
| last_recurrent_state = None |
| core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) |
| return core_attn_out, last_recurrent_state |
|
|
| class NeoLLMGatedDeltaNet(nn.Module): |
| """Linear attention with FANformer integration for periodicity modeling""" |
| |
| def __init__(self, config: NeoLLMConfig, layer_idx: int): |
| super().__init__() |
| self.hidden_size = config.hidden_size |
| self.num_v_heads = config.linear_num_value_heads |
| self.num_k_heads = config.linear_num_key_heads |
| self.head_k_dim = config.linear_key_head_dim |
| self.head_v_dim = config.linear_value_head_dim |
| self.key_dim = self.head_k_dim * self.num_k_heads |
| self.value_dim = self.head_v_dim * self.num_v_heads |
|
|
| self.conv_kernel_size = config.linear_conv_kernel_dim |
| self.layer_idx = layer_idx |
| self.activation = config.hidden_act |
| self.act = ACT2FN[config.hidden_act] |
| self.layer_norm_epsilon = config.rms_norm_eps |
|
|
| |
| self.fan_layer = FANLayer( |
| hidden_size=config.hidden_size, |
| fan_ratio=getattr(config, 'fan_ratio', 0.25) |
| ) |
| |
| |
| fan_output_dim = config.hidden_size + int(config.hidden_size * getattr(config, 'fan_ratio', 0.25)) |
|
|
| |
| self.conv_dim = self.key_dim * 2 + self.value_dim |
| self.conv1d = nn.Conv1d( |
| in_channels=self.conv_dim, |
| out_channels=self.conv_dim, |
| bias=False, |
| kernel_size=self.conv_kernel_size, |
| groups=self.conv_dim, |
| padding=self.conv_kernel_size - 1, |
| ) |
|
|
| |
| projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2 |
| projection_size_ba = self.num_v_heads * 2 |
| self.in_proj_qkvz = nn.Linear(fan_output_dim, projection_size_qkvz, bias=False) |
| self.in_proj_ba = nn.Linear(fan_output_dim, projection_size_ba, bias=False) |
|
|
| |
| self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads)) |
|
|
| A = torch.empty(self.num_v_heads).uniform_(0, 16) |
| self.A_log = nn.Parameter(torch.log(A)) |
|
|
| |
| fla_compatible_activation = "silu" if self.activation not in ['swish', 'silu', 'sigmoid'] else self.activation |
| |
| self.norm = ( |
| NeoLLMRMSNormGated(self.head_v_dim, eps=self.layer_norm_epsilon) |
| if FusedRMSNormGated is None |
| else FusedRMSNormGated( |
| self.head_v_dim, |
| eps=self.layer_norm_epsilon, |
| activation=fla_compatible_activation, |
| device=torch.cuda.current_device(), |
| dtype=config.dtype if config.dtype is not None else torch.get_default_dtype(), |
| ) |
| ) |
|
|
| self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False) |
| |
| |
| self.dropout = nn.Dropout(config.dropout_rate) |
|
|
| self.causal_conv1d_fn = causal_conv1d_fn |
| self.causal_conv1d_update = causal_conv1d_update or torch_causal_conv1d_update |
| self.chunk_gated_delta_rule = chunk_gated_delta_rule or torch_chunk_gated_delta_rule |
| self.recurrent_gated_delta_rule = fused_recurrent_gated_delta_rule or torch_recurrent_gated_delta_rule |
|
|
| if not is_fast_path_available: |
| logger.warning_once( |
| "The fast path is not available because one of the required library is not installed. Falling back to " |
| "torch implementation. To install follow https://github.com/fla-org/flash-linear-attention#installation and" |
| " https://github.com/Dao-AILab/causal-conv1d" |
| ) |
|
|
| def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba): |
| """ |
| Derives `query`, `key` and `value` tensors from `mixed_qkvz` and `mixed_ba`. |
| """ |
| new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + ( |
| self.num_k_heads, |
| 2 * self.head_k_dim + 2 * self.head_v_dim * self.num_v_heads // self.num_k_heads, |
| ) |
| new_tensor_shape_ba = mixed_ba.size()[:-1] + (self.num_k_heads, 2 * self.num_v_heads // self.num_k_heads) |
|
|
| mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz) |
| mixed_ba = mixed_ba.view(*new_tensor_shape_ba) |
| split_arg_list_qkvz = [ |
| self.head_k_dim, |
| self.head_k_dim, |
| (self.num_v_heads // self.num_k_heads * self.head_v_dim), |
| (self.num_v_heads // self.num_k_heads * self.head_v_dim), |
| ] |
| split_arg_list_ba = [self.num_v_heads // self.num_k_heads, self.num_v_heads // self.num_k_heads] |
| query, key, value, z = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=3) |
| b, a = torch.split(mixed_ba, split_arg_list_ba, dim=3) |
| |
| value = value.reshape(value.size(0), value.size(1), -1, self.head_v_dim) |
| z = z.reshape(z.size(0), z.size(1), -1, self.head_v_dim) |
| b = b.reshape(b.size(0), b.size(1), self.num_v_heads) |
| a = a.reshape(a.size(0), a.size(1), self.num_v_heads) |
| return query, key, value, z, b, a |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| ): |
| hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) |
|
|
| |
| batch_size, seq_len, _ = hidden_states.shape |
|
|
| |
| hidden_states_fan = self.fan_layer(hidden_states) |
|
|
| projected_states_qkvz = self.in_proj_qkvz(hidden_states_fan) |
| projected_states_ba = self.in_proj_ba(hidden_states_fan) |
| query, key, value, z, b, a = self.fix_query_key_value_ordering(projected_states_qkvz, projected_states_ba) |
| query, key, value = (x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value)) |
|
|
| mixed_qkv = torch.cat((query, key, value), dim=-1) |
| mixed_qkv = mixed_qkv.transpose(1, 2) |
|
|
| |
| if self.causal_conv1d_fn is not None: |
| mixed_qkv = self.causal_conv1d_fn( |
| x=mixed_qkv, |
| weight=self.conv1d.weight.squeeze(1), |
| bias=self.conv1d.bias, |
| activation="silu", |
| seq_idx=None, |
| ) |
| else: |
| mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len]) |
|
|
| mixed_qkv = mixed_qkv.transpose(1, 2) |
| query, key, value = torch.split( |
| mixed_qkv, |
| [ |
| self.key_dim, |
| self.key_dim, |
| self.value_dim, |
| ], |
| dim=-1, |
| ) |
| query = query.reshape(query.shape[0], query.shape[1], -1, self.head_k_dim) |
| key = key.reshape(key.shape[0], key.shape[1], -1, self.head_k_dim) |
| value = value.reshape(value.shape[0], value.shape[1], -1, self.head_v_dim) |
|
|
| beta = b.sigmoid() |
| |
| g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) |
| if self.num_v_heads // self.num_k_heads > 1: |
| query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) |
| key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) |
|
|
| |
| core_attn_out, _ = self.chunk_gated_delta_rule( |
| query, |
| key, |
| value, |
| g=g, |
| beta=beta, |
| initial_state=None, |
| output_final_state=False, |
| use_qk_l2norm_in_kernel=True, |
| ) |
|
|
| z_shape_og = z.shape |
| |
| core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) |
| z = z.reshape(-1, z.shape[-1]) |
| core_attn_out = self.norm(core_attn_out, z) |
| core_attn_out = core_attn_out.reshape(z_shape_og) |
| core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1) |
|
|
| output = self.out_proj(core_attn_out) |
| output = self.dropout(output) |
| return output |
|
|
| class PolyNorm(torch.nn.Module): |
| def __init__(self, eps=1e-6): |
| super(PolyNorm, self).__init__() |
| self.weight = torch.nn.Parameter(torch.ones(3) / 3) |
| self.bias = torch.nn.Parameter(torch.zeros(1)) |
| self.eps = eps |
|
|
| def _norm(self, x): |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
|
|
| def forward(self, x): |
| return self.weight[0] * self._norm(x**3) + self.weight[1] * self._norm(x**2) + self.weight[2] * self._norm(x) + self.bias |
| |
| class NeoLLMMLP(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.hidden_size = config.hidden_size |
| self.intermediate_size = config.intermediate_size |
| self.linear1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
| self.linear2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) |
| self.act_fn = PolyNorm() |
| |
| |
| self.dropout = nn.Dropout(config.dropout_rate) |
|
|
| def forward(self, x): |
| hidden = self.act_fn(self.linear1(x)) |
| hidden = self.dropout(hidden) |
| return self.linear2(hidden) |
|
|
|
|
| class NeoLLMDecoderLayer(GradientCheckpointingLayer): |
| def __init__(self, config: NeoLLMConfig, layer_idx: int): |
| super().__init__() |
| self.hidden_size = config.hidden_size |
| self.layer_idx = layer_idx |
|
|
| |
| self.layer_type = config.layer_types[layer_idx] |
| if self.layer_type == "linear_attention": |
| self.linear_attn = NeoLLMGatedDeltaNet(config, layer_idx) |
| elif self.layer_type == "full_attention": |
| self.self_attn = NeoLLMAttention(config, layer_idx) |
|
|
| |
| self.mlp = NeoLLMMLP(config) |
|
|
| self.input_layernorm = NeoLLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.post_attention_layernorm = NeoLLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| |
| |
| self.lns_attn = LNS(layer_idx) |
| self.lns_mlp = LNS(layer_idx) |
| |
| |
| self.gpas_attn = GPAS(config.hidden_size) |
| self.gpas_mlp = GPAS(config.hidden_size) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| position_embeddings: tuple[torch.Tensor, torch.Tensor], |
| attention_mask: Optional[torch.Tensor] = None, |
| **kwargs: Unpack[FlashAttentionKwargs], |
| ) -> torch.FloatTensor: |
| residual = hidden_states |
|
|
| |
| hidden_states = self.input_layernorm(hidden_states) |
| |
| |
| hidden_states = self.lns_attn(hidden_states) |
|
|
| |
| if self.layer_type == "linear_attention": |
| hidden_states = self.linear_attn( |
| hidden_states=hidden_states, |
| attention_mask=attention_mask, |
| ) |
| elif self.layer_type == "full_attention": |
| |
| hidden_states, _ = self.self_attn( |
| hidden_states=hidden_states, |
| attention_mask=attention_mask, |
| position_embeddings=position_embeddings, |
| **kwargs, |
| ) |
|
|
| |
| hidden_states = residual + hidden_states |
| |
| |
| hidden_states = self.gpas_attn(hidden_states) |
|
|
| |
| residual = hidden_states |
| hidden_states = self.post_attention_layernorm(hidden_states) |
| |
| |
| hidden_states = self.lns_mlp(hidden_states) |
| |
| hidden_states = self.mlp(hidden_states) |
| |
| |
| hidden_states = residual + hidden_states |
| |
| |
| hidden_states = self.gpas_mlp(hidden_states) |
|
|
| return hidden_states |
|
|
|
|
| class NeoLLMPreTrainedModel(PreTrainedModel): |
| config: NeoLLMConfig |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["NeoLLMDecoderLayer"] |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _is_stateful = True |
|
|
| def _init_weights(self, module): |
| super()._init_weights(module) |
| if isinstance(module, NeoLLMGatedDeltaNet): |
| module.dt_bias.data.fill_(1.0) |
| module.A_log.data.uniform_(0, 16).log_() |
| elif isinstance(module, GPAS): |
| |
| module.alpha.data.fill_(0.0) |
| elif isinstance(module, FANLayer): |
| |
| pass |
|
|
|
|
| class NeoLLMModel(NeoLLMPreTrainedModel): |
| def __init__(self, config: NeoLLMConfig): |
| super().__init__(config) |
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) |
| self.layers = nn.ModuleList( |
| [NeoLLMDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] |
| ) |
| self.norm = NeoLLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.rotary_emb = NeoLLMRotaryEmbedding(config=config) |
| self.gradient_checkpointing = False |
| |
| self.post_init() |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> BaseModelOutputWithPast: |
| if (input_ids is None) ^ (inputs_embeds is not None): |
| raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.embed_tokens(input_ids) |
|
|
| if position_ids is None: |
| position_ids = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0) |
|
|
| causal_mask = create_causal_mask( |
| config=self.config, |
| input_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| cache_position=position_ids.squeeze(0), |
| past_key_values=None, |
| position_ids=position_ids, |
| ) |
| linear_attn_mask = self._update_linear_attn_mask(attention_mask, position_ids.squeeze(0)) |
|
|
| hidden_states = inputs_embeds |
|
|
| |
| position_embeddings = self.rotary_emb(hidden_states, position_ids) |
|
|
| for decoder_layer in self.layers[: self.config.num_hidden_layers]: |
| layer_mask = linear_attn_mask if decoder_layer.layer_type == "linear_attention" else causal_mask |
|
|
| hidden_states = decoder_layer( |
| hidden_states, |
| position_embeddings=position_embeddings, |
| attention_mask=layer_mask, |
| **kwargs, |
| ) |
|
|
| hidden_states = self.norm(hidden_states) |
|
|
| return BaseModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| past_key_values=None, |
| ) |
|
|
| def _update_linear_attn_mask(self, attention_mask, cache_position): |
| """ |
| NOTE: Left-padding is used for linear attention mask. |
| No need for zeroing states when attending to all inputs |
| """ |
| linear_attn_mask = attention_mask |
| if attention_mask is not None and torch.all(attention_mask == 1): |
| linear_attn_mask = None |
| return linear_attn_mask |
|
|
| class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin): |
| _tied_weights_keys = ["lm_head.weight"] |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.model = NeoLLMModel(config) |
| self.vocab_size = config.vocab_size |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
| |
| self.post_init() |
|
|
| @torch.compiler.disable |
| def _compute_cce_loss(self, hidden_states, labels): |
| """ |
| CCE loss computation excluded from compilation. |
| Preprocesses labels to eliminate torch.compile warnings. |
| """ |
| |
| processed_labels = labels.to(hidden_states.device) |
| |
| |
| if self.config.pad_token_id is not None: |
| processed_labels = torch.where( |
| processed_labels == self.config.pad_token_id, |
| torch.tensor(-100, dtype=processed_labels.dtype, device=processed_labels.device), |
| processed_labels |
| ) |
| |
| return linear_cross_entropy( |
| hidden_states, |
| self.lm_head.weight, |
| processed_labels, |
| bias=getattr(self.lm_head, 'bias', None), |
| shift=1, |
| impl="cce", |
| reduction="mean" |
| ) |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| logits_to_keep: Union[int, torch.Tensor] = 0, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> CausalLMOutputWithPast: |
| r""" |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
| config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
| (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
| """ |
|
|
| |
| outputs: BaseModelOutputWithPast = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| inputs_embeds=inputs_embeds, |
| **kwargs, |
| ) |
|
|
| hidden_states = outputs.last_hidden_state |
|
|
| |
| if labels is not None: |
| loss = self._compute_cce_loss(hidden_states, labels) |
| logits = None |
| else: |
| |
| slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
| logits = self.lm_head(hidden_states[:, slice_indices, :]) |
| loss = None |
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=None, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
| |
|
|
| |
| AutoConfig.register("neollm", NeoLLMConfig) |
| AutoModel.register(NeoLLMConfig, NeoLLMModel) |
| AutoModelForCausalLM.register(NeoLLMConfig, NeoLLMForCausalLM) |
|
|
| __all__ = [ |
| "NeoLLMForCausalLM", |
| "NeoLLMModel", |
| "NeoLLMPreTrainedModel", |
| "NeoLLMConfig", |
| "FANLayer", |
| ] |
|
|
|
|