KitsuVp commited on
Commit
f406785
·
verified ·
1 Parent(s): bfcc64d

Update modeling_neollm.py

Browse files
Files changed (1) hide show
  1. modeling_neollm.py +11 -0
modeling_neollm.py CHANGED
@@ -2431,6 +2431,17 @@ class NeoLLMAttention(nn.Module):
2431
  # (Li et al., 2026, §3.2 — Eq. 6–7)
2432
  repo_a = attn_analysis.repo if attn_analysis is not None else None
2433
  z = self.repo_module(hidden_states, repo_analysis=repo_a) # [B, H, S]
 
 
 
 
 
 
 
 
 
 
 
2434
  q, k = _apply_repo_rope(
2435
  q, k, z,
2436
  self._repo_inv_freq,
 
2431
  # (Li et al., 2026, §3.2 — Eq. 6–7)
2432
  repo_a = attn_analysis.repo if attn_analysis is not None else None
2433
  z = self.repo_module(hidden_states, repo_analysis=repo_a) # [B, H, S]
2434
+
2435
+ # Meta-device guard: _repo_inv_freq heredó el meta device de
2436
+ # rotary_emb.inv_freq si el modelo fue cargado con from_pretrained.
2437
+ # Se materializa una sola vez; los forwards siguientes toman el
2438
+ # path normal sin overhead adicional.
2439
+ if self._repo_inv_freq.device.type == "meta":
2440
+ inv_freq_data, _ = NeoLLMRotaryEmbedding.compute_default_rope_parameters(
2441
+ self.config, device=hidden_states.device
2442
+ )
2443
+ self.register_buffer("_repo_inv_freq", inv_freq_data.float(), persistent=False)
2444
+
2445
  q, k = _apply_repo_rope(
2446
  q, k, z,
2447
  self._repo_inv_freq,