Upload README.md
Browse files
README.md
CHANGED
|
@@ -1,3 +1,285 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# SmolMixtral - Mixtral Inspired Model
|
| 3 |
+
|
| 4 |
+
A PyTorch implementation of a Mixtral inspired transformer model with Mixture of Experts (MoE), designed for text generation and understanding tasks. This model is built on the Mixtral architecture with enhancements like Flash Attention, SWiGLU activation, and Liger kernels for optimized performance.
|
| 5 |
+
|
| 6 |
+
- So, I trained a MoE based a 124M (8x12M) architecture I coded from ground up.
|
| 7 |
+
- Trained on TinyStories dataset from HuggingFace consisting of 1M texts for a total of 14000 steps
|
| 8 |
+
|
| 9 |
+
## Examples
|
| 10 |
+
|
| 11 |
+
Provided under the `generated_data/` directory, these examples showcase the model's capabilities in text generation and understanding.
|
| 12 |
+
|
| 13 |
+

|
| 14 |
+
|
| 15 |
+
## π Training Results & Model Weights
|
| 16 |
+
|
| 17 |
+
**π View Training Report**: [SmolMixtral Training Results on WandB](https://wandb.ai/rentio/Mixtral-DDP-Pretrain-10-billion-tokens/reports/SmolMixtral--VmlldzoxMzYyNzc0OQ?accessToken=nybd4lxybsbq5k5fh2dqjcucdawilt3fossn583wv6jiu8tbdzcybiihe7rhsqmq)
|
| 18 |
+
|
| 19 |
+
**πΎ Download Pre-trained Weights**:
|
| 20 |
+
- **Hugging Face Model**: [YuvrajSingh9886/SmolMixtral](https://huggingface.co/YuvrajSingh9886/SmolMixtral)
|
| 21 |
+
- **WandB Checkpoints**: Check the WandB report above for additional trained model checkpoints
|
| 22 |
+
|
| 23 |
+
## Features
|
| 24 |
+
|
| 25 |
+
- **Flash Attention**: Efficient attention mechanism with memory optimization
|
| 26 |
+
- **Mixture of Experts (MoE)**: 8 experts with top-2 routing and noisy top-k support
|
| 27 |
+
- **SWiGLU Activation**: Advanced activation function in expert layers
|
| 28 |
+
- **Rotary Positional Embeddings**: Position encoding for sequence understanding
|
| 29 |
+
- **Liger Kernels**: Optimized kernels for faster training (optional)
|
| 30 |
+
- **Distributed Training**: Support for multi-GPU training with DDP
|
| 31 |
+
- **Advanced Optimizer**: AdamW optimizer with custom learning rate scheduling
|
| 32 |
+
- **Gradio Interface**: Interactive web interface for text generation
|
| 33 |
+
|
| 34 |
+
## Model Architecture
|
| 35 |
+
|
| 36 |
+
### Default Configuration
|
| 37 |
+
- **Embedding Dimensions**: 512
|
| 38 |
+
- **Decoder Layers**: 8
|
| 39 |
+
- **Attention Heads**: 8
|
| 40 |
+
- **MoE Experts**: 8 (top-2 routing)
|
| 41 |
+
- **Block Size**: 1024 tokens
|
| 42 |
+
- **Vocabulary Size**: Based on Llama-2-7b tokenizer (~32,000 tokens)
|
| 43 |
+
- **Batch Size**: 16
|
| 44 |
+
|
| 45 |
+
### Full Parameter List
|
| 46 |
+
|
| 47 |
+
#### Model Architecture Parameters
|
| 48 |
+
- `epochs`: Number of training epochs (default: 4)
|
| 49 |
+
- `block_size`: Maximum sequence length (default: 1024)
|
| 50 |
+
- `batch_size`: Training batch size (default: 16)
|
| 51 |
+
- `embeddings_dims`: Model embedding dimensions (default: 512)
|
| 52 |
+
- `no_of_heads`: Number of attention heads (default: 8)
|
| 53 |
+
- `no_of_decoder_layers`: Number of decoder layers (default: 8)
|
| 54 |
+
- `attn_dropout`: Attention dropout rate (default: 0.1)
|
| 55 |
+
- `dropout`: General dropout rate (default: 0.1)
|
| 56 |
+
|
| 57 |
+
#### Mixture of Experts (MoE) Parameters
|
| 58 |
+
- `experts`: Number of MoE experts (default: 8)
|
| 59 |
+
- `top_experts`: Number of experts to route to (default: 2)
|
| 60 |
+
- `noisy_topk`: Use noisy top-k routing (default: False)
|
| 61 |
+
|
| 62 |
+
#### Training Hyperparameters
|
| 63 |
+
- `max_lr`: Maximum learning rate (default: 6e-4)
|
| 64 |
+
- `weight_decay_optim`: Weight decay for optimizer (default: 0.01)
|
| 65 |
+
- `beta_1`: Beta1 for optimizer (default: 0.9)
|
| 66 |
+
- `beta_2`: Beta2 for optimizer (default: 0.95)
|
| 67 |
+
- `eps`: Epsilon for optimizer (default: 1e-8)
|
| 68 |
+
- `clip`: Gradient clipping value (default: 1.0)
|
| 69 |
+
|
| 70 |
+
#### System Configuration
|
| 71 |
+
- `device`: Device to use (default: 'cuda:9')
|
| 72 |
+
- `use_checkpointing`: Use gradient checkpointing (default: False)
|
| 73 |
+
- `use_liger`: Use Liger kernels for optimization (default: True)
|
| 74 |
+
- `use_flash_attention`: Use Flash Attention (default: True)
|
| 75 |
+
- `use_compile`: Use torch.compile (default: True)
|
| 76 |
+
|
| 77 |
+
#### Data Configuration
|
| 78 |
+
- `vocab_size`: Vocabulary size (default: based on tokenizer + 768)
|
| 79 |
+
- `val_epochs`: Validation frequency (default: 2)
|
| 80 |
+
|
| 81 |
+
## Quick Start
|
| 82 |
+
|
| 83 |
+
### Installation
|
| 84 |
+
|
| 85 |
+
```bash
|
| 86 |
+
chmod +x install.sh
|
| 87 |
+
./install.sh
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
### Important: Hugging Face Token Setup
|
| 91 |
+
|
| 92 |
+
Since this model uses the Llama-2 tokenizer, you'll need a Hugging Face token to access the gated model.
|
| 93 |
+
|
| 94 |
+
1. **Get a Hugging Face Token:**
|
| 95 |
+
- Go to [Hugging Face Settings](https://huggingface.co/settings/tokens)
|
| 96 |
+
- Create a new token with "Read" permissions
|
| 97 |
+
- Accept the Llama-2 license at [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf)
|
| 98 |
+
|
| 99 |
+
2. **Set your token in config.py:**
|
| 100 |
+
```python
|
| 101 |
+
TOKEN = 'your_token_here'
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
### Using Pre-trained Weights
|
| 105 |
+
|
| 106 |
+
1. **Download Model Weights**:
|
| 107 |
+
- **Option 1**: Download from [Hugging Face - YuvrajSingh9886/SmolMixtral](https://huggingface.co/YuvrajSingh9886/SmolMixtral)
|
| 108 |
+
- **Option 2**: Visit the [WandB Training Report](https://wandb.ai/rentio/Mixtral-DDP-Pretrain-10-billion-tokens) for additional checkpoints
|
| 109 |
+
- Place downloaded files in the `checkpoints/` directory
|
| 110 |
+
|
| 111 |
+
2. **Load Pre-trained Model for Inference**:
|
| 112 |
+
```bash
|
| 113 |
+
# Using the Gradio web interface
|
| 114 |
+
cd gradio
|
| 115 |
+
python app.py
|
| 116 |
+
|
| 117 |
+
# Or use in your own code
|
| 118 |
+
python inference.py
|
| 119 |
+
```
|
| 120 |
+
|
| 121 |
+
### Training Examples
|
| 122 |
+
|
| 123 |
+
#### Basic Training (Single GPU)
|
| 124 |
+
```bash
|
| 125 |
+
python trainer.py
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
#### Training with Custom Parameters
|
| 129 |
+
```bash
|
| 130 |
+
# Train with larger model (modify config.py)
|
| 131 |
+
python trainer.py
|
| 132 |
+
|
| 133 |
+
# Train with different dataset (modify data.py)
|
| 134 |
+
python trainer.py
|
| 135 |
+
```
|
| 136 |
+
|
| 137 |
+
#### Multi-GPU Distributed Training
|
| 138 |
+
```bash
|
| 139 |
+
# 2 GPUs
|
| 140 |
+
torchrun --nproc_per_node=2 trainer.py
|
| 141 |
+
|
| 142 |
+
# 4 GPUs
|
| 143 |
+
torchrun --nproc_per_node=4 trainer.py
|
| 144 |
+
|
| 145 |
+
# 8 GPUs
|
| 146 |
+
torchrun --nproc_per_node=8 trainer.py
|
| 147 |
+
```
|
| 148 |
+
|
| 149 |
+
### Inference with Gradio
|
| 150 |
+
|
| 151 |
+
**HF_TOKEN** should be set in `config.py` to use the Gradio interface. Moreover, set your token as follows:
|
| 152 |
+
|
| 153 |
+
```python
|
| 154 |
+
export HF_TOKEN=<TOKEN_HERE>
|
| 155 |
+
```
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
```bash
|
| 159 |
+
# Run the Gradio app
|
| 160 |
+
cd gradio
|
| 161 |
+
python app.py
|
| 162 |
+
|
| 163 |
+
# With custom checkpoint (edit app.py to point to your checkpoint)
|
| 164 |
+
cd gradio
|
| 165 |
+
python app.py
|
| 166 |
+
```
|
| 167 |
+
|
| 168 |
+
## File Structure
|
| 169 |
+
|
| 170 |
+
```
|
| 171 |
+
SmolMixtral/
|
| 172 |
+
βββ config.py # Model configuration and hyperparameters
|
| 173 |
+
βββ model.py # Model architecture (Mixtral, MoE, Attention, etc.)
|
| 174 |
+
βββ data.py # Data loading and preparation
|
| 175 |
+
βββ inference.py # Inference functions and text generation
|
| 176 |
+
βββ trainer.py # Main training loop with DDP support
|
| 177 |
+
βββ install.sh # Setup script
|
| 178 |
+
βββ requirements.txt # Python dependencies
|
| 179 |
+
βββ model_summary.py # Model architecture summary
|
| 180 |
+
βββ gradio/
|
| 181 |
+
β βββ app.py # Gradio web interface
|
| 182 |
+
βββ checkpoints/ # Model checkpoints
|
| 183 |
+
βββ generated_data/ # Generated text outputs
|
| 184 |
+
βββ images/ # Project images
|
| 185 |
+
βββ old/ # Original files
|
| 186 |
+
```
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
## Training Features
|
| 191 |
+
|
| 192 |
+
- **Gradient Accumulation**: Configurable batch size scaling
|
| 193 |
+
- **Learning Rate Scheduling**: Cosine decay with warmup
|
| 194 |
+
- **Gradient Clipping**: Prevents gradient explosion
|
| 195 |
+
- **Wandb Integration**: Experiment tracking and logging
|
| 196 |
+
- **Checkpointing**: Regular model checkpoints during training
|
| 197 |
+
- **Loss Calculation**: Optimized cross-entropy with padding token handling
|
| 198 |
+
- **Distributed Training**: Multi-GPU support with DDP
|
| 199 |
+
- **Memory Optimization**: Gradient checkpointing support
|
| 200 |
+
|
| 201 |
+
## Generation Methods
|
| 202 |
+
|
| 203 |
+
1. **Top-k Sampling**: Traditional sampling with temperature control
|
| 204 |
+
|
| 205 |
+
## Advanced Usage
|
| 206 |
+
|
| 207 |
+
### Configuration
|
| 208 |
+
All parameters can be configured by modifying `config.py`:
|
| 209 |
+
|
| 210 |
+
```python
|
| 211 |
+
@dataclass
|
| 212 |
+
class ModelArgs:
|
| 213 |
+
epochs = 4
|
| 214 |
+
block_size = 1024
|
| 215 |
+
batch_size = 16
|
| 216 |
+
embeddings_dims = 512
|
| 217 |
+
# ... other parameters
|
| 218 |
+
```
|
| 219 |
+
|
| 220 |
+
### Custom Dataset Training
|
| 221 |
+
Modify `data.py` to use different datasets:
|
| 222 |
+
```python
|
| 223 |
+
# TinyStories (default)
|
| 224 |
+
tinystories = True
|
| 225 |
+
fw = False
|
| 226 |
+
|
| 227 |
+
# FineWeb
|
| 228 |
+
tinystories = False
|
| 229 |
+
fw = True
|
| 230 |
+
```
|
| 231 |
+
|
| 232 |
+
### Monitoring and Logging
|
| 233 |
+
Training automatically logs to WandB with project name "Mixtral-DDP-Pretrain-10-billion-tokens"
|
| 234 |
+
|
| 235 |
+
## Performance Tips
|
| 236 |
+
|
| 237 |
+
1. **Use Liger Kernels**: Keep `use_liger = True` for optimized operations
|
| 238 |
+
2. **Flash Attention**: Keep `use_flash_attention = True` for memory efficiency
|
| 239 |
+
3. **Gradient Checkpointing**: Use `use_checkpointing = True` for memory-constrained setups
|
| 240 |
+
4. **Batch Size Tuning**: Start with smaller batch sizes and increase gradually
|
| 241 |
+
5. **Block Size**: Larger block sizes improve quality but require more memory
|
| 242 |
+
|
| 243 |
+
## Troubleshooting
|
| 244 |
+
|
| 245 |
+
### Common Issues
|
| 246 |
+
|
| 247 |
+
#### Authentication Error (401)
|
| 248 |
+
```bash
|
| 249 |
+
# Make sure you have accepted the Llama-2 license and have a valid token
|
| 250 |
+
# Visit: https://huggingface.co/meta-llama/Llama-2-7b-hf
|
| 251 |
+
# Then set your token in config.py
|
| 252 |
+
```
|
| 253 |
+
|
| 254 |
+
#### Out of Memory (OOM)
|
| 255 |
+
```python
|
| 256 |
+
# Reduce batch size and enable checkpointing in config.py
|
| 257 |
+
batch_size = 8
|
| 258 |
+
use_checkpointing = True
|
| 259 |
+
```
|
| 260 |
+
|
| 261 |
+
#### Slow Training
|
| 262 |
+
```python
|
| 263 |
+
# Enable optimizations in config.py
|
| 264 |
+
use_liger = True
|
| 265 |
+
use_flash_attention = True
|
| 266 |
+
use_compile = True
|
| 267 |
+
```
|
| 268 |
+
|
| 269 |
+
## Contributing
|
| 270 |
+
|
| 271 |
+
Feel free to contribute improvements, bug fixes, or new features!
|
| 272 |
+
|
| 273 |
+
## Requirements
|
| 274 |
+
|
| 275 |
+
- Python 3.8+
|
| 276 |
+
- PyTorch 2.0+
|
| 277 |
+
- Transformers
|
| 278 |
+
- Datasets
|
| 279 |
+
- Gradio
|
| 280 |
+
- Wandb
|
| 281 |
+
- Liger-kernel (optional)
|
| 282 |
+
|
| 283 |
+
## License
|
| 284 |
+
|
| 285 |
+
MIT License
|