Update modeling_neollm.py
Browse files- modeling_neollm.py +94 -33
modeling_neollm.py
CHANGED
|
@@ -1,13 +1,13 @@
|
|
| 1 |
-
# ==================== modeling_neollm.py ====================
|
| 2 |
#!/usr/bin/env python3
|
| 3 |
"""
|
| 4 |
NeoLLM Model with FANformer Integration in both Attention and FFN, Dropout Regularization,
|
| 5 |
-
and ResFormer Value Residual Learning
|
| 6 |
-
flow through deep layers.
|
| 7 |
|
| 8 |
Updated to include:
|
| 9 |
- Fourier Analysis Network (FAN) layer for effective periodicity modeling in attention (relational space)
|
| 10 |
-
-
|
|
|
|
| 11 |
- Dropout regularization at strategic locations
|
| 12 |
- ResFormer: Feature residual connections from first layer (applied before projections)
|
| 13 |
"""
|
|
@@ -35,7 +35,7 @@ from transformers.utils.import_utils import (
|
|
| 35 |
is_causal_conv1d_available,
|
| 36 |
is_flash_linear_attention_available,
|
| 37 |
)
|
| 38 |
-
from
|
| 39 |
|
| 40 |
|
| 41 |
if is_causal_conv1d_available():
|
|
@@ -153,7 +153,74 @@ class GPAS(nn.Module):
|
|
| 153 |
return x_scaled
|
| 154 |
|
| 155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
class NeoLLMRMSNormGated(nn.Module):
|
|
|
|
|
|
|
|
|
|
| 157 |
def __init__(self, hidden_size, eps=1e-6, **kwargs):
|
| 158 |
super().__init__()
|
| 159 |
self.weight = nn.Parameter(torch.ones(hidden_size))
|
|
@@ -207,25 +274,6 @@ class NeoLLMRotaryEmbedding(nn.Module):
|
|
| 207 |
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 208 |
|
| 209 |
|
| 210 |
-
class NeoLLMRMSNorm(nn.Module):
|
| 211 |
-
def __init__(self, dim: int, eps: float = 1e-6):
|
| 212 |
-
super().__init__()
|
| 213 |
-
self.eps = eps
|
| 214 |
-
self.weight = nn.Parameter(torch.zeros(dim))
|
| 215 |
-
|
| 216 |
-
def _norm(self, x):
|
| 217 |
-
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 218 |
-
|
| 219 |
-
def forward(self, x):
|
| 220 |
-
output = self._norm(x.float())
|
| 221 |
-
# Llama does x.to(float16) * w whilst NeoLLM is (x * w).to(float16)
|
| 222 |
-
output = output * (1.0 + self.weight.float())
|
| 223 |
-
return output.type_as(x)
|
| 224 |
-
|
| 225 |
-
def extra_repr(self):
|
| 226 |
-
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
| 227 |
-
|
| 228 |
-
|
| 229 |
def rotate_half(x):
|
| 230 |
"""Rotates half the hidden dims of the input."""
|
| 231 |
x1 = x[..., : x.shape[-1] // 2]
|
|
@@ -293,7 +341,7 @@ def eager_attention_forward(
|
|
| 293 |
|
| 294 |
class NeoLLMAttention(nn.Module):
|
| 295 |
"""
|
| 296 |
-
Multi-headed attention with FANformer integration,
|
| 297 |
and ResFormer feature residual connections for enhanced information flow.
|
| 298 |
|
| 299 |
ResFormer enhancement: Applies learnable feature residual connections from the first layer
|
|
@@ -332,8 +380,10 @@ class NeoLLMAttention(nn.Module):
|
|
| 332 |
self.o_proj = nn.Linear(
|
| 333 |
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
| 334 |
)
|
| 335 |
-
|
| 336 |
-
|
|
|
|
|
|
|
| 337 |
|
| 338 |
# Dropout for attention output
|
| 339 |
self.dropout = nn.Dropout(config.dropout_rate)
|
|
@@ -371,6 +421,7 @@ class NeoLLMAttention(nn.Module):
|
|
| 371 |
)
|
| 372 |
gate = gate.reshape(*input_shape, -1)
|
| 373 |
|
|
|
|
| 374 |
query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2)
|
| 375 |
key_states = self.k_norm(self.k_proj(hidden_states_fan).view(hidden_shape)).transpose(1, 2)
|
| 376 |
value_states = self.v_proj(hidden_states_fan).view(hidden_shape).transpose(1, 2)
|
|
@@ -566,7 +617,7 @@ def torch_recurrent_gated_delta_rule(
|
|
| 566 |
|
| 567 |
class NeoLLMGatedDeltaNet(nn.Module):
|
| 568 |
"""
|
| 569 |
-
Linear attention with FANformer integration,
|
| 570 |
and ResFormer feature residual connections for enhanced information flow.
|
| 571 |
|
| 572 |
ResFormer enhancement: Applies learnable feature residual connections from the first layer
|
|
@@ -630,7 +681,7 @@ class NeoLLMGatedDeltaNet(nn.Module):
|
|
| 630 |
else FusedRMSNormGated(
|
| 631 |
self.head_v_dim,
|
| 632 |
eps=self.layer_norm_epsilon,
|
| 633 |
-
activation=fla_compatible_activation,
|
| 634 |
device=torch.cuda.current_device(),
|
| 635 |
dtype=config.dtype if config.dtype is not None else torch.get_default_dtype(),
|
| 636 |
)
|
|
@@ -849,8 +900,9 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
|
| 849 |
# MLP with FANformer integration
|
| 850 |
self.mlp = NeoLLMMLP(config)
|
| 851 |
|
| 852 |
-
|
| 853 |
-
self.
|
|
|
|
| 854 |
|
| 855 |
# LNS (LayerNorm Scaling) - applies 1/√ℓ scaling
|
| 856 |
self.lns_attn = LNS(layer_idx)
|
|
@@ -873,7 +925,7 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
|
| 873 |
) -> torch.FloatTensor:
|
| 874 |
residual = hidden_states
|
| 875 |
|
| 876 |
-
# Apply
|
| 877 |
hidden_states = self.input_layernorm(hidden_states)
|
| 878 |
|
| 879 |
# Apply LNS scaling after normalization
|
|
@@ -952,6 +1004,12 @@ class NeoLLMPreTrainedModel(PreTrainedModel):
|
|
| 952 |
elif isinstance(module, FANLayer):
|
| 953 |
# FANLayer initialization is handled within the class
|
| 954 |
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 955 |
|
| 956 |
|
| 957 |
class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
@@ -963,7 +1021,8 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 963 |
self.layers = nn.ModuleList(
|
| 964 |
[NeoLLMDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 965 |
)
|
| 966 |
-
|
|
|
|
| 967 |
self.rotary_emb = NeoLLMRotaryEmbedding(config=config)
|
| 968 |
self.gradient_checkpointing = False
|
| 969 |
|
|
@@ -1023,6 +1082,7 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 1023 |
if self.first_layer_fan is None and hasattr(decoder_layer, 'current_layer_fan'):
|
| 1024 |
self.first_layer_fan = decoder_layer.current_layer_fan
|
| 1025 |
|
|
|
|
| 1026 |
hidden_states = self.norm(hidden_states)
|
| 1027 |
|
| 1028 |
return BaseModelOutputWithPast(
|
|
@@ -1132,6 +1192,7 @@ __all__ = [
|
|
| 1132 |
"NeoLLMPreTrainedModel",
|
| 1133 |
"NeoLLMConfig",
|
| 1134 |
"FANLayer",
|
|
|
|
| 1135 |
]
|
| 1136 |
|
| 1137 |
# Register the configuration and model for AutoClass support
|
|
|
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
NeoLLM Model with FANformer Integration in both Attention and FFN, Dropout Regularization,
|
| 4 |
+
SeeDNorm (Self-Rescaled Dynamic Normalization), and ResFormer Value Residual Learning
|
| 5 |
+
for enhanced information flow through deep layers.
|
| 6 |
|
| 7 |
Updated to include:
|
| 8 |
- Fourier Analysis Network (FAN) layer for effective periodicity modeling in attention (relational space)
|
| 9 |
+
- FAN layer in FFN for featural periodicity modeling (complementary coverage)
|
| 10 |
+
- SeeDNorm: Dynamic normalization with input-dependent scaling for better adaptability
|
| 11 |
- Dropout regularization at strategic locations
|
| 12 |
- ResFormer: Feature residual connections from first layer (applied before projections)
|
| 13 |
"""
|
|
|
|
| 35 |
is_causal_conv1d_available,
|
| 36 |
is_flash_linear_attention_available,
|
| 37 |
)
|
| 38 |
+
from configuration_neollm import NeoLLMConfig
|
| 39 |
|
| 40 |
|
| 41 |
if is_causal_conv1d_available():
|
|
|
|
| 153 |
return x_scaled
|
| 154 |
|
| 155 |
|
| 156 |
+
class SeeDNorm(nn.Module):
|
| 157 |
+
"""
|
| 158 |
+
Self-Rescaled Dynamic Normalization (SeeDNorm)
|
| 159 |
+
|
| 160 |
+
From "SeeDNorm: Self-Rescaled Dynamic Normalization":
|
| 161 |
+
SeeDNorm(x) = [σ(x·β^T)·α + γ] ⊙ x/RMS(x)
|
| 162 |
+
|
| 163 |
+
Dynamically adjusts the scaling coefficient based on the current input,
|
| 164 |
+
preserving input norm information and enabling data-dependent normalization.
|
| 165 |
+
|
| 166 |
+
Key features:
|
| 167 |
+
- γ: Static scaling factor (like RMSNorm), initialized to 1
|
| 168 |
+
- β: Self-rescaling parameter, initialized to 0
|
| 169 |
+
- α: Dynamic modulation parameter, initialized to 1
|
| 170 |
+
- σ: tanh activation to constrain dynamic scaling range [-1, 1]
|
| 171 |
+
|
| 172 |
+
Args:
|
| 173 |
+
dim: Hidden dimension size
|
| 174 |
+
eps: Small constant for numerical stability
|
| 175 |
+
"""
|
| 176 |
+
|
| 177 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
| 178 |
+
super().__init__()
|
| 179 |
+
self.dim = dim
|
| 180 |
+
self.eps = eps
|
| 181 |
+
|
| 182 |
+
# Learnable parameters
|
| 183 |
+
self.gamma = nn.Parameter(torch.ones(dim)) # γ: static scaling (RMSNorm-like)
|
| 184 |
+
self.beta = nn.Parameter(torch.zeros(dim)) # β: self-rescaling parameter
|
| 185 |
+
self.alpha = nn.Parameter(torch.ones(dim)) # α: dynamic modulation parameter
|
| 186 |
+
|
| 187 |
+
def _rms_norm(self, x: torch.Tensor) -> torch.Tensor:
|
| 188 |
+
"""Compute RMS normalization: x / RMS(x)"""
|
| 189 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 190 |
+
|
| 191 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 192 |
+
"""
|
| 193 |
+
Apply Self-Rescaled Dynamic Normalization.
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
x: Input tensor of shape (..., dim)
|
| 197 |
+
|
| 198 |
+
Returns:
|
| 199 |
+
Normalized and dynamically scaled tensor of same shape
|
| 200 |
+
"""
|
| 201 |
+
# Compute input-dependent rescaling: σ(x·β^T)
|
| 202 |
+
# x·β^T produces scalar per token via dot product
|
| 203 |
+
rescale_factor = torch.tanh(torch.sum(x * self.beta, dim=-1, keepdim=True))
|
| 204 |
+
|
| 205 |
+
# Dynamic scaling coefficient: σ(x·β^T)·α + γ
|
| 206 |
+
dynamic_scale = rescale_factor * self.alpha + self.gamma
|
| 207 |
+
|
| 208 |
+
# Apply RMS normalization
|
| 209 |
+
x_normalized = self._rms_norm(x.float())
|
| 210 |
+
|
| 211 |
+
# Apply dynamic scaling
|
| 212 |
+
output = x_normalized * dynamic_scale.float()
|
| 213 |
+
|
| 214 |
+
return output.type_as(x)
|
| 215 |
+
|
| 216 |
+
def extra_repr(self) -> str:
|
| 217 |
+
return f"dim={self.dim}, eps={self.eps}"
|
| 218 |
+
|
| 219 |
+
|
| 220 |
class NeoLLMRMSNormGated(nn.Module):
|
| 221 |
+
"""
|
| 222 |
+
Gated RMSNorm variant used in specific contexts.
|
| 223 |
+
"""
|
| 224 |
def __init__(self, hidden_size, eps=1e-6, **kwargs):
|
| 225 |
super().__init__()
|
| 226 |
self.weight = nn.Parameter(torch.ones(hidden_size))
|
|
|
|
| 274 |
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 275 |
|
| 276 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
def rotate_half(x):
|
| 278 |
"""Rotates half the hidden dims of the input."""
|
| 279 |
x1 = x[..., : x.shape[-1] // 2]
|
|
|
|
| 341 |
|
| 342 |
class NeoLLMAttention(nn.Module):
|
| 343 |
"""
|
| 344 |
+
Multi-headed attention with FANformer integration, SeeDNorm for Q/K normalization,
|
| 345 |
and ResFormer feature residual connections for enhanced information flow.
|
| 346 |
|
| 347 |
ResFormer enhancement: Applies learnable feature residual connections from the first layer
|
|
|
|
| 380 |
self.o_proj = nn.Linear(
|
| 381 |
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
| 382 |
)
|
| 383 |
+
|
| 384 |
+
# SeeDNorm for Q/K normalization (replaces RMSNorm)
|
| 385 |
+
self.q_norm = SeeDNorm(self.head_dim, eps=config.rms_norm_eps)
|
| 386 |
+
self.k_norm = SeeDNorm(self.head_dim, eps=config.rms_norm_eps)
|
| 387 |
|
| 388 |
# Dropout for attention output
|
| 389 |
self.dropout = nn.Dropout(config.dropout_rate)
|
|
|
|
| 421 |
)
|
| 422 |
gate = gate.reshape(*input_shape, -1)
|
| 423 |
|
| 424 |
+
# Apply SeeDNorm to Q and K
|
| 425 |
query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2)
|
| 426 |
key_states = self.k_norm(self.k_proj(hidden_states_fan).view(hidden_shape)).transpose(1, 2)
|
| 427 |
value_states = self.v_proj(hidden_states_fan).view(hidden_shape).transpose(1, 2)
|
|
|
|
| 617 |
|
| 618 |
class NeoLLMGatedDeltaNet(nn.Module):
|
| 619 |
"""
|
| 620 |
+
Linear attention with FANformer integration, SeeDNorm for normalization,
|
| 621 |
and ResFormer feature residual connections for enhanced information flow.
|
| 622 |
|
| 623 |
ResFormer enhancement: Applies learnable feature residual connections from the first layer
|
|
|
|
| 681 |
else FusedRMSNormGated(
|
| 682 |
self.head_v_dim,
|
| 683 |
eps=self.layer_norm_epsilon,
|
| 684 |
+
activation=fla_compatible_activation,
|
| 685 |
device=torch.cuda.current_device(),
|
| 686 |
dtype=config.dtype if config.dtype is not None else torch.get_default_dtype(),
|
| 687 |
)
|
|
|
|
| 900 |
# MLP with FANformer integration
|
| 901 |
self.mlp = NeoLLMMLP(config)
|
| 902 |
|
| 903 |
+
# SeeDNorm for input and post-attention normalization (replaces RMSNorm)
|
| 904 |
+
self.input_layernorm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 905 |
+
self.post_attention_layernorm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 906 |
|
| 907 |
# LNS (LayerNorm Scaling) - applies 1/√ℓ scaling
|
| 908 |
self.lns_attn = LNS(layer_idx)
|
|
|
|
| 925 |
) -> torch.FloatTensor:
|
| 926 |
residual = hidden_states
|
| 927 |
|
| 928 |
+
# Apply SeeDNorm normalization
|
| 929 |
hidden_states = self.input_layernorm(hidden_states)
|
| 930 |
|
| 931 |
# Apply LNS scaling after normalization
|
|
|
|
| 1004 |
elif isinstance(module, FANLayer):
|
| 1005 |
# FANLayer initialization is handled within the class
|
| 1006 |
pass
|
| 1007 |
+
elif isinstance(module, SeeDNorm):
|
| 1008 |
+
# SeeDNorm initialization:
|
| 1009 |
+
# gamma (γ) initialized to 1 (default in Parameter definition)
|
| 1010 |
+
# beta (β) initialized to 0 (default in Parameter definition)
|
| 1011 |
+
# alpha (α) initialized to 1 (default in Parameter definition)
|
| 1012 |
+
pass
|
| 1013 |
|
| 1014 |
|
| 1015 |
class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
|
|
| 1021 |
self.layers = nn.ModuleList(
|
| 1022 |
[NeoLLMDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 1023 |
)
|
| 1024 |
+
# SeeDNorm for final output normalization (replaces RMSNorm)
|
| 1025 |
+
self.norm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 1026 |
self.rotary_emb = NeoLLMRotaryEmbedding(config=config)
|
| 1027 |
self.gradient_checkpointing = False
|
| 1028 |
|
|
|
|
| 1082 |
if self.first_layer_fan is None and hasattr(decoder_layer, 'current_layer_fan'):
|
| 1083 |
self.first_layer_fan = decoder_layer.current_layer_fan
|
| 1084 |
|
| 1085 |
+
# Apply SeeDNorm for final normalization
|
| 1086 |
hidden_states = self.norm(hidden_states)
|
| 1087 |
|
| 1088 |
return BaseModelOutputWithPast(
|
|
|
|
| 1192 |
"NeoLLMPreTrainedModel",
|
| 1193 |
"NeoLLMConfig",
|
| 1194 |
"FANLayer",
|
| 1195 |
+
"SeeDNorm",
|
| 1196 |
]
|
| 1197 |
|
| 1198 |
# Register the configuration and model for AutoClass support
|