YuvrajSingh9886 commited on
Commit
7e93e32
Β·
verified Β·
1 Parent(s): 866ba64

Upload README.md

Browse files
Files changed (1) hide show
  1. README.md +285 -3
README.md CHANGED
@@ -1,3 +1,285 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # SmolMixtral - Mixtral Inspired Model
3
+
4
+ A PyTorch implementation of a Mixtral inspired transformer model with Mixture of Experts (MoE), designed for text generation and understanding tasks. This model is built on the Mixtral architecture with enhancements like Flash Attention, SWiGLU activation, and Liger kernels for optimized performance.
5
+
6
+ - So, I trained a MoE based a 124M (8x12M) architecture I coded from ground up.
7
+ - Trained on TinyStories dataset from HuggingFace consisting of 1M texts for a total of 14000 steps
8
+
9
+ ## Examples
10
+
11
+ Provided under the `generated_data/` directory, these examples showcase the model's capabilities in text generation and understanding.
12
+
13
+ ![SmolMixtral Model](images/loss.png)
14
+
15
+ ## πŸ“Š Training Results & Model Weights
16
+
17
+ **πŸ“ˆ View Training Report**: [SmolMixtral Training Results on WandB](https://wandb.ai/rentio/Mixtral-DDP-Pretrain-10-billion-tokens/reports/SmolMixtral--VmlldzoxMzYyNzc0OQ?accessToken=nybd4lxybsbq5k5fh2dqjcucdawilt3fossn583wv6jiu8tbdzcybiihe7rhsqmq)
18
+
19
+ **πŸ’Ύ Download Pre-trained Weights**:
20
+ - **Hugging Face Model**: [YuvrajSingh9886/SmolMixtral](https://huggingface.co/YuvrajSingh9886/SmolMixtral)
21
+ - **WandB Checkpoints**: Check the WandB report above for additional trained model checkpoints
22
+
23
+ ## Features
24
+
25
+ - **Flash Attention**: Efficient attention mechanism with memory optimization
26
+ - **Mixture of Experts (MoE)**: 8 experts with top-2 routing and noisy top-k support
27
+ - **SWiGLU Activation**: Advanced activation function in expert layers
28
+ - **Rotary Positional Embeddings**: Position encoding for sequence understanding
29
+ - **Liger Kernels**: Optimized kernels for faster training (optional)
30
+ - **Distributed Training**: Support for multi-GPU training with DDP
31
+ - **Advanced Optimizer**: AdamW optimizer with custom learning rate scheduling
32
+ - **Gradio Interface**: Interactive web interface for text generation
33
+
34
+ ## Model Architecture
35
+
36
+ ### Default Configuration
37
+ - **Embedding Dimensions**: 512
38
+ - **Decoder Layers**: 8
39
+ - **Attention Heads**: 8
40
+ - **MoE Experts**: 8 (top-2 routing)
41
+ - **Block Size**: 1024 tokens
42
+ - **Vocabulary Size**: Based on Llama-2-7b tokenizer (~32,000 tokens)
43
+ - **Batch Size**: 16
44
+
45
+ ### Full Parameter List
46
+
47
+ #### Model Architecture Parameters
48
+ - `epochs`: Number of training epochs (default: 4)
49
+ - `block_size`: Maximum sequence length (default: 1024)
50
+ - `batch_size`: Training batch size (default: 16)
51
+ - `embeddings_dims`: Model embedding dimensions (default: 512)
52
+ - `no_of_heads`: Number of attention heads (default: 8)
53
+ - `no_of_decoder_layers`: Number of decoder layers (default: 8)
54
+ - `attn_dropout`: Attention dropout rate (default: 0.1)
55
+ - `dropout`: General dropout rate (default: 0.1)
56
+
57
+ #### Mixture of Experts (MoE) Parameters
58
+ - `experts`: Number of MoE experts (default: 8)
59
+ - `top_experts`: Number of experts to route to (default: 2)
60
+ - `noisy_topk`: Use noisy top-k routing (default: False)
61
+
62
+ #### Training Hyperparameters
63
+ - `max_lr`: Maximum learning rate (default: 6e-4)
64
+ - `weight_decay_optim`: Weight decay for optimizer (default: 0.01)
65
+ - `beta_1`: Beta1 for optimizer (default: 0.9)
66
+ - `beta_2`: Beta2 for optimizer (default: 0.95)
67
+ - `eps`: Epsilon for optimizer (default: 1e-8)
68
+ - `clip`: Gradient clipping value (default: 1.0)
69
+
70
+ #### System Configuration
71
+ - `device`: Device to use (default: 'cuda:9')
72
+ - `use_checkpointing`: Use gradient checkpointing (default: False)
73
+ - `use_liger`: Use Liger kernels for optimization (default: True)
74
+ - `use_flash_attention`: Use Flash Attention (default: True)
75
+ - `use_compile`: Use torch.compile (default: True)
76
+
77
+ #### Data Configuration
78
+ - `vocab_size`: Vocabulary size (default: based on tokenizer + 768)
79
+ - `val_epochs`: Validation frequency (default: 2)
80
+
81
+ ## Quick Start
82
+
83
+ ### Installation
84
+
85
+ ```bash
86
+ chmod +x install.sh
87
+ ./install.sh
88
+ ```
89
+
90
+ ### Important: Hugging Face Token Setup
91
+
92
+ Since this model uses the Llama-2 tokenizer, you'll need a Hugging Face token to access the gated model.
93
+
94
+ 1. **Get a Hugging Face Token:**
95
+ - Go to [Hugging Face Settings](https://huggingface.co/settings/tokens)
96
+ - Create a new token with "Read" permissions
97
+ - Accept the Llama-2 license at [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf)
98
+
99
+ 2. **Set your token in config.py:**
100
+ ```python
101
+ TOKEN = 'your_token_here'
102
+ ```
103
+
104
+ ### Using Pre-trained Weights
105
+
106
+ 1. **Download Model Weights**:
107
+ - **Option 1**: Download from [Hugging Face - YuvrajSingh9886/SmolMixtral](https://huggingface.co/YuvrajSingh9886/SmolMixtral)
108
+ - **Option 2**: Visit the [WandB Training Report](https://wandb.ai/rentio/Mixtral-DDP-Pretrain-10-billion-tokens) for additional checkpoints
109
+ - Place downloaded files in the `checkpoints/` directory
110
+
111
+ 2. **Load Pre-trained Model for Inference**:
112
+ ```bash
113
+ # Using the Gradio web interface
114
+ cd gradio
115
+ python app.py
116
+
117
+ # Or use in your own code
118
+ python inference.py
119
+ ```
120
+
121
+ ### Training Examples
122
+
123
+ #### Basic Training (Single GPU)
124
+ ```bash
125
+ python trainer.py
126
+ ```
127
+
128
+ #### Training with Custom Parameters
129
+ ```bash
130
+ # Train with larger model (modify config.py)
131
+ python trainer.py
132
+
133
+ # Train with different dataset (modify data.py)
134
+ python trainer.py
135
+ ```
136
+
137
+ #### Multi-GPU Distributed Training
138
+ ```bash
139
+ # 2 GPUs
140
+ torchrun --nproc_per_node=2 trainer.py
141
+
142
+ # 4 GPUs
143
+ torchrun --nproc_per_node=4 trainer.py
144
+
145
+ # 8 GPUs
146
+ torchrun --nproc_per_node=8 trainer.py
147
+ ```
148
+
149
+ ### Inference with Gradio
150
+
151
+ **HF_TOKEN** should be set in `config.py` to use the Gradio interface. Moreover, set your token as follows:
152
+
153
+ ```python
154
+ export HF_TOKEN=<TOKEN_HERE>
155
+ ```
156
+
157
+
158
+ ```bash
159
+ # Run the Gradio app
160
+ cd gradio
161
+ python app.py
162
+
163
+ # With custom checkpoint (edit app.py to point to your checkpoint)
164
+ cd gradio
165
+ python app.py
166
+ ```
167
+
168
+ ## File Structure
169
+
170
+ ```
171
+ SmolMixtral/
172
+ β”œβ”€β”€ config.py # Model configuration and hyperparameters
173
+ β”œβ”€β”€ model.py # Model architecture (Mixtral, MoE, Attention, etc.)
174
+ β”œβ”€β”€ data.py # Data loading and preparation
175
+ β”œβ”€β”€ inference.py # Inference functions and text generation
176
+ β”œβ”€β”€ trainer.py # Main training loop with DDP support
177
+ β”œβ”€β”€ install.sh # Setup script
178
+ β”œβ”€β”€ requirements.txt # Python dependencies
179
+ β”œβ”€β”€ model_summary.py # Model architecture summary
180
+ β”œβ”€β”€ gradio/
181
+ β”‚ └── app.py # Gradio web interface
182
+ β”œβ”€β”€ checkpoints/ # Model checkpoints
183
+ β”œβ”€β”€ generated_data/ # Generated text outputs
184
+ β”œβ”€β”€ images/ # Project images
185
+ └── old/ # Original files
186
+ ```
187
+
188
+
189
+
190
+ ## Training Features
191
+
192
+ - **Gradient Accumulation**: Configurable batch size scaling
193
+ - **Learning Rate Scheduling**: Cosine decay with warmup
194
+ - **Gradient Clipping**: Prevents gradient explosion
195
+ - **Wandb Integration**: Experiment tracking and logging
196
+ - **Checkpointing**: Regular model checkpoints during training
197
+ - **Loss Calculation**: Optimized cross-entropy with padding token handling
198
+ - **Distributed Training**: Multi-GPU support with DDP
199
+ - **Memory Optimization**: Gradient checkpointing support
200
+
201
+ ## Generation Methods
202
+
203
+ 1. **Top-k Sampling**: Traditional sampling with temperature control
204
+
205
+ ## Advanced Usage
206
+
207
+ ### Configuration
208
+ All parameters can be configured by modifying `config.py`:
209
+
210
+ ```python
211
+ @dataclass
212
+ class ModelArgs:
213
+ epochs = 4
214
+ block_size = 1024
215
+ batch_size = 16
216
+ embeddings_dims = 512
217
+ # ... other parameters
218
+ ```
219
+
220
+ ### Custom Dataset Training
221
+ Modify `data.py` to use different datasets:
222
+ ```python
223
+ # TinyStories (default)
224
+ tinystories = True
225
+ fw = False
226
+
227
+ # FineWeb
228
+ tinystories = False
229
+ fw = True
230
+ ```
231
+
232
+ ### Monitoring and Logging
233
+ Training automatically logs to WandB with project name "Mixtral-DDP-Pretrain-10-billion-tokens"
234
+
235
+ ## Performance Tips
236
+
237
+ 1. **Use Liger Kernels**: Keep `use_liger = True` for optimized operations
238
+ 2. **Flash Attention**: Keep `use_flash_attention = True` for memory efficiency
239
+ 3. **Gradient Checkpointing**: Use `use_checkpointing = True` for memory-constrained setups
240
+ 4. **Batch Size Tuning**: Start with smaller batch sizes and increase gradually
241
+ 5. **Block Size**: Larger block sizes improve quality but require more memory
242
+
243
+ ## Troubleshooting
244
+
245
+ ### Common Issues
246
+
247
+ #### Authentication Error (401)
248
+ ```bash
249
+ # Make sure you have accepted the Llama-2 license and have a valid token
250
+ # Visit: https://huggingface.co/meta-llama/Llama-2-7b-hf
251
+ # Then set your token in config.py
252
+ ```
253
+
254
+ #### Out of Memory (OOM)
255
+ ```python
256
+ # Reduce batch size and enable checkpointing in config.py
257
+ batch_size = 8
258
+ use_checkpointing = True
259
+ ```
260
+
261
+ #### Slow Training
262
+ ```python
263
+ # Enable optimizations in config.py
264
+ use_liger = True
265
+ use_flash_attention = True
266
+ use_compile = True
267
+ ```
268
+
269
+ ## Contributing
270
+
271
+ Feel free to contribute improvements, bug fixes, or new features!
272
+
273
+ ## Requirements
274
+
275
+ - Python 3.8+
276
+ - PyTorch 2.0+
277
+ - Transformers
278
+ - Datasets
279
+ - Gradio
280
+ - Wandb
281
+ - Liger-kernel (optional)
282
+
283
+ ## License
284
+
285
+ MIT License