feat: Sync training infrastructure from main repository
Browse files- README.md +40 -75
- app.py +959 -158
- requirements.txt +44 -19
- training/data_loader.py +480 -0
- training/evaluate_model.py +782 -0
- training/model.py +641 -0
- training/train_model.py +657 -0
- training/train_tokenizer.py +429 -0
README.md
CHANGED
|
@@ -1,75 +1,40 @@
|
|
| 1 |
-
---
|
| 2 |
-
title: OpenLLM Training Space
|
| 3 |
-
emoji: π
|
| 4 |
-
colorFrom: blue
|
| 5 |
-
colorTo: purple
|
| 6 |
-
sdk: gradio
|
| 7 |
-
sdk_version: 4.44.
|
| 8 |
-
app_file: app.py
|
| 9 |
-
pinned: false
|
| 10 |
-
license: gpl-3.0
|
| 11 |
-
---
|
| 12 |
-
|
| 13 |
-
# OpenLLM Training Space
|
| 14 |
-
|
| 15 |
-
This space provides
|
| 16 |
-
|
| 17 |
-
## Features
|
| 18 |
-
|
| 19 |
-
- π―
|
| 20 |
-
- π
|
| 21 |
-
- π
|
| 22 |
-
- π
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
-
|
| 41 |
-
|
| 42 |
-
### Expected Results:
|
| 43 |
-
|
| 44 |
-
- **Training Time**: 10-30 minutes for 1000 steps (depending on HF Space resources)
|
| 45 |
-
- **Output Model**: `lemms/openllm-small-extended-8k` (or other sizes)
|
| 46 |
-
- **Model Files**: Complete PyTorch model with tokenizer and configuration
|
| 47 |
-
|
| 48 |
-
## Model Repositories
|
| 49 |
-
|
| 50 |
-
- [π 7k Small Model](https://huggingface.co/lemms/openllm-small-extended-7k)
|
| 51 |
-
- [π― 8k Small Model](https://huggingface.co/lemms/openllm-small-extended-8k)
|
| 52 |
-
- [π Training Dataset](https://huggingface.co/datasets/lemms/openllm-training-data)
|
| 53 |
-
|
| 54 |
-
## Technical Details
|
| 55 |
-
|
| 56 |
-
- **Framework**: PyTorch with Transformers
|
| 57 |
-
- **UI**: Gradio 4.44.1 (latest stable version)
|
| 58 |
-
- **Training**: Mixed precision (FP16) for efficiency
|
| 59 |
-
- **Memory**: Optimized for HF Spaces with gradient accumulation
|
| 60 |
-
- **Dependencies**: Complete ML stack with all training utilities
|
| 61 |
-
|
| 62 |
-
## Usage
|
| 63 |
-
|
| 64 |
-
1. **Configure Parameters**: Set model size, steps, learning rate, and batch size
|
| 65 |
-
2. **Start Training**: Click "Start Training" to begin the complete pipeline
|
| 66 |
-
3. **Monitor Progress**: Watch real-time status updates and training progress
|
| 67 |
-
4. **Access Results**: Find your trained model in the HF Hub repository
|
| 68 |
-
|
| 69 |
-
## License
|
| 70 |
-
|
| 71 |
-
GPL-3.0 - See [LICENSE](LICENSE) for details.
|
| 72 |
-
|
| 73 |
-
## Author
|
| 74 |
-
|
| 75 |
-
**Louis Chua Bean Chong** - OpenLLM Project Maintainer
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: OpenLLM Training Space
|
| 3 |
+
emoji: π
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 4.44.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: gpl-3.0
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
# OpenLLM Training Space
|
| 14 |
+
|
| 15 |
+
This space provides training infrastructure for OpenLLM models.
|
| 16 |
+
|
| 17 |
+
## Features
|
| 18 |
+
|
| 19 |
+
- π― Model training pipeline
|
| 20 |
+
- π Training monitoring
|
| 21 |
+
- π Model versioning
|
| 22 |
+
- π Performance tracking
|
| 23 |
+
|
| 24 |
+
## Usage
|
| 25 |
+
|
| 26 |
+
1. Upload your training data
|
| 27 |
+
2. Configure training parameters
|
| 28 |
+
3. Start training
|
| 29 |
+
4. Monitor progress
|
| 30 |
+
5. Download trained models
|
| 31 |
+
|
| 32 |
+
## Model Repositories
|
| 33 |
+
|
| 34 |
+
- [openllm-small-extended-7k](https://huggingface.co/lemms/openllm-small-extended-7k)
|
| 35 |
+
- [openllm-small-extended-8k](https://huggingface.co/lemms/openllm-small-extended-8k)
|
| 36 |
+
- [openllm-training-data](https://huggingface.co/datasets/lemms/openllm-training-data)
|
| 37 |
+
|
| 38 |
+
## License
|
| 39 |
+
|
| 40 |
+
GPL-3.0 - See [LICENSE](LICENSE) for details.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
|
@@ -1,223 +1,1024 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
-
OpenLLM Training Space -
|
| 4 |
|
| 5 |
-
This
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
Author: Louis Chua Bean Chong
|
| 9 |
-
License:
|
|
|
|
|
|
|
| 10 |
"""
|
| 11 |
|
| 12 |
-
import os
|
| 13 |
-
import sys
|
| 14 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
from pathlib import Path
|
| 16 |
|
| 17 |
-
# Import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
try:
|
| 19 |
-
from
|
| 20 |
-
from
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
| 22 |
except ImportError as e:
|
| 23 |
-
|
| 24 |
-
print(
|
|
|
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
-
|
| 28 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
-
def
|
| 31 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
try:
|
| 33 |
-
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
except Exception as e:
|
| 52 |
-
return f"β
|
| 53 |
|
| 54 |
-
def
|
| 55 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
try:
|
| 57 |
-
|
| 58 |
-
|
|
|
|
| 59 |
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
-
|
|
|
|
|
|
|
| 73 |
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
except Exception as e:
|
| 77 |
-
return f"β
|
| 78 |
|
| 79 |
-
def
|
| 80 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
try:
|
| 82 |
-
#
|
| 83 |
-
|
| 84 |
-
|
|
|
|
| 85 |
|
| 86 |
-
|
| 87 |
-
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
|
|
|
| 106 |
|
| 107 |
-
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
-
return
|
| 110 |
|
| 111 |
except Exception as e:
|
| 112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
# π OpenLLM Training Space
|
| 118 |
|
| 119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
-
This
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
-
|
|
|
|
| 126 |
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
with gr.Tab("π Run Training"):
|
| 145 |
-
gr.Markdown("""
|
| 146 |
-
Start OpenLLM training with automatic model upload.
|
| 147 |
-
|
| 148 |
-
**Training Parameters:**
|
| 149 |
-
- **Model Size**: Choose the model size (small, medium, large)
|
| 150 |
-
- **Training Steps**: Number of training steps (default: 8000)
|
| 151 |
-
|
| 152 |
-
**Expected Results:**
|
| 153 |
-
- Training will complete successfully
|
| 154 |
-
- Model will be uploaded to Hugging Face Hub
|
| 155 |
-
- Repository will be created with proper model files
|
| 156 |
-
""")
|
| 157 |
-
|
| 158 |
-
with gr.Row():
|
| 159 |
model_size = gr.Dropdown(
|
| 160 |
choices=["small", "medium", "large"],
|
| 161 |
value="small",
|
| 162 |
label="Model Size",
|
| 163 |
-
info="
|
| 164 |
)
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
minimum=
|
| 170 |
-
maximum=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
)
|
| 172 |
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
|
| 190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
- **integrate_auth_into_training.py**: Integration guide
|
| 195 |
-
- **setup_hf_space_auth.py**: Space authentication setup
|
| 196 |
-
- **verify_space_auth.py**: Space verification script
|
| 197 |
|
| 198 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
|
| 205 |
-
|
|
|
|
|
|
|
| 206 |
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
|
| 213 |
-
return
|
| 214 |
-
|
| 215 |
|
| 216 |
if __name__ == "__main__":
|
| 217 |
-
#
|
| 218 |
-
interface
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
server_port=7860,
|
| 222 |
-
share=False
|
| 223 |
-
)
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
+
OpenLLM Training Space Application - Fixed with Uploaded Modules
|
| 4 |
|
| 5 |
+
This version imports OpenLLM modules from the uploaded files in the HF Space:
|
| 6 |
+
- Imports model.py and data_loader.py that were uploaded to the Space
|
| 7 |
+
- Uses OpenLLM's actual custom model architecture
|
| 8 |
+
- Compatible with OpenLLM's implementation
|
| 9 |
+
|
| 10 |
+
This application provides a complete training interface for OpenLLM models on Hugging Face Spaces.
|
| 11 |
+
It uses OpenLLM's custom GPTModel architecture instead of Hugging Face Transformers,
|
| 12 |
+
ensuring compatibility with the actual OpenLLM implementation.
|
| 13 |
+
|
| 14 |
+
Key Features:
|
| 15 |
+
- Real model training using OpenLLM's custom architecture
|
| 16 |
+
- SentencePiece tokenization for OpenLLM models
|
| 17 |
+
- Complete training pipeline with progress monitoring
|
| 18 |
+
- Automatic model saving and uploading to Hugging Face Hub
|
| 19 |
+
- Gradio 4.44.1 compatible user interface
|
| 20 |
+
|
| 21 |
+
Technical Architecture:
|
| 22 |
+
- Uses OpenLLM's GPTModel class (not Hugging Face Transformers)
|
| 23 |
+
- Imports custom modules from uploaded files in the Space
|
| 24 |
+
- Uses sentencepiece.SentencePieceProcessor() for tokenization
|
| 25 |
+
- Implements OpenLLM's training loop and optimization strategy
|
| 26 |
+
- Saves checkpoints in OpenLLM's format
|
| 27 |
|
| 28 |
Author: Louis Chua Bean Chong
|
| 29 |
+
License: GPL-3.0
|
| 30 |
+
Version: 2.1.1
|
| 31 |
+
Last Updated: 2024
|
| 32 |
"""
|
| 33 |
|
|
|
|
|
|
|
| 34 |
import gradio as gr
|
| 35 |
+
import torch
|
| 36 |
+
import torch.nn as nn
|
| 37 |
+
import os
|
| 38 |
+
import time
|
| 39 |
+
import math
|
| 40 |
+
import gc
|
| 41 |
+
from typing import Dict, Any, Optional
|
| 42 |
+
import threading
|
| 43 |
+
from dataclasses import dataclass
|
| 44 |
from pathlib import Path
|
| 45 |
|
| 46 |
+
# Import OpenLLM's custom model architecture from uploaded files
|
| 47 |
+
# These files were uploaded to the HF Space and contain OpenLLM's actual implementation
|
| 48 |
+
try:
|
| 49 |
+
# Import from the uploaded files in the HF Space
|
| 50 |
+
# model.py contains GPTModel, GPTConfig, and create_model factory function
|
| 51 |
+
from model import GPTModel, GPTConfig, create_model
|
| 52 |
+
# data_loader.py contains TextDataLoader for OpenLLM's data loading approach
|
| 53 |
+
from data_loader import TextDataLoader
|
| 54 |
+
OPENLLM_AVAILABLE = True
|
| 55 |
+
print("β
OpenLLM custom model architecture imported successfully from uploaded files")
|
| 56 |
+
print(" - GPTModel: Custom PyTorch model architecture")
|
| 57 |
+
print(" - GPTConfig: Model configuration dataclass")
|
| 58 |
+
print(" - create_model: Factory function for model creation")
|
| 59 |
+
print(" - TextDataLoader: Custom data loading implementation")
|
| 60 |
+
except ImportError as e:
|
| 61 |
+
print(f"β OpenLLM imports failed: {e}")
|
| 62 |
+
print(" This indicates the uploaded OpenLLM source files are not available")
|
| 63 |
+
print(" The training functionality will be disabled")
|
| 64 |
+
OPENLLM_AVAILABLE = False
|
| 65 |
+
|
| 66 |
+
# Try to import sentencepiece - CRITICAL for OpenLLM tokenization
|
| 67 |
+
# OpenLLM uses SentencePiece for tokenization, not Hugging Face tokenizers
|
| 68 |
+
try:
|
| 69 |
+
import sentencepiece as spm
|
| 70 |
+
SENTENCEPIECE_AVAILABLE = True
|
| 71 |
+
print(f"β
SentencePiece available: {spm.__version__}")
|
| 72 |
+
print(" - Required for OpenLLM tokenization")
|
| 73 |
+
print(" - Used for loading tokenizer.model files")
|
| 74 |
+
except ImportError:
|
| 75 |
+
SENTENCEPIECE_AVAILABLE = False
|
| 76 |
+
print("β SentencePiece not available")
|
| 77 |
+
print(" - This will prevent tokenizer loading")
|
| 78 |
+
print(" - Training functionality will be limited")
|
| 79 |
+
|
| 80 |
+
# Import other dependencies for the complete training pipeline
|
| 81 |
try:
|
| 82 |
+
from datasets import load_dataset # For loading training data from HF Hub
|
| 83 |
+
from huggingface_hub import HfApi, hf_hub_download # For model uploads and downloads
|
| 84 |
+
DEPENDENCIES_AVAILABLE = True
|
| 85 |
+
print("β
Training dependencies available")
|
| 86 |
+
print(" - datasets: For loading training data")
|
| 87 |
+
print(" - huggingface_hub: For model uploads/downloads")
|
| 88 |
except ImportError as e:
|
| 89 |
+
print(f"β Dependencies not available: {e}")
|
| 90 |
+
print(" - This will prevent dataset loading and model uploading")
|
| 91 |
+
DEPENDENCIES_AVAILABLE = False
|
| 92 |
|
| 93 |
+
@dataclass
|
| 94 |
+
class TrainingConfig:
|
| 95 |
+
"""
|
| 96 |
+
Configuration class for training parameters.
|
| 97 |
+
|
| 98 |
+
This dataclass encapsulates all the training hyperparameters and settings
|
| 99 |
+
that control the OpenLLM training process. It provides a clean interface
|
| 100 |
+
for passing configuration between different components of the training pipeline.
|
| 101 |
+
|
| 102 |
+
Attributes:
|
| 103 |
+
model_size: Size of the model to train ("small", "medium", "large")
|
| 104 |
+
max_steps: Maximum number of training iterations
|
| 105 |
+
learning_rate: Learning rate for the optimizer
|
| 106 |
+
batch_size: Number of samples per training batch
|
| 107 |
+
output_dir: Directory to save trained models and checkpoints
|
| 108 |
+
save_steps: Frequency of checkpoint saving (every N steps)
|
| 109 |
+
logging_steps: Frequency of progress logging (every N steps)
|
| 110 |
+
warmup_steps: Number of warmup steps for learning rate scheduling
|
| 111 |
+
gradient_accumulation_steps: Number of steps to accumulate gradients
|
| 112 |
+
"""
|
| 113 |
+
model_size: str
|
| 114 |
+
max_steps: int
|
| 115 |
+
learning_rate: float
|
| 116 |
+
batch_size: int
|
| 117 |
+
output_dir: str = "./openllm-trained"
|
| 118 |
+
save_steps: int = 100
|
| 119 |
+
logging_steps: int = 10
|
| 120 |
+
warmup_steps: int = 50
|
| 121 |
+
gradient_accumulation_steps: int = 4
|
| 122 |
|
| 123 |
+
class OpenLLMTrainer:
|
| 124 |
+
"""
|
| 125 |
+
Complete training implementation using OpenLLM's actual architecture.
|
| 126 |
+
|
| 127 |
+
This class handles the entire training pipeline including:
|
| 128 |
+
- Model loading using OpenLLM's custom GPTModel
|
| 129 |
+
- Tokenizer loading using sentencepiece.SentencePieceProcessor()
|
| 130 |
+
- Dataset preparation using OpenLLM's TextDataLoader
|
| 131 |
+
- Training execution using OpenLLM's approach
|
| 132 |
+
- Model saving and uploading to Hugging Face Hub
|
| 133 |
+
|
| 134 |
+
The trainer implements OpenLLM's actual training methodology rather than
|
| 135 |
+
using Hugging Face Transformers, ensuring compatibility with the real
|
| 136 |
+
OpenLLM implementation.
|
| 137 |
+
|
| 138 |
+
Key Features:
|
| 139 |
+
- Custom model architecture (GPTModel, not PreTrainedModel)
|
| 140 |
+
- SentencePiece tokenization (not Hugging Face tokenizers)
|
| 141 |
+
- OpenLLM's training loop and optimization strategy
|
| 142 |
+
- Gradient accumulation for memory efficiency
|
| 143 |
+
- Learning rate scheduling with warmup
|
| 144 |
+
- Automatic checkpoint saving and model uploading
|
| 145 |
+
"""
|
| 146 |
|
| 147 |
+
def __init__(self):
|
| 148 |
+
"""
|
| 149 |
+
Initialize the trainer with default settings.
|
| 150 |
+
|
| 151 |
+
Sets up the trainer with default values and initializes the Hugging Face
|
| 152 |
+
API for model uploading. All components start as None and are initialized
|
| 153 |
+
during the training process.
|
| 154 |
+
"""
|
| 155 |
+
# Core training components - initialized during training
|
| 156 |
+
self.model = None # OpenLLM's GPTModel instance
|
| 157 |
+
self.tokenizer = None # SentencePieceProcessor instance
|
| 158 |
+
self.data_loader = None # OpenLLM's TextDataLoader instance
|
| 159 |
+
self.optimizer = None # PyTorch optimizer (AdamW)
|
| 160 |
+
self.scheduler = None # Learning rate scheduler
|
| 161 |
+
|
| 162 |
+
# Training state management
|
| 163 |
+
self.is_training = False # Flag to track training status
|
| 164 |
+
self.tokenizer_path = None # Path to the tokenizer.model file
|
| 165 |
+
|
| 166 |
+
# Progress tracking for UI updates
|
| 167 |
+
self.training_progress = {
|
| 168 |
+
"status": "Ready", # Current training status
|
| 169 |
+
"current_step": 0, # Current training step
|
| 170 |
+
"total_steps": 0, # Total steps to complete
|
| 171 |
+
"loss": 0.0, # Current training loss
|
| 172 |
+
"learning_rate": 0.0 # Current learning rate
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
# Initialize Hugging Face API for model uploading
|
| 176 |
+
# This allows the trained model to be automatically uploaded to HF Hub
|
| 177 |
try:
|
| 178 |
+
self.hf_api = HfApi()
|
| 179 |
+
print("β
Hugging Face API initialized for model uploading")
|
| 180 |
+
except Exception as e:
|
| 181 |
+
print(f"Failed to initialize HF API: {e}")
|
| 182 |
+
print(" - Model uploading will be disabled")
|
| 183 |
+
self.hf_api = None
|
| 184 |
+
|
| 185 |
+
def load_model_and_tokenizer(self, model_size: str) -> str:
|
| 186 |
+
"""
|
| 187 |
+
Load the pre-trained OpenLLM model and tokenizer using OpenLLM's approach.
|
| 188 |
+
|
| 189 |
+
This method implements OpenLLM's actual model loading strategy:
|
| 190 |
+
1. Creates a new GPTModel using OpenLLM's factory function
|
| 191 |
+
2. Downloads the tokenizer.model file from Hugging Face Hub
|
| 192 |
+
3. Loads the tokenizer using SentencePieceProcessor
|
| 193 |
+
4. Stores both components for use in training
|
| 194 |
+
|
| 195 |
+
This approach differs from Hugging Face Transformers because:
|
| 196 |
+
- Uses OpenLLM's custom GPTModel (not AutoModelForCausalLM)
|
| 197 |
+
- Uses SentencePiece directly (not AutoTokenizer)
|
| 198 |
+
- Downloads specific files rather than using from_pretrained()
|
| 199 |
+
|
| 200 |
+
Args:
|
| 201 |
+
model_size: Size of the model to load ("small", "medium", "large")
|
| 202 |
+
Determines which pre-trained model to download
|
| 203 |
|
| 204 |
+
Returns:
|
| 205 |
+
Status message indicating success or failure
|
| 206 |
+
Success: "β
Successfully loaded OpenLLM {model_size} model with custom architecture"
|
| 207 |
+
Failure: "β Failed to load OpenLLM model and tokenizer: {error details}"
|
| 208 |
+
"""
|
| 209 |
+
try:
|
| 210 |
+
# Verify OpenLLM modules are available
|
| 211 |
+
if not OPENLLM_AVAILABLE:
|
| 212 |
+
return "β OpenLLM custom model architecture not available"
|
| 213 |
|
| 214 |
+
print(f"π Loading OpenLLM {model_size} model using custom architecture...")
|
| 215 |
+
print(f" - Using OpenLLM's create_model factory function")
|
| 216 |
+
print(f" - Not using Hugging Face Transformers")
|
| 217 |
|
| 218 |
+
# Step 1: Create model using OpenLLM's factory function
|
| 219 |
+
# This creates a fresh GPTModel instance with the specified size
|
| 220 |
+
try:
|
| 221 |
+
self.model = create_model(model_size)
|
| 222 |
+
print(f"β
OpenLLM {model_size} model created: {type(self.model).__name__}")
|
| 223 |
+
print(f" - Model type: {type(self.model).__name__}")
|
| 224 |
+
print(f" - Parameters: {self.model.get_num_params():,}")
|
| 225 |
+
print(f" - Architecture: Custom GPTModel (not PreTrainedModel)")
|
| 226 |
+
except Exception as e:
|
| 227 |
+
print(f"β Failed to create model: {e}")
|
| 228 |
+
return f"β Failed to create OpenLLM model: {str(e)}"
|
| 229 |
|
| 230 |
+
# Step 2: Load tokenizer using sentencepiece
|
| 231 |
+
# OpenLLM uses SentencePiece directly, not Hugging Face tokenizers
|
| 232 |
+
try:
|
| 233 |
+
print("π Loading tokenizer using sentencepiece.SentencePieceProcessor()...")
|
| 234 |
+
print(" - Using SentencePiece directly (not AutoTokenizer)")
|
| 235 |
+
print(" - Downloading tokenizer.model from Hugging Face Hub")
|
| 236 |
+
|
| 237 |
+
# Download tokenizer.model from HF Hub
|
| 238 |
+
# This is the actual tokenizer file used by OpenLLM models
|
| 239 |
+
model_name = f"lemms/openllm-{model_size}-extended-7k"
|
| 240 |
+
tokenizer_path = hf_hub_download(
|
| 241 |
+
repo_id=model_name,
|
| 242 |
+
filename="tokenizer.model" # Specific file name for OpenLLM
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
print(f"β
Tokenizer downloaded to: {tokenizer_path}")
|
| 246 |
+
print(f" - Source: {model_name}")
|
| 247 |
+
print(f" - File: tokenizer.model")
|
| 248 |
+
|
| 249 |
+
# Create SentencePieceProcessor and load the tokenizer
|
| 250 |
+
# This is OpenLLM's actual tokenization approach
|
| 251 |
+
sp_processor = spm.SentencePieceProcessor()
|
| 252 |
+
sp_processor.load(tokenizer_path)
|
| 253 |
|
| 254 |
+
# Store tokenizer and its path separately
|
| 255 |
+
# We need the path for the TextDataLoader later
|
| 256 |
+
self.tokenizer = sp_processor
|
| 257 |
+
self.tokenizer_path = tokenizer_path # Store the path separately
|
| 258 |
+
|
| 259 |
+
print(f"β
Tokenizer loaded successfully using SentencePieceProcessor")
|
| 260 |
+
print(f" - Vocabulary size: {sp_processor.vocab_size()}")
|
| 261 |
+
print(f" - Tokenizer path: {tokenizer_path}")
|
| 262 |
+
print(f" - Tokenizer type: {type(sp_processor).__name__}")
|
| 263 |
+
|
| 264 |
+
except Exception as e:
|
| 265 |
+
print(f"β Failed to load tokenizer: {e}")
|
| 266 |
+
return f"β Failed to load OpenLLM tokenizer: {str(e)}"
|
| 267 |
+
|
| 268 |
+
return f"β
Successfully loaded OpenLLM {model_size} model with custom architecture"
|
| 269 |
+
|
| 270 |
except Exception as e:
|
| 271 |
+
return f"β Failed to load OpenLLM model and tokenizer: {str(e)}"
|
| 272 |
|
| 273 |
+
def prepare_dataset(self) -> str:
|
| 274 |
+
"""
|
| 275 |
+
Load and prepare the training dataset using OpenLLM's approach.
|
| 276 |
+
|
| 277 |
+
This method implements OpenLLM's data preparation strategy:
|
| 278 |
+
1. Loads training data from Hugging Face Hub dataset
|
| 279 |
+
2. Creates a temporary text file for OpenLLM's TextDataLoader
|
| 280 |
+
3. Initializes OpenLLM's TextDataLoader with the tokenizer
|
| 281 |
+
4. Prepares the data for training
|
| 282 |
+
|
| 283 |
+
OpenLLM's approach differs from Hugging Face because:
|
| 284 |
+
- Uses a simple text file format (not tokenized datasets)
|
| 285 |
+
- Uses OpenLLM's TextDataLoader (not Hugging Face datasets)
|
| 286 |
+
- Tokenization happens on-the-fly during training
|
| 287 |
+
|
| 288 |
+
Returns:
|
| 289 |
+
Status message indicating success or failure
|
| 290 |
+
Success: "β
Successfully prepared dataset with {count} samples"
|
| 291 |
+
Failure: "β Failed to prepare dataset: {error details}"
|
| 292 |
+
"""
|
| 293 |
try:
|
| 294 |
+
# Verify dependencies are available
|
| 295 |
+
if not DEPENDENCIES_AVAILABLE:
|
| 296 |
+
return "β Required dependencies not available"
|
| 297 |
|
| 298 |
+
print("π Loading training dataset...")
|
| 299 |
+
print(" - Loading from Hugging Face Hub dataset")
|
| 300 |
+
print(" - Using OpenLLM's data preparation approach")
|
| 301 |
|
| 302 |
+
# Load dataset from HF Hub
|
| 303 |
+
# This contains the training text data for continuing model training
|
| 304 |
+
dataset = load_dataset("lemms/openllm-training-data")
|
| 305 |
+
print(f"β
Dataset loaded: {len(dataset['train'])} samples")
|
| 306 |
+
print(f" - Dataset: lemms/openllm-training-data")
|
| 307 |
+
print(f" - Samples: {len(dataset['train'])}")
|
| 308 |
+
|
| 309 |
+
# Create temporary data file for OpenLLM's TextDataLoader
|
| 310 |
+
# OpenLLM expects a simple text file with one text sample per line
|
| 311 |
+
temp_data_file = "temp_training_data.txt"
|
| 312 |
+
with open(temp_data_file, 'w', encoding='utf-8') as f:
|
| 313 |
+
for item in dataset['train']:
|
| 314 |
+
f.write(item['text'] + '\n')
|
| 315 |
+
|
| 316 |
+
print(f"β
Temporary data file created: {temp_data_file}")
|
| 317 |
+
print(f" - Format: One text sample per line")
|
| 318 |
+
print(f" - Encoding: UTF-8")
|
| 319 |
+
|
| 320 |
+
# Create OpenLLM's TextDataLoader
|
| 321 |
+
# This is OpenLLM's custom data loading implementation
|
| 322 |
+
try:
|
| 323 |
+
# Use the stored tokenizer path instead of trying to access model_file_path
|
| 324 |
+
# SentencePieceProcessor doesn't have a model_file_path attribute
|
| 325 |
+
tokenizer_path = self.tokenizer_path # Use the stored path
|
| 326 |
+
|
| 327 |
+
print(f"π Creating OpenLLM TextDataLoader...")
|
| 328 |
+
print(f" - Data file: {temp_data_file}")
|
| 329 |
+
print(f" - Tokenizer path: {tokenizer_path}")
|
| 330 |
+
print(f" - Sequence length: 512")
|
| 331 |
+
print(f" - Batch size: 4 (will be overridden by training config)")
|
| 332 |
+
|
| 333 |
+
self.data_loader = TextDataLoader(
|
| 334 |
+
data_file=temp_data_file,
|
| 335 |
+
tokenizer_path=tokenizer_path,
|
| 336 |
+
seq_len=512, # Maximum sequence length for training
|
| 337 |
+
batch_size=4, # Will be overridden by training config
|
| 338 |
+
shuffle=True # Shuffle data for better training
|
| 339 |
)
|
| 340 |
+
|
| 341 |
+
print(f"β
OpenLLM TextDataLoader created successfully")
|
| 342 |
+
print(f" - DataLoader type: {type(self.data_loader).__name__}")
|
| 343 |
+
print(f" - Uses OpenLLM's custom implementation")
|
| 344 |
+
|
| 345 |
+
except Exception as e:
|
| 346 |
+
print(f"β Failed to create TextDataLoader: {e}")
|
| 347 |
+
return f"β Failed to create data loader: {str(e)}"
|
| 348 |
+
|
| 349 |
+
return f"β
Successfully prepared dataset with {len(dataset['train'])} samples"
|
| 350 |
+
|
| 351 |
+
except Exception as e:
|
| 352 |
+
return f"β Failed to prepare dataset: {str(e)}"
|
| 353 |
+
|
| 354 |
+
def setup_training(self, config: TrainingConfig) -> str:
|
| 355 |
+
"""
|
| 356 |
+
Set up the training configuration using OpenLLM's approach.
|
| 357 |
+
|
| 358 |
+
This method configures the training environment with:
|
| 359 |
+
1. Output directory creation
|
| 360 |
+
2. Optimizer setup with weight decay groups
|
| 361 |
+
3. Learning rate scheduler with warmup
|
| 362 |
+
4. Training hyperparameters
|
| 363 |
+
|
| 364 |
+
The setup follows OpenLLM's training methodology:
|
| 365 |
+
- Uses AdamW optimizer with weight decay
|
| 366 |
+
- Implements learning rate warmup followed by cosine annealing
|
| 367 |
+
- Separates parameters for different weight decay rates
|
| 368 |
+
- Uses gradient clipping for stability
|
| 369 |
+
|
| 370 |
+
Args:
|
| 371 |
+
config: Training configuration object containing all hyperparameters
|
| 372 |
+
|
| 373 |
+
Returns:
|
| 374 |
+
Status message indicating success or failure
|
| 375 |
+
Success: "β
Training setup completed successfully"
|
| 376 |
+
Failure: "β Failed to setup training: {error details}"
|
| 377 |
+
"""
|
| 378 |
+
try:
|
| 379 |
+
print("π Setting up training configuration...")
|
| 380 |
+
print(f" - Output directory: {config.output_dir}")
|
| 381 |
+
print(f" - Learning rate: {config.learning_rate}")
|
| 382 |
+
print(f" - Max steps: {config.max_steps}")
|
| 383 |
+
|
| 384 |
+
# Create output directory for saving models and checkpoints
|
| 385 |
+
os.makedirs(config.output_dir, exist_ok=True)
|
| 386 |
+
print(f"β
Output directory created: {config.output_dir}")
|
| 387 |
|
| 388 |
+
# Set up optimizer (AdamW with weight decay)
|
| 389 |
+
# This follows OpenLLM's optimization strategy
|
| 390 |
+
print("π Setting up AdamW optimizer with weight decay...")
|
| 391 |
|
| 392 |
+
# Separate parameters for different weight decay rates
|
| 393 |
+
# This is a common practice for transformer training
|
| 394 |
+
decay_params = [] # Parameters that should have weight decay
|
| 395 |
+
no_decay_params = [] # Parameters that should not have weight decay
|
| 396 |
+
|
| 397 |
+
for name, param in self.model.named_parameters():
|
| 398 |
+
if not param.requires_grad:
|
| 399 |
+
continue
|
| 400 |
|
| 401 |
+
# Apply weight decay to all parameters except biases and layer norm weights
|
| 402 |
+
if len(param.shape) == 1 or name.endswith('.bias'):
|
| 403 |
+
no_decay_params.append(param)
|
| 404 |
+
else:
|
| 405 |
+
decay_params.append(param)
|
| 406 |
+
|
| 407 |
+
# Create parameter groups with different weight decay rates
|
| 408 |
+
param_groups = [
|
| 409 |
+
{'params': decay_params, 'weight_decay': 0.01}, # 1% weight decay
|
| 410 |
+
{'params': no_decay_params, 'weight_decay': 0.0} # No weight decay
|
| 411 |
+
]
|
| 412 |
+
|
| 413 |
+
print(f" - Decay parameters: {len(decay_params)}")
|
| 414 |
+
print(f" - No-decay parameters: {len(no_decay_params)}")
|
| 415 |
+
|
| 416 |
+
# Initialize AdamW optimizer with OpenLLM's recommended settings
|
| 417 |
+
self.optimizer = torch.optim.AdamW(
|
| 418 |
+
param_groups,
|
| 419 |
+
lr=config.learning_rate,
|
| 420 |
+
betas=(0.9, 0.95), # Beta values for momentum
|
| 421 |
+
eps=1e-8 # Epsilon for numerical stability
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
print(f"β
AdamW optimizer configured")
|
| 425 |
+
print(f" - Learning rate: {config.learning_rate}")
|
| 426 |
+
print(f" - Betas: (0.9, 0.95)")
|
| 427 |
+
print(f" - Epsilon: 1e-8")
|
| 428 |
+
|
| 429 |
+
# Set up learning rate scheduler
|
| 430 |
+
# OpenLLM uses a warmup followed by cosine annealing
|
| 431 |
+
print("π Setting up learning rate scheduler...")
|
| 432 |
+
|
| 433 |
+
# Warmup scheduler: linearly increase LR from 1% to 100%
|
| 434 |
+
warmup_scheduler = torch.optim.lr_scheduler.LinearLR(
|
| 435 |
+
self.optimizer,
|
| 436 |
+
start_factor=0.01, # Start at 1% of target LR
|
| 437 |
+
end_factor=1.0, # End at 100% of target LR
|
| 438 |
+
total_iters=config.warmup_steps
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
# Main scheduler: cosine annealing after warmup
|
| 442 |
+
main_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 443 |
+
self.optimizer,
|
| 444 |
+
T_max=config.max_steps - config.warmup_steps # Duration of cosine annealing
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
# Combine warmup and main schedulers
|
| 448 |
+
self.scheduler = torch.optim.lr_scheduler.SequentialLR(
|
| 449 |
+
self.optimizer,
|
| 450 |
+
schedulers=[warmup_scheduler, main_scheduler],
|
| 451 |
+
milestones=[config.warmup_steps] # Switch to main scheduler after warmup
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
print(f"β
Learning rate scheduler configured")
|
| 455 |
+
print(f" - Warmup steps: {config.warmup_steps}")
|
| 456 |
+
print(f" - Total steps: {config.max_steps}")
|
| 457 |
+
print(f" - Schedule: Linear warmup β Cosine annealing")
|
| 458 |
+
|
| 459 |
+
print("β
Training setup completed successfully")
|
| 460 |
+
return f"β
Training setup completed successfully"
|
| 461 |
+
|
| 462 |
except Exception as e:
|
| 463 |
+
return f"β Failed to setup training: {str(e)}"
|
| 464 |
|
| 465 |
+
def train_model(self, config: TrainingConfig, progress_callback=None) -> str:
|
| 466 |
+
"""
|
| 467 |
+
Execute the actual model training using OpenLLM's approach.
|
| 468 |
+
|
| 469 |
+
This method implements OpenLLM's training loop:
|
| 470 |
+
1. Sets up training mode and progress tracking
|
| 471 |
+
2. Iterates through data batches using OpenLLM's TextDataLoader
|
| 472 |
+
3. Performs forward pass, loss computation, and backward pass
|
| 473 |
+
4. Implements gradient accumulation for memory efficiency
|
| 474 |
+
5. Updates model parameters and learning rate
|
| 475 |
+
6. Saves checkpoints and logs progress
|
| 476 |
+
|
| 477 |
+
The training loop follows OpenLLM's methodology:
|
| 478 |
+
- Uses OpenLLM's GPTModel forward pass (returns logits and loss)
|
| 479 |
+
- Implements gradient accumulation for effective larger batch sizes
|
| 480 |
+
- Uses gradient clipping for training stability
|
| 481 |
+
- Saves checkpoints in OpenLLM's format
|
| 482 |
+
- Updates progress for UI monitoring
|
| 483 |
+
|
| 484 |
+
Args:
|
| 485 |
+
config: Training configuration object containing hyperparameters
|
| 486 |
+
progress_callback: Optional callback function for progress updates
|
| 487 |
+
(Not used in current implementation)
|
| 488 |
+
|
| 489 |
+
Returns:
|
| 490 |
+
Status message indicating success or failure
|
| 491 |
+
Success: "β
Training completed successfully! Final step: {step}"
|
| 492 |
+
Failure: "β Training failed: {error details}"
|
| 493 |
+
"""
|
| 494 |
try:
|
| 495 |
+
# Set training state
|
| 496 |
+
self.is_training = True
|
| 497 |
+
self.training_progress["status"] = "Training"
|
| 498 |
+
self.training_progress["total_steps"] = config.max_steps
|
| 499 |
|
| 500 |
+
print(f"π Starting OpenLLM training for {config.max_steps} steps...")
|
| 501 |
+
print(f" - Model: {type(self.model).__name__}")
|
| 502 |
+
print(f" - DataLoader: {type(self.data_loader).__name__}")
|
| 503 |
+
print(f" - Optimizer: {type(self.optimizer).__name__}")
|
| 504 |
+
print(f" - Gradient accumulation: {config.gradient_accumulation_steps}")
|
| 505 |
|
| 506 |
+
# Training loop using OpenLLM's approach
|
| 507 |
+
self.model.train() # Set model to training mode
|
| 508 |
+
accumulated_loss = 0.0 # Track loss across accumulation steps
|
| 509 |
+
self.optimizer.zero_grad() # Clear gradients
|
| 510 |
|
| 511 |
+
step = 0 # Current training step
|
| 512 |
+
for batch_idx, (input_ids, target_ids) in enumerate(self.data_loader):
|
| 513 |
+
# Check if we've reached the maximum number of steps
|
| 514 |
+
if step >= config.max_steps:
|
| 515 |
+
break
|
| 516 |
+
|
| 517 |
+
# Forward pass (model computes loss internally when targets provided)
|
| 518 |
+
# OpenLLM's GPTModel returns both logits and loss
|
| 519 |
+
logits, loss = self.model(input_ids, target_ids)
|
| 520 |
+
|
| 521 |
+
# Scale loss for gradient accumulation
|
| 522 |
+
# This allows us to simulate larger batch sizes
|
| 523 |
+
loss = loss / config.gradient_accumulation_steps
|
| 524 |
+
accumulated_loss += loss.item()
|
| 525 |
+
|
| 526 |
+
# Backward pass - compute gradients
|
| 527 |
+
loss.backward()
|
| 528 |
+
|
| 529 |
+
# Update weights every gradient_accumulation_steps
|
| 530 |
+
if (batch_idx + 1) % config.gradient_accumulation_steps == 0:
|
| 531 |
+
# Clip gradients for training stability
|
| 532 |
+
# This prevents exploding gradients
|
| 533 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
|
| 534 |
+
|
| 535 |
+
# Update parameters using the optimizer
|
| 536 |
+
self.optimizer.step()
|
| 537 |
+
|
| 538 |
+
# Update learning rate using the scheduler
|
| 539 |
+
self.scheduler.step()
|
| 540 |
+
|
| 541 |
+
# Clear gradients for the next accumulation cycle
|
| 542 |
+
self.optimizer.zero_grad()
|
| 543 |
+
|
| 544 |
+
# Update step count
|
| 545 |
+
step += 1
|
| 546 |
+
|
| 547 |
+
# Update progress for UI monitoring
|
| 548 |
+
self.training_progress["current_step"] = step
|
| 549 |
+
self.training_progress["loss"] = accumulated_loss
|
| 550 |
+
self.training_progress["learning_rate"] = self.scheduler.get_last_lr()[0]
|
| 551 |
+
|
| 552 |
+
# Log progress at specified intervals
|
| 553 |
+
if step % config.logging_steps == 0:
|
| 554 |
+
current_lr = self.scheduler.get_last_lr()[0]
|
| 555 |
+
print(f"Step {step}/{config.max_steps} | Loss: {accumulated_loss:.4f} | LR: {current_lr:.2e}")
|
| 556 |
+
|
| 557 |
+
# Save checkpoint at specified intervals
|
| 558 |
+
if step % config.save_steps == 0:
|
| 559 |
+
self._save_checkpoint(config.output_dir, step)
|
| 560 |
+
print(f"πΎ Checkpoint saved at step {step}")
|
| 561 |
+
|
| 562 |
+
# Reset accumulated loss for the next accumulation cycle
|
| 563 |
+
accumulated_loss = 0.0
|
| 564 |
+
|
| 565 |
+
# Clean up memory periodically
|
| 566 |
+
if step % 100 == 0:
|
| 567 |
+
gc.collect()
|
| 568 |
+
print(f"π§Ή Memory cleanup at step {step}")
|
| 569 |
|
| 570 |
+
# Save final checkpoint
|
| 571 |
+
self._save_checkpoint(config.output_dir, step, is_best=True)
|
| 572 |
+
print(f"πΎ Final checkpoint saved at step {step}")
|
| 573 |
+
|
| 574 |
+
# Update final progress
|
| 575 |
+
self.training_progress["status"] = "Completed"
|
| 576 |
+
self.training_progress["current_step"] = step
|
| 577 |
|
| 578 |
+
print(f"β
Training completed! Final step: {step}")
|
| 579 |
+
print(f" - Total steps completed: {step}")
|
| 580 |
+
print(f" - Final loss: {self.training_progress['loss']:.4f}")
|
| 581 |
+
print(f" - Final learning rate: {self.training_progress['learning_rate']:.2e}")
|
| 582 |
|
| 583 |
+
return f"β
Training completed successfully! Final step: {step}"
|
| 584 |
|
| 585 |
except Exception as e:
|
| 586 |
+
self.training_progress["status"] = "Failed"
|
| 587 |
+
print(f"β Training failed: {e}")
|
| 588 |
+
print(f" - Error occurred during training")
|
| 589 |
+
print(f" - Training state: {self.training_progress['status']}")
|
| 590 |
+
return f"β Training failed: {str(e)}"
|
| 591 |
+
finally:
|
| 592 |
+
self.is_training = False
|
| 593 |
|
| 594 |
+
def _save_checkpoint(self, output_dir: str, step: int, is_best: bool = False) -> None:
|
| 595 |
+
"""
|
| 596 |
+
Save model checkpoint using OpenLLM's approach.
|
|
|
|
| 597 |
|
| 598 |
+
This method saves the model state in OpenLLM's checkpoint format:
|
| 599 |
+
- Model state dictionary
|
| 600 |
+
- Optimizer state dictionary
|
| 601 |
+
- Scheduler state dictionary
|
| 602 |
+
- Model configuration
|
| 603 |
+
- Training step information
|
| 604 |
|
| 605 |
+
The checkpoint format is compatible with OpenLLM's loading mechanism
|
| 606 |
+
and can be used to resume training or load the model for inference.
|
| 607 |
+
|
| 608 |
+
Args:
|
| 609 |
+
output_dir: Directory to save the checkpoint
|
| 610 |
+
step: Current training step number
|
| 611 |
+
is_best: Whether this is the best model so far
|
| 612 |
+
"""
|
| 613 |
+
try:
|
| 614 |
+
# Create checkpoint dictionary with all necessary components
|
| 615 |
+
checkpoint = {
|
| 616 |
+
'step': step, # Current training step
|
| 617 |
+
'model_state_dict': self.model.state_dict(), # Model parameters
|
| 618 |
+
'optimizer_state_dict': self.optimizer.state_dict(), # Optimizer state
|
| 619 |
+
'scheduler_state_dict': self.scheduler.state_dict(), # Scheduler state
|
| 620 |
+
'config': self.model.config.__dict__ # Model configuration
|
| 621 |
+
}
|
| 622 |
+
|
| 623 |
+
# Save latest checkpoint
|
| 624 |
+
checkpoint_path = os.path.join(output_dir, f"checkpoint_step_{step}.pt")
|
| 625 |
+
torch.save(checkpoint, checkpoint_path)
|
| 626 |
+
|
| 627 |
+
# Save best checkpoint if this is the best model
|
| 628 |
+
if is_best:
|
| 629 |
+
best_path = os.path.join(output_dir, "best_model.pt")
|
| 630 |
+
torch.save(checkpoint, best_path)
|
| 631 |
+
print(f"πΎ Best model saved: {best_path}")
|
| 632 |
+
|
| 633 |
+
print(f"πΎ Checkpoint saved: {checkpoint_path}")
|
| 634 |
+
|
| 635 |
+
except Exception as e:
|
| 636 |
+
print(f"β Failed to save checkpoint: {e}")
|
| 637 |
+
|
| 638 |
+
def save_and_upload_model(self, config: TrainingConfig) -> str:
|
| 639 |
+
"""
|
| 640 |
+
Save the trained model and upload it to Hugging Face Hub.
|
| 641 |
|
| 642 |
+
This method completes the training pipeline by:
|
| 643 |
+
1. Saving the final model checkpoint
|
| 644 |
+
2. Copying the tokenizer files
|
| 645 |
+
3. Uploading the complete model to Hugging Face Hub
|
| 646 |
+
4. Creating a new model repository for the trained model
|
| 647 |
|
| 648 |
+
The uploaded model will be available at:
|
| 649 |
+
https://huggingface.co/lemms/openllm-{size}-extended-8k
|
| 650 |
|
| 651 |
+
Args:
|
| 652 |
+
config: Training configuration object
|
| 653 |
+
|
| 654 |
+
Returns:
|
| 655 |
+
Status message indicating success or failure
|
| 656 |
+
Success: "β
Model saved and uploaded to https://huggingface.co/{repo_id}"
|
| 657 |
+
Failure: "β Failed to save/upload model: {error details}"
|
| 658 |
+
"""
|
| 659 |
+
try:
|
| 660 |
+
print("π Saving trained model...")
|
| 661 |
+
print(f" - Output directory: {config.output_dir}")
|
| 662 |
+
print(f" - Model size: {config.model_size}")
|
| 663 |
+
|
| 664 |
+
# Save the final model checkpoint
|
| 665 |
+
self._save_checkpoint(config.output_dir, config.max_steps, is_best=True)
|
| 666 |
+
|
| 667 |
+
# Save tokenizer files
|
| 668 |
+
# Create a tokenizer directory within the output directory
|
| 669 |
+
tokenizer_dir = os.path.join(config.output_dir, "tokenizer")
|
| 670 |
+
os.makedirs(tokenizer_dir, exist_ok=True)
|
| 671 |
+
|
| 672 |
+
# Copy the tokenizer.model file using the stored path
|
| 673 |
+
# This ensures the tokenizer is included with the model
|
| 674 |
+
import shutil
|
| 675 |
+
shutil.copy2(self.tokenizer_path, os.path.join(tokenizer_dir, "tokenizer.model"))
|
| 676 |
+
|
| 677 |
+
print("β
Model saved locally")
|
| 678 |
+
print(f" - Model checkpoint: {config.output_dir}/best_model.pt")
|
| 679 |
+
print(f" - Tokenizer: {tokenizer_dir}/tokenizer.model")
|
| 680 |
+
|
| 681 |
+
# Generate model name for upload
|
| 682 |
+
# The naming convention follows: openllm-{size}-extended-8k
|
| 683 |
+
model_name = f"openllm-{config.model_size}-extended-8k"
|
| 684 |
+
repo_id = f"lemms/{model_name}"
|
| 685 |
+
|
| 686 |
+
# Upload to Hugging Face Hub
|
| 687 |
+
if self.hf_api:
|
| 688 |
+
print(f"π Uploading model to {repo_id}...")
|
| 689 |
+
print(f" - Repository: {repo_id}")
|
| 690 |
+
print(f" - Type: model")
|
| 691 |
+
print(f" - Source: {config.output_dir}")
|
| 692 |
+
|
| 693 |
+
# Create the repository first if it doesn't exist
|
| 694 |
+
try:
|
| 695 |
+
from huggingface_hub import create_repo
|
| 696 |
+
create_repo(
|
| 697 |
+
repo_id=repo_id,
|
| 698 |
+
repo_type="model",
|
| 699 |
+
exist_ok=True,
|
| 700 |
+
private=False
|
| 701 |
+
)
|
| 702 |
+
print(f"β
Repository {repo_id} ready for upload")
|
| 703 |
+
except Exception as create_error:
|
| 704 |
+
print(f"β οΈ Repository creation warning: {create_error}")
|
| 705 |
+
print(" Continuing with upload attempt...")
|
| 706 |
+
|
| 707 |
+
# Upload model files to Hugging Face Hub
|
| 708 |
+
# This creates a new model repository with all the files
|
| 709 |
+
self.hf_api.upload_folder(
|
| 710 |
+
folder_path=config.output_dir,
|
| 711 |
+
repo_id=repo_id,
|
| 712 |
+
repo_type="model",
|
| 713 |
+
commit_message=f"Add trained OpenLLM {config.model_size} model (8k steps)"
|
| 714 |
+
)
|
| 715 |
+
|
| 716 |
+
print(f"β
Model uploaded successfully to {repo_id}")
|
| 717 |
+
print(f" - Available at: https://huggingface.co/{repo_id}")
|
| 718 |
+
return f"β
Model saved and uploaded to https://huggingface.co/{repo_id}"
|
| 719 |
+
else:
|
| 720 |
+
print("β οΈ Hugging Face API not available - model saved locally only")
|
| 721 |
+
return f"β
Model saved locally to {config.output_dir}"
|
| 722 |
+
|
| 723 |
+
except Exception as e:
|
| 724 |
+
print(f"β Failed to save/upload model: {e}")
|
| 725 |
+
return f"β Failed to save/upload model: {str(e)}"
|
| 726 |
+
|
| 727 |
+
def get_training_progress(self) -> Dict[str, Any]:
|
| 728 |
+
"""
|
| 729 |
+
Get current training progress information.
|
| 730 |
+
|
| 731 |
+
This method returns a copy of the current training progress
|
| 732 |
+
for display in the Gradio UI. The progress information includes:
|
| 733 |
+
- Current training status
|
| 734 |
+
- Current step and total steps
|
| 735 |
+
- Current loss value
|
| 736 |
+
- Current learning rate
|
| 737 |
+
|
| 738 |
+
Returns:
|
| 739 |
+
Dictionary containing current training progress information
|
| 740 |
+
"""
|
| 741 |
+
return self.training_progress.copy()
|
| 742 |
+
|
| 743 |
+
def main():
|
| 744 |
+
"""
|
| 745 |
+
Main function that creates the complete Gradio application interface.
|
| 746 |
+
|
| 747 |
+
This function sets up the entire Gradio application with:
|
| 748 |
+
1. Application header and status information
|
| 749 |
+
2. Training configuration controls
|
| 750 |
+
3. Training status and progress display
|
| 751 |
+
4. Training control buttons
|
| 752 |
+
5. Instructions and resource links
|
| 753 |
+
6. Training function implementation
|
| 754 |
+
|
| 755 |
+
The interface provides a complete training experience for OpenLLM models
|
| 756 |
+
with real-time progress monitoring and comprehensive configuration options.
|
| 757 |
+
|
| 758 |
+
Returns:
|
| 759 |
+
Gradio Blocks interface for the training application
|
| 760 |
+
"""
|
| 761 |
+
|
| 762 |
+
# Initialize the trainer
|
| 763 |
+
# This creates the OpenLLMTrainer instance that will handle all training operations
|
| 764 |
+
trainer = OpenLLMTrainer()
|
| 765 |
+
|
| 766 |
+
# Create the main Gradio application interface
|
| 767 |
+
# Using Gradio 4.44.1 with Soft theme for modern appearance
|
| 768 |
+
with gr.Blocks(
|
| 769 |
+
title="OpenLLM Training Space - Fixed with Uploaded Modules",
|
| 770 |
+
theme=gr.themes.Soft()
|
| 771 |
+
) as demo:
|
| 772 |
+
|
| 773 |
+
# Application Header
|
| 774 |
+
# Provides clear identification and description of the application
|
| 775 |
+
gr.Markdown("# π OpenLLM Training Space - Fixed with Uploaded Modules")
|
| 776 |
+
gr.Markdown("### *Uses OpenLLM's Custom Model Architecture from Uploaded Files*")
|
| 777 |
+
gr.Markdown("---")
|
| 778 |
+
|
| 779 |
+
# Status Information
|
| 780 |
+
# Shows the availability of key components and dependencies
|
| 781 |
+
gr.Markdown(f"**OpenLLM Available**: {'β
Yes' if OPENLLM_AVAILABLE else 'β No'}")
|
| 782 |
+
gr.Markdown(f"**SentencePiece Available**: {'β
Yes' if SENTENCEPIECE_AVAILABLE else 'β No'}")
|
| 783 |
+
gr.Markdown(f"**Dependencies Available**: {'β
Yes' if DEPENDENCIES_AVAILABLE else 'β No'}")
|
| 784 |
+
gr.Markdown("**Architecture**: β
OpenLLM Custom GPTModel (From Uploaded Files)")
|
| 785 |
|
| 786 |
+
# Main Content Area
|
| 787 |
+
# Two-column layout for configuration and status
|
| 788 |
+
with gr.Row():
|
| 789 |
+
|
| 790 |
+
# Left Column: Training Configuration
|
| 791 |
+
# Contains all the training hyperparameters and settings
|
| 792 |
+
with gr.Column(scale=1):
|
| 793 |
+
gr.Markdown("## π Training Configuration")
|
| 794 |
+
|
| 795 |
+
# Model Size Selection
|
| 796 |
+
# Allows users to choose which base model to train from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 797 |
model_size = gr.Dropdown(
|
| 798 |
choices=["small", "medium", "large"],
|
| 799 |
value="small",
|
| 800 |
label="Model Size",
|
| 801 |
+
info="Select the base model size to train from"
|
| 802 |
)
|
| 803 |
+
|
| 804 |
+
# Training Steps Configuration
|
| 805 |
+
# Controls the number of training iterations
|
| 806 |
+
max_steps = gr.Slider(
|
| 807 |
+
minimum=100,
|
| 808 |
+
maximum=10000,
|
| 809 |
+
value=1000,
|
| 810 |
+
step=100,
|
| 811 |
+
label="Max Training Steps",
|
| 812 |
+
info="Number of training iterations (100-10,000)"
|
| 813 |
+
)
|
| 814 |
+
|
| 815 |
+
# Learning Rate Configuration
|
| 816 |
+
# Controls the learning rate for the optimizer
|
| 817 |
+
learning_rate = gr.Slider(
|
| 818 |
+
minimum=1e-5,
|
| 819 |
+
maximum=1e-3,
|
| 820 |
+
value=3e-4,
|
| 821 |
+
step=1e-5,
|
| 822 |
+
label="Learning Rate",
|
| 823 |
+
info="Training rate (0.00001-0.001)"
|
| 824 |
+
)
|
| 825 |
+
|
| 826 |
+
# Batch Size Configuration
|
| 827 |
+
# Controls the number of samples per training batch
|
| 828 |
+
batch_size = gr.Slider(
|
| 829 |
+
minimum=1,
|
| 830 |
+
maximum=16,
|
| 831 |
+
value=4,
|
| 832 |
+
step=1,
|
| 833 |
+
label="Batch Size",
|
| 834 |
+
info="Samples per training batch (1-16)"
|
| 835 |
)
|
| 836 |
|
| 837 |
+
# Right Column: Training Status and Controls
|
| 838 |
+
# Contains status display and control buttons
|
| 839 |
+
with gr.Column(scale=1):
|
| 840 |
+
gr.Markdown("## π― Training Status")
|
| 841 |
+
|
| 842 |
+
# Training Status Display
|
| 843 |
+
# Shows current training status and any error messages
|
| 844 |
+
status_text = gr.Textbox(
|
| 845 |
+
value="Ready to start training" if OPENLLM_AVAILABLE else "OpenLLM not available",
|
| 846 |
+
label="Current Status",
|
| 847 |
+
interactive=False,
|
| 848 |
+
lines=5,
|
| 849 |
+
info="Shows current training status and progress updates"
|
| 850 |
+
)
|
| 851 |
+
|
| 852 |
+
# Progress Information
|
| 853 |
+
# Displays detailed training progress in JSON format
|
| 854 |
+
progress_info = gr.JSON(
|
| 855 |
+
value=trainer.get_training_progress(),
|
| 856 |
+
label="Training Progress"
|
| 857 |
+
)
|
| 858 |
+
|
| 859 |
+
# Training Control Buttons
|
| 860 |
+
# Buttons to start and stop training
|
| 861 |
+
with gr.Row():
|
| 862 |
+
start_btn = gr.Button("π Start Training", variant="primary")
|
| 863 |
+
stop_btn = gr.Button("βΉοΈ Stop Training", variant="stop")
|
| 864 |
|
| 865 |
+
# Instructions Section
|
| 866 |
+
# Provides detailed instructions for using the training interface
|
| 867 |
+
gr.Markdown("## π OpenLLM Training Instructions")
|
| 868 |
+
gr.Markdown("""
|
| 869 |
+
This interface uses **OpenLLM's actual custom model architecture** from uploaded files:
|
| 870 |
+
|
| 871 |
+
### **Step 1: Configure Parameters**
|
| 872 |
+
- **Model Size**: Select the base model to train from (small, medium, large)
|
| 873 |
+
- **Max Steps**: Number of training iterations (100-10,000)
|
| 874 |
+
- **Learning Rate**: Training rate (0.00001-0.001)
|
| 875 |
+
- **Batch Size**: Samples per training batch (1-16)
|
| 876 |
+
|
| 877 |
+
### **Step 2: Start Training**
|
| 878 |
+
- Click "Start Training" to begin the actual training process
|
| 879 |
+
- Uses OpenLLM's custom GPTModel class from uploaded files
|
| 880 |
+
- Uses sentencepiece.SentencePieceProcessor() for tokenization
|
| 881 |
+
- Compatible with OpenLLM's actual implementation
|
| 882 |
+
|
| 883 |
+
### **Step 3: Monitor Progress**
|
| 884 |
+
- Watch the status updates and progress information
|
| 885 |
+
- Training may take several minutes depending on steps
|
| 886 |
+
- The final model will be uploaded to Hugging Face Hub
|
| 887 |
+
|
| 888 |
+
### **Step 4: Access Results**
|
| 889 |
+
- Trained models are automatically pushed to: `lemms/openllm-{size}-extended-8k`
|
| 890 |
+
- Check the model repository for your trained model
|
| 891 |
+
- Use the model for inference or further training
|
| 892 |
+
""")
|
| 893 |
+
|
| 894 |
+
# Resource Links Section
|
| 895 |
+
# Provides links to related models and resources
|
| 896 |
+
gr.Markdown("## π Model Resources")
|
| 897 |
+
gr.Markdown("""
|
| 898 |
+
- [π 7k Small Model](https://huggingface.co/lemms/openllm-small-extended-7k)
|
| 899 |
+
- [π― 8k Small Model](https://huggingface.co/lemms/openllm-small-extended-8k)
|
| 900 |
+
- [π Training Dataset](https://huggingface.co/datasets/lemms/openllm-training-data)
|
| 901 |
+
- [π Main Project](https://github.com/louischua/openllm)
|
| 902 |
+
""")
|
| 903 |
+
|
| 904 |
+
# Training Function Definition
|
| 905 |
+
# This function is called when the Start Training button is clicked
|
| 906 |
+
def start_complete_training(model_size, max_steps, learning_rate, batch_size):
|
| 907 |
+
"""
|
| 908 |
+
Execute the complete training process using OpenLLM's approach.
|
| 909 |
|
| 910 |
+
This function orchestrates the entire training pipeline:
|
| 911 |
+
1. Validates OpenLLM availability
|
| 912 |
+
2. Creates training configuration
|
| 913 |
+
3. Loads model and tokenizer
|
| 914 |
+
4. Prepares dataset
|
| 915 |
+
5. Sets up training environment
|
| 916 |
+
6. Executes training
|
| 917 |
+
7. Saves and uploads the trained model
|
| 918 |
|
| 919 |
+
The function provides comprehensive error handling and status updates
|
| 920 |
+
throughout the training process.
|
|
|
|
|
|
|
|
|
|
| 921 |
|
| 922 |
+
Args:
|
| 923 |
+
model_size: Size of the model to train ("small", "medium", "large")
|
| 924 |
+
max_steps: Maximum number of training steps
|
| 925 |
+
learning_rate: Learning rate for the optimizer
|
| 926 |
+
batch_size: Batch size for training
|
| 927 |
+
|
| 928 |
+
Returns:
|
| 929 |
+
Status message indicating the result of the training process
|
| 930 |
+
"""
|
| 931 |
+
# Validate OpenLLM availability
|
| 932 |
+
if not OPENLLM_AVAILABLE:
|
| 933 |
+
return "β OpenLLM custom model architecture not available. Please check the installation."
|
| 934 |
|
| 935 |
+
try:
|
| 936 |
+
print(f"π Starting complete training process...")
|
| 937 |
+
print(f" - Model size: {model_size}")
|
| 938 |
+
print(f" - Max steps: {max_steps}")
|
| 939 |
+
print(f" - Learning rate: {learning_rate}")
|
| 940 |
+
print(f" - Batch size: {batch_size}")
|
| 941 |
+
|
| 942 |
+
# Create training configuration
|
| 943 |
+
# This encapsulates all training parameters
|
| 944 |
+
config = TrainingConfig(
|
| 945 |
+
model_size=model_size,
|
| 946 |
+
max_steps=max_steps,
|
| 947 |
+
learning_rate=learning_rate,
|
| 948 |
+
batch_size=batch_size
|
| 949 |
+
)
|
| 950 |
+
|
| 951 |
+
# Step 1: Load model and tokenizer using OpenLLM's approach
|
| 952 |
+
print("π Step 1: Loading model and tokenizer...")
|
| 953 |
+
status = trainer.load_model_and_tokenizer(model_size)
|
| 954 |
+
if "β" in status:
|
| 955 |
+
return status
|
| 956 |
+
|
| 957 |
+
# Step 2: Prepare dataset
|
| 958 |
+
print("π Step 2: Preparing dataset...")
|
| 959 |
+
status = trainer.prepare_dataset()
|
| 960 |
+
if "β" in status:
|
| 961 |
+
return status
|
| 962 |
+
|
| 963 |
+
# Step 3: Setup training
|
| 964 |
+
print("π Step 3: Setting up training...")
|
| 965 |
+
status = trainer.setup_training(config)
|
| 966 |
+
if "β" in status:
|
| 967 |
+
return status
|
| 968 |
+
|
| 969 |
+
# Step 4: Execute training
|
| 970 |
+
print("π Step 4: Executing training...")
|
| 971 |
+
status = trainer.train_model(config)
|
| 972 |
+
if "β" in status:
|
| 973 |
+
return status
|
| 974 |
+
|
| 975 |
+
# Step 5: Save and upload model
|
| 976 |
+
print("π Step 5: Saving and uploading model...")
|
| 977 |
+
status = trainer.save_and_upload_model(config)
|
| 978 |
+
|
| 979 |
+
print("π Complete training process finished!")
|
| 980 |
+
return f"π Complete training process finished!\n{status}"
|
| 981 |
+
|
| 982 |
+
except Exception as e:
|
| 983 |
+
print(f"β Training process failed: {str(e)}")
|
| 984 |
+
return f"β Training process failed: {str(e)}"
|
| 985 |
+
|
| 986 |
+
def update_progress():
|
| 987 |
+
"""
|
| 988 |
+
Update the progress display.
|
| 989 |
|
| 990 |
+
This function is called periodically to update the progress
|
| 991 |
+
information displayed in the Gradio interface. It returns the
|
| 992 |
+
current training progress from the trainer.
|
| 993 |
|
| 994 |
+
Returns:
|
| 995 |
+
Current training progress dictionary
|
| 996 |
+
"""
|
| 997 |
+
return trainer.get_training_progress()
|
| 998 |
+
|
| 999 |
+
# Connect UI Components to Functions
|
| 1000 |
+
# This connects the Start Training button to the training function
|
| 1001 |
+
start_btn.click(
|
| 1002 |
+
fn=start_complete_training,
|
| 1003 |
+
inputs=[model_size, max_steps, learning_rate, batch_size],
|
| 1004 |
+
outputs=[status_text]
|
| 1005 |
+
)
|
| 1006 |
+
|
| 1007 |
+
# Auto-refresh progress every 5 seconds during training
|
| 1008 |
+
# This ensures the progress display stays up to date
|
| 1009 |
+
demo.load(update_progress, outputs=[progress_info])
|
| 1010 |
+
|
| 1011 |
+
# Application Footer
|
| 1012 |
+
# Provides attribution and technical information
|
| 1013 |
+
gr.Markdown("---")
|
| 1014 |
+
gr.Markdown("**Author**: Louis Chua Bean Chong | **Project**: OpenLLM | **License**: GPL-3.0")
|
| 1015 |
+
gr.Markdown("**Architecture**: OpenLLM Custom GPTModel (From Uploaded Files)")
|
| 1016 |
+
gr.Markdown("**Tokenizer**: sentencepiece.SentencePieceProcessor()")
|
| 1017 |
|
| 1018 |
+
return demo
|
|
|
|
| 1019 |
|
| 1020 |
if __name__ == "__main__":
|
| 1021 |
+
# Launch the Gradio application
|
| 1022 |
+
# This starts the web interface for the training application
|
| 1023 |
+
demo = main()
|
| 1024 |
+
demo.launch()
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
|
@@ -1,26 +1,51 @@
|
|
| 1 |
-
#
|
| 2 |
-
#
|
| 3 |
|
| 4 |
-
#
|
| 5 |
-
|
|
|
|
|
|
|
| 6 |
|
| 7 |
-
#
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
-
#
|
| 11 |
-
|
| 12 |
-
torchvision>=0.15.0
|
| 13 |
|
| 14 |
-
#
|
| 15 |
-
|
|
|
|
|
|
|
| 16 |
|
| 17 |
-
#
|
| 18 |
-
|
|
|
|
| 19 |
|
| 20 |
-
#
|
| 21 |
-
|
| 22 |
-
|
| 23 |
|
| 24 |
-
#
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Complete Training Dependencies for OpenLLM Space - Updated for Gradio 4.44.1
|
| 2 |
+
# This file includes all necessary packages for real model training
|
| 3 |
|
| 4 |
+
# Core Machine Learning Framework
|
| 5 |
+
torch>=2.0.0 # PyTorch deep learning framework
|
| 6 |
+
torchvision>=0.15.0 # Computer vision utilities
|
| 7 |
+
torchaudio>=2.0.0 # Audio processing utilities
|
| 8 |
|
| 9 |
+
# Hugging Face Ecosystem - Complete Training Stack
|
| 10 |
+
transformers>=4.30.0 # Pre-trained models and training utilities
|
| 11 |
+
datasets>=2.12.0 # Dataset loading and processing
|
| 12 |
+
tokenizers>=0.13.0 # Fast tokenization library
|
| 13 |
+
sentencepiece>=0.1.99 # SentencePiece tokenization (CRITICAL for OpenLLM models)
|
| 14 |
+
huggingface_hub>=0.34.0 # Hugging Face Hub integration
|
| 15 |
+
accelerate>=0.20.0 # Distributed training acceleration
|
| 16 |
|
| 17 |
+
# User Interface Framework - Updated to 4.44.1
|
| 18 |
+
gradio==4.44.1 # Web UI framework for ML applications (fixed version)
|
|
|
|
| 19 |
|
| 20 |
+
# Data Processing and Scientific Computing
|
| 21 |
+
numpy>=1.24.0 # Numerical computing library
|
| 22 |
+
pandas>=2.0.0 # Data manipulation and analysis
|
| 23 |
+
scipy>=1.10.0 # Scientific computing utilities
|
| 24 |
|
| 25 |
+
# Progress and Monitoring
|
| 26 |
+
tqdm>=4.65.0 # Progress bars for long-running operations
|
| 27 |
+
psutil>=5.9.0 # System and process utilities
|
| 28 |
|
| 29 |
+
# Memory and Performance Optimization
|
| 30 |
+
bitsandbytes>=0.41.0 # Quantization utilities for memory efficiency
|
| 31 |
+
peft>=0.4.0 # Parameter-Efficient Fine-Tuning
|
| 32 |
|
| 33 |
+
# Logging and Debugging
|
| 34 |
+
wandb>=0.15.0 # Experiment tracking (optional)
|
| 35 |
+
tensorboard>=2.13.0 # Training visualization (optional)
|
| 36 |
+
|
| 37 |
+
# Additional Utilities
|
| 38 |
+
requests>=2.31.0 # HTTP library for API calls
|
| 39 |
+
pillow>=9.5.0 # Image processing (if needed)
|
| 40 |
+
matplotlib>=3.7.0 # Plotting and visualization
|
| 41 |
+
seaborn>=0.12.0 # Statistical data visualization
|
| 42 |
+
|
| 43 |
+
# Development and Testing (optional)
|
| 44 |
+
pytest>=7.4.0 # Testing framework
|
| 45 |
+
black>=23.0.0 # Code formatting
|
| 46 |
+
flake8>=6.0.0 # Code linting
|
| 47 |
+
|
| 48 |
+
# Note: These versions are compatible with Hugging Face Spaces
|
| 49 |
+
# and provide stable training performance for OpenLLM models
|
| 50 |
+
# Gradio 4.44.1 fixes compatibility issues with JSON components
|
| 51 |
+
# SentencePiece is CRITICAL for OpenLLM model tokenization
|
training/data_loader.py
ADDED
|
@@ -0,0 +1,480 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright (C) 2024 Louis Chua Bean Chong
|
| 3 |
+
#
|
| 4 |
+
# This file is part of OpenLLM.
|
| 5 |
+
#
|
| 6 |
+
# OpenLLM is dual-licensed:
|
| 7 |
+
# 1. For open source use: GNU General Public License v3.0
|
| 8 |
+
# 2. For commercial use: Commercial License (contact for details)
|
| 9 |
+
#
|
| 10 |
+
# See LICENSE and docs/LICENSES.md for full license information.
|
| 11 |
+
|
| 12 |
+
"""
|
| 13 |
+
Training Data Loader for Language Model Training
|
| 14 |
+
|
| 15 |
+
This module provides efficient data loading and batching for training GPT-style
|
| 16 |
+
language models. It handles text preprocessing, tokenization, and creates
|
| 17 |
+
batches suitable for autoregressive language modeling.
|
| 18 |
+
|
| 19 |
+
FEATURES:
|
| 20 |
+
- Memory-efficient text loading with sliding window
|
| 21 |
+
- Automatic tokenization using trained SentencePiece model
|
| 22 |
+
- Configurable sequence length and batch size
|
| 23 |
+
- CPU-optimized data loading for limited hardware
|
| 24 |
+
- Support for training data validation and statistics
|
| 25 |
+
|
| 26 |
+
MEMORY OPTIMIZATION:
|
| 27 |
+
- Streaming data loading (doesn't load entire dataset to memory)
|
| 28 |
+
- Configurable chunk sizes for large files
|
| 29 |
+
- Efficient tensor creation and batching
|
| 30 |
+
- Garbage collection hints for memory management
|
| 31 |
+
|
| 32 |
+
Usage:
|
| 33 |
+
from data_loader import TextDataLoader
|
| 34 |
+
|
| 35 |
+
loader = TextDataLoader(
|
| 36 |
+
data_file="data/clean/training_data.txt",
|
| 37 |
+
tokenizer_path="data/tokenizer/tokenizer.model",
|
| 38 |
+
seq_len=512,
|
| 39 |
+
batch_size=4
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
for batch in loader:
|
| 43 |
+
input_ids, targets = batch
|
| 44 |
+
# input_ids: (batch_size, seq_len)
|
| 45 |
+
# targets: (batch_size, seq_len) - shifted by 1 for next token prediction
|
| 46 |
+
|
| 47 |
+
Author: Louis Chua Bean Chong
|
| 48 |
+
License: GPLv3
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
import os
|
| 52 |
+
import gc
|
| 53 |
+
import random
|
| 54 |
+
import torch
|
| 55 |
+
import time
|
| 56 |
+
from typing import Iterator, Tuple, List, Optional
|
| 57 |
+
from pathlib import Path
|
| 58 |
+
|
| 59 |
+
try:
|
| 60 |
+
import sentencepiece as spm
|
| 61 |
+
except ImportError:
|
| 62 |
+
print("ERROR: SentencePiece not installed. Run: pip install sentencepiece")
|
| 63 |
+
exit(1)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class TextDataLoader:
|
| 67 |
+
"""
|
| 68 |
+
Efficient data loader for autoregressive language model training.
|
| 69 |
+
|
| 70 |
+
This class handles loading text data, tokenizing it using SentencePiece,
|
| 71 |
+
and creating batches suitable for next-token prediction training.
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
def __init__(
|
| 75 |
+
self,
|
| 76 |
+
data_file: str,
|
| 77 |
+
tokenizer_path: str,
|
| 78 |
+
seq_len: int = 512,
|
| 79 |
+
batch_size: int = 4,
|
| 80 |
+
chunk_size: int = 1000000, # Lines to read at once
|
| 81 |
+
shuffle: bool = True,
|
| 82 |
+
seed: int = 42
|
| 83 |
+
):
|
| 84 |
+
"""
|
| 85 |
+
Initialize the data loader.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
data_file: Path to training text file (one passage per line)
|
| 89 |
+
tokenizer_path: Path to trained SentencePiece model
|
| 90 |
+
seq_len: Maximum sequence length for training
|
| 91 |
+
batch_size: Batch size for training
|
| 92 |
+
chunk_size: Number of lines to read in memory at once
|
| 93 |
+
shuffle: Whether to shuffle training examples
|
| 94 |
+
seed: Random seed for reproducibility
|
| 95 |
+
"""
|
| 96 |
+
self.data_file = data_file
|
| 97 |
+
self.tokenizer_path = tokenizer_path
|
| 98 |
+
self.seq_len = seq_len
|
| 99 |
+
self.batch_size = batch_size
|
| 100 |
+
self.chunk_size = chunk_size
|
| 101 |
+
self.shuffle = shuffle
|
| 102 |
+
self.seed = seed
|
| 103 |
+
|
| 104 |
+
# Validate inputs
|
| 105 |
+
self._validate_inputs()
|
| 106 |
+
|
| 107 |
+
# Load tokenizer
|
| 108 |
+
self.tokenizer = self._load_tokenizer()
|
| 109 |
+
|
| 110 |
+
# Get data statistics
|
| 111 |
+
self.total_lines = self._count_lines()
|
| 112 |
+
self.current_line = 0
|
| 113 |
+
|
| 114 |
+
# Set random seed for reproducibility
|
| 115 |
+
random.seed(seed)
|
| 116 |
+
|
| 117 |
+
print(f"π TextDataLoader initialized")
|
| 118 |
+
print(f" Data file: {data_file}")
|
| 119 |
+
print(f" Total passages: {self.total_lines:,}")
|
| 120 |
+
print(f" Sequence length: {seq_len}")
|
| 121 |
+
print(f" Batch size: {batch_size}")
|
| 122 |
+
print(f" Vocabulary size: {self.tokenizer.vocab_size():,}")
|
| 123 |
+
|
| 124 |
+
def _validate_inputs(self) -> None:
|
| 125 |
+
"""Validate input parameters and file paths."""
|
| 126 |
+
if not os.path.exists(self.data_file):
|
| 127 |
+
raise FileNotFoundError(f"Training data file not found: {self.data_file}")
|
| 128 |
+
|
| 129 |
+
if not os.path.exists(self.tokenizer_path):
|
| 130 |
+
raise FileNotFoundError(f"Tokenizer model not found: {self.tokenizer_path}")
|
| 131 |
+
|
| 132 |
+
if self.seq_len <= 0:
|
| 133 |
+
raise ValueError(f"Sequence length must be positive, got {self.seq_len}")
|
| 134 |
+
|
| 135 |
+
if self.batch_size <= 0:
|
| 136 |
+
raise ValueError(f"Batch size must be positive, got {self.batch_size}")
|
| 137 |
+
|
| 138 |
+
if self.chunk_size <= 0:
|
| 139 |
+
raise ValueError(f"Chunk size must be positive, got {self.chunk_size}")
|
| 140 |
+
|
| 141 |
+
def _load_tokenizer(self) -> spm.SentencePieceProcessor:
|
| 142 |
+
"""Load the trained SentencePiece tokenizer."""
|
| 143 |
+
try:
|
| 144 |
+
tokenizer = spm.SentencePieceProcessor()
|
| 145 |
+
tokenizer.load(self.tokenizer_path)
|
| 146 |
+
return tokenizer
|
| 147 |
+
except Exception as e:
|
| 148 |
+
raise RuntimeError(f"Failed to load tokenizer: {e}")
|
| 149 |
+
|
| 150 |
+
def _count_lines(self) -> int:
|
| 151 |
+
"""Count total number of lines in the data file."""
|
| 152 |
+
print("π Counting training passages...")
|
| 153 |
+
start_time = time.time()
|
| 154 |
+
|
| 155 |
+
line_count = 0
|
| 156 |
+
with open(self.data_file, 'r', encoding='utf-8') as f:
|
| 157 |
+
for line in f:
|
| 158 |
+
if line.strip(): # Only count non-empty lines
|
| 159 |
+
line_count += 1
|
| 160 |
+
|
| 161 |
+
count_time = time.time() - start_time
|
| 162 |
+
print(f"β Found {line_count:,} passages in {count_time:.1f}s")
|
| 163 |
+
|
| 164 |
+
return line_count
|
| 165 |
+
|
| 166 |
+
def _read_chunk(self, start_line: int = 0) -> List[str]:
|
| 167 |
+
"""
|
| 168 |
+
Read a chunk of lines from the data file.
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
start_line: Line number to start reading from
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
List of text passages
|
| 175 |
+
"""
|
| 176 |
+
chunk = []
|
| 177 |
+
current_line = 0
|
| 178 |
+
lines_read = 0
|
| 179 |
+
|
| 180 |
+
with open(self.data_file, 'r', encoding='utf-8') as f:
|
| 181 |
+
for line in f:
|
| 182 |
+
if current_line < start_line:
|
| 183 |
+
current_line += 1
|
| 184 |
+
continue
|
| 185 |
+
|
| 186 |
+
text = line.strip()
|
| 187 |
+
if text: # Only include non-empty lines
|
| 188 |
+
chunk.append(text)
|
| 189 |
+
lines_read += 1
|
| 190 |
+
|
| 191 |
+
if lines_read >= self.chunk_size:
|
| 192 |
+
break
|
| 193 |
+
|
| 194 |
+
current_line += 1
|
| 195 |
+
|
| 196 |
+
return chunk
|
| 197 |
+
|
| 198 |
+
def _tokenize_texts(self, texts: List[str]) -> List[List[int]]:
|
| 199 |
+
"""
|
| 200 |
+
Tokenize a list of text passages using SentencePiece tokenizer.
|
| 201 |
+
|
| 202 |
+
This method converts raw text into token ID sequences suitable for language model training.
|
| 203 |
+
It handles special tokens (BOS/EOS) and length constraints for efficient training.
|
| 204 |
+
|
| 205 |
+
Text processing pipeline:
|
| 206 |
+
1. Add BOS (Beginning of Sequence) token to mark sequence start
|
| 207 |
+
2. Tokenize text using trained SentencePiece model (subword tokenization)
|
| 208 |
+
3. Truncate sequences that exceed maximum length
|
| 209 |
+
4. Add EOS (End of Sequence) token to mark sequence end
|
| 210 |
+
|
| 211 |
+
Special token handling:
|
| 212 |
+
- BOS token helps model learn to generate text from scratch
|
| 213 |
+
- EOS token signals natural sequence endings
|
| 214 |
+
- These tokens are crucial for proper autoregressive generation
|
| 215 |
+
|
| 216 |
+
Args:
|
| 217 |
+
texts: List of text passages (typically Wikipedia passages from SQUAD)
|
| 218 |
+
Each passage should be a complete, coherent text segment
|
| 219 |
+
|
| 220 |
+
Returns:
|
| 221 |
+
List of token ID sequences, where each sequence is a list of integers
|
| 222 |
+
representing subword tokens from the SentencePiece vocabulary
|
| 223 |
+
"""
|
| 224 |
+
tokenized = []
|
| 225 |
+
|
| 226 |
+
for text in texts:
|
| 227 |
+
try:
|
| 228 |
+
# Add BOS (Beginning of Sequence) token at the start
|
| 229 |
+
# BOS token ID=2 by default in SentencePiece, signals sequence start
|
| 230 |
+
# This helps the model learn proper sequence initialization during generation
|
| 231 |
+
tokens = [self.tokenizer.bos_id()] + self.tokenizer.encode(text)
|
| 232 |
+
|
| 233 |
+
# Truncate sequences that exceed maximum context length
|
| 234 |
+
# Reserve one position for EOS token by using (seq_len - 1)
|
| 235 |
+
# This ensures we never exceed the model's context window during training
|
| 236 |
+
if len(tokens) > self.seq_len - 1:
|
| 237 |
+
tokens = tokens[:self.seq_len - 1]
|
| 238 |
+
# NOTE: Truncation may cut off text mid-sentence, but this is acceptable
|
| 239 |
+
# for language modeling where the model learns from partial contexts
|
| 240 |
+
|
| 241 |
+
# Add EOS (End of Sequence) token at the end
|
| 242 |
+
# EOS token ID=1 by default in SentencePiece, signals sequence completion
|
| 243 |
+
# This teaches the model when to stop generating text naturally
|
| 244 |
+
tokens.append(self.tokenizer.eos_id())
|
| 245 |
+
|
| 246 |
+
# Validate tokenization result
|
| 247 |
+
if len(tokens) <= 2: # Only BOS + EOS tokens, no actual content
|
| 248 |
+
print(f"β οΈ Skipping very short text: {text[:50]}...")
|
| 249 |
+
continue
|
| 250 |
+
|
| 251 |
+
tokenized.append(tokens)
|
| 252 |
+
|
| 253 |
+
except Exception as e:
|
| 254 |
+
# Handle tokenization errors gracefully to avoid stopping training
|
| 255 |
+
# Common causes: encoding issues, very long texts, special characters
|
| 256 |
+
print(f"β οΈ Failed to tokenize passage: {text[:50]}... Error: {e}")
|
| 257 |
+
continue
|
| 258 |
+
|
| 259 |
+
# Log tokenization statistics for monitoring
|
| 260 |
+
if tokenized:
|
| 261 |
+
avg_length = sum(len(tokens) for tokens in tokenized) / len(tokenized)
|
| 262 |
+
print(f"π Tokenized {len(tokenized)} passages, avg length: {avg_length:.1f} tokens")
|
| 263 |
+
|
| 264 |
+
return tokenized
|
| 265 |
+
|
| 266 |
+
def _create_training_examples(self, token_sequences: List[List[int]]) -> List[Tuple[List[int], List[int]]]:
|
| 267 |
+
"""
|
| 268 |
+
Create training examples with input and target sequences.
|
| 269 |
+
|
| 270 |
+
For autoregressive training, targets are inputs shifted by one position.
|
| 271 |
+
|
| 272 |
+
Args:
|
| 273 |
+
token_sequences: List of tokenized sequences
|
| 274 |
+
|
| 275 |
+
Returns:
|
| 276 |
+
List of (input_ids, target_ids) tuples
|
| 277 |
+
"""
|
| 278 |
+
examples = []
|
| 279 |
+
|
| 280 |
+
for tokens in token_sequences:
|
| 281 |
+
if len(tokens) < 2: # Need at least 2 tokens for input/target pair
|
| 282 |
+
continue
|
| 283 |
+
|
| 284 |
+
# For sequences longer than seq_len, create multiple examples with sliding window
|
| 285 |
+
if len(tokens) > self.seq_len:
|
| 286 |
+
# Create overlapping windows (50% overlap for better learning)
|
| 287 |
+
stride = self.seq_len // 2
|
| 288 |
+
for i in range(0, len(tokens) - self.seq_len, stride):
|
| 289 |
+
input_ids = tokens[i:i + self.seq_len]
|
| 290 |
+
target_ids = tokens[i + 1:i + self.seq_len + 1]
|
| 291 |
+
examples.append((input_ids, target_ids))
|
| 292 |
+
else:
|
| 293 |
+
# Pad shorter sequences
|
| 294 |
+
input_ids = tokens[:-1] # All but last token
|
| 295 |
+
target_ids = tokens[1:] # All but first token
|
| 296 |
+
|
| 297 |
+
# Pad to seq_len if necessary
|
| 298 |
+
while len(input_ids) < self.seq_len:
|
| 299 |
+
input_ids.append(self.tokenizer.pad_id())
|
| 300 |
+
target_ids.append(-1) # Use -1 for padding in targets (ignored in loss)
|
| 301 |
+
|
| 302 |
+
# Truncate if still too long
|
| 303 |
+
input_ids = input_ids[:self.seq_len]
|
| 304 |
+
target_ids = target_ids[:self.seq_len]
|
| 305 |
+
|
| 306 |
+
examples.append((input_ids, target_ids))
|
| 307 |
+
|
| 308 |
+
return examples
|
| 309 |
+
|
| 310 |
+
def _create_batch(self, examples: List[Tuple[List[int], List[int]]]) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 311 |
+
"""
|
| 312 |
+
Create a batch tensor from training examples.
|
| 313 |
+
|
| 314 |
+
Args:
|
| 315 |
+
examples: List of (input_ids, target_ids) tuples
|
| 316 |
+
|
| 317 |
+
Returns:
|
| 318 |
+
Tuple of (input_tensor, target_tensor)
|
| 319 |
+
"""
|
| 320 |
+
if not examples:
|
| 321 |
+
raise ValueError("Cannot create batch from empty examples")
|
| 322 |
+
|
| 323 |
+
batch_size = len(examples)
|
| 324 |
+
|
| 325 |
+
# Initialize tensors
|
| 326 |
+
input_ids = torch.zeros((batch_size, self.seq_len), dtype=torch.long)
|
| 327 |
+
target_ids = torch.full((batch_size, self.seq_len), -1, dtype=torch.long)
|
| 328 |
+
|
| 329 |
+
# Fill tensors
|
| 330 |
+
for i, (inp, tgt) in enumerate(examples):
|
| 331 |
+
input_ids[i, :len(inp)] = torch.tensor(inp, dtype=torch.long)
|
| 332 |
+
target_ids[i, :len(tgt)] = torch.tensor(tgt, dtype=torch.long)
|
| 333 |
+
|
| 334 |
+
return input_ids, target_ids
|
| 335 |
+
|
| 336 |
+
def __iter__(self) -> Iterator[Tuple[torch.Tensor, torch.Tensor]]:
|
| 337 |
+
"""
|
| 338 |
+
Iterate over training batches.
|
| 339 |
+
|
| 340 |
+
Yields:
|
| 341 |
+
Tuple of (input_ids, target_ids) tensors
|
| 342 |
+
"""
|
| 343 |
+
self.current_line = 0
|
| 344 |
+
|
| 345 |
+
while self.current_line < self.total_lines:
|
| 346 |
+
# Read chunk of text
|
| 347 |
+
texts = self._read_chunk(self.current_line)
|
| 348 |
+
if not texts:
|
| 349 |
+
break
|
| 350 |
+
|
| 351 |
+
# Tokenize texts
|
| 352 |
+
token_sequences = self._tokenize_texts(texts)
|
| 353 |
+
|
| 354 |
+
# Create training examples
|
| 355 |
+
examples = self._create_training_examples(token_sequences)
|
| 356 |
+
|
| 357 |
+
# Shuffle examples if requested
|
| 358 |
+
if self.shuffle:
|
| 359 |
+
random.shuffle(examples)
|
| 360 |
+
|
| 361 |
+
# Create batches
|
| 362 |
+
for i in range(0, len(examples), self.batch_size):
|
| 363 |
+
batch_examples = examples[i:i + self.batch_size]
|
| 364 |
+
|
| 365 |
+
if len(batch_examples) == self.batch_size: # Only yield full batches
|
| 366 |
+
try:
|
| 367 |
+
input_ids, target_ids = self._create_batch(batch_examples)
|
| 368 |
+
yield input_ids, target_ids
|
| 369 |
+
except Exception as e:
|
| 370 |
+
print(f"β οΈ Failed to create batch: {e}")
|
| 371 |
+
continue
|
| 372 |
+
|
| 373 |
+
# Update progress
|
| 374 |
+
self.current_line += len(texts)
|
| 375 |
+
|
| 376 |
+
# Clean up memory
|
| 377 |
+
del texts, token_sequences, examples
|
| 378 |
+
gc.collect()
|
| 379 |
+
|
| 380 |
+
def get_data_stats(self) -> dict:
|
| 381 |
+
"""
|
| 382 |
+
Get statistics about the training data.
|
| 383 |
+
|
| 384 |
+
Returns:
|
| 385 |
+
Dictionary with data statistics
|
| 386 |
+
"""
|
| 387 |
+
print("π Analyzing training data...")
|
| 388 |
+
|
| 389 |
+
# Sample some data to get statistics
|
| 390 |
+
sample_texts = self._read_chunk(0)[:100] # Sample first 100 passages
|
| 391 |
+
token_sequences = self._tokenize_texts(sample_texts)
|
| 392 |
+
|
| 393 |
+
if token_sequences:
|
| 394 |
+
sequence_lengths = [len(seq) for seq in token_sequences]
|
| 395 |
+
avg_length = sum(sequence_lengths) / len(sequence_lengths)
|
| 396 |
+
max_length = max(sequence_lengths)
|
| 397 |
+
min_length = min(sequence_lengths)
|
| 398 |
+
else:
|
| 399 |
+
avg_length = max_length = min_length = 0
|
| 400 |
+
|
| 401 |
+
# Estimate total tokens
|
| 402 |
+
estimated_total_tokens = int(avg_length * self.total_lines)
|
| 403 |
+
|
| 404 |
+
# Estimate number of batches per epoch
|
| 405 |
+
examples_per_passage = max(1, avg_length // self.seq_len)
|
| 406 |
+
total_examples = int(self.total_lines * examples_per_passage)
|
| 407 |
+
batches_per_epoch = total_examples // self.batch_size
|
| 408 |
+
|
| 409 |
+
stats = {
|
| 410 |
+
"total_passages": self.total_lines,
|
| 411 |
+
"avg_tokens_per_passage": avg_length,
|
| 412 |
+
"min_tokens_per_passage": min_length,
|
| 413 |
+
"max_tokens_per_passage": max_length,
|
| 414 |
+
"estimated_total_tokens": estimated_total_tokens,
|
| 415 |
+
"estimated_examples_per_epoch": total_examples,
|
| 416 |
+
"estimated_batches_per_epoch": batches_per_epoch,
|
| 417 |
+
"sequence_length": self.seq_len,
|
| 418 |
+
"batch_size": self.batch_size,
|
| 419 |
+
"vocabulary_size": self.tokenizer.vocab_size()
|
| 420 |
+
}
|
| 421 |
+
|
| 422 |
+
print(f"β Data analysis complete:")
|
| 423 |
+
print(f" Total passages: {stats['total_passages']:,}")
|
| 424 |
+
print(f" Avg tokens per passage: {stats['avg_tokens_per_passage']:.1f}")
|
| 425 |
+
print(f" Estimated total tokens: {stats['estimated_total_tokens']:,}")
|
| 426 |
+
print(f" Estimated batches per epoch: {stats['estimated_batches_per_epoch']:,}")
|
| 427 |
+
|
| 428 |
+
return stats
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
def test_data_loader():
|
| 432 |
+
"""Test function for the data loader."""
|
| 433 |
+
print("π§ͺ Testing TextDataLoader...")
|
| 434 |
+
|
| 435 |
+
# Test with small parameters
|
| 436 |
+
try:
|
| 437 |
+
loader = TextDataLoader(
|
| 438 |
+
data_file="data/clean/training_data.txt",
|
| 439 |
+
tokenizer_path="data/tokenizer/tokenizer.model",
|
| 440 |
+
seq_len=128,
|
| 441 |
+
batch_size=2,
|
| 442 |
+
chunk_size=10 # Small for testing
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
# Get data statistics
|
| 446 |
+
stats = loader.get_data_stats()
|
| 447 |
+
|
| 448 |
+
# Test iteration
|
| 449 |
+
print("\nπ Testing batch iteration...")
|
| 450 |
+
start_time = time.time()
|
| 451 |
+
batch_count = 0
|
| 452 |
+
|
| 453 |
+
for batch_idx, (input_ids, target_ids) in enumerate(loader):
|
| 454 |
+
batch_count += 1
|
| 455 |
+
|
| 456 |
+
print(f"Batch {batch_idx + 1}:")
|
| 457 |
+
print(f" Input shape: {input_ids.shape}")
|
| 458 |
+
print(f" Target shape: {target_ids.shape}")
|
| 459 |
+
print(f" Sample input tokens: {input_ids[0][:10].tolist()}")
|
| 460 |
+
print(f" Sample target tokens: {target_ids[0][:10].tolist()}")
|
| 461 |
+
|
| 462 |
+
if batch_idx >= 2: # Only test first few batches
|
| 463 |
+
break
|
| 464 |
+
|
| 465 |
+
test_time = time.time() - start_time
|
| 466 |
+
print(f"\nβ Data loader test completed successfully!")
|
| 467 |
+
print(f" Processed {batch_count} batches in {test_time:.2f}s")
|
| 468 |
+
print(f" Average time per batch: {test_time/max(1, batch_count):.2f}s")
|
| 469 |
+
|
| 470 |
+
return True
|
| 471 |
+
|
| 472 |
+
except Exception as e:
|
| 473 |
+
print(f"β Data loader test failed: {e}")
|
| 474 |
+
import traceback
|
| 475 |
+
traceback.print_exc()
|
| 476 |
+
return False
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
if __name__ == "__main__":
|
| 480 |
+
test_data_loader()
|
training/evaluate_model.py
ADDED
|
@@ -0,0 +1,782 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright (C) 2024 Louis Chua Bean Chong
|
| 3 |
+
#
|
| 4 |
+
# This file is part of OpenLLM.
|
| 5 |
+
#
|
| 6 |
+
# OpenLLM is dual-licensed:
|
| 7 |
+
# 1. For open source use: GNU General Public License v3.0
|
| 8 |
+
# 2. For commercial use: Commercial License (contact for details)
|
| 9 |
+
#
|
| 10 |
+
# See LICENSE and docs/LICENSES.md for full license information.
|
| 11 |
+
|
| 12 |
+
"""
|
| 13 |
+
OpenLLM Model Evaluation Script
|
| 14 |
+
|
| 15 |
+
This script implements comprehensive evaluation for trained OpenLLM models,
|
| 16 |
+
including intrinsic evaluation (perplexity, loss) and text generation quality
|
| 17 |
+
assessment as specified in Step 5 of the training pipeline.
|
| 18 |
+
|
| 19 |
+
Usage:
|
| 20 |
+
python core/src/evaluate_model.py \
|
| 21 |
+
--model_dir models/openllm-medium \
|
| 22 |
+
--eval_data data/clean/validation_data.txt \
|
| 23 |
+
--metrics perplexity,loss
|
| 24 |
+
|
| 25 |
+
Features:
|
| 26 |
+
- Perplexity calculation on held-out data
|
| 27 |
+
- Text generation quality assessment
|
| 28 |
+
- Multiple evaluation metrics
|
| 29 |
+
- Comprehensive quality benchmarks
|
| 30 |
+
- JSON output for downstream analysis
|
| 31 |
+
|
| 32 |
+
Author: Louis Chua Bean Chong
|
| 33 |
+
License: GPLv3
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
import argparse
|
| 37 |
+
import json
|
| 38 |
+
import os
|
| 39 |
+
import sys
|
| 40 |
+
import time
|
| 41 |
+
import math
|
| 42 |
+
from pathlib import Path
|
| 43 |
+
from typing import Dict, List, Optional, Tuple, Any
|
| 44 |
+
|
| 45 |
+
import torch
|
| 46 |
+
import torch.nn.functional as F
|
| 47 |
+
import sentencepiece as smp
|
| 48 |
+
|
| 49 |
+
# Add current directory to path for imports
|
| 50 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 51 |
+
|
| 52 |
+
from model import GPTModel, create_model
|
| 53 |
+
from data_loader import TextDataLoader
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class ModelEvaluator:
|
| 57 |
+
"""
|
| 58 |
+
Comprehensive evaluator for OpenLLM models.
|
| 59 |
+
|
| 60 |
+
Implements intrinsic evaluation metrics and text generation quality
|
| 61 |
+
assessment following the training pipeline specifications.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
def __init__(
|
| 65 |
+
self,
|
| 66 |
+
model: GPTModel,
|
| 67 |
+
tokenizer_path: str,
|
| 68 |
+
device: str = "cpu"
|
| 69 |
+
):
|
| 70 |
+
"""
|
| 71 |
+
Initialize the model evaluator.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
model: Trained GPT model
|
| 75 |
+
tokenizer_path: Path to tokenizer model file
|
| 76 |
+
device: Device to run evaluation on
|
| 77 |
+
"""
|
| 78 |
+
self.model = model.to(device)
|
| 79 |
+
self.device = device
|
| 80 |
+
|
| 81 |
+
# Load tokenizer
|
| 82 |
+
self.tokenizer = smp.SentencePieceProcessor()
|
| 83 |
+
self.tokenizer.load(tokenizer_path)
|
| 84 |
+
|
| 85 |
+
print(f"π§ ModelEvaluator initialized")
|
| 86 |
+
print(f" Device: {device}")
|
| 87 |
+
print(f" Model parameters: {model.get_num_params():,}")
|
| 88 |
+
print(f" Vocabulary size: {self.tokenizer.vocab_size():,}")
|
| 89 |
+
|
| 90 |
+
def evaluate_perplexity(
|
| 91 |
+
self,
|
| 92 |
+
eval_data: List[str],
|
| 93 |
+
max_seq_len: int = 512,
|
| 94 |
+
batch_size: int = 1
|
| 95 |
+
) -> Dict[str, float]:
|
| 96 |
+
"""
|
| 97 |
+
Calculate perplexity on evaluation data.
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
eval_data: List of text passages for evaluation
|
| 101 |
+
max_seq_len: Maximum sequence length for evaluation
|
| 102 |
+
batch_size: Batch size for evaluation
|
| 103 |
+
|
| 104 |
+
Returns:
|
| 105 |
+
Dictionary with loss and perplexity metrics
|
| 106 |
+
"""
|
| 107 |
+
self.model.eval()
|
| 108 |
+
total_loss = 0.0
|
| 109 |
+
total_tokens = 0
|
| 110 |
+
num_sequences = 0
|
| 111 |
+
|
| 112 |
+
print(f"π Calculating perplexity on {len(eval_data)} passages...")
|
| 113 |
+
|
| 114 |
+
with torch.no_grad():
|
| 115 |
+
for i, text in enumerate(eval_data):
|
| 116 |
+
if i % 100 == 0:
|
| 117 |
+
print(f" Progress: {i}/{len(eval_data)} passages")
|
| 118 |
+
|
| 119 |
+
# Tokenize text
|
| 120 |
+
tokens = self.tokenizer.encode(text)
|
| 121 |
+
if len(tokens) < 2:
|
| 122 |
+
continue
|
| 123 |
+
|
| 124 |
+
# Truncate if too long
|
| 125 |
+
if len(tokens) > max_seq_len:
|
| 126 |
+
tokens = tokens[:max_seq_len]
|
| 127 |
+
|
| 128 |
+
# Create input and target tensors
|
| 129 |
+
input_ids = torch.tensor([tokens[:-1]], dtype=torch.long, device=self.device)
|
| 130 |
+
target_ids = torch.tensor([tokens[1:]], dtype=torch.long, device=self.device)
|
| 131 |
+
|
| 132 |
+
# Forward pass
|
| 133 |
+
logits, loss = self.model(input_ids, target_ids)
|
| 134 |
+
|
| 135 |
+
# Accumulate loss
|
| 136 |
+
seq_length = len(tokens) - 1
|
| 137 |
+
total_loss += loss.item() * seq_length
|
| 138 |
+
total_tokens += seq_length
|
| 139 |
+
num_sequences += 1
|
| 140 |
+
|
| 141 |
+
# Calculate metrics
|
| 142 |
+
avg_loss = total_loss / total_tokens if total_tokens > 0 else float('inf')
|
| 143 |
+
perplexity = math.exp(min(avg_loss, 10)) # Cap to prevent overflow
|
| 144 |
+
|
| 145 |
+
return {
|
| 146 |
+
'loss': avg_loss,
|
| 147 |
+
'perplexity': perplexity,
|
| 148 |
+
'total_tokens': total_tokens,
|
| 149 |
+
'num_sequences': num_sequences
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
def evaluate_text_generation(
|
| 153 |
+
self,
|
| 154 |
+
prompts: List[str],
|
| 155 |
+
max_length: int = 256,
|
| 156 |
+
temperature: float = 0.7,
|
| 157 |
+
top_k: Optional[int] = 40,
|
| 158 |
+
num_samples: int = 1
|
| 159 |
+
) -> List[Dict[str, Any]]:
|
| 160 |
+
"""
|
| 161 |
+
Evaluate text generation quality.
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
prompts: List of input prompts
|
| 165 |
+
max_length: Maximum generation length
|
| 166 |
+
temperature: Sampling temperature
|
| 167 |
+
top_k: Top-k sampling parameter
|
| 168 |
+
num_samples: Number of samples per prompt
|
| 169 |
+
|
| 170 |
+
Returns:
|
| 171 |
+
List of generation results with quality metrics
|
| 172 |
+
"""
|
| 173 |
+
self.model.eval()
|
| 174 |
+
results = []
|
| 175 |
+
|
| 176 |
+
print(f"βοΈ Evaluating text generation on {len(prompts)} prompts...")
|
| 177 |
+
|
| 178 |
+
with torch.no_grad():
|
| 179 |
+
for prompt in prompts:
|
| 180 |
+
prompt_results = []
|
| 181 |
+
|
| 182 |
+
for sample_idx in range(num_samples):
|
| 183 |
+
# Tokenize prompt
|
| 184 |
+
input_ids = self.tokenizer.encode(prompt)
|
| 185 |
+
input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.device)
|
| 186 |
+
|
| 187 |
+
start_time = time.time()
|
| 188 |
+
|
| 189 |
+
# Generate text
|
| 190 |
+
output = self.model.generate(
|
| 191 |
+
input_tensor,
|
| 192 |
+
max_new_tokens=max_length,
|
| 193 |
+
temperature=temperature,
|
| 194 |
+
top_k=top_k
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
generation_time = time.time() - start_time
|
| 198 |
+
|
| 199 |
+
# Decode output
|
| 200 |
+
generated_ids = output[0].tolist()
|
| 201 |
+
full_text = self.tokenizer.decode(generated_ids)
|
| 202 |
+
generated_text = self.tokenizer.decode(generated_ids[len(input_ids):])
|
| 203 |
+
|
| 204 |
+
# Calculate quality metrics
|
| 205 |
+
quality_metrics = self._assess_generation_quality(generated_text)
|
| 206 |
+
|
| 207 |
+
prompt_results.append({
|
| 208 |
+
'prompt': prompt,
|
| 209 |
+
'generated_text': generated_text,
|
| 210 |
+
'full_text': full_text,
|
| 211 |
+
'generation_time': generation_time,
|
| 212 |
+
'tokens_generated': len(generated_ids) - len(input_ids),
|
| 213 |
+
'tokens_per_second': (len(generated_ids) - len(input_ids)) / generation_time,
|
| 214 |
+
'quality_metrics': quality_metrics
|
| 215 |
+
})
|
| 216 |
+
|
| 217 |
+
results.extend(prompt_results)
|
| 218 |
+
|
| 219 |
+
return results
|
| 220 |
+
|
| 221 |
+
def _assess_generation_quality(self, text: str) -> Dict[str, float]:
|
| 222 |
+
"""
|
| 223 |
+
Assess basic quality metrics for generated text.
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
text: Generated text to assess
|
| 227 |
+
|
| 228 |
+
Returns:
|
| 229 |
+
Dictionary of quality metrics
|
| 230 |
+
"""
|
| 231 |
+
if not text.strip():
|
| 232 |
+
return {
|
| 233 |
+
'length': 0,
|
| 234 |
+
'avg_word_length': 0,
|
| 235 |
+
'repetition_rate': 1.0,
|
| 236 |
+
'coherence_score': 0.0
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
words = text.split()
|
| 240 |
+
|
| 241 |
+
# Basic metrics
|
| 242 |
+
length = len(words)
|
| 243 |
+
avg_word_length = sum(len(word) for word in words) / len(words) if words else 0
|
| 244 |
+
|
| 245 |
+
# Repetition rate (simple n-gram repetition)
|
| 246 |
+
bigrams = [f"{words[i]} {words[i+1]}" for i in range(len(words)-1)]
|
| 247 |
+
unique_bigrams = len(set(bigrams))
|
| 248 |
+
repetition_rate = 1 - (unique_bigrams / len(bigrams) if bigrams else 0)
|
| 249 |
+
|
| 250 |
+
# Simple coherence score (based on sentence structure)
|
| 251 |
+
sentences = text.split('.')
|
| 252 |
+
valid_sentences = [s for s in sentences if len(s.strip().split()) > 3]
|
| 253 |
+
coherence_score = len(valid_sentences) / len(sentences) if sentences else 0
|
| 254 |
+
|
| 255 |
+
return {
|
| 256 |
+
'length': length,
|
| 257 |
+
'avg_word_length': avg_word_length,
|
| 258 |
+
'repetition_rate': repetition_rate,
|
| 259 |
+
'coherence_score': coherence_score
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
def evaluate_downstream_tasks(self) -> Dict[str, Any]:
|
| 263 |
+
"""
|
| 264 |
+
Evaluate model performance on downstream tasks.
|
| 265 |
+
|
| 266 |
+
This function implements basic downstream task evaluation including:
|
| 267 |
+
- Reading comprehension (simplified SQUAD-style)
|
| 268 |
+
- Sentiment analysis (few-shot)
|
| 269 |
+
- Common sense reasoning
|
| 270 |
+
|
| 271 |
+
Returns:
|
| 272 |
+
Dictionary of downstream task results
|
| 273 |
+
"""
|
| 274 |
+
results = {}
|
| 275 |
+
|
| 276 |
+
# 1. Reading Comprehension (Simplified SQUAD-style)
|
| 277 |
+
results['reading_comprehension'] = self._evaluate_reading_comprehension()
|
| 278 |
+
|
| 279 |
+
# 2. Sentiment Analysis (Few-shot learning)
|
| 280 |
+
results['sentiment_analysis'] = self._evaluate_sentiment_analysis()
|
| 281 |
+
|
| 282 |
+
# 3. Common Sense Reasoning
|
| 283 |
+
results['reasoning'] = self._evaluate_reasoning()
|
| 284 |
+
|
| 285 |
+
# 4. Text Completion Quality
|
| 286 |
+
results['text_completion'] = self._evaluate_text_completion()
|
| 287 |
+
|
| 288 |
+
return results
|
| 289 |
+
|
| 290 |
+
def _evaluate_reading_comprehension(self) -> Dict[str, Any]:
|
| 291 |
+
"""Simplified reading comprehension evaluation."""
|
| 292 |
+
# Sample reading comprehension tasks
|
| 293 |
+
tasks = [
|
| 294 |
+
{
|
| 295 |
+
'context': 'The Eiffel Tower is a wrought-iron lattice tower on the Champ de Mars in Paris, France. It is named after the engineer Gustave Eiffel, whose company designed and built the tower.',
|
| 296 |
+
'question': 'Who is the Eiffel Tower named after?',
|
| 297 |
+
'expected': 'Gustave Eiffel'
|
| 298 |
+
},
|
| 299 |
+
{
|
| 300 |
+
'context': 'Python is a high-level programming language. It was created by Guido van Rossum and first released in 1991.',
|
| 301 |
+
'question': 'When was Python first released?',
|
| 302 |
+
'expected': '1991'
|
| 303 |
+
},
|
| 304 |
+
{
|
| 305 |
+
'context': 'Machine learning is a subset of artificial intelligence that enables computers to learn without being explicitly programmed.',
|
| 306 |
+
'question': 'What is machine learning a subset of?',
|
| 307 |
+
'expected': 'artificial intelligence'
|
| 308 |
+
}
|
| 309 |
+
]
|
| 310 |
+
|
| 311 |
+
correct = 0
|
| 312 |
+
total = len(tasks)
|
| 313 |
+
|
| 314 |
+
for task in tasks:
|
| 315 |
+
prompt = f"Context: {task['context']}\nQuestion: {task['question']}\nAnswer:"
|
| 316 |
+
|
| 317 |
+
# Generate answer
|
| 318 |
+
input_ids = self.tokenizer.encode(prompt)
|
| 319 |
+
input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.device)
|
| 320 |
+
|
| 321 |
+
with torch.no_grad():
|
| 322 |
+
output = self.model.generate(input_tensor, max_new_tokens=20, temperature=0.1)
|
| 323 |
+
|
| 324 |
+
generated_ids = output[0].tolist()
|
| 325 |
+
answer = self.tokenizer.decode(generated_ids[len(input_ids):]).strip().lower()
|
| 326 |
+
|
| 327 |
+
# Simple substring matching
|
| 328 |
+
if task['expected'].lower() in answer:
|
| 329 |
+
correct += 1
|
| 330 |
+
|
| 331 |
+
return {
|
| 332 |
+
'accuracy': correct / total,
|
| 333 |
+
'correct': correct,
|
| 334 |
+
'total': total,
|
| 335 |
+
'score': correct / total
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
def _evaluate_sentiment_analysis(self) -> Dict[str, Any]:
|
| 339 |
+
"""Few-shot sentiment analysis evaluation."""
|
| 340 |
+
# Few-shot examples
|
| 341 |
+
examples = "Examples:\nText: 'I love this movie!' Sentiment: Positive\nText: 'This is terrible.' Sentiment: Negative\nText: 'It was okay.' Sentiment: Neutral\n\n"
|
| 342 |
+
|
| 343 |
+
# Test cases
|
| 344 |
+
test_cases = [
|
| 345 |
+
{'text': 'This is amazing!', 'expected': 'positive'},
|
| 346 |
+
{'text': 'I hate this.', 'expected': 'negative'},
|
| 347 |
+
{'text': 'This is wonderful.', 'expected': 'positive'},
|
| 348 |
+
{'text': 'This is awful.', 'expected': 'negative'},
|
| 349 |
+
{'text': 'It was fine.', 'expected': 'neutral'}
|
| 350 |
+
]
|
| 351 |
+
|
| 352 |
+
correct = 0
|
| 353 |
+
total = len(test_cases)
|
| 354 |
+
|
| 355 |
+
for case in test_cases:
|
| 356 |
+
prompt = f"{examples}Text: '{case['text']}' Sentiment:"
|
| 357 |
+
|
| 358 |
+
# Generate sentiment
|
| 359 |
+
input_ids = self.tokenizer.encode(prompt)
|
| 360 |
+
input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.device)
|
| 361 |
+
|
| 362 |
+
with torch.no_grad():
|
| 363 |
+
output = self.model.generate(input_tensor, max_new_tokens=5, temperature=0.1)
|
| 364 |
+
|
| 365 |
+
generated_ids = output[0].tolist()
|
| 366 |
+
sentiment = self.tokenizer.decode(generated_ids[len(input_ids):]).strip().lower()
|
| 367 |
+
|
| 368 |
+
# Check if expected sentiment is in the generated response
|
| 369 |
+
if case['expected'] in sentiment:
|
| 370 |
+
correct += 1
|
| 371 |
+
|
| 372 |
+
return {
|
| 373 |
+
'accuracy': correct / total,
|
| 374 |
+
'correct': correct,
|
| 375 |
+
'total': total,
|
| 376 |
+
'score': correct / total
|
| 377 |
+
}
|
| 378 |
+
|
| 379 |
+
def _evaluate_reasoning(self) -> Dict[str, Any]:
|
| 380 |
+
"""Simple reasoning evaluation."""
|
| 381 |
+
# Basic reasoning tasks
|
| 382 |
+
tasks = [
|
| 383 |
+
{
|
| 384 |
+
'question': 'If all birds can fly and a penguin is a bird, can a penguin fly?',
|
| 385 |
+
'expected': 'no' # This tests if model knows real-world facts
|
| 386 |
+
},
|
| 387 |
+
{
|
| 388 |
+
'question': 'If it is raining outside, should you take an umbrella?',
|
| 389 |
+
'expected': 'yes'
|
| 390 |
+
},
|
| 391 |
+
{
|
| 392 |
+
'question': 'What comes after Monday?',
|
| 393 |
+
'expected': 'tuesday'
|
| 394 |
+
},
|
| 395 |
+
{
|
| 396 |
+
'question': 'Is the sun larger than the earth?',
|
| 397 |
+
'expected': 'yes'
|
| 398 |
+
}
|
| 399 |
+
]
|
| 400 |
+
|
| 401 |
+
correct = 0
|
| 402 |
+
total = len(tasks)
|
| 403 |
+
|
| 404 |
+
for task in tasks:
|
| 405 |
+
prompt = f"Question: {task['question']}\nAnswer:"
|
| 406 |
+
|
| 407 |
+
# Generate answer
|
| 408 |
+
input_ids = self.tokenizer.encode(prompt)
|
| 409 |
+
input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.device)
|
| 410 |
+
|
| 411 |
+
with torch.no_grad():
|
| 412 |
+
output = self.model.generate(input_tensor, max_new_tokens=10, temperature=0.1)
|
| 413 |
+
|
| 414 |
+
generated_ids = output[0].tolist()
|
| 415 |
+
answer = self.tokenizer.decode(generated_ids[len(input_ids):]).strip().lower()
|
| 416 |
+
|
| 417 |
+
# Check if expected answer is in the response
|
| 418 |
+
if task['expected'] in answer:
|
| 419 |
+
correct += 1
|
| 420 |
+
|
| 421 |
+
return {
|
| 422 |
+
'accuracy': correct / total,
|
| 423 |
+
'correct': correct,
|
| 424 |
+
'total': total,
|
| 425 |
+
'score': correct / total
|
| 426 |
+
}
|
| 427 |
+
|
| 428 |
+
def _evaluate_text_completion(self) -> Dict[str, Any]:
|
| 429 |
+
"""Evaluate text completion quality."""
|
| 430 |
+
# Common phrases that should be completed predictably
|
| 431 |
+
completions = [
|
| 432 |
+
{'prompt': 'The capital of France is', 'expected_word': 'paris'},
|
| 433 |
+
{'prompt': 'Two plus two equals', 'expected_word': 'four'},
|
| 434 |
+
{'prompt': 'The largest planet in our solar system is', 'expected_word': 'jupiter'},
|
| 435 |
+
{'prompt': 'Water boils at', 'expected_word': '100'}
|
| 436 |
+
]
|
| 437 |
+
|
| 438 |
+
correct = 0
|
| 439 |
+
total = len(completions)
|
| 440 |
+
|
| 441 |
+
for completion in completions:
|
| 442 |
+
# Generate completion
|
| 443 |
+
input_ids = self.tokenizer.encode(completion['prompt'])
|
| 444 |
+
input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.device)
|
| 445 |
+
|
| 446 |
+
with torch.no_grad():
|
| 447 |
+
output = self.model.generate(input_tensor, max_new_tokens=5, temperature=0.1)
|
| 448 |
+
|
| 449 |
+
generated_ids = output[0].tolist()
|
| 450 |
+
generated_text = self.tokenizer.decode(generated_ids[len(input_ids):]).strip().lower()
|
| 451 |
+
|
| 452 |
+
# Check if expected word appears in completion
|
| 453 |
+
if completion['expected_word'] in generated_text:
|
| 454 |
+
correct += 1
|
| 455 |
+
|
| 456 |
+
return {
|
| 457 |
+
'accuracy': correct / total,
|
| 458 |
+
'correct': correct,
|
| 459 |
+
'total': total,
|
| 460 |
+
'score': correct / total
|
| 461 |
+
}
|
| 462 |
+
|
| 463 |
+
def run_comprehensive_evaluation(
|
| 464 |
+
self,
|
| 465 |
+
eval_data_path: str,
|
| 466 |
+
metrics: List[str] = None,
|
| 467 |
+
generation_prompts: List[str] = None
|
| 468 |
+
) -> Dict[str, Any]:
|
| 469 |
+
"""
|
| 470 |
+
Run comprehensive model evaluation.
|
| 471 |
+
|
| 472 |
+
Args:
|
| 473 |
+
eval_data_path: Path to evaluation text file
|
| 474 |
+
metrics: List of metrics to compute
|
| 475 |
+
generation_prompts: Prompts for text generation evaluation
|
| 476 |
+
|
| 477 |
+
Returns:
|
| 478 |
+
Complete evaluation results
|
| 479 |
+
"""
|
| 480 |
+
if metrics is None:
|
| 481 |
+
metrics = ['perplexity', 'loss', 'generation']
|
| 482 |
+
|
| 483 |
+
if generation_prompts is None:
|
| 484 |
+
generation_prompts = [
|
| 485 |
+
"The history of artificial intelligence",
|
| 486 |
+
"Machine learning algorithms",
|
| 487 |
+
"The future of technology",
|
| 488 |
+
"In a world where",
|
| 489 |
+
"Scientists have discovered"
|
| 490 |
+
]
|
| 491 |
+
|
| 492 |
+
results = {
|
| 493 |
+
'model_info': {
|
| 494 |
+
'parameters': self.model.get_num_params(),
|
| 495 |
+
'device': self.device,
|
| 496 |
+
'vocab_size': self.tokenizer.vocab_size()
|
| 497 |
+
},
|
| 498 |
+
'evaluation_timestamp': time.time()
|
| 499 |
+
}
|
| 500 |
+
|
| 501 |
+
# Load evaluation data
|
| 502 |
+
print(f"π Loading evaluation data from {eval_data_path}")
|
| 503 |
+
if os.path.exists(eval_data_path):
|
| 504 |
+
with open(eval_data_path, 'r', encoding='utf-8') as f:
|
| 505 |
+
eval_texts = [line.strip() for line in f if line.strip()]
|
| 506 |
+
else:
|
| 507 |
+
print(f"β οΈ Evaluation file not found, using sample texts")
|
| 508 |
+
eval_texts = [
|
| 509 |
+
"Artificial intelligence is a rapidly growing field of computer science.",
|
| 510 |
+
"Machine learning algorithms can learn patterns from data automatically.",
|
| 511 |
+
"Natural language processing helps computers understand human language.",
|
| 512 |
+
"Deep learning uses neural networks with multiple layers for complex tasks.",
|
| 513 |
+
"The development of large language models has transformed AI applications."
|
| 514 |
+
]
|
| 515 |
+
|
| 516 |
+
# Intrinsic evaluation
|
| 517 |
+
if 'perplexity' in metrics or 'loss' in metrics:
|
| 518 |
+
perplexity_results = self.evaluate_perplexity(eval_texts)
|
| 519 |
+
results['intrinsic_evaluation'] = perplexity_results
|
| 520 |
+
|
| 521 |
+
# Text generation evaluation
|
| 522 |
+
if 'generation' in metrics:
|
| 523 |
+
generation_results = self.evaluate_text_generation(generation_prompts)
|
| 524 |
+
results['generation_evaluation'] = {
|
| 525 |
+
'results': generation_results,
|
| 526 |
+
'summary': self._summarize_generation_results(generation_results)
|
| 527 |
+
}
|
| 528 |
+
|
| 529 |
+
# Downstream tasks (placeholder)
|
| 530 |
+
results['downstream_evaluation'] = self.evaluate_downstream_tasks()
|
| 531 |
+
|
| 532 |
+
# Overall quality assessment
|
| 533 |
+
results['quality_assessment'] = self._assess_overall_quality(results)
|
| 534 |
+
|
| 535 |
+
return results
|
| 536 |
+
|
| 537 |
+
def _summarize_generation_results(self, results: List[Dict[str, Any]]) -> Dict[str, float]:
|
| 538 |
+
"""Summarize text generation results."""
|
| 539 |
+
if not results:
|
| 540 |
+
return {}
|
| 541 |
+
|
| 542 |
+
total_time = sum(r['generation_time'] for r in results)
|
| 543 |
+
total_tokens = sum(r['tokens_generated'] for r in results)
|
| 544 |
+
|
| 545 |
+
quality_metrics = [r['quality_metrics'] for r in results]
|
| 546 |
+
|
| 547 |
+
return {
|
| 548 |
+
'avg_generation_time': total_time / len(results),
|
| 549 |
+
'avg_tokens_per_second': total_tokens / total_time if total_time > 0 else 0,
|
| 550 |
+
'avg_length': sum(q['length'] for q in quality_metrics) / len(quality_metrics),
|
| 551 |
+
'avg_repetition_rate': sum(q['repetition_rate'] for q in quality_metrics) / len(quality_metrics),
|
| 552 |
+
'avg_coherence_score': sum(q['coherence_score'] for q in quality_metrics) / len(quality_metrics)
|
| 553 |
+
}
|
| 554 |
+
|
| 555 |
+
def _assess_overall_quality(self, results: Dict[str, Any]) -> Dict[str, Any]:
|
| 556 |
+
"""Assess overall model quality based on evaluation results."""
|
| 557 |
+
assessment = {
|
| 558 |
+
'quality_level': 'unknown',
|
| 559 |
+
'recommendations': []
|
| 560 |
+
}
|
| 561 |
+
|
| 562 |
+
# Check intrinsic metrics
|
| 563 |
+
if 'intrinsic_evaluation' in results:
|
| 564 |
+
perplexity = results['intrinsic_evaluation'].get('perplexity', float('inf'))
|
| 565 |
+
|
| 566 |
+
if perplexity < 12:
|
| 567 |
+
assessment['quality_level'] = 'good'
|
| 568 |
+
assessment['recommendations'].append('Model shows good perplexity scores')
|
| 569 |
+
elif perplexity < 50:
|
| 570 |
+
assessment['quality_level'] = 'fair'
|
| 571 |
+
assessment['recommendations'].append('Model shows fair performance, could benefit from more training')
|
| 572 |
+
else:
|
| 573 |
+
assessment['quality_level'] = 'poor'
|
| 574 |
+
assessment['recommendations'].append('Model needs significant more training or data improvements')
|
| 575 |
+
|
| 576 |
+
# Check generation quality
|
| 577 |
+
if 'generation_evaluation' in results:
|
| 578 |
+
summary = results['generation_evaluation'].get('summary', {})
|
| 579 |
+
repetition_rate = summary.get('avg_repetition_rate', 1.0)
|
| 580 |
+
coherence_score = summary.get('avg_coherence_score', 0.0)
|
| 581 |
+
|
| 582 |
+
if repetition_rate > 0.7:
|
| 583 |
+
assessment['recommendations'].append('High repetition rate - consider training longer or adjusting data')
|
| 584 |
+
if coherence_score < 0.3:
|
| 585 |
+
assessment['recommendations'].append('Low coherence - model may need more training steps')
|
| 586 |
+
|
| 587 |
+
return assessment
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
def load_model_from_directory(model_dir: str, device: str = "cpu") -> Tuple[GPTModel, str]:
|
| 591 |
+
"""
|
| 592 |
+
Load model from directory containing checkpoints.
|
| 593 |
+
|
| 594 |
+
Args:
|
| 595 |
+
model_dir: Directory containing model files
|
| 596 |
+
device: Device to load model on
|
| 597 |
+
|
| 598 |
+
Returns:
|
| 599 |
+
Tuple of (model, tokenizer_path)
|
| 600 |
+
"""
|
| 601 |
+
model_dir = Path(model_dir)
|
| 602 |
+
|
| 603 |
+
# Find best model checkpoint
|
| 604 |
+
best_model_path = model_dir / "best_model.pt"
|
| 605 |
+
if not best_model_path.exists():
|
| 606 |
+
# Look for latest checkpoint
|
| 607 |
+
checkpoints = list(model_dir.glob("checkpoint_step_*.pt"))
|
| 608 |
+
if not checkpoints:
|
| 609 |
+
raise FileNotFoundError(f"No model checkpoints found in {model_dir}")
|
| 610 |
+
|
| 611 |
+
# Get latest checkpoint
|
| 612 |
+
latest_checkpoint = max(checkpoints, key=lambda p: int(p.stem.split('_')[-1]))
|
| 613 |
+
best_model_path = latest_checkpoint
|
| 614 |
+
|
| 615 |
+
print(f"π Loading model from {best_model_path}")
|
| 616 |
+
|
| 617 |
+
# Load checkpoint
|
| 618 |
+
checkpoint = torch.load(best_model_path, map_location=device)
|
| 619 |
+
|
| 620 |
+
# Determine model size from config
|
| 621 |
+
config = checkpoint.get('config', {})
|
| 622 |
+
n_layer = config.get('n_layer', 12)
|
| 623 |
+
|
| 624 |
+
if n_layer <= 6:
|
| 625 |
+
model_size = "small"
|
| 626 |
+
elif n_layer <= 12:
|
| 627 |
+
model_size = "medium"
|
| 628 |
+
else:
|
| 629 |
+
model_size = "large"
|
| 630 |
+
|
| 631 |
+
# Create and load model
|
| 632 |
+
model = create_model(model_size)
|
| 633 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 634 |
+
|
| 635 |
+
print(f"β
Model loaded successfully ({model_size}, {model.get_num_params():,} parameters)")
|
| 636 |
+
|
| 637 |
+
# Find tokenizer
|
| 638 |
+
tokenizer_path = model_dir.parent / "tokenizer" / "tokenizer.model"
|
| 639 |
+
if not tokenizer_path.exists():
|
| 640 |
+
tokenizer_path = Path("data/tokenizer/tokenizer.model")
|
| 641 |
+
|
| 642 |
+
if not tokenizer_path.exists():
|
| 643 |
+
raise FileNotFoundError(f"Tokenizer not found at {tokenizer_path}")
|
| 644 |
+
|
| 645 |
+
return model, str(tokenizer_path)
|
| 646 |
+
|
| 647 |
+
|
| 648 |
+
def main():
|
| 649 |
+
"""Main evaluation function."""
|
| 650 |
+
parser = argparse.ArgumentParser(
|
| 651 |
+
description="Evaluate OpenLLM model performance",
|
| 652 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 653 |
+
epilog="""
|
| 654 |
+
Examples:
|
| 655 |
+
# Basic evaluation
|
| 656 |
+
python core/src/evaluate_model.py \\
|
| 657 |
+
--model_dir models/small-extended-4k \\
|
| 658 |
+
--eval_data data/clean/training_data.txt
|
| 659 |
+
|
| 660 |
+
# Specific metrics
|
| 661 |
+
python core/src/evaluate_model.py \\
|
| 662 |
+
--model_dir models/small-extended-4k \\
|
| 663 |
+
--metrics perplexity,generation \\
|
| 664 |
+
--output results.json
|
| 665 |
+
"""
|
| 666 |
+
)
|
| 667 |
+
|
| 668 |
+
parser.add_argument(
|
| 669 |
+
"--model_dir",
|
| 670 |
+
required=True,
|
| 671 |
+
help="Directory containing trained model"
|
| 672 |
+
)
|
| 673 |
+
|
| 674 |
+
parser.add_argument(
|
| 675 |
+
"--eval_data",
|
| 676 |
+
help="Path to evaluation text file (default: use sample texts)"
|
| 677 |
+
)
|
| 678 |
+
|
| 679 |
+
parser.add_argument(
|
| 680 |
+
"--metrics",
|
| 681 |
+
default="perplexity,loss,generation",
|
| 682 |
+
help="Comma-separated list of metrics to evaluate (default: perplexity,loss,generation)"
|
| 683 |
+
)
|
| 684 |
+
|
| 685 |
+
parser.add_argument(
|
| 686 |
+
"--output",
|
| 687 |
+
help="Output JSON file for results (default: print to console)"
|
| 688 |
+
)
|
| 689 |
+
|
| 690 |
+
parser.add_argument(
|
| 691 |
+
"--device",
|
| 692 |
+
choices=["cpu", "cuda", "auto"],
|
| 693 |
+
default="auto",
|
| 694 |
+
help="Device for evaluation (default: auto)"
|
| 695 |
+
)
|
| 696 |
+
|
| 697 |
+
parser.add_argument(
|
| 698 |
+
"--generation_prompts",
|
| 699 |
+
help="File containing prompts for text generation evaluation"
|
| 700 |
+
)
|
| 701 |
+
|
| 702 |
+
args = parser.parse_args()
|
| 703 |
+
|
| 704 |
+
print("π OpenLLM Model Evaluation")
|
| 705 |
+
print("=" * 50)
|
| 706 |
+
|
| 707 |
+
# Determine device
|
| 708 |
+
if args.device == "auto":
|
| 709 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 710 |
+
else:
|
| 711 |
+
device = args.device
|
| 712 |
+
|
| 713 |
+
print(f"Using device: {device}")
|
| 714 |
+
|
| 715 |
+
try:
|
| 716 |
+
# Load model
|
| 717 |
+
model, tokenizer_path = load_model_from_directory(args.model_dir, device)
|
| 718 |
+
|
| 719 |
+
# Create evaluator
|
| 720 |
+
evaluator = ModelEvaluator(model, tokenizer_path, device)
|
| 721 |
+
|
| 722 |
+
# Parse metrics
|
| 723 |
+
metrics = [m.strip() for m in args.metrics.split(',')]
|
| 724 |
+
|
| 725 |
+
# Load generation prompts if specified
|
| 726 |
+
generation_prompts = None
|
| 727 |
+
if args.generation_prompts and os.path.exists(args.generation_prompts):
|
| 728 |
+
with open(args.generation_prompts, 'r', encoding='utf-8') as f:
|
| 729 |
+
generation_prompts = [line.strip() for line in f if line.strip()]
|
| 730 |
+
|
| 731 |
+
# Run evaluation
|
| 732 |
+
eval_data_path = args.eval_data or "data/clean/training_data.txt"
|
| 733 |
+
results = evaluator.run_comprehensive_evaluation(
|
| 734 |
+
eval_data_path, metrics, generation_prompts
|
| 735 |
+
)
|
| 736 |
+
|
| 737 |
+
# Output results
|
| 738 |
+
if args.output:
|
| 739 |
+
with open(args.output, 'w', encoding='utf-8') as f:
|
| 740 |
+
json.dump(results, f, indent=2)
|
| 741 |
+
print(f"\nπΎ Results saved to {args.output}")
|
| 742 |
+
else:
|
| 743 |
+
print(f"\nπ Evaluation Results:")
|
| 744 |
+
print("=" * 50)
|
| 745 |
+
|
| 746 |
+
# Print key metrics
|
| 747 |
+
if 'intrinsic_evaluation' in results:
|
| 748 |
+
intrinsic = results['intrinsic_evaluation']
|
| 749 |
+
print(f"π Intrinsic Metrics:")
|
| 750 |
+
print(f" Loss: {intrinsic['loss']:.4f}")
|
| 751 |
+
print(f" Perplexity: {intrinsic['perplexity']:.2f}")
|
| 752 |
+
print(f" Sequences evaluated: {intrinsic['num_sequences']:,}")
|
| 753 |
+
|
| 754 |
+
if 'generation_evaluation' in results:
|
| 755 |
+
gen_summary = results['generation_evaluation']['summary']
|
| 756 |
+
print(f"\nβοΈ Generation Quality:")
|
| 757 |
+
print(f" Avg generation speed: {gen_summary['avg_tokens_per_second']:.1f} tokens/sec")
|
| 758 |
+
print(f" Avg text length: {gen_summary['avg_length']:.1f} words")
|
| 759 |
+
print(f" Repetition rate: {gen_summary['avg_repetition_rate']:.3f}")
|
| 760 |
+
print(f" Coherence score: {gen_summary['avg_coherence_score']:.3f}")
|
| 761 |
+
|
| 762 |
+
# Quality assessment
|
| 763 |
+
if 'quality_assessment' in results:
|
| 764 |
+
assessment = results['quality_assessment']
|
| 765 |
+
print(f"\nπ― Overall Assessment:")
|
| 766 |
+
print(f" Quality Level: {assessment['quality_level'].upper()}")
|
| 767 |
+
for rec in assessment['recommendations']:
|
| 768 |
+
print(f" β’ {rec}")
|
| 769 |
+
|
| 770 |
+
print(f"\nπ Evaluation completed successfully!")
|
| 771 |
+
|
| 772 |
+
except Exception as e:
|
| 773 |
+
print(f"\nβ Evaluation failed: {e}")
|
| 774 |
+
import traceback
|
| 775 |
+
traceback.print_exc()
|
| 776 |
+
return False
|
| 777 |
+
|
| 778 |
+
return True
|
| 779 |
+
|
| 780 |
+
|
| 781 |
+
if __name__ == "__main__":
|
| 782 |
+
main()
|
training/model.py
ADDED
|
@@ -0,0 +1,641 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright (C) 2024 Louis Chua Bean Chong
|
| 3 |
+
#
|
| 4 |
+
# This file is part of OpenLLM.
|
| 5 |
+
#
|
| 6 |
+
# OpenLLM is dual-licensed:
|
| 7 |
+
# 1. For open source use: GNU General Public License v3.0
|
| 8 |
+
# 2. For commercial use: Commercial License (contact for details)
|
| 9 |
+
#
|
| 10 |
+
# See LICENSE and docs/LICENSES.md for full license information.
|
| 11 |
+
|
| 12 |
+
"""
|
| 13 |
+
GPT-style Language Model Architecture
|
| 14 |
+
|
| 15 |
+
This module implements a standard GPT (Generative Pre-trained Transformer) architecture
|
| 16 |
+
using pure PyTorch. The model is a decoder-only transformer designed for autoregressive
|
| 17 |
+
language modeling (next-token prediction).
|
| 18 |
+
|
| 19 |
+
ARCHITECTURE OVERVIEW:
|
| 20 |
+
- Token Embedding: Maps token IDs to dense vectors
|
| 21 |
+
- Positional Embedding: Adds position information to token embeddings
|
| 22 |
+
- Transformer Blocks: Stack of multi-head attention + feed-forward layers
|
| 23 |
+
- Layer Normalization: Pre-norm placement for training stability
|
| 24 |
+
- Output Head: Linear projection to vocabulary for next-token prediction
|
| 25 |
+
|
| 26 |
+
FEATURES:
|
| 27 |
+
- Configurable model size (small/medium/large)
|
| 28 |
+
- Dropout for regularization
|
| 29 |
+
- Causal (autoregressive) attention masking
|
| 30 |
+
- Compatible with our SentencePiece tokenizer
|
| 31 |
+
- Memory-efficient implementation for training on limited hardware
|
| 32 |
+
|
| 33 |
+
Usage:
|
| 34 |
+
from model import GPTConfig, GPTModel
|
| 35 |
+
|
| 36 |
+
config = GPTConfig(vocab_size=32000, n_layer=12, n_head=12, n_embd=768)
|
| 37 |
+
model = GPTModel(config)
|
| 38 |
+
|
| 39 |
+
# Forward pass
|
| 40 |
+
logits = model(input_ids) # Shape: (batch_size, seq_len, vocab_size)
|
| 41 |
+
|
| 42 |
+
Hardware Requirements:
|
| 43 |
+
- Small Model (25M params): 4-8GB RAM, CPU/integrated GPU
|
| 44 |
+
- Medium Model (117M params): 8-16GB RAM, dedicated GPU recommended
|
| 45 |
+
- Large Model (350M params): 16GB+ RAM, high-end GPU required
|
| 46 |
+
|
| 47 |
+
Author: Louis Chua Bean Chong
|
| 48 |
+
License: GPLv3
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
import math
|
| 52 |
+
import torch
|
| 53 |
+
import torch.nn as nn
|
| 54 |
+
import torch.nn.functional as F
|
| 55 |
+
from dataclasses import dataclass
|
| 56 |
+
from typing import Optional, Tuple
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@dataclass
|
| 60 |
+
class GPTConfig:
|
| 61 |
+
"""
|
| 62 |
+
Configuration class for GPT model hyperparameters.
|
| 63 |
+
|
| 64 |
+
This class defines all the architectural parameters needed to instantiate
|
| 65 |
+
a GPT model. Use the provided class methods to get pre-configured setups
|
| 66 |
+
for different model sizes.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
# Model architecture
|
| 70 |
+
vocab_size: int = 32000 # Vocabulary size (from tokenizer)
|
| 71 |
+
n_layer: int = 12 # Number of transformer layers
|
| 72 |
+
n_head: int = 12 # Number of attention heads
|
| 73 |
+
n_embd: int = 768 # Embedding dimension
|
| 74 |
+
|
| 75 |
+
# Sequence and context
|
| 76 |
+
block_size: int = 1024 # Maximum sequence length
|
| 77 |
+
|
| 78 |
+
# Training hyperparameters
|
| 79 |
+
dropout: float = 0.1 # Dropout probability
|
| 80 |
+
bias: bool = True # Use bias in linear layers
|
| 81 |
+
|
| 82 |
+
# Model size identifier
|
| 83 |
+
model_name: str = "gpt-medium" # Human-readable model identifier
|
| 84 |
+
|
| 85 |
+
@classmethod
|
| 86 |
+
def small(cls) -> 'GPTConfig':
|
| 87 |
+
"""Small model configuration (~25M parameters) - Good for CPU training"""
|
| 88 |
+
return cls(
|
| 89 |
+
vocab_size=32000,
|
| 90 |
+
n_layer=6,
|
| 91 |
+
n_head=8,
|
| 92 |
+
n_embd=512,
|
| 93 |
+
block_size=1024,
|
| 94 |
+
dropout=0.1,
|
| 95 |
+
model_name="gpt-small"
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
@classmethod
|
| 99 |
+
def medium(cls) -> 'GPTConfig':
|
| 100 |
+
"""Medium model configuration (~117M parameters) - Balanced performance"""
|
| 101 |
+
return cls(
|
| 102 |
+
vocab_size=32000,
|
| 103 |
+
n_layer=12,
|
| 104 |
+
n_head=12,
|
| 105 |
+
n_embd=768,
|
| 106 |
+
block_size=2048,
|
| 107 |
+
dropout=0.1,
|
| 108 |
+
model_name="gpt-medium"
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
@classmethod
|
| 112 |
+
def large(cls) -> 'GPTConfig':
|
| 113 |
+
"""Large model configuration (~350M parameters) - High performance"""
|
| 114 |
+
return cls(
|
| 115 |
+
vocab_size=32000,
|
| 116 |
+
n_layer=24,
|
| 117 |
+
n_head=16,
|
| 118 |
+
n_embd=1024,
|
| 119 |
+
block_size=2048,
|
| 120 |
+
dropout=0.1,
|
| 121 |
+
model_name="gpt-large"
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
def estimate_parameters(self) -> int:
|
| 125 |
+
"""
|
| 126 |
+
Estimate the total number of trainable parameters.
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
int: Estimated parameter count
|
| 130 |
+
"""
|
| 131 |
+
# Token embeddings
|
| 132 |
+
token_emb = self.vocab_size * self.n_embd
|
| 133 |
+
|
| 134 |
+
# Position embeddings
|
| 135 |
+
pos_emb = self.block_size * self.n_embd
|
| 136 |
+
|
| 137 |
+
# Transformer layers
|
| 138 |
+
# Each layer: attention (4 * n_embd^2) + mlp (8 * n_embd^2) + layer_norms
|
| 139 |
+
layer_params = self.n_layer * (12 * self.n_embd**2 + 4 * self.n_embd)
|
| 140 |
+
|
| 141 |
+
# Output head
|
| 142 |
+
output_head = self.vocab_size * self.n_embd
|
| 143 |
+
|
| 144 |
+
total = token_emb + pos_emb + layer_params + output_head
|
| 145 |
+
return total
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class CausalSelfAttention(nn.Module):
|
| 149 |
+
"""
|
| 150 |
+
Multi-head causal self-attention mechanism.
|
| 151 |
+
|
| 152 |
+
This implements the core attention mechanism of the transformer, with causal
|
| 153 |
+
masking to ensure autoregressive behavior (tokens can only attend to previous
|
| 154 |
+
tokens, not future ones).
|
| 155 |
+
"""
|
| 156 |
+
|
| 157 |
+
def __init__(self, config: GPTConfig):
|
| 158 |
+
super().__init__()
|
| 159 |
+
assert config.n_embd % config.n_head == 0, "Embedding dim must be divisible by number of heads"
|
| 160 |
+
|
| 161 |
+
self.config = config
|
| 162 |
+
self.n_head = config.n_head
|
| 163 |
+
self.n_embd = config.n_embd
|
| 164 |
+
self.head_dim = self.n_embd // self.n_head
|
| 165 |
+
|
| 166 |
+
# Key, query, value projections for all heads (batched)
|
| 167 |
+
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
|
| 168 |
+
|
| 169 |
+
# Output projection
|
| 170 |
+
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
|
| 171 |
+
|
| 172 |
+
# Dropout
|
| 173 |
+
self.attn_dropout = nn.Dropout(config.dropout)
|
| 174 |
+
self.resid_dropout = nn.Dropout(config.dropout)
|
| 175 |
+
|
| 176 |
+
# Causal mask - lower triangular matrix
|
| 177 |
+
self.register_buffer(
|
| 178 |
+
"bias",
|
| 179 |
+
torch.tril(torch.ones(config.block_size, config.block_size))
|
| 180 |
+
.view(1, 1, config.block_size, config.block_size)
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 184 |
+
"""
|
| 185 |
+
Forward pass of causal self-attention.
|
| 186 |
+
|
| 187 |
+
This method implements the scaled dot-product attention mechanism with causal masking.
|
| 188 |
+
The attention mechanism allows each token to attend to all previous tokens in the sequence,
|
| 189 |
+
but not to future tokens, maintaining the autoregressive property essential for language modeling.
|
| 190 |
+
|
| 191 |
+
Mathematical formulation:
|
| 192 |
+
Attention(Q, K, V) = softmax(QK^T / sqrt(d_k))V
|
| 193 |
+
where Q, K, V are query, key, value matrices derived from input x
|
| 194 |
+
|
| 195 |
+
Implementation details:
|
| 196 |
+
- Uses batch matrix multiplication for efficiency
|
| 197 |
+
- Applies causal mask to prevent future token attention
|
| 198 |
+
- Implements multi-head attention by reshaping and parallel processing
|
| 199 |
+
- Applies dropout for regularization during training
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
x: Input tensor of shape (batch_size, seq_len, n_embd)
|
| 203 |
+
Contains embedded token representations from previous layer
|
| 204 |
+
|
| 205 |
+
Returns:
|
| 206 |
+
torch.Tensor: Output tensor of shape (batch_size, seq_len, n_embd)
|
| 207 |
+
"""
|
| 208 |
+
# Extract tensor dimensions for clear variable naming and validation
|
| 209 |
+
# B = batch size (number of sequences processed in parallel)
|
| 210 |
+
# T = sequence length (number of tokens in each sequence)
|
| 211 |
+
# C = embedding dimensionality (n_embd from config)
|
| 212 |
+
B, T, C = x.size()
|
| 213 |
+
|
| 214 |
+
# Generate query, key, and value projections for all attention heads
|
| 215 |
+
# The c_attn linear layer outputs 3 * n_embd features, which we split into Q, K, V
|
| 216 |
+
# This batched approach is more efficient than separate linear layers
|
| 217 |
+
# Input shape: (B, T, C) -> Output shape: (B, T, 3*C) -> Split to 3x (B, T, C)
|
| 218 |
+
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
|
| 219 |
+
|
| 220 |
+
# Reshape tensors for multi-head attention computation
|
| 221 |
+
# Transform from (B, T, C) to (B, nh, T, hs) where:
|
| 222 |
+
# - nh = number of heads (self.n_head)
|
| 223 |
+
# - hs = head size (self.head_dim = C // nh)
|
| 224 |
+
# The transpose(1, 2) moves the head dimension before sequence dimension for efficient computation
|
| 225 |
+
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
|
| 226 |
+
k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
|
| 227 |
+
v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
|
| 228 |
+
|
| 229 |
+
# Compute scaled dot-product attention scores
|
| 230 |
+
# Matrix multiplication: Q @ K^T gives attention affinities between all token pairs
|
| 231 |
+
# Scaling by 1/sqrt(head_dim) prevents softmax saturation for large embedding dimensions
|
| 232 |
+
# Shape: (B, nh, T, hs) @ (B, nh, hs, T) -> (B, nh, T, T)
|
| 233 |
+
# The resulting (T, T) matrix represents attention weights from each token to every other token
|
| 234 |
+
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))
|
| 235 |
+
|
| 236 |
+
# Apply causal masking to enforce autoregressive property
|
| 237 |
+
# The causal mask ensures that token i can only attend to tokens j where j <= i
|
| 238 |
+
# This prevents the model from "cheating" by looking at future tokens during training
|
| 239 |
+
# We use -inf for masked positions so they become 0 after softmax
|
| 240 |
+
att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
|
| 241 |
+
|
| 242 |
+
# Convert attention scores to probabilities using softmax
|
| 243 |
+
# Each row of the attention matrix now sums to 1, representing a probability distribution
|
| 244 |
+
# over which tokens to attend to for each query position
|
| 245 |
+
att = F.softmax(att, dim=-1)
|
| 246 |
+
|
| 247 |
+
# Apply dropout to attention weights for regularization
|
| 248 |
+
# This randomly zeros some attention connections during training to prevent overfitting
|
| 249 |
+
att = self.attn_dropout(att)
|
| 250 |
+
|
| 251 |
+
# Apply attention weights to value vectors
|
| 252 |
+
# This weighted combination produces the actual output of the attention mechanism
|
| 253 |
+
# Shape: (B, nh, T, T) @ (B, nh, T, hs) -> (B, nh, T, hs)
|
| 254 |
+
# Each output position is a weighted sum of all value vectors, with weights from attention
|
| 255 |
+
y = att @ v
|
| 256 |
+
|
| 257 |
+
# Concatenate multi-head outputs back to original embedding dimension
|
| 258 |
+
# Transform from (B, nh, T, hs) back to (B, T, C) where C = nh * hs
|
| 259 |
+
# The transpose moves head dimension back, and contiguous() ensures memory layout efficiency
|
| 260 |
+
# This combines information from all attention heads into a single representation
|
| 261 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C)
|
| 262 |
+
|
| 263 |
+
# Apply final output projection and residual dropout
|
| 264 |
+
# The output projection allows the model to learn how to best combine multi-head information
|
| 265 |
+
# Residual dropout provides additional regularization before the residual connection
|
| 266 |
+
y = self.resid_dropout(self.c_proj(y))
|
| 267 |
+
return y
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
class MLP(nn.Module):
|
| 271 |
+
"""
|
| 272 |
+
Multi-Layer Perceptron (Feed-Forward Network) for Transformer.
|
| 273 |
+
|
| 274 |
+
This implements the position-wise feed-forward network that appears in each transformer layer.
|
| 275 |
+
The MLP provides additional non-linear transformation capacity beyond what attention provides.
|
| 276 |
+
|
| 277 |
+
Architecture:
|
| 278 |
+
Input -> Linear(n_embd -> 4*n_embd) -> GELU -> Linear(4*n_embd -> n_embd) -> Dropout -> Output
|
| 279 |
+
|
| 280 |
+
Design rationale:
|
| 281 |
+
- 4x expansion is standard in transformers (from "Attention Is All You Need")
|
| 282 |
+
- GELU activation provides smoother gradients than ReLU for language modeling
|
| 283 |
+
- Dropout prevents overfitting in the feed-forward layers
|
| 284 |
+
- Two linear layers allow complex non-linear transformations of attention outputs
|
| 285 |
+
|
| 286 |
+
Parameters:
|
| 287 |
+
- First linear layer: n_embd * 4*n_embd parameters (expansion)
|
| 288 |
+
- Second linear layer: 4*n_embd * n_embd parameters (projection back)
|
| 289 |
+
- Total: 8 * n_embd^2 parameters (significant portion of model size)
|
| 290 |
+
"""
|
| 291 |
+
|
| 292 |
+
def __init__(self, config: GPTConfig):
|
| 293 |
+
super().__init__()
|
| 294 |
+
|
| 295 |
+
# First linear layer: expand embedding dimension by 4x
|
| 296 |
+
# This expansion gives the network more representational capacity
|
| 297 |
+
# The 4x factor is a standard choice that balances capacity vs efficiency
|
| 298 |
+
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
|
| 299 |
+
|
| 300 |
+
# GELU (Gaussian Error Linear Unit) activation function
|
| 301 |
+
# GELU provides smoother gradients compared to ReLU and works better for language modeling
|
| 302 |
+
# It's approximately: GELU(x) = x * Ξ¦(x) where Ξ¦ is the CDF of standard normal distribution
|
| 303 |
+
self.gelu = nn.GELU()
|
| 304 |
+
|
| 305 |
+
# Second linear layer: project back to original embedding dimension
|
| 306 |
+
# This projection allows the network to combine information from the expanded representation
|
| 307 |
+
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
|
| 308 |
+
|
| 309 |
+
# Dropout for regularization in the feed-forward network
|
| 310 |
+
# Applied after the final projection to prevent overfitting
|
| 311 |
+
self.dropout = nn.Dropout(config.dropout)
|
| 312 |
+
|
| 313 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 314 |
+
"""
|
| 315 |
+
Forward pass of the feed-forward network.
|
| 316 |
+
|
| 317 |
+
This method applies a two-layer MLP with GELU activation to transform
|
| 318 |
+
the attention outputs. The MLP operates independently on each position
|
| 319 |
+
in the sequence, providing position-wise non-linear transformations.
|
| 320 |
+
|
| 321 |
+
Mathematical operation:
|
| 322 |
+
MLP(x) = Dropout(Linearβ(GELU(Linearβ(x))))
|
| 323 |
+
where Linearβ: R^n_embd -> R^4*n_embd and Linearβ: R^4*n_embd -> R^n_embd
|
| 324 |
+
|
| 325 |
+
Args:
|
| 326 |
+
x: Input tensor of shape (batch_size, seq_len, n_embd)
|
| 327 |
+
Contains attended representations from the attention layer
|
| 328 |
+
|
| 329 |
+
Returns:
|
| 330 |
+
torch.Tensor: Output tensor of shape (batch_size, seq_len, n_embd)
|
| 331 |
+
Contains transformed representations ready for residual connection
|
| 332 |
+
"""
|
| 333 |
+
# First linear transformation: expand from n_embd to 4*n_embd dimensions
|
| 334 |
+
# This expansion provides the network with a higher-dimensional space for computation
|
| 335 |
+
# Shape: (batch_size, seq_len, n_embd) -> (batch_size, seq_len, 4*n_embd)
|
| 336 |
+
x = self.c_fc(x)
|
| 337 |
+
|
| 338 |
+
# Apply GELU activation function for non-linearity
|
| 339 |
+
# GELU is smoother than ReLU and provides better gradients for language modeling
|
| 340 |
+
# It introduces non-linearity while maintaining differentiability everywhere
|
| 341 |
+
x = self.gelu(x)
|
| 342 |
+
|
| 343 |
+
# Second linear transformation: project back to original n_embd dimensions
|
| 344 |
+
# This projection combines information from the expanded representation
|
| 345 |
+
# Shape: (batch_size, seq_len, 4*n_embd) -> (batch_size, seq_len, n_embd)
|
| 346 |
+
x = self.c_proj(x)
|
| 347 |
+
|
| 348 |
+
# Apply dropout for regularization before residual connection
|
| 349 |
+
# Dropout randomly zeros some neurons during training to prevent overfitting
|
| 350 |
+
# This is particularly important in the feed-forward layers which have many parameters
|
| 351 |
+
x = self.dropout(x)
|
| 352 |
+
|
| 353 |
+
return x
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
class Block(nn.Module):
|
| 357 |
+
"""
|
| 358 |
+
Single Transformer block.
|
| 359 |
+
|
| 360 |
+
Consists of:
|
| 361 |
+
1. Layer normalization
|
| 362 |
+
2. Multi-head causal self-attention
|
| 363 |
+
3. Residual connection
|
| 364 |
+
4. Layer normalization
|
| 365 |
+
5. MLP (feed-forward network)
|
| 366 |
+
6. Residual connection
|
| 367 |
+
|
| 368 |
+
Uses pre-norm architecture for better training stability.
|
| 369 |
+
"""
|
| 370 |
+
|
| 371 |
+
def __init__(self, config: GPTConfig):
|
| 372 |
+
super().__init__()
|
| 373 |
+
self.ln_1 = nn.LayerNorm(config.n_embd)
|
| 374 |
+
self.attn = CausalSelfAttention(config)
|
| 375 |
+
self.ln_2 = nn.LayerNorm(config.n_embd)
|
| 376 |
+
self.mlp = MLP(config)
|
| 377 |
+
|
| 378 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 379 |
+
"""
|
| 380 |
+
Forward pass of transformer block.
|
| 381 |
+
|
| 382 |
+
Args:
|
| 383 |
+
x: Input tensor of shape (batch_size, seq_len, n_embd)
|
| 384 |
+
|
| 385 |
+
Returns:
|
| 386 |
+
torch.Tensor: Output tensor of shape (batch_size, seq_len, n_embd)
|
| 387 |
+
"""
|
| 388 |
+
# Pre-norm attention with residual connection
|
| 389 |
+
x = x + self.attn(self.ln_1(x))
|
| 390 |
+
|
| 391 |
+
# Pre-norm MLP with residual connection
|
| 392 |
+
x = x + self.mlp(self.ln_2(x))
|
| 393 |
+
|
| 394 |
+
return x
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
class GPTModel(nn.Module):
|
| 398 |
+
"""
|
| 399 |
+
Complete GPT Language Model.
|
| 400 |
+
|
| 401 |
+
This is the main model class that combines all components:
|
| 402 |
+
- Token and positional embeddings
|
| 403 |
+
- Stack of transformer blocks
|
| 404 |
+
- Final layer normalization
|
| 405 |
+
- Language modeling head
|
| 406 |
+
|
| 407 |
+
The model can be used for:
|
| 408 |
+
- Training from scratch on text data
|
| 409 |
+
- Fine-tuning on downstream tasks
|
| 410 |
+
- Text generation (inference)
|
| 411 |
+
"""
|
| 412 |
+
|
| 413 |
+
def __init__(self, config: GPTConfig):
|
| 414 |
+
super().__init__()
|
| 415 |
+
assert config.vocab_size is not None, "vocab_size must be specified"
|
| 416 |
+
assert config.block_size is not None, "block_size must be specified"
|
| 417 |
+
|
| 418 |
+
self.config = config
|
| 419 |
+
|
| 420 |
+
# Embeddings
|
| 421 |
+
self.transformer = nn.ModuleDict(dict(
|
| 422 |
+
wte = nn.Embedding(config.vocab_size, config.n_embd), # Token embeddings
|
| 423 |
+
wpe = nn.Embedding(config.block_size, config.n_embd), # Position embeddings
|
| 424 |
+
drop = nn.Dropout(config.dropout),
|
| 425 |
+
h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), # Transformer blocks
|
| 426 |
+
ln_f = nn.LayerNorm(config.n_embd), # Final layer norm
|
| 427 |
+
))
|
| 428 |
+
|
| 429 |
+
# Language modeling head (maps hidden states to vocabulary)
|
| 430 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
| 431 |
+
|
| 432 |
+
# Tie weights between token embeddings and output head (common practice)
|
| 433 |
+
self.transformer.wte.weight = self.lm_head.weight
|
| 434 |
+
|
| 435 |
+
# Initialize weights
|
| 436 |
+
self.apply(self._init_weights)
|
| 437 |
+
|
| 438 |
+
# Report parameter count
|
| 439 |
+
print(f"Model initialized: {self.config.model_name}")
|
| 440 |
+
print(f"Parameters: {self.get_num_params():,}")
|
| 441 |
+
print(f"Estimated: {self.config.estimate_parameters():,}")
|
| 442 |
+
|
| 443 |
+
def _init_weights(self, module):
|
| 444 |
+
"""Initialize model weights using standard practices."""
|
| 445 |
+
if isinstance(module, nn.Linear):
|
| 446 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 447 |
+
if module.bias is not None:
|
| 448 |
+
torch.nn.init.zeros_(module.bias)
|
| 449 |
+
elif isinstance(module, nn.Embedding):
|
| 450 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 451 |
+
|
| 452 |
+
def get_num_params(self, non_embedding: bool = False) -> int:
|
| 453 |
+
"""
|
| 454 |
+
Count the number of parameters in the model.
|
| 455 |
+
|
| 456 |
+
Args:
|
| 457 |
+
non_embedding: If True, subtract embedding parameters
|
| 458 |
+
|
| 459 |
+
Returns:
|
| 460 |
+
int: Number of parameters
|
| 461 |
+
"""
|
| 462 |
+
n_params = sum(p.numel() for p in self.parameters())
|
| 463 |
+
if non_embedding:
|
| 464 |
+
n_params -= self.transformer.wpe.weight.numel()
|
| 465 |
+
n_params -= self.transformer.wte.weight.numel()
|
| 466 |
+
return n_params
|
| 467 |
+
|
| 468 |
+
def forward(
|
| 469 |
+
self,
|
| 470 |
+
idx: torch.Tensor,
|
| 471 |
+
targets: Optional[torch.Tensor] = None
|
| 472 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 473 |
+
"""
|
| 474 |
+
Forward pass of the GPT model.
|
| 475 |
+
|
| 476 |
+
Args:
|
| 477 |
+
idx: Input token indices of shape (batch_size, seq_len)
|
| 478 |
+
targets: Optional target tokens for loss calculation (batch_size, seq_len)
|
| 479 |
+
|
| 480 |
+
Returns:
|
| 481 |
+
Tuple containing:
|
| 482 |
+
- logits: Output logits of shape (batch_size, seq_len, vocab_size)
|
| 483 |
+
- loss: Cross-entropy loss if targets provided, None otherwise
|
| 484 |
+
"""
|
| 485 |
+
device = idx.device
|
| 486 |
+
b, t = idx.size()
|
| 487 |
+
assert t <= self.config.block_size, f"Sequence length {t} exceeds block size {self.config.block_size}"
|
| 488 |
+
|
| 489 |
+
# Token embeddings
|
| 490 |
+
tok_emb = self.transformer.wte(idx) # (b, t, n_embd)
|
| 491 |
+
|
| 492 |
+
# Position embeddings
|
| 493 |
+
pos = torch.arange(0, t, dtype=torch.long, device=device) # (t,)
|
| 494 |
+
pos_emb = self.transformer.wpe(pos) # (t, n_embd)
|
| 495 |
+
|
| 496 |
+
# Combine embeddings and apply dropout
|
| 497 |
+
x = self.transformer.drop(tok_emb + pos_emb)
|
| 498 |
+
|
| 499 |
+
# Pass through transformer blocks
|
| 500 |
+
for block in self.transformer.h:
|
| 501 |
+
x = block(x)
|
| 502 |
+
|
| 503 |
+
# Final layer normalization
|
| 504 |
+
x = self.transformer.ln_f(x)
|
| 505 |
+
|
| 506 |
+
# Language modeling head
|
| 507 |
+
if targets is not None:
|
| 508 |
+
# If we have targets, compute loss
|
| 509 |
+
logits = self.lm_head(x)
|
| 510 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
| 511 |
+
else:
|
| 512 |
+
# If no targets, only compute logits for the last token (more efficient for generation)
|
| 513 |
+
logits = self.lm_head(x[:, [-1], :]) # Note: using list [-1] to preserve the time dim
|
| 514 |
+
loss = None
|
| 515 |
+
|
| 516 |
+
return logits, loss
|
| 517 |
+
|
| 518 |
+
def generate(
|
| 519 |
+
self,
|
| 520 |
+
idx: torch.Tensor,
|
| 521 |
+
max_new_tokens: int = 100,
|
| 522 |
+
temperature: float = 1.0,
|
| 523 |
+
top_k: Optional[int] = None
|
| 524 |
+
) -> torch.Tensor:
|
| 525 |
+
"""
|
| 526 |
+
Generate new tokens autoregressively.
|
| 527 |
+
|
| 528 |
+
Args:
|
| 529 |
+
idx: Starting token indices (batch_size, seq_len)
|
| 530 |
+
max_new_tokens: Maximum number of new tokens to generate
|
| 531 |
+
temperature: Sampling temperature (higher = more random)
|
| 532 |
+
top_k: If set, only sample from top-k most likely tokens
|
| 533 |
+
|
| 534 |
+
Returns:
|
| 535 |
+
torch.Tensor: Generated sequence (batch_size, seq_len + max_new_tokens)
|
| 536 |
+
"""
|
| 537 |
+
self.eval()
|
| 538 |
+
with torch.no_grad():
|
| 539 |
+
for _ in range(max_new_tokens):
|
| 540 |
+
# Crop sequence if it exceeds block size
|
| 541 |
+
idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
|
| 542 |
+
|
| 543 |
+
# Forward pass
|
| 544 |
+
logits, _ = self(idx_cond)
|
| 545 |
+
|
| 546 |
+
# Get logits for the last token and apply temperature
|
| 547 |
+
logits = logits[:, -1, :] / temperature
|
| 548 |
+
|
| 549 |
+
# Optionally crop to top-k most likely tokens
|
| 550 |
+
if top_k is not None:
|
| 551 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 552 |
+
logits[logits < v[:, [-1]]] = -float('Inf')
|
| 553 |
+
|
| 554 |
+
# Apply softmax and sample
|
| 555 |
+
probs = F.softmax(logits, dim=-1)
|
| 556 |
+
idx_next = torch.multinomial(probs, num_samples=1)
|
| 557 |
+
|
| 558 |
+
# Append to sequence
|
| 559 |
+
idx = torch.cat((idx, idx_next), dim=1)
|
| 560 |
+
|
| 561 |
+
self.train() # Return to training mode
|
| 562 |
+
return idx
|
| 563 |
+
|
| 564 |
+
def estimate_memory_usage(self, batch_size: int = 1, seq_len: int = None) -> dict:
|
| 565 |
+
"""
|
| 566 |
+
Estimate memory usage for training and inference.
|
| 567 |
+
|
| 568 |
+
Args:
|
| 569 |
+
batch_size: Batch size for estimation
|
| 570 |
+
seq_len: Sequence length (defaults to block_size)
|
| 571 |
+
|
| 572 |
+
Returns:
|
| 573 |
+
dict: Memory usage estimates in MB
|
| 574 |
+
"""
|
| 575 |
+
if seq_len is None:
|
| 576 |
+
seq_len = self.config.block_size
|
| 577 |
+
|
| 578 |
+
# Model parameters (weights)
|
| 579 |
+
param_memory = self.get_num_params() * 4 / (1024**2) # 4 bytes per float32
|
| 580 |
+
|
| 581 |
+
# Activations (rough estimate)
|
| 582 |
+
activation_memory = (
|
| 583 |
+
batch_size * seq_len * self.config.n_embd * self.config.n_layer * 8 # Rough estimate
|
| 584 |
+
) / (1024**2)
|
| 585 |
+
|
| 586 |
+
# Gradients (same size as parameters during training)
|
| 587 |
+
gradient_memory = param_memory
|
| 588 |
+
|
| 589 |
+
return {
|
| 590 |
+
"parameters_mb": param_memory,
|
| 591 |
+
"activations_mb": activation_memory,
|
| 592 |
+
"gradients_mb": gradient_memory,
|
| 593 |
+
"total_training_mb": param_memory + activation_memory + gradient_memory,
|
| 594 |
+
"total_inference_mb": param_memory + activation_memory * 0.5, # No gradients needed
|
| 595 |
+
}
|
| 596 |
+
|
| 597 |
+
|
| 598 |
+
def create_model(model_size: str = "medium") -> GPTModel:
|
| 599 |
+
"""
|
| 600 |
+
Factory function to create a GPT model with predefined configurations.
|
| 601 |
+
|
| 602 |
+
Args:
|
| 603 |
+
model_size: Size of model to create ("small", "medium", "large")
|
| 604 |
+
|
| 605 |
+
Returns:
|
| 606 |
+
GPTModel: Initialized model
|
| 607 |
+
"""
|
| 608 |
+
configs = {
|
| 609 |
+
"small": GPTConfig.small(),
|
| 610 |
+
"medium": GPTConfig.medium(),
|
| 611 |
+
"large": GPTConfig.large(),
|
| 612 |
+
}
|
| 613 |
+
|
| 614 |
+
if model_size not in configs:
|
| 615 |
+
raise ValueError(f"Unknown model size: {model_size}. Choose from {list(configs.keys())}")
|
| 616 |
+
|
| 617 |
+
config = configs[model_size]
|
| 618 |
+
model = GPTModel(config)
|
| 619 |
+
|
| 620 |
+
return model
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
if __name__ == "__main__":
|
| 624 |
+
# Example usage
|
| 625 |
+
print("π§ GPT Model Architecture")
|
| 626 |
+
print("=" * 50)
|
| 627 |
+
|
| 628 |
+
# Create models of different sizes
|
| 629 |
+
for size in ["small", "medium", "large"]:
|
| 630 |
+
print(f"\n{size.upper()} MODEL:")
|
| 631 |
+
model = create_model(size)
|
| 632 |
+
|
| 633 |
+
# Show memory estimates
|
| 634 |
+
memory = model.estimate_memory_usage(batch_size=4, seq_len=512)
|
| 635 |
+
print(f"Memory (4 batch, 512 seq): {memory['total_training_mb']:.1f}MB training, {memory['total_inference_mb']:.1f}MB inference")
|
| 636 |
+
|
| 637 |
+
# Test forward pass
|
| 638 |
+
x = torch.randint(0, 32000, (2, 64)) # Batch size 2, sequence length 64
|
| 639 |
+
with torch.no_grad():
|
| 640 |
+
logits, _ = model(x)
|
| 641 |
+
print(f"Test forward pass: {x.shape} -> {logits.shape} β")
|
training/train_model.py
ADDED
|
@@ -0,0 +1,657 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright (C) 2024 Louis Chua Bean Chong
|
| 3 |
+
#
|
| 4 |
+
# This file is part of OpenLLM.
|
| 5 |
+
#
|
| 6 |
+
# OpenLLM is dual-licensed:
|
| 7 |
+
# 1. For open source use: GNU General Public License v3.0
|
| 8 |
+
# 2. For commercial use: Commercial License (contact for details)
|
| 9 |
+
#
|
| 10 |
+
# See LICENSE and docs/LICENSES.md for full license information.
|
| 11 |
+
|
| 12 |
+
"""
|
| 13 |
+
Language Model Training Script
|
| 14 |
+
|
| 15 |
+
This script implements the complete training pipeline for GPT-style language models.
|
| 16 |
+
It includes optimization, checkpointing, progress monitoring, and CPU-optimized training
|
| 17 |
+
for limited hardware environments.
|
| 18 |
+
|
| 19 |
+
FEATURES:
|
| 20 |
+
- CPU-optimized training with memory management
|
| 21 |
+
- Gradient accumulation for effective large batch sizes
|
| 22 |
+
- Learning rate scheduling with warmup
|
| 23 |
+
- Model checkpointing and resume capability
|
| 24 |
+
- Real-time monitoring of loss, perplexity, and speed
|
| 25 |
+
- Memory usage tracking and optimization
|
| 26 |
+
- Automatic mixed precision (if available)
|
| 27 |
+
|
| 28 |
+
HARDWARE OPTIMIZATION:
|
| 29 |
+
- Designed for 8GB RAM systems
|
| 30 |
+
- Efficient CPU training with PyTorch optimizations
|
| 31 |
+
- Gradient accumulation to simulate larger batches
|
| 32 |
+
- Memory cleanup and garbage collection
|
| 33 |
+
- Progress saving for long training runs
|
| 34 |
+
|
| 35 |
+
Usage:
|
| 36 |
+
python core/src/train_model.py \\
|
| 37 |
+
--model-size small \\
|
| 38 |
+
--data-file data/clean/training_data.txt \\
|
| 39 |
+
--tokenizer-dir data/tokenizer/ \\
|
| 40 |
+
--output-dir models/my-model/ \\
|
| 41 |
+
--max-steps 10000
|
| 42 |
+
|
| 43 |
+
Requirements:
|
| 44 |
+
- PyTorch
|
| 45 |
+
- SentencePiece
|
| 46 |
+
- Our model architecture and data loader
|
| 47 |
+
|
| 48 |
+
Author: Louis Chua Bean Chong
|
| 49 |
+
License: GPLv3
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
import argparse
|
| 53 |
+
import json
|
| 54 |
+
import os
|
| 55 |
+
import time
|
| 56 |
+
import math
|
| 57 |
+
import gc
|
| 58 |
+
from pathlib import Path
|
| 59 |
+
from typing import Dict, Any, Optional, Tuple
|
| 60 |
+
|
| 61 |
+
import torch
|
| 62 |
+
import torch.nn as nn
|
| 63 |
+
import torch.optim as optim
|
| 64 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
|
| 65 |
+
|
| 66 |
+
# Import our modules
|
| 67 |
+
try:
|
| 68 |
+
from model import GPTModel, GPTConfig, create_model
|
| 69 |
+
from data_loader import TextDataLoader
|
| 70 |
+
except ImportError:
|
| 71 |
+
import sys
|
| 72 |
+
sys.path.append(os.path.dirname(__file__))
|
| 73 |
+
from model import GPTModel, GPTConfig, create_model
|
| 74 |
+
from data_loader import TextDataLoader
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class ModelTrainer:
|
| 78 |
+
"""
|
| 79 |
+
Comprehensive trainer for GPT-style language models.
|
| 80 |
+
|
| 81 |
+
Handles the complete training pipeline including data loading, optimization,
|
| 82 |
+
checkpointing, and progress monitoring.
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
def __init__(
|
| 86 |
+
self,
|
| 87 |
+
model: GPTModel,
|
| 88 |
+
data_loader: TextDataLoader,
|
| 89 |
+
output_dir: str,
|
| 90 |
+
device: str = "cpu",
|
| 91 |
+
learning_rate: float = 3e-4,
|
| 92 |
+
weight_decay: float = 0.01,
|
| 93 |
+
warmup_steps: int = 1000,
|
| 94 |
+
max_steps: int = 10000,
|
| 95 |
+
gradient_accumulation_steps: int = 4,
|
| 96 |
+
gradient_clipping: float = 1.0,
|
| 97 |
+
save_every: int = 1000,
|
| 98 |
+
eval_every: int = 500,
|
| 99 |
+
log_every: int = 100
|
| 100 |
+
):
|
| 101 |
+
"""
|
| 102 |
+
Initialize the model trainer.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
model: GPT model to train
|
| 106 |
+
data_loader: Data loader for training data
|
| 107 |
+
output_dir: Directory to save checkpoints and logs
|
| 108 |
+
device: Training device ("cpu" or "cuda")
|
| 109 |
+
learning_rate: Peak learning rate
|
| 110 |
+
weight_decay: Weight decay for regularization
|
| 111 |
+
warmup_steps: Number of warmup steps for learning rate
|
| 112 |
+
max_steps: Maximum training steps
|
| 113 |
+
gradient_accumulation_steps: Steps to accumulate gradients
|
| 114 |
+
gradient_clipping: Maximum gradient norm
|
| 115 |
+
save_every: Save checkpoint every N steps
|
| 116 |
+
eval_every: Evaluate model every N steps
|
| 117 |
+
log_every: Log progress every N steps
|
| 118 |
+
"""
|
| 119 |
+
self.model = model.to(device)
|
| 120 |
+
self.data_loader = data_loader
|
| 121 |
+
self.output_dir = Path(output_dir)
|
| 122 |
+
self.device = device
|
| 123 |
+
|
| 124 |
+
# Training hyperparameters
|
| 125 |
+
self.learning_rate = learning_rate
|
| 126 |
+
self.weight_decay = weight_decay
|
| 127 |
+
self.warmup_steps = warmup_steps
|
| 128 |
+
self.max_steps = max_steps
|
| 129 |
+
self.gradient_accumulation_steps = gradient_accumulation_steps
|
| 130 |
+
self.gradient_clipping = gradient_clipping
|
| 131 |
+
|
| 132 |
+
# Logging and saving
|
| 133 |
+
self.save_every = save_every
|
| 134 |
+
self.eval_every = eval_every
|
| 135 |
+
self.log_every = log_every
|
| 136 |
+
|
| 137 |
+
# Create output directory
|
| 138 |
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 139 |
+
|
| 140 |
+
# Initialize optimizer and scheduler
|
| 141 |
+
self.optimizer = self._create_optimizer()
|
| 142 |
+
self.scheduler = self._create_scheduler()
|
| 143 |
+
|
| 144 |
+
# Training state
|
| 145 |
+
self.step = 0
|
| 146 |
+
self.epoch = 0
|
| 147 |
+
self.best_loss = float('inf')
|
| 148 |
+
self.training_log = []
|
| 149 |
+
|
| 150 |
+
# Performance tracking
|
| 151 |
+
self.start_time = None
|
| 152 |
+
self.step_times = []
|
| 153 |
+
|
| 154 |
+
print(f"π ModelTrainer initialized")
|
| 155 |
+
print(f" Device: {device}")
|
| 156 |
+
print(f" Model parameters: {model.get_num_params():,}")
|
| 157 |
+
print(f" Learning rate: {learning_rate}")
|
| 158 |
+
print(f" Max steps: {max_steps:,}")
|
| 159 |
+
print(f" Gradient accumulation: {gradient_accumulation_steps}")
|
| 160 |
+
print(f" Output directory: {output_dir}")
|
| 161 |
+
|
| 162 |
+
def _create_optimizer(self) -> optim.Optimizer:
|
| 163 |
+
"""Create AdamW optimizer with weight decay."""
|
| 164 |
+
# Separate parameters for weight decay
|
| 165 |
+
decay_params = []
|
| 166 |
+
no_decay_params = []
|
| 167 |
+
|
| 168 |
+
for name, param in self.model.named_parameters():
|
| 169 |
+
if not param.requires_grad:
|
| 170 |
+
continue
|
| 171 |
+
|
| 172 |
+
# Don't apply weight decay to biases and layer norm parameters
|
| 173 |
+
if len(param.shape) == 1 or name.endswith('.bias'):
|
| 174 |
+
no_decay_params.append(param)
|
| 175 |
+
else:
|
| 176 |
+
decay_params.append(param)
|
| 177 |
+
|
| 178 |
+
param_groups = [
|
| 179 |
+
{'params': decay_params, 'weight_decay': self.weight_decay},
|
| 180 |
+
{'params': no_decay_params, 'weight_decay': 0.0}
|
| 181 |
+
]
|
| 182 |
+
|
| 183 |
+
# Use AdamW with lower memory usage for CPU
|
| 184 |
+
optimizer = optim.AdamW(
|
| 185 |
+
param_groups,
|
| 186 |
+
lr=self.learning_rate,
|
| 187 |
+
betas=(0.9, 0.95), # Slightly different from default for LLM training
|
| 188 |
+
eps=1e-8
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
return optimizer
|
| 192 |
+
|
| 193 |
+
def _create_scheduler(self) -> torch.optim.lr_scheduler._LRScheduler:
|
| 194 |
+
"""Create learning rate scheduler with warmup and cosine decay."""
|
| 195 |
+
if self.warmup_steps > 0:
|
| 196 |
+
# Linear warmup
|
| 197 |
+
warmup_scheduler = LinearLR(
|
| 198 |
+
self.optimizer,
|
| 199 |
+
start_factor=0.01, # Start at 1% of learning rate
|
| 200 |
+
end_factor=1.0,
|
| 201 |
+
total_iters=self.warmup_steps
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
# Cosine decay after warmup
|
| 205 |
+
cosine_scheduler = CosineAnnealingLR(
|
| 206 |
+
self.optimizer,
|
| 207 |
+
T_max=self.max_steps - self.warmup_steps,
|
| 208 |
+
eta_min=self.learning_rate * 0.1 # Minimum 10% of peak LR
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
# Combine warmup and cosine decay
|
| 212 |
+
scheduler = SequentialLR(
|
| 213 |
+
self.optimizer,
|
| 214 |
+
schedulers=[warmup_scheduler, cosine_scheduler],
|
| 215 |
+
milestones=[self.warmup_steps]
|
| 216 |
+
)
|
| 217 |
+
else:
|
| 218 |
+
# Just cosine decay
|
| 219 |
+
scheduler = CosineAnnealingLR(
|
| 220 |
+
self.optimizer,
|
| 221 |
+
T_max=self.max_steps,
|
| 222 |
+
eta_min=self.learning_rate * 0.1
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
return scheduler
|
| 226 |
+
|
| 227 |
+
def _calculate_loss(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
| 228 |
+
"""
|
| 229 |
+
Calculate cross-entropy loss for autoregressive language modeling.
|
| 230 |
+
|
| 231 |
+
This method computes the standard cross-entropy loss used in language model training.
|
| 232 |
+
The loss measures how well the model predicts the next token in the sequence.
|
| 233 |
+
|
| 234 |
+
Mathematical formulation:
|
| 235 |
+
Loss = -β log(P(target_token | context))
|
| 236 |
+
where P is the softmax probability distribution over vocabulary
|
| 237 |
+
|
| 238 |
+
Implementation details:
|
| 239 |
+
- Reshapes 3D tensors to 2D for efficient computation
|
| 240 |
+
- Uses PyTorch's optimized cross_entropy function
|
| 241 |
+
- Handles padding tokens by ignoring them in loss calculation
|
| 242 |
+
- Computes mean loss across all valid positions
|
| 243 |
+
|
| 244 |
+
Why cross-entropy for language modeling:
|
| 245 |
+
- Natural choice for multi-class classification (next token prediction)
|
| 246 |
+
- Provides strong gradient signal for correct token probabilities
|
| 247 |
+
- Mathematically equivalent to minimizing negative log-likelihood
|
| 248 |
+
- Well-studied optimization properties for neural language models
|
| 249 |
+
|
| 250 |
+
Args:
|
| 251 |
+
logits: Raw model predictions of shape (batch_size, seq_len, vocab_size)
|
| 252 |
+
Contains unnormalized scores for each token in vocabulary
|
| 253 |
+
These will be converted to probabilities via softmax internally
|
| 254 |
+
targets: Ground truth next tokens of shape (batch_size, seq_len)
|
| 255 |
+
Contains token IDs representing the true next tokens
|
| 256 |
+
Should be input sequence shifted by one position
|
| 257 |
+
|
| 258 |
+
Returns:
|
| 259 |
+
torch.Tensor: Scalar loss value representing prediction error
|
| 260 |
+
Lower values indicate better next-token prediction accuracy
|
| 261 |
+
"""
|
| 262 |
+
# Reshape tensors from 3D to 2D for efficient loss computation
|
| 263 |
+
# This converts per-sequence per-position predictions to a flat structure
|
| 264 |
+
# where each row represents one prediction over the entire vocabulary
|
| 265 |
+
logits = logits.view(-1, logits.size(-1)) # (batch_size * seq_len, vocab_size)
|
| 266 |
+
targets = targets.view(-1) # (batch_size * seq_len,)
|
| 267 |
+
|
| 268 |
+
# Calculate cross-entropy loss with proper handling of special tokens
|
| 269 |
+
# ignore_index=-1 excludes padding tokens from loss calculation
|
| 270 |
+
# This prevents the model from learning to predict padding, which would skew training
|
| 271 |
+
# The function internally applies softmax to logits and computes negative log-likelihood
|
| 272 |
+
loss = nn.functional.cross_entropy(logits, targets, ignore_index=-1)
|
| 273 |
+
|
| 274 |
+
# Return scalar loss for backpropagation
|
| 275 |
+
# This loss will be used to compute gradients via automatic differentiation
|
| 276 |
+
return loss
|
| 277 |
+
|
| 278 |
+
def _get_memory_usage(self) -> Dict[str, float]:
|
| 279 |
+
"""Get current memory usage statistics."""
|
| 280 |
+
memory_stats = {}
|
| 281 |
+
|
| 282 |
+
if torch.cuda.is_available() and self.device.startswith('cuda'):
|
| 283 |
+
memory_stats['gpu_allocated_mb'] = torch.cuda.memory_allocated() / (1024**2)
|
| 284 |
+
memory_stats['gpu_cached_mb'] = torch.cuda.memory_reserved() / (1024**2)
|
| 285 |
+
|
| 286 |
+
# Estimate CPU memory (approximate)
|
| 287 |
+
import psutil
|
| 288 |
+
process = psutil.Process()
|
| 289 |
+
memory_stats['cpu_memory_mb'] = process.memory_info().rss / (1024**2)
|
| 290 |
+
|
| 291 |
+
return memory_stats
|
| 292 |
+
|
| 293 |
+
def _log_step(self, step: int, loss: float, lr: float, step_time: float) -> None:
|
| 294 |
+
"""Log training progress for a single step."""
|
| 295 |
+
perplexity = math.exp(min(loss, 10)) # Cap at exp(10) to avoid overflow
|
| 296 |
+
|
| 297 |
+
# Calculate tokens per second
|
| 298 |
+
tokens_per_batch = self.data_loader.batch_size * self.data_loader.seq_len
|
| 299 |
+
tokens_per_second = tokens_per_batch / step_time if step_time > 0 else 0
|
| 300 |
+
|
| 301 |
+
# Get memory usage
|
| 302 |
+
memory_stats = self._get_memory_usage()
|
| 303 |
+
|
| 304 |
+
# Create log entry
|
| 305 |
+
log_entry = {
|
| 306 |
+
'step': step,
|
| 307 |
+
'loss': loss,
|
| 308 |
+
'perplexity': perplexity,
|
| 309 |
+
'learning_rate': lr,
|
| 310 |
+
'step_time': step_time,
|
| 311 |
+
'tokens_per_second': tokens_per_second,
|
| 312 |
+
'memory_mb': memory_stats.get('cpu_memory_mb', 0)
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
self.training_log.append(log_entry)
|
| 316 |
+
|
| 317 |
+
# Print progress
|
| 318 |
+
elapsed_time = time.time() - self.start_time if self.start_time else 0
|
| 319 |
+
eta_seconds = (self.max_steps - step) * step_time if step_time > 0 else 0
|
| 320 |
+
eta_hours = eta_seconds / 3600
|
| 321 |
+
|
| 322 |
+
print(f"Step {step:,}/{self.max_steps:,} | "
|
| 323 |
+
f"Loss: {loss:.4f} | "
|
| 324 |
+
f"PPL: {perplexity:.2f} | "
|
| 325 |
+
f"LR: {lr:.2e} | "
|
| 326 |
+
f"Time: {step_time:.2f}s | "
|
| 327 |
+
f"Tokens/s: {tokens_per_second:.1f} | "
|
| 328 |
+
f"Memory: {memory_stats.get('cpu_memory_mb', 0):.0f}MB | "
|
| 329 |
+
f"ETA: {eta_hours:.1f}h")
|
| 330 |
+
|
| 331 |
+
def _save_checkpoint(self, step: int, is_best: bool = False) -> None:
|
| 332 |
+
"""Save model checkpoint."""
|
| 333 |
+
checkpoint = {
|
| 334 |
+
'step': step,
|
| 335 |
+
'epoch': self.epoch,
|
| 336 |
+
'model_state_dict': self.model.state_dict(),
|
| 337 |
+
'optimizer_state_dict': self.optimizer.state_dict(),
|
| 338 |
+
'scheduler_state_dict': self.scheduler.state_dict(),
|
| 339 |
+
'best_loss': self.best_loss,
|
| 340 |
+
'training_log': self.training_log,
|
| 341 |
+
'config': self.model.config.__dict__
|
| 342 |
+
}
|
| 343 |
+
|
| 344 |
+
# Save latest checkpoint
|
| 345 |
+
checkpoint_path = self.output_dir / f"checkpoint_step_{step}.pt"
|
| 346 |
+
torch.save(checkpoint, checkpoint_path)
|
| 347 |
+
|
| 348 |
+
# Save best checkpoint
|
| 349 |
+
if is_best:
|
| 350 |
+
best_path = self.output_dir / "best_model.pt"
|
| 351 |
+
torch.save(checkpoint, best_path)
|
| 352 |
+
print(f"πΎ New best model saved: {best_path}")
|
| 353 |
+
|
| 354 |
+
# Save training log
|
| 355 |
+
log_path = self.output_dir / "training_log.json"
|
| 356 |
+
with open(log_path, 'w') as f:
|
| 357 |
+
json.dump(self.training_log, f, indent=2)
|
| 358 |
+
|
| 359 |
+
print(f"πΎ Checkpoint saved: {checkpoint_path}")
|
| 360 |
+
|
| 361 |
+
def _load_checkpoint(self, checkpoint_path: str) -> None:
|
| 362 |
+
"""Load model checkpoint to resume training."""
|
| 363 |
+
if not os.path.exists(checkpoint_path):
|
| 364 |
+
print(f"β οΈ Checkpoint not found: {checkpoint_path}")
|
| 365 |
+
return
|
| 366 |
+
|
| 367 |
+
print(f"π Loading checkpoint: {checkpoint_path}")
|
| 368 |
+
|
| 369 |
+
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
| 370 |
+
|
| 371 |
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
| 372 |
+
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 373 |
+
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
| 374 |
+
|
| 375 |
+
self.step = checkpoint['step']
|
| 376 |
+
self.epoch = checkpoint['epoch']
|
| 377 |
+
self.best_loss = checkpoint['best_loss']
|
| 378 |
+
self.training_log = checkpoint.get('training_log', [])
|
| 379 |
+
|
| 380 |
+
print(f"β Checkpoint loaded successfully")
|
| 381 |
+
print(f" Resuming from step: {self.step:,}")
|
| 382 |
+
print(f" Best loss so far: {self.best_loss:.4f}")
|
| 383 |
+
|
| 384 |
+
def train(self) -> None:
|
| 385 |
+
"""Main training loop."""
|
| 386 |
+
print(f"\nπ Starting training...")
|
| 387 |
+
print(f" Model: {self.model.config.model_name}")
|
| 388 |
+
print(f" Parameters: {self.model.get_num_params():,}")
|
| 389 |
+
print(f" Device: {self.device}")
|
| 390 |
+
print(f" Max steps: {self.max_steps:,}")
|
| 391 |
+
print("=" * 80)
|
| 392 |
+
|
| 393 |
+
self.model.train()
|
| 394 |
+
self.start_time = time.time()
|
| 395 |
+
|
| 396 |
+
# Initialize gradient accumulation
|
| 397 |
+
accumulated_loss = 0.0
|
| 398 |
+
self.optimizer.zero_grad()
|
| 399 |
+
|
| 400 |
+
for batch_idx, (input_ids, target_ids) in enumerate(self.data_loader):
|
| 401 |
+
if self.step >= self.max_steps:
|
| 402 |
+
break
|
| 403 |
+
|
| 404 |
+
step_start_time = time.time()
|
| 405 |
+
|
| 406 |
+
# Move batch to device
|
| 407 |
+
input_ids = input_ids.to(self.device)
|
| 408 |
+
target_ids = target_ids.to(self.device)
|
| 409 |
+
|
| 410 |
+
# Forward pass (model computes loss internally when targets provided)
|
| 411 |
+
logits, loss = self.model(input_ids, target_ids)
|
| 412 |
+
|
| 413 |
+
# Scale loss for gradient accumulation
|
| 414 |
+
loss = loss / self.gradient_accumulation_steps
|
| 415 |
+
accumulated_loss += loss.item()
|
| 416 |
+
|
| 417 |
+
# Backward pass
|
| 418 |
+
loss.backward()
|
| 419 |
+
|
| 420 |
+
# Update weights every gradient_accumulation_steps
|
| 421 |
+
if (batch_idx + 1) % self.gradient_accumulation_steps == 0:
|
| 422 |
+
# Clip gradients
|
| 423 |
+
if self.gradient_clipping > 0:
|
| 424 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.gradient_clipping)
|
| 425 |
+
|
| 426 |
+
# Update parameters
|
| 427 |
+
self.optimizer.step()
|
| 428 |
+
self.scheduler.step()
|
| 429 |
+
self.optimizer.zero_grad()
|
| 430 |
+
|
| 431 |
+
# Update step count
|
| 432 |
+
self.step += 1
|
| 433 |
+
step_time = time.time() - step_start_time
|
| 434 |
+
self.step_times.append(step_time)
|
| 435 |
+
|
| 436 |
+
# Get current learning rate
|
| 437 |
+
current_lr = self.scheduler.get_last_lr()[0]
|
| 438 |
+
|
| 439 |
+
# Log progress
|
| 440 |
+
if self.step % self.log_every == 0:
|
| 441 |
+
avg_loss = accumulated_loss
|
| 442 |
+
self._log_step(self.step, avg_loss, current_lr, step_time)
|
| 443 |
+
|
| 444 |
+
# Save checkpoint
|
| 445 |
+
if self.step % self.save_every == 0:
|
| 446 |
+
is_best = accumulated_loss < self.best_loss
|
| 447 |
+
if is_best:
|
| 448 |
+
self.best_loss = accumulated_loss
|
| 449 |
+
|
| 450 |
+
self._save_checkpoint(self.step, is_best)
|
| 451 |
+
|
| 452 |
+
# Clean up memory periodically
|
| 453 |
+
if self.step % 100 == 0:
|
| 454 |
+
gc.collect()
|
| 455 |
+
|
| 456 |
+
# Reset accumulated loss
|
| 457 |
+
accumulated_loss = 0.0
|
| 458 |
+
|
| 459 |
+
# Check if training complete
|
| 460 |
+
if self.step >= self.max_steps:
|
| 461 |
+
break
|
| 462 |
+
|
| 463 |
+
# Final checkpoint
|
| 464 |
+
print(f"\nπ Training completed!")
|
| 465 |
+
self._save_checkpoint(self.step, is_best=True)
|
| 466 |
+
|
| 467 |
+
# Training summary
|
| 468 |
+
total_time = time.time() - self.start_time
|
| 469 |
+
avg_step_time = sum(self.step_times) / len(self.step_times) if self.step_times else 0
|
| 470 |
+
|
| 471 |
+
print(f"\nπ Training Summary:")
|
| 472 |
+
print(f" Steps completed: {self.step:,}")
|
| 473 |
+
print(f" Total time: {total_time/3600:.2f} hours")
|
| 474 |
+
print(f" Average time per step: {avg_step_time:.2f}s")
|
| 475 |
+
print(f" Final loss: {self.best_loss:.4f}")
|
| 476 |
+
print(f" Final perplexity: {math.exp(min(self.best_loss, 10)):.2f}")
|
| 477 |
+
print(f" Model saved to: {self.output_dir}")
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
def main():
|
| 481 |
+
"""Main function to handle command line training."""
|
| 482 |
+
parser = argparse.ArgumentParser(
|
| 483 |
+
description="Train a GPT-style language model",
|
| 484 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 485 |
+
epilog="""
|
| 486 |
+
Examples:
|
| 487 |
+
# Train small model for quick experimentation
|
| 488 |
+
python core/src/train_model.py \\
|
| 489 |
+
--model-size small \\
|
| 490 |
+
--max-steps 5000 \\
|
| 491 |
+
--output-dir models/test-small
|
| 492 |
+
|
| 493 |
+
# Train medium model with custom settings
|
| 494 |
+
python core/src/train_model.py \\
|
| 495 |
+
--model-size medium \\
|
| 496 |
+
--learning-rate 1e-4 \\
|
| 497 |
+
--batch-size 2 \\
|
| 498 |
+
--max-steps 50000 \\
|
| 499 |
+
--output-dir models/my-medium-model
|
| 500 |
+
"""
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
# Model and data arguments
|
| 504 |
+
parser.add_argument(
|
| 505 |
+
"--model-size",
|
| 506 |
+
choices=["small", "medium", "large"],
|
| 507 |
+
default="small",
|
| 508 |
+
help="Model size to train (default: small)"
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
parser.add_argument(
|
| 512 |
+
"--data-file",
|
| 513 |
+
default="data/clean/training_data.txt",
|
| 514 |
+
help="Path to training text file (default: data/clean/training_data.txt)"
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
parser.add_argument(
|
| 518 |
+
"--tokenizer-dir",
|
| 519 |
+
default="data/tokenizer/",
|
| 520 |
+
help="Path to tokenizer directory (default: data/tokenizer/)"
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
parser.add_argument(
|
| 524 |
+
"--output-dir",
|
| 525 |
+
required=True,
|
| 526 |
+
help="Output directory for model checkpoints"
|
| 527 |
+
)
|
| 528 |
+
|
| 529 |
+
# Training hyperparameters
|
| 530 |
+
parser.add_argument(
|
| 531 |
+
"--seq-len",
|
| 532 |
+
type=int,
|
| 533 |
+
default=512,
|
| 534 |
+
help="Sequence length for training (default: 512)"
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
parser.add_argument(
|
| 538 |
+
"--batch-size",
|
| 539 |
+
type=int,
|
| 540 |
+
default=4,
|
| 541 |
+
help="Batch size (default: 4)"
|
| 542 |
+
)
|
| 543 |
+
|
| 544 |
+
parser.add_argument(
|
| 545 |
+
"--learning-rate",
|
| 546 |
+
type=float,
|
| 547 |
+
default=3e-4,
|
| 548 |
+
help="Learning rate (default: 3e-4)"
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
parser.add_argument(
|
| 552 |
+
"--max-steps",
|
| 553 |
+
type=int,
|
| 554 |
+
default=10000,
|
| 555 |
+
help="Maximum training steps (default: 10000)"
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
parser.add_argument(
|
| 559 |
+
"--warmup-steps",
|
| 560 |
+
type=int,
|
| 561 |
+
default=1000,
|
| 562 |
+
help="Warmup steps (default: 1000)"
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
+
parser.add_argument(
|
| 566 |
+
"--gradient-accumulation-steps",
|
| 567 |
+
type=int,
|
| 568 |
+
default=4,
|
| 569 |
+
help="Gradient accumulation steps (default: 4)"
|
| 570 |
+
)
|
| 571 |
+
|
| 572 |
+
parser.add_argument(
|
| 573 |
+
"--device",
|
| 574 |
+
choices=["cpu", "cuda", "auto"],
|
| 575 |
+
default="auto",
|
| 576 |
+
help="Training device (default: auto)"
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
parser.add_argument(
|
| 580 |
+
"--resume",
|
| 581 |
+
help="Path to checkpoint to resume training from"
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
parser.add_argument(
|
| 585 |
+
"--save-every",
|
| 586 |
+
type=int,
|
| 587 |
+
default=1000,
|
| 588 |
+
help="Save checkpoint every N steps (default: 1000)"
|
| 589 |
+
)
|
| 590 |
+
|
| 591 |
+
args = parser.parse_args()
|
| 592 |
+
|
| 593 |
+
print("π OpenLLM Model Training")
|
| 594 |
+
print("=" * 60)
|
| 595 |
+
|
| 596 |
+
# Determine device
|
| 597 |
+
if args.device == "auto":
|
| 598 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 599 |
+
else:
|
| 600 |
+
device = args.device
|
| 601 |
+
|
| 602 |
+
print(f"Using device: {device}")
|
| 603 |
+
|
| 604 |
+
try:
|
| 605 |
+
# Create model
|
| 606 |
+
print(f"\nποΈ Creating {args.model_size} model...")
|
| 607 |
+
model = create_model(args.model_size)
|
| 608 |
+
|
| 609 |
+
# Create data loader
|
| 610 |
+
print(f"\nπ Setting up data loader...")
|
| 611 |
+
tokenizer_path = os.path.join(args.tokenizer_dir, "tokenizer.model")
|
| 612 |
+
|
| 613 |
+
data_loader = TextDataLoader(
|
| 614 |
+
data_file=args.data_file,
|
| 615 |
+
tokenizer_path=tokenizer_path,
|
| 616 |
+
seq_len=args.seq_len,
|
| 617 |
+
batch_size=args.batch_size,
|
| 618 |
+
shuffle=True
|
| 619 |
+
)
|
| 620 |
+
|
| 621 |
+
# Get data statistics
|
| 622 |
+
data_stats = data_loader.get_data_stats()
|
| 623 |
+
|
| 624 |
+
# Create trainer
|
| 625 |
+
print(f"\nπ― Setting up trainer...")
|
| 626 |
+
trainer = ModelTrainer(
|
| 627 |
+
model=model,
|
| 628 |
+
data_loader=data_loader,
|
| 629 |
+
output_dir=args.output_dir,
|
| 630 |
+
device=device,
|
| 631 |
+
learning_rate=args.learning_rate,
|
| 632 |
+
max_steps=args.max_steps,
|
| 633 |
+
warmup_steps=args.warmup_steps,
|
| 634 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 635 |
+
save_every=args.save_every
|
| 636 |
+
)
|
| 637 |
+
|
| 638 |
+
# Resume from checkpoint if specified
|
| 639 |
+
if args.resume:
|
| 640 |
+
trainer._load_checkpoint(args.resume)
|
| 641 |
+
|
| 642 |
+
# Start training
|
| 643 |
+
trainer.train()
|
| 644 |
+
|
| 645 |
+
print(f"\nπ Training completed successfully!")
|
| 646 |
+
|
| 647 |
+
except Exception as e:
|
| 648 |
+
print(f"\nβ Training failed: {e}")
|
| 649 |
+
import traceback
|
| 650 |
+
traceback.print_exc()
|
| 651 |
+
return False
|
| 652 |
+
|
| 653 |
+
return True
|
| 654 |
+
|
| 655 |
+
|
| 656 |
+
if __name__ == "__main__":
|
| 657 |
+
main()
|
training/train_tokenizer.py
ADDED
|
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright (C) 2024 Louis Chua Bean Chong
|
| 3 |
+
#
|
| 4 |
+
# This file is part of OpenLLM.
|
| 5 |
+
#
|
| 6 |
+
# OpenLLM is dual-licensed:
|
| 7 |
+
# 1. For open source use: GNU General Public License v3.0
|
| 8 |
+
# 2. For commercial use: Commercial License (contact for details)
|
| 9 |
+
#
|
| 10 |
+
# See LICENSE and docs/LICENSES.md for full license information.
|
| 11 |
+
|
| 12 |
+
"""
|
| 13 |
+
Train a SentencePiece tokenizer from scratch using the prepared training data.
|
| 14 |
+
|
| 15 |
+
OVERVIEW:
|
| 16 |
+
This script trains a SentencePiece tokenizer on the cleaned text data from the SQUAD dataset
|
| 17 |
+
or any other text corpus. SentencePiece is a subword tokenizer that works well for language
|
| 18 |
+
models and supports multiple languages without requiring pre-tokenization.
|
| 19 |
+
|
| 20 |
+
FEATURES:
|
| 21 |
+
- Supports BPE (Byte Pair Encoding) and Unigram tokenization algorithms
|
| 22 |
+
- Configurable vocabulary size (recommended: 8k-64k for LLMs)
|
| 23 |
+
- Handles special tokens (BOS, EOS, UNK, PAD)
|
| 24 |
+
- Outputs tokenizer model files compatible with Hugging Face
|
| 25 |
+
- Comprehensive statistics and vocabulary analysis
|
| 26 |
+
|
| 27 |
+
TOKENIZER OUTPUT:
|
| 28 |
+
- tokenizer.model: SentencePiece model file
|
| 29 |
+
- tokenizer.vocab: Human-readable vocabulary file
|
| 30 |
+
- tokenizer_config.json: Configuration for Hugging Face integration
|
| 31 |
+
|
| 32 |
+
Usage:
|
| 33 |
+
python core/src/train_tokenizer.py --input data/clean/training_data.txt --vocab_size 32000
|
| 34 |
+
|
| 35 |
+
Advanced usage:
|
| 36 |
+
python core/src/train_tokenizer.py \\
|
| 37 |
+
--input data/clean/training_data.txt \\
|
| 38 |
+
--vocab_size 32000 \\
|
| 39 |
+
--model_type bpe \\
|
| 40 |
+
--output_dir data/tokenizer/ \\
|
| 41 |
+
--character_coverage 0.9995
|
| 42 |
+
|
| 43 |
+
Requirements:
|
| 44 |
+
pip install sentencepiece
|
| 45 |
+
|
| 46 |
+
Example setup:
|
| 47 |
+
```bash
|
| 48 |
+
# If not already in virtual environment
|
| 49 |
+
python -m venv venv
|
| 50 |
+
source venv/bin/activate # Linux/macOS
|
| 51 |
+
# .\venv\Scripts\Activate.ps1 # Windows PowerShell
|
| 52 |
+
|
| 53 |
+
# Install SentencePiece
|
| 54 |
+
pip install sentencepiece
|
| 55 |
+
|
| 56 |
+
# Train tokenizer
|
| 57 |
+
python core/src/train_tokenizer.py --input data/clean/training_data.txt --vocab_size 32000
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
import argparse
|
| 63 |
+
import json
|
| 64 |
+
import os
|
| 65 |
+
import time
|
| 66 |
+
from pathlib import Path
|
| 67 |
+
from typing import Dict, Any
|
| 68 |
+
|
| 69 |
+
try:
|
| 70 |
+
import sentencepiece as spm
|
| 71 |
+
except ImportError:
|
| 72 |
+
print("ERROR: SentencePiece not installed. Run: pip install sentencepiece")
|
| 73 |
+
exit(1)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def validate_input_file(input_path: str) -> None:
|
| 77 |
+
"""
|
| 78 |
+
Validate that the input training file exists and is readable.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
input_path (str): Path to the training text file
|
| 82 |
+
|
| 83 |
+
Raises:
|
| 84 |
+
FileNotFoundError: If input file doesn't exist
|
| 85 |
+
ValueError: If input file is empty or unreadable
|
| 86 |
+
"""
|
| 87 |
+
if not os.path.exists(input_path):
|
| 88 |
+
raise FileNotFoundError(f"Training data file not found: {input_path}")
|
| 89 |
+
|
| 90 |
+
# Check file size and readability
|
| 91 |
+
file_size = os.path.getsize(input_path)
|
| 92 |
+
if file_size == 0:
|
| 93 |
+
raise ValueError(f"Training data file is empty: {input_path}")
|
| 94 |
+
|
| 95 |
+
# Test that we can read the file
|
| 96 |
+
try:
|
| 97 |
+
with open(input_path, 'r', encoding='utf-8') as f:
|
| 98 |
+
first_line = f.readline()
|
| 99 |
+
if not first_line.strip():
|
| 100 |
+
raise ValueError(f"Training data file appears to be empty or contains only whitespace")
|
| 101 |
+
except UnicodeDecodeError as e:
|
| 102 |
+
raise ValueError(f"Cannot read training data file as UTF-8: {e}")
|
| 103 |
+
|
| 104 |
+
print(f"β Input file validated: {input_path} ({file_size:,} bytes)")
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def count_training_sentences(input_path: str) -> int:
|
| 108 |
+
"""
|
| 109 |
+
Count the number of training sentences/lines in the input file.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
input_path (str): Path to the training text file
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
int: Number of lines in the file
|
| 116 |
+
"""
|
| 117 |
+
print("Counting training sentences...")
|
| 118 |
+
with open(input_path, 'r', encoding='utf-8') as f:
|
| 119 |
+
count = sum(1 for line in f if line.strip())
|
| 120 |
+
print(f"β Found {count:,} training sentences")
|
| 121 |
+
return count
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def train_sentencepiece_tokenizer(
|
| 125 |
+
input_path: str,
|
| 126 |
+
output_dir: str,
|
| 127 |
+
vocab_size: int = 32000,
|
| 128 |
+
model_type: str = "bpe",
|
| 129 |
+
character_coverage: float = 0.9995,
|
| 130 |
+
max_sentence_length: int = 4192,
|
| 131 |
+
input_sentence_size: int = 10000000,
|
| 132 |
+
shuffle_input_sentence: bool = True,
|
| 133 |
+
) -> Dict[str, Any]:
|
| 134 |
+
"""
|
| 135 |
+
Train a SentencePiece tokenizer with the specified parameters.
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
input_path (str): Path to training text file
|
| 139 |
+
output_dir (str): Directory to save tokenizer files
|
| 140 |
+
vocab_size (int): Target vocabulary size (recommended: 8k-64k)
|
| 141 |
+
model_type (str): Algorithm type ('bpe' or 'unigram')
|
| 142 |
+
character_coverage (float): Character coverage (0.9995 for English, 1.0 for Japanese)
|
| 143 |
+
max_sentence_length (int): Maximum sentence length in characters
|
| 144 |
+
input_sentence_size (int): Maximum number of sentences to use for training
|
| 145 |
+
shuffle_input_sentence (bool): Whether to shuffle input sentences
|
| 146 |
+
|
| 147 |
+
Returns:
|
| 148 |
+
Dict[str, Any]: Training statistics and configuration
|
| 149 |
+
"""
|
| 150 |
+
# Ensure output directory exists
|
| 151 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 152 |
+
|
| 153 |
+
# Define output paths
|
| 154 |
+
model_prefix = os.path.join(output_dir, "tokenizer")
|
| 155 |
+
|
| 156 |
+
# SentencePiece training parameters
|
| 157 |
+
train_params = [
|
| 158 |
+
f"--input={input_path}",
|
| 159 |
+
f"--model_prefix={model_prefix}",
|
| 160 |
+
f"--vocab_size={vocab_size}",
|
| 161 |
+
f"--model_type={model_type}",
|
| 162 |
+
f"--character_coverage={character_coverage}",
|
| 163 |
+
f"--max_sentence_length={max_sentence_length}",
|
| 164 |
+
f"--input_sentence_size={input_sentence_size}",
|
| 165 |
+
f"--shuffle_input_sentence={shuffle_input_sentence}",
|
| 166 |
+
|
| 167 |
+
# Special tokens for language modeling
|
| 168 |
+
"--pad_id=0", # Padding token
|
| 169 |
+
"--unk_id=1", # Unknown token
|
| 170 |
+
"--bos_id=2", # Beginning of sequence
|
| 171 |
+
"--eos_id=3", # End of sequence
|
| 172 |
+
|
| 173 |
+
# Additional useful parameters
|
| 174 |
+
"--split_by_unicode_script=true", # Better handling of mixed scripts
|
| 175 |
+
"--split_by_whitespace=true", # Split on whitespace
|
| 176 |
+
"--remove_extra_whitespaces=true", # Clean up whitespace
|
| 177 |
+
"--normalization_rule_name=identity", # Keep original text as-is
|
| 178 |
+
]
|
| 179 |
+
|
| 180 |
+
print(f"\nTraining SentencePiece tokenizer...")
|
| 181 |
+
print(f" Algorithm: {model_type.upper()}")
|
| 182 |
+
print(f" Vocabulary size: {vocab_size:,}")
|
| 183 |
+
print(f" Character coverage: {character_coverage}")
|
| 184 |
+
print(f" Output directory: {output_dir}")
|
| 185 |
+
print(f" Model files: {model_prefix}.model, {model_prefix}.vocab")
|
| 186 |
+
|
| 187 |
+
# Record training start time
|
| 188 |
+
start_time = time.time()
|
| 189 |
+
|
| 190 |
+
# Train the tokenizer
|
| 191 |
+
try:
|
| 192 |
+
spm.SentencePieceTrainer.train(" ".join(train_params))
|
| 193 |
+
training_time = time.time() - start_time
|
| 194 |
+
print(f"β Tokenizer training completed in {training_time:.1f} seconds")
|
| 195 |
+
except Exception as e:
|
| 196 |
+
raise RuntimeError(f"SentencePiece training failed: {e}")
|
| 197 |
+
|
| 198 |
+
# Verify output files were created
|
| 199 |
+
model_file = f"{model_prefix}.model"
|
| 200 |
+
vocab_file = f"{model_prefix}.vocab"
|
| 201 |
+
|
| 202 |
+
if not os.path.exists(model_file):
|
| 203 |
+
raise RuntimeError(f"Expected model file not created: {model_file}")
|
| 204 |
+
if not os.path.exists(vocab_file):
|
| 205 |
+
raise RuntimeError(f"Expected vocab file not created: {vocab_file}")
|
| 206 |
+
|
| 207 |
+
print(f"β Model file created: {model_file} ({os.path.getsize(model_file):,} bytes)")
|
| 208 |
+
print(f"β Vocab file created: {vocab_file} ({os.path.getsize(vocab_file):,} bytes)")
|
| 209 |
+
|
| 210 |
+
# Return training configuration and statistics
|
| 211 |
+
config = {
|
| 212 |
+
"model_type": model_type,
|
| 213 |
+
"vocab_size": vocab_size,
|
| 214 |
+
"character_coverage": character_coverage,
|
| 215 |
+
"max_sentence_length": max_sentence_length,
|
| 216 |
+
"training_time_seconds": training_time,
|
| 217 |
+
"input_file": input_path,
|
| 218 |
+
"output_directory": output_dir,
|
| 219 |
+
"model_file": model_file,
|
| 220 |
+
"vocab_file": vocab_file,
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
return config
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def test_tokenizer(model_path: str, test_sentences: list = None) -> None:
|
| 227 |
+
"""
|
| 228 |
+
Test the trained tokenizer on sample sentences to verify it works correctly.
|
| 229 |
+
|
| 230 |
+
Args:
|
| 231 |
+
model_path (str): Path to the trained .model file
|
| 232 |
+
test_sentences (list): Optional list of test sentences
|
| 233 |
+
"""
|
| 234 |
+
print(f"\nTesting trained tokenizer...")
|
| 235 |
+
|
| 236 |
+
# Load the trained tokenizer
|
| 237 |
+
sp = spm.SentencePieceProcessor()
|
| 238 |
+
sp.load(model_path)
|
| 239 |
+
|
| 240 |
+
# Default test sentences if none provided
|
| 241 |
+
if test_sentences is None:
|
| 242 |
+
test_sentences = [
|
| 243 |
+
"Hello, world! This is a test sentence.",
|
| 244 |
+
"The quick brown fox jumps over the lazy dog.",
|
| 245 |
+
"Machine learning and artificial intelligence are transforming technology.",
|
| 246 |
+
"SentencePiece tokenization works well for language models.",
|
| 247 |
+
]
|
| 248 |
+
|
| 249 |
+
print(f"Vocabulary size: {sp.vocab_size():,}")
|
| 250 |
+
print(f"Special tokens: PAD={sp.pad_id()}, UNK={sp.unk_id()}, BOS={sp.bos_id()}, EOS={sp.eos_id()}")
|
| 251 |
+
|
| 252 |
+
print("\nTokenization examples:")
|
| 253 |
+
for i, sentence in enumerate(test_sentences, 1):
|
| 254 |
+
# Encode to token IDs and pieces
|
| 255 |
+
token_ids = sp.encode(sentence)
|
| 256 |
+
token_pieces = sp.encode(sentence, out_type=str)
|
| 257 |
+
|
| 258 |
+
print(f"\n{i}. Input: {sentence}")
|
| 259 |
+
print(f" Tokens ({len(token_pieces)}): {token_pieces}")
|
| 260 |
+
print(f" IDs: {token_ids[:10]}{'...' if len(token_ids) > 10 else ''}")
|
| 261 |
+
|
| 262 |
+
# Test decoding
|
| 263 |
+
decoded = sp.decode(token_ids)
|
| 264 |
+
print(f" Decoded: {decoded}")
|
| 265 |
+
|
| 266 |
+
# Verify round-trip encoding/decoding
|
| 267 |
+
if decoded.strip() != sentence.strip():
|
| 268 |
+
print(f" β οΈ Warning: Decode mismatch!")
|
| 269 |
+
|
| 270 |
+
print("β Tokenizer testing completed")
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def save_huggingface_config(output_dir: str, config: Dict[str, Any]) -> None:
|
| 274 |
+
"""
|
| 275 |
+
Save a Hugging Face compatible tokenizer configuration file.
|
| 276 |
+
|
| 277 |
+
Args:
|
| 278 |
+
output_dir (str): Directory containing the tokenizer files
|
| 279 |
+
config (Dict[str, Any]): Tokenizer configuration
|
| 280 |
+
"""
|
| 281 |
+
# Create Hugging Face tokenizer config
|
| 282 |
+
hf_config = {
|
| 283 |
+
"tokenizer_class": "SentencePieceTokenizer",
|
| 284 |
+
"model_type": config["model_type"],
|
| 285 |
+
"vocab_size": config["vocab_size"],
|
| 286 |
+
"model_file": "tokenizer.model",
|
| 287 |
+
"special_tokens": {
|
| 288 |
+
"pad_token": "<pad>",
|
| 289 |
+
"unk_token": "<unk>",
|
| 290 |
+
"bos_token": "<s>",
|
| 291 |
+
"eos_token": "</s>",
|
| 292 |
+
},
|
| 293 |
+
"special_token_ids": {
|
| 294 |
+
"pad_token_id": 0,
|
| 295 |
+
"unk_token_id": 1,
|
| 296 |
+
"bos_token_id": 2,
|
| 297 |
+
"eos_token_id": 3,
|
| 298 |
+
}
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
config_path = os.path.join(output_dir, "tokenizer_config.json")
|
| 302 |
+
with open(config_path, 'w', encoding='utf-8') as f:
|
| 303 |
+
json.dump(hf_config, f, indent=2, ensure_ascii=False)
|
| 304 |
+
|
| 305 |
+
print(f"β Hugging Face config saved: {config_path}")
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def main():
|
| 309 |
+
"""Main function to handle command line arguments and orchestrate tokenizer training."""
|
| 310 |
+
parser = argparse.ArgumentParser(
|
| 311 |
+
description="Train a SentencePiece tokenizer for language model training",
|
| 312 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 313 |
+
epilog="""
|
| 314 |
+
Examples:
|
| 315 |
+
# Basic usage with SQUAD data
|
| 316 |
+
python core/src/train_tokenizer.py --input data/clean/training_data.txt --vocab_size 32000
|
| 317 |
+
|
| 318 |
+
# Advanced configuration
|
| 319 |
+
python core/src/train_tokenizer.py \\
|
| 320 |
+
--input data/clean/training_data.txt \\
|
| 321 |
+
--vocab_size 32000 \\
|
| 322 |
+
--model_type bpe \\
|
| 323 |
+
--output_dir data/tokenizer/ \\
|
| 324 |
+
--character_coverage 0.9995
|
| 325 |
+
"""
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
# Required arguments
|
| 329 |
+
parser.add_argument(
|
| 330 |
+
"--input",
|
| 331 |
+
required=True,
|
| 332 |
+
help="Path to training text file (e.g., data/clean/training_data.txt)"
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
# Optional arguments with sensible defaults
|
| 336 |
+
parser.add_argument(
|
| 337 |
+
"--vocab_size",
|
| 338 |
+
type=int,
|
| 339 |
+
default=32000,
|
| 340 |
+
help="Vocabulary size (default: 32000, recommended: 8k-64k)"
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
parser.add_argument(
|
| 344 |
+
"--model_type",
|
| 345 |
+
choices=["bpe", "unigram"],
|
| 346 |
+
default="bpe",
|
| 347 |
+
help="Tokenization algorithm (default: bpe)"
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
parser.add_argument(
|
| 351 |
+
"--output_dir",
|
| 352 |
+
default="data/tokenizer/",
|
| 353 |
+
help="Output directory for tokenizer files (default: data/tokenizer/)"
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
parser.add_argument(
|
| 357 |
+
"--character_coverage",
|
| 358 |
+
type=float,
|
| 359 |
+
default=0.9995,
|
| 360 |
+
help="Character coverage (default: 0.9995 for English)"
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
parser.add_argument(
|
| 364 |
+
"--max_sentence_length",
|
| 365 |
+
type=int,
|
| 366 |
+
default=4192,
|
| 367 |
+
help="Maximum sentence length in characters (default: 4192)"
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
parser.add_argument(
|
| 371 |
+
"--no_test",
|
| 372 |
+
action="store_true",
|
| 373 |
+
help="Skip tokenizer testing after training"
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
args = parser.parse_args()
|
| 377 |
+
|
| 378 |
+
print("π€ SentencePiece Tokenizer Training")
|
| 379 |
+
print("=" * 50)
|
| 380 |
+
|
| 381 |
+
try:
|
| 382 |
+
# Step 1: Validate input file
|
| 383 |
+
validate_input_file(args.input)
|
| 384 |
+
|
| 385 |
+
# Step 2: Count training data
|
| 386 |
+
sentence_count = count_training_sentences(args.input)
|
| 387 |
+
|
| 388 |
+
# Step 3: Train tokenizer
|
| 389 |
+
config = train_sentencepiece_tokenizer(
|
| 390 |
+
input_path=args.input,
|
| 391 |
+
output_dir=args.output_dir,
|
| 392 |
+
vocab_size=args.vocab_size,
|
| 393 |
+
model_type=args.model_type,
|
| 394 |
+
character_coverage=args.character_coverage,
|
| 395 |
+
max_sentence_length=args.max_sentence_length,
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
# Step 4: Save Hugging Face compatible config
|
| 399 |
+
save_huggingface_config(args.output_dir, config)
|
| 400 |
+
|
| 401 |
+
# Step 5: Test tokenizer (unless skipped)
|
| 402 |
+
if not args.no_test:
|
| 403 |
+
model_path = os.path.join(args.output_dir, "tokenizer.model")
|
| 404 |
+
test_tokenizer(model_path)
|
| 405 |
+
|
| 406 |
+
# Step 6: Print summary
|
| 407 |
+
print(f"\nπ Tokenizer training completed successfully!")
|
| 408 |
+
print(f"π Output directory: {args.output_dir}")
|
| 409 |
+
print(f"π Vocabulary size: {config['vocab_size']:,}")
|
| 410 |
+
print(f"β±οΈ Training time: {config['training_time_seconds']:.1f}s")
|
| 411 |
+
print(f"π Training sentences: {sentence_count:,}")
|
| 412 |
+
|
| 413 |
+
print(f"\nFiles created:")
|
| 414 |
+
print(f" β’ {config['model_file']} - SentencePiece model")
|
| 415 |
+
print(f" β’ {config['vocab_file']} - Vocabulary file")
|
| 416 |
+
print(f" β’ {os.path.join(args.output_dir, 'tokenizer_config.json')} - Hugging Face config")
|
| 417 |
+
|
| 418 |
+
print(f"\nTo use this tokenizer in your language model:")
|
| 419 |
+
print(f" import sentencepiece as spm")
|
| 420 |
+
print(f" sp = spm.SentencePieceProcessor()")
|
| 421 |
+
print(f" sp.load('{config['model_file']}')")
|
| 422 |
+
|
| 423 |
+
except Exception as e:
|
| 424 |
+
print(f"\nβ Error: {e}")
|
| 425 |
+
exit(1)
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
if __name__ == "__main__":
|
| 429 |
+
main()
|