KitsuVp commited on
Commit
34c6d09
·
verified ·
1 Parent(s): bb6b821

Initial FanFormer checkpoint with architecture and README

Browse files
README.md CHANGED
@@ -1,3 +1,116 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FanConections: Advanced Neural Connections for Language Modeling
2
+
3
+ [![Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97-Hugging%20Face-blue)](https://huggingface.co/KitsuVp/FanConections)
4
+ [![PyTorch](https://img.shields.io/badge/PyTorch-%23EE4C2C.svg?style=flat&logo=pytorch&logoColor=white)](https://pytorch.org/)
5
+
6
+ FanConections is an advanced language model architecture that enhances traditional transformers with specialized neural connection mechanisms and efficient computational techniques. The model incorporates unique components, including Fourier-inspired analysis, to better capture complex patterns and periodicities within language.
7
+
8
+ ## Model Description
9
+
10
+ FanConections introduces several key architectural innovations:
11
+
12
+ - **Fourier-Inspired Neural Processing (FAN Components)**: These components help the model understand and represent repeating or cyclical patterns often found in language (e.g., common phrasings, structural recurrences). It does this by transforming parts of the input using mathematical functions similar to those in Fourier analysis.
13
+ - **Compressed Linear Layers (CoLA)**: To make the model more efficient, CoLA layers reduce the number of parameters in linear projections. They achieve this by breaking down large matrices into smaller, low-rank approximations, akin to summarizing a large dataset with its most essential components.
14
+ - **Hybrid Normalization**: Employs a combination of Pre-Normalization and Query-Key-Value (QKV) Normalization strategies. This approach enhances training stability and model performance.
15
+ - **HyperConnections**: These are sophisticated residual connections that go beyond simple skip connections. They use dynamic parameters, allowing the model to intelligently decide how to combine information from different parts of the network, improving gradient flow and the model's ability to learn long-range dependencies.
16
+ - **Optimized Flash Attention**: Leverages highly efficient attention mechanisms, including adaptive normalization techniques, to speed up computation and reduce memory usage.
17
+
18
+ ### Key Features
19
+
20
+ - **Parameter Efficiency**: Thoughtful design choices, like CoLA layers, lead to a more compact model.
21
+ - **Enhanced Pattern Recognition**: FAN components are designed to improve the modeling of periodic or recurrent structures in text.
22
+ - **Improved Training Stability**: Advanced normalization and connection strategies contribute to a smoother training process.
23
+ - **High-Quality Outputs**: Aims to generate more coherent and contextually relevant text by better understanding underlying language patterns.
24
+
25
+ ## Training Data
26
+
27
+ The FanConections model was pre-trained on a substantial dataset of **900 million tokens**. The training corpus was a carefully curated mix:
28
+
29
+ - **90% FineWeb**: A large-scale, high-quality dataset of web content, focusing on educational material.
30
+ - **10% FineMath 4+**: A specialized dataset containing mathematical text and reasoning.
31
+
32
+ This blend provides the model with a broad understanding of general language as well as more structured, logical text.
33
+
34
+ ## Usage
35
+
36
+ You can use this model with the Transformers library:
37
+
38
+ ```python
39
+ from transformers import AutoTokenizer, AutoModelForCausalLM
40
+
41
+ # Load tokenizer and model
42
+ tokenizer = AutoTokenizer.from_pretrained("KitsuVp/FanConections")
43
+ model = AutoModelForCausalLM.from_pretrained("KitsuVp/FanConections", trust_remote_code=True)
44
+ model.eval() # Set the model to evaluation mode
45
+
46
+ # Example input text
47
+ input_text = "The FanConections architecture is designed to"
48
+ input_ids = tokenizer(input_text, return_tensors="pt").input_ids
49
+
50
+ # Generate text with recommended parameters
51
+ # Move input_ids to the same device as the model if using GPU
52
+ # model.to('cuda') # Uncomment this line if you have a CUDA-enabled GPU
53
+ # input_ids = input_ids.to('cuda') # Uncomment this line if you have a CUDA-enabled GPU
54
+
55
+ outputs = model.generate(
56
+ input_ids,
57
+ max_length=120, # Maximum length of the generated sequence
58
+ top_p=0.92, # Nucleus sampling: keeps the top p% probability mass
59
+ top_k=50, # Keeps the top k most likely next tokens
60
+ temperature=0.75, # Controls randomness: lower is less random
61
+ num_return_sequences=1, # Number of sequences to generate
62
+ do_sample=True, # Whether to use sampling; set to False for greedy decoding
63
+ pad_token_id=tokenizer.eos_token_id # Important for open-ended generation
64
+ )
65
+
66
+ # Decode and print the generated text
67
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
68
+ print(generated_text)
69
+ ```
70
+
71
+ ## Model Architecture Details
72
+
73
+ The FanConections model implements a decoder-only transformer architecture with several novel components:
74
+
75
+ 1. **FAN Components (CoLA_FAN)**: These specialized layers integrate Fourier-inspired transformations directly into the linear projections (particularly for Query, Key, and Value in attention). This allows the model to more effectively capture and utilize periodic or cyclical information present in the input data.
76
+ 2. **Low-Rank Matrix Factorization (CoLA_Linear & CoLA_FAN)**: Both `CoLA_Linear` (used in MLPs) and `CoLA_FAN` (used in attention) reduce computational cost and parameter count by approximating large weight matrices with the product of two smaller, lower-rank matrices.
77
+ 3. **HyperConnections**: An advanced form of residual connection. Instead of a simple addition, HyperConnections use learnable parameters (both static and dynamically computed based on the input) to create a more flexible and expressive way of combining outputs from previous layers with the current layer's computation. This helps in training deeper networks and managing information flow.
78
+ 4. **RoPE Positional Embeddings**: Implements Rotary Positional Embeddings, which inject positional information by rotating parts of the embedding vectors, offering better relative position awareness.
79
+ 5. **Progressive Dropout**: A dropout strategy where the probability of dropping units increases with the depth of the network layer, providing stronger regularization for deeper parts of the model.
80
+ 6. **Flash Attention with Unpadding**: Utilizes optimized attention computations (FlashAttention) combined with techniques to handle variable-length sequences efficiently (unpadding/padding), maximizing GPU utilization.
81
+ 7. **Muon Optimizer**: A custom optimizer used during pre-training, which combines Newton-Schulz orthogonalization for matrix parameters with an AdamW-like update for other parameters.
82
+
83
+ ## Training
84
+
85
+ The model's pre-training involved:
86
+
87
+ - Distributed training across multiple GPUs.
88
+ - The specialized **Muon optimizer**, which incorporates Newton-Schulz orthogonalization for certain parameters and an AdamW-like mechanism for others.
89
+ - Progressive learning rate scheduling.
90
+ - Mixed precision (bfloat16) training for speed and memory efficiency.
91
+ - Strategic gradient checkpointing to manage memory consumption during the training of large sequences.
92
+
93
+ ## Limitations
94
+
95
+ - **Context Window**: The model has a fixed context window (e.g., 1024 tokens in the provided code). It cannot process information beyond this limit in a single pass.
96
+ - **Domain Specificity**: While trained on a diverse dataset, performance might be suboptimal on highly specialized or out-of-distribution content.
97
+ - **Potential for Hallucinations**: Like all language models, FanConections can generate text that is factually incorrect, nonsensical, or misleading.
98
+ - **Bias**: The model may reflect biases present in its extensive training data.
99
+
100
+ ## Citation
101
+
102
+ If you use FanConections or its architecture in your research, please cite:
103
+
104
+ ```bibtex
105
+ @misc{fanconections2025,
106
+ author = {Kitsun},
107
+ title = {FanConections: Advanced Neural Connections for Language Modeling},
108
+ year = {2025},
109
+ publisher = {HuggingFace},
110
+ howpublished = {\\url{[https://huggingface.co/KitsuVp/FanConections](https://huggingface.co/KitsuVp/FanConections)}}
111
+ }
112
+ ```
113
+
114
+ ## License
115
+
116
+ This model is released under the [Apache 2.0 License](https://www.apache.org/licenses/LICENSE-2.0).
config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dropout": 0.101,
3
+ "embed_dim": 768,
4
+ "ff_dim": 2048,
5
+ "max_seq_len": 1024,
6
+ "num_decoder_layers": 12,
7
+ "num_gqa_groups": 6,
8
+ "num_heads": 12,
9
+ "p": 0.14,
10
+ "tie_weights": true,
11
+ "vocab_size": 49152,
12
+ "model_type": "fanformer",
13
+ "architectures": [
14
+ "MultiModalModel"
15
+ ],
16
+ "auto_map": {
17
+ "AutoConfig": "model_architecture.FanConfig",
18
+ "AutoModelForCausalLM": "model_architecture.MultiModalModel"
19
+ }
20
+ }
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:607ebeab78df2738e7039a379d2e0b022cdf42f7ff9b675f38be06d91f72a160
3
+ size 331514552
model_architecture.py ADDED
@@ -0,0 +1,867 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import random
4
+ from typing import Any, Dict, List, Optional, cast
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.nn import RMSNorm
10
+ import torch.nn.functional as F
11
+ import torch.utils.checkpoint as checkpoint
12
+ # Añade estas importaciones al principio del archivo
13
+ from huggingface_hub import PyTorchModelHubMixin
14
+ from transformers import PretrainedConfig
15
+
16
+ # Añade esta clase de configuración
17
+ class FanConfig(PretrainedConfig):
18
+ model_type = "fanformer"
19
+
20
+ def __init__(
21
+ self,
22
+ vocab_size=32000,
23
+ embed_dim=768,
24
+ max_seq_len=1024,
25
+ num_heads=12,
26
+ num_decoder_layers=12,
27
+ ff_dim=2048,
28
+ dropout=0.12,
29
+ num_gqa_groups=6,
30
+ p=0.15,
31
+ tie_weights=True,
32
+ **kwargs
33
+ ):
34
+ super().__init__(**kwargs)
35
+ self.vocab_size = vocab_size
36
+ self.embed_dim = embed_dim
37
+ self.max_seq_len = max_seq_len
38
+ self.num_heads = num_heads
39
+ self.num_decoder_layers = num_decoder_layers
40
+ self.ff_dim = ff_dim
41
+ self.dropout = dropout
42
+ self.num_gqa_groups = num_gqa_groups
43
+ self.p = p
44
+ self.tie_weights = tie_weights
45
+
46
+ ############################################
47
+ # FUNCIONES DE INICIALIZACIÓN DE CAPAS
48
+ ############################################
49
+ def init_linear(layer: nn.Linear, random_factor: float = 0.02):
50
+ gain = nn.init.calculate_gain('linear') * (1.0 + random.uniform(-random_factor, random_factor))
51
+ nn.init.xavier_uniform_(layer.weight, gain=gain)
52
+ if layer.bias is not None:
53
+ nn.init.zeros_(layer.bias)
54
+
55
+ def init_embedding(embedding: nn.Embedding):
56
+ nn.init.normal_(embedding.weight, mean=0.0, std=0.02)
57
+
58
+ def init_gate_parameter(gate: torch.Tensor, a: float = -0.02, b: float = 0.02):
59
+ nn.init.uniform_(gate, a=a, b=b)
60
+
61
+ ############################################
62
+ # NUEVA CAPA: COLA NORMAL – CAPA LINEAL DE BAJO RANGO
63
+ ############################################
64
+ class CoLA_Linear(nn.Module):
65
+ """
66
+ Implementación de una capa lineal según la propuesta CoLA (normal).
67
+ Reemplaza la operación full-rank W*x por:
68
+ h' = B(σ(Ax))
69
+ donde A y B son matrices de bajo rango, y σ es una función de activación no lineal.
70
+
71
+ Por defecto, se utiliza rank = in_features // 4.
72
+ """
73
+ def __init__(self, in_features: int, out_features: int, rank: Optional[int] = None, activation=F.gelu):
74
+ super().__init__()
75
+ if rank is None:
76
+ rank = in_features // 4
77
+ self.rank = rank
78
+ self.activation = activation
79
+ # Definición de las dos proyecciones
80
+ self.A = nn.Linear(in_features, rank, bias=False)
81
+ self.B = nn.Linear(rank, out_features, bias=True)
82
+ init_linear(self.A)
83
+ init_linear(self.B)
84
+
85
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
86
+ return self.B(self.activation(self.A(x)))
87
+
88
+ ############################################
89
+ # NUEVA CAPA: COLA_FAN – CAPA LINEAL CON ANÁLISIS DE FOURIER PARA FANFORMER
90
+ ############################################
91
+ class CoLA_FAN(nn.Module):
92
+ """
93
+ Implementación de una capa CoLA con análisis de Fourier para FANformer.
94
+ Combina la eficiencia de CoLA con la capacidad de modelado de periodicidad de FANformer.
95
+
96
+ Esta implementación omite el dropout interno ya que la regularización ya se aplica en las
97
+ capas superiores (FANformerMultiheadAttention y flash attention). Esto evita una
98
+ regularización excesiva que podría limitar la capacidad de aprendizaje del modelo.
99
+
100
+ Parámetros:
101
+ in_features: Dimensión de entrada
102
+ out_features: Dimensión de salida
103
+ rank: Rango para compresión CoLA (por defecto in_features // 4)
104
+ p: Proporción de la dimensión dedicada al modelado periódico (por defecto 0.15)
105
+ activation: Función de activación para las proyecciones
106
+ depth: Profundidad de la capa en la red (mantenido para compatibilidad)
107
+ """
108
+ def __init__(self, in_features: int, out_features: int, rank: Optional[int] = None,
109
+ p: float = 0.15, activation=F.gelu, dropout: float = 0.0, depth: int = 1):
110
+ super().__init__()
111
+ if rank is None:
112
+ rank = in_features // 4
113
+ self.rank = rank
114
+ self.activation = activation
115
+ self.p = p
116
+
117
+ # Calcular dimensiones para componentes periódicos y no periódicos
118
+ p_dim = int(out_features * p) # Dimensión para componente periódico (antes de cos/sin)
119
+ non_p_dim = out_features - 2 * p_dim # Dimensión para componente no periódico
120
+
121
+ # Proyecciones para componente periódico
122
+ self.A_p = nn.Linear(in_features, rank, bias=False)
123
+ self.B_p = nn.Linear(rank, p_dim, bias=False) # Sin bias para transformación periódica
124
+
125
+ # Proyecciones para componente no periódico (CoLA estándar)
126
+ self.A_np = nn.Linear(in_features, rank, bias=False)
127
+ self.B_np = nn.Linear(rank, non_p_dim, bias=True)
128
+
129
+ # Se elimina el dropout interno para evitar regularización excesiva
130
+ # ya que el dropout se aplica en capas superiores (FANformerMultiheadAttention)
131
+
132
+ # Inicialización
133
+ init_linear(self.A_p)
134
+ init_linear(self.B_p)
135
+ init_linear(self.A_np)
136
+ init_linear(self.B_np)
137
+
138
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
139
+ # Componente periódico sin dropout
140
+ p_activation = self.activation(self.A_p(x))
141
+ p_proj = self.B_p(p_activation)
142
+
143
+ # Componente no periódico sin dropout
144
+ np_activation = self.activation(self.A_np(x))
145
+ np_proj = self.B_np(np_activation)
146
+
147
+ # Combinar usando transformaciones de Fourier (cos/sin) y componente regular
148
+ return torch.cat([torch.cos(p_proj), torch.sin(p_proj), np_proj], dim=-1)
149
+
150
+ ############################################
151
+ # UTILIDAD: CREACIÓN DE DROPOUT PROGRESIVO
152
+ ############################################
153
+ def progressive_dropout(p: float, depth: int) -> nn.Dropout:
154
+ """
155
+ Implementa un dropout progresivo que aumenta logarítmicamente con la profundidad.
156
+
157
+ Args:
158
+ p (float): Probabilidad base de dropout
159
+ depth (int): Profundidad de la capa
160
+
161
+ Returns:
162
+ nn.Dropout: Módulo de dropout con probabilidad ajustada
163
+ """
164
+ if p == 0.0:
165
+ return nn.Dropout(0.0)
166
+
167
+ # Base logarítmica (ajustable según necesidades)
168
+ base = 1.4
169
+
170
+ # Usar logaritmo para un crecimiento más lento en capas profundas
171
+ return nn.Dropout(p * (1 + math.log(depth + 1, base) * 0.04))
172
+
173
+ ############################################
174
+ # UTILIDADES: ROPE UNIFICADO CON PRECÁLCULO
175
+ ############################################
176
+ def get_rope_buffer(seq_len: int, head_dim: int, device: torch.device, dtype: torch.dtype):
177
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim))
178
+ pos = torch.arange(seq_len, device=device).float().unsqueeze(1)
179
+ sinusoid_inp = pos * inv_freq.unsqueeze(0)
180
+ cos = torch.cos(sinusoid_inp).to(dtype)
181
+ sin = torch.sin(sinusoid_inp).to(dtype)
182
+ cos = cos.unsqueeze(0).unsqueeze(0)
183
+ sin = sin.unsqueeze(0).unsqueeze(0)
184
+ return cos, sin
185
+
186
+ def apply_rope_vectorized(x: torch.Tensor) -> torch.Tensor:
187
+ B, num_heads, T, head_dim = x.shape
188
+ if head_dim % 2 != 0:
189
+ raise ValueError("head_dim debe ser par para RoPE")
190
+ cos, sin = get_rope_buffer(T, head_dim, x.device, x.dtype)
191
+ x_reshaped = x.view(B, num_heads, T, head_dim // 2, 2)
192
+ x_even = x_reshaped[..., 0]
193
+ x_odd = x_reshaped[..., 1]
194
+ x_rotated_even = x_even * cos - x_odd * sin
195
+ x_rotated_odd = x_even * sin + x_odd * cos
196
+ x_rotated = torch.stack([x_rotated_even, x_rotated_odd], dim=-1)
197
+ result = x_rotated.flatten(-2)
198
+ return result
199
+
200
+ ############################################
201
+ # GATED RESIDUALS
202
+ ############################################
203
+ class HyperConnections(nn.Module):
204
+ def __init__(self, d_model: int, expansion_rate: int = 4, dropout: float = 0.12, depth: int = 1):
205
+ super().__init__()
206
+ self.expansion_rate = expansion_rate
207
+
208
+ # Determinar si CUDA está disponible
209
+ device = torch.device('cuda')
210
+
211
+ # Definición de las matrices estáticas - directamente en CUDA con bfloat16
212
+ self.static_beta = nn.Parameter(torch.ones(expansion_rate, device=device, dtype=torch.bfloat16))
213
+
214
+ # Inicialización de alpha según el paper - directamente en CUDA con bfloat16
215
+ init_alpha0 = torch.zeros((expansion_rate, 1), device=device, dtype=torch.bfloat16)
216
+ init_alpha0[depth % expansion_rate, 0] = 1.
217
+
218
+ self.static_alpha = nn.Parameter(torch.cat(
219
+ [init_alpha0, torch.eye(expansion_rate, device=device, dtype=torch.bfloat16)], dim=1))
220
+
221
+ # Parámetros para la parte dinámica - directamente en CUDA con bfloat16
222
+ self.dynamic_alpha_fn = nn.Parameter(torch.zeros((d_model, expansion_rate+1), device=device, dtype=torch.bfloat16))
223
+ self.dynamic_alpha_scale = nn.Parameter(torch.ones(1, device=device, dtype=torch.bfloat16) * 0.01)
224
+ self.dynamic_beta_fn = nn.Parameter(torch.zeros((d_model), device=device, dtype=torch.bfloat16))
225
+ self.dynamic_beta_scale = nn.Parameter(torch.ones(1, device=device, dtype=torch.bfloat16) * 0.01)
226
+
227
+ # Normalización para estabilidad
228
+ self.layer_norm = nn.RMSNorm(d_model, eps=1e-5)
229
+
230
+ # Dropout
231
+ self.dropout = nn.Dropout(dropout)
232
+
233
+ # Pre-calcular buffers estáticos
234
+ self.register_buffer(
235
+ 'static_alpha_expanded',
236
+ self.static_alpha.unsqueeze(0).unsqueeze(0)
237
+ )
238
+ self.register_buffer(
239
+ 'static_beta_expanded',
240
+ self.static_beta.unsqueeze(0).unsqueeze(0)
241
+ )
242
+
243
+ def _compute_dynamic_params(self, norm_x):
244
+ """Calcular parámetros dinámicos (alpha y beta)"""
245
+ dynamic_alpha = F.tanh(norm_x @ self.dynamic_alpha_fn) * self.dynamic_alpha_scale
246
+ dynamic_beta = F.tanh(norm_x @ self.dynamic_beta_fn) * self.dynamic_beta_scale
247
+
248
+ # Preparar para broadcasting
249
+ dynamic_alpha = dynamic_alpha.unsqueeze(2) # [B, T, 1, E+1]
250
+ dynamic_beta = dynamic_beta.unsqueeze(2) # [B, T, 1]
251
+
252
+ # Combinar static y dynamic
253
+ alpha = self.static_alpha_expanded + dynamic_alpha # [B, T, E, E+1]
254
+ beta = self.static_beta_expanded + dynamic_beta # [B, T, E]
255
+
256
+ return alpha, beta
257
+
258
+ def _compute_width_connection(self, x, alpha):
259
+ """Calcular la conexión de ancho (width connection)"""
260
+ alpha_t = alpha.transpose(2, 3) # [B, T, E+1, E]
261
+ x_expanded = x.unsqueeze(2).expand(-1, -1, self.expansion_rate, -1) # [B, T, E, D]
262
+
263
+ # Calcular mix_h con un solo einsum
264
+ mix_h = torch.einsum('btij,btjd->btid', alpha_t, x_expanded) # [B, T, E+1, D]
265
+ return mix_h
266
+
267
+ def _compute_depth_connection(self, residual, beta, mix_h):
268
+ """Calcular la conexión de profundidad (depth connection) y combinar"""
269
+ residual = self.dropout(residual)
270
+ residual_expanded = residual.unsqueeze(2).expand(-1, -1, self.expansion_rate, -1)
271
+ weighted_residual = residual_expanded * beta.unsqueeze(-1) # [B, T, E, D]
272
+
273
+ # Extraer mix_h_rest (todas excepto primera)
274
+ mix_h_rest = mix_h[:, :, 1:, :] # [B, T, E, D]
275
+
276
+ # Combinar y reducir
277
+ h = weighted_residual + mix_h_rest # [B, T, E, D]
278
+ output = h.sum(dim=2) # [B, T, D]
279
+
280
+ return output
281
+
282
+ def forward(self, x, residual):
283
+ """Forward pass con checkpointing para ahorrar memoria"""
284
+ # Convertir las entradas a bfloat16 si no lo están ya
285
+ x = x.to(dtype=torch.bfloat16)
286
+ residual = residual.to(dtype=torch.bfloat16)
287
+
288
+ # Paso 1: Normalizar entrada (no checkpointed - bajo uso de memoria)
289
+ norm_x = self.layer_norm(x)
290
+
291
+ # Función auxiliar para aplicar checkpoint y forzar el tipo de retorno
292
+ def apply_checkpoint(func, *args):
293
+ return cast(torch.Tensor, checkpoint.checkpoint(func, *args, use_reentrant=False))
294
+
295
+ # Paso 2: Checkpoint para cálculo de parámetros dinámicos
296
+ alpha, beta = apply_checkpoint(self._compute_dynamic_params, norm_x)
297
+
298
+ # Paso 3: Checkpoint para width connection
299
+ mix_h = apply_checkpoint(self._compute_width_connection, x, alpha)
300
+
301
+ # Paso 4: Checkpoint para depth connection y combinación final
302
+ output = apply_checkpoint(self._compute_depth_connection, residual, beta, mix_h)
303
+
304
+ return output
305
+
306
+ ############################################
307
+ # MÓDULO AUXILIAR: GQA FAN LINEAR
308
+ ############################################
309
+ class GQAFANLinear(nn.Module):
310
+ """
311
+ Proyección de GQA utilizando CoLA_FAN para FANformer.
312
+ Divide la proyección en grupos, usando internamente una capa CoLA_FAN.
313
+
314
+ Se espera que out_features sea divisible por num_heads.
315
+
316
+ Parámetros:
317
+ in_features: Dimensión de entrada
318
+ out_features: Dimensión de salida
319
+ num_heads: Número de cabezales de atención
320
+ num_gqa_groups: Número de grupos para GQA
321
+ p: Proporción de la dimensión dedicada al modelado periódico
322
+ divide_dim: Si se debe dividir la dimensión (por defecto False)
323
+ """
324
+ def __init__(self, in_features: int, out_features: int, num_heads: int,
325
+ num_gqa_groups: int, p: float = 0.15, divide_dim: bool = False):
326
+ super().__init__()
327
+ if out_features % num_heads != 0:
328
+ raise ValueError("out_features debe ser divisible por num_heads")
329
+ self.num_heads = num_heads
330
+ self.num_gqa_groups = num_gqa_groups
331
+ self.rep_factor = num_heads // num_gqa_groups
332
+
333
+ self.divide_factor = 1
334
+ self.head_dim = (out_features // num_heads) // self.divide_factor
335
+
336
+ self.inter_dim = num_gqa_groups * self.head_dim
337
+ # Usamos CoLA_FAN en lugar de CoLA_Linear:
338
+ self.linear = CoLA_FAN(in_features, self.inter_dim, rank=in_features // 4, p=p)
339
+
340
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
341
+ B, T, _ = x.shape
342
+ out = self.linear(x)
343
+ out = out.view(B, T, self.num_gqa_groups, self.head_dim)
344
+ out = out.repeat(1, 1, self.rep_factor, 1)
345
+ out = out.view(B, T, self.num_heads, self.head_dim)
346
+ return out
347
+
348
+ ############################################
349
+ # MÓDULO: ATENCIÓN MULTI-CABEZA CON FANFORMER
350
+ ############################################
351
+ class FANformerMultiheadAttention(nn.Module):
352
+ """
353
+ Implementación de la atención multi-cabeza con FANformer.
354
+ Aplica normalización a Q, K, V individualmente y utiliza unpadding para mejorar el rendimiento.
355
+ Incorpora modelado de periodicidad a través de proyecciones CoLA_FAN.
356
+ """
357
+ def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.12, use_rope: bool = True,
358
+ layer_index: int = 1, max_seq_len: int = 512, p: float = 0.15,
359
+ num_gqa_groups: Optional[int] = None, debug: bool = True,
360
+ use_pre_norm: bool = False):
361
+ super().__init__()
362
+ self.embed_dim = embed_dim
363
+ self.num_heads = num_heads
364
+ self.debug = debug
365
+ self.layer_name = f"Layer_{layer_index}"
366
+ self.layer_index = layer_index
367
+ self.use_pre_norm = use_pre_norm
368
+ self.p = p # Proporción para periodicidad
369
+
370
+ if embed_dim % num_heads != 0:
371
+ raise ValueError("embed_dim debe ser divisible por num_heads")
372
+
373
+ self.head_dim = embed_dim // num_heads
374
+ self.use_rope = use_rope
375
+
376
+ if num_gqa_groups is None:
377
+ num_gqa_groups = num_heads
378
+
379
+ try:
380
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
381
+ self.flash_attn_func = flash_attn_func
382
+ self.flash_attn_varlen_func = flash_attn_varlen_func
383
+ except ImportError as e:
384
+ raise ImportError(f"Error al inicializar FlashAttention: {e}")
385
+
386
+ # Para el unpadding
387
+ try:
388
+ from flash_attn.bert_padding import unpad_input, pad_input
389
+ self.unpad_input = unpad_input
390
+ self.pad_input = pad_input
391
+ except ImportError as e:
392
+ raise ImportError(f"Error al importar funciones de padding: {e}")
393
+
394
+ # Inicialización de parámetros de escala
395
+ self.ssmax_scale = nn.Parameter(torch.ones(num_heads, dtype=torch.bfloat16) * 0.168)
396
+ nn.init.uniform_(self.ssmax_scale, a=0.166, b=0.170)
397
+ self.register_buffer('seq_scale', torch.log(torch.tensor(max_seq_len, dtype=torch.bfloat16)))
398
+
399
+ # Capas de normalización para la entrada (Pre-Norm en primer bloque o QKV-Norm para los demás)
400
+ self.norm = nn.RMSNorm(embed_dim, eps=1e-5)
401
+
402
+ # Capas de dropout (simplificadas)
403
+ self.attention_dropout = progressive_dropout(dropout, depth=1)
404
+ # Eliminado: self.projection_dropout = progressive_dropout(dropout * 1.1, depth=1)
405
+ self.output_dropout = progressive_dropout(dropout, depth=1)
406
+
407
+ # Proyecciones para Q, K, V usando GQAFANLinear (implementación FANformer)
408
+ self.Wq = GQAFANLinear(embed_dim, embed_dim, num_heads, num_gqa_groups, p=p)
409
+ self.Wk = GQAFANLinear(embed_dim, embed_dim, num_heads, num_gqa_groups, p=p)
410
+ self.Wv = GQAFANLinear(embed_dim, embed_dim, num_heads, num_gqa_groups, p=p)
411
+
412
+ # Proyección de salida (se mantiene como CoLA_Linear)
413
+ self.out_proj = CoLA_Linear(embed_dim, embed_dim, rank=embed_dim // 4)
414
+
415
+ def scaled_dot_product_attention_flash_unpadded(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
416
+ attention_mask: Optional[torch.Tensor] = None,
417
+ is_causal: bool = False) -> torch.Tensor:
418
+ B, H, S, D = q.shape # batch, heads, sequence length, head dimension
419
+
420
+ if attention_mask is None:
421
+ # Si no hay máscara de atención, usamos la versión regular
422
+ return self.scaled_dot_product_attention_flash(q, k, v, mask=None, is_causal=is_causal)
423
+
424
+ # Convertir las tensiones a [B, S, H, D] para unpad_input
425
+ q_unpad = q.permute(0, 2, 1, 3) # [B, S, H, D]
426
+ k_unpad = k.permute(0, 2, 1, 3) # [B, S, H, D]
427
+ v_unpad = v.permute(0, 2, 1, 3) # [B, S, H, D]
428
+
429
+ # Preparar máscara: convertir a bool si es necesario
430
+ if attention_mask.dtype != torch.bool:
431
+ attention_mask = attention_mask.bool()
432
+
433
+ # Hacer unpadding de los tensores
434
+ q_unpadded, indices_q, cu_seqlens_q, max_seqlen_q, _ = self.unpad_input(q_unpad, attention_mask)
435
+ k_unpadded, indices_k, cu_seqlens_k, max_seqlen_k, _ = self.unpad_input(k_unpad, attention_mask)
436
+ v_unpadded, _, _, _, _ = self.unpad_input(v_unpad, attention_mask)
437
+
438
+ # Reacomodar para flash_attn_varlen_func: [Total, H, D]
439
+ q_unpadded = q_unpadded.reshape(-1, H, D)
440
+ k_unpadded = k_unpadded.reshape(-1, H, D)
441
+ v_unpadded = v_unpadded.reshape(-1, H, D)
442
+
443
+ # Normalizar vectores Q y K para mejorar estabilidad numérica
444
+ q_norm = F.normalize(q_unpadded, p=2, dim=-1).to(torch.bfloat16)
445
+ k_norm = F.normalize(k_unpadded, p=2, dim=-1).to(torch.bfloat16)
446
+
447
+ # Ajustar q con factor de escala
448
+ s = self.ssmax_scale.view(1, H, 1)
449
+ q_adjusted = q_norm * (self.seq_scale * s)
450
+
451
+ # Factor de escala para softmax
452
+ softmax_scale = 1.0 / math.sqrt(D)
453
+
454
+ try:
455
+ # Usar flash attention sin padding
456
+ output_unpadded = self.flash_attn_varlen_func(
457
+ q_adjusted, k_norm, v_unpadded,
458
+ cu_seqlens_q, cu_seqlens_k,
459
+ max_seqlen_q, max_seqlen_k,
460
+ dropout_p=self.attention_dropout.p, # Aplicamos dropout aquí
461
+ softmax_scale=softmax_scale,
462
+ causal=is_causal
463
+ )
464
+
465
+ # Volver a aplicar padding
466
+ output_padded = self.pad_input(output_unpadded, indices_q, B, S)
467
+
468
+ # Reorganizar a [B, H, S, D]
469
+ output = output_padded.reshape(B, S, H, D).permute(0, 2, 1, 3)
470
+
471
+ return output
472
+
473
+ except Exception as e:
474
+ raise RuntimeError(f"Error en flash_attn_varlen_func: {e}")
475
+
476
+ def scaled_dot_product_attention_flash(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
477
+ mask: Optional[torch.Tensor] = None,
478
+ is_causal: bool = False) -> torch.Tensor:
479
+ # Normalizar vectores Q y K para mejorar estabilidad numérica
480
+ q_norm = F.normalize(q, p=2, dim=-1).to(torch.bfloat16)
481
+ k_norm = F.normalize(k, p=2, dim=-1).to(torch.bfloat16)
482
+
483
+ # Ajustar q con factor de escala
484
+ s = self.ssmax_scale.view(-1, 1, 1)
485
+ q_adjusted = q_norm * (self.seq_scale * s)
486
+
487
+ # Preparar tensores para Flash Attention (requiere shape [B, S, H, D])
488
+ q_trans = q_adjusted.permute(0, 2, 1, 3)
489
+ k_trans = k_norm.permute(0, 2, 1, 3)
490
+ v_trans = v.permute(0, 2, 1, 3)
491
+
492
+ # Verificar dimensiones
493
+ if q_trans.size(-1) != k_trans.size(-1):
494
+ raise ValueError(f"Las dimensiones de head no coinciden: q={q_trans.size(-1)}, k={k_trans.size(-1)}")
495
+
496
+ # Factor de escala para softmax
497
+ softmax_scale = 1.0 / math.sqrt(q_trans.size(-1))
498
+
499
+ try:
500
+ # Aplicar Flash Attention
501
+ output = self.flash_attn_func(
502
+ q_trans, k_trans, v_trans,
503
+ dropout_p=self.attention_dropout.p, # Aplicamos dropout aquí
504
+ softmax_scale=softmax_scale,
505
+ causal=is_causal
506
+ )
507
+
508
+ if output is None:
509
+ raise ValueError("flash_attn_func devolvió None. Verifica las dimensiones y tipos de los tensores de entrada.")
510
+
511
+ # Volver a la forma original
512
+ output = output.permute(0, 2, 1, 3)
513
+ return output
514
+
515
+ except Exception as e:
516
+ raise RuntimeError(f"Error en flash_attn_func: {e}")
517
+
518
+ def forward(self, X: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, causal: bool = True) -> torch.Tensor:
519
+ B, T, _ = X.shape
520
+
521
+ # Implementación de HybridNorm*
522
+ if self.use_pre_norm:
523
+ # Primer bloque: Pre-Norm en atención
524
+ X_norm = self.norm(X)
525
+ # Proyecciones para Q, K, V con FANformer
526
+ Q = self.Wq(X_norm) # [B, T, num_heads, head_dim]
527
+ K = self.Wk(X_norm) # [B, T, num_heads, head_dim]
528
+ V = self.Wv(X_norm) # [B, T, num_heads, head_dim]
529
+ else:
530
+ # Otros bloques: QKV-Norm
531
+ Q = self.Wq(self.norm(X)) # [B, T, num_heads, head_dim]
532
+ K = self.Wk(self.norm(X)) # [B, T, num_heads, head_dim]
533
+ V = self.Wv(self.norm(X)) # [B, T, num_heads, head_dim]
534
+
535
+ # Permutar a formato [B, num_heads, T, head_dim]
536
+ Q = Q.permute(0, 2, 1, 3)
537
+ K = K.permute(0, 2, 1, 3)
538
+ V = V.permute(0, 2, 1, 3)
539
+
540
+ # Aplicar RoPE si está activado
541
+ if self.use_rope:
542
+ Q = apply_rope_vectorized(Q)
543
+ K = apply_rope_vectorized(K)
544
+
545
+ # Convertir a bfloat16 para flash attention
546
+ Q = Q.to(torch.bfloat16)
547
+ K = K.to(torch.bfloat16)
548
+ V = V.to(torch.bfloat16)
549
+
550
+ # Procesar la secuencia utilizando unpadding si hay máscara de atención
551
+ if attention_mask is not None:
552
+ attn_output = self.scaled_dot_product_attention_flash_unpadded(
553
+ Q, K, V,
554
+ attention_mask=attention_mask,
555
+ is_causal=causal
556
+ )
557
+ else:
558
+ # Si no hay máscara, usar la versión regular
559
+ attn_output = self.scaled_dot_product_attention_flash(
560
+ Q, K, V,
561
+ mask=None,
562
+ is_causal=causal
563
+ )
564
+
565
+ # Eliminada la aplicación redundante de dropout:
566
+ # attn_output = self.attention_dropout(attn_output)
567
+
568
+ # Reorganizar la salida y aplicar proyección final
569
+ out = attn_output.permute(0, 2, 1, 3).contiguous()
570
+ out = out.reshape(B, T, self.embed_dim)
571
+ out = self.output_dropout(self.out_proj(out))
572
+
573
+ return out
574
+
575
+ ############################################
576
+ # NUEVO MÓDULO: SWIGLU CON COLA (MLP)
577
+ ############################################
578
+ class SwiGLU(nn.Module):
579
+ def __init__(self, in_features: int, hidden_features: int, dropout: float = 0.12, depth: int = 1):
580
+ super().__init__()
581
+ # Reemplazamos fc1 y fc2 por CoLA_Linear
582
+ self.fc1 = CoLA_Linear(in_features, hidden_features * 2, rank=in_features // 4)
583
+ self.fc2 = CoLA_Linear(hidden_features, in_features, rank=hidden_features // 4)
584
+ self.dropout = progressive_dropout(dropout, depth)
585
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
586
+ x_proj = self.fc1(x)
587
+ x1, x2 = x_proj.chunk(2, dim=-1)
588
+ x_out = x1 * F.silu(x2)
589
+ x_out = self.dropout(x_out)
590
+ return self.fc2(x_out)
591
+
592
+ ############################################
593
+ # BLOQUE DEL FANFORMER: CAPA CON ATENCIÓN Y MLP (Decoder-Only)
594
+ ############################################
595
+ class FANformerLayer(nn.Module):
596
+ """
597
+ Implementación de capa de transformador con FANformer.
598
+ Similar a RegularTransformerLayer pero utiliza FANformerMultiheadAttention.
599
+ """
600
+ def __init__(self, embed_dim: int, num_heads: int, ff_dim: int, dropout: float = 0.12,
601
+ layer_index: int = 1, num_gqa_groups: Optional[int] = None,
602
+ is_first_layer: bool = False, p: float = 0.15):
603
+ super().__init__()
604
+ self.is_first_layer = is_first_layer
605
+
606
+ # En HybridNorm*, el primer bloque usa Pre-Norm en MHA
607
+ # Usamos FANformerMultiheadAttention en lugar de RegularMultiheadAttention
608
+ self.attn = FANformerMultiheadAttention(
609
+ embed_dim, num_heads, dropout=dropout, use_rope=True,
610
+ layer_index=layer_index, num_gqa_groups=num_gqa_groups,
611
+ use_pre_norm=is_first_layer, p=p
612
+ )
613
+
614
+ # Reemplazando GatedResidual con HyperConnections para atención
615
+ self.hyper_conn_attn = HyperConnections(
616
+ embed_dim,
617
+ expansion_rate=2,
618
+ dropout=dropout,
619
+ depth=layer_index
620
+ )
621
+
622
+ # Post-Norm para FFN (HybridNorm)
623
+ self.ffn_norm = nn.RMSNorm(embed_dim, eps=1e-5)
624
+ self.mlp = SwiGLU(embed_dim, ff_dim, dropout, depth=1)
625
+
626
+ # Reemplazando GatedResidual con HyperConnections para FFN
627
+ self.hyper_conn_mlp = HyperConnections(
628
+ embed_dim,
629
+ expansion_rate=2,
630
+ dropout=dropout,
631
+ depth=layer_index
632
+ )
633
+
634
+ # Post-Norm final (HybridNorm)
635
+ self.post_ffn_norm = nn.RMSNorm(embed_dim, eps=1e-5)
636
+
637
+ def _attn_forward(self, x: torch.Tensor, tgt_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
638
+ """Parte de atención sin HyperConnections"""
639
+ return self.attn(x, tgt_mask)
640
+
641
+ def _ffn_forward(self, x: torch.Tensor) -> torch.Tensor:
642
+ """Parte de feed-forward sin HyperConnections"""
643
+ ffn_input = self.ffn_norm(x)
644
+ return self.mlp(ffn_input)
645
+
646
+ def _post_ffn_norm_forward(self, x: torch.Tensor) -> torch.Tensor:
647
+ """Normalización final"""
648
+ return self.post_ffn_norm(x)
649
+
650
+ def forward(self, x: torch.Tensor, tgt_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
651
+ """Forward con checkpointing selectivo"""
652
+ # Función auxiliar para aplicar checkpoint y forzar el tipo de retorno
653
+ def apply_checkpoint(func, *args) -> torch.Tensor:
654
+ # Usamos cast para indicar explícitamente al verificador de tipos
655
+ # que el resultado de checkpoint.checkpoint es un tensor
656
+ return cast(torch.Tensor, checkpoint.checkpoint(func, *args, use_reentrant=False))
657
+
658
+ # Bloque de atención con HybridNorm
659
+ if self.is_first_layer:
660
+ # Primer bloque: Pre-Norm + QKV-Norm
661
+ attention_output = apply_checkpoint(self._attn_forward, x, tgt_mask)
662
+ attention_output = F.dropout(attention_output, p=self.hyper_conn_attn.dropout.p, training=self.training)
663
+ hidden_states = self.hyper_conn_attn(x, attention_output)
664
+ else:
665
+ # Otros bloques: QKV-Norm
666
+ attention_output = apply_checkpoint(self._attn_forward, x, tgt_mask)
667
+ attention_output = F.dropout(attention_output, p=self.hyper_conn_attn.dropout.p, training=self.training)
668
+ hidden_states = self.hyper_conn_attn(x, attention_output)
669
+
670
+ # Paso 3: Aplicar checkpoint al feed-forward
671
+ ffn_output = apply_checkpoint(self._ffn_forward, hidden_states)
672
+
673
+ # Aplicar dropout a la salida de FFN
674
+ ffn_output = F.dropout(ffn_output, p=self.hyper_conn_mlp.dropout.p, training=self.training)
675
+
676
+ # Paso 4: Aplicar HyperConnections
677
+ hidden_states = self.hyper_conn_mlp(hidden_states, ffn_output)
678
+
679
+ # Paso 5: Aplicar checkpoint a la normalización final
680
+ output = apply_checkpoint(self._post_ffn_norm_forward, hidden_states)
681
+
682
+ return output
683
+
684
+ ############################################
685
+ # FANFORMER DECODER CON RECURRENT DEPTH (Decoder-Only)
686
+ ############################################
687
+ class FANformerDecoder(nn.Module):
688
+ """
689
+ Implementación del decoder FANformer con recurrent depth.
690
+ Versión simplificada con skip connections directas sin gates.
691
+ """
692
+ def __init__(self, num_layers: int, embed_dim: int, num_heads: int, ff_dim: int, dropout: float = 0.12,
693
+ num_gqa_groups: Optional[int] = None, p: float = 0.15,
694
+ use_checkpoint: bool = True, skip_every: int = 3):
695
+ super().__init__()
696
+ self.use_checkpoint = use_checkpoint
697
+ self.skip_every = skip_every
698
+ self.embed_dim = embed_dim
699
+
700
+ # Crear capas de FANformer con tratamiento especial para el primer bloque (HybridNorm*)
701
+ self.layers = nn.ModuleList()
702
+ for i in range(num_layers):
703
+ is_first_layer = (i == 0) # Identificar si es el primer bloque para HybridNorm*
704
+ self.layers.append(
705
+ FANformerLayer(
706
+ embed_dim, num_heads, ff_dim,
707
+ dropout=dropout * (1 + i * 0.035),
708
+ layer_index=i+1,
709
+ num_gqa_groups=num_gqa_groups,
710
+ is_first_layer=is_first_layer,
711
+ p=p
712
+ )
713
+ )
714
+
715
+ num_skips = num_layers // skip_every
716
+
717
+ # Mantenemos los dropouts pero eliminamos los gates y normalizaciones
718
+ self.skip_dropouts = nn.ModuleList([
719
+ progressive_dropout(dropout * 0.8, depth=i+1)
720
+ for i in range(num_skips)
721
+ ])
722
+
723
+ # Mantenemos las normalizaciones finales
724
+ self.dropout = progressive_dropout(dropout, depth=1)
725
+ self.layer_norm = nn.RMSNorm(embed_dim, eps=1e-5)
726
+
727
+ def forward(self, tgt: torch.Tensor, tgt_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
728
+ output = tgt
729
+ layer_states = []
730
+
731
+ for i, layer in enumerate(self.layers):
732
+ if i % self.skip_every == 0:
733
+ layer_states.append(output)
734
+
735
+ # Añadimos cuda empty cada 4 capas
736
+ if i > 0 and i % 4 == 0:
737
+ torch.cuda.empty_cache()
738
+
739
+ # Simplemente llamamos al método forward estándar
740
+ output = layer(output, tgt_mask)
741
+
742
+ if (i + 1) % self.skip_every == 0 and i // self.skip_every < len(self.skip_dropouts):
743
+ skip_idx = i // self.skip_every
744
+
745
+ # Obtener skip state
746
+ skip_state = layer_states[skip_idx]
747
+
748
+ # Aplicar dropout directamente (sin normalización ni gates)
749
+ skip_state_dropped = self.skip_dropouts[skip_idx](skip_state)
750
+
751
+ # Combinar directamente sin gates
752
+ output = output + skip_state_dropped
753
+
754
+ # Normalizaciones finales
755
+ output = self.dropout(output)
756
+ output = self.layer_norm(output)
757
+
758
+ return output
759
+
760
+ ############################################
761
+ # MODELO TEXT-ONLY (DECODER-ONLY)
762
+ ############################################
763
+ from transformers.generation.utils import GenerationMixin
764
+
765
+ ############################################
766
+ # MODELO TEXT-ONLY (DECODER-ONLY)
767
+ ############################################
768
+ from typing import Optional
769
+ from transformers.generation.utils import GenerationMixin
770
+ from transformers import PretrainedConfig
771
+ import torch, torch.nn as nn
772
+ import torch.nn.functional as F
773
+ from transformers import GenerationConfig # NUEVO import
774
+ from transformers.modeling_outputs import CausalLMOutput
775
+
776
+ class MultiModalModel(nn.Module,
777
+ PyTorchModelHubMixin,
778
+ GenerationMixin):
779
+ """
780
+ FANformer compatible con generate() y PyTorchModelHubMixin.
781
+ """
782
+ config_class = FanConfig
783
+ model_type = "fanformer"
784
+ main_input_name = "input_ids"
785
+ _supports_cache_class = False # NUEVO ← evita el error actual
786
+ _supports_static_cache = False # NUEVO ← futura verificación
787
+ def __init__(self, # ← signature igual
788
+ config: Optional[FanConfig] = None,
789
+ vocab_size: int = None, embed_dim: int = None,
790
+ max_seq_len: int = None, num_heads: int = None,
791
+ num_decoder_layers: int = None, ff_dim: int = None,
792
+ dropout: float = 0.12, num_gqa_groups: Optional[int] = None,
793
+ p: float = 0.15, tie_weights: bool = True, **kwargs):
794
+ super().__init__()
795
+
796
+ # --- Normaliza entrada (igual que antes) ---
797
+ if config is not None:
798
+ self.config = config
799
+ vocab_size, embed_dim = config.vocab_size, config.embed_dim
800
+ max_seq_len, num_heads = config.max_seq_len, config.num_heads
801
+ num_decoder_layers, ff_dim = config.num_decoder_layers, config.ff_dim
802
+ dropout, num_gqa_groups = config.dropout, config.num_gqa_groups
803
+ p, tie_weights = config.p, config.tie_weights
804
+ else:
805
+ self.config = FanConfig(
806
+ vocab_size=vocab_size, embed_dim=embed_dim,
807
+ max_seq_len=max_seq_len, num_heads=num_heads,
808
+ num_decoder_layers=num_decoder_layers, ff_dim=ff_dim,
809
+ dropout=dropout, num_gqa_groups=num_gqa_groups,
810
+ p=p, tie_weights=tie_weights,
811
+ )
812
+
813
+ # --- NUEVA línea: generación por defecto ---
814
+ self.generation_config = GenerationConfig.from_model_config(self.config)
815
+ # (o simplemente GenerationConfig(), pero la utilidad de arriba copia
816
+ # parámetros útiles como eos_token_id, pad_token_id, etc.) :contentReference[oaicite:2]{index=2}
817
+
818
+ # --- resto de tu constructor sin cambios ---
819
+ self.embed_dim = embed_dim
820
+ self.epsilon = 1e-5
821
+ self.dropout_rate = dropout
822
+
823
+ self.decoder_embedding = nn.Embedding(vocab_size, embed_dim)
824
+ init_embedding(self.decoder_embedding)
825
+ self.emb_dropout = progressive_dropout(dropout, depth=1)
826
+ self.decoder_input_norm = nn.RMSNorm(embed_dim, eps=self.epsilon)
827
+
828
+ self.decoder = FANformerDecoder(
829
+ num_decoder_layers, embed_dim, num_heads, ff_dim,
830
+ dropout=dropout, num_gqa_groups=num_gqa_groups, p=p
831
+ )
832
+
833
+ self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False)
834
+ if tie_weights:
835
+ self.lm_head.weight = self.decoder_embedding.weight
836
+ @property # ← NUEVO (o vuelve a añadirlo)
837
+ def device(self):
838
+ # Hace lo mismo que en PreTrainedModel
839
+ return next(self.parameters()).device
840
+ def can_generate(self) -> bool:
841
+ """Indica a GenerationMixin que el modelo es válido para .generate()"""
842
+ return True
843
+ # GenerationMixin hooks -------------
844
+ def get_input_embeddings(self):
845
+ return self.decoder_embedding
846
+ def set_input_embeddings(self, value):
847
+ self.decoder_embedding = value
848
+
849
+ def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
850
+ x = self.decoder_embedding(input_ids).to(self.decoder_embedding.weight.dtype)
851
+ x = self.emb_dropout(x)
852
+ x = self.decoder_input_norm(x)
853
+ hidden = self.decoder(x, tgt_mask=attention_mask)
854
+ logits = self.lm_head(hidden)
855
+
856
+ loss = None
857
+ if labels is not None:
858
+ # Shift logits and labels for causal LM
859
+ shift_logits = logits[..., :-1, :].contiguous()
860
+ shift_labels = labels[..., 1:].contiguous()
861
+ loss = F.cross_entropy(
862
+ shift_logits.view(-1, shift_logits.size(-1)),
863
+ shift_labels.view(-1),
864
+ ignore_index=-100 # estándar en Transformers
865
+ )
866
+
867
+ return CausalLMOutput(loss=loss, logits=logits)
special_tokens_map.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>"
5
+ ],
6
+ "bos_token": {
7
+ "content": "<|im_start|>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false
12
+ },
13
+ "eos_token": {
14
+ "content": "<|im_end|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false
19
+ },
20
+ "pad_token": {
21
+ "content": "<|im_end|>",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false
26
+ },
27
+ "unk_token": {
28
+ "content": "<|endoftext|>",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false
33
+ }
34
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "0": {
5
+ "content": "<|endoftext|>",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "1": {
13
+ "content": "<|im_start|>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "2": {
21
+ "content": "<|im_end|>",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ },
28
+ "3": {
29
+ "content": "<repo_name>",
30
+ "lstrip": false,
31
+ "normalized": false,
32
+ "rstrip": false,
33
+ "single_word": false,
34
+ "special": true
35
+ },
36
+ "4": {
37
+ "content": "<reponame>",
38
+ "lstrip": false,
39
+ "normalized": false,
40
+ "rstrip": false,
41
+ "single_word": false,
42
+ "special": true
43
+ },
44
+ "5": {
45
+ "content": "<file_sep>",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false,
50
+ "special": true
51
+ },
52
+ "6": {
53
+ "content": "<filename>",
54
+ "lstrip": false,
55
+ "normalized": false,
56
+ "rstrip": false,
57
+ "single_word": false,
58
+ "special": true
59
+ },
60
+ "7": {
61
+ "content": "<gh_stars>",
62
+ "lstrip": false,
63
+ "normalized": false,
64
+ "rstrip": false,
65
+ "single_word": false,
66
+ "special": true
67
+ },
68
+ "8": {
69
+ "content": "<issue_start>",
70
+ "lstrip": false,
71
+ "normalized": false,
72
+ "rstrip": false,
73
+ "single_word": false,
74
+ "special": true
75
+ },
76
+ "9": {
77
+ "content": "<issue_comment>",
78
+ "lstrip": false,
79
+ "normalized": false,
80
+ "rstrip": false,
81
+ "single_word": false,
82
+ "special": true
83
+ },
84
+ "10": {
85
+ "content": "<issue_closed>",
86
+ "lstrip": false,
87
+ "normalized": false,
88
+ "rstrip": false,
89
+ "single_word": false,
90
+ "special": true
91
+ },
92
+ "11": {
93
+ "content": "<jupyter_start>",
94
+ "lstrip": false,
95
+ "normalized": false,
96
+ "rstrip": false,
97
+ "single_word": false,
98
+ "special": true
99
+ },
100
+ "12": {
101
+ "content": "<jupyter_text>",
102
+ "lstrip": false,
103
+ "normalized": false,
104
+ "rstrip": false,
105
+ "single_word": false,
106
+ "special": true
107
+ },
108
+ "13": {
109
+ "content": "<jupyter_code>",
110
+ "lstrip": false,
111
+ "normalized": false,
112
+ "rstrip": false,
113
+ "single_word": false,
114
+ "special": true
115
+ },
116
+ "14": {
117
+ "content": "<jupyter_output>",
118
+ "lstrip": false,
119
+ "normalized": false,
120
+ "rstrip": false,
121
+ "single_word": false,
122
+ "special": true
123
+ },
124
+ "15": {
125
+ "content": "<jupyter_script>",
126
+ "lstrip": false,
127
+ "normalized": false,
128
+ "rstrip": false,
129
+ "single_word": false,
130
+ "special": true
131
+ },
132
+ "16": {
133
+ "content": "<empty_output>",
134
+ "lstrip": false,
135
+ "normalized": false,
136
+ "rstrip": false,
137
+ "single_word": false,
138
+ "special": true
139
+ }
140
+ },
141
+ "additional_special_tokens": [
142
+ "<|im_start|>",
143
+ "<|im_end|>"
144
+ ],
145
+ "bos_token": "<|im_start|>",
146
+ "chat_template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful AI assistant named SmolLM, trained by Hugging Face<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
147
+ "clean_up_tokenization_spaces": false,
148
+ "eos_token": "<|im_end|>",
149
+ "extra_special_tokens": {},
150
+ "model_max_length": 8192,
151
+ "pad_token": "<|im_end|>",
152
+ "tokenizer_class": "GPT2Tokenizer",
153
+ "unk_token": "<|endoftext|>",
154
+ "vocab_size": 49152
155
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff