Initial FanFormer checkpoint with architecture and README
Browse files- README.md +116 -3
- config.json +20 -0
- merges.txt +0 -0
- model.safetensors +3 -0
- model_architecture.py +867 -0
- special_tokens_map.json +34 -0
- tokenizer.json +0 -0
- tokenizer_config.json +155 -0
- vocab.json +0 -0
README.md
CHANGED
|
@@ -1,3 +1,116 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FanConections: Advanced Neural Connections for Language Modeling
|
| 2 |
+
|
| 3 |
+
[](https://huggingface.co/KitsuVp/FanConections)
|
| 4 |
+
[](https://pytorch.org/)
|
| 5 |
+
|
| 6 |
+
FanConections is an advanced language model architecture that enhances traditional transformers with specialized neural connection mechanisms and efficient computational techniques. The model incorporates unique components, including Fourier-inspired analysis, to better capture complex patterns and periodicities within language.
|
| 7 |
+
|
| 8 |
+
## Model Description
|
| 9 |
+
|
| 10 |
+
FanConections introduces several key architectural innovations:
|
| 11 |
+
|
| 12 |
+
- **Fourier-Inspired Neural Processing (FAN Components)**: These components help the model understand and represent repeating or cyclical patterns often found in language (e.g., common phrasings, structural recurrences). It does this by transforming parts of the input using mathematical functions similar to those in Fourier analysis.
|
| 13 |
+
- **Compressed Linear Layers (CoLA)**: To make the model more efficient, CoLA layers reduce the number of parameters in linear projections. They achieve this by breaking down large matrices into smaller, low-rank approximations, akin to summarizing a large dataset with its most essential components.
|
| 14 |
+
- **Hybrid Normalization**: Employs a combination of Pre-Normalization and Query-Key-Value (QKV) Normalization strategies. This approach enhances training stability and model performance.
|
| 15 |
+
- **HyperConnections**: These are sophisticated residual connections that go beyond simple skip connections. They use dynamic parameters, allowing the model to intelligently decide how to combine information from different parts of the network, improving gradient flow and the model's ability to learn long-range dependencies.
|
| 16 |
+
- **Optimized Flash Attention**: Leverages highly efficient attention mechanisms, including adaptive normalization techniques, to speed up computation and reduce memory usage.
|
| 17 |
+
|
| 18 |
+
### Key Features
|
| 19 |
+
|
| 20 |
+
- **Parameter Efficiency**: Thoughtful design choices, like CoLA layers, lead to a more compact model.
|
| 21 |
+
- **Enhanced Pattern Recognition**: FAN components are designed to improve the modeling of periodic or recurrent structures in text.
|
| 22 |
+
- **Improved Training Stability**: Advanced normalization and connection strategies contribute to a smoother training process.
|
| 23 |
+
- **High-Quality Outputs**: Aims to generate more coherent and contextually relevant text by better understanding underlying language patterns.
|
| 24 |
+
|
| 25 |
+
## Training Data
|
| 26 |
+
|
| 27 |
+
The FanConections model was pre-trained on a substantial dataset of **900 million tokens**. The training corpus was a carefully curated mix:
|
| 28 |
+
|
| 29 |
+
- **90% FineWeb**: A large-scale, high-quality dataset of web content, focusing on educational material.
|
| 30 |
+
- **10% FineMath 4+**: A specialized dataset containing mathematical text and reasoning.
|
| 31 |
+
|
| 32 |
+
This blend provides the model with a broad understanding of general language as well as more structured, logical text.
|
| 33 |
+
|
| 34 |
+
## Usage
|
| 35 |
+
|
| 36 |
+
You can use this model with the Transformers library:
|
| 37 |
+
|
| 38 |
+
```python
|
| 39 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 40 |
+
|
| 41 |
+
# Load tokenizer and model
|
| 42 |
+
tokenizer = AutoTokenizer.from_pretrained("KitsuVp/FanConections")
|
| 43 |
+
model = AutoModelForCausalLM.from_pretrained("KitsuVp/FanConections", trust_remote_code=True)
|
| 44 |
+
model.eval() # Set the model to evaluation mode
|
| 45 |
+
|
| 46 |
+
# Example input text
|
| 47 |
+
input_text = "The FanConections architecture is designed to"
|
| 48 |
+
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
|
| 49 |
+
|
| 50 |
+
# Generate text with recommended parameters
|
| 51 |
+
# Move input_ids to the same device as the model if using GPU
|
| 52 |
+
# model.to('cuda') # Uncomment this line if you have a CUDA-enabled GPU
|
| 53 |
+
# input_ids = input_ids.to('cuda') # Uncomment this line if you have a CUDA-enabled GPU
|
| 54 |
+
|
| 55 |
+
outputs = model.generate(
|
| 56 |
+
input_ids,
|
| 57 |
+
max_length=120, # Maximum length of the generated sequence
|
| 58 |
+
top_p=0.92, # Nucleus sampling: keeps the top p% probability mass
|
| 59 |
+
top_k=50, # Keeps the top k most likely next tokens
|
| 60 |
+
temperature=0.75, # Controls randomness: lower is less random
|
| 61 |
+
num_return_sequences=1, # Number of sequences to generate
|
| 62 |
+
do_sample=True, # Whether to use sampling; set to False for greedy decoding
|
| 63 |
+
pad_token_id=tokenizer.eos_token_id # Important for open-ended generation
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
# Decode and print the generated text
|
| 67 |
+
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 68 |
+
print(generated_text)
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
## Model Architecture Details
|
| 72 |
+
|
| 73 |
+
The FanConections model implements a decoder-only transformer architecture with several novel components:
|
| 74 |
+
|
| 75 |
+
1. **FAN Components (CoLA_FAN)**: These specialized layers integrate Fourier-inspired transformations directly into the linear projections (particularly for Query, Key, and Value in attention). This allows the model to more effectively capture and utilize periodic or cyclical information present in the input data.
|
| 76 |
+
2. **Low-Rank Matrix Factorization (CoLA_Linear & CoLA_FAN)**: Both `CoLA_Linear` (used in MLPs) and `CoLA_FAN` (used in attention) reduce computational cost and parameter count by approximating large weight matrices with the product of two smaller, lower-rank matrices.
|
| 77 |
+
3. **HyperConnections**: An advanced form of residual connection. Instead of a simple addition, HyperConnections use learnable parameters (both static and dynamically computed based on the input) to create a more flexible and expressive way of combining outputs from previous layers with the current layer's computation. This helps in training deeper networks and managing information flow.
|
| 78 |
+
4. **RoPE Positional Embeddings**: Implements Rotary Positional Embeddings, which inject positional information by rotating parts of the embedding vectors, offering better relative position awareness.
|
| 79 |
+
5. **Progressive Dropout**: A dropout strategy where the probability of dropping units increases with the depth of the network layer, providing stronger regularization for deeper parts of the model.
|
| 80 |
+
6. **Flash Attention with Unpadding**: Utilizes optimized attention computations (FlashAttention) combined with techniques to handle variable-length sequences efficiently (unpadding/padding), maximizing GPU utilization.
|
| 81 |
+
7. **Muon Optimizer**: A custom optimizer used during pre-training, which combines Newton-Schulz orthogonalization for matrix parameters with an AdamW-like update for other parameters.
|
| 82 |
+
|
| 83 |
+
## Training
|
| 84 |
+
|
| 85 |
+
The model's pre-training involved:
|
| 86 |
+
|
| 87 |
+
- Distributed training across multiple GPUs.
|
| 88 |
+
- The specialized **Muon optimizer**, which incorporates Newton-Schulz orthogonalization for certain parameters and an AdamW-like mechanism for others.
|
| 89 |
+
- Progressive learning rate scheduling.
|
| 90 |
+
- Mixed precision (bfloat16) training for speed and memory efficiency.
|
| 91 |
+
- Strategic gradient checkpointing to manage memory consumption during the training of large sequences.
|
| 92 |
+
|
| 93 |
+
## Limitations
|
| 94 |
+
|
| 95 |
+
- **Context Window**: The model has a fixed context window (e.g., 1024 tokens in the provided code). It cannot process information beyond this limit in a single pass.
|
| 96 |
+
- **Domain Specificity**: While trained on a diverse dataset, performance might be suboptimal on highly specialized or out-of-distribution content.
|
| 97 |
+
- **Potential for Hallucinations**: Like all language models, FanConections can generate text that is factually incorrect, nonsensical, or misleading.
|
| 98 |
+
- **Bias**: The model may reflect biases present in its extensive training data.
|
| 99 |
+
|
| 100 |
+
## Citation
|
| 101 |
+
|
| 102 |
+
If you use FanConections or its architecture in your research, please cite:
|
| 103 |
+
|
| 104 |
+
```bibtex
|
| 105 |
+
@misc{fanconections2025,
|
| 106 |
+
author = {Kitsun},
|
| 107 |
+
title = {FanConections: Advanced Neural Connections for Language Modeling},
|
| 108 |
+
year = {2025},
|
| 109 |
+
publisher = {HuggingFace},
|
| 110 |
+
howpublished = {\\url{[https://huggingface.co/KitsuVp/FanConections](https://huggingface.co/KitsuVp/FanConections)}}
|
| 111 |
+
}
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
## License
|
| 115 |
+
|
| 116 |
+
This model is released under the [Apache 2.0 License](https://www.apache.org/licenses/LICENSE-2.0).
|
config.json
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"dropout": 0.101,
|
| 3 |
+
"embed_dim": 768,
|
| 4 |
+
"ff_dim": 2048,
|
| 5 |
+
"max_seq_len": 1024,
|
| 6 |
+
"num_decoder_layers": 12,
|
| 7 |
+
"num_gqa_groups": 6,
|
| 8 |
+
"num_heads": 12,
|
| 9 |
+
"p": 0.14,
|
| 10 |
+
"tie_weights": true,
|
| 11 |
+
"vocab_size": 49152,
|
| 12 |
+
"model_type": "fanformer",
|
| 13 |
+
"architectures": [
|
| 14 |
+
"MultiModalModel"
|
| 15 |
+
],
|
| 16 |
+
"auto_map": {
|
| 17 |
+
"AutoConfig": "model_architecture.FanConfig",
|
| 18 |
+
"AutoModelForCausalLM": "model_architecture.MultiModalModel"
|
| 19 |
+
}
|
| 20 |
+
}
|
merges.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:607ebeab78df2738e7039a379d2e0b022cdf42f7ff9b675f38be06d91f72a160
|
| 3 |
+
size 331514552
|
model_architecture.py
ADDED
|
@@ -0,0 +1,867 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
from typing import Any, Dict, List, Optional, cast
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from torch.nn import RMSNorm
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
import torch.utils.checkpoint as checkpoint
|
| 12 |
+
# Añade estas importaciones al principio del archivo
|
| 13 |
+
from huggingface_hub import PyTorchModelHubMixin
|
| 14 |
+
from transformers import PretrainedConfig
|
| 15 |
+
|
| 16 |
+
# Añade esta clase de configuración
|
| 17 |
+
class FanConfig(PretrainedConfig):
|
| 18 |
+
model_type = "fanformer"
|
| 19 |
+
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
vocab_size=32000,
|
| 23 |
+
embed_dim=768,
|
| 24 |
+
max_seq_len=1024,
|
| 25 |
+
num_heads=12,
|
| 26 |
+
num_decoder_layers=12,
|
| 27 |
+
ff_dim=2048,
|
| 28 |
+
dropout=0.12,
|
| 29 |
+
num_gqa_groups=6,
|
| 30 |
+
p=0.15,
|
| 31 |
+
tie_weights=True,
|
| 32 |
+
**kwargs
|
| 33 |
+
):
|
| 34 |
+
super().__init__(**kwargs)
|
| 35 |
+
self.vocab_size = vocab_size
|
| 36 |
+
self.embed_dim = embed_dim
|
| 37 |
+
self.max_seq_len = max_seq_len
|
| 38 |
+
self.num_heads = num_heads
|
| 39 |
+
self.num_decoder_layers = num_decoder_layers
|
| 40 |
+
self.ff_dim = ff_dim
|
| 41 |
+
self.dropout = dropout
|
| 42 |
+
self.num_gqa_groups = num_gqa_groups
|
| 43 |
+
self.p = p
|
| 44 |
+
self.tie_weights = tie_weights
|
| 45 |
+
|
| 46 |
+
############################################
|
| 47 |
+
# FUNCIONES DE INICIALIZACIÓN DE CAPAS
|
| 48 |
+
############################################
|
| 49 |
+
def init_linear(layer: nn.Linear, random_factor: float = 0.02):
|
| 50 |
+
gain = nn.init.calculate_gain('linear') * (1.0 + random.uniform(-random_factor, random_factor))
|
| 51 |
+
nn.init.xavier_uniform_(layer.weight, gain=gain)
|
| 52 |
+
if layer.bias is not None:
|
| 53 |
+
nn.init.zeros_(layer.bias)
|
| 54 |
+
|
| 55 |
+
def init_embedding(embedding: nn.Embedding):
|
| 56 |
+
nn.init.normal_(embedding.weight, mean=0.0, std=0.02)
|
| 57 |
+
|
| 58 |
+
def init_gate_parameter(gate: torch.Tensor, a: float = -0.02, b: float = 0.02):
|
| 59 |
+
nn.init.uniform_(gate, a=a, b=b)
|
| 60 |
+
|
| 61 |
+
############################################
|
| 62 |
+
# NUEVA CAPA: COLA NORMAL – CAPA LINEAL DE BAJO RANGO
|
| 63 |
+
############################################
|
| 64 |
+
class CoLA_Linear(nn.Module):
|
| 65 |
+
"""
|
| 66 |
+
Implementación de una capa lineal según la propuesta CoLA (normal).
|
| 67 |
+
Reemplaza la operación full-rank W*x por:
|
| 68 |
+
h' = B(σ(Ax))
|
| 69 |
+
donde A y B son matrices de bajo rango, y σ es una función de activación no lineal.
|
| 70 |
+
|
| 71 |
+
Por defecto, se utiliza rank = in_features // 4.
|
| 72 |
+
"""
|
| 73 |
+
def __init__(self, in_features: int, out_features: int, rank: Optional[int] = None, activation=F.gelu):
|
| 74 |
+
super().__init__()
|
| 75 |
+
if rank is None:
|
| 76 |
+
rank = in_features // 4
|
| 77 |
+
self.rank = rank
|
| 78 |
+
self.activation = activation
|
| 79 |
+
# Definición de las dos proyecciones
|
| 80 |
+
self.A = nn.Linear(in_features, rank, bias=False)
|
| 81 |
+
self.B = nn.Linear(rank, out_features, bias=True)
|
| 82 |
+
init_linear(self.A)
|
| 83 |
+
init_linear(self.B)
|
| 84 |
+
|
| 85 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 86 |
+
return self.B(self.activation(self.A(x)))
|
| 87 |
+
|
| 88 |
+
############################################
|
| 89 |
+
# NUEVA CAPA: COLA_FAN – CAPA LINEAL CON ANÁLISIS DE FOURIER PARA FANFORMER
|
| 90 |
+
############################################
|
| 91 |
+
class CoLA_FAN(nn.Module):
|
| 92 |
+
"""
|
| 93 |
+
Implementación de una capa CoLA con análisis de Fourier para FANformer.
|
| 94 |
+
Combina la eficiencia de CoLA con la capacidad de modelado de periodicidad de FANformer.
|
| 95 |
+
|
| 96 |
+
Esta implementación omite el dropout interno ya que la regularización ya se aplica en las
|
| 97 |
+
capas superiores (FANformerMultiheadAttention y flash attention). Esto evita una
|
| 98 |
+
regularización excesiva que podría limitar la capacidad de aprendizaje del modelo.
|
| 99 |
+
|
| 100 |
+
Parámetros:
|
| 101 |
+
in_features: Dimensión de entrada
|
| 102 |
+
out_features: Dimensión de salida
|
| 103 |
+
rank: Rango para compresión CoLA (por defecto in_features // 4)
|
| 104 |
+
p: Proporción de la dimensión dedicada al modelado periódico (por defecto 0.15)
|
| 105 |
+
activation: Función de activación para las proyecciones
|
| 106 |
+
depth: Profundidad de la capa en la red (mantenido para compatibilidad)
|
| 107 |
+
"""
|
| 108 |
+
def __init__(self, in_features: int, out_features: int, rank: Optional[int] = None,
|
| 109 |
+
p: float = 0.15, activation=F.gelu, dropout: float = 0.0, depth: int = 1):
|
| 110 |
+
super().__init__()
|
| 111 |
+
if rank is None:
|
| 112 |
+
rank = in_features // 4
|
| 113 |
+
self.rank = rank
|
| 114 |
+
self.activation = activation
|
| 115 |
+
self.p = p
|
| 116 |
+
|
| 117 |
+
# Calcular dimensiones para componentes periódicos y no periódicos
|
| 118 |
+
p_dim = int(out_features * p) # Dimensión para componente periódico (antes de cos/sin)
|
| 119 |
+
non_p_dim = out_features - 2 * p_dim # Dimensión para componente no periódico
|
| 120 |
+
|
| 121 |
+
# Proyecciones para componente periódico
|
| 122 |
+
self.A_p = nn.Linear(in_features, rank, bias=False)
|
| 123 |
+
self.B_p = nn.Linear(rank, p_dim, bias=False) # Sin bias para transformación periódica
|
| 124 |
+
|
| 125 |
+
# Proyecciones para componente no periódico (CoLA estándar)
|
| 126 |
+
self.A_np = nn.Linear(in_features, rank, bias=False)
|
| 127 |
+
self.B_np = nn.Linear(rank, non_p_dim, bias=True)
|
| 128 |
+
|
| 129 |
+
# Se elimina el dropout interno para evitar regularización excesiva
|
| 130 |
+
# ya que el dropout se aplica en capas superiores (FANformerMultiheadAttention)
|
| 131 |
+
|
| 132 |
+
# Inicialización
|
| 133 |
+
init_linear(self.A_p)
|
| 134 |
+
init_linear(self.B_p)
|
| 135 |
+
init_linear(self.A_np)
|
| 136 |
+
init_linear(self.B_np)
|
| 137 |
+
|
| 138 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 139 |
+
# Componente periódico sin dropout
|
| 140 |
+
p_activation = self.activation(self.A_p(x))
|
| 141 |
+
p_proj = self.B_p(p_activation)
|
| 142 |
+
|
| 143 |
+
# Componente no periódico sin dropout
|
| 144 |
+
np_activation = self.activation(self.A_np(x))
|
| 145 |
+
np_proj = self.B_np(np_activation)
|
| 146 |
+
|
| 147 |
+
# Combinar usando transformaciones de Fourier (cos/sin) y componente regular
|
| 148 |
+
return torch.cat([torch.cos(p_proj), torch.sin(p_proj), np_proj], dim=-1)
|
| 149 |
+
|
| 150 |
+
############################################
|
| 151 |
+
# UTILIDAD: CREACIÓN DE DROPOUT PROGRESIVO
|
| 152 |
+
############################################
|
| 153 |
+
def progressive_dropout(p: float, depth: int) -> nn.Dropout:
|
| 154 |
+
"""
|
| 155 |
+
Implementa un dropout progresivo que aumenta logarítmicamente con la profundidad.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
p (float): Probabilidad base de dropout
|
| 159 |
+
depth (int): Profundidad de la capa
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
nn.Dropout: Módulo de dropout con probabilidad ajustada
|
| 163 |
+
"""
|
| 164 |
+
if p == 0.0:
|
| 165 |
+
return nn.Dropout(0.0)
|
| 166 |
+
|
| 167 |
+
# Base logarítmica (ajustable según necesidades)
|
| 168 |
+
base = 1.4
|
| 169 |
+
|
| 170 |
+
# Usar logaritmo para un crecimiento más lento en capas profundas
|
| 171 |
+
return nn.Dropout(p * (1 + math.log(depth + 1, base) * 0.04))
|
| 172 |
+
|
| 173 |
+
############################################
|
| 174 |
+
# UTILIDADES: ROPE UNIFICADO CON PRECÁLCULO
|
| 175 |
+
############################################
|
| 176 |
+
def get_rope_buffer(seq_len: int, head_dim: int, device: torch.device, dtype: torch.dtype):
|
| 177 |
+
inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim))
|
| 178 |
+
pos = torch.arange(seq_len, device=device).float().unsqueeze(1)
|
| 179 |
+
sinusoid_inp = pos * inv_freq.unsqueeze(0)
|
| 180 |
+
cos = torch.cos(sinusoid_inp).to(dtype)
|
| 181 |
+
sin = torch.sin(sinusoid_inp).to(dtype)
|
| 182 |
+
cos = cos.unsqueeze(0).unsqueeze(0)
|
| 183 |
+
sin = sin.unsqueeze(0).unsqueeze(0)
|
| 184 |
+
return cos, sin
|
| 185 |
+
|
| 186 |
+
def apply_rope_vectorized(x: torch.Tensor) -> torch.Tensor:
|
| 187 |
+
B, num_heads, T, head_dim = x.shape
|
| 188 |
+
if head_dim % 2 != 0:
|
| 189 |
+
raise ValueError("head_dim debe ser par para RoPE")
|
| 190 |
+
cos, sin = get_rope_buffer(T, head_dim, x.device, x.dtype)
|
| 191 |
+
x_reshaped = x.view(B, num_heads, T, head_dim // 2, 2)
|
| 192 |
+
x_even = x_reshaped[..., 0]
|
| 193 |
+
x_odd = x_reshaped[..., 1]
|
| 194 |
+
x_rotated_even = x_even * cos - x_odd * sin
|
| 195 |
+
x_rotated_odd = x_even * sin + x_odd * cos
|
| 196 |
+
x_rotated = torch.stack([x_rotated_even, x_rotated_odd], dim=-1)
|
| 197 |
+
result = x_rotated.flatten(-2)
|
| 198 |
+
return result
|
| 199 |
+
|
| 200 |
+
############################################
|
| 201 |
+
# GATED RESIDUALS
|
| 202 |
+
############################################
|
| 203 |
+
class HyperConnections(nn.Module):
|
| 204 |
+
def __init__(self, d_model: int, expansion_rate: int = 4, dropout: float = 0.12, depth: int = 1):
|
| 205 |
+
super().__init__()
|
| 206 |
+
self.expansion_rate = expansion_rate
|
| 207 |
+
|
| 208 |
+
# Determinar si CUDA está disponible
|
| 209 |
+
device = torch.device('cuda')
|
| 210 |
+
|
| 211 |
+
# Definición de las matrices estáticas - directamente en CUDA con bfloat16
|
| 212 |
+
self.static_beta = nn.Parameter(torch.ones(expansion_rate, device=device, dtype=torch.bfloat16))
|
| 213 |
+
|
| 214 |
+
# Inicialización de alpha según el paper - directamente en CUDA con bfloat16
|
| 215 |
+
init_alpha0 = torch.zeros((expansion_rate, 1), device=device, dtype=torch.bfloat16)
|
| 216 |
+
init_alpha0[depth % expansion_rate, 0] = 1.
|
| 217 |
+
|
| 218 |
+
self.static_alpha = nn.Parameter(torch.cat(
|
| 219 |
+
[init_alpha0, torch.eye(expansion_rate, device=device, dtype=torch.bfloat16)], dim=1))
|
| 220 |
+
|
| 221 |
+
# Parámetros para la parte dinámica - directamente en CUDA con bfloat16
|
| 222 |
+
self.dynamic_alpha_fn = nn.Parameter(torch.zeros((d_model, expansion_rate+1), device=device, dtype=torch.bfloat16))
|
| 223 |
+
self.dynamic_alpha_scale = nn.Parameter(torch.ones(1, device=device, dtype=torch.bfloat16) * 0.01)
|
| 224 |
+
self.dynamic_beta_fn = nn.Parameter(torch.zeros((d_model), device=device, dtype=torch.bfloat16))
|
| 225 |
+
self.dynamic_beta_scale = nn.Parameter(torch.ones(1, device=device, dtype=torch.bfloat16) * 0.01)
|
| 226 |
+
|
| 227 |
+
# Normalización para estabilidad
|
| 228 |
+
self.layer_norm = nn.RMSNorm(d_model, eps=1e-5)
|
| 229 |
+
|
| 230 |
+
# Dropout
|
| 231 |
+
self.dropout = nn.Dropout(dropout)
|
| 232 |
+
|
| 233 |
+
# Pre-calcular buffers estáticos
|
| 234 |
+
self.register_buffer(
|
| 235 |
+
'static_alpha_expanded',
|
| 236 |
+
self.static_alpha.unsqueeze(0).unsqueeze(0)
|
| 237 |
+
)
|
| 238 |
+
self.register_buffer(
|
| 239 |
+
'static_beta_expanded',
|
| 240 |
+
self.static_beta.unsqueeze(0).unsqueeze(0)
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
def _compute_dynamic_params(self, norm_x):
|
| 244 |
+
"""Calcular parámetros dinámicos (alpha y beta)"""
|
| 245 |
+
dynamic_alpha = F.tanh(norm_x @ self.dynamic_alpha_fn) * self.dynamic_alpha_scale
|
| 246 |
+
dynamic_beta = F.tanh(norm_x @ self.dynamic_beta_fn) * self.dynamic_beta_scale
|
| 247 |
+
|
| 248 |
+
# Preparar para broadcasting
|
| 249 |
+
dynamic_alpha = dynamic_alpha.unsqueeze(2) # [B, T, 1, E+1]
|
| 250 |
+
dynamic_beta = dynamic_beta.unsqueeze(2) # [B, T, 1]
|
| 251 |
+
|
| 252 |
+
# Combinar static y dynamic
|
| 253 |
+
alpha = self.static_alpha_expanded + dynamic_alpha # [B, T, E, E+1]
|
| 254 |
+
beta = self.static_beta_expanded + dynamic_beta # [B, T, E]
|
| 255 |
+
|
| 256 |
+
return alpha, beta
|
| 257 |
+
|
| 258 |
+
def _compute_width_connection(self, x, alpha):
|
| 259 |
+
"""Calcular la conexión de ancho (width connection)"""
|
| 260 |
+
alpha_t = alpha.transpose(2, 3) # [B, T, E+1, E]
|
| 261 |
+
x_expanded = x.unsqueeze(2).expand(-1, -1, self.expansion_rate, -1) # [B, T, E, D]
|
| 262 |
+
|
| 263 |
+
# Calcular mix_h con un solo einsum
|
| 264 |
+
mix_h = torch.einsum('btij,btjd->btid', alpha_t, x_expanded) # [B, T, E+1, D]
|
| 265 |
+
return mix_h
|
| 266 |
+
|
| 267 |
+
def _compute_depth_connection(self, residual, beta, mix_h):
|
| 268 |
+
"""Calcular la conexión de profundidad (depth connection) y combinar"""
|
| 269 |
+
residual = self.dropout(residual)
|
| 270 |
+
residual_expanded = residual.unsqueeze(2).expand(-1, -1, self.expansion_rate, -1)
|
| 271 |
+
weighted_residual = residual_expanded * beta.unsqueeze(-1) # [B, T, E, D]
|
| 272 |
+
|
| 273 |
+
# Extraer mix_h_rest (todas excepto primera)
|
| 274 |
+
mix_h_rest = mix_h[:, :, 1:, :] # [B, T, E, D]
|
| 275 |
+
|
| 276 |
+
# Combinar y reducir
|
| 277 |
+
h = weighted_residual + mix_h_rest # [B, T, E, D]
|
| 278 |
+
output = h.sum(dim=2) # [B, T, D]
|
| 279 |
+
|
| 280 |
+
return output
|
| 281 |
+
|
| 282 |
+
def forward(self, x, residual):
|
| 283 |
+
"""Forward pass con checkpointing para ahorrar memoria"""
|
| 284 |
+
# Convertir las entradas a bfloat16 si no lo están ya
|
| 285 |
+
x = x.to(dtype=torch.bfloat16)
|
| 286 |
+
residual = residual.to(dtype=torch.bfloat16)
|
| 287 |
+
|
| 288 |
+
# Paso 1: Normalizar entrada (no checkpointed - bajo uso de memoria)
|
| 289 |
+
norm_x = self.layer_norm(x)
|
| 290 |
+
|
| 291 |
+
# Función auxiliar para aplicar checkpoint y forzar el tipo de retorno
|
| 292 |
+
def apply_checkpoint(func, *args):
|
| 293 |
+
return cast(torch.Tensor, checkpoint.checkpoint(func, *args, use_reentrant=False))
|
| 294 |
+
|
| 295 |
+
# Paso 2: Checkpoint para cálculo de parámetros dinámicos
|
| 296 |
+
alpha, beta = apply_checkpoint(self._compute_dynamic_params, norm_x)
|
| 297 |
+
|
| 298 |
+
# Paso 3: Checkpoint para width connection
|
| 299 |
+
mix_h = apply_checkpoint(self._compute_width_connection, x, alpha)
|
| 300 |
+
|
| 301 |
+
# Paso 4: Checkpoint para depth connection y combinación final
|
| 302 |
+
output = apply_checkpoint(self._compute_depth_connection, residual, beta, mix_h)
|
| 303 |
+
|
| 304 |
+
return output
|
| 305 |
+
|
| 306 |
+
############################################
|
| 307 |
+
# MÓDULO AUXILIAR: GQA FAN LINEAR
|
| 308 |
+
############################################
|
| 309 |
+
class GQAFANLinear(nn.Module):
|
| 310 |
+
"""
|
| 311 |
+
Proyección de GQA utilizando CoLA_FAN para FANformer.
|
| 312 |
+
Divide la proyección en grupos, usando internamente una capa CoLA_FAN.
|
| 313 |
+
|
| 314 |
+
Se espera que out_features sea divisible por num_heads.
|
| 315 |
+
|
| 316 |
+
Parámetros:
|
| 317 |
+
in_features: Dimensión de entrada
|
| 318 |
+
out_features: Dimensión de salida
|
| 319 |
+
num_heads: Número de cabezales de atención
|
| 320 |
+
num_gqa_groups: Número de grupos para GQA
|
| 321 |
+
p: Proporción de la dimensión dedicada al modelado periódico
|
| 322 |
+
divide_dim: Si se debe dividir la dimensión (por defecto False)
|
| 323 |
+
"""
|
| 324 |
+
def __init__(self, in_features: int, out_features: int, num_heads: int,
|
| 325 |
+
num_gqa_groups: int, p: float = 0.15, divide_dim: bool = False):
|
| 326 |
+
super().__init__()
|
| 327 |
+
if out_features % num_heads != 0:
|
| 328 |
+
raise ValueError("out_features debe ser divisible por num_heads")
|
| 329 |
+
self.num_heads = num_heads
|
| 330 |
+
self.num_gqa_groups = num_gqa_groups
|
| 331 |
+
self.rep_factor = num_heads // num_gqa_groups
|
| 332 |
+
|
| 333 |
+
self.divide_factor = 1
|
| 334 |
+
self.head_dim = (out_features // num_heads) // self.divide_factor
|
| 335 |
+
|
| 336 |
+
self.inter_dim = num_gqa_groups * self.head_dim
|
| 337 |
+
# Usamos CoLA_FAN en lugar de CoLA_Linear:
|
| 338 |
+
self.linear = CoLA_FAN(in_features, self.inter_dim, rank=in_features // 4, p=p)
|
| 339 |
+
|
| 340 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 341 |
+
B, T, _ = x.shape
|
| 342 |
+
out = self.linear(x)
|
| 343 |
+
out = out.view(B, T, self.num_gqa_groups, self.head_dim)
|
| 344 |
+
out = out.repeat(1, 1, self.rep_factor, 1)
|
| 345 |
+
out = out.view(B, T, self.num_heads, self.head_dim)
|
| 346 |
+
return out
|
| 347 |
+
|
| 348 |
+
############################################
|
| 349 |
+
# MÓDULO: ATENCIÓN MULTI-CABEZA CON FANFORMER
|
| 350 |
+
############################################
|
| 351 |
+
class FANformerMultiheadAttention(nn.Module):
|
| 352 |
+
"""
|
| 353 |
+
Implementación de la atención multi-cabeza con FANformer.
|
| 354 |
+
Aplica normalización a Q, K, V individualmente y utiliza unpadding para mejorar el rendimiento.
|
| 355 |
+
Incorpora modelado de periodicidad a través de proyecciones CoLA_FAN.
|
| 356 |
+
"""
|
| 357 |
+
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.12, use_rope: bool = True,
|
| 358 |
+
layer_index: int = 1, max_seq_len: int = 512, p: float = 0.15,
|
| 359 |
+
num_gqa_groups: Optional[int] = None, debug: bool = True,
|
| 360 |
+
use_pre_norm: bool = False):
|
| 361 |
+
super().__init__()
|
| 362 |
+
self.embed_dim = embed_dim
|
| 363 |
+
self.num_heads = num_heads
|
| 364 |
+
self.debug = debug
|
| 365 |
+
self.layer_name = f"Layer_{layer_index}"
|
| 366 |
+
self.layer_index = layer_index
|
| 367 |
+
self.use_pre_norm = use_pre_norm
|
| 368 |
+
self.p = p # Proporción para periodicidad
|
| 369 |
+
|
| 370 |
+
if embed_dim % num_heads != 0:
|
| 371 |
+
raise ValueError("embed_dim debe ser divisible por num_heads")
|
| 372 |
+
|
| 373 |
+
self.head_dim = embed_dim // num_heads
|
| 374 |
+
self.use_rope = use_rope
|
| 375 |
+
|
| 376 |
+
if num_gqa_groups is None:
|
| 377 |
+
num_gqa_groups = num_heads
|
| 378 |
+
|
| 379 |
+
try:
|
| 380 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 381 |
+
self.flash_attn_func = flash_attn_func
|
| 382 |
+
self.flash_attn_varlen_func = flash_attn_varlen_func
|
| 383 |
+
except ImportError as e:
|
| 384 |
+
raise ImportError(f"Error al inicializar FlashAttention: {e}")
|
| 385 |
+
|
| 386 |
+
# Para el unpadding
|
| 387 |
+
try:
|
| 388 |
+
from flash_attn.bert_padding import unpad_input, pad_input
|
| 389 |
+
self.unpad_input = unpad_input
|
| 390 |
+
self.pad_input = pad_input
|
| 391 |
+
except ImportError as e:
|
| 392 |
+
raise ImportError(f"Error al importar funciones de padding: {e}")
|
| 393 |
+
|
| 394 |
+
# Inicialización de parámetros de escala
|
| 395 |
+
self.ssmax_scale = nn.Parameter(torch.ones(num_heads, dtype=torch.bfloat16) * 0.168)
|
| 396 |
+
nn.init.uniform_(self.ssmax_scale, a=0.166, b=0.170)
|
| 397 |
+
self.register_buffer('seq_scale', torch.log(torch.tensor(max_seq_len, dtype=torch.bfloat16)))
|
| 398 |
+
|
| 399 |
+
# Capas de normalización para la entrada (Pre-Norm en primer bloque o QKV-Norm para los demás)
|
| 400 |
+
self.norm = nn.RMSNorm(embed_dim, eps=1e-5)
|
| 401 |
+
|
| 402 |
+
# Capas de dropout (simplificadas)
|
| 403 |
+
self.attention_dropout = progressive_dropout(dropout, depth=1)
|
| 404 |
+
# Eliminado: self.projection_dropout = progressive_dropout(dropout * 1.1, depth=1)
|
| 405 |
+
self.output_dropout = progressive_dropout(dropout, depth=1)
|
| 406 |
+
|
| 407 |
+
# Proyecciones para Q, K, V usando GQAFANLinear (implementación FANformer)
|
| 408 |
+
self.Wq = GQAFANLinear(embed_dim, embed_dim, num_heads, num_gqa_groups, p=p)
|
| 409 |
+
self.Wk = GQAFANLinear(embed_dim, embed_dim, num_heads, num_gqa_groups, p=p)
|
| 410 |
+
self.Wv = GQAFANLinear(embed_dim, embed_dim, num_heads, num_gqa_groups, p=p)
|
| 411 |
+
|
| 412 |
+
# Proyección de salida (se mantiene como CoLA_Linear)
|
| 413 |
+
self.out_proj = CoLA_Linear(embed_dim, embed_dim, rank=embed_dim // 4)
|
| 414 |
+
|
| 415 |
+
def scaled_dot_product_attention_flash_unpadded(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
| 416 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 417 |
+
is_causal: bool = False) -> torch.Tensor:
|
| 418 |
+
B, H, S, D = q.shape # batch, heads, sequence length, head dimension
|
| 419 |
+
|
| 420 |
+
if attention_mask is None:
|
| 421 |
+
# Si no hay máscara de atención, usamos la versión regular
|
| 422 |
+
return self.scaled_dot_product_attention_flash(q, k, v, mask=None, is_causal=is_causal)
|
| 423 |
+
|
| 424 |
+
# Convertir las tensiones a [B, S, H, D] para unpad_input
|
| 425 |
+
q_unpad = q.permute(0, 2, 1, 3) # [B, S, H, D]
|
| 426 |
+
k_unpad = k.permute(0, 2, 1, 3) # [B, S, H, D]
|
| 427 |
+
v_unpad = v.permute(0, 2, 1, 3) # [B, S, H, D]
|
| 428 |
+
|
| 429 |
+
# Preparar máscara: convertir a bool si es necesario
|
| 430 |
+
if attention_mask.dtype != torch.bool:
|
| 431 |
+
attention_mask = attention_mask.bool()
|
| 432 |
+
|
| 433 |
+
# Hacer unpadding de los tensores
|
| 434 |
+
q_unpadded, indices_q, cu_seqlens_q, max_seqlen_q, _ = self.unpad_input(q_unpad, attention_mask)
|
| 435 |
+
k_unpadded, indices_k, cu_seqlens_k, max_seqlen_k, _ = self.unpad_input(k_unpad, attention_mask)
|
| 436 |
+
v_unpadded, _, _, _, _ = self.unpad_input(v_unpad, attention_mask)
|
| 437 |
+
|
| 438 |
+
# Reacomodar para flash_attn_varlen_func: [Total, H, D]
|
| 439 |
+
q_unpadded = q_unpadded.reshape(-1, H, D)
|
| 440 |
+
k_unpadded = k_unpadded.reshape(-1, H, D)
|
| 441 |
+
v_unpadded = v_unpadded.reshape(-1, H, D)
|
| 442 |
+
|
| 443 |
+
# Normalizar vectores Q y K para mejorar estabilidad numérica
|
| 444 |
+
q_norm = F.normalize(q_unpadded, p=2, dim=-1).to(torch.bfloat16)
|
| 445 |
+
k_norm = F.normalize(k_unpadded, p=2, dim=-1).to(torch.bfloat16)
|
| 446 |
+
|
| 447 |
+
# Ajustar q con factor de escala
|
| 448 |
+
s = self.ssmax_scale.view(1, H, 1)
|
| 449 |
+
q_adjusted = q_norm * (self.seq_scale * s)
|
| 450 |
+
|
| 451 |
+
# Factor de escala para softmax
|
| 452 |
+
softmax_scale = 1.0 / math.sqrt(D)
|
| 453 |
+
|
| 454 |
+
try:
|
| 455 |
+
# Usar flash attention sin padding
|
| 456 |
+
output_unpadded = self.flash_attn_varlen_func(
|
| 457 |
+
q_adjusted, k_norm, v_unpadded,
|
| 458 |
+
cu_seqlens_q, cu_seqlens_k,
|
| 459 |
+
max_seqlen_q, max_seqlen_k,
|
| 460 |
+
dropout_p=self.attention_dropout.p, # Aplicamos dropout aquí
|
| 461 |
+
softmax_scale=softmax_scale,
|
| 462 |
+
causal=is_causal
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
# Volver a aplicar padding
|
| 466 |
+
output_padded = self.pad_input(output_unpadded, indices_q, B, S)
|
| 467 |
+
|
| 468 |
+
# Reorganizar a [B, H, S, D]
|
| 469 |
+
output = output_padded.reshape(B, S, H, D).permute(0, 2, 1, 3)
|
| 470 |
+
|
| 471 |
+
return output
|
| 472 |
+
|
| 473 |
+
except Exception as e:
|
| 474 |
+
raise RuntimeError(f"Error en flash_attn_varlen_func: {e}")
|
| 475 |
+
|
| 476 |
+
def scaled_dot_product_attention_flash(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
| 477 |
+
mask: Optional[torch.Tensor] = None,
|
| 478 |
+
is_causal: bool = False) -> torch.Tensor:
|
| 479 |
+
# Normalizar vectores Q y K para mejorar estabilidad numérica
|
| 480 |
+
q_norm = F.normalize(q, p=2, dim=-1).to(torch.bfloat16)
|
| 481 |
+
k_norm = F.normalize(k, p=2, dim=-1).to(torch.bfloat16)
|
| 482 |
+
|
| 483 |
+
# Ajustar q con factor de escala
|
| 484 |
+
s = self.ssmax_scale.view(-1, 1, 1)
|
| 485 |
+
q_adjusted = q_norm * (self.seq_scale * s)
|
| 486 |
+
|
| 487 |
+
# Preparar tensores para Flash Attention (requiere shape [B, S, H, D])
|
| 488 |
+
q_trans = q_adjusted.permute(0, 2, 1, 3)
|
| 489 |
+
k_trans = k_norm.permute(0, 2, 1, 3)
|
| 490 |
+
v_trans = v.permute(0, 2, 1, 3)
|
| 491 |
+
|
| 492 |
+
# Verificar dimensiones
|
| 493 |
+
if q_trans.size(-1) != k_trans.size(-1):
|
| 494 |
+
raise ValueError(f"Las dimensiones de head no coinciden: q={q_trans.size(-1)}, k={k_trans.size(-1)}")
|
| 495 |
+
|
| 496 |
+
# Factor de escala para softmax
|
| 497 |
+
softmax_scale = 1.0 / math.sqrt(q_trans.size(-1))
|
| 498 |
+
|
| 499 |
+
try:
|
| 500 |
+
# Aplicar Flash Attention
|
| 501 |
+
output = self.flash_attn_func(
|
| 502 |
+
q_trans, k_trans, v_trans,
|
| 503 |
+
dropout_p=self.attention_dropout.p, # Aplicamos dropout aquí
|
| 504 |
+
softmax_scale=softmax_scale,
|
| 505 |
+
causal=is_causal
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
if output is None:
|
| 509 |
+
raise ValueError("flash_attn_func devolvió None. Verifica las dimensiones y tipos de los tensores de entrada.")
|
| 510 |
+
|
| 511 |
+
# Volver a la forma original
|
| 512 |
+
output = output.permute(0, 2, 1, 3)
|
| 513 |
+
return output
|
| 514 |
+
|
| 515 |
+
except Exception as e:
|
| 516 |
+
raise RuntimeError(f"Error en flash_attn_func: {e}")
|
| 517 |
+
|
| 518 |
+
def forward(self, X: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, causal: bool = True) -> torch.Tensor:
|
| 519 |
+
B, T, _ = X.shape
|
| 520 |
+
|
| 521 |
+
# Implementación de HybridNorm*
|
| 522 |
+
if self.use_pre_norm:
|
| 523 |
+
# Primer bloque: Pre-Norm en atención
|
| 524 |
+
X_norm = self.norm(X)
|
| 525 |
+
# Proyecciones para Q, K, V con FANformer
|
| 526 |
+
Q = self.Wq(X_norm) # [B, T, num_heads, head_dim]
|
| 527 |
+
K = self.Wk(X_norm) # [B, T, num_heads, head_dim]
|
| 528 |
+
V = self.Wv(X_norm) # [B, T, num_heads, head_dim]
|
| 529 |
+
else:
|
| 530 |
+
# Otros bloques: QKV-Norm
|
| 531 |
+
Q = self.Wq(self.norm(X)) # [B, T, num_heads, head_dim]
|
| 532 |
+
K = self.Wk(self.norm(X)) # [B, T, num_heads, head_dim]
|
| 533 |
+
V = self.Wv(self.norm(X)) # [B, T, num_heads, head_dim]
|
| 534 |
+
|
| 535 |
+
# Permutar a formato [B, num_heads, T, head_dim]
|
| 536 |
+
Q = Q.permute(0, 2, 1, 3)
|
| 537 |
+
K = K.permute(0, 2, 1, 3)
|
| 538 |
+
V = V.permute(0, 2, 1, 3)
|
| 539 |
+
|
| 540 |
+
# Aplicar RoPE si está activado
|
| 541 |
+
if self.use_rope:
|
| 542 |
+
Q = apply_rope_vectorized(Q)
|
| 543 |
+
K = apply_rope_vectorized(K)
|
| 544 |
+
|
| 545 |
+
# Convertir a bfloat16 para flash attention
|
| 546 |
+
Q = Q.to(torch.bfloat16)
|
| 547 |
+
K = K.to(torch.bfloat16)
|
| 548 |
+
V = V.to(torch.bfloat16)
|
| 549 |
+
|
| 550 |
+
# Procesar la secuencia utilizando unpadding si hay máscara de atención
|
| 551 |
+
if attention_mask is not None:
|
| 552 |
+
attn_output = self.scaled_dot_product_attention_flash_unpadded(
|
| 553 |
+
Q, K, V,
|
| 554 |
+
attention_mask=attention_mask,
|
| 555 |
+
is_causal=causal
|
| 556 |
+
)
|
| 557 |
+
else:
|
| 558 |
+
# Si no hay máscara, usar la versión regular
|
| 559 |
+
attn_output = self.scaled_dot_product_attention_flash(
|
| 560 |
+
Q, K, V,
|
| 561 |
+
mask=None,
|
| 562 |
+
is_causal=causal
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
+
# Eliminada la aplicación redundante de dropout:
|
| 566 |
+
# attn_output = self.attention_dropout(attn_output)
|
| 567 |
+
|
| 568 |
+
# Reorganizar la salida y aplicar proyección final
|
| 569 |
+
out = attn_output.permute(0, 2, 1, 3).contiguous()
|
| 570 |
+
out = out.reshape(B, T, self.embed_dim)
|
| 571 |
+
out = self.output_dropout(self.out_proj(out))
|
| 572 |
+
|
| 573 |
+
return out
|
| 574 |
+
|
| 575 |
+
############################################
|
| 576 |
+
# NUEVO MÓDULO: SWIGLU CON COLA (MLP)
|
| 577 |
+
############################################
|
| 578 |
+
class SwiGLU(nn.Module):
|
| 579 |
+
def __init__(self, in_features: int, hidden_features: int, dropout: float = 0.12, depth: int = 1):
|
| 580 |
+
super().__init__()
|
| 581 |
+
# Reemplazamos fc1 y fc2 por CoLA_Linear
|
| 582 |
+
self.fc1 = CoLA_Linear(in_features, hidden_features * 2, rank=in_features // 4)
|
| 583 |
+
self.fc2 = CoLA_Linear(hidden_features, in_features, rank=hidden_features // 4)
|
| 584 |
+
self.dropout = progressive_dropout(dropout, depth)
|
| 585 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 586 |
+
x_proj = self.fc1(x)
|
| 587 |
+
x1, x2 = x_proj.chunk(2, dim=-1)
|
| 588 |
+
x_out = x1 * F.silu(x2)
|
| 589 |
+
x_out = self.dropout(x_out)
|
| 590 |
+
return self.fc2(x_out)
|
| 591 |
+
|
| 592 |
+
############################################
|
| 593 |
+
# BLOQUE DEL FANFORMER: CAPA CON ATENCIÓN Y MLP (Decoder-Only)
|
| 594 |
+
############################################
|
| 595 |
+
class FANformerLayer(nn.Module):
|
| 596 |
+
"""
|
| 597 |
+
Implementación de capa de transformador con FANformer.
|
| 598 |
+
Similar a RegularTransformerLayer pero utiliza FANformerMultiheadAttention.
|
| 599 |
+
"""
|
| 600 |
+
def __init__(self, embed_dim: int, num_heads: int, ff_dim: int, dropout: float = 0.12,
|
| 601 |
+
layer_index: int = 1, num_gqa_groups: Optional[int] = None,
|
| 602 |
+
is_first_layer: bool = False, p: float = 0.15):
|
| 603 |
+
super().__init__()
|
| 604 |
+
self.is_first_layer = is_first_layer
|
| 605 |
+
|
| 606 |
+
# En HybridNorm*, el primer bloque usa Pre-Norm en MHA
|
| 607 |
+
# Usamos FANformerMultiheadAttention en lugar de RegularMultiheadAttention
|
| 608 |
+
self.attn = FANformerMultiheadAttention(
|
| 609 |
+
embed_dim, num_heads, dropout=dropout, use_rope=True,
|
| 610 |
+
layer_index=layer_index, num_gqa_groups=num_gqa_groups,
|
| 611 |
+
use_pre_norm=is_first_layer, p=p
|
| 612 |
+
)
|
| 613 |
+
|
| 614 |
+
# Reemplazando GatedResidual con HyperConnections para atención
|
| 615 |
+
self.hyper_conn_attn = HyperConnections(
|
| 616 |
+
embed_dim,
|
| 617 |
+
expansion_rate=2,
|
| 618 |
+
dropout=dropout,
|
| 619 |
+
depth=layer_index
|
| 620 |
+
)
|
| 621 |
+
|
| 622 |
+
# Post-Norm para FFN (HybridNorm)
|
| 623 |
+
self.ffn_norm = nn.RMSNorm(embed_dim, eps=1e-5)
|
| 624 |
+
self.mlp = SwiGLU(embed_dim, ff_dim, dropout, depth=1)
|
| 625 |
+
|
| 626 |
+
# Reemplazando GatedResidual con HyperConnections para FFN
|
| 627 |
+
self.hyper_conn_mlp = HyperConnections(
|
| 628 |
+
embed_dim,
|
| 629 |
+
expansion_rate=2,
|
| 630 |
+
dropout=dropout,
|
| 631 |
+
depth=layer_index
|
| 632 |
+
)
|
| 633 |
+
|
| 634 |
+
# Post-Norm final (HybridNorm)
|
| 635 |
+
self.post_ffn_norm = nn.RMSNorm(embed_dim, eps=1e-5)
|
| 636 |
+
|
| 637 |
+
def _attn_forward(self, x: torch.Tensor, tgt_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 638 |
+
"""Parte de atención sin HyperConnections"""
|
| 639 |
+
return self.attn(x, tgt_mask)
|
| 640 |
+
|
| 641 |
+
def _ffn_forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 642 |
+
"""Parte de feed-forward sin HyperConnections"""
|
| 643 |
+
ffn_input = self.ffn_norm(x)
|
| 644 |
+
return self.mlp(ffn_input)
|
| 645 |
+
|
| 646 |
+
def _post_ffn_norm_forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 647 |
+
"""Normalización final"""
|
| 648 |
+
return self.post_ffn_norm(x)
|
| 649 |
+
|
| 650 |
+
def forward(self, x: torch.Tensor, tgt_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 651 |
+
"""Forward con checkpointing selectivo"""
|
| 652 |
+
# Función auxiliar para aplicar checkpoint y forzar el tipo de retorno
|
| 653 |
+
def apply_checkpoint(func, *args) -> torch.Tensor:
|
| 654 |
+
# Usamos cast para indicar explícitamente al verificador de tipos
|
| 655 |
+
# que el resultado de checkpoint.checkpoint es un tensor
|
| 656 |
+
return cast(torch.Tensor, checkpoint.checkpoint(func, *args, use_reentrant=False))
|
| 657 |
+
|
| 658 |
+
# Bloque de atención con HybridNorm
|
| 659 |
+
if self.is_first_layer:
|
| 660 |
+
# Primer bloque: Pre-Norm + QKV-Norm
|
| 661 |
+
attention_output = apply_checkpoint(self._attn_forward, x, tgt_mask)
|
| 662 |
+
attention_output = F.dropout(attention_output, p=self.hyper_conn_attn.dropout.p, training=self.training)
|
| 663 |
+
hidden_states = self.hyper_conn_attn(x, attention_output)
|
| 664 |
+
else:
|
| 665 |
+
# Otros bloques: QKV-Norm
|
| 666 |
+
attention_output = apply_checkpoint(self._attn_forward, x, tgt_mask)
|
| 667 |
+
attention_output = F.dropout(attention_output, p=self.hyper_conn_attn.dropout.p, training=self.training)
|
| 668 |
+
hidden_states = self.hyper_conn_attn(x, attention_output)
|
| 669 |
+
|
| 670 |
+
# Paso 3: Aplicar checkpoint al feed-forward
|
| 671 |
+
ffn_output = apply_checkpoint(self._ffn_forward, hidden_states)
|
| 672 |
+
|
| 673 |
+
# Aplicar dropout a la salida de FFN
|
| 674 |
+
ffn_output = F.dropout(ffn_output, p=self.hyper_conn_mlp.dropout.p, training=self.training)
|
| 675 |
+
|
| 676 |
+
# Paso 4: Aplicar HyperConnections
|
| 677 |
+
hidden_states = self.hyper_conn_mlp(hidden_states, ffn_output)
|
| 678 |
+
|
| 679 |
+
# Paso 5: Aplicar checkpoint a la normalización final
|
| 680 |
+
output = apply_checkpoint(self._post_ffn_norm_forward, hidden_states)
|
| 681 |
+
|
| 682 |
+
return output
|
| 683 |
+
|
| 684 |
+
############################################
|
| 685 |
+
# FANFORMER DECODER CON RECURRENT DEPTH (Decoder-Only)
|
| 686 |
+
############################################
|
| 687 |
+
class FANformerDecoder(nn.Module):
|
| 688 |
+
"""
|
| 689 |
+
Implementación del decoder FANformer con recurrent depth.
|
| 690 |
+
Versión simplificada con skip connections directas sin gates.
|
| 691 |
+
"""
|
| 692 |
+
def __init__(self, num_layers: int, embed_dim: int, num_heads: int, ff_dim: int, dropout: float = 0.12,
|
| 693 |
+
num_gqa_groups: Optional[int] = None, p: float = 0.15,
|
| 694 |
+
use_checkpoint: bool = True, skip_every: int = 3):
|
| 695 |
+
super().__init__()
|
| 696 |
+
self.use_checkpoint = use_checkpoint
|
| 697 |
+
self.skip_every = skip_every
|
| 698 |
+
self.embed_dim = embed_dim
|
| 699 |
+
|
| 700 |
+
# Crear capas de FANformer con tratamiento especial para el primer bloque (HybridNorm*)
|
| 701 |
+
self.layers = nn.ModuleList()
|
| 702 |
+
for i in range(num_layers):
|
| 703 |
+
is_first_layer = (i == 0) # Identificar si es el primer bloque para HybridNorm*
|
| 704 |
+
self.layers.append(
|
| 705 |
+
FANformerLayer(
|
| 706 |
+
embed_dim, num_heads, ff_dim,
|
| 707 |
+
dropout=dropout * (1 + i * 0.035),
|
| 708 |
+
layer_index=i+1,
|
| 709 |
+
num_gqa_groups=num_gqa_groups,
|
| 710 |
+
is_first_layer=is_first_layer,
|
| 711 |
+
p=p
|
| 712 |
+
)
|
| 713 |
+
)
|
| 714 |
+
|
| 715 |
+
num_skips = num_layers // skip_every
|
| 716 |
+
|
| 717 |
+
# Mantenemos los dropouts pero eliminamos los gates y normalizaciones
|
| 718 |
+
self.skip_dropouts = nn.ModuleList([
|
| 719 |
+
progressive_dropout(dropout * 0.8, depth=i+1)
|
| 720 |
+
for i in range(num_skips)
|
| 721 |
+
])
|
| 722 |
+
|
| 723 |
+
# Mantenemos las normalizaciones finales
|
| 724 |
+
self.dropout = progressive_dropout(dropout, depth=1)
|
| 725 |
+
self.layer_norm = nn.RMSNorm(embed_dim, eps=1e-5)
|
| 726 |
+
|
| 727 |
+
def forward(self, tgt: torch.Tensor, tgt_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 728 |
+
output = tgt
|
| 729 |
+
layer_states = []
|
| 730 |
+
|
| 731 |
+
for i, layer in enumerate(self.layers):
|
| 732 |
+
if i % self.skip_every == 0:
|
| 733 |
+
layer_states.append(output)
|
| 734 |
+
|
| 735 |
+
# Añadimos cuda empty cada 4 capas
|
| 736 |
+
if i > 0 and i % 4 == 0:
|
| 737 |
+
torch.cuda.empty_cache()
|
| 738 |
+
|
| 739 |
+
# Simplemente llamamos al método forward estándar
|
| 740 |
+
output = layer(output, tgt_mask)
|
| 741 |
+
|
| 742 |
+
if (i + 1) % self.skip_every == 0 and i // self.skip_every < len(self.skip_dropouts):
|
| 743 |
+
skip_idx = i // self.skip_every
|
| 744 |
+
|
| 745 |
+
# Obtener skip state
|
| 746 |
+
skip_state = layer_states[skip_idx]
|
| 747 |
+
|
| 748 |
+
# Aplicar dropout directamente (sin normalización ni gates)
|
| 749 |
+
skip_state_dropped = self.skip_dropouts[skip_idx](skip_state)
|
| 750 |
+
|
| 751 |
+
# Combinar directamente sin gates
|
| 752 |
+
output = output + skip_state_dropped
|
| 753 |
+
|
| 754 |
+
# Normalizaciones finales
|
| 755 |
+
output = self.dropout(output)
|
| 756 |
+
output = self.layer_norm(output)
|
| 757 |
+
|
| 758 |
+
return output
|
| 759 |
+
|
| 760 |
+
############################################
|
| 761 |
+
# MODELO TEXT-ONLY (DECODER-ONLY)
|
| 762 |
+
############################################
|
| 763 |
+
from transformers.generation.utils import GenerationMixin
|
| 764 |
+
|
| 765 |
+
############################################
|
| 766 |
+
# MODELO TEXT-ONLY (DECODER-ONLY)
|
| 767 |
+
############################################
|
| 768 |
+
from typing import Optional
|
| 769 |
+
from transformers.generation.utils import GenerationMixin
|
| 770 |
+
from transformers import PretrainedConfig
|
| 771 |
+
import torch, torch.nn as nn
|
| 772 |
+
import torch.nn.functional as F
|
| 773 |
+
from transformers import GenerationConfig # NUEVO import
|
| 774 |
+
from transformers.modeling_outputs import CausalLMOutput
|
| 775 |
+
|
| 776 |
+
class MultiModalModel(nn.Module,
|
| 777 |
+
PyTorchModelHubMixin,
|
| 778 |
+
GenerationMixin):
|
| 779 |
+
"""
|
| 780 |
+
FANformer compatible con generate() y PyTorchModelHubMixin.
|
| 781 |
+
"""
|
| 782 |
+
config_class = FanConfig
|
| 783 |
+
model_type = "fanformer"
|
| 784 |
+
main_input_name = "input_ids"
|
| 785 |
+
_supports_cache_class = False # NUEVO ← evita el error actual
|
| 786 |
+
_supports_static_cache = False # NUEVO ← futura verificación
|
| 787 |
+
def __init__(self, # ← signature igual
|
| 788 |
+
config: Optional[FanConfig] = None,
|
| 789 |
+
vocab_size: int = None, embed_dim: int = None,
|
| 790 |
+
max_seq_len: int = None, num_heads: int = None,
|
| 791 |
+
num_decoder_layers: int = None, ff_dim: int = None,
|
| 792 |
+
dropout: float = 0.12, num_gqa_groups: Optional[int] = None,
|
| 793 |
+
p: float = 0.15, tie_weights: bool = True, **kwargs):
|
| 794 |
+
super().__init__()
|
| 795 |
+
|
| 796 |
+
# --- Normaliza entrada (igual que antes) ---
|
| 797 |
+
if config is not None:
|
| 798 |
+
self.config = config
|
| 799 |
+
vocab_size, embed_dim = config.vocab_size, config.embed_dim
|
| 800 |
+
max_seq_len, num_heads = config.max_seq_len, config.num_heads
|
| 801 |
+
num_decoder_layers, ff_dim = config.num_decoder_layers, config.ff_dim
|
| 802 |
+
dropout, num_gqa_groups = config.dropout, config.num_gqa_groups
|
| 803 |
+
p, tie_weights = config.p, config.tie_weights
|
| 804 |
+
else:
|
| 805 |
+
self.config = FanConfig(
|
| 806 |
+
vocab_size=vocab_size, embed_dim=embed_dim,
|
| 807 |
+
max_seq_len=max_seq_len, num_heads=num_heads,
|
| 808 |
+
num_decoder_layers=num_decoder_layers, ff_dim=ff_dim,
|
| 809 |
+
dropout=dropout, num_gqa_groups=num_gqa_groups,
|
| 810 |
+
p=p, tie_weights=tie_weights,
|
| 811 |
+
)
|
| 812 |
+
|
| 813 |
+
# --- NUEVA línea: generación por defecto ---
|
| 814 |
+
self.generation_config = GenerationConfig.from_model_config(self.config)
|
| 815 |
+
# (o simplemente GenerationConfig(), pero la utilidad de arriba copia
|
| 816 |
+
# parámetros útiles como eos_token_id, pad_token_id, etc.) :contentReference[oaicite:2]{index=2}
|
| 817 |
+
|
| 818 |
+
# --- resto de tu constructor sin cambios ---
|
| 819 |
+
self.embed_dim = embed_dim
|
| 820 |
+
self.epsilon = 1e-5
|
| 821 |
+
self.dropout_rate = dropout
|
| 822 |
+
|
| 823 |
+
self.decoder_embedding = nn.Embedding(vocab_size, embed_dim)
|
| 824 |
+
init_embedding(self.decoder_embedding)
|
| 825 |
+
self.emb_dropout = progressive_dropout(dropout, depth=1)
|
| 826 |
+
self.decoder_input_norm = nn.RMSNorm(embed_dim, eps=self.epsilon)
|
| 827 |
+
|
| 828 |
+
self.decoder = FANformerDecoder(
|
| 829 |
+
num_decoder_layers, embed_dim, num_heads, ff_dim,
|
| 830 |
+
dropout=dropout, num_gqa_groups=num_gqa_groups, p=p
|
| 831 |
+
)
|
| 832 |
+
|
| 833 |
+
self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False)
|
| 834 |
+
if tie_weights:
|
| 835 |
+
self.lm_head.weight = self.decoder_embedding.weight
|
| 836 |
+
@property # ← NUEVO (o vuelve a añadirlo)
|
| 837 |
+
def device(self):
|
| 838 |
+
# Hace lo mismo que en PreTrainedModel
|
| 839 |
+
return next(self.parameters()).device
|
| 840 |
+
def can_generate(self) -> bool:
|
| 841 |
+
"""Indica a GenerationMixin que el modelo es válido para .generate()"""
|
| 842 |
+
return True
|
| 843 |
+
# GenerationMixin hooks -------------
|
| 844 |
+
def get_input_embeddings(self):
|
| 845 |
+
return self.decoder_embedding
|
| 846 |
+
def set_input_embeddings(self, value):
|
| 847 |
+
self.decoder_embedding = value
|
| 848 |
+
|
| 849 |
+
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
|
| 850 |
+
x = self.decoder_embedding(input_ids).to(self.decoder_embedding.weight.dtype)
|
| 851 |
+
x = self.emb_dropout(x)
|
| 852 |
+
x = self.decoder_input_norm(x)
|
| 853 |
+
hidden = self.decoder(x, tgt_mask=attention_mask)
|
| 854 |
+
logits = self.lm_head(hidden)
|
| 855 |
+
|
| 856 |
+
loss = None
|
| 857 |
+
if labels is not None:
|
| 858 |
+
# Shift logits and labels for causal LM
|
| 859 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 860 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 861 |
+
loss = F.cross_entropy(
|
| 862 |
+
shift_logits.view(-1, shift_logits.size(-1)),
|
| 863 |
+
shift_labels.view(-1),
|
| 864 |
+
ignore_index=-100 # estándar en Transformers
|
| 865 |
+
)
|
| 866 |
+
|
| 867 |
+
return CausalLMOutput(loss=loss, logits=logits)
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"additional_special_tokens": [
|
| 3 |
+
"<|im_start|>",
|
| 4 |
+
"<|im_end|>"
|
| 5 |
+
],
|
| 6 |
+
"bos_token": {
|
| 7 |
+
"content": "<|im_start|>",
|
| 8 |
+
"lstrip": false,
|
| 9 |
+
"normalized": false,
|
| 10 |
+
"rstrip": false,
|
| 11 |
+
"single_word": false
|
| 12 |
+
},
|
| 13 |
+
"eos_token": {
|
| 14 |
+
"content": "<|im_end|>",
|
| 15 |
+
"lstrip": false,
|
| 16 |
+
"normalized": false,
|
| 17 |
+
"rstrip": false,
|
| 18 |
+
"single_word": false
|
| 19 |
+
},
|
| 20 |
+
"pad_token": {
|
| 21 |
+
"content": "<|im_end|>",
|
| 22 |
+
"lstrip": false,
|
| 23 |
+
"normalized": false,
|
| 24 |
+
"rstrip": false,
|
| 25 |
+
"single_word": false
|
| 26 |
+
},
|
| 27 |
+
"unk_token": {
|
| 28 |
+
"content": "<|endoftext|>",
|
| 29 |
+
"lstrip": false,
|
| 30 |
+
"normalized": false,
|
| 31 |
+
"rstrip": false,
|
| 32 |
+
"single_word": false
|
| 33 |
+
}
|
| 34 |
+
}
|
tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_prefix_space": false,
|
| 3 |
+
"added_tokens_decoder": {
|
| 4 |
+
"0": {
|
| 5 |
+
"content": "<|endoftext|>",
|
| 6 |
+
"lstrip": false,
|
| 7 |
+
"normalized": false,
|
| 8 |
+
"rstrip": false,
|
| 9 |
+
"single_word": false,
|
| 10 |
+
"special": true
|
| 11 |
+
},
|
| 12 |
+
"1": {
|
| 13 |
+
"content": "<|im_start|>",
|
| 14 |
+
"lstrip": false,
|
| 15 |
+
"normalized": false,
|
| 16 |
+
"rstrip": false,
|
| 17 |
+
"single_word": false,
|
| 18 |
+
"special": true
|
| 19 |
+
},
|
| 20 |
+
"2": {
|
| 21 |
+
"content": "<|im_end|>",
|
| 22 |
+
"lstrip": false,
|
| 23 |
+
"normalized": false,
|
| 24 |
+
"rstrip": false,
|
| 25 |
+
"single_word": false,
|
| 26 |
+
"special": true
|
| 27 |
+
},
|
| 28 |
+
"3": {
|
| 29 |
+
"content": "<repo_name>",
|
| 30 |
+
"lstrip": false,
|
| 31 |
+
"normalized": false,
|
| 32 |
+
"rstrip": false,
|
| 33 |
+
"single_word": false,
|
| 34 |
+
"special": true
|
| 35 |
+
},
|
| 36 |
+
"4": {
|
| 37 |
+
"content": "<reponame>",
|
| 38 |
+
"lstrip": false,
|
| 39 |
+
"normalized": false,
|
| 40 |
+
"rstrip": false,
|
| 41 |
+
"single_word": false,
|
| 42 |
+
"special": true
|
| 43 |
+
},
|
| 44 |
+
"5": {
|
| 45 |
+
"content": "<file_sep>",
|
| 46 |
+
"lstrip": false,
|
| 47 |
+
"normalized": false,
|
| 48 |
+
"rstrip": false,
|
| 49 |
+
"single_word": false,
|
| 50 |
+
"special": true
|
| 51 |
+
},
|
| 52 |
+
"6": {
|
| 53 |
+
"content": "<filename>",
|
| 54 |
+
"lstrip": false,
|
| 55 |
+
"normalized": false,
|
| 56 |
+
"rstrip": false,
|
| 57 |
+
"single_word": false,
|
| 58 |
+
"special": true
|
| 59 |
+
},
|
| 60 |
+
"7": {
|
| 61 |
+
"content": "<gh_stars>",
|
| 62 |
+
"lstrip": false,
|
| 63 |
+
"normalized": false,
|
| 64 |
+
"rstrip": false,
|
| 65 |
+
"single_word": false,
|
| 66 |
+
"special": true
|
| 67 |
+
},
|
| 68 |
+
"8": {
|
| 69 |
+
"content": "<issue_start>",
|
| 70 |
+
"lstrip": false,
|
| 71 |
+
"normalized": false,
|
| 72 |
+
"rstrip": false,
|
| 73 |
+
"single_word": false,
|
| 74 |
+
"special": true
|
| 75 |
+
},
|
| 76 |
+
"9": {
|
| 77 |
+
"content": "<issue_comment>",
|
| 78 |
+
"lstrip": false,
|
| 79 |
+
"normalized": false,
|
| 80 |
+
"rstrip": false,
|
| 81 |
+
"single_word": false,
|
| 82 |
+
"special": true
|
| 83 |
+
},
|
| 84 |
+
"10": {
|
| 85 |
+
"content": "<issue_closed>",
|
| 86 |
+
"lstrip": false,
|
| 87 |
+
"normalized": false,
|
| 88 |
+
"rstrip": false,
|
| 89 |
+
"single_word": false,
|
| 90 |
+
"special": true
|
| 91 |
+
},
|
| 92 |
+
"11": {
|
| 93 |
+
"content": "<jupyter_start>",
|
| 94 |
+
"lstrip": false,
|
| 95 |
+
"normalized": false,
|
| 96 |
+
"rstrip": false,
|
| 97 |
+
"single_word": false,
|
| 98 |
+
"special": true
|
| 99 |
+
},
|
| 100 |
+
"12": {
|
| 101 |
+
"content": "<jupyter_text>",
|
| 102 |
+
"lstrip": false,
|
| 103 |
+
"normalized": false,
|
| 104 |
+
"rstrip": false,
|
| 105 |
+
"single_word": false,
|
| 106 |
+
"special": true
|
| 107 |
+
},
|
| 108 |
+
"13": {
|
| 109 |
+
"content": "<jupyter_code>",
|
| 110 |
+
"lstrip": false,
|
| 111 |
+
"normalized": false,
|
| 112 |
+
"rstrip": false,
|
| 113 |
+
"single_word": false,
|
| 114 |
+
"special": true
|
| 115 |
+
},
|
| 116 |
+
"14": {
|
| 117 |
+
"content": "<jupyter_output>",
|
| 118 |
+
"lstrip": false,
|
| 119 |
+
"normalized": false,
|
| 120 |
+
"rstrip": false,
|
| 121 |
+
"single_word": false,
|
| 122 |
+
"special": true
|
| 123 |
+
},
|
| 124 |
+
"15": {
|
| 125 |
+
"content": "<jupyter_script>",
|
| 126 |
+
"lstrip": false,
|
| 127 |
+
"normalized": false,
|
| 128 |
+
"rstrip": false,
|
| 129 |
+
"single_word": false,
|
| 130 |
+
"special": true
|
| 131 |
+
},
|
| 132 |
+
"16": {
|
| 133 |
+
"content": "<empty_output>",
|
| 134 |
+
"lstrip": false,
|
| 135 |
+
"normalized": false,
|
| 136 |
+
"rstrip": false,
|
| 137 |
+
"single_word": false,
|
| 138 |
+
"special": true
|
| 139 |
+
}
|
| 140 |
+
},
|
| 141 |
+
"additional_special_tokens": [
|
| 142 |
+
"<|im_start|>",
|
| 143 |
+
"<|im_end|>"
|
| 144 |
+
],
|
| 145 |
+
"bos_token": "<|im_start|>",
|
| 146 |
+
"chat_template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful AI assistant named SmolLM, trained by Hugging Face<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
| 147 |
+
"clean_up_tokenization_spaces": false,
|
| 148 |
+
"eos_token": "<|im_end|>",
|
| 149 |
+
"extra_special_tokens": {},
|
| 150 |
+
"model_max_length": 8192,
|
| 151 |
+
"pad_token": "<|im_end|>",
|
| 152 |
+
"tokenizer_class": "GPT2Tokenizer",
|
| 153 |
+
"unk_token": "<|endoftext|>",
|
| 154 |
+
"vocab_size": 49152
|
| 155 |
+
}
|
vocab.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|