KitsuVp commited on
Commit
7a75741
Β·
verified Β·
1 Parent(s): f406785

Update modeling_neollm.py

Browse files
Files changed (1) hide show
  1. modeling_neollm.py +22 -57
modeling_neollm.py CHANGED
@@ -1954,8 +1954,8 @@ def _apply_repo_rope(
1954
  n_groups = H // H_kv
1955
  rotary_dim = inv_freq.shape[0] * 2 # inv_freq covers half the rotary dim
1956
 
1957
- # inv_freq is already float32 on the correct device (registered as buffer
1958
- # via set_repo_inv_freq) β€” no .to() needed, no DeviceCopy op.
1959
  # No autocast barrier: explicit .float() casts on z_q/z_k are sufficient
1960
  # to maintain float32 precision for the trig ops. Removing the context
1961
  # manager lets Inductor plan all intermediate tensors as part of a single
@@ -2225,37 +2225,9 @@ class NeoLLMAttention(nn.Module):
2225
  d_p=_d_p,
2226
  num_heads=config.num_attention_heads,
2227
  )
2228
- # _repo_inv_freq is registered as a non-persistent buffer by
2229
- # set_repo_inv_freq(), called from NeoLLMModel.__init__ after
2230
- # rotary_emb is built. Declaring it here would conflict.
2231
- self._repo_attn_scaling: float = 1.0
2232
  else:
2233
  self.repo_module = None
2234
 
2235
- def set_repo_inv_freq(
2236
- self,
2237
- inv_freq: torch.Tensor,
2238
- attention_scaling: float,
2239
- ) -> None:
2240
- """
2241
- Inject the rotary frequency vector from NeoLLMRotaryEmbedding so that
2242
- REPO can build cos/sin inline from continuous positions.
2243
-
2244
- Called once by NeoLLMModel.__init__ after rotary_emb is constructed.
2245
- Only has effect when use_repo=True for this layer.
2246
-
2247
- Args:
2248
- inv_freq: [rotary_dim/2] β€” frozen inv_freq buffer from
2249
- NeoLLMRotaryEmbedding.
2250
- attention_scaling: float β€” attention_scaling from the same module.
2251
- """
2252
- if self.use_repo:
2253
- # Register as non-persistent buffer so .to(device) / .cuda() moves
2254
- # it automatically β€” eliminates the DeviceCopy op that splits the
2255
- # CUDAGraph into 2 partitions when _apply_repo_rope runs.
2256
- self.register_buffer("_repo_inv_freq", inv_freq.float(), persistent=False)
2257
- self._repo_attn_scaling = attention_scaling
2258
-
2259
  def _apply_momentum_attention(
2260
  self,
2261
  q: torch.Tensor,
@@ -2391,6 +2363,7 @@ class NeoLLMAttention(nn.Module):
2391
  attention_mask: Optional[torch.Tensor] = None,
2392
  first_layer_fan: Optional[torch.Tensor] = None,
2393
  attn_analysis: Optional[AttentionAnalysis] = None,
 
2394
  **kwargs: Unpack[FlashAttentionKwargs],
2395
  ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
2396
  input_shape = hidden_states.shape[:-1]
@@ -2428,25 +2401,14 @@ class NeoLLMAttention(nn.Module):
2428
  # REPO path: f_Ο• predicts continuous per-head positions from the
2429
  # residual stream, then cos/sin are built inline from those positions
2430
  # so the rotation is differentiable w.r.t. REPOModule parameters.
 
 
 
2431
  # (Li et al., 2026, Β§3.2 β€” Eq. 6–7)
2432
  repo_a = attn_analysis.repo if attn_analysis is not None else None
2433
  z = self.repo_module(hidden_states, repo_analysis=repo_a) # [B, H, S]
2434
-
2435
- # Meta-device guard: _repo_inv_freq heredΓ³ el meta device de
2436
- # rotary_emb.inv_freq si el modelo fue cargado con from_pretrained.
2437
- # Se materializa una sola vez; los forwards siguientes toman el
2438
- # path normal sin overhead adicional.
2439
- if self._repo_inv_freq.device.type == "meta":
2440
- inv_freq_data, _ = NeoLLMRotaryEmbedding.compute_default_rope_parameters(
2441
- self.config, device=hidden_states.device
2442
- )
2443
- self.register_buffer("_repo_inv_freq", inv_freq_data.float(), persistent=False)
2444
-
2445
- q, k = _apply_repo_rope(
2446
- q, k, z,
2447
- self._repo_inv_freq,
2448
- self._repo_attn_scaling,
2449
- )
2450
  else:
2451
  # Standard path: integer positions pre-computed by NeoLLMModel.
2452
  q, k = apply_rotary_pos_emb(q, k, cos, sin)
@@ -3213,6 +3175,7 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
3213
  attn_res_partial: Optional[torch.Tensor] = None,
3214
  layer_analysis: Optional[LayerAnalysis] = None,
3215
  output_attentions: Optional[bool] = False,
 
3216
  **kwargs: Unpack[FlashAttentionKwargs],
3217
  ) -> Tuple:
3218
  # ── Snapshot input ────────────────────────────────────���───────────
@@ -3250,6 +3213,7 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
3250
  position_embeddings=position_embeddings,
3251
  first_layer_fan=first_layer_fan,
3252
  attn_analysis=layer_analysis.attention if layer_analysis is not None else None,
 
3253
  **kwargs,
3254
  )
3255
 
@@ -3775,17 +3739,6 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
3775
 
3776
  self.post_init()
3777
 
3778
- # ── REPO: inject inv_freq into every attention layer that uses it ─────
3779
- # Done after post_init so rotary_emb.inv_freq is already initialized.
3780
- # Layers below repo_start_layer never call set_repo_inv_freq (their
3781
- # use_repo flag is False) so the call is harmless for those layers.
3782
- if getattr(config, "use_repo", False):
3783
- for layer in self.layers:
3784
- layer.self_attn.set_repo_inv_freq(
3785
- self.rotary_emb.inv_freq,
3786
- self.rotary_emb.attention_scaling,
3787
- )
3788
-
3789
  def get_input_embeddings(self):
3790
  if self.config.use_token_generator:
3791
  return self.token_generator
@@ -3919,6 +3872,17 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
3919
  position_embeddings = self.rotary_emb(hidden_states, position_ids)
3920
  self.first_layer_fan = None
3921
 
 
 
 
 
 
 
 
 
 
 
 
3922
  # ── Attention Residuals state ──────────────────────────────────────
3923
  # Full AttnRes (attn_res_num_blocks=0): sources grows by one entry per
3924
  # decoder layer β€” all previous outputs are kept, max N=num_layers+1.
@@ -3979,6 +3943,7 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
3979
  attn_res_partial=attn_res_partial if use_attn_res else None,
3980
  layer_analysis=layer_analysis,
3981
  output_attentions=output_attentions,
 
3982
  **kwargs,
3983
  )
3984
  hidden_states = layer_outputs[0]
 
1954
  n_groups = H // H_kv
1955
  rotary_dim = inv_freq.shape[0] * 2 # inv_freq covers half the rotary dim
1956
 
1957
+ # inv_freq arrives from rotary_emb at forward time via repo_rope_args β€”
1958
+ # already float32 on the correct device, no .to() needed, no DeviceCopy op.
1959
  # No autocast barrier: explicit .float() casts on z_q/z_k are sufficient
1960
  # to maintain float32 precision for the trig ops. Removing the context
1961
  # manager lets Inductor plan all intermediate tensors as part of a single
 
2225
  d_p=_d_p,
2226
  num_heads=config.num_attention_heads,
2227
  )
 
 
 
 
2228
  else:
2229
  self.repo_module = None
2230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2231
  def _apply_momentum_attention(
2232
  self,
2233
  q: torch.Tensor,
 
2363
  attention_mask: Optional[torch.Tensor] = None,
2364
  first_layer_fan: Optional[torch.Tensor] = None,
2365
  attn_analysis: Optional[AttentionAnalysis] = None,
2366
+ repo_rope_args: Optional[Tuple[torch.Tensor, float]] = None,
2367
  **kwargs: Unpack[FlashAttentionKwargs],
2368
  ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
2369
  input_shape = hidden_states.shape[:-1]
 
2401
  # REPO path: f_Ο• predicts continuous per-head positions from the
2402
  # residual stream, then cos/sin are built inline from those positions
2403
  # so the rotation is differentiable w.r.t. REPOModule parameters.
2404
+ # inv_freq and attention_scaling arrive via repo_rope_args, sourced
2405
+ # directly from rotary_emb at forward time β€” no buffer on this module,
2406
+ # no meta-tensor issue on lm_eval / to(device) paths.
2407
  # (Li et al., 2026, Β§3.2 β€” Eq. 6–7)
2408
  repo_a = attn_analysis.repo if attn_analysis is not None else None
2409
  z = self.repo_module(hidden_states, repo_analysis=repo_a) # [B, H, S]
2410
+ inv_freq, attn_scaling = repo_rope_args
2411
+ q, k = _apply_repo_rope(q, k, z, inv_freq, attn_scaling)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2412
  else:
2413
  # Standard path: integer positions pre-computed by NeoLLMModel.
2414
  q, k = apply_rotary_pos_emb(q, k, cos, sin)
 
3175
  attn_res_partial: Optional[torch.Tensor] = None,
3176
  layer_analysis: Optional[LayerAnalysis] = None,
3177
  output_attentions: Optional[bool] = False,
3178
+ repo_rope_args: Optional[Tuple[torch.Tensor, float]] = None,
3179
  **kwargs: Unpack[FlashAttentionKwargs],
3180
  ) -> Tuple:
3181
  # ── Snapshot input ────────────────────────────────────���───────────
 
3213
  position_embeddings=position_embeddings,
3214
  first_layer_fan=first_layer_fan,
3215
  attn_analysis=layer_analysis.attention if layer_analysis is not None else None,
3216
+ repo_rope_args=repo_rope_args,
3217
  **kwargs,
3218
  )
3219
 
 
3739
 
3740
  self.post_init()
3741
 
 
 
 
 
 
 
 
 
 
 
 
3742
  def get_input_embeddings(self):
3743
  if self.config.use_token_generator:
3744
  return self.token_generator
 
3872
  position_embeddings = self.rotary_emb(hidden_states, position_ids)
3873
  self.first_layer_fan = None
3874
 
3875
+ # ── REPO: pass inv_freq by reference at forward time ──────────────────
3876
+ # rotary_emb.inv_freq is already on the correct device (managed by
3877
+ # NeoLLMRotaryEmbedding as a buffer) β€” no .to(), no DeviceCopy op.
3878
+ # Computed once here and passed through the decoder layer chain so
3879
+ # NeoLLMAttention never needs to store it as a buffer itself, avoiding
3880
+ # the meta-tensor issue that occurs when lm_eval calls .to(device).
3881
+ repo_rope_args = (
3882
+ (self.rotary_emb.inv_freq, self.rotary_emb.attention_scaling)
3883
+ if getattr(self.config, "use_repo", False) else None
3884
+ )
3885
+
3886
  # ── Attention Residuals state ──────────────────────────────────────
3887
  # Full AttnRes (attn_res_num_blocks=0): sources grows by one entry per
3888
  # decoder layer β€” all previous outputs are kept, max N=num_layers+1.
 
3943
  attn_res_partial=attn_res_partial if use_attn_res else None,
3944
  layer_analysis=layer_analysis,
3945
  output_attentions=output_attentions,
3946
+ repo_rope_args=repo_rope_args,
3947
  **kwargs,
3948
  )
3949
  hidden_states = layer_outputs[0]