KitsuVp commited on
Commit
b9eef49
Β·
verified Β·
1 Parent(s): 4e7638f

Upload 2 files

Browse files
Files changed (2) hide show
  1. configuration_unified.py +129 -0
  2. modeling_unified.py +824 -0
configuration_unified.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ====================================================================
2
+ # configuration_unified.py
3
+ # ====================================================================
4
+
5
+ """
6
+ Configuration class for Unified Language Model
7
+ HuggingFace Transformers compatible configuration with AutoClass support
8
+ """
9
+
10
+ from transformers import PretrainedConfig
11
+ from typing import Optional
12
+
13
+ class UnifiedModelConfig(PretrainedConfig):
14
+ """
15
+ Configuration class for UnifiedModel.
16
+ Inherits from PretrainedConfig for full HuggingFace compatibility.
17
+ """
18
+ model_type = "unified_model"
19
+
20
+ def __init__(
21
+ self,
22
+ vocab_size: int = None,
23
+ hidden_size: int = 256,
24
+ intermediate_size: int = 1024,
25
+ num_hidden_layers: int = 6,
26
+ num_attention_heads: int = 8,
27
+ num_key_value_heads: int = 4,
28
+ max_position_embeddings: int = 2048,
29
+ rms_norm_eps: float = 1e-6,
30
+ rope_theta: float = 10000.0,
31
+
32
+ attention_dropout: float = 0.1,
33
+ mlp_dropout: float = 0.1,
34
+ embedding_dropout: float = 0.1,
35
+
36
+ xielu_alpha_p_init: float = 0.8,
37
+ xielu_alpha_n_init: float = 0.8,
38
+ xielu_beta: float = 0.5,
39
+
40
+ tie_word_embeddings: bool = True, # HuggingFace standard parameter name
41
+
42
+ # LaX configuration (Linear only)
43
+ lax_enabled: bool = True,
44
+ lax_gate_type: str = "linear", # Only "linear" supported now
45
+
46
+ # Canon Layers configuration (A+C only)
47
+ canon_enabled: bool = True,
48
+ canon_kernel_size: int = 4,
49
+ canon_a_enabled: bool = True, # Before attention
50
+ canon_c_enabled: bool = True, # Before MLP
51
+ # Canon B and D are permanently disabled
52
+
53
+ # FANFormer configuration
54
+ fanformer_p: float = 0.15,
55
+
56
+ # HuggingFace standard parameters
57
+ pad_token_id: int = None,
58
+ bos_token_id: int = None,
59
+ eos_token_id: int = None,
60
+
61
+ **kwargs
62
+ ):
63
+ super().__init__(
64
+ pad_token_id=pad_token_id,
65
+ bos_token_id=bos_token_id,
66
+ eos_token_id=eos_token_id,
67
+ tie_word_embeddings=tie_word_embeddings,
68
+ **kwargs
69
+ )
70
+
71
+ self.vocab_size = vocab_size
72
+ self.hidden_size = hidden_size
73
+ self.intermediate_size = intermediate_size
74
+ self.num_hidden_layers = num_hidden_layers
75
+ self.num_attention_heads = num_attention_heads
76
+ self.num_key_value_heads = num_key_value_heads
77
+ self.max_position_embeddings = max_position_embeddings
78
+ self.rms_norm_eps = rms_norm_eps
79
+ self.rope_theta = rope_theta
80
+
81
+ self.attention_dropout = attention_dropout
82
+ self.mlp_dropout = mlp_dropout
83
+ self.embedding_dropout = embedding_dropout
84
+
85
+ self.xielu_alpha_p_init = xielu_alpha_p_init
86
+ self.xielu_alpha_n_init = xielu_alpha_n_init
87
+ self.xielu_beta = xielu_beta
88
+ self.tie_word_embeddings = tie_word_embeddings
89
+
90
+ # LaX configuration
91
+ self.lax_enabled = lax_enabled
92
+ self.lax_gate_type = lax_gate_type
93
+
94
+ # Canon Layers configuration
95
+ self.canon_enabled = canon_enabled
96
+ self.canon_kernel_size = canon_kernel_size
97
+ self.canon_a_enabled = canon_a_enabled
98
+ self.canon_c_enabled = canon_c_enabled
99
+
100
+ # FANFormer
101
+ self.fanformer_p = fanformer_p
102
+
103
+ # βœ… FIXED: Force complete auto_map in config.json
104
+ self.auto_map = {
105
+ "AutoConfig": "configuration_unified.UnifiedModelConfig",
106
+ "AutoModel": "modeling_unified.UnifiedModel",
107
+ "AutoModelForCausalLM": "modeling_unified.UnifiedModel"
108
+ }
109
+
110
+ def to_diff_dict(self):
111
+ """
112
+ βœ… FIXED: Fuerza la serializaciΓ³n de tie_word_embeddings en config.json
113
+
114
+ Sobreescribe to_diff_dict() para asegurar que tie_word_embeddings
115
+ siempre aparezca en el config.json, evitando problemas de carga
116
+ donde HuggingFace no reconoce el weight tying.
117
+
118
+ Returns:
119
+ Dict: ConfiguraciΓ³n optimizada con tie_word_embeddings forzado
120
+ """
121
+ # Obtiene la serializaciΓ³n normal (solo diferencias)
122
+ output = super().to_diff_dict()
123
+
124
+ # βœ… FUERZA la inclusiΓ³n de tie_word_embeddings
125
+ # Esto asegura que aparezca en config.json sin importar si HF
126
+ # considera que es "default" o no
127
+ output["tie_word_embeddings"] = self.tie_word_embeddings
128
+
129
+ return output
modeling_unified.py ADDED
@@ -0,0 +1,824 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ====================================================================
2
+ # modeling_unified.py
3
+ # ====================================================================
4
+
5
+ """
6
+ Unified Language Model with GPAS + LNS Integration + xIELU Activation + CoLA (Linear Only) + LaX + Weight Tying + Canon Layers (A+C Only)
7
+ MIGRATED TO HUGGINGFACE TRANSFORMERS - FINAL VERSION WITH ALL FIXES + CORRECTED LaX IMPLEMENTATION
8
+ UPDATED: Standard Transformer with advanced variance control, parameter efficiency, Canon horizontal information flow, and WORKING LaX Inter-Layer
9
+ Combines advanced Transformer architecture with CORRECTED variance control mechanisms,
10
+ advanced variance control via GPAS and LNS, xIELU activation function, FIXED LaX integration, and Canon Layers (A+C only)
11
+ Based on LLaMA 3 architecture with 30M parameters
12
+
13
+ MIGRATION TO HUGGINGFACE - FINAL FIXED VERSION + LaX CORRECTION:
14
+ ==============================================================
15
+
16
+ 1. **HUGGINGFACE INTEGRATION**: Migrado de PyTorch Lightning a Transformers v4.53.3
17
+ 2. **UPDATED API**: processing_class en lugar de tokenizer (deprecated)
18
+ 3. **UPDATED COMPUTE_LOSS**: MΓ©todo actualizado con num_items_in_batch parameter
19
+ 4. **FIXED LOGGING**: Corregido self.log() syntax segΓΊn documentaciΓ³n oficial HF
20
+ 5. **RESTORED PAD HANDLING**: pad_token_id β†’ -100 conversion for CrossEntropyLoss (from original code)
21
+ 6. **NATIVE TORCH COMPILE**: Moved to TrainingArguments (torch_compile=True)
22
+ 7. **FIXED WEIGHT TYING**: Corrected _tied_weights_keys as class attribute (HF standard)
23
+ 8. **VALIDATION DIAGNOSTIC**: Added simple method to diagnose validation loss issues
24
+ 9. **CUSTOM CONFIGURATION**: PretrainedConfig personalizada con todos los parΓ‘metros
25
+ 10. **PRETRAINED MODEL**: Hereda de PreTrainedModel para compatibilidad completa
26
+ 11. **MAINTAINED OPTIMIZERS**: Muon + AdamW hΓ­brido preservado
27
+ 12. **MAINTAINED PRECISION**: bf16-true preservado
28
+ 13. **MAINTAINED TRAINING**: Custom Trainer con todas las mΓ©tricas y logging
29
+ 14. **MAINTAINED ARCHITECTURE**: Toda la arquitectura personalizada preservada
30
+ 15. **AUTO TOKENIZER**: IntegraciΓ³n completa con AutoTokenizer dinΓ‘mico
31
+ 16. **AUTOCLASS SUPPORT**: Registro completo para AutoConfig y AutoModel
32
+ 17. **βœ… FIXED LaX**: ImplementaciΓ³n correcta Inter-Layer con Linear Gate funcional
33
+ """
34
+
35
+ import torch
36
+ import torch.nn as nn
37
+ import torch.nn.functional as F
38
+ from torch.utils.checkpoint import checkpoint
39
+ from transformers import (
40
+ AutoTokenizer,
41
+ AutoConfig,
42
+ AutoModel,
43
+ AutoModelForCausalLM,
44
+ PreTrainedModel,
45
+ )
46
+ import math
47
+ import os
48
+ from typing import Optional, Tuple, Dict, Any, cast, List
49
+ from flash_attn import flash_attn_func
50
+ import numpy as np
51
+
52
+ # βœ… ABSOLUTE IMPORT - No relative imports for Hub compatibility
53
+ from configuration_unified import UnifiedModelConfig
54
+
55
+ # Fix tokenizer parallelism warnings
56
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
57
+ torch.set_float32_matmul_precision('high')
58
+
59
+ def init_cola_components(A: nn.Linear, B: nn.Linear):
60
+ nn.init.kaiming_normal_(A.weight, mode='fan_in', nonlinearity='relu')
61
+ nn.init.xavier_normal_(B.weight, gain=0.8)
62
+ if B.bias is not None:
63
+ nn.init.zeros_(B.bias)
64
+
65
+ def init_embedding(embedding: nn.Embedding):
66
+ nn.init.normal_(embedding.weight, mean=0.0, std=0.02)
67
+
68
+ class CanonLayer(nn.Module):
69
+ def __init__(self, hidden_dim: int, kernel_size: int = 4):
70
+ """
71
+ Canon layer using a 1D causal convolution with residual connection.
72
+ """
73
+ super().__init__()
74
+ self.hidden_dim = hidden_dim
75
+ self.kernel_size = kernel_size
76
+
77
+ # Use causal convolution with explicit initialization
78
+ self.causal_conv1d = nn.Conv1d(
79
+ in_channels=hidden_dim,
80
+ out_channels=hidden_dim,
81
+ kernel_size=kernel_size,
82
+ groups=hidden_dim, # Depthwise convolution
83
+ padding=0, # No automatic padding
84
+ bias=True
85
+ )
86
+
87
+ # Initialize weights more conservatively (as per paper)
88
+ nn.init.zeros_(self.causal_conv1d.weight)
89
+ nn.init.zeros_(self.causal_conv1d.bias)
90
+
91
+ def forward(self, h: torch.Tensor) -> torch.Tensor:
92
+ """
93
+ Applies the Canon layer transformation with causal masking.
94
+ """
95
+ # Conv1d expects input shape (batch_size, channels, sequence_length)
96
+ h_permuted = h.permute(0, 2, 1) # (batch, hidden_dim, seq_len)
97
+
98
+ # Add padding of (kernel_size - 1) only to the left side
99
+ padding = self.kernel_size - 1
100
+ h_padded = F.pad(h_permuted, (padding, 0))
101
+
102
+ # Apply causal convolution
103
+ conv_out = self.causal_conv1d(h_padded)
104
+
105
+ # Permute back to the original shape
106
+ conv_out_permuted = conv_out.permute(0, 2, 1)
107
+
108
+ # Add the residual connection
109
+ output = h + conv_out_permuted
110
+
111
+ return output
112
+
113
+ class CoLA_Linear(nn.Module):
114
+ def __init__(self, in_features: int, out_features: int, rank: Optional[int] = None, activation=F.gelu, bias: bool = True):
115
+ super().__init__()
116
+ if rank is None:
117
+ rank = in_features // 4
118
+ self.rank = rank
119
+ self.activation = activation
120
+
121
+ self.A = nn.Linear(in_features, rank, bias=False)
122
+ self.B = nn.Linear(rank, out_features, bias=bias)
123
+
124
+ init_cola_components(self.A, self.B)
125
+
126
+ def forward(self, x: torch.Tensor, prev_latent: Optional[torch.Tensor] = None, lax_beta: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
127
+ """
128
+ Forward pass with optional LaX Inter-Layer integration.
129
+
130
+ Args:
131
+ x: Input tensor
132
+ prev_latent: Previous latent from same module type in previous layer (for LaX)
133
+ lax_beta: Linear gate parameter (scalar) for LaX
134
+
135
+ Returns:
136
+ Tuple of (output, current_latent) where current_latent can be used for next layer
137
+ """
138
+ # Standard CoLA forward: A -> activation
139
+ latent = self.A(x)
140
+ latent_activated = self.activation(latent)
141
+
142
+ # Apply LaX Inter-Layer if previous latent exists
143
+ if prev_latent is not None and lax_beta is not None and prev_latent.shape == latent_activated.shape:
144
+ # Linear Gate: h_i = h_i + Ξ² * h_{i-1}
145
+ latent_activated = latent_activated + lax_beta * prev_latent
146
+
147
+ # B projection
148
+ output = self.B(latent_activated)
149
+
150
+ return output, latent_activated
151
+
152
+ class LayerNormScaling(nn.Module):
153
+ def __init__(self, layer_depth: int):
154
+ super().__init__()
155
+
156
+ if layer_depth < 1:
157
+ raise ValueError(f"layer_depth debe ser β‰₯ 1, got {layer_depth}")
158
+
159
+ self.layer_depth = layer_depth
160
+ self.scaling_factor = 1.0 / math.sqrt(float(layer_depth))
161
+
162
+ def forward(self, normalized_input: torch.Tensor) -> torch.Tensor:
163
+ return normalized_input * self.scaling_factor
164
+
165
+ class GPAS(nn.Module):
166
+ def __init__(self, d_model: int):
167
+ super().__init__()
168
+
169
+ self.d_model = d_model
170
+ self.alpha = nn.Parameter(torch.zeros(1))
171
+
172
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
173
+ x_detached = x.detach()
174
+ scaled_component = F.silu(self.alpha) * x_detached
175
+ x_scaled = x - scaled_component
176
+
177
+ return x_scaled
178
+
179
+ class RotaryEmbedding(nn.Module):
180
+ def __init__(self, dim: int, max_position_embeddings: int = 2048, base: float = 10000):
181
+ super().__init__()
182
+ self.dim = dim
183
+ self.max_position_embeddings = max_position_embeddings
184
+ self.base = base
185
+
186
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
187
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
188
+
189
+ def forward(self, x, seq_len=None):
190
+ if seq_len is None:
191
+ seq_len = x.shape[-2]
192
+ t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
193
+ freqs = torch.outer(t, self.inv_freq)
194
+ emb = torch.cat((freqs, freqs), dim=-1)
195
+ return emb.cos().to(x.dtype), emb.sin().to(x.dtype)
196
+
197
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None):
198
+ def rotate_half(x):
199
+ x1 = x[..., : x.shape[-1] // 2]
200
+ x2 = x[..., x.shape[-1] // 2 :]
201
+ return torch.cat((-x2, x1), dim=-1)
202
+
203
+ q_embed = (q * cos) + (rotate_half(q) * sin)
204
+ k_embed = (k * cos) + (rotate_half(k) * sin)
205
+ return q_embed, k_embed
206
+
207
+ class XIELU(nn.Module):
208
+ def __init__(self, alpha_p_init: float = 0.8, alpha_n_init: float = 0.8, beta: float = 0.5):
209
+ super().__init__()
210
+
211
+ self.beta = beta
212
+
213
+ self.alpha_p = nn.Parameter(torch.log(torch.exp(torch.tensor(alpha_p_init)) - 1))
214
+ self.alpha_n = nn.Parameter(torch.log(torch.exp(torch.tensor(alpha_n_init - self.beta)) - 1))
215
+
216
+ self.register_buffer('eps', torch.tensor(-1e-6))
217
+
218
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
219
+ alpha_p = F.softplus(self.alpha_p)
220
+ alpha_n = self.beta + F.softplus(self.alpha_n)
221
+
222
+ return torch.where(
223
+ x > 0,
224
+ alpha_p * x * x + self.beta * x,
225
+ alpha_n * torch.expm1(torch.clamp(x, min=self.eps)) - alpha_n * x + self.beta * x
226
+ )
227
+
228
+ class StandardMLP(nn.Module):
229
+ def __init__(self, hidden_size: int, intermediate_size: int, dropout: float = 0.0, config=None, layer_idx: int = 0):
230
+ super().__init__()
231
+
232
+ self.hidden_size = hidden_size
233
+ self.intermediate_size = intermediate_size
234
+ self.config = config
235
+ self.layer_idx = layer_idx
236
+
237
+ self.up_proj = CoLA_Linear(hidden_size, intermediate_size, bias=False)
238
+ self.down_proj = CoLA_Linear(intermediate_size, hidden_size, bias=False)
239
+
240
+ if config is not None:
241
+ self.activation = XIELU(
242
+ alpha_p_init=config.xielu_alpha_p_init,
243
+ alpha_n_init=config.xielu_alpha_n_init,
244
+ beta=config.xielu_beta
245
+ )
246
+ else:
247
+ self.activation = XIELU(alpha_p_init=0.8, alpha_n_init=0.8, beta=0.5)
248
+
249
+ self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
250
+
251
+ # LaX Linear Gate parameters (Ξ² scalars)
252
+ if config is not None and config.lax_enabled:
253
+ self.lax_beta_up = nn.Parameter(torch.full((1,), 0.2)) # 0.0 β†’ 0.2
254
+ self.lax_beta_down = nn.Parameter(torch.full((1,), 0.2)) # 0.0 β†’ 0.2
255
+ else:
256
+ self.lax_beta_up = None
257
+ self.lax_beta_down = None
258
+
259
+ def forward(self, x: torch.Tensor, lax_buffer: Optional[Dict] = None) -> torch.Tensor:
260
+ # LaX: Get previous latents from buffer
261
+ prev_up_latent = None
262
+ prev_down_latent = None
263
+ if lax_buffer is not None and self.lax_beta_up is not None:
264
+ prev_up_latent = lax_buffer.get(('mlp_up', self.layer_idx - 1))
265
+ prev_down_latent = lax_buffer.get(('mlp_down', self.layer_idx - 1))
266
+
267
+ # Up projection with LaX
268
+ intermediate, up_latent = self.up_proj(x, prev_up_latent, self.lax_beta_up)
269
+
270
+ # Store current up latent for next layer
271
+ if lax_buffer is not None:
272
+ lax_buffer[('mlp_up', self.layer_idx)] = up_latent.clone()
273
+
274
+ # Activation and dropout
275
+ activated = self.activation(intermediate)
276
+ activated = self.dropout(activated)
277
+
278
+ # Down projection with LaX
279
+ output, down_latent = self.down_proj(activated, prev_down_latent, self.lax_beta_down)
280
+
281
+ # Store current down latent for next layer
282
+ if lax_buffer is not None:
283
+ lax_buffer[('mlp_down', self.layer_idx)] = down_latent.clone()
284
+
285
+ return output
286
+
287
+ class GroupedQueryAttention(nn.Module):
288
+ def __init__(self, config, layer_idx: int = 0):
289
+ super().__init__()
290
+ self.config = config
291
+ self.layer_idx = layer_idx
292
+ self.hidden_size = config.hidden_size
293
+ self.num_heads = config.num_attention_heads
294
+ self.num_key_value_heads = config.num_key_value_heads
295
+ self.head_dim = self.hidden_size // self.num_heads
296
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
297
+
298
+ # FANFormer components
299
+ self.fanformer_p = getattr(config, 'fanformer_p', 0.15)
300
+
301
+ self.d_periodic = int(self.hidden_size * self.fanformer_p)
302
+ self.d_standard = self.hidden_size - 2 * self.d_periodic
303
+
304
+ assert self.d_standard > 0, \
305
+ f"fanformer_p={self.fanformer_p} is too high. d_standard={self.d_standard} must be > 0"
306
+
307
+ self.fan_w_p = CoLA_Linear(self.hidden_size, self.d_periodic, bias=False)
308
+ self.fan_w_p_bar = CoLA_Linear(self.hidden_size, self.d_standard, bias=False)
309
+
310
+ self.q_proj = CoLA_Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
311
+ self.k_proj = CoLA_Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
312
+ self.v_proj = CoLA_Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
313
+ self.o_proj = CoLA_Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
314
+
315
+ self.q_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps)
316
+ self.k_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps)
317
+ self.v_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps)
318
+
319
+ self.rotary_emb = RotaryEmbedding(
320
+ self.head_dim,
321
+ max_position_embeddings=config.max_position_embeddings,
322
+ base=config.rope_theta
323
+ )
324
+
325
+ # LaX Linear Gate parameters (Ξ² scalars) - NO o_proj segΓΊn plan
326
+ if config.lax_enabled:
327
+ self.lax_beta_q = nn.Parameter(torch.full((1,), 0.2)) # 0.0 β†’ 0.2
328
+ self.lax_beta_k = nn.Parameter(torch.full((1,), 0.2)) # 0.0 β†’ 0.2
329
+ self.lax_beta_v = nn.Parameter(torch.full((1,), 0.2)) # 0.0 β†’ 0.2
330
+ else:
331
+ self.lax_beta_q = None
332
+ self.lax_beta_k = None
333
+ self.lax_beta_v = None
334
+
335
+ def _fan_layer_prime(self, x: torch.Tensor) -> torch.Tensor:
336
+ periodic_proj, _ = self.fan_w_p(x)
337
+ standard_proj, _ = self.fan_w_p_bar(x)
338
+
339
+ cos_component = torch.cos(periodic_proj)
340
+ sin_component = torch.sin(periodic_proj)
341
+
342
+ x_f = torch.cat([cos_component, sin_component, standard_proj], dim=-1)
343
+
344
+ return x_f
345
+
346
+ def _compute_flash_attention(
347
+ self,
348
+ query_states: torch.Tensor,
349
+ key_states: torch.Tensor,
350
+ value_states: torch.Tensor,
351
+ seq_len: int,
352
+ position_ids: Optional[torch.Tensor] = None
353
+ ) -> torch.Tensor:
354
+ batch_size = query_states.shape[0]
355
+
356
+ q_rope = query_states.transpose(1, 2)
357
+ k_rope = key_states.transpose(1, 2)
358
+
359
+ cos, sin = self.rotary_emb(value_states, seq_len=seq_len)
360
+ q_rope, k_rope = apply_rotary_pos_emb(q_rope, k_rope, cos, sin, position_ids)
361
+
362
+ query_states = q_rope.transpose(1, 2)
363
+ key_states = k_rope.transpose(1, 2)
364
+
365
+ from flash_attn import flash_attn_func
366
+
367
+ attn_output = flash_attn_func(
368
+ query_states,
369
+ key_states,
370
+ value_states,
371
+ dropout_p=self.config.attention_dropout if self.training else 0.0,
372
+ causal=True,
373
+ )
374
+
375
+ return attn_output
376
+
377
+ def forward(self, hidden_states, position_ids=None, attention_mask=None, lax_buffer: Optional[Dict] = None):
378
+ batch_size, seq_len, _ = hidden_states.shape
379
+
380
+ enhanced_input = self._fan_layer_prime(hidden_states)
381
+
382
+ # LaX: Get previous latents from buffer
383
+ prev_q_latent = None
384
+ prev_k_latent = None
385
+ prev_v_latent = None
386
+ if lax_buffer is not None and self.lax_beta_q is not None:
387
+ prev_q_latent = lax_buffer.get(('attn_q', self.layer_idx - 1))
388
+ prev_k_latent = lax_buffer.get(('attn_k', self.layer_idx - 1))
389
+ prev_v_latent = lax_buffer.get(('attn_v', self.layer_idx - 1))
390
+
391
+ # Q/K/V projections with LaX
392
+ query_states, q_latent = self.q_proj(enhanced_input, prev_q_latent, self.lax_beta_q)
393
+ key_states, k_latent = self.k_proj(enhanced_input, prev_k_latent, self.lax_beta_k)
394
+ value_states, v_latent = self.v_proj(enhanced_input, prev_v_latent, self.lax_beta_v)
395
+
396
+ # Store current latents for next layer
397
+ if lax_buffer is not None:
398
+ lax_buffer[('attn_q', self.layer_idx)] = q_latent.clone()
399
+ lax_buffer[('attn_k', self.layer_idx)] = k_latent.clone()
400
+ lax_buffer[('attn_v', self.layer_idx)] = v_latent.clone()
401
+
402
+ query_states = query_states.view(batch_size, seq_len, self.num_heads, self.head_dim)
403
+ key_states = key_states.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim)
404
+ value_states = value_states.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim)
405
+
406
+ q_flat = query_states.reshape(-1, self.head_dim)
407
+ k_flat = key_states.reshape(-1, self.head_dim)
408
+ v_flat = value_states.reshape(-1, self.head_dim)
409
+
410
+ q_normalized = self.q_norm(q_flat)
411
+ k_normalized = self.k_norm(k_flat)
412
+ v_normalized = self.v_norm(v_flat)
413
+
414
+ query_states = q_normalized.view(batch_size, seq_len, self.num_heads, self.head_dim)
415
+ key_states = k_normalized.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim)
416
+ value_states = v_normalized.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim)
417
+
418
+ attn_output = self._compute_flash_attention(
419
+ query_states=query_states,
420
+ key_states=key_states,
421
+ value_states=value_states,
422
+ seq_len=seq_len,
423
+ position_ids=position_ids
424
+ )
425
+
426
+ attn_output = attn_output.reshape(batch_size, seq_len, self.hidden_size)
427
+
428
+ # O projection WITHOUT LaX (segΓΊn plan)
429
+ output, _ = self.o_proj(attn_output)
430
+ return output
431
+
432
+ class DecoderLayer(nn.Module):
433
+ def __init__(self, config, layer_idx: int):
434
+ super().__init__()
435
+ self.config = config
436
+ self.layer_idx = layer_idx
437
+
438
+ if layer_idx < 0:
439
+ raise ValueError(f"layer_idx debe ser >= 0, got {layer_idx}")
440
+
441
+ self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
442
+ self.self_attn = GroupedQueryAttention(config, layer_idx)
443
+ self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
444
+
445
+ self.mlp = StandardMLP(
446
+ config.hidden_size,
447
+ config.intermediate_size,
448
+ config.mlp_dropout,
449
+ config,
450
+ layer_idx
451
+ )
452
+
453
+ self.dropout_output = nn.Dropout(0.01)
454
+
455
+ self.lns_attention = LayerNormScaling(layer_depth=layer_idx + 1)
456
+ self.lns_mlp = LayerNormScaling(layer_depth=layer_idx + 1)
457
+
458
+ self.gpas_attention = GPAS(config.hidden_size)
459
+ self.gpas_mlp = GPAS(config.hidden_size)
460
+
461
+ # Canon layers (A+C only)
462
+ # Canon-A: Before attention block
463
+ if config.canon_enabled and config.canon_a_enabled:
464
+ self.canon_a = CanonLayer(config.hidden_size, config.canon_kernel_size)
465
+ else:
466
+ self.canon_a = None
467
+
468
+ # Canon-C: Before MLP block
469
+ if config.canon_enabled and config.canon_c_enabled:
470
+ self.canon_c = CanonLayer(config.hidden_size, config.canon_kernel_size)
471
+ else:
472
+ self.canon_c = None
473
+
474
+ def forward(self, hidden_states: torch.Tensor, position_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, lax_buffer: Optional[Dict] = None) -> torch.Tensor:
475
+ residual = hidden_states
476
+
477
+ # Apply Canon-A before attention
478
+ if self.canon_a is not None:
479
+ hidden_states = self.canon_a(hidden_states)
480
+
481
+ attention_input = self.input_layernorm(hidden_states)
482
+ attention_input = self.lns_attention(attention_input)
483
+ attention_output = self.self_attn(attention_input, position_ids, attention_mask, lax_buffer)
484
+ hidden_states = residual + attention_output
485
+ hidden_states = self.gpas_attention(hidden_states)
486
+ hidden_states = self.dropout_output(hidden_states)
487
+
488
+ residual = hidden_states
489
+
490
+ # Apply Canon-C before MLP
491
+ if self.canon_c is not None:
492
+ hidden_states = self.canon_c(hidden_states)
493
+
494
+ mlp_input = self.post_attention_layernorm(hidden_states)
495
+ mlp_input = self.lns_mlp(mlp_input)
496
+ mlp_output = self.mlp(mlp_input, lax_buffer)
497
+ hidden_states = residual + mlp_output
498
+ hidden_states = self.gpas_mlp(hidden_states)
499
+ hidden_states = self.dropout_output(hidden_states)
500
+
501
+ return hidden_states
502
+
503
+ class UnifiedModel(PreTrainedModel):
504
+ """
505
+ UnifiedModel that inherits from PreTrainedModel for full HuggingFace compatibility.
506
+ With AutoClass support for seamless Hub integration.
507
+ """
508
+ config_class = UnifiedModelConfig
509
+
510
+ # βœ… FIXED: _tied_weights_keys as class attribute (HuggingFace standard)
511
+ _tied_weights_keys = ["lm_head.weight"]
512
+
513
+ def __init__(self, config: UnifiedModelConfig):
514
+ super().__init__(config)
515
+ self.config = config
516
+
517
+ if config.vocab_size is None:
518
+ raise ValueError("config.vocab_size must be set from tokenizer before model initialization")
519
+
520
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
521
+ self.embedding_dropout = nn.Dropout(config.embedding_dropout)
522
+ self.output_dropout = nn.Dropout(0.05)
523
+
524
+ # Create lm_head for output projections
525
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
526
+
527
+ self.layers = nn.ModuleList()
528
+ for i in range(config.num_hidden_layers):
529
+ self.layers.append(DecoderLayer(config, i))
530
+
531
+ self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
532
+
533
+ # Initialize weights
534
+ self.post_init()
535
+
536
+ self._print_configuration()
537
+
538
+ def tie_weights(self):
539
+ """
540
+ βœ… FIXED: Simplified tie_weights method following HuggingFace standard.
541
+ Tie the word embeddings and the output layer.
542
+ This is called automatically if config.tie_word_embeddings is True.
543
+ """
544
+ if self.config.tie_word_embeddings:
545
+ print("πŸ”— Applying weight tying: lm_head.weight = embed_tokens.weight")
546
+ self.lm_head.weight = self.embed_tokens.weight
547
+ print("βœ… Weight tying successful: Parameters are properly shared")
548
+
549
+ def _init_weights(self, module):
550
+ """Initialize weights following the custom initialization scheme."""
551
+ if isinstance(module, nn.Linear):
552
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
553
+ if module.bias is not None:
554
+ nn.init.zeros_(module.bias)
555
+ elif isinstance(module, nn.Embedding):
556
+ nn.init.trunc_normal_(module.weight, mean=0.0, std=0.02, a=-0.04, b=0.04)
557
+ elif isinstance(module, CoLA_Linear):
558
+ pass # CoLA_Linear has its own initialization
559
+
560
+ def _print_configuration(self):
561
+ # Conteo ingenuo de todos los parΓ‘metros registrados
562
+ total_params_naive = sum(p.numel() for p in self.parameters())
563
+
564
+ # Conteo inteligente considerando weight tying
565
+ total_params_actual = total_params_naive
566
+ vocab_params = self.config.vocab_size * self.config.hidden_size
567
+ tied_savings = 0
568
+
569
+ # βœ… CORRECCIΓ“N: Detectar y ajustar por weight tying real
570
+ if self.config.tie_word_embeddings:
571
+ # Verificar si los tensors estΓ‘n realmente atados en memoria
572
+ embed_weight = self.embed_tokens.weight
573
+ lm_head_weight = self.lm_head.weight
574
+
575
+ if embed_weight is lm_head_weight:
576
+ # Los tensors son idΓ©nticos - restar la duplicaciΓ³n
577
+ tied_savings = vocab_params
578
+ total_params_actual = total_params_naive - tied_savings
579
+ else:
580
+ # Weight tying configurado pero no aplicado aΓΊn
581
+ tied_savings = 0
582
+
583
+ # CΓ‘lculos de optimizaciΓ³n existentes
584
+ total_linear_params = 0
585
+ total_cola_params = 0
586
+ canon_params = 0
587
+ lax_params = 0
588
+
589
+ for name, module in self.named_modules():
590
+ if isinstance(module, CoLA_Linear):
591
+ in_features = module.A.in_features
592
+ out_features = module.B.out_features
593
+ rank = module.rank
594
+
595
+ standard_params = in_features * out_features
596
+ cola_params = (in_features * rank) + (rank * out_features)
597
+
598
+ total_linear_params += standard_params
599
+ total_cola_params += cola_params
600
+ elif isinstance(module, CanonLayer):
601
+ # Canon layer parameters: depthwise conv1d + bias
602
+ canon_layer_params = module.hidden_dim * module.kernel_size + module.hidden_dim
603
+ canon_params += canon_layer_params
604
+ elif hasattr(module, 'lax_beta_q') and module.lax_beta_q is not None:
605
+ # Count LaX Ξ² parameters
606
+ lax_params += 3 # q, k, v
607
+ elif hasattr(module, 'lax_beta_up') and module.lax_beta_up is not None:
608
+ # Count LaX Ξ² parameters
609
+ lax_params += 2 # up, down
610
+
611
+ cola_reduction = ((total_linear_params - total_cola_params) / total_linear_params) * 100 if total_linear_params > 0 else 0
612
+ canon_overhead = (canon_params / total_params_actual) * 100 if total_params_actual > 0 else 0
613
+ lax_overhead = (lax_params / total_params_actual) * 100 if total_params_actual > 0 else 0
614
+
615
+ print(f"\nπŸ“Š UNIFIED Model + GPAS + LNS + xIELU + CoLA (Linear Only) + LaX + Canon (A+C) + Weight Tying:")
616
+
617
+ # βœ… MEJORADO: Mostrar conteo real vs ingenuo para transparencia
618
+ if self.config.tie_word_embeddings and tied_savings > 0:
619
+ print(f"🎯 Total Parameters: {total_params_actual/1e6:.2f}M (effective)")
620
+ print(f"πŸ“Š Parameter Breakdown:")
621
+ print(f" β€’ Naive count: {total_params_naive/1e6:.2f}M (all registered params)")
622
+ print(f" β€’ Actual count: {total_params_actual/1e6:.2f}M (after weight tying)")
623
+ print(f" β€’ Weight tying savings: {tied_savings/1e6:.2f}M ({tied_savings/total_params_naive*100:.1f}%)")
624
+ else:
625
+ print(f"🎯 Total Parameters: {total_params_actual/1e6:.2f}M")
626
+
627
+ print(f"πŸ“š DYNAMIC Vocabulary Size: {self.config.vocab_size} (from tokenizer)")
628
+ print(f"πŸ”— βœ… PROPER Weight Tying: {'ENABLED' if self.config.tie_word_embeddings else 'DISABLED'}")
629
+
630
+ # βœ… CORRECCIΓ“N: Mostrar estado real del weight tying
631
+ if self.config.tie_word_embeddings:
632
+ if tied_savings > 0:
633
+ print(f"πŸ’Ύ Weight Tying Status: βœ… ACTIVE (tensors are shared in memory)")
634
+ else:
635
+ print(f"πŸ’Ύ Weight Tying Status: ⏳ CONFIGURED (will be applied during post_init)")
636
+
637
+ print(f"πŸš€ ACTIVATION: xIELU (Ξ±p_init={self.config.xielu_alpha_p_init}, Ξ±n_init={self.config.xielu_alpha_n_init}, Ξ²={self.config.xielu_beta})")
638
+ print(f"πŸ”„ UPGRADE: SwiGLU β†’ StandardMLP + xIELU (better efficiency & adaptability)")
639
+ print(f"πŸ—œοΈ CoLA Integration: {cola_reduction:.1f}% parameter reduction in internal projections")
640
+ print(f"πŸ”€ LaX Enabled: {'YES' if self.config.lax_enabled else 'NO'} βœ… WORKING Inter-Layer (Linear Gate)")
641
+ if self.config.lax_enabled:
642
+ print(f" β€’ LaX Method: Inter-Layer with Linear Gate (Ξ² scalars)")
643
+ print(f" β€’ LaX Applied to: q_proj, k_proj, v_proj, up_proj, down_proj (NOT o_proj)")
644
+ print(f" β€’ LaX Parameters: {lax_params} Ξ² scalars ({lax_overhead:.6f}% overhead)")
645
+ print(f" β€’ LaX Initialization: Ξ²=0.0 (conservative start)")
646
+ print(f"🎼 Canon Layers Enabled: {'YES' if self.config.canon_enabled else 'NO'} (A+C ONLY)")
647
+ if self.config.canon_enabled:
648
+ print(f" β€’ Canon-A (Before Attention): {'βœ…' if self.config.canon_a_enabled else '❌'}")
649
+ print(f" β€’ Canon-B (Inside Attention): ❌ PERMANENTLY DISABLED")
650
+ print(f" β€’ Canon-C (Before MLP): {'βœ…' if self.config.canon_c_enabled else '❌'}")
651
+ print(f" β€’ Canon-D (Inside MLP): ❌ PERMANENTLY DISABLED")
652
+ print(f" β€’ Canon Kernel Size: {self.config.canon_kernel_size}")
653
+ print(f" β€’ Canon Parameters Overhead: {canon_overhead:.3f}% ({canon_params/1e3:.1f}K params)")
654
+ print(f"⚑ GPAS Enabled: ALWAYS (Dynamic variance control)")
655
+ print(f"πŸ“ LNS Enabled: ALWAYS (Static depth scaling)")
656
+ print(f"πŸ”§ Variance Control: Triple-level (LNS + GPAS + Canon A+C) ALWAYS")
657
+ print(f"πŸ”— Residual Connections: STANDARD + HORIZONTAL (Canon A+C only)")
658
+ print(f"🧹 CLEAN: Standard transformer architecture - CrossEntropyLoss manages PAD naturally")
659
+ print(f"⚑ FlashAttention: Scaled Dot-Product Attention with GQA + automatic causal masking")
660
+ print(f"🎯 TOKENIZER AGNOSTIC: Dynamic vocab_size and pad_token_id")
661
+ print(f"🎯 SIMPLIFIED: CoLA Linear Only + Canon A+C Only = Better performance & less overhead")
662
+ print(f"πŸ”— βœ… FIXED Weight Tying: _tied_weights_keys as class attribute (HF standard)")
663
+ print(f"🎼 Canon A+C BENEFITS: Strategic horizontal information flow with minimal parameters")
664
+ print(f"πŸ”€ βœ… FIXED LaX: Functional Inter-Layer with ephemeral buffer (no broken reset)")
665
+ print(f"πŸ€— HUGGINGFACE COMPATIBLE: Full PreTrainedModel integration v4.53.3")
666
+ print(f"⚑ βœ… NATIVE TORCH COMPILE: Will be handled by TrainingArguments")
667
+ print(f"πŸš€ βœ… AUTOCLASS SUPPORT: Compatible with AutoConfig.from_pretrained() and AutoModel.from_pretrained()")
668
+
669
+ def forward(
670
+ self,
671
+ input_ids: torch.Tensor,
672
+ attention_mask: Optional[torch.Tensor] = None,
673
+ position_ids: Optional[torch.Tensor] = None,
674
+ labels: Optional[torch.Tensor] = None,
675
+ **kwargs
676
+ ):
677
+ batch_size, seq_len = input_ids.shape
678
+
679
+ # βœ… LaX: Create ephemeral buffer for this forward pass
680
+ lax_buffer = {} if self.config.lax_enabled else None
681
+
682
+ hidden_states = self.embed_tokens(input_ids)
683
+ hidden_states = self.embedding_dropout(hidden_states)
684
+
685
+ for layer in self.layers:
686
+ hidden_states = layer(hidden_states, position_ids=position_ids, attention_mask=attention_mask, lax_buffer=lax_buffer)
687
+
688
+ hidden_states = self.norm(hidden_states)
689
+ hidden_states = self.output_dropout(hidden_states)
690
+
691
+ logits = self.lm_head(hidden_states)
692
+
693
+ loss = None
694
+ if labels is not None:
695
+ # Shift so that tokens < n predict n
696
+ shift_logits = logits[..., :-1, :].contiguous()
697
+ shift_labels = labels[..., 1:].contiguous()
698
+ # Flatten the tokens
699
+ loss_fct = nn.CrossEntropyLoss()
700
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
701
+ shift_labels = shift_labels.view(-1)
702
+ # Enable model parallelism
703
+ shift_labels = shift_labels.to(shift_logits.device)
704
+
705
+ # βœ… RESTORED: Change pad tokens to -100 so CrossEntropyLoss ignores them (from original code)
706
+ if self.config.pad_token_id is not None:
707
+ shift_labels[shift_labels == self.config.pad_token_id] = -100
708
+
709
+ loss = loss_fct(shift_logits, shift_labels)
710
+
711
+ # βœ… LaX buffer is automatically cleaned up (ephemeral, goes out of scope)
712
+
713
+ # Return in HuggingFace format
714
+ from transformers.modeling_outputs import CausalLMOutputWithPast
715
+ return CausalLMOutputWithPast(
716
+ loss=loss,
717
+ logits=logits,
718
+ past_key_values=None,
719
+ hidden_states=None,
720
+ attentions=None,
721
+ )
722
+
723
+ def get_input_embeddings(self):
724
+ return self.embed_tokens
725
+
726
+ def set_input_embeddings(self, value):
727
+ self.embed_tokens = value
728
+
729
+ def get_output_embeddings(self):
730
+ return self.lm_head
731
+
732
+ def set_output_embeddings(self, new_embeddings):
733
+ self.lm_head = new_embeddings
734
+
735
+ @torch.no_grad()
736
+ def generate(
737
+ self,
738
+ input_ids: torch.Tensor,
739
+ max_new_tokens: int = 50,
740
+ temperature: float = 1.0,
741
+ top_p: float = 0.9,
742
+ top_k: Optional[int] = None,
743
+ do_sample: bool = True,
744
+ pad_token_id: Optional[int] = None,
745
+ eos_token_id: Optional[int] = None,
746
+ **kwargs
747
+ ) -> torch.Tensor:
748
+ """
749
+ Generate sequences using the model.
750
+ Compatible with AutoModelForCausalLM interface.
751
+ """
752
+ # Set default token IDs
753
+ if pad_token_id is None:
754
+ pad_token_id = self.config.pad_token_id
755
+ if eos_token_id is None:
756
+ eos_token_id = self.config.eos_token_id
757
+
758
+ batch_size = input_ids.shape[0]
759
+ device = input_ids.device
760
+
761
+ generated = input_ids.clone()
762
+
763
+ for _ in range(max_new_tokens):
764
+ # Forward pass (LaX buffer is created fresh each time)
765
+ outputs = self.forward(generated)
766
+ logits = outputs.logits
767
+
768
+ # Get the logits for the last token
769
+ next_token_logits = logits[:, -1, :]
770
+
771
+ if do_sample:
772
+ # Apply temperature
773
+ if temperature != 1.0:
774
+ next_token_logits = next_token_logits / temperature
775
+
776
+ # Apply top-k filtering
777
+ if top_k is not None:
778
+ values, indices = torch.topk(next_token_logits, top_k)
779
+ next_token_logits[next_token_logits < values[:, [-1]]] = -float('inf')
780
+
781
+ # Apply top-p (nucleus) filtering
782
+ if top_p < 1.0:
783
+ sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
784
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
785
+
786
+ # Remove tokens with cumulative probability above the threshold
787
+ sorted_indices_to_remove = cumulative_probs > top_p
788
+ # Shift the indices to the right to keep also the first token above the threshold
789
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
790
+ sorted_indices_to_remove[..., 0] = 0
791
+
792
+ # Scatter sorted tensors to original indexing
793
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
794
+ next_token_logits[indices_to_remove] = -float('inf')
795
+
796
+ # Sample from the filtered distribution
797
+ probs = F.softmax(next_token_logits, dim=-1)
798
+ next_token = torch.multinomial(probs, num_samples=1)
799
+ else:
800
+ # Greedy decoding
801
+ next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
802
+
803
+ # Append the new token
804
+ generated = torch.cat([generated, next_token], dim=1)
805
+
806
+ # Check for EOS token
807
+ if eos_token_id is not None and (next_token == eos_token_id).all():
808
+ break
809
+
810
+ return generated
811
+
812
+
813
+
814
+ # βœ… AUTOCLASS REGISTRATION - Required for Hub compatibility
815
+ # Register the configuration and model for AutoClass support
816
+ AutoConfig.register("unified_model", UnifiedModelConfig)
817
+ AutoModel.register(UnifiedModelConfig, UnifiedModel)
818
+ AutoModelForCausalLM.register(UnifiedModelConfig, UnifiedModel)
819
+
820
+ print("πŸš€ βœ… AUTOCLASS REGISTRATION COMPLETE:")
821
+ print(" β€’ AutoConfig.register('unified_model', UnifiedModelConfig)")
822
+ print(" β€’ AutoModel.register(UnifiedModelConfig, UnifiedModel)")
823
+ print(" β€’ AutoModelForCausalLM.register(UnifiedModelConfig, UnifiedModel)")
824
+ print(" β€’ Users can now load with: AutoModel.from_pretrained('your-repo', trust_remote_code=True)")