Spaces:
Sleeping
Sleeping
Commit ·
486475d
1
Parent(s): 4d8d059
feat: Add FLAN-T5 compatibility with relative position bias
Browse filesMajor changes:
- Implement T5RelativePositionBias for encoder/decoder self-attention
- T5 uses unscaled attention (no sqrt(d_k) scaling)
- Add float32 softmax path for numerical stability
- Switch to aot_eager compile backend (inductor causes NaN in decoder backward)
- Add gated-gelu activation support for T5 FFN
- Fix vocab size handling (32100 vs 32128)
- Update model configs for T5-base architecture
- Add dev/medium training configs for faster iteration
- Optimize training for ~4 min dev runs on RTX 4070
The model now correctly loads FLAN-T5-base weights and generates
coherent summaries with proper encoder-decoder architecture.
- .gitignore +2 -0
- README.md +103 -60
- artifacts/hf_tokenizer/special_tokens_map.json +105 -31
- artifacts/hf_tokenizer/spiece.model +3 -0
- artifacts/hf_tokenizer/tokenizer.json +0 -0
- artifacts/hf_tokenizer/tokenizer_config.json +904 -22
- configs/data/datasets.yaml +10 -1
- configs/model/base.yaml +9 -5
- configs/model/large.yaml +10 -5
- configs/model/small.yaml +8 -4
- configs/training/default.yaml +0 -20
- configs/training/dev.yaml +35 -0
- configs/training/full.yaml +22 -4
- configs/training/medium.yaml +36 -0
- configs/training/quick_test.yaml +0 -9
- docs/architecture.md +50 -37
- docs/training.md +32 -11
- outputs/evaluation_report.json +31 -32
- outputs/training_history.json +13 -84
- pyproject.toml +1 -0
- scripts/evaluate.py +22 -4
- scripts/export_model.py +2 -2
- scripts/export_tokenizer.py +51 -0
- scripts/train.py +143 -6
- src/data/dataloader.py +30 -3
- src/data/preprocessing.py +1 -1
- src/data/tokenization.py +24 -10
- src/models/attention.py +214 -83
- src/models/decoder.py +253 -61
- src/models/encoder.py +76 -14
- src/models/factory.py +242 -92
- src/models/feedforward.py +18 -15
- src/models/heads.py +22 -2
- src/models/multitask.py +21 -6
- src/models/positional_encoding.py +37 -0
- src/training/trainer.py +95 -26
- src/utils/io.py +17 -2
- tests/test_models/test_attention.py +22 -17
.gitignore
CHANGED
|
@@ -40,6 +40,8 @@ checkpoints/*.pt
|
|
| 40 |
logs/
|
| 41 |
*.log
|
| 42 |
runs/
|
|
|
|
|
|
|
| 43 |
|
| 44 |
# Outputs
|
| 45 |
results/
|
|
|
|
| 40 |
logs/
|
| 41 |
*.log
|
| 42 |
runs/
|
| 43 |
+
mlruns/
|
| 44 |
+
outputs/
|
| 45 |
|
| 46 |
# Outputs
|
| 47 |
results/
|
README.md
CHANGED
|
@@ -10,21 +10,55 @@ pinned: false
|
|
| 10 |
|
| 11 |
# LexiMind: A Multi-Task NLP Model
|
| 12 |
|
| 13 |
-
LexiMind is a state-of-the-art Natural Language Processing model designed for complex document understanding. It
|
|
|
|
|
|
|
| 14 |
|
| 15 |
This project is built with industry-standard MLOps practices, including configuration management with Hydra, experiment tracking with MLflow, and containerization with Docker, making it a reproducible and scalable solution.
|
| 16 |
|
| 17 |
## Core Features
|
| 18 |
|
| 19 |
-
* **Abstractive Summarization:** Generates concise, coherent summaries of long-form text.
|
| 20 |
-
* **Emotion Classification:** Identifies
|
| 21 |
-
* **Topic Clustering:**
|
| 22 |
|
| 23 |
## Model Architecture
|
| 24 |
|
| 25 |
-
LexiMind
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
## Getting Started
|
| 30 |
|
|
@@ -39,24 +73,18 @@ The model employs a multi-task learning framework, with a shared encoder-decoder
|
|
| 39 |
|
| 40 |
1. **Clone the repository:**
|
| 41 |
```bash
|
| 42 |
-
git clone https://github.com/
|
| 43 |
cd LexiMind
|
| 44 |
```
|
| 45 |
|
| 46 |
2. **Install dependencies:**
|
| 47 |
-
Poetry will handle the virtual environment and package installation.
|
| 48 |
```bash
|
| 49 |
poetry install
|
| 50 |
```
|
| 51 |
|
| 52 |
-
3. **Download
|
| 53 |
-
(Instructions for downloading your specific dataset would go here)
|
| 54 |
```bash
|
| 55 |
poetry run python scripts/download_data.py
|
| 56 |
-
```
|
| 57 |
-
|
| 58 |
-
4. **Preprocess data:**
|
| 59 |
-
```bash
|
| 60 |
poetry run python scripts/preprocess_data.py
|
| 61 |
```
|
| 62 |
|
|
@@ -64,84 +92,99 @@ The model employs a multi-task learning framework, with a shared encoder-decoder
|
|
| 64 |
|
| 65 |
### Configuration
|
| 66 |
|
| 67 |
-
All training and model parameters are managed via Hydra. Configurations are located in the `configs/` directory.
|
| 68 |
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
-
|
| 72 |
|
| 73 |
```bash
|
| 74 |
-
|
| 75 |
-
|
| 76 |
|
| 77 |
-
|
|
|
|
| 78 |
|
| 79 |
-
|
| 80 |
-
poetry run python
|
|
|
|
|
|
|
|
|
|
| 81 |
```
|
| 82 |
|
| 83 |
-
Experiments are automatically tracked with MLflow.
|
| 84 |
|
| 85 |
### Evaluation
|
| 86 |
|
| 87 |
-
To evaluate a trained model checkpoint against the test set:
|
| 88 |
-
|
| 89 |
```bash
|
| 90 |
-
poetry run python
|
| 91 |
```
|
| 92 |
|
| 93 |
-
Evaluation metrics and model outputs will be saved to the `outputs/` directory.
|
| 94 |
-
|
| 95 |
### Inference & Demo
|
| 96 |
|
| 97 |
-
A Gradio demo is available to interact with the trained model. To launch it:
|
| 98 |
-
|
| 99 |
```bash
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
poetry run python scripts/demo_gradio.py
|
| 101 |
```
|
| 102 |
|
| 103 |
-
Navigate to the local URL provided to access the web interface for summarization, classification, and clustering.
|
| 104 |
-
|
| 105 |
## Docker
|
| 106 |
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
```bash
|
| 111 |
-
docker build -t leximind .
|
| 112 |
-
```
|
| 113 |
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
```
|
| 118 |
|
| 119 |
## Project Structure
|
| 120 |
|
| 121 |
```
|
| 122 |
├── configs/ # Hydra configuration files
|
| 123 |
-
├──
|
| 124 |
-
├──
|
| 125 |
-
|
| 126 |
-
├── src/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
│ ├── data/ # Data loading and preprocessing
|
| 128 |
-
│ ├──
|
| 129 |
-
│ └──
|
| 130 |
-
├──
|
| 131 |
-
├──
|
| 132 |
-
|
| 133 |
-
└── README.md
|
| 134 |
```
|
| 135 |
|
| 136 |
## Code Quality
|
| 137 |
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
* **
|
| 141 |
-
* **MyPy:** For static type checking.
|
| 142 |
-
|
| 143 |
-
These checks are automated on every commit using pre-commit hooks. To set them up, run:
|
| 144 |
|
| 145 |
```bash
|
| 146 |
poetry run pre-commit install
|
| 147 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
# LexiMind: A Multi-Task NLP Model
|
| 12 |
|
| 13 |
+
LexiMind is a state-of-the-art Natural Language Processing model designed for complex document understanding. It features a **custom-built Transformer architecture** initialized with weights from Google's **FLAN-T5**, combining the flexibility of from-scratch implementation with the power of modern pre-trained models.
|
| 14 |
+
|
| 15 |
+
The model performs three sophisticated tasks simultaneously: **text summarization**, **emotion classification**, and **topic clustering**.
|
| 16 |
|
| 17 |
This project is built with industry-standard MLOps practices, including configuration management with Hydra, experiment tracking with MLflow, and containerization with Docker, making it a reproducible and scalable solution.
|
| 18 |
|
| 19 |
## Core Features
|
| 20 |
|
| 21 |
+
* **Abstractive Summarization:** Generates concise, coherent summaries of long-form text using encoder-decoder attention.
|
| 22 |
+
* **Emotion Classification:** Identifies emotions (Joy, Sadness, Anger, Fear, Love, Surprise) conveyed in a document.
|
| 23 |
+
* **Topic Clustering:** Classifies documents into thematic categories (World, Sports, Business, Sci/Tech).
|
| 24 |
|
| 25 |
## Model Architecture
|
| 26 |
|
| 27 |
+
LexiMind implements a **from-scratch Transformer** with modern architectural choices:
|
| 28 |
+
|
| 29 |
+
### Custom Transformer Features
|
| 30 |
+
- **Pre-Layer Normalization (Pre-LN):** RMSNorm applied before each sublayer for stable training
|
| 31 |
+
- **FlashAttention:** Via PyTorch 2.0's `scaled_dot_product_attention` for efficient computation
|
| 32 |
+
- **Learned Positional Embeddings:** Trainable position representations
|
| 33 |
+
- **Multi-Head Attention:** 12 heads with 768-dimensional representations
|
| 34 |
+
- **RMSNorm:** Modern normalization without bias (more efficient than LayerNorm)
|
| 35 |
+
|
| 36 |
+
### Pre-trained Weight Initialization
|
| 37 |
+
The model loads weights from **Google's FLAN-T5-base**, which provides:
|
| 38 |
+
- Strong language understanding from instruction-tuning
|
| 39 |
+
- Excellent performance on summarization and classification tasks
|
| 40 |
+
- Encoder-decoder architecture matching our custom implementation
|
| 41 |
+
|
| 42 |
+
### Multi-Task Learning
|
| 43 |
+
A shared encoder-decoder backbone with task-specific heads:
|
| 44 |
+
- **Summarization Head:** Language modeling head with weight tying
|
| 45 |
+
- **Emotion Head:** Mean-pooled classification with dropout
|
| 46 |
+
- **Topic Head:** Mean-pooled classification with dropout
|
| 47 |
+
|
| 48 |
+
## Technical Specifications
|
| 49 |
+
|
| 50 |
+
| Component | Specification |
|
| 51 |
+
|-----------|--------------|
|
| 52 |
+
| Architecture | Encoder-Decoder Transformer |
|
| 53 |
+
| Pre-trained Base | google/flan-t5-base |
|
| 54 |
+
| Hidden Dimension | 768 |
|
| 55 |
+
| Encoder Layers | 12 |
|
| 56 |
+
| Decoder Layers | 12 |
|
| 57 |
+
| Attention Heads | 12 |
|
| 58 |
+
| FFN Dimension | 2048 |
|
| 59 |
+
| Normalization | RMSNorm (Pre-LN) |
|
| 60 |
+
| Position Encoding | Learned Embeddings |
|
| 61 |
+
| Max Sequence Length | 512 tokens |
|
| 62 |
|
| 63 |
## Getting Started
|
| 64 |
|
|
|
|
| 73 |
|
| 74 |
1. **Clone the repository:**
|
| 75 |
```bash
|
| 76 |
+
git clone https://github.com/OliverPerrin/LexiMind.git
|
| 77 |
cd LexiMind
|
| 78 |
```
|
| 79 |
|
| 80 |
2. **Install dependencies:**
|
|
|
|
| 81 |
```bash
|
| 82 |
poetry install
|
| 83 |
```
|
| 84 |
|
| 85 |
+
3. **Download and preprocess data:**
|
|
|
|
| 86 |
```bash
|
| 87 |
poetry run python scripts/download_data.py
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
poetry run python scripts/preprocess_data.py
|
| 89 |
```
|
| 90 |
|
|
|
|
| 92 |
|
| 93 |
### Configuration
|
| 94 |
|
| 95 |
+
All training and model parameters are managed via Hydra. Configurations are located in the `configs/` directory.
|
| 96 |
|
| 97 |
+
Available configurations:
|
| 98 |
+
- `model=base` - FLAN-T5-base (default, 12 layers)
|
| 99 |
+
- `model=small` - Smaller model for testing (no pretrained weights)
|
| 100 |
+
- `model=large` - FLAN-T5-large (24 layers, requires more VRAM)
|
| 101 |
+
- `training=dev` - Quick development run
|
| 102 |
+
- `training=medium` - Balanced training (~2-3 hours on RTX 4070)
|
| 103 |
+
- `training=full` - Full training run
|
| 104 |
|
| 105 |
+
### Training
|
| 106 |
|
| 107 |
```bash
|
| 108 |
+
# Default training with FLAN-T5-base
|
| 109 |
+
poetry run python scripts/train.py
|
| 110 |
|
| 111 |
+
# Quick development run
|
| 112 |
+
poetry run python scripts/train.py training=dev
|
| 113 |
|
| 114 |
+
# Medium training run (recommended for RTX 4070)
|
| 115 |
+
poetry run python scripts/train.py training=medium
|
| 116 |
+
|
| 117 |
+
# Override parameters
|
| 118 |
+
poetry run python scripts/train.py training.optimizer.lr=5e-5
|
| 119 |
```
|
| 120 |
|
| 121 |
+
Experiments are automatically tracked with MLflow. View results with `mlflow ui`.
|
| 122 |
|
| 123 |
### Evaluation
|
| 124 |
|
|
|
|
|
|
|
| 125 |
```bash
|
| 126 |
+
poetry run python scripts/evaluate.py --checkpoint checkpoints/best.pt
|
| 127 |
```
|
| 128 |
|
|
|
|
|
|
|
| 129 |
### Inference & Demo
|
| 130 |
|
|
|
|
|
|
|
| 131 |
```bash
|
| 132 |
+
# Command-line inference
|
| 133 |
+
poetry run python scripts/inference.py "Your text to analyze"
|
| 134 |
+
|
| 135 |
+
# Gradio web demo
|
| 136 |
poetry run python scripts/demo_gradio.py
|
| 137 |
```
|
| 138 |
|
|
|
|
|
|
|
| 139 |
## Docker
|
| 140 |
|
| 141 |
+
```bash
|
| 142 |
+
# Build
|
| 143 |
+
docker build -t leximind .
|
|
|
|
|
|
|
|
|
|
| 144 |
|
| 145 |
+
# Run demo
|
| 146 |
+
docker run -p 7860:7860 leximind
|
| 147 |
+
```
|
|
|
|
| 148 |
|
| 149 |
## Project Structure
|
| 150 |
|
| 151 |
```
|
| 152 |
├── configs/ # Hydra configuration files
|
| 153 |
+
│ ├── model/ # Model architectures (base, small, large)
|
| 154 |
+
│ ├── training/ # Training configs (dev, medium, full)
|
| 155 |
+
│ └── data/ # Dataset configurations
|
| 156 |
+
├── src/
|
| 157 |
+
│ ├── models/ # Custom Transformer implementation
|
| 158 |
+
│ │ ├── encoder.py # TransformerEncoder with Pre-LN RMSNorm
|
| 159 |
+
│ │ ├── decoder.py # TransformerDecoder with KV-cache
|
| 160 |
+
│ │ ├── attention.py # Multi-Head Attention with FlashAttention
|
| 161 |
+
│ │ └── factory.py # Model building with FLAN-T5 weight loading
|
| 162 |
│ ├── data/ # Data loading and preprocessing
|
| 163 |
+
│ ├── training/ # Training loop with mixed precision
|
| 164 |
+
│ └── inference/ # Inference pipeline
|
| 165 |
+
├── scripts/ # Entry points
|
| 166 |
+
├── tests/ # Unit tests
|
| 167 |
+
└── notebooks/ # Analysis notebooks
|
|
|
|
| 168 |
```
|
| 169 |
|
| 170 |
## Code Quality
|
| 171 |
|
| 172 |
+
* **Ruff:** Fast linting and formatting
|
| 173 |
+
* **MyPy:** Static type checking
|
| 174 |
+
* **Pre-commit hooks:** Automated quality checks
|
|
|
|
|
|
|
|
|
|
| 175 |
|
| 176 |
```bash
|
| 177 |
poetry run pre-commit install
|
| 178 |
+
```
|
| 179 |
+
|
| 180 |
+
## Performance Optimizations
|
| 181 |
+
|
| 182 |
+
- **torch.compile:** JIT compilation with Inductor backend
|
| 183 |
+
- **Mixed Precision:** bfloat16 training on Ampere/Ada GPUs
|
| 184 |
+
- **TF32:** Enabled for RTX 30xx/40xx series
|
| 185 |
+
- **KV-Cache:** Efficient autoregressive decoding
|
| 186 |
+
- **FlashAttention:** Memory-efficient attention via SDPA
|
| 187 |
+
|
| 188 |
+
## License
|
| 189 |
+
|
| 190 |
+
MIT License - see [LICENSE](LICENSE) for details.
|
artifacts/hf_tokenizer/special_tokens_map.json
CHANGED
|
@@ -1,50 +1,124 @@
|
|
| 1 |
{
|
| 2 |
-
"
|
| 3 |
-
"
|
| 4 |
-
"
|
| 5 |
-
"
|
| 6 |
-
"
|
| 7 |
-
"
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
"
|
| 11 |
-
"
|
| 12 |
-
"
|
| 13 |
-
"
|
| 14 |
-
"
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
"eos_token": {
|
| 17 |
"content": "</s>",
|
| 18 |
"lstrip": false,
|
| 19 |
-
"normalized":
|
| 20 |
-
"rstrip": false,
|
| 21 |
-
"single_word": false
|
| 22 |
-
},
|
| 23 |
-
"mask_token": {
|
| 24 |
-
"content": "<mask>",
|
| 25 |
-
"lstrip": true,
|
| 26 |
-
"normalized": true,
|
| 27 |
"rstrip": false,
|
| 28 |
"single_word": false
|
| 29 |
},
|
| 30 |
"pad_token": {
|
| 31 |
"content": "<pad>",
|
| 32 |
"lstrip": false,
|
| 33 |
-
"normalized":
|
| 34 |
-
"rstrip": false,
|
| 35 |
-
"single_word": false
|
| 36 |
-
},
|
| 37 |
-
"sep_token": {
|
| 38 |
-
"content": "</s>",
|
| 39 |
-
"lstrip": false,
|
| 40 |
-
"normalized": true,
|
| 41 |
"rstrip": false,
|
| 42 |
"single_word": false
|
| 43 |
},
|
| 44 |
"unk_token": {
|
| 45 |
"content": "<unk>",
|
| 46 |
"lstrip": false,
|
| 47 |
-
"normalized":
|
| 48 |
"rstrip": false,
|
| 49 |
"single_word": false
|
| 50 |
}
|
|
|
|
| 1 |
{
|
| 2 |
+
"additional_special_tokens": [
|
| 3 |
+
"<extra_id_0>",
|
| 4 |
+
"<extra_id_1>",
|
| 5 |
+
"<extra_id_2>",
|
| 6 |
+
"<extra_id_3>",
|
| 7 |
+
"<extra_id_4>",
|
| 8 |
+
"<extra_id_5>",
|
| 9 |
+
"<extra_id_6>",
|
| 10 |
+
"<extra_id_7>",
|
| 11 |
+
"<extra_id_8>",
|
| 12 |
+
"<extra_id_9>",
|
| 13 |
+
"<extra_id_10>",
|
| 14 |
+
"<extra_id_11>",
|
| 15 |
+
"<extra_id_12>",
|
| 16 |
+
"<extra_id_13>",
|
| 17 |
+
"<extra_id_14>",
|
| 18 |
+
"<extra_id_15>",
|
| 19 |
+
"<extra_id_16>",
|
| 20 |
+
"<extra_id_17>",
|
| 21 |
+
"<extra_id_18>",
|
| 22 |
+
"<extra_id_19>",
|
| 23 |
+
"<extra_id_20>",
|
| 24 |
+
"<extra_id_21>",
|
| 25 |
+
"<extra_id_22>",
|
| 26 |
+
"<extra_id_23>",
|
| 27 |
+
"<extra_id_24>",
|
| 28 |
+
"<extra_id_25>",
|
| 29 |
+
"<extra_id_26>",
|
| 30 |
+
"<extra_id_27>",
|
| 31 |
+
"<extra_id_28>",
|
| 32 |
+
"<extra_id_29>",
|
| 33 |
+
"<extra_id_30>",
|
| 34 |
+
"<extra_id_31>",
|
| 35 |
+
"<extra_id_32>",
|
| 36 |
+
"<extra_id_33>",
|
| 37 |
+
"<extra_id_34>",
|
| 38 |
+
"<extra_id_35>",
|
| 39 |
+
"<extra_id_36>",
|
| 40 |
+
"<extra_id_37>",
|
| 41 |
+
"<extra_id_38>",
|
| 42 |
+
"<extra_id_39>",
|
| 43 |
+
"<extra_id_40>",
|
| 44 |
+
"<extra_id_41>",
|
| 45 |
+
"<extra_id_42>",
|
| 46 |
+
"<extra_id_43>",
|
| 47 |
+
"<extra_id_44>",
|
| 48 |
+
"<extra_id_45>",
|
| 49 |
+
"<extra_id_46>",
|
| 50 |
+
"<extra_id_47>",
|
| 51 |
+
"<extra_id_48>",
|
| 52 |
+
"<extra_id_49>",
|
| 53 |
+
"<extra_id_50>",
|
| 54 |
+
"<extra_id_51>",
|
| 55 |
+
"<extra_id_52>",
|
| 56 |
+
"<extra_id_53>",
|
| 57 |
+
"<extra_id_54>",
|
| 58 |
+
"<extra_id_55>",
|
| 59 |
+
"<extra_id_56>",
|
| 60 |
+
"<extra_id_57>",
|
| 61 |
+
"<extra_id_58>",
|
| 62 |
+
"<extra_id_59>",
|
| 63 |
+
"<extra_id_60>",
|
| 64 |
+
"<extra_id_61>",
|
| 65 |
+
"<extra_id_62>",
|
| 66 |
+
"<extra_id_63>",
|
| 67 |
+
"<extra_id_64>",
|
| 68 |
+
"<extra_id_65>",
|
| 69 |
+
"<extra_id_66>",
|
| 70 |
+
"<extra_id_67>",
|
| 71 |
+
"<extra_id_68>",
|
| 72 |
+
"<extra_id_69>",
|
| 73 |
+
"<extra_id_70>",
|
| 74 |
+
"<extra_id_71>",
|
| 75 |
+
"<extra_id_72>",
|
| 76 |
+
"<extra_id_73>",
|
| 77 |
+
"<extra_id_74>",
|
| 78 |
+
"<extra_id_75>",
|
| 79 |
+
"<extra_id_76>",
|
| 80 |
+
"<extra_id_77>",
|
| 81 |
+
"<extra_id_78>",
|
| 82 |
+
"<extra_id_79>",
|
| 83 |
+
"<extra_id_80>",
|
| 84 |
+
"<extra_id_81>",
|
| 85 |
+
"<extra_id_82>",
|
| 86 |
+
"<extra_id_83>",
|
| 87 |
+
"<extra_id_84>",
|
| 88 |
+
"<extra_id_85>",
|
| 89 |
+
"<extra_id_86>",
|
| 90 |
+
"<extra_id_87>",
|
| 91 |
+
"<extra_id_88>",
|
| 92 |
+
"<extra_id_89>",
|
| 93 |
+
"<extra_id_90>",
|
| 94 |
+
"<extra_id_91>",
|
| 95 |
+
"<extra_id_92>",
|
| 96 |
+
"<extra_id_93>",
|
| 97 |
+
"<extra_id_94>",
|
| 98 |
+
"<extra_id_95>",
|
| 99 |
+
"<extra_id_96>",
|
| 100 |
+
"<extra_id_97>",
|
| 101 |
+
"<extra_id_98>",
|
| 102 |
+
"<extra_id_99>"
|
| 103 |
+
],
|
| 104 |
"eos_token": {
|
| 105 |
"content": "</s>",
|
| 106 |
"lstrip": false,
|
| 107 |
+
"normalized": false,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
"rstrip": false,
|
| 109 |
"single_word": false
|
| 110 |
},
|
| 111 |
"pad_token": {
|
| 112 |
"content": "<pad>",
|
| 113 |
"lstrip": false,
|
| 114 |
+
"normalized": false,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
"rstrip": false,
|
| 116 |
"single_word": false
|
| 117 |
},
|
| 118 |
"unk_token": {
|
| 119 |
"content": "<unk>",
|
| 120 |
"lstrip": false,
|
| 121 |
+
"normalized": false,
|
| 122 |
"rstrip": false,
|
| 123 |
"single_word": false
|
| 124 |
}
|
artifacts/hf_tokenizer/spiece.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d60acb128cf7b7f2536e8f38a5b18a05535c9e14c7a355904270e15b0945ea86
|
| 3 |
+
size 791656
|
artifacts/hf_tokenizer/tokenizer.json
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
artifacts/hf_tokenizer/tokenizer_config.json
CHANGED
|
@@ -1,58 +1,940 @@
|
|
| 1 |
{
|
| 2 |
-
"add_prefix_space":
|
| 3 |
"added_tokens_decoder": {
|
| 4 |
"0": {
|
| 5 |
-
"content": "<
|
| 6 |
"lstrip": false,
|
| 7 |
-
"normalized":
|
| 8 |
"rstrip": false,
|
| 9 |
"single_word": false,
|
| 10 |
"special": true
|
| 11 |
},
|
| 12 |
"1": {
|
| 13 |
-
"content": "<
|
| 14 |
"lstrip": false,
|
| 15 |
-
"normalized":
|
| 16 |
"rstrip": false,
|
| 17 |
"single_word": false,
|
| 18 |
"special": true
|
| 19 |
},
|
| 20 |
"2": {
|
| 21 |
-
"content": "<
|
| 22 |
"lstrip": false,
|
| 23 |
-
"normalized":
|
| 24 |
"rstrip": false,
|
| 25 |
"single_word": false,
|
| 26 |
"special": true
|
| 27 |
},
|
| 28 |
-
"
|
| 29 |
-
"content": "<
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
"lstrip": false,
|
| 31 |
-
"normalized":
|
| 32 |
"rstrip": false,
|
| 33 |
"single_word": false,
|
| 34 |
"special": true
|
| 35 |
},
|
| 36 |
-
"
|
| 37 |
-
"content": "<
|
| 38 |
-
"lstrip":
|
| 39 |
-
"normalized":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
"rstrip": false,
|
| 41 |
"single_word": false,
|
| 42 |
"special": true
|
| 43 |
}
|
| 44 |
},
|
| 45 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
"clean_up_tokenization_spaces": false,
|
| 47 |
-
"cls_token": "<s>",
|
| 48 |
"eos_token": "</s>",
|
| 49 |
-
"
|
| 50 |
"extra_special_tokens": {},
|
| 51 |
-
"
|
| 52 |
-
"model_max_length": 1000000000000000019884624838656,
|
| 53 |
"pad_token": "<pad>",
|
| 54 |
-
"
|
| 55 |
-
"tokenizer_class": "
|
| 56 |
-
"trim_offsets": true,
|
| 57 |
"unk_token": "<unk>"
|
| 58 |
}
|
|
|
|
| 1 |
{
|
| 2 |
+
"add_prefix_space": null,
|
| 3 |
"added_tokens_decoder": {
|
| 4 |
"0": {
|
| 5 |
+
"content": "<pad>",
|
| 6 |
"lstrip": false,
|
| 7 |
+
"normalized": false,
|
| 8 |
"rstrip": false,
|
| 9 |
"single_word": false,
|
| 10 |
"special": true
|
| 11 |
},
|
| 12 |
"1": {
|
| 13 |
+
"content": "</s>",
|
| 14 |
"lstrip": false,
|
| 15 |
+
"normalized": false,
|
| 16 |
"rstrip": false,
|
| 17 |
"single_word": false,
|
| 18 |
"special": true
|
| 19 |
},
|
| 20 |
"2": {
|
| 21 |
+
"content": "<unk>",
|
| 22 |
"lstrip": false,
|
| 23 |
+
"normalized": false,
|
| 24 |
"rstrip": false,
|
| 25 |
"single_word": false,
|
| 26 |
"special": true
|
| 27 |
},
|
| 28 |
+
"32000": {
|
| 29 |
+
"content": "<extra_id_99>",
|
| 30 |
+
"lstrip": false,
|
| 31 |
+
"normalized": false,
|
| 32 |
+
"rstrip": false,
|
| 33 |
+
"single_word": false,
|
| 34 |
+
"special": true
|
| 35 |
+
},
|
| 36 |
+
"32001": {
|
| 37 |
+
"content": "<extra_id_98>",
|
| 38 |
+
"lstrip": false,
|
| 39 |
+
"normalized": false,
|
| 40 |
+
"rstrip": false,
|
| 41 |
+
"single_word": false,
|
| 42 |
+
"special": true
|
| 43 |
+
},
|
| 44 |
+
"32002": {
|
| 45 |
+
"content": "<extra_id_97>",
|
| 46 |
"lstrip": false,
|
| 47 |
+
"normalized": false,
|
| 48 |
"rstrip": false,
|
| 49 |
"single_word": false,
|
| 50 |
"special": true
|
| 51 |
},
|
| 52 |
+
"32003": {
|
| 53 |
+
"content": "<extra_id_96>",
|
| 54 |
+
"lstrip": false,
|
| 55 |
+
"normalized": false,
|
| 56 |
+
"rstrip": false,
|
| 57 |
+
"single_word": false,
|
| 58 |
+
"special": true
|
| 59 |
+
},
|
| 60 |
+
"32004": {
|
| 61 |
+
"content": "<extra_id_95>",
|
| 62 |
+
"lstrip": false,
|
| 63 |
+
"normalized": false,
|
| 64 |
+
"rstrip": false,
|
| 65 |
+
"single_word": false,
|
| 66 |
+
"special": true
|
| 67 |
+
},
|
| 68 |
+
"32005": {
|
| 69 |
+
"content": "<extra_id_94>",
|
| 70 |
+
"lstrip": false,
|
| 71 |
+
"normalized": false,
|
| 72 |
+
"rstrip": false,
|
| 73 |
+
"single_word": false,
|
| 74 |
+
"special": true
|
| 75 |
+
},
|
| 76 |
+
"32006": {
|
| 77 |
+
"content": "<extra_id_93>",
|
| 78 |
+
"lstrip": false,
|
| 79 |
+
"normalized": false,
|
| 80 |
+
"rstrip": false,
|
| 81 |
+
"single_word": false,
|
| 82 |
+
"special": true
|
| 83 |
+
},
|
| 84 |
+
"32007": {
|
| 85 |
+
"content": "<extra_id_92>",
|
| 86 |
+
"lstrip": false,
|
| 87 |
+
"normalized": false,
|
| 88 |
+
"rstrip": false,
|
| 89 |
+
"single_word": false,
|
| 90 |
+
"special": true
|
| 91 |
+
},
|
| 92 |
+
"32008": {
|
| 93 |
+
"content": "<extra_id_91>",
|
| 94 |
+
"lstrip": false,
|
| 95 |
+
"normalized": false,
|
| 96 |
+
"rstrip": false,
|
| 97 |
+
"single_word": false,
|
| 98 |
+
"special": true
|
| 99 |
+
},
|
| 100 |
+
"32009": {
|
| 101 |
+
"content": "<extra_id_90>",
|
| 102 |
+
"lstrip": false,
|
| 103 |
+
"normalized": false,
|
| 104 |
+
"rstrip": false,
|
| 105 |
+
"single_word": false,
|
| 106 |
+
"special": true
|
| 107 |
+
},
|
| 108 |
+
"32010": {
|
| 109 |
+
"content": "<extra_id_89>",
|
| 110 |
+
"lstrip": false,
|
| 111 |
+
"normalized": false,
|
| 112 |
+
"rstrip": false,
|
| 113 |
+
"single_word": false,
|
| 114 |
+
"special": true
|
| 115 |
+
},
|
| 116 |
+
"32011": {
|
| 117 |
+
"content": "<extra_id_88>",
|
| 118 |
+
"lstrip": false,
|
| 119 |
+
"normalized": false,
|
| 120 |
+
"rstrip": false,
|
| 121 |
+
"single_word": false,
|
| 122 |
+
"special": true
|
| 123 |
+
},
|
| 124 |
+
"32012": {
|
| 125 |
+
"content": "<extra_id_87>",
|
| 126 |
+
"lstrip": false,
|
| 127 |
+
"normalized": false,
|
| 128 |
+
"rstrip": false,
|
| 129 |
+
"single_word": false,
|
| 130 |
+
"special": true
|
| 131 |
+
},
|
| 132 |
+
"32013": {
|
| 133 |
+
"content": "<extra_id_86>",
|
| 134 |
+
"lstrip": false,
|
| 135 |
+
"normalized": false,
|
| 136 |
+
"rstrip": false,
|
| 137 |
+
"single_word": false,
|
| 138 |
+
"special": true
|
| 139 |
+
},
|
| 140 |
+
"32014": {
|
| 141 |
+
"content": "<extra_id_85>",
|
| 142 |
+
"lstrip": false,
|
| 143 |
+
"normalized": false,
|
| 144 |
+
"rstrip": false,
|
| 145 |
+
"single_word": false,
|
| 146 |
+
"special": true
|
| 147 |
+
},
|
| 148 |
+
"32015": {
|
| 149 |
+
"content": "<extra_id_84>",
|
| 150 |
+
"lstrip": false,
|
| 151 |
+
"normalized": false,
|
| 152 |
+
"rstrip": false,
|
| 153 |
+
"single_word": false,
|
| 154 |
+
"special": true
|
| 155 |
+
},
|
| 156 |
+
"32016": {
|
| 157 |
+
"content": "<extra_id_83>",
|
| 158 |
+
"lstrip": false,
|
| 159 |
+
"normalized": false,
|
| 160 |
+
"rstrip": false,
|
| 161 |
+
"single_word": false,
|
| 162 |
+
"special": true
|
| 163 |
+
},
|
| 164 |
+
"32017": {
|
| 165 |
+
"content": "<extra_id_82>",
|
| 166 |
+
"lstrip": false,
|
| 167 |
+
"normalized": false,
|
| 168 |
+
"rstrip": false,
|
| 169 |
+
"single_word": false,
|
| 170 |
+
"special": true
|
| 171 |
+
},
|
| 172 |
+
"32018": {
|
| 173 |
+
"content": "<extra_id_81>",
|
| 174 |
+
"lstrip": false,
|
| 175 |
+
"normalized": false,
|
| 176 |
+
"rstrip": false,
|
| 177 |
+
"single_word": false,
|
| 178 |
+
"special": true
|
| 179 |
+
},
|
| 180 |
+
"32019": {
|
| 181 |
+
"content": "<extra_id_80>",
|
| 182 |
+
"lstrip": false,
|
| 183 |
+
"normalized": false,
|
| 184 |
+
"rstrip": false,
|
| 185 |
+
"single_word": false,
|
| 186 |
+
"special": true
|
| 187 |
+
},
|
| 188 |
+
"32020": {
|
| 189 |
+
"content": "<extra_id_79>",
|
| 190 |
+
"lstrip": false,
|
| 191 |
+
"normalized": false,
|
| 192 |
+
"rstrip": false,
|
| 193 |
+
"single_word": false,
|
| 194 |
+
"special": true
|
| 195 |
+
},
|
| 196 |
+
"32021": {
|
| 197 |
+
"content": "<extra_id_78>",
|
| 198 |
+
"lstrip": false,
|
| 199 |
+
"normalized": false,
|
| 200 |
+
"rstrip": false,
|
| 201 |
+
"single_word": false,
|
| 202 |
+
"special": true
|
| 203 |
+
},
|
| 204 |
+
"32022": {
|
| 205 |
+
"content": "<extra_id_77>",
|
| 206 |
+
"lstrip": false,
|
| 207 |
+
"normalized": false,
|
| 208 |
+
"rstrip": false,
|
| 209 |
+
"single_word": false,
|
| 210 |
+
"special": true
|
| 211 |
+
},
|
| 212 |
+
"32023": {
|
| 213 |
+
"content": "<extra_id_76>",
|
| 214 |
+
"lstrip": false,
|
| 215 |
+
"normalized": false,
|
| 216 |
+
"rstrip": false,
|
| 217 |
+
"single_word": false,
|
| 218 |
+
"special": true
|
| 219 |
+
},
|
| 220 |
+
"32024": {
|
| 221 |
+
"content": "<extra_id_75>",
|
| 222 |
+
"lstrip": false,
|
| 223 |
+
"normalized": false,
|
| 224 |
+
"rstrip": false,
|
| 225 |
+
"single_word": false,
|
| 226 |
+
"special": true
|
| 227 |
+
},
|
| 228 |
+
"32025": {
|
| 229 |
+
"content": "<extra_id_74>",
|
| 230 |
+
"lstrip": false,
|
| 231 |
+
"normalized": false,
|
| 232 |
+
"rstrip": false,
|
| 233 |
+
"single_word": false,
|
| 234 |
+
"special": true
|
| 235 |
+
},
|
| 236 |
+
"32026": {
|
| 237 |
+
"content": "<extra_id_73>",
|
| 238 |
+
"lstrip": false,
|
| 239 |
+
"normalized": false,
|
| 240 |
+
"rstrip": false,
|
| 241 |
+
"single_word": false,
|
| 242 |
+
"special": true
|
| 243 |
+
},
|
| 244 |
+
"32027": {
|
| 245 |
+
"content": "<extra_id_72>",
|
| 246 |
+
"lstrip": false,
|
| 247 |
+
"normalized": false,
|
| 248 |
+
"rstrip": false,
|
| 249 |
+
"single_word": false,
|
| 250 |
+
"special": true
|
| 251 |
+
},
|
| 252 |
+
"32028": {
|
| 253 |
+
"content": "<extra_id_71>",
|
| 254 |
+
"lstrip": false,
|
| 255 |
+
"normalized": false,
|
| 256 |
+
"rstrip": false,
|
| 257 |
+
"single_word": false,
|
| 258 |
+
"special": true
|
| 259 |
+
},
|
| 260 |
+
"32029": {
|
| 261 |
+
"content": "<extra_id_70>",
|
| 262 |
+
"lstrip": false,
|
| 263 |
+
"normalized": false,
|
| 264 |
+
"rstrip": false,
|
| 265 |
+
"single_word": false,
|
| 266 |
+
"special": true
|
| 267 |
+
},
|
| 268 |
+
"32030": {
|
| 269 |
+
"content": "<extra_id_69>",
|
| 270 |
+
"lstrip": false,
|
| 271 |
+
"normalized": false,
|
| 272 |
+
"rstrip": false,
|
| 273 |
+
"single_word": false,
|
| 274 |
+
"special": true
|
| 275 |
+
},
|
| 276 |
+
"32031": {
|
| 277 |
+
"content": "<extra_id_68>",
|
| 278 |
+
"lstrip": false,
|
| 279 |
+
"normalized": false,
|
| 280 |
+
"rstrip": false,
|
| 281 |
+
"single_word": false,
|
| 282 |
+
"special": true
|
| 283 |
+
},
|
| 284 |
+
"32032": {
|
| 285 |
+
"content": "<extra_id_67>",
|
| 286 |
+
"lstrip": false,
|
| 287 |
+
"normalized": false,
|
| 288 |
+
"rstrip": false,
|
| 289 |
+
"single_word": false,
|
| 290 |
+
"special": true
|
| 291 |
+
},
|
| 292 |
+
"32033": {
|
| 293 |
+
"content": "<extra_id_66>",
|
| 294 |
+
"lstrip": false,
|
| 295 |
+
"normalized": false,
|
| 296 |
+
"rstrip": false,
|
| 297 |
+
"single_word": false,
|
| 298 |
+
"special": true
|
| 299 |
+
},
|
| 300 |
+
"32034": {
|
| 301 |
+
"content": "<extra_id_65>",
|
| 302 |
+
"lstrip": false,
|
| 303 |
+
"normalized": false,
|
| 304 |
+
"rstrip": false,
|
| 305 |
+
"single_word": false,
|
| 306 |
+
"special": true
|
| 307 |
+
},
|
| 308 |
+
"32035": {
|
| 309 |
+
"content": "<extra_id_64>",
|
| 310 |
+
"lstrip": false,
|
| 311 |
+
"normalized": false,
|
| 312 |
+
"rstrip": false,
|
| 313 |
+
"single_word": false,
|
| 314 |
+
"special": true
|
| 315 |
+
},
|
| 316 |
+
"32036": {
|
| 317 |
+
"content": "<extra_id_63>",
|
| 318 |
+
"lstrip": false,
|
| 319 |
+
"normalized": false,
|
| 320 |
+
"rstrip": false,
|
| 321 |
+
"single_word": false,
|
| 322 |
+
"special": true
|
| 323 |
+
},
|
| 324 |
+
"32037": {
|
| 325 |
+
"content": "<extra_id_62>",
|
| 326 |
+
"lstrip": false,
|
| 327 |
+
"normalized": false,
|
| 328 |
+
"rstrip": false,
|
| 329 |
+
"single_word": false,
|
| 330 |
+
"special": true
|
| 331 |
+
},
|
| 332 |
+
"32038": {
|
| 333 |
+
"content": "<extra_id_61>",
|
| 334 |
+
"lstrip": false,
|
| 335 |
+
"normalized": false,
|
| 336 |
+
"rstrip": false,
|
| 337 |
+
"single_word": false,
|
| 338 |
+
"special": true
|
| 339 |
+
},
|
| 340 |
+
"32039": {
|
| 341 |
+
"content": "<extra_id_60>",
|
| 342 |
+
"lstrip": false,
|
| 343 |
+
"normalized": false,
|
| 344 |
+
"rstrip": false,
|
| 345 |
+
"single_word": false,
|
| 346 |
+
"special": true
|
| 347 |
+
},
|
| 348 |
+
"32040": {
|
| 349 |
+
"content": "<extra_id_59>",
|
| 350 |
+
"lstrip": false,
|
| 351 |
+
"normalized": false,
|
| 352 |
+
"rstrip": false,
|
| 353 |
+
"single_word": false,
|
| 354 |
+
"special": true
|
| 355 |
+
},
|
| 356 |
+
"32041": {
|
| 357 |
+
"content": "<extra_id_58>",
|
| 358 |
+
"lstrip": false,
|
| 359 |
+
"normalized": false,
|
| 360 |
+
"rstrip": false,
|
| 361 |
+
"single_word": false,
|
| 362 |
+
"special": true
|
| 363 |
+
},
|
| 364 |
+
"32042": {
|
| 365 |
+
"content": "<extra_id_57>",
|
| 366 |
+
"lstrip": false,
|
| 367 |
+
"normalized": false,
|
| 368 |
+
"rstrip": false,
|
| 369 |
+
"single_word": false,
|
| 370 |
+
"special": true
|
| 371 |
+
},
|
| 372 |
+
"32043": {
|
| 373 |
+
"content": "<extra_id_56>",
|
| 374 |
+
"lstrip": false,
|
| 375 |
+
"normalized": false,
|
| 376 |
+
"rstrip": false,
|
| 377 |
+
"single_word": false,
|
| 378 |
+
"special": true
|
| 379 |
+
},
|
| 380 |
+
"32044": {
|
| 381 |
+
"content": "<extra_id_55>",
|
| 382 |
+
"lstrip": false,
|
| 383 |
+
"normalized": false,
|
| 384 |
+
"rstrip": false,
|
| 385 |
+
"single_word": false,
|
| 386 |
+
"special": true
|
| 387 |
+
},
|
| 388 |
+
"32045": {
|
| 389 |
+
"content": "<extra_id_54>",
|
| 390 |
+
"lstrip": false,
|
| 391 |
+
"normalized": false,
|
| 392 |
+
"rstrip": false,
|
| 393 |
+
"single_word": false,
|
| 394 |
+
"special": true
|
| 395 |
+
},
|
| 396 |
+
"32046": {
|
| 397 |
+
"content": "<extra_id_53>",
|
| 398 |
+
"lstrip": false,
|
| 399 |
+
"normalized": false,
|
| 400 |
+
"rstrip": false,
|
| 401 |
+
"single_word": false,
|
| 402 |
+
"special": true
|
| 403 |
+
},
|
| 404 |
+
"32047": {
|
| 405 |
+
"content": "<extra_id_52>",
|
| 406 |
+
"lstrip": false,
|
| 407 |
+
"normalized": false,
|
| 408 |
+
"rstrip": false,
|
| 409 |
+
"single_word": false,
|
| 410 |
+
"special": true
|
| 411 |
+
},
|
| 412 |
+
"32048": {
|
| 413 |
+
"content": "<extra_id_51>",
|
| 414 |
+
"lstrip": false,
|
| 415 |
+
"normalized": false,
|
| 416 |
+
"rstrip": false,
|
| 417 |
+
"single_word": false,
|
| 418 |
+
"special": true
|
| 419 |
+
},
|
| 420 |
+
"32049": {
|
| 421 |
+
"content": "<extra_id_50>",
|
| 422 |
+
"lstrip": false,
|
| 423 |
+
"normalized": false,
|
| 424 |
+
"rstrip": false,
|
| 425 |
+
"single_word": false,
|
| 426 |
+
"special": true
|
| 427 |
+
},
|
| 428 |
+
"32050": {
|
| 429 |
+
"content": "<extra_id_49>",
|
| 430 |
+
"lstrip": false,
|
| 431 |
+
"normalized": false,
|
| 432 |
+
"rstrip": false,
|
| 433 |
+
"single_word": false,
|
| 434 |
+
"special": true
|
| 435 |
+
},
|
| 436 |
+
"32051": {
|
| 437 |
+
"content": "<extra_id_48>",
|
| 438 |
+
"lstrip": false,
|
| 439 |
+
"normalized": false,
|
| 440 |
+
"rstrip": false,
|
| 441 |
+
"single_word": false,
|
| 442 |
+
"special": true
|
| 443 |
+
},
|
| 444 |
+
"32052": {
|
| 445 |
+
"content": "<extra_id_47>",
|
| 446 |
+
"lstrip": false,
|
| 447 |
+
"normalized": false,
|
| 448 |
+
"rstrip": false,
|
| 449 |
+
"single_word": false,
|
| 450 |
+
"special": true
|
| 451 |
+
},
|
| 452 |
+
"32053": {
|
| 453 |
+
"content": "<extra_id_46>",
|
| 454 |
+
"lstrip": false,
|
| 455 |
+
"normalized": false,
|
| 456 |
+
"rstrip": false,
|
| 457 |
+
"single_word": false,
|
| 458 |
+
"special": true
|
| 459 |
+
},
|
| 460 |
+
"32054": {
|
| 461 |
+
"content": "<extra_id_45>",
|
| 462 |
+
"lstrip": false,
|
| 463 |
+
"normalized": false,
|
| 464 |
+
"rstrip": false,
|
| 465 |
+
"single_word": false,
|
| 466 |
+
"special": true
|
| 467 |
+
},
|
| 468 |
+
"32055": {
|
| 469 |
+
"content": "<extra_id_44>",
|
| 470 |
+
"lstrip": false,
|
| 471 |
+
"normalized": false,
|
| 472 |
+
"rstrip": false,
|
| 473 |
+
"single_word": false,
|
| 474 |
+
"special": true
|
| 475 |
+
},
|
| 476 |
+
"32056": {
|
| 477 |
+
"content": "<extra_id_43>",
|
| 478 |
+
"lstrip": false,
|
| 479 |
+
"normalized": false,
|
| 480 |
+
"rstrip": false,
|
| 481 |
+
"single_word": false,
|
| 482 |
+
"special": true
|
| 483 |
+
},
|
| 484 |
+
"32057": {
|
| 485 |
+
"content": "<extra_id_42>",
|
| 486 |
+
"lstrip": false,
|
| 487 |
+
"normalized": false,
|
| 488 |
+
"rstrip": false,
|
| 489 |
+
"single_word": false,
|
| 490 |
+
"special": true
|
| 491 |
+
},
|
| 492 |
+
"32058": {
|
| 493 |
+
"content": "<extra_id_41>",
|
| 494 |
+
"lstrip": false,
|
| 495 |
+
"normalized": false,
|
| 496 |
+
"rstrip": false,
|
| 497 |
+
"single_word": false,
|
| 498 |
+
"special": true
|
| 499 |
+
},
|
| 500 |
+
"32059": {
|
| 501 |
+
"content": "<extra_id_40>",
|
| 502 |
+
"lstrip": false,
|
| 503 |
+
"normalized": false,
|
| 504 |
+
"rstrip": false,
|
| 505 |
+
"single_word": false,
|
| 506 |
+
"special": true
|
| 507 |
+
},
|
| 508 |
+
"32060": {
|
| 509 |
+
"content": "<extra_id_39>",
|
| 510 |
+
"lstrip": false,
|
| 511 |
+
"normalized": false,
|
| 512 |
+
"rstrip": false,
|
| 513 |
+
"single_word": false,
|
| 514 |
+
"special": true
|
| 515 |
+
},
|
| 516 |
+
"32061": {
|
| 517 |
+
"content": "<extra_id_38>",
|
| 518 |
+
"lstrip": false,
|
| 519 |
+
"normalized": false,
|
| 520 |
+
"rstrip": false,
|
| 521 |
+
"single_word": false,
|
| 522 |
+
"special": true
|
| 523 |
+
},
|
| 524 |
+
"32062": {
|
| 525 |
+
"content": "<extra_id_37>",
|
| 526 |
+
"lstrip": false,
|
| 527 |
+
"normalized": false,
|
| 528 |
+
"rstrip": false,
|
| 529 |
+
"single_word": false,
|
| 530 |
+
"special": true
|
| 531 |
+
},
|
| 532 |
+
"32063": {
|
| 533 |
+
"content": "<extra_id_36>",
|
| 534 |
+
"lstrip": false,
|
| 535 |
+
"normalized": false,
|
| 536 |
+
"rstrip": false,
|
| 537 |
+
"single_word": false,
|
| 538 |
+
"special": true
|
| 539 |
+
},
|
| 540 |
+
"32064": {
|
| 541 |
+
"content": "<extra_id_35>",
|
| 542 |
+
"lstrip": false,
|
| 543 |
+
"normalized": false,
|
| 544 |
+
"rstrip": false,
|
| 545 |
+
"single_word": false,
|
| 546 |
+
"special": true
|
| 547 |
+
},
|
| 548 |
+
"32065": {
|
| 549 |
+
"content": "<extra_id_34>",
|
| 550 |
+
"lstrip": false,
|
| 551 |
+
"normalized": false,
|
| 552 |
+
"rstrip": false,
|
| 553 |
+
"single_word": false,
|
| 554 |
+
"special": true
|
| 555 |
+
},
|
| 556 |
+
"32066": {
|
| 557 |
+
"content": "<extra_id_33>",
|
| 558 |
+
"lstrip": false,
|
| 559 |
+
"normalized": false,
|
| 560 |
+
"rstrip": false,
|
| 561 |
+
"single_word": false,
|
| 562 |
+
"special": true
|
| 563 |
+
},
|
| 564 |
+
"32067": {
|
| 565 |
+
"content": "<extra_id_32>",
|
| 566 |
+
"lstrip": false,
|
| 567 |
+
"normalized": false,
|
| 568 |
+
"rstrip": false,
|
| 569 |
+
"single_word": false,
|
| 570 |
+
"special": true
|
| 571 |
+
},
|
| 572 |
+
"32068": {
|
| 573 |
+
"content": "<extra_id_31>",
|
| 574 |
+
"lstrip": false,
|
| 575 |
+
"normalized": false,
|
| 576 |
+
"rstrip": false,
|
| 577 |
+
"single_word": false,
|
| 578 |
+
"special": true
|
| 579 |
+
},
|
| 580 |
+
"32069": {
|
| 581 |
+
"content": "<extra_id_30>",
|
| 582 |
+
"lstrip": false,
|
| 583 |
+
"normalized": false,
|
| 584 |
+
"rstrip": false,
|
| 585 |
+
"single_word": false,
|
| 586 |
+
"special": true
|
| 587 |
+
},
|
| 588 |
+
"32070": {
|
| 589 |
+
"content": "<extra_id_29>",
|
| 590 |
+
"lstrip": false,
|
| 591 |
+
"normalized": false,
|
| 592 |
+
"rstrip": false,
|
| 593 |
+
"single_word": false,
|
| 594 |
+
"special": true
|
| 595 |
+
},
|
| 596 |
+
"32071": {
|
| 597 |
+
"content": "<extra_id_28>",
|
| 598 |
+
"lstrip": false,
|
| 599 |
+
"normalized": false,
|
| 600 |
+
"rstrip": false,
|
| 601 |
+
"single_word": false,
|
| 602 |
+
"special": true
|
| 603 |
+
},
|
| 604 |
+
"32072": {
|
| 605 |
+
"content": "<extra_id_27>",
|
| 606 |
+
"lstrip": false,
|
| 607 |
+
"normalized": false,
|
| 608 |
+
"rstrip": false,
|
| 609 |
+
"single_word": false,
|
| 610 |
+
"special": true
|
| 611 |
+
},
|
| 612 |
+
"32073": {
|
| 613 |
+
"content": "<extra_id_26>",
|
| 614 |
+
"lstrip": false,
|
| 615 |
+
"normalized": false,
|
| 616 |
+
"rstrip": false,
|
| 617 |
+
"single_word": false,
|
| 618 |
+
"special": true
|
| 619 |
+
},
|
| 620 |
+
"32074": {
|
| 621 |
+
"content": "<extra_id_25>",
|
| 622 |
+
"lstrip": false,
|
| 623 |
+
"normalized": false,
|
| 624 |
+
"rstrip": false,
|
| 625 |
+
"single_word": false,
|
| 626 |
+
"special": true
|
| 627 |
+
},
|
| 628 |
+
"32075": {
|
| 629 |
+
"content": "<extra_id_24>",
|
| 630 |
+
"lstrip": false,
|
| 631 |
+
"normalized": false,
|
| 632 |
+
"rstrip": false,
|
| 633 |
+
"single_word": false,
|
| 634 |
+
"special": true
|
| 635 |
+
},
|
| 636 |
+
"32076": {
|
| 637 |
+
"content": "<extra_id_23>",
|
| 638 |
+
"lstrip": false,
|
| 639 |
+
"normalized": false,
|
| 640 |
+
"rstrip": false,
|
| 641 |
+
"single_word": false,
|
| 642 |
+
"special": true
|
| 643 |
+
},
|
| 644 |
+
"32077": {
|
| 645 |
+
"content": "<extra_id_22>",
|
| 646 |
+
"lstrip": false,
|
| 647 |
+
"normalized": false,
|
| 648 |
+
"rstrip": false,
|
| 649 |
+
"single_word": false,
|
| 650 |
+
"special": true
|
| 651 |
+
},
|
| 652 |
+
"32078": {
|
| 653 |
+
"content": "<extra_id_21>",
|
| 654 |
+
"lstrip": false,
|
| 655 |
+
"normalized": false,
|
| 656 |
+
"rstrip": false,
|
| 657 |
+
"single_word": false,
|
| 658 |
+
"special": true
|
| 659 |
+
},
|
| 660 |
+
"32079": {
|
| 661 |
+
"content": "<extra_id_20>",
|
| 662 |
+
"lstrip": false,
|
| 663 |
+
"normalized": false,
|
| 664 |
+
"rstrip": false,
|
| 665 |
+
"single_word": false,
|
| 666 |
+
"special": true
|
| 667 |
+
},
|
| 668 |
+
"32080": {
|
| 669 |
+
"content": "<extra_id_19>",
|
| 670 |
+
"lstrip": false,
|
| 671 |
+
"normalized": false,
|
| 672 |
+
"rstrip": false,
|
| 673 |
+
"single_word": false,
|
| 674 |
+
"special": true
|
| 675 |
+
},
|
| 676 |
+
"32081": {
|
| 677 |
+
"content": "<extra_id_18>",
|
| 678 |
+
"lstrip": false,
|
| 679 |
+
"normalized": false,
|
| 680 |
+
"rstrip": false,
|
| 681 |
+
"single_word": false,
|
| 682 |
+
"special": true
|
| 683 |
+
},
|
| 684 |
+
"32082": {
|
| 685 |
+
"content": "<extra_id_17>",
|
| 686 |
+
"lstrip": false,
|
| 687 |
+
"normalized": false,
|
| 688 |
+
"rstrip": false,
|
| 689 |
+
"single_word": false,
|
| 690 |
+
"special": true
|
| 691 |
+
},
|
| 692 |
+
"32083": {
|
| 693 |
+
"content": "<extra_id_16>",
|
| 694 |
+
"lstrip": false,
|
| 695 |
+
"normalized": false,
|
| 696 |
+
"rstrip": false,
|
| 697 |
+
"single_word": false,
|
| 698 |
+
"special": true
|
| 699 |
+
},
|
| 700 |
+
"32084": {
|
| 701 |
+
"content": "<extra_id_15>",
|
| 702 |
+
"lstrip": false,
|
| 703 |
+
"normalized": false,
|
| 704 |
+
"rstrip": false,
|
| 705 |
+
"single_word": false,
|
| 706 |
+
"special": true
|
| 707 |
+
},
|
| 708 |
+
"32085": {
|
| 709 |
+
"content": "<extra_id_14>",
|
| 710 |
+
"lstrip": false,
|
| 711 |
+
"normalized": false,
|
| 712 |
+
"rstrip": false,
|
| 713 |
+
"single_word": false,
|
| 714 |
+
"special": true
|
| 715 |
+
},
|
| 716 |
+
"32086": {
|
| 717 |
+
"content": "<extra_id_13>",
|
| 718 |
+
"lstrip": false,
|
| 719 |
+
"normalized": false,
|
| 720 |
+
"rstrip": false,
|
| 721 |
+
"single_word": false,
|
| 722 |
+
"special": true
|
| 723 |
+
},
|
| 724 |
+
"32087": {
|
| 725 |
+
"content": "<extra_id_12>",
|
| 726 |
+
"lstrip": false,
|
| 727 |
+
"normalized": false,
|
| 728 |
+
"rstrip": false,
|
| 729 |
+
"single_word": false,
|
| 730 |
+
"special": true
|
| 731 |
+
},
|
| 732 |
+
"32088": {
|
| 733 |
+
"content": "<extra_id_11>",
|
| 734 |
+
"lstrip": false,
|
| 735 |
+
"normalized": false,
|
| 736 |
+
"rstrip": false,
|
| 737 |
+
"single_word": false,
|
| 738 |
+
"special": true
|
| 739 |
+
},
|
| 740 |
+
"32089": {
|
| 741 |
+
"content": "<extra_id_10>",
|
| 742 |
+
"lstrip": false,
|
| 743 |
+
"normalized": false,
|
| 744 |
+
"rstrip": false,
|
| 745 |
+
"single_word": false,
|
| 746 |
+
"special": true
|
| 747 |
+
},
|
| 748 |
+
"32090": {
|
| 749 |
+
"content": "<extra_id_9>",
|
| 750 |
+
"lstrip": false,
|
| 751 |
+
"normalized": false,
|
| 752 |
+
"rstrip": false,
|
| 753 |
+
"single_word": false,
|
| 754 |
+
"special": true
|
| 755 |
+
},
|
| 756 |
+
"32091": {
|
| 757 |
+
"content": "<extra_id_8>",
|
| 758 |
+
"lstrip": false,
|
| 759 |
+
"normalized": false,
|
| 760 |
+
"rstrip": false,
|
| 761 |
+
"single_word": false,
|
| 762 |
+
"special": true
|
| 763 |
+
},
|
| 764 |
+
"32092": {
|
| 765 |
+
"content": "<extra_id_7>",
|
| 766 |
+
"lstrip": false,
|
| 767 |
+
"normalized": false,
|
| 768 |
+
"rstrip": false,
|
| 769 |
+
"single_word": false,
|
| 770 |
+
"special": true
|
| 771 |
+
},
|
| 772 |
+
"32093": {
|
| 773 |
+
"content": "<extra_id_6>",
|
| 774 |
+
"lstrip": false,
|
| 775 |
+
"normalized": false,
|
| 776 |
+
"rstrip": false,
|
| 777 |
+
"single_word": false,
|
| 778 |
+
"special": true
|
| 779 |
+
},
|
| 780 |
+
"32094": {
|
| 781 |
+
"content": "<extra_id_5>",
|
| 782 |
+
"lstrip": false,
|
| 783 |
+
"normalized": false,
|
| 784 |
+
"rstrip": false,
|
| 785 |
+
"single_word": false,
|
| 786 |
+
"special": true
|
| 787 |
+
},
|
| 788 |
+
"32095": {
|
| 789 |
+
"content": "<extra_id_4>",
|
| 790 |
+
"lstrip": false,
|
| 791 |
+
"normalized": false,
|
| 792 |
+
"rstrip": false,
|
| 793 |
+
"single_word": false,
|
| 794 |
+
"special": true
|
| 795 |
+
},
|
| 796 |
+
"32096": {
|
| 797 |
+
"content": "<extra_id_3>",
|
| 798 |
+
"lstrip": false,
|
| 799 |
+
"normalized": false,
|
| 800 |
+
"rstrip": false,
|
| 801 |
+
"single_word": false,
|
| 802 |
+
"special": true
|
| 803 |
+
},
|
| 804 |
+
"32097": {
|
| 805 |
+
"content": "<extra_id_2>",
|
| 806 |
+
"lstrip": false,
|
| 807 |
+
"normalized": false,
|
| 808 |
+
"rstrip": false,
|
| 809 |
+
"single_word": false,
|
| 810 |
+
"special": true
|
| 811 |
+
},
|
| 812 |
+
"32098": {
|
| 813 |
+
"content": "<extra_id_1>",
|
| 814 |
+
"lstrip": false,
|
| 815 |
+
"normalized": false,
|
| 816 |
+
"rstrip": false,
|
| 817 |
+
"single_word": false,
|
| 818 |
+
"special": true
|
| 819 |
+
},
|
| 820 |
+
"32099": {
|
| 821 |
+
"content": "<extra_id_0>",
|
| 822 |
+
"lstrip": false,
|
| 823 |
+
"normalized": false,
|
| 824 |
"rstrip": false,
|
| 825 |
"single_word": false,
|
| 826 |
"special": true
|
| 827 |
}
|
| 828 |
},
|
| 829 |
+
"additional_special_tokens": [
|
| 830 |
+
"<extra_id_0>",
|
| 831 |
+
"<extra_id_1>",
|
| 832 |
+
"<extra_id_2>",
|
| 833 |
+
"<extra_id_3>",
|
| 834 |
+
"<extra_id_4>",
|
| 835 |
+
"<extra_id_5>",
|
| 836 |
+
"<extra_id_6>",
|
| 837 |
+
"<extra_id_7>",
|
| 838 |
+
"<extra_id_8>",
|
| 839 |
+
"<extra_id_9>",
|
| 840 |
+
"<extra_id_10>",
|
| 841 |
+
"<extra_id_11>",
|
| 842 |
+
"<extra_id_12>",
|
| 843 |
+
"<extra_id_13>",
|
| 844 |
+
"<extra_id_14>",
|
| 845 |
+
"<extra_id_15>",
|
| 846 |
+
"<extra_id_16>",
|
| 847 |
+
"<extra_id_17>",
|
| 848 |
+
"<extra_id_18>",
|
| 849 |
+
"<extra_id_19>",
|
| 850 |
+
"<extra_id_20>",
|
| 851 |
+
"<extra_id_21>",
|
| 852 |
+
"<extra_id_22>",
|
| 853 |
+
"<extra_id_23>",
|
| 854 |
+
"<extra_id_24>",
|
| 855 |
+
"<extra_id_25>",
|
| 856 |
+
"<extra_id_26>",
|
| 857 |
+
"<extra_id_27>",
|
| 858 |
+
"<extra_id_28>",
|
| 859 |
+
"<extra_id_29>",
|
| 860 |
+
"<extra_id_30>",
|
| 861 |
+
"<extra_id_31>",
|
| 862 |
+
"<extra_id_32>",
|
| 863 |
+
"<extra_id_33>",
|
| 864 |
+
"<extra_id_34>",
|
| 865 |
+
"<extra_id_35>",
|
| 866 |
+
"<extra_id_36>",
|
| 867 |
+
"<extra_id_37>",
|
| 868 |
+
"<extra_id_38>",
|
| 869 |
+
"<extra_id_39>",
|
| 870 |
+
"<extra_id_40>",
|
| 871 |
+
"<extra_id_41>",
|
| 872 |
+
"<extra_id_42>",
|
| 873 |
+
"<extra_id_43>",
|
| 874 |
+
"<extra_id_44>",
|
| 875 |
+
"<extra_id_45>",
|
| 876 |
+
"<extra_id_46>",
|
| 877 |
+
"<extra_id_47>",
|
| 878 |
+
"<extra_id_48>",
|
| 879 |
+
"<extra_id_49>",
|
| 880 |
+
"<extra_id_50>",
|
| 881 |
+
"<extra_id_51>",
|
| 882 |
+
"<extra_id_52>",
|
| 883 |
+
"<extra_id_53>",
|
| 884 |
+
"<extra_id_54>",
|
| 885 |
+
"<extra_id_55>",
|
| 886 |
+
"<extra_id_56>",
|
| 887 |
+
"<extra_id_57>",
|
| 888 |
+
"<extra_id_58>",
|
| 889 |
+
"<extra_id_59>",
|
| 890 |
+
"<extra_id_60>",
|
| 891 |
+
"<extra_id_61>",
|
| 892 |
+
"<extra_id_62>",
|
| 893 |
+
"<extra_id_63>",
|
| 894 |
+
"<extra_id_64>",
|
| 895 |
+
"<extra_id_65>",
|
| 896 |
+
"<extra_id_66>",
|
| 897 |
+
"<extra_id_67>",
|
| 898 |
+
"<extra_id_68>",
|
| 899 |
+
"<extra_id_69>",
|
| 900 |
+
"<extra_id_70>",
|
| 901 |
+
"<extra_id_71>",
|
| 902 |
+
"<extra_id_72>",
|
| 903 |
+
"<extra_id_73>",
|
| 904 |
+
"<extra_id_74>",
|
| 905 |
+
"<extra_id_75>",
|
| 906 |
+
"<extra_id_76>",
|
| 907 |
+
"<extra_id_77>",
|
| 908 |
+
"<extra_id_78>",
|
| 909 |
+
"<extra_id_79>",
|
| 910 |
+
"<extra_id_80>",
|
| 911 |
+
"<extra_id_81>",
|
| 912 |
+
"<extra_id_82>",
|
| 913 |
+
"<extra_id_83>",
|
| 914 |
+
"<extra_id_84>",
|
| 915 |
+
"<extra_id_85>",
|
| 916 |
+
"<extra_id_86>",
|
| 917 |
+
"<extra_id_87>",
|
| 918 |
+
"<extra_id_88>",
|
| 919 |
+
"<extra_id_89>",
|
| 920 |
+
"<extra_id_90>",
|
| 921 |
+
"<extra_id_91>",
|
| 922 |
+
"<extra_id_92>",
|
| 923 |
+
"<extra_id_93>",
|
| 924 |
+
"<extra_id_94>",
|
| 925 |
+
"<extra_id_95>",
|
| 926 |
+
"<extra_id_96>",
|
| 927 |
+
"<extra_id_97>",
|
| 928 |
+
"<extra_id_98>",
|
| 929 |
+
"<extra_id_99>"
|
| 930 |
+
],
|
| 931 |
"clean_up_tokenization_spaces": false,
|
|
|
|
| 932 |
"eos_token": "</s>",
|
| 933 |
+
"extra_ids": 100,
|
| 934 |
"extra_special_tokens": {},
|
| 935 |
+
"model_max_length": 512,
|
|
|
|
| 936 |
"pad_token": "<pad>",
|
| 937 |
+
"sp_model_kwargs": {},
|
| 938 |
+
"tokenizer_class": "T5Tokenizer",
|
|
|
|
| 939 |
"unk_token": "<unk>"
|
| 940 |
}
|
configs/data/datasets.yaml
CHANGED
|
@@ -9,7 +9,7 @@ processed:
|
|
| 9 |
topic: data/processed/topic
|
| 10 |
books: data/processed/books
|
| 11 |
tokenizer:
|
| 12 |
-
pretrained_model_name:
|
| 13 |
max_length: 512
|
| 14 |
lower: false
|
| 15 |
downloads:
|
|
@@ -20,6 +20,15 @@ downloads:
|
|
| 20 |
- name: pride_and_prejudice
|
| 21 |
url: https://www.gutenberg.org/cache/epub/1342/pg1342.txt
|
| 22 |
output: data/raw/books/pride_and_prejudice.txt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
emotion:
|
| 24 |
dataset: dair-ai/emotion
|
| 25 |
topic:
|
|
|
|
| 9 |
topic: data/processed/topic
|
| 10 |
books: data/processed/books
|
| 11 |
tokenizer:
|
| 12 |
+
pretrained_model_name: google/flan-t5-base
|
| 13 |
max_length: 512
|
| 14 |
lower: false
|
| 15 |
downloads:
|
|
|
|
| 20 |
- name: pride_and_prejudice
|
| 21 |
url: https://www.gutenberg.org/cache/epub/1342/pg1342.txt
|
| 22 |
output: data/raw/books/pride_and_prejudice.txt
|
| 23 |
+
- name: frankenstein
|
| 24 |
+
url: https://www.gutenberg.org/cache/epub/84/pg84.txt
|
| 25 |
+
output: data/raw/books/frankenstein.txt
|
| 26 |
+
- name: sherlock_holmes
|
| 27 |
+
url: https://www.gutenberg.org/cache/epub/1661/pg1661.txt
|
| 28 |
+
output: data/raw/books/sherlock_holmes.txt
|
| 29 |
+
- name: moby_dick
|
| 30 |
+
url: https://www.gutenberg.org/cache/epub/2701/pg2701.txt
|
| 31 |
+
output: data/raw/books/moby_dick.txt
|
| 32 |
emotion:
|
| 33 |
dataset: dair-ai/emotion
|
| 34 |
topic:
|
configs/model/base.yaml
CHANGED
|
@@ -1,8 +1,12 @@
|
|
|
|
|
|
|
|
| 1 |
d_model: 768
|
| 2 |
-
num_encoder_layers:
|
| 3 |
-
num_decoder_layers:
|
| 4 |
num_attention_heads: 12
|
| 5 |
-
ffn_dim:
|
| 6 |
-
dropout: 0.
|
|
|
|
| 7 |
use_pretrained: true
|
| 8 |
-
pretrained_model_name:
|
|
|
|
|
|
| 1 |
+
# FLAN-T5-base architecture
|
| 2 |
+
# 12 encoder layers, 12 decoder layers, 768 hidden dim
|
| 3 |
d_model: 768
|
| 4 |
+
num_encoder_layers: 12
|
| 5 |
+
num_decoder_layers: 12
|
| 6 |
num_attention_heads: 12
|
| 7 |
+
ffn_dim: 2048 # T5 uses d_ff = 2048 for base model
|
| 8 |
+
dropout: 0.1
|
| 9 |
+
activation: gated-gelu # T5/FLAN-T5 uses gated-gelu (GELU activation with gating, not SwiGLU)
|
| 10 |
use_pretrained: true
|
| 11 |
+
pretrained_model_name: google/flan-t5-base
|
| 12 |
+
use_relative_position_bias: true # T5 uses relative position bias instead of absolute embeddings
|
configs/model/large.yaml
CHANGED
|
@@ -1,6 +1,11 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
|
|
|
|
|
|
| 6 |
dropout: 0.1
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FLAN-T5-large architecture
|
| 2 |
+
# 24 encoder layers, 24 decoder layers, 1024 hidden dim
|
| 3 |
+
d_model: 1024
|
| 4 |
+
num_encoder_layers: 24
|
| 5 |
+
num_decoder_layers: 24
|
| 6 |
+
num_attention_heads: 16
|
| 7 |
+
ffn_dim: 2816 # T5-large uses 2816
|
| 8 |
dropout: 0.1
|
| 9 |
+
activation: gated-gelu # T5/FLAN-T5 uses gated-gelu (GELU with gating)
|
| 10 |
+
use_pretrained: true
|
| 11 |
+
pretrained_model_name: google/flan-t5-large
|
configs/model/small.yaml
CHANGED
|
@@ -1,6 +1,10 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
|
|
|
| 5 |
ffn_dim: 1024
|
| 6 |
dropout: 0.1
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Small config for quick testing (no pretrained weights)
|
| 2 |
+
d_model: 512
|
| 3 |
+
num_encoder_layers: 6
|
| 4 |
+
num_decoder_layers: 6
|
| 5 |
+
num_attention_heads: 8
|
| 6 |
ffn_dim: 1024
|
| 7 |
dropout: 0.1
|
| 8 |
+
activation: gated-gelu # Use gated-gelu for T5 compatibility
|
| 9 |
+
use_pretrained: false
|
| 10 |
+
pretrained_model_name: google/flan-t5-small
|
configs/training/default.yaml
DELETED
|
@@ -1,20 +0,0 @@
|
|
| 1 |
-
dataloader:
|
| 2 |
-
batch_size: 8
|
| 3 |
-
shuffle: true
|
| 4 |
-
optimizer:
|
| 5 |
-
name: adamw
|
| 6 |
-
lr: 3.0e-5
|
| 7 |
-
weight_decay: 0.01 # L2 regularization to prevent overfitting
|
| 8 |
-
scheduler:
|
| 9 |
-
name: cosine
|
| 10 |
-
warmup_steps: 500
|
| 11 |
-
trainer:
|
| 12 |
-
max_epochs: 4 # Reduced from 5 to prevent overfitting
|
| 13 |
-
gradient_clip_norm: 1.0
|
| 14 |
-
validation_samples: 3
|
| 15 |
-
validation_max_length: 128
|
| 16 |
-
label_smoothing: 0.1 # Smooths target distribution for better generalization
|
| 17 |
-
task_weights:
|
| 18 |
-
summarization: 1.0
|
| 19 |
-
emotion: 1.0
|
| 20 |
-
topic: 1.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/training/dev.yaml
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Development/Testing Configuration for FLAN-T5-base
|
| 2 |
+
# Fast iteration for debugging and testing changes
|
| 3 |
+
# Training time: ~10 minutes on RTX 4070 with aot_eager backend
|
| 4 |
+
# Use: python scripts/train.py training=dev
|
| 5 |
+
|
| 6 |
+
dataloader:
|
| 7 |
+
batch_size: 8
|
| 8 |
+
shuffle: true
|
| 9 |
+
num_workers: 4 # Reduced to avoid overhead
|
| 10 |
+
pin_memory: true
|
| 11 |
+
|
| 12 |
+
optimizer:
|
| 13 |
+
name: adamw
|
| 14 |
+
lr: 5.0e-5 # Higher LR for faster convergence on small dataset
|
| 15 |
+
weight_decay: 0.01
|
| 16 |
+
|
| 17 |
+
scheduler:
|
| 18 |
+
name: cosine
|
| 19 |
+
warmup_steps: 50 # Fewer warmup steps for short training
|
| 20 |
+
|
| 21 |
+
trainer:
|
| 22 |
+
max_epochs: 1 # Single epoch for quick testing
|
| 23 |
+
gradient_clip_norm: 1.0
|
| 24 |
+
gradient_accumulation_steps: 1 # No accumulation for speed
|
| 25 |
+
validation_max_length: 64 # Shorter for faster validation
|
| 26 |
+
label_smoothing: 0.1
|
| 27 |
+
task_weights:
|
| 28 |
+
summarization: 1.0
|
| 29 |
+
emotion: 1.0
|
| 30 |
+
topic: 1.0
|
| 31 |
+
|
| 32 |
+
# Development-specific settings - optimized for ~10 min total
|
| 33 |
+
max_train_samples: 2000 # Reduced for faster iteration
|
| 34 |
+
max_val_samples: 200
|
| 35 |
+
validation_frequency: 1000 # Validate once during training
|
configs/training/full.yaml
CHANGED
|
@@ -1,12 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
dataloader:
|
| 2 |
-
batch_size:
|
| 3 |
shuffle: true
|
|
|
|
|
|
|
|
|
|
| 4 |
optimizer:
|
| 5 |
name: adamw
|
| 6 |
lr: 2.0e-5
|
|
|
|
|
|
|
| 7 |
scheduler:
|
| 8 |
name: cosine
|
| 9 |
-
warmup_steps: 1000
|
|
|
|
| 10 |
trainer:
|
| 11 |
-
max_epochs:
|
| 12 |
-
gradient_clip_norm:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Full Training Configuration for FLAN-T5-base
|
| 2 |
+
# Complete training run on all data
|
| 3 |
+
# Training time: ~6-8 hours on RTX 4070
|
| 4 |
+
# Use: python scripts/train.py training=full
|
| 5 |
+
|
| 6 |
dataloader:
|
| 7 |
+
batch_size: 11 # Reduced for FLAN-T5-base (12 layers)
|
| 8 |
shuffle: true
|
| 9 |
+
num_workers: 8
|
| 10 |
+
pin_memory: true
|
| 11 |
+
|
| 12 |
optimizer:
|
| 13 |
name: adamw
|
| 14 |
lr: 2.0e-5
|
| 15 |
+
weight_decay: 0.01
|
| 16 |
+
|
| 17 |
scheduler:
|
| 18 |
name: cosine
|
| 19 |
+
warmup_steps: 1000 # More warmup for full training
|
| 20 |
+
|
| 21 |
trainer:
|
| 22 |
+
max_epochs: 4
|
| 23 |
+
gradient_clip_norm: 0.5
|
| 24 |
+
gradient_accumulation_steps: 6 # Effective batch size = 8 * 6 = 48
|
| 25 |
+
validation_max_length: 128
|
| 26 |
+
label_smoothing: 0.1
|
| 27 |
+
task_weights:
|
| 28 |
+
summarization: 1.0
|
| 29 |
+
emotion: 1.0
|
| 30 |
+
topic: 1.0
|
configs/training/medium.yaml
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Medium Configuration for FLAN-T5-base
|
| 2 |
+
# Balanced approach - good results in reasonable time
|
| 3 |
+
# Training time: ~2-3 hours on RTX 4070
|
| 4 |
+
# Use: python scripts/train.py training=medium
|
| 5 |
+
# Note: FLAN-T5-base has 12 layers (vs BART's 6), may need smaller batch
|
| 6 |
+
|
| 7 |
+
dataloader:
|
| 8 |
+
batch_size: 11 # Reduced for FLAN-T5-base (12 layers uses more VRAM)
|
| 9 |
+
shuffle: true
|
| 10 |
+
num_workers: 8
|
| 11 |
+
pin_memory: true
|
| 12 |
+
|
| 13 |
+
optimizer:
|
| 14 |
+
name: adamw
|
| 15 |
+
lr: 2.0e-5 # Slightly lower for larger model
|
| 16 |
+
weight_decay: 0.01
|
| 17 |
+
|
| 18 |
+
scheduler:
|
| 19 |
+
name: cosine
|
| 20 |
+
warmup_steps: 500 # More warmup for larger model
|
| 21 |
+
|
| 22 |
+
trainer:
|
| 23 |
+
max_epochs: 3
|
| 24 |
+
gradient_clip_norm: 0.5
|
| 25 |
+
gradient_accumulation_steps: 4 # Effective batch size = 8 * 4 = 32
|
| 26 |
+
validation_max_length: 128
|
| 27 |
+
label_smoothing: 0.1
|
| 28 |
+
task_weights:
|
| 29 |
+
summarization: 1.0
|
| 30 |
+
emotion: 1.0
|
| 31 |
+
topic: 1.0
|
| 32 |
+
|
| 33 |
+
# Medium dataset - good representative sample
|
| 34 |
+
max_train_samples: 50000
|
| 35 |
+
max_val_samples: 5000
|
| 36 |
+
validation_frequency: 5000
|
configs/training/quick_test.yaml
DELETED
|
@@ -1,9 +0,0 @@
|
|
| 1 |
-
dataloader:
|
| 2 |
-
batch_size: 2
|
| 3 |
-
shuffle: false
|
| 4 |
-
optimizer:
|
| 5 |
-
name: adamw
|
| 6 |
-
lr: 1.0e-4
|
| 7 |
-
trainer:
|
| 8 |
-
max_epochs: 1
|
| 9 |
-
gradient_clip_norm: 0.5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docs/architecture.md
CHANGED
|
@@ -8,50 +8,63 @@ LexiMind couples a from-scratch Transformer implementation with a modern data an
|
|
| 8 |
2. **Model Composition** – the bespoke encoder/decoder stack with task heads assembled via
|
| 9 |
`MultiTaskModel`, plus `models.factory.build_multitask_model` to rebuild the network from
|
| 10 |
configuration files.
|
| 11 |
-
3. **Inference & Serving** – a multi-task pipeline capable of summarization, emotion, and topic classification; surfaced through a CLI and FastAPI service with
|
| 12 |
|
| 13 |
## Custom Transformer Stack
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
-
|
| 19 |
-
|
| 20 |
-
-
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
## Data, Tokenization, and Preprocessing
|
| 24 |
-
- `src/data/tokenization.py` wraps `AutoTokenizer` to provide tensor-aware batching and helper
|
| 25 |
-
|
| 26 |
-
- `src/data/
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
-
|
| 31 |
-
|
| 32 |
-
emotions, categorical topics, seq2seq summaries).
|
| 33 |
|
| 34 |
## Training Pipeline
|
| 35 |
-
- `src/training/trainer.py` coordinates multi-task optimization with
|
| 36 |
-
-
|
| 37 |
-
-
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
## Inference & Serving
|
| 40 |
-
- `src/inference/pipeline.py` exposes summarization, emotion, and topic predictions with shared pre-processing, generation, and thresholding logic.
|
| 41 |
-
- `src/inference/factory.py` rebuilds the full pipeline
|
| 42 |
-
- The CLI (`scripts/inference.py`) drives the pipeline from the command line
|
| 43 |
-
|
| 44 |
-
## Gradio UI Roadmap
|
| 45 |
-
- The inference pipeline returns structured outputs that are already suitable for a web UI.
|
| 46 |
-
- Planned steps for a Gradio demo:
|
| 47 |
-
1. Wrap `InferencePipeline.batch_predict` inside Gradio callbacks for text input.
|
| 48 |
-
2. Display summaries alongside emotion tag chips and topic confidence bars.
|
| 49 |
-
3. Surface token-level attention visualizations by extending the pipeline to emit decoder attention maps (hooks already exist in the decoder).
|
| 50 |
-
- Documentation and code paths were structured to keep the Gradio integration isolated in a future `src/ui/gradio_app.py` module without altering core logic.
|
| 51 |
|
| 52 |
## Key Decisions
|
| 53 |
-
- **Custom Transformer
|
| 54 |
-
- **
|
| 55 |
-
- **
|
| 56 |
-
|
| 57 |
-
- **Documentation Alignment** – the `docs/` folder mirrors the structure requested, capturing design reasoning and paving the way for future diagrams in `docs/images`.
|
|
|
|
| 8 |
2. **Model Composition** – the bespoke encoder/decoder stack with task heads assembled via
|
| 9 |
`MultiTaskModel`, plus `models.factory.build_multitask_model` to rebuild the network from
|
| 10 |
configuration files.
|
| 11 |
+
3. **Inference & Serving** – a multi-task pipeline capable of summarization, emotion, and topic classification; surfaced through a CLI and FastAPI service with a Gradio UI.
|
| 12 |
|
| 13 |
## Custom Transformer Stack
|
| 14 |
+
|
| 15 |
+
The custom Transformer is designed with **modern architectural choices** while maintaining compatibility with pre-trained weights from Google's **FLAN-T5**.
|
| 16 |
+
|
| 17 |
+
### Architecture Highlights
|
| 18 |
+
- **Pre-Layer Normalization (Pre-LN):** RMSNorm applied *before* each sublayer for stable training
|
| 19 |
+
- **RMSNorm:** More efficient than LayerNorm (no mean computation, no bias parameters)
|
| 20 |
+
- **FlashAttention:** Via PyTorch 2.0's `F.scaled_dot_product_attention` for O(N) memory
|
| 21 |
+
- **Learned Positional Embeddings:** Trainable position representations (randomly initialized)
|
| 22 |
+
- **Multi-Head Attention:** 12 heads with optional LoRA adapters and RoPE support
|
| 23 |
+
|
| 24 |
+
### Weight Loading from FLAN-T5
|
| 25 |
+
The `factory.py` module loads weights from FLAN-T5-base, which uses a compatible Pre-LN architecture:
|
| 26 |
+
- **Token embeddings:** Shared between encoder and decoder
|
| 27 |
+
- **Attention projections:** Q, K, V, O weights (bias initialized to zero since T5 has no attention bias)
|
| 28 |
+
- **FFN weights:** `wi_1` → `linear1`, `wo` → `linear2` (T5 uses gated FFN; we use the up/down projections)
|
| 29 |
+
- **RMSNorm weights:** Direct transfer (both use RMSNorm without bias)
|
| 30 |
+
- **LM head:** Loaded from T5's `lm_head`
|
| 31 |
+
|
| 32 |
+
**Note:** T5 uses *relative position bias* computed in attention, not absolute embeddings. Our learned positional embeddings are randomly initialized and train quickly during fine-tuning.
|
| 33 |
+
|
| 34 |
+
### File Structure
|
| 35 |
+
- `src/models/encoder.py` – TransformerEncoder with Pre-LN RMSNorm blocks
|
| 36 |
+
- `src/models/decoder.py` – TransformerDecoder with KV-cache for efficient generation
|
| 37 |
+
- `src/models/attention.py` – Multi-Head Attention with FlashAttention, LoRA, and RoPE support
|
| 38 |
+
- `src/models/heads.py` – ClassificationHead (mean pooling) and LMHead (with weight tying)
|
| 39 |
+
- `src/models/multitask.py` – Routes inputs to task-specific heads
|
| 40 |
+
- `src/models/factory.py` – Builds models and loads FLAN-T5 weights
|
| 41 |
|
| 42 |
## Data, Tokenization, and Preprocessing
|
| 43 |
+
- `src/data/tokenization.py` wraps `AutoTokenizer` (configured for FLAN-T5) to provide tensor-aware batching and helper utilities for decoder input shifting.
|
| 44 |
+
- `src/data/preprocessing.py` introduces `TextPreprocessor`, layering a `BasicTextCleaner` with optional scikit-learn transformers.
|
| 45 |
+
- `src/data/dataset.py` and `src/data/dataloader.py` define strongly typed dataset containers and collators.
|
| 46 |
+
|
| 47 |
+
### T5 Tokenizer Differences
|
| 48 |
+
- **Vocab size:** 32,128 tokens (SentencePiece)
|
| 49 |
+
- **Special tokens:** pad=0, eos=1 (no explicit BOS; decoder starts with pad token)
|
| 50 |
+
- **Subword tokenization:** Unigram-based (vs BART's BPE)
|
|
|
|
| 51 |
|
| 52 |
## Training Pipeline
|
| 53 |
+
- `src/training/trainer.py` coordinates multi-task optimization with:
|
| 54 |
+
- Mixed precision training (bfloat16 on Ampere/Ada GPUs)
|
| 55 |
+
- Gradient accumulation for larger effective batch sizes
|
| 56 |
+
- Per-task loss weighting and label smoothing
|
| 57 |
+
- **torch.compile:** JIT compilation with Inductor backend for 20-40% speedup
|
| 58 |
+
- Metrics in `src/training/metrics.py` include accuracy, multi-label F1, and ROUGE-like overlap
|
| 59 |
|
| 60 |
## Inference & Serving
|
| 61 |
+
- `src/inference/pipeline.py` exposes summarization, emotion, and topic predictions with shared pre-processing, generation, and thresholding logic.
|
| 62 |
+
- `src/inference/factory.py` rebuilds the full pipeline using the exported tokenizer artifact
|
| 63 |
+
- The CLI (`scripts/inference.py`) drives the pipeline from the command line
|
| 64 |
+
- Gradio demo (`scripts/demo_gradio.py`) provides a web interface
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
## Key Decisions
|
| 67 |
+
- **Custom Transformer + Pre-trained Weights:** Building from scratch demonstrates deep understanding while leveraging FLAN-T5's language knowledge
|
| 68 |
+
- **Pre-LN RMSNorm:** Modern architecture used by LLaMA, T5 v1.1, and other 2023-2025 models
|
| 69 |
+
- **Tokenizer Artifact Preference:** Inference favors `artifacts/hf_tokenizer` for reproducibility
|
| 70 |
+
- **Sklearn-friendly Preprocessing:** Optional `TransformerMixin` injection for custom cleaning
|
|
|
docs/training.md
CHANGED
|
@@ -7,10 +7,10 @@
|
|
| 7 |
`text` and `emotions` arrays. The dataset owns a `MultiLabelBinarizer` for consistent encoding.
|
| 8 |
- **Topic Classification** – single-label categorical samples with `text` and `topic` fields, encoded via `LabelEncoder`.
|
| 9 |
|
| 10 |
-
Paths and tokenizer defaults are configured in `configs/data/datasets.yaml`. The tokenizer section chooses the Hugging Face backbone (`
|
| 11 |
|
| 12 |
## Dataloaders & Collators
|
| 13 |
-
- `SummarizationCollator` encodes encoder/decoder inputs, prepares decoder input IDs via `Tokenizer.prepare_decoder_inputs`, and masks padding tokens with `-100` for loss computation.
|
| 14 |
- `EmotionCollator` applies the dataset's `MultiLabelBinarizer`, returning dense float tensors suitable for `BCEWithLogitsLoss`.
|
| 15 |
- `TopicCollator` emits integer class IDs via the dataset's `LabelEncoder` for `CrossEntropyLoss`.
|
| 16 |
|
|
@@ -18,8 +18,13 @@ These collators keep all tokenization centralized, reducing duplication and maki
|
|
| 18 |
|
| 19 |
## Model Assembly
|
| 20 |
- `src/models/factory.build_multitask_model` rebuilds the encoder, decoder, and heads from the tokenizer metadata and YAML config. This factory is used both during training and inference to eliminate drift between environments.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
- The model wraps:
|
| 22 |
-
- Transformer encoder/decoder stacks with
|
| 23 |
- LM head tied to decoder embeddings for summarization.
|
| 24 |
- Mean-pooled classification heads for emotion and topic tasks.
|
| 25 |
|
|
@@ -39,21 +44,37 @@ These collators keep all tokenization centralized, reducing duplication and maki
|
|
| 39 |
- `src/utils/io.save_state` stores model weights; checkpoints live under `checkpoints/`.
|
| 40 |
- `artifacts/labels.json` captures the ordered emotion/topic vocabularies immediately after
|
| 41 |
training. This file is required for inference so class indices map back to human-readable labels.
|
| 42 |
-
- The tokenizer is exported to `artifacts/hf_tokenizer/` for reproducible vocabularies.
|
| 43 |
|
| 44 |
## Running Training
|
| 45 |
1. Ensure processed datasets are available (see `data/processed/` structure).
|
| 46 |
-
2.
|
| 47 |
-
3.
|
| 48 |
-
4.
|
|
|
|
| 49 |
`Trainer.fit(train_loaders, val_loaders)`.
|
| 50 |
-
|
| 51 |
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
## Future Enhancements
|
| 57 |
- Integrate curriculum scheduling or task-balanced sampling once empirical results dictate.
|
| 58 |
- Capture attention maps during training to support visualization in the planned Gradio UI.
|
| 59 |
- Leverage the optional `sklearn_transformer` hook in `TextPreprocessor` for lemmatization or domain-specific normalization when datasets require it.
|
|
|
|
|
|
| 7 |
`text` and `emotions` arrays. The dataset owns a `MultiLabelBinarizer` for consistent encoding.
|
| 8 |
- **Topic Classification** – single-label categorical samples with `text` and `topic` fields, encoded via `LabelEncoder`.
|
| 9 |
|
| 10 |
+
Paths and tokenizer defaults are configured in `configs/data/datasets.yaml`. The tokenizer section chooses the Hugging Face backbone (`google/flan-t5-base` by default) and maximum length. Gutenberg book downloads are controlled via the `downloads.books` list (each entry includes `name`, `url`, and `output`).
|
| 11 |
|
| 12 |
## Dataloaders & Collators
|
| 13 |
+
- `SummarizationCollator` encodes encoder/decoder inputs, prepares decoder input IDs via `Tokenizer.prepare_decoder_inputs`, and masks padding tokens with `-100` for loss computation. Note: FLAN-T5 uses `pad_token_id=0` and `decoder_start_token_id=0`.
|
| 14 |
- `EmotionCollator` applies the dataset's `MultiLabelBinarizer`, returning dense float tensors suitable for `BCEWithLogitsLoss`.
|
| 15 |
- `TopicCollator` emits integer class IDs via the dataset's `LabelEncoder` for `CrossEntropyLoss`.
|
| 16 |
|
|
|
|
| 18 |
|
| 19 |
## Model Assembly
|
| 20 |
- `src/models/factory.build_multitask_model` rebuilds the encoder, decoder, and heads from the tokenizer metadata and YAML config. This factory is used both during training and inference to eliminate drift between environments.
|
| 21 |
+
- Pretrained weights are loaded from FLAN-T5 using `_load_t5_weights()`, which transfers:
|
| 22 |
+
- Shared token embeddings (with proper scaling)
|
| 23 |
+
- Attention projections (q, k, v, o) for all encoder/decoder layers
|
| 24 |
+
- FFN weights (wi_0, wi_1 for gated activation, wo for output)
|
| 25 |
+
- Layer normalization parameters (mapped from T5's RMSNorm)
|
| 26 |
- The model wraps:
|
| 27 |
+
- Transformer encoder/decoder stacks with **Pre-LN RMSNorm** architecture.
|
| 28 |
- LM head tied to decoder embeddings for summarization.
|
| 29 |
- Mean-pooled classification heads for emotion and topic tasks.
|
| 30 |
|
|
|
|
| 44 |
- `src/utils/io.save_state` stores model weights; checkpoints live under `checkpoints/`.
|
| 45 |
- `artifacts/labels.json` captures the ordered emotion/topic vocabularies immediately after
|
| 46 |
training. This file is required for inference so class indices map back to human-readable labels.
|
| 47 |
+
- The tokenizer is exported to `artifacts/hf_tokenizer/` for reproducible vocabularies using `scripts/export_tokenizer.py`.
|
| 48 |
|
| 49 |
## Running Training
|
| 50 |
1. Ensure processed datasets are available (see `data/processed/` structure).
|
| 51 |
+
2. Export the FLAN-T5 tokenizer: `python scripts/export_tokenizer.py`
|
| 52 |
+
3. Choose a configuration (e.g., `configs/training/dev.yaml`) for hyperparameters and data splits.
|
| 53 |
+
4. Instantiate the tokenizer via `TokenizerConfig` and build datasets/dataloaders.
|
| 54 |
+
5. Use `build_multitask_model` to construct the model with FLAN-T5 weights, create an optimizer, and run
|
| 55 |
`Trainer.fit(train_loaders, val_loaders)`.
|
| 56 |
+
6. Save checkpoints and update `artifacts/labels.json` with the dataset label order.
|
| 57 |
|
| 58 |
+
```bash
|
| 59 |
+
# Quick start
|
| 60 |
+
python scripts/export_tokenizer.py # Export FLAN-T5 tokenizer
|
| 61 |
+
python scripts/train.py training=dev # Run dev training (2 epochs)
|
| 62 |
+
python scripts/train.py training=medium # Run medium training (5 epochs)
|
| 63 |
+
python scripts/train.py training=full # Run full training (10 epochs)
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
## Why FLAN-T5?
|
| 67 |
+
LexiMind's custom Transformer uses **Pre-LN (normalization before sublayers)** with **RMSNorm**. This modern architecture choice provides:
|
| 68 |
+
- Better gradient flow during training
|
| 69 |
+
- Improved training stability
|
| 70 |
+
- Faster convergence
|
| 71 |
+
|
| 72 |
+
FLAN-T5 uses the same Pre-LN RMSNorm architecture, making weight transfer straightforward. Previously used BART (Post-LN LayerNorm) had a fundamental architectural mismatch that caused training issues.
|
| 73 |
+
|
| 74 |
+
> **Note:** T5's relative position bias is NOT transferred. The model uses learned positional encodings which train from scratch. This is fine since positional information is task-specific.
|
| 75 |
|
| 76 |
## Future Enhancements
|
| 77 |
- Integrate curriculum scheduling or task-balanced sampling once empirical results dictate.
|
| 78 |
- Capture attention maps during training to support visualization in the planned Gradio UI.
|
| 79 |
- Leverage the optional `sklearn_transformer` hook in `TextPreprocessor` for lemmatization or domain-specific normalization when datasets require it.
|
| 80 |
+
- Experiment with FLAN-T5-large for improved performance on longer sequences.
|
outputs/evaluation_report.json
CHANGED
|
@@ -1,46 +1,45 @@
|
|
| 1 |
{
|
|
|
|
| 2 |
"summarization": {
|
| 3 |
-
"rouge_like": 0.
|
| 4 |
-
"bleu": 0.
|
| 5 |
},
|
| 6 |
"emotion": {
|
| 7 |
-
"f1_macro": 0.
|
| 8 |
},
|
| 9 |
"topic": {
|
| 10 |
-
"accuracy": 0.
|
| 11 |
"classification_report": {
|
| 12 |
-
"
|
| 13 |
-
"precision": 0.
|
| 14 |
-
"recall": 0.
|
| 15 |
-
"f1-score": 0.
|
| 16 |
-
"support":
|
| 17 |
},
|
| 18 |
-
"
|
| 19 |
-
"precision": 0.
|
| 20 |
-
"recall": 0.
|
| 21 |
-
"f1-score": 0.
|
| 22 |
-
"support":
|
| 23 |
},
|
| 24 |
-
"
|
| 25 |
-
"precision": 0.
|
| 26 |
-
"recall": 0.
|
| 27 |
-
"f1-score": 0.
|
| 28 |
-
"support":
|
| 29 |
},
|
| 30 |
-
"
|
| 31 |
-
|
| 32 |
-
"
|
| 33 |
-
"
|
| 34 |
-
"
|
| 35 |
-
"support": 270
|
| 36 |
},
|
| 37 |
-
"
|
| 38 |
-
"precision": 0.
|
| 39 |
-
"recall": 0.
|
| 40 |
-
"f1-score": 0.
|
| 41 |
-
"support":
|
| 42 |
}
|
| 43 |
}
|
| 44 |
-
}
|
| 45 |
-
"split": "validation_dummy"
|
| 46 |
}
|
|
|
|
| 1 |
{
|
| 2 |
+
"split": "test",
|
| 3 |
"summarization": {
|
| 4 |
+
"rouge_like": 0.031742493938280825,
|
| 5 |
+
"bleu": 0.0008530696741094626
|
| 6 |
},
|
| 7 |
"emotion": {
|
| 8 |
+
"f1_macro": 0.42543327808380127
|
| 9 |
},
|
| 10 |
"topic": {
|
| 11 |
+
"accuracy": 0.3325,
|
| 12 |
"classification_report": {
|
| 13 |
+
"Business": {
|
| 14 |
+
"precision": 0.24772065955383124,
|
| 15 |
+
"recall": 0.6721052631578948,
|
| 16 |
+
"f1-score": 0.3620127569099929,
|
| 17 |
+
"support": 1900
|
| 18 |
},
|
| 19 |
+
"Sci/Tech": {
|
| 20 |
+
"precision": 0.4942170818505338,
|
| 21 |
+
"recall": 0.5847368421052631,
|
| 22 |
+
"f1-score": 0.5356798457087754,
|
| 23 |
+
"support": 1900
|
| 24 |
},
|
| 25 |
+
"Sports": {
|
| 26 |
+
"precision": 0.9473684210526315,
|
| 27 |
+
"recall": 0.018947368421052633,
|
| 28 |
+
"f1-score": 0.03715170278637771,
|
| 29 |
+
"support": 1900
|
| 30 |
},
|
| 31 |
+
"World": {
|
| 32 |
+
"precision": 0.6477987421383647,
|
| 33 |
+
"recall": 0.05421052631578947,
|
| 34 |
+
"f1-score": 0.10004856726566294,
|
| 35 |
+
"support": 1900
|
|
|
|
| 36 |
},
|
| 37 |
+
"macro avg": {
|
| 38 |
+
"precision": 0.5842762261488403,
|
| 39 |
+
"recall": 0.3325,
|
| 40 |
+
"f1-score": 0.2587232181677022,
|
| 41 |
+
"support": 7600
|
| 42 |
}
|
| 43 |
}
|
| 44 |
+
}
|
|
|
|
| 45 |
}
|
outputs/training_history.json
CHANGED
|
@@ -1,92 +1,21 @@
|
|
| 1 |
{
|
| 2 |
"train_epoch_1": {
|
| 3 |
-
"summarization_loss":
|
| 4 |
-
"summarization_rouge_like": 0.
|
| 5 |
-
"emotion_loss": 0.
|
| 6 |
-
"emotion_f1": 0.
|
| 7 |
-
"topic_loss":
|
| 8 |
-
"topic_accuracy": 0.
|
|
|
|
| 9 |
"epoch": 1.0
|
| 10 |
},
|
| 11 |
"val_epoch_1": {
|
| 12 |
-
"summarization_loss": 3.
|
| 13 |
-
"summarization_rouge_like": 0.
|
| 14 |
-
"emotion_loss": 0.
|
| 15 |
-
"emotion_f1": 0.
|
| 16 |
-
"topic_loss": 0.
|
| 17 |
-
"topic_accuracy": 0.
|
| 18 |
"epoch": 1.0
|
| 19 |
-
},
|
| 20 |
-
"train_epoch_2": {
|
| 21 |
-
"summarization_loss": 3.398382334982861,
|
| 22 |
-
"summarization_rouge_like": 0.31421210196164595,
|
| 23 |
-
"emotion_loss": 0.008744604070504772,
|
| 24 |
-
"emotion_f1": 0.9922616565848632,
|
| 25 |
-
"topic_loss": 0.12368396144345378,
|
| 26 |
-
"topic_accuracy": 0.9631060183895236,
|
| 27 |
-
"epoch": 2.0
|
| 28 |
-
},
|
| 29 |
-
"val_epoch_2": {
|
| 30 |
-
"summarization_loss": 2.728874285017067,
|
| 31 |
-
"summarization_rouge_like": 0.3867885960963845,
|
| 32 |
-
"emotion_loss": 0.20949344621063382,
|
| 33 |
-
"emotion_f1": 0.9095850804121747,
|
| 34 |
-
"topic_loss": 0.2887416907434674,
|
| 35 |
-
"topic_accuracy": 0.9329742669060442,
|
| 36 |
-
"epoch": 2.0
|
| 37 |
-
},
|
| 38 |
-
"train_epoch_3": {
|
| 39 |
-
"summarization_loss": 2.699047506134568,
|
| 40 |
-
"summarization_rouge_like": 0.38349341261349945,
|
| 41 |
-
"emotion_loss": 0.005096756787117961,
|
| 42 |
-
"emotion_f1": 0.9953213525834805,
|
| 43 |
-
"topic_loss": 0.07009015341349616,
|
| 44 |
-
"topic_accuracy": 0.9802800222903316,
|
| 45 |
-
"epoch": 3.0
|
| 46 |
-
},
|
| 47 |
-
"val_epoch_3": {
|
| 48 |
-
"summarization_loss": 2.354555403451446,
|
| 49 |
-
"summarization_rouge_like": 0.4275408038759501,
|
| 50 |
-
"emotion_loss": 0.20089952317384335,
|
| 51 |
-
"emotion_f1": 0.9075279304326329,
|
| 52 |
-
"topic_loss": 0.4845805834182202,
|
| 53 |
-
"topic_accuracy": 0.9298324356672651,
|
| 54 |
-
"epoch": 3.0
|
| 55 |
-
},
|
| 56 |
-
"train_epoch_4": {
|
| 57 |
-
"summarization_loss": 2.3750830047009015,
|
| 58 |
-
"summarization_rouge_like": 0.4200744394095619,
|
| 59 |
-
"emotion_loss": 0.0037049090056492364,
|
| 60 |
-
"emotion_f1": 0.9962315410599798,
|
| 61 |
-
"topic_loss": 0.042221361385891144,
|
| 62 |
-
"topic_accuracy": 0.9888652828085818,
|
| 63 |
-
"epoch": 4.0
|
| 64 |
-
},
|
| 65 |
-
"val_epoch_4": {
|
| 66 |
-
"summarization_loss": 2.198225014299636,
|
| 67 |
-
"summarization_rouge_like": 0.444635960654823,
|
| 68 |
-
"emotion_loss": 0.20359252842952202,
|
| 69 |
-
"emotion_f1": 0.9163175773506461,
|
| 70 |
-
"topic_loss": 0.5501026207833392,
|
| 71 |
-
"topic_accuracy": 0.9272890484739676,
|
| 72 |
-
"epoch": 4.0
|
| 73 |
-
},
|
| 74 |
-
"train_epoch_5": {
|
| 75 |
-
"summarization_loss": 2.186419085976007,
|
| 76 |
-
"summarization_rouge_like": 0.4416556068282783,
|
| 77 |
-
"emotion_loss": 0.0030099891204739266,
|
| 78 |
-
"emotion_f1": 0.9964672148443591,
|
| 79 |
-
"topic_loss": 0.03006078401232904,
|
| 80 |
-
"topic_accuracy": 0.9925606018389523,
|
| 81 |
-
"epoch": 5.0
|
| 82 |
-
},
|
| 83 |
-
"val_epoch_5": {
|
| 84 |
-
"summarization_loss": 2.114973693461849,
|
| 85 |
-
"summarization_rouge_like": 0.4553148986859889,
|
| 86 |
-
"emotion_loss": 0.2197709748711572,
|
| 87 |
-
"emotion_f1": 0.9121534032496345,
|
| 88 |
-
"topic_loss": 0.6607796598369469,
|
| 89 |
-
"topic_accuracy": 0.931178934769599,
|
| 90 |
-
"epoch": 5.0
|
| 91 |
}
|
| 92 |
}
|
|
|
|
| 1 |
{
|
| 2 |
"train_epoch_1": {
|
| 3 |
+
"summarization_loss": 3.6738915424346925,
|
| 4 |
+
"summarization_rouge_like": 0.3936604625654161,
|
| 5 |
+
"emotion_loss": 0.5655887125730514,
|
| 6 |
+
"emotion_f1": 0.02088333384692669,
|
| 7 |
+
"topic_loss": 1.2472841796875,
|
| 8 |
+
"topic_accuracy": 0.5795,
|
| 9 |
+
"total_loss": 5.486764434695244,
|
| 10 |
"epoch": 1.0
|
| 11 |
},
|
| 12 |
"val_epoch_1": {
|
| 13 |
+
"summarization_loss": 3.24564736366272,
|
| 14 |
+
"summarization_rouge_like": 0.4398922732261946,
|
| 15 |
+
"emotion_loss": 0.4284175229072571,
|
| 16 |
+
"emotion_f1": 0.0,
|
| 17 |
+
"topic_loss": 0.814755859375,
|
| 18 |
+
"topic_accuracy": 0.835,
|
| 19 |
"epoch": 1.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
}
|
| 21 |
}
|
pyproject.toml
CHANGED
|
@@ -35,6 +35,7 @@ bitsandbytes = ">=0.41.0"
|
|
| 35 |
accelerate = ">=0.21.0"
|
| 36 |
fastapi = ">=0.110.0"
|
| 37 |
mlflow = ">=2.0.0"
|
|
|
|
| 38 |
|
| 39 |
[tool.poetry.group.dev.dependencies]
|
| 40 |
pytest = "^7.4.0"
|
|
|
|
| 35 |
accelerate = ">=0.21.0"
|
| 36 |
fastapi = ">=0.110.0"
|
| 37 |
mlflow = ">=2.0.0"
|
| 38 |
+
triton = { version = "*", markers = "sys_platform == 'linux'" }
|
| 39 |
|
| 40 |
[tool.poetry.group.dev.dependencies]
|
| 41 |
pytest = "^7.4.0"
|
scripts/evaluate.py
CHANGED
|
@@ -13,6 +13,7 @@ from typing import Any, List, cast
|
|
| 13 |
|
| 14 |
import torch
|
| 15 |
from sklearn.preprocessing import MultiLabelBinarizer
|
|
|
|
| 16 |
|
| 17 |
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
| 18 |
if str(PROJECT_ROOT) not in sys.path:
|
|
@@ -135,7 +136,13 @@ def main() -> None:
|
|
| 135 |
print("Evaluating Summarization...")
|
| 136 |
summaries_pred = []
|
| 137 |
summaries_ref = []
|
| 138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
inputs = [example.source for example in batch]
|
| 140 |
summaries_pred.extend(pipeline.summarize(inputs))
|
| 141 |
summaries_ref.extend([example.summary for example in batch])
|
|
@@ -148,9 +155,17 @@ def main() -> None:
|
|
| 148 |
emotion_preds_tensor = []
|
| 149 |
emotion_target_tensor = []
|
| 150 |
label_to_index = {label: idx for idx, label in enumerate(metadata.emotion)}
|
| 151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
inputs = [example.text for example in batch]
|
| 153 |
-
predictions = pipeline.predict_emotions(inputs)
|
| 154 |
target_matrix = emotion_binarizer.transform([list(example.emotions) for example in batch])
|
| 155 |
for pred, target_row in zip(predictions, target_matrix, strict=False):
|
| 156 |
vector = torch.zeros(len(metadata.emotion), dtype=torch.float32)
|
|
@@ -169,7 +184,10 @@ def main() -> None:
|
|
| 169 |
print("Evaluating Topic Classification...")
|
| 170 |
topic_preds = []
|
| 171 |
topic_targets = []
|
| 172 |
-
|
|
|
|
|
|
|
|
|
|
| 173 |
inputs = [example.text for example in batch]
|
| 174 |
topic_predictions = pipeline.predict_topics(inputs)
|
| 175 |
topic_preds.extend([pred.label for pred in topic_predictions])
|
|
|
|
| 13 |
|
| 14 |
import torch
|
| 15 |
from sklearn.preprocessing import MultiLabelBinarizer
|
| 16 |
+
from tqdm import tqdm
|
| 17 |
|
| 18 |
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
| 19 |
if str(PROJECT_ROOT) not in sys.path:
|
|
|
|
| 136 |
print("Evaluating Summarization...")
|
| 137 |
summaries_pred = []
|
| 138 |
summaries_ref = []
|
| 139 |
+
total_batches = (len(summary_examples) + args.batch_size - 1) // args.batch_size
|
| 140 |
+
for batch in tqdm(
|
| 141 |
+
chunks(summary_examples, args.batch_size),
|
| 142 |
+
total=total_batches,
|
| 143 |
+
desc="Summarization",
|
| 144 |
+
unit="batch",
|
| 145 |
+
):
|
| 146 |
inputs = [example.source for example in batch]
|
| 147 |
summaries_pred.extend(pipeline.summarize(inputs))
|
| 148 |
summaries_ref.extend([example.summary for example in batch])
|
|
|
|
| 155 |
emotion_preds_tensor = []
|
| 156 |
emotion_target_tensor = []
|
| 157 |
label_to_index = {label: idx for idx, label in enumerate(metadata.emotion)}
|
| 158 |
+
total_batches = (len(emotion_examples) + args.batch_size - 1) // args.batch_size
|
| 159 |
+
|
| 160 |
+
# Lower threshold to 0.3 to catch weak signals, or use argmax if appropriate
|
| 161 |
+
# For now, we'll stick to thresholding but lower it.
|
| 162 |
+
inference_threshold = 0.3
|
| 163 |
+
|
| 164 |
+
for batch in tqdm(
|
| 165 |
+
chunks(emotion_examples, args.batch_size), total=total_batches, desc="Emotion", unit="batch"
|
| 166 |
+
):
|
| 167 |
inputs = [example.text for example in batch]
|
| 168 |
+
predictions = pipeline.predict_emotions(inputs, threshold=inference_threshold)
|
| 169 |
target_matrix = emotion_binarizer.transform([list(example.emotions) for example in batch])
|
| 170 |
for pred, target_row in zip(predictions, target_matrix, strict=False):
|
| 171 |
vector = torch.zeros(len(metadata.emotion), dtype=torch.float32)
|
|
|
|
| 184 |
print("Evaluating Topic Classification...")
|
| 185 |
topic_preds = []
|
| 186 |
topic_targets = []
|
| 187 |
+
total_batches = (len(topic_examples) + args.batch_size - 1) // args.batch_size
|
| 188 |
+
for batch in tqdm(
|
| 189 |
+
chunks(topic_examples, args.batch_size), total=total_batches, desc="Topic", unit="batch"
|
| 190 |
+
):
|
| 191 |
inputs = [example.text for example in batch]
|
| 192 |
topic_predictions = pipeline.predict_topics(inputs)
|
| 193 |
topic_preds.extend([pred.label for pred in topic_predictions])
|
scripts/export_model.py
CHANGED
|
@@ -51,7 +51,7 @@ def main() -> None:
|
|
| 51 |
data_cfg = load_yaml(args.data_config).data
|
| 52 |
tokenizer_section = data_cfg.get("tokenizer", {})
|
| 53 |
tokenizer_config = TokenizerConfig(
|
| 54 |
-
pretrained_model_name=tokenizer_section.get("pretrained_model_name", "
|
| 55 |
max_length=int(tokenizer_section.get("max_length", 512)),
|
| 56 |
lower=bool(tokenizer_section.get("lower", False)),
|
| 57 |
)
|
|
@@ -64,7 +64,7 @@ def main() -> None:
|
|
| 64 |
config=load_model_config(args.model_config),
|
| 65 |
)
|
| 66 |
|
| 67 |
-
raw_state = torch.load(checkpoint, map_location="
|
| 68 |
if isinstance(raw_state, dict):
|
| 69 |
if "model_state_dict" in raw_state and isinstance(raw_state["model_state_dict"], dict):
|
| 70 |
state_dict = raw_state["model_state_dict"]
|
|
|
|
| 51 |
data_cfg = load_yaml(args.data_config).data
|
| 52 |
tokenizer_section = data_cfg.get("tokenizer", {})
|
| 53 |
tokenizer_config = TokenizerConfig(
|
| 54 |
+
pretrained_model_name=tokenizer_section.get("pretrained_model_name", "google/flan-t5-base"),
|
| 55 |
max_length=int(tokenizer_section.get("max_length", 512)),
|
| 56 |
lower=bool(tokenizer_section.get("lower", False)),
|
| 57 |
)
|
|
|
|
| 64 |
config=load_model_config(args.model_config),
|
| 65 |
)
|
| 66 |
|
| 67 |
+
raw_state = torch.load(checkpoint, map_location="cuda")
|
| 68 |
if isinstance(raw_state, dict):
|
| 69 |
if "model_state_dict" in raw_state and isinstance(raw_state["model_state_dict"], dict):
|
| 70 |
state_dict = raw_state["model_state_dict"]
|
scripts/export_tokenizer.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Export the FLAN-T5 tokenizer to the artifacts directory for reproducible inference."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
from transformers import AutoTokenizer
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def parse_args() -> argparse.Namespace:
|
| 12 |
+
parser = argparse.ArgumentParser(description="Export tokenizer to artifacts directory")
|
| 13 |
+
parser.add_argument(
|
| 14 |
+
"--model-name",
|
| 15 |
+
default="google/flan-t5-base",
|
| 16 |
+
help="HuggingFace model name for the tokenizer.",
|
| 17 |
+
)
|
| 18 |
+
parser.add_argument(
|
| 19 |
+
"--output-dir",
|
| 20 |
+
default="artifacts/hf_tokenizer",
|
| 21 |
+
help="Output directory for tokenizer files.",
|
| 22 |
+
)
|
| 23 |
+
return parser.parse_args()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def main() -> None:
|
| 27 |
+
args = parse_args()
|
| 28 |
+
|
| 29 |
+
output_dir = Path(args.output_dir)
|
| 30 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 31 |
+
|
| 32 |
+
print(f"Downloading tokenizer from {args.model_name}...")
|
| 33 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
| 34 |
+
|
| 35 |
+
print(f"Saving tokenizer to {output_dir}...")
|
| 36 |
+
tokenizer.save_pretrained(str(output_dir))
|
| 37 |
+
|
| 38 |
+
# Print tokenizer info
|
| 39 |
+
print("\nTokenizer saved successfully!")
|
| 40 |
+
print(f" Vocab size: {tokenizer.vocab_size}")
|
| 41 |
+
print(f" Pad token: {tokenizer.pad_token} (id={tokenizer.pad_token_id})")
|
| 42 |
+
print(f" EOS token: {tokenizer.eos_token} (id={tokenizer.eos_token_id})")
|
| 43 |
+
print(f" BOS token: {tokenizer.bos_token} (id={getattr(tokenizer, 'bos_token_id', 'N/A')})")
|
| 44 |
+
|
| 45 |
+
print("\nFiles created:")
|
| 46 |
+
for file in sorted(output_dir.iterdir()):
|
| 47 |
+
print(f" - {file.name}")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
if __name__ == "__main__":
|
| 51 |
+
main()
|
scripts/train.py
CHANGED
|
@@ -3,9 +3,11 @@
|
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
import json
|
|
|
|
| 6 |
import sys
|
|
|
|
| 7 |
from pathlib import Path
|
| 8 |
-
from typing import Dict, Sequence, cast
|
| 9 |
|
| 10 |
import hydra
|
| 11 |
import torch
|
|
@@ -63,11 +65,86 @@ def _read_examples(data_dir: Path, loader) -> SplitExamples:
|
|
| 63 |
return splits
|
| 64 |
|
| 65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
@hydra.main(version_base=None, config_path="../configs", config_name="config")
|
| 67 |
def main(cfg: DictConfig) -> None:
|
| 68 |
print(OmegaConf.to_yaml(cfg))
|
| 69 |
set_seed(cfg.seed)
|
| 70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
# Access configs directly from Hydra cfg object
|
| 72 |
data_cfg = cfg.data
|
| 73 |
training_cfg = cfg.training
|
|
@@ -82,6 +159,8 @@ def main(cfg: DictConfig) -> None:
|
|
| 82 |
dropout=cfg.model.dropout,
|
| 83 |
use_pretrained=cfg.model.use_pretrained,
|
| 84 |
pretrained_model_name=cfg.model.pretrained_model_name,
|
|
|
|
|
|
|
| 85 |
)
|
| 86 |
|
| 87 |
summarization_dir = Path(data_cfg.processed.summarization)
|
|
@@ -92,9 +171,17 @@ def main(cfg: DictConfig) -> None:
|
|
| 92 |
emotion_splits = _read_examples(emotion_dir, load_emotion_jsonl)
|
| 93 |
topic_splits = _read_examples(topic_dir, load_topic_jsonl)
|
| 94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
tokenizer_section = data_cfg.get("tokenizer", {})
|
| 96 |
tokenizer_config = TokenizerConfig(
|
| 97 |
-
pretrained_model_name=tokenizer_section.get("pretrained_model_name", "
|
| 98 |
max_length=int(tokenizer_section.get("max_length", 512)),
|
| 99 |
lower=bool(tokenizer_section.get("lower", False)),
|
| 100 |
)
|
|
@@ -112,6 +199,9 @@ def main(cfg: DictConfig) -> None:
|
|
| 112 |
dataloader_args = training_cfg.get("dataloader", {})
|
| 113 |
batch_size = int(dataloader_args.get("batch_size", 8))
|
| 114 |
shuffle = bool(dataloader_args.get("shuffle", True))
|
|
|
|
|
|
|
|
|
|
| 115 |
max_length = tokenizer.config.max_length
|
| 116 |
|
| 117 |
train_loaders = {
|
|
@@ -122,6 +212,8 @@ def main(cfg: DictConfig) -> None:
|
|
| 122 |
shuffle=shuffle,
|
| 123 |
max_source_length=max_length,
|
| 124 |
max_target_length=max_length,
|
|
|
|
|
|
|
| 125 |
),
|
| 126 |
"emotion": build_emotion_dataloader(
|
| 127 |
emotion_train,
|
|
@@ -129,6 +221,8 @@ def main(cfg: DictConfig) -> None:
|
|
| 129 |
batch_size=batch_size,
|
| 130 |
shuffle=shuffle,
|
| 131 |
max_length=max_length,
|
|
|
|
|
|
|
| 132 |
),
|
| 133 |
"topic": build_topic_dataloader(
|
| 134 |
topic_train,
|
|
@@ -136,6 +230,8 @@ def main(cfg: DictConfig) -> None:
|
|
| 136 |
batch_size=batch_size,
|
| 137 |
shuffle=shuffle,
|
| 138 |
max_length=max_length,
|
|
|
|
|
|
|
| 139 |
),
|
| 140 |
}
|
| 141 |
|
|
@@ -147,6 +243,8 @@ def main(cfg: DictConfig) -> None:
|
|
| 147 |
shuffle=False,
|
| 148 |
max_source_length=max_length,
|
| 149 |
max_target_length=max_length,
|
|
|
|
|
|
|
| 150 |
),
|
| 151 |
"emotion": build_emotion_dataloader(
|
| 152 |
emotion_val,
|
|
@@ -154,6 +252,8 @@ def main(cfg: DictConfig) -> None:
|
|
| 154 |
batch_size=batch_size,
|
| 155 |
shuffle=False,
|
| 156 |
max_length=max_length,
|
|
|
|
|
|
|
| 157 |
),
|
| 158 |
"topic": build_topic_dataloader(
|
| 159 |
topic_val,
|
|
@@ -161,6 +261,8 @@ def main(cfg: DictConfig) -> None:
|
|
| 161 |
batch_size=batch_size,
|
| 162 |
shuffle=False,
|
| 163 |
max_length=max_length,
|
|
|
|
|
|
|
| 164 |
),
|
| 165 |
}
|
| 166 |
|
|
@@ -179,9 +281,43 @@ def main(cfg: DictConfig) -> None:
|
|
| 179 |
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
|
| 180 |
|
| 181 |
# Optimize model execution graph with torch.compile (PyTorch 2.0+)
|
| 182 |
-
# This fuses kernels and reduces overhead for faster training
|
| 183 |
-
|
| 184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
|
| 186 |
trainer_cfg = training_cfg.get("trainer", {})
|
| 187 |
trainer = Trainer(
|
|
@@ -193,6 +329,7 @@ def main(cfg: DictConfig) -> None:
|
|
| 193 |
logging_interval=int(trainer_cfg.get("logging_interval", 50)),
|
| 194 |
task_weights=trainer_cfg.get("task_weights"),
|
| 195 |
label_smoothing=float(trainer_cfg.get("label_smoothing", 0.0)),
|
|
|
|
| 196 |
),
|
| 197 |
device=device,
|
| 198 |
tokenizer=tokenizer,
|
|
@@ -200,7 +337,7 @@ def main(cfg: DictConfig) -> None:
|
|
| 200 |
|
| 201 |
# Save checkpoint after every epoch to avoid losing good early checkpoints
|
| 202 |
# Previous training showed overfitting at epoch 5 but good results at epoch 3
|
| 203 |
-
def save_epoch_checkpoint(epoch: int) -> None:
|
| 204 |
epoch_path = Path(cfg.checkpoint_out).parent / f"epoch_{epoch}.pt"
|
| 205 |
epoch_path.parent.mkdir(parents=True, exist_ok=True)
|
| 206 |
save_state(model, str(epoch_path))
|
|
|
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
import json
|
| 6 |
+
import platform
|
| 7 |
import sys
|
| 8 |
+
import warnings
|
| 9 |
from pathlib import Path
|
| 10 |
+
from typing import Any, Dict, Sequence, Tuple, cast
|
| 11 |
|
| 12 |
import hydra
|
| 13 |
import torch
|
|
|
|
| 65 |
return splits
|
| 66 |
|
| 67 |
|
| 68 |
+
def _limit_samples(splits: SplitExamples, trainer_cfg: DictConfig) -> None:
|
| 69 |
+
"""Limit the number of samples in train/val splits if configured."""
|
| 70 |
+
max_train = trainer_cfg.get("max_train_samples")
|
| 71 |
+
max_val = trainer_cfg.get("max_val_samples")
|
| 72 |
+
|
| 73 |
+
if max_train is not None and "train" in splits:
|
| 74 |
+
original_len = len(splits["train"])
|
| 75 |
+
limit = int(max_train)
|
| 76 |
+
if original_len > limit:
|
| 77 |
+
splits["train"] = splits["train"][:limit]
|
| 78 |
+
print(f"Limited 'train' split from {original_len} to {limit} samples")
|
| 79 |
+
|
| 80 |
+
if max_val is not None and "val" in splits:
|
| 81 |
+
original_len = len(splits["val"])
|
| 82 |
+
limit = int(max_val)
|
| 83 |
+
if original_len > limit:
|
| 84 |
+
splits["val"] = splits["val"][:limit]
|
| 85 |
+
print(f"Limited 'val' split from {original_len} to {limit} samples")
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def compile_model_safe(model: torch.nn.Module) -> Tuple[Any, str]:
|
| 89 |
+
"""
|
| 90 |
+
Safely compile model with best available backend.
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
Compiled model and backend name used
|
| 94 |
+
"""
|
| 95 |
+
system = platform.system()
|
| 96 |
+
|
| 97 |
+
# NOTE: The 'inductor' backend causes NaN gradients during backward pass with
|
| 98 |
+
# bfloat16 autocast on the decoder (seq2seq tasks). This is a known issue.
|
| 99 |
+
# Use 'aot_eager' which provides graph optimization without inductor's codegen.
|
| 100 |
+
# See: debug_compile_config.py and test_compile_modes.py for investigation.
|
| 101 |
+
|
| 102 |
+
# Try aot_eager first - it's stable and provides good speedup
|
| 103 |
+
try:
|
| 104 |
+
print("Attempting to compile with 'aot_eager' backend...")
|
| 105 |
+
compiled_model = torch.compile(model, backend="aot_eager")
|
| 106 |
+
print("✓ Successfully compiled with 'aot_eager' backend")
|
| 107 |
+
return cast(torch.nn.Module, compiled_model), "aot_eager"
|
| 108 |
+
except Exception as e:
|
| 109 |
+
warnings.warn(f"aot_eager backend failed: {e}", stacklevel=2)
|
| 110 |
+
|
| 111 |
+
# Fallback: Try other backends (inductor may work for encoder-only tasks)
|
| 112 |
+
backends_to_try = ["eager"]
|
| 113 |
+
if system != "Windows":
|
| 114 |
+
# On Linux, inductor might work for some configurations
|
| 115 |
+
backends_to_try = ["eager", "inductor"]
|
| 116 |
+
|
| 117 |
+
for backend in backends_to_try:
|
| 118 |
+
try:
|
| 119 |
+
print(f"Attempting to compile with '{backend}' backend...")
|
| 120 |
+
compiled_model = torch.compile(model, backend=backend)
|
| 121 |
+
# Trigger a dummy run or just return? torch.compile is lazy.
|
| 122 |
+
# I assume it works if the call succeeds, runtime errors handled later.
|
| 123 |
+
print(f"✓ Successfully compiled with '{backend}' backend")
|
| 124 |
+
return cast(torch.nn.Module, compiled_model), backend
|
| 125 |
+
except Exception as e:
|
| 126 |
+
print(f"✗ '{backend}' backend failed: {e}")
|
| 127 |
+
continue
|
| 128 |
+
|
| 129 |
+
# No compilation worked, return original model
|
| 130 |
+
warnings.warn("All torch.compile backends failed, using uncompiled model", stacklevel=2)
|
| 131 |
+
return model, "none"
|
| 132 |
+
|
| 133 |
+
|
| 134 |
@hydra.main(version_base=None, config_path="../configs", config_name="config")
|
| 135 |
def main(cfg: DictConfig) -> None:
|
| 136 |
print(OmegaConf.to_yaml(cfg))
|
| 137 |
set_seed(cfg.seed)
|
| 138 |
|
| 139 |
+
# Enable TF32 for Ampere/Ada GPUs (RTX 30xx/40xx)
|
| 140 |
+
# This provides significant speedup on RTX 4070
|
| 141 |
+
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
|
| 142 |
+
print("Enabling TF32 for Ampere/Ada GPU...")
|
| 143 |
+
torch.set_float32_matmul_precision("high")
|
| 144 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 145 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 146 |
+
torch.backends.cudnn.benchmark = True # Auto-tunes convolution algorithms
|
| 147 |
+
|
| 148 |
# Access configs directly from Hydra cfg object
|
| 149 |
data_cfg = cfg.data
|
| 150 |
training_cfg = cfg.training
|
|
|
|
| 159 |
dropout=cfg.model.dropout,
|
| 160 |
use_pretrained=cfg.model.use_pretrained,
|
| 161 |
pretrained_model_name=cfg.model.pretrained_model_name,
|
| 162 |
+
activation=getattr(cfg.model, "activation", "gelu"),
|
| 163 |
+
use_relative_position_bias=getattr(cfg.model, "use_relative_position_bias", False),
|
| 164 |
)
|
| 165 |
|
| 166 |
summarization_dir = Path(data_cfg.processed.summarization)
|
|
|
|
| 171 |
emotion_splits = _read_examples(emotion_dir, load_emotion_jsonl)
|
| 172 |
topic_splits = _read_examples(topic_dir, load_topic_jsonl)
|
| 173 |
|
| 174 |
+
# Apply sample limits if configured (e.g. for dev/medium runs)
|
| 175 |
+
trainer_cfg = training_cfg.get("trainer", {})
|
| 176 |
+
print("\nApplying dataset limits...")
|
| 177 |
+
_limit_samples(summarization_splits, trainer_cfg)
|
| 178 |
+
_limit_samples(emotion_splits, trainer_cfg)
|
| 179 |
+
_limit_samples(topic_splits, trainer_cfg)
|
| 180 |
+
print("Dataset limits applied.\n")
|
| 181 |
+
|
| 182 |
tokenizer_section = data_cfg.get("tokenizer", {})
|
| 183 |
tokenizer_config = TokenizerConfig(
|
| 184 |
+
pretrained_model_name=tokenizer_section.get("pretrained_model_name", "google/flan-t5-base"),
|
| 185 |
max_length=int(tokenizer_section.get("max_length", 512)),
|
| 186 |
lower=bool(tokenizer_section.get("lower", False)),
|
| 187 |
)
|
|
|
|
| 199 |
dataloader_args = training_cfg.get("dataloader", {})
|
| 200 |
batch_size = int(dataloader_args.get("batch_size", 8))
|
| 201 |
shuffle = bool(dataloader_args.get("shuffle", True))
|
| 202 |
+
# Optimization: Use multiple workers and pinned memory for faster data transfer
|
| 203 |
+
num_workers = int(dataloader_args.get("num_workers", 4))
|
| 204 |
+
pin_memory = bool(dataloader_args.get("pin_memory", True))
|
| 205 |
max_length = tokenizer.config.max_length
|
| 206 |
|
| 207 |
train_loaders = {
|
|
|
|
| 212 |
shuffle=shuffle,
|
| 213 |
max_source_length=max_length,
|
| 214 |
max_target_length=max_length,
|
| 215 |
+
num_workers=num_workers,
|
| 216 |
+
pin_memory=pin_memory,
|
| 217 |
),
|
| 218 |
"emotion": build_emotion_dataloader(
|
| 219 |
emotion_train,
|
|
|
|
| 221 |
batch_size=batch_size,
|
| 222 |
shuffle=shuffle,
|
| 223 |
max_length=max_length,
|
| 224 |
+
num_workers=num_workers,
|
| 225 |
+
pin_memory=pin_memory,
|
| 226 |
),
|
| 227 |
"topic": build_topic_dataloader(
|
| 228 |
topic_train,
|
|
|
|
| 230 |
batch_size=batch_size,
|
| 231 |
shuffle=shuffle,
|
| 232 |
max_length=max_length,
|
| 233 |
+
num_workers=num_workers,
|
| 234 |
+
pin_memory=pin_memory,
|
| 235 |
),
|
| 236 |
}
|
| 237 |
|
|
|
|
| 243 |
shuffle=False,
|
| 244 |
max_source_length=max_length,
|
| 245 |
max_target_length=max_length,
|
| 246 |
+
num_workers=num_workers,
|
| 247 |
+
pin_memory=pin_memory,
|
| 248 |
),
|
| 249 |
"emotion": build_emotion_dataloader(
|
| 250 |
emotion_val,
|
|
|
|
| 252 |
batch_size=batch_size,
|
| 253 |
shuffle=False,
|
| 254 |
max_length=max_length,
|
| 255 |
+
num_workers=num_workers,
|
| 256 |
+
pin_memory=pin_memory,
|
| 257 |
),
|
| 258 |
"topic": build_topic_dataloader(
|
| 259 |
topic_val,
|
|
|
|
| 261 |
batch_size=batch_size,
|
| 262 |
shuffle=False,
|
| 263 |
max_length=max_length,
|
| 264 |
+
num_workers=num_workers,
|
| 265 |
+
pin_memory=pin_memory,
|
| 266 |
),
|
| 267 |
}
|
| 268 |
|
|
|
|
| 281 |
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
|
| 282 |
|
| 283 |
# Optimize model execution graph with torch.compile (PyTorch 2.0+)
|
| 284 |
+
# This fuses kernels and reduces overhead for faster training
|
| 285 |
+
# Note: We only compile encoder/decoder for training, not the step() method used in generation
|
| 286 |
+
# Compile encoder and decoder separately to avoid control flow issues in MultiTaskModel.forward
|
| 287 |
+
# Compiling the top-level model causes excessive recompilation due to task switching
|
| 288 |
+
use_compile = True # torch.compile for faster training
|
| 289 |
+
|
| 290 |
+
if use_compile and model.encoder is not None:
|
| 291 |
+
model.encoder, backend_used = compile_model_safe(model.encoder)
|
| 292 |
+
else:
|
| 293 |
+
backend_used = "disabled"
|
| 294 |
+
if use_compile and model.decoder is not None:
|
| 295 |
+
# Compile decoder.forward but keep step/greedy_decode uncompiled for generation
|
| 296 |
+
model.decoder, _ = compile_model_safe(model.decoder)
|
| 297 |
+
|
| 298 |
+
# Compile heads
|
| 299 |
+
if use_compile:
|
| 300 |
+
for name, head in model.heads.items():
|
| 301 |
+
compiled_head, _ = compile_model_safe(head)
|
| 302 |
+
model.heads[name] = compiled_head
|
| 303 |
+
# Update the registered module as well to ensure parameters are tracked correctly
|
| 304 |
+
setattr(model, f"head_{name}", compiled_head)
|
| 305 |
+
|
| 306 |
+
print(f"Using compilation backend: {backend_used}")
|
| 307 |
+
|
| 308 |
+
# Verify weights loaded correctly (check for NaNs/Infs)
|
| 309 |
+
print("\n=== Weight Loading Verification ===")
|
| 310 |
+
has_issues = False
|
| 311 |
+
for name, param in model.named_parameters():
|
| 312 |
+
if torch.isnan(param).any():
|
| 313 |
+
print(f"WARNING: NaN in {name}")
|
| 314 |
+
has_issues = True
|
| 315 |
+
if torch.isinf(param).any():
|
| 316 |
+
print(f"WARNING: Inf in {name}")
|
| 317 |
+
has_issues = True
|
| 318 |
+
if not has_issues:
|
| 319 |
+
print("✓ No NaNs or Infs found in model parameters.")
|
| 320 |
+
print("=== Verification Complete ===\n")
|
| 321 |
|
| 322 |
trainer_cfg = training_cfg.get("trainer", {})
|
| 323 |
trainer = Trainer(
|
|
|
|
| 329 |
logging_interval=int(trainer_cfg.get("logging_interval", 50)),
|
| 330 |
task_weights=trainer_cfg.get("task_weights"),
|
| 331 |
label_smoothing=float(trainer_cfg.get("label_smoothing", 0.0)),
|
| 332 |
+
gradient_accumulation_steps=int(trainer_cfg.get("gradient_accumulation_steps", 1)),
|
| 333 |
),
|
| 334 |
device=device,
|
| 335 |
tokenizer=tokenizer,
|
|
|
|
| 337 |
|
| 338 |
# Save checkpoint after every epoch to avoid losing good early checkpoints
|
| 339 |
# Previous training showed overfitting at epoch 5 but good results at epoch 3
|
| 340 |
+
def save_epoch_checkpoint(epoch: int, model: torch.nn.Module, history: Dict) -> None:
|
| 341 |
epoch_path = Path(cfg.checkpoint_out).parent / f"epoch_{epoch}.pt"
|
| 342 |
epoch_path.parent.mkdir(parents=True, exist_ok=True)
|
| 343 |
save_state(model, str(epoch_path))
|
src/data/dataloader.py
CHANGED
|
@@ -120,13 +120,22 @@ def build_summarization_dataloader(
|
|
| 120 |
shuffle: bool = True,
|
| 121 |
max_source_length: int | None = None,
|
| 122 |
max_target_length: int | None = None,
|
|
|
|
|
|
|
| 123 |
) -> DataLoader:
|
| 124 |
collator = SummarizationCollator(
|
| 125 |
tokenizer,
|
| 126 |
max_source_length=max_source_length,
|
| 127 |
max_target_length=max_target_length,
|
| 128 |
)
|
| 129 |
-
return DataLoader(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
|
| 132 |
def build_emotion_dataloader(
|
|
@@ -136,9 +145,18 @@ def build_emotion_dataloader(
|
|
| 136 |
batch_size: int,
|
| 137 |
shuffle: bool = True,
|
| 138 |
max_length: int | None = None,
|
|
|
|
|
|
|
| 139 |
) -> DataLoader:
|
| 140 |
collator = EmotionCollator(tokenizer, dataset, max_length=max_length)
|
| 141 |
-
return DataLoader(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
|
| 143 |
|
| 144 |
def build_topic_dataloader(
|
|
@@ -148,6 +166,15 @@ def build_topic_dataloader(
|
|
| 148 |
batch_size: int,
|
| 149 |
shuffle: bool = True,
|
| 150 |
max_length: int | None = None,
|
|
|
|
|
|
|
| 151 |
) -> DataLoader:
|
| 152 |
collator = TopicCollator(tokenizer, dataset, max_length=max_length)
|
| 153 |
-
return DataLoader(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
shuffle: bool = True,
|
| 121 |
max_source_length: int | None = None,
|
| 122 |
max_target_length: int | None = None,
|
| 123 |
+
num_workers: int = 0,
|
| 124 |
+
pin_memory: bool = False,
|
| 125 |
) -> DataLoader:
|
| 126 |
collator = SummarizationCollator(
|
| 127 |
tokenizer,
|
| 128 |
max_source_length=max_source_length,
|
| 129 |
max_target_length=max_target_length,
|
| 130 |
)
|
| 131 |
+
return DataLoader(
|
| 132 |
+
dataset,
|
| 133 |
+
batch_size=batch_size,
|
| 134 |
+
shuffle=shuffle,
|
| 135 |
+
collate_fn=collator,
|
| 136 |
+
num_workers=num_workers,
|
| 137 |
+
pin_memory=pin_memory,
|
| 138 |
+
)
|
| 139 |
|
| 140 |
|
| 141 |
def build_emotion_dataloader(
|
|
|
|
| 145 |
batch_size: int,
|
| 146 |
shuffle: bool = True,
|
| 147 |
max_length: int | None = None,
|
| 148 |
+
num_workers: int = 0,
|
| 149 |
+
pin_memory: bool = False,
|
| 150 |
) -> DataLoader:
|
| 151 |
collator = EmotionCollator(tokenizer, dataset, max_length=max_length)
|
| 152 |
+
return DataLoader(
|
| 153 |
+
dataset,
|
| 154 |
+
batch_size=batch_size,
|
| 155 |
+
shuffle=shuffle,
|
| 156 |
+
collate_fn=collator,
|
| 157 |
+
num_workers=num_workers,
|
| 158 |
+
pin_memory=pin_memory,
|
| 159 |
+
)
|
| 160 |
|
| 161 |
|
| 162 |
def build_topic_dataloader(
|
|
|
|
| 166 |
batch_size: int,
|
| 167 |
shuffle: bool = True,
|
| 168 |
max_length: int | None = None,
|
| 169 |
+
num_workers: int = 0,
|
| 170 |
+
pin_memory: bool = False,
|
| 171 |
) -> DataLoader:
|
| 172 |
collator = TopicCollator(tokenizer, dataset, max_length=max_length)
|
| 173 |
+
return DataLoader(
|
| 174 |
+
dataset,
|
| 175 |
+
batch_size=batch_size,
|
| 176 |
+
shuffle=shuffle,
|
| 177 |
+
collate_fn=collator,
|
| 178 |
+
num_workers=num_workers,
|
| 179 |
+
pin_memory=pin_memory,
|
| 180 |
+
)
|
src/data/preprocessing.py
CHANGED
|
@@ -53,7 +53,7 @@ class TextPreprocessor:
|
|
| 53 |
tokenizer: Tokenizer | None = None,
|
| 54 |
*,
|
| 55 |
tokenizer_config: TokenizerConfig | None = None,
|
| 56 |
-
tokenizer_name: str = "
|
| 57 |
max_length: int | None = None,
|
| 58 |
lowercase: bool = True,
|
| 59 |
remove_stopwords: bool = False,
|
|
|
|
| 53 |
tokenizer: Tokenizer | None = None,
|
| 54 |
*,
|
| 55 |
tokenizer_config: TokenizerConfig | None = None,
|
| 56 |
+
tokenizer_name: str = "google/flan-t5-base",
|
| 57 |
max_length: int | None = None,
|
| 58 |
lowercase: bool = True,
|
| 59 |
remove_stopwords: bool = False,
|
src/data/tokenization.py
CHANGED
|
@@ -11,9 +11,9 @@ from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
|
| 11 |
|
| 12 |
@dataclass
|
| 13 |
class TokenizerConfig:
|
| 14 |
-
pretrained_model_name: str = "
|
| 15 |
max_length: int = 512
|
| 16 |
-
padding: str = "
|
| 17 |
truncation: bool = True
|
| 18 |
lower: bool = False
|
| 19 |
|
|
@@ -28,15 +28,29 @@ class Tokenizer:
|
|
| 28 |
cfg.pretrained_model_name
|
| 29 |
)
|
| 30 |
self._pad_token_id = self._resolve_id(self._tokenizer.pad_token_id)
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
self._eos_token_id = self._resolve_id(
|
| 37 |
-
self._tokenizer.
|
| 38 |
-
if self._tokenizer.eos_token_id is not None
|
| 39 |
-
else self._tokenizer.sep_token_id
|
| 40 |
)
|
| 41 |
|
| 42 |
@property
|
|
|
|
| 11 |
|
| 12 |
@dataclass
|
| 13 |
class TokenizerConfig:
|
| 14 |
+
pretrained_model_name: str = "google/flan-t5-base"
|
| 15 |
max_length: int = 512
|
| 16 |
+
padding: str = "max_length"
|
| 17 |
truncation: bool = True
|
| 18 |
lower: bool = False
|
| 19 |
|
|
|
|
| 28 |
cfg.pretrained_model_name
|
| 29 |
)
|
| 30 |
self._pad_token_id = self._resolve_id(self._tokenizer.pad_token_id)
|
| 31 |
+
|
| 32 |
+
# T5 uses different special tokens than BART:
|
| 33 |
+
# T5: pad=0, eos=1, no explicit bos (uses pad or eos as decoder start)
|
| 34 |
+
# BART: bos=0, pad=1, eos=2
|
| 35 |
+
# We use eos_token_id as bos for T5 decoder start (common practice)
|
| 36 |
+
eos_id = self._tokenizer.eos_token_id
|
| 37 |
+
bos_id = self._tokenizer.bos_token_id
|
| 38 |
+
|
| 39 |
+
# For T5, decoder_start_token_id is typically pad_token_id (0)
|
| 40 |
+
# But we'll use a sensible default based on what's available
|
| 41 |
+
if bos_id is not None:
|
| 42 |
+
self._bos_token_id = self._resolve_id(bos_id)
|
| 43 |
+
elif (
|
| 44 |
+
hasattr(self._tokenizer, "decoder_start_token_id")
|
| 45 |
+
and self._tokenizer.decoder_start_token_id is not None
|
| 46 |
+
):
|
| 47 |
+
self._bos_token_id = self._resolve_id(self._tokenizer.decoder_start_token_id)
|
| 48 |
+
else:
|
| 49 |
+
# T5 convention: use pad_token_id as decoder start
|
| 50 |
+
self._bos_token_id = self._pad_token_id
|
| 51 |
+
|
| 52 |
self._eos_token_id = self._resolve_id(
|
| 53 |
+
eos_id if eos_id is not None else self._tokenizer.sep_token_id
|
|
|
|
|
|
|
| 54 |
)
|
| 55 |
|
| 56 |
@property
|
src/models/attention.py
CHANGED
|
@@ -4,6 +4,7 @@ Attention mechanisms for Transformer architecture.
|
|
| 4 |
This module implements the core attention mechanisms used in the Transformer model:
|
| 5 |
- ScaledDotProductAttention: Fundamental attention operation
|
| 6 |
- MultiHeadAttention: Parallel attention with learned projections
|
|
|
|
| 7 |
|
| 8 |
Doing this first for Bottom-Up implementation of the Transformer
|
| 9 |
|
|
@@ -19,6 +20,130 @@ import torch.nn as nn
|
|
| 19 |
import torch.nn.functional as F
|
| 20 |
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
class ScaledDotProductAttention(nn.Module):
|
| 23 |
"""
|
| 24 |
Scaled Dot-Product Attention using PyTorch's optimized backend.
|
|
@@ -31,10 +156,15 @@ class ScaledDotProductAttention(nn.Module):
|
|
| 31 |
See: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
| 32 |
"""
|
| 33 |
|
| 34 |
-
def __init__(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
super().__init__()
|
| 36 |
-
|
| 37 |
-
pass
|
| 38 |
|
| 39 |
def forward(
|
| 40 |
self,
|
|
@@ -43,90 +173,86 @@ class ScaledDotProductAttention(nn.Module):
|
|
| 43 |
value: torch.Tensor,
|
| 44 |
mask: Optional[torch.Tensor] = None,
|
| 45 |
return_attn_weights: bool = False,
|
|
|
|
| 46 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 47 |
"""
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
6. Return both output and attention_weights
|
| 55 |
-
"""
|
| 56 |
-
# NEW: FlashAttention implementation using PyTorch 2.0+ SDPA
|
| 57 |
-
# This automatically selects the best kernel (FlashAttention, EfficientAttention, etc.)
|
| 58 |
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
attn_mask = None
|
| 64 |
-
if mask is not None:
|
| 65 |
-
attn_mask = ~mask.to(dtype=torch.bool, device=query.device)
|
| 66 |
-
|
| 67 |
-
# Call SDPA
|
| 68 |
-
# Note: I don't apply dropout here as my original implementation doesn't
|
| 69 |
-
# If we wanted to, I'd pass dropout_p to this method
|
| 70 |
-
if not return_attn_weights:
|
| 71 |
-
output = F.scaled_dot_product_attention(
|
| 72 |
-
query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False
|
| 73 |
-
)
|
| 74 |
-
# SDPA doesn't return attention weights by default for efficiency
|
| 75 |
-
# I return None for weights when using the optimized kernel
|
| 76 |
-
return output, None
|
| 77 |
-
|
| 78 |
-
# --------- OLD: Manual implementation (Fallback when weights are needed) ---------------
|
| 79 |
-
# Scaled Dot-Product Attention as described in "Attention Is All You Need" 2017.
|
| 80 |
-
# Computes: Attention(Q, K, V) = softmax(QK^T / sqrt(d_k))V
|
| 81 |
-
# The scaling factor (1/sqrt(d_k)) prevents the dot products from growing too large,
|
| 82 |
-
# which would push the softmax into regions with extremely small gradients.
|
| 83 |
-
# Args:
|
| 84 |
-
# None - this module has no learnable parameters
|
| 85 |
-
# Forward Args:
|
| 86 |
-
# query: Query tensor of shape (batch, seq_len, d_k)
|
| 87 |
-
# key: Key tensor of shape (batch, seq_len, d_k)
|
| 88 |
-
# value: Value tensor of shape (batch, seq_len, d_v)
|
| 89 |
-
# mask: Optional mask tensor of shape (batch, seq_len, seq_len)
|
| 90 |
-
# True/1 values indicate positions to attend to, False/0 to mask
|
| 91 |
-
# Returns:
|
| 92 |
-
# output: Attention output of shape (batch, seq_len, d_v)
|
| 93 |
-
# attention_weights: Attention probability matrix (batch, seq_len, seq_len)
|
| 94 |
-
# Getting Dimension for Scaling
|
| 95 |
d_k = query.size(-1)
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
# Convert mask to same dtype as p_attn for multiplication
|
| 113 |
-
mask_float = mask.to(dtype=p_attn.dtype, device=p_attn.device)
|
| 114 |
-
# Broadcast-multiply (zero out masked key positions)
|
| 115 |
-
p_attn = p_attn * mask_float
|
| 116 |
-
# Replace any NaNs (can occur when a row was entirely -inf prior to softmax) with 0.0
|
| 117 |
-
# torch.nan_to_num is efficient and handles negative/positive inf as well
|
| 118 |
p_attn = torch.nan_to_num(p_attn, nan=0.0, posinf=0.0, neginf=0.0)
|
|
|
|
|
|
|
| 119 |
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
p_attn = torch.where(nonzero_rows, p_attn / (row_sums + 1e-12), p_attn)
|
| 126 |
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
|
| 132 |
# --------------- Rotary Positional Embeddings ---------------
|
|
@@ -186,6 +312,7 @@ class MultiHeadAttention(nn.Module):
|
|
| 186 |
lora_rank: Rank of LoRA matrices (default: 8)
|
| 187 |
lora_alpha: Scaling factor for LoRA (default: 16)
|
| 188 |
lora_dropout: Dropout probability for LoRA (default: 0.1)
|
|
|
|
| 189 |
"""
|
| 190 |
|
| 191 |
def __init__(
|
|
@@ -200,6 +327,7 @@ class MultiHeadAttention(nn.Module):
|
|
| 200 |
lora_alpha: int = 16,
|
| 201 |
lora_dropout: float = 0.1,
|
| 202 |
quantization: Optional[str] = None,
|
|
|
|
| 203 |
):
|
| 204 |
super().__init__()
|
| 205 |
|
|
@@ -238,7 +366,8 @@ class MultiHeadAttention(nn.Module):
|
|
| 238 |
self.W_V = Linear(d_model, d_model, **kwargs)
|
| 239 |
self.W_O = Linear(d_model, d_model, **kwargs)
|
| 240 |
# Create ScaledDotProductAttention instance
|
| 241 |
-
|
|
|
|
| 242 |
# Create dropout layer
|
| 243 |
self.dropout = nn.Dropout(p=dropout)
|
| 244 |
|
|
@@ -277,6 +406,7 @@ class MultiHeadAttention(nn.Module):
|
|
| 277 |
value: torch.Tensor,
|
| 278 |
mask: Optional[torch.Tensor] = None,
|
| 279 |
return_attn_weights: bool = False,
|
|
|
|
| 280 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 281 |
"""
|
| 282 |
Args:
|
|
@@ -284,6 +414,7 @@ class MultiHeadAttention(nn.Module):
|
|
| 284 |
key: (batch, seq_len, d_model)
|
| 285 |
value: (batch, seq_len, d_model)
|
| 286 |
mask: Optional (batch, seq_len, seq_len) or (batch, 1, seq_len, seq_len)
|
|
|
|
| 287 |
|
| 288 |
Returns:
|
| 289 |
output: (batch, seq_len, d_model)
|
|
@@ -329,9 +460,9 @@ class MultiHeadAttention(nn.Module):
|
|
| 329 |
mask = mask.unsqueeze(1) # (batch, 1, seq, seq)
|
| 330 |
# Now mask broadcasts across all heads: (batch, 1, seq, seq) → (batch, 8, seq, seq)
|
| 331 |
|
| 332 |
-
# Apply attention
|
| 333 |
output, attn_weights = self.attention(
|
| 334 |
-
Q, K, V, mask, return_attn_weights=return_attn_weights
|
| 335 |
)
|
| 336 |
# output: (batch, num_heads, seq_len, d_k)
|
| 337 |
# attn_weights: (batch, num_heads, seq_len, seq_len)
|
|
|
|
| 4 |
This module implements the core attention mechanisms used in the Transformer model:
|
| 5 |
- ScaledDotProductAttention: Fundamental attention operation
|
| 6 |
- MultiHeadAttention: Parallel attention with learned projections
|
| 7 |
+
- T5RelativePositionBias: Relative position bias for T5-style attention
|
| 8 |
|
| 9 |
Doing this first for Bottom-Up implementation of the Transformer
|
| 10 |
|
|
|
|
| 20 |
import torch.nn.functional as F
|
| 21 |
|
| 22 |
|
| 23 |
+
class T5RelativePositionBias(nn.Module):
|
| 24 |
+
"""
|
| 25 |
+
T5-style relative position bias for attention.
|
| 26 |
+
|
| 27 |
+
T5 uses a learned embedding table to encode relative positions between tokens.
|
| 28 |
+
Positions are bucketed to handle arbitrary sequence lengths efficiently.
|
| 29 |
+
|
| 30 |
+
This is added to attention scores BEFORE softmax, not to embeddings.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
num_heads: int,
|
| 36 |
+
num_buckets: int = 32,
|
| 37 |
+
max_distance: int = 128,
|
| 38 |
+
is_decoder: bool = False,
|
| 39 |
+
):
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.num_heads = num_heads
|
| 42 |
+
self.num_buckets = num_buckets
|
| 43 |
+
self.max_distance = max_distance
|
| 44 |
+
self.is_decoder = is_decoder
|
| 45 |
+
|
| 46 |
+
# Learned embedding table: (num_buckets, num_heads)
|
| 47 |
+
self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
|
| 48 |
+
|
| 49 |
+
@staticmethod
|
| 50 |
+
def _relative_position_bucket(
|
| 51 |
+
relative_position: torch.Tensor,
|
| 52 |
+
bidirectional: bool = True,
|
| 53 |
+
num_buckets: int = 32,
|
| 54 |
+
max_distance: int = 128,
|
| 55 |
+
) -> torch.Tensor:
|
| 56 |
+
"""
|
| 57 |
+
Translate relative position to a bucket index.
|
| 58 |
+
|
| 59 |
+
T5 uses a combination of exact positions (for nearby tokens) and
|
| 60 |
+
logarithmically-spaced buckets (for distant tokens).
|
| 61 |
+
"""
|
| 62 |
+
relative_buckets = torch.zeros_like(relative_position, dtype=torch.long)
|
| 63 |
+
|
| 64 |
+
if bidirectional:
|
| 65 |
+
num_buckets //= 2
|
| 66 |
+
relative_buckets += (relative_position > 0).long() * num_buckets
|
| 67 |
+
relative_position = torch.abs(relative_position)
|
| 68 |
+
else:
|
| 69 |
+
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
|
| 70 |
+
|
| 71 |
+
# Half buckets for exact positions
|
| 72 |
+
max_exact = num_buckets // 2
|
| 73 |
+
is_small = relative_position < max_exact
|
| 74 |
+
|
| 75 |
+
# Other half for logarithmically-spaced buckets
|
| 76 |
+
relative_position_if_large = (
|
| 77 |
+
max_exact
|
| 78 |
+
+ (
|
| 79 |
+
torch.log(relative_position.float() / max_exact)
|
| 80 |
+
/ math.log(max_distance / max_exact)
|
| 81 |
+
* (num_buckets - max_exact)
|
| 82 |
+
).long()
|
| 83 |
+
)
|
| 84 |
+
relative_position_if_large = torch.min(
|
| 85 |
+
relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
|
| 89 |
+
return relative_buckets
|
| 90 |
+
|
| 91 |
+
def compute_bias(
|
| 92 |
+
self,
|
| 93 |
+
query_length: int,
|
| 94 |
+
key_length: int,
|
| 95 |
+
device: torch.device,
|
| 96 |
+
query_position_offset: int = 0,
|
| 97 |
+
) -> torch.Tensor:
|
| 98 |
+
"""
|
| 99 |
+
Compute relative position bias for attention.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
query_length: Number of query positions
|
| 103 |
+
key_length: Number of key positions
|
| 104 |
+
device: Device to create tensors on
|
| 105 |
+
query_position_offset: Offset for query positions (for incremental decoding)
|
| 106 |
+
When decoding step-by-step, query_length=1 but the actual
|
| 107 |
+
position is past_len, so query_position_offset=past_len.
|
| 108 |
+
|
| 109 |
+
Returns: (1, num_heads, query_length, key_length)
|
| 110 |
+
"""
|
| 111 |
+
# Create position indices
|
| 112 |
+
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
|
| 113 |
+
context_position = (
|
| 114 |
+
context_position + query_position_offset
|
| 115 |
+
) # Apply offset for incremental decoding
|
| 116 |
+
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
|
| 117 |
+
|
| 118 |
+
# Relative position: (query_length, key_length)
|
| 119 |
+
relative_position = memory_position - context_position
|
| 120 |
+
|
| 121 |
+
# Convert to bucket indices
|
| 122 |
+
relative_position_bucket = self._relative_position_bucket(
|
| 123 |
+
relative_position,
|
| 124 |
+
bidirectional=(not self.is_decoder),
|
| 125 |
+
num_buckets=self.num_buckets,
|
| 126 |
+
max_distance=self.max_distance,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
# Look up bias values: (query_length, key_length, num_heads)
|
| 130 |
+
values = self.relative_attention_bias(relative_position_bucket)
|
| 131 |
+
|
| 132 |
+
# Reshape to (1, num_heads, query_length, key_length)
|
| 133 |
+
values = values.permute([2, 0, 1]).unsqueeze(0)
|
| 134 |
+
|
| 135 |
+
return values
|
| 136 |
+
|
| 137 |
+
def forward(
|
| 138 |
+
self,
|
| 139 |
+
query_length: int,
|
| 140 |
+
key_length: int,
|
| 141 |
+
device: torch.device,
|
| 142 |
+
query_position_offset: int = 0,
|
| 143 |
+
) -> torch.Tensor:
|
| 144 |
+
return self.compute_bias(query_length, key_length, device, query_position_offset)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
class ScaledDotProductAttention(nn.Module):
|
| 148 |
"""
|
| 149 |
Scaled Dot-Product Attention using PyTorch's optimized backend.
|
|
|
|
| 156 |
See: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
| 157 |
"""
|
| 158 |
|
| 159 |
+
def __init__(self, scale_scores: bool = True):
|
| 160 |
+
"""
|
| 161 |
+
Args:
|
| 162 |
+
scale_scores: Whether to scale attention scores by sqrt(d_k).
|
| 163 |
+
T5 does NOT scale scores, so set this to False for T5.
|
| 164 |
+
Standard transformers (BERT, GPT, etc.) use scaling.
|
| 165 |
+
"""
|
| 166 |
super().__init__()
|
| 167 |
+
self.scale_scores = scale_scores
|
|
|
|
| 168 |
|
| 169 |
def forward(
|
| 170 |
self,
|
|
|
|
| 173 |
value: torch.Tensor,
|
| 174 |
mask: Optional[torch.Tensor] = None,
|
| 175 |
return_attn_weights: bool = False,
|
| 176 |
+
position_bias: Optional[torch.Tensor] = None,
|
| 177 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 178 |
"""
|
| 179 |
+
Args:
|
| 180 |
+
query: (batch, num_heads, seq_q, d_k)
|
| 181 |
+
key: (batch, num_heads, seq_k, d_k)
|
| 182 |
+
value: (batch, num_heads, seq_k, d_v)
|
| 183 |
+
mask: Optional boolean mask, True = attend, False = mask
|
| 184 |
+
position_bias: Optional (1, num_heads, seq_q, seq_k) T5-style relative position bias
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
|
| 186 |
+
Returns:
|
| 187 |
+
output: (batch, num_heads, seq_q, d_v)
|
| 188 |
+
attention_weights: Optional (batch, num_heads, seq_q, seq_k)
|
| 189 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
d_k = query.size(-1)
|
| 191 |
+
scale_factor = 1.0 / math.sqrt(d_k) if self.scale_scores else 1.0
|
| 192 |
+
|
| 193 |
+
# If we need attention weights, must use manual path
|
| 194 |
+
if return_attn_weights:
|
| 195 |
+
# Manual implementation with float32 softmax for numerical stability
|
| 196 |
+
scores = torch.matmul(query, key.transpose(-2, -1)) * scale_factor
|
| 197 |
+
if position_bias is not None:
|
| 198 |
+
scores = scores + position_bias
|
| 199 |
+
if mask is not None:
|
| 200 |
+
mask_bool = mask.to(dtype=torch.bool, device=scores.device)
|
| 201 |
+
if mask_bool.dim() == 2:
|
| 202 |
+
mask_bool = mask_bool.unsqueeze(1).unsqueeze(2)
|
| 203 |
+
elif mask_bool.dim() == 3:
|
| 204 |
+
mask_bool = mask_bool.unsqueeze(1)
|
| 205 |
+
scores = scores.masked_fill(~mask_bool, -1e4)
|
| 206 |
+
p_attn = F.softmax(scores.float(), dim=-1).type_as(scores)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
p_attn = torch.nan_to_num(p_attn, nan=0.0, posinf=0.0, neginf=0.0)
|
| 208 |
+
output = torch.matmul(p_attn, value)
|
| 209 |
+
return output, p_attn
|
| 210 |
|
| 211 |
+
# Use optimized SDPA path - torch.compile friendly version
|
| 212 |
+
# Pre-scale query instead of using SDPA's scale parameter for better compile compatibility
|
| 213 |
+
# This avoids issues with inductor and custom scale values
|
| 214 |
+
if self.scale_scores:
|
| 215 |
+
query = query * scale_factor
|
|
|
|
| 216 |
|
| 217 |
+
# Build combined attention mask (float tensor added to scores)
|
| 218 |
+
attn_mask = None
|
| 219 |
+
|
| 220 |
+
if position_bias is not None or mask is not None:
|
| 221 |
+
# Start with position bias if provided
|
| 222 |
+
if position_bias is not None:
|
| 223 |
+
# Clamp position bias to prevent overflow
|
| 224 |
+
attn_mask = position_bias.to(dtype=query.dtype).clamp(-100, 100)
|
| 225 |
+
|
| 226 |
+
# Add mask (convert bool mask to additive float mask)
|
| 227 |
+
if mask is not None:
|
| 228 |
+
mask_bool = mask.to(dtype=torch.bool, device=query.device)
|
| 229 |
+
if mask_bool.dim() == 2:
|
| 230 |
+
mask_bool = mask_bool.unsqueeze(1).unsqueeze(2)
|
| 231 |
+
elif mask_bool.dim() == 3:
|
| 232 |
+
mask_bool = mask_bool.unsqueeze(1)
|
| 233 |
+
|
| 234 |
+
mask_float = torch.zeros(mask_bool.shape, dtype=query.dtype, device=query.device)
|
| 235 |
+
mask_float = mask_float.masked_fill(~mask_bool, -1e4)
|
| 236 |
+
|
| 237 |
+
if attn_mask is not None:
|
| 238 |
+
attn_mask = attn_mask + mask_float
|
| 239 |
+
else:
|
| 240 |
+
attn_mask = mask_float
|
| 241 |
+
|
| 242 |
+
# Use SDPA without custom scale (scale=None uses default 1/sqrt(d_k))
|
| 243 |
+
# For T5 (scale_scores=False), we already didn't scale query above, so default scale is wrong
|
| 244 |
+
# But we pre-scaled query for scaled attention, so we need scale=1.0 here
|
| 245 |
+
# Actually simpler: always use scale=1.0 since we handle scaling ourselves
|
| 246 |
+
output = F.scaled_dot_product_attention(
|
| 247 |
+
query,
|
| 248 |
+
key,
|
| 249 |
+
value,
|
| 250 |
+
attn_mask=attn_mask,
|
| 251 |
+
dropout_p=0.0,
|
| 252 |
+
is_causal=False,
|
| 253 |
+
scale=1.0, # We handle scaling manually above
|
| 254 |
+
)
|
| 255 |
+
return output, None
|
| 256 |
|
| 257 |
|
| 258 |
# --------------- Rotary Positional Embeddings ---------------
|
|
|
|
| 312 |
lora_rank: Rank of LoRA matrices (default: 8)
|
| 313 |
lora_alpha: Scaling factor for LoRA (default: 16)
|
| 314 |
lora_dropout: Dropout probability for LoRA (default: 0.1)
|
| 315 |
+
scale_scores: Whether to scale attention scores by sqrt(d_k). T5 does NOT scale.
|
| 316 |
"""
|
| 317 |
|
| 318 |
def __init__(
|
|
|
|
| 327 |
lora_alpha: int = 16,
|
| 328 |
lora_dropout: float = 0.1,
|
| 329 |
quantization: Optional[str] = None,
|
| 330 |
+
scale_scores: bool = True, # T5 uses scale_scores=False
|
| 331 |
):
|
| 332 |
super().__init__()
|
| 333 |
|
|
|
|
| 366 |
self.W_V = Linear(d_model, d_model, **kwargs)
|
| 367 |
self.W_O = Linear(d_model, d_model, **kwargs)
|
| 368 |
# Create ScaledDotProductAttention instance
|
| 369 |
+
# Note: T5 does NOT scale attention scores by sqrt(d_k)
|
| 370 |
+
self.attention = ScaledDotProductAttention(scale_scores=scale_scores)
|
| 371 |
# Create dropout layer
|
| 372 |
self.dropout = nn.Dropout(p=dropout)
|
| 373 |
|
|
|
|
| 406 |
value: torch.Tensor,
|
| 407 |
mask: Optional[torch.Tensor] = None,
|
| 408 |
return_attn_weights: bool = False,
|
| 409 |
+
position_bias: Optional[torch.Tensor] = None,
|
| 410 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 411 |
"""
|
| 412 |
Args:
|
|
|
|
| 414 |
key: (batch, seq_len, d_model)
|
| 415 |
value: (batch, seq_len, d_model)
|
| 416 |
mask: Optional (batch, seq_len, seq_len) or (batch, 1, seq_len, seq_len)
|
| 417 |
+
position_bias: Optional (1, num_heads, seq_q, seq_k) T5-style relative position bias
|
| 418 |
|
| 419 |
Returns:
|
| 420 |
output: (batch, seq_len, d_model)
|
|
|
|
| 460 |
mask = mask.unsqueeze(1) # (batch, 1, seq, seq)
|
| 461 |
# Now mask broadcasts across all heads: (batch, 1, seq, seq) → (batch, 8, seq, seq)
|
| 462 |
|
| 463 |
+
# Apply attention with optional position bias
|
| 464 |
output, attn_weights = self.attention(
|
| 465 |
+
Q, K, V, mask, return_attn_weights=return_attn_weights, position_bias=position_bias
|
| 466 |
)
|
| 467 |
# output: (batch, num_heads, seq_len, d_k)
|
| 468 |
# attn_weights: (batch, num_heads, seq_len, seq_len)
|
src/models/decoder.py
CHANGED
|
@@ -13,15 +13,14 @@ Conventions:
|
|
| 13 |
- RMSNorm is just simpler than LayerNorm and more computationally efficient, it's become the modern convention. These reasons are why I used it here.
|
| 14 |
"""
|
| 15 |
|
| 16 |
-
import
|
| 17 |
-
from typing import Dict, List, Optional, Tuple, Union
|
| 18 |
|
| 19 |
import torch
|
| 20 |
import torch.nn as nn
|
| 21 |
|
| 22 |
-
from .attention import MultiHeadAttention
|
| 23 |
from .feedforward import FeedForward
|
| 24 |
-
from .positional_encoding import PositionalEncoding
|
| 25 |
|
| 26 |
|
| 27 |
def create_causal_mask(seq_len: int, device: Optional[torch.device] = None) -> torch.Tensor:
|
|
@@ -50,17 +49,31 @@ class TransformerDecoderLayer(nn.Module):
|
|
| 50 |
d_ff: int,
|
| 51 |
dropout: float = 0.1,
|
| 52 |
quantization: Optional[str] = None,
|
|
|
|
|
|
|
| 53 |
):
|
| 54 |
super().__init__()
|
| 55 |
# use internal MHA dropout = 0.0; the layer handles dropout after sublayers
|
| 56 |
self.self_attn = MultiHeadAttention(
|
| 57 |
-
d_model=d_model,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
)
|
| 59 |
self.cross_attn = MultiHeadAttention(
|
| 60 |
-
d_model=d_model,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
)
|
| 62 |
self.ffn = FeedForward(
|
| 63 |
-
d_model=d_model,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
)
|
| 65 |
|
| 66 |
self.norm1 = nn.RMSNorm(d_model)
|
|
@@ -78,6 +91,8 @@ class TransformerDecoderLayer(nn.Module):
|
|
| 78 |
tgt_mask: Optional[torch.Tensor] = None,
|
| 79 |
memory_mask: Optional[torch.Tensor] = None,
|
| 80 |
collect_attn: bool = False,
|
|
|
|
|
|
|
| 81 |
) -> Tuple[torch.Tensor, Dict[str, Optional[torch.Tensor]]]:
|
| 82 |
"""
|
| 83 |
Args:
|
|
@@ -86,6 +101,8 @@ class TransformerDecoderLayer(nn.Module):
|
|
| 86 |
tgt_mask: optional mask for self-attn - shape (B, T, T) or (B, 1, T, T)
|
| 87 |
memory_mask: optional mask for cross-attn - shape (B, S) or (B, 1, S) or (B, 1, T, S)
|
| 88 |
collect_attn: whether to return attention weights
|
|
|
|
|
|
|
| 89 |
|
| 90 |
Returns:
|
| 91 |
(tgt_out, {"self": self_attn_weights, "cross": cross_attn_weights})
|
|
@@ -106,22 +123,47 @@ class TransformerDecoderLayer(nn.Module):
|
|
| 106 |
# --- Masked self-attention (Pre-LN) ---
|
| 107 |
x_norm = self.norm1(tgt)
|
| 108 |
self_out, self_attn = self.self_attn(
|
| 109 |
-
x_norm,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
)
|
| 111 |
tgt = tgt + self.dropout1(self_out)
|
| 112 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
# --- Cross-attention (Pre-LN) ---
|
| 114 |
x_norm = self.norm2(tgt)
|
| 115 |
cross_out, cross_attn = self.cross_attn(
|
| 116 |
-
x_norm,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
)
|
| 118 |
tgt = tgt + self.dropout2(cross_out)
|
| 119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
# --- Feed-forward (Pre-LN) ---
|
| 121 |
x_norm = self.norm3(tgt)
|
| 122 |
ffn_out = self.ffn(x_norm)
|
| 123 |
tgt = tgt + self.dropout3(ffn_out)
|
| 124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
return tgt, {"self": self_attn, "cross": cross_attn}
|
| 126 |
|
| 127 |
|
|
@@ -143,14 +185,42 @@ class TransformerDecoder(nn.Module):
|
|
| 143 |
max_len: int = 512,
|
| 144 |
pad_token_id: Optional[int] = None,
|
| 145 |
quantization: Optional[str] = None,
|
|
|
|
|
|
|
|
|
|
| 146 |
):
|
| 147 |
super().__init__()
|
| 148 |
self.vocab_size = vocab_size
|
| 149 |
self.d_model = d_model
|
| 150 |
self.pad_token_id = pad_token_id
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
-
|
| 153 |
-
|
| 154 |
|
| 155 |
self.layers = nn.ModuleList(
|
| 156 |
[
|
|
@@ -160,6 +230,8 @@ class TransformerDecoder(nn.Module):
|
|
| 160 |
d_ff=d_ff,
|
| 161 |
dropout=dropout,
|
| 162 |
quantization=quantization,
|
|
|
|
|
|
|
| 163 |
)
|
| 164 |
for _ in range(num_layers)
|
| 165 |
]
|
|
@@ -172,6 +244,10 @@ class TransformerDecoder(nn.Module):
|
|
| 172 |
def _build_padding_mask_from_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 173 |
"""
|
| 174 |
Convert input ids to (B, T, T) boolean mask where True = allowed.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
"""
|
| 176 |
assert self.pad_token_id is not None, "pad_token_id must be set to build mask from ids"
|
| 177 |
pad_mask = input_ids != self.pad_token_id # (B, T)
|
|
@@ -185,6 +261,7 @@ class TransformerDecoder(nn.Module):
|
|
| 185 |
tgt_mask: Optional[torch.Tensor] = None,
|
| 186 |
memory_mask: Optional[torch.Tensor] = None,
|
| 187 |
collect_attn: bool = False,
|
|
|
|
| 188 |
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[Dict[str, torch.Tensor]]]]:
|
| 189 |
"""
|
| 190 |
Args:
|
|
@@ -192,16 +269,21 @@ class TransformerDecoder(nn.Module):
|
|
| 192 |
memory: (B, S, d_model)
|
| 193 |
tgt_mask: optional; if None, will create (causal [+ padding if ids available])
|
| 194 |
memory_mask: optional; if provided as (B, S) will be expanded to (B, 1, 1, S)
|
|
|
|
| 195 |
"""
|
| 196 |
# Prepare embeddings
|
| 197 |
if inputs.dim() == 2: # token ids
|
| 198 |
-
|
|
|
|
| 199 |
elif inputs.dim() == 3:
|
| 200 |
x = inputs
|
| 201 |
else:
|
| 202 |
raise ValueError("inputs must be (B, T) token ids or (B, T, d_model) embeddings")
|
| 203 |
|
| 204 |
-
|
|
|
|
|
|
|
|
|
|
| 205 |
x = self.input_dropout(x)
|
| 206 |
|
| 207 |
B, T, _ = x.shape
|
|
@@ -209,12 +291,14 @@ class TransformerDecoder(nn.Module):
|
|
| 209 |
# Build target mask if not provided: combine causal + padding (if available)
|
| 210 |
if tgt_mask is None:
|
| 211 |
causal = create_causal_mask(T, device=x.device) # (T, T)
|
| 212 |
-
if inputs.dim() == 2 and self.pad_token_id is not None:
|
|
|
|
| 213 |
pad_pairwise = self._build_padding_mask_from_ids(inputs) # (B, T, T)
|
| 214 |
combined = pad_pairwise & causal.unsqueeze(0) # (B, T, T)
|
| 215 |
tgt_mask = combined.unsqueeze(1) # (B, 1, T, T) -> broadcast to heads
|
| 216 |
else:
|
| 217 |
-
#
|
|
|
|
| 218 |
tgt_mask = causal.unsqueeze(0).unsqueeze(1) # (1, 1, T, T)
|
| 219 |
else:
|
| 220 |
# Ensure boolean and device alignment; accept (B, T, T) or (B,1,T,T) or (1,1,T,T)
|
|
@@ -230,10 +314,27 @@ class TransformerDecoder(nn.Module):
|
|
| 230 |
|
| 231 |
attn_list: List[Dict[str, torch.Tensor]] = []
|
| 232 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
# Pass through decoder layers
|
| 234 |
for layer in self.layers:
|
| 235 |
x, attn = layer(
|
| 236 |
-
x,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
)
|
| 238 |
if collect_attn:
|
| 239 |
attn_list.append(attn)
|
|
@@ -245,6 +346,51 @@ class TransformerDecoder(nn.Module):
|
|
| 245 |
return logits, attn_list
|
| 246 |
return logits
|
| 247 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
def greedy_decode(
|
| 249 |
self,
|
| 250 |
memory: torch.Tensor,
|
|
@@ -256,50 +402,65 @@ class TransformerDecoder(nn.Module):
|
|
| 256 |
min_len: Optional[int] = None,
|
| 257 |
ban_token_ids: Optional[List[int]] = None,
|
| 258 |
no_repeat_ngram_size: int = 0,
|
|
|
|
| 259 |
memory_mask: Optional[torch.Tensor] = None,
|
| 260 |
) -> torch.Tensor:
|
| 261 |
"""
|
| 262 |
-
|
| 263 |
-
Not optimized (recomputes full decoder each step) but simple and correct.
|
| 264 |
"""
|
| 265 |
if device is None:
|
| 266 |
device = memory.device
|
| 267 |
B = memory.size(0)
|
|
|
|
|
|
|
| 268 |
generated = torch.full((B, 1), start_token_id, dtype=torch.long, device=device)
|
| 269 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
min_len = 0 if min_len is None else max(0, min_len)
|
| 271 |
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
generated, memory, collect_attn=False, memory_mask=memory_mask
|
| 275 |
-
) # (B, L, V)
|
| 276 |
-
assert isinstance(logits, torch.Tensor) # type narrowing
|
| 277 |
-
next_step_logits = logits[:, -1, :]
|
| 278 |
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
should_clone = True
|
| 283 |
-
if ban_token_ids:
|
| 284 |
-
should_clone = True
|
| 285 |
|
| 286 |
-
#
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
pass
|
| 290 |
|
| 291 |
-
|
| 292 |
-
next_step_logits = next_step_logits.clone()
|
| 293 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
if end_token_id is not None and generated.size(1) < max(1, min_len):
|
| 295 |
next_step_logits[:, end_token_id] = float("-inf")
|
| 296 |
|
| 297 |
if ban_token_ids:
|
| 298 |
next_step_logits[:, ban_token_ids] = float("-inf")
|
| 299 |
|
|
|
|
| 300 |
if no_repeat_ngram_size > 0:
|
| 301 |
-
# Calculate banned tokens based on n-grams
|
| 302 |
for b in range(B):
|
|
|
|
|
|
|
| 303 |
gen_seq = generated[b].tolist()
|
| 304 |
if len(gen_seq) < no_repeat_ngram_size - 1:
|
| 305 |
continue
|
|
@@ -307,28 +468,27 @@ class TransformerDecoder(nn.Module):
|
|
| 307 |
prefix = tuple(gen_seq[-(no_repeat_ngram_size - 1) :])
|
| 308 |
banned_for_this_batch = set()
|
| 309 |
|
| 310 |
-
# Scan history for prefix
|
| 311 |
for i in range(len(gen_seq) - no_repeat_ngram_size + 1):
|
| 312 |
window = tuple(gen_seq[i : i + no_repeat_ngram_size - 1])
|
| 313 |
if window == prefix:
|
| 314 |
-
# The token that followed this instance of prefix
|
| 315 |
if i + no_repeat_ngram_size - 1 < len(gen_seq):
|
| 316 |
banned_for_this_batch.add(gen_seq[i + no_repeat_ngram_size - 1])
|
| 317 |
|
| 318 |
if banned_for_this_batch:
|
| 319 |
-
if not should_clone:
|
| 320 |
-
next_step_logits = next_step_logits.clone()
|
| 321 |
-
should_clone = True
|
| 322 |
next_step_logits[b, list(banned_for_this_batch)] = float("-inf")
|
| 323 |
|
|
|
|
| 324 |
next_token = next_step_logits.argmax(dim=-1, keepdim=True) # (B, 1)
|
|
|
|
|
|
|
| 325 |
generated = torch.cat([generated, next_token], dim=1)
|
| 326 |
|
|
|
|
| 327 |
if end_token_id is not None:
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
|
| 333 |
return generated
|
| 334 |
|
|
@@ -337,7 +497,7 @@ class TransformerDecoder(nn.Module):
|
|
| 337 |
# -----------------------------
|
| 338 |
def step(
|
| 339 |
self,
|
| 340 |
-
last_token_ids: torch.
|
| 341 |
memory: torch.Tensor,
|
| 342 |
cache: Optional[Dict] = None,
|
| 343 |
) -> Tuple[torch.Tensor, Dict]:
|
|
@@ -361,18 +521,33 @@ class TransformerDecoder(nn.Module):
|
|
| 361 |
past_len = int(cache.get("past_length", 0))
|
| 362 |
|
| 363 |
# 1) Embed last token and add positional encoding for position `past_len`
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
if
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
|
| 377 |
# We will update new_cache incrementally
|
| 378 |
new_cache = dict(cache) # shallow copy
|
|
@@ -388,6 +563,23 @@ class TransformerDecoder(nn.Module):
|
|
| 388 |
elif memory_mask.dim() == 3:
|
| 389 |
memory_mask = memory_mask.unsqueeze(1)
|
| 390 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 391 |
# Iterate layers, updating caches and computing output for current token only
|
| 392 |
layer_input = x # (B,1,d_model)
|
| 393 |
for i, layer in enumerate(self.layers):
|
|
@@ -430,7 +622,7 @@ class TransformerDecoder(nn.Module):
|
|
| 430 |
# mask=True means attend.
|
| 431 |
step_mask = torch.ones(B_, 1, 1, K_all.size(2), dtype=torch.bool, device=device)
|
| 432 |
attn_out_heads, self_attn_w = layer.self_attn.attention(
|
| 433 |
-
Qh, K_all, V_all, mask=step_mask
|
| 434 |
)
|
| 435 |
# attn_out_heads: (B, H, 1, d_k)
|
| 436 |
# concat heads, project out
|
|
@@ -472,7 +664,7 @@ class TransformerDecoder(nn.Module):
|
|
| 472 |
) # (B,H,1,d_k)
|
| 473 |
|
| 474 |
cross_out_heads, cross_attn_w = layer.cross_attn.attention(
|
| 475 |
-
Qch, mem_k, mem_v, mask=memory_mask
|
| 476 |
)
|
| 477 |
cross_out = (
|
| 478 |
cross_out_heads.transpose(1, 2)
|
|
|
|
| 13 |
- RMSNorm is just simpler than LayerNorm and more computationally efficient, it's become the modern convention. These reasons are why I used it here.
|
| 14 |
"""
|
| 15 |
|
| 16 |
+
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
|
|
|
| 17 |
|
| 18 |
import torch
|
| 19 |
import torch.nn as nn
|
| 20 |
|
| 21 |
+
from .attention import MultiHeadAttention, T5RelativePositionBias
|
| 22 |
from .feedforward import FeedForward
|
| 23 |
+
from .positional_encoding import LearnedPositionalEncoding, PositionalEncoding
|
| 24 |
|
| 25 |
|
| 26 |
def create_causal_mask(seq_len: int, device: Optional[torch.device] = None) -> torch.Tensor:
|
|
|
|
| 49 |
d_ff: int,
|
| 50 |
dropout: float = 0.1,
|
| 51 |
quantization: Optional[str] = None,
|
| 52 |
+
activation: Literal["gelu", "relu", "swiglu", "gated-gelu"] = "gated-gelu",
|
| 53 |
+
scale_attn_scores: bool = True, # T5 uses False
|
| 54 |
):
|
| 55 |
super().__init__()
|
| 56 |
# use internal MHA dropout = 0.0; the layer handles dropout after sublayers
|
| 57 |
self.self_attn = MultiHeadAttention(
|
| 58 |
+
d_model=d_model,
|
| 59 |
+
num_heads=num_heads,
|
| 60 |
+
dropout=0.0,
|
| 61 |
+
quantization=quantization,
|
| 62 |
+
scale_scores=scale_attn_scores,
|
| 63 |
)
|
| 64 |
self.cross_attn = MultiHeadAttention(
|
| 65 |
+
d_model=d_model,
|
| 66 |
+
num_heads=num_heads,
|
| 67 |
+
dropout=0.0,
|
| 68 |
+
quantization=quantization,
|
| 69 |
+
scale_scores=scale_attn_scores,
|
| 70 |
)
|
| 71 |
self.ffn = FeedForward(
|
| 72 |
+
d_model=d_model,
|
| 73 |
+
d_ff=d_ff,
|
| 74 |
+
dropout=dropout,
|
| 75 |
+
activation=activation,
|
| 76 |
+
quantization=quantization,
|
| 77 |
)
|
| 78 |
|
| 79 |
self.norm1 = nn.RMSNorm(d_model)
|
|
|
|
| 91 |
tgt_mask: Optional[torch.Tensor] = None,
|
| 92 |
memory_mask: Optional[torch.Tensor] = None,
|
| 93 |
collect_attn: bool = False,
|
| 94 |
+
self_attn_position_bias: Optional[torch.Tensor] = None,
|
| 95 |
+
cross_attn_position_bias: Optional[torch.Tensor] = None,
|
| 96 |
) -> Tuple[torch.Tensor, Dict[str, Optional[torch.Tensor]]]:
|
| 97 |
"""
|
| 98 |
Args:
|
|
|
|
| 101 |
tgt_mask: optional mask for self-attn - shape (B, T, T) or (B, 1, T, T)
|
| 102 |
memory_mask: optional mask for cross-attn - shape (B, S) or (B, 1, S) or (B, 1, T, S)
|
| 103 |
collect_attn: whether to return attention weights
|
| 104 |
+
self_attn_position_bias: optional T5 relative position bias for self-attention
|
| 105 |
+
cross_attn_position_bias: optional T5 relative position bias for cross-attention
|
| 106 |
|
| 107 |
Returns:
|
| 108 |
(tgt_out, {"self": self_attn_weights, "cross": cross_attn_weights})
|
|
|
|
| 123 |
# --- Masked self-attention (Pre-LN) ---
|
| 124 |
x_norm = self.norm1(tgt)
|
| 125 |
self_out, self_attn = self.self_attn(
|
| 126 |
+
x_norm,
|
| 127 |
+
x_norm,
|
| 128 |
+
x_norm,
|
| 129 |
+
tgt_mask,
|
| 130 |
+
return_attn_weights=collect_attn,
|
| 131 |
+
position_bias=self_attn_position_bias,
|
| 132 |
)
|
| 133 |
tgt = tgt + self.dropout1(self_out)
|
| 134 |
|
| 135 |
+
# Clamp inf values for fp16/bf16 training stability (like HuggingFace T5)
|
| 136 |
+
if tgt.dtype == torch.float16 or tgt.dtype == torch.bfloat16:
|
| 137 |
+
clamp_value = torch.finfo(tgt.dtype).max - 1000
|
| 138 |
+
tgt = torch.clamp(tgt, min=-clamp_value, max=clamp_value)
|
| 139 |
+
|
| 140 |
# --- Cross-attention (Pre-LN) ---
|
| 141 |
x_norm = self.norm2(tgt)
|
| 142 |
cross_out, cross_attn = self.cross_attn(
|
| 143 |
+
x_norm,
|
| 144 |
+
memory,
|
| 145 |
+
memory,
|
| 146 |
+
memory_mask,
|
| 147 |
+
return_attn_weights=collect_attn,
|
| 148 |
+
position_bias=cross_attn_position_bias,
|
| 149 |
)
|
| 150 |
tgt = tgt + self.dropout2(cross_out)
|
| 151 |
|
| 152 |
+
# Clamp inf values for fp16/bf16 training stability
|
| 153 |
+
if tgt.dtype == torch.float16 or tgt.dtype == torch.bfloat16:
|
| 154 |
+
clamp_value = torch.finfo(tgt.dtype).max - 1000
|
| 155 |
+
tgt = torch.clamp(tgt, min=-clamp_value, max=clamp_value)
|
| 156 |
+
|
| 157 |
# --- Feed-forward (Pre-LN) ---
|
| 158 |
x_norm = self.norm3(tgt)
|
| 159 |
ffn_out = self.ffn(x_norm)
|
| 160 |
tgt = tgt + self.dropout3(ffn_out)
|
| 161 |
|
| 162 |
+
# Clamp inf values for fp16/bf16 training stability
|
| 163 |
+
if tgt.dtype == torch.float16 or tgt.dtype == torch.bfloat16:
|
| 164 |
+
clamp_value = torch.finfo(tgt.dtype).max - 1000
|
| 165 |
+
tgt = torch.clamp(tgt, min=-clamp_value, max=clamp_value)
|
| 166 |
+
|
| 167 |
return tgt, {"self": self_attn, "cross": cross_attn}
|
| 168 |
|
| 169 |
|
|
|
|
| 185 |
max_len: int = 512,
|
| 186 |
pad_token_id: Optional[int] = None,
|
| 187 |
quantization: Optional[str] = None,
|
| 188 |
+
use_learned_pos_enc: bool = False,
|
| 189 |
+
activation: Literal["gelu", "relu", "swiglu", "gated-gelu"] = "gated-gelu",
|
| 190 |
+
use_relative_position_bias: bool = False, # T5-style relative position bias
|
| 191 |
):
|
| 192 |
super().__init__()
|
| 193 |
self.vocab_size = vocab_size
|
| 194 |
self.d_model = d_model
|
| 195 |
self.pad_token_id = pad_token_id
|
| 196 |
+
self.num_heads = num_heads
|
| 197 |
+
self.use_relative_position_bias = use_relative_position_bias
|
| 198 |
+
|
| 199 |
+
self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_token_id)
|
| 200 |
+
|
| 201 |
+
# Positional encoding (disabled when using relative position bias for T5)
|
| 202 |
+
self.self_relative_position_bias: Optional[T5RelativePositionBias] = None
|
| 203 |
+
self.cross_relative_position_bias: Optional[T5RelativePositionBias] = None
|
| 204 |
+
if use_relative_position_bias:
|
| 205 |
+
# T5 uses relative position bias instead of absolute positional embeddings
|
| 206 |
+
self.pos_encoder = None
|
| 207 |
+
# Self-attention position bias (decoder is causal, so is_decoder=True)
|
| 208 |
+
self.self_relative_position_bias = T5RelativePositionBias(
|
| 209 |
+
num_heads=num_heads,
|
| 210 |
+
num_buckets=32,
|
| 211 |
+
max_distance=128,
|
| 212 |
+
is_decoder=True,
|
| 213 |
+
)
|
| 214 |
+
# T5 cross-attention does NOT use position bias
|
| 215 |
+
elif use_learned_pos_enc:
|
| 216 |
+
self.pos_encoder = LearnedPositionalEncoding(
|
| 217 |
+
d_model=d_model, max_len=max_len + 2, dropout=dropout
|
| 218 |
+
)
|
| 219 |
+
else:
|
| 220 |
+
self.pos_encoder = PositionalEncoding(d_model=d_model, max_len=max_len, dropout=dropout)
|
| 221 |
|
| 222 |
+
# T5 does NOT scale attention scores by sqrt(d_k), others do
|
| 223 |
+
scale_attn_scores = not use_relative_position_bias
|
| 224 |
|
| 225 |
self.layers = nn.ModuleList(
|
| 226 |
[
|
|
|
|
| 230 |
d_ff=d_ff,
|
| 231 |
dropout=dropout,
|
| 232 |
quantization=quantization,
|
| 233 |
+
activation=activation,
|
| 234 |
+
scale_attn_scores=scale_attn_scores,
|
| 235 |
)
|
| 236 |
for _ in range(num_layers)
|
| 237 |
]
|
|
|
|
| 244 |
def _build_padding_mask_from_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 245 |
"""
|
| 246 |
Convert input ids to (B, T, T) boolean mask where True = allowed.
|
| 247 |
+
|
| 248 |
+
Note: For T5, pad_token_id=0 is also used as decoder_start_token_id.
|
| 249 |
+
During generation, we should NOT mask the start token. The caller should
|
| 250 |
+
provide an explicit mask or set tgt_mask to avoid this issue.
|
| 251 |
"""
|
| 252 |
assert self.pad_token_id is not None, "pad_token_id must be set to build mask from ids"
|
| 253 |
pad_mask = input_ids != self.pad_token_id # (B, T)
|
|
|
|
| 261 |
tgt_mask: Optional[torch.Tensor] = None,
|
| 262 |
memory_mask: Optional[torch.Tensor] = None,
|
| 263 |
collect_attn: bool = False,
|
| 264 |
+
skip_padding_mask: bool = False, # Set True during generation to avoid masking start token
|
| 265 |
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[Dict[str, torch.Tensor]]]]:
|
| 266 |
"""
|
| 267 |
Args:
|
|
|
|
| 269 |
memory: (B, S, d_model)
|
| 270 |
tgt_mask: optional; if None, will create (causal [+ padding if ids available])
|
| 271 |
memory_mask: optional; if provided as (B, S) will be expanded to (B, 1, 1, S)
|
| 272 |
+
skip_padding_mask: if True, only use causal mask (for generation where start_token=pad_token)
|
| 273 |
"""
|
| 274 |
# Prepare embeddings
|
| 275 |
if inputs.dim() == 2: # token ids
|
| 276 |
+
# T5/FLAN-T5 does NOT scale embeddings by sqrt(d_model)
|
| 277 |
+
x = self.embedding(inputs)
|
| 278 |
elif inputs.dim() == 3:
|
| 279 |
x = inputs
|
| 280 |
else:
|
| 281 |
raise ValueError("inputs must be (B, T) token ids or (B, T, d_model) embeddings")
|
| 282 |
|
| 283 |
+
# Apply positional encoding if not using relative position bias
|
| 284 |
+
# (T5 uses relative position bias in attention instead of absolute positional embeddings)
|
| 285 |
+
if self.pos_encoder is not None:
|
| 286 |
+
x = self.pos_encoder(x)
|
| 287 |
x = self.input_dropout(x)
|
| 288 |
|
| 289 |
B, T, _ = x.shape
|
|
|
|
| 291 |
# Build target mask if not provided: combine causal + padding (if available)
|
| 292 |
if tgt_mask is None:
|
| 293 |
causal = create_causal_mask(T, device=x.device) # (T, T)
|
| 294 |
+
if inputs.dim() == 2 and self.pad_token_id is not None and not skip_padding_mask:
|
| 295 |
+
# During training: combine causal mask with padding mask
|
| 296 |
pad_pairwise = self._build_padding_mask_from_ids(inputs) # (B, T, T)
|
| 297 |
combined = pad_pairwise & causal.unsqueeze(0) # (B, T, T)
|
| 298 |
tgt_mask = combined.unsqueeze(1) # (B, 1, T, T) -> broadcast to heads
|
| 299 |
else:
|
| 300 |
+
# During generation (skip_padding_mask=True) or no padding info:
|
| 301 |
+
# Use only causal mask - don't mask based on token values
|
| 302 |
tgt_mask = causal.unsqueeze(0).unsqueeze(1) # (1, 1, T, T)
|
| 303 |
else:
|
| 304 |
# Ensure boolean and device alignment; accept (B, T, T) or (B,1,T,T) or (1,1,T,T)
|
|
|
|
| 314 |
|
| 315 |
attn_list: List[Dict[str, torch.Tensor]] = []
|
| 316 |
|
| 317 |
+
# Compute relative position biases (T5-style)
|
| 318 |
+
# Note: T5 uses relative position bias for self-attention but NOT for cross-attention
|
| 319 |
+
if self.use_relative_position_bias and self.self_relative_position_bias is not None:
|
| 320 |
+
self_position_bias = self.self_relative_position_bias(
|
| 321 |
+
T, T, x.device
|
| 322 |
+
) # (1, num_heads, T, T)
|
| 323 |
+
else:
|
| 324 |
+
self_position_bias = None
|
| 325 |
+
# Cross-attention position bias is None for T5 (see T5 paper/implementation)
|
| 326 |
+
cross_position_bias = None
|
| 327 |
+
|
| 328 |
# Pass through decoder layers
|
| 329 |
for layer in self.layers:
|
| 330 |
x, attn = layer(
|
| 331 |
+
x,
|
| 332 |
+
memory,
|
| 333 |
+
tgt_mask=tgt_mask,
|
| 334 |
+
memory_mask=memory_mask,
|
| 335 |
+
collect_attn=collect_attn,
|
| 336 |
+
self_attn_position_bias=self_position_bias,
|
| 337 |
+
cross_attn_position_bias=cross_position_bias,
|
| 338 |
)
|
| 339 |
if collect_attn:
|
| 340 |
attn_list.append(attn)
|
|
|
|
| 346 |
return logits, attn_list
|
| 347 |
return logits
|
| 348 |
|
| 349 |
+
def greedy_decode_naive(
|
| 350 |
+
self,
|
| 351 |
+
memory: torch.Tensor,
|
| 352 |
+
max_len: int,
|
| 353 |
+
start_token_id: int,
|
| 354 |
+
end_token_id: Optional[int] = None,
|
| 355 |
+
device: Optional[torch.device] = None,
|
| 356 |
+
memory_mask: Optional[torch.Tensor] = None,
|
| 357 |
+
) -> torch.Tensor:
|
| 358 |
+
"""
|
| 359 |
+
Naive greedy decoding using full forward pass (O(N^2) but simpler).
|
| 360 |
+
Used for debugging to verify step() correctness.
|
| 361 |
+
"""
|
| 362 |
+
if device is None:
|
| 363 |
+
device = memory.device
|
| 364 |
+
B = memory.size(0)
|
| 365 |
+
|
| 366 |
+
# Initialize with start token
|
| 367 |
+
generated = torch.full((B, 1), start_token_id, dtype=torch.long, device=device)
|
| 368 |
+
|
| 369 |
+
for _ in range(max_len - 1):
|
| 370 |
+
# Full forward pass on entire generated sequence
|
| 371 |
+
# skip_padding_mask=True because start_token=pad_token for T5
|
| 372 |
+
logits = self.forward(
|
| 373 |
+
generated, memory, memory_mask=memory_mask, skip_padding_mask=True
|
| 374 |
+
)
|
| 375 |
+
if isinstance(logits, tuple):
|
| 376 |
+
logits = logits[0]
|
| 377 |
+
# logits: (B, T, vocab)
|
| 378 |
+
|
| 379 |
+
# Get logits for last position
|
| 380 |
+
next_logits = logits[:, -1, :] # (B, vocab)
|
| 381 |
+
|
| 382 |
+
# Greedy: pick highest probability token
|
| 383 |
+
next_token = next_logits.argmax(dim=-1, keepdim=True) # (B, 1)
|
| 384 |
+
|
| 385 |
+
# Append to generated
|
| 386 |
+
generated = torch.cat([generated, next_token], dim=1)
|
| 387 |
+
|
| 388 |
+
# Check for EOS
|
| 389 |
+
if end_token_id is not None and (next_token == end_token_id).all():
|
| 390 |
+
break
|
| 391 |
+
|
| 392 |
+
return generated
|
| 393 |
+
|
| 394 |
def greedy_decode(
|
| 395 |
self,
|
| 396 |
memory: torch.Tensor,
|
|
|
|
| 402 |
min_len: Optional[int] = None,
|
| 403 |
ban_token_ids: Optional[List[int]] = None,
|
| 404 |
no_repeat_ngram_size: int = 0,
|
| 405 |
+
repetition_penalty: float = 1.0,
|
| 406 |
memory_mask: Optional[torch.Tensor] = None,
|
| 407 |
) -> torch.Tensor:
|
| 408 |
"""
|
| 409 |
+
Greedy decoding with KV caching for O(N) complexity.
|
|
|
|
| 410 |
"""
|
| 411 |
if device is None:
|
| 412 |
device = memory.device
|
| 413 |
B = memory.size(0)
|
| 414 |
+
|
| 415 |
+
# Initialize generated sequence with start token
|
| 416 |
generated = torch.full((B, 1), start_token_id, dtype=torch.long, device=device)
|
| 417 |
|
| 418 |
+
# Initialize cache
|
| 419 |
+
cache: Dict[str, Any] = {"past_length": 0}
|
| 420 |
+
if memory_mask is not None:
|
| 421 |
+
cache["memory_mask"] = memory_mask
|
| 422 |
+
|
| 423 |
min_len = 0 if min_len is None else max(0, min_len)
|
| 424 |
|
| 425 |
+
# Keep track of finished sequences
|
| 426 |
+
finished = torch.zeros(B, dtype=torch.bool, device=device)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 427 |
|
| 428 |
+
for _ in range(max_len - 1):
|
| 429 |
+
# Use the last generated token for the next step
|
| 430 |
+
last_token = generated[:, -1:] # (B, 1)
|
|
|
|
|
|
|
|
|
|
| 431 |
|
| 432 |
+
# Run one step of the decoder
|
| 433 |
+
logits, cache = self.step(last_token, memory, cache)
|
| 434 |
+
# logits: (B, vocab_size)
|
|
|
|
| 435 |
|
| 436 |
+
next_step_logits = logits.clone()
|
|
|
|
| 437 |
|
| 438 |
+
# Apply repetition penalty
|
| 439 |
+
if repetition_penalty != 1.0:
|
| 440 |
+
for b in range(B):
|
| 441 |
+
if finished[b]:
|
| 442 |
+
continue
|
| 443 |
+
gen_seq = generated[b]
|
| 444 |
+
unique_tokens = torch.unique(gen_seq)
|
| 445 |
+
current_logits = next_step_logits[b, unique_tokens]
|
| 446 |
+
next_step_logits[b, unique_tokens] = torch.where(
|
| 447 |
+
current_logits < 0,
|
| 448 |
+
current_logits * repetition_penalty,
|
| 449 |
+
current_logits / repetition_penalty,
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
# Apply constraints
|
| 453 |
if end_token_id is not None and generated.size(1) < max(1, min_len):
|
| 454 |
next_step_logits[:, end_token_id] = float("-inf")
|
| 455 |
|
| 456 |
if ban_token_ids:
|
| 457 |
next_step_logits[:, ban_token_ids] = float("-inf")
|
| 458 |
|
| 459 |
+
# N-gram repetition blocking
|
| 460 |
if no_repeat_ngram_size > 0:
|
|
|
|
| 461 |
for b in range(B):
|
| 462 |
+
if finished[b]:
|
| 463 |
+
continue
|
| 464 |
gen_seq = generated[b].tolist()
|
| 465 |
if len(gen_seq) < no_repeat_ngram_size - 1:
|
| 466 |
continue
|
|
|
|
| 468 |
prefix = tuple(gen_seq[-(no_repeat_ngram_size - 1) :])
|
| 469 |
banned_for_this_batch = set()
|
| 470 |
|
|
|
|
| 471 |
for i in range(len(gen_seq) - no_repeat_ngram_size + 1):
|
| 472 |
window = tuple(gen_seq[i : i + no_repeat_ngram_size - 1])
|
| 473 |
if window == prefix:
|
|
|
|
| 474 |
if i + no_repeat_ngram_size - 1 < len(gen_seq):
|
| 475 |
banned_for_this_batch.add(gen_seq[i + no_repeat_ngram_size - 1])
|
| 476 |
|
| 477 |
if banned_for_this_batch:
|
|
|
|
|
|
|
|
|
|
| 478 |
next_step_logits[b, list(banned_for_this_batch)] = float("-inf")
|
| 479 |
|
| 480 |
+
# Greedy selection
|
| 481 |
next_token = next_step_logits.argmax(dim=-1, keepdim=True) # (B, 1)
|
| 482 |
+
|
| 483 |
+
# Update generated sequence
|
| 484 |
generated = torch.cat([generated, next_token], dim=1)
|
| 485 |
|
| 486 |
+
# Check for completion
|
| 487 |
if end_token_id is not None:
|
| 488 |
+
is_end = next_token.squeeze(-1) == end_token_id
|
| 489 |
+
finished = finished | is_end
|
| 490 |
+
if finished.all() and generated.size(1) >= max(1, min_len):
|
| 491 |
+
break
|
| 492 |
|
| 493 |
return generated
|
| 494 |
|
|
|
|
| 497 |
# -----------------------------
|
| 498 |
def step(
|
| 499 |
self,
|
| 500 |
+
last_token_ids: torch.Tensor,
|
| 501 |
memory: torch.Tensor,
|
| 502 |
cache: Optional[Dict] = None,
|
| 503 |
) -> Tuple[torch.Tensor, Dict]:
|
|
|
|
| 521 |
past_len = int(cache.get("past_length", 0))
|
| 522 |
|
| 523 |
# 1) Embed last token and add positional encoding for position `past_len`
|
| 524 |
+
# T5/FLAN-T5 does NOT scale embeddings by sqrt(d_model)
|
| 525 |
+
x = self.embedding(last_token_ids) # (B,1,d)
|
| 526 |
+
|
| 527 |
+
# Handle positional encoding for single step
|
| 528 |
+
# Note: When using relative position bias (T5-style), pos_encoder is None
|
| 529 |
+
if self.pos_encoder is not None:
|
| 530 |
+
if hasattr(self.pos_encoder, "pe"):
|
| 531 |
+
# Sinusoidal: use buffer directly
|
| 532 |
+
pe = self.pos_encoder.pe # (1, max_len, d_model)
|
| 533 |
+
pos_idx = past_len
|
| 534 |
+
if pos_idx >= pe.size(1):
|
| 535 |
+
raise RuntimeError(f"pos_idx {pos_idx} exceeds max_len {pe.size(1)}")
|
| 536 |
+
x = x + pe[:, pos_idx : pos_idx + 1, :].to(device)
|
| 537 |
+
elif hasattr(self.pos_encoder, "embeddings"):
|
| 538 |
+
# Learned: lookup specific position
|
| 539 |
+
# Create position ids: [past_len]
|
| 540 |
+
pos_idx = torch.tensor([past_len], dtype=torch.long, device=device)
|
| 541 |
+
# Lookup embedding: (1, d_model)
|
| 542 |
+
pos_emb = self.pos_encoder.embeddings(pos_idx)
|
| 543 |
+
# Add to input: (B, 1, d_model) + (1, 1, d_model) broadcast
|
| 544 |
+
x = x + pos_emb.unsqueeze(0)
|
| 545 |
+
x = self.pos_encoder.dropout(x)
|
| 546 |
+
else:
|
| 547 |
+
# fallback: call pos_encoder (likely incorrect for step-by-step if it assumes pos 0)
|
| 548 |
+
x = self.pos_encoder(x)
|
| 549 |
+
# When pos_encoder is None (relative position bias mode), we skip positional encoding
|
| 550 |
+
# The position information is provided via relative_position_bias in attention
|
| 551 |
|
| 552 |
# We will update new_cache incrementally
|
| 553 |
new_cache = dict(cache) # shallow copy
|
|
|
|
| 563 |
elif memory_mask.dim() == 3:
|
| 564 |
memory_mask = memory_mask.unsqueeze(1)
|
| 565 |
|
| 566 |
+
# Compute position biases for incremental step (T5-style)
|
| 567 |
+
# For step mode: query_length=1, but actual position is past_len
|
| 568 |
+
# Self-attention: query at position past_len attends to keys at positions 0..past_len
|
| 569 |
+
# Note: T5 uses relative position bias for self-attention but NOT for cross-attention
|
| 570 |
+
if self.use_relative_position_bias and self.self_relative_position_bias is not None:
|
| 571 |
+
# Self-attention bias: query_length=1, key_length=past_len+1, offset=past_len
|
| 572 |
+
self_position_bias = self.self_relative_position_bias(
|
| 573 |
+
query_length=1,
|
| 574 |
+
key_length=past_len + 1,
|
| 575 |
+
device=device,
|
| 576 |
+
query_position_offset=past_len,
|
| 577 |
+
) # (1, num_heads, 1, past_len+1)
|
| 578 |
+
else:
|
| 579 |
+
self_position_bias = None
|
| 580 |
+
# Cross-attention position bias is None for T5 (see T5 paper/implementation)
|
| 581 |
+
cross_position_bias = None
|
| 582 |
+
|
| 583 |
# Iterate layers, updating caches and computing output for current token only
|
| 584 |
layer_input = x # (B,1,d_model)
|
| 585 |
for i, layer in enumerate(self.layers):
|
|
|
|
| 622 |
# mask=True means attend.
|
| 623 |
step_mask = torch.ones(B_, 1, 1, K_all.size(2), dtype=torch.bool, device=device)
|
| 624 |
attn_out_heads, self_attn_w = layer.self_attn.attention(
|
| 625 |
+
Qh, K_all, V_all, mask=step_mask, position_bias=self_position_bias
|
| 626 |
)
|
| 627 |
# attn_out_heads: (B, H, 1, d_k)
|
| 628 |
# concat heads, project out
|
|
|
|
| 664 |
) # (B,H,1,d_k)
|
| 665 |
|
| 666 |
cross_out_heads, cross_attn_w = layer.cross_attn.attention(
|
| 667 |
+
Qch, mem_k, mem_v, mask=memory_mask, position_bias=cross_position_bias
|
| 668 |
)
|
| 669 |
cross_out = (
|
| 670 |
cross_out_heads.transpose(1, 2)
|
src/models/encoder.py
CHANGED
|
@@ -14,16 +14,15 @@ Design choices:
|
|
| 14 |
- Optionally collect attention weights by passing collect_attn=True to forward().
|
| 15 |
"""
|
| 16 |
|
| 17 |
-
import
|
| 18 |
-
from typing import List, Optional, Tuple, Union
|
| 19 |
|
| 20 |
import torch
|
| 21 |
import torch.nn as nn
|
| 22 |
|
| 23 |
# Encoder implementation
|
| 24 |
-
from .attention import MultiHeadAttention
|
| 25 |
from .feedforward import FeedForward
|
| 26 |
-
from .positional_encoding import PositionalEncoding
|
| 27 |
|
| 28 |
|
| 29 |
class TransformerEncoderLayer(nn.Module):
|
|
@@ -36,6 +35,8 @@ class TransformerEncoderLayer(nn.Module):
|
|
| 36 |
d_ff: hidden dimension of the position-wise feed-forward network
|
| 37 |
dropout: dropout probability applied to sublayer outputs
|
| 38 |
quantization: optional quantization mode ("4bit", "8bit")
|
|
|
|
|
|
|
| 39 |
"""
|
| 40 |
|
| 41 |
def __init__(
|
|
@@ -45,14 +46,24 @@ class TransformerEncoderLayer(nn.Module):
|
|
| 45 |
d_ff: int,
|
| 46 |
dropout: float = 0.1,
|
| 47 |
quantization: Optional[str] = None,
|
|
|
|
|
|
|
| 48 |
):
|
| 49 |
super().__init__()
|
| 50 |
self.self_attn = MultiHeadAttention(
|
| 51 |
-
d_model=d_model,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
)
|
| 53 |
# set MHA internal dropout to 0.0 and use dropout1/dropout2 in the layer
|
| 54 |
self.ffn = FeedForward(
|
| 55 |
-
d_model=d_model,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
)
|
| 57 |
|
| 58 |
self.norm1 = nn.RMSNorm(d_model)
|
|
@@ -66,6 +77,7 @@ class TransformerEncoderLayer(nn.Module):
|
|
| 66 |
x: torch.Tensor,
|
| 67 |
mask: Optional[torch.Tensor] = None,
|
| 68 |
collect_attn: bool = False,
|
|
|
|
| 69 |
) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
|
| 70 |
"""
|
| 71 |
Forward pass for the encoder layer.
|
|
@@ -74,6 +86,7 @@ class TransformerEncoderLayer(nn.Module):
|
|
| 74 |
x: (batch, seq_len, d_model) - input embeddings / representations
|
| 75 |
mask: optional attention mask, shape either (batch, seq_q, seq_k) or (batch, 1, seq_q, seq_k)
|
| 76 |
collect_attn: whether to return attention weights
|
|
|
|
| 77 |
|
| 78 |
Returns:
|
| 79 |
x: (batch, seq_len, d_model)
|
|
@@ -83,15 +96,30 @@ class TransformerEncoderLayer(nn.Module):
|
|
| 83 |
x_norm = self.norm1(x) # Pre-LN
|
| 84 |
# self_attn expects query, key, value; for encoder they are the same
|
| 85 |
attn_out, attn_weights = self.self_attn(
|
| 86 |
-
x_norm,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
)
|
| 88 |
x = x + self.dropout1(attn_out)
|
| 89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
# Feed-forward sublayer (Pre-LN)
|
| 91 |
x_norm = self.norm2(x)
|
| 92 |
ffn_out = self.ffn(x_norm)
|
| 93 |
x = x + self.dropout2(ffn_out)
|
| 94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
# Return output (and optionally attn_weights if caller wants to collect them)
|
| 96 |
return x, attn_weights
|
| 97 |
|
|
@@ -123,17 +151,40 @@ class TransformerEncoder(nn.Module):
|
|
| 123 |
max_len: int = 512,
|
| 124 |
pad_token_id: Optional[int] = None,
|
| 125 |
quantization: Optional[str] = None,
|
|
|
|
|
|
|
|
|
|
| 126 |
):
|
| 127 |
super().__init__()
|
| 128 |
self.vocab_size = vocab_size
|
| 129 |
self.d_model = d_model
|
| 130 |
self.pad_token_id = pad_token_id
|
|
|
|
| 131 |
|
| 132 |
# Token embedding (only used if forward receives token ids)
|
| 133 |
-
self.embedding = nn.Embedding(vocab_size, d_model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
-
#
|
| 136 |
-
|
| 137 |
|
| 138 |
# Encoder layers stack
|
| 139 |
self.layers = nn.ModuleList(
|
|
@@ -144,6 +195,8 @@ class TransformerEncoder(nn.Module):
|
|
| 144 |
d_ff=d_ff,
|
| 145 |
dropout=dropout,
|
| 146 |
quantization=quantization,
|
|
|
|
|
|
|
| 147 |
)
|
| 148 |
for _ in range(num_layers)
|
| 149 |
]
|
|
@@ -197,16 +250,20 @@ class TransformerEncoder(nn.Module):
|
|
| 197 |
if inputs.dim() == 2: # token ids
|
| 198 |
if self.embedding is None:
|
| 199 |
raise ValueError("Encoder was not constructed with an embedding layer.")
|
| 200 |
-
|
|
|
|
|
|
|
| 201 |
elif inputs.dim() == 3: # already embeddings
|
| 202 |
x = inputs
|
|
|
|
| 203 |
else:
|
| 204 |
raise ValueError(
|
| 205 |
"inputs must be (batch, seq) token ids or (batch, seq, d_model) embeddings"
|
| 206 |
)
|
| 207 |
|
| 208 |
-
# Positional encoding + dropout
|
| 209 |
-
|
|
|
|
| 210 |
x = self.input_dropout(x)
|
| 211 |
|
| 212 |
# Build mask if needed
|
|
@@ -217,11 +274,16 @@ class TransformerEncoder(nn.Module):
|
|
| 217 |
if mask is not None:
|
| 218 |
mask = mask.to(dtype=torch.bool, device=x.device)
|
| 219 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
attn_weights_per_layer: List[torch.Tensor] = []
|
| 221 |
|
| 222 |
# Pass through each encoder layer (optionally collect attn)
|
| 223 |
for layer in self.layers:
|
| 224 |
-
x, attn = layer(x, mask=mask, collect_attn=collect_attn)
|
| 225 |
if collect_attn:
|
| 226 |
attn_weights_per_layer.append(attn)
|
| 227 |
|
|
|
|
| 14 |
- Optionally collect attention weights by passing collect_attn=True to forward().
|
| 15 |
"""
|
| 16 |
|
| 17 |
+
from typing import List, Literal, Optional, Tuple, Union
|
|
|
|
| 18 |
|
| 19 |
import torch
|
| 20 |
import torch.nn as nn
|
| 21 |
|
| 22 |
# Encoder implementation
|
| 23 |
+
from .attention import MultiHeadAttention, T5RelativePositionBias
|
| 24 |
from .feedforward import FeedForward
|
| 25 |
+
from .positional_encoding import LearnedPositionalEncoding, PositionalEncoding
|
| 26 |
|
| 27 |
|
| 28 |
class TransformerEncoderLayer(nn.Module):
|
|
|
|
| 35 |
d_ff: hidden dimension of the position-wise feed-forward network
|
| 36 |
dropout: dropout probability applied to sublayer outputs
|
| 37 |
quantization: optional quantization mode ("4bit", "8bit")
|
| 38 |
+
activation: activation function for FFN ("gelu", "relu", or "swiglu")
|
| 39 |
+
scale_attn_scores: Whether to scale attention scores by sqrt(d_k). T5 does NOT scale.
|
| 40 |
"""
|
| 41 |
|
| 42 |
def __init__(
|
|
|
|
| 46 |
d_ff: int,
|
| 47 |
dropout: float = 0.1,
|
| 48 |
quantization: Optional[str] = None,
|
| 49 |
+
activation: Literal["gelu", "relu", "swiglu", "gated-gelu"] = "gated-gelu",
|
| 50 |
+
scale_attn_scores: bool = True, # T5 uses False
|
| 51 |
):
|
| 52 |
super().__init__()
|
| 53 |
self.self_attn = MultiHeadAttention(
|
| 54 |
+
d_model=d_model,
|
| 55 |
+
num_heads=num_heads,
|
| 56 |
+
dropout=0.0,
|
| 57 |
+
quantization=quantization,
|
| 58 |
+
scale_scores=scale_attn_scores,
|
| 59 |
)
|
| 60 |
# set MHA internal dropout to 0.0 and use dropout1/dropout2 in the layer
|
| 61 |
self.ffn = FeedForward(
|
| 62 |
+
d_model=d_model,
|
| 63 |
+
d_ff=d_ff,
|
| 64 |
+
dropout=dropout,
|
| 65 |
+
activation=activation,
|
| 66 |
+
quantization=quantization,
|
| 67 |
)
|
| 68 |
|
| 69 |
self.norm1 = nn.RMSNorm(d_model)
|
|
|
|
| 77 |
x: torch.Tensor,
|
| 78 |
mask: Optional[torch.Tensor] = None,
|
| 79 |
collect_attn: bool = False,
|
| 80 |
+
position_bias: Optional[torch.Tensor] = None,
|
| 81 |
) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
|
| 82 |
"""
|
| 83 |
Forward pass for the encoder layer.
|
|
|
|
| 86 |
x: (batch, seq_len, d_model) - input embeddings / representations
|
| 87 |
mask: optional attention mask, shape either (batch, seq_q, seq_k) or (batch, 1, seq_q, seq_k)
|
| 88 |
collect_attn: whether to return attention weights
|
| 89 |
+
position_bias: optional (1, num_heads, seq_q, seq_k) T5-style relative position bias
|
| 90 |
|
| 91 |
Returns:
|
| 92 |
x: (batch, seq_len, d_model)
|
|
|
|
| 96 |
x_norm = self.norm1(x) # Pre-LN
|
| 97 |
# self_attn expects query, key, value; for encoder they are the same
|
| 98 |
attn_out, attn_weights = self.self_attn(
|
| 99 |
+
x_norm,
|
| 100 |
+
x_norm,
|
| 101 |
+
x_norm,
|
| 102 |
+
mask,
|
| 103 |
+
return_attn_weights=collect_attn,
|
| 104 |
+
position_bias=position_bias,
|
| 105 |
)
|
| 106 |
x = x + self.dropout1(attn_out)
|
| 107 |
|
| 108 |
+
# Clamp inf values for fp16/bf16 training stability (like HuggingFace T5)
|
| 109 |
+
if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
|
| 110 |
+
clamp_value = torch.finfo(x.dtype).max - 1000
|
| 111 |
+
x = torch.clamp(x, min=-clamp_value, max=clamp_value)
|
| 112 |
+
|
| 113 |
# Feed-forward sublayer (Pre-LN)
|
| 114 |
x_norm = self.norm2(x)
|
| 115 |
ffn_out = self.ffn(x_norm)
|
| 116 |
x = x + self.dropout2(ffn_out)
|
| 117 |
|
| 118 |
+
# Clamp inf values for fp16/bf16 training stability
|
| 119 |
+
if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
|
| 120 |
+
clamp_value = torch.finfo(x.dtype).max - 1000
|
| 121 |
+
x = torch.clamp(x, min=-clamp_value, max=clamp_value)
|
| 122 |
+
|
| 123 |
# Return output (and optionally attn_weights if caller wants to collect them)
|
| 124 |
return x, attn_weights
|
| 125 |
|
|
|
|
| 151 |
max_len: int = 512,
|
| 152 |
pad_token_id: Optional[int] = None,
|
| 153 |
quantization: Optional[str] = None,
|
| 154 |
+
use_learned_pos_enc: bool = False,
|
| 155 |
+
activation: Literal["gelu", "relu", "swiglu", "gated-gelu"] = "gated-gelu",
|
| 156 |
+
use_relative_position_bias: bool = False, # T5-style relative position bias
|
| 157 |
):
|
| 158 |
super().__init__()
|
| 159 |
self.vocab_size = vocab_size
|
| 160 |
self.d_model = d_model
|
| 161 |
self.pad_token_id = pad_token_id
|
| 162 |
+
self.use_relative_position_bias = use_relative_position_bias
|
| 163 |
|
| 164 |
# Token embedding (only used if forward receives token ids)
|
| 165 |
+
self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_token_id)
|
| 166 |
+
|
| 167 |
+
# Positional encoding (disabled when using relative position bias for T5)
|
| 168 |
+
self.relative_position_bias: Optional[T5RelativePositionBias] = None
|
| 169 |
+
if use_relative_position_bias:
|
| 170 |
+
# T5 uses relative position bias instead of absolute positional embeddings
|
| 171 |
+
self.pos_encoder = None
|
| 172 |
+
self.relative_position_bias = T5RelativePositionBias(
|
| 173 |
+
num_heads=num_heads,
|
| 174 |
+
num_buckets=32,
|
| 175 |
+
max_distance=128,
|
| 176 |
+
is_decoder=False,
|
| 177 |
+
)
|
| 178 |
+
elif use_learned_pos_enc:
|
| 179 |
+
# T5 uses max_len=512 by default; we add buffer for special tokens
|
| 180 |
+
self.pos_encoder = LearnedPositionalEncoding(
|
| 181 |
+
d_model=d_model, max_len=max_len + 2, dropout=dropout
|
| 182 |
+
)
|
| 183 |
+
else:
|
| 184 |
+
self.pos_encoder = PositionalEncoding(d_model=d_model, max_len=max_len, dropout=dropout)
|
| 185 |
|
| 186 |
+
# T5 does NOT scale attention scores by sqrt(d_k), others do
|
| 187 |
+
scale_attn_scores = not use_relative_position_bias
|
| 188 |
|
| 189 |
# Encoder layers stack
|
| 190 |
self.layers = nn.ModuleList(
|
|
|
|
| 195 |
d_ff=d_ff,
|
| 196 |
dropout=dropout,
|
| 197 |
quantization=quantization,
|
| 198 |
+
activation=activation,
|
| 199 |
+
scale_attn_scores=scale_attn_scores,
|
| 200 |
)
|
| 201 |
for _ in range(num_layers)
|
| 202 |
]
|
|
|
|
| 250 |
if inputs.dim() == 2: # token ids
|
| 251 |
if self.embedding is None:
|
| 252 |
raise ValueError("Encoder was not constructed with an embedding layer.")
|
| 253 |
+
# T5/FLAN-T5 does NOT scale embeddings by sqrt(d_model)
|
| 254 |
+
x = self.embedding(inputs)
|
| 255 |
+
seq_len = inputs.size(1)
|
| 256 |
elif inputs.dim() == 3: # already embeddings
|
| 257 |
x = inputs
|
| 258 |
+
seq_len = inputs.size(1)
|
| 259 |
else:
|
| 260 |
raise ValueError(
|
| 261 |
"inputs must be (batch, seq) token ids or (batch, seq, d_model) embeddings"
|
| 262 |
)
|
| 263 |
|
| 264 |
+
# Positional encoding + dropout (only if not using relative position bias)
|
| 265 |
+
if self.pos_encoder is not None:
|
| 266 |
+
x = self.pos_encoder(x)
|
| 267 |
x = self.input_dropout(x)
|
| 268 |
|
| 269 |
# Build mask if needed
|
|
|
|
| 274 |
if mask is not None:
|
| 275 |
mask = mask.to(dtype=torch.bool, device=x.device)
|
| 276 |
|
| 277 |
+
# Compute relative position bias if using T5-style
|
| 278 |
+
position_bias = None
|
| 279 |
+
if self.relative_position_bias is not None:
|
| 280 |
+
position_bias = self.relative_position_bias(seq_len, seq_len, x.device)
|
| 281 |
+
|
| 282 |
attn_weights_per_layer: List[torch.Tensor] = []
|
| 283 |
|
| 284 |
# Pass through each encoder layer (optionally collect attn)
|
| 285 |
for layer in self.layers:
|
| 286 |
+
x, attn = layer(x, mask=mask, collect_attn=collect_attn, position_bias=position_bias)
|
| 287 |
if collect_attn:
|
| 288 |
attn_weights_per_layer.append(attn)
|
| 289 |
|
src/models/factory.py
CHANGED
|
@@ -4,10 +4,10 @@ from __future__ import annotations
|
|
| 4 |
|
| 5 |
from dataclasses import dataclass
|
| 6 |
from pathlib import Path
|
| 7 |
-
from typing import Optional
|
| 8 |
|
| 9 |
import torch
|
| 10 |
-
from transformers import
|
| 11 |
|
| 12 |
from ..data.tokenization import Tokenizer
|
| 13 |
from ..utils.config import load_yaml
|
|
@@ -16,20 +16,30 @@ from .encoder import TransformerEncoder
|
|
| 16 |
from .heads import ClassificationHead, LMHead
|
| 17 |
from .multitask import MultiTaskModel
|
| 18 |
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
@dataclass
|
| 21 |
class ModelConfig:
|
| 22 |
"""Configuration describing the transformer architecture."""
|
| 23 |
|
| 24 |
-
d_model: int =
|
| 25 |
-
num_encoder_layers: int =
|
| 26 |
-
num_decoder_layers: int =
|
| 27 |
-
num_attention_heads: int =
|
| 28 |
-
ffn_dim: int =
|
| 29 |
dropout: float = 0.1
|
| 30 |
use_pretrained: bool = False
|
| 31 |
-
pretrained_model_name: str = "
|
| 32 |
quantization: Optional[str] = None # "4bit" or "8bit"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
def __post_init__(self):
|
| 35 |
if self.d_model % self.num_attention_heads != 0:
|
|
@@ -63,103 +73,226 @@ def load_model_config(path: Optional[str | Path]) -> ModelConfig:
|
|
| 63 |
ffn_dim=int(data.get("ffn_dim", 2048)),
|
| 64 |
dropout=float(data.get("dropout", 0.1)),
|
| 65 |
use_pretrained=bool(data.get("use_pretrained", False)),
|
| 66 |
-
pretrained_model_name=str(data.get("pretrained_model_name", "
|
| 67 |
quantization=data.get("quantization", None),
|
|
|
|
|
|
|
|
|
|
| 68 |
)
|
| 69 |
|
| 70 |
|
| 71 |
def _load_pretrained_weights(
|
| 72 |
encoder: TransformerEncoder, decoder: TransformerDecoder, model_name: str
|
| 73 |
) -> None:
|
| 74 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
print(f"Loading pretrained weights from {model_name}...")
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
# Load encoder weights
|
| 79 |
print("Transferring encoder weights...")
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
custom_layer.self_attn.W_Q.
|
| 90 |
-
custom_layer.self_attn.W_K.weight.data.copy_(
|
| 91 |
-
custom_layer.self_attn.
|
| 92 |
-
custom_layer.self_attn.
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
custom_layer.self_attn.
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
# Load decoder weights
|
| 116 |
print("Transferring decoder weights...")
|
| 117 |
-
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
-
for _i, (custom_layer, bart_layer) in enumerate(
|
| 121 |
-
zip(decoder.layers, bart.decoder.layers, strict=False)
|
| 122 |
-
):
|
| 123 |
# Self-attention
|
| 124 |
-
custom_layer.self_attn.W_Q.weight.data.copy_(
|
| 125 |
-
custom_layer.self_attn.
|
| 126 |
-
custom_layer.self_attn.
|
| 127 |
-
custom_layer.self_attn.
|
| 128 |
-
|
| 129 |
-
custom_layer.self_attn.
|
| 130 |
-
|
| 131 |
-
|
|
|
|
|
|
|
| 132 |
|
| 133 |
# Cross-attention
|
| 134 |
-
custom_layer.cross_attn.W_Q.weight.data.copy_(
|
| 135 |
-
custom_layer.cross_attn.
|
| 136 |
-
custom_layer.cross_attn.
|
| 137 |
-
custom_layer.cross_attn.
|
| 138 |
-
|
| 139 |
-
custom_layer.cross_attn.
|
| 140 |
-
|
| 141 |
-
|
|
|
|
|
|
|
| 142 |
|
| 143 |
# Layer norms
|
| 144 |
-
custom_layer.norm1.weight.data.copy_(
|
| 145 |
-
custom_layer.
|
| 146 |
-
custom_layer.
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
custom_layer.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
|
|
|
|
|
|
| 156 |
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
decoder.final_norm.weight.data.copy_(bart.decoder.layernorm_embedding.weight.data)
|
| 160 |
-
decoder.final_norm.bias.data.copy_(bart.decoder.layernorm_embedding.bias.data)
|
| 161 |
|
| 162 |
-
print("Pretrained weights loaded successfully!")
|
| 163 |
|
| 164 |
|
| 165 |
def _load_llama_weights(
|
|
@@ -313,6 +446,17 @@ def build_multitask_model(
|
|
| 313 |
if not isinstance(num_topics, int) or num_topics <= 0:
|
| 314 |
raise ValueError("num_topics must be a positive integer")
|
| 315 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
encoder = TransformerEncoder(
|
| 317 |
vocab_size=tokenizer.vocab_size,
|
| 318 |
d_model=cfg.d_model,
|
|
@@ -320,9 +464,12 @@ def build_multitask_model(
|
|
| 320 |
num_heads=cfg.num_attention_heads,
|
| 321 |
d_ff=cfg.ffn_dim,
|
| 322 |
dropout=cfg.dropout,
|
| 323 |
-
max_len=
|
| 324 |
pad_token_id=tokenizer.pad_token_id,
|
| 325 |
quantization=cfg.quantization,
|
|
|
|
|
|
|
|
|
|
| 326 |
)
|
| 327 |
decoder = TransformerDecoder(
|
| 328 |
vocab_size=tokenizer.vocab_size,
|
|
@@ -331,28 +478,31 @@ def build_multitask_model(
|
|
| 331 |
num_heads=cfg.num_attention_heads,
|
| 332 |
d_ff=cfg.ffn_dim,
|
| 333 |
dropout=cfg.dropout,
|
| 334 |
-
max_len=
|
| 335 |
pad_token_id=tokenizer.pad_token_id,
|
| 336 |
quantization=cfg.quantization,
|
|
|
|
|
|
|
|
|
|
| 337 |
)
|
| 338 |
|
| 339 |
# Load pretrained weights if requested (but allow override for inference)
|
| 340 |
should_load = cfg.use_pretrained if load_pretrained is None else load_pretrained
|
| 341 |
if should_load:
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
_load_llama_weights(
|
| 347 |
encoder, decoder, cfg.pretrained_model_name, quantization=cfg.quantization
|
| 348 |
)
|
| 349 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
_load_pretrained_weights(encoder, decoder, cfg.pretrained_model_name)
|
| 351 |
|
| 352 |
-
# NOTE: Weight tying disabled because the current checkpoint was trained without it
|
| 353 |
-
# For NEW training runs, uncomment this line to enable proper weight tying:
|
| 354 |
-
# decoder.output_projection.weight = decoder.embedding.weight
|
| 355 |
-
|
| 356 |
model = MultiTaskModel(encoder=encoder, decoder=decoder, decoder_outputs_logits=True)
|
| 357 |
model.add_head(
|
| 358 |
"summarization",
|
|
|
|
| 4 |
|
| 5 |
from dataclasses import dataclass
|
| 6 |
from pathlib import Path
|
| 7 |
+
from typing import Literal, Optional, cast
|
| 8 |
|
| 9 |
import torch
|
| 10 |
+
from transformers import T5ForConditionalGeneration
|
| 11 |
|
| 12 |
from ..data.tokenization import Tokenizer
|
| 13 |
from ..utils.config import load_yaml
|
|
|
|
| 16 |
from .heads import ClassificationHead, LMHead
|
| 17 |
from .multitask import MultiTaskModel
|
| 18 |
|
| 19 |
+
# Type alias for activation functions
|
| 20 |
+
ActivationType = Literal["gelu", "relu", "swiglu", "gated-gelu"]
|
| 21 |
+
|
| 22 |
|
| 23 |
@dataclass
|
| 24 |
class ModelConfig:
|
| 25 |
"""Configuration describing the transformer architecture."""
|
| 26 |
|
| 27 |
+
d_model: int = 768
|
| 28 |
+
num_encoder_layers: int = 12
|
| 29 |
+
num_decoder_layers: int = 12
|
| 30 |
+
num_attention_heads: int = 12
|
| 31 |
+
ffn_dim: int = 3072
|
| 32 |
dropout: float = 0.1
|
| 33 |
use_pretrained: bool = False
|
| 34 |
+
pretrained_model_name: str = "google/flan-t5-base"
|
| 35 |
quantization: Optional[str] = None # "4bit" or "8bit"
|
| 36 |
+
use_learned_pos_enc: bool = True # Use learned positional embeddings
|
| 37 |
+
activation: str = (
|
| 38 |
+
"gated-gelu" # "gelu", "relu", "swiglu", or "gated-gelu" (use gated-gelu for T5/FLAN-T5)
|
| 39 |
+
)
|
| 40 |
+
use_relative_position_bias: bool = (
|
| 41 |
+
False # T5-style relative position bias (use True for T5/FLAN-T5)
|
| 42 |
+
)
|
| 43 |
|
| 44 |
def __post_init__(self):
|
| 45 |
if self.d_model % self.num_attention_heads != 0:
|
|
|
|
| 73 |
ffn_dim=int(data.get("ffn_dim", 2048)),
|
| 74 |
dropout=float(data.get("dropout", 0.1)),
|
| 75 |
use_pretrained=bool(data.get("use_pretrained", False)),
|
| 76 |
+
pretrained_model_name=str(data.get("pretrained_model_name", "google/flan-t5-base")),
|
| 77 |
quantization=data.get("quantization", None),
|
| 78 |
+
use_learned_pos_enc=bool(data.get("use_learned_pos_enc", True)),
|
| 79 |
+
activation=str(data.get("activation", "gelu")),
|
| 80 |
+
use_relative_position_bias=bool(data.get("use_relative_position_bias", False)),
|
| 81 |
)
|
| 82 |
|
| 83 |
|
| 84 |
def _load_pretrained_weights(
|
| 85 |
encoder: TransformerEncoder, decoder: TransformerDecoder, model_name: str
|
| 86 |
) -> None:
|
| 87 |
+
"""
|
| 88 |
+
Load pretrained T5/FLAN-T5 weights into custom encoder/decoder.
|
| 89 |
+
|
| 90 |
+
T5 architecture compatibility with our custom Transformer:
|
| 91 |
+
- T5 uses Pre-LN (RMSNorm before sublayers) ✓ matches our design
|
| 92 |
+
- T5 uses relative position bias instead of absolute embeddings
|
| 93 |
+
-> We now load T5's relative position bias weights into our T5RelativePositionBias modules
|
| 94 |
+
-> This allows exact weight transfer without requiring fine-tuning
|
| 95 |
+
- T5 uses gated FFN (wi_0, wi_1, wo) - we use gated-gelu FFN matching this
|
| 96 |
+
- T5 attention has no bias, our attention has bias
|
| 97 |
+
-> We zero-initialize the bias terms
|
| 98 |
+
"""
|
| 99 |
print(f"Loading pretrained weights from {model_name}...")
|
| 100 |
+
t5 = T5ForConditionalGeneration.from_pretrained(model_name)
|
| 101 |
+
|
| 102 |
+
# Load shared embeddings (T5 uses shared embeddings for encoder and decoder)
|
| 103 |
+
# Note: T5's vocab is padded to multiple of 128 for efficiency (32100 -> 32128)
|
| 104 |
+
# Our model uses the tokenizer's actual vocab size, so we only copy the valid tokens
|
| 105 |
+
print("Transferring shared token embeddings...")
|
| 106 |
+
shared_embeddings = t5.shared.weight.data
|
| 107 |
+
our_vocab_size = encoder.embedding.weight.size(0)
|
| 108 |
+
t5_vocab_size = shared_embeddings.size(0)
|
| 109 |
+
|
| 110 |
+
if our_vocab_size != t5_vocab_size:
|
| 111 |
+
print(f" Vocab size mismatch: our model={our_vocab_size}, T5={t5_vocab_size}")
|
| 112 |
+
# Copy only the tokens that exist in both (T5 pads vocab to multiple of 128)
|
| 113 |
+
min_vocab = min(our_vocab_size, t5_vocab_size)
|
| 114 |
+
print(f" Copying first {min_vocab} token embeddings...")
|
| 115 |
+
encoder.embedding.weight.data[:min_vocab].copy_(shared_embeddings[:min_vocab])
|
| 116 |
+
decoder.embedding.weight.data[:min_vocab].copy_(shared_embeddings[:min_vocab])
|
| 117 |
+
else:
|
| 118 |
+
encoder.embedding.weight.data.copy_(shared_embeddings)
|
| 119 |
+
decoder.embedding.weight.data.copy_(shared_embeddings)
|
| 120 |
+
|
| 121 |
+
# Note: T5 uses relative position bias (computed in attention, not absolute embeddings).
|
| 122 |
+
# We now use T5RelativePositionBias which will be loaded below. The pos_encoder in our model
|
| 123 |
+
# is still present but adds zero/minimal contribution when relative_position_bias is used.
|
| 124 |
|
| 125 |
# Load encoder weights
|
| 126 |
print("Transferring encoder weights...")
|
| 127 |
+
t5_encoder = t5.encoder
|
| 128 |
+
|
| 129 |
+
for custom_layer, t5_layer in zip(encoder.layers, t5_encoder.block, strict=False):
|
| 130 |
+
t5_self_attn = t5_layer.layer[0].SelfAttention
|
| 131 |
+
t5_ffn = t5_layer.layer[1].DenseReluDense
|
| 132 |
+
t5_norm1 = t5_layer.layer[0].layer_norm
|
| 133 |
+
t5_norm2 = t5_layer.layer[1].layer_norm
|
| 134 |
+
|
| 135 |
+
# Self-attention (T5 has no bias in attention projections)
|
| 136 |
+
custom_layer.self_attn.W_Q.weight.data.copy_(t5_self_attn.q.weight.data)
|
| 137 |
+
custom_layer.self_attn.W_K.weight.data.copy_(t5_self_attn.k.weight.data)
|
| 138 |
+
custom_layer.self_attn.W_V.weight.data.copy_(t5_self_attn.v.weight.data)
|
| 139 |
+
custom_layer.self_attn.W_O.weight.data.copy_(t5_self_attn.o.weight.data)
|
| 140 |
+
|
| 141 |
+
# Zero-initialize bias (T5 doesn't have attention bias)
|
| 142 |
+
if custom_layer.self_attn.W_Q.bias is not None:
|
| 143 |
+
custom_layer.self_attn.W_Q.bias.data.zero_()
|
| 144 |
+
custom_layer.self_attn.W_K.bias.data.zero_()
|
| 145 |
+
custom_layer.self_attn.W_V.bias.data.zero_()
|
| 146 |
+
custom_layer.self_attn.W_O.bias.data.zero_()
|
| 147 |
+
|
| 148 |
+
# Layer norms (T5 uses RMSNorm like us, just weight, no bias)
|
| 149 |
+
custom_layer.norm1.weight.data.copy_(t5_norm1.weight.data)
|
| 150 |
+
custom_layer.norm2.weight.data.copy_(t5_norm2.weight.data)
|
| 151 |
+
|
| 152 |
+
# FFN - T5 uses gated FFN: wi_0 (gate), wi_1 (up), wo (down)
|
| 153 |
+
# If our model uses swiglu activation: linear_gate (gate), linear1 (up), linear2 (down)
|
| 154 |
+
# If our model uses standard activation: linear1 (up), linear2 (down) - partial transfer
|
| 155 |
+
if hasattr(t5_ffn, "wi_0") and hasattr(custom_layer.ffn, "linear_gate"):
|
| 156 |
+
# Full gated FFN transfer (swiglu mode)
|
| 157 |
+
custom_layer.ffn.linear_gate.weight.data.copy_(t5_ffn.wi_0.weight.data)
|
| 158 |
+
custom_layer.ffn.linear1.weight.data.copy_(t5_ffn.wi_1.weight.data)
|
| 159 |
+
custom_layer.ffn.linear2.weight.data.copy_(t5_ffn.wo.weight.data)
|
| 160 |
+
if custom_layer.ffn.linear_gate.bias is not None:
|
| 161 |
+
custom_layer.ffn.linear_gate.bias.data.zero_()
|
| 162 |
+
elif hasattr(t5_ffn, "wi_1"):
|
| 163 |
+
# T5 v1.1 / FLAN-T5 gated FFN -> standard FFN (partial transfer)
|
| 164 |
+
custom_layer.ffn.linear1.weight.data.copy_(t5_ffn.wi_1.weight.data)
|
| 165 |
+
custom_layer.ffn.linear2.weight.data.copy_(t5_ffn.wo.weight.data)
|
| 166 |
+
elif hasattr(t5_ffn, "wi"):
|
| 167 |
+
# Original T5 v1.0
|
| 168 |
+
custom_layer.ffn.linear1.weight.data.copy_(t5_ffn.wi.weight.data)
|
| 169 |
+
custom_layer.ffn.linear2.weight.data.copy_(t5_ffn.wo.weight.data)
|
| 170 |
+
|
| 171 |
+
# Zero-initialize FFN bias (T5 doesn't have FFN bias)
|
| 172 |
+
if custom_layer.ffn.linear1.bias is not None:
|
| 173 |
+
custom_layer.ffn.linear1.bias.data.zero_()
|
| 174 |
+
custom_layer.ffn.linear2.bias.data.zero_()
|
| 175 |
+
|
| 176 |
+
# Encoder final norm
|
| 177 |
+
encoder.final_norm.weight.data.copy_(t5_encoder.final_layer_norm.weight.data)
|
| 178 |
+
|
| 179 |
+
# Load encoder relative position bias (T5 stores it only in first layer, shared across all layers)
|
| 180 |
+
if hasattr(encoder, "relative_position_bias") and encoder.relative_position_bias is not None:
|
| 181 |
+
print("Transferring encoder relative position bias...")
|
| 182 |
+
t5_enc_rel_bias = (
|
| 183 |
+
t5_encoder.block[0].layer[0].SelfAttention.relative_attention_bias.weight.data
|
| 184 |
+
)
|
| 185 |
+
encoder.relative_position_bias.relative_attention_bias.weight.data.copy_(t5_enc_rel_bias)
|
| 186 |
|
| 187 |
# Load decoder weights
|
| 188 |
print("Transferring decoder weights...")
|
| 189 |
+
t5_decoder = t5.decoder
|
| 190 |
+
|
| 191 |
+
for custom_layer, t5_layer in zip(decoder.layers, t5_decoder.block, strict=False):
|
| 192 |
+
t5_self_attn = t5_layer.layer[0].SelfAttention
|
| 193 |
+
t5_cross_attn = t5_layer.layer[1].EncDecAttention
|
| 194 |
+
t5_ffn = t5_layer.layer[2].DenseReluDense
|
| 195 |
+
t5_norm1 = t5_layer.layer[0].layer_norm
|
| 196 |
+
t5_norm2 = t5_layer.layer[1].layer_norm
|
| 197 |
+
t5_norm3 = t5_layer.layer[2].layer_norm
|
| 198 |
|
|
|
|
|
|
|
|
|
|
| 199 |
# Self-attention
|
| 200 |
+
custom_layer.self_attn.W_Q.weight.data.copy_(t5_self_attn.q.weight.data)
|
| 201 |
+
custom_layer.self_attn.W_K.weight.data.copy_(t5_self_attn.k.weight.data)
|
| 202 |
+
custom_layer.self_attn.W_V.weight.data.copy_(t5_self_attn.v.weight.data)
|
| 203 |
+
custom_layer.self_attn.W_O.weight.data.copy_(t5_self_attn.o.weight.data)
|
| 204 |
+
|
| 205 |
+
if custom_layer.self_attn.W_Q.bias is not None:
|
| 206 |
+
custom_layer.self_attn.W_Q.bias.data.zero_()
|
| 207 |
+
custom_layer.self_attn.W_K.bias.data.zero_()
|
| 208 |
+
custom_layer.self_attn.W_V.bias.data.zero_()
|
| 209 |
+
custom_layer.self_attn.W_O.bias.data.zero_()
|
| 210 |
|
| 211 |
# Cross-attention
|
| 212 |
+
custom_layer.cross_attn.W_Q.weight.data.copy_(t5_cross_attn.q.weight.data)
|
| 213 |
+
custom_layer.cross_attn.W_K.weight.data.copy_(t5_cross_attn.k.weight.data)
|
| 214 |
+
custom_layer.cross_attn.W_V.weight.data.copy_(t5_cross_attn.v.weight.data)
|
| 215 |
+
custom_layer.cross_attn.W_O.weight.data.copy_(t5_cross_attn.o.weight.data)
|
| 216 |
+
|
| 217 |
+
if custom_layer.cross_attn.W_Q.bias is not None:
|
| 218 |
+
custom_layer.cross_attn.W_Q.bias.data.zero_()
|
| 219 |
+
custom_layer.cross_attn.W_K.bias.data.zero_()
|
| 220 |
+
custom_layer.cross_attn.W_V.bias.data.zero_()
|
| 221 |
+
custom_layer.cross_attn.W_O.bias.data.zero_()
|
| 222 |
|
| 223 |
# Layer norms
|
| 224 |
+
custom_layer.norm1.weight.data.copy_(t5_norm1.weight.data)
|
| 225 |
+
custom_layer.norm2.weight.data.copy_(t5_norm2.weight.data)
|
| 226 |
+
custom_layer.norm3.weight.data.copy_(t5_norm3.weight.data)
|
| 227 |
+
|
| 228 |
+
# FFN - same gated logic as encoder
|
| 229 |
+
if hasattr(t5_ffn, "wi_0") and hasattr(custom_layer.ffn, "linear_gate"):
|
| 230 |
+
# Full gated FFN transfer (swiglu mode)
|
| 231 |
+
custom_layer.ffn.linear_gate.weight.data.copy_(t5_ffn.wi_0.weight.data)
|
| 232 |
+
custom_layer.ffn.linear1.weight.data.copy_(t5_ffn.wi_1.weight.data)
|
| 233 |
+
custom_layer.ffn.linear2.weight.data.copy_(t5_ffn.wo.weight.data)
|
| 234 |
+
if custom_layer.ffn.linear_gate.bias is not None:
|
| 235 |
+
custom_layer.ffn.linear_gate.bias.data.zero_()
|
| 236 |
+
elif hasattr(t5_ffn, "wi_1"):
|
| 237 |
+
custom_layer.ffn.linear1.weight.data.copy_(t5_ffn.wi_1.weight.data)
|
| 238 |
+
custom_layer.ffn.linear2.weight.data.copy_(t5_ffn.wo.weight.data)
|
| 239 |
+
elif hasattr(t5_ffn, "wi"):
|
| 240 |
+
custom_layer.ffn.linear1.weight.data.copy_(t5_ffn.wi.weight.data)
|
| 241 |
+
custom_layer.ffn.linear2.weight.data.copy_(t5_ffn.wo.weight.data)
|
| 242 |
+
|
| 243 |
+
if custom_layer.ffn.linear1.bias is not None:
|
| 244 |
+
custom_layer.ffn.linear1.bias.data.zero_()
|
| 245 |
+
custom_layer.ffn.linear2.bias.data.zero_()
|
| 246 |
+
|
| 247 |
+
# Decoder final norm
|
| 248 |
+
decoder.final_norm.weight.data.copy_(t5_decoder.final_layer_norm.weight.data)
|
| 249 |
+
|
| 250 |
+
# Load decoder relative position biases (T5 stores them in first layer, shared across all layers)
|
| 251 |
+
# Decoder has both self-attention bias and cross-attention bias
|
| 252 |
+
if (
|
| 253 |
+
hasattr(decoder, "self_relative_position_bias")
|
| 254 |
+
and decoder.self_relative_position_bias is not None
|
| 255 |
+
):
|
| 256 |
+
print("Transferring decoder self-attention relative position bias...")
|
| 257 |
+
t5_dec_self_rel_bias = (
|
| 258 |
+
t5_decoder.block[0].layer[0].SelfAttention.relative_attention_bias.weight.data
|
| 259 |
+
)
|
| 260 |
+
decoder.self_relative_position_bias.relative_attention_bias.weight.data.copy_(
|
| 261 |
+
t5_dec_self_rel_bias
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
if (
|
| 265 |
+
hasattr(decoder, "cross_relative_position_bias")
|
| 266 |
+
and decoder.cross_relative_position_bias is not None
|
| 267 |
+
):
|
| 268 |
+
print("Transferring decoder cross-attention relative position bias...")
|
| 269 |
+
# Cross-attention relative position bias is in EncDecAttention of first block
|
| 270 |
+
t5_dec_cross_rel_bias = (
|
| 271 |
+
t5_decoder.block[0].layer[1].EncDecAttention.relative_attention_bias.weight.data
|
| 272 |
+
)
|
| 273 |
+
decoder.cross_relative_position_bias.relative_attention_bias.weight.data.copy_(
|
| 274 |
+
t5_dec_cross_rel_bias
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
# Load LM head weights (T5's lm_head)
|
| 278 |
+
# Handle vocab size mismatch (T5 pads to multiple of 128)
|
| 279 |
+
print("Transferring LM head weights...")
|
| 280 |
+
lm_head_weights = t5.lm_head.weight.data
|
| 281 |
+
our_vocab_size = decoder.output_projection.weight.size(0)
|
| 282 |
+
t5_vocab_size = lm_head_weights.size(0)
|
| 283 |
|
| 284 |
+
if our_vocab_size != t5_vocab_size:
|
| 285 |
+
print(f" LM head vocab mismatch: our model={our_vocab_size}, T5={t5_vocab_size}")
|
| 286 |
+
min_vocab = min(our_vocab_size, t5_vocab_size)
|
| 287 |
+
print(f" Copying first {min_vocab} LM head weights...")
|
| 288 |
+
decoder.output_projection.weight.data[:min_vocab].copy_(lm_head_weights[:min_vocab])
|
| 289 |
+
else:
|
| 290 |
+
decoder.output_projection.weight.data.copy_(lm_head_weights)
|
| 291 |
|
| 292 |
+
if decoder.output_projection.bias is not None:
|
| 293 |
+
decoder.output_projection.bias.data.zero_()
|
|
|
|
|
|
|
| 294 |
|
| 295 |
+
print("Pretrained FLAN-T5 weights loaded successfully!")
|
| 296 |
|
| 297 |
|
| 298 |
def _load_llama_weights(
|
|
|
|
| 446 |
if not isinstance(num_topics, int) or num_topics <= 0:
|
| 447 |
raise ValueError("num_topics must be a positive integer")
|
| 448 |
|
| 449 |
+
# Get max_length from tokenizer (handle both custom and HF tokenizers)
|
| 450 |
+
if hasattr(tokenizer, "config") and hasattr(tokenizer.config, "max_length"):
|
| 451 |
+
max_len = tokenizer.config.max_length
|
| 452 |
+
elif hasattr(tokenizer, "model_max_length"):
|
| 453 |
+
max_len = tokenizer.model_max_length
|
| 454 |
+
else:
|
| 455 |
+
max_len = 512 # Default fallback
|
| 456 |
+
|
| 457 |
+
# Cast activation to the literal type for mypy
|
| 458 |
+
activation = cast(ActivationType, cfg.activation)
|
| 459 |
+
|
| 460 |
encoder = TransformerEncoder(
|
| 461 |
vocab_size=tokenizer.vocab_size,
|
| 462 |
d_model=cfg.d_model,
|
|
|
|
| 464 |
num_heads=cfg.num_attention_heads,
|
| 465 |
d_ff=cfg.ffn_dim,
|
| 466 |
dropout=cfg.dropout,
|
| 467 |
+
max_len=max_len,
|
| 468 |
pad_token_id=tokenizer.pad_token_id,
|
| 469 |
quantization=cfg.quantization,
|
| 470 |
+
use_learned_pos_enc=cfg.use_learned_pos_enc,
|
| 471 |
+
activation=activation,
|
| 472 |
+
use_relative_position_bias=cfg.use_relative_position_bias,
|
| 473 |
)
|
| 474 |
decoder = TransformerDecoder(
|
| 475 |
vocab_size=tokenizer.vocab_size,
|
|
|
|
| 478 |
num_heads=cfg.num_attention_heads,
|
| 479 |
d_ff=cfg.ffn_dim,
|
| 480 |
dropout=cfg.dropout,
|
| 481 |
+
max_len=max_len,
|
| 482 |
pad_token_id=tokenizer.pad_token_id,
|
| 483 |
quantization=cfg.quantization,
|
| 484 |
+
use_learned_pos_enc=cfg.use_learned_pos_enc,
|
| 485 |
+
activation=activation,
|
| 486 |
+
use_relative_position_bias=cfg.use_relative_position_bias,
|
| 487 |
)
|
| 488 |
|
| 489 |
# Load pretrained weights if requested (but allow override for inference)
|
| 490 |
should_load = cfg.use_pretrained if load_pretrained is None else load_pretrained
|
| 491 |
if should_load:
|
| 492 |
+
model_name_lower = cfg.pretrained_model_name.lower()
|
| 493 |
+
if "t5" in model_name_lower or "flan" in model_name_lower:
|
| 494 |
+
_load_pretrained_weights(encoder, decoder, cfg.pretrained_model_name)
|
| 495 |
+
elif "llama" in model_name_lower or "gemma" in model_name_lower:
|
| 496 |
_load_llama_weights(
|
| 497 |
encoder, decoder, cfg.pretrained_model_name, quantization=cfg.quantization
|
| 498 |
)
|
| 499 |
else:
|
| 500 |
+
# Default to T5 loading for unknown models
|
| 501 |
+
print(
|
| 502 |
+
f"Warning: Unknown model type '{cfg.pretrained_model_name}', attempting T5-style loading..."
|
| 503 |
+
)
|
| 504 |
_load_pretrained_weights(encoder, decoder, cfg.pretrained_model_name)
|
| 505 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 506 |
model = MultiTaskModel(encoder=encoder, decoder=decoder, decoder_outputs_logits=True)
|
| 507 |
model.add_head(
|
| 508 |
"summarization",
|
src/models/feedforward.py
CHANGED
|
@@ -15,6 +15,7 @@ class FeedForward(nn.Module):
|
|
| 15 |
|
| 16 |
Or with GELU: FFN(x) = GELU(xW₁ + b₁)W₂ + b₂
|
| 17 |
Or with SwiGLU: FFN(x) = (Swish(xW_gate) * xW_up)W_down
|
|
|
|
| 18 |
"""
|
| 19 |
|
| 20 |
def __init__(
|
|
@@ -22,7 +23,7 @@ class FeedForward(nn.Module):
|
|
| 22 |
d_model: int,
|
| 23 |
d_ff: int,
|
| 24 |
dropout: float = 0.1,
|
| 25 |
-
activation: Literal["gelu", "relu", "swiglu"] = "gelu",
|
| 26 |
quantization: Optional[str] = None,
|
| 27 |
):
|
| 28 |
super().__init__()
|
|
@@ -47,20 +48,22 @@ class FeedForward(nn.Module):
|
|
| 47 |
except (ImportError, AttributeError):
|
| 48 |
print("bitsandbytes not installed or incompatible, falling back to nn.Linear")
|
| 49 |
|
| 50 |
-
if activation
|
| 51 |
-
#
|
| 52 |
-
#
|
| 53 |
-
|
| 54 |
-
self.
|
| 55 |
-
self.
|
| 56 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
# Init gate
|
| 59 |
-
# Note: bnb layers might not support direct init like this if they are already quantized/packed
|
| 60 |
-
# But if we are initializing from scratch, they are just empty params.
|
| 61 |
-
# However, bnb layers are usually used for loading pretrained weights.
|
| 62 |
-
# If training from scratch with 4bit, it's unusual (QLoRA is for finetuning).
|
| 63 |
-
# We'll assume standard init works or is overwritten by loading.
|
| 64 |
if not quantization:
|
| 65 |
init.xavier_uniform_(self.linear_gate.weight)
|
| 66 |
init.zeros_(self.linear_gate.bias)
|
|
@@ -83,8 +86,8 @@ class FeedForward(nn.Module):
|
|
| 83 |
x: (batch, seq_len, d_model)
|
| 84 |
returns: (batch, seq_len, d_model)
|
| 85 |
"""
|
| 86 |
-
if self.activation_type
|
| 87 |
-
#
|
| 88 |
gate = self.activation(self.linear_gate(x))
|
| 89 |
up = self.linear1(x)
|
| 90 |
x = gate * up
|
|
|
|
| 15 |
|
| 16 |
Or with GELU: FFN(x) = GELU(xW₁ + b₁)W₂ + b₂
|
| 17 |
Or with SwiGLU: FFN(x) = (Swish(xW_gate) * xW_up)W_down
|
| 18 |
+
Or with gated-gelu: FFN(x) = (GELU(xW_gate) * xW_up)W_down (T5/FLAN-T5 style)
|
| 19 |
"""
|
| 20 |
|
| 21 |
def __init__(
|
|
|
|
| 23 |
d_model: int,
|
| 24 |
d_ff: int,
|
| 25 |
dropout: float = 0.1,
|
| 26 |
+
activation: Literal["gelu", "relu", "swiglu", "gated-gelu"] = "gelu",
|
| 27 |
quantization: Optional[str] = None,
|
| 28 |
):
|
| 29 |
super().__init__()
|
|
|
|
| 48 |
except (ImportError, AttributeError):
|
| 49 |
print("bitsandbytes not installed or incompatible, falling back to nn.Linear")
|
| 50 |
|
| 51 |
+
if activation in ("swiglu", "gated-gelu"):
|
| 52 |
+
# Gated FFN requires 3 linear layers: Gate, Up, Down
|
| 53 |
+
# - swiglu uses SiLU (Swish) activation (LLaMA style)
|
| 54 |
+
# - gated-gelu uses GELU activation (T5/FLAN-T5 style)
|
| 55 |
+
self.linear_gate = Linear(d_model, d_ff, **kwargs) # Gate projection (wi_0)
|
| 56 |
+
self.linear1 = Linear(d_model, d_ff, **kwargs) # Up projection (wi_1)
|
| 57 |
+
self.linear2 = Linear(d_ff, d_model, **kwargs) # Down projection (wo)
|
| 58 |
+
|
| 59 |
+
if activation == "swiglu":
|
| 60 |
+
self.activation = nn.SiLU() # Swish activation
|
| 61 |
+
else: # gated-gelu
|
| 62 |
+
self.activation = (
|
| 63 |
+
nn.GELU()
|
| 64 |
+
) # GELU activation (T5 uses gelu_new which is very close)
|
| 65 |
|
| 66 |
# Init gate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
if not quantization:
|
| 68 |
init.xavier_uniform_(self.linear_gate.weight)
|
| 69 |
init.zeros_(self.linear_gate.bias)
|
|
|
|
| 86 |
x: (batch, seq_len, d_model)
|
| 87 |
returns: (batch, seq_len, d_model)
|
| 88 |
"""
|
| 89 |
+
if self.activation_type in ("swiglu", "gated-gelu"):
|
| 90 |
+
# Gated FFN: (activation(xW_gate) * xW_up) W_down
|
| 91 |
gate = self.activation(self.linear_gate(x))
|
| 92 |
up = self.linear1(x)
|
| 93 |
x = gate * up
|
src/models/heads.py
CHANGED
|
@@ -40,16 +40,36 @@ class ClassificationHead(nn.Module):
|
|
| 40 |
self.dropout = nn.Dropout(dropout)
|
| 41 |
self.out_proj = nn.Linear(d_model, num_labels)
|
| 42 |
|
| 43 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 44 |
"""
|
| 45 |
x: (batch, seq_len, d_model)
|
|
|
|
| 46 |
returns: (batch, num_labels)
|
| 47 |
"""
|
| 48 |
if self.pooler == "mean":
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
elif self.pooler == "cls":
|
| 51 |
pooled = x[:, 0, :]
|
| 52 |
else: # max
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
pooled, _ = x.max(dim=1)
|
| 54 |
pooled = self.dropout(pooled)
|
| 55 |
return self.out_proj(pooled)
|
|
|
|
| 40 |
self.dropout = nn.Dropout(dropout)
|
| 41 |
self.out_proj = nn.Linear(d_model, num_labels)
|
| 42 |
|
| 43 |
+
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 44 |
"""
|
| 45 |
x: (batch, seq_len, d_model)
|
| 46 |
+
mask: (batch, seq_len) - True for valid tokens, False for padding
|
| 47 |
returns: (batch, num_labels)
|
| 48 |
"""
|
| 49 |
if self.pooler == "mean":
|
| 50 |
+
if mask is not None:
|
| 51 |
+
# mask is (B, S)
|
| 52 |
+
# x is (B, S, D)
|
| 53 |
+
# Expand mask to (B, S, 1)
|
| 54 |
+
mask_expanded = mask.unsqueeze(-1).float()
|
| 55 |
+
# Zero out padding
|
| 56 |
+
x = x * mask_expanded
|
| 57 |
+
# Sum over sequence
|
| 58 |
+
sum_embeddings = x.sum(dim=1)
|
| 59 |
+
# Count valid tokens
|
| 60 |
+
sum_mask = mask_expanded.sum(dim=1)
|
| 61 |
+
# Avoid division by zero
|
| 62 |
+
sum_mask = torch.clamp(sum_mask, min=1e-9)
|
| 63 |
+
pooled = sum_embeddings / sum_mask
|
| 64 |
+
else:
|
| 65 |
+
pooled = x.mean(dim=1)
|
| 66 |
elif self.pooler == "cls":
|
| 67 |
pooled = x[:, 0, :]
|
| 68 |
else: # max
|
| 69 |
+
if mask is not None:
|
| 70 |
+
# Mask padding with -inf
|
| 71 |
+
mask_expanded = mask.unsqueeze(-1)
|
| 72 |
+
x = x.masked_fill(~mask_expanded, float("-inf"))
|
| 73 |
pooled, _ = x.max(dim=1)
|
| 74 |
pooled = self.dropout(pooled)
|
| 75 |
return self.out_proj(pooled)
|
src/models/multitask.py
CHANGED
|
@@ -104,10 +104,15 @@ class MultiTaskModel(nn.Module):
|
|
| 104 |
raise KeyError(f"Unknown task/head '{task}'")
|
| 105 |
|
| 106 |
head = self.heads[task]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
loss_kwargs = loss_kwargs or {}
|
| 108 |
|
| 109 |
# Encoder-only heads expect encoder outputs
|
| 110 |
-
if isinstance(
|
| 111 |
if self.encoder is None:
|
| 112 |
raise RuntimeError("Encoder is required for encoder-side heads")
|
| 113 |
# accept either input_ids or embeddings
|
|
@@ -129,18 +134,23 @@ class MultiTaskModel(nn.Module):
|
|
| 129 |
raise ValueError(
|
| 130 |
"inputs must contain 'input_ids' or 'embeddings' for encoder tasks"
|
| 131 |
)
|
| 132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
|
| 134 |
if return_loss:
|
| 135 |
labels = inputs.get("labels", None)
|
| 136 |
if labels is None:
|
| 137 |
raise ValueError("return_loss=True requires 'labels' in inputs")
|
| 138 |
-
loss = self.compute_loss_for_head(
|
| 139 |
return loss, logits
|
| 140 |
return logits
|
| 141 |
|
| 142 |
# LM/seq2seq head: run encoder -> decoder -> lm head
|
| 143 |
-
if isinstance(
|
| 144 |
if self.encoder is None or self.decoder is None:
|
| 145 |
raise RuntimeError("Both encoder and decoder are required for LM-style heads")
|
| 146 |
|
|
@@ -164,6 +174,11 @@ class MultiTaskModel(nn.Module):
|
|
| 164 |
"inputs must contain 'src_ids' or 'src_embeddings' for seq2seq tasks"
|
| 165 |
)
|
| 166 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
# If training / teacher forcing: expect tgt_ids (shifted by caller) or embeddings
|
| 168 |
if "tgt_ids" in inputs:
|
| 169 |
decoder_inputs = inputs["tgt_ids"]
|
|
@@ -191,12 +206,12 @@ class MultiTaskModel(nn.Module):
|
|
| 191 |
labels = inputs.get("labels", None)
|
| 192 |
if labels is None:
|
| 193 |
raise ValueError("return_loss=True requires 'labels' in inputs for seq2seq")
|
| 194 |
-
loss = self.compute_loss_for_head(
|
| 195 |
return loss, logits
|
| 196 |
return logits
|
| 197 |
|
| 198 |
# Otherwise unsupported head type
|
| 199 |
-
raise RuntimeError(f"Unsupported head type: {type(
|
| 200 |
|
| 201 |
def compute_loss_for_head(
|
| 202 |
self,
|
|
|
|
| 104 |
raise KeyError(f"Unknown task/head '{task}'")
|
| 105 |
|
| 106 |
head = self.heads[task]
|
| 107 |
+
# Unwrap for type checking if compiled
|
| 108 |
+
check_head = head
|
| 109 |
+
if hasattr(head, "_orig_mod"):
|
| 110 |
+
check_head = head._orig_mod
|
| 111 |
+
|
| 112 |
loss_kwargs = loss_kwargs or {}
|
| 113 |
|
| 114 |
# Encoder-only heads expect encoder outputs
|
| 115 |
+
if isinstance(check_head, (ClassificationHead, TokenClassificationHead)):
|
| 116 |
if self.encoder is None:
|
| 117 |
raise RuntimeError("Encoder is required for encoder-side heads")
|
| 118 |
# accept either input_ids or embeddings
|
|
|
|
| 134 |
raise ValueError(
|
| 135 |
"inputs must contain 'input_ids' or 'embeddings' for encoder tasks"
|
| 136 |
)
|
| 137 |
+
|
| 138 |
+
# Pass attention_mask to head if available (needed for mean pooling to ignore padding)
|
| 139 |
+
if isinstance(check_head, ClassificationHead):
|
| 140 |
+
logits = head(enc_out, mask=inputs.get("attention_mask"))
|
| 141 |
+
else:
|
| 142 |
+
logits = head(enc_out)
|
| 143 |
|
| 144 |
if return_loss:
|
| 145 |
labels = inputs.get("labels", None)
|
| 146 |
if labels is None:
|
| 147 |
raise ValueError("return_loss=True requires 'labels' in inputs")
|
| 148 |
+
loss = self.compute_loss_for_head(check_head, logits, labels, **loss_kwargs)
|
| 149 |
return loss, logits
|
| 150 |
return logits
|
| 151 |
|
| 152 |
# LM/seq2seq head: run encoder -> decoder -> lm head
|
| 153 |
+
if isinstance(check_head, LMHead):
|
| 154 |
if self.encoder is None or self.decoder is None:
|
| 155 |
raise RuntimeError("Both encoder and decoder are required for LM-style heads")
|
| 156 |
|
|
|
|
| 174 |
"inputs must contain 'src_ids' or 'src_embeddings' for seq2seq tasks"
|
| 175 |
)
|
| 176 |
|
| 177 |
+
# Clone memory to prevent CUDA Graph buffer overwrites when passing between compiled graphs
|
| 178 |
+
# This fixes "accessing tensor output of CUDAGraphs that has been overwritten" error
|
| 179 |
+
if isinstance(memory, torch.Tensor):
|
| 180 |
+
memory = memory.clone()
|
| 181 |
+
|
| 182 |
# If training / teacher forcing: expect tgt_ids (shifted by caller) or embeddings
|
| 183 |
if "tgt_ids" in inputs:
|
| 184 |
decoder_inputs = inputs["tgt_ids"]
|
|
|
|
| 206 |
labels = inputs.get("labels", None)
|
| 207 |
if labels is None:
|
| 208 |
raise ValueError("return_loss=True requires 'labels' in inputs for seq2seq")
|
| 209 |
+
loss = self.compute_loss_for_head(check_head, logits, labels, **loss_kwargs)
|
| 210 |
return loss, logits
|
| 211 |
return logits
|
| 212 |
|
| 213 |
# Otherwise unsupported head type
|
| 214 |
+
raise RuntimeError(f"Unsupported head type: {type(check_head)}")
|
| 215 |
|
| 216 |
def compute_loss_for_head(
|
| 217 |
self,
|
src/models/positional_encoding.py
CHANGED
|
@@ -76,3 +76,40 @@ class PositionalEncoding(nn.Module):
|
|
| 76 |
# self.pe contains pre-computed encodings for all positions
|
| 77 |
# just need to add the first seq_len positions to x
|
| 78 |
return self.dropout(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
# self.pe contains pre-computed encodings for all positions
|
| 77 |
# just need to add the first seq_len positions to x
|
| 78 |
return self.dropout(x)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class LearnedPositionalEncoding(nn.Module):
|
| 82 |
+
"""
|
| 83 |
+
Learned positional embeddings (used by BERT, GPT, etc.).
|
| 84 |
+
|
| 85 |
+
Note: T5/FLAN-T5 uses relative position bias instead of absolute positional embeddings.
|
| 86 |
+
When loading from T5, the model uses learned positional encodings that train from scratch.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
d_model: Dimension of the model embeddings
|
| 90 |
+
max_len: Maximum sequence length
|
| 91 |
+
dropout: Dropout probability
|
| 92 |
+
padding_idx: Index of padding token (used to mask out padding positions if needed)
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
def __init__(
|
| 96 |
+
self, d_model: int, max_len: int = 1024, dropout: float = 0.1, padding_idx: int = 1
|
| 97 |
+
):
|
| 98 |
+
super().__init__()
|
| 99 |
+
# Standard learned positional embeddings.
|
| 100 |
+
# Note: T5's relative position bias is NOT transferred - we train these from scratch.
|
| 101 |
+
self.embeddings = nn.Embedding(max_len, d_model)
|
| 102 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 103 |
+
|
| 104 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 105 |
+
"""
|
| 106 |
+
Args:
|
| 107 |
+
x: Input embeddings (batch, seq_len, d_model)
|
| 108 |
+
"""
|
| 109 |
+
seq_len = x.size(1)
|
| 110 |
+
positions = torch.arange(seq_len, dtype=torch.long, device=x.device)
|
| 111 |
+
# Broadcast to batch
|
| 112 |
+
positions = positions.unsqueeze(0).expand(x.size(0), -1)
|
| 113 |
+
|
| 114 |
+
pos_embeds = self.embeddings(positions)
|
| 115 |
+
return self.dropout(x + pos_embeds)
|
src/training/trainer.py
CHANGED
|
@@ -28,6 +28,7 @@ class TrainerConfig:
|
|
| 28 |
label_smoothing: float = 0.0 # Label smoothing for regularization (e.g., 0.1)
|
| 29 |
experiment_name: str = "LexiMind"
|
| 30 |
run_name: str | None = None
|
|
|
|
| 31 |
|
| 32 |
|
| 33 |
class Trainer:
|
|
@@ -51,10 +52,13 @@ class Trainer:
|
|
| 51 |
# Apply label smoothing to summarization task if configured
|
| 52 |
self.label_smoothing = config.label_smoothing
|
| 53 |
self._progress_last_len = 0
|
|
|
|
|
|
|
| 54 |
|
| 55 |
# Mixed Precision Training
|
| 56 |
# Initialize GradScaler for float16/bfloat16 training
|
| 57 |
# This scales gradients to prevent underflow during backward pass
|
|
|
|
| 58 |
self.scaler = torch.GradScaler("cuda", enabled=(device.type == "cuda"))
|
| 59 |
|
| 60 |
# Initialize MLflow
|
|
@@ -181,24 +185,53 @@ class Trainer:
|
|
| 181 |
context = torch.enable_grad() if train else torch.no_grad()
|
| 182 |
with context:
|
| 183 |
for step in range(max_batches):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
backward_performed = False
|
| 185 |
step_total_loss = 0.0
|
| 186 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
for task, loader in loaders.items():
|
| 188 |
batch = self._next_batch(iterator_map, loader, task)
|
| 189 |
if batch is None:
|
| 190 |
continue
|
| 191 |
|
| 192 |
-
# Mixed Precision Context
|
| 193 |
-
# Using bfloat16 for my RTX 4070 (Ampere/Ada) - better stability than float16
|
| 194 |
with torch.autocast(
|
| 195 |
-
"cuda",
|
|
|
|
|
|
|
| 196 |
):
|
| 197 |
loss, task_metrics = self._forward_task(task, batch, train)
|
| 198 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
weight = self._task_weight(task)
|
| 200 |
-
|
| 201 |
-
|
|
|
|
| 202 |
|
| 203 |
metrics_accumulator[f"{task}_loss"].append(loss.item())
|
| 204 |
for metric_name, metric_value in task_metrics.items():
|
|
@@ -208,23 +241,39 @@ class Trainer:
|
|
| 208 |
# Scale loss before backward to prevent underflow
|
| 209 |
# We accumulate gradients from all tasks before stepping the optimizer
|
| 210 |
# This effectively minimizes the weighted sum of losses: L_total = w1*L1 + w2*L2 + ...
|
| 211 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
backward_performed = True
|
| 213 |
|
| 214 |
if backward_performed:
|
| 215 |
metrics_accumulator["total_loss"].append(step_total_loss)
|
| 216 |
|
| 217 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
# Unscale gradients before clipping
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
|
| 229 |
if (
|
| 230 |
train
|
|
@@ -360,6 +409,21 @@ class Trainer:
|
|
| 360 |
encoder_mask = src_mask.unsqueeze(1) & src_mask.unsqueeze(2)
|
| 361 |
memory = self.model.encoder(src_ids, mask=encoder_mask)
|
| 362 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
# Ban special tokens from generation
|
| 364 |
ban_token_ids = [self.tokenizer.bos_token_id, self.tokenizer.pad_token_id]
|
| 365 |
unk_id = getattr(self.tokenizer._tokenizer, "unk_token_id", None)
|
|
@@ -367,16 +431,13 @@ class Trainer:
|
|
| 367 |
ban_token_ids.append(unk_id)
|
| 368 |
ban_token_ids = [tid for tid in ban_token_ids if tid is not None]
|
| 369 |
|
| 370 |
-
# Generate
|
| 371 |
-
generated = self.model.decoder.
|
| 372 |
memory=memory,
|
| 373 |
max_len=self.config.validation_max_length,
|
| 374 |
start_token_id=self.tokenizer.bos_token_id,
|
| 375 |
end_token_id=self.tokenizer.eos_token_id,
|
| 376 |
device=self.device,
|
| 377 |
-
min_len=10,
|
| 378 |
-
ban_token_ids=ban_token_ids,
|
| 379 |
-
no_repeat_ngram_size=3,
|
| 380 |
memory_mask=src_mask,
|
| 381 |
)
|
| 382 |
|
|
@@ -386,6 +447,9 @@ class Trainer:
|
|
| 386 |
reference_text = self._decode_labels(labels)[0]
|
| 387 |
|
| 388 |
print(f"\nSample {samples_generated + 1}:")
|
|
|
|
|
|
|
|
|
|
| 389 |
print(
|
| 390 |
f"Source: {source_text[:200]}..."
|
| 391 |
if len(source_text) > 200
|
|
@@ -451,19 +515,24 @@ class Trainer:
|
|
| 451 |
total_elapsed = time.perf_counter() - global_start
|
| 452 |
if epochs_completed > 0:
|
| 453 |
remaining_epochs = max(total_epochs - epochs_completed, 0.0)
|
| 454 |
-
|
| 455 |
(total_elapsed / epochs_completed) * remaining_epochs if total_elapsed > 0 else 0.0
|
| 456 |
)
|
| 457 |
else:
|
| 458 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 459 |
bar = self._format_progress_bar(overall_progress, width=self._progress_bar_width())
|
| 460 |
message = (
|
| 461 |
f"[progress] {bar} {percent:5.1f}% "
|
| 462 |
f"e {epoch}/{total_epochs} "
|
| 463 |
f"s {bounded_step}/{total_steps} "
|
| 464 |
-
f"
|
| 465 |
-
f"
|
| 466 |
-
f"eta {self._format_duration(eta)}"
|
| 467 |
)
|
| 468 |
display = self._truncate_to_terminal(message)
|
| 469 |
padding = " " * max(self._progress_last_len - len(display), 0)
|
|
|
|
| 28 |
label_smoothing: float = 0.0 # Label smoothing for regularization (e.g., 0.1)
|
| 29 |
experiment_name: str = "LexiMind"
|
| 30 |
run_name: str | None = None
|
| 31 |
+
gradient_accumulation_steps: int = 1
|
| 32 |
|
| 33 |
|
| 34 |
class Trainer:
|
|
|
|
| 52 |
# Apply label smoothing to summarization task if configured
|
| 53 |
self.label_smoothing = config.label_smoothing
|
| 54 |
self._progress_last_len = 0
|
| 55 |
+
self.gradient_accumulation_steps = max(1, config.gradient_accumulation_steps)
|
| 56 |
+
self._nan_counter = 0 # Track consecutive NaNs
|
| 57 |
|
| 58 |
# Mixed Precision Training
|
| 59 |
# Initialize GradScaler for float16/bfloat16 training
|
| 60 |
# This scales gradients to prevent underflow during backward pass
|
| 61 |
+
# Note: bfloat16 generally doesn't need scaling, but we keep it for safety unless it causes NaNs
|
| 62 |
self.scaler = torch.GradScaler("cuda", enabled=(device.type == "cuda"))
|
| 63 |
|
| 64 |
# Initialize MLflow
|
|
|
|
| 185 |
context = torch.enable_grad() if train else torch.no_grad()
|
| 186 |
with context:
|
| 187 |
for step in range(max_batches):
|
| 188 |
+
# Mark step begin for CUDA Graphs (inductor) to handle memory reuse correctly
|
| 189 |
+
if (
|
| 190 |
+
train
|
| 191 |
+
and self.device.type == "cuda"
|
| 192 |
+
and hasattr(torch.compiler, "cudagraph_mark_step_begin")
|
| 193 |
+
):
|
| 194 |
+
torch.compiler.cudagraph_mark_step_begin()
|
| 195 |
+
|
| 196 |
backward_performed = False
|
| 197 |
step_total_loss = 0.0
|
| 198 |
|
| 199 |
+
# Mixed Precision Context
|
| 200 |
+
# Using bfloat16 for my RTX 4070 (Ampere/Ada) - better stability than float16
|
| 201 |
+
# Disable scaler for bfloat16 to prevent NaNs
|
| 202 |
+
use_bfloat16 = self.device.type == "cuda" and torch.cuda.is_bf16_supported()
|
| 203 |
+
|
| 204 |
for task, loader in loaders.items():
|
| 205 |
batch = self._next_batch(iterator_map, loader, task)
|
| 206 |
if batch is None:
|
| 207 |
continue
|
| 208 |
|
|
|
|
|
|
|
| 209 |
with torch.autocast(
|
| 210 |
+
"cuda",
|
| 211 |
+
dtype=torch.bfloat16 if use_bfloat16 else torch.float16,
|
| 212 |
+
enabled=(self.device.type == "cuda"),
|
| 213 |
):
|
| 214 |
loss, task_metrics = self._forward_task(task, batch, train)
|
| 215 |
|
| 216 |
+
if torch.isnan(loss):
|
| 217 |
+
if train:
|
| 218 |
+
self._nan_counter += 1
|
| 219 |
+
print(
|
| 220 |
+
f"Warning: NaN loss detected for task '{task}'. Skipping update for this task. (Consecutive NaNs: {self._nan_counter})"
|
| 221 |
+
)
|
| 222 |
+
if self._nan_counter > 10:
|
| 223 |
+
raise RuntimeError(
|
| 224 |
+
"Too many consecutive NaN losses. Training is diverging."
|
| 225 |
+
)
|
| 226 |
+
continue
|
| 227 |
+
else:
|
| 228 |
+
if train:
|
| 229 |
+
self._nan_counter = 0
|
| 230 |
+
|
| 231 |
weight = self._task_weight(task)
|
| 232 |
+
# Scale loss by gradient accumulation steps
|
| 233 |
+
weighted_loss = (loss * weight) / self.gradient_accumulation_steps
|
| 234 |
+
step_total_loss += weighted_loss.item() * self.gradient_accumulation_steps
|
| 235 |
|
| 236 |
metrics_accumulator[f"{task}_loss"].append(loss.item())
|
| 237 |
for metric_name, metric_value in task_metrics.items():
|
|
|
|
| 241 |
# Scale loss before backward to prevent underflow
|
| 242 |
# We accumulate gradients from all tasks before stepping the optimizer
|
| 243 |
# This effectively minimizes the weighted sum of losses: L_total = w1*L1 + w2*L2 + ...
|
| 244 |
+
if use_bfloat16:
|
| 245 |
+
# bfloat16 doesn't need scaling and it can cause NaNs
|
| 246 |
+
weighted_loss.backward()
|
| 247 |
+
else:
|
| 248 |
+
self.scaler.scale(weighted_loss).backward()
|
| 249 |
backward_performed = True
|
| 250 |
|
| 251 |
if backward_performed:
|
| 252 |
metrics_accumulator["total_loss"].append(step_total_loss)
|
| 253 |
|
| 254 |
+
# Perform optimizer step only after accumulating enough gradients
|
| 255 |
+
if (
|
| 256 |
+
train
|
| 257 |
+
and backward_performed
|
| 258 |
+
and (step + 1) % self.gradient_accumulation_steps == 0
|
| 259 |
+
):
|
| 260 |
# Unscale gradients before clipping
|
| 261 |
+
if use_bfloat16:
|
| 262 |
+
torch.nn.utils.clip_grad_norm_(
|
| 263 |
+
self.model.parameters(), self.config.gradient_clip_norm
|
| 264 |
+
)
|
| 265 |
+
self.optimizer.step()
|
| 266 |
+
self.optimizer.zero_grad()
|
| 267 |
+
else:
|
| 268 |
+
self.scaler.unscale_(self.optimizer)
|
| 269 |
+
torch.nn.utils.clip_grad_norm_(
|
| 270 |
+
self.model.parameters(), self.config.gradient_clip_norm
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
# Step optimizer using scaler
|
| 274 |
+
self.scaler.step(self.optimizer)
|
| 275 |
+
self.scaler.update()
|
| 276 |
+
self.optimizer.zero_grad()
|
| 277 |
|
| 278 |
if (
|
| 279 |
train
|
|
|
|
| 409 |
encoder_mask = src_mask.unsqueeze(1) & src_mask.unsqueeze(2)
|
| 410 |
memory = self.model.encoder(src_ids, mask=encoder_mask)
|
| 411 |
|
| 412 |
+
# DEBUG: Check encoder output statistics
|
| 413 |
+
if samples_generated == 0:
|
| 414 |
+
print("\n[DEBUG] Encoder output stats:")
|
| 415 |
+
print(f" Shape: {memory.shape}")
|
| 416 |
+
print(f" Mean: {memory.mean().item():.6f}")
|
| 417 |
+
print(f" Std: {memory.std().item():.6f}")
|
| 418 |
+
print(f" Min: {memory.min().item():.6f}")
|
| 419 |
+
print(f" Max: {memory.max().item():.6f}")
|
| 420 |
+
print(f" Has NaN: {torch.isnan(memory).any().item()}")
|
| 421 |
+
print(f" Has Inf: {torch.isinf(memory).any().item()}")
|
| 422 |
+
|
| 423 |
+
# Check first few positions
|
| 424 |
+
print(f" First position norm: {memory[0, 0].norm().item():.4f}")
|
| 425 |
+
print(f" Last position norm: {memory[0, -1].norm().item():.4f}")
|
| 426 |
+
|
| 427 |
# Ban special tokens from generation
|
| 428 |
ban_token_ids = [self.tokenizer.bos_token_id, self.tokenizer.pad_token_id]
|
| 429 |
unk_id = getattr(self.tokenizer._tokenizer, "unk_token_id", None)
|
|
|
|
| 431 |
ban_token_ids.append(unk_id)
|
| 432 |
ban_token_ids = [tid for tid in ban_token_ids if tid is not None]
|
| 433 |
|
| 434 |
+
# Generate using naive method (full forward, O(N^2)) for debugging
|
| 435 |
+
generated = self.model.decoder.greedy_decode_naive(
|
| 436 |
memory=memory,
|
| 437 |
max_len=self.config.validation_max_length,
|
| 438 |
start_token_id=self.tokenizer.bos_token_id,
|
| 439 |
end_token_id=self.tokenizer.eos_token_id,
|
| 440 |
device=self.device,
|
|
|
|
|
|
|
|
|
|
| 441 |
memory_mask=src_mask,
|
| 442 |
)
|
| 443 |
|
|
|
|
| 447 |
reference_text = self._decode_labels(labels)[0]
|
| 448 |
|
| 449 |
print(f"\nSample {samples_generated + 1}:")
|
| 450 |
+
print(
|
| 451 |
+
f"Raw token IDs: {generated[0][:20].tolist()}..."
|
| 452 |
+
) # Debug: show first 20 tokens
|
| 453 |
print(
|
| 454 |
f"Source: {source_text[:200]}..."
|
| 455 |
if len(source_text) > 200
|
|
|
|
| 515 |
total_elapsed = time.perf_counter() - global_start
|
| 516 |
if epochs_completed > 0:
|
| 517 |
remaining_epochs = max(total_epochs - epochs_completed, 0.0)
|
| 518 |
+
total_eta = (
|
| 519 |
(total_elapsed / epochs_completed) * remaining_epochs if total_elapsed > 0 else 0.0
|
| 520 |
)
|
| 521 |
else:
|
| 522 |
+
total_eta = 0.0
|
| 523 |
+
|
| 524 |
+
if step > 0:
|
| 525 |
+
epoch_eta = (epoch_elapsed / step) * (total_steps - step)
|
| 526 |
+
else:
|
| 527 |
+
epoch_eta = 0.0
|
| 528 |
+
|
| 529 |
bar = self._format_progress_bar(overall_progress, width=self._progress_bar_width())
|
| 530 |
message = (
|
| 531 |
f"[progress] {bar} {percent:5.1f}% "
|
| 532 |
f"e {epoch}/{total_epochs} "
|
| 533 |
f"s {bounded_step}/{total_steps} "
|
| 534 |
+
f"ep_eta {self._format_duration(epoch_eta)} "
|
| 535 |
+
f"tot_eta {self._format_duration(total_eta)}"
|
|
|
|
| 536 |
)
|
| 537 |
display = self._truncate_to_terminal(message)
|
| 538 |
padding = " " * max(self._progress_last_len - len(display), 0)
|
src/utils/io.py
CHANGED
|
@@ -8,9 +8,24 @@ import torch
|
|
| 8 |
def save_state(model: torch.nn.Module, path: str) -> None:
|
| 9 |
destination = Path(path)
|
| 10 |
destination.parent.mkdir(parents=True, exist_ok=True)
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
def load_state(model: torch.nn.Module, path: str) -> None:
|
| 15 |
state = torch.load(path, map_location="cpu", weights_only=True)
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
def save_state(model: torch.nn.Module, path: str) -> None:
|
| 9 |
destination = Path(path)
|
| 10 |
destination.parent.mkdir(parents=True, exist_ok=True)
|
| 11 |
+
|
| 12 |
+
# Handle torch.compile artifacts: strip '_orig_mod.' prefix
|
| 13 |
+
state_dict = model.state_dict()
|
| 14 |
+
clean_state_dict = {}
|
| 15 |
+
for k, v in state_dict.items():
|
| 16 |
+
new_k = k.replace("_orig_mod.", "")
|
| 17 |
+
clean_state_dict[new_k] = v
|
| 18 |
+
|
| 19 |
+
torch.save(clean_state_dict, destination)
|
| 20 |
|
| 21 |
|
| 22 |
def load_state(model: torch.nn.Module, path: str) -> None:
|
| 23 |
state = torch.load(path, map_location="cpu", weights_only=True)
|
| 24 |
+
|
| 25 |
+
# Handle torch.compile artifacts in loaded checkpoints
|
| 26 |
+
clean_state = {}
|
| 27 |
+
for k, v in state.items():
|
| 28 |
+
new_k = k.replace("_orig_mod.", "")
|
| 29 |
+
clean_state[new_k] = v
|
| 30 |
+
|
| 31 |
+
model.load_state_dict(clean_state)
|
tests/test_models/test_attention.py
CHANGED
|
@@ -11,49 +11,54 @@ from src.models.attention import MultiHeadAttention, ScaledDotProductAttention
|
|
| 11 |
|
| 12 |
|
| 13 |
class TestScaledDotProductAttention:
|
| 14 |
-
"""Test suite for ScaledDotProductAttention.
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
def test_output_shape(self):
|
| 17 |
"""Test that output shapes are correct."""
|
| 18 |
attention = ScaledDotProductAttention()
|
| 19 |
-
batch_size, seq_len, d_k = 2, 10, 64
|
| 20 |
|
| 21 |
-
Q = torch.randn(batch_size, seq_len, d_k)
|
| 22 |
-
K = torch.randn(batch_size, seq_len, d_k)
|
| 23 |
-
V = torch.randn(batch_size, seq_len, d_k)
|
| 24 |
|
| 25 |
output, weights = attention(Q, K, V, return_attn_weights=True)
|
| 26 |
|
| 27 |
-
assert output.shape == (batch_size, seq_len, d_k)
|
| 28 |
-
assert weights.shape == (batch_size, seq_len, seq_len)
|
| 29 |
|
| 30 |
def test_attention_weights_sum_to_one(self):
|
| 31 |
"""Test that attention weights are a valid probability distribution."""
|
| 32 |
attention = ScaledDotProductAttention()
|
| 33 |
-
batch_size, seq_len, d_k = 2, 10, 64
|
| 34 |
|
| 35 |
-
Q = K = V = torch.randn(batch_size, seq_len, d_k)
|
| 36 |
_, weights = attention(Q, K, V, return_attn_weights=True)
|
| 37 |
|
| 38 |
# Each row should sum to 1 (probability distribution over keys)
|
| 39 |
row_sums = weights.sum(dim=-1)
|
| 40 |
-
assert torch.allclose(row_sums, torch.ones(batch_size, seq_len), atol=1e-6)
|
| 41 |
|
| 42 |
def test_masking(self):
|
| 43 |
"""Test that masking properly zeros out attention to masked positions."""
|
| 44 |
attention = ScaledDotProductAttention()
|
| 45 |
-
batch_size, seq_len, d_k = 1, 5, 64
|
| 46 |
|
| 47 |
-
Q = K = V = torch.randn(batch_size, seq_len, d_k)
|
| 48 |
|
| 49 |
-
# Create mask: only attend to first 3 positions
|
| 50 |
-
mask = torch.zeros(batch_size, seq_len, seq_len, dtype=torch.bool)
|
| 51 |
-
mask[:, :, :3] = True
|
| 52 |
|
| 53 |
_, weights = attention(Q, K, V, mask, return_attn_weights=True)
|
| 54 |
|
| 55 |
-
#
|
| 56 |
-
assert torch.allclose(
|
|
|
|
|
|
|
| 57 |
|
| 58 |
# TODO: Add more tests as you understand the mechanism better
|
| 59 |
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
class TestScaledDotProductAttention:
|
| 14 |
+
"""Test suite for ScaledDotProductAttention.
|
| 15 |
+
|
| 16 |
+
Note: ScaledDotProductAttention expects 4D inputs: (batch, num_heads, seq, d_k)
|
| 17 |
+
"""
|
| 18 |
|
| 19 |
def test_output_shape(self):
|
| 20 |
"""Test that output shapes are correct."""
|
| 21 |
attention = ScaledDotProductAttention()
|
| 22 |
+
batch_size, num_heads, seq_len, d_k = 2, 8, 10, 64
|
| 23 |
|
| 24 |
+
Q = torch.randn(batch_size, num_heads, seq_len, d_k)
|
| 25 |
+
K = torch.randn(batch_size, num_heads, seq_len, d_k)
|
| 26 |
+
V = torch.randn(batch_size, num_heads, seq_len, d_k)
|
| 27 |
|
| 28 |
output, weights = attention(Q, K, V, return_attn_weights=True)
|
| 29 |
|
| 30 |
+
assert output.shape == (batch_size, num_heads, seq_len, d_k)
|
| 31 |
+
assert weights.shape == (batch_size, num_heads, seq_len, seq_len)
|
| 32 |
|
| 33 |
def test_attention_weights_sum_to_one(self):
|
| 34 |
"""Test that attention weights are a valid probability distribution."""
|
| 35 |
attention = ScaledDotProductAttention()
|
| 36 |
+
batch_size, num_heads, seq_len, d_k = 2, 4, 10, 64
|
| 37 |
|
| 38 |
+
Q = K = V = torch.randn(batch_size, num_heads, seq_len, d_k)
|
| 39 |
_, weights = attention(Q, K, V, return_attn_weights=True)
|
| 40 |
|
| 41 |
# Each row should sum to 1 (probability distribution over keys)
|
| 42 |
row_sums = weights.sum(dim=-1)
|
| 43 |
+
assert torch.allclose(row_sums, torch.ones(batch_size, num_heads, seq_len), atol=1e-6)
|
| 44 |
|
| 45 |
def test_masking(self):
|
| 46 |
"""Test that masking properly zeros out attention to masked positions."""
|
| 47 |
attention = ScaledDotProductAttention()
|
| 48 |
+
batch_size, num_heads, seq_len, d_k = 1, 4, 5, 64
|
| 49 |
|
| 50 |
+
Q = K = V = torch.randn(batch_size, num_heads, seq_len, d_k)
|
| 51 |
|
| 52 |
+
# Create mask: only attend to first 3 positions (4D mask)
|
| 53 |
+
mask = torch.zeros(batch_size, 1, seq_len, seq_len, dtype=torch.bool)
|
| 54 |
+
mask[:, :, :, :3] = True # Attend to first 3 key positions
|
| 55 |
|
| 56 |
_, weights = attention(Q, K, V, mask, return_attn_weights=True)
|
| 57 |
|
| 58 |
+
# Key positions 3 and 4 should have zero attention weight
|
| 59 |
+
assert torch.allclose(
|
| 60 |
+
weights[:, :, :, 3:], torch.zeros(batch_size, num_heads, seq_len, 2), atol=1e-6
|
| 61 |
+
)
|
| 62 |
|
| 63 |
# TODO: Add more tests as you understand the mechanism better
|
| 64 |
|