KitsuVp commited on
Commit
4879fe2
·
verified ·
1 Parent(s): 2fa9200

Update modeling_neollm.py

Browse files
Files changed (1) hide show
  1. modeling_neollm.py +268 -85
modeling_neollm.py CHANGED
@@ -10,6 +10,7 @@ Updated to include:
10
  - SeeDNorm: Dynamic normalization with input-dependent scaling for better adaptability
11
  - Dropout regularization at strategic locations
12
  - ResFormer: Feature residual connections from first layer (applied before projections)
 
13
  """
14
 
15
  import math
@@ -26,7 +27,6 @@ from transformers.masking_utils import create_causal_mask
26
  from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
27
  from transformers.modeling_layers import GradientCheckpointingLayer
28
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
29
- from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
30
  from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
31
  from transformers.processing_utils import Unpack
32
  from transformers.utils import TransformersKwargs, logging
@@ -238,67 +238,178 @@ class NeoLLMRMSNormGated(nn.Module):
238
  return hidden_states.to(input_dtype)
239
 
240
 
241
- class NeoLLMRotaryEmbedding(nn.Module):
242
- inv_freq: torch.Tensor # fix linting for `register_buffer`
243
-
244
- def __init__(self, config: NeoLLMConfig, device=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  super().__init__()
246
- # BC: "rope_type" was originally "type"
247
- if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
248
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
249
- else:
250
- self.rope_type = "default"
251
- self.max_seq_len_cached = config.max_position_embeddings
252
- self.original_max_seq_len = config.max_position_embeddings
253
-
254
- self.config = config
255
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
256
-
257
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
258
  self.register_buffer("inv_freq", inv_freq, persistent=False)
259
- self.original_inv_freq = self.inv_freq
260
-
261
- @torch.no_grad()
262
- @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
263
- def forward(self, x, position_ids):
264
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
265
- position_ids_expanded = position_ids[:, None, :].float()
266
-
267
- device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
268
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
269
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
270
- emb = torch.cat((freqs, freqs), dim=-1)
271
- cos = emb.cos() * self.attention_scaling
272
- sin = emb.sin() * self.attention_scaling
273
-
274
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
275
-
276
-
277
- def rotate_half(x):
278
- """Rotates half the hidden dims of the input."""
279
- x1 = x[..., : x.shape[-1] // 2]
280
- x2 = x[..., x.shape[-1] // 2 :]
281
- return torch.cat((-x2, x1), dim=-1)
282
-
283
-
284
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
285
- """Applies Rotary Position Embedding to the query and key tensors."""
286
- cos = cos.unsqueeze(unsqueeze_dim)
287
- sin = sin.unsqueeze(unsqueeze_dim)
288
-
289
- # Keep half or full tensor for later concatenation
290
- rotary_dim = cos.shape[-1]
291
- q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
292
- k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
293
-
294
- # Apply rotary embeddings on the first half or full tensor
295
- q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
296
- k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
297
-
298
- # Concatenate back to full shape
299
- q_embed = torch.cat([q_embed, q_pass], dim=-1)
300
- k_embed = torch.cat([k_embed, k_pass], dim=-1)
301
- return q_embed, k_embed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
 
303
 
304
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
@@ -323,10 +434,18 @@ def eager_attention_forward(
323
  dropout: float = 0.0,
324
  **kwargs: Unpack[TransformersKwargs],
325
  ):
 
 
 
 
 
 
326
  key_states = repeat_kv(key, module.num_key_value_groups)
327
  value_states = repeat_kv(value, module.num_key_value_groups)
328
 
 
329
  attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
 
330
  if attention_mask is not None:
331
  causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
332
  attn_weights = attn_weights + causal_mask
@@ -342,10 +461,14 @@ def eager_attention_forward(
342
  class NeoLLMAttention(nn.Module):
343
  """
344
  Multi-headed attention with FANformer integration, SeeDNorm for Q/K normalization,
345
- and ResFormer feature residual connections for enhanced information flow.
346
 
347
  ResFormer enhancement: Applies learnable feature residual connections from the first layer
348
  BEFORE QKV projections: H'_fan_n = λ_1 * H_fan_1 + λ_2 * H_fan_n
 
 
 
 
349
  """
350
 
351
  def __init__(self, config: NeoLLMConfig, layer_idx: int):
@@ -353,7 +476,11 @@ class NeoLLMAttention(nn.Module):
353
  self.config = config
354
  self.layer_idx = layer_idx
355
  self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
356
- self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
 
 
 
 
357
  self.scaling = self.head_dim**-0.5
358
  self.attention_dropout = config.attention_dropout
359
  self.is_causal = True
@@ -369,22 +496,34 @@ class NeoLLMAttention(nn.Module):
369
 
370
  # QKV projections operate on FAN-transformed features
371
  self.q_proj = nn.Linear(
372
- fan_output_dim, config.num_attention_heads * self.head_dim * 2, bias=config.attention_bias
373
  )
374
  self.k_proj = nn.Linear(
375
- fan_output_dim, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
376
  )
377
  self.v_proj = nn.Linear(
378
- fan_output_dim, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
379
  )
380
  self.o_proj = nn.Linear(
381
- config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
382
  )
383
 
384
  # SeeDNorm for Q/K normalization (replaces RMSNorm)
385
  self.q_norm = SeeDNorm(self.head_dim, eps=config.rms_norm_eps)
386
  self.k_norm = SeeDNorm(self.head_dim, eps=config.rms_norm_eps)
387
 
 
 
 
 
 
 
 
 
 
 
 
 
388
  # Dropout for attention output
389
  self.dropout = nn.Dropout(config.dropout_rate)
390
 
@@ -401,34 +540,61 @@ class NeoLLMAttention(nn.Module):
401
  **kwargs: Unpack[FlashAttentionKwargs],
402
  ) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]:
403
  input_shape = hidden_states.shape[:-1]
 
404
 
405
  # Apply FANformer transformation first
406
  hidden_states_fan = self.fan_layer(hidden_states)
407
 
408
  # ResFormer: Apply feature residual connection BEFORE projections
409
- # This ensures dimensional compatibility across all layer types
410
  if first_layer_fan is not None:
411
  hidden_states_fan = self.lambda_1 * first_layer_fan + self.lambda_2 * hidden_states_fan
412
 
413
  # Store current FAN features for potential use as first_layer_fan in subsequent layers
414
  current_layer_fan = hidden_states_fan.clone()
415
 
416
- hidden_shape = (*input_shape, -1, self.head_dim)
417
-
418
- # Use FAN-transformed features (with residual applied) for projections
419
  query_states, gate = torch.chunk(
420
- self.q_proj(hidden_states_fan).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1
 
421
  )
422
- gate = gate.reshape(*input_shape, -1)
423
-
424
- # Apply SeeDNorm to Q and K
425
- query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2)
426
- key_states = self.k_norm(self.k_proj(hidden_states_fan).view(hidden_shape)).transpose(1, 2)
427
- value_states = self.v_proj(hidden_states_fan).view(hidden_shape).transpose(1, 2)
428
-
429
- cos, sin = position_embeddings
430
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
431
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
432
  attention_interface: Callable = eager_attention_forward
433
  if self.config._attn_implementation != "eager":
434
  attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
@@ -444,7 +610,10 @@ class NeoLLMAttention(nn.Module):
444
  **kwargs,
445
  )
446
 
447
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
 
 
 
448
  attn_output = attn_output * torch.sigmoid(gate)
449
 
450
  attn_output = self.o_proj(attn_output)
@@ -998,6 +1167,7 @@ class NeoLLMPreTrainedModel(PreTrainedModel):
998
  module.lambda_1.data.fill_(0.5)
999
  if hasattr(module, 'lambda_2'):
1000
  module.lambda_2.data.fill_(0.5)
 
1001
  elif isinstance(module, GPAS):
1002
  # Initialize GPAS alpha to 0 as per paper
1003
  module.alpha.data.fill_(0.0)
@@ -1010,6 +1180,9 @@ class NeoLLMPreTrainedModel(PreTrainedModel):
1010
  # beta (β) initialized to 0 (default in Parameter definition)
1011
  # alpha (α) initialized to 1 (default in Parameter definition)
1012
  pass
 
 
 
1013
 
1014
 
1015
  class NeoLLMModel(NeoLLMPreTrainedModel):
@@ -1023,7 +1196,15 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
1023
  )
1024
  # SeeDNorm for final output normalization (replaces RMSNorm)
1025
  self.norm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps)
1026
- self.rotary_emb = NeoLLMRotaryEmbedding(config=config)
 
 
 
 
 
 
 
 
1027
  self.gradient_checkpointing = False
1028
 
1029
  # ResFormer: storage for first layer's FAN features (H_fan_1)
@@ -1061,8 +1242,9 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
1061
 
1062
  hidden_states = inputs_embeds
1063
 
1064
- # create position embeddings to be shared across the decoder layers
1065
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
 
1066
 
1067
  # ResFormer: reset first_layer_fan at the start of each forward pass
1068
  self.first_layer_fan = None
@@ -1193,6 +1375,7 @@ __all__ = [
1193
  "NeoLLMConfig",
1194
  "FANLayer",
1195
  "SeeDNorm",
 
1196
  ]
1197
 
1198
  # Register the configuration and model for AutoClass support
 
10
  - SeeDNorm: Dynamic normalization with input-dependent scaling for better adaptability
11
  - Dropout regularization at strategic locations
12
  - ResFormer: Feature residual connections from first layer (applied before projections)
13
+ - PoPE (Polar Coordinate Position Embedding): Decouples 'what' and 'where' for superior length extrapolation
14
  """
15
 
16
  import math
 
27
  from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
28
  from transformers.modeling_layers import GradientCheckpointingLayer
29
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
 
30
  from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
31
  from transformers.processing_utils import Unpack
32
  from transformers.utils import TransformersKwargs, logging
 
238
  return hidden_states.to(input_dtype)
239
 
240
 
241
+ class PolarPositionalEmbedding(nn.Module):
242
+ """
243
+ Polar Coordinate Position Embedding (PoPE) - FlashAttention2-compatible implementation
244
+
245
+ From "Decoupling the 'What' and 'Where' with Polar Coordinate Positional Embedding":
246
+
247
+ THEORETICAL FORMULATION (from paper):
248
+ - Magnitudes: μ_q̃tc = softplus(qtc), μ_k̃sc = softplus(ksc) (content only)
249
+ - Phases: φ_q̃tc = t*θc, φ_k̃sc = s*θc (position only)
250
+ - Attention score: a^PoPE_ts = Re[q̃^H @ k̃] = Σ (x_q * x_k + y_q * y_k)
251
+
252
+ Where x = μ*cos(φ), y = μ*sin(φ) are Cartesian coordinates.
253
+
254
+ PRACTICAL IMPLEMENTATION (this code):
255
+ To enable FlashAttention2 compatibility without custom kernels, we use the
256
+ mathematically equivalent formulation:
257
+
258
+ Q' = [x_q; y_q] ∈ ℝ^(2d) (concatenation of real and imaginary parts)
259
+ K' = [x_k; y_k] ∈ ℝ^(2d)
260
+
261
+ This doubles head_dim (d → 2d) but allows:
262
+ - Standard FlashAttention2 kernel usage
263
+ - Q'·K' = Σ(x_q*x_k + y_q*y_k) = a^PoPE_ts (mathematically equivalent)
264
+ - ~2× overhead in attention computation (acceptable tradeoff vs custom kernels)
265
+
266
+ Benefits retained:
267
+ - Superior length extrapolation without fine-tuning
268
+ - Decoupled 'what' and 'where' information
269
+ - Better performance on content/position independent matching tasks
270
+
271
+ Args:
272
+ dim: Original dimension per attention head (will be doubled to 2d internally)
273
+ max_position_embeddings: Maximum sequence length
274
+ base: Base wavelength (theta) for frequency components
275
+ device: Device to place tensors on
276
+ """
277
+
278
+ def __init__(
279
+ self,
280
+ dim: int,
281
+ max_position_embeddings: int = 2048,
282
+ base: float = 10000.0,
283
+ device=None
284
+ ):
285
  super().__init__()
286
+ self.dim = dim # Original head_dim (d)
287
+ self.max_position_embeddings = max_position_embeddings
288
+ self.base = base
289
+
290
+ # Compute frequency components: θc = base^(-(c-1)/d) for c = 1, ..., d
291
+ # PoPE uses d frequencies (not d/2 like RoPE)
292
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 1, dtype=torch.float32) / self.dim))
 
 
 
 
 
293
  self.register_buffer("inv_freq", inv_freq, persistent=False)
294
+
295
+ def forward(
296
+ self,
297
+ q: torch.Tensor,
298
+ k: torch.Tensor,
299
+ position_ids: torch.LongTensor,
300
+ ) -> tuple[torch.Tensor, torch.Tensor]:
301
+ """
302
+ Apply PoPE transformation with concatenation for FlashAttention2 compatibility.
303
+
304
+ Args:
305
+ q: Query tensor of shape (batch, num_heads, seq_len, head_dim)
306
+ k: Key tensor of shape (batch, num_kv_heads, seq_len, head_dim)
307
+ position_ids: Position indices of shape (batch, seq_len)
308
+
309
+ Returns:
310
+ Tuple of (Q', K') with doubled head_dim:
311
+ - Q': shape (batch, num_heads, seq_len, 2*head_dim) = [x_q; y_q]
312
+ - K': shape (batch, num_kv_heads, seq_len, 2*head_dim) = [x_k; y_k]
313
+ """
314
+ # Step 1: Apply softplus to get magnitudes (Equation 3 from paper)
315
+ # μ_q̃tc = softplus(qtc), μ_k̃sc = softplus(ksc)
316
+ mu_q = F.softplus(q)
317
+ mu_k = F.softplus(k)
318
+
319
+ # Step 2: Compute phase angles (Equation 4 from paper)
320
+ # φ_q̃tc = t*θc, φ_k̃sc = s*θc
321
+ # freqs shape: (batch, 1, seq_len, head_dim)
322
+ inv_freq_expanded = self.inv_freq[None, None, None, :].to(q.device)
323
+ position_ids_expanded = position_ids[:, None, :, None].float()
324
+ freqs = position_ids_expanded * inv_freq_expanded
325
+
326
+ # Step 3: Convert to Cartesian coordinates (Equations 7-8 from paper)
327
+ # x = μ * cos(φ), y = μ * sin(φ)
328
+ cos_freqs = torch.cos(freqs)
329
+ sin_freqs = torch.sin(freqs)
330
+
331
+ q_real = mu_q * cos_freqs # x_q component
332
+ q_imag = mu_q * sin_freqs # y_q component
333
+ k_real = mu_k * cos_freqs # x_k component
334
+ k_imag = mu_k * sin_freqs # y_k component
335
+
336
+ # Step 4: Concatenate [real; imag] to create 2d dimensional vectors
337
+ # This enables Q'·K' = Σ(x_q*x_k + y_q*y_k) via standard dot product
338
+ q_pope = torch.cat([q_real, q_imag], dim=-1) # (batch, num_heads, seq_len, 2*head_dim)
339
+ k_pope = torch.cat([k_real, k_imag], dim=-1) # (batch, num_kv_heads, seq_len, 2*head_dim)
340
+
341
+ return q_pope, k_pope
342
+
343
+ def apply_pope_embedding(
344
+ q_pope: torch.Tensor,
345
+ k_pope: torch.Tensor,
346
+ delta_bias: Optional[torch.Tensor] = None,
347
+ num_key_value_groups: int = 1
348
+ ) -> tuple[torch.Tensor, torch.Tensor]:
349
+ """
350
+ Apply learnable phase bias δc to PoPE embeddings (Equation 6 from paper).
351
+
352
+ With phase bias: a^PoPE_ts = Σ μ_q μ_k cos((s-t)θc + δc)
353
+
354
+ This is implemented by rotating k by exp(i*δ) in the concatenated representation.
355
+
356
+ Args:
357
+ q_pope: Query with PoPE applied, shape (batch, num_heads, seq_len, 2*head_dim)
358
+ Format: [x_q; y_q] where first head_dim is real, second head_dim is imaginary
359
+ k_pope: Key with PoPE applied, shape (batch, num_kv_heads, seq_len, 2*head_dim)
360
+ Format: [x_k; y_k]
361
+ delta_bias: Learnable phase bias per head/dim, shape (num_attention_heads, head_dim)
362
+ Bounded to [-2π, 0] as per paper. Applied only to keys.
363
+ num_key_value_groups: Number of query groups per key/value head for GQA
364
+
365
+ Returns:
366
+ Tuple of (q_out, k_out) with delta_bias applied:
367
+ - q_out: Query unchanged (phase bias only affects keys)
368
+ - k_out: Key rotated by delta_bias
369
+ Both maintain shape with 2*head_dim
370
+ """
371
+ # Query passes through unchanged (phase bias only affects keys)
372
+ q_out = q_pope
373
+
374
+ # Apply learnable phase bias to key if provided
375
+ if delta_bias is not None:
376
+ # Get head_dim (original dimension, half of current last dim)
377
+ head_dim = k_pope.shape[-1] // 2
378
+
379
+ # Split k into real and imaginary components
380
+ k_real, k_imag = k_pope[..., :head_dim], k_pope[..., head_dim:]
381
+
382
+ # Clamp delta_bias to [-2π, 0] as specified in paper Section 3
383
+ delta_clamped = torch.clamp(delta_bias, min=-2*math.pi, max=0)
384
+
385
+ # Adapt delta_bias for GQA: (num_attention_heads, head_dim) -> (num_kv_heads, head_dim)
386
+ # Group the attention heads' biases by averaging/selecting
387
+ if num_key_value_groups > 1:
388
+ # Reshape: (num_attention_heads, head_dim) -> (num_kv_heads, num_key_value_groups, head_dim)
389
+ num_kv_heads = delta_clamped.shape[0] // num_key_value_groups
390
+ delta_clamped = delta_clamped.view(num_kv_heads, num_key_value_groups, head_dim)
391
+ # Average across the groups to get one bias per kv_head
392
+ delta_clamped = delta_clamped.mean(dim=1) # (num_kv_heads, head_dim)
393
+
394
+ # Reshape for broadcasting: (num_kv_heads, head_dim) -> (1, num_kv_heads, 1, head_dim)
395
+ delta_clamped = delta_clamped.unsqueeze(0).unsqueeze(2)
396
+
397
+ # Compute rotation components: exp(i*δ) = cos(δ) + i*sin(δ)
398
+ cos_delta = torch.cos(delta_clamped)
399
+ sin_delta = torch.sin(delta_clamped)
400
+
401
+ # Apply complex multiplication: k * exp(i*δ)
402
+ # Real part: k_real*cos(δ) - k_imag*sin(δ)
403
+ # Imag part: k_real*sin(δ) + k_imag*cos(δ)
404
+ k_real_rotated = k_real * cos_delta - k_imag * sin_delta
405
+ k_imag_rotated = k_real * sin_delta + k_imag * cos_delta
406
+
407
+ # Recombine into concatenated form [real; imag]
408
+ k_out = torch.cat([k_real_rotated, k_imag_rotated], dim=-1)
409
+ else:
410
+ k_out = k_pope
411
+
412
+ return q_out, k_out
413
 
414
 
415
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
 
434
  dropout: float = 0.0,
435
  **kwargs: Unpack[TransformersKwargs],
436
  ):
437
+ """
438
+ Standard eager attention implementation for PoPE.
439
+
440
+ Note: query and key have 2*head_dim due to PoPE concatenation [real; imag].
441
+ Value is padded to match this dimension for kernel compatibility.
442
+ """
443
  key_states = repeat_kv(key, module.num_key_value_groups)
444
  value_states = repeat_kv(value, module.num_key_value_groups)
445
 
446
+ # Standard attention computation
447
  attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
448
+
449
  if attention_mask is not None:
450
  causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
451
  attn_weights = attn_weights + causal_mask
 
461
  class NeoLLMAttention(nn.Module):
462
  """
463
  Multi-headed attention with FANformer integration, SeeDNorm for Q/K normalization,
464
+ PoPE for positional encoding, and ResFormer feature residual connections.
465
 
466
  ResFormer enhancement: Applies learnable feature residual connections from the first layer
467
  BEFORE QKV projections: H'_fan_n = λ_1 * H_fan_1 + λ_2 * H_fan_n
468
+
469
+ PoPE enhancement: Decouples 'what' and 'where' via polar coordinates for superior
470
+ length extrapolation and content/position independent matching. Uses concatenated
471
+ [real; imag] representation for FlashAttention2 compatibility (2× head_dim overhead).
472
  """
473
 
474
  def __init__(self, config: NeoLLMConfig, layer_idx: int):
 
476
  self.config = config
477
  self.layer_idx = layer_idx
478
  self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
479
+ self.num_attention_heads = config.num_attention_heads
480
+ self.num_key_value_heads = config.num_key_value_heads
481
+ self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
482
+
483
+ # PoPE uses original head_dim for scaling (not 2*head_dim)
484
  self.scaling = self.head_dim**-0.5
485
  self.attention_dropout = config.attention_dropout
486
  self.is_causal = True
 
496
 
497
  # QKV projections operate on FAN-transformed features
498
  self.q_proj = nn.Linear(
499
+ fan_output_dim, self.num_attention_heads * self.head_dim * 2, bias=config.attention_bias
500
  )
501
  self.k_proj = nn.Linear(
502
+ fan_output_dim, self.num_key_value_heads * self.head_dim, bias=config.attention_bias
503
  )
504
  self.v_proj = nn.Linear(
505
+ fan_output_dim, self.num_key_value_heads * self.head_dim, bias=config.attention_bias
506
  )
507
  self.o_proj = nn.Linear(
508
+ self.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
509
  )
510
 
511
  # SeeDNorm for Q/K normalization (replaces RMSNorm)
512
  self.q_norm = SeeDNorm(self.head_dim, eps=config.rms_norm_eps)
513
  self.k_norm = SeeDNorm(self.head_dim, eps=config.rms_norm_eps)
514
 
515
+ # PoPE: Learnable phase bias δc for each head and dimension
516
+ # Initialized based on pope_bias_init config: 'zero' or 'uniform'
517
+ pope_bias_init = getattr(config, 'pope_bias_init', 'zero')
518
+ if pope_bias_init == 'uniform':
519
+ # Uniform initialization in [-2π, 0]
520
+ delta_init = torch.empty(self.num_attention_heads, self.head_dim).uniform_(-2 * math.pi, 0)
521
+ else:
522
+ # Zero initialization (better for length extrapolation)
523
+ delta_init = torch.zeros(self.num_attention_heads, self.head_dim)
524
+
525
+ self.delta_bias = nn.Parameter(delta_init)
526
+
527
  # Dropout for attention output
528
  self.dropout = nn.Dropout(config.dropout_rate)
529
 
 
540
  **kwargs: Unpack[FlashAttentionKwargs],
541
  ) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]:
542
  input_shape = hidden_states.shape[:-1]
543
+ batch_size, seq_len = input_shape
544
 
545
  # Apply FANformer transformation first
546
  hidden_states_fan = self.fan_layer(hidden_states)
547
 
548
  # ResFormer: Apply feature residual connection BEFORE projections
 
549
  if first_layer_fan is not None:
550
  hidden_states_fan = self.lambda_1 * first_layer_fan + self.lambda_2 * hidden_states_fan
551
 
552
  # Store current FAN features for potential use as first_layer_fan in subsequent layers
553
  current_layer_fan = hidden_states_fan.clone()
554
 
555
+ # Project to Q, K, V
 
 
556
  query_states, gate = torch.chunk(
557
+ self.q_proj(hidden_states_fan).view(batch_size, seq_len, self.num_attention_heads, self.head_dim * 2),
558
+ 2, dim=-1
559
  )
560
+ gate = gate.reshape(batch_size, seq_len, -1)
561
+
562
+ key_states = self.k_proj(hidden_states_fan).view(
563
+ batch_size, seq_len, self.num_key_value_heads, self.head_dim
564
+ )
565
+ value_states = self.v_proj(hidden_states_fan).view(
566
+ batch_size, seq_len, self.num_key_value_heads, self.head_dim
567
+ )
568
+
569
+ # Apply SeeDNorm to Q and K before PoPE
570
+ query_states = self.q_norm(query_states)
571
+ key_states = self.k_norm(key_states)
572
+
573
+ # Transpose to (batch, num_heads, seq_len, head_dim)
574
+ query_states = query_states.transpose(1, 2)
575
+ key_states = key_states.transpose(1, 2)
576
+ value_states = value_states.transpose(1, 2)
577
+
578
+ # Apply PoPE: position_embeddings is (pope_emb, position_ids)
579
+ pope_emb, position_ids = position_embeddings
580
+
581
+ # Get PoPE embeddings with concatenated [real; imag] representation
582
+ # Returns Q', K' with shape (..., 2*head_dim)
583
+ query_states, key_states = pope_emb(query_states, key_states, position_ids)
584
+
585
+ # Apply learnable phase bias δc
586
+ # Apply learnable phase bias δc
587
+ query_states, key_states = apply_pope_embedding(
588
+ query_states,
589
+ key_states,
590
+ self.delta_bias,
591
+ num_key_value_groups=self.num_key_value_groups # AGREGAR ESTE PARÁMETRO
592
+ )
593
+ # Pad value to 2*head_dim for dimension compatibility
594
+ # Only first head_dim components are used in output
595
+ value_states = F.pad(value_states, (0, self.head_dim), value=0.0)
596
+
597
+ # Call attention with doubled head_dim
598
  attention_interface: Callable = eager_attention_forward
599
  if self.config._attn_implementation != "eager":
600
  attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
 
610
  **kwargs,
611
  )
612
 
613
+ # Extract only the first head_dim components (discard padding)
614
+ attn_output = attn_output[..., :self.head_dim]
615
+
616
+ attn_output = attn_output.reshape(batch_size, seq_len, -1).contiguous()
617
  attn_output = attn_output * torch.sigmoid(gate)
618
 
619
  attn_output = self.o_proj(attn_output)
 
1167
  module.lambda_1.data.fill_(0.5)
1168
  if hasattr(module, 'lambda_2'):
1169
  module.lambda_2.data.fill_(0.5)
1170
+ # PoPE delta_bias already initialized in __init__
1171
  elif isinstance(module, GPAS):
1172
  # Initialize GPAS alpha to 0 as per paper
1173
  module.alpha.data.fill_(0.0)
 
1180
  # beta (β) initialized to 0 (default in Parameter definition)
1181
  # alpha (α) initialized to 1 (default in Parameter definition)
1182
  pass
1183
+ elif isinstance(module, PolarPositionalEmbedding):
1184
+ # PoPE frequency initialization handled in __init__
1185
+ pass
1186
 
1187
 
1188
  class NeoLLMModel(NeoLLMPreTrainedModel):
 
1196
  )
1197
  # SeeDNorm for final output normalization (replaces RMSNorm)
1198
  self.norm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps)
1199
+
1200
+ # PoPE positional embedding (replaces RoPE)
1201
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
1202
+ self.pope_emb = PolarPositionalEmbedding(
1203
+ dim=head_dim,
1204
+ max_position_embeddings=config.max_position_embeddings,
1205
+ base=getattr(config, 'rope_theta', 10000.0), # Use rope_theta for backward compatibility
1206
+ )
1207
+
1208
  self.gradient_checkpointing = False
1209
 
1210
  # ResFormer: storage for first layer's FAN features (H_fan_1)
 
1242
 
1243
  hidden_states = inputs_embeds
1244
 
1245
+ # Create position embeddings for PoPE
1246
+ # position_embeddings is a tuple of (pope_emb, position_ids)
1247
+ position_embeddings = (self.pope_emb, position_ids)
1248
 
1249
  # ResFormer: reset first_layer_fan at the start of each forward pass
1250
  self.first_layer_fan = None
 
1375
  "NeoLLMConfig",
1376
  "FANLayer",
1377
  "SeeDNorm",
1378
+ "PolarPositionalEmbedding",
1379
  ]
1380
 
1381
  # Register the configuration and model for AutoClass support