feat: Sync training infrastructure from main repository
Browse files- app.py +960 -159
- requirements.txt +44 -19
- training/data_loader.py +112 -107
- training/evaluate_model.py +298 -311
- training/model.py +173 -158
- training/train_model.py +169 -189
- training/train_tokenizer.py +91 -91
app.py
CHANGED
|
@@ -1,223 +1,1024 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
-
OpenLLM Training Space -
|
| 4 |
|
| 5 |
-
This
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
Author: Louis Chua Bean Chong
|
| 9 |
-
License:
|
|
|
|
|
|
|
| 10 |
"""
|
| 11 |
|
| 12 |
-
import os
|
| 13 |
-
import sys
|
| 14 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
from pathlib import Path
|
| 16 |
|
| 17 |
-
# Import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
try:
|
| 19 |
-
from
|
| 20 |
-
from
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
| 22 |
except ImportError as e:
|
| 23 |
-
|
| 24 |
-
print(
|
|
|
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
-
|
| 28 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
-
def
|
| 31 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
try:
|
| 33 |
-
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
except Exception as e:
|
| 52 |
-
return f"β
|
| 53 |
|
| 54 |
-
def
|
| 55 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
try:
|
| 57 |
-
|
| 58 |
-
|
|
|
|
| 59 |
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
-
|
|
|
|
|
|
|
| 73 |
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
except Exception as e:
|
| 77 |
-
return f"β
|
| 78 |
|
| 79 |
-
def
|
| 80 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
try:
|
| 82 |
-
#
|
| 83 |
-
|
| 84 |
-
|
|
|
|
| 85 |
|
| 86 |
-
|
| 87 |
-
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
|
|
|
| 106 |
|
| 107 |
-
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
-
return
|
| 110 |
|
| 111 |
except Exception as e:
|
| 112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
|
|
|
|
|
|
| 124 |
|
| 125 |
-
|
|
|
|
| 126 |
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
model_size = gr.Dropdown(
|
| 160 |
choices=["small", "medium", "large"],
|
| 161 |
value="small",
|
| 162 |
label="Model Size",
|
| 163 |
-
info="
|
| 164 |
)
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
minimum=
|
| 170 |
-
maximum=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
)
|
| 172 |
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
|
| 190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
- **integrate_auth_into_training.py**: Integration guide
|
| 195 |
-
- **setup_hf_space_auth.py**: Space authentication setup
|
| 196 |
-
- **verify_space_auth.py**: Space verification script
|
| 197 |
|
| 198 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
|
| 205 |
-
|
|
|
|
|
|
|
| 206 |
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
|
| 213 |
-
return
|
| 214 |
-
|
| 215 |
|
| 216 |
if __name__ == "__main__":
|
| 217 |
-
#
|
| 218 |
-
interface
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
server_port=7860,
|
| 222 |
-
share=False
|
| 223 |
-
)
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
+
OpenLLM Training Space Application - Fixed with Uploaded Modules
|
| 4 |
|
| 5 |
+
This version imports OpenLLM modules from the uploaded files in the HF Space:
|
| 6 |
+
- Imports model.py and data_loader.py that were uploaded to the Space
|
| 7 |
+
- Uses OpenLLM's actual custom model architecture
|
| 8 |
+
- Compatible with OpenLLM's implementation
|
| 9 |
+
|
| 10 |
+
This application provides a complete training interface for OpenLLM models on Hugging Face Spaces.
|
| 11 |
+
It uses OpenLLM's custom GPTModel architecture instead of Hugging Face Transformers,
|
| 12 |
+
ensuring compatibility with the actual OpenLLM implementation.
|
| 13 |
+
|
| 14 |
+
Key Features:
|
| 15 |
+
- Real model training using OpenLLM's custom architecture
|
| 16 |
+
- SentencePiece tokenization for OpenLLM models
|
| 17 |
+
- Complete training pipeline with progress monitoring
|
| 18 |
+
- Automatic model saving and uploading to Hugging Face Hub
|
| 19 |
+
- Gradio 4.44.1 compatible user interface
|
| 20 |
+
|
| 21 |
+
Technical Architecture:
|
| 22 |
+
- Uses OpenLLM's GPTModel class (not Hugging Face Transformers)
|
| 23 |
+
- Imports custom modules from uploaded files in the Space
|
| 24 |
+
- Uses sentencepiece.SentencePieceProcessor() for tokenization
|
| 25 |
+
- Implements OpenLLM's training loop and optimization strategy
|
| 26 |
+
- Saves checkpoints in OpenLLM's format
|
| 27 |
|
| 28 |
Author: Louis Chua Bean Chong
|
| 29 |
+
License: GPL-3.0
|
| 30 |
+
Version: 2.1.1
|
| 31 |
+
Last Updated: 2024
|
| 32 |
"""
|
| 33 |
|
|
|
|
|
|
|
| 34 |
import gradio as gr
|
| 35 |
+
import torch
|
| 36 |
+
import torch.nn as nn
|
| 37 |
+
import os
|
| 38 |
+
import time
|
| 39 |
+
import math
|
| 40 |
+
import gc
|
| 41 |
+
from typing import Dict, Any, Optional
|
| 42 |
+
import threading
|
| 43 |
+
from dataclasses import dataclass
|
| 44 |
from pathlib import Path
|
| 45 |
|
| 46 |
+
# Import OpenLLM's custom model architecture from uploaded files
|
| 47 |
+
# These files were uploaded to the HF Space and contain OpenLLM's actual implementation
|
| 48 |
+
try:
|
| 49 |
+
# Import from the uploaded files in the HF Space
|
| 50 |
+
# model.py contains GPTModel, GPTConfig, and create_model factory function
|
| 51 |
+
from model import GPTModel, GPTConfig, create_model
|
| 52 |
+
# data_loader.py contains TextDataLoader for OpenLLM's data loading approach
|
| 53 |
+
from data_loader import TextDataLoader
|
| 54 |
+
OPENLLM_AVAILABLE = True
|
| 55 |
+
print("β
OpenLLM custom model architecture imported successfully from uploaded files")
|
| 56 |
+
print(" - GPTModel: Custom PyTorch model architecture")
|
| 57 |
+
print(" - GPTConfig: Model configuration dataclass")
|
| 58 |
+
print(" - create_model: Factory function for model creation")
|
| 59 |
+
print(" - TextDataLoader: Custom data loading implementation")
|
| 60 |
+
except ImportError as e:
|
| 61 |
+
print(f"β OpenLLM imports failed: {e}")
|
| 62 |
+
print(" This indicates the uploaded OpenLLM source files are not available")
|
| 63 |
+
print(" The training functionality will be disabled")
|
| 64 |
+
OPENLLM_AVAILABLE = False
|
| 65 |
+
|
| 66 |
+
# Try to import sentencepiece - CRITICAL for OpenLLM tokenization
|
| 67 |
+
# OpenLLM uses SentencePiece for tokenization, not Hugging Face tokenizers
|
| 68 |
+
try:
|
| 69 |
+
import sentencepiece as spm
|
| 70 |
+
SENTENCEPIECE_AVAILABLE = True
|
| 71 |
+
print(f"β
SentencePiece available: {spm.__version__}")
|
| 72 |
+
print(" - Required for OpenLLM tokenization")
|
| 73 |
+
print(" - Used for loading tokenizer.model files")
|
| 74 |
+
except ImportError:
|
| 75 |
+
SENTENCEPIECE_AVAILABLE = False
|
| 76 |
+
print("β SentencePiece not available")
|
| 77 |
+
print(" - This will prevent tokenizer loading")
|
| 78 |
+
print(" - Training functionality will be limited")
|
| 79 |
+
|
| 80 |
+
# Import other dependencies for the complete training pipeline
|
| 81 |
try:
|
| 82 |
+
from datasets import load_dataset # For loading training data from HF Hub
|
| 83 |
+
from huggingface_hub import HfApi, hf_hub_download # For model uploads and downloads
|
| 84 |
+
DEPENDENCIES_AVAILABLE = True
|
| 85 |
+
print("β
Training dependencies available")
|
| 86 |
+
print(" - datasets: For loading training data")
|
| 87 |
+
print(" - huggingface_hub: For model uploads/downloads")
|
| 88 |
except ImportError as e:
|
| 89 |
+
print(f"β Dependencies not available: {e}")
|
| 90 |
+
print(" - This will prevent dataset loading and model uploading")
|
| 91 |
+
DEPENDENCIES_AVAILABLE = False
|
| 92 |
|
| 93 |
+
@dataclass
|
| 94 |
+
class TrainingConfig:
|
| 95 |
+
"""
|
| 96 |
+
Configuration class for training parameters.
|
| 97 |
+
|
| 98 |
+
This dataclass encapsulates all the training hyperparameters and settings
|
| 99 |
+
that control the OpenLLM training process. It provides a clean interface
|
| 100 |
+
for passing configuration between different components of the training pipeline.
|
| 101 |
+
|
| 102 |
+
Attributes:
|
| 103 |
+
model_size: Size of the model to train ("small", "medium", "large")
|
| 104 |
+
max_steps: Maximum number of training iterations
|
| 105 |
+
learning_rate: Learning rate for the optimizer
|
| 106 |
+
batch_size: Number of samples per training batch
|
| 107 |
+
output_dir: Directory to save trained models and checkpoints
|
| 108 |
+
save_steps: Frequency of checkpoint saving (every N steps)
|
| 109 |
+
logging_steps: Frequency of progress logging (every N steps)
|
| 110 |
+
warmup_steps: Number of warmup steps for learning rate scheduling
|
| 111 |
+
gradient_accumulation_steps: Number of steps to accumulate gradients
|
| 112 |
+
"""
|
| 113 |
+
model_size: str
|
| 114 |
+
max_steps: int
|
| 115 |
+
learning_rate: float
|
| 116 |
+
batch_size: int
|
| 117 |
+
output_dir: str = "./openllm-trained"
|
| 118 |
+
save_steps: int = 100
|
| 119 |
+
logging_steps: int = 10
|
| 120 |
+
warmup_steps: int = 50
|
| 121 |
+
gradient_accumulation_steps: int = 4
|
| 122 |
|
| 123 |
+
class OpenLLMTrainer:
|
| 124 |
+
"""
|
| 125 |
+
Complete training implementation using OpenLLM's actual architecture.
|
| 126 |
+
|
| 127 |
+
This class handles the entire training pipeline including:
|
| 128 |
+
- Model loading using OpenLLM's custom GPTModel
|
| 129 |
+
- Tokenizer loading using sentencepiece.SentencePieceProcessor()
|
| 130 |
+
- Dataset preparation using OpenLLM's TextDataLoader
|
| 131 |
+
- Training execution using OpenLLM's approach
|
| 132 |
+
- Model saving and uploading to Hugging Face Hub
|
| 133 |
+
|
| 134 |
+
The trainer implements OpenLLM's actual training methodology rather than
|
| 135 |
+
using Hugging Face Transformers, ensuring compatibility with the real
|
| 136 |
+
OpenLLM implementation.
|
| 137 |
+
|
| 138 |
+
Key Features:
|
| 139 |
+
- Custom model architecture (GPTModel, not PreTrainedModel)
|
| 140 |
+
- SentencePiece tokenization (not Hugging Face tokenizers)
|
| 141 |
+
- OpenLLM's training loop and optimization strategy
|
| 142 |
+
- Gradient accumulation for memory efficiency
|
| 143 |
+
- Learning rate scheduling with warmup
|
| 144 |
+
- Automatic checkpoint saving and model uploading
|
| 145 |
+
"""
|
| 146 |
|
| 147 |
+
def __init__(self):
|
| 148 |
+
"""
|
| 149 |
+
Initialize the trainer with default settings.
|
| 150 |
+
|
| 151 |
+
Sets up the trainer with default values and initializes the Hugging Face
|
| 152 |
+
API for model uploading. All components start as None and are initialized
|
| 153 |
+
during the training process.
|
| 154 |
+
"""
|
| 155 |
+
# Core training components - initialized during training
|
| 156 |
+
self.model = None # OpenLLM's GPTModel instance
|
| 157 |
+
self.tokenizer = None # SentencePieceProcessor instance
|
| 158 |
+
self.data_loader = None # OpenLLM's TextDataLoader instance
|
| 159 |
+
self.optimizer = None # PyTorch optimizer (AdamW)
|
| 160 |
+
self.scheduler = None # Learning rate scheduler
|
| 161 |
+
|
| 162 |
+
# Training state management
|
| 163 |
+
self.is_training = False # Flag to track training status
|
| 164 |
+
self.tokenizer_path = None # Path to the tokenizer.model file
|
| 165 |
+
|
| 166 |
+
# Progress tracking for UI updates
|
| 167 |
+
self.training_progress = {
|
| 168 |
+
"status": "Ready", # Current training status
|
| 169 |
+
"current_step": 0, # Current training step
|
| 170 |
+
"total_steps": 0, # Total steps to complete
|
| 171 |
+
"loss": 0.0, # Current training loss
|
| 172 |
+
"learning_rate": 0.0 # Current learning rate
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
# Initialize Hugging Face API for model uploading
|
| 176 |
+
# This allows the trained model to be automatically uploaded to HF Hub
|
| 177 |
try:
|
| 178 |
+
self.hf_api = HfApi()
|
| 179 |
+
print("β
Hugging Face API initialized for model uploading")
|
| 180 |
+
except Exception as e:
|
| 181 |
+
print(f"Failed to initialize HF API: {e}")
|
| 182 |
+
print(" - Model uploading will be disabled")
|
| 183 |
+
self.hf_api = None
|
| 184 |
+
|
| 185 |
+
def load_model_and_tokenizer(self, model_size: str) -> str:
|
| 186 |
+
"""
|
| 187 |
+
Load the pre-trained OpenLLM model and tokenizer using OpenLLM's approach.
|
| 188 |
+
|
| 189 |
+
This method implements OpenLLM's actual model loading strategy:
|
| 190 |
+
1. Creates a new GPTModel using OpenLLM's factory function
|
| 191 |
+
2. Downloads the tokenizer.model file from Hugging Face Hub
|
| 192 |
+
3. Loads the tokenizer using SentencePieceProcessor
|
| 193 |
+
4. Stores both components for use in training
|
| 194 |
+
|
| 195 |
+
This approach differs from Hugging Face Transformers because:
|
| 196 |
+
- Uses OpenLLM's custom GPTModel (not AutoModelForCausalLM)
|
| 197 |
+
- Uses SentencePiece directly (not AutoTokenizer)
|
| 198 |
+
- Downloads specific files rather than using from_pretrained()
|
| 199 |
+
|
| 200 |
+
Args:
|
| 201 |
+
model_size: Size of the model to load ("small", "medium", "large")
|
| 202 |
+
Determines which pre-trained model to download
|
| 203 |
|
| 204 |
+
Returns:
|
| 205 |
+
Status message indicating success or failure
|
| 206 |
+
Success: "β
Successfully loaded OpenLLM {model_size} model with custom architecture"
|
| 207 |
+
Failure: "β Failed to load OpenLLM model and tokenizer: {error details}"
|
| 208 |
+
"""
|
| 209 |
+
try:
|
| 210 |
+
# Verify OpenLLM modules are available
|
| 211 |
+
if not OPENLLM_AVAILABLE:
|
| 212 |
+
return "β OpenLLM custom model architecture not available"
|
| 213 |
|
| 214 |
+
print(f"π Loading OpenLLM {model_size} model using custom architecture...")
|
| 215 |
+
print(f" - Using OpenLLM's create_model factory function")
|
| 216 |
+
print(f" - Not using Hugging Face Transformers")
|
| 217 |
|
| 218 |
+
# Step 1: Create model using OpenLLM's factory function
|
| 219 |
+
# This creates a fresh GPTModel instance with the specified size
|
| 220 |
+
try:
|
| 221 |
+
self.model = create_model(model_size)
|
| 222 |
+
print(f"β
OpenLLM {model_size} model created: {type(self.model).__name__}")
|
| 223 |
+
print(f" - Model type: {type(self.model).__name__}")
|
| 224 |
+
print(f" - Parameters: {self.model.get_num_params():,}")
|
| 225 |
+
print(f" - Architecture: Custom GPTModel (not PreTrainedModel)")
|
| 226 |
+
except Exception as e:
|
| 227 |
+
print(f"β Failed to create model: {e}")
|
| 228 |
+
return f"β Failed to create OpenLLM model: {str(e)}"
|
| 229 |
|
| 230 |
+
# Step 2: Load tokenizer using sentencepiece
|
| 231 |
+
# OpenLLM uses SentencePiece directly, not Hugging Face tokenizers
|
| 232 |
+
try:
|
| 233 |
+
print("π Loading tokenizer using sentencepiece.SentencePieceProcessor()...")
|
| 234 |
+
print(" - Using SentencePiece directly (not AutoTokenizer)")
|
| 235 |
+
print(" - Downloading tokenizer.model from Hugging Face Hub")
|
| 236 |
+
|
| 237 |
+
# Download tokenizer.model from HF Hub
|
| 238 |
+
# This is the actual tokenizer file used by OpenLLM models
|
| 239 |
+
model_name = f"lemms/openllm-{model_size}-extended-7k"
|
| 240 |
+
tokenizer_path = hf_hub_download(
|
| 241 |
+
repo_id=model_name,
|
| 242 |
+
filename="tokenizer.model" # Specific file name for OpenLLM
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
print(f"β
Tokenizer downloaded to: {tokenizer_path}")
|
| 246 |
+
print(f" - Source: {model_name}")
|
| 247 |
+
print(f" - File: tokenizer.model")
|
| 248 |
+
|
| 249 |
+
# Create SentencePieceProcessor and load the tokenizer
|
| 250 |
+
# This is OpenLLM's actual tokenization approach
|
| 251 |
+
sp_processor = spm.SentencePieceProcessor()
|
| 252 |
+
sp_processor.load(tokenizer_path)
|
| 253 |
|
| 254 |
+
# Store tokenizer and its path separately
|
| 255 |
+
# We need the path for the TextDataLoader later
|
| 256 |
+
self.tokenizer = sp_processor
|
| 257 |
+
self.tokenizer_path = tokenizer_path # Store the path separately
|
| 258 |
+
|
| 259 |
+
print(f"β
Tokenizer loaded successfully using SentencePieceProcessor")
|
| 260 |
+
print(f" - Vocabulary size: {sp_processor.vocab_size()}")
|
| 261 |
+
print(f" - Tokenizer path: {tokenizer_path}")
|
| 262 |
+
print(f" - Tokenizer type: {type(sp_processor).__name__}")
|
| 263 |
+
|
| 264 |
+
except Exception as e:
|
| 265 |
+
print(f"β Failed to load tokenizer: {e}")
|
| 266 |
+
return f"β Failed to load OpenLLM tokenizer: {str(e)}"
|
| 267 |
+
|
| 268 |
+
return f"β
Successfully loaded OpenLLM {model_size} model with custom architecture"
|
| 269 |
+
|
| 270 |
except Exception as e:
|
| 271 |
+
return f"β Failed to load OpenLLM model and tokenizer: {str(e)}"
|
| 272 |
|
| 273 |
+
def prepare_dataset(self) -> str:
|
| 274 |
+
"""
|
| 275 |
+
Load and prepare the training dataset using OpenLLM's approach.
|
| 276 |
+
|
| 277 |
+
This method implements OpenLLM's data preparation strategy:
|
| 278 |
+
1. Loads training data from Hugging Face Hub dataset
|
| 279 |
+
2. Creates a temporary text file for OpenLLM's TextDataLoader
|
| 280 |
+
3. Initializes OpenLLM's TextDataLoader with the tokenizer
|
| 281 |
+
4. Prepares the data for training
|
| 282 |
+
|
| 283 |
+
OpenLLM's approach differs from Hugging Face because:
|
| 284 |
+
- Uses a simple text file format (not tokenized datasets)
|
| 285 |
+
- Uses OpenLLM's TextDataLoader (not Hugging Face datasets)
|
| 286 |
+
- Tokenization happens on-the-fly during training
|
| 287 |
+
|
| 288 |
+
Returns:
|
| 289 |
+
Status message indicating success or failure
|
| 290 |
+
Success: "β
Successfully prepared dataset with {count} samples"
|
| 291 |
+
Failure: "β Failed to prepare dataset: {error details}"
|
| 292 |
+
"""
|
| 293 |
try:
|
| 294 |
+
# Verify dependencies are available
|
| 295 |
+
if not DEPENDENCIES_AVAILABLE:
|
| 296 |
+
return "β Required dependencies not available"
|
| 297 |
|
| 298 |
+
print("π Loading training dataset...")
|
| 299 |
+
print(" - Loading from Hugging Face Hub dataset")
|
| 300 |
+
print(" - Using OpenLLM's data preparation approach")
|
| 301 |
|
| 302 |
+
# Load dataset from HF Hub
|
| 303 |
+
# This contains the training text data for continuing model training
|
| 304 |
+
dataset = load_dataset("lemms/openllm-training-data")
|
| 305 |
+
print(f"β
Dataset loaded: {len(dataset['train'])} samples")
|
| 306 |
+
print(f" - Dataset: lemms/openllm-training-data")
|
| 307 |
+
print(f" - Samples: {len(dataset['train'])}")
|
| 308 |
+
|
| 309 |
+
# Create temporary data file for OpenLLM's TextDataLoader
|
| 310 |
+
# OpenLLM expects a simple text file with one text sample per line
|
| 311 |
+
temp_data_file = "temp_training_data.txt"
|
| 312 |
+
with open(temp_data_file, 'w', encoding='utf-8') as f:
|
| 313 |
+
for item in dataset['train']:
|
| 314 |
+
f.write(item['text'] + '\n')
|
| 315 |
+
|
| 316 |
+
print(f"β
Temporary data file created: {temp_data_file}")
|
| 317 |
+
print(f" - Format: One text sample per line")
|
| 318 |
+
print(f" - Encoding: UTF-8")
|
| 319 |
+
|
| 320 |
+
# Create OpenLLM's TextDataLoader
|
| 321 |
+
# This is OpenLLM's custom data loading implementation
|
| 322 |
+
try:
|
| 323 |
+
# Use the stored tokenizer path instead of trying to access model_file_path
|
| 324 |
+
# SentencePieceProcessor doesn't have a model_file_path attribute
|
| 325 |
+
tokenizer_path = self.tokenizer_path # Use the stored path
|
| 326 |
+
|
| 327 |
+
print(f"π Creating OpenLLM TextDataLoader...")
|
| 328 |
+
print(f" - Data file: {temp_data_file}")
|
| 329 |
+
print(f" - Tokenizer path: {tokenizer_path}")
|
| 330 |
+
print(f" - Sequence length: 512")
|
| 331 |
+
print(f" - Batch size: 4 (will be overridden by training config)")
|
| 332 |
+
|
| 333 |
+
self.data_loader = TextDataLoader(
|
| 334 |
+
data_file=temp_data_file,
|
| 335 |
+
tokenizer_path=tokenizer_path,
|
| 336 |
+
seq_len=512, # Maximum sequence length for training
|
| 337 |
+
batch_size=4, # Will be overridden by training config
|
| 338 |
+
shuffle=True # Shuffle data for better training
|
| 339 |
)
|
| 340 |
+
|
| 341 |
+
print(f"β
OpenLLM TextDataLoader created successfully")
|
| 342 |
+
print(f" - DataLoader type: {type(self.data_loader).__name__}")
|
| 343 |
+
print(f" - Uses OpenLLM's custom implementation")
|
| 344 |
+
|
| 345 |
+
except Exception as e:
|
| 346 |
+
print(f"β Failed to create TextDataLoader: {e}")
|
| 347 |
+
return f"β Failed to create data loader: {str(e)}"
|
| 348 |
+
|
| 349 |
+
return f"β
Successfully prepared dataset with {len(dataset['train'])} samples"
|
| 350 |
+
|
| 351 |
+
except Exception as e:
|
| 352 |
+
return f"β Failed to prepare dataset: {str(e)}"
|
| 353 |
+
|
| 354 |
+
def setup_training(self, config: TrainingConfig) -> str:
|
| 355 |
+
"""
|
| 356 |
+
Set up the training configuration using OpenLLM's approach.
|
| 357 |
+
|
| 358 |
+
This method configures the training environment with:
|
| 359 |
+
1. Output directory creation
|
| 360 |
+
2. Optimizer setup with weight decay groups
|
| 361 |
+
3. Learning rate scheduler with warmup
|
| 362 |
+
4. Training hyperparameters
|
| 363 |
+
|
| 364 |
+
The setup follows OpenLLM's training methodology:
|
| 365 |
+
- Uses AdamW optimizer with weight decay
|
| 366 |
+
- Implements learning rate warmup followed by cosine annealing
|
| 367 |
+
- Separates parameters for different weight decay rates
|
| 368 |
+
- Uses gradient clipping for stability
|
| 369 |
+
|
| 370 |
+
Args:
|
| 371 |
+
config: Training configuration object containing all hyperparameters
|
| 372 |
+
|
| 373 |
+
Returns:
|
| 374 |
+
Status message indicating success or failure
|
| 375 |
+
Success: "β
Training setup completed successfully"
|
| 376 |
+
Failure: "β Failed to setup training: {error details}"
|
| 377 |
+
"""
|
| 378 |
+
try:
|
| 379 |
+
print("π Setting up training configuration...")
|
| 380 |
+
print(f" - Output directory: {config.output_dir}")
|
| 381 |
+
print(f" - Learning rate: {config.learning_rate}")
|
| 382 |
+
print(f" - Max steps: {config.max_steps}")
|
| 383 |
+
|
| 384 |
+
# Create output directory for saving models and checkpoints
|
| 385 |
+
os.makedirs(config.output_dir, exist_ok=True)
|
| 386 |
+
print(f"β
Output directory created: {config.output_dir}")
|
| 387 |
|
| 388 |
+
# Set up optimizer (AdamW with weight decay)
|
| 389 |
+
# This follows OpenLLM's optimization strategy
|
| 390 |
+
print("π Setting up AdamW optimizer with weight decay...")
|
| 391 |
|
| 392 |
+
# Separate parameters for different weight decay rates
|
| 393 |
+
# This is a common practice for transformer training
|
| 394 |
+
decay_params = [] # Parameters that should have weight decay
|
| 395 |
+
no_decay_params = [] # Parameters that should not have weight decay
|
| 396 |
+
|
| 397 |
+
for name, param in self.model.named_parameters():
|
| 398 |
+
if not param.requires_grad:
|
| 399 |
+
continue
|
| 400 |
|
| 401 |
+
# Apply weight decay to all parameters except biases and layer norm weights
|
| 402 |
+
if len(param.shape) == 1 or name.endswith('.bias'):
|
| 403 |
+
no_decay_params.append(param)
|
| 404 |
+
else:
|
| 405 |
+
decay_params.append(param)
|
| 406 |
+
|
| 407 |
+
# Create parameter groups with different weight decay rates
|
| 408 |
+
param_groups = [
|
| 409 |
+
{'params': decay_params, 'weight_decay': 0.01}, # 1% weight decay
|
| 410 |
+
{'params': no_decay_params, 'weight_decay': 0.0} # No weight decay
|
| 411 |
+
]
|
| 412 |
+
|
| 413 |
+
print(f" - Decay parameters: {len(decay_params)}")
|
| 414 |
+
print(f" - No-decay parameters: {len(no_decay_params)}")
|
| 415 |
+
|
| 416 |
+
# Initialize AdamW optimizer with OpenLLM's recommended settings
|
| 417 |
+
self.optimizer = torch.optim.AdamW(
|
| 418 |
+
param_groups,
|
| 419 |
+
lr=config.learning_rate,
|
| 420 |
+
betas=(0.9, 0.95), # Beta values for momentum
|
| 421 |
+
eps=1e-8 # Epsilon for numerical stability
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
print(f"β
AdamW optimizer configured")
|
| 425 |
+
print(f" - Learning rate: {config.learning_rate}")
|
| 426 |
+
print(f" - Betas: (0.9, 0.95)")
|
| 427 |
+
print(f" - Epsilon: 1e-8")
|
| 428 |
+
|
| 429 |
+
# Set up learning rate scheduler
|
| 430 |
+
# OpenLLM uses a warmup followed by cosine annealing
|
| 431 |
+
print("π Setting up learning rate scheduler...")
|
| 432 |
+
|
| 433 |
+
# Warmup scheduler: linearly increase LR from 1% to 100%
|
| 434 |
+
warmup_scheduler = torch.optim.lr_scheduler.LinearLR(
|
| 435 |
+
self.optimizer,
|
| 436 |
+
start_factor=0.01, # Start at 1% of target LR
|
| 437 |
+
end_factor=1.0, # End at 100% of target LR
|
| 438 |
+
total_iters=config.warmup_steps
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
# Main scheduler: cosine annealing after warmup
|
| 442 |
+
main_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 443 |
+
self.optimizer,
|
| 444 |
+
T_max=config.max_steps - config.warmup_steps # Duration of cosine annealing
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
# Combine warmup and main schedulers
|
| 448 |
+
self.scheduler = torch.optim.lr_scheduler.SequentialLR(
|
| 449 |
+
self.optimizer,
|
| 450 |
+
schedulers=[warmup_scheduler, main_scheduler],
|
| 451 |
+
milestones=[config.warmup_steps] # Switch to main scheduler after warmup
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
print(f"β
Learning rate scheduler configured")
|
| 455 |
+
print(f" - Warmup steps: {config.warmup_steps}")
|
| 456 |
+
print(f" - Total steps: {config.max_steps}")
|
| 457 |
+
print(f" - Schedule: Linear warmup β Cosine annealing")
|
| 458 |
+
|
| 459 |
+
print("β
Training setup completed successfully")
|
| 460 |
+
return f"β
Training setup completed successfully"
|
| 461 |
+
|
| 462 |
except Exception as e:
|
| 463 |
+
return f"β Failed to setup training: {str(e)}"
|
| 464 |
|
| 465 |
+
def train_model(self, config: TrainingConfig, progress_callback=None) -> str:
|
| 466 |
+
"""
|
| 467 |
+
Execute the actual model training using OpenLLM's approach.
|
| 468 |
+
|
| 469 |
+
This method implements OpenLLM's training loop:
|
| 470 |
+
1. Sets up training mode and progress tracking
|
| 471 |
+
2. Iterates through data batches using OpenLLM's TextDataLoader
|
| 472 |
+
3. Performs forward pass, loss computation, and backward pass
|
| 473 |
+
4. Implements gradient accumulation for memory efficiency
|
| 474 |
+
5. Updates model parameters and learning rate
|
| 475 |
+
6. Saves checkpoints and logs progress
|
| 476 |
+
|
| 477 |
+
The training loop follows OpenLLM's methodology:
|
| 478 |
+
- Uses OpenLLM's GPTModel forward pass (returns logits and loss)
|
| 479 |
+
- Implements gradient accumulation for effective larger batch sizes
|
| 480 |
+
- Uses gradient clipping for training stability
|
| 481 |
+
- Saves checkpoints in OpenLLM's format
|
| 482 |
+
- Updates progress for UI monitoring
|
| 483 |
+
|
| 484 |
+
Args:
|
| 485 |
+
config: Training configuration object containing hyperparameters
|
| 486 |
+
progress_callback: Optional callback function for progress updates
|
| 487 |
+
(Not used in current implementation)
|
| 488 |
+
|
| 489 |
+
Returns:
|
| 490 |
+
Status message indicating success or failure
|
| 491 |
+
Success: "β
Training completed successfully! Final step: {step}"
|
| 492 |
+
Failure: "β Training failed: {error details}"
|
| 493 |
+
"""
|
| 494 |
try:
|
| 495 |
+
# Set training state
|
| 496 |
+
self.is_training = True
|
| 497 |
+
self.training_progress["status"] = "Training"
|
| 498 |
+
self.training_progress["total_steps"] = config.max_steps
|
| 499 |
|
| 500 |
+
print(f"π Starting OpenLLM training for {config.max_steps} steps...")
|
| 501 |
+
print(f" - Model: {type(self.model).__name__}")
|
| 502 |
+
print(f" - DataLoader: {type(self.data_loader).__name__}")
|
| 503 |
+
print(f" - Optimizer: {type(self.optimizer).__name__}")
|
| 504 |
+
print(f" - Gradient accumulation: {config.gradient_accumulation_steps}")
|
| 505 |
|
| 506 |
+
# Training loop using OpenLLM's approach
|
| 507 |
+
self.model.train() # Set model to training mode
|
| 508 |
+
accumulated_loss = 0.0 # Track loss across accumulation steps
|
| 509 |
+
self.optimizer.zero_grad() # Clear gradients
|
| 510 |
|
| 511 |
+
step = 0 # Current training step
|
| 512 |
+
for batch_idx, (input_ids, target_ids) in enumerate(self.data_loader):
|
| 513 |
+
# Check if we've reached the maximum number of steps
|
| 514 |
+
if step >= config.max_steps:
|
| 515 |
+
break
|
| 516 |
+
|
| 517 |
+
# Forward pass (model computes loss internally when targets provided)
|
| 518 |
+
# OpenLLM's GPTModel returns both logits and loss
|
| 519 |
+
logits, loss = self.model(input_ids, target_ids)
|
| 520 |
+
|
| 521 |
+
# Scale loss for gradient accumulation
|
| 522 |
+
# This allows us to simulate larger batch sizes
|
| 523 |
+
loss = loss / config.gradient_accumulation_steps
|
| 524 |
+
accumulated_loss += loss.item()
|
| 525 |
+
|
| 526 |
+
# Backward pass - compute gradients
|
| 527 |
+
loss.backward()
|
| 528 |
+
|
| 529 |
+
# Update weights every gradient_accumulation_steps
|
| 530 |
+
if (batch_idx + 1) % config.gradient_accumulation_steps == 0:
|
| 531 |
+
# Clip gradients for training stability
|
| 532 |
+
# This prevents exploding gradients
|
| 533 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
|
| 534 |
+
|
| 535 |
+
# Update parameters using the optimizer
|
| 536 |
+
self.optimizer.step()
|
| 537 |
+
|
| 538 |
+
# Update learning rate using the scheduler
|
| 539 |
+
self.scheduler.step()
|
| 540 |
+
|
| 541 |
+
# Clear gradients for the next accumulation cycle
|
| 542 |
+
self.optimizer.zero_grad()
|
| 543 |
+
|
| 544 |
+
# Update step count
|
| 545 |
+
step += 1
|
| 546 |
+
|
| 547 |
+
# Update progress for UI monitoring
|
| 548 |
+
self.training_progress["current_step"] = step
|
| 549 |
+
self.training_progress["loss"] = accumulated_loss
|
| 550 |
+
self.training_progress["learning_rate"] = self.scheduler.get_last_lr()[0]
|
| 551 |
+
|
| 552 |
+
# Log progress at specified intervals
|
| 553 |
+
if step % config.logging_steps == 0:
|
| 554 |
+
current_lr = self.scheduler.get_last_lr()[0]
|
| 555 |
+
print(f"Step {step}/{config.max_steps} | Loss: {accumulated_loss:.4f} | LR: {current_lr:.2e}")
|
| 556 |
+
|
| 557 |
+
# Save checkpoint at specified intervals
|
| 558 |
+
if step % config.save_steps == 0:
|
| 559 |
+
self._save_checkpoint(config.output_dir, step)
|
| 560 |
+
print(f"πΎ Checkpoint saved at step {step}")
|
| 561 |
+
|
| 562 |
+
# Reset accumulated loss for the next accumulation cycle
|
| 563 |
+
accumulated_loss = 0.0
|
| 564 |
+
|
| 565 |
+
# Clean up memory periodically
|
| 566 |
+
if step % 100 == 0:
|
| 567 |
+
gc.collect()
|
| 568 |
+
print(f"π§Ή Memory cleanup at step {step}")
|
| 569 |
|
| 570 |
+
# Save final checkpoint
|
| 571 |
+
self._save_checkpoint(config.output_dir, step, is_best=True)
|
| 572 |
+
print(f"πΎ Final checkpoint saved at step {step}")
|
| 573 |
+
|
| 574 |
+
# Update final progress
|
| 575 |
+
self.training_progress["status"] = "Completed"
|
| 576 |
+
self.training_progress["current_step"] = step
|
| 577 |
|
| 578 |
+
print(f"β
Training completed! Final step: {step}")
|
| 579 |
+
print(f" - Total steps completed: {step}")
|
| 580 |
+
print(f" - Final loss: {self.training_progress['loss']:.4f}")
|
| 581 |
+
print(f" - Final learning rate: {self.training_progress['learning_rate']:.2e}")
|
| 582 |
|
| 583 |
+
return f"β
Training completed successfully! Final step: {step}"
|
| 584 |
|
| 585 |
except Exception as e:
|
| 586 |
+
self.training_progress["status"] = "Failed"
|
| 587 |
+
print(f"β Training failed: {e}")
|
| 588 |
+
print(f" - Error occurred during training")
|
| 589 |
+
print(f" - Training state: {self.training_progress['status']}")
|
| 590 |
+
return f"β Training failed: {str(e)}"
|
| 591 |
+
finally:
|
| 592 |
+
self.is_training = False
|
| 593 |
|
| 594 |
+
def _save_checkpoint(self, output_dir: str, step: int, is_best: bool = False) -> None:
|
| 595 |
+
"""
|
| 596 |
+
Save model checkpoint using OpenLLM's approach.
|
| 597 |
+
|
| 598 |
+
This method saves the model state in OpenLLM's checkpoint format:
|
| 599 |
+
- Model state dictionary
|
| 600 |
+
- Optimizer state dictionary
|
| 601 |
+
- Scheduler state dictionary
|
| 602 |
+
- Model configuration
|
| 603 |
+
- Training step information
|
| 604 |
+
|
| 605 |
+
The checkpoint format is compatible with OpenLLM's loading mechanism
|
| 606 |
+
and can be used to resume training or load the model for inference.
|
| 607 |
|
| 608 |
+
Args:
|
| 609 |
+
output_dir: Directory to save the checkpoint
|
| 610 |
+
step: Current training step number
|
| 611 |
+
is_best: Whether this is the best model so far
|
| 612 |
+
"""
|
| 613 |
+
try:
|
| 614 |
+
# Create checkpoint dictionary with all necessary components
|
| 615 |
+
checkpoint = {
|
| 616 |
+
'step': step, # Current training step
|
| 617 |
+
'model_state_dict': self.model.state_dict(), # Model parameters
|
| 618 |
+
'optimizer_state_dict': self.optimizer.state_dict(), # Optimizer state
|
| 619 |
+
'scheduler_state_dict': self.scheduler.state_dict(), # Scheduler state
|
| 620 |
+
'config': self.model.config.__dict__ # Model configuration
|
| 621 |
+
}
|
| 622 |
+
|
| 623 |
+
# Save latest checkpoint
|
| 624 |
+
checkpoint_path = os.path.join(output_dir, f"checkpoint_step_{step}.pt")
|
| 625 |
+
torch.save(checkpoint, checkpoint_path)
|
| 626 |
+
|
| 627 |
+
# Save best checkpoint if this is the best model
|
| 628 |
+
if is_best:
|
| 629 |
+
best_path = os.path.join(output_dir, "best_model.pt")
|
| 630 |
+
torch.save(checkpoint, best_path)
|
| 631 |
+
print(f"πΎ Best model saved: {best_path}")
|
| 632 |
+
|
| 633 |
+
print(f"πΎ Checkpoint saved: {checkpoint_path}")
|
| 634 |
+
|
| 635 |
+
except Exception as e:
|
| 636 |
+
print(f"β Failed to save checkpoint: {e}")
|
| 637 |
+
|
| 638 |
+
def save_and_upload_model(self, config: TrainingConfig) -> str:
|
| 639 |
+
"""
|
| 640 |
+
Save the trained model and upload it to Hugging Face Hub.
|
| 641 |
|
| 642 |
+
This method completes the training pipeline by:
|
| 643 |
+
1. Saving the final model checkpoint
|
| 644 |
+
2. Copying the tokenizer files
|
| 645 |
+
3. Uploading the complete model to Hugging Face Hub
|
| 646 |
+
4. Creating a new model repository for the trained model
|
| 647 |
|
| 648 |
+
The uploaded model will be available at:
|
| 649 |
+
https://huggingface.co/lemms/openllm-{size}-extended-8k
|
| 650 |
|
| 651 |
+
Args:
|
| 652 |
+
config: Training configuration object
|
| 653 |
+
|
| 654 |
+
Returns:
|
| 655 |
+
Status message indicating success or failure
|
| 656 |
+
Success: "β
Model saved and uploaded to https://huggingface.co/{repo_id}"
|
| 657 |
+
Failure: "β Failed to save/upload model: {error details}"
|
| 658 |
+
"""
|
| 659 |
+
try:
|
| 660 |
+
print("π Saving trained model...")
|
| 661 |
+
print(f" - Output directory: {config.output_dir}")
|
| 662 |
+
print(f" - Model size: {config.model_size}")
|
| 663 |
+
|
| 664 |
+
# Save the final model checkpoint
|
| 665 |
+
self._save_checkpoint(config.output_dir, config.max_steps, is_best=True)
|
| 666 |
+
|
| 667 |
+
# Save tokenizer files
|
| 668 |
+
# Create a tokenizer directory within the output directory
|
| 669 |
+
tokenizer_dir = os.path.join(config.output_dir, "tokenizer")
|
| 670 |
+
os.makedirs(tokenizer_dir, exist_ok=True)
|
| 671 |
+
|
| 672 |
+
# Copy the tokenizer.model file using the stored path
|
| 673 |
+
# This ensures the tokenizer is included with the model
|
| 674 |
+
import shutil
|
| 675 |
+
shutil.copy2(self.tokenizer_path, os.path.join(tokenizer_dir, "tokenizer.model"))
|
| 676 |
+
|
| 677 |
+
print("β
Model saved locally")
|
| 678 |
+
print(f" - Model checkpoint: {config.output_dir}/best_model.pt")
|
| 679 |
+
print(f" - Tokenizer: {tokenizer_dir}/tokenizer.model")
|
| 680 |
+
|
| 681 |
+
# Generate model name for upload
|
| 682 |
+
# The naming convention follows: openllm-{size}-extended-8k
|
| 683 |
+
model_name = f"openllm-{config.model_size}-extended-8k"
|
| 684 |
+
repo_id = f"lemms/{model_name}"
|
| 685 |
+
|
| 686 |
+
# Upload to Hugging Face Hub
|
| 687 |
+
if self.hf_api:
|
| 688 |
+
print(f"π Uploading model to {repo_id}...")
|
| 689 |
+
print(f" - Repository: {repo_id}")
|
| 690 |
+
print(f" - Type: model")
|
| 691 |
+
print(f" - Source: {config.output_dir}")
|
| 692 |
+
|
| 693 |
+
# Create the repository first if it doesn't exist
|
| 694 |
+
try:
|
| 695 |
+
from huggingface_hub import create_repo
|
| 696 |
+
create_repo(
|
| 697 |
+
repo_id=repo_id,
|
| 698 |
+
repo_type="model",
|
| 699 |
+
exist_ok=True,
|
| 700 |
+
private=False
|
| 701 |
+
)
|
| 702 |
+
print(f"β
Repository {repo_id} ready for upload")
|
| 703 |
+
except Exception as create_error:
|
| 704 |
+
print(f"β οΈ Repository creation warning: {create_error}")
|
| 705 |
+
print(" Continuing with upload attempt...")
|
| 706 |
+
|
| 707 |
+
# Upload model files to Hugging Face Hub
|
| 708 |
+
# This creates a new model repository with all the files
|
| 709 |
+
self.hf_api.upload_folder(
|
| 710 |
+
folder_path=config.output_dir,
|
| 711 |
+
repo_id=repo_id,
|
| 712 |
+
repo_type="model",
|
| 713 |
+
commit_message=f"Add trained OpenLLM {config.model_size} model (8k steps)"
|
| 714 |
+
)
|
| 715 |
+
|
| 716 |
+
print(f"β
Model uploaded successfully to {repo_id}")
|
| 717 |
+
print(f" - Available at: https://huggingface.co/{repo_id}")
|
| 718 |
+
return f"β
Model saved and uploaded to https://huggingface.co/{repo_id}"
|
| 719 |
+
else:
|
| 720 |
+
print("β οΈ Hugging Face API not available - model saved locally only")
|
| 721 |
+
return f"β
Model saved locally to {config.output_dir}"
|
| 722 |
+
|
| 723 |
+
except Exception as e:
|
| 724 |
+
print(f"β Failed to save/upload model: {e}")
|
| 725 |
+
return f"β Failed to save/upload model: {str(e)}"
|
| 726 |
+
|
| 727 |
+
def get_training_progress(self) -> Dict[str, Any]:
|
| 728 |
+
"""
|
| 729 |
+
Get current training progress information.
|
| 730 |
+
|
| 731 |
+
This method returns a copy of the current training progress
|
| 732 |
+
for display in the Gradio UI. The progress information includes:
|
| 733 |
+
- Current training status
|
| 734 |
+
- Current step and total steps
|
| 735 |
+
- Current loss value
|
| 736 |
+
- Current learning rate
|
| 737 |
|
| 738 |
+
Returns:
|
| 739 |
+
Dictionary containing current training progress information
|
| 740 |
+
"""
|
| 741 |
+
return self.training_progress.copy()
|
| 742 |
+
|
| 743 |
+
def main():
|
| 744 |
+
"""
|
| 745 |
+
Main function that creates the complete Gradio application interface.
|
| 746 |
+
|
| 747 |
+
This function sets up the entire Gradio application with:
|
| 748 |
+
1. Application header and status information
|
| 749 |
+
2. Training configuration controls
|
| 750 |
+
3. Training status and progress display
|
| 751 |
+
4. Training control buttons
|
| 752 |
+
5. Instructions and resource links
|
| 753 |
+
6. Training function implementation
|
| 754 |
+
|
| 755 |
+
The interface provides a complete training experience for OpenLLM models
|
| 756 |
+
with real-time progress monitoring and comprehensive configuration options.
|
| 757 |
+
|
| 758 |
+
Returns:
|
| 759 |
+
Gradio Blocks interface for the training application
|
| 760 |
+
"""
|
| 761 |
+
|
| 762 |
+
# Initialize the trainer
|
| 763 |
+
# This creates the OpenLLMTrainer instance that will handle all training operations
|
| 764 |
+
trainer = OpenLLMTrainer()
|
| 765 |
+
|
| 766 |
+
# Create the main Gradio application interface
|
| 767 |
+
# Using Gradio 4.44.1 with Soft theme for modern appearance
|
| 768 |
+
with gr.Blocks(
|
| 769 |
+
title="OpenLLM Training Space - Fixed with Uploaded Modules",
|
| 770 |
+
theme=gr.themes.Soft()
|
| 771 |
+
) as demo:
|
| 772 |
+
|
| 773 |
+
# Application Header
|
| 774 |
+
# Provides clear identification and description of the application
|
| 775 |
+
gr.Markdown("# π OpenLLM Training Space - Fixed with Uploaded Modules")
|
| 776 |
+
gr.Markdown("### *Uses OpenLLM's Custom Model Architecture from Uploaded Files*")
|
| 777 |
+
gr.Markdown("---")
|
| 778 |
+
|
| 779 |
+
# Status Information
|
| 780 |
+
# Shows the availability of key components and dependencies
|
| 781 |
+
gr.Markdown(f"**OpenLLM Available**: {'β
Yes' if OPENLLM_AVAILABLE else 'β No'}")
|
| 782 |
+
gr.Markdown(f"**SentencePiece Available**: {'β
Yes' if SENTENCEPIECE_AVAILABLE else 'β No'}")
|
| 783 |
+
gr.Markdown(f"**Dependencies Available**: {'β
Yes' if DEPENDENCIES_AVAILABLE else 'β No'}")
|
| 784 |
+
gr.Markdown("**Architecture**: β
OpenLLM Custom GPTModel (From Uploaded Files)")
|
| 785 |
+
|
| 786 |
+
# Main Content Area
|
| 787 |
+
# Two-column layout for configuration and status
|
| 788 |
+
with gr.Row():
|
| 789 |
+
|
| 790 |
+
# Left Column: Training Configuration
|
| 791 |
+
# Contains all the training hyperparameters and settings
|
| 792 |
+
with gr.Column(scale=1):
|
| 793 |
+
gr.Markdown("## π Training Configuration")
|
| 794 |
+
|
| 795 |
+
# Model Size Selection
|
| 796 |
+
# Allows users to choose which base model to train from
|
| 797 |
model_size = gr.Dropdown(
|
| 798 |
choices=["small", "medium", "large"],
|
| 799 |
value="small",
|
| 800 |
label="Model Size",
|
| 801 |
+
info="Select the base model size to train from"
|
| 802 |
)
|
| 803 |
+
|
| 804 |
+
# Training Steps Configuration
|
| 805 |
+
# Controls the number of training iterations
|
| 806 |
+
max_steps = gr.Slider(
|
| 807 |
+
minimum=100,
|
| 808 |
+
maximum=10000,
|
| 809 |
+
value=1000,
|
| 810 |
+
step=100,
|
| 811 |
+
label="Max Training Steps",
|
| 812 |
+
info="Number of training iterations (100-10,000)"
|
| 813 |
+
)
|
| 814 |
+
|
| 815 |
+
# Learning Rate Configuration
|
| 816 |
+
# Controls the learning rate for the optimizer
|
| 817 |
+
learning_rate = gr.Slider(
|
| 818 |
+
minimum=1e-5,
|
| 819 |
+
maximum=1e-3,
|
| 820 |
+
value=3e-4,
|
| 821 |
+
step=1e-5,
|
| 822 |
+
label="Learning Rate",
|
| 823 |
+
info="Training rate (0.00001-0.001)"
|
| 824 |
+
)
|
| 825 |
+
|
| 826 |
+
# Batch Size Configuration
|
| 827 |
+
# Controls the number of samples per training batch
|
| 828 |
+
batch_size = gr.Slider(
|
| 829 |
+
minimum=1,
|
| 830 |
+
maximum=16,
|
| 831 |
+
value=4,
|
| 832 |
+
step=1,
|
| 833 |
+
label="Batch Size",
|
| 834 |
+
info="Samples per training batch (1-16)"
|
| 835 |
)
|
| 836 |
|
| 837 |
+
# Right Column: Training Status and Controls
|
| 838 |
+
# Contains status display and control buttons
|
| 839 |
+
with gr.Column(scale=1):
|
| 840 |
+
gr.Markdown("## π― Training Status")
|
| 841 |
+
|
| 842 |
+
# Training Status Display
|
| 843 |
+
# Shows current training status and any error messages
|
| 844 |
+
status_text = gr.Textbox(
|
| 845 |
+
value="Ready to start training" if OPENLLM_AVAILABLE else "OpenLLM not available",
|
| 846 |
+
label="Current Status",
|
| 847 |
+
interactive=False,
|
| 848 |
+
lines=5,
|
| 849 |
+
info="Shows current training status and progress updates"
|
| 850 |
+
)
|
| 851 |
+
|
| 852 |
+
# Progress Information
|
| 853 |
+
# Displays detailed training progress in JSON format
|
| 854 |
+
progress_info = gr.JSON(
|
| 855 |
+
value=trainer.get_training_progress(),
|
| 856 |
+
label="Training Progress"
|
| 857 |
+
)
|
| 858 |
+
|
| 859 |
+
# Training Control Buttons
|
| 860 |
+
# Buttons to start and stop training
|
| 861 |
+
with gr.Row():
|
| 862 |
+
start_btn = gr.Button("π Start Training", variant="primary")
|
| 863 |
+
stop_btn = gr.Button("βΉοΈ Stop Training", variant="stop")
|
| 864 |
|
| 865 |
+
# Instructions Section
|
| 866 |
+
# Provides detailed instructions for using the training interface
|
| 867 |
+
gr.Markdown("## π OpenLLM Training Instructions")
|
| 868 |
+
gr.Markdown("""
|
| 869 |
+
This interface uses **OpenLLM's actual custom model architecture** from uploaded files:
|
| 870 |
+
|
| 871 |
+
### **Step 1: Configure Parameters**
|
| 872 |
+
- **Model Size**: Select the base model to train from (small, medium, large)
|
| 873 |
+
- **Max Steps**: Number of training iterations (100-10,000)
|
| 874 |
+
- **Learning Rate**: Training rate (0.00001-0.001)
|
| 875 |
+
- **Batch Size**: Samples per training batch (1-16)
|
| 876 |
+
|
| 877 |
+
### **Step 2: Start Training**
|
| 878 |
+
- Click "Start Training" to begin the actual training process
|
| 879 |
+
- Uses OpenLLM's custom GPTModel class from uploaded files
|
| 880 |
+
- Uses sentencepiece.SentencePieceProcessor() for tokenization
|
| 881 |
+
- Compatible with OpenLLM's actual implementation
|
| 882 |
+
|
| 883 |
+
### **Step 3: Monitor Progress**
|
| 884 |
+
- Watch the status updates and progress information
|
| 885 |
+
- Training may take several minutes depending on steps
|
| 886 |
+
- The final model will be uploaded to Hugging Face Hub
|
| 887 |
+
|
| 888 |
+
### **Step 4: Access Results**
|
| 889 |
+
- Trained models are automatically pushed to: `lemms/openllm-{size}-extended-8k`
|
| 890 |
+
- Check the model repository for your trained model
|
| 891 |
+
- Use the model for inference or further training
|
| 892 |
+
""")
|
| 893 |
+
|
| 894 |
+
# Resource Links Section
|
| 895 |
+
# Provides links to related models and resources
|
| 896 |
+
gr.Markdown("## π Model Resources")
|
| 897 |
+
gr.Markdown("""
|
| 898 |
+
- [π 7k Small Model](https://huggingface.co/lemms/openllm-small-extended-7k)
|
| 899 |
+
- [π― 8k Small Model](https://huggingface.co/lemms/openllm-small-extended-8k)
|
| 900 |
+
- [π Training Dataset](https://huggingface.co/datasets/lemms/openllm-training-data)
|
| 901 |
+
- [π Main Project](https://github.com/louischua/openllm)
|
| 902 |
+
""")
|
| 903 |
+
|
| 904 |
+
# Training Function Definition
|
| 905 |
+
# This function is called when the Start Training button is clicked
|
| 906 |
+
def start_complete_training(model_size, max_steps, learning_rate, batch_size):
|
| 907 |
+
"""
|
| 908 |
+
Execute the complete training process using OpenLLM's approach.
|
| 909 |
|
| 910 |
+
This function orchestrates the entire training pipeline:
|
| 911 |
+
1. Validates OpenLLM availability
|
| 912 |
+
2. Creates training configuration
|
| 913 |
+
3. Loads model and tokenizer
|
| 914 |
+
4. Prepares dataset
|
| 915 |
+
5. Sets up training environment
|
| 916 |
+
6. Executes training
|
| 917 |
+
7. Saves and uploads the trained model
|
| 918 |
|
| 919 |
+
The function provides comprehensive error handling and status updates
|
| 920 |
+
throughout the training process.
|
|
|
|
|
|
|
|
|
|
| 921 |
|
| 922 |
+
Args:
|
| 923 |
+
model_size: Size of the model to train ("small", "medium", "large")
|
| 924 |
+
max_steps: Maximum number of training steps
|
| 925 |
+
learning_rate: Learning rate for the optimizer
|
| 926 |
+
batch_size: Batch size for training
|
| 927 |
+
|
| 928 |
+
Returns:
|
| 929 |
+
Status message indicating the result of the training process
|
| 930 |
+
"""
|
| 931 |
+
# Validate OpenLLM availability
|
| 932 |
+
if not OPENLLM_AVAILABLE:
|
| 933 |
+
return "β OpenLLM custom model architecture not available. Please check the installation."
|
| 934 |
|
| 935 |
+
try:
|
| 936 |
+
print(f"π Starting complete training process...")
|
| 937 |
+
print(f" - Model size: {model_size}")
|
| 938 |
+
print(f" - Max steps: {max_steps}")
|
| 939 |
+
print(f" - Learning rate: {learning_rate}")
|
| 940 |
+
print(f" - Batch size: {batch_size}")
|
| 941 |
+
|
| 942 |
+
# Create training configuration
|
| 943 |
+
# This encapsulates all training parameters
|
| 944 |
+
config = TrainingConfig(
|
| 945 |
+
model_size=model_size,
|
| 946 |
+
max_steps=max_steps,
|
| 947 |
+
learning_rate=learning_rate,
|
| 948 |
+
batch_size=batch_size
|
| 949 |
+
)
|
| 950 |
+
|
| 951 |
+
# Step 1: Load model and tokenizer using OpenLLM's approach
|
| 952 |
+
print("π Step 1: Loading model and tokenizer...")
|
| 953 |
+
status = trainer.load_model_and_tokenizer(model_size)
|
| 954 |
+
if "β" in status:
|
| 955 |
+
return status
|
| 956 |
+
|
| 957 |
+
# Step 2: Prepare dataset
|
| 958 |
+
print("π Step 2: Preparing dataset...")
|
| 959 |
+
status = trainer.prepare_dataset()
|
| 960 |
+
if "β" in status:
|
| 961 |
+
return status
|
| 962 |
+
|
| 963 |
+
# Step 3: Setup training
|
| 964 |
+
print("π Step 3: Setting up training...")
|
| 965 |
+
status = trainer.setup_training(config)
|
| 966 |
+
if "β" in status:
|
| 967 |
+
return status
|
| 968 |
+
|
| 969 |
+
# Step 4: Execute training
|
| 970 |
+
print("π Step 4: Executing training...")
|
| 971 |
+
status = trainer.train_model(config)
|
| 972 |
+
if "β" in status:
|
| 973 |
+
return status
|
| 974 |
+
|
| 975 |
+
# Step 5: Save and upload model
|
| 976 |
+
print("π Step 5: Saving and uploading model...")
|
| 977 |
+
status = trainer.save_and_upload_model(config)
|
| 978 |
+
|
| 979 |
+
print("π Complete training process finished!")
|
| 980 |
+
return f"π Complete training process finished!\n{status}"
|
| 981 |
+
|
| 982 |
+
except Exception as e:
|
| 983 |
+
print(f"β Training process failed: {str(e)}")
|
| 984 |
+
return f"β Training process failed: {str(e)}"
|
| 985 |
+
|
| 986 |
+
def update_progress():
|
| 987 |
+
"""
|
| 988 |
+
Update the progress display.
|
| 989 |
|
| 990 |
+
This function is called periodically to update the progress
|
| 991 |
+
information displayed in the Gradio interface. It returns the
|
| 992 |
+
current training progress from the trainer.
|
| 993 |
|
| 994 |
+
Returns:
|
| 995 |
+
Current training progress dictionary
|
| 996 |
+
"""
|
| 997 |
+
return trainer.get_training_progress()
|
| 998 |
+
|
| 999 |
+
# Connect UI Components to Functions
|
| 1000 |
+
# This connects the Start Training button to the training function
|
| 1001 |
+
start_btn.click(
|
| 1002 |
+
fn=start_complete_training,
|
| 1003 |
+
inputs=[model_size, max_steps, learning_rate, batch_size],
|
| 1004 |
+
outputs=[status_text]
|
| 1005 |
+
)
|
| 1006 |
+
|
| 1007 |
+
# Auto-refresh progress every 5 seconds during training
|
| 1008 |
+
# This ensures the progress display stays up to date
|
| 1009 |
+
demo.load(update_progress, outputs=[progress_info])
|
| 1010 |
+
|
| 1011 |
+
# Application Footer
|
| 1012 |
+
# Provides attribution and technical information
|
| 1013 |
+
gr.Markdown("---")
|
| 1014 |
+
gr.Markdown("**Author**: Louis Chua Bean Chong | **Project**: OpenLLM | **License**: GPL-3.0")
|
| 1015 |
+
gr.Markdown("**Architecture**: OpenLLM Custom GPTModel (From Uploaded Files)")
|
| 1016 |
+
gr.Markdown("**Tokenizer**: sentencepiece.SentencePieceProcessor()")
|
| 1017 |
|
| 1018 |
+
return demo
|
|
|
|
| 1019 |
|
| 1020 |
if __name__ == "__main__":
|
| 1021 |
+
# Launch the Gradio application
|
| 1022 |
+
# This starts the web interface for the training application
|
| 1023 |
+
demo = main()
|
| 1024 |
+
demo.launch()
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
|
@@ -1,26 +1,51 @@
|
|
| 1 |
-
#
|
| 2 |
-
#
|
| 3 |
|
| 4 |
-
#
|
| 5 |
-
|
|
|
|
|
|
|
| 6 |
|
| 7 |
-
#
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
-
#
|
| 11 |
-
|
| 12 |
-
torchvision>=0.15.0
|
| 13 |
|
| 14 |
-
#
|
| 15 |
-
|
|
|
|
|
|
|
| 16 |
|
| 17 |
-
#
|
| 18 |
-
|
|
|
|
| 19 |
|
| 20 |
-
#
|
| 21 |
-
|
| 22 |
-
|
| 23 |
|
| 24 |
-
#
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Complete Training Dependencies for OpenLLM Space - Updated for Gradio 4.44.1
|
| 2 |
+
# This file includes all necessary packages for real model training
|
| 3 |
|
| 4 |
+
# Core Machine Learning Framework
|
| 5 |
+
torch>=2.0.0 # PyTorch deep learning framework
|
| 6 |
+
torchvision>=0.15.0 # Computer vision utilities
|
| 7 |
+
torchaudio>=2.0.0 # Audio processing utilities
|
| 8 |
|
| 9 |
+
# Hugging Face Ecosystem - Complete Training Stack
|
| 10 |
+
transformers>=4.30.0 # Pre-trained models and training utilities
|
| 11 |
+
datasets>=2.12.0 # Dataset loading and processing
|
| 12 |
+
tokenizers>=0.13.0 # Fast tokenization library
|
| 13 |
+
sentencepiece>=0.1.99 # SentencePiece tokenization (CRITICAL for OpenLLM models)
|
| 14 |
+
huggingface_hub>=0.34.0 # Hugging Face Hub integration
|
| 15 |
+
accelerate>=0.20.0 # Distributed training acceleration
|
| 16 |
|
| 17 |
+
# User Interface Framework - Updated to 4.44.1
|
| 18 |
+
gradio==4.44.1 # Web UI framework for ML applications (fixed version)
|
|
|
|
| 19 |
|
| 20 |
+
# Data Processing and Scientific Computing
|
| 21 |
+
numpy>=1.24.0 # Numerical computing library
|
| 22 |
+
pandas>=2.0.0 # Data manipulation and analysis
|
| 23 |
+
scipy>=1.10.0 # Scientific computing utilities
|
| 24 |
|
| 25 |
+
# Progress and Monitoring
|
| 26 |
+
tqdm>=4.65.0 # Progress bars for long-running operations
|
| 27 |
+
psutil>=5.9.0 # System and process utilities
|
| 28 |
|
| 29 |
+
# Memory and Performance Optimization
|
| 30 |
+
bitsandbytes>=0.41.0 # Quantization utilities for memory efficiency
|
| 31 |
+
peft>=0.4.0 # Parameter-Efficient Fine-Tuning
|
| 32 |
|
| 33 |
+
# Logging and Debugging
|
| 34 |
+
wandb>=0.15.0 # Experiment tracking (optional)
|
| 35 |
+
tensorboard>=2.13.0 # Training visualization (optional)
|
| 36 |
+
|
| 37 |
+
# Additional Utilities
|
| 38 |
+
requests>=2.31.0 # HTTP library for API calls
|
| 39 |
+
pillow>=9.5.0 # Image processing (if needed)
|
| 40 |
+
matplotlib>=3.7.0 # Plotting and visualization
|
| 41 |
+
seaborn>=0.12.0 # Statistical data visualization
|
| 42 |
+
|
| 43 |
+
# Development and Testing (optional)
|
| 44 |
+
pytest>=7.4.0 # Testing framework
|
| 45 |
+
black>=23.0.0 # Code formatting
|
| 46 |
+
flake8>=6.0.0 # Code linting
|
| 47 |
+
|
| 48 |
+
# Note: These versions are compatible with Hugging Face Spaces
|
| 49 |
+
# and provide stable training performance for OpenLLM models
|
| 50 |
+
# Gradio 4.44.1 fixes compatibility issues with JSON components
|
| 51 |
+
# SentencePiece is CRITICAL for OpenLLM model tokenization
|
training/data_loader.py
CHANGED
|
@@ -13,12 +13,12 @@
|
|
| 13 |
Training Data Loader for Language Model Training
|
| 14 |
|
| 15 |
This module provides efficient data loading and batching for training GPT-style
|
| 16 |
-
language models. It handles text preprocessing, tokenization, and creates
|
| 17 |
batches suitable for autoregressive language modeling.
|
| 18 |
|
| 19 |
FEATURES:
|
| 20 |
- Memory-efficient text loading with sliding window
|
| 21 |
-
- Automatic tokenization using trained SentencePiece model
|
| 22 |
- Configurable sequence length and batch size
|
| 23 |
- CPU-optimized data loading for limited hardware
|
| 24 |
- Support for training data validation and statistics
|
|
@@ -31,20 +31,20 @@ MEMORY OPTIMIZATION:
|
|
| 31 |
|
| 32 |
Usage:
|
| 33 |
from data_loader import TextDataLoader
|
| 34 |
-
|
| 35 |
loader = TextDataLoader(
|
| 36 |
data_file="data/clean/training_data.txt",
|
| 37 |
-
tokenizer_path="data/tokenizer/tokenizer.model",
|
| 38 |
seq_len=512,
|
| 39 |
batch_size=4
|
| 40 |
)
|
| 41 |
-
|
| 42 |
for batch in loader:
|
| 43 |
input_ids, targets = batch
|
| 44 |
# input_ids: (batch_size, seq_len)
|
| 45 |
# targets: (batch_size, seq_len) - shifted by 1 for next token prediction
|
| 46 |
|
| 47 |
-
Author: Louis Chua Bean Chong
|
| 48 |
License: GPLv3
|
| 49 |
"""
|
| 50 |
|
|
@@ -66,11 +66,11 @@ except ImportError:
|
|
| 66 |
class TextDataLoader:
|
| 67 |
"""
|
| 68 |
Efficient data loader for autoregressive language model training.
|
| 69 |
-
|
| 70 |
This class handles loading text data, tokenizing it using SentencePiece,
|
| 71 |
and creating batches suitable for next-token prediction training.
|
| 72 |
"""
|
| 73 |
-
|
| 74 |
def __init__(
|
| 75 |
self,
|
| 76 |
data_file: str,
|
|
@@ -79,11 +79,11 @@ class TextDataLoader:
|
|
| 79 |
batch_size: int = 4,
|
| 80 |
chunk_size: int = 1000000, # Lines to read at once
|
| 81 |
shuffle: bool = True,
|
| 82 |
-
seed: int = 42
|
| 83 |
):
|
| 84 |
"""
|
| 85 |
Initialize the data loader.
|
| 86 |
-
|
| 87 |
Args:
|
| 88 |
data_file: Path to training text file (one passage per line)
|
| 89 |
tokenizer_path: Path to trained SentencePiece model
|
|
@@ -100,44 +100,44 @@ class TextDataLoader:
|
|
| 100 |
self.chunk_size = chunk_size
|
| 101 |
self.shuffle = shuffle
|
| 102 |
self.seed = seed
|
| 103 |
-
|
| 104 |
# Validate inputs
|
| 105 |
self._validate_inputs()
|
| 106 |
-
|
| 107 |
# Load tokenizer
|
| 108 |
self.tokenizer = self._load_tokenizer()
|
| 109 |
-
|
| 110 |
# Get data statistics
|
| 111 |
self.total_lines = self._count_lines()
|
| 112 |
self.current_line = 0
|
| 113 |
-
|
| 114 |
# Set random seed for reproducibility
|
| 115 |
random.seed(seed)
|
| 116 |
-
|
| 117 |
print(f"π TextDataLoader initialized")
|
| 118 |
print(f" Data file: {data_file}")
|
| 119 |
print(f" Total passages: {self.total_lines:,}")
|
| 120 |
print(f" Sequence length: {seq_len}")
|
| 121 |
print(f" Batch size: {batch_size}")
|
| 122 |
print(f" Vocabulary size: {self.tokenizer.vocab_size():,}")
|
| 123 |
-
|
| 124 |
def _validate_inputs(self) -> None:
|
| 125 |
"""Validate input parameters and file paths."""
|
| 126 |
if not os.path.exists(self.data_file):
|
| 127 |
raise FileNotFoundError(f"Training data file not found: {self.data_file}")
|
| 128 |
-
|
| 129 |
if not os.path.exists(self.tokenizer_path):
|
| 130 |
raise FileNotFoundError(f"Tokenizer model not found: {self.tokenizer_path}")
|
| 131 |
-
|
| 132 |
if self.seq_len <= 0:
|
| 133 |
raise ValueError(f"Sequence length must be positive, got {self.seq_len}")
|
| 134 |
-
|
| 135 |
if self.batch_size <= 0:
|
| 136 |
raise ValueError(f"Batch size must be positive, got {self.batch_size}")
|
| 137 |
-
|
| 138 |
if self.chunk_size <= 0:
|
| 139 |
raise ValueError(f"Chunk size must be positive, got {self.chunk_size}")
|
| 140 |
-
|
| 141 |
def _load_tokenizer(self) -> spm.SentencePieceProcessor:
|
| 142 |
"""Load the trained SentencePiece tokenizer."""
|
| 143 |
try:
|
|
@@ -146,222 +146,226 @@ class TextDataLoader:
|
|
| 146 |
return tokenizer
|
| 147 |
except Exception as e:
|
| 148 |
raise RuntimeError(f"Failed to load tokenizer: {e}")
|
| 149 |
-
|
| 150 |
def _count_lines(self) -> int:
|
| 151 |
"""Count total number of lines in the data file."""
|
| 152 |
print("π Counting training passages...")
|
| 153 |
start_time = time.time()
|
| 154 |
-
|
| 155 |
line_count = 0
|
| 156 |
-
with open(self.data_file,
|
| 157 |
for line in f:
|
| 158 |
if line.strip(): # Only count non-empty lines
|
| 159 |
line_count += 1
|
| 160 |
-
|
| 161 |
count_time = time.time() - start_time
|
| 162 |
print(f"β Found {line_count:,} passages in {count_time:.1f}s")
|
| 163 |
-
|
| 164 |
return line_count
|
| 165 |
-
|
| 166 |
def _read_chunk(self, start_line: int = 0) -> List[str]:
|
| 167 |
"""
|
| 168 |
Read a chunk of lines from the data file.
|
| 169 |
-
|
| 170 |
Args:
|
| 171 |
start_line: Line number to start reading from
|
| 172 |
-
|
| 173 |
Returns:
|
| 174 |
List of text passages
|
| 175 |
"""
|
| 176 |
chunk = []
|
| 177 |
current_line = 0
|
| 178 |
lines_read = 0
|
| 179 |
-
|
| 180 |
-
with open(self.data_file,
|
| 181 |
for line in f:
|
| 182 |
if current_line < start_line:
|
| 183 |
current_line += 1
|
| 184 |
continue
|
| 185 |
-
|
| 186 |
text = line.strip()
|
| 187 |
if text: # Only include non-empty lines
|
| 188 |
chunk.append(text)
|
| 189 |
lines_read += 1
|
| 190 |
-
|
| 191 |
if lines_read >= self.chunk_size:
|
| 192 |
break
|
| 193 |
-
|
| 194 |
current_line += 1
|
| 195 |
-
|
| 196 |
return chunk
|
| 197 |
-
|
| 198 |
def _tokenize_texts(self, texts: List[str]) -> List[List[int]]:
|
| 199 |
"""
|
| 200 |
Tokenize a list of text passages using SentencePiece tokenizer.
|
| 201 |
-
|
| 202 |
This method converts raw text into token ID sequences suitable for language model training.
|
| 203 |
It handles special tokens (BOS/EOS) and length constraints for efficient training.
|
| 204 |
-
|
| 205 |
Text processing pipeline:
|
| 206 |
1. Add BOS (Beginning of Sequence) token to mark sequence start
|
| 207 |
2. Tokenize text using trained SentencePiece model (subword tokenization)
|
| 208 |
3. Truncate sequences that exceed maximum length
|
| 209 |
4. Add EOS (End of Sequence) token to mark sequence end
|
| 210 |
-
|
| 211 |
Special token handling:
|
| 212 |
- BOS token helps model learn to generate text from scratch
|
| 213 |
- EOS token signals natural sequence endings
|
| 214 |
- These tokens are crucial for proper autoregressive generation
|
| 215 |
-
|
| 216 |
Args:
|
| 217 |
texts: List of text passages (typically Wikipedia passages from SQUAD)
|
| 218 |
Each passage should be a complete, coherent text segment
|
| 219 |
-
|
| 220 |
Returns:
|
| 221 |
List of token ID sequences, where each sequence is a list of integers
|
| 222 |
representing subword tokens from the SentencePiece vocabulary
|
| 223 |
"""
|
| 224 |
tokenized = []
|
| 225 |
-
|
| 226 |
for text in texts:
|
| 227 |
try:
|
| 228 |
# Add BOS (Beginning of Sequence) token at the start
|
| 229 |
# BOS token ID=2 by default in SentencePiece, signals sequence start
|
| 230 |
# This helps the model learn proper sequence initialization during generation
|
| 231 |
tokens = [self.tokenizer.bos_id()] + self.tokenizer.encode(text)
|
| 232 |
-
|
| 233 |
# Truncate sequences that exceed maximum context length
|
| 234 |
# Reserve one position for EOS token by using (seq_len - 1)
|
| 235 |
# This ensures we never exceed the model's context window during training
|
| 236 |
if len(tokens) > self.seq_len - 1:
|
| 237 |
-
tokens = tokens[:self.seq_len - 1]
|
| 238 |
# NOTE: Truncation may cut off text mid-sentence, but this is acceptable
|
| 239 |
# for language modeling where the model learns from partial contexts
|
| 240 |
-
|
| 241 |
# Add EOS (End of Sequence) token at the end
|
| 242 |
# EOS token ID=1 by default in SentencePiece, signals sequence completion
|
| 243 |
# This teaches the model when to stop generating text naturally
|
| 244 |
tokens.append(self.tokenizer.eos_id())
|
| 245 |
-
|
| 246 |
# Validate tokenization result
|
| 247 |
if len(tokens) <= 2: # Only BOS + EOS tokens, no actual content
|
| 248 |
print(f"β οΈ Skipping very short text: {text[:50]}...")
|
| 249 |
continue
|
| 250 |
-
|
| 251 |
tokenized.append(tokens)
|
| 252 |
-
|
| 253 |
except Exception as e:
|
| 254 |
# Handle tokenization errors gracefully to avoid stopping training
|
| 255 |
# Common causes: encoding issues, very long texts, special characters
|
| 256 |
print(f"β οΈ Failed to tokenize passage: {text[:50]}... Error: {e}")
|
| 257 |
continue
|
| 258 |
-
|
| 259 |
# Log tokenization statistics for monitoring
|
| 260 |
if tokenized:
|
| 261 |
avg_length = sum(len(tokens) for tokens in tokenized) / len(tokenized)
|
| 262 |
print(f"π Tokenized {len(tokenized)} passages, avg length: {avg_length:.1f} tokens")
|
| 263 |
-
|
| 264 |
return tokenized
|
| 265 |
-
|
| 266 |
-
def _create_training_examples(
|
|
|
|
|
|
|
| 267 |
"""
|
| 268 |
Create training examples with input and target sequences.
|
| 269 |
-
|
| 270 |
For autoregressive training, targets are inputs shifted by one position.
|
| 271 |
-
|
| 272 |
Args:
|
| 273 |
token_sequences: List of tokenized sequences
|
| 274 |
-
|
| 275 |
Returns:
|
| 276 |
List of (input_ids, target_ids) tuples
|
| 277 |
"""
|
| 278 |
examples = []
|
| 279 |
-
|
| 280 |
for tokens in token_sequences:
|
| 281 |
if len(tokens) < 2: # Need at least 2 tokens for input/target pair
|
| 282 |
continue
|
| 283 |
-
|
| 284 |
# For sequences longer than seq_len, create multiple examples with sliding window
|
| 285 |
if len(tokens) > self.seq_len:
|
| 286 |
# Create overlapping windows (50% overlap for better learning)
|
| 287 |
stride = self.seq_len // 2
|
| 288 |
for i in range(0, len(tokens) - self.seq_len, stride):
|
| 289 |
-
input_ids = tokens[i:i + self.seq_len]
|
| 290 |
-
target_ids = tokens[i + 1:i + self.seq_len + 1]
|
| 291 |
examples.append((input_ids, target_ids))
|
| 292 |
else:
|
| 293 |
# Pad shorter sequences
|
| 294 |
input_ids = tokens[:-1] # All but last token
|
| 295 |
target_ids = tokens[1:] # All but first token
|
| 296 |
-
|
| 297 |
# Pad to seq_len if necessary
|
| 298 |
while len(input_ids) < self.seq_len:
|
| 299 |
input_ids.append(self.tokenizer.pad_id())
|
| 300 |
target_ids.append(-1) # Use -1 for padding in targets (ignored in loss)
|
| 301 |
-
|
| 302 |
# Truncate if still too long
|
| 303 |
-
input_ids = input_ids[:self.seq_len]
|
| 304 |
-
target_ids = target_ids[:self.seq_len]
|
| 305 |
-
|
| 306 |
examples.append((input_ids, target_ids))
|
| 307 |
-
|
| 308 |
return examples
|
| 309 |
-
|
| 310 |
-
def _create_batch(
|
|
|
|
|
|
|
| 311 |
"""
|
| 312 |
Create a batch tensor from training examples.
|
| 313 |
-
|
| 314 |
Args:
|
| 315 |
examples: List of (input_ids, target_ids) tuples
|
| 316 |
-
|
| 317 |
Returns:
|
| 318 |
Tuple of (input_tensor, target_tensor)
|
| 319 |
"""
|
| 320 |
if not examples:
|
| 321 |
raise ValueError("Cannot create batch from empty examples")
|
| 322 |
-
|
| 323 |
batch_size = len(examples)
|
| 324 |
-
|
| 325 |
# Initialize tensors
|
| 326 |
input_ids = torch.zeros((batch_size, self.seq_len), dtype=torch.long)
|
| 327 |
target_ids = torch.full((batch_size, self.seq_len), -1, dtype=torch.long)
|
| 328 |
-
|
| 329 |
# Fill tensors
|
| 330 |
for i, (inp, tgt) in enumerate(examples):
|
| 331 |
-
input_ids[i, :len(inp)] = torch.tensor(inp, dtype=torch.long)
|
| 332 |
-
target_ids[i, :len(tgt)] = torch.tensor(tgt, dtype=torch.long)
|
| 333 |
-
|
| 334 |
return input_ids, target_ids
|
| 335 |
-
|
| 336 |
def __iter__(self) -> Iterator[Tuple[torch.Tensor, torch.Tensor]]:
|
| 337 |
"""
|
| 338 |
Iterate over training batches.
|
| 339 |
-
|
| 340 |
Yields:
|
| 341 |
Tuple of (input_ids, target_ids) tensors
|
| 342 |
"""
|
| 343 |
self.current_line = 0
|
| 344 |
-
|
| 345 |
while self.current_line < self.total_lines:
|
| 346 |
# Read chunk of text
|
| 347 |
texts = self._read_chunk(self.current_line)
|
| 348 |
if not texts:
|
| 349 |
break
|
| 350 |
-
|
| 351 |
# Tokenize texts
|
| 352 |
token_sequences = self._tokenize_texts(texts)
|
| 353 |
-
|
| 354 |
# Create training examples
|
| 355 |
examples = self._create_training_examples(token_sequences)
|
| 356 |
-
|
| 357 |
# Shuffle examples if requested
|
| 358 |
if self.shuffle:
|
| 359 |
random.shuffle(examples)
|
| 360 |
-
|
| 361 |
# Create batches
|
| 362 |
for i in range(0, len(examples), self.batch_size):
|
| 363 |
-
batch_examples = examples[i:i + self.batch_size]
|
| 364 |
-
|
| 365 |
if len(batch_examples) == self.batch_size: # Only yield full batches
|
| 366 |
try:
|
| 367 |
input_ids, target_ids = self._create_batch(batch_examples)
|
|
@@ -369,27 +373,27 @@ class TextDataLoader:
|
|
| 369 |
except Exception as e:
|
| 370 |
print(f"β οΈ Failed to create batch: {e}")
|
| 371 |
continue
|
| 372 |
-
|
| 373 |
# Update progress
|
| 374 |
self.current_line += len(texts)
|
| 375 |
-
|
| 376 |
# Clean up memory
|
| 377 |
del texts, token_sequences, examples
|
| 378 |
gc.collect()
|
| 379 |
-
|
| 380 |
def get_data_stats(self) -> dict:
|
| 381 |
"""
|
| 382 |
Get statistics about the training data.
|
| 383 |
-
|
| 384 |
Returns:
|
| 385 |
Dictionary with data statistics
|
| 386 |
"""
|
| 387 |
print("π Analyzing training data...")
|
| 388 |
-
|
| 389 |
# Sample some data to get statistics
|
| 390 |
sample_texts = self._read_chunk(0)[:100] # Sample first 100 passages
|
| 391 |
token_sequences = self._tokenize_texts(sample_texts)
|
| 392 |
-
|
| 393 |
if token_sequences:
|
| 394 |
sequence_lengths = [len(seq) for seq in token_sequences]
|
| 395 |
avg_length = sum(sequence_lengths) / len(sequence_lengths)
|
|
@@ -397,15 +401,15 @@ class TextDataLoader:
|
|
| 397 |
min_length = min(sequence_lengths)
|
| 398 |
else:
|
| 399 |
avg_length = max_length = min_length = 0
|
| 400 |
-
|
| 401 |
# Estimate total tokens
|
| 402 |
estimated_total_tokens = int(avg_length * self.total_lines)
|
| 403 |
-
|
| 404 |
# Estimate number of batches per epoch
|
| 405 |
examples_per_passage = max(1, avg_length // self.seq_len)
|
| 406 |
total_examples = int(self.total_lines * examples_per_passage)
|
| 407 |
batches_per_epoch = total_examples // self.batch_size
|
| 408 |
-
|
| 409 |
stats = {
|
| 410 |
"total_passages": self.total_lines,
|
| 411 |
"avg_tokens_per_passage": avg_length,
|
|
@@ -416,22 +420,22 @@ class TextDataLoader:
|
|
| 416 |
"estimated_batches_per_epoch": batches_per_epoch,
|
| 417 |
"sequence_length": self.seq_len,
|
| 418 |
"batch_size": self.batch_size,
|
| 419 |
-
"vocabulary_size": self.tokenizer.vocab_size()
|
| 420 |
}
|
| 421 |
-
|
| 422 |
print(f"β Data analysis complete:")
|
| 423 |
print(f" Total passages: {stats['total_passages']:,}")
|
| 424 |
print(f" Avg tokens per passage: {stats['avg_tokens_per_passage']:.1f}")
|
| 425 |
print(f" Estimated total tokens: {stats['estimated_total_tokens']:,}")
|
| 426 |
print(f" Estimated batches per epoch: {stats['estimated_batches_per_epoch']:,}")
|
| 427 |
-
|
| 428 |
return stats
|
| 429 |
|
| 430 |
|
| 431 |
def test_data_loader():
|
| 432 |
"""Test function for the data loader."""
|
| 433 |
print("π§ͺ Testing TextDataLoader...")
|
| 434 |
-
|
| 435 |
# Test with small parameters
|
| 436 |
try:
|
| 437 |
loader = TextDataLoader(
|
|
@@ -439,42 +443,43 @@ def test_data_loader():
|
|
| 439 |
tokenizer_path="data/tokenizer/tokenizer.model",
|
| 440 |
seq_len=128,
|
| 441 |
batch_size=2,
|
| 442 |
-
chunk_size=10 # Small for testing
|
| 443 |
)
|
| 444 |
-
|
| 445 |
# Get data statistics
|
| 446 |
stats = loader.get_data_stats()
|
| 447 |
-
|
| 448 |
# Test iteration
|
| 449 |
print("\nπ Testing batch iteration...")
|
| 450 |
start_time = time.time()
|
| 451 |
batch_count = 0
|
| 452 |
-
|
| 453 |
for batch_idx, (input_ids, target_ids) in enumerate(loader):
|
| 454 |
batch_count += 1
|
| 455 |
-
|
| 456 |
print(f"Batch {batch_idx + 1}:")
|
| 457 |
print(f" Input shape: {input_ids.shape}")
|
| 458 |
print(f" Target shape: {target_ids.shape}")
|
| 459 |
print(f" Sample input tokens: {input_ids[0][:10].tolist()}")
|
| 460 |
print(f" Sample target tokens: {target_ids[0][:10].tolist()}")
|
| 461 |
-
|
| 462 |
if batch_idx >= 2: # Only test first few batches
|
| 463 |
break
|
| 464 |
-
|
| 465 |
test_time = time.time() - start_time
|
| 466 |
print(f"\nβ Data loader test completed successfully!")
|
| 467 |
print(f" Processed {batch_count} batches in {test_time:.2f}s")
|
| 468 |
print(f" Average time per batch: {test_time/max(1, batch_count):.2f}s")
|
| 469 |
-
|
| 470 |
return True
|
| 471 |
-
|
| 472 |
except Exception as e:
|
| 473 |
print(f"β Data loader test failed: {e}")
|
| 474 |
import traceback
|
|
|
|
| 475 |
traceback.print_exc()
|
| 476 |
return False
|
| 477 |
|
| 478 |
|
| 479 |
if __name__ == "__main__":
|
| 480 |
-
test_data_loader()
|
|
|
|
| 13 |
Training Data Loader for Language Model Training
|
| 14 |
|
| 15 |
This module provides efficient data loading and batching for training GPT-style
|
| 16 |
+
language models. It handles text preprocessing, tokenization, and creates
|
| 17 |
batches suitable for autoregressive language modeling.
|
| 18 |
|
| 19 |
FEATURES:
|
| 20 |
- Memory-efficient text loading with sliding window
|
| 21 |
+
- Automatic tokenization using trained SentencePiece model
|
| 22 |
- Configurable sequence length and batch size
|
| 23 |
- CPU-optimized data loading for limited hardware
|
| 24 |
- Support for training data validation and statistics
|
|
|
|
| 31 |
|
| 32 |
Usage:
|
| 33 |
from data_loader import TextDataLoader
|
| 34 |
+
|
| 35 |
loader = TextDataLoader(
|
| 36 |
data_file="data/clean/training_data.txt",
|
| 37 |
+
tokenizer_path="data/tokenizer/tokenizer.model",
|
| 38 |
seq_len=512,
|
| 39 |
batch_size=4
|
| 40 |
)
|
| 41 |
+
|
| 42 |
for batch in loader:
|
| 43 |
input_ids, targets = batch
|
| 44 |
# input_ids: (batch_size, seq_len)
|
| 45 |
# targets: (batch_size, seq_len) - shifted by 1 for next token prediction
|
| 46 |
|
| 47 |
+
Author: Louis Chua Bean Chong
|
| 48 |
License: GPLv3
|
| 49 |
"""
|
| 50 |
|
|
|
|
| 66 |
class TextDataLoader:
|
| 67 |
"""
|
| 68 |
Efficient data loader for autoregressive language model training.
|
| 69 |
+
|
| 70 |
This class handles loading text data, tokenizing it using SentencePiece,
|
| 71 |
and creating batches suitable for next-token prediction training.
|
| 72 |
"""
|
| 73 |
+
|
| 74 |
def __init__(
|
| 75 |
self,
|
| 76 |
data_file: str,
|
|
|
|
| 79 |
batch_size: int = 4,
|
| 80 |
chunk_size: int = 1000000, # Lines to read at once
|
| 81 |
shuffle: bool = True,
|
| 82 |
+
seed: int = 42,
|
| 83 |
):
|
| 84 |
"""
|
| 85 |
Initialize the data loader.
|
| 86 |
+
|
| 87 |
Args:
|
| 88 |
data_file: Path to training text file (one passage per line)
|
| 89 |
tokenizer_path: Path to trained SentencePiece model
|
|
|
|
| 100 |
self.chunk_size = chunk_size
|
| 101 |
self.shuffle = shuffle
|
| 102 |
self.seed = seed
|
| 103 |
+
|
| 104 |
# Validate inputs
|
| 105 |
self._validate_inputs()
|
| 106 |
+
|
| 107 |
# Load tokenizer
|
| 108 |
self.tokenizer = self._load_tokenizer()
|
| 109 |
+
|
| 110 |
# Get data statistics
|
| 111 |
self.total_lines = self._count_lines()
|
| 112 |
self.current_line = 0
|
| 113 |
+
|
| 114 |
# Set random seed for reproducibility
|
| 115 |
random.seed(seed)
|
| 116 |
+
|
| 117 |
print(f"π TextDataLoader initialized")
|
| 118 |
print(f" Data file: {data_file}")
|
| 119 |
print(f" Total passages: {self.total_lines:,}")
|
| 120 |
print(f" Sequence length: {seq_len}")
|
| 121 |
print(f" Batch size: {batch_size}")
|
| 122 |
print(f" Vocabulary size: {self.tokenizer.vocab_size():,}")
|
| 123 |
+
|
| 124 |
def _validate_inputs(self) -> None:
|
| 125 |
"""Validate input parameters and file paths."""
|
| 126 |
if not os.path.exists(self.data_file):
|
| 127 |
raise FileNotFoundError(f"Training data file not found: {self.data_file}")
|
| 128 |
+
|
| 129 |
if not os.path.exists(self.tokenizer_path):
|
| 130 |
raise FileNotFoundError(f"Tokenizer model not found: {self.tokenizer_path}")
|
| 131 |
+
|
| 132 |
if self.seq_len <= 0:
|
| 133 |
raise ValueError(f"Sequence length must be positive, got {self.seq_len}")
|
| 134 |
+
|
| 135 |
if self.batch_size <= 0:
|
| 136 |
raise ValueError(f"Batch size must be positive, got {self.batch_size}")
|
| 137 |
+
|
| 138 |
if self.chunk_size <= 0:
|
| 139 |
raise ValueError(f"Chunk size must be positive, got {self.chunk_size}")
|
| 140 |
+
|
| 141 |
def _load_tokenizer(self) -> spm.SentencePieceProcessor:
|
| 142 |
"""Load the trained SentencePiece tokenizer."""
|
| 143 |
try:
|
|
|
|
| 146 |
return tokenizer
|
| 147 |
except Exception as e:
|
| 148 |
raise RuntimeError(f"Failed to load tokenizer: {e}")
|
| 149 |
+
|
| 150 |
def _count_lines(self) -> int:
|
| 151 |
"""Count total number of lines in the data file."""
|
| 152 |
print("π Counting training passages...")
|
| 153 |
start_time = time.time()
|
| 154 |
+
|
| 155 |
line_count = 0
|
| 156 |
+
with open(self.data_file, "r", encoding="utf-8") as f:
|
| 157 |
for line in f:
|
| 158 |
if line.strip(): # Only count non-empty lines
|
| 159 |
line_count += 1
|
| 160 |
+
|
| 161 |
count_time = time.time() - start_time
|
| 162 |
print(f"β Found {line_count:,} passages in {count_time:.1f}s")
|
| 163 |
+
|
| 164 |
return line_count
|
| 165 |
+
|
| 166 |
def _read_chunk(self, start_line: int = 0) -> List[str]:
|
| 167 |
"""
|
| 168 |
Read a chunk of lines from the data file.
|
| 169 |
+
|
| 170 |
Args:
|
| 171 |
start_line: Line number to start reading from
|
| 172 |
+
|
| 173 |
Returns:
|
| 174 |
List of text passages
|
| 175 |
"""
|
| 176 |
chunk = []
|
| 177 |
current_line = 0
|
| 178 |
lines_read = 0
|
| 179 |
+
|
| 180 |
+
with open(self.data_file, "r", encoding="utf-8") as f:
|
| 181 |
for line in f:
|
| 182 |
if current_line < start_line:
|
| 183 |
current_line += 1
|
| 184 |
continue
|
| 185 |
+
|
| 186 |
text = line.strip()
|
| 187 |
if text: # Only include non-empty lines
|
| 188 |
chunk.append(text)
|
| 189 |
lines_read += 1
|
| 190 |
+
|
| 191 |
if lines_read >= self.chunk_size:
|
| 192 |
break
|
| 193 |
+
|
| 194 |
current_line += 1
|
| 195 |
+
|
| 196 |
return chunk
|
| 197 |
+
|
| 198 |
def _tokenize_texts(self, texts: List[str]) -> List[List[int]]:
|
| 199 |
"""
|
| 200 |
Tokenize a list of text passages using SentencePiece tokenizer.
|
| 201 |
+
|
| 202 |
This method converts raw text into token ID sequences suitable for language model training.
|
| 203 |
It handles special tokens (BOS/EOS) and length constraints for efficient training.
|
| 204 |
+
|
| 205 |
Text processing pipeline:
|
| 206 |
1. Add BOS (Beginning of Sequence) token to mark sequence start
|
| 207 |
2. Tokenize text using trained SentencePiece model (subword tokenization)
|
| 208 |
3. Truncate sequences that exceed maximum length
|
| 209 |
4. Add EOS (End of Sequence) token to mark sequence end
|
| 210 |
+
|
| 211 |
Special token handling:
|
| 212 |
- BOS token helps model learn to generate text from scratch
|
| 213 |
- EOS token signals natural sequence endings
|
| 214 |
- These tokens are crucial for proper autoregressive generation
|
| 215 |
+
|
| 216 |
Args:
|
| 217 |
texts: List of text passages (typically Wikipedia passages from SQUAD)
|
| 218 |
Each passage should be a complete, coherent text segment
|
| 219 |
+
|
| 220 |
Returns:
|
| 221 |
List of token ID sequences, where each sequence is a list of integers
|
| 222 |
representing subword tokens from the SentencePiece vocabulary
|
| 223 |
"""
|
| 224 |
tokenized = []
|
| 225 |
+
|
| 226 |
for text in texts:
|
| 227 |
try:
|
| 228 |
# Add BOS (Beginning of Sequence) token at the start
|
| 229 |
# BOS token ID=2 by default in SentencePiece, signals sequence start
|
| 230 |
# This helps the model learn proper sequence initialization during generation
|
| 231 |
tokens = [self.tokenizer.bos_id()] + self.tokenizer.encode(text)
|
| 232 |
+
|
| 233 |
# Truncate sequences that exceed maximum context length
|
| 234 |
# Reserve one position for EOS token by using (seq_len - 1)
|
| 235 |
# This ensures we never exceed the model's context window during training
|
| 236 |
if len(tokens) > self.seq_len - 1:
|
| 237 |
+
tokens = tokens[: self.seq_len - 1]
|
| 238 |
# NOTE: Truncation may cut off text mid-sentence, but this is acceptable
|
| 239 |
# for language modeling where the model learns from partial contexts
|
| 240 |
+
|
| 241 |
# Add EOS (End of Sequence) token at the end
|
| 242 |
# EOS token ID=1 by default in SentencePiece, signals sequence completion
|
| 243 |
# This teaches the model when to stop generating text naturally
|
| 244 |
tokens.append(self.tokenizer.eos_id())
|
| 245 |
+
|
| 246 |
# Validate tokenization result
|
| 247 |
if len(tokens) <= 2: # Only BOS + EOS tokens, no actual content
|
| 248 |
print(f"β οΈ Skipping very short text: {text[:50]}...")
|
| 249 |
continue
|
| 250 |
+
|
| 251 |
tokenized.append(tokens)
|
| 252 |
+
|
| 253 |
except Exception as e:
|
| 254 |
# Handle tokenization errors gracefully to avoid stopping training
|
| 255 |
# Common causes: encoding issues, very long texts, special characters
|
| 256 |
print(f"β οΈ Failed to tokenize passage: {text[:50]}... Error: {e}")
|
| 257 |
continue
|
| 258 |
+
|
| 259 |
# Log tokenization statistics for monitoring
|
| 260 |
if tokenized:
|
| 261 |
avg_length = sum(len(tokens) for tokens in tokenized) / len(tokenized)
|
| 262 |
print(f"π Tokenized {len(tokenized)} passages, avg length: {avg_length:.1f} tokens")
|
| 263 |
+
|
| 264 |
return tokenized
|
| 265 |
+
|
| 266 |
+
def _create_training_examples(
|
| 267 |
+
self, token_sequences: List[List[int]]
|
| 268 |
+
) -> List[Tuple[List[int], List[int]]]:
|
| 269 |
"""
|
| 270 |
Create training examples with input and target sequences.
|
| 271 |
+
|
| 272 |
For autoregressive training, targets are inputs shifted by one position.
|
| 273 |
+
|
| 274 |
Args:
|
| 275 |
token_sequences: List of tokenized sequences
|
| 276 |
+
|
| 277 |
Returns:
|
| 278 |
List of (input_ids, target_ids) tuples
|
| 279 |
"""
|
| 280 |
examples = []
|
| 281 |
+
|
| 282 |
for tokens in token_sequences:
|
| 283 |
if len(tokens) < 2: # Need at least 2 tokens for input/target pair
|
| 284 |
continue
|
| 285 |
+
|
| 286 |
# For sequences longer than seq_len, create multiple examples with sliding window
|
| 287 |
if len(tokens) > self.seq_len:
|
| 288 |
# Create overlapping windows (50% overlap for better learning)
|
| 289 |
stride = self.seq_len // 2
|
| 290 |
for i in range(0, len(tokens) - self.seq_len, stride):
|
| 291 |
+
input_ids = tokens[i : i + self.seq_len]
|
| 292 |
+
target_ids = tokens[i + 1 : i + self.seq_len + 1]
|
| 293 |
examples.append((input_ids, target_ids))
|
| 294 |
else:
|
| 295 |
# Pad shorter sequences
|
| 296 |
input_ids = tokens[:-1] # All but last token
|
| 297 |
target_ids = tokens[1:] # All but first token
|
| 298 |
+
|
| 299 |
# Pad to seq_len if necessary
|
| 300 |
while len(input_ids) < self.seq_len:
|
| 301 |
input_ids.append(self.tokenizer.pad_id())
|
| 302 |
target_ids.append(-1) # Use -1 for padding in targets (ignored in loss)
|
| 303 |
+
|
| 304 |
# Truncate if still too long
|
| 305 |
+
input_ids = input_ids[: self.seq_len]
|
| 306 |
+
target_ids = target_ids[: self.seq_len]
|
| 307 |
+
|
| 308 |
examples.append((input_ids, target_ids))
|
| 309 |
+
|
| 310 |
return examples
|
| 311 |
+
|
| 312 |
+
def _create_batch(
|
| 313 |
+
self, examples: List[Tuple[List[int], List[int]]]
|
| 314 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 315 |
"""
|
| 316 |
Create a batch tensor from training examples.
|
| 317 |
+
|
| 318 |
Args:
|
| 319 |
examples: List of (input_ids, target_ids) tuples
|
| 320 |
+
|
| 321 |
Returns:
|
| 322 |
Tuple of (input_tensor, target_tensor)
|
| 323 |
"""
|
| 324 |
if not examples:
|
| 325 |
raise ValueError("Cannot create batch from empty examples")
|
| 326 |
+
|
| 327 |
batch_size = len(examples)
|
| 328 |
+
|
| 329 |
# Initialize tensors
|
| 330 |
input_ids = torch.zeros((batch_size, self.seq_len), dtype=torch.long)
|
| 331 |
target_ids = torch.full((batch_size, self.seq_len), -1, dtype=torch.long)
|
| 332 |
+
|
| 333 |
# Fill tensors
|
| 334 |
for i, (inp, tgt) in enumerate(examples):
|
| 335 |
+
input_ids[i, : len(inp)] = torch.tensor(inp, dtype=torch.long)
|
| 336 |
+
target_ids[i, : len(tgt)] = torch.tensor(tgt, dtype=torch.long)
|
| 337 |
+
|
| 338 |
return input_ids, target_ids
|
| 339 |
+
|
| 340 |
def __iter__(self) -> Iterator[Tuple[torch.Tensor, torch.Tensor]]:
|
| 341 |
"""
|
| 342 |
Iterate over training batches.
|
| 343 |
+
|
| 344 |
Yields:
|
| 345 |
Tuple of (input_ids, target_ids) tensors
|
| 346 |
"""
|
| 347 |
self.current_line = 0
|
| 348 |
+
|
| 349 |
while self.current_line < self.total_lines:
|
| 350 |
# Read chunk of text
|
| 351 |
texts = self._read_chunk(self.current_line)
|
| 352 |
if not texts:
|
| 353 |
break
|
| 354 |
+
|
| 355 |
# Tokenize texts
|
| 356 |
token_sequences = self._tokenize_texts(texts)
|
| 357 |
+
|
| 358 |
# Create training examples
|
| 359 |
examples = self._create_training_examples(token_sequences)
|
| 360 |
+
|
| 361 |
# Shuffle examples if requested
|
| 362 |
if self.shuffle:
|
| 363 |
random.shuffle(examples)
|
| 364 |
+
|
| 365 |
# Create batches
|
| 366 |
for i in range(0, len(examples), self.batch_size):
|
| 367 |
+
batch_examples = examples[i : i + self.batch_size]
|
| 368 |
+
|
| 369 |
if len(batch_examples) == self.batch_size: # Only yield full batches
|
| 370 |
try:
|
| 371 |
input_ids, target_ids = self._create_batch(batch_examples)
|
|
|
|
| 373 |
except Exception as e:
|
| 374 |
print(f"β οΈ Failed to create batch: {e}")
|
| 375 |
continue
|
| 376 |
+
|
| 377 |
# Update progress
|
| 378 |
self.current_line += len(texts)
|
| 379 |
+
|
| 380 |
# Clean up memory
|
| 381 |
del texts, token_sequences, examples
|
| 382 |
gc.collect()
|
| 383 |
+
|
| 384 |
def get_data_stats(self) -> dict:
|
| 385 |
"""
|
| 386 |
Get statistics about the training data.
|
| 387 |
+
|
| 388 |
Returns:
|
| 389 |
Dictionary with data statistics
|
| 390 |
"""
|
| 391 |
print("π Analyzing training data...")
|
| 392 |
+
|
| 393 |
# Sample some data to get statistics
|
| 394 |
sample_texts = self._read_chunk(0)[:100] # Sample first 100 passages
|
| 395 |
token_sequences = self._tokenize_texts(sample_texts)
|
| 396 |
+
|
| 397 |
if token_sequences:
|
| 398 |
sequence_lengths = [len(seq) for seq in token_sequences]
|
| 399 |
avg_length = sum(sequence_lengths) / len(sequence_lengths)
|
|
|
|
| 401 |
min_length = min(sequence_lengths)
|
| 402 |
else:
|
| 403 |
avg_length = max_length = min_length = 0
|
| 404 |
+
|
| 405 |
# Estimate total tokens
|
| 406 |
estimated_total_tokens = int(avg_length * self.total_lines)
|
| 407 |
+
|
| 408 |
# Estimate number of batches per epoch
|
| 409 |
examples_per_passage = max(1, avg_length // self.seq_len)
|
| 410 |
total_examples = int(self.total_lines * examples_per_passage)
|
| 411 |
batches_per_epoch = total_examples // self.batch_size
|
| 412 |
+
|
| 413 |
stats = {
|
| 414 |
"total_passages": self.total_lines,
|
| 415 |
"avg_tokens_per_passage": avg_length,
|
|
|
|
| 420 |
"estimated_batches_per_epoch": batches_per_epoch,
|
| 421 |
"sequence_length": self.seq_len,
|
| 422 |
"batch_size": self.batch_size,
|
| 423 |
+
"vocabulary_size": self.tokenizer.vocab_size(),
|
| 424 |
}
|
| 425 |
+
|
| 426 |
print(f"β Data analysis complete:")
|
| 427 |
print(f" Total passages: {stats['total_passages']:,}")
|
| 428 |
print(f" Avg tokens per passage: {stats['avg_tokens_per_passage']:.1f}")
|
| 429 |
print(f" Estimated total tokens: {stats['estimated_total_tokens']:,}")
|
| 430 |
print(f" Estimated batches per epoch: {stats['estimated_batches_per_epoch']:,}")
|
| 431 |
+
|
| 432 |
return stats
|
| 433 |
|
| 434 |
|
| 435 |
def test_data_loader():
|
| 436 |
"""Test function for the data loader."""
|
| 437 |
print("π§ͺ Testing TextDataLoader...")
|
| 438 |
+
|
| 439 |
# Test with small parameters
|
| 440 |
try:
|
| 441 |
loader = TextDataLoader(
|
|
|
|
| 443 |
tokenizer_path="data/tokenizer/tokenizer.model",
|
| 444 |
seq_len=128,
|
| 445 |
batch_size=2,
|
| 446 |
+
chunk_size=10, # Small for testing
|
| 447 |
)
|
| 448 |
+
|
| 449 |
# Get data statistics
|
| 450 |
stats = loader.get_data_stats()
|
| 451 |
+
|
| 452 |
# Test iteration
|
| 453 |
print("\nπ Testing batch iteration...")
|
| 454 |
start_time = time.time()
|
| 455 |
batch_count = 0
|
| 456 |
+
|
| 457 |
for batch_idx, (input_ids, target_ids) in enumerate(loader):
|
| 458 |
batch_count += 1
|
| 459 |
+
|
| 460 |
print(f"Batch {batch_idx + 1}:")
|
| 461 |
print(f" Input shape: {input_ids.shape}")
|
| 462 |
print(f" Target shape: {target_ids.shape}")
|
| 463 |
print(f" Sample input tokens: {input_ids[0][:10].tolist()}")
|
| 464 |
print(f" Sample target tokens: {target_ids[0][:10].tolist()}")
|
| 465 |
+
|
| 466 |
if batch_idx >= 2: # Only test first few batches
|
| 467 |
break
|
| 468 |
+
|
| 469 |
test_time = time.time() - start_time
|
| 470 |
print(f"\nβ Data loader test completed successfully!")
|
| 471 |
print(f" Processed {batch_count} batches in {test_time:.2f}s")
|
| 472 |
print(f" Average time per batch: {test_time/max(1, batch_count):.2f}s")
|
| 473 |
+
|
| 474 |
return True
|
| 475 |
+
|
| 476 |
except Exception as e:
|
| 477 |
print(f"β Data loader test failed: {e}")
|
| 478 |
import traceback
|
| 479 |
+
|
| 480 |
traceback.print_exc()
|
| 481 |
return False
|
| 482 |
|
| 483 |
|
| 484 |
if __name__ == "__main__":
|
| 485 |
+
test_data_loader()
|
training/evaluate_model.py
CHANGED
|
@@ -56,20 +56,15 @@ from data_loader import TextDataLoader
|
|
| 56 |
class ModelEvaluator:
|
| 57 |
"""
|
| 58 |
Comprehensive evaluator for OpenLLM models.
|
| 59 |
-
|
| 60 |
Implements intrinsic evaluation metrics and text generation quality
|
| 61 |
assessment following the training pipeline specifications.
|
| 62 |
"""
|
| 63 |
-
|
| 64 |
-
def __init__(
|
| 65 |
-
self,
|
| 66 |
-
model: GPTModel,
|
| 67 |
-
tokenizer_path: str,
|
| 68 |
-
device: str = "cpu"
|
| 69 |
-
):
|
| 70 |
"""
|
| 71 |
Initialize the model evaluator.
|
| 72 |
-
|
| 73 |
Args:
|
| 74 |
model: Trained GPT model
|
| 75 |
tokenizer_path: Path to tokenizer model file
|
|
@@ -77,30 +72,27 @@ class ModelEvaluator:
|
|
| 77 |
"""
|
| 78 |
self.model = model.to(device)
|
| 79 |
self.device = device
|
| 80 |
-
|
| 81 |
# Load tokenizer
|
| 82 |
self.tokenizer = smp.SentencePieceProcessor()
|
| 83 |
self.tokenizer.load(tokenizer_path)
|
| 84 |
-
|
| 85 |
print(f"π§ ModelEvaluator initialized")
|
| 86 |
print(f" Device: {device}")
|
| 87 |
print(f" Model parameters: {model.get_num_params():,}")
|
| 88 |
print(f" Vocabulary size: {self.tokenizer.vocab_size():,}")
|
| 89 |
-
|
| 90 |
def evaluate_perplexity(
|
| 91 |
-
self,
|
| 92 |
-
eval_data: List[str],
|
| 93 |
-
max_seq_len: int = 512,
|
| 94 |
-
batch_size: int = 1
|
| 95 |
) -> Dict[str, float]:
|
| 96 |
"""
|
| 97 |
Calculate perplexity on evaluation data.
|
| 98 |
-
|
| 99 |
Args:
|
| 100 |
eval_data: List of text passages for evaluation
|
| 101 |
max_seq_len: Maximum sequence length for evaluation
|
| 102 |
batch_size: Batch size for evaluation
|
| 103 |
-
|
| 104 |
Returns:
|
| 105 |
Dictionary with loss and perplexity metrics
|
| 106 |
"""
|
|
@@ -108,400 +100,394 @@ class ModelEvaluator:
|
|
| 108 |
total_loss = 0.0
|
| 109 |
total_tokens = 0
|
| 110 |
num_sequences = 0
|
| 111 |
-
|
| 112 |
print(f"π Calculating perplexity on {len(eval_data)} passages...")
|
| 113 |
-
|
| 114 |
with torch.no_grad():
|
| 115 |
for i, text in enumerate(eval_data):
|
| 116 |
if i % 100 == 0:
|
| 117 |
print(f" Progress: {i}/{len(eval_data)} passages")
|
| 118 |
-
|
| 119 |
# Tokenize text
|
| 120 |
tokens = self.tokenizer.encode(text)
|
| 121 |
if len(tokens) < 2:
|
| 122 |
continue
|
| 123 |
-
|
| 124 |
# Truncate if too long
|
| 125 |
if len(tokens) > max_seq_len:
|
| 126 |
tokens = tokens[:max_seq_len]
|
| 127 |
-
|
| 128 |
# Create input and target tensors
|
| 129 |
input_ids = torch.tensor([tokens[:-1]], dtype=torch.long, device=self.device)
|
| 130 |
target_ids = torch.tensor([tokens[1:]], dtype=torch.long, device=self.device)
|
| 131 |
-
|
| 132 |
# Forward pass
|
| 133 |
logits, loss = self.model(input_ids, target_ids)
|
| 134 |
-
|
| 135 |
# Accumulate loss
|
| 136 |
seq_length = len(tokens) - 1
|
| 137 |
total_loss += loss.item() * seq_length
|
| 138 |
total_tokens += seq_length
|
| 139 |
num_sequences += 1
|
| 140 |
-
|
| 141 |
# Calculate metrics
|
| 142 |
-
avg_loss = total_loss / total_tokens if total_tokens > 0 else float(
|
| 143 |
perplexity = math.exp(min(avg_loss, 10)) # Cap to prevent overflow
|
| 144 |
-
|
| 145 |
return {
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
}
|
| 151 |
-
|
| 152 |
def evaluate_text_generation(
|
| 153 |
self,
|
| 154 |
prompts: List[str],
|
| 155 |
max_length: int = 256,
|
| 156 |
temperature: float = 0.7,
|
| 157 |
top_k: Optional[int] = 40,
|
| 158 |
-
num_samples: int = 1
|
| 159 |
) -> List[Dict[str, Any]]:
|
| 160 |
"""
|
| 161 |
Evaluate text generation quality.
|
| 162 |
-
|
| 163 |
Args:
|
| 164 |
prompts: List of input prompts
|
| 165 |
max_length: Maximum generation length
|
| 166 |
temperature: Sampling temperature
|
| 167 |
top_k: Top-k sampling parameter
|
| 168 |
num_samples: Number of samples per prompt
|
| 169 |
-
|
| 170 |
Returns:
|
| 171 |
List of generation results with quality metrics
|
| 172 |
"""
|
| 173 |
self.model.eval()
|
| 174 |
results = []
|
| 175 |
-
|
| 176 |
print(f"βοΈ Evaluating text generation on {len(prompts)} prompts...")
|
| 177 |
-
|
| 178 |
with torch.no_grad():
|
| 179 |
for prompt in prompts:
|
| 180 |
prompt_results = []
|
| 181 |
-
|
| 182 |
for sample_idx in range(num_samples):
|
| 183 |
# Tokenize prompt
|
| 184 |
input_ids = self.tokenizer.encode(prompt)
|
| 185 |
input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.device)
|
| 186 |
-
|
| 187 |
start_time = time.time()
|
| 188 |
-
|
| 189 |
# Generate text
|
| 190 |
output = self.model.generate(
|
| 191 |
input_tensor,
|
| 192 |
max_new_tokens=max_length,
|
| 193 |
temperature=temperature,
|
| 194 |
-
top_k=top_k
|
| 195 |
)
|
| 196 |
-
|
| 197 |
generation_time = time.time() - start_time
|
| 198 |
-
|
| 199 |
# Decode output
|
| 200 |
generated_ids = output[0].tolist()
|
| 201 |
full_text = self.tokenizer.decode(generated_ids)
|
| 202 |
-
generated_text = self.tokenizer.decode(generated_ids[len(input_ids):])
|
| 203 |
-
|
| 204 |
# Calculate quality metrics
|
| 205 |
quality_metrics = self._assess_generation_quality(generated_text)
|
| 206 |
-
|
| 207 |
-
prompt_results.append(
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
|
|
|
|
|
|
|
|
|
| 217 |
results.extend(prompt_results)
|
| 218 |
-
|
| 219 |
return results
|
| 220 |
-
|
| 221 |
def _assess_generation_quality(self, text: str) -> Dict[str, float]:
|
| 222 |
"""
|
| 223 |
Assess basic quality metrics for generated text.
|
| 224 |
-
|
| 225 |
Args:
|
| 226 |
text: Generated text to assess
|
| 227 |
-
|
| 228 |
Returns:
|
| 229 |
Dictionary of quality metrics
|
| 230 |
"""
|
| 231 |
if not text.strip():
|
| 232 |
return {
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
}
|
| 238 |
-
|
| 239 |
words = text.split()
|
| 240 |
-
|
| 241 |
# Basic metrics
|
| 242 |
length = len(words)
|
| 243 |
avg_word_length = sum(len(word) for word in words) / len(words) if words else 0
|
| 244 |
-
|
| 245 |
# Repetition rate (simple n-gram repetition)
|
| 246 |
-
bigrams = [f"{words[i]} {words[i+1]}" for i in range(len(words)-1)]
|
| 247 |
unique_bigrams = len(set(bigrams))
|
| 248 |
repetition_rate = 1 - (unique_bigrams / len(bigrams) if bigrams else 0)
|
| 249 |
-
|
| 250 |
# Simple coherence score (based on sentence structure)
|
| 251 |
-
sentences = text.split(
|
| 252 |
valid_sentences = [s for s in sentences if len(s.strip().split()) > 3]
|
| 253 |
coherence_score = len(valid_sentences) / len(sentences) if sentences else 0
|
| 254 |
-
|
| 255 |
return {
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
}
|
| 261 |
-
|
| 262 |
def evaluate_downstream_tasks(self) -> Dict[str, Any]:
|
| 263 |
"""
|
| 264 |
Evaluate model performance on downstream tasks.
|
| 265 |
-
|
| 266 |
This function implements basic downstream task evaluation including:
|
| 267 |
- Reading comprehension (simplified SQUAD-style)
|
| 268 |
- Sentiment analysis (few-shot)
|
| 269 |
- Common sense reasoning
|
| 270 |
-
|
| 271 |
Returns:
|
| 272 |
Dictionary of downstream task results
|
| 273 |
"""
|
| 274 |
results = {}
|
| 275 |
-
|
| 276 |
# 1. Reading Comprehension (Simplified SQUAD-style)
|
| 277 |
-
results[
|
| 278 |
-
|
| 279 |
# 2. Sentiment Analysis (Few-shot learning)
|
| 280 |
-
results[
|
| 281 |
-
|
| 282 |
# 3. Common Sense Reasoning
|
| 283 |
-
results[
|
| 284 |
-
|
| 285 |
# 4. Text Completion Quality
|
| 286 |
-
results[
|
| 287 |
-
|
| 288 |
return results
|
| 289 |
-
|
| 290 |
def _evaluate_reading_comprehension(self) -> Dict[str, Any]:
|
| 291 |
"""Simplified reading comprehension evaluation."""
|
| 292 |
# Sample reading comprehension tasks
|
| 293 |
tasks = [
|
| 294 |
{
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
},
|
| 299 |
{
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
},
|
| 304 |
{
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
}
|
| 309 |
]
|
| 310 |
-
|
| 311 |
correct = 0
|
| 312 |
total = len(tasks)
|
| 313 |
-
|
| 314 |
for task in tasks:
|
| 315 |
prompt = f"Context: {task['context']}\nQuestion: {task['question']}\nAnswer:"
|
| 316 |
-
|
| 317 |
# Generate answer
|
| 318 |
input_ids = self.tokenizer.encode(prompt)
|
| 319 |
input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.device)
|
| 320 |
-
|
| 321 |
with torch.no_grad():
|
| 322 |
output = self.model.generate(input_tensor, max_new_tokens=20, temperature=0.1)
|
| 323 |
-
|
| 324 |
generated_ids = output[0].tolist()
|
| 325 |
-
answer = self.tokenizer.decode(generated_ids[len(input_ids):]).strip().lower()
|
| 326 |
-
|
| 327 |
# Simple substring matching
|
| 328 |
-
if task[
|
| 329 |
correct += 1
|
| 330 |
-
|
| 331 |
return {
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
}
|
| 337 |
-
|
| 338 |
def _evaluate_sentiment_analysis(self) -> Dict[str, Any]:
|
| 339 |
"""Few-shot sentiment analysis evaluation."""
|
| 340 |
# Few-shot examples
|
| 341 |
examples = "Examples:\nText: 'I love this movie!' Sentiment: Positive\nText: 'This is terrible.' Sentiment: Negative\nText: 'It was okay.' Sentiment: Neutral\n\n"
|
| 342 |
-
|
| 343 |
# Test cases
|
| 344 |
test_cases = [
|
| 345 |
-
{
|
| 346 |
-
{
|
| 347 |
-
{
|
| 348 |
-
{
|
| 349 |
-
{
|
| 350 |
]
|
| 351 |
-
|
| 352 |
correct = 0
|
| 353 |
total = len(test_cases)
|
| 354 |
-
|
| 355 |
for case in test_cases:
|
| 356 |
prompt = f"{examples}Text: '{case['text']}' Sentiment:"
|
| 357 |
-
|
| 358 |
# Generate sentiment
|
| 359 |
input_ids = self.tokenizer.encode(prompt)
|
| 360 |
input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.device)
|
| 361 |
-
|
| 362 |
with torch.no_grad():
|
| 363 |
output = self.model.generate(input_tensor, max_new_tokens=5, temperature=0.1)
|
| 364 |
-
|
| 365 |
generated_ids = output[0].tolist()
|
| 366 |
-
sentiment = self.tokenizer.decode(generated_ids[len(input_ids):]).strip().lower()
|
| 367 |
-
|
| 368 |
# Check if expected sentiment is in the generated response
|
| 369 |
-
if case[
|
| 370 |
correct += 1
|
| 371 |
-
|
| 372 |
return {
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
}
|
| 378 |
-
|
| 379 |
def _evaluate_reasoning(self) -> Dict[str, Any]:
|
| 380 |
"""Simple reasoning evaluation."""
|
| 381 |
# Basic reasoning tasks
|
| 382 |
tasks = [
|
| 383 |
{
|
| 384 |
-
|
| 385 |
-
|
| 386 |
},
|
| 387 |
{
|
| 388 |
-
|
| 389 |
-
|
| 390 |
},
|
| 391 |
-
{
|
| 392 |
-
|
| 393 |
-
'expected': 'tuesday'
|
| 394 |
-
},
|
| 395 |
-
{
|
| 396 |
-
'question': 'Is the sun larger than the earth?',
|
| 397 |
-
'expected': 'yes'
|
| 398 |
-
}
|
| 399 |
]
|
| 400 |
-
|
| 401 |
correct = 0
|
| 402 |
total = len(tasks)
|
| 403 |
-
|
| 404 |
for task in tasks:
|
| 405 |
prompt = f"Question: {task['question']}\nAnswer:"
|
| 406 |
-
|
| 407 |
# Generate answer
|
| 408 |
input_ids = self.tokenizer.encode(prompt)
|
| 409 |
input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.device)
|
| 410 |
-
|
| 411 |
with torch.no_grad():
|
| 412 |
output = self.model.generate(input_tensor, max_new_tokens=10, temperature=0.1)
|
| 413 |
-
|
| 414 |
generated_ids = output[0].tolist()
|
| 415 |
-
answer = self.tokenizer.decode(generated_ids[len(input_ids):]).strip().lower()
|
| 416 |
-
|
| 417 |
# Check if expected answer is in the response
|
| 418 |
-
if task[
|
| 419 |
correct += 1
|
| 420 |
-
|
| 421 |
return {
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
}
|
| 427 |
-
|
| 428 |
def _evaluate_text_completion(self) -> Dict[str, Any]:
|
| 429 |
"""Evaluate text completion quality."""
|
| 430 |
# Common phrases that should be completed predictably
|
| 431 |
completions = [
|
| 432 |
-
{
|
| 433 |
-
{
|
| 434 |
-
{
|
| 435 |
-
{
|
| 436 |
]
|
| 437 |
-
|
| 438 |
correct = 0
|
| 439 |
total = len(completions)
|
| 440 |
-
|
| 441 |
for completion in completions:
|
| 442 |
# Generate completion
|
| 443 |
-
input_ids = self.tokenizer.encode(completion[
|
| 444 |
input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.device)
|
| 445 |
-
|
| 446 |
with torch.no_grad():
|
| 447 |
output = self.model.generate(input_tensor, max_new_tokens=5, temperature=0.1)
|
| 448 |
-
|
| 449 |
generated_ids = output[0].tolist()
|
| 450 |
-
generated_text = self.tokenizer.decode(generated_ids[len(input_ids):]).strip().lower()
|
| 451 |
-
|
| 452 |
# Check if expected word appears in completion
|
| 453 |
-
if completion[
|
| 454 |
correct += 1
|
| 455 |
-
|
| 456 |
return {
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
}
|
| 462 |
-
|
| 463 |
def run_comprehensive_evaluation(
|
| 464 |
-
self,
|
| 465 |
-
eval_data_path: str,
|
| 466 |
-
metrics: List[str] = None,
|
| 467 |
-
generation_prompts: List[str] = None
|
| 468 |
) -> Dict[str, Any]:
|
| 469 |
"""
|
| 470 |
Run comprehensive model evaluation.
|
| 471 |
-
|
| 472 |
Args:
|
| 473 |
eval_data_path: Path to evaluation text file
|
| 474 |
metrics: List of metrics to compute
|
| 475 |
generation_prompts: Prompts for text generation evaluation
|
| 476 |
-
|
| 477 |
Returns:
|
| 478 |
Complete evaluation results
|
| 479 |
"""
|
| 480 |
if metrics is None:
|
| 481 |
-
metrics = [
|
| 482 |
-
|
| 483 |
if generation_prompts is None:
|
| 484 |
generation_prompts = [
|
| 485 |
"The history of artificial intelligence",
|
| 486 |
"Machine learning algorithms",
|
| 487 |
"The future of technology",
|
| 488 |
"In a world where",
|
| 489 |
-
"Scientists have discovered"
|
| 490 |
]
|
| 491 |
-
|
| 492 |
results = {
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
},
|
| 498 |
-
|
| 499 |
}
|
| 500 |
-
|
| 501 |
# Load evaluation data
|
| 502 |
print(f"π Loading evaluation data from {eval_data_path}")
|
| 503 |
if os.path.exists(eval_data_path):
|
| 504 |
-
with open(eval_data_path,
|
| 505 |
eval_texts = [line.strip() for line in f if line.strip()]
|
| 506 |
else:
|
| 507 |
print(f"β οΈ Evaluation file not found, using sample texts")
|
|
@@ -510,96 +496,103 @@ class ModelEvaluator:
|
|
| 510 |
"Machine learning algorithms can learn patterns from data automatically.",
|
| 511 |
"Natural language processing helps computers understand human language.",
|
| 512 |
"Deep learning uses neural networks with multiple layers for complex tasks.",
|
| 513 |
-
"The development of large language models has transformed AI applications."
|
| 514 |
]
|
| 515 |
-
|
| 516 |
# Intrinsic evaluation
|
| 517 |
-
if
|
| 518 |
perplexity_results = self.evaluate_perplexity(eval_texts)
|
| 519 |
-
results[
|
| 520 |
-
|
| 521 |
# Text generation evaluation
|
| 522 |
-
if
|
| 523 |
generation_results = self.evaluate_text_generation(generation_prompts)
|
| 524 |
-
results[
|
| 525 |
-
|
| 526 |
-
|
| 527 |
}
|
| 528 |
-
|
| 529 |
# Downstream tasks (placeholder)
|
| 530 |
-
results[
|
| 531 |
-
|
| 532 |
# Overall quality assessment
|
| 533 |
-
results[
|
| 534 |
-
|
| 535 |
return results
|
| 536 |
-
|
| 537 |
def _summarize_generation_results(self, results: List[Dict[str, Any]]) -> Dict[str, float]:
|
| 538 |
"""Summarize text generation results."""
|
| 539 |
if not results:
|
| 540 |
return {}
|
| 541 |
-
|
| 542 |
-
total_time = sum(r[
|
| 543 |
-
total_tokens = sum(r[
|
| 544 |
-
|
| 545 |
-
quality_metrics = [r[
|
| 546 |
-
|
| 547 |
return {
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
|
|
|
|
|
|
| 553 |
}
|
| 554 |
-
|
| 555 |
def _assess_overall_quality(self, results: Dict[str, Any]) -> Dict[str, Any]:
|
| 556 |
"""Assess overall model quality based on evaluation results."""
|
| 557 |
-
assessment = {
|
| 558 |
-
|
| 559 |
-
'recommendations': []
|
| 560 |
-
}
|
| 561 |
-
|
| 562 |
# Check intrinsic metrics
|
| 563 |
-
if
|
| 564 |
-
perplexity = results[
|
| 565 |
-
|
| 566 |
if perplexity < 12:
|
| 567 |
-
assessment[
|
| 568 |
-
assessment[
|
| 569 |
elif perplexity < 50:
|
| 570 |
-
assessment[
|
| 571 |
-
assessment[
|
|
|
|
|
|
|
| 572 |
else:
|
| 573 |
-
assessment[
|
| 574 |
-
assessment[
|
| 575 |
-
|
|
|
|
|
|
|
| 576 |
# Check generation quality
|
| 577 |
-
if
|
| 578 |
-
summary = results[
|
| 579 |
-
repetition_rate = summary.get(
|
| 580 |
-
coherence_score = summary.get(
|
| 581 |
-
|
| 582 |
if repetition_rate > 0.7:
|
| 583 |
-
assessment[
|
|
|
|
|
|
|
| 584 |
if coherence_score < 0.3:
|
| 585 |
-
assessment[
|
| 586 |
-
|
|
|
|
|
|
|
| 587 |
return assessment
|
| 588 |
|
| 589 |
|
| 590 |
def load_model_from_directory(model_dir: str, device: str = "cpu") -> Tuple[GPTModel, str]:
|
| 591 |
"""
|
| 592 |
Load model from directory containing checkpoints.
|
| 593 |
-
|
| 594 |
Args:
|
| 595 |
model_dir: Directory containing model files
|
| 596 |
device: Device to load model on
|
| 597 |
-
|
| 598 |
Returns:
|
| 599 |
Tuple of (model, tokenizer_path)
|
| 600 |
"""
|
| 601 |
model_dir = Path(model_dir)
|
| 602 |
-
|
| 603 |
# Find best model checkpoint
|
| 604 |
best_model_path = model_dir / "best_model.pt"
|
| 605 |
if not best_model_path.exists():
|
|
@@ -607,41 +600,41 @@ def load_model_from_directory(model_dir: str, device: str = "cpu") -> Tuple[GPTM
|
|
| 607 |
checkpoints = list(model_dir.glob("checkpoint_step_*.pt"))
|
| 608 |
if not checkpoints:
|
| 609 |
raise FileNotFoundError(f"No model checkpoints found in {model_dir}")
|
| 610 |
-
|
| 611 |
# Get latest checkpoint
|
| 612 |
-
latest_checkpoint = max(checkpoints, key=lambda p: int(p.stem.split(
|
| 613 |
best_model_path = latest_checkpoint
|
| 614 |
-
|
| 615 |
print(f"π Loading model from {best_model_path}")
|
| 616 |
-
|
| 617 |
# Load checkpoint
|
| 618 |
checkpoint = torch.load(best_model_path, map_location=device)
|
| 619 |
-
|
| 620 |
# Determine model size from config
|
| 621 |
-
config = checkpoint.get(
|
| 622 |
-
n_layer = config.get(
|
| 623 |
-
|
| 624 |
if n_layer <= 6:
|
| 625 |
model_size = "small"
|
| 626 |
elif n_layer <= 12:
|
| 627 |
model_size = "medium"
|
| 628 |
else:
|
| 629 |
model_size = "large"
|
| 630 |
-
|
| 631 |
# Create and load model
|
| 632 |
model = create_model(model_size)
|
| 633 |
-
model.load_state_dict(checkpoint[
|
| 634 |
-
|
| 635 |
print(f"β
Model loaded successfully ({model_size}, {model.get_num_params():,} parameters)")
|
| 636 |
-
|
| 637 |
# Find tokenizer
|
| 638 |
tokenizer_path = model_dir.parent / "tokenizer" / "tokenizer.model"
|
| 639 |
if not tokenizer_path.exists():
|
| 640 |
tokenizer_path = Path("data/tokenizer/tokenizer.model")
|
| 641 |
-
|
| 642 |
if not tokenizer_path.exists():
|
| 643 |
raise FileNotFoundError(f"Tokenizer not found at {tokenizer_path}")
|
| 644 |
-
|
| 645 |
return model, str(tokenizer_path)
|
| 646 |
|
| 647 |
|
|
@@ -662,121 +655,115 @@ Examples:
|
|
| 662 |
--model_dir models/small-extended-4k \\
|
| 663 |
--metrics perplexity,generation \\
|
| 664 |
--output results.json
|
| 665 |
-
"""
|
| 666 |
)
|
| 667 |
-
|
| 668 |
-
parser.add_argument(
|
| 669 |
-
|
| 670 |
-
required=True,
|
| 671 |
-
help="Directory containing trained model"
|
| 672 |
-
)
|
| 673 |
-
|
| 674 |
parser.add_argument(
|
| 675 |
-
"--eval_data",
|
| 676 |
-
help="Path to evaluation text file (default: use sample texts)"
|
| 677 |
)
|
| 678 |
-
|
| 679 |
parser.add_argument(
|
| 680 |
"--metrics",
|
| 681 |
default="perplexity,loss,generation",
|
| 682 |
-
help="Comma-separated list of metrics to evaluate (default: perplexity,loss,generation)"
|
| 683 |
)
|
| 684 |
-
|
| 685 |
-
parser.add_argument(
|
| 686 |
-
|
| 687 |
-
help="Output JSON file for results (default: print to console)"
|
| 688 |
-
)
|
| 689 |
-
|
| 690 |
parser.add_argument(
|
| 691 |
"--device",
|
| 692 |
choices=["cpu", "cuda", "auto"],
|
| 693 |
default="auto",
|
| 694 |
-
help="Device for evaluation (default: auto)"
|
| 695 |
)
|
| 696 |
-
|
| 697 |
parser.add_argument(
|
| 698 |
-
"--generation_prompts",
|
| 699 |
-
help="File containing prompts for text generation evaluation"
|
| 700 |
)
|
| 701 |
-
|
| 702 |
args = parser.parse_args()
|
| 703 |
-
|
| 704 |
print("π OpenLLM Model Evaluation")
|
| 705 |
print("=" * 50)
|
| 706 |
-
|
| 707 |
# Determine device
|
| 708 |
if args.device == "auto":
|
| 709 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 710 |
else:
|
| 711 |
device = args.device
|
| 712 |
-
|
| 713 |
print(f"Using device: {device}")
|
| 714 |
-
|
| 715 |
try:
|
| 716 |
# Load model
|
| 717 |
model, tokenizer_path = load_model_from_directory(args.model_dir, device)
|
| 718 |
-
|
| 719 |
# Create evaluator
|
| 720 |
evaluator = ModelEvaluator(model, tokenizer_path, device)
|
| 721 |
-
|
| 722 |
# Parse metrics
|
| 723 |
-
metrics = [m.strip() for m in args.metrics.split(
|
| 724 |
-
|
| 725 |
# Load generation prompts if specified
|
| 726 |
generation_prompts = None
|
| 727 |
if args.generation_prompts and os.path.exists(args.generation_prompts):
|
| 728 |
-
with open(args.generation_prompts,
|
| 729 |
generation_prompts = [line.strip() for line in f if line.strip()]
|
| 730 |
-
|
| 731 |
# Run evaluation
|
| 732 |
eval_data_path = args.eval_data or "data/clean/training_data.txt"
|
| 733 |
results = evaluator.run_comprehensive_evaluation(
|
| 734 |
eval_data_path, metrics, generation_prompts
|
| 735 |
)
|
| 736 |
-
|
| 737 |
# Output results
|
| 738 |
if args.output:
|
| 739 |
-
with open(args.output,
|
| 740 |
json.dump(results, f, indent=2)
|
| 741 |
print(f"\nπΎ Results saved to {args.output}")
|
| 742 |
else:
|
| 743 |
print(f"\nπ Evaluation Results:")
|
| 744 |
print("=" * 50)
|
| 745 |
-
|
| 746 |
# Print key metrics
|
| 747 |
-
if
|
| 748 |
-
intrinsic = results[
|
| 749 |
print(f"π Intrinsic Metrics:")
|
| 750 |
print(f" Loss: {intrinsic['loss']:.4f}")
|
| 751 |
print(f" Perplexity: {intrinsic['perplexity']:.2f}")
|
| 752 |
print(f" Sequences evaluated: {intrinsic['num_sequences']:,}")
|
| 753 |
-
|
| 754 |
-
if
|
| 755 |
-
gen_summary = results[
|
| 756 |
print(f"\nβοΈ Generation Quality:")
|
| 757 |
-
print(
|
|
|
|
|
|
|
| 758 |
print(f" Avg text length: {gen_summary['avg_length']:.1f} words")
|
| 759 |
print(f" Repetition rate: {gen_summary['avg_repetition_rate']:.3f}")
|
| 760 |
print(f" Coherence score: {gen_summary['avg_coherence_score']:.3f}")
|
| 761 |
-
|
| 762 |
# Quality assessment
|
| 763 |
-
if
|
| 764 |
-
assessment = results[
|
| 765 |
print(f"\nπ― Overall Assessment:")
|
| 766 |
print(f" Quality Level: {assessment['quality_level'].upper()}")
|
| 767 |
-
for rec in assessment[
|
| 768 |
print(f" β’ {rec}")
|
| 769 |
-
|
| 770 |
print(f"\nπ Evaluation completed successfully!")
|
| 771 |
-
|
| 772 |
except Exception as e:
|
| 773 |
print(f"\nβ Evaluation failed: {e}")
|
| 774 |
import traceback
|
|
|
|
| 775 |
traceback.print_exc()
|
| 776 |
return False
|
| 777 |
-
|
| 778 |
return True
|
| 779 |
|
| 780 |
|
| 781 |
if __name__ == "__main__":
|
| 782 |
-
main()
|
|
|
|
| 56 |
class ModelEvaluator:
|
| 57 |
"""
|
| 58 |
Comprehensive evaluator for OpenLLM models.
|
| 59 |
+
|
| 60 |
Implements intrinsic evaluation metrics and text generation quality
|
| 61 |
assessment following the training pipeline specifications.
|
| 62 |
"""
|
| 63 |
+
|
| 64 |
+
def __init__(self, model: GPTModel, tokenizer_path: str, device: str = "cpu"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
"""
|
| 66 |
Initialize the model evaluator.
|
| 67 |
+
|
| 68 |
Args:
|
| 69 |
model: Trained GPT model
|
| 70 |
tokenizer_path: Path to tokenizer model file
|
|
|
|
| 72 |
"""
|
| 73 |
self.model = model.to(device)
|
| 74 |
self.device = device
|
| 75 |
+
|
| 76 |
# Load tokenizer
|
| 77 |
self.tokenizer = smp.SentencePieceProcessor()
|
| 78 |
self.tokenizer.load(tokenizer_path)
|
| 79 |
+
|
| 80 |
print(f"π§ ModelEvaluator initialized")
|
| 81 |
print(f" Device: {device}")
|
| 82 |
print(f" Model parameters: {model.get_num_params():,}")
|
| 83 |
print(f" Vocabulary size: {self.tokenizer.vocab_size():,}")
|
| 84 |
+
|
| 85 |
def evaluate_perplexity(
|
| 86 |
+
self, eval_data: List[str], max_seq_len: int = 512, batch_size: int = 1
|
|
|
|
|
|
|
|
|
|
| 87 |
) -> Dict[str, float]:
|
| 88 |
"""
|
| 89 |
Calculate perplexity on evaluation data.
|
| 90 |
+
|
| 91 |
Args:
|
| 92 |
eval_data: List of text passages for evaluation
|
| 93 |
max_seq_len: Maximum sequence length for evaluation
|
| 94 |
batch_size: Batch size for evaluation
|
| 95 |
+
|
| 96 |
Returns:
|
| 97 |
Dictionary with loss and perplexity metrics
|
| 98 |
"""
|
|
|
|
| 100 |
total_loss = 0.0
|
| 101 |
total_tokens = 0
|
| 102 |
num_sequences = 0
|
| 103 |
+
|
| 104 |
print(f"π Calculating perplexity on {len(eval_data)} passages...")
|
| 105 |
+
|
| 106 |
with torch.no_grad():
|
| 107 |
for i, text in enumerate(eval_data):
|
| 108 |
if i % 100 == 0:
|
| 109 |
print(f" Progress: {i}/{len(eval_data)} passages")
|
| 110 |
+
|
| 111 |
# Tokenize text
|
| 112 |
tokens = self.tokenizer.encode(text)
|
| 113 |
if len(tokens) < 2:
|
| 114 |
continue
|
| 115 |
+
|
| 116 |
# Truncate if too long
|
| 117 |
if len(tokens) > max_seq_len:
|
| 118 |
tokens = tokens[:max_seq_len]
|
| 119 |
+
|
| 120 |
# Create input and target tensors
|
| 121 |
input_ids = torch.tensor([tokens[:-1]], dtype=torch.long, device=self.device)
|
| 122 |
target_ids = torch.tensor([tokens[1:]], dtype=torch.long, device=self.device)
|
| 123 |
+
|
| 124 |
# Forward pass
|
| 125 |
logits, loss = self.model(input_ids, target_ids)
|
| 126 |
+
|
| 127 |
# Accumulate loss
|
| 128 |
seq_length = len(tokens) - 1
|
| 129 |
total_loss += loss.item() * seq_length
|
| 130 |
total_tokens += seq_length
|
| 131 |
num_sequences += 1
|
| 132 |
+
|
| 133 |
# Calculate metrics
|
| 134 |
+
avg_loss = total_loss / total_tokens if total_tokens > 0 else float("inf")
|
| 135 |
perplexity = math.exp(min(avg_loss, 10)) # Cap to prevent overflow
|
| 136 |
+
|
| 137 |
return {
|
| 138 |
+
"loss": avg_loss,
|
| 139 |
+
"perplexity": perplexity,
|
| 140 |
+
"total_tokens": total_tokens,
|
| 141 |
+
"num_sequences": num_sequences,
|
| 142 |
}
|
| 143 |
+
|
| 144 |
def evaluate_text_generation(
|
| 145 |
self,
|
| 146 |
prompts: List[str],
|
| 147 |
max_length: int = 256,
|
| 148 |
temperature: float = 0.7,
|
| 149 |
top_k: Optional[int] = 40,
|
| 150 |
+
num_samples: int = 1,
|
| 151 |
) -> List[Dict[str, Any]]:
|
| 152 |
"""
|
| 153 |
Evaluate text generation quality.
|
| 154 |
+
|
| 155 |
Args:
|
| 156 |
prompts: List of input prompts
|
| 157 |
max_length: Maximum generation length
|
| 158 |
temperature: Sampling temperature
|
| 159 |
top_k: Top-k sampling parameter
|
| 160 |
num_samples: Number of samples per prompt
|
| 161 |
+
|
| 162 |
Returns:
|
| 163 |
List of generation results with quality metrics
|
| 164 |
"""
|
| 165 |
self.model.eval()
|
| 166 |
results = []
|
| 167 |
+
|
| 168 |
print(f"βοΈ Evaluating text generation on {len(prompts)} prompts...")
|
| 169 |
+
|
| 170 |
with torch.no_grad():
|
| 171 |
for prompt in prompts:
|
| 172 |
prompt_results = []
|
| 173 |
+
|
| 174 |
for sample_idx in range(num_samples):
|
| 175 |
# Tokenize prompt
|
| 176 |
input_ids = self.tokenizer.encode(prompt)
|
| 177 |
input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.device)
|
| 178 |
+
|
| 179 |
start_time = time.time()
|
| 180 |
+
|
| 181 |
# Generate text
|
| 182 |
output = self.model.generate(
|
| 183 |
input_tensor,
|
| 184 |
max_new_tokens=max_length,
|
| 185 |
temperature=temperature,
|
| 186 |
+
top_k=top_k,
|
| 187 |
)
|
| 188 |
+
|
| 189 |
generation_time = time.time() - start_time
|
| 190 |
+
|
| 191 |
# Decode output
|
| 192 |
generated_ids = output[0].tolist()
|
| 193 |
full_text = self.tokenizer.decode(generated_ids)
|
| 194 |
+
generated_text = self.tokenizer.decode(generated_ids[len(input_ids) :])
|
| 195 |
+
|
| 196 |
# Calculate quality metrics
|
| 197 |
quality_metrics = self._assess_generation_quality(generated_text)
|
| 198 |
+
|
| 199 |
+
prompt_results.append(
|
| 200 |
+
{
|
| 201 |
+
"prompt": prompt,
|
| 202 |
+
"generated_text": generated_text,
|
| 203 |
+
"full_text": full_text,
|
| 204 |
+
"generation_time": generation_time,
|
| 205 |
+
"tokens_generated": len(generated_ids) - len(input_ids),
|
| 206 |
+
"tokens_per_second": (len(generated_ids) - len(input_ids))
|
| 207 |
+
/ generation_time,
|
| 208 |
+
"quality_metrics": quality_metrics,
|
| 209 |
+
}
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
results.extend(prompt_results)
|
| 213 |
+
|
| 214 |
return results
|
| 215 |
+
|
| 216 |
def _assess_generation_quality(self, text: str) -> Dict[str, float]:
|
| 217 |
"""
|
| 218 |
Assess basic quality metrics for generated text.
|
| 219 |
+
|
| 220 |
Args:
|
| 221 |
text: Generated text to assess
|
| 222 |
+
|
| 223 |
Returns:
|
| 224 |
Dictionary of quality metrics
|
| 225 |
"""
|
| 226 |
if not text.strip():
|
| 227 |
return {
|
| 228 |
+
"length": 0,
|
| 229 |
+
"avg_word_length": 0,
|
| 230 |
+
"repetition_rate": 1.0,
|
| 231 |
+
"coherence_score": 0.0,
|
| 232 |
}
|
| 233 |
+
|
| 234 |
words = text.split()
|
| 235 |
+
|
| 236 |
# Basic metrics
|
| 237 |
length = len(words)
|
| 238 |
avg_word_length = sum(len(word) for word in words) / len(words) if words else 0
|
| 239 |
+
|
| 240 |
# Repetition rate (simple n-gram repetition)
|
| 241 |
+
bigrams = [f"{words[i]} {words[i+1]}" for i in range(len(words) - 1)]
|
| 242 |
unique_bigrams = len(set(bigrams))
|
| 243 |
repetition_rate = 1 - (unique_bigrams / len(bigrams) if bigrams else 0)
|
| 244 |
+
|
| 245 |
# Simple coherence score (based on sentence structure)
|
| 246 |
+
sentences = text.split(".")
|
| 247 |
valid_sentences = [s for s in sentences if len(s.strip().split()) > 3]
|
| 248 |
coherence_score = len(valid_sentences) / len(sentences) if sentences else 0
|
| 249 |
+
|
| 250 |
return {
|
| 251 |
+
"length": length,
|
| 252 |
+
"avg_word_length": avg_word_length,
|
| 253 |
+
"repetition_rate": repetition_rate,
|
| 254 |
+
"coherence_score": coherence_score,
|
| 255 |
}
|
| 256 |
+
|
| 257 |
def evaluate_downstream_tasks(self) -> Dict[str, Any]:
|
| 258 |
"""
|
| 259 |
Evaluate model performance on downstream tasks.
|
| 260 |
+
|
| 261 |
This function implements basic downstream task evaluation including:
|
| 262 |
- Reading comprehension (simplified SQUAD-style)
|
| 263 |
- Sentiment analysis (few-shot)
|
| 264 |
- Common sense reasoning
|
| 265 |
+
|
| 266 |
Returns:
|
| 267 |
Dictionary of downstream task results
|
| 268 |
"""
|
| 269 |
results = {}
|
| 270 |
+
|
| 271 |
# 1. Reading Comprehension (Simplified SQUAD-style)
|
| 272 |
+
results["reading_comprehension"] = self._evaluate_reading_comprehension()
|
| 273 |
+
|
| 274 |
# 2. Sentiment Analysis (Few-shot learning)
|
| 275 |
+
results["sentiment_analysis"] = self._evaluate_sentiment_analysis()
|
| 276 |
+
|
| 277 |
# 3. Common Sense Reasoning
|
| 278 |
+
results["reasoning"] = self._evaluate_reasoning()
|
| 279 |
+
|
| 280 |
# 4. Text Completion Quality
|
| 281 |
+
results["text_completion"] = self._evaluate_text_completion()
|
| 282 |
+
|
| 283 |
return results
|
| 284 |
+
|
| 285 |
def _evaluate_reading_comprehension(self) -> Dict[str, Any]:
|
| 286 |
"""Simplified reading comprehension evaluation."""
|
| 287 |
# Sample reading comprehension tasks
|
| 288 |
tasks = [
|
| 289 |
{
|
| 290 |
+
"context": "The Eiffel Tower is a wrought-iron lattice tower on the Champ de Mars in Paris, France. It is named after the engineer Gustave Eiffel, whose company designed and built the tower.",
|
| 291 |
+
"question": "Who is the Eiffel Tower named after?",
|
| 292 |
+
"expected": "Gustave Eiffel",
|
| 293 |
},
|
| 294 |
{
|
| 295 |
+
"context": "Python is a high-level programming language. It was created by Guido van Rossum and first released in 1991.",
|
| 296 |
+
"question": "When was Python first released?",
|
| 297 |
+
"expected": "1991",
|
| 298 |
},
|
| 299 |
{
|
| 300 |
+
"context": "Machine learning is a subset of artificial intelligence that enables computers to learn without being explicitly programmed.",
|
| 301 |
+
"question": "What is machine learning a subset of?",
|
| 302 |
+
"expected": "artificial intelligence",
|
| 303 |
+
},
|
| 304 |
]
|
| 305 |
+
|
| 306 |
correct = 0
|
| 307 |
total = len(tasks)
|
| 308 |
+
|
| 309 |
for task in tasks:
|
| 310 |
prompt = f"Context: {task['context']}\nQuestion: {task['question']}\nAnswer:"
|
| 311 |
+
|
| 312 |
# Generate answer
|
| 313 |
input_ids = self.tokenizer.encode(prompt)
|
| 314 |
input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.device)
|
| 315 |
+
|
| 316 |
with torch.no_grad():
|
| 317 |
output = self.model.generate(input_tensor, max_new_tokens=20, temperature=0.1)
|
| 318 |
+
|
| 319 |
generated_ids = output[0].tolist()
|
| 320 |
+
answer = self.tokenizer.decode(generated_ids[len(input_ids) :]).strip().lower()
|
| 321 |
+
|
| 322 |
# Simple substring matching
|
| 323 |
+
if task["expected"].lower() in answer:
|
| 324 |
correct += 1
|
| 325 |
+
|
| 326 |
return {
|
| 327 |
+
"accuracy": correct / total,
|
| 328 |
+
"correct": correct,
|
| 329 |
+
"total": total,
|
| 330 |
+
"score": correct / total,
|
| 331 |
}
|
| 332 |
+
|
| 333 |
def _evaluate_sentiment_analysis(self) -> Dict[str, Any]:
|
| 334 |
"""Few-shot sentiment analysis evaluation."""
|
| 335 |
# Few-shot examples
|
| 336 |
examples = "Examples:\nText: 'I love this movie!' Sentiment: Positive\nText: 'This is terrible.' Sentiment: Negative\nText: 'It was okay.' Sentiment: Neutral\n\n"
|
| 337 |
+
|
| 338 |
# Test cases
|
| 339 |
test_cases = [
|
| 340 |
+
{"text": "This is amazing!", "expected": "positive"},
|
| 341 |
+
{"text": "I hate this.", "expected": "negative"},
|
| 342 |
+
{"text": "This is wonderful.", "expected": "positive"},
|
| 343 |
+
{"text": "This is awful.", "expected": "negative"},
|
| 344 |
+
{"text": "It was fine.", "expected": "neutral"},
|
| 345 |
]
|
| 346 |
+
|
| 347 |
correct = 0
|
| 348 |
total = len(test_cases)
|
| 349 |
+
|
| 350 |
for case in test_cases:
|
| 351 |
prompt = f"{examples}Text: '{case['text']}' Sentiment:"
|
| 352 |
+
|
| 353 |
# Generate sentiment
|
| 354 |
input_ids = self.tokenizer.encode(prompt)
|
| 355 |
input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.device)
|
| 356 |
+
|
| 357 |
with torch.no_grad():
|
| 358 |
output = self.model.generate(input_tensor, max_new_tokens=5, temperature=0.1)
|
| 359 |
+
|
| 360 |
generated_ids = output[0].tolist()
|
| 361 |
+
sentiment = self.tokenizer.decode(generated_ids[len(input_ids) :]).strip().lower()
|
| 362 |
+
|
| 363 |
# Check if expected sentiment is in the generated response
|
| 364 |
+
if case["expected"] in sentiment:
|
| 365 |
correct += 1
|
| 366 |
+
|
| 367 |
return {
|
| 368 |
+
"accuracy": correct / total,
|
| 369 |
+
"correct": correct,
|
| 370 |
+
"total": total,
|
| 371 |
+
"score": correct / total,
|
| 372 |
}
|
| 373 |
+
|
| 374 |
def _evaluate_reasoning(self) -> Dict[str, Any]:
|
| 375 |
"""Simple reasoning evaluation."""
|
| 376 |
# Basic reasoning tasks
|
| 377 |
tasks = [
|
| 378 |
{
|
| 379 |
+
"question": "If all birds can fly and a penguin is a bird, can a penguin fly?",
|
| 380 |
+
"expected": "no", # This tests if model knows real-world facts
|
| 381 |
},
|
| 382 |
{
|
| 383 |
+
"question": "If it is raining outside, should you take an umbrella?",
|
| 384 |
+
"expected": "yes",
|
| 385 |
},
|
| 386 |
+
{"question": "What comes after Monday?", "expected": "tuesday"},
|
| 387 |
+
{"question": "Is the sun larger than the earth?", "expected": "yes"},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 388 |
]
|
| 389 |
+
|
| 390 |
correct = 0
|
| 391 |
total = len(tasks)
|
| 392 |
+
|
| 393 |
for task in tasks:
|
| 394 |
prompt = f"Question: {task['question']}\nAnswer:"
|
| 395 |
+
|
| 396 |
# Generate answer
|
| 397 |
input_ids = self.tokenizer.encode(prompt)
|
| 398 |
input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.device)
|
| 399 |
+
|
| 400 |
with torch.no_grad():
|
| 401 |
output = self.model.generate(input_tensor, max_new_tokens=10, temperature=0.1)
|
| 402 |
+
|
| 403 |
generated_ids = output[0].tolist()
|
| 404 |
+
answer = self.tokenizer.decode(generated_ids[len(input_ids) :]).strip().lower()
|
| 405 |
+
|
| 406 |
# Check if expected answer is in the response
|
| 407 |
+
if task["expected"] in answer:
|
| 408 |
correct += 1
|
| 409 |
+
|
| 410 |
return {
|
| 411 |
+
"accuracy": correct / total,
|
| 412 |
+
"correct": correct,
|
| 413 |
+
"total": total,
|
| 414 |
+
"score": correct / total,
|
| 415 |
}
|
| 416 |
+
|
| 417 |
def _evaluate_text_completion(self) -> Dict[str, Any]:
|
| 418 |
"""Evaluate text completion quality."""
|
| 419 |
# Common phrases that should be completed predictably
|
| 420 |
completions = [
|
| 421 |
+
{"prompt": "The capital of France is", "expected_word": "paris"},
|
| 422 |
+
{"prompt": "Two plus two equals", "expected_word": "four"},
|
| 423 |
+
{"prompt": "The largest planet in our solar system is", "expected_word": "jupiter"},
|
| 424 |
+
{"prompt": "Water boils at", "expected_word": "100"},
|
| 425 |
]
|
| 426 |
+
|
| 427 |
correct = 0
|
| 428 |
total = len(completions)
|
| 429 |
+
|
| 430 |
for completion in completions:
|
| 431 |
# Generate completion
|
| 432 |
+
input_ids = self.tokenizer.encode(completion["prompt"])
|
| 433 |
input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.device)
|
| 434 |
+
|
| 435 |
with torch.no_grad():
|
| 436 |
output = self.model.generate(input_tensor, max_new_tokens=5, temperature=0.1)
|
| 437 |
+
|
| 438 |
generated_ids = output[0].tolist()
|
| 439 |
+
generated_text = self.tokenizer.decode(generated_ids[len(input_ids) :]).strip().lower()
|
| 440 |
+
|
| 441 |
# Check if expected word appears in completion
|
| 442 |
+
if completion["expected_word"] in generated_text:
|
| 443 |
correct += 1
|
| 444 |
+
|
| 445 |
return {
|
| 446 |
+
"accuracy": correct / total,
|
| 447 |
+
"correct": correct,
|
| 448 |
+
"total": total,
|
| 449 |
+
"score": correct / total,
|
| 450 |
}
|
| 451 |
+
|
| 452 |
def run_comprehensive_evaluation(
|
| 453 |
+
self, eval_data_path: str, metrics: List[str] = None, generation_prompts: List[str] = None
|
|
|
|
|
|
|
|
|
|
| 454 |
) -> Dict[str, Any]:
|
| 455 |
"""
|
| 456 |
Run comprehensive model evaluation.
|
| 457 |
+
|
| 458 |
Args:
|
| 459 |
eval_data_path: Path to evaluation text file
|
| 460 |
metrics: List of metrics to compute
|
| 461 |
generation_prompts: Prompts for text generation evaluation
|
| 462 |
+
|
| 463 |
Returns:
|
| 464 |
Complete evaluation results
|
| 465 |
"""
|
| 466 |
if metrics is None:
|
| 467 |
+
metrics = ["perplexity", "loss", "generation"]
|
| 468 |
+
|
| 469 |
if generation_prompts is None:
|
| 470 |
generation_prompts = [
|
| 471 |
"The history of artificial intelligence",
|
| 472 |
"Machine learning algorithms",
|
| 473 |
"The future of technology",
|
| 474 |
"In a world where",
|
| 475 |
+
"Scientists have discovered",
|
| 476 |
]
|
| 477 |
+
|
| 478 |
results = {
|
| 479 |
+
"model_info": {
|
| 480 |
+
"parameters": self.model.get_num_params(),
|
| 481 |
+
"device": self.device,
|
| 482 |
+
"vocab_size": self.tokenizer.vocab_size(),
|
| 483 |
},
|
| 484 |
+
"evaluation_timestamp": time.time(),
|
| 485 |
}
|
| 486 |
+
|
| 487 |
# Load evaluation data
|
| 488 |
print(f"π Loading evaluation data from {eval_data_path}")
|
| 489 |
if os.path.exists(eval_data_path):
|
| 490 |
+
with open(eval_data_path, "r", encoding="utf-8") as f:
|
| 491 |
eval_texts = [line.strip() for line in f if line.strip()]
|
| 492 |
else:
|
| 493 |
print(f"β οΈ Evaluation file not found, using sample texts")
|
|
|
|
| 496 |
"Machine learning algorithms can learn patterns from data automatically.",
|
| 497 |
"Natural language processing helps computers understand human language.",
|
| 498 |
"Deep learning uses neural networks with multiple layers for complex tasks.",
|
| 499 |
+
"The development of large language models has transformed AI applications.",
|
| 500 |
]
|
| 501 |
+
|
| 502 |
# Intrinsic evaluation
|
| 503 |
+
if "perplexity" in metrics or "loss" in metrics:
|
| 504 |
perplexity_results = self.evaluate_perplexity(eval_texts)
|
| 505 |
+
results["intrinsic_evaluation"] = perplexity_results
|
| 506 |
+
|
| 507 |
# Text generation evaluation
|
| 508 |
+
if "generation" in metrics:
|
| 509 |
generation_results = self.evaluate_text_generation(generation_prompts)
|
| 510 |
+
results["generation_evaluation"] = {
|
| 511 |
+
"results": generation_results,
|
| 512 |
+
"summary": self._summarize_generation_results(generation_results),
|
| 513 |
}
|
| 514 |
+
|
| 515 |
# Downstream tasks (placeholder)
|
| 516 |
+
results["downstream_evaluation"] = self.evaluate_downstream_tasks()
|
| 517 |
+
|
| 518 |
# Overall quality assessment
|
| 519 |
+
results["quality_assessment"] = self._assess_overall_quality(results)
|
| 520 |
+
|
| 521 |
return results
|
| 522 |
+
|
| 523 |
def _summarize_generation_results(self, results: List[Dict[str, Any]]) -> Dict[str, float]:
|
| 524 |
"""Summarize text generation results."""
|
| 525 |
if not results:
|
| 526 |
return {}
|
| 527 |
+
|
| 528 |
+
total_time = sum(r["generation_time"] for r in results)
|
| 529 |
+
total_tokens = sum(r["tokens_generated"] for r in results)
|
| 530 |
+
|
| 531 |
+
quality_metrics = [r["quality_metrics"] for r in results]
|
| 532 |
+
|
| 533 |
return {
|
| 534 |
+
"avg_generation_time": total_time / len(results),
|
| 535 |
+
"avg_tokens_per_second": total_tokens / total_time if total_time > 0 else 0,
|
| 536 |
+
"avg_length": sum(q["length"] for q in quality_metrics) / len(quality_metrics),
|
| 537 |
+
"avg_repetition_rate": sum(q["repetition_rate"] for q in quality_metrics)
|
| 538 |
+
/ len(quality_metrics),
|
| 539 |
+
"avg_coherence_score": sum(q["coherence_score"] for q in quality_metrics)
|
| 540 |
+
/ len(quality_metrics),
|
| 541 |
}
|
| 542 |
+
|
| 543 |
def _assess_overall_quality(self, results: Dict[str, Any]) -> Dict[str, Any]:
|
| 544 |
"""Assess overall model quality based on evaluation results."""
|
| 545 |
+
assessment = {"quality_level": "unknown", "recommendations": []}
|
| 546 |
+
|
|
|
|
|
|
|
|
|
|
| 547 |
# Check intrinsic metrics
|
| 548 |
+
if "intrinsic_evaluation" in results:
|
| 549 |
+
perplexity = results["intrinsic_evaluation"].get("perplexity", float("inf"))
|
| 550 |
+
|
| 551 |
if perplexity < 12:
|
| 552 |
+
assessment["quality_level"] = "good"
|
| 553 |
+
assessment["recommendations"].append("Model shows good perplexity scores")
|
| 554 |
elif perplexity < 50:
|
| 555 |
+
assessment["quality_level"] = "fair"
|
| 556 |
+
assessment["recommendations"].append(
|
| 557 |
+
"Model shows fair performance, could benefit from more training"
|
| 558 |
+
)
|
| 559 |
else:
|
| 560 |
+
assessment["quality_level"] = "poor"
|
| 561 |
+
assessment["recommendations"].append(
|
| 562 |
+
"Model needs significant more training or data improvements"
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
# Check generation quality
|
| 566 |
+
if "generation_evaluation" in results:
|
| 567 |
+
summary = results["generation_evaluation"].get("summary", {})
|
| 568 |
+
repetition_rate = summary.get("avg_repetition_rate", 1.0)
|
| 569 |
+
coherence_score = summary.get("avg_coherence_score", 0.0)
|
| 570 |
+
|
| 571 |
if repetition_rate > 0.7:
|
| 572 |
+
assessment["recommendations"].append(
|
| 573 |
+
"High repetition rate - consider training longer or adjusting data"
|
| 574 |
+
)
|
| 575 |
if coherence_score < 0.3:
|
| 576 |
+
assessment["recommendations"].append(
|
| 577 |
+
"Low coherence - model may need more training steps"
|
| 578 |
+
)
|
| 579 |
+
|
| 580 |
return assessment
|
| 581 |
|
| 582 |
|
| 583 |
def load_model_from_directory(model_dir: str, device: str = "cpu") -> Tuple[GPTModel, str]:
|
| 584 |
"""
|
| 585 |
Load model from directory containing checkpoints.
|
| 586 |
+
|
| 587 |
Args:
|
| 588 |
model_dir: Directory containing model files
|
| 589 |
device: Device to load model on
|
| 590 |
+
|
| 591 |
Returns:
|
| 592 |
Tuple of (model, tokenizer_path)
|
| 593 |
"""
|
| 594 |
model_dir = Path(model_dir)
|
| 595 |
+
|
| 596 |
# Find best model checkpoint
|
| 597 |
best_model_path = model_dir / "best_model.pt"
|
| 598 |
if not best_model_path.exists():
|
|
|
|
| 600 |
checkpoints = list(model_dir.glob("checkpoint_step_*.pt"))
|
| 601 |
if not checkpoints:
|
| 602 |
raise FileNotFoundError(f"No model checkpoints found in {model_dir}")
|
| 603 |
+
|
| 604 |
# Get latest checkpoint
|
| 605 |
+
latest_checkpoint = max(checkpoints, key=lambda p: int(p.stem.split("_")[-1]))
|
| 606 |
best_model_path = latest_checkpoint
|
| 607 |
+
|
| 608 |
print(f"π Loading model from {best_model_path}")
|
| 609 |
+
|
| 610 |
# Load checkpoint
|
| 611 |
checkpoint = torch.load(best_model_path, map_location=device)
|
| 612 |
+
|
| 613 |
# Determine model size from config
|
| 614 |
+
config = checkpoint.get("config", {})
|
| 615 |
+
n_layer = config.get("n_layer", 12)
|
| 616 |
+
|
| 617 |
if n_layer <= 6:
|
| 618 |
model_size = "small"
|
| 619 |
elif n_layer <= 12:
|
| 620 |
model_size = "medium"
|
| 621 |
else:
|
| 622 |
model_size = "large"
|
| 623 |
+
|
| 624 |
# Create and load model
|
| 625 |
model = create_model(model_size)
|
| 626 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
| 627 |
+
|
| 628 |
print(f"β
Model loaded successfully ({model_size}, {model.get_num_params():,} parameters)")
|
| 629 |
+
|
| 630 |
# Find tokenizer
|
| 631 |
tokenizer_path = model_dir.parent / "tokenizer" / "tokenizer.model"
|
| 632 |
if not tokenizer_path.exists():
|
| 633 |
tokenizer_path = Path("data/tokenizer/tokenizer.model")
|
| 634 |
+
|
| 635 |
if not tokenizer_path.exists():
|
| 636 |
raise FileNotFoundError(f"Tokenizer not found at {tokenizer_path}")
|
| 637 |
+
|
| 638 |
return model, str(tokenizer_path)
|
| 639 |
|
| 640 |
|
|
|
|
| 655 |
--model_dir models/small-extended-4k \\
|
| 656 |
--metrics perplexity,generation \\
|
| 657 |
--output results.json
|
| 658 |
+
""",
|
| 659 |
)
|
| 660 |
+
|
| 661 |
+
parser.add_argument("--model_dir", required=True, help="Directory containing trained model")
|
| 662 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 663 |
parser.add_argument(
|
| 664 |
+
"--eval_data", help="Path to evaluation text file (default: use sample texts)"
|
|
|
|
| 665 |
)
|
| 666 |
+
|
| 667 |
parser.add_argument(
|
| 668 |
"--metrics",
|
| 669 |
default="perplexity,loss,generation",
|
| 670 |
+
help="Comma-separated list of metrics to evaluate (default: perplexity,loss,generation)",
|
| 671 |
)
|
| 672 |
+
|
| 673 |
+
parser.add_argument("--output", help="Output JSON file for results (default: print to console)")
|
| 674 |
+
|
|
|
|
|
|
|
|
|
|
| 675 |
parser.add_argument(
|
| 676 |
"--device",
|
| 677 |
choices=["cpu", "cuda", "auto"],
|
| 678 |
default="auto",
|
| 679 |
+
help="Device for evaluation (default: auto)",
|
| 680 |
)
|
| 681 |
+
|
| 682 |
parser.add_argument(
|
| 683 |
+
"--generation_prompts", help="File containing prompts for text generation evaluation"
|
|
|
|
| 684 |
)
|
| 685 |
+
|
| 686 |
args = parser.parse_args()
|
| 687 |
+
|
| 688 |
print("π OpenLLM Model Evaluation")
|
| 689 |
print("=" * 50)
|
| 690 |
+
|
| 691 |
# Determine device
|
| 692 |
if args.device == "auto":
|
| 693 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 694 |
else:
|
| 695 |
device = args.device
|
| 696 |
+
|
| 697 |
print(f"Using device: {device}")
|
| 698 |
+
|
| 699 |
try:
|
| 700 |
# Load model
|
| 701 |
model, tokenizer_path = load_model_from_directory(args.model_dir, device)
|
| 702 |
+
|
| 703 |
# Create evaluator
|
| 704 |
evaluator = ModelEvaluator(model, tokenizer_path, device)
|
| 705 |
+
|
| 706 |
# Parse metrics
|
| 707 |
+
metrics = [m.strip() for m in args.metrics.split(",")]
|
| 708 |
+
|
| 709 |
# Load generation prompts if specified
|
| 710 |
generation_prompts = None
|
| 711 |
if args.generation_prompts and os.path.exists(args.generation_prompts):
|
| 712 |
+
with open(args.generation_prompts, "r", encoding="utf-8") as f:
|
| 713 |
generation_prompts = [line.strip() for line in f if line.strip()]
|
| 714 |
+
|
| 715 |
# Run evaluation
|
| 716 |
eval_data_path = args.eval_data or "data/clean/training_data.txt"
|
| 717 |
results = evaluator.run_comprehensive_evaluation(
|
| 718 |
eval_data_path, metrics, generation_prompts
|
| 719 |
)
|
| 720 |
+
|
| 721 |
# Output results
|
| 722 |
if args.output:
|
| 723 |
+
with open(args.output, "w", encoding="utf-8") as f:
|
| 724 |
json.dump(results, f, indent=2)
|
| 725 |
print(f"\nπΎ Results saved to {args.output}")
|
| 726 |
else:
|
| 727 |
print(f"\nπ Evaluation Results:")
|
| 728 |
print("=" * 50)
|
| 729 |
+
|
| 730 |
# Print key metrics
|
| 731 |
+
if "intrinsic_evaluation" in results:
|
| 732 |
+
intrinsic = results["intrinsic_evaluation"]
|
| 733 |
print(f"π Intrinsic Metrics:")
|
| 734 |
print(f" Loss: {intrinsic['loss']:.4f}")
|
| 735 |
print(f" Perplexity: {intrinsic['perplexity']:.2f}")
|
| 736 |
print(f" Sequences evaluated: {intrinsic['num_sequences']:,}")
|
| 737 |
+
|
| 738 |
+
if "generation_evaluation" in results:
|
| 739 |
+
gen_summary = results["generation_evaluation"]["summary"]
|
| 740 |
print(f"\nβοΈ Generation Quality:")
|
| 741 |
+
print(
|
| 742 |
+
f" Avg generation speed: {gen_summary['avg_tokens_per_second']:.1f} tokens/sec"
|
| 743 |
+
)
|
| 744 |
print(f" Avg text length: {gen_summary['avg_length']:.1f} words")
|
| 745 |
print(f" Repetition rate: {gen_summary['avg_repetition_rate']:.3f}")
|
| 746 |
print(f" Coherence score: {gen_summary['avg_coherence_score']:.3f}")
|
| 747 |
+
|
| 748 |
# Quality assessment
|
| 749 |
+
if "quality_assessment" in results:
|
| 750 |
+
assessment = results["quality_assessment"]
|
| 751 |
print(f"\nπ― Overall Assessment:")
|
| 752 |
print(f" Quality Level: {assessment['quality_level'].upper()}")
|
| 753 |
+
for rec in assessment["recommendations"]:
|
| 754 |
print(f" β’ {rec}")
|
| 755 |
+
|
| 756 |
print(f"\nπ Evaluation completed successfully!")
|
| 757 |
+
|
| 758 |
except Exception as e:
|
| 759 |
print(f"\nβ Evaluation failed: {e}")
|
| 760 |
import traceback
|
| 761 |
+
|
| 762 |
traceback.print_exc()
|
| 763 |
return False
|
| 764 |
+
|
| 765 |
return True
|
| 766 |
|
| 767 |
|
| 768 |
if __name__ == "__main__":
|
| 769 |
+
main()
|
training/model.py
CHANGED
|
@@ -18,7 +18,7 @@ language modeling (next-token prediction).
|
|
| 18 |
|
| 19 |
ARCHITECTURE OVERVIEW:
|
| 20 |
- Token Embedding: Maps token IDs to dense vectors
|
| 21 |
-
- Positional Embedding: Adds position information to token embeddings
|
| 22 |
- Transformer Blocks: Stack of multi-head attention + feed-forward layers
|
| 23 |
- Layer Normalization: Pre-norm placement for training stability
|
| 24 |
- Output Head: Linear projection to vocabulary for next-token prediction
|
|
@@ -32,16 +32,16 @@ FEATURES:
|
|
| 32 |
|
| 33 |
Usage:
|
| 34 |
from model import GPTConfig, GPTModel
|
| 35 |
-
|
| 36 |
config = GPTConfig(vocab_size=32000, n_layer=12, n_head=12, n_embd=768)
|
| 37 |
model = GPTModel(config)
|
| 38 |
-
|
| 39 |
# Forward pass
|
| 40 |
logits = model(input_ids) # Shape: (batch_size, seq_len, vocab_size)
|
| 41 |
|
| 42 |
Hardware Requirements:
|
| 43 |
- Small Model (25M params): 4-8GB RAM, CPU/integrated GPU
|
| 44 |
-
- Medium Model (117M params): 8-16GB RAM, dedicated GPU recommended
|
| 45 |
- Large Model (350M params): 16GB+ RAM, high-end GPU required
|
| 46 |
|
| 47 |
Author: Louis Chua Bean Chong
|
|
@@ -60,43 +60,43 @@ from typing import Optional, Tuple
|
|
| 60 |
class GPTConfig:
|
| 61 |
"""
|
| 62 |
Configuration class for GPT model hyperparameters.
|
| 63 |
-
|
| 64 |
This class defines all the architectural parameters needed to instantiate
|
| 65 |
a GPT model. Use the provided class methods to get pre-configured setups
|
| 66 |
for different model sizes.
|
| 67 |
"""
|
| 68 |
-
|
| 69 |
# Model architecture
|
| 70 |
-
vocab_size: int = 32000
|
| 71 |
-
n_layer: int = 12
|
| 72 |
-
n_head: int = 12
|
| 73 |
-
n_embd: int = 768
|
| 74 |
-
|
| 75 |
# Sequence and context
|
| 76 |
-
block_size: int = 1024
|
| 77 |
-
|
| 78 |
# Training hyperparameters
|
| 79 |
-
dropout: float = 0.1
|
| 80 |
-
bias: bool = True
|
| 81 |
-
|
| 82 |
# Model size identifier
|
| 83 |
-
model_name: str = "gpt-medium"
|
| 84 |
-
|
| 85 |
@classmethod
|
| 86 |
-
def small(cls) ->
|
| 87 |
"""Small model configuration (~25M parameters) - Good for CPU training"""
|
| 88 |
return cls(
|
| 89 |
vocab_size=32000,
|
| 90 |
n_layer=6,
|
| 91 |
-
n_head=8,
|
| 92 |
n_embd=512,
|
| 93 |
block_size=1024,
|
| 94 |
dropout=0.1,
|
| 95 |
-
model_name="gpt-small"
|
| 96 |
)
|
| 97 |
-
|
| 98 |
-
@classmethod
|
| 99 |
-
def medium(cls) ->
|
| 100 |
"""Medium model configuration (~117M parameters) - Balanced performance"""
|
| 101 |
return cls(
|
| 102 |
vocab_size=32000,
|
|
@@ -105,11 +105,11 @@ class GPTConfig:
|
|
| 105 |
n_embd=768,
|
| 106 |
block_size=2048,
|
| 107 |
dropout=0.1,
|
| 108 |
-
model_name="gpt-medium"
|
| 109 |
)
|
| 110 |
-
|
| 111 |
@classmethod
|
| 112 |
-
def large(cls) ->
|
| 113 |
"""Large model configuration (~350M parameters) - High performance"""
|
| 114 |
return cls(
|
| 115 |
vocab_size=32000,
|
|
@@ -118,29 +118,29 @@ class GPTConfig:
|
|
| 118 |
n_embd=1024,
|
| 119 |
block_size=2048,
|
| 120 |
dropout=0.1,
|
| 121 |
-
model_name="gpt-large"
|
| 122 |
)
|
| 123 |
-
|
| 124 |
def estimate_parameters(self) -> int:
|
| 125 |
"""
|
| 126 |
Estimate the total number of trainable parameters.
|
| 127 |
-
|
| 128 |
Returns:
|
| 129 |
int: Estimated parameter count
|
| 130 |
"""
|
| 131 |
# Token embeddings
|
| 132 |
token_emb = self.vocab_size * self.n_embd
|
| 133 |
-
|
| 134 |
-
# Position embeddings
|
| 135 |
pos_emb = self.block_size * self.n_embd
|
| 136 |
-
|
| 137 |
# Transformer layers
|
| 138 |
# Each layer: attention (4 * n_embd^2) + mlp (8 * n_embd^2) + layer_norms
|
| 139 |
layer_params = self.n_layer * (12 * self.n_embd**2 + 4 * self.n_embd)
|
| 140 |
-
|
| 141 |
# Output head
|
| 142 |
output_head = self.vocab_size * self.n_embd
|
| 143 |
-
|
| 144 |
total = token_emb + pos_emb + layer_params + output_head
|
| 145 |
return total
|
| 146 |
|
|
@@ -148,75 +148,78 @@ class GPTConfig:
|
|
| 148 |
class CausalSelfAttention(nn.Module):
|
| 149 |
"""
|
| 150 |
Multi-head causal self-attention mechanism.
|
| 151 |
-
|
| 152 |
This implements the core attention mechanism of the transformer, with causal
|
| 153 |
masking to ensure autoregressive behavior (tokens can only attend to previous
|
| 154 |
tokens, not future ones).
|
| 155 |
"""
|
| 156 |
-
|
| 157 |
def __init__(self, config: GPTConfig):
|
| 158 |
super().__init__()
|
| 159 |
-
assert
|
| 160 |
-
|
|
|
|
|
|
|
| 161 |
self.config = config
|
| 162 |
self.n_head = config.n_head
|
| 163 |
self.n_embd = config.n_embd
|
| 164 |
self.head_dim = self.n_embd // self.n_head
|
| 165 |
-
|
| 166 |
# Key, query, value projections for all heads (batched)
|
| 167 |
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
|
| 168 |
-
|
| 169 |
# Output projection
|
| 170 |
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
|
| 171 |
-
|
| 172 |
# Dropout
|
| 173 |
self.attn_dropout = nn.Dropout(config.dropout)
|
| 174 |
self.resid_dropout = nn.Dropout(config.dropout)
|
| 175 |
-
|
| 176 |
# Causal mask - lower triangular matrix
|
| 177 |
self.register_buffer(
|
| 178 |
"bias",
|
| 179 |
-
torch.tril(torch.ones(config.block_size, config.block_size))
|
| 180 |
-
|
|
|
|
| 181 |
)
|
| 182 |
-
|
| 183 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 184 |
"""
|
| 185 |
Forward pass of causal self-attention.
|
| 186 |
-
|
| 187 |
This method implements the scaled dot-product attention mechanism with causal masking.
|
| 188 |
The attention mechanism allows each token to attend to all previous tokens in the sequence,
|
| 189 |
but not to future tokens, maintaining the autoregressive property essential for language modeling.
|
| 190 |
-
|
| 191 |
Mathematical formulation:
|
| 192 |
Attention(Q, K, V) = softmax(QK^T / sqrt(d_k))V
|
| 193 |
where Q, K, V are query, key, value matrices derived from input x
|
| 194 |
-
|
| 195 |
Implementation details:
|
| 196 |
- Uses batch matrix multiplication for efficiency
|
| 197 |
- Applies causal mask to prevent future token attention
|
| 198 |
- Implements multi-head attention by reshaping and parallel processing
|
| 199 |
- Applies dropout for regularization during training
|
| 200 |
-
|
| 201 |
Args:
|
| 202 |
x: Input tensor of shape (batch_size, seq_len, n_embd)
|
| 203 |
Contains embedded token representations from previous layer
|
| 204 |
-
|
| 205 |
Returns:
|
| 206 |
torch.Tensor: Output tensor of shape (batch_size, seq_len, n_embd)
|
| 207 |
"""
|
| 208 |
# Extract tensor dimensions for clear variable naming and validation
|
| 209 |
# B = batch size (number of sequences processed in parallel)
|
| 210 |
-
# T = sequence length (number of tokens in each sequence)
|
| 211 |
# C = embedding dimensionality (n_embd from config)
|
| 212 |
B, T, C = x.size()
|
| 213 |
-
|
| 214 |
# Generate query, key, and value projections for all attention heads
|
| 215 |
# The c_attn linear layer outputs 3 * n_embd features, which we split into Q, K, V
|
| 216 |
# This batched approach is more efficient than separate linear layers
|
| 217 |
# Input shape: (B, T, C) -> Output shape: (B, T, 3*C) -> Split to 3x (B, T, C)
|
| 218 |
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
|
| 219 |
-
|
| 220 |
# Reshape tensors for multi-head attention computation
|
| 221 |
# Transform from (B, T, C) to (B, nh, T, hs) where:
|
| 222 |
# - nh = number of heads (self.n_head)
|
|
@@ -225,41 +228,41 @@ class CausalSelfAttention(nn.Module):
|
|
| 225 |
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
|
| 226 |
k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
|
| 227 |
v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
|
| 228 |
-
|
| 229 |
# Compute scaled dot-product attention scores
|
| 230 |
# Matrix multiplication: Q @ K^T gives attention affinities between all token pairs
|
| 231 |
# Scaling by 1/sqrt(head_dim) prevents softmax saturation for large embedding dimensions
|
| 232 |
# Shape: (B, nh, T, hs) @ (B, nh, hs, T) -> (B, nh, T, T)
|
| 233 |
# The resulting (T, T) matrix represents attention weights from each token to every other token
|
| 234 |
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))
|
| 235 |
-
|
| 236 |
# Apply causal masking to enforce autoregressive property
|
| 237 |
# The causal mask ensures that token i can only attend to tokens j where j <= i
|
| 238 |
# This prevents the model from "cheating" by looking at future tokens during training
|
| 239 |
# We use -inf for masked positions so they become 0 after softmax
|
| 240 |
-
att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float(
|
| 241 |
-
|
| 242 |
# Convert attention scores to probabilities using softmax
|
| 243 |
# Each row of the attention matrix now sums to 1, representing a probability distribution
|
| 244 |
# over which tokens to attend to for each query position
|
| 245 |
att = F.softmax(att, dim=-1)
|
| 246 |
-
|
| 247 |
# Apply dropout to attention weights for regularization
|
| 248 |
# This randomly zeros some attention connections during training to prevent overfitting
|
| 249 |
att = self.attn_dropout(att)
|
| 250 |
-
|
| 251 |
# Apply attention weights to value vectors
|
| 252 |
# This weighted combination produces the actual output of the attention mechanism
|
| 253 |
# Shape: (B, nh, T, T) @ (B, nh, T, hs) -> (B, nh, T, hs)
|
| 254 |
# Each output position is a weighted sum of all value vectors, with weights from attention
|
| 255 |
y = att @ v
|
| 256 |
-
|
| 257 |
# Concatenate multi-head outputs back to original embedding dimension
|
| 258 |
# Transform from (B, nh, T, hs) back to (B, T, C) where C = nh * hs
|
| 259 |
# The transpose moves head dimension back, and contiguous() ensures memory layout efficiency
|
| 260 |
# This combines information from all attention heads into a single representation
|
| 261 |
y = y.transpose(1, 2).contiguous().view(B, T, C)
|
| 262 |
-
|
| 263 |
# Apply final output projection and residual dropout
|
| 264 |
# The output projection allows the model to learn how to best combine multi-head information
|
| 265 |
# Residual dropout provides additional regularization before the residual connection
|
|
@@ -270,62 +273,62 @@ class CausalSelfAttention(nn.Module):
|
|
| 270 |
class MLP(nn.Module):
|
| 271 |
"""
|
| 272 |
Multi-Layer Perceptron (Feed-Forward Network) for Transformer.
|
| 273 |
-
|
| 274 |
This implements the position-wise feed-forward network that appears in each transformer layer.
|
| 275 |
The MLP provides additional non-linear transformation capacity beyond what attention provides.
|
| 276 |
-
|
| 277 |
Architecture:
|
| 278 |
Input -> Linear(n_embd -> 4*n_embd) -> GELU -> Linear(4*n_embd -> n_embd) -> Dropout -> Output
|
| 279 |
-
|
| 280 |
Design rationale:
|
| 281 |
- 4x expansion is standard in transformers (from "Attention Is All You Need")
|
| 282 |
- GELU activation provides smoother gradients than ReLU for language modeling
|
| 283 |
- Dropout prevents overfitting in the feed-forward layers
|
| 284 |
- Two linear layers allow complex non-linear transformations of attention outputs
|
| 285 |
-
|
| 286 |
Parameters:
|
| 287 |
- First linear layer: n_embd * 4*n_embd parameters (expansion)
|
| 288 |
- Second linear layer: 4*n_embd * n_embd parameters (projection back)
|
| 289 |
- Total: 8 * n_embd^2 parameters (significant portion of model size)
|
| 290 |
"""
|
| 291 |
-
|
| 292 |
def __init__(self, config: GPTConfig):
|
| 293 |
super().__init__()
|
| 294 |
-
|
| 295 |
# First linear layer: expand embedding dimension by 4x
|
| 296 |
# This expansion gives the network more representational capacity
|
| 297 |
# The 4x factor is a standard choice that balances capacity vs efficiency
|
| 298 |
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
|
| 299 |
-
|
| 300 |
# GELU (Gaussian Error Linear Unit) activation function
|
| 301 |
# GELU provides smoother gradients compared to ReLU and works better for language modeling
|
| 302 |
# It's approximately: GELU(x) = x * Ξ¦(x) where Ξ¦ is the CDF of standard normal distribution
|
| 303 |
self.gelu = nn.GELU()
|
| 304 |
-
|
| 305 |
# Second linear layer: project back to original embedding dimension
|
| 306 |
# This projection allows the network to combine information from the expanded representation
|
| 307 |
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
|
| 308 |
-
|
| 309 |
# Dropout for regularization in the feed-forward network
|
| 310 |
# Applied after the final projection to prevent overfitting
|
| 311 |
self.dropout = nn.Dropout(config.dropout)
|
| 312 |
-
|
| 313 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 314 |
"""
|
| 315 |
Forward pass of the feed-forward network.
|
| 316 |
-
|
| 317 |
This method applies a two-layer MLP with GELU activation to transform
|
| 318 |
the attention outputs. The MLP operates independently on each position
|
| 319 |
in the sequence, providing position-wise non-linear transformations.
|
| 320 |
-
|
| 321 |
Mathematical operation:
|
| 322 |
MLP(x) = Dropout(Linearβ(GELU(Linearβ(x))))
|
| 323 |
where Linearβ: R^n_embd -> R^4*n_embd and Linearβ: R^4*n_embd -> R^n_embd
|
| 324 |
-
|
| 325 |
Args:
|
| 326 |
x: Input tensor of shape (batch_size, seq_len, n_embd)
|
| 327 |
Contains attended representations from the attention layer
|
| 328 |
-
|
| 329 |
Returns:
|
| 330 |
torch.Tensor: Output tensor of shape (batch_size, seq_len, n_embd)
|
| 331 |
Contains transformed representations ready for residual connection
|
|
@@ -334,112 +337,116 @@ class MLP(nn.Module):
|
|
| 334 |
# This expansion provides the network with a higher-dimensional space for computation
|
| 335 |
# Shape: (batch_size, seq_len, n_embd) -> (batch_size, seq_len, 4*n_embd)
|
| 336 |
x = self.c_fc(x)
|
| 337 |
-
|
| 338 |
# Apply GELU activation function for non-linearity
|
| 339 |
# GELU is smoother than ReLU and provides better gradients for language modeling
|
| 340 |
# It introduces non-linearity while maintaining differentiability everywhere
|
| 341 |
x = self.gelu(x)
|
| 342 |
-
|
| 343 |
# Second linear transformation: project back to original n_embd dimensions
|
| 344 |
# This projection combines information from the expanded representation
|
| 345 |
# Shape: (batch_size, seq_len, 4*n_embd) -> (batch_size, seq_len, n_embd)
|
| 346 |
x = self.c_proj(x)
|
| 347 |
-
|
| 348 |
# Apply dropout for regularization before residual connection
|
| 349 |
# Dropout randomly zeros some neurons during training to prevent overfitting
|
| 350 |
# This is particularly important in the feed-forward layers which have many parameters
|
| 351 |
x = self.dropout(x)
|
| 352 |
-
|
| 353 |
return x
|
| 354 |
|
| 355 |
|
| 356 |
class Block(nn.Module):
|
| 357 |
"""
|
| 358 |
Single Transformer block.
|
| 359 |
-
|
| 360 |
Consists of:
|
| 361 |
1. Layer normalization
|
| 362 |
-
2. Multi-head causal self-attention
|
| 363 |
3. Residual connection
|
| 364 |
4. Layer normalization
|
| 365 |
5. MLP (feed-forward network)
|
| 366 |
6. Residual connection
|
| 367 |
-
|
| 368 |
Uses pre-norm architecture for better training stability.
|
| 369 |
"""
|
| 370 |
-
|
| 371 |
def __init__(self, config: GPTConfig):
|
| 372 |
super().__init__()
|
| 373 |
self.ln_1 = nn.LayerNorm(config.n_embd)
|
| 374 |
self.attn = CausalSelfAttention(config)
|
| 375 |
self.ln_2 = nn.LayerNorm(config.n_embd)
|
| 376 |
self.mlp = MLP(config)
|
| 377 |
-
|
| 378 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 379 |
"""
|
| 380 |
Forward pass of transformer block.
|
| 381 |
-
|
| 382 |
Args:
|
| 383 |
x: Input tensor of shape (batch_size, seq_len, n_embd)
|
| 384 |
-
|
| 385 |
Returns:
|
| 386 |
torch.Tensor: Output tensor of shape (batch_size, seq_len, n_embd)
|
| 387 |
"""
|
| 388 |
# Pre-norm attention with residual connection
|
| 389 |
x = x + self.attn(self.ln_1(x))
|
| 390 |
-
|
| 391 |
-
# Pre-norm MLP with residual connection
|
| 392 |
x = x + self.mlp(self.ln_2(x))
|
| 393 |
-
|
| 394 |
return x
|
| 395 |
|
| 396 |
|
| 397 |
class GPTModel(nn.Module):
|
| 398 |
"""
|
| 399 |
Complete GPT Language Model.
|
| 400 |
-
|
| 401 |
This is the main model class that combines all components:
|
| 402 |
- Token and positional embeddings
|
| 403 |
- Stack of transformer blocks
|
| 404 |
- Final layer normalization
|
| 405 |
- Language modeling head
|
| 406 |
-
|
| 407 |
The model can be used for:
|
| 408 |
- Training from scratch on text data
|
| 409 |
- Fine-tuning on downstream tasks
|
| 410 |
- Text generation (inference)
|
| 411 |
"""
|
| 412 |
-
|
| 413 |
def __init__(self, config: GPTConfig):
|
| 414 |
super().__init__()
|
| 415 |
assert config.vocab_size is not None, "vocab_size must be specified"
|
| 416 |
assert config.block_size is not None, "block_size must be specified"
|
| 417 |
-
|
| 418 |
self.config = config
|
| 419 |
-
|
| 420 |
# Embeddings
|
| 421 |
-
self.transformer = nn.ModuleDict(
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 429 |
# Language modeling head (maps hidden states to vocabulary)
|
| 430 |
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
| 431 |
-
|
| 432 |
# Tie weights between token embeddings and output head (common practice)
|
| 433 |
self.transformer.wte.weight = self.lm_head.weight
|
| 434 |
-
|
| 435 |
# Initialize weights
|
| 436 |
self.apply(self._init_weights)
|
| 437 |
-
|
| 438 |
# Report parameter count
|
| 439 |
print(f"Model initialized: {self.config.model_name}")
|
| 440 |
print(f"Parameters: {self.get_num_params():,}")
|
| 441 |
print(f"Estimated: {self.config.estimate_parameters():,}")
|
| 442 |
-
|
| 443 |
def _init_weights(self, module):
|
| 444 |
"""Initialize model weights using standard practices."""
|
| 445 |
if isinstance(module, nn.Linear):
|
|
@@ -448,14 +455,14 @@ class GPTModel(nn.Module):
|
|
| 448 |
torch.nn.init.zeros_(module.bias)
|
| 449 |
elif isinstance(module, nn.Embedding):
|
| 450 |
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 451 |
-
|
| 452 |
def get_num_params(self, non_embedding: bool = False) -> int:
|
| 453 |
"""
|
| 454 |
Count the number of parameters in the model.
|
| 455 |
-
|
| 456 |
Args:
|
| 457 |
non_embedding: If True, subtract embedding parameters
|
| 458 |
-
|
| 459 |
Returns:
|
| 460 |
int: Number of parameters
|
| 461 |
"""
|
|
@@ -464,19 +471,17 @@ class GPTModel(nn.Module):
|
|
| 464 |
n_params -= self.transformer.wpe.weight.numel()
|
| 465 |
n_params -= self.transformer.wte.weight.numel()
|
| 466 |
return n_params
|
| 467 |
-
|
| 468 |
def forward(
|
| 469 |
-
self,
|
| 470 |
-
idx: torch.Tensor,
|
| 471 |
-
targets: Optional[torch.Tensor] = None
|
| 472 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 473 |
"""
|
| 474 |
Forward pass of the GPT model.
|
| 475 |
-
|
| 476 |
Args:
|
| 477 |
idx: Input token indices of shape (batch_size, seq_len)
|
| 478 |
targets: Optional target tokens for loss calculation (batch_size, seq_len)
|
| 479 |
-
|
| 480 |
Returns:
|
| 481 |
Tuple containing:
|
| 482 |
- logits: Output logits of shape (batch_size, seq_len, vocab_size)
|
|
@@ -484,53 +489,57 @@ class GPTModel(nn.Module):
|
|
| 484 |
"""
|
| 485 |
device = idx.device
|
| 486 |
b, t = idx.size()
|
| 487 |
-
assert
|
| 488 |
-
|
|
|
|
|
|
|
| 489 |
# Token embeddings
|
| 490 |
tok_emb = self.transformer.wte(idx) # (b, t, n_embd)
|
| 491 |
-
|
| 492 |
# Position embeddings
|
| 493 |
pos = torch.arange(0, t, dtype=torch.long, device=device) # (t,)
|
| 494 |
pos_emb = self.transformer.wpe(pos) # (t, n_embd)
|
| 495 |
-
|
| 496 |
# Combine embeddings and apply dropout
|
| 497 |
x = self.transformer.drop(tok_emb + pos_emb)
|
| 498 |
-
|
| 499 |
# Pass through transformer blocks
|
| 500 |
for block in self.transformer.h:
|
| 501 |
x = block(x)
|
| 502 |
-
|
| 503 |
# Final layer normalization
|
| 504 |
x = self.transformer.ln_f(x)
|
| 505 |
-
|
| 506 |
# Language modeling head
|
| 507 |
if targets is not None:
|
| 508 |
# If we have targets, compute loss
|
| 509 |
logits = self.lm_head(x)
|
| 510 |
-
loss = F.cross_entropy(
|
|
|
|
|
|
|
| 511 |
else:
|
| 512 |
# If no targets, only compute logits for the last token (more efficient for generation)
|
| 513 |
logits = self.lm_head(x[:, [-1], :]) # Note: using list [-1] to preserve the time dim
|
| 514 |
loss = None
|
| 515 |
-
|
| 516 |
return logits, loss
|
| 517 |
-
|
| 518 |
def generate(
|
| 519 |
-
self,
|
| 520 |
-
idx: torch.Tensor,
|
| 521 |
max_new_tokens: int = 100,
|
| 522 |
temperature: float = 1.0,
|
| 523 |
-
top_k: Optional[int] = None
|
| 524 |
) -> torch.Tensor:
|
| 525 |
"""
|
| 526 |
Generate new tokens autoregressively.
|
| 527 |
-
|
| 528 |
Args:
|
| 529 |
idx: Starting token indices (batch_size, seq_len)
|
| 530 |
max_new_tokens: Maximum number of new tokens to generate
|
| 531 |
temperature: Sampling temperature (higher = more random)
|
| 532 |
top_k: If set, only sample from top-k most likely tokens
|
| 533 |
-
|
| 534 |
Returns:
|
| 535 |
torch.Tensor: Generated sequence (batch_size, seq_len + max_new_tokens)
|
| 536 |
"""
|
|
@@ -538,57 +547,61 @@ class GPTModel(nn.Module):
|
|
| 538 |
with torch.no_grad():
|
| 539 |
for _ in range(max_new_tokens):
|
| 540 |
# Crop sequence if it exceeds block size
|
| 541 |
-
idx_cond =
|
| 542 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 543 |
# Forward pass
|
| 544 |
logits, _ = self(idx_cond)
|
| 545 |
-
|
| 546 |
# Get logits for the last token and apply temperature
|
| 547 |
logits = logits[:, -1, :] / temperature
|
| 548 |
-
|
| 549 |
# Optionally crop to top-k most likely tokens
|
| 550 |
if top_k is not None:
|
| 551 |
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 552 |
-
logits[logits < v[:, [-1]]] = -float(
|
| 553 |
-
|
| 554 |
# Apply softmax and sample
|
| 555 |
probs = F.softmax(logits, dim=-1)
|
| 556 |
idx_next = torch.multinomial(probs, num_samples=1)
|
| 557 |
-
|
| 558 |
# Append to sequence
|
| 559 |
idx = torch.cat((idx, idx_next), dim=1)
|
| 560 |
-
|
| 561 |
self.train() # Return to training mode
|
| 562 |
return idx
|
| 563 |
-
|
| 564 |
def estimate_memory_usage(self, batch_size: int = 1, seq_len: int = None) -> dict:
|
| 565 |
"""
|
| 566 |
Estimate memory usage for training and inference.
|
| 567 |
-
|
| 568 |
Args:
|
| 569 |
batch_size: Batch size for estimation
|
| 570 |
seq_len: Sequence length (defaults to block_size)
|
| 571 |
-
|
| 572 |
Returns:
|
| 573 |
dict: Memory usage estimates in MB
|
| 574 |
"""
|
| 575 |
if seq_len is None:
|
| 576 |
seq_len = self.config.block_size
|
| 577 |
-
|
| 578 |
# Model parameters (weights)
|
| 579 |
param_memory = self.get_num_params() * 4 / (1024**2) # 4 bytes per float32
|
| 580 |
-
|
| 581 |
# Activations (rough estimate)
|
| 582 |
activation_memory = (
|
| 583 |
batch_size * seq_len * self.config.n_embd * self.config.n_layer * 8 # Rough estimate
|
| 584 |
) / (1024**2)
|
| 585 |
-
|
| 586 |
# Gradients (same size as parameters during training)
|
| 587 |
gradient_memory = param_memory
|
| 588 |
-
|
| 589 |
return {
|
| 590 |
"parameters_mb": param_memory,
|
| 591 |
-
"activations_mb": activation_memory,
|
| 592 |
"gradients_mb": gradient_memory,
|
| 593 |
"total_training_mb": param_memory + activation_memory + gradient_memory,
|
| 594 |
"total_inference_mb": param_memory + activation_memory * 0.5, # No gradients needed
|
|
@@ -598,25 +611,25 @@ class GPTModel(nn.Module):
|
|
| 598 |
def create_model(model_size: str = "medium") -> GPTModel:
|
| 599 |
"""
|
| 600 |
Factory function to create a GPT model with predefined configurations.
|
| 601 |
-
|
| 602 |
Args:
|
| 603 |
model_size: Size of model to create ("small", "medium", "large")
|
| 604 |
-
|
| 605 |
Returns:
|
| 606 |
GPTModel: Initialized model
|
| 607 |
"""
|
| 608 |
configs = {
|
| 609 |
"small": GPTConfig.small(),
|
| 610 |
-
"medium": GPTConfig.medium(),
|
| 611 |
"large": GPTConfig.large(),
|
| 612 |
}
|
| 613 |
-
|
| 614 |
if model_size not in configs:
|
| 615 |
raise ValueError(f"Unknown model size: {model_size}. Choose from {list(configs.keys())}")
|
| 616 |
-
|
| 617 |
config = configs[model_size]
|
| 618 |
model = GPTModel(config)
|
| 619 |
-
|
| 620 |
return model
|
| 621 |
|
| 622 |
|
|
@@ -624,18 +637,20 @@ if __name__ == "__main__":
|
|
| 624 |
# Example usage
|
| 625 |
print("π§ GPT Model Architecture")
|
| 626 |
print("=" * 50)
|
| 627 |
-
|
| 628 |
# Create models of different sizes
|
| 629 |
for size in ["small", "medium", "large"]:
|
| 630 |
print(f"\n{size.upper()} MODEL:")
|
| 631 |
model = create_model(size)
|
| 632 |
-
|
| 633 |
# Show memory estimates
|
| 634 |
memory = model.estimate_memory_usage(batch_size=4, seq_len=512)
|
| 635 |
-
print(
|
| 636 |
-
|
|
|
|
|
|
|
| 637 |
# Test forward pass
|
| 638 |
x = torch.randint(0, 32000, (2, 64)) # Batch size 2, sequence length 64
|
| 639 |
with torch.no_grad():
|
| 640 |
logits, _ = model(x)
|
| 641 |
-
print(f"Test forward pass: {x.shape} -> {logits.shape} β")
|
|
|
|
| 18 |
|
| 19 |
ARCHITECTURE OVERVIEW:
|
| 20 |
- Token Embedding: Maps token IDs to dense vectors
|
| 21 |
+
- Positional Embedding: Adds position information to token embeddings
|
| 22 |
- Transformer Blocks: Stack of multi-head attention + feed-forward layers
|
| 23 |
- Layer Normalization: Pre-norm placement for training stability
|
| 24 |
- Output Head: Linear projection to vocabulary for next-token prediction
|
|
|
|
| 32 |
|
| 33 |
Usage:
|
| 34 |
from model import GPTConfig, GPTModel
|
| 35 |
+
|
| 36 |
config = GPTConfig(vocab_size=32000, n_layer=12, n_head=12, n_embd=768)
|
| 37 |
model = GPTModel(config)
|
| 38 |
+
|
| 39 |
# Forward pass
|
| 40 |
logits = model(input_ids) # Shape: (batch_size, seq_len, vocab_size)
|
| 41 |
|
| 42 |
Hardware Requirements:
|
| 43 |
- Small Model (25M params): 4-8GB RAM, CPU/integrated GPU
|
| 44 |
+
- Medium Model (117M params): 8-16GB RAM, dedicated GPU recommended
|
| 45 |
- Large Model (350M params): 16GB+ RAM, high-end GPU required
|
| 46 |
|
| 47 |
Author: Louis Chua Bean Chong
|
|
|
|
| 60 |
class GPTConfig:
|
| 61 |
"""
|
| 62 |
Configuration class for GPT model hyperparameters.
|
| 63 |
+
|
| 64 |
This class defines all the architectural parameters needed to instantiate
|
| 65 |
a GPT model. Use the provided class methods to get pre-configured setups
|
| 66 |
for different model sizes.
|
| 67 |
"""
|
| 68 |
+
|
| 69 |
# Model architecture
|
| 70 |
+
vocab_size: int = 32000 # Vocabulary size (from tokenizer)
|
| 71 |
+
n_layer: int = 12 # Number of transformer layers
|
| 72 |
+
n_head: int = 12 # Number of attention heads
|
| 73 |
+
n_embd: int = 768 # Embedding dimension
|
| 74 |
+
|
| 75 |
# Sequence and context
|
| 76 |
+
block_size: int = 1024 # Maximum sequence length
|
| 77 |
+
|
| 78 |
# Training hyperparameters
|
| 79 |
+
dropout: float = 0.1 # Dropout probability
|
| 80 |
+
bias: bool = True # Use bias in linear layers
|
| 81 |
+
|
| 82 |
# Model size identifier
|
| 83 |
+
model_name: str = "gpt-medium" # Human-readable model identifier
|
| 84 |
+
|
| 85 |
@classmethod
|
| 86 |
+
def small(cls) -> "GPTConfig":
|
| 87 |
"""Small model configuration (~25M parameters) - Good for CPU training"""
|
| 88 |
return cls(
|
| 89 |
vocab_size=32000,
|
| 90 |
n_layer=6,
|
| 91 |
+
n_head=8,
|
| 92 |
n_embd=512,
|
| 93 |
block_size=1024,
|
| 94 |
dropout=0.1,
|
| 95 |
+
model_name="gpt-small",
|
| 96 |
)
|
| 97 |
+
|
| 98 |
+
@classmethod
|
| 99 |
+
def medium(cls) -> "GPTConfig":
|
| 100 |
"""Medium model configuration (~117M parameters) - Balanced performance"""
|
| 101 |
return cls(
|
| 102 |
vocab_size=32000,
|
|
|
|
| 105 |
n_embd=768,
|
| 106 |
block_size=2048,
|
| 107 |
dropout=0.1,
|
| 108 |
+
model_name="gpt-medium",
|
| 109 |
)
|
| 110 |
+
|
| 111 |
@classmethod
|
| 112 |
+
def large(cls) -> "GPTConfig":
|
| 113 |
"""Large model configuration (~350M parameters) - High performance"""
|
| 114 |
return cls(
|
| 115 |
vocab_size=32000,
|
|
|
|
| 118 |
n_embd=1024,
|
| 119 |
block_size=2048,
|
| 120 |
dropout=0.1,
|
| 121 |
+
model_name="gpt-large",
|
| 122 |
)
|
| 123 |
+
|
| 124 |
def estimate_parameters(self) -> int:
|
| 125 |
"""
|
| 126 |
Estimate the total number of trainable parameters.
|
| 127 |
+
|
| 128 |
Returns:
|
| 129 |
int: Estimated parameter count
|
| 130 |
"""
|
| 131 |
# Token embeddings
|
| 132 |
token_emb = self.vocab_size * self.n_embd
|
| 133 |
+
|
| 134 |
+
# Position embeddings
|
| 135 |
pos_emb = self.block_size * self.n_embd
|
| 136 |
+
|
| 137 |
# Transformer layers
|
| 138 |
# Each layer: attention (4 * n_embd^2) + mlp (8 * n_embd^2) + layer_norms
|
| 139 |
layer_params = self.n_layer * (12 * self.n_embd**2 + 4 * self.n_embd)
|
| 140 |
+
|
| 141 |
# Output head
|
| 142 |
output_head = self.vocab_size * self.n_embd
|
| 143 |
+
|
| 144 |
total = token_emb + pos_emb + layer_params + output_head
|
| 145 |
return total
|
| 146 |
|
|
|
|
| 148 |
class CausalSelfAttention(nn.Module):
|
| 149 |
"""
|
| 150 |
Multi-head causal self-attention mechanism.
|
| 151 |
+
|
| 152 |
This implements the core attention mechanism of the transformer, with causal
|
| 153 |
masking to ensure autoregressive behavior (tokens can only attend to previous
|
| 154 |
tokens, not future ones).
|
| 155 |
"""
|
| 156 |
+
|
| 157 |
def __init__(self, config: GPTConfig):
|
| 158 |
super().__init__()
|
| 159 |
+
assert (
|
| 160 |
+
config.n_embd % config.n_head == 0
|
| 161 |
+
), "Embedding dim must be divisible by number of heads"
|
| 162 |
+
|
| 163 |
self.config = config
|
| 164 |
self.n_head = config.n_head
|
| 165 |
self.n_embd = config.n_embd
|
| 166 |
self.head_dim = self.n_embd // self.n_head
|
| 167 |
+
|
| 168 |
# Key, query, value projections for all heads (batched)
|
| 169 |
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
|
| 170 |
+
|
| 171 |
# Output projection
|
| 172 |
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
|
| 173 |
+
|
| 174 |
# Dropout
|
| 175 |
self.attn_dropout = nn.Dropout(config.dropout)
|
| 176 |
self.resid_dropout = nn.Dropout(config.dropout)
|
| 177 |
+
|
| 178 |
# Causal mask - lower triangular matrix
|
| 179 |
self.register_buffer(
|
| 180 |
"bias",
|
| 181 |
+
torch.tril(torch.ones(config.block_size, config.block_size)).view(
|
| 182 |
+
1, 1, config.block_size, config.block_size
|
| 183 |
+
),
|
| 184 |
)
|
| 185 |
+
|
| 186 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 187 |
"""
|
| 188 |
Forward pass of causal self-attention.
|
| 189 |
+
|
| 190 |
This method implements the scaled dot-product attention mechanism with causal masking.
|
| 191 |
The attention mechanism allows each token to attend to all previous tokens in the sequence,
|
| 192 |
but not to future tokens, maintaining the autoregressive property essential for language modeling.
|
| 193 |
+
|
| 194 |
Mathematical formulation:
|
| 195 |
Attention(Q, K, V) = softmax(QK^T / sqrt(d_k))V
|
| 196 |
where Q, K, V are query, key, value matrices derived from input x
|
| 197 |
+
|
| 198 |
Implementation details:
|
| 199 |
- Uses batch matrix multiplication for efficiency
|
| 200 |
- Applies causal mask to prevent future token attention
|
| 201 |
- Implements multi-head attention by reshaping and parallel processing
|
| 202 |
- Applies dropout for regularization during training
|
| 203 |
+
|
| 204 |
Args:
|
| 205 |
x: Input tensor of shape (batch_size, seq_len, n_embd)
|
| 206 |
Contains embedded token representations from previous layer
|
| 207 |
+
|
| 208 |
Returns:
|
| 209 |
torch.Tensor: Output tensor of shape (batch_size, seq_len, n_embd)
|
| 210 |
"""
|
| 211 |
# Extract tensor dimensions for clear variable naming and validation
|
| 212 |
# B = batch size (number of sequences processed in parallel)
|
| 213 |
+
# T = sequence length (number of tokens in each sequence)
|
| 214 |
# C = embedding dimensionality (n_embd from config)
|
| 215 |
B, T, C = x.size()
|
| 216 |
+
|
| 217 |
# Generate query, key, and value projections for all attention heads
|
| 218 |
# The c_attn linear layer outputs 3 * n_embd features, which we split into Q, K, V
|
| 219 |
# This batched approach is more efficient than separate linear layers
|
| 220 |
# Input shape: (B, T, C) -> Output shape: (B, T, 3*C) -> Split to 3x (B, T, C)
|
| 221 |
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
|
| 222 |
+
|
| 223 |
# Reshape tensors for multi-head attention computation
|
| 224 |
# Transform from (B, T, C) to (B, nh, T, hs) where:
|
| 225 |
# - nh = number of heads (self.n_head)
|
|
|
|
| 228 |
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
|
| 229 |
k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
|
| 230 |
v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
|
| 231 |
+
|
| 232 |
# Compute scaled dot-product attention scores
|
| 233 |
# Matrix multiplication: Q @ K^T gives attention affinities between all token pairs
|
| 234 |
# Scaling by 1/sqrt(head_dim) prevents softmax saturation for large embedding dimensions
|
| 235 |
# Shape: (B, nh, T, hs) @ (B, nh, hs, T) -> (B, nh, T, T)
|
| 236 |
# The resulting (T, T) matrix represents attention weights from each token to every other token
|
| 237 |
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))
|
| 238 |
+
|
| 239 |
# Apply causal masking to enforce autoregressive property
|
| 240 |
# The causal mask ensures that token i can only attend to tokens j where j <= i
|
| 241 |
# This prevents the model from "cheating" by looking at future tokens during training
|
| 242 |
# We use -inf for masked positions so they become 0 after softmax
|
| 243 |
+
att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
|
| 244 |
+
|
| 245 |
# Convert attention scores to probabilities using softmax
|
| 246 |
# Each row of the attention matrix now sums to 1, representing a probability distribution
|
| 247 |
# over which tokens to attend to for each query position
|
| 248 |
att = F.softmax(att, dim=-1)
|
| 249 |
+
|
| 250 |
# Apply dropout to attention weights for regularization
|
| 251 |
# This randomly zeros some attention connections during training to prevent overfitting
|
| 252 |
att = self.attn_dropout(att)
|
| 253 |
+
|
| 254 |
# Apply attention weights to value vectors
|
| 255 |
# This weighted combination produces the actual output of the attention mechanism
|
| 256 |
# Shape: (B, nh, T, T) @ (B, nh, T, hs) -> (B, nh, T, hs)
|
| 257 |
# Each output position is a weighted sum of all value vectors, with weights from attention
|
| 258 |
y = att @ v
|
| 259 |
+
|
| 260 |
# Concatenate multi-head outputs back to original embedding dimension
|
| 261 |
# Transform from (B, nh, T, hs) back to (B, T, C) where C = nh * hs
|
| 262 |
# The transpose moves head dimension back, and contiguous() ensures memory layout efficiency
|
| 263 |
# This combines information from all attention heads into a single representation
|
| 264 |
y = y.transpose(1, 2).contiguous().view(B, T, C)
|
| 265 |
+
|
| 266 |
# Apply final output projection and residual dropout
|
| 267 |
# The output projection allows the model to learn how to best combine multi-head information
|
| 268 |
# Residual dropout provides additional regularization before the residual connection
|
|
|
|
| 273 |
class MLP(nn.Module):
|
| 274 |
"""
|
| 275 |
Multi-Layer Perceptron (Feed-Forward Network) for Transformer.
|
| 276 |
+
|
| 277 |
This implements the position-wise feed-forward network that appears in each transformer layer.
|
| 278 |
The MLP provides additional non-linear transformation capacity beyond what attention provides.
|
| 279 |
+
|
| 280 |
Architecture:
|
| 281 |
Input -> Linear(n_embd -> 4*n_embd) -> GELU -> Linear(4*n_embd -> n_embd) -> Dropout -> Output
|
| 282 |
+
|
| 283 |
Design rationale:
|
| 284 |
- 4x expansion is standard in transformers (from "Attention Is All You Need")
|
| 285 |
- GELU activation provides smoother gradients than ReLU for language modeling
|
| 286 |
- Dropout prevents overfitting in the feed-forward layers
|
| 287 |
- Two linear layers allow complex non-linear transformations of attention outputs
|
| 288 |
+
|
| 289 |
Parameters:
|
| 290 |
- First linear layer: n_embd * 4*n_embd parameters (expansion)
|
| 291 |
- Second linear layer: 4*n_embd * n_embd parameters (projection back)
|
| 292 |
- Total: 8 * n_embd^2 parameters (significant portion of model size)
|
| 293 |
"""
|
| 294 |
+
|
| 295 |
def __init__(self, config: GPTConfig):
|
| 296 |
super().__init__()
|
| 297 |
+
|
| 298 |
# First linear layer: expand embedding dimension by 4x
|
| 299 |
# This expansion gives the network more representational capacity
|
| 300 |
# The 4x factor is a standard choice that balances capacity vs efficiency
|
| 301 |
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
|
| 302 |
+
|
| 303 |
# GELU (Gaussian Error Linear Unit) activation function
|
| 304 |
# GELU provides smoother gradients compared to ReLU and works better for language modeling
|
| 305 |
# It's approximately: GELU(x) = x * Ξ¦(x) where Ξ¦ is the CDF of standard normal distribution
|
| 306 |
self.gelu = nn.GELU()
|
| 307 |
+
|
| 308 |
# Second linear layer: project back to original embedding dimension
|
| 309 |
# This projection allows the network to combine information from the expanded representation
|
| 310 |
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
|
| 311 |
+
|
| 312 |
# Dropout for regularization in the feed-forward network
|
| 313 |
# Applied after the final projection to prevent overfitting
|
| 314 |
self.dropout = nn.Dropout(config.dropout)
|
| 315 |
+
|
| 316 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 317 |
"""
|
| 318 |
Forward pass of the feed-forward network.
|
| 319 |
+
|
| 320 |
This method applies a two-layer MLP with GELU activation to transform
|
| 321 |
the attention outputs. The MLP operates independently on each position
|
| 322 |
in the sequence, providing position-wise non-linear transformations.
|
| 323 |
+
|
| 324 |
Mathematical operation:
|
| 325 |
MLP(x) = Dropout(Linearβ(GELU(Linearβ(x))))
|
| 326 |
where Linearβ: R^n_embd -> R^4*n_embd and Linearβ: R^4*n_embd -> R^n_embd
|
| 327 |
+
|
| 328 |
Args:
|
| 329 |
x: Input tensor of shape (batch_size, seq_len, n_embd)
|
| 330 |
Contains attended representations from the attention layer
|
| 331 |
+
|
| 332 |
Returns:
|
| 333 |
torch.Tensor: Output tensor of shape (batch_size, seq_len, n_embd)
|
| 334 |
Contains transformed representations ready for residual connection
|
|
|
|
| 337 |
# This expansion provides the network with a higher-dimensional space for computation
|
| 338 |
# Shape: (batch_size, seq_len, n_embd) -> (batch_size, seq_len, 4*n_embd)
|
| 339 |
x = self.c_fc(x)
|
| 340 |
+
|
| 341 |
# Apply GELU activation function for non-linearity
|
| 342 |
# GELU is smoother than ReLU and provides better gradients for language modeling
|
| 343 |
# It introduces non-linearity while maintaining differentiability everywhere
|
| 344 |
x = self.gelu(x)
|
| 345 |
+
|
| 346 |
# Second linear transformation: project back to original n_embd dimensions
|
| 347 |
# This projection combines information from the expanded representation
|
| 348 |
# Shape: (batch_size, seq_len, 4*n_embd) -> (batch_size, seq_len, n_embd)
|
| 349 |
x = self.c_proj(x)
|
| 350 |
+
|
| 351 |
# Apply dropout for regularization before residual connection
|
| 352 |
# Dropout randomly zeros some neurons during training to prevent overfitting
|
| 353 |
# This is particularly important in the feed-forward layers which have many parameters
|
| 354 |
x = self.dropout(x)
|
| 355 |
+
|
| 356 |
return x
|
| 357 |
|
| 358 |
|
| 359 |
class Block(nn.Module):
|
| 360 |
"""
|
| 361 |
Single Transformer block.
|
| 362 |
+
|
| 363 |
Consists of:
|
| 364 |
1. Layer normalization
|
| 365 |
+
2. Multi-head causal self-attention
|
| 366 |
3. Residual connection
|
| 367 |
4. Layer normalization
|
| 368 |
5. MLP (feed-forward network)
|
| 369 |
6. Residual connection
|
| 370 |
+
|
| 371 |
Uses pre-norm architecture for better training stability.
|
| 372 |
"""
|
| 373 |
+
|
| 374 |
def __init__(self, config: GPTConfig):
|
| 375 |
super().__init__()
|
| 376 |
self.ln_1 = nn.LayerNorm(config.n_embd)
|
| 377 |
self.attn = CausalSelfAttention(config)
|
| 378 |
self.ln_2 = nn.LayerNorm(config.n_embd)
|
| 379 |
self.mlp = MLP(config)
|
| 380 |
+
|
| 381 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 382 |
"""
|
| 383 |
Forward pass of transformer block.
|
| 384 |
+
|
| 385 |
Args:
|
| 386 |
x: Input tensor of shape (batch_size, seq_len, n_embd)
|
| 387 |
+
|
| 388 |
Returns:
|
| 389 |
torch.Tensor: Output tensor of shape (batch_size, seq_len, n_embd)
|
| 390 |
"""
|
| 391 |
# Pre-norm attention with residual connection
|
| 392 |
x = x + self.attn(self.ln_1(x))
|
| 393 |
+
|
| 394 |
+
# Pre-norm MLP with residual connection
|
| 395 |
x = x + self.mlp(self.ln_2(x))
|
| 396 |
+
|
| 397 |
return x
|
| 398 |
|
| 399 |
|
| 400 |
class GPTModel(nn.Module):
|
| 401 |
"""
|
| 402 |
Complete GPT Language Model.
|
| 403 |
+
|
| 404 |
This is the main model class that combines all components:
|
| 405 |
- Token and positional embeddings
|
| 406 |
- Stack of transformer blocks
|
| 407 |
- Final layer normalization
|
| 408 |
- Language modeling head
|
| 409 |
+
|
| 410 |
The model can be used for:
|
| 411 |
- Training from scratch on text data
|
| 412 |
- Fine-tuning on downstream tasks
|
| 413 |
- Text generation (inference)
|
| 414 |
"""
|
| 415 |
+
|
| 416 |
def __init__(self, config: GPTConfig):
|
| 417 |
super().__init__()
|
| 418 |
assert config.vocab_size is not None, "vocab_size must be specified"
|
| 419 |
assert config.block_size is not None, "block_size must be specified"
|
| 420 |
+
|
| 421 |
self.config = config
|
| 422 |
+
|
| 423 |
# Embeddings
|
| 424 |
+
self.transformer = nn.ModuleDict(
|
| 425 |
+
dict(
|
| 426 |
+
wte=nn.Embedding(config.vocab_size, config.n_embd), # Token embeddings
|
| 427 |
+
wpe=nn.Embedding(config.block_size, config.n_embd), # Position embeddings
|
| 428 |
+
drop=nn.Dropout(config.dropout),
|
| 429 |
+
h=nn.ModuleList(
|
| 430 |
+
[Block(config) for _ in range(config.n_layer)]
|
| 431 |
+
), # Transformer blocks
|
| 432 |
+
ln_f=nn.LayerNorm(config.n_embd), # Final layer norm
|
| 433 |
+
)
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
# Language modeling head (maps hidden states to vocabulary)
|
| 437 |
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
| 438 |
+
|
| 439 |
# Tie weights between token embeddings and output head (common practice)
|
| 440 |
self.transformer.wte.weight = self.lm_head.weight
|
| 441 |
+
|
| 442 |
# Initialize weights
|
| 443 |
self.apply(self._init_weights)
|
| 444 |
+
|
| 445 |
# Report parameter count
|
| 446 |
print(f"Model initialized: {self.config.model_name}")
|
| 447 |
print(f"Parameters: {self.get_num_params():,}")
|
| 448 |
print(f"Estimated: {self.config.estimate_parameters():,}")
|
| 449 |
+
|
| 450 |
def _init_weights(self, module):
|
| 451 |
"""Initialize model weights using standard practices."""
|
| 452 |
if isinstance(module, nn.Linear):
|
|
|
|
| 455 |
torch.nn.init.zeros_(module.bias)
|
| 456 |
elif isinstance(module, nn.Embedding):
|
| 457 |
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 458 |
+
|
| 459 |
def get_num_params(self, non_embedding: bool = False) -> int:
|
| 460 |
"""
|
| 461 |
Count the number of parameters in the model.
|
| 462 |
+
|
| 463 |
Args:
|
| 464 |
non_embedding: If True, subtract embedding parameters
|
| 465 |
+
|
| 466 |
Returns:
|
| 467 |
int: Number of parameters
|
| 468 |
"""
|
|
|
|
| 471 |
n_params -= self.transformer.wpe.weight.numel()
|
| 472 |
n_params -= self.transformer.wte.weight.numel()
|
| 473 |
return n_params
|
| 474 |
+
|
| 475 |
def forward(
|
| 476 |
+
self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None
|
|
|
|
|
|
|
| 477 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 478 |
"""
|
| 479 |
Forward pass of the GPT model.
|
| 480 |
+
|
| 481 |
Args:
|
| 482 |
idx: Input token indices of shape (batch_size, seq_len)
|
| 483 |
targets: Optional target tokens for loss calculation (batch_size, seq_len)
|
| 484 |
+
|
| 485 |
Returns:
|
| 486 |
Tuple containing:
|
| 487 |
- logits: Output logits of shape (batch_size, seq_len, vocab_size)
|
|
|
|
| 489 |
"""
|
| 490 |
device = idx.device
|
| 491 |
b, t = idx.size()
|
| 492 |
+
assert (
|
| 493 |
+
t <= self.config.block_size
|
| 494 |
+
), f"Sequence length {t} exceeds block size {self.config.block_size}"
|
| 495 |
+
|
| 496 |
# Token embeddings
|
| 497 |
tok_emb = self.transformer.wte(idx) # (b, t, n_embd)
|
| 498 |
+
|
| 499 |
# Position embeddings
|
| 500 |
pos = torch.arange(0, t, dtype=torch.long, device=device) # (t,)
|
| 501 |
pos_emb = self.transformer.wpe(pos) # (t, n_embd)
|
| 502 |
+
|
| 503 |
# Combine embeddings and apply dropout
|
| 504 |
x = self.transformer.drop(tok_emb + pos_emb)
|
| 505 |
+
|
| 506 |
# Pass through transformer blocks
|
| 507 |
for block in self.transformer.h:
|
| 508 |
x = block(x)
|
| 509 |
+
|
| 510 |
# Final layer normalization
|
| 511 |
x = self.transformer.ln_f(x)
|
| 512 |
+
|
| 513 |
# Language modeling head
|
| 514 |
if targets is not None:
|
| 515 |
# If we have targets, compute loss
|
| 516 |
logits = self.lm_head(x)
|
| 517 |
+
loss = F.cross_entropy(
|
| 518 |
+
logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1
|
| 519 |
+
)
|
| 520 |
else:
|
| 521 |
# If no targets, only compute logits for the last token (more efficient for generation)
|
| 522 |
logits = self.lm_head(x[:, [-1], :]) # Note: using list [-1] to preserve the time dim
|
| 523 |
loss = None
|
| 524 |
+
|
| 525 |
return logits, loss
|
| 526 |
+
|
| 527 |
def generate(
|
| 528 |
+
self,
|
| 529 |
+
idx: torch.Tensor,
|
| 530 |
max_new_tokens: int = 100,
|
| 531 |
temperature: float = 1.0,
|
| 532 |
+
top_k: Optional[int] = None,
|
| 533 |
) -> torch.Tensor:
|
| 534 |
"""
|
| 535 |
Generate new tokens autoregressively.
|
| 536 |
+
|
| 537 |
Args:
|
| 538 |
idx: Starting token indices (batch_size, seq_len)
|
| 539 |
max_new_tokens: Maximum number of new tokens to generate
|
| 540 |
temperature: Sampling temperature (higher = more random)
|
| 541 |
top_k: If set, only sample from top-k most likely tokens
|
| 542 |
+
|
| 543 |
Returns:
|
| 544 |
torch.Tensor: Generated sequence (batch_size, seq_len + max_new_tokens)
|
| 545 |
"""
|
|
|
|
| 547 |
with torch.no_grad():
|
| 548 |
for _ in range(max_new_tokens):
|
| 549 |
# Crop sequence if it exceeds block size
|
| 550 |
+
idx_cond = (
|
| 551 |
+
idx
|
| 552 |
+
if idx.size(1) <= self.config.block_size
|
| 553 |
+
else idx[:, -self.config.block_size :]
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
# Forward pass
|
| 557 |
logits, _ = self(idx_cond)
|
| 558 |
+
|
| 559 |
# Get logits for the last token and apply temperature
|
| 560 |
logits = logits[:, -1, :] / temperature
|
| 561 |
+
|
| 562 |
# Optionally crop to top-k most likely tokens
|
| 563 |
if top_k is not None:
|
| 564 |
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 565 |
+
logits[logits < v[:, [-1]]] = -float("Inf")
|
| 566 |
+
|
| 567 |
# Apply softmax and sample
|
| 568 |
probs = F.softmax(logits, dim=-1)
|
| 569 |
idx_next = torch.multinomial(probs, num_samples=1)
|
| 570 |
+
|
| 571 |
# Append to sequence
|
| 572 |
idx = torch.cat((idx, idx_next), dim=1)
|
| 573 |
+
|
| 574 |
self.train() # Return to training mode
|
| 575 |
return idx
|
| 576 |
+
|
| 577 |
def estimate_memory_usage(self, batch_size: int = 1, seq_len: int = None) -> dict:
|
| 578 |
"""
|
| 579 |
Estimate memory usage for training and inference.
|
| 580 |
+
|
| 581 |
Args:
|
| 582 |
batch_size: Batch size for estimation
|
| 583 |
seq_len: Sequence length (defaults to block_size)
|
| 584 |
+
|
| 585 |
Returns:
|
| 586 |
dict: Memory usage estimates in MB
|
| 587 |
"""
|
| 588 |
if seq_len is None:
|
| 589 |
seq_len = self.config.block_size
|
| 590 |
+
|
| 591 |
# Model parameters (weights)
|
| 592 |
param_memory = self.get_num_params() * 4 / (1024**2) # 4 bytes per float32
|
| 593 |
+
|
| 594 |
# Activations (rough estimate)
|
| 595 |
activation_memory = (
|
| 596 |
batch_size * seq_len * self.config.n_embd * self.config.n_layer * 8 # Rough estimate
|
| 597 |
) / (1024**2)
|
| 598 |
+
|
| 599 |
# Gradients (same size as parameters during training)
|
| 600 |
gradient_memory = param_memory
|
| 601 |
+
|
| 602 |
return {
|
| 603 |
"parameters_mb": param_memory,
|
| 604 |
+
"activations_mb": activation_memory,
|
| 605 |
"gradients_mb": gradient_memory,
|
| 606 |
"total_training_mb": param_memory + activation_memory + gradient_memory,
|
| 607 |
"total_inference_mb": param_memory + activation_memory * 0.5, # No gradients needed
|
|
|
|
| 611 |
def create_model(model_size: str = "medium") -> GPTModel:
|
| 612 |
"""
|
| 613 |
Factory function to create a GPT model with predefined configurations.
|
| 614 |
+
|
| 615 |
Args:
|
| 616 |
model_size: Size of model to create ("small", "medium", "large")
|
| 617 |
+
|
| 618 |
Returns:
|
| 619 |
GPTModel: Initialized model
|
| 620 |
"""
|
| 621 |
configs = {
|
| 622 |
"small": GPTConfig.small(),
|
| 623 |
+
"medium": GPTConfig.medium(),
|
| 624 |
"large": GPTConfig.large(),
|
| 625 |
}
|
| 626 |
+
|
| 627 |
if model_size not in configs:
|
| 628 |
raise ValueError(f"Unknown model size: {model_size}. Choose from {list(configs.keys())}")
|
| 629 |
+
|
| 630 |
config = configs[model_size]
|
| 631 |
model = GPTModel(config)
|
| 632 |
+
|
| 633 |
return model
|
| 634 |
|
| 635 |
|
|
|
|
| 637 |
# Example usage
|
| 638 |
print("π§ GPT Model Architecture")
|
| 639 |
print("=" * 50)
|
| 640 |
+
|
| 641 |
# Create models of different sizes
|
| 642 |
for size in ["small", "medium", "large"]:
|
| 643 |
print(f"\n{size.upper()} MODEL:")
|
| 644 |
model = create_model(size)
|
| 645 |
+
|
| 646 |
# Show memory estimates
|
| 647 |
memory = model.estimate_memory_usage(batch_size=4, seq_len=512)
|
| 648 |
+
print(
|
| 649 |
+
f"Memory (4 batch, 512 seq): {memory['total_training_mb']:.1f}MB training, {memory['total_inference_mb']:.1f}MB inference"
|
| 650 |
+
)
|
| 651 |
+
|
| 652 |
# Test forward pass
|
| 653 |
x = torch.randint(0, 32000, (2, 64)) # Batch size 2, sequence length 64
|
| 654 |
with torch.no_grad():
|
| 655 |
logits, _ = model(x)
|
| 656 |
+
print(f"Test forward pass: {x.shape} -> {logits.shape} β")
|
training/train_model.py
CHANGED
|
@@ -69,6 +69,7 @@ try:
|
|
| 69 |
from data_loader import TextDataLoader
|
| 70 |
except ImportError:
|
| 71 |
import sys
|
|
|
|
| 72 |
sys.path.append(os.path.dirname(__file__))
|
| 73 |
from model import GPTModel, GPTConfig, create_model
|
| 74 |
from data_loader import TextDataLoader
|
|
@@ -77,11 +78,11 @@ except ImportError:
|
|
| 77 |
class ModelTrainer:
|
| 78 |
"""
|
| 79 |
Comprehensive trainer for GPT-style language models.
|
| 80 |
-
|
| 81 |
Handles the complete training pipeline including data loading, optimization,
|
| 82 |
checkpointing, and progress monitoring.
|
| 83 |
"""
|
| 84 |
-
|
| 85 |
def __init__(
|
| 86 |
self,
|
| 87 |
model: GPTModel,
|
|
@@ -96,11 +97,11 @@ class ModelTrainer:
|
|
| 96 |
gradient_clipping: float = 1.0,
|
| 97 |
save_every: int = 1000,
|
| 98 |
eval_every: int = 500,
|
| 99 |
-
log_every: int = 100
|
| 100 |
):
|
| 101 |
"""
|
| 102 |
Initialize the model trainer.
|
| 103 |
-
|
| 104 |
Args:
|
| 105 |
model: GPT model to train
|
| 106 |
data_loader: Data loader for training data
|
|
@@ -120,7 +121,7 @@ class ModelTrainer:
|
|
| 120 |
self.data_loader = data_loader
|
| 121 |
self.output_dir = Path(output_dir)
|
| 122 |
self.device = device
|
| 123 |
-
|
| 124 |
# Training hyperparameters
|
| 125 |
self.learning_rate = learning_rate
|
| 126 |
self.weight_decay = weight_decay
|
|
@@ -128,29 +129,29 @@ class ModelTrainer:
|
|
| 128 |
self.max_steps = max_steps
|
| 129 |
self.gradient_accumulation_steps = gradient_accumulation_steps
|
| 130 |
self.gradient_clipping = gradient_clipping
|
| 131 |
-
|
| 132 |
# Logging and saving
|
| 133 |
self.save_every = save_every
|
| 134 |
self.eval_every = eval_every
|
| 135 |
self.log_every = log_every
|
| 136 |
-
|
| 137 |
# Create output directory
|
| 138 |
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 139 |
-
|
| 140 |
# Initialize optimizer and scheduler
|
| 141 |
self.optimizer = self._create_optimizer()
|
| 142 |
self.scheduler = self._create_scheduler()
|
| 143 |
-
|
| 144 |
# Training state
|
| 145 |
self.step = 0
|
| 146 |
self.epoch = 0
|
| 147 |
-
self.best_loss = float(
|
| 148 |
self.training_log = []
|
| 149 |
-
|
| 150 |
# Performance tracking
|
| 151 |
self.start_time = None
|
| 152 |
self.step_times = []
|
| 153 |
-
|
| 154 |
print(f"π ModelTrainer initialized")
|
| 155 |
print(f" Device: {device}")
|
| 156 |
print(f" Model parameters: {model.get_num_params():,}")
|
|
@@ -158,38 +159,38 @@ class ModelTrainer:
|
|
| 158 |
print(f" Max steps: {max_steps:,}")
|
| 159 |
print(f" Gradient accumulation: {gradient_accumulation_steps}")
|
| 160 |
print(f" Output directory: {output_dir}")
|
| 161 |
-
|
| 162 |
def _create_optimizer(self) -> optim.Optimizer:
|
| 163 |
"""Create AdamW optimizer with weight decay."""
|
| 164 |
# Separate parameters for weight decay
|
| 165 |
decay_params = []
|
| 166 |
no_decay_params = []
|
| 167 |
-
|
| 168 |
for name, param in self.model.named_parameters():
|
| 169 |
if not param.requires_grad:
|
| 170 |
continue
|
| 171 |
-
|
| 172 |
# Don't apply weight decay to biases and layer norm parameters
|
| 173 |
-
if len(param.shape) == 1 or name.endswith(
|
| 174 |
no_decay_params.append(param)
|
| 175 |
else:
|
| 176 |
decay_params.append(param)
|
| 177 |
-
|
| 178 |
param_groups = [
|
| 179 |
-
{
|
| 180 |
-
{
|
| 181 |
]
|
| 182 |
-
|
| 183 |
# Use AdamW with lower memory usage for CPU
|
| 184 |
optimizer = optim.AdamW(
|
| 185 |
param_groups,
|
| 186 |
lr=self.learning_rate,
|
| 187 |
betas=(0.9, 0.95), # Slightly different from default for LLM training
|
| 188 |
-
eps=1e-8
|
| 189 |
)
|
| 190 |
-
|
| 191 |
return optimizer
|
| 192 |
-
|
| 193 |
def _create_scheduler(self) -> torch.optim.lr_scheduler._LRScheduler:
|
| 194 |
"""Create learning rate scheduler with warmup and cosine decay."""
|
| 195 |
if self.warmup_steps > 0:
|
|
@@ -201,59 +202,59 @@ class ModelTrainer:
|
|
| 201 |
self.max_steps = max_steps
|
| 202 |
self.min_lr_factor = min_lr_factor
|
| 203 |
super().__init__(optimizer)
|
| 204 |
-
|
| 205 |
def get_lr(self):
|
| 206 |
if self.last_epoch < self.warmup_steps:
|
| 207 |
-
|
| 208 |
factor = self.last_epoch / self.warmup_steps
|
| 209 |
return [base_lr * (0.01 + 0.99 * factor) for base_lr in self.base_lrs]
|
| 210 |
else:
|
| 211 |
# Cosine decay
|
| 212 |
-
progress = (self.last_epoch - self.warmup_steps) / (
|
|
|
|
|
|
|
| 213 |
progress = min(progress, 1.0) # Clamp to 1.0
|
| 214 |
factor = 0.5 * (1 + math.cos(math.pi * progress))
|
| 215 |
factor = self.min_lr_factor + (1 - self.min_lr_factor) * factor
|
| 216 |
return [base_lr * factor for base_lr in self.base_lrs]
|
| 217 |
-
|
| 218 |
scheduler = WarmupCosineScheduler(
|
| 219 |
self.optimizer,
|
| 220 |
warmup_steps=self.warmup_steps,
|
| 221 |
max_steps=self.max_steps,
|
| 222 |
-
min_lr_factor=0.1
|
| 223 |
)
|
| 224 |
else:
|
| 225 |
# Just cosine decay - this should not trigger warnings
|
| 226 |
scheduler = CosineAnnealingLR(
|
| 227 |
-
self.optimizer,
|
| 228 |
-
T_max=self.max_steps,
|
| 229 |
-
eta_min=self.learning_rate * 0.1
|
| 230 |
)
|
| 231 |
-
|
| 232 |
return scheduler
|
| 233 |
-
|
| 234 |
def _calculate_loss(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
| 235 |
"""
|
| 236 |
Calculate cross-entropy loss for autoregressive language modeling.
|
| 237 |
-
|
| 238 |
This method computes the standard cross-entropy loss used in language model training.
|
| 239 |
The loss measures how well the model predicts the next token in the sequence.
|
| 240 |
-
|
| 241 |
Mathematical formulation:
|
| 242 |
Loss = -β log(P(target_token | context))
|
| 243 |
where P is the softmax probability distribution over vocabulary
|
| 244 |
-
|
| 245 |
Implementation details:
|
| 246 |
- Reshapes 3D tensors to 2D for efficient computation
|
| 247 |
- Uses PyTorch's optimized cross_entropy function
|
| 248 |
- Handles padding tokens by ignoring them in loss calculation
|
| 249 |
- Computes mean loss across all valid positions
|
| 250 |
-
|
| 251 |
Why cross-entropy for language modeling:
|
| 252 |
- Natural choice for multi-class classification (next token prediction)
|
| 253 |
- Provides strong gradient signal for correct token probabilities
|
| 254 |
- Mathematically equivalent to minimizing negative log-likelihood
|
| 255 |
- Well-studied optimization properties for neural language models
|
| 256 |
-
|
| 257 |
Args:
|
| 258 |
logits: Raw model predictions of shape (batch_size, seq_len, vocab_size)
|
| 259 |
Contains unnormalized scores for each token in vocabulary
|
|
@@ -261,7 +262,7 @@ class ModelTrainer:
|
|
| 261 |
targets: Ground truth next tokens of shape (batch_size, seq_len)
|
| 262 |
Contains token IDs representing the true next tokens
|
| 263 |
Should be input sequence shifted by one position
|
| 264 |
-
|
| 265 |
Returns:
|
| 266 |
torch.Tensor: Scalar loss value representing prediction error
|
| 267 |
Lower values indicate better next-token prediction accuracy
|
|
@@ -271,123 +272,126 @@ class ModelTrainer:
|
|
| 271 |
# where each row represents one prediction over the entire vocabulary
|
| 272 |
logits = logits.view(-1, logits.size(-1)) # (batch_size * seq_len, vocab_size)
|
| 273 |
targets = targets.view(-1) # (batch_size * seq_len,)
|
| 274 |
-
|
| 275 |
# Calculate cross-entropy loss with proper handling of special tokens
|
| 276 |
# ignore_index=-1 excludes padding tokens from loss calculation
|
| 277 |
# This prevents the model from learning to predict padding, which would skew training
|
| 278 |
# The function internally applies softmax to logits and computes negative log-likelihood
|
| 279 |
loss = nn.functional.cross_entropy(logits, targets, ignore_index=-1)
|
| 280 |
-
|
| 281 |
# Return scalar loss for backpropagation
|
| 282 |
# This loss will be used to compute gradients via automatic differentiation
|
| 283 |
return loss
|
| 284 |
-
|
| 285 |
def _get_memory_usage(self) -> Dict[str, float]:
|
| 286 |
"""Get current memory usage statistics."""
|
| 287 |
memory_stats = {}
|
| 288 |
-
|
| 289 |
-
if torch.cuda.is_available() and self.device.startswith(
|
| 290 |
-
memory_stats[
|
| 291 |
-
memory_stats[
|
| 292 |
-
|
| 293 |
# Estimate CPU memory (approximate)
|
| 294 |
import psutil
|
|
|
|
| 295 |
process = psutil.Process()
|
| 296 |
-
memory_stats[
|
| 297 |
-
|
| 298 |
return memory_stats
|
| 299 |
-
|
| 300 |
def _log_step(self, step: int, loss: float, lr: float, step_time: float) -> None:
|
| 301 |
"""Log training progress for a single step."""
|
| 302 |
perplexity = math.exp(min(loss, 10)) # Cap at exp(10) to avoid overflow
|
| 303 |
-
|
| 304 |
# Calculate tokens per second
|
| 305 |
tokens_per_batch = self.data_loader.batch_size * self.data_loader.seq_len
|
| 306 |
tokens_per_second = tokens_per_batch / step_time if step_time > 0 else 0
|
| 307 |
-
|
| 308 |
# Get memory usage
|
| 309 |
memory_stats = self._get_memory_usage()
|
| 310 |
-
|
| 311 |
# Create log entry
|
| 312 |
log_entry = {
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
}
|
| 321 |
-
|
| 322 |
self.training_log.append(log_entry)
|
| 323 |
-
|
| 324 |
# Print progress
|
| 325 |
elapsed_time = time.time() - self.start_time if self.start_time else 0
|
| 326 |
eta_seconds = (self.max_steps - step) * step_time if step_time > 0 else 0
|
| 327 |
eta_hours = eta_seconds / 3600
|
| 328 |
-
|
| 329 |
-
print(
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
|
|
|
|
|
|
| 338 |
def _save_checkpoint(self, step: int, is_best: bool = False) -> None:
|
| 339 |
"""Save model checkpoint."""
|
| 340 |
checkpoint = {
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
}
|
| 350 |
-
|
| 351 |
# Save latest checkpoint
|
| 352 |
checkpoint_path = self.output_dir / f"checkpoint_step_{step}.pt"
|
| 353 |
torch.save(checkpoint, checkpoint_path)
|
| 354 |
-
|
| 355 |
# Save best checkpoint
|
| 356 |
if is_best:
|
| 357 |
best_path = self.output_dir / "best_model.pt"
|
| 358 |
torch.save(checkpoint, best_path)
|
| 359 |
print(f"πΎ New best model saved: {best_path}")
|
| 360 |
-
|
| 361 |
# Save training log
|
| 362 |
log_path = self.output_dir / "training_log.json"
|
| 363 |
-
with open(log_path,
|
| 364 |
json.dump(self.training_log, f, indent=2)
|
| 365 |
-
|
| 366 |
print(f"πΎ Checkpoint saved: {checkpoint_path}")
|
| 367 |
-
|
| 368 |
def _load_checkpoint(self, checkpoint_path: str) -> None:
|
| 369 |
"""Load model checkpoint to resume training."""
|
| 370 |
if not os.path.exists(checkpoint_path):
|
| 371 |
print(f"β οΈ Checkpoint not found: {checkpoint_path}")
|
| 372 |
return
|
| 373 |
-
|
| 374 |
print(f"π Loading checkpoint: {checkpoint_path}")
|
| 375 |
-
|
| 376 |
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
| 377 |
-
|
| 378 |
-
self.model.load_state_dict(checkpoint[
|
| 379 |
-
self.optimizer.load_state_dict(checkpoint[
|
| 380 |
-
self.scheduler.load_state_dict(checkpoint[
|
| 381 |
-
|
| 382 |
-
self.step = checkpoint[
|
| 383 |
-
self.epoch = checkpoint[
|
| 384 |
-
self.best_loss = checkpoint[
|
| 385 |
-
self.training_log = checkpoint.get(
|
| 386 |
-
|
| 387 |
print(f"β Checkpoint loaded successfully")
|
| 388 |
print(f" Resuming from step: {self.step:,}")
|
| 389 |
print(f" Best loss so far: {self.best_loss:.4f}")
|
| 390 |
-
|
| 391 |
def train(self) -> None:
|
| 392 |
"""Main training loop."""
|
| 393 |
print(f"\nπ Starting training...")
|
|
@@ -396,85 +400,85 @@ class ModelTrainer:
|
|
| 396 |
print(f" Device: {self.device}")
|
| 397 |
print(f" Max steps: {self.max_steps:,}")
|
| 398 |
print("=" * 80)
|
| 399 |
-
|
| 400 |
self.model.train()
|
| 401 |
self.start_time = time.time()
|
| 402 |
-
|
| 403 |
# Initialize gradient accumulation
|
| 404 |
accumulated_loss = 0.0
|
| 405 |
self.optimizer.zero_grad()
|
| 406 |
-
|
| 407 |
for batch_idx, (input_ids, target_ids) in enumerate(self.data_loader):
|
| 408 |
if self.step >= self.max_steps:
|
| 409 |
break
|
| 410 |
-
|
| 411 |
step_start_time = time.time()
|
| 412 |
-
|
| 413 |
# Move batch to device
|
| 414 |
input_ids = input_ids.to(self.device)
|
| 415 |
target_ids = target_ids.to(self.device)
|
| 416 |
-
|
| 417 |
# Forward pass (model computes loss internally when targets provided)
|
| 418 |
logits, loss = self.model(input_ids, target_ids)
|
| 419 |
-
|
| 420 |
# Scale loss for gradient accumulation
|
| 421 |
loss = loss / self.gradient_accumulation_steps
|
| 422 |
accumulated_loss += loss.item()
|
| 423 |
-
|
| 424 |
# Backward pass
|
| 425 |
loss.backward()
|
| 426 |
-
|
| 427 |
# Update weights every gradient_accumulation_steps
|
| 428 |
if (batch_idx + 1) % self.gradient_accumulation_steps == 0:
|
| 429 |
# Clip gradients
|
| 430 |
if self.gradient_clipping > 0:
|
| 431 |
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.gradient_clipping)
|
| 432 |
-
|
| 433 |
# Update parameters
|
| 434 |
self.optimizer.step()
|
| 435 |
self.scheduler.step()
|
| 436 |
self.optimizer.zero_grad()
|
| 437 |
-
|
| 438 |
# Update step count
|
| 439 |
self.step += 1
|
| 440 |
step_time = time.time() - step_start_time
|
| 441 |
self.step_times.append(step_time)
|
| 442 |
-
|
| 443 |
# Get current learning rate
|
| 444 |
current_lr = self.scheduler.get_last_lr()[0]
|
| 445 |
-
|
| 446 |
# Log progress
|
| 447 |
if self.step % self.log_every == 0:
|
| 448 |
avg_loss = accumulated_loss
|
| 449 |
self._log_step(self.step, avg_loss, current_lr, step_time)
|
| 450 |
-
|
| 451 |
# Save checkpoint
|
| 452 |
if self.step % self.save_every == 0:
|
| 453 |
is_best = accumulated_loss < self.best_loss
|
| 454 |
if is_best:
|
| 455 |
self.best_loss = accumulated_loss
|
| 456 |
-
|
| 457 |
self._save_checkpoint(self.step, is_best)
|
| 458 |
-
|
| 459 |
# Clean up memory periodically
|
| 460 |
if self.step % 100 == 0:
|
| 461 |
gc.collect()
|
| 462 |
-
|
| 463 |
# Reset accumulated loss
|
| 464 |
accumulated_loss = 0.0
|
| 465 |
-
|
| 466 |
# Check if training complete
|
| 467 |
if self.step >= self.max_steps:
|
| 468 |
break
|
| 469 |
-
|
| 470 |
# Final checkpoint
|
| 471 |
print(f"\nπ Training completed!")
|
| 472 |
self._save_checkpoint(self.step, is_best=True)
|
| 473 |
-
|
| 474 |
# Training summary
|
| 475 |
total_time = time.time() - self.start_time
|
| 476 |
avg_step_time = sum(self.step_times) / len(self.step_times) if self.step_times else 0
|
| 477 |
-
|
| 478 |
print(f"\nπ Training Summary:")
|
| 479 |
print(f" Steps completed: {self.step:,}")
|
| 480 |
print(f" Total time: {total_time/3600:.2f} hours")
|
|
@@ -504,130 +508,105 @@ Examples:
|
|
| 504 |
--batch-size 2 \\
|
| 505 |
--max-steps 50000 \\
|
| 506 |
--output-dir models/my-medium-model
|
| 507 |
-
"""
|
| 508 |
)
|
| 509 |
-
|
| 510 |
# Model and data arguments
|
| 511 |
parser.add_argument(
|
| 512 |
"--model-size",
|
| 513 |
choices=["small", "medium", "large"],
|
| 514 |
default="small",
|
| 515 |
-
help="Model size to train (default: small)"
|
| 516 |
)
|
| 517 |
-
|
| 518 |
parser.add_argument(
|
| 519 |
"--data-file",
|
| 520 |
default="data/clean/training_data.txt",
|
| 521 |
-
help="Path to training text file (default: data/clean/training_data.txt)"
|
| 522 |
)
|
| 523 |
-
|
| 524 |
parser.add_argument(
|
| 525 |
"--tokenizer-dir",
|
| 526 |
default="data/tokenizer/",
|
| 527 |
-
help="Path to tokenizer directory (default: data/tokenizer/)"
|
| 528 |
)
|
| 529 |
-
|
| 530 |
parser.add_argument(
|
| 531 |
-
"--output-dir",
|
| 532 |
-
required=True,
|
| 533 |
-
help="Output directory for model checkpoints"
|
| 534 |
)
|
| 535 |
-
|
| 536 |
# Training hyperparameters
|
| 537 |
parser.add_argument(
|
| 538 |
-
"--seq-len",
|
| 539 |
-
type=int,
|
| 540 |
-
default=512,
|
| 541 |
-
help="Sequence length for training (default: 512)"
|
| 542 |
-
)
|
| 543 |
-
|
| 544 |
-
parser.add_argument(
|
| 545 |
-
"--batch-size",
|
| 546 |
-
type=int,
|
| 547 |
-
default=4,
|
| 548 |
-
help="Batch size (default: 4)"
|
| 549 |
)
|
| 550 |
-
|
|
|
|
|
|
|
| 551 |
parser.add_argument(
|
| 552 |
-
"--learning-rate",
|
| 553 |
-
type=float,
|
| 554 |
-
default=3e-4,
|
| 555 |
-
help="Learning rate (default: 3e-4)"
|
| 556 |
)
|
| 557 |
-
|
| 558 |
parser.add_argument(
|
| 559 |
-
"--max-steps",
|
| 560 |
-
type=int,
|
| 561 |
-
default=10000,
|
| 562 |
-
help="Maximum training steps (default: 10000)"
|
| 563 |
)
|
| 564 |
-
|
| 565 |
parser.add_argument(
|
| 566 |
-
"--warmup-steps",
|
| 567 |
-
type=int,
|
| 568 |
-
default=1000,
|
| 569 |
-
help="Warmup steps (default: 1000)"
|
| 570 |
)
|
| 571 |
-
|
| 572 |
parser.add_argument(
|
| 573 |
"--gradient-accumulation-steps",
|
| 574 |
type=int,
|
| 575 |
default=4,
|
| 576 |
-
help="Gradient accumulation steps (default: 4)"
|
| 577 |
)
|
| 578 |
-
|
| 579 |
parser.add_argument(
|
| 580 |
"--device",
|
| 581 |
choices=["cpu", "cuda", "auto"],
|
| 582 |
default="auto",
|
| 583 |
-
help="Training device (default: auto)"
|
| 584 |
-
)
|
| 585 |
-
|
| 586 |
-
parser.add_argument(
|
| 587 |
-
"--resume",
|
| 588 |
-
help="Path to checkpoint to resume training from"
|
| 589 |
)
|
| 590 |
-
|
|
|
|
|
|
|
| 591 |
parser.add_argument(
|
| 592 |
-
"--save-every",
|
| 593 |
-
type=int,
|
| 594 |
-
default=1000,
|
| 595 |
-
help="Save checkpoint every N steps (default: 1000)"
|
| 596 |
)
|
| 597 |
-
|
| 598 |
args = parser.parse_args()
|
| 599 |
-
|
| 600 |
print("π OpenLLM Model Training")
|
| 601 |
print("=" * 60)
|
| 602 |
-
|
| 603 |
# Determine device
|
| 604 |
if args.device == "auto":
|
| 605 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 606 |
else:
|
| 607 |
device = args.device
|
| 608 |
-
|
| 609 |
print(f"Using device: {device}")
|
| 610 |
-
|
| 611 |
try:
|
| 612 |
# Create model
|
| 613 |
print(f"\nποΈ Creating {args.model_size} model...")
|
| 614 |
model = create_model(args.model_size)
|
| 615 |
-
|
| 616 |
# Create data loader
|
| 617 |
print(f"\nπ Setting up data loader...")
|
| 618 |
tokenizer_path = os.path.join(args.tokenizer_dir, "tokenizer.model")
|
| 619 |
-
|
| 620 |
data_loader = TextDataLoader(
|
| 621 |
data_file=args.data_file,
|
| 622 |
tokenizer_path=tokenizer_path,
|
| 623 |
seq_len=args.seq_len,
|
| 624 |
batch_size=args.batch_size,
|
| 625 |
-
shuffle=True
|
| 626 |
)
|
| 627 |
-
|
| 628 |
# Get data statistics
|
| 629 |
data_stats = data_loader.get_data_stats()
|
| 630 |
-
|
| 631 |
# Create trainer
|
| 632 |
print(f"\nπ― Setting up trainer...")
|
| 633 |
trainer = ModelTrainer(
|
|
@@ -639,26 +618,27 @@ Examples:
|
|
| 639 |
max_steps=args.max_steps,
|
| 640 |
warmup_steps=args.warmup_steps,
|
| 641 |
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 642 |
-
save_every=args.save_every
|
| 643 |
)
|
| 644 |
-
|
| 645 |
# Resume from checkpoint if specified
|
| 646 |
if args.resume:
|
| 647 |
trainer._load_checkpoint(args.resume)
|
| 648 |
-
|
| 649 |
# Start training
|
| 650 |
trainer.train()
|
| 651 |
-
|
| 652 |
print(f"\nπ Training completed successfully!")
|
| 653 |
-
|
| 654 |
except Exception as e:
|
| 655 |
print(f"\nβ Training failed: {e}")
|
| 656 |
import traceback
|
|
|
|
| 657 |
traceback.print_exc()
|
| 658 |
return False
|
| 659 |
-
|
| 660 |
return True
|
| 661 |
|
| 662 |
|
| 663 |
if __name__ == "__main__":
|
| 664 |
-
main()
|
|
|
|
| 69 |
from data_loader import TextDataLoader
|
| 70 |
except ImportError:
|
| 71 |
import sys
|
| 72 |
+
|
| 73 |
sys.path.append(os.path.dirname(__file__))
|
| 74 |
from model import GPTModel, GPTConfig, create_model
|
| 75 |
from data_loader import TextDataLoader
|
|
|
|
| 78 |
class ModelTrainer:
|
| 79 |
"""
|
| 80 |
Comprehensive trainer for GPT-style language models.
|
| 81 |
+
|
| 82 |
Handles the complete training pipeline including data loading, optimization,
|
| 83 |
checkpointing, and progress monitoring.
|
| 84 |
"""
|
| 85 |
+
|
| 86 |
def __init__(
|
| 87 |
self,
|
| 88 |
model: GPTModel,
|
|
|
|
| 97 |
gradient_clipping: float = 1.0,
|
| 98 |
save_every: int = 1000,
|
| 99 |
eval_every: int = 500,
|
| 100 |
+
log_every: int = 100,
|
| 101 |
):
|
| 102 |
"""
|
| 103 |
Initialize the model trainer.
|
| 104 |
+
|
| 105 |
Args:
|
| 106 |
model: GPT model to train
|
| 107 |
data_loader: Data loader for training data
|
|
|
|
| 121 |
self.data_loader = data_loader
|
| 122 |
self.output_dir = Path(output_dir)
|
| 123 |
self.device = device
|
| 124 |
+
|
| 125 |
# Training hyperparameters
|
| 126 |
self.learning_rate = learning_rate
|
| 127 |
self.weight_decay = weight_decay
|
|
|
|
| 129 |
self.max_steps = max_steps
|
| 130 |
self.gradient_accumulation_steps = gradient_accumulation_steps
|
| 131 |
self.gradient_clipping = gradient_clipping
|
| 132 |
+
|
| 133 |
# Logging and saving
|
| 134 |
self.save_every = save_every
|
| 135 |
self.eval_every = eval_every
|
| 136 |
self.log_every = log_every
|
| 137 |
+
|
| 138 |
# Create output directory
|
| 139 |
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 140 |
+
|
| 141 |
# Initialize optimizer and scheduler
|
| 142 |
self.optimizer = self._create_optimizer()
|
| 143 |
self.scheduler = self._create_scheduler()
|
| 144 |
+
|
| 145 |
# Training state
|
| 146 |
self.step = 0
|
| 147 |
self.epoch = 0
|
| 148 |
+
self.best_loss = float("inf")
|
| 149 |
self.training_log = []
|
| 150 |
+
|
| 151 |
# Performance tracking
|
| 152 |
self.start_time = None
|
| 153 |
self.step_times = []
|
| 154 |
+
|
| 155 |
print(f"π ModelTrainer initialized")
|
| 156 |
print(f" Device: {device}")
|
| 157 |
print(f" Model parameters: {model.get_num_params():,}")
|
|
|
|
| 159 |
print(f" Max steps: {max_steps:,}")
|
| 160 |
print(f" Gradient accumulation: {gradient_accumulation_steps}")
|
| 161 |
print(f" Output directory: {output_dir}")
|
| 162 |
+
|
| 163 |
def _create_optimizer(self) -> optim.Optimizer:
|
| 164 |
"""Create AdamW optimizer with weight decay."""
|
| 165 |
# Separate parameters for weight decay
|
| 166 |
decay_params = []
|
| 167 |
no_decay_params = []
|
| 168 |
+
|
| 169 |
for name, param in self.model.named_parameters():
|
| 170 |
if not param.requires_grad:
|
| 171 |
continue
|
| 172 |
+
|
| 173 |
# Don't apply weight decay to biases and layer norm parameters
|
| 174 |
+
if len(param.shape) == 1 or name.endswith(".bias"):
|
| 175 |
no_decay_params.append(param)
|
| 176 |
else:
|
| 177 |
decay_params.append(param)
|
| 178 |
+
|
| 179 |
param_groups = [
|
| 180 |
+
{"params": decay_params, "weight_decay": self.weight_decay},
|
| 181 |
+
{"params": no_decay_params, "weight_decay": 0.0},
|
| 182 |
]
|
| 183 |
+
|
| 184 |
# Use AdamW with lower memory usage for CPU
|
| 185 |
optimizer = optim.AdamW(
|
| 186 |
param_groups,
|
| 187 |
lr=self.learning_rate,
|
| 188 |
betas=(0.9, 0.95), # Slightly different from default for LLM training
|
| 189 |
+
eps=1e-8,
|
| 190 |
)
|
| 191 |
+
|
| 192 |
return optimizer
|
| 193 |
+
|
| 194 |
def _create_scheduler(self) -> torch.optim.lr_scheduler._LRScheduler:
|
| 195 |
"""Create learning rate scheduler with warmup and cosine decay."""
|
| 196 |
if self.warmup_steps > 0:
|
|
|
|
| 202 |
self.max_steps = max_steps
|
| 203 |
self.min_lr_factor = min_lr_factor
|
| 204 |
super().__init__(optimizer)
|
| 205 |
+
|
| 206 |
def get_lr(self):
|
| 207 |
if self.last_epoch < self.warmup_steps:
|
| 208 |
+
# Linear warmup
|
| 209 |
factor = self.last_epoch / self.warmup_steps
|
| 210 |
return [base_lr * (0.01 + 0.99 * factor) for base_lr in self.base_lrs]
|
| 211 |
else:
|
| 212 |
# Cosine decay
|
| 213 |
+
progress = (self.last_epoch - self.warmup_steps) / (
|
| 214 |
+
self.max_steps - self.warmup_steps
|
| 215 |
+
)
|
| 216 |
progress = min(progress, 1.0) # Clamp to 1.0
|
| 217 |
factor = 0.5 * (1 + math.cos(math.pi * progress))
|
| 218 |
factor = self.min_lr_factor + (1 - self.min_lr_factor) * factor
|
| 219 |
return [base_lr * factor for base_lr in self.base_lrs]
|
| 220 |
+
|
| 221 |
scheduler = WarmupCosineScheduler(
|
| 222 |
self.optimizer,
|
| 223 |
warmup_steps=self.warmup_steps,
|
| 224 |
max_steps=self.max_steps,
|
| 225 |
+
min_lr_factor=0.1,
|
| 226 |
)
|
| 227 |
else:
|
| 228 |
# Just cosine decay - this should not trigger warnings
|
| 229 |
scheduler = CosineAnnealingLR(
|
| 230 |
+
self.optimizer, T_max=self.max_steps, eta_min=self.learning_rate * 0.1
|
|
|
|
|
|
|
| 231 |
)
|
| 232 |
+
|
| 233 |
return scheduler
|
| 234 |
+
|
| 235 |
def _calculate_loss(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
| 236 |
"""
|
| 237 |
Calculate cross-entropy loss for autoregressive language modeling.
|
| 238 |
+
|
| 239 |
This method computes the standard cross-entropy loss used in language model training.
|
| 240 |
The loss measures how well the model predicts the next token in the sequence.
|
| 241 |
+
|
| 242 |
Mathematical formulation:
|
| 243 |
Loss = -β log(P(target_token | context))
|
| 244 |
where P is the softmax probability distribution over vocabulary
|
| 245 |
+
|
| 246 |
Implementation details:
|
| 247 |
- Reshapes 3D tensors to 2D for efficient computation
|
| 248 |
- Uses PyTorch's optimized cross_entropy function
|
| 249 |
- Handles padding tokens by ignoring them in loss calculation
|
| 250 |
- Computes mean loss across all valid positions
|
| 251 |
+
|
| 252 |
Why cross-entropy for language modeling:
|
| 253 |
- Natural choice for multi-class classification (next token prediction)
|
| 254 |
- Provides strong gradient signal for correct token probabilities
|
| 255 |
- Mathematically equivalent to minimizing negative log-likelihood
|
| 256 |
- Well-studied optimization properties for neural language models
|
| 257 |
+
|
| 258 |
Args:
|
| 259 |
logits: Raw model predictions of shape (batch_size, seq_len, vocab_size)
|
| 260 |
Contains unnormalized scores for each token in vocabulary
|
|
|
|
| 262 |
targets: Ground truth next tokens of shape (batch_size, seq_len)
|
| 263 |
Contains token IDs representing the true next tokens
|
| 264 |
Should be input sequence shifted by one position
|
| 265 |
+
|
| 266 |
Returns:
|
| 267 |
torch.Tensor: Scalar loss value representing prediction error
|
| 268 |
Lower values indicate better next-token prediction accuracy
|
|
|
|
| 272 |
# where each row represents one prediction over the entire vocabulary
|
| 273 |
logits = logits.view(-1, logits.size(-1)) # (batch_size * seq_len, vocab_size)
|
| 274 |
targets = targets.view(-1) # (batch_size * seq_len,)
|
| 275 |
+
|
| 276 |
# Calculate cross-entropy loss with proper handling of special tokens
|
| 277 |
# ignore_index=-1 excludes padding tokens from loss calculation
|
| 278 |
# This prevents the model from learning to predict padding, which would skew training
|
| 279 |
# The function internally applies softmax to logits and computes negative log-likelihood
|
| 280 |
loss = nn.functional.cross_entropy(logits, targets, ignore_index=-1)
|
| 281 |
+
|
| 282 |
# Return scalar loss for backpropagation
|
| 283 |
# This loss will be used to compute gradients via automatic differentiation
|
| 284 |
return loss
|
| 285 |
+
|
| 286 |
def _get_memory_usage(self) -> Dict[str, float]:
|
| 287 |
"""Get current memory usage statistics."""
|
| 288 |
memory_stats = {}
|
| 289 |
+
|
| 290 |
+
if torch.cuda.is_available() and self.device.startswith("cuda"):
|
| 291 |
+
memory_stats["gpu_allocated_mb"] = torch.cuda.memory_allocated() / (1024**2)
|
| 292 |
+
memory_stats["gpu_cached_mb"] = torch.cuda.memory_reserved() / (1024**2)
|
| 293 |
+
|
| 294 |
# Estimate CPU memory (approximate)
|
| 295 |
import psutil
|
| 296 |
+
|
| 297 |
process = psutil.Process()
|
| 298 |
+
memory_stats["cpu_memory_mb"] = process.memory_info().rss / (1024**2)
|
| 299 |
+
|
| 300 |
return memory_stats
|
| 301 |
+
|
| 302 |
def _log_step(self, step: int, loss: float, lr: float, step_time: float) -> None:
|
| 303 |
"""Log training progress for a single step."""
|
| 304 |
perplexity = math.exp(min(loss, 10)) # Cap at exp(10) to avoid overflow
|
| 305 |
+
|
| 306 |
# Calculate tokens per second
|
| 307 |
tokens_per_batch = self.data_loader.batch_size * self.data_loader.seq_len
|
| 308 |
tokens_per_second = tokens_per_batch / step_time if step_time > 0 else 0
|
| 309 |
+
|
| 310 |
# Get memory usage
|
| 311 |
memory_stats = self._get_memory_usage()
|
| 312 |
+
|
| 313 |
# Create log entry
|
| 314 |
log_entry = {
|
| 315 |
+
"step": step,
|
| 316 |
+
"loss": loss,
|
| 317 |
+
"perplexity": perplexity,
|
| 318 |
+
"learning_rate": lr,
|
| 319 |
+
"step_time": step_time,
|
| 320 |
+
"tokens_per_second": tokens_per_second,
|
| 321 |
+
"memory_mb": memory_stats.get("cpu_memory_mb", 0),
|
| 322 |
}
|
| 323 |
+
|
| 324 |
self.training_log.append(log_entry)
|
| 325 |
+
|
| 326 |
# Print progress
|
| 327 |
elapsed_time = time.time() - self.start_time if self.start_time else 0
|
| 328 |
eta_seconds = (self.max_steps - step) * step_time if step_time > 0 else 0
|
| 329 |
eta_hours = eta_seconds / 3600
|
| 330 |
+
|
| 331 |
+
print(
|
| 332 |
+
f"Step {step:,}/{self.max_steps:,} | "
|
| 333 |
+
f"Loss: {loss:.4f} | "
|
| 334 |
+
f"PPL: {perplexity:.2f} | "
|
| 335 |
+
f"LR: {lr:.2e} | "
|
| 336 |
+
f"Time: {step_time:.2f}s | "
|
| 337 |
+
f"Tokens/s: {tokens_per_second:.1f} | "
|
| 338 |
+
f"Memory: {memory_stats.get('cpu_memory_mb', 0):.0f}MB | "
|
| 339 |
+
f"ETA: {eta_hours:.1f}h"
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
def _save_checkpoint(self, step: int, is_best: bool = False) -> None:
|
| 343 |
"""Save model checkpoint."""
|
| 344 |
checkpoint = {
|
| 345 |
+
"step": step,
|
| 346 |
+
"epoch": self.epoch,
|
| 347 |
+
"model_state_dict": self.model.state_dict(),
|
| 348 |
+
"optimizer_state_dict": self.optimizer.state_dict(),
|
| 349 |
+
"scheduler_state_dict": self.scheduler.state_dict(),
|
| 350 |
+
"best_loss": self.best_loss,
|
| 351 |
+
"training_log": self.training_log,
|
| 352 |
+
"config": self.model.config.__dict__,
|
| 353 |
}
|
| 354 |
+
|
| 355 |
# Save latest checkpoint
|
| 356 |
checkpoint_path = self.output_dir / f"checkpoint_step_{step}.pt"
|
| 357 |
torch.save(checkpoint, checkpoint_path)
|
| 358 |
+
|
| 359 |
# Save best checkpoint
|
| 360 |
if is_best:
|
| 361 |
best_path = self.output_dir / "best_model.pt"
|
| 362 |
torch.save(checkpoint, best_path)
|
| 363 |
print(f"πΎ New best model saved: {best_path}")
|
| 364 |
+
|
| 365 |
# Save training log
|
| 366 |
log_path = self.output_dir / "training_log.json"
|
| 367 |
+
with open(log_path, "w") as f:
|
| 368 |
json.dump(self.training_log, f, indent=2)
|
| 369 |
+
|
| 370 |
print(f"πΎ Checkpoint saved: {checkpoint_path}")
|
| 371 |
+
|
| 372 |
def _load_checkpoint(self, checkpoint_path: str) -> None:
|
| 373 |
"""Load model checkpoint to resume training."""
|
| 374 |
if not os.path.exists(checkpoint_path):
|
| 375 |
print(f"β οΈ Checkpoint not found: {checkpoint_path}")
|
| 376 |
return
|
| 377 |
+
|
| 378 |
print(f"π Loading checkpoint: {checkpoint_path}")
|
| 379 |
+
|
| 380 |
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
| 381 |
+
|
| 382 |
+
self.model.load_state_dict(checkpoint["model_state_dict"])
|
| 383 |
+
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
| 384 |
+
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
| 385 |
+
|
| 386 |
+
self.step = checkpoint["step"]
|
| 387 |
+
self.epoch = checkpoint["epoch"]
|
| 388 |
+
self.best_loss = checkpoint["best_loss"]
|
| 389 |
+
self.training_log = checkpoint.get("training_log", [])
|
| 390 |
+
|
| 391 |
print(f"β Checkpoint loaded successfully")
|
| 392 |
print(f" Resuming from step: {self.step:,}")
|
| 393 |
print(f" Best loss so far: {self.best_loss:.4f}")
|
| 394 |
+
|
| 395 |
def train(self) -> None:
|
| 396 |
"""Main training loop."""
|
| 397 |
print(f"\nπ Starting training...")
|
|
|
|
| 400 |
print(f" Device: {self.device}")
|
| 401 |
print(f" Max steps: {self.max_steps:,}")
|
| 402 |
print("=" * 80)
|
| 403 |
+
|
| 404 |
self.model.train()
|
| 405 |
self.start_time = time.time()
|
| 406 |
+
|
| 407 |
# Initialize gradient accumulation
|
| 408 |
accumulated_loss = 0.0
|
| 409 |
self.optimizer.zero_grad()
|
| 410 |
+
|
| 411 |
for batch_idx, (input_ids, target_ids) in enumerate(self.data_loader):
|
| 412 |
if self.step >= self.max_steps:
|
| 413 |
break
|
| 414 |
+
|
| 415 |
step_start_time = time.time()
|
| 416 |
+
|
| 417 |
# Move batch to device
|
| 418 |
input_ids = input_ids.to(self.device)
|
| 419 |
target_ids = target_ids.to(self.device)
|
| 420 |
+
|
| 421 |
# Forward pass (model computes loss internally when targets provided)
|
| 422 |
logits, loss = self.model(input_ids, target_ids)
|
| 423 |
+
|
| 424 |
# Scale loss for gradient accumulation
|
| 425 |
loss = loss / self.gradient_accumulation_steps
|
| 426 |
accumulated_loss += loss.item()
|
| 427 |
+
|
| 428 |
# Backward pass
|
| 429 |
loss.backward()
|
| 430 |
+
|
| 431 |
# Update weights every gradient_accumulation_steps
|
| 432 |
if (batch_idx + 1) % self.gradient_accumulation_steps == 0:
|
| 433 |
# Clip gradients
|
| 434 |
if self.gradient_clipping > 0:
|
| 435 |
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.gradient_clipping)
|
| 436 |
+
|
| 437 |
# Update parameters
|
| 438 |
self.optimizer.step()
|
| 439 |
self.scheduler.step()
|
| 440 |
self.optimizer.zero_grad()
|
| 441 |
+
|
| 442 |
# Update step count
|
| 443 |
self.step += 1
|
| 444 |
step_time = time.time() - step_start_time
|
| 445 |
self.step_times.append(step_time)
|
| 446 |
+
|
| 447 |
# Get current learning rate
|
| 448 |
current_lr = self.scheduler.get_last_lr()[0]
|
| 449 |
+
|
| 450 |
# Log progress
|
| 451 |
if self.step % self.log_every == 0:
|
| 452 |
avg_loss = accumulated_loss
|
| 453 |
self._log_step(self.step, avg_loss, current_lr, step_time)
|
| 454 |
+
|
| 455 |
# Save checkpoint
|
| 456 |
if self.step % self.save_every == 0:
|
| 457 |
is_best = accumulated_loss < self.best_loss
|
| 458 |
if is_best:
|
| 459 |
self.best_loss = accumulated_loss
|
| 460 |
+
|
| 461 |
self._save_checkpoint(self.step, is_best)
|
| 462 |
+
|
| 463 |
# Clean up memory periodically
|
| 464 |
if self.step % 100 == 0:
|
| 465 |
gc.collect()
|
| 466 |
+
|
| 467 |
# Reset accumulated loss
|
| 468 |
accumulated_loss = 0.0
|
| 469 |
+
|
| 470 |
# Check if training complete
|
| 471 |
if self.step >= self.max_steps:
|
| 472 |
break
|
| 473 |
+
|
| 474 |
# Final checkpoint
|
| 475 |
print(f"\nπ Training completed!")
|
| 476 |
self._save_checkpoint(self.step, is_best=True)
|
| 477 |
+
|
| 478 |
# Training summary
|
| 479 |
total_time = time.time() - self.start_time
|
| 480 |
avg_step_time = sum(self.step_times) / len(self.step_times) if self.step_times else 0
|
| 481 |
+
|
| 482 |
print(f"\nπ Training Summary:")
|
| 483 |
print(f" Steps completed: {self.step:,}")
|
| 484 |
print(f" Total time: {total_time/3600:.2f} hours")
|
|
|
|
| 508 |
--batch-size 2 \\
|
| 509 |
--max-steps 50000 \\
|
| 510 |
--output-dir models/my-medium-model
|
| 511 |
+
""",
|
| 512 |
)
|
| 513 |
+
|
| 514 |
# Model and data arguments
|
| 515 |
parser.add_argument(
|
| 516 |
"--model-size",
|
| 517 |
choices=["small", "medium", "large"],
|
| 518 |
default="small",
|
| 519 |
+
help="Model size to train (default: small)",
|
| 520 |
)
|
| 521 |
+
|
| 522 |
parser.add_argument(
|
| 523 |
"--data-file",
|
| 524 |
default="data/clean/training_data.txt",
|
| 525 |
+
help="Path to training text file (default: data/clean/training_data.txt)",
|
| 526 |
)
|
| 527 |
+
|
| 528 |
parser.add_argument(
|
| 529 |
"--tokenizer-dir",
|
| 530 |
default="data/tokenizer/",
|
| 531 |
+
help="Path to tokenizer directory (default: data/tokenizer/)",
|
| 532 |
)
|
| 533 |
+
|
| 534 |
parser.add_argument(
|
| 535 |
+
"--output-dir", required=True, help="Output directory for model checkpoints"
|
|
|
|
|
|
|
| 536 |
)
|
| 537 |
+
|
| 538 |
# Training hyperparameters
|
| 539 |
parser.add_argument(
|
| 540 |
+
"--seq-len", type=int, default=512, help="Sequence length for training (default: 512)"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 541 |
)
|
| 542 |
+
|
| 543 |
+
parser.add_argument("--batch-size", type=int, default=4, help="Batch size (default: 4)")
|
| 544 |
+
|
| 545 |
parser.add_argument(
|
| 546 |
+
"--learning-rate", type=float, default=3e-4, help="Learning rate (default: 3e-4)"
|
|
|
|
|
|
|
|
|
|
| 547 |
)
|
| 548 |
+
|
| 549 |
parser.add_argument(
|
| 550 |
+
"--max-steps", type=int, default=10000, help="Maximum training steps (default: 10000)"
|
|
|
|
|
|
|
|
|
|
| 551 |
)
|
| 552 |
+
|
| 553 |
parser.add_argument(
|
| 554 |
+
"--warmup-steps", type=int, default=1000, help="Warmup steps (default: 1000)"
|
|
|
|
|
|
|
|
|
|
| 555 |
)
|
| 556 |
+
|
| 557 |
parser.add_argument(
|
| 558 |
"--gradient-accumulation-steps",
|
| 559 |
type=int,
|
| 560 |
default=4,
|
| 561 |
+
help="Gradient accumulation steps (default: 4)",
|
| 562 |
)
|
| 563 |
+
|
| 564 |
parser.add_argument(
|
| 565 |
"--device",
|
| 566 |
choices=["cpu", "cuda", "auto"],
|
| 567 |
default="auto",
|
| 568 |
+
help="Training device (default: auto)",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 569 |
)
|
| 570 |
+
|
| 571 |
+
parser.add_argument("--resume", help="Path to checkpoint to resume training from")
|
| 572 |
+
|
| 573 |
parser.add_argument(
|
| 574 |
+
"--save-every", type=int, default=1000, help="Save checkpoint every N steps (default: 1000)"
|
|
|
|
|
|
|
|
|
|
| 575 |
)
|
| 576 |
+
|
| 577 |
args = parser.parse_args()
|
| 578 |
+
|
| 579 |
print("π OpenLLM Model Training")
|
| 580 |
print("=" * 60)
|
| 581 |
+
|
| 582 |
# Determine device
|
| 583 |
if args.device == "auto":
|
| 584 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 585 |
else:
|
| 586 |
device = args.device
|
| 587 |
+
|
| 588 |
print(f"Using device: {device}")
|
| 589 |
+
|
| 590 |
try:
|
| 591 |
# Create model
|
| 592 |
print(f"\nποΈ Creating {args.model_size} model...")
|
| 593 |
model = create_model(args.model_size)
|
| 594 |
+
|
| 595 |
# Create data loader
|
| 596 |
print(f"\nπ Setting up data loader...")
|
| 597 |
tokenizer_path = os.path.join(args.tokenizer_dir, "tokenizer.model")
|
| 598 |
+
|
| 599 |
data_loader = TextDataLoader(
|
| 600 |
data_file=args.data_file,
|
| 601 |
tokenizer_path=tokenizer_path,
|
| 602 |
seq_len=args.seq_len,
|
| 603 |
batch_size=args.batch_size,
|
| 604 |
+
shuffle=True,
|
| 605 |
)
|
| 606 |
+
|
| 607 |
# Get data statistics
|
| 608 |
data_stats = data_loader.get_data_stats()
|
| 609 |
+
|
| 610 |
# Create trainer
|
| 611 |
print(f"\nπ― Setting up trainer...")
|
| 612 |
trainer = ModelTrainer(
|
|
|
|
| 618 |
max_steps=args.max_steps,
|
| 619 |
warmup_steps=args.warmup_steps,
|
| 620 |
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 621 |
+
save_every=args.save_every,
|
| 622 |
)
|
| 623 |
+
|
| 624 |
# Resume from checkpoint if specified
|
| 625 |
if args.resume:
|
| 626 |
trainer._load_checkpoint(args.resume)
|
| 627 |
+
|
| 628 |
# Start training
|
| 629 |
trainer.train()
|
| 630 |
+
|
| 631 |
print(f"\nπ Training completed successfully!")
|
| 632 |
+
|
| 633 |
except Exception as e:
|
| 634 |
print(f"\nβ Training failed: {e}")
|
| 635 |
import traceback
|
| 636 |
+
|
| 637 |
traceback.print_exc()
|
| 638 |
return False
|
| 639 |
+
|
| 640 |
return True
|
| 641 |
|
| 642 |
|
| 643 |
if __name__ == "__main__":
|
| 644 |
+
main()
|
training/train_tokenizer.py
CHANGED
|
@@ -76,46 +76,48 @@ except ImportError:
|
|
| 76 |
def validate_input_file(input_path: str) -> None:
|
| 77 |
"""
|
| 78 |
Validate that the input training file exists and is readable.
|
| 79 |
-
|
| 80 |
Args:
|
| 81 |
input_path (str): Path to the training text file
|
| 82 |
-
|
| 83 |
Raises:
|
| 84 |
FileNotFoundError: If input file doesn't exist
|
| 85 |
ValueError: If input file is empty or unreadable
|
| 86 |
"""
|
| 87 |
if not os.path.exists(input_path):
|
| 88 |
raise FileNotFoundError(f"Training data file not found: {input_path}")
|
| 89 |
-
|
| 90 |
# Check file size and readability
|
| 91 |
file_size = os.path.getsize(input_path)
|
| 92 |
if file_size == 0:
|
| 93 |
raise ValueError(f"Training data file is empty: {input_path}")
|
| 94 |
-
|
| 95 |
# Test that we can read the file
|
| 96 |
try:
|
| 97 |
-
with open(input_path,
|
| 98 |
first_line = f.readline()
|
| 99 |
if not first_line.strip():
|
| 100 |
-
raise ValueError(
|
|
|
|
|
|
|
| 101 |
except UnicodeDecodeError as e:
|
| 102 |
raise ValueError(f"Cannot read training data file as UTF-8: {e}")
|
| 103 |
-
|
| 104 |
print(f"β Input file validated: {input_path} ({file_size:,} bytes)")
|
| 105 |
|
| 106 |
|
| 107 |
def count_training_sentences(input_path: str) -> int:
|
| 108 |
"""
|
| 109 |
Count the number of training sentences/lines in the input file.
|
| 110 |
-
|
| 111 |
Args:
|
| 112 |
input_path (str): Path to the training text file
|
| 113 |
-
|
| 114 |
Returns:
|
| 115 |
int: Number of lines in the file
|
| 116 |
"""
|
| 117 |
print("Counting training sentences...")
|
| 118 |
-
with open(input_path,
|
| 119 |
count = sum(1 for line in f if line.strip())
|
| 120 |
print(f"β Found {count:,} training sentences")
|
| 121 |
return count
|
|
@@ -133,7 +135,7 @@ def train_sentencepiece_tokenizer(
|
|
| 133 |
) -> Dict[str, Any]:
|
| 134 |
"""
|
| 135 |
Train a SentencePiece tokenizer with the specified parameters.
|
| 136 |
-
|
| 137 |
Args:
|
| 138 |
input_path (str): Path to training text file
|
| 139 |
output_dir (str): Directory to save tokenizer files
|
|
@@ -143,16 +145,16 @@ def train_sentencepiece_tokenizer(
|
|
| 143 |
max_sentence_length (int): Maximum sentence length in characters
|
| 144 |
input_sentence_size (int): Maximum number of sentences to use for training
|
| 145 |
shuffle_input_sentence (bool): Whether to shuffle input sentences
|
| 146 |
-
|
| 147 |
Returns:
|
| 148 |
Dict[str, Any]: Training statistics and configuration
|
| 149 |
"""
|
| 150 |
# Ensure output directory exists
|
| 151 |
os.makedirs(output_dir, exist_ok=True)
|
| 152 |
-
|
| 153 |
# Define output paths
|
| 154 |
model_prefix = os.path.join(output_dir, "tokenizer")
|
| 155 |
-
|
| 156 |
# SentencePiece training parameters
|
| 157 |
train_params = [
|
| 158 |
f"--input={input_path}",
|
|
@@ -163,30 +165,28 @@ def train_sentencepiece_tokenizer(
|
|
| 163 |
f"--max_sentence_length={max_sentence_length}",
|
| 164 |
f"--input_sentence_size={input_sentence_size}",
|
| 165 |
f"--shuffle_input_sentence={shuffle_input_sentence}",
|
| 166 |
-
|
| 167 |
# Special tokens for language modeling
|
| 168 |
-
"--pad_id=0",
|
| 169 |
-
"--unk_id=1",
|
| 170 |
-
"--bos_id=2",
|
| 171 |
-
"--eos_id=3",
|
| 172 |
-
|
| 173 |
# Additional useful parameters
|
| 174 |
-
"--split_by_unicode_script=true",
|
| 175 |
-
"--split_by_whitespace=true",
|
| 176 |
-
"--remove_extra_whitespaces=true",
|
| 177 |
-
"--normalization_rule_name=identity",
|
| 178 |
]
|
| 179 |
-
|
| 180 |
print(f"\nTraining SentencePiece tokenizer...")
|
| 181 |
print(f" Algorithm: {model_type.upper()}")
|
| 182 |
print(f" Vocabulary size: {vocab_size:,}")
|
| 183 |
print(f" Character coverage: {character_coverage}")
|
| 184 |
print(f" Output directory: {output_dir}")
|
| 185 |
print(f" Model files: {model_prefix}.model, {model_prefix}.vocab")
|
| 186 |
-
|
| 187 |
# Record training start time
|
| 188 |
start_time = time.time()
|
| 189 |
-
|
| 190 |
# Train the tokenizer
|
| 191 |
try:
|
| 192 |
spm.SentencePieceTrainer.train(" ".join(train_params))
|
|
@@ -194,19 +194,19 @@ def train_sentencepiece_tokenizer(
|
|
| 194 |
print(f"β Tokenizer training completed in {training_time:.1f} seconds")
|
| 195 |
except Exception as e:
|
| 196 |
raise RuntimeError(f"SentencePiece training failed: {e}")
|
| 197 |
-
|
| 198 |
# Verify output files were created
|
| 199 |
model_file = f"{model_prefix}.model"
|
| 200 |
vocab_file = f"{model_prefix}.vocab"
|
| 201 |
-
|
| 202 |
if not os.path.exists(model_file):
|
| 203 |
raise RuntimeError(f"Expected model file not created: {model_file}")
|
| 204 |
if not os.path.exists(vocab_file):
|
| 205 |
raise RuntimeError(f"Expected vocab file not created: {vocab_file}")
|
| 206 |
-
|
| 207 |
print(f"β Model file created: {model_file} ({os.path.getsize(model_file):,} bytes)")
|
| 208 |
print(f"β Vocab file created: {vocab_file} ({os.path.getsize(vocab_file):,} bytes)")
|
| 209 |
-
|
| 210 |
# Return training configuration and statistics
|
| 211 |
config = {
|
| 212 |
"model_type": model_type,
|
|
@@ -219,24 +219,24 @@ def train_sentencepiece_tokenizer(
|
|
| 219 |
"model_file": model_file,
|
| 220 |
"vocab_file": vocab_file,
|
| 221 |
}
|
| 222 |
-
|
| 223 |
return config
|
| 224 |
|
| 225 |
|
| 226 |
def test_tokenizer(model_path: str, test_sentences: list = None) -> None:
|
| 227 |
"""
|
| 228 |
Test the trained tokenizer on sample sentences to verify it works correctly.
|
| 229 |
-
|
| 230 |
Args:
|
| 231 |
model_path (str): Path to the trained .model file
|
| 232 |
test_sentences (list): Optional list of test sentences
|
| 233 |
"""
|
| 234 |
print(f"\nTesting trained tokenizer...")
|
| 235 |
-
|
| 236 |
# Load the trained tokenizer
|
| 237 |
sp = spm.SentencePieceProcessor()
|
| 238 |
sp.load(model_path)
|
| 239 |
-
|
| 240 |
# Default test sentences if none provided
|
| 241 |
if test_sentences is None:
|
| 242 |
test_sentences = [
|
|
@@ -245,63 +245,65 @@ def test_tokenizer(model_path: str, test_sentences: list = None) -> None:
|
|
| 245 |
"Machine learning and artificial intelligence are transforming technology.",
|
| 246 |
"SentencePiece tokenization works well for language models.",
|
| 247 |
]
|
| 248 |
-
|
| 249 |
print(f"Vocabulary size: {sp.vocab_size():,}")
|
| 250 |
-
print(
|
| 251 |
-
|
|
|
|
|
|
|
| 252 |
print("\nTokenization examples:")
|
| 253 |
for i, sentence in enumerate(test_sentences, 1):
|
| 254 |
# Encode to token IDs and pieces
|
| 255 |
token_ids = sp.encode(sentence)
|
| 256 |
token_pieces = sp.encode(sentence, out_type=str)
|
| 257 |
-
|
| 258 |
print(f"\n{i}. Input: {sentence}")
|
| 259 |
print(f" Tokens ({len(token_pieces)}): {token_pieces}")
|
| 260 |
print(f" IDs: {token_ids[:10]}{'...' if len(token_ids) > 10 else ''}")
|
| 261 |
-
|
| 262 |
# Test decoding
|
| 263 |
decoded = sp.decode(token_ids)
|
| 264 |
print(f" Decoded: {decoded}")
|
| 265 |
-
|
| 266 |
# Verify round-trip encoding/decoding
|
| 267 |
if decoded.strip() != sentence.strip():
|
| 268 |
print(f" β οΈ Warning: Decode mismatch!")
|
| 269 |
-
|
| 270 |
print("β Tokenizer testing completed")
|
| 271 |
|
| 272 |
|
| 273 |
def save_huggingface_config(output_dir: str, config: Dict[str, Any]) -> None:
|
| 274 |
"""
|
| 275 |
Save a Hugging Face compatible tokenizer configuration file.
|
| 276 |
-
|
| 277 |
Args:
|
| 278 |
output_dir (str): Directory containing the tokenizer files
|
| 279 |
config (Dict[str, Any]): Tokenizer configuration
|
| 280 |
"""
|
| 281 |
# Create Hugging Face tokenizer config
|
| 282 |
hf_config = {
|
| 283 |
-
"tokenizer_class": "SentencePieceTokenizer",
|
| 284 |
"model_type": config["model_type"],
|
| 285 |
"vocab_size": config["vocab_size"],
|
| 286 |
"model_file": "tokenizer.model",
|
| 287 |
"special_tokens": {
|
| 288 |
"pad_token": "<pad>",
|
| 289 |
-
"unk_token": "<unk>",
|
| 290 |
"bos_token": "<s>",
|
| 291 |
"eos_token": "</s>",
|
| 292 |
},
|
| 293 |
"special_token_ids": {
|
| 294 |
"pad_token_id": 0,
|
| 295 |
"unk_token_id": 1,
|
| 296 |
-
"bos_token_id": 2,
|
| 297 |
"eos_token_id": 3,
|
| 298 |
-
}
|
| 299 |
}
|
| 300 |
-
|
| 301 |
config_path = os.path.join(output_dir, "tokenizer_config.json")
|
| 302 |
-
with open(config_path,
|
| 303 |
json.dump(hf_config, f, indent=2, ensure_ascii=False)
|
| 304 |
-
|
| 305 |
print(f"β Hugging Face config saved: {config_path}")
|
| 306 |
|
| 307 |
|
|
@@ -322,69 +324,67 @@ Examples:
|
|
| 322 |
--model_type bpe \\
|
| 323 |
--output_dir data/tokenizer/ \\
|
| 324 |
--character_coverage 0.9995
|
| 325 |
-
"""
|
| 326 |
)
|
| 327 |
-
|
| 328 |
# Required arguments
|
| 329 |
parser.add_argument(
|
| 330 |
-
"--input",
|
| 331 |
-
required=True,
|
| 332 |
-
help="Path to training text file (e.g., data/clean/training_data.txt)"
|
| 333 |
)
|
| 334 |
-
|
| 335 |
# Optional arguments with sensible defaults
|
| 336 |
parser.add_argument(
|
| 337 |
-
"--vocab_size",
|
| 338 |
-
type=int,
|
| 339 |
default=32000,
|
| 340 |
-
help="Vocabulary size (default: 32000, recommended: 8k-64k)"
|
| 341 |
)
|
| 342 |
-
|
| 343 |
parser.add_argument(
|
| 344 |
-
"--model_type",
|
| 345 |
-
choices=["bpe", "unigram"],
|
| 346 |
default="bpe",
|
| 347 |
-
help="Tokenization algorithm (default: bpe)"
|
| 348 |
)
|
| 349 |
-
|
| 350 |
parser.add_argument(
|
| 351 |
-
"--output_dir",
|
| 352 |
default="data/tokenizer/",
|
| 353 |
-
help="Output directory for tokenizer files (default: data/tokenizer/)"
|
| 354 |
)
|
| 355 |
-
|
| 356 |
parser.add_argument(
|
| 357 |
-
"--character_coverage",
|
| 358 |
-
type=float,
|
| 359 |
default=0.9995,
|
| 360 |
-
help="Character coverage (default: 0.9995 for English)"
|
| 361 |
)
|
| 362 |
-
|
| 363 |
parser.add_argument(
|
| 364 |
-
"--max_sentence_length",
|
| 365 |
-
type=int,
|
| 366 |
default=4192,
|
| 367 |
-
help="Maximum sentence length in characters (default: 4192)"
|
| 368 |
)
|
| 369 |
-
|
| 370 |
parser.add_argument(
|
| 371 |
-
"--no_test",
|
| 372 |
-
action="store_true",
|
| 373 |
-
help="Skip tokenizer testing after training"
|
| 374 |
)
|
| 375 |
-
|
| 376 |
args = parser.parse_args()
|
| 377 |
-
|
| 378 |
print("π€ SentencePiece Tokenizer Training")
|
| 379 |
print("=" * 50)
|
| 380 |
-
|
| 381 |
try:
|
| 382 |
# Step 1: Validate input file
|
| 383 |
validate_input_file(args.input)
|
| 384 |
-
|
| 385 |
# Step 2: Count training data
|
| 386 |
sentence_count = count_training_sentences(args.input)
|
| 387 |
-
|
| 388 |
# Step 3: Train tokenizer
|
| 389 |
config = train_sentencepiece_tokenizer(
|
| 390 |
input_path=args.input,
|
|
@@ -394,36 +394,36 @@ Examples:
|
|
| 394 |
character_coverage=args.character_coverage,
|
| 395 |
max_sentence_length=args.max_sentence_length,
|
| 396 |
)
|
| 397 |
-
|
| 398 |
# Step 4: Save Hugging Face compatible config
|
| 399 |
save_huggingface_config(args.output_dir, config)
|
| 400 |
-
|
| 401 |
# Step 5: Test tokenizer (unless skipped)
|
| 402 |
if not args.no_test:
|
| 403 |
model_path = os.path.join(args.output_dir, "tokenizer.model")
|
| 404 |
test_tokenizer(model_path)
|
| 405 |
-
|
| 406 |
# Step 6: Print summary
|
| 407 |
print(f"\nπ Tokenizer training completed successfully!")
|
| 408 |
print(f"π Output directory: {args.output_dir}")
|
| 409 |
print(f"π Vocabulary size: {config['vocab_size']:,}")
|
| 410 |
print(f"β±οΈ Training time: {config['training_time_seconds']:.1f}s")
|
| 411 |
print(f"π Training sentences: {sentence_count:,}")
|
| 412 |
-
|
| 413 |
print(f"\nFiles created:")
|
| 414 |
print(f" β’ {config['model_file']} - SentencePiece model")
|
| 415 |
-
print(f" β’ {config['vocab_file']} - Vocabulary file")
|
| 416 |
print(f" β’ {os.path.join(args.output_dir, 'tokenizer_config.json')} - Hugging Face config")
|
| 417 |
-
|
| 418 |
print(f"\nTo use this tokenizer in your language model:")
|
| 419 |
print(f" import sentencepiece as spm")
|
| 420 |
print(f" sp = spm.SentencePieceProcessor()")
|
| 421 |
print(f" sp.load('{config['model_file']}')")
|
| 422 |
-
|
| 423 |
except Exception as e:
|
| 424 |
print(f"\nβ Error: {e}")
|
| 425 |
exit(1)
|
| 426 |
|
| 427 |
|
| 428 |
if __name__ == "__main__":
|
| 429 |
-
main()
|
|
|
|
| 76 |
def validate_input_file(input_path: str) -> None:
|
| 77 |
"""
|
| 78 |
Validate that the input training file exists and is readable.
|
| 79 |
+
|
| 80 |
Args:
|
| 81 |
input_path (str): Path to the training text file
|
| 82 |
+
|
| 83 |
Raises:
|
| 84 |
FileNotFoundError: If input file doesn't exist
|
| 85 |
ValueError: If input file is empty or unreadable
|
| 86 |
"""
|
| 87 |
if not os.path.exists(input_path):
|
| 88 |
raise FileNotFoundError(f"Training data file not found: {input_path}")
|
| 89 |
+
|
| 90 |
# Check file size and readability
|
| 91 |
file_size = os.path.getsize(input_path)
|
| 92 |
if file_size == 0:
|
| 93 |
raise ValueError(f"Training data file is empty: {input_path}")
|
| 94 |
+
|
| 95 |
# Test that we can read the file
|
| 96 |
try:
|
| 97 |
+
with open(input_path, "r", encoding="utf-8") as f:
|
| 98 |
first_line = f.readline()
|
| 99 |
if not first_line.strip():
|
| 100 |
+
raise ValueError(
|
| 101 |
+
f"Training data file appears to be empty or contains only whitespace"
|
| 102 |
+
)
|
| 103 |
except UnicodeDecodeError as e:
|
| 104 |
raise ValueError(f"Cannot read training data file as UTF-8: {e}")
|
| 105 |
+
|
| 106 |
print(f"β Input file validated: {input_path} ({file_size:,} bytes)")
|
| 107 |
|
| 108 |
|
| 109 |
def count_training_sentences(input_path: str) -> int:
|
| 110 |
"""
|
| 111 |
Count the number of training sentences/lines in the input file.
|
| 112 |
+
|
| 113 |
Args:
|
| 114 |
input_path (str): Path to the training text file
|
| 115 |
+
|
| 116 |
Returns:
|
| 117 |
int: Number of lines in the file
|
| 118 |
"""
|
| 119 |
print("Counting training sentences...")
|
| 120 |
+
with open(input_path, "r", encoding="utf-8") as f:
|
| 121 |
count = sum(1 for line in f if line.strip())
|
| 122 |
print(f"β Found {count:,} training sentences")
|
| 123 |
return count
|
|
|
|
| 135 |
) -> Dict[str, Any]:
|
| 136 |
"""
|
| 137 |
Train a SentencePiece tokenizer with the specified parameters.
|
| 138 |
+
|
| 139 |
Args:
|
| 140 |
input_path (str): Path to training text file
|
| 141 |
output_dir (str): Directory to save tokenizer files
|
|
|
|
| 145 |
max_sentence_length (int): Maximum sentence length in characters
|
| 146 |
input_sentence_size (int): Maximum number of sentences to use for training
|
| 147 |
shuffle_input_sentence (bool): Whether to shuffle input sentences
|
| 148 |
+
|
| 149 |
Returns:
|
| 150 |
Dict[str, Any]: Training statistics and configuration
|
| 151 |
"""
|
| 152 |
# Ensure output directory exists
|
| 153 |
os.makedirs(output_dir, exist_ok=True)
|
| 154 |
+
|
| 155 |
# Define output paths
|
| 156 |
model_prefix = os.path.join(output_dir, "tokenizer")
|
| 157 |
+
|
| 158 |
# SentencePiece training parameters
|
| 159 |
train_params = [
|
| 160 |
f"--input={input_path}",
|
|
|
|
| 165 |
f"--max_sentence_length={max_sentence_length}",
|
| 166 |
f"--input_sentence_size={input_sentence_size}",
|
| 167 |
f"--shuffle_input_sentence={shuffle_input_sentence}",
|
|
|
|
| 168 |
# Special tokens for language modeling
|
| 169 |
+
"--pad_id=0", # Padding token
|
| 170 |
+
"--unk_id=1", # Unknown token
|
| 171 |
+
"--bos_id=2", # Beginning of sequence
|
| 172 |
+
"--eos_id=3", # End of sequence
|
|
|
|
| 173 |
# Additional useful parameters
|
| 174 |
+
"--split_by_unicode_script=true", # Better handling of mixed scripts
|
| 175 |
+
"--split_by_whitespace=true", # Split on whitespace
|
| 176 |
+
"--remove_extra_whitespaces=true", # Clean up whitespace
|
| 177 |
+
"--normalization_rule_name=identity", # Keep original text as-is
|
| 178 |
]
|
| 179 |
+
|
| 180 |
print(f"\nTraining SentencePiece tokenizer...")
|
| 181 |
print(f" Algorithm: {model_type.upper()}")
|
| 182 |
print(f" Vocabulary size: {vocab_size:,}")
|
| 183 |
print(f" Character coverage: {character_coverage}")
|
| 184 |
print(f" Output directory: {output_dir}")
|
| 185 |
print(f" Model files: {model_prefix}.model, {model_prefix}.vocab")
|
| 186 |
+
|
| 187 |
# Record training start time
|
| 188 |
start_time = time.time()
|
| 189 |
+
|
| 190 |
# Train the tokenizer
|
| 191 |
try:
|
| 192 |
spm.SentencePieceTrainer.train(" ".join(train_params))
|
|
|
|
| 194 |
print(f"β Tokenizer training completed in {training_time:.1f} seconds")
|
| 195 |
except Exception as e:
|
| 196 |
raise RuntimeError(f"SentencePiece training failed: {e}")
|
| 197 |
+
|
| 198 |
# Verify output files were created
|
| 199 |
model_file = f"{model_prefix}.model"
|
| 200 |
vocab_file = f"{model_prefix}.vocab"
|
| 201 |
+
|
| 202 |
if not os.path.exists(model_file):
|
| 203 |
raise RuntimeError(f"Expected model file not created: {model_file}")
|
| 204 |
if not os.path.exists(vocab_file):
|
| 205 |
raise RuntimeError(f"Expected vocab file not created: {vocab_file}")
|
| 206 |
+
|
| 207 |
print(f"β Model file created: {model_file} ({os.path.getsize(model_file):,} bytes)")
|
| 208 |
print(f"β Vocab file created: {vocab_file} ({os.path.getsize(vocab_file):,} bytes)")
|
| 209 |
+
|
| 210 |
# Return training configuration and statistics
|
| 211 |
config = {
|
| 212 |
"model_type": model_type,
|
|
|
|
| 219 |
"model_file": model_file,
|
| 220 |
"vocab_file": vocab_file,
|
| 221 |
}
|
| 222 |
+
|
| 223 |
return config
|
| 224 |
|
| 225 |
|
| 226 |
def test_tokenizer(model_path: str, test_sentences: list = None) -> None:
|
| 227 |
"""
|
| 228 |
Test the trained tokenizer on sample sentences to verify it works correctly.
|
| 229 |
+
|
| 230 |
Args:
|
| 231 |
model_path (str): Path to the trained .model file
|
| 232 |
test_sentences (list): Optional list of test sentences
|
| 233 |
"""
|
| 234 |
print(f"\nTesting trained tokenizer...")
|
| 235 |
+
|
| 236 |
# Load the trained tokenizer
|
| 237 |
sp = spm.SentencePieceProcessor()
|
| 238 |
sp.load(model_path)
|
| 239 |
+
|
| 240 |
# Default test sentences if none provided
|
| 241 |
if test_sentences is None:
|
| 242 |
test_sentences = [
|
|
|
|
| 245 |
"Machine learning and artificial intelligence are transforming technology.",
|
| 246 |
"SentencePiece tokenization works well for language models.",
|
| 247 |
]
|
| 248 |
+
|
| 249 |
print(f"Vocabulary size: {sp.vocab_size():,}")
|
| 250 |
+
print(
|
| 251 |
+
f"Special tokens: PAD={sp.pad_id()}, UNK={sp.unk_id()}, BOS={sp.bos_id()}, EOS={sp.eos_id()}"
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
print("\nTokenization examples:")
|
| 255 |
for i, sentence in enumerate(test_sentences, 1):
|
| 256 |
# Encode to token IDs and pieces
|
| 257 |
token_ids = sp.encode(sentence)
|
| 258 |
token_pieces = sp.encode(sentence, out_type=str)
|
| 259 |
+
|
| 260 |
print(f"\n{i}. Input: {sentence}")
|
| 261 |
print(f" Tokens ({len(token_pieces)}): {token_pieces}")
|
| 262 |
print(f" IDs: {token_ids[:10]}{'...' if len(token_ids) > 10 else ''}")
|
| 263 |
+
|
| 264 |
# Test decoding
|
| 265 |
decoded = sp.decode(token_ids)
|
| 266 |
print(f" Decoded: {decoded}")
|
| 267 |
+
|
| 268 |
# Verify round-trip encoding/decoding
|
| 269 |
if decoded.strip() != sentence.strip():
|
| 270 |
print(f" β οΈ Warning: Decode mismatch!")
|
| 271 |
+
|
| 272 |
print("β Tokenizer testing completed")
|
| 273 |
|
| 274 |
|
| 275 |
def save_huggingface_config(output_dir: str, config: Dict[str, Any]) -> None:
|
| 276 |
"""
|
| 277 |
Save a Hugging Face compatible tokenizer configuration file.
|
| 278 |
+
|
| 279 |
Args:
|
| 280 |
output_dir (str): Directory containing the tokenizer files
|
| 281 |
config (Dict[str, Any]): Tokenizer configuration
|
| 282 |
"""
|
| 283 |
# Create Hugging Face tokenizer config
|
| 284 |
hf_config = {
|
| 285 |
+
"tokenizer_class": "SentencePieceTokenizer",
|
| 286 |
"model_type": config["model_type"],
|
| 287 |
"vocab_size": config["vocab_size"],
|
| 288 |
"model_file": "tokenizer.model",
|
| 289 |
"special_tokens": {
|
| 290 |
"pad_token": "<pad>",
|
| 291 |
+
"unk_token": "<unk>",
|
| 292 |
"bos_token": "<s>",
|
| 293 |
"eos_token": "</s>",
|
| 294 |
},
|
| 295 |
"special_token_ids": {
|
| 296 |
"pad_token_id": 0,
|
| 297 |
"unk_token_id": 1,
|
| 298 |
+
"bos_token_id": 2,
|
| 299 |
"eos_token_id": 3,
|
| 300 |
+
},
|
| 301 |
}
|
| 302 |
+
|
| 303 |
config_path = os.path.join(output_dir, "tokenizer_config.json")
|
| 304 |
+
with open(config_path, "w", encoding="utf-8") as f:
|
| 305 |
json.dump(hf_config, f, indent=2, ensure_ascii=False)
|
| 306 |
+
|
| 307 |
print(f"β Hugging Face config saved: {config_path}")
|
| 308 |
|
| 309 |
|
|
|
|
| 324 |
--model_type bpe \\
|
| 325 |
--output_dir data/tokenizer/ \\
|
| 326 |
--character_coverage 0.9995
|
| 327 |
+
""",
|
| 328 |
)
|
| 329 |
+
|
| 330 |
# Required arguments
|
| 331 |
parser.add_argument(
|
| 332 |
+
"--input",
|
| 333 |
+
required=True,
|
| 334 |
+
help="Path to training text file (e.g., data/clean/training_data.txt)",
|
| 335 |
)
|
| 336 |
+
|
| 337 |
# Optional arguments with sensible defaults
|
| 338 |
parser.add_argument(
|
| 339 |
+
"--vocab_size",
|
| 340 |
+
type=int,
|
| 341 |
default=32000,
|
| 342 |
+
help="Vocabulary size (default: 32000, recommended: 8k-64k)",
|
| 343 |
)
|
| 344 |
+
|
| 345 |
parser.add_argument(
|
| 346 |
+
"--model_type",
|
| 347 |
+
choices=["bpe", "unigram"],
|
| 348 |
default="bpe",
|
| 349 |
+
help="Tokenization algorithm (default: bpe)",
|
| 350 |
)
|
| 351 |
+
|
| 352 |
parser.add_argument(
|
| 353 |
+
"--output_dir",
|
| 354 |
default="data/tokenizer/",
|
| 355 |
+
help="Output directory for tokenizer files (default: data/tokenizer/)",
|
| 356 |
)
|
| 357 |
+
|
| 358 |
parser.add_argument(
|
| 359 |
+
"--character_coverage",
|
| 360 |
+
type=float,
|
| 361 |
default=0.9995,
|
| 362 |
+
help="Character coverage (default: 0.9995 for English)",
|
| 363 |
)
|
| 364 |
+
|
| 365 |
parser.add_argument(
|
| 366 |
+
"--max_sentence_length",
|
| 367 |
+
type=int,
|
| 368 |
default=4192,
|
| 369 |
+
help="Maximum sentence length in characters (default: 4192)",
|
| 370 |
)
|
| 371 |
+
|
| 372 |
parser.add_argument(
|
| 373 |
+
"--no_test", action="store_true", help="Skip tokenizer testing after training"
|
|
|
|
|
|
|
| 374 |
)
|
| 375 |
+
|
| 376 |
args = parser.parse_args()
|
| 377 |
+
|
| 378 |
print("π€ SentencePiece Tokenizer Training")
|
| 379 |
print("=" * 50)
|
| 380 |
+
|
| 381 |
try:
|
| 382 |
# Step 1: Validate input file
|
| 383 |
validate_input_file(args.input)
|
| 384 |
+
|
| 385 |
# Step 2: Count training data
|
| 386 |
sentence_count = count_training_sentences(args.input)
|
| 387 |
+
|
| 388 |
# Step 3: Train tokenizer
|
| 389 |
config = train_sentencepiece_tokenizer(
|
| 390 |
input_path=args.input,
|
|
|
|
| 394 |
character_coverage=args.character_coverage,
|
| 395 |
max_sentence_length=args.max_sentence_length,
|
| 396 |
)
|
| 397 |
+
|
| 398 |
# Step 4: Save Hugging Face compatible config
|
| 399 |
save_huggingface_config(args.output_dir, config)
|
| 400 |
+
|
| 401 |
# Step 5: Test tokenizer (unless skipped)
|
| 402 |
if not args.no_test:
|
| 403 |
model_path = os.path.join(args.output_dir, "tokenizer.model")
|
| 404 |
test_tokenizer(model_path)
|
| 405 |
+
|
| 406 |
# Step 6: Print summary
|
| 407 |
print(f"\nπ Tokenizer training completed successfully!")
|
| 408 |
print(f"π Output directory: {args.output_dir}")
|
| 409 |
print(f"π Vocabulary size: {config['vocab_size']:,}")
|
| 410 |
print(f"β±οΈ Training time: {config['training_time_seconds']:.1f}s")
|
| 411 |
print(f"π Training sentences: {sentence_count:,}")
|
| 412 |
+
|
| 413 |
print(f"\nFiles created:")
|
| 414 |
print(f" β’ {config['model_file']} - SentencePiece model")
|
| 415 |
+
print(f" β’ {config['vocab_file']} - Vocabulary file")
|
| 416 |
print(f" β’ {os.path.join(args.output_dir, 'tokenizer_config.json')} - Hugging Face config")
|
| 417 |
+
|
| 418 |
print(f"\nTo use this tokenizer in your language model:")
|
| 419 |
print(f" import sentencepiece as spm")
|
| 420 |
print(f" sp = spm.SentencePieceProcessor()")
|
| 421 |
print(f" sp.load('{config['model_file']}')")
|
| 422 |
+
|
| 423 |
except Exception as e:
|
| 424 |
print(f"\nβ Error: {e}")
|
| 425 |
exit(1)
|
| 426 |
|
| 427 |
|
| 428 |
if __name__ == "__main__":
|
| 429 |
+
main()
|