Update modeling_neollm.py
Browse files- modeling_neollm.py +11 -0
modeling_neollm.py
CHANGED
|
@@ -2431,6 +2431,17 @@ class NeoLLMAttention(nn.Module):
|
|
| 2431 |
# (Li et al., 2026, §3.2 — Eq. 6–7)
|
| 2432 |
repo_a = attn_analysis.repo if attn_analysis is not None else None
|
| 2433 |
z = self.repo_module(hidden_states, repo_analysis=repo_a) # [B, H, S]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2434 |
q, k = _apply_repo_rope(
|
| 2435 |
q, k, z,
|
| 2436 |
self._repo_inv_freq,
|
|
|
|
| 2431 |
# (Li et al., 2026, §3.2 — Eq. 6–7)
|
| 2432 |
repo_a = attn_analysis.repo if attn_analysis is not None else None
|
| 2433 |
z = self.repo_module(hidden_states, repo_analysis=repo_a) # [B, H, S]
|
| 2434 |
+
|
| 2435 |
+
# Meta-device guard: _repo_inv_freq heredó el meta device de
|
| 2436 |
+
# rotary_emb.inv_freq si el modelo fue cargado con from_pretrained.
|
| 2437 |
+
# Se materializa una sola vez; los forwards siguientes toman el
|
| 2438 |
+
# path normal sin overhead adicional.
|
| 2439 |
+
if self._repo_inv_freq.device.type == "meta":
|
| 2440 |
+
inv_freq_data, _ = NeoLLMRotaryEmbedding.compute_default_rope_parameters(
|
| 2441 |
+
self.config, device=hidden_states.device
|
| 2442 |
+
)
|
| 2443 |
+
self.register_buffer("_repo_inv_freq", inv_freq_data.float(), persistent=False)
|
| 2444 |
+
|
| 2445 |
q, k = _apply_repo_rope(
|
| 2446 |
q, k, z,
|
| 2447 |
self._repo_inv_freq,
|