Update modeling_neollm.py
Browse files- modeling_neollm.py +268 -85
modeling_neollm.py
CHANGED
|
@@ -10,6 +10,7 @@ Updated to include:
|
|
| 10 |
- SeeDNorm: Dynamic normalization with input-dependent scaling for better adaptability
|
| 11 |
- Dropout regularization at strategic locations
|
| 12 |
- ResFormer: Feature residual connections from first layer (applied before projections)
|
|
|
|
| 13 |
"""
|
| 14 |
|
| 15 |
import math
|
|
@@ -26,7 +27,6 @@ from transformers.masking_utils import create_causal_mask
|
|
| 26 |
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
| 27 |
from transformers.modeling_layers import GradientCheckpointingLayer
|
| 28 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
| 29 |
-
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
| 30 |
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 31 |
from transformers.processing_utils import Unpack
|
| 32 |
from transformers.utils import TransformersKwargs, logging
|
|
@@ -238,67 +238,178 @@ class NeoLLMRMSNormGated(nn.Module):
|
|
| 238 |
return hidden_states.to(input_dtype)
|
| 239 |
|
| 240 |
|
| 241 |
-
class
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
super().__init__()
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
self.
|
| 253 |
-
|
| 254 |
-
self.config = config
|
| 255 |
-
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 256 |
-
|
| 257 |
-
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
| 258 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
|
| 303 |
|
| 304 |
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
@@ -323,10 +434,18 @@ def eager_attention_forward(
|
|
| 323 |
dropout: float = 0.0,
|
| 324 |
**kwargs: Unpack[TransformersKwargs],
|
| 325 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
key_states = repeat_kv(key, module.num_key_value_groups)
|
| 327 |
value_states = repeat_kv(value, module.num_key_value_groups)
|
| 328 |
|
|
|
|
| 329 |
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
|
|
|
| 330 |
if attention_mask is not None:
|
| 331 |
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
| 332 |
attn_weights = attn_weights + causal_mask
|
|
@@ -342,10 +461,14 @@ def eager_attention_forward(
|
|
| 342 |
class NeoLLMAttention(nn.Module):
|
| 343 |
"""
|
| 344 |
Multi-headed attention with FANformer integration, SeeDNorm for Q/K normalization,
|
| 345 |
-
and ResFormer feature residual connections
|
| 346 |
|
| 347 |
ResFormer enhancement: Applies learnable feature residual connections from the first layer
|
| 348 |
BEFORE QKV projections: H'_fan_n = λ_1 * H_fan_1 + λ_2 * H_fan_n
|
|
|
|
|
|
|
|
|
|
|
|
|
| 349 |
"""
|
| 350 |
|
| 351 |
def __init__(self, config: NeoLLMConfig, layer_idx: int):
|
|
@@ -353,7 +476,11 @@ class NeoLLMAttention(nn.Module):
|
|
| 353 |
self.config = config
|
| 354 |
self.layer_idx = layer_idx
|
| 355 |
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
| 356 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 357 |
self.scaling = self.head_dim**-0.5
|
| 358 |
self.attention_dropout = config.attention_dropout
|
| 359 |
self.is_causal = True
|
|
@@ -369,22 +496,34 @@ class NeoLLMAttention(nn.Module):
|
|
| 369 |
|
| 370 |
# QKV projections operate on FAN-transformed features
|
| 371 |
self.q_proj = nn.Linear(
|
| 372 |
-
fan_output_dim,
|
| 373 |
)
|
| 374 |
self.k_proj = nn.Linear(
|
| 375 |
-
fan_output_dim,
|
| 376 |
)
|
| 377 |
self.v_proj = nn.Linear(
|
| 378 |
-
fan_output_dim,
|
| 379 |
)
|
| 380 |
self.o_proj = nn.Linear(
|
| 381 |
-
|
| 382 |
)
|
| 383 |
|
| 384 |
# SeeDNorm for Q/K normalization (replaces RMSNorm)
|
| 385 |
self.q_norm = SeeDNorm(self.head_dim, eps=config.rms_norm_eps)
|
| 386 |
self.k_norm = SeeDNorm(self.head_dim, eps=config.rms_norm_eps)
|
| 387 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 388 |
# Dropout for attention output
|
| 389 |
self.dropout = nn.Dropout(config.dropout_rate)
|
| 390 |
|
|
@@ -401,34 +540,61 @@ class NeoLLMAttention(nn.Module):
|
|
| 401 |
**kwargs: Unpack[FlashAttentionKwargs],
|
| 402 |
) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]:
|
| 403 |
input_shape = hidden_states.shape[:-1]
|
|
|
|
| 404 |
|
| 405 |
# Apply FANformer transformation first
|
| 406 |
hidden_states_fan = self.fan_layer(hidden_states)
|
| 407 |
|
| 408 |
# ResFormer: Apply feature residual connection BEFORE projections
|
| 409 |
-
# This ensures dimensional compatibility across all layer types
|
| 410 |
if first_layer_fan is not None:
|
| 411 |
hidden_states_fan = self.lambda_1 * first_layer_fan + self.lambda_2 * hidden_states_fan
|
| 412 |
|
| 413 |
# Store current FAN features for potential use as first_layer_fan in subsequent layers
|
| 414 |
current_layer_fan = hidden_states_fan.clone()
|
| 415 |
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
# Use FAN-transformed features (with residual applied) for projections
|
| 419 |
query_states, gate = torch.chunk(
|
| 420 |
-
self.q_proj(hidden_states_fan).view(
|
|
|
|
| 421 |
)
|
| 422 |
-
gate = gate.reshape(
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
value_states = self.v_proj(hidden_states_fan).view(
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 432 |
attention_interface: Callable = eager_attention_forward
|
| 433 |
if self.config._attn_implementation != "eager":
|
| 434 |
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
|
@@ -444,7 +610,10 @@ class NeoLLMAttention(nn.Module):
|
|
| 444 |
**kwargs,
|
| 445 |
)
|
| 446 |
|
| 447 |
-
|
|
|
|
|
|
|
|
|
|
| 448 |
attn_output = attn_output * torch.sigmoid(gate)
|
| 449 |
|
| 450 |
attn_output = self.o_proj(attn_output)
|
|
@@ -998,6 +1167,7 @@ class NeoLLMPreTrainedModel(PreTrainedModel):
|
|
| 998 |
module.lambda_1.data.fill_(0.5)
|
| 999 |
if hasattr(module, 'lambda_2'):
|
| 1000 |
module.lambda_2.data.fill_(0.5)
|
|
|
|
| 1001 |
elif isinstance(module, GPAS):
|
| 1002 |
# Initialize GPAS alpha to 0 as per paper
|
| 1003 |
module.alpha.data.fill_(0.0)
|
|
@@ -1010,6 +1180,9 @@ class NeoLLMPreTrainedModel(PreTrainedModel):
|
|
| 1010 |
# beta (β) initialized to 0 (default in Parameter definition)
|
| 1011 |
# alpha (α) initialized to 1 (default in Parameter definition)
|
| 1012 |
pass
|
|
|
|
|
|
|
|
|
|
| 1013 |
|
| 1014 |
|
| 1015 |
class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
@@ -1023,7 +1196,15 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 1023 |
)
|
| 1024 |
# SeeDNorm for final output normalization (replaces RMSNorm)
|
| 1025 |
self.norm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 1026 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1027 |
self.gradient_checkpointing = False
|
| 1028 |
|
| 1029 |
# ResFormer: storage for first layer's FAN features (H_fan_1)
|
|
@@ -1061,8 +1242,9 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 1061 |
|
| 1062 |
hidden_states = inputs_embeds
|
| 1063 |
|
| 1064 |
-
#
|
| 1065 |
-
position_embeddings
|
|
|
|
| 1066 |
|
| 1067 |
# ResFormer: reset first_layer_fan at the start of each forward pass
|
| 1068 |
self.first_layer_fan = None
|
|
@@ -1193,6 +1375,7 @@ __all__ = [
|
|
| 1193 |
"NeoLLMConfig",
|
| 1194 |
"FANLayer",
|
| 1195 |
"SeeDNorm",
|
|
|
|
| 1196 |
]
|
| 1197 |
|
| 1198 |
# Register the configuration and model for AutoClass support
|
|
|
|
| 10 |
- SeeDNorm: Dynamic normalization with input-dependent scaling for better adaptability
|
| 11 |
- Dropout regularization at strategic locations
|
| 12 |
- ResFormer: Feature residual connections from first layer (applied before projections)
|
| 13 |
+
- PoPE (Polar Coordinate Position Embedding): Decouples 'what' and 'where' for superior length extrapolation
|
| 14 |
"""
|
| 15 |
|
| 16 |
import math
|
|
|
|
| 27 |
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
| 28 |
from transformers.modeling_layers import GradientCheckpointingLayer
|
| 29 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
|
|
|
| 30 |
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 31 |
from transformers.processing_utils import Unpack
|
| 32 |
from transformers.utils import TransformersKwargs, logging
|
|
|
|
| 238 |
return hidden_states.to(input_dtype)
|
| 239 |
|
| 240 |
|
| 241 |
+
class PolarPositionalEmbedding(nn.Module):
|
| 242 |
+
"""
|
| 243 |
+
Polar Coordinate Position Embedding (PoPE) - FlashAttention2-compatible implementation
|
| 244 |
+
|
| 245 |
+
From "Decoupling the 'What' and 'Where' with Polar Coordinate Positional Embedding":
|
| 246 |
+
|
| 247 |
+
THEORETICAL FORMULATION (from paper):
|
| 248 |
+
- Magnitudes: μ_q̃tc = softplus(qtc), μ_k̃sc = softplus(ksc) (content only)
|
| 249 |
+
- Phases: φ_q̃tc = t*θc, φ_k̃sc = s*θc (position only)
|
| 250 |
+
- Attention score: a^PoPE_ts = Re[q̃^H @ k̃] = Σ (x_q * x_k + y_q * y_k)
|
| 251 |
+
|
| 252 |
+
Where x = μ*cos(φ), y = μ*sin(φ) are Cartesian coordinates.
|
| 253 |
+
|
| 254 |
+
PRACTICAL IMPLEMENTATION (this code):
|
| 255 |
+
To enable FlashAttention2 compatibility without custom kernels, we use the
|
| 256 |
+
mathematically equivalent formulation:
|
| 257 |
+
|
| 258 |
+
Q' = [x_q; y_q] ∈ ℝ^(2d) (concatenation of real and imaginary parts)
|
| 259 |
+
K' = [x_k; y_k] ∈ ℝ^(2d)
|
| 260 |
+
|
| 261 |
+
This doubles head_dim (d → 2d) but allows:
|
| 262 |
+
- Standard FlashAttention2 kernel usage
|
| 263 |
+
- Q'·K' = Σ(x_q*x_k + y_q*y_k) = a^PoPE_ts (mathematically equivalent)
|
| 264 |
+
- ~2× overhead in attention computation (acceptable tradeoff vs custom kernels)
|
| 265 |
+
|
| 266 |
+
Benefits retained:
|
| 267 |
+
- Superior length extrapolation without fine-tuning
|
| 268 |
+
- Decoupled 'what' and 'where' information
|
| 269 |
+
- Better performance on content/position independent matching tasks
|
| 270 |
+
|
| 271 |
+
Args:
|
| 272 |
+
dim: Original dimension per attention head (will be doubled to 2d internally)
|
| 273 |
+
max_position_embeddings: Maximum sequence length
|
| 274 |
+
base: Base wavelength (theta) for frequency components
|
| 275 |
+
device: Device to place tensors on
|
| 276 |
+
"""
|
| 277 |
+
|
| 278 |
+
def __init__(
|
| 279 |
+
self,
|
| 280 |
+
dim: int,
|
| 281 |
+
max_position_embeddings: int = 2048,
|
| 282 |
+
base: float = 10000.0,
|
| 283 |
+
device=None
|
| 284 |
+
):
|
| 285 |
super().__init__()
|
| 286 |
+
self.dim = dim # Original head_dim (d)
|
| 287 |
+
self.max_position_embeddings = max_position_embeddings
|
| 288 |
+
self.base = base
|
| 289 |
+
|
| 290 |
+
# Compute frequency components: θc = base^(-(c-1)/d) for c = 1, ..., d
|
| 291 |
+
# PoPE uses d frequencies (not d/2 like RoPE)
|
| 292 |
+
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 1, dtype=torch.float32) / self.dim))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 294 |
+
|
| 295 |
+
def forward(
|
| 296 |
+
self,
|
| 297 |
+
q: torch.Tensor,
|
| 298 |
+
k: torch.Tensor,
|
| 299 |
+
position_ids: torch.LongTensor,
|
| 300 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 301 |
+
"""
|
| 302 |
+
Apply PoPE transformation with concatenation for FlashAttention2 compatibility.
|
| 303 |
+
|
| 304 |
+
Args:
|
| 305 |
+
q: Query tensor of shape (batch, num_heads, seq_len, head_dim)
|
| 306 |
+
k: Key tensor of shape (batch, num_kv_heads, seq_len, head_dim)
|
| 307 |
+
position_ids: Position indices of shape (batch, seq_len)
|
| 308 |
+
|
| 309 |
+
Returns:
|
| 310 |
+
Tuple of (Q', K') with doubled head_dim:
|
| 311 |
+
- Q': shape (batch, num_heads, seq_len, 2*head_dim) = [x_q; y_q]
|
| 312 |
+
- K': shape (batch, num_kv_heads, seq_len, 2*head_dim) = [x_k; y_k]
|
| 313 |
+
"""
|
| 314 |
+
# Step 1: Apply softplus to get magnitudes (Equation 3 from paper)
|
| 315 |
+
# μ_q̃tc = softplus(qtc), μ_k̃sc = softplus(ksc)
|
| 316 |
+
mu_q = F.softplus(q)
|
| 317 |
+
mu_k = F.softplus(k)
|
| 318 |
+
|
| 319 |
+
# Step 2: Compute phase angles (Equation 4 from paper)
|
| 320 |
+
# φ_q̃tc = t*θc, φ_k̃sc = s*θc
|
| 321 |
+
# freqs shape: (batch, 1, seq_len, head_dim)
|
| 322 |
+
inv_freq_expanded = self.inv_freq[None, None, None, :].to(q.device)
|
| 323 |
+
position_ids_expanded = position_ids[:, None, :, None].float()
|
| 324 |
+
freqs = position_ids_expanded * inv_freq_expanded
|
| 325 |
+
|
| 326 |
+
# Step 3: Convert to Cartesian coordinates (Equations 7-8 from paper)
|
| 327 |
+
# x = μ * cos(φ), y = μ * sin(φ)
|
| 328 |
+
cos_freqs = torch.cos(freqs)
|
| 329 |
+
sin_freqs = torch.sin(freqs)
|
| 330 |
+
|
| 331 |
+
q_real = mu_q * cos_freqs # x_q component
|
| 332 |
+
q_imag = mu_q * sin_freqs # y_q component
|
| 333 |
+
k_real = mu_k * cos_freqs # x_k component
|
| 334 |
+
k_imag = mu_k * sin_freqs # y_k component
|
| 335 |
+
|
| 336 |
+
# Step 4: Concatenate [real; imag] to create 2d dimensional vectors
|
| 337 |
+
# This enables Q'·K' = Σ(x_q*x_k + y_q*y_k) via standard dot product
|
| 338 |
+
q_pope = torch.cat([q_real, q_imag], dim=-1) # (batch, num_heads, seq_len, 2*head_dim)
|
| 339 |
+
k_pope = torch.cat([k_real, k_imag], dim=-1) # (batch, num_kv_heads, seq_len, 2*head_dim)
|
| 340 |
+
|
| 341 |
+
return q_pope, k_pope
|
| 342 |
+
|
| 343 |
+
def apply_pope_embedding(
|
| 344 |
+
q_pope: torch.Tensor,
|
| 345 |
+
k_pope: torch.Tensor,
|
| 346 |
+
delta_bias: Optional[torch.Tensor] = None,
|
| 347 |
+
num_key_value_groups: int = 1
|
| 348 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 349 |
+
"""
|
| 350 |
+
Apply learnable phase bias δc to PoPE embeddings (Equation 6 from paper).
|
| 351 |
+
|
| 352 |
+
With phase bias: a^PoPE_ts = Σ μ_q μ_k cos((s-t)θc + δc)
|
| 353 |
+
|
| 354 |
+
This is implemented by rotating k by exp(i*δ) in the concatenated representation.
|
| 355 |
+
|
| 356 |
+
Args:
|
| 357 |
+
q_pope: Query with PoPE applied, shape (batch, num_heads, seq_len, 2*head_dim)
|
| 358 |
+
Format: [x_q; y_q] where first head_dim is real, second head_dim is imaginary
|
| 359 |
+
k_pope: Key with PoPE applied, shape (batch, num_kv_heads, seq_len, 2*head_dim)
|
| 360 |
+
Format: [x_k; y_k]
|
| 361 |
+
delta_bias: Learnable phase bias per head/dim, shape (num_attention_heads, head_dim)
|
| 362 |
+
Bounded to [-2π, 0] as per paper. Applied only to keys.
|
| 363 |
+
num_key_value_groups: Number of query groups per key/value head for GQA
|
| 364 |
+
|
| 365 |
+
Returns:
|
| 366 |
+
Tuple of (q_out, k_out) with delta_bias applied:
|
| 367 |
+
- q_out: Query unchanged (phase bias only affects keys)
|
| 368 |
+
- k_out: Key rotated by delta_bias
|
| 369 |
+
Both maintain shape with 2*head_dim
|
| 370 |
+
"""
|
| 371 |
+
# Query passes through unchanged (phase bias only affects keys)
|
| 372 |
+
q_out = q_pope
|
| 373 |
+
|
| 374 |
+
# Apply learnable phase bias to key if provided
|
| 375 |
+
if delta_bias is not None:
|
| 376 |
+
# Get head_dim (original dimension, half of current last dim)
|
| 377 |
+
head_dim = k_pope.shape[-1] // 2
|
| 378 |
+
|
| 379 |
+
# Split k into real and imaginary components
|
| 380 |
+
k_real, k_imag = k_pope[..., :head_dim], k_pope[..., head_dim:]
|
| 381 |
+
|
| 382 |
+
# Clamp delta_bias to [-2π, 0] as specified in paper Section 3
|
| 383 |
+
delta_clamped = torch.clamp(delta_bias, min=-2*math.pi, max=0)
|
| 384 |
+
|
| 385 |
+
# Adapt delta_bias for GQA: (num_attention_heads, head_dim) -> (num_kv_heads, head_dim)
|
| 386 |
+
# Group the attention heads' biases by averaging/selecting
|
| 387 |
+
if num_key_value_groups > 1:
|
| 388 |
+
# Reshape: (num_attention_heads, head_dim) -> (num_kv_heads, num_key_value_groups, head_dim)
|
| 389 |
+
num_kv_heads = delta_clamped.shape[0] // num_key_value_groups
|
| 390 |
+
delta_clamped = delta_clamped.view(num_kv_heads, num_key_value_groups, head_dim)
|
| 391 |
+
# Average across the groups to get one bias per kv_head
|
| 392 |
+
delta_clamped = delta_clamped.mean(dim=1) # (num_kv_heads, head_dim)
|
| 393 |
+
|
| 394 |
+
# Reshape for broadcasting: (num_kv_heads, head_dim) -> (1, num_kv_heads, 1, head_dim)
|
| 395 |
+
delta_clamped = delta_clamped.unsqueeze(0).unsqueeze(2)
|
| 396 |
+
|
| 397 |
+
# Compute rotation components: exp(i*δ) = cos(δ) + i*sin(δ)
|
| 398 |
+
cos_delta = torch.cos(delta_clamped)
|
| 399 |
+
sin_delta = torch.sin(delta_clamped)
|
| 400 |
+
|
| 401 |
+
# Apply complex multiplication: k * exp(i*δ)
|
| 402 |
+
# Real part: k_real*cos(δ) - k_imag*sin(δ)
|
| 403 |
+
# Imag part: k_real*sin(δ) + k_imag*cos(δ)
|
| 404 |
+
k_real_rotated = k_real * cos_delta - k_imag * sin_delta
|
| 405 |
+
k_imag_rotated = k_real * sin_delta + k_imag * cos_delta
|
| 406 |
+
|
| 407 |
+
# Recombine into concatenated form [real; imag]
|
| 408 |
+
k_out = torch.cat([k_real_rotated, k_imag_rotated], dim=-1)
|
| 409 |
+
else:
|
| 410 |
+
k_out = k_pope
|
| 411 |
+
|
| 412 |
+
return q_out, k_out
|
| 413 |
|
| 414 |
|
| 415 |
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
|
|
| 434 |
dropout: float = 0.0,
|
| 435 |
**kwargs: Unpack[TransformersKwargs],
|
| 436 |
):
|
| 437 |
+
"""
|
| 438 |
+
Standard eager attention implementation for PoPE.
|
| 439 |
+
|
| 440 |
+
Note: query and key have 2*head_dim due to PoPE concatenation [real; imag].
|
| 441 |
+
Value is padded to match this dimension for kernel compatibility.
|
| 442 |
+
"""
|
| 443 |
key_states = repeat_kv(key, module.num_key_value_groups)
|
| 444 |
value_states = repeat_kv(value, module.num_key_value_groups)
|
| 445 |
|
| 446 |
+
# Standard attention computation
|
| 447 |
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
| 448 |
+
|
| 449 |
if attention_mask is not None:
|
| 450 |
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
| 451 |
attn_weights = attn_weights + causal_mask
|
|
|
|
| 461 |
class NeoLLMAttention(nn.Module):
|
| 462 |
"""
|
| 463 |
Multi-headed attention with FANformer integration, SeeDNorm for Q/K normalization,
|
| 464 |
+
PoPE for positional encoding, and ResFormer feature residual connections.
|
| 465 |
|
| 466 |
ResFormer enhancement: Applies learnable feature residual connections from the first layer
|
| 467 |
BEFORE QKV projections: H'_fan_n = λ_1 * H_fan_1 + λ_2 * H_fan_n
|
| 468 |
+
|
| 469 |
+
PoPE enhancement: Decouples 'what' and 'where' via polar coordinates for superior
|
| 470 |
+
length extrapolation and content/position independent matching. Uses concatenated
|
| 471 |
+
[real; imag] representation for FlashAttention2 compatibility (2× head_dim overhead).
|
| 472 |
"""
|
| 473 |
|
| 474 |
def __init__(self, config: NeoLLMConfig, layer_idx: int):
|
|
|
|
| 476 |
self.config = config
|
| 477 |
self.layer_idx = layer_idx
|
| 478 |
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
| 479 |
+
self.num_attention_heads = config.num_attention_heads
|
| 480 |
+
self.num_key_value_heads = config.num_key_value_heads
|
| 481 |
+
self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
|
| 482 |
+
|
| 483 |
+
# PoPE uses original head_dim for scaling (not 2*head_dim)
|
| 484 |
self.scaling = self.head_dim**-0.5
|
| 485 |
self.attention_dropout = config.attention_dropout
|
| 486 |
self.is_causal = True
|
|
|
|
| 496 |
|
| 497 |
# QKV projections operate on FAN-transformed features
|
| 498 |
self.q_proj = nn.Linear(
|
| 499 |
+
fan_output_dim, self.num_attention_heads * self.head_dim * 2, bias=config.attention_bias
|
| 500 |
)
|
| 501 |
self.k_proj = nn.Linear(
|
| 502 |
+
fan_output_dim, self.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
| 503 |
)
|
| 504 |
self.v_proj = nn.Linear(
|
| 505 |
+
fan_output_dim, self.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
| 506 |
)
|
| 507 |
self.o_proj = nn.Linear(
|
| 508 |
+
self.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
| 509 |
)
|
| 510 |
|
| 511 |
# SeeDNorm for Q/K normalization (replaces RMSNorm)
|
| 512 |
self.q_norm = SeeDNorm(self.head_dim, eps=config.rms_norm_eps)
|
| 513 |
self.k_norm = SeeDNorm(self.head_dim, eps=config.rms_norm_eps)
|
| 514 |
|
| 515 |
+
# PoPE: Learnable phase bias δc for each head and dimension
|
| 516 |
+
# Initialized based on pope_bias_init config: 'zero' or 'uniform'
|
| 517 |
+
pope_bias_init = getattr(config, 'pope_bias_init', 'zero')
|
| 518 |
+
if pope_bias_init == 'uniform':
|
| 519 |
+
# Uniform initialization in [-2π, 0]
|
| 520 |
+
delta_init = torch.empty(self.num_attention_heads, self.head_dim).uniform_(-2 * math.pi, 0)
|
| 521 |
+
else:
|
| 522 |
+
# Zero initialization (better for length extrapolation)
|
| 523 |
+
delta_init = torch.zeros(self.num_attention_heads, self.head_dim)
|
| 524 |
+
|
| 525 |
+
self.delta_bias = nn.Parameter(delta_init)
|
| 526 |
+
|
| 527 |
# Dropout for attention output
|
| 528 |
self.dropout = nn.Dropout(config.dropout_rate)
|
| 529 |
|
|
|
|
| 540 |
**kwargs: Unpack[FlashAttentionKwargs],
|
| 541 |
) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]:
|
| 542 |
input_shape = hidden_states.shape[:-1]
|
| 543 |
+
batch_size, seq_len = input_shape
|
| 544 |
|
| 545 |
# Apply FANformer transformation first
|
| 546 |
hidden_states_fan = self.fan_layer(hidden_states)
|
| 547 |
|
| 548 |
# ResFormer: Apply feature residual connection BEFORE projections
|
|
|
|
| 549 |
if first_layer_fan is not None:
|
| 550 |
hidden_states_fan = self.lambda_1 * first_layer_fan + self.lambda_2 * hidden_states_fan
|
| 551 |
|
| 552 |
# Store current FAN features for potential use as first_layer_fan in subsequent layers
|
| 553 |
current_layer_fan = hidden_states_fan.clone()
|
| 554 |
|
| 555 |
+
# Project to Q, K, V
|
|
|
|
|
|
|
| 556 |
query_states, gate = torch.chunk(
|
| 557 |
+
self.q_proj(hidden_states_fan).view(batch_size, seq_len, self.num_attention_heads, self.head_dim * 2),
|
| 558 |
+
2, dim=-1
|
| 559 |
)
|
| 560 |
+
gate = gate.reshape(batch_size, seq_len, -1)
|
| 561 |
+
|
| 562 |
+
key_states = self.k_proj(hidden_states_fan).view(
|
| 563 |
+
batch_size, seq_len, self.num_key_value_heads, self.head_dim
|
| 564 |
+
)
|
| 565 |
+
value_states = self.v_proj(hidden_states_fan).view(
|
| 566 |
+
batch_size, seq_len, self.num_key_value_heads, self.head_dim
|
| 567 |
+
)
|
| 568 |
+
|
| 569 |
+
# Apply SeeDNorm to Q and K before PoPE
|
| 570 |
+
query_states = self.q_norm(query_states)
|
| 571 |
+
key_states = self.k_norm(key_states)
|
| 572 |
+
|
| 573 |
+
# Transpose to (batch, num_heads, seq_len, head_dim)
|
| 574 |
+
query_states = query_states.transpose(1, 2)
|
| 575 |
+
key_states = key_states.transpose(1, 2)
|
| 576 |
+
value_states = value_states.transpose(1, 2)
|
| 577 |
+
|
| 578 |
+
# Apply PoPE: position_embeddings is (pope_emb, position_ids)
|
| 579 |
+
pope_emb, position_ids = position_embeddings
|
| 580 |
+
|
| 581 |
+
# Get PoPE embeddings with concatenated [real; imag] representation
|
| 582 |
+
# Returns Q', K' with shape (..., 2*head_dim)
|
| 583 |
+
query_states, key_states = pope_emb(query_states, key_states, position_ids)
|
| 584 |
+
|
| 585 |
+
# Apply learnable phase bias δc
|
| 586 |
+
# Apply learnable phase bias δc
|
| 587 |
+
query_states, key_states = apply_pope_embedding(
|
| 588 |
+
query_states,
|
| 589 |
+
key_states,
|
| 590 |
+
self.delta_bias,
|
| 591 |
+
num_key_value_groups=self.num_key_value_groups # AGREGAR ESTE PARÁMETRO
|
| 592 |
+
)
|
| 593 |
+
# Pad value to 2*head_dim for dimension compatibility
|
| 594 |
+
# Only first head_dim components are used in output
|
| 595 |
+
value_states = F.pad(value_states, (0, self.head_dim), value=0.0)
|
| 596 |
+
|
| 597 |
+
# Call attention with doubled head_dim
|
| 598 |
attention_interface: Callable = eager_attention_forward
|
| 599 |
if self.config._attn_implementation != "eager":
|
| 600 |
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
|
|
|
| 610 |
**kwargs,
|
| 611 |
)
|
| 612 |
|
| 613 |
+
# Extract only the first head_dim components (discard padding)
|
| 614 |
+
attn_output = attn_output[..., :self.head_dim]
|
| 615 |
+
|
| 616 |
+
attn_output = attn_output.reshape(batch_size, seq_len, -1).contiguous()
|
| 617 |
attn_output = attn_output * torch.sigmoid(gate)
|
| 618 |
|
| 619 |
attn_output = self.o_proj(attn_output)
|
|
|
|
| 1167 |
module.lambda_1.data.fill_(0.5)
|
| 1168 |
if hasattr(module, 'lambda_2'):
|
| 1169 |
module.lambda_2.data.fill_(0.5)
|
| 1170 |
+
# PoPE delta_bias already initialized in __init__
|
| 1171 |
elif isinstance(module, GPAS):
|
| 1172 |
# Initialize GPAS alpha to 0 as per paper
|
| 1173 |
module.alpha.data.fill_(0.0)
|
|
|
|
| 1180 |
# beta (β) initialized to 0 (default in Parameter definition)
|
| 1181 |
# alpha (α) initialized to 1 (default in Parameter definition)
|
| 1182 |
pass
|
| 1183 |
+
elif isinstance(module, PolarPositionalEmbedding):
|
| 1184 |
+
# PoPE frequency initialization handled in __init__
|
| 1185 |
+
pass
|
| 1186 |
|
| 1187 |
|
| 1188 |
class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
|
|
| 1196 |
)
|
| 1197 |
# SeeDNorm for final output normalization (replaces RMSNorm)
|
| 1198 |
self.norm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 1199 |
+
|
| 1200 |
+
# PoPE positional embedding (replaces RoPE)
|
| 1201 |
+
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
| 1202 |
+
self.pope_emb = PolarPositionalEmbedding(
|
| 1203 |
+
dim=head_dim,
|
| 1204 |
+
max_position_embeddings=config.max_position_embeddings,
|
| 1205 |
+
base=getattr(config, 'rope_theta', 10000.0), # Use rope_theta for backward compatibility
|
| 1206 |
+
)
|
| 1207 |
+
|
| 1208 |
self.gradient_checkpointing = False
|
| 1209 |
|
| 1210 |
# ResFormer: storage for first layer's FAN features (H_fan_1)
|
|
|
|
| 1242 |
|
| 1243 |
hidden_states = inputs_embeds
|
| 1244 |
|
| 1245 |
+
# Create position embeddings for PoPE
|
| 1246 |
+
# position_embeddings is a tuple of (pope_emb, position_ids)
|
| 1247 |
+
position_embeddings = (self.pope_emb, position_ids)
|
| 1248 |
|
| 1249 |
# ResFormer: reset first_layer_fan at the start of each forward pass
|
| 1250 |
self.first_layer_fan = None
|
|
|
|
| 1375 |
"NeoLLMConfig",
|
| 1376 |
"FANLayer",
|
| 1377 |
"SeeDNorm",
|
| 1378 |
+
"PolarPositionalEmbedding",
|
| 1379 |
]
|
| 1380 |
|
| 1381 |
# Register the configuration and model for AutoClass support
|