Update modeling_neollm.py
Browse files- modeling_neollm.py +22 -57
modeling_neollm.py
CHANGED
|
@@ -1954,8 +1954,8 @@ def _apply_repo_rope(
|
|
| 1954 |
n_groups = H // H_kv
|
| 1955 |
rotary_dim = inv_freq.shape[0] * 2 # inv_freq covers half the rotary dim
|
| 1956 |
|
| 1957 |
-
# inv_freq
|
| 1958 |
-
#
|
| 1959 |
# No autocast barrier: explicit .float() casts on z_q/z_k are sufficient
|
| 1960 |
# to maintain float32 precision for the trig ops. Removing the context
|
| 1961 |
# manager lets Inductor plan all intermediate tensors as part of a single
|
|
@@ -2225,37 +2225,9 @@ class NeoLLMAttention(nn.Module):
|
|
| 2225 |
d_p=_d_p,
|
| 2226 |
num_heads=config.num_attention_heads,
|
| 2227 |
)
|
| 2228 |
-
# _repo_inv_freq is registered as a non-persistent buffer by
|
| 2229 |
-
# set_repo_inv_freq(), called from NeoLLMModel.__init__ after
|
| 2230 |
-
# rotary_emb is built. Declaring it here would conflict.
|
| 2231 |
-
self._repo_attn_scaling: float = 1.0
|
| 2232 |
else:
|
| 2233 |
self.repo_module = None
|
| 2234 |
|
| 2235 |
-
def set_repo_inv_freq(
|
| 2236 |
-
self,
|
| 2237 |
-
inv_freq: torch.Tensor,
|
| 2238 |
-
attention_scaling: float,
|
| 2239 |
-
) -> None:
|
| 2240 |
-
"""
|
| 2241 |
-
Inject the rotary frequency vector from NeoLLMRotaryEmbedding so that
|
| 2242 |
-
REPO can build cos/sin inline from continuous positions.
|
| 2243 |
-
|
| 2244 |
-
Called once by NeoLLMModel.__init__ after rotary_emb is constructed.
|
| 2245 |
-
Only has effect when use_repo=True for this layer.
|
| 2246 |
-
|
| 2247 |
-
Args:
|
| 2248 |
-
inv_freq: [rotary_dim/2] β frozen inv_freq buffer from
|
| 2249 |
-
NeoLLMRotaryEmbedding.
|
| 2250 |
-
attention_scaling: float β attention_scaling from the same module.
|
| 2251 |
-
"""
|
| 2252 |
-
if self.use_repo:
|
| 2253 |
-
# Register as non-persistent buffer so .to(device) / .cuda() moves
|
| 2254 |
-
# it automatically β eliminates the DeviceCopy op that splits the
|
| 2255 |
-
# CUDAGraph into 2 partitions when _apply_repo_rope runs.
|
| 2256 |
-
self.register_buffer("_repo_inv_freq", inv_freq.float(), persistent=False)
|
| 2257 |
-
self._repo_attn_scaling = attention_scaling
|
| 2258 |
-
|
| 2259 |
def _apply_momentum_attention(
|
| 2260 |
self,
|
| 2261 |
q: torch.Tensor,
|
|
@@ -2391,6 +2363,7 @@ class NeoLLMAttention(nn.Module):
|
|
| 2391 |
attention_mask: Optional[torch.Tensor] = None,
|
| 2392 |
first_layer_fan: Optional[torch.Tensor] = None,
|
| 2393 |
attn_analysis: Optional[AttentionAnalysis] = None,
|
|
|
|
| 2394 |
**kwargs: Unpack[FlashAttentionKwargs],
|
| 2395 |
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
| 2396 |
input_shape = hidden_states.shape[:-1]
|
|
@@ -2428,25 +2401,14 @@ class NeoLLMAttention(nn.Module):
|
|
| 2428 |
# REPO path: f_Ο predicts continuous per-head positions from the
|
| 2429 |
# residual stream, then cos/sin are built inline from those positions
|
| 2430 |
# so the rotation is differentiable w.r.t. REPOModule parameters.
|
|
|
|
|
|
|
|
|
|
| 2431 |
# (Li et al., 2026, Β§3.2 β Eq. 6β7)
|
| 2432 |
repo_a = attn_analysis.repo if attn_analysis is not None else None
|
| 2433 |
z = self.repo_module(hidden_states, repo_analysis=repo_a) # [B, H, S]
|
| 2434 |
-
|
| 2435 |
-
|
| 2436 |
-
# rotary_emb.inv_freq si el modelo fue cargado con from_pretrained.
|
| 2437 |
-
# Se materializa una sola vez; los forwards siguientes toman el
|
| 2438 |
-
# path normal sin overhead adicional.
|
| 2439 |
-
if self._repo_inv_freq.device.type == "meta":
|
| 2440 |
-
inv_freq_data, _ = NeoLLMRotaryEmbedding.compute_default_rope_parameters(
|
| 2441 |
-
self.config, device=hidden_states.device
|
| 2442 |
-
)
|
| 2443 |
-
self.register_buffer("_repo_inv_freq", inv_freq_data.float(), persistent=False)
|
| 2444 |
-
|
| 2445 |
-
q, k = _apply_repo_rope(
|
| 2446 |
-
q, k, z,
|
| 2447 |
-
self._repo_inv_freq,
|
| 2448 |
-
self._repo_attn_scaling,
|
| 2449 |
-
)
|
| 2450 |
else:
|
| 2451 |
# Standard path: integer positions pre-computed by NeoLLMModel.
|
| 2452 |
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
|
@@ -3213,6 +3175,7 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
|
| 3213 |
attn_res_partial: Optional[torch.Tensor] = None,
|
| 3214 |
layer_analysis: Optional[LayerAnalysis] = None,
|
| 3215 |
output_attentions: Optional[bool] = False,
|
|
|
|
| 3216 |
**kwargs: Unpack[FlashAttentionKwargs],
|
| 3217 |
) -> Tuple:
|
| 3218 |
# ββ Snapshot input ββββββββββββββββββββββββββββββββββββοΏ½οΏ½οΏ½βββββββββββ
|
|
@@ -3250,6 +3213,7 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
|
| 3250 |
position_embeddings=position_embeddings,
|
| 3251 |
first_layer_fan=first_layer_fan,
|
| 3252 |
attn_analysis=layer_analysis.attention if layer_analysis is not None else None,
|
|
|
|
| 3253 |
**kwargs,
|
| 3254 |
)
|
| 3255 |
|
|
@@ -3775,17 +3739,6 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 3775 |
|
| 3776 |
self.post_init()
|
| 3777 |
|
| 3778 |
-
# ββ REPO: inject inv_freq into every attention layer that uses it βββββ
|
| 3779 |
-
# Done after post_init so rotary_emb.inv_freq is already initialized.
|
| 3780 |
-
# Layers below repo_start_layer never call set_repo_inv_freq (their
|
| 3781 |
-
# use_repo flag is False) so the call is harmless for those layers.
|
| 3782 |
-
if getattr(config, "use_repo", False):
|
| 3783 |
-
for layer in self.layers:
|
| 3784 |
-
layer.self_attn.set_repo_inv_freq(
|
| 3785 |
-
self.rotary_emb.inv_freq,
|
| 3786 |
-
self.rotary_emb.attention_scaling,
|
| 3787 |
-
)
|
| 3788 |
-
|
| 3789 |
def get_input_embeddings(self):
|
| 3790 |
if self.config.use_token_generator:
|
| 3791 |
return self.token_generator
|
|
@@ -3919,6 +3872,17 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 3919 |
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 3920 |
self.first_layer_fan = None
|
| 3921 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3922 |
# ββ Attention Residuals state ββββββββββββββββββββββββββββββββββββββ
|
| 3923 |
# Full AttnRes (attn_res_num_blocks=0): sources grows by one entry per
|
| 3924 |
# decoder layer β all previous outputs are kept, max N=num_layers+1.
|
|
@@ -3979,6 +3943,7 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 3979 |
attn_res_partial=attn_res_partial if use_attn_res else None,
|
| 3980 |
layer_analysis=layer_analysis,
|
| 3981 |
output_attentions=output_attentions,
|
|
|
|
| 3982 |
**kwargs,
|
| 3983 |
)
|
| 3984 |
hidden_states = layer_outputs[0]
|
|
|
|
| 1954 |
n_groups = H // H_kv
|
| 1955 |
rotary_dim = inv_freq.shape[0] * 2 # inv_freq covers half the rotary dim
|
| 1956 |
|
| 1957 |
+
# inv_freq arrives from rotary_emb at forward time via repo_rope_args β
|
| 1958 |
+
# already float32 on the correct device, no .to() needed, no DeviceCopy op.
|
| 1959 |
# No autocast barrier: explicit .float() casts on z_q/z_k are sufficient
|
| 1960 |
# to maintain float32 precision for the trig ops. Removing the context
|
| 1961 |
# manager lets Inductor plan all intermediate tensors as part of a single
|
|
|
|
| 2225 |
d_p=_d_p,
|
| 2226 |
num_heads=config.num_attention_heads,
|
| 2227 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2228 |
else:
|
| 2229 |
self.repo_module = None
|
| 2230 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2231 |
def _apply_momentum_attention(
|
| 2232 |
self,
|
| 2233 |
q: torch.Tensor,
|
|
|
|
| 2363 |
attention_mask: Optional[torch.Tensor] = None,
|
| 2364 |
first_layer_fan: Optional[torch.Tensor] = None,
|
| 2365 |
attn_analysis: Optional[AttentionAnalysis] = None,
|
| 2366 |
+
repo_rope_args: Optional[Tuple[torch.Tensor, float]] = None,
|
| 2367 |
**kwargs: Unpack[FlashAttentionKwargs],
|
| 2368 |
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
| 2369 |
input_shape = hidden_states.shape[:-1]
|
|
|
|
| 2401 |
# REPO path: f_Ο predicts continuous per-head positions from the
|
| 2402 |
# residual stream, then cos/sin are built inline from those positions
|
| 2403 |
# so the rotation is differentiable w.r.t. REPOModule parameters.
|
| 2404 |
+
# inv_freq and attention_scaling arrive via repo_rope_args, sourced
|
| 2405 |
+
# directly from rotary_emb at forward time β no buffer on this module,
|
| 2406 |
+
# no meta-tensor issue on lm_eval / to(device) paths.
|
| 2407 |
# (Li et al., 2026, Β§3.2 β Eq. 6β7)
|
| 2408 |
repo_a = attn_analysis.repo if attn_analysis is not None else None
|
| 2409 |
z = self.repo_module(hidden_states, repo_analysis=repo_a) # [B, H, S]
|
| 2410 |
+
inv_freq, attn_scaling = repo_rope_args
|
| 2411 |
+
q, k = _apply_repo_rope(q, k, z, inv_freq, attn_scaling)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2412 |
else:
|
| 2413 |
# Standard path: integer positions pre-computed by NeoLLMModel.
|
| 2414 |
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
|
|
|
| 3175 |
attn_res_partial: Optional[torch.Tensor] = None,
|
| 3176 |
layer_analysis: Optional[LayerAnalysis] = None,
|
| 3177 |
output_attentions: Optional[bool] = False,
|
| 3178 |
+
repo_rope_args: Optional[Tuple[torch.Tensor, float]] = None,
|
| 3179 |
**kwargs: Unpack[FlashAttentionKwargs],
|
| 3180 |
) -> Tuple:
|
| 3181 |
# ββ Snapshot input ββββββββββββββββββββββββββββββββββββοΏ½οΏ½οΏ½βββββββββββ
|
|
|
|
| 3213 |
position_embeddings=position_embeddings,
|
| 3214 |
first_layer_fan=first_layer_fan,
|
| 3215 |
attn_analysis=layer_analysis.attention if layer_analysis is not None else None,
|
| 3216 |
+
repo_rope_args=repo_rope_args,
|
| 3217 |
**kwargs,
|
| 3218 |
)
|
| 3219 |
|
|
|
|
| 3739 |
|
| 3740 |
self.post_init()
|
| 3741 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3742 |
def get_input_embeddings(self):
|
| 3743 |
if self.config.use_token_generator:
|
| 3744 |
return self.token_generator
|
|
|
|
| 3872 |
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 3873 |
self.first_layer_fan = None
|
| 3874 |
|
| 3875 |
+
# ββ REPO: pass inv_freq by reference at forward time ββββββββββββββββββ
|
| 3876 |
+
# rotary_emb.inv_freq is already on the correct device (managed by
|
| 3877 |
+
# NeoLLMRotaryEmbedding as a buffer) β no .to(), no DeviceCopy op.
|
| 3878 |
+
# Computed once here and passed through the decoder layer chain so
|
| 3879 |
+
# NeoLLMAttention never needs to store it as a buffer itself, avoiding
|
| 3880 |
+
# the meta-tensor issue that occurs when lm_eval calls .to(device).
|
| 3881 |
+
repo_rope_args = (
|
| 3882 |
+
(self.rotary_emb.inv_freq, self.rotary_emb.attention_scaling)
|
| 3883 |
+
if getattr(self.config, "use_repo", False) else None
|
| 3884 |
+
)
|
| 3885 |
+
|
| 3886 |
# ββ Attention Residuals state ββββββββββββββββββββββββββββββββββββββ
|
| 3887 |
# Full AttnRes (attn_res_num_blocks=0): sources grows by one entry per
|
| 3888 |
# decoder layer β all previous outputs are kept, max N=num_layers+1.
|
|
|
|
| 3943 |
attn_res_partial=attn_res_partial if use_attn_res else None,
|
| 3944 |
layer_analysis=layer_analysis,
|
| 3945 |
output_attentions=output_attentions,
|
| 3946 |
+
repo_rope_args=repo_rope_args,
|
| 3947 |
**kwargs,
|
| 3948 |
)
|
| 3949 |
hidden_states = layer_outputs[0]
|