SmolMixtral - Mixtral Inspired Model
A PyTorch implementation of a Mixtral inspired transformer model with Mixture of Experts (MoE), designed for text generation and understanding tasks. This model is built on the Mixtral architecture with enhancements like Flash Attention, SWiGLU activation, and Liger kernels for optimized performance.
- So, I trained a MoE based a 124M (8x12M) architecture I coded from ground up.
- Trained on TinyStories dataset from HuggingFace consisting of 1M texts for a total of 14000 steps
Examples
Provided under the generated_data/ directory, these examples showcase the model's capabilities in text generation and understanding.
π Training Results & Model Weights
π View Training Report: SmolMixtral Training Results on WandB
πΎ Download Pre-trained Weights:
- Hugging Face Model: YuvrajSingh9886/SmolMixtral
- WandB Checkpoints: Check the WandB report above for additional trained model checkpoints
Features
- Flash Attention: Efficient attention mechanism with memory optimization
- Mixture of Experts (MoE): 8 experts with top-2 routing and noisy top-k support
- SWiGLU Activation: Advanced activation function in expert layers
- Rotary Positional Embeddings: Position encoding for sequence understanding
- Liger Kernels: Optimized kernels for faster training (optional)
- Distributed Training: Support for multi-GPU training with DDP
- Advanced Optimizer: AdamW optimizer with custom learning rate scheduling
- Gradio Interface: Interactive web interface for text generation
Model Architecture
Default Configuration
- Embedding Dimensions: 512
- Decoder Layers: 8
- Attention Heads: 8
- MoE Experts: 8 (top-2 routing)
- Block Size: 1024 tokens
- Vocabulary Size: Based on Llama-2-7b tokenizer (~32,000 tokens)
- Batch Size: 16
Full Parameter List
Model Architecture Parameters
epochs: Number of training epochs (default: 4)block_size: Maximum sequence length (default: 1024)batch_size: Training batch size (default: 16)embeddings_dims: Model embedding dimensions (default: 512)no_of_heads: Number of attention heads (default: 8)no_of_decoder_layers: Number of decoder layers (default: 8)attn_dropout: Attention dropout rate (default: 0.1)dropout: General dropout rate (default: 0.1)
Mixture of Experts (MoE) Parameters
experts: Number of MoE experts (default: 8)top_experts: Number of experts to route to (default: 2)noisy_topk: Use noisy top-k routing (default: False)
Training Hyperparameters
max_lr: Maximum learning rate (default: 6e-4)weight_decay_optim: Weight decay for optimizer (default: 0.01)beta_1: Beta1 for optimizer (default: 0.9)beta_2: Beta2 for optimizer (default: 0.95)eps: Epsilon for optimizer (default: 1e-8)clip: Gradient clipping value (default: 1.0)
System Configuration
device: Device to use (default: 'cuda:9')use_checkpointing: Use gradient checkpointing (default: False)use_liger: Use Liger kernels for optimization (default: True)use_flash_attention: Use Flash Attention (default: True)use_compile: Use torch.compile (default: True)
Data Configuration
vocab_size: Vocabulary size (default: based on tokenizer + 768)val_epochs: Validation frequency (default: 2)
Quick Start
Installation
chmod +x install.sh
./install.sh
Important: Hugging Face Token Setup
Since this model uses the Llama-2 tokenizer, you'll need a Hugging Face token to access the gated model.
Get a Hugging Face Token:
- Go to Hugging Face Settings
- Create a new token with "Read" permissions
- Accept the Llama-2 license at meta-llama/Llama-2-7b-hf
Set your token in config.py:
TOKEN = 'your_token_here'
Using Pre-trained Weights
Download Model Weights:
- Option 1: Download from Hugging Face - YuvrajSingh9886/SmolMixtral
- Option 2: Visit the WandB Training Report for additional checkpoints
- Place downloaded files in the
checkpoints/directory
Load Pre-trained Model for Inference:
# Using the Gradio web interface cd gradio python app.py # Or use in your own code python inference.py
Training Examples
Basic Training (Single GPU)
python trainer.py
Training with Custom Parameters
# Train with larger model (modify config.py)
python trainer.py
# Train with different dataset (modify data.py)
python trainer.py
Multi-GPU Distributed Training
# 2 GPUs
torchrun --nproc_per_node=2 trainer.py
# 4 GPUs
torchrun --nproc_per_node=4 trainer.py
# 8 GPUs
torchrun --nproc_per_node=8 trainer.py
Inference with Gradio
HF_TOKEN should be set in config.py to use the Gradio interface. Moreover, set your token as follows:
export HF_TOKEN=<TOKEN_HERE>
# Run the Gradio app
cd gradio
python app.py
# With custom checkpoint (edit app.py to point to your checkpoint)
cd gradio
python app.py
File Structure
SmolMixtral/
βββ config.py # Model configuration and hyperparameters
βββ model.py # Model architecture (Mixtral, MoE, Attention, etc.)
βββ data.py # Data loading and preparation
βββ inference.py # Inference functions and text generation
βββ trainer.py # Main training loop with DDP support
βββ install.sh # Setup script
βββ requirements.txt # Python dependencies
βββ model_summary.py # Model architecture summary
βββ gradio/
β βββ app.py # Gradio web interface
βββ checkpoints/ # Model checkpoints
βββ generated_data/ # Generated text outputs
βββ images/ # Project images
βββ old/ # Original files
Training Features
- Gradient Accumulation: Configurable batch size scaling
- Learning Rate Scheduling: Cosine decay with warmup
- Gradient Clipping: Prevents gradient explosion
- Wandb Integration: Experiment tracking and logging
- Checkpointing: Regular model checkpoints during training
- Loss Calculation: Optimized cross-entropy with padding token handling
- Distributed Training: Multi-GPU support with DDP
- Memory Optimization: Gradient checkpointing support
Generation Methods
- Top-k Sampling: Traditional sampling with temperature control
Advanced Usage
Configuration
All parameters can be configured by modifying config.py:
@dataclass
class ModelArgs:
epochs = 4
block_size = 1024
batch_size = 16
embeddings_dims = 512
# ... other parameters
Custom Dataset Training
Modify data.py to use different datasets:
# TinyStories (default)
tinystories = True
fw = False
# FineWeb
tinystories = False
fw = True
Monitoring and Logging
Training automatically logs to WandB with project name "Mixtral-DDP-Pretrain-10-billion-tokens"
Performance Tips
- Use Liger Kernels: Keep
use_liger = Truefor optimized operations - Flash Attention: Keep
use_flash_attention = Truefor memory efficiency - Gradient Checkpointing: Use
use_checkpointing = Truefor memory-constrained setups - Batch Size Tuning: Start with smaller batch sizes and increase gradually
- Block Size: Larger block sizes improve quality but require more memory
Troubleshooting
Common Issues
Authentication Error (401)
# Make sure you have accepted the Llama-2 license and have a valid token
# Visit: https://huggingface.co/meta-llama/Llama-2-7b-hf
# Then set your token in config.py
Out of Memory (OOM)
# Reduce batch size and enable checkpointing in config.py
batch_size = 8
use_checkpointing = True
Slow Training
# Enable optimizations in config.py
use_liger = True
use_flash_attention = True
use_compile = True
Contributing
Feel free to contribute improvements, bug fixes, or new features!
Requirements
- Python 3.8+
- PyTorch 2.0+
- Transformers
- Datasets
- Gradio
- Wandb
- Liger-kernel (optional)
License
MIT License
