lemms commited on
Commit
6c43565
Β·
verified Β·
1 Parent(s): 483881c

feat: Sync training infrastructure from main repository

Browse files
app.py CHANGED
@@ -1,223 +1,1024 @@
1
  #!/usr/bin/env python3
2
  """
3
- OpenLLM Training Space - Main Application
4
 
5
- This is the main entry point for the Hugging Face Space.
6
- It provides a web interface for running OpenLLM training with authentication.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  Author: Louis Chua Bean Chong
9
- License: GPLv3
 
 
10
  """
11
 
12
- import os
13
- import sys
14
  import gradio as gr
 
 
 
 
 
 
 
 
 
15
  from pathlib import Path
16
 
17
- # Import our authentication and training modules
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  try:
19
- from space_auth_test import test_space_authentication
20
- from openllm_training_with_auth import OpenLLMTrainingManager
21
- MODULES_AVAILABLE = True
 
 
 
22
  except ImportError as e:
23
- MODULES_AVAILABLE = False
24
- print(f"❌ Required modules not available: {e}")
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- def create_space_interface():
28
- """Create the Gradio interface for the Space."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- def run_authentication_test():
31
- """Run the authentication test and return results."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  try:
33
- if not MODULES_AVAILABLE:
34
- return "❌ Required modules not available. Please check deployment."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- # Capture output from authentication test
37
- import io
38
- import contextlib
 
 
 
 
 
 
39
 
40
- output = io.StringIO()
41
- with contextlib.redirect_stdout(output):
42
- success = test_space_authentication()
43
 
44
- result = output.getvalue()
 
 
 
 
 
 
 
 
 
 
45
 
46
- if success:
47
- return f"βœ… Authentication Test Results:\n\n{result}"
48
- else:
49
- return f"❌ Authentication Test Failed:\n\n{result}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  except Exception as e:
52
- return f"❌ Error running authentication test: {e}"
53
 
54
- def run_training(model_size, training_steps):
55
- """Run the OpenLLM training with authentication."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  try:
57
- if not MODULES_AVAILABLE:
58
- return "❌ Required modules not available. Please check deployment."
 
59
 
60
- # Capture output from training
61
- import io
62
- import contextlib
63
 
64
- output = io.StringIO()
65
- with contextlib.redirect_stdout(output):
66
- training_manager = OpenLLMTrainingManager()
67
- repo_id = training_manager.run_training(
68
- model_size=model_size,
69
- steps=int(training_steps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
- result = output.getvalue()
 
 
73
 
74
- return f"βœ… Training Results:\n\n{result}\n\nπŸŽ‰ Model available at: https://huggingface.co/{repo_id}"
 
 
 
 
 
 
 
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  except Exception as e:
77
- return f"❌ Error running training: {e}"
78
 
79
- def check_space_environment():
80
- """Check the Space environment and configuration."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  try:
82
- # Check if we're in a Space
83
- space_vars = ["SPACE_ID", "SPACE_HOST", "SPACE_REPO_ID"]
84
- is_space = any(os.getenv(var) for var in space_vars)
 
85
 
86
- # Check HF_TOKEN
87
- hf_token = os.getenv("HF_TOKEN")
 
 
 
88
 
89
- result = "πŸ” Space Environment Check:\n\n"
 
 
 
90
 
91
- if is_space:
92
- result += "βœ… Running in Hugging Face Space environment\n"
93
- for var in space_vars:
94
- value = os.getenv(var)
95
- if value:
96
- result += f" - {var}: {value}\n"
97
- else:
98
- result += "ℹ️ Running in local environment\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
- if hf_token:
101
- result += f"βœ… HF access token found: {hf_token[:8]}...{hf_token[-4:]}\n"
102
- result += " - Source: HF access token in Space settings\n"
103
- else:
104
- result += "❌ HF access token not found\n"
105
- result += " - Please set HF_TOKEN in Space settings with HF access token\n"
 
106
 
107
- result += f"\nπŸ“ Available modules: {'βœ…' if MODULES_AVAILABLE else '❌'}"
 
 
 
108
 
109
- return result
110
 
111
  except Exception as e:
112
- return f"❌ Error checking environment: {e}"
 
 
 
 
 
 
113
 
114
- # Create the Gradio interface
115
- with gr.Blocks(title="OpenLLM Training Space", theme=gr.themes.Soft()) as interface:
116
- gr.Markdown("""
117
- # πŸš€ OpenLLM Training Space
 
 
 
 
 
 
 
 
 
118
 
119
- Welcome to the OpenLLM Training Space! This Space provides a complete environment for training OpenLLM models with automatic Hugging Face authentication and model upload.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
- ## πŸ” Authentication
122
-
123
- This Space uses HF access token for secure authentication. The HF_TOKEN is automatically available from your Space settings.
 
 
124
 
125
- ## πŸ“‹ Available Actions
 
126
 
127
- 1. **Environment Check**: Verify Space configuration and authentication
128
- 2. **Authentication Test**: Test Hugging Face authentication
129
- 3. **Run Training**: Start OpenLLM training with automatic upload
130
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
- with gr.Tab("πŸ” Environment Check"):
133
- gr.Markdown("Check the Space environment and configuration.")
134
- env_check_btn = gr.Button("Check Environment", variant="primary")
135
- env_output = gr.Textbox(label="Environment Status", lines=10, interactive=False)
136
- env_check_btn.click(check_space_environment, outputs=env_output)
137
-
138
- with gr.Tab("πŸ” Authentication Test"):
139
- gr.Markdown("Test Hugging Face authentication using HF access token.")
140
- auth_test_btn = gr.Button("Run Authentication Test", variant="primary")
141
- auth_output = gr.Textbox(label="Authentication Results", lines=15, interactive=False)
142
- auth_test_btn.click(run_authentication_test, outputs=auth_output)
143
-
144
- with gr.Tab("πŸš€ Run Training"):
145
- gr.Markdown("""
146
- Start OpenLLM training with automatic model upload.
147
-
148
- **Training Parameters:**
149
- - **Model Size**: Choose the model size (small, medium, large)
150
- - **Training Steps**: Number of training steps (default: 8000)
151
-
152
- **Expected Results:**
153
- - Training will complete successfully
154
- - Model will be uploaded to Hugging Face Hub
155
- - Repository will be created with proper model files
156
- """)
157
-
158
- with gr.Row():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  model_size = gr.Dropdown(
160
  choices=["small", "medium", "large"],
161
  value="small",
162
  label="Model Size",
163
- info="Choose the model size for training"
164
  )
165
- training_steps = gr.Number(
166
- value=8000,
167
- label="Training Steps",
168
- info="Number of training steps",
169
- minimum=1000,
170
- maximum=50000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  )
172
 
173
- train_btn = gr.Button("Start Training", variant="primary", size="lg")
174
- train_output = gr.Textbox(label="Training Results", lines=20, interactive=False)
175
-
176
- train_btn.click(
177
- run_training,
178
- inputs=[model_size, training_steps],
179
- outputs=train_output
180
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
- with gr.Tab("πŸ“š Documentation"):
183
- gr.Markdown("""
184
- ## πŸ“– Available Documentation
185
-
186
- - **HUGGINGFACE_SPACE_SETUP_GUIDE.md**: Complete setup guide
187
- - **SPACE_AUTHENTICATION_SUMMARY.md**: Authentication summary
188
- - **SPACE_READY_SUMMARY.md**: Deployment summary
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
- ## πŸ”§ Available Scripts
 
 
 
 
 
 
 
191
 
192
- - **space_auth_test.py**: Authentication verification
193
- - **openllm_training_with_auth.py**: Complete training script
194
- - **integrate_auth_into_training.py**: Integration guide
195
- - **setup_hf_space_auth.py**: Space authentication setup
196
- - **verify_space_auth.py**: Space verification script
197
 
198
- ## 🎯 Quick Start
 
 
 
 
 
 
 
 
 
 
 
199
 
200
- 1. Check the environment to verify configuration
201
- 2. Run authentication test to ensure GitHub secrets are working
202
- 3. Start training with your desired parameters
203
- 4. Monitor the training progress and model upload
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
- ## πŸ”’ Security
 
 
206
 
207
- - HF_TOKEN is securely stored in GitHub repository secrets
208
- - No hardcoded tokens in any scripts
209
- - Automatic cleanup of test repositories
210
- - Proper error handling and logging
211
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
- return interface
214
-
215
 
216
  if __name__ == "__main__":
217
- # Create and launch the interface
218
- interface = create_space_interface()
219
- interface.launch(
220
- server_name="0.0.0.0",
221
- server_port=7860,
222
- share=False
223
- )
 
1
  #!/usr/bin/env python3
2
  """
3
+ OpenLLM Training Space Application - Fixed with Uploaded Modules
4
 
5
+ This version imports OpenLLM modules from the uploaded files in the HF Space:
6
+ - Imports model.py and data_loader.py that were uploaded to the Space
7
+ - Uses OpenLLM's actual custom model architecture
8
+ - Compatible with OpenLLM's implementation
9
+
10
+ This application provides a complete training interface for OpenLLM models on Hugging Face Spaces.
11
+ It uses OpenLLM's custom GPTModel architecture instead of Hugging Face Transformers,
12
+ ensuring compatibility with the actual OpenLLM implementation.
13
+
14
+ Key Features:
15
+ - Real model training using OpenLLM's custom architecture
16
+ - SentencePiece tokenization for OpenLLM models
17
+ - Complete training pipeline with progress monitoring
18
+ - Automatic model saving and uploading to Hugging Face Hub
19
+ - Gradio 4.44.1 compatible user interface
20
+
21
+ Technical Architecture:
22
+ - Uses OpenLLM's GPTModel class (not Hugging Face Transformers)
23
+ - Imports custom modules from uploaded files in the Space
24
+ - Uses sentencepiece.SentencePieceProcessor() for tokenization
25
+ - Implements OpenLLM's training loop and optimization strategy
26
+ - Saves checkpoints in OpenLLM's format
27
 
28
  Author: Louis Chua Bean Chong
29
+ License: GPL-3.0
30
+ Version: 2.1.1
31
+ Last Updated: 2024
32
  """
33
 
 
 
34
  import gradio as gr
35
+ import torch
36
+ import torch.nn as nn
37
+ import os
38
+ import time
39
+ import math
40
+ import gc
41
+ from typing import Dict, Any, Optional
42
+ import threading
43
+ from dataclasses import dataclass
44
  from pathlib import Path
45
 
46
+ # Import OpenLLM's custom model architecture from uploaded files
47
+ # These files were uploaded to the HF Space and contain OpenLLM's actual implementation
48
+ try:
49
+ # Import from the uploaded files in the HF Space
50
+ # model.py contains GPTModel, GPTConfig, and create_model factory function
51
+ from model import GPTModel, GPTConfig, create_model
52
+ # data_loader.py contains TextDataLoader for OpenLLM's data loading approach
53
+ from data_loader import TextDataLoader
54
+ OPENLLM_AVAILABLE = True
55
+ print("βœ… OpenLLM custom model architecture imported successfully from uploaded files")
56
+ print(" - GPTModel: Custom PyTorch model architecture")
57
+ print(" - GPTConfig: Model configuration dataclass")
58
+ print(" - create_model: Factory function for model creation")
59
+ print(" - TextDataLoader: Custom data loading implementation")
60
+ except ImportError as e:
61
+ print(f"❌ OpenLLM imports failed: {e}")
62
+ print(" This indicates the uploaded OpenLLM source files are not available")
63
+ print(" The training functionality will be disabled")
64
+ OPENLLM_AVAILABLE = False
65
+
66
+ # Try to import sentencepiece - CRITICAL for OpenLLM tokenization
67
+ # OpenLLM uses SentencePiece for tokenization, not Hugging Face tokenizers
68
+ try:
69
+ import sentencepiece as spm
70
+ SENTENCEPIECE_AVAILABLE = True
71
+ print(f"βœ… SentencePiece available: {spm.__version__}")
72
+ print(" - Required for OpenLLM tokenization")
73
+ print(" - Used for loading tokenizer.model files")
74
+ except ImportError:
75
+ SENTENCEPIECE_AVAILABLE = False
76
+ print("❌ SentencePiece not available")
77
+ print(" - This will prevent tokenizer loading")
78
+ print(" - Training functionality will be limited")
79
+
80
+ # Import other dependencies for the complete training pipeline
81
  try:
82
+ from datasets import load_dataset # For loading training data from HF Hub
83
+ from huggingface_hub import HfApi, hf_hub_download # For model uploads and downloads
84
+ DEPENDENCIES_AVAILABLE = True
85
+ print("βœ… Training dependencies available")
86
+ print(" - datasets: For loading training data")
87
+ print(" - huggingface_hub: For model uploads/downloads")
88
  except ImportError as e:
89
+ print(f"❌ Dependencies not available: {e}")
90
+ print(" - This will prevent dataset loading and model uploading")
91
+ DEPENDENCIES_AVAILABLE = False
92
 
93
+ @dataclass
94
+ class TrainingConfig:
95
+ """
96
+ Configuration class for training parameters.
97
+
98
+ This dataclass encapsulates all the training hyperparameters and settings
99
+ that control the OpenLLM training process. It provides a clean interface
100
+ for passing configuration between different components of the training pipeline.
101
+
102
+ Attributes:
103
+ model_size: Size of the model to train ("small", "medium", "large")
104
+ max_steps: Maximum number of training iterations
105
+ learning_rate: Learning rate for the optimizer
106
+ batch_size: Number of samples per training batch
107
+ output_dir: Directory to save trained models and checkpoints
108
+ save_steps: Frequency of checkpoint saving (every N steps)
109
+ logging_steps: Frequency of progress logging (every N steps)
110
+ warmup_steps: Number of warmup steps for learning rate scheduling
111
+ gradient_accumulation_steps: Number of steps to accumulate gradients
112
+ """
113
+ model_size: str
114
+ max_steps: int
115
+ learning_rate: float
116
+ batch_size: int
117
+ output_dir: str = "./openllm-trained"
118
+ save_steps: int = 100
119
+ logging_steps: int = 10
120
+ warmup_steps: int = 50
121
+ gradient_accumulation_steps: int = 4
122
 
123
+ class OpenLLMTrainer:
124
+ """
125
+ Complete training implementation using OpenLLM's actual architecture.
126
+
127
+ This class handles the entire training pipeline including:
128
+ - Model loading using OpenLLM's custom GPTModel
129
+ - Tokenizer loading using sentencepiece.SentencePieceProcessor()
130
+ - Dataset preparation using OpenLLM's TextDataLoader
131
+ - Training execution using OpenLLM's approach
132
+ - Model saving and uploading to Hugging Face Hub
133
+
134
+ The trainer implements OpenLLM's actual training methodology rather than
135
+ using Hugging Face Transformers, ensuring compatibility with the real
136
+ OpenLLM implementation.
137
+
138
+ Key Features:
139
+ - Custom model architecture (GPTModel, not PreTrainedModel)
140
+ - SentencePiece tokenization (not Hugging Face tokenizers)
141
+ - OpenLLM's training loop and optimization strategy
142
+ - Gradient accumulation for memory efficiency
143
+ - Learning rate scheduling with warmup
144
+ - Automatic checkpoint saving and model uploading
145
+ """
146
 
147
+ def __init__(self):
148
+ """
149
+ Initialize the trainer with default settings.
150
+
151
+ Sets up the trainer with default values and initializes the Hugging Face
152
+ API for model uploading. All components start as None and are initialized
153
+ during the training process.
154
+ """
155
+ # Core training components - initialized during training
156
+ self.model = None # OpenLLM's GPTModel instance
157
+ self.tokenizer = None # SentencePieceProcessor instance
158
+ self.data_loader = None # OpenLLM's TextDataLoader instance
159
+ self.optimizer = None # PyTorch optimizer (AdamW)
160
+ self.scheduler = None # Learning rate scheduler
161
+
162
+ # Training state management
163
+ self.is_training = False # Flag to track training status
164
+ self.tokenizer_path = None # Path to the tokenizer.model file
165
+
166
+ # Progress tracking for UI updates
167
+ self.training_progress = {
168
+ "status": "Ready", # Current training status
169
+ "current_step": 0, # Current training step
170
+ "total_steps": 0, # Total steps to complete
171
+ "loss": 0.0, # Current training loss
172
+ "learning_rate": 0.0 # Current learning rate
173
+ }
174
+
175
+ # Initialize Hugging Face API for model uploading
176
+ # This allows the trained model to be automatically uploaded to HF Hub
177
  try:
178
+ self.hf_api = HfApi()
179
+ print("βœ… Hugging Face API initialized for model uploading")
180
+ except Exception as e:
181
+ print(f"Failed to initialize HF API: {e}")
182
+ print(" - Model uploading will be disabled")
183
+ self.hf_api = None
184
+
185
+ def load_model_and_tokenizer(self, model_size: str) -> str:
186
+ """
187
+ Load the pre-trained OpenLLM model and tokenizer using OpenLLM's approach.
188
+
189
+ This method implements OpenLLM's actual model loading strategy:
190
+ 1. Creates a new GPTModel using OpenLLM's factory function
191
+ 2. Downloads the tokenizer.model file from Hugging Face Hub
192
+ 3. Loads the tokenizer using SentencePieceProcessor
193
+ 4. Stores both components for use in training
194
+
195
+ This approach differs from Hugging Face Transformers because:
196
+ - Uses OpenLLM's custom GPTModel (not AutoModelForCausalLM)
197
+ - Uses SentencePiece directly (not AutoTokenizer)
198
+ - Downloads specific files rather than using from_pretrained()
199
+
200
+ Args:
201
+ model_size: Size of the model to load ("small", "medium", "large")
202
+ Determines which pre-trained model to download
203
 
204
+ Returns:
205
+ Status message indicating success or failure
206
+ Success: "βœ… Successfully loaded OpenLLM {model_size} model with custom architecture"
207
+ Failure: "❌ Failed to load OpenLLM model and tokenizer: {error details}"
208
+ """
209
+ try:
210
+ # Verify OpenLLM modules are available
211
+ if not OPENLLM_AVAILABLE:
212
+ return "❌ OpenLLM custom model architecture not available"
213
 
214
+ print(f"πŸ”„ Loading OpenLLM {model_size} model using custom architecture...")
215
+ print(f" - Using OpenLLM's create_model factory function")
216
+ print(f" - Not using Hugging Face Transformers")
217
 
218
+ # Step 1: Create model using OpenLLM's factory function
219
+ # This creates a fresh GPTModel instance with the specified size
220
+ try:
221
+ self.model = create_model(model_size)
222
+ print(f"βœ… OpenLLM {model_size} model created: {type(self.model).__name__}")
223
+ print(f" - Model type: {type(self.model).__name__}")
224
+ print(f" - Parameters: {self.model.get_num_params():,}")
225
+ print(f" - Architecture: Custom GPTModel (not PreTrainedModel)")
226
+ except Exception as e:
227
+ print(f"❌ Failed to create model: {e}")
228
+ return f"❌ Failed to create OpenLLM model: {str(e)}"
229
 
230
+ # Step 2: Load tokenizer using sentencepiece
231
+ # OpenLLM uses SentencePiece directly, not Hugging Face tokenizers
232
+ try:
233
+ print("πŸ”„ Loading tokenizer using sentencepiece.SentencePieceProcessor()...")
234
+ print(" - Using SentencePiece directly (not AutoTokenizer)")
235
+ print(" - Downloading tokenizer.model from Hugging Face Hub")
236
+
237
+ # Download tokenizer.model from HF Hub
238
+ # This is the actual tokenizer file used by OpenLLM models
239
+ model_name = f"lemms/openllm-{model_size}-extended-7k"
240
+ tokenizer_path = hf_hub_download(
241
+ repo_id=model_name,
242
+ filename="tokenizer.model" # Specific file name for OpenLLM
243
+ )
244
+
245
+ print(f"βœ… Tokenizer downloaded to: {tokenizer_path}")
246
+ print(f" - Source: {model_name}")
247
+ print(f" - File: tokenizer.model")
248
+
249
+ # Create SentencePieceProcessor and load the tokenizer
250
+ # This is OpenLLM's actual tokenization approach
251
+ sp_processor = spm.SentencePieceProcessor()
252
+ sp_processor.load(tokenizer_path)
253
 
254
+ # Store tokenizer and its path separately
255
+ # We need the path for the TextDataLoader later
256
+ self.tokenizer = sp_processor
257
+ self.tokenizer_path = tokenizer_path # Store the path separately
258
+
259
+ print(f"βœ… Tokenizer loaded successfully using SentencePieceProcessor")
260
+ print(f" - Vocabulary size: {sp_processor.vocab_size()}")
261
+ print(f" - Tokenizer path: {tokenizer_path}")
262
+ print(f" - Tokenizer type: {type(sp_processor).__name__}")
263
+
264
+ except Exception as e:
265
+ print(f"❌ Failed to load tokenizer: {e}")
266
+ return f"❌ Failed to load OpenLLM tokenizer: {str(e)}"
267
+
268
+ return f"βœ… Successfully loaded OpenLLM {model_size} model with custom architecture"
269
+
270
  except Exception as e:
271
+ return f"❌ Failed to load OpenLLM model and tokenizer: {str(e)}"
272
 
273
+ def prepare_dataset(self) -> str:
274
+ """
275
+ Load and prepare the training dataset using OpenLLM's approach.
276
+
277
+ This method implements OpenLLM's data preparation strategy:
278
+ 1. Loads training data from Hugging Face Hub dataset
279
+ 2. Creates a temporary text file for OpenLLM's TextDataLoader
280
+ 3. Initializes OpenLLM's TextDataLoader with the tokenizer
281
+ 4. Prepares the data for training
282
+
283
+ OpenLLM's approach differs from Hugging Face because:
284
+ - Uses a simple text file format (not tokenized datasets)
285
+ - Uses OpenLLM's TextDataLoader (not Hugging Face datasets)
286
+ - Tokenization happens on-the-fly during training
287
+
288
+ Returns:
289
+ Status message indicating success or failure
290
+ Success: "βœ… Successfully prepared dataset with {count} samples"
291
+ Failure: "❌ Failed to prepare dataset: {error details}"
292
+ """
293
  try:
294
+ # Verify dependencies are available
295
+ if not DEPENDENCIES_AVAILABLE:
296
+ return "❌ Required dependencies not available"
297
 
298
+ print("πŸ”„ Loading training dataset...")
299
+ print(" - Loading from Hugging Face Hub dataset")
300
+ print(" - Using OpenLLM's data preparation approach")
301
 
302
+ # Load dataset from HF Hub
303
+ # This contains the training text data for continuing model training
304
+ dataset = load_dataset("lemms/openllm-training-data")
305
+ print(f"βœ… Dataset loaded: {len(dataset['train'])} samples")
306
+ print(f" - Dataset: lemms/openllm-training-data")
307
+ print(f" - Samples: {len(dataset['train'])}")
308
+
309
+ # Create temporary data file for OpenLLM's TextDataLoader
310
+ # OpenLLM expects a simple text file with one text sample per line
311
+ temp_data_file = "temp_training_data.txt"
312
+ with open(temp_data_file, 'w', encoding='utf-8') as f:
313
+ for item in dataset['train']:
314
+ f.write(item['text'] + '\n')
315
+
316
+ print(f"βœ… Temporary data file created: {temp_data_file}")
317
+ print(f" - Format: One text sample per line")
318
+ print(f" - Encoding: UTF-8")
319
+
320
+ # Create OpenLLM's TextDataLoader
321
+ # This is OpenLLM's custom data loading implementation
322
+ try:
323
+ # Use the stored tokenizer path instead of trying to access model_file_path
324
+ # SentencePieceProcessor doesn't have a model_file_path attribute
325
+ tokenizer_path = self.tokenizer_path # Use the stored path
326
+
327
+ print(f"πŸ”„ Creating OpenLLM TextDataLoader...")
328
+ print(f" - Data file: {temp_data_file}")
329
+ print(f" - Tokenizer path: {tokenizer_path}")
330
+ print(f" - Sequence length: 512")
331
+ print(f" - Batch size: 4 (will be overridden by training config)")
332
+
333
+ self.data_loader = TextDataLoader(
334
+ data_file=temp_data_file,
335
+ tokenizer_path=tokenizer_path,
336
+ seq_len=512, # Maximum sequence length for training
337
+ batch_size=4, # Will be overridden by training config
338
+ shuffle=True # Shuffle data for better training
339
  )
340
+
341
+ print(f"βœ… OpenLLM TextDataLoader created successfully")
342
+ print(f" - DataLoader type: {type(self.data_loader).__name__}")
343
+ print(f" - Uses OpenLLM's custom implementation")
344
+
345
+ except Exception as e:
346
+ print(f"❌ Failed to create TextDataLoader: {e}")
347
+ return f"❌ Failed to create data loader: {str(e)}"
348
+
349
+ return f"βœ… Successfully prepared dataset with {len(dataset['train'])} samples"
350
+
351
+ except Exception as e:
352
+ return f"❌ Failed to prepare dataset: {str(e)}"
353
+
354
+ def setup_training(self, config: TrainingConfig) -> str:
355
+ """
356
+ Set up the training configuration using OpenLLM's approach.
357
+
358
+ This method configures the training environment with:
359
+ 1. Output directory creation
360
+ 2. Optimizer setup with weight decay groups
361
+ 3. Learning rate scheduler with warmup
362
+ 4. Training hyperparameters
363
+
364
+ The setup follows OpenLLM's training methodology:
365
+ - Uses AdamW optimizer with weight decay
366
+ - Implements learning rate warmup followed by cosine annealing
367
+ - Separates parameters for different weight decay rates
368
+ - Uses gradient clipping for stability
369
+
370
+ Args:
371
+ config: Training configuration object containing all hyperparameters
372
+
373
+ Returns:
374
+ Status message indicating success or failure
375
+ Success: "βœ… Training setup completed successfully"
376
+ Failure: "❌ Failed to setup training: {error details}"
377
+ """
378
+ try:
379
+ print("πŸ”„ Setting up training configuration...")
380
+ print(f" - Output directory: {config.output_dir}")
381
+ print(f" - Learning rate: {config.learning_rate}")
382
+ print(f" - Max steps: {config.max_steps}")
383
+
384
+ # Create output directory for saving models and checkpoints
385
+ os.makedirs(config.output_dir, exist_ok=True)
386
+ print(f"βœ… Output directory created: {config.output_dir}")
387
 
388
+ # Set up optimizer (AdamW with weight decay)
389
+ # This follows OpenLLM's optimization strategy
390
+ print("πŸ”„ Setting up AdamW optimizer with weight decay...")
391
 
392
+ # Separate parameters for different weight decay rates
393
+ # This is a common practice for transformer training
394
+ decay_params = [] # Parameters that should have weight decay
395
+ no_decay_params = [] # Parameters that should not have weight decay
396
+
397
+ for name, param in self.model.named_parameters():
398
+ if not param.requires_grad:
399
+ continue
400
 
401
+ # Apply weight decay to all parameters except biases and layer norm weights
402
+ if len(param.shape) == 1 or name.endswith('.bias'):
403
+ no_decay_params.append(param)
404
+ else:
405
+ decay_params.append(param)
406
+
407
+ # Create parameter groups with different weight decay rates
408
+ param_groups = [
409
+ {'params': decay_params, 'weight_decay': 0.01}, # 1% weight decay
410
+ {'params': no_decay_params, 'weight_decay': 0.0} # No weight decay
411
+ ]
412
+
413
+ print(f" - Decay parameters: {len(decay_params)}")
414
+ print(f" - No-decay parameters: {len(no_decay_params)}")
415
+
416
+ # Initialize AdamW optimizer with OpenLLM's recommended settings
417
+ self.optimizer = torch.optim.AdamW(
418
+ param_groups,
419
+ lr=config.learning_rate,
420
+ betas=(0.9, 0.95), # Beta values for momentum
421
+ eps=1e-8 # Epsilon for numerical stability
422
+ )
423
+
424
+ print(f"βœ… AdamW optimizer configured")
425
+ print(f" - Learning rate: {config.learning_rate}")
426
+ print(f" - Betas: (0.9, 0.95)")
427
+ print(f" - Epsilon: 1e-8")
428
+
429
+ # Set up learning rate scheduler
430
+ # OpenLLM uses a warmup followed by cosine annealing
431
+ print("πŸ”„ Setting up learning rate scheduler...")
432
+
433
+ # Warmup scheduler: linearly increase LR from 1% to 100%
434
+ warmup_scheduler = torch.optim.lr_scheduler.LinearLR(
435
+ self.optimizer,
436
+ start_factor=0.01, # Start at 1% of target LR
437
+ end_factor=1.0, # End at 100% of target LR
438
+ total_iters=config.warmup_steps
439
+ )
440
+
441
+ # Main scheduler: cosine annealing after warmup
442
+ main_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
443
+ self.optimizer,
444
+ T_max=config.max_steps - config.warmup_steps # Duration of cosine annealing
445
+ )
446
+
447
+ # Combine warmup and main schedulers
448
+ self.scheduler = torch.optim.lr_scheduler.SequentialLR(
449
+ self.optimizer,
450
+ schedulers=[warmup_scheduler, main_scheduler],
451
+ milestones=[config.warmup_steps] # Switch to main scheduler after warmup
452
+ )
453
+
454
+ print(f"βœ… Learning rate scheduler configured")
455
+ print(f" - Warmup steps: {config.warmup_steps}")
456
+ print(f" - Total steps: {config.max_steps}")
457
+ print(f" - Schedule: Linear warmup β†’ Cosine annealing")
458
+
459
+ print("βœ… Training setup completed successfully")
460
+ return f"βœ… Training setup completed successfully"
461
+
462
  except Exception as e:
463
+ return f"❌ Failed to setup training: {str(e)}"
464
 
465
+ def train_model(self, config: TrainingConfig, progress_callback=None) -> str:
466
+ """
467
+ Execute the actual model training using OpenLLM's approach.
468
+
469
+ This method implements OpenLLM's training loop:
470
+ 1. Sets up training mode and progress tracking
471
+ 2. Iterates through data batches using OpenLLM's TextDataLoader
472
+ 3. Performs forward pass, loss computation, and backward pass
473
+ 4. Implements gradient accumulation for memory efficiency
474
+ 5. Updates model parameters and learning rate
475
+ 6. Saves checkpoints and logs progress
476
+
477
+ The training loop follows OpenLLM's methodology:
478
+ - Uses OpenLLM's GPTModel forward pass (returns logits and loss)
479
+ - Implements gradient accumulation for effective larger batch sizes
480
+ - Uses gradient clipping for training stability
481
+ - Saves checkpoints in OpenLLM's format
482
+ - Updates progress for UI monitoring
483
+
484
+ Args:
485
+ config: Training configuration object containing hyperparameters
486
+ progress_callback: Optional callback function for progress updates
487
+ (Not used in current implementation)
488
+
489
+ Returns:
490
+ Status message indicating success or failure
491
+ Success: "βœ… Training completed successfully! Final step: {step}"
492
+ Failure: "❌ Training failed: {error details}"
493
+ """
494
  try:
495
+ # Set training state
496
+ self.is_training = True
497
+ self.training_progress["status"] = "Training"
498
+ self.training_progress["total_steps"] = config.max_steps
499
 
500
+ print(f"πŸš€ Starting OpenLLM training for {config.max_steps} steps...")
501
+ print(f" - Model: {type(self.model).__name__}")
502
+ print(f" - DataLoader: {type(self.data_loader).__name__}")
503
+ print(f" - Optimizer: {type(self.optimizer).__name__}")
504
+ print(f" - Gradient accumulation: {config.gradient_accumulation_steps}")
505
 
506
+ # Training loop using OpenLLM's approach
507
+ self.model.train() # Set model to training mode
508
+ accumulated_loss = 0.0 # Track loss across accumulation steps
509
+ self.optimizer.zero_grad() # Clear gradients
510
 
511
+ step = 0 # Current training step
512
+ for batch_idx, (input_ids, target_ids) in enumerate(self.data_loader):
513
+ # Check if we've reached the maximum number of steps
514
+ if step >= config.max_steps:
515
+ break
516
+
517
+ # Forward pass (model computes loss internally when targets provided)
518
+ # OpenLLM's GPTModel returns both logits and loss
519
+ logits, loss = self.model(input_ids, target_ids)
520
+
521
+ # Scale loss for gradient accumulation
522
+ # This allows us to simulate larger batch sizes
523
+ loss = loss / config.gradient_accumulation_steps
524
+ accumulated_loss += loss.item()
525
+
526
+ # Backward pass - compute gradients
527
+ loss.backward()
528
+
529
+ # Update weights every gradient_accumulation_steps
530
+ if (batch_idx + 1) % config.gradient_accumulation_steps == 0:
531
+ # Clip gradients for training stability
532
+ # This prevents exploding gradients
533
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
534
+
535
+ # Update parameters using the optimizer
536
+ self.optimizer.step()
537
+
538
+ # Update learning rate using the scheduler
539
+ self.scheduler.step()
540
+
541
+ # Clear gradients for the next accumulation cycle
542
+ self.optimizer.zero_grad()
543
+
544
+ # Update step count
545
+ step += 1
546
+
547
+ # Update progress for UI monitoring
548
+ self.training_progress["current_step"] = step
549
+ self.training_progress["loss"] = accumulated_loss
550
+ self.training_progress["learning_rate"] = self.scheduler.get_last_lr()[0]
551
+
552
+ # Log progress at specified intervals
553
+ if step % config.logging_steps == 0:
554
+ current_lr = self.scheduler.get_last_lr()[0]
555
+ print(f"Step {step}/{config.max_steps} | Loss: {accumulated_loss:.4f} | LR: {current_lr:.2e}")
556
+
557
+ # Save checkpoint at specified intervals
558
+ if step % config.save_steps == 0:
559
+ self._save_checkpoint(config.output_dir, step)
560
+ print(f"πŸ’Ύ Checkpoint saved at step {step}")
561
+
562
+ # Reset accumulated loss for the next accumulation cycle
563
+ accumulated_loss = 0.0
564
+
565
+ # Clean up memory periodically
566
+ if step % 100 == 0:
567
+ gc.collect()
568
+ print(f"🧹 Memory cleanup at step {step}")
569
 
570
+ # Save final checkpoint
571
+ self._save_checkpoint(config.output_dir, step, is_best=True)
572
+ print(f"πŸ’Ύ Final checkpoint saved at step {step}")
573
+
574
+ # Update final progress
575
+ self.training_progress["status"] = "Completed"
576
+ self.training_progress["current_step"] = step
577
 
578
+ print(f"βœ… Training completed! Final step: {step}")
579
+ print(f" - Total steps completed: {step}")
580
+ print(f" - Final loss: {self.training_progress['loss']:.4f}")
581
+ print(f" - Final learning rate: {self.training_progress['learning_rate']:.2e}")
582
 
583
+ return f"βœ… Training completed successfully! Final step: {step}"
584
 
585
  except Exception as e:
586
+ self.training_progress["status"] = "Failed"
587
+ print(f"❌ Training failed: {e}")
588
+ print(f" - Error occurred during training")
589
+ print(f" - Training state: {self.training_progress['status']}")
590
+ return f"❌ Training failed: {str(e)}"
591
+ finally:
592
+ self.is_training = False
593
 
594
+ def _save_checkpoint(self, output_dir: str, step: int, is_best: bool = False) -> None:
595
+ """
596
+ Save model checkpoint using OpenLLM's approach.
597
+
598
+ This method saves the model state in OpenLLM's checkpoint format:
599
+ - Model state dictionary
600
+ - Optimizer state dictionary
601
+ - Scheduler state dictionary
602
+ - Model configuration
603
+ - Training step information
604
+
605
+ The checkpoint format is compatible with OpenLLM's loading mechanism
606
+ and can be used to resume training or load the model for inference.
607
 
608
+ Args:
609
+ output_dir: Directory to save the checkpoint
610
+ step: Current training step number
611
+ is_best: Whether this is the best model so far
612
+ """
613
+ try:
614
+ # Create checkpoint dictionary with all necessary components
615
+ checkpoint = {
616
+ 'step': step, # Current training step
617
+ 'model_state_dict': self.model.state_dict(), # Model parameters
618
+ 'optimizer_state_dict': self.optimizer.state_dict(), # Optimizer state
619
+ 'scheduler_state_dict': self.scheduler.state_dict(), # Scheduler state
620
+ 'config': self.model.config.__dict__ # Model configuration
621
+ }
622
+
623
+ # Save latest checkpoint
624
+ checkpoint_path = os.path.join(output_dir, f"checkpoint_step_{step}.pt")
625
+ torch.save(checkpoint, checkpoint_path)
626
+
627
+ # Save best checkpoint if this is the best model
628
+ if is_best:
629
+ best_path = os.path.join(output_dir, "best_model.pt")
630
+ torch.save(checkpoint, best_path)
631
+ print(f"πŸ’Ύ Best model saved: {best_path}")
632
+
633
+ print(f"πŸ’Ύ Checkpoint saved: {checkpoint_path}")
634
+
635
+ except Exception as e:
636
+ print(f"❌ Failed to save checkpoint: {e}")
637
+
638
+ def save_and_upload_model(self, config: TrainingConfig) -> str:
639
+ """
640
+ Save the trained model and upload it to Hugging Face Hub.
641
 
642
+ This method completes the training pipeline by:
643
+ 1. Saving the final model checkpoint
644
+ 2. Copying the tokenizer files
645
+ 3. Uploading the complete model to Hugging Face Hub
646
+ 4. Creating a new model repository for the trained model
647
 
648
+ The uploaded model will be available at:
649
+ https://huggingface.co/lemms/openllm-{size}-extended-8k
650
 
651
+ Args:
652
+ config: Training configuration object
653
+
654
+ Returns:
655
+ Status message indicating success or failure
656
+ Success: "βœ… Model saved and uploaded to https://huggingface.co/{repo_id}"
657
+ Failure: "❌ Failed to save/upload model: {error details}"
658
+ """
659
+ try:
660
+ print("πŸ”„ Saving trained model...")
661
+ print(f" - Output directory: {config.output_dir}")
662
+ print(f" - Model size: {config.model_size}")
663
+
664
+ # Save the final model checkpoint
665
+ self._save_checkpoint(config.output_dir, config.max_steps, is_best=True)
666
+
667
+ # Save tokenizer files
668
+ # Create a tokenizer directory within the output directory
669
+ tokenizer_dir = os.path.join(config.output_dir, "tokenizer")
670
+ os.makedirs(tokenizer_dir, exist_ok=True)
671
+
672
+ # Copy the tokenizer.model file using the stored path
673
+ # This ensures the tokenizer is included with the model
674
+ import shutil
675
+ shutil.copy2(self.tokenizer_path, os.path.join(tokenizer_dir, "tokenizer.model"))
676
+
677
+ print("βœ… Model saved locally")
678
+ print(f" - Model checkpoint: {config.output_dir}/best_model.pt")
679
+ print(f" - Tokenizer: {tokenizer_dir}/tokenizer.model")
680
+
681
+ # Generate model name for upload
682
+ # The naming convention follows: openllm-{size}-extended-8k
683
+ model_name = f"openllm-{config.model_size}-extended-8k"
684
+ repo_id = f"lemms/{model_name}"
685
+
686
+ # Upload to Hugging Face Hub
687
+ if self.hf_api:
688
+ print(f"πŸ”„ Uploading model to {repo_id}...")
689
+ print(f" - Repository: {repo_id}")
690
+ print(f" - Type: model")
691
+ print(f" - Source: {config.output_dir}")
692
+
693
+ # Create the repository first if it doesn't exist
694
+ try:
695
+ from huggingface_hub import create_repo
696
+ create_repo(
697
+ repo_id=repo_id,
698
+ repo_type="model",
699
+ exist_ok=True,
700
+ private=False
701
+ )
702
+ print(f"βœ… Repository {repo_id} ready for upload")
703
+ except Exception as create_error:
704
+ print(f"⚠️ Repository creation warning: {create_error}")
705
+ print(" Continuing with upload attempt...")
706
+
707
+ # Upload model files to Hugging Face Hub
708
+ # This creates a new model repository with all the files
709
+ self.hf_api.upload_folder(
710
+ folder_path=config.output_dir,
711
+ repo_id=repo_id,
712
+ repo_type="model",
713
+ commit_message=f"Add trained OpenLLM {config.model_size} model (8k steps)"
714
+ )
715
+
716
+ print(f"βœ… Model uploaded successfully to {repo_id}")
717
+ print(f" - Available at: https://huggingface.co/{repo_id}")
718
+ return f"βœ… Model saved and uploaded to https://huggingface.co/{repo_id}"
719
+ else:
720
+ print("⚠️ Hugging Face API not available - model saved locally only")
721
+ return f"βœ… Model saved locally to {config.output_dir}"
722
+
723
+ except Exception as e:
724
+ print(f"❌ Failed to save/upload model: {e}")
725
+ return f"❌ Failed to save/upload model: {str(e)}"
726
+
727
+ def get_training_progress(self) -> Dict[str, Any]:
728
+ """
729
+ Get current training progress information.
730
+
731
+ This method returns a copy of the current training progress
732
+ for display in the Gradio UI. The progress information includes:
733
+ - Current training status
734
+ - Current step and total steps
735
+ - Current loss value
736
+ - Current learning rate
737
 
738
+ Returns:
739
+ Dictionary containing current training progress information
740
+ """
741
+ return self.training_progress.copy()
742
+
743
+ def main():
744
+ """
745
+ Main function that creates the complete Gradio application interface.
746
+
747
+ This function sets up the entire Gradio application with:
748
+ 1. Application header and status information
749
+ 2. Training configuration controls
750
+ 3. Training status and progress display
751
+ 4. Training control buttons
752
+ 5. Instructions and resource links
753
+ 6. Training function implementation
754
+
755
+ The interface provides a complete training experience for OpenLLM models
756
+ with real-time progress monitoring and comprehensive configuration options.
757
+
758
+ Returns:
759
+ Gradio Blocks interface for the training application
760
+ """
761
+
762
+ # Initialize the trainer
763
+ # This creates the OpenLLMTrainer instance that will handle all training operations
764
+ trainer = OpenLLMTrainer()
765
+
766
+ # Create the main Gradio application interface
767
+ # Using Gradio 4.44.1 with Soft theme for modern appearance
768
+ with gr.Blocks(
769
+ title="OpenLLM Training Space - Fixed with Uploaded Modules",
770
+ theme=gr.themes.Soft()
771
+ ) as demo:
772
+
773
+ # Application Header
774
+ # Provides clear identification and description of the application
775
+ gr.Markdown("# πŸš€ OpenLLM Training Space - Fixed with Uploaded Modules")
776
+ gr.Markdown("### *Uses OpenLLM's Custom Model Architecture from Uploaded Files*")
777
+ gr.Markdown("---")
778
+
779
+ # Status Information
780
+ # Shows the availability of key components and dependencies
781
+ gr.Markdown(f"**OpenLLM Available**: {'βœ… Yes' if OPENLLM_AVAILABLE else '❌ No'}")
782
+ gr.Markdown(f"**SentencePiece Available**: {'βœ… Yes' if SENTENCEPIECE_AVAILABLE else '❌ No'}")
783
+ gr.Markdown(f"**Dependencies Available**: {'βœ… Yes' if DEPENDENCIES_AVAILABLE else '❌ No'}")
784
+ gr.Markdown("**Architecture**: βœ… OpenLLM Custom GPTModel (From Uploaded Files)")
785
+
786
+ # Main Content Area
787
+ # Two-column layout for configuration and status
788
+ with gr.Row():
789
+
790
+ # Left Column: Training Configuration
791
+ # Contains all the training hyperparameters and settings
792
+ with gr.Column(scale=1):
793
+ gr.Markdown("## πŸ“Š Training Configuration")
794
+
795
+ # Model Size Selection
796
+ # Allows users to choose which base model to train from
797
  model_size = gr.Dropdown(
798
  choices=["small", "medium", "large"],
799
  value="small",
800
  label="Model Size",
801
+ info="Select the base model size to train from"
802
  )
803
+
804
+ # Training Steps Configuration
805
+ # Controls the number of training iterations
806
+ max_steps = gr.Slider(
807
+ minimum=100,
808
+ maximum=10000,
809
+ value=1000,
810
+ step=100,
811
+ label="Max Training Steps",
812
+ info="Number of training iterations (100-10,000)"
813
+ )
814
+
815
+ # Learning Rate Configuration
816
+ # Controls the learning rate for the optimizer
817
+ learning_rate = gr.Slider(
818
+ minimum=1e-5,
819
+ maximum=1e-3,
820
+ value=3e-4,
821
+ step=1e-5,
822
+ label="Learning Rate",
823
+ info="Training rate (0.00001-0.001)"
824
+ )
825
+
826
+ # Batch Size Configuration
827
+ # Controls the number of samples per training batch
828
+ batch_size = gr.Slider(
829
+ minimum=1,
830
+ maximum=16,
831
+ value=4,
832
+ step=1,
833
+ label="Batch Size",
834
+ info="Samples per training batch (1-16)"
835
  )
836
 
837
+ # Right Column: Training Status and Controls
838
+ # Contains status display and control buttons
839
+ with gr.Column(scale=1):
840
+ gr.Markdown("## 🎯 Training Status")
841
+
842
+ # Training Status Display
843
+ # Shows current training status and any error messages
844
+ status_text = gr.Textbox(
845
+ value="Ready to start training" if OPENLLM_AVAILABLE else "OpenLLM not available",
846
+ label="Current Status",
847
+ interactive=False,
848
+ lines=5,
849
+ info="Shows current training status and progress updates"
850
+ )
851
+
852
+ # Progress Information
853
+ # Displays detailed training progress in JSON format
854
+ progress_info = gr.JSON(
855
+ value=trainer.get_training_progress(),
856
+ label="Training Progress"
857
+ )
858
+
859
+ # Training Control Buttons
860
+ # Buttons to start and stop training
861
+ with gr.Row():
862
+ start_btn = gr.Button("πŸš€ Start Training", variant="primary")
863
+ stop_btn = gr.Button("⏹️ Stop Training", variant="stop")
864
 
865
+ # Instructions Section
866
+ # Provides detailed instructions for using the training interface
867
+ gr.Markdown("## πŸ“‹ OpenLLM Training Instructions")
868
+ gr.Markdown("""
869
+ This interface uses **OpenLLM's actual custom model architecture** from uploaded files:
870
+
871
+ ### **Step 1: Configure Parameters**
872
+ - **Model Size**: Select the base model to train from (small, medium, large)
873
+ - **Max Steps**: Number of training iterations (100-10,000)
874
+ - **Learning Rate**: Training rate (0.00001-0.001)
875
+ - **Batch Size**: Samples per training batch (1-16)
876
+
877
+ ### **Step 2: Start Training**
878
+ - Click "Start Training" to begin the actual training process
879
+ - Uses OpenLLM's custom GPTModel class from uploaded files
880
+ - Uses sentencepiece.SentencePieceProcessor() for tokenization
881
+ - Compatible with OpenLLM's actual implementation
882
+
883
+ ### **Step 3: Monitor Progress**
884
+ - Watch the status updates and progress information
885
+ - Training may take several minutes depending on steps
886
+ - The final model will be uploaded to Hugging Face Hub
887
+
888
+ ### **Step 4: Access Results**
889
+ - Trained models are automatically pushed to: `lemms/openllm-{size}-extended-8k`
890
+ - Check the model repository for your trained model
891
+ - Use the model for inference or further training
892
+ """)
893
+
894
+ # Resource Links Section
895
+ # Provides links to related models and resources
896
+ gr.Markdown("## πŸ”— Model Resources")
897
+ gr.Markdown("""
898
+ - [πŸ“š 7k Small Model](https://huggingface.co/lemms/openllm-small-extended-7k)
899
+ - [🎯 8k Small Model](https://huggingface.co/lemms/openllm-small-extended-8k)
900
+ - [πŸ“Š Training Dataset](https://huggingface.co/datasets/lemms/openllm-training-data)
901
+ - [πŸ“– Main Project](https://github.com/louischua/openllm)
902
+ """)
903
+
904
+ # Training Function Definition
905
+ # This function is called when the Start Training button is clicked
906
+ def start_complete_training(model_size, max_steps, learning_rate, batch_size):
907
+ """
908
+ Execute the complete training process using OpenLLM's approach.
909
 
910
+ This function orchestrates the entire training pipeline:
911
+ 1. Validates OpenLLM availability
912
+ 2. Creates training configuration
913
+ 3. Loads model and tokenizer
914
+ 4. Prepares dataset
915
+ 5. Sets up training environment
916
+ 6. Executes training
917
+ 7. Saves and uploads the trained model
918
 
919
+ The function provides comprehensive error handling and status updates
920
+ throughout the training process.
 
 
 
921
 
922
+ Args:
923
+ model_size: Size of the model to train ("small", "medium", "large")
924
+ max_steps: Maximum number of training steps
925
+ learning_rate: Learning rate for the optimizer
926
+ batch_size: Batch size for training
927
+
928
+ Returns:
929
+ Status message indicating the result of the training process
930
+ """
931
+ # Validate OpenLLM availability
932
+ if not OPENLLM_AVAILABLE:
933
+ return "❌ OpenLLM custom model architecture not available. Please check the installation."
934
 
935
+ try:
936
+ print(f"πŸš€ Starting complete training process...")
937
+ print(f" - Model size: {model_size}")
938
+ print(f" - Max steps: {max_steps}")
939
+ print(f" - Learning rate: {learning_rate}")
940
+ print(f" - Batch size: {batch_size}")
941
+
942
+ # Create training configuration
943
+ # This encapsulates all training parameters
944
+ config = TrainingConfig(
945
+ model_size=model_size,
946
+ max_steps=max_steps,
947
+ learning_rate=learning_rate,
948
+ batch_size=batch_size
949
+ )
950
+
951
+ # Step 1: Load model and tokenizer using OpenLLM's approach
952
+ print("πŸ”„ Step 1: Loading model and tokenizer...")
953
+ status = trainer.load_model_and_tokenizer(model_size)
954
+ if "❌" in status:
955
+ return status
956
+
957
+ # Step 2: Prepare dataset
958
+ print("πŸ”„ Step 2: Preparing dataset...")
959
+ status = trainer.prepare_dataset()
960
+ if "❌" in status:
961
+ return status
962
+
963
+ # Step 3: Setup training
964
+ print("πŸ”„ Step 3: Setting up training...")
965
+ status = trainer.setup_training(config)
966
+ if "❌" in status:
967
+ return status
968
+
969
+ # Step 4: Execute training
970
+ print("πŸ”„ Step 4: Executing training...")
971
+ status = trainer.train_model(config)
972
+ if "❌" in status:
973
+ return status
974
+
975
+ # Step 5: Save and upload model
976
+ print("πŸ”„ Step 5: Saving and uploading model...")
977
+ status = trainer.save_and_upload_model(config)
978
+
979
+ print("πŸŽ‰ Complete training process finished!")
980
+ return f"πŸš€ Complete training process finished!\n{status}"
981
+
982
+ except Exception as e:
983
+ print(f"❌ Training process failed: {str(e)}")
984
+ return f"❌ Training process failed: {str(e)}"
985
+
986
+ def update_progress():
987
+ """
988
+ Update the progress display.
989
 
990
+ This function is called periodically to update the progress
991
+ information displayed in the Gradio interface. It returns the
992
+ current training progress from the trainer.
993
 
994
+ Returns:
995
+ Current training progress dictionary
996
+ """
997
+ return trainer.get_training_progress()
998
+
999
+ # Connect UI Components to Functions
1000
+ # This connects the Start Training button to the training function
1001
+ start_btn.click(
1002
+ fn=start_complete_training,
1003
+ inputs=[model_size, max_steps, learning_rate, batch_size],
1004
+ outputs=[status_text]
1005
+ )
1006
+
1007
+ # Auto-refresh progress every 5 seconds during training
1008
+ # This ensures the progress display stays up to date
1009
+ demo.load(update_progress, outputs=[progress_info])
1010
+
1011
+ # Application Footer
1012
+ # Provides attribution and technical information
1013
+ gr.Markdown("---")
1014
+ gr.Markdown("**Author**: Louis Chua Bean Chong | **Project**: OpenLLM | **License**: GPL-3.0")
1015
+ gr.Markdown("**Architecture**: OpenLLM Custom GPTModel (From Uploaded Files)")
1016
+ gr.Markdown("**Tokenizer**: sentencepiece.SentencePieceProcessor()")
1017
 
1018
+ return demo
 
1019
 
1020
  if __name__ == "__main__":
1021
+ # Launch the Gradio application
1022
+ # This starts the web interface for the training application
1023
+ demo = main()
1024
+ demo.launch()
 
 
 
requirements.txt CHANGED
@@ -1,26 +1,51 @@
1
- # OpenLLM Training Space Requirements
2
- # Core dependencies for Space deployment
3
 
4
- # Hugging Face Hub for authentication and model upload
5
- huggingface_hub>=0.19.0
 
 
6
 
7
- # Gradio for web interface (latest stable version)
8
- gradio==4.44.1
 
 
 
 
 
9
 
10
- # PyTorch for model training
11
- torch>=2.0.0
12
- torchvision>=0.15.0
13
 
14
- # Transformers for model handling
15
- transformers>=4.35.0
 
 
16
 
17
- # SentencePiece for tokenization
18
- sentencepiece>=0.1.99
 
19
 
20
- # NumPy and other utilities
21
- numpy>=1.24.0
22
- pandas>=2.0.0
23
 
24
- # Additional utilities
25
- requests>=2.31.0
26
- tqdm>=4.65.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Complete Training Dependencies for OpenLLM Space - Updated for Gradio 4.44.1
2
+ # This file includes all necessary packages for real model training
3
 
4
+ # Core Machine Learning Framework
5
+ torch>=2.0.0 # PyTorch deep learning framework
6
+ torchvision>=0.15.0 # Computer vision utilities
7
+ torchaudio>=2.0.0 # Audio processing utilities
8
 
9
+ # Hugging Face Ecosystem - Complete Training Stack
10
+ transformers>=4.30.0 # Pre-trained models and training utilities
11
+ datasets>=2.12.0 # Dataset loading and processing
12
+ tokenizers>=0.13.0 # Fast tokenization library
13
+ sentencepiece>=0.1.99 # SentencePiece tokenization (CRITICAL for OpenLLM models)
14
+ huggingface_hub>=0.34.0 # Hugging Face Hub integration
15
+ accelerate>=0.20.0 # Distributed training acceleration
16
 
17
+ # User Interface Framework - Updated to 4.44.1
18
+ gradio==4.44.1 # Web UI framework for ML applications (fixed version)
 
19
 
20
+ # Data Processing and Scientific Computing
21
+ numpy>=1.24.0 # Numerical computing library
22
+ pandas>=2.0.0 # Data manipulation and analysis
23
+ scipy>=1.10.0 # Scientific computing utilities
24
 
25
+ # Progress and Monitoring
26
+ tqdm>=4.65.0 # Progress bars for long-running operations
27
+ psutil>=5.9.0 # System and process utilities
28
 
29
+ # Memory and Performance Optimization
30
+ bitsandbytes>=0.41.0 # Quantization utilities for memory efficiency
31
+ peft>=0.4.0 # Parameter-Efficient Fine-Tuning
32
 
33
+ # Logging and Debugging
34
+ wandb>=0.15.0 # Experiment tracking (optional)
35
+ tensorboard>=2.13.0 # Training visualization (optional)
36
+
37
+ # Additional Utilities
38
+ requests>=2.31.0 # HTTP library for API calls
39
+ pillow>=9.5.0 # Image processing (if needed)
40
+ matplotlib>=3.7.0 # Plotting and visualization
41
+ seaborn>=0.12.0 # Statistical data visualization
42
+
43
+ # Development and Testing (optional)
44
+ pytest>=7.4.0 # Testing framework
45
+ black>=23.0.0 # Code formatting
46
+ flake8>=6.0.0 # Code linting
47
+
48
+ # Note: These versions are compatible with Hugging Face Spaces
49
+ # and provide stable training performance for OpenLLM models
50
+ # Gradio 4.44.1 fixes compatibility issues with JSON components
51
+ # SentencePiece is CRITICAL for OpenLLM model tokenization
training/data_loader.py CHANGED
@@ -13,12 +13,12 @@
13
  Training Data Loader for Language Model Training
14
 
15
  This module provides efficient data loading and batching for training GPT-style
16
- language models. It handles text preprocessing, tokenization, and creates
17
  batches suitable for autoregressive language modeling.
18
 
19
  FEATURES:
20
  - Memory-efficient text loading with sliding window
21
- - Automatic tokenization using trained SentencePiece model
22
  - Configurable sequence length and batch size
23
  - CPU-optimized data loading for limited hardware
24
  - Support for training data validation and statistics
@@ -31,20 +31,20 @@ MEMORY OPTIMIZATION:
31
 
32
  Usage:
33
  from data_loader import TextDataLoader
34
-
35
  loader = TextDataLoader(
36
  data_file="data/clean/training_data.txt",
37
- tokenizer_path="data/tokenizer/tokenizer.model",
38
  seq_len=512,
39
  batch_size=4
40
  )
41
-
42
  for batch in loader:
43
  input_ids, targets = batch
44
  # input_ids: (batch_size, seq_len)
45
  # targets: (batch_size, seq_len) - shifted by 1 for next token prediction
46
 
47
- Author: Louis Chua Bean Chong
48
  License: GPLv3
49
  """
50
 
@@ -66,11 +66,11 @@ except ImportError:
66
  class TextDataLoader:
67
  """
68
  Efficient data loader for autoregressive language model training.
69
-
70
  This class handles loading text data, tokenizing it using SentencePiece,
71
  and creating batches suitable for next-token prediction training.
72
  """
73
-
74
  def __init__(
75
  self,
76
  data_file: str,
@@ -79,11 +79,11 @@ class TextDataLoader:
79
  batch_size: int = 4,
80
  chunk_size: int = 1000000, # Lines to read at once
81
  shuffle: bool = True,
82
- seed: int = 42
83
  ):
84
  """
85
  Initialize the data loader.
86
-
87
  Args:
88
  data_file: Path to training text file (one passage per line)
89
  tokenizer_path: Path to trained SentencePiece model
@@ -100,44 +100,44 @@ class TextDataLoader:
100
  self.chunk_size = chunk_size
101
  self.shuffle = shuffle
102
  self.seed = seed
103
-
104
  # Validate inputs
105
  self._validate_inputs()
106
-
107
  # Load tokenizer
108
  self.tokenizer = self._load_tokenizer()
109
-
110
  # Get data statistics
111
  self.total_lines = self._count_lines()
112
  self.current_line = 0
113
-
114
  # Set random seed for reproducibility
115
  random.seed(seed)
116
-
117
  print(f"πŸ“Š TextDataLoader initialized")
118
  print(f" Data file: {data_file}")
119
  print(f" Total passages: {self.total_lines:,}")
120
  print(f" Sequence length: {seq_len}")
121
  print(f" Batch size: {batch_size}")
122
  print(f" Vocabulary size: {self.tokenizer.vocab_size():,}")
123
-
124
  def _validate_inputs(self) -> None:
125
  """Validate input parameters and file paths."""
126
  if not os.path.exists(self.data_file):
127
  raise FileNotFoundError(f"Training data file not found: {self.data_file}")
128
-
129
  if not os.path.exists(self.tokenizer_path):
130
  raise FileNotFoundError(f"Tokenizer model not found: {self.tokenizer_path}")
131
-
132
  if self.seq_len <= 0:
133
  raise ValueError(f"Sequence length must be positive, got {self.seq_len}")
134
-
135
  if self.batch_size <= 0:
136
  raise ValueError(f"Batch size must be positive, got {self.batch_size}")
137
-
138
  if self.chunk_size <= 0:
139
  raise ValueError(f"Chunk size must be positive, got {self.chunk_size}")
140
-
141
  def _load_tokenizer(self) -> spm.SentencePieceProcessor:
142
  """Load the trained SentencePiece tokenizer."""
143
  try:
@@ -146,222 +146,226 @@ class TextDataLoader:
146
  return tokenizer
147
  except Exception as e:
148
  raise RuntimeError(f"Failed to load tokenizer: {e}")
149
-
150
  def _count_lines(self) -> int:
151
  """Count total number of lines in the data file."""
152
  print("πŸ“ Counting training passages...")
153
  start_time = time.time()
154
-
155
  line_count = 0
156
- with open(self.data_file, 'r', encoding='utf-8') as f:
157
  for line in f:
158
  if line.strip(): # Only count non-empty lines
159
  line_count += 1
160
-
161
  count_time = time.time() - start_time
162
  print(f"βœ“ Found {line_count:,} passages in {count_time:.1f}s")
163
-
164
  return line_count
165
-
166
  def _read_chunk(self, start_line: int = 0) -> List[str]:
167
  """
168
  Read a chunk of lines from the data file.
169
-
170
  Args:
171
  start_line: Line number to start reading from
172
-
173
  Returns:
174
  List of text passages
175
  """
176
  chunk = []
177
  current_line = 0
178
  lines_read = 0
179
-
180
- with open(self.data_file, 'r', encoding='utf-8') as f:
181
  for line in f:
182
  if current_line < start_line:
183
  current_line += 1
184
  continue
185
-
186
  text = line.strip()
187
  if text: # Only include non-empty lines
188
  chunk.append(text)
189
  lines_read += 1
190
-
191
  if lines_read >= self.chunk_size:
192
  break
193
-
194
  current_line += 1
195
-
196
  return chunk
197
-
198
  def _tokenize_texts(self, texts: List[str]) -> List[List[int]]:
199
  """
200
  Tokenize a list of text passages using SentencePiece tokenizer.
201
-
202
  This method converts raw text into token ID sequences suitable for language model training.
203
  It handles special tokens (BOS/EOS) and length constraints for efficient training.
204
-
205
  Text processing pipeline:
206
  1. Add BOS (Beginning of Sequence) token to mark sequence start
207
  2. Tokenize text using trained SentencePiece model (subword tokenization)
208
  3. Truncate sequences that exceed maximum length
209
  4. Add EOS (End of Sequence) token to mark sequence end
210
-
211
  Special token handling:
212
  - BOS token helps model learn to generate text from scratch
213
  - EOS token signals natural sequence endings
214
  - These tokens are crucial for proper autoregressive generation
215
-
216
  Args:
217
  texts: List of text passages (typically Wikipedia passages from SQUAD)
218
  Each passage should be a complete, coherent text segment
219
-
220
  Returns:
221
  List of token ID sequences, where each sequence is a list of integers
222
  representing subword tokens from the SentencePiece vocabulary
223
  """
224
  tokenized = []
225
-
226
  for text in texts:
227
  try:
228
  # Add BOS (Beginning of Sequence) token at the start
229
  # BOS token ID=2 by default in SentencePiece, signals sequence start
230
  # This helps the model learn proper sequence initialization during generation
231
  tokens = [self.tokenizer.bos_id()] + self.tokenizer.encode(text)
232
-
233
  # Truncate sequences that exceed maximum context length
234
  # Reserve one position for EOS token by using (seq_len - 1)
235
  # This ensures we never exceed the model's context window during training
236
  if len(tokens) > self.seq_len - 1:
237
- tokens = tokens[:self.seq_len - 1]
238
  # NOTE: Truncation may cut off text mid-sentence, but this is acceptable
239
  # for language modeling where the model learns from partial contexts
240
-
241
  # Add EOS (End of Sequence) token at the end
242
  # EOS token ID=1 by default in SentencePiece, signals sequence completion
243
  # This teaches the model when to stop generating text naturally
244
  tokens.append(self.tokenizer.eos_id())
245
-
246
  # Validate tokenization result
247
  if len(tokens) <= 2: # Only BOS + EOS tokens, no actual content
248
  print(f"⚠️ Skipping very short text: {text[:50]}...")
249
  continue
250
-
251
  tokenized.append(tokens)
252
-
253
  except Exception as e:
254
  # Handle tokenization errors gracefully to avoid stopping training
255
  # Common causes: encoding issues, very long texts, special characters
256
  print(f"⚠️ Failed to tokenize passage: {text[:50]}... Error: {e}")
257
  continue
258
-
259
  # Log tokenization statistics for monitoring
260
  if tokenized:
261
  avg_length = sum(len(tokens) for tokens in tokenized) / len(tokenized)
262
  print(f"πŸ“Š Tokenized {len(tokenized)} passages, avg length: {avg_length:.1f} tokens")
263
-
264
  return tokenized
265
-
266
- def _create_training_examples(self, token_sequences: List[List[int]]) -> List[Tuple[List[int], List[int]]]:
 
 
267
  """
268
  Create training examples with input and target sequences.
269
-
270
  For autoregressive training, targets are inputs shifted by one position.
271
-
272
  Args:
273
  token_sequences: List of tokenized sequences
274
-
275
  Returns:
276
  List of (input_ids, target_ids) tuples
277
  """
278
  examples = []
279
-
280
  for tokens in token_sequences:
281
  if len(tokens) < 2: # Need at least 2 tokens for input/target pair
282
  continue
283
-
284
  # For sequences longer than seq_len, create multiple examples with sliding window
285
  if len(tokens) > self.seq_len:
286
  # Create overlapping windows (50% overlap for better learning)
287
  stride = self.seq_len // 2
288
  for i in range(0, len(tokens) - self.seq_len, stride):
289
- input_ids = tokens[i:i + self.seq_len]
290
- target_ids = tokens[i + 1:i + self.seq_len + 1]
291
  examples.append((input_ids, target_ids))
292
  else:
293
  # Pad shorter sequences
294
  input_ids = tokens[:-1] # All but last token
295
  target_ids = tokens[1:] # All but first token
296
-
297
  # Pad to seq_len if necessary
298
  while len(input_ids) < self.seq_len:
299
  input_ids.append(self.tokenizer.pad_id())
300
  target_ids.append(-1) # Use -1 for padding in targets (ignored in loss)
301
-
302
  # Truncate if still too long
303
- input_ids = input_ids[:self.seq_len]
304
- target_ids = target_ids[:self.seq_len]
305
-
306
  examples.append((input_ids, target_ids))
307
-
308
  return examples
309
-
310
- def _create_batch(self, examples: List[Tuple[List[int], List[int]]]) -> Tuple[torch.Tensor, torch.Tensor]:
 
 
311
  """
312
  Create a batch tensor from training examples.
313
-
314
  Args:
315
  examples: List of (input_ids, target_ids) tuples
316
-
317
  Returns:
318
  Tuple of (input_tensor, target_tensor)
319
  """
320
  if not examples:
321
  raise ValueError("Cannot create batch from empty examples")
322
-
323
  batch_size = len(examples)
324
-
325
  # Initialize tensors
326
  input_ids = torch.zeros((batch_size, self.seq_len), dtype=torch.long)
327
  target_ids = torch.full((batch_size, self.seq_len), -1, dtype=torch.long)
328
-
329
  # Fill tensors
330
  for i, (inp, tgt) in enumerate(examples):
331
- input_ids[i, :len(inp)] = torch.tensor(inp, dtype=torch.long)
332
- target_ids[i, :len(tgt)] = torch.tensor(tgt, dtype=torch.long)
333
-
334
  return input_ids, target_ids
335
-
336
  def __iter__(self) -> Iterator[Tuple[torch.Tensor, torch.Tensor]]:
337
  """
338
  Iterate over training batches.
339
-
340
  Yields:
341
  Tuple of (input_ids, target_ids) tensors
342
  """
343
  self.current_line = 0
344
-
345
  while self.current_line < self.total_lines:
346
  # Read chunk of text
347
  texts = self._read_chunk(self.current_line)
348
  if not texts:
349
  break
350
-
351
  # Tokenize texts
352
  token_sequences = self._tokenize_texts(texts)
353
-
354
  # Create training examples
355
  examples = self._create_training_examples(token_sequences)
356
-
357
  # Shuffle examples if requested
358
  if self.shuffle:
359
  random.shuffle(examples)
360
-
361
  # Create batches
362
  for i in range(0, len(examples), self.batch_size):
363
- batch_examples = examples[i:i + self.batch_size]
364
-
365
  if len(batch_examples) == self.batch_size: # Only yield full batches
366
  try:
367
  input_ids, target_ids = self._create_batch(batch_examples)
@@ -369,27 +373,27 @@ class TextDataLoader:
369
  except Exception as e:
370
  print(f"⚠️ Failed to create batch: {e}")
371
  continue
372
-
373
  # Update progress
374
  self.current_line += len(texts)
375
-
376
  # Clean up memory
377
  del texts, token_sequences, examples
378
  gc.collect()
379
-
380
  def get_data_stats(self) -> dict:
381
  """
382
  Get statistics about the training data.
383
-
384
  Returns:
385
  Dictionary with data statistics
386
  """
387
  print("πŸ“Š Analyzing training data...")
388
-
389
  # Sample some data to get statistics
390
  sample_texts = self._read_chunk(0)[:100] # Sample first 100 passages
391
  token_sequences = self._tokenize_texts(sample_texts)
392
-
393
  if token_sequences:
394
  sequence_lengths = [len(seq) for seq in token_sequences]
395
  avg_length = sum(sequence_lengths) / len(sequence_lengths)
@@ -397,15 +401,15 @@ class TextDataLoader:
397
  min_length = min(sequence_lengths)
398
  else:
399
  avg_length = max_length = min_length = 0
400
-
401
  # Estimate total tokens
402
  estimated_total_tokens = int(avg_length * self.total_lines)
403
-
404
  # Estimate number of batches per epoch
405
  examples_per_passage = max(1, avg_length // self.seq_len)
406
  total_examples = int(self.total_lines * examples_per_passage)
407
  batches_per_epoch = total_examples // self.batch_size
408
-
409
  stats = {
410
  "total_passages": self.total_lines,
411
  "avg_tokens_per_passage": avg_length,
@@ -416,22 +420,22 @@ class TextDataLoader:
416
  "estimated_batches_per_epoch": batches_per_epoch,
417
  "sequence_length": self.seq_len,
418
  "batch_size": self.batch_size,
419
- "vocabulary_size": self.tokenizer.vocab_size()
420
  }
421
-
422
  print(f"βœ“ Data analysis complete:")
423
  print(f" Total passages: {stats['total_passages']:,}")
424
  print(f" Avg tokens per passage: {stats['avg_tokens_per_passage']:.1f}")
425
  print(f" Estimated total tokens: {stats['estimated_total_tokens']:,}")
426
  print(f" Estimated batches per epoch: {stats['estimated_batches_per_epoch']:,}")
427
-
428
  return stats
429
 
430
 
431
  def test_data_loader():
432
  """Test function for the data loader."""
433
  print("πŸ§ͺ Testing TextDataLoader...")
434
-
435
  # Test with small parameters
436
  try:
437
  loader = TextDataLoader(
@@ -439,42 +443,43 @@ def test_data_loader():
439
  tokenizer_path="data/tokenizer/tokenizer.model",
440
  seq_len=128,
441
  batch_size=2,
442
- chunk_size=10 # Small for testing
443
  )
444
-
445
  # Get data statistics
446
  stats = loader.get_data_stats()
447
-
448
  # Test iteration
449
  print("\nπŸ”„ Testing batch iteration...")
450
  start_time = time.time()
451
  batch_count = 0
452
-
453
  for batch_idx, (input_ids, target_ids) in enumerate(loader):
454
  batch_count += 1
455
-
456
  print(f"Batch {batch_idx + 1}:")
457
  print(f" Input shape: {input_ids.shape}")
458
  print(f" Target shape: {target_ids.shape}")
459
  print(f" Sample input tokens: {input_ids[0][:10].tolist()}")
460
  print(f" Sample target tokens: {target_ids[0][:10].tolist()}")
461
-
462
  if batch_idx >= 2: # Only test first few batches
463
  break
464
-
465
  test_time = time.time() - start_time
466
  print(f"\nβœ“ Data loader test completed successfully!")
467
  print(f" Processed {batch_count} batches in {test_time:.2f}s")
468
  print(f" Average time per batch: {test_time/max(1, batch_count):.2f}s")
469
-
470
  return True
471
-
472
  except Exception as e:
473
  print(f"❌ Data loader test failed: {e}")
474
  import traceback
 
475
  traceback.print_exc()
476
  return False
477
 
478
 
479
  if __name__ == "__main__":
480
- test_data_loader()
 
13
  Training Data Loader for Language Model Training
14
 
15
  This module provides efficient data loading and batching for training GPT-style
16
+ language models. It handles text preprocessing, tokenization, and creates
17
  batches suitable for autoregressive language modeling.
18
 
19
  FEATURES:
20
  - Memory-efficient text loading with sliding window
21
+ - Automatic tokenization using trained SentencePiece model
22
  - Configurable sequence length and batch size
23
  - CPU-optimized data loading for limited hardware
24
  - Support for training data validation and statistics
 
31
 
32
  Usage:
33
  from data_loader import TextDataLoader
34
+
35
  loader = TextDataLoader(
36
  data_file="data/clean/training_data.txt",
37
+ tokenizer_path="data/tokenizer/tokenizer.model",
38
  seq_len=512,
39
  batch_size=4
40
  )
41
+
42
  for batch in loader:
43
  input_ids, targets = batch
44
  # input_ids: (batch_size, seq_len)
45
  # targets: (batch_size, seq_len) - shifted by 1 for next token prediction
46
 
47
+ Author: Louis Chua Bean Chong
48
  License: GPLv3
49
  """
50
 
 
66
  class TextDataLoader:
67
  """
68
  Efficient data loader for autoregressive language model training.
69
+
70
  This class handles loading text data, tokenizing it using SentencePiece,
71
  and creating batches suitable for next-token prediction training.
72
  """
73
+
74
  def __init__(
75
  self,
76
  data_file: str,
 
79
  batch_size: int = 4,
80
  chunk_size: int = 1000000, # Lines to read at once
81
  shuffle: bool = True,
82
+ seed: int = 42,
83
  ):
84
  """
85
  Initialize the data loader.
86
+
87
  Args:
88
  data_file: Path to training text file (one passage per line)
89
  tokenizer_path: Path to trained SentencePiece model
 
100
  self.chunk_size = chunk_size
101
  self.shuffle = shuffle
102
  self.seed = seed
103
+
104
  # Validate inputs
105
  self._validate_inputs()
106
+
107
  # Load tokenizer
108
  self.tokenizer = self._load_tokenizer()
109
+
110
  # Get data statistics
111
  self.total_lines = self._count_lines()
112
  self.current_line = 0
113
+
114
  # Set random seed for reproducibility
115
  random.seed(seed)
116
+
117
  print(f"πŸ“Š TextDataLoader initialized")
118
  print(f" Data file: {data_file}")
119
  print(f" Total passages: {self.total_lines:,}")
120
  print(f" Sequence length: {seq_len}")
121
  print(f" Batch size: {batch_size}")
122
  print(f" Vocabulary size: {self.tokenizer.vocab_size():,}")
123
+
124
  def _validate_inputs(self) -> None:
125
  """Validate input parameters and file paths."""
126
  if not os.path.exists(self.data_file):
127
  raise FileNotFoundError(f"Training data file not found: {self.data_file}")
128
+
129
  if not os.path.exists(self.tokenizer_path):
130
  raise FileNotFoundError(f"Tokenizer model not found: {self.tokenizer_path}")
131
+
132
  if self.seq_len <= 0:
133
  raise ValueError(f"Sequence length must be positive, got {self.seq_len}")
134
+
135
  if self.batch_size <= 0:
136
  raise ValueError(f"Batch size must be positive, got {self.batch_size}")
137
+
138
  if self.chunk_size <= 0:
139
  raise ValueError(f"Chunk size must be positive, got {self.chunk_size}")
140
+
141
  def _load_tokenizer(self) -> spm.SentencePieceProcessor:
142
  """Load the trained SentencePiece tokenizer."""
143
  try:
 
146
  return tokenizer
147
  except Exception as e:
148
  raise RuntimeError(f"Failed to load tokenizer: {e}")
149
+
150
  def _count_lines(self) -> int:
151
  """Count total number of lines in the data file."""
152
  print("πŸ“ Counting training passages...")
153
  start_time = time.time()
154
+
155
  line_count = 0
156
+ with open(self.data_file, "r", encoding="utf-8") as f:
157
  for line in f:
158
  if line.strip(): # Only count non-empty lines
159
  line_count += 1
160
+
161
  count_time = time.time() - start_time
162
  print(f"βœ“ Found {line_count:,} passages in {count_time:.1f}s")
163
+
164
  return line_count
165
+
166
  def _read_chunk(self, start_line: int = 0) -> List[str]:
167
  """
168
  Read a chunk of lines from the data file.
169
+
170
  Args:
171
  start_line: Line number to start reading from
172
+
173
  Returns:
174
  List of text passages
175
  """
176
  chunk = []
177
  current_line = 0
178
  lines_read = 0
179
+
180
+ with open(self.data_file, "r", encoding="utf-8") as f:
181
  for line in f:
182
  if current_line < start_line:
183
  current_line += 1
184
  continue
185
+
186
  text = line.strip()
187
  if text: # Only include non-empty lines
188
  chunk.append(text)
189
  lines_read += 1
190
+
191
  if lines_read >= self.chunk_size:
192
  break
193
+
194
  current_line += 1
195
+
196
  return chunk
197
+
198
  def _tokenize_texts(self, texts: List[str]) -> List[List[int]]:
199
  """
200
  Tokenize a list of text passages using SentencePiece tokenizer.
201
+
202
  This method converts raw text into token ID sequences suitable for language model training.
203
  It handles special tokens (BOS/EOS) and length constraints for efficient training.
204
+
205
  Text processing pipeline:
206
  1. Add BOS (Beginning of Sequence) token to mark sequence start
207
  2. Tokenize text using trained SentencePiece model (subword tokenization)
208
  3. Truncate sequences that exceed maximum length
209
  4. Add EOS (End of Sequence) token to mark sequence end
210
+
211
  Special token handling:
212
  - BOS token helps model learn to generate text from scratch
213
  - EOS token signals natural sequence endings
214
  - These tokens are crucial for proper autoregressive generation
215
+
216
  Args:
217
  texts: List of text passages (typically Wikipedia passages from SQUAD)
218
  Each passage should be a complete, coherent text segment
219
+
220
  Returns:
221
  List of token ID sequences, where each sequence is a list of integers
222
  representing subword tokens from the SentencePiece vocabulary
223
  """
224
  tokenized = []
225
+
226
  for text in texts:
227
  try:
228
  # Add BOS (Beginning of Sequence) token at the start
229
  # BOS token ID=2 by default in SentencePiece, signals sequence start
230
  # This helps the model learn proper sequence initialization during generation
231
  tokens = [self.tokenizer.bos_id()] + self.tokenizer.encode(text)
232
+
233
  # Truncate sequences that exceed maximum context length
234
  # Reserve one position for EOS token by using (seq_len - 1)
235
  # This ensures we never exceed the model's context window during training
236
  if len(tokens) > self.seq_len - 1:
237
+ tokens = tokens[: self.seq_len - 1]
238
  # NOTE: Truncation may cut off text mid-sentence, but this is acceptable
239
  # for language modeling where the model learns from partial contexts
240
+
241
  # Add EOS (End of Sequence) token at the end
242
  # EOS token ID=1 by default in SentencePiece, signals sequence completion
243
  # This teaches the model when to stop generating text naturally
244
  tokens.append(self.tokenizer.eos_id())
245
+
246
  # Validate tokenization result
247
  if len(tokens) <= 2: # Only BOS + EOS tokens, no actual content
248
  print(f"⚠️ Skipping very short text: {text[:50]}...")
249
  continue
250
+
251
  tokenized.append(tokens)
252
+
253
  except Exception as e:
254
  # Handle tokenization errors gracefully to avoid stopping training
255
  # Common causes: encoding issues, very long texts, special characters
256
  print(f"⚠️ Failed to tokenize passage: {text[:50]}... Error: {e}")
257
  continue
258
+
259
  # Log tokenization statistics for monitoring
260
  if tokenized:
261
  avg_length = sum(len(tokens) for tokens in tokenized) / len(tokenized)
262
  print(f"πŸ“Š Tokenized {len(tokenized)} passages, avg length: {avg_length:.1f} tokens")
263
+
264
  return tokenized
265
+
266
+ def _create_training_examples(
267
+ self, token_sequences: List[List[int]]
268
+ ) -> List[Tuple[List[int], List[int]]]:
269
  """
270
  Create training examples with input and target sequences.
271
+
272
  For autoregressive training, targets are inputs shifted by one position.
273
+
274
  Args:
275
  token_sequences: List of tokenized sequences
276
+
277
  Returns:
278
  List of (input_ids, target_ids) tuples
279
  """
280
  examples = []
281
+
282
  for tokens in token_sequences:
283
  if len(tokens) < 2: # Need at least 2 tokens for input/target pair
284
  continue
285
+
286
  # For sequences longer than seq_len, create multiple examples with sliding window
287
  if len(tokens) > self.seq_len:
288
  # Create overlapping windows (50% overlap for better learning)
289
  stride = self.seq_len // 2
290
  for i in range(0, len(tokens) - self.seq_len, stride):
291
+ input_ids = tokens[i : i + self.seq_len]
292
+ target_ids = tokens[i + 1 : i + self.seq_len + 1]
293
  examples.append((input_ids, target_ids))
294
  else:
295
  # Pad shorter sequences
296
  input_ids = tokens[:-1] # All but last token
297
  target_ids = tokens[1:] # All but first token
298
+
299
  # Pad to seq_len if necessary
300
  while len(input_ids) < self.seq_len:
301
  input_ids.append(self.tokenizer.pad_id())
302
  target_ids.append(-1) # Use -1 for padding in targets (ignored in loss)
303
+
304
  # Truncate if still too long
305
+ input_ids = input_ids[: self.seq_len]
306
+ target_ids = target_ids[: self.seq_len]
307
+
308
  examples.append((input_ids, target_ids))
309
+
310
  return examples
311
+
312
+ def _create_batch(
313
+ self, examples: List[Tuple[List[int], List[int]]]
314
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
315
  """
316
  Create a batch tensor from training examples.
317
+
318
  Args:
319
  examples: List of (input_ids, target_ids) tuples
320
+
321
  Returns:
322
  Tuple of (input_tensor, target_tensor)
323
  """
324
  if not examples:
325
  raise ValueError("Cannot create batch from empty examples")
326
+
327
  batch_size = len(examples)
328
+
329
  # Initialize tensors
330
  input_ids = torch.zeros((batch_size, self.seq_len), dtype=torch.long)
331
  target_ids = torch.full((batch_size, self.seq_len), -1, dtype=torch.long)
332
+
333
  # Fill tensors
334
  for i, (inp, tgt) in enumerate(examples):
335
+ input_ids[i, : len(inp)] = torch.tensor(inp, dtype=torch.long)
336
+ target_ids[i, : len(tgt)] = torch.tensor(tgt, dtype=torch.long)
337
+
338
  return input_ids, target_ids
339
+
340
  def __iter__(self) -> Iterator[Tuple[torch.Tensor, torch.Tensor]]:
341
  """
342
  Iterate over training batches.
343
+
344
  Yields:
345
  Tuple of (input_ids, target_ids) tensors
346
  """
347
  self.current_line = 0
348
+
349
  while self.current_line < self.total_lines:
350
  # Read chunk of text
351
  texts = self._read_chunk(self.current_line)
352
  if not texts:
353
  break
354
+
355
  # Tokenize texts
356
  token_sequences = self._tokenize_texts(texts)
357
+
358
  # Create training examples
359
  examples = self._create_training_examples(token_sequences)
360
+
361
  # Shuffle examples if requested
362
  if self.shuffle:
363
  random.shuffle(examples)
364
+
365
  # Create batches
366
  for i in range(0, len(examples), self.batch_size):
367
+ batch_examples = examples[i : i + self.batch_size]
368
+
369
  if len(batch_examples) == self.batch_size: # Only yield full batches
370
  try:
371
  input_ids, target_ids = self._create_batch(batch_examples)
 
373
  except Exception as e:
374
  print(f"⚠️ Failed to create batch: {e}")
375
  continue
376
+
377
  # Update progress
378
  self.current_line += len(texts)
379
+
380
  # Clean up memory
381
  del texts, token_sequences, examples
382
  gc.collect()
383
+
384
  def get_data_stats(self) -> dict:
385
  """
386
  Get statistics about the training data.
387
+
388
  Returns:
389
  Dictionary with data statistics
390
  """
391
  print("πŸ“Š Analyzing training data...")
392
+
393
  # Sample some data to get statistics
394
  sample_texts = self._read_chunk(0)[:100] # Sample first 100 passages
395
  token_sequences = self._tokenize_texts(sample_texts)
396
+
397
  if token_sequences:
398
  sequence_lengths = [len(seq) for seq in token_sequences]
399
  avg_length = sum(sequence_lengths) / len(sequence_lengths)
 
401
  min_length = min(sequence_lengths)
402
  else:
403
  avg_length = max_length = min_length = 0
404
+
405
  # Estimate total tokens
406
  estimated_total_tokens = int(avg_length * self.total_lines)
407
+
408
  # Estimate number of batches per epoch
409
  examples_per_passage = max(1, avg_length // self.seq_len)
410
  total_examples = int(self.total_lines * examples_per_passage)
411
  batches_per_epoch = total_examples // self.batch_size
412
+
413
  stats = {
414
  "total_passages": self.total_lines,
415
  "avg_tokens_per_passage": avg_length,
 
420
  "estimated_batches_per_epoch": batches_per_epoch,
421
  "sequence_length": self.seq_len,
422
  "batch_size": self.batch_size,
423
+ "vocabulary_size": self.tokenizer.vocab_size(),
424
  }
425
+
426
  print(f"βœ“ Data analysis complete:")
427
  print(f" Total passages: {stats['total_passages']:,}")
428
  print(f" Avg tokens per passage: {stats['avg_tokens_per_passage']:.1f}")
429
  print(f" Estimated total tokens: {stats['estimated_total_tokens']:,}")
430
  print(f" Estimated batches per epoch: {stats['estimated_batches_per_epoch']:,}")
431
+
432
  return stats
433
 
434
 
435
  def test_data_loader():
436
  """Test function for the data loader."""
437
  print("πŸ§ͺ Testing TextDataLoader...")
438
+
439
  # Test with small parameters
440
  try:
441
  loader = TextDataLoader(
 
443
  tokenizer_path="data/tokenizer/tokenizer.model",
444
  seq_len=128,
445
  batch_size=2,
446
+ chunk_size=10, # Small for testing
447
  )
448
+
449
  # Get data statistics
450
  stats = loader.get_data_stats()
451
+
452
  # Test iteration
453
  print("\nπŸ”„ Testing batch iteration...")
454
  start_time = time.time()
455
  batch_count = 0
456
+
457
  for batch_idx, (input_ids, target_ids) in enumerate(loader):
458
  batch_count += 1
459
+
460
  print(f"Batch {batch_idx + 1}:")
461
  print(f" Input shape: {input_ids.shape}")
462
  print(f" Target shape: {target_ids.shape}")
463
  print(f" Sample input tokens: {input_ids[0][:10].tolist()}")
464
  print(f" Sample target tokens: {target_ids[0][:10].tolist()}")
465
+
466
  if batch_idx >= 2: # Only test first few batches
467
  break
468
+
469
  test_time = time.time() - start_time
470
  print(f"\nβœ“ Data loader test completed successfully!")
471
  print(f" Processed {batch_count} batches in {test_time:.2f}s")
472
  print(f" Average time per batch: {test_time/max(1, batch_count):.2f}s")
473
+
474
  return True
475
+
476
  except Exception as e:
477
  print(f"❌ Data loader test failed: {e}")
478
  import traceback
479
+
480
  traceback.print_exc()
481
  return False
482
 
483
 
484
  if __name__ == "__main__":
485
+ test_data_loader()
training/evaluate_model.py CHANGED
@@ -56,20 +56,15 @@ from data_loader import TextDataLoader
56
  class ModelEvaluator:
57
  """
58
  Comprehensive evaluator for OpenLLM models.
59
-
60
  Implements intrinsic evaluation metrics and text generation quality
61
  assessment following the training pipeline specifications.
62
  """
63
-
64
- def __init__(
65
- self,
66
- model: GPTModel,
67
- tokenizer_path: str,
68
- device: str = "cpu"
69
- ):
70
  """
71
  Initialize the model evaluator.
72
-
73
  Args:
74
  model: Trained GPT model
75
  tokenizer_path: Path to tokenizer model file
@@ -77,30 +72,27 @@ class ModelEvaluator:
77
  """
78
  self.model = model.to(device)
79
  self.device = device
80
-
81
  # Load tokenizer
82
  self.tokenizer = smp.SentencePieceProcessor()
83
  self.tokenizer.load(tokenizer_path)
84
-
85
  print(f"πŸ”§ ModelEvaluator initialized")
86
  print(f" Device: {device}")
87
  print(f" Model parameters: {model.get_num_params():,}")
88
  print(f" Vocabulary size: {self.tokenizer.vocab_size():,}")
89
-
90
  def evaluate_perplexity(
91
- self,
92
- eval_data: List[str],
93
- max_seq_len: int = 512,
94
- batch_size: int = 1
95
  ) -> Dict[str, float]:
96
  """
97
  Calculate perplexity on evaluation data.
98
-
99
  Args:
100
  eval_data: List of text passages for evaluation
101
  max_seq_len: Maximum sequence length for evaluation
102
  batch_size: Batch size for evaluation
103
-
104
  Returns:
105
  Dictionary with loss and perplexity metrics
106
  """
@@ -108,400 +100,394 @@ class ModelEvaluator:
108
  total_loss = 0.0
109
  total_tokens = 0
110
  num_sequences = 0
111
-
112
  print(f"πŸ“Š Calculating perplexity on {len(eval_data)} passages...")
113
-
114
  with torch.no_grad():
115
  for i, text in enumerate(eval_data):
116
  if i % 100 == 0:
117
  print(f" Progress: {i}/{len(eval_data)} passages")
118
-
119
  # Tokenize text
120
  tokens = self.tokenizer.encode(text)
121
  if len(tokens) < 2:
122
  continue
123
-
124
  # Truncate if too long
125
  if len(tokens) > max_seq_len:
126
  tokens = tokens[:max_seq_len]
127
-
128
  # Create input and target tensors
129
  input_ids = torch.tensor([tokens[:-1]], dtype=torch.long, device=self.device)
130
  target_ids = torch.tensor([tokens[1:]], dtype=torch.long, device=self.device)
131
-
132
  # Forward pass
133
  logits, loss = self.model(input_ids, target_ids)
134
-
135
  # Accumulate loss
136
  seq_length = len(tokens) - 1
137
  total_loss += loss.item() * seq_length
138
  total_tokens += seq_length
139
  num_sequences += 1
140
-
141
  # Calculate metrics
142
- avg_loss = total_loss / total_tokens if total_tokens > 0 else float('inf')
143
  perplexity = math.exp(min(avg_loss, 10)) # Cap to prevent overflow
144
-
145
  return {
146
- 'loss': avg_loss,
147
- 'perplexity': perplexity,
148
- 'total_tokens': total_tokens,
149
- 'num_sequences': num_sequences
150
  }
151
-
152
  def evaluate_text_generation(
153
  self,
154
  prompts: List[str],
155
  max_length: int = 256,
156
  temperature: float = 0.7,
157
  top_k: Optional[int] = 40,
158
- num_samples: int = 1
159
  ) -> List[Dict[str, Any]]:
160
  """
161
  Evaluate text generation quality.
162
-
163
  Args:
164
  prompts: List of input prompts
165
  max_length: Maximum generation length
166
  temperature: Sampling temperature
167
  top_k: Top-k sampling parameter
168
  num_samples: Number of samples per prompt
169
-
170
  Returns:
171
  List of generation results with quality metrics
172
  """
173
  self.model.eval()
174
  results = []
175
-
176
  print(f"✍️ Evaluating text generation on {len(prompts)} prompts...")
177
-
178
  with torch.no_grad():
179
  for prompt in prompts:
180
  prompt_results = []
181
-
182
  for sample_idx in range(num_samples):
183
  # Tokenize prompt
184
  input_ids = self.tokenizer.encode(prompt)
185
  input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.device)
186
-
187
  start_time = time.time()
188
-
189
  # Generate text
190
  output = self.model.generate(
191
  input_tensor,
192
  max_new_tokens=max_length,
193
  temperature=temperature,
194
- top_k=top_k
195
  )
196
-
197
  generation_time = time.time() - start_time
198
-
199
  # Decode output
200
  generated_ids = output[0].tolist()
201
  full_text = self.tokenizer.decode(generated_ids)
202
- generated_text = self.tokenizer.decode(generated_ids[len(input_ids):])
203
-
204
  # Calculate quality metrics
205
  quality_metrics = self._assess_generation_quality(generated_text)
206
-
207
- prompt_results.append({
208
- 'prompt': prompt,
209
- 'generated_text': generated_text,
210
- 'full_text': full_text,
211
- 'generation_time': generation_time,
212
- 'tokens_generated': len(generated_ids) - len(input_ids),
213
- 'tokens_per_second': (len(generated_ids) - len(input_ids)) / generation_time,
214
- 'quality_metrics': quality_metrics
215
- })
216
-
 
 
 
217
  results.extend(prompt_results)
218
-
219
  return results
220
-
221
  def _assess_generation_quality(self, text: str) -> Dict[str, float]:
222
  """
223
  Assess basic quality metrics for generated text.
224
-
225
  Args:
226
  text: Generated text to assess
227
-
228
  Returns:
229
  Dictionary of quality metrics
230
  """
231
  if not text.strip():
232
  return {
233
- 'length': 0,
234
- 'avg_word_length': 0,
235
- 'repetition_rate': 1.0,
236
- 'coherence_score': 0.0
237
  }
238
-
239
  words = text.split()
240
-
241
  # Basic metrics
242
  length = len(words)
243
  avg_word_length = sum(len(word) for word in words) / len(words) if words else 0
244
-
245
  # Repetition rate (simple n-gram repetition)
246
- bigrams = [f"{words[i]} {words[i+1]}" for i in range(len(words)-1)]
247
  unique_bigrams = len(set(bigrams))
248
  repetition_rate = 1 - (unique_bigrams / len(bigrams) if bigrams else 0)
249
-
250
  # Simple coherence score (based on sentence structure)
251
- sentences = text.split('.')
252
  valid_sentences = [s for s in sentences if len(s.strip().split()) > 3]
253
  coherence_score = len(valid_sentences) / len(sentences) if sentences else 0
254
-
255
  return {
256
- 'length': length,
257
- 'avg_word_length': avg_word_length,
258
- 'repetition_rate': repetition_rate,
259
- 'coherence_score': coherence_score
260
  }
261
-
262
  def evaluate_downstream_tasks(self) -> Dict[str, Any]:
263
  """
264
  Evaluate model performance on downstream tasks.
265
-
266
  This function implements basic downstream task evaluation including:
267
  - Reading comprehension (simplified SQUAD-style)
268
  - Sentiment analysis (few-shot)
269
  - Common sense reasoning
270
-
271
  Returns:
272
  Dictionary of downstream task results
273
  """
274
  results = {}
275
-
276
  # 1. Reading Comprehension (Simplified SQUAD-style)
277
- results['reading_comprehension'] = self._evaluate_reading_comprehension()
278
-
279
  # 2. Sentiment Analysis (Few-shot learning)
280
- results['sentiment_analysis'] = self._evaluate_sentiment_analysis()
281
-
282
  # 3. Common Sense Reasoning
283
- results['reasoning'] = self._evaluate_reasoning()
284
-
285
  # 4. Text Completion Quality
286
- results['text_completion'] = self._evaluate_text_completion()
287
-
288
  return results
289
-
290
  def _evaluate_reading_comprehension(self) -> Dict[str, Any]:
291
  """Simplified reading comprehension evaluation."""
292
  # Sample reading comprehension tasks
293
  tasks = [
294
  {
295
- 'context': 'The Eiffel Tower is a wrought-iron lattice tower on the Champ de Mars in Paris, France. It is named after the engineer Gustave Eiffel, whose company designed and built the tower.',
296
- 'question': 'Who is the Eiffel Tower named after?',
297
- 'expected': 'Gustave Eiffel'
298
  },
299
  {
300
- 'context': 'Python is a high-level programming language. It was created by Guido van Rossum and first released in 1991.',
301
- 'question': 'When was Python first released?',
302
- 'expected': '1991'
303
  },
304
  {
305
- 'context': 'Machine learning is a subset of artificial intelligence that enables computers to learn without being explicitly programmed.',
306
- 'question': 'What is machine learning a subset of?',
307
- 'expected': 'artificial intelligence'
308
- }
309
  ]
310
-
311
  correct = 0
312
  total = len(tasks)
313
-
314
  for task in tasks:
315
  prompt = f"Context: {task['context']}\nQuestion: {task['question']}\nAnswer:"
316
-
317
  # Generate answer
318
  input_ids = self.tokenizer.encode(prompt)
319
  input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.device)
320
-
321
  with torch.no_grad():
322
  output = self.model.generate(input_tensor, max_new_tokens=20, temperature=0.1)
323
-
324
  generated_ids = output[0].tolist()
325
- answer = self.tokenizer.decode(generated_ids[len(input_ids):]).strip().lower()
326
-
327
  # Simple substring matching
328
- if task['expected'].lower() in answer:
329
  correct += 1
330
-
331
  return {
332
- 'accuracy': correct / total,
333
- 'correct': correct,
334
- 'total': total,
335
- 'score': correct / total
336
  }
337
-
338
  def _evaluate_sentiment_analysis(self) -> Dict[str, Any]:
339
  """Few-shot sentiment analysis evaluation."""
340
  # Few-shot examples
341
  examples = "Examples:\nText: 'I love this movie!' Sentiment: Positive\nText: 'This is terrible.' Sentiment: Negative\nText: 'It was okay.' Sentiment: Neutral\n\n"
342
-
343
  # Test cases
344
  test_cases = [
345
- {'text': 'This is amazing!', 'expected': 'positive'},
346
- {'text': 'I hate this.', 'expected': 'negative'},
347
- {'text': 'This is wonderful.', 'expected': 'positive'},
348
- {'text': 'This is awful.', 'expected': 'negative'},
349
- {'text': 'It was fine.', 'expected': 'neutral'}
350
  ]
351
-
352
  correct = 0
353
  total = len(test_cases)
354
-
355
  for case in test_cases:
356
  prompt = f"{examples}Text: '{case['text']}' Sentiment:"
357
-
358
  # Generate sentiment
359
  input_ids = self.tokenizer.encode(prompt)
360
  input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.device)
361
-
362
  with torch.no_grad():
363
  output = self.model.generate(input_tensor, max_new_tokens=5, temperature=0.1)
364
-
365
  generated_ids = output[0].tolist()
366
- sentiment = self.tokenizer.decode(generated_ids[len(input_ids):]).strip().lower()
367
-
368
  # Check if expected sentiment is in the generated response
369
- if case['expected'] in sentiment:
370
  correct += 1
371
-
372
  return {
373
- 'accuracy': correct / total,
374
- 'correct': correct,
375
- 'total': total,
376
- 'score': correct / total
377
  }
378
-
379
  def _evaluate_reasoning(self) -> Dict[str, Any]:
380
  """Simple reasoning evaluation."""
381
  # Basic reasoning tasks
382
  tasks = [
383
  {
384
- 'question': 'If all birds can fly and a penguin is a bird, can a penguin fly?',
385
- 'expected': 'no' # This tests if model knows real-world facts
386
  },
387
  {
388
- 'question': 'If it is raining outside, should you take an umbrella?',
389
- 'expected': 'yes'
390
  },
391
- {
392
- 'question': 'What comes after Monday?',
393
- 'expected': 'tuesday'
394
- },
395
- {
396
- 'question': 'Is the sun larger than the earth?',
397
- 'expected': 'yes'
398
- }
399
  ]
400
-
401
  correct = 0
402
  total = len(tasks)
403
-
404
  for task in tasks:
405
  prompt = f"Question: {task['question']}\nAnswer:"
406
-
407
  # Generate answer
408
  input_ids = self.tokenizer.encode(prompt)
409
  input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.device)
410
-
411
  with torch.no_grad():
412
  output = self.model.generate(input_tensor, max_new_tokens=10, temperature=0.1)
413
-
414
  generated_ids = output[0].tolist()
415
- answer = self.tokenizer.decode(generated_ids[len(input_ids):]).strip().lower()
416
-
417
  # Check if expected answer is in the response
418
- if task['expected'] in answer:
419
  correct += 1
420
-
421
  return {
422
- 'accuracy': correct / total,
423
- 'correct': correct,
424
- 'total': total,
425
- 'score': correct / total
426
  }
427
-
428
  def _evaluate_text_completion(self) -> Dict[str, Any]:
429
  """Evaluate text completion quality."""
430
  # Common phrases that should be completed predictably
431
  completions = [
432
- {'prompt': 'The capital of France is', 'expected_word': 'paris'},
433
- {'prompt': 'Two plus two equals', 'expected_word': 'four'},
434
- {'prompt': 'The largest planet in our solar system is', 'expected_word': 'jupiter'},
435
- {'prompt': 'Water boils at', 'expected_word': '100'}
436
  ]
437
-
438
  correct = 0
439
  total = len(completions)
440
-
441
  for completion in completions:
442
  # Generate completion
443
- input_ids = self.tokenizer.encode(completion['prompt'])
444
  input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.device)
445
-
446
  with torch.no_grad():
447
  output = self.model.generate(input_tensor, max_new_tokens=5, temperature=0.1)
448
-
449
  generated_ids = output[0].tolist()
450
- generated_text = self.tokenizer.decode(generated_ids[len(input_ids):]).strip().lower()
451
-
452
  # Check if expected word appears in completion
453
- if completion['expected_word'] in generated_text:
454
  correct += 1
455
-
456
  return {
457
- 'accuracy': correct / total,
458
- 'correct': correct,
459
- 'total': total,
460
- 'score': correct / total
461
  }
462
-
463
  def run_comprehensive_evaluation(
464
- self,
465
- eval_data_path: str,
466
- metrics: List[str] = None,
467
- generation_prompts: List[str] = None
468
  ) -> Dict[str, Any]:
469
  """
470
  Run comprehensive model evaluation.
471
-
472
  Args:
473
  eval_data_path: Path to evaluation text file
474
  metrics: List of metrics to compute
475
  generation_prompts: Prompts for text generation evaluation
476
-
477
  Returns:
478
  Complete evaluation results
479
  """
480
  if metrics is None:
481
- metrics = ['perplexity', 'loss', 'generation']
482
-
483
  if generation_prompts is None:
484
  generation_prompts = [
485
  "The history of artificial intelligence",
486
  "Machine learning algorithms",
487
  "The future of technology",
488
  "In a world where",
489
- "Scientists have discovered"
490
  ]
491
-
492
  results = {
493
- 'model_info': {
494
- 'parameters': self.model.get_num_params(),
495
- 'device': self.device,
496
- 'vocab_size': self.tokenizer.vocab_size()
497
  },
498
- 'evaluation_timestamp': time.time()
499
  }
500
-
501
  # Load evaluation data
502
  print(f"πŸ“‚ Loading evaluation data from {eval_data_path}")
503
  if os.path.exists(eval_data_path):
504
- with open(eval_data_path, 'r', encoding='utf-8') as f:
505
  eval_texts = [line.strip() for line in f if line.strip()]
506
  else:
507
  print(f"⚠️ Evaluation file not found, using sample texts")
@@ -510,96 +496,103 @@ class ModelEvaluator:
510
  "Machine learning algorithms can learn patterns from data automatically.",
511
  "Natural language processing helps computers understand human language.",
512
  "Deep learning uses neural networks with multiple layers for complex tasks.",
513
- "The development of large language models has transformed AI applications."
514
  ]
515
-
516
  # Intrinsic evaluation
517
- if 'perplexity' in metrics or 'loss' in metrics:
518
  perplexity_results = self.evaluate_perplexity(eval_texts)
519
- results['intrinsic_evaluation'] = perplexity_results
520
-
521
  # Text generation evaluation
522
- if 'generation' in metrics:
523
  generation_results = self.evaluate_text_generation(generation_prompts)
524
- results['generation_evaluation'] = {
525
- 'results': generation_results,
526
- 'summary': self._summarize_generation_results(generation_results)
527
  }
528
-
529
  # Downstream tasks (placeholder)
530
- results['downstream_evaluation'] = self.evaluate_downstream_tasks()
531
-
532
  # Overall quality assessment
533
- results['quality_assessment'] = self._assess_overall_quality(results)
534
-
535
  return results
536
-
537
  def _summarize_generation_results(self, results: List[Dict[str, Any]]) -> Dict[str, float]:
538
  """Summarize text generation results."""
539
  if not results:
540
  return {}
541
-
542
- total_time = sum(r['generation_time'] for r in results)
543
- total_tokens = sum(r['tokens_generated'] for r in results)
544
-
545
- quality_metrics = [r['quality_metrics'] for r in results]
546
-
547
  return {
548
- 'avg_generation_time': total_time / len(results),
549
- 'avg_tokens_per_second': total_tokens / total_time if total_time > 0 else 0,
550
- 'avg_length': sum(q['length'] for q in quality_metrics) / len(quality_metrics),
551
- 'avg_repetition_rate': sum(q['repetition_rate'] for q in quality_metrics) / len(quality_metrics),
552
- 'avg_coherence_score': sum(q['coherence_score'] for q in quality_metrics) / len(quality_metrics)
 
 
553
  }
554
-
555
  def _assess_overall_quality(self, results: Dict[str, Any]) -> Dict[str, Any]:
556
  """Assess overall model quality based on evaluation results."""
557
- assessment = {
558
- 'quality_level': 'unknown',
559
- 'recommendations': []
560
- }
561
-
562
  # Check intrinsic metrics
563
- if 'intrinsic_evaluation' in results:
564
- perplexity = results['intrinsic_evaluation'].get('perplexity', float('inf'))
565
-
566
  if perplexity < 12:
567
- assessment['quality_level'] = 'good'
568
- assessment['recommendations'].append('Model shows good perplexity scores')
569
  elif perplexity < 50:
570
- assessment['quality_level'] = 'fair'
571
- assessment['recommendations'].append('Model shows fair performance, could benefit from more training')
 
 
572
  else:
573
- assessment['quality_level'] = 'poor'
574
- assessment['recommendations'].append('Model needs significant more training or data improvements')
575
-
 
 
576
  # Check generation quality
577
- if 'generation_evaluation' in results:
578
- summary = results['generation_evaluation'].get('summary', {})
579
- repetition_rate = summary.get('avg_repetition_rate', 1.0)
580
- coherence_score = summary.get('avg_coherence_score', 0.0)
581
-
582
  if repetition_rate > 0.7:
583
- assessment['recommendations'].append('High repetition rate - consider training longer or adjusting data')
 
 
584
  if coherence_score < 0.3:
585
- assessment['recommendations'].append('Low coherence - model may need more training steps')
586
-
 
 
587
  return assessment
588
 
589
 
590
  def load_model_from_directory(model_dir: str, device: str = "cpu") -> Tuple[GPTModel, str]:
591
  """
592
  Load model from directory containing checkpoints.
593
-
594
  Args:
595
  model_dir: Directory containing model files
596
  device: Device to load model on
597
-
598
  Returns:
599
  Tuple of (model, tokenizer_path)
600
  """
601
  model_dir = Path(model_dir)
602
-
603
  # Find best model checkpoint
604
  best_model_path = model_dir / "best_model.pt"
605
  if not best_model_path.exists():
@@ -607,41 +600,41 @@ def load_model_from_directory(model_dir: str, device: str = "cpu") -> Tuple[GPTM
607
  checkpoints = list(model_dir.glob("checkpoint_step_*.pt"))
608
  if not checkpoints:
609
  raise FileNotFoundError(f"No model checkpoints found in {model_dir}")
610
-
611
  # Get latest checkpoint
612
- latest_checkpoint = max(checkpoints, key=lambda p: int(p.stem.split('_')[-1]))
613
  best_model_path = latest_checkpoint
614
-
615
  print(f"πŸ“‚ Loading model from {best_model_path}")
616
-
617
  # Load checkpoint
618
  checkpoint = torch.load(best_model_path, map_location=device)
619
-
620
  # Determine model size from config
621
- config = checkpoint.get('config', {})
622
- n_layer = config.get('n_layer', 12)
623
-
624
  if n_layer <= 6:
625
  model_size = "small"
626
  elif n_layer <= 12:
627
  model_size = "medium"
628
  else:
629
  model_size = "large"
630
-
631
  # Create and load model
632
  model = create_model(model_size)
633
- model.load_state_dict(checkpoint['model_state_dict'])
634
-
635
  print(f"βœ… Model loaded successfully ({model_size}, {model.get_num_params():,} parameters)")
636
-
637
  # Find tokenizer
638
  tokenizer_path = model_dir.parent / "tokenizer" / "tokenizer.model"
639
  if not tokenizer_path.exists():
640
  tokenizer_path = Path("data/tokenizer/tokenizer.model")
641
-
642
  if not tokenizer_path.exists():
643
  raise FileNotFoundError(f"Tokenizer not found at {tokenizer_path}")
644
-
645
  return model, str(tokenizer_path)
646
 
647
 
@@ -662,121 +655,115 @@ Examples:
662
  --model_dir models/small-extended-4k \\
663
  --metrics perplexity,generation \\
664
  --output results.json
665
- """
666
  )
667
-
668
- parser.add_argument(
669
- "--model_dir",
670
- required=True,
671
- help="Directory containing trained model"
672
- )
673
-
674
  parser.add_argument(
675
- "--eval_data",
676
- help="Path to evaluation text file (default: use sample texts)"
677
  )
678
-
679
  parser.add_argument(
680
  "--metrics",
681
  default="perplexity,loss,generation",
682
- help="Comma-separated list of metrics to evaluate (default: perplexity,loss,generation)"
683
  )
684
-
685
- parser.add_argument(
686
- "--output",
687
- help="Output JSON file for results (default: print to console)"
688
- )
689
-
690
  parser.add_argument(
691
  "--device",
692
  choices=["cpu", "cuda", "auto"],
693
  default="auto",
694
- help="Device for evaluation (default: auto)"
695
  )
696
-
697
  parser.add_argument(
698
- "--generation_prompts",
699
- help="File containing prompts for text generation evaluation"
700
  )
701
-
702
  args = parser.parse_args()
703
-
704
  print("πŸ“Š OpenLLM Model Evaluation")
705
  print("=" * 50)
706
-
707
  # Determine device
708
  if args.device == "auto":
709
  device = "cuda" if torch.cuda.is_available() else "cpu"
710
  else:
711
  device = args.device
712
-
713
  print(f"Using device: {device}")
714
-
715
  try:
716
  # Load model
717
  model, tokenizer_path = load_model_from_directory(args.model_dir, device)
718
-
719
  # Create evaluator
720
  evaluator = ModelEvaluator(model, tokenizer_path, device)
721
-
722
  # Parse metrics
723
- metrics = [m.strip() for m in args.metrics.split(',')]
724
-
725
  # Load generation prompts if specified
726
  generation_prompts = None
727
  if args.generation_prompts and os.path.exists(args.generation_prompts):
728
- with open(args.generation_prompts, 'r', encoding='utf-8') as f:
729
  generation_prompts = [line.strip() for line in f if line.strip()]
730
-
731
  # Run evaluation
732
  eval_data_path = args.eval_data or "data/clean/training_data.txt"
733
  results = evaluator.run_comprehensive_evaluation(
734
  eval_data_path, metrics, generation_prompts
735
  )
736
-
737
  # Output results
738
  if args.output:
739
- with open(args.output, 'w', encoding='utf-8') as f:
740
  json.dump(results, f, indent=2)
741
  print(f"\nπŸ’Ύ Results saved to {args.output}")
742
  else:
743
  print(f"\nπŸ“Š Evaluation Results:")
744
  print("=" * 50)
745
-
746
  # Print key metrics
747
- if 'intrinsic_evaluation' in results:
748
- intrinsic = results['intrinsic_evaluation']
749
  print(f"πŸ“ˆ Intrinsic Metrics:")
750
  print(f" Loss: {intrinsic['loss']:.4f}")
751
  print(f" Perplexity: {intrinsic['perplexity']:.2f}")
752
  print(f" Sequences evaluated: {intrinsic['num_sequences']:,}")
753
-
754
- if 'generation_evaluation' in results:
755
- gen_summary = results['generation_evaluation']['summary']
756
  print(f"\n✍️ Generation Quality:")
757
- print(f" Avg generation speed: {gen_summary['avg_tokens_per_second']:.1f} tokens/sec")
 
 
758
  print(f" Avg text length: {gen_summary['avg_length']:.1f} words")
759
  print(f" Repetition rate: {gen_summary['avg_repetition_rate']:.3f}")
760
  print(f" Coherence score: {gen_summary['avg_coherence_score']:.3f}")
761
-
762
  # Quality assessment
763
- if 'quality_assessment' in results:
764
- assessment = results['quality_assessment']
765
  print(f"\n🎯 Overall Assessment:")
766
  print(f" Quality Level: {assessment['quality_level'].upper()}")
767
- for rec in assessment['recommendations']:
768
  print(f" β€’ {rec}")
769
-
770
  print(f"\nπŸŽ‰ Evaluation completed successfully!")
771
-
772
  except Exception as e:
773
  print(f"\n❌ Evaluation failed: {e}")
774
  import traceback
 
775
  traceback.print_exc()
776
  return False
777
-
778
  return True
779
 
780
 
781
  if __name__ == "__main__":
782
- main()
 
56
  class ModelEvaluator:
57
  """
58
  Comprehensive evaluator for OpenLLM models.
59
+
60
  Implements intrinsic evaluation metrics and text generation quality
61
  assessment following the training pipeline specifications.
62
  """
63
+
64
+ def __init__(self, model: GPTModel, tokenizer_path: str, device: str = "cpu"):
 
 
 
 
 
65
  """
66
  Initialize the model evaluator.
67
+
68
  Args:
69
  model: Trained GPT model
70
  tokenizer_path: Path to tokenizer model file
 
72
  """
73
  self.model = model.to(device)
74
  self.device = device
75
+
76
  # Load tokenizer
77
  self.tokenizer = smp.SentencePieceProcessor()
78
  self.tokenizer.load(tokenizer_path)
79
+
80
  print(f"πŸ”§ ModelEvaluator initialized")
81
  print(f" Device: {device}")
82
  print(f" Model parameters: {model.get_num_params():,}")
83
  print(f" Vocabulary size: {self.tokenizer.vocab_size():,}")
84
+
85
  def evaluate_perplexity(
86
+ self, eval_data: List[str], max_seq_len: int = 512, batch_size: int = 1
 
 
 
87
  ) -> Dict[str, float]:
88
  """
89
  Calculate perplexity on evaluation data.
90
+
91
  Args:
92
  eval_data: List of text passages for evaluation
93
  max_seq_len: Maximum sequence length for evaluation
94
  batch_size: Batch size for evaluation
95
+
96
  Returns:
97
  Dictionary with loss and perplexity metrics
98
  """
 
100
  total_loss = 0.0
101
  total_tokens = 0
102
  num_sequences = 0
103
+
104
  print(f"πŸ“Š Calculating perplexity on {len(eval_data)} passages...")
105
+
106
  with torch.no_grad():
107
  for i, text in enumerate(eval_data):
108
  if i % 100 == 0:
109
  print(f" Progress: {i}/{len(eval_data)} passages")
110
+
111
  # Tokenize text
112
  tokens = self.tokenizer.encode(text)
113
  if len(tokens) < 2:
114
  continue
115
+
116
  # Truncate if too long
117
  if len(tokens) > max_seq_len:
118
  tokens = tokens[:max_seq_len]
119
+
120
  # Create input and target tensors
121
  input_ids = torch.tensor([tokens[:-1]], dtype=torch.long, device=self.device)
122
  target_ids = torch.tensor([tokens[1:]], dtype=torch.long, device=self.device)
123
+
124
  # Forward pass
125
  logits, loss = self.model(input_ids, target_ids)
126
+
127
  # Accumulate loss
128
  seq_length = len(tokens) - 1
129
  total_loss += loss.item() * seq_length
130
  total_tokens += seq_length
131
  num_sequences += 1
132
+
133
  # Calculate metrics
134
+ avg_loss = total_loss / total_tokens if total_tokens > 0 else float("inf")
135
  perplexity = math.exp(min(avg_loss, 10)) # Cap to prevent overflow
136
+
137
  return {
138
+ "loss": avg_loss,
139
+ "perplexity": perplexity,
140
+ "total_tokens": total_tokens,
141
+ "num_sequences": num_sequences,
142
  }
143
+
144
  def evaluate_text_generation(
145
  self,
146
  prompts: List[str],
147
  max_length: int = 256,
148
  temperature: float = 0.7,
149
  top_k: Optional[int] = 40,
150
+ num_samples: int = 1,
151
  ) -> List[Dict[str, Any]]:
152
  """
153
  Evaluate text generation quality.
154
+
155
  Args:
156
  prompts: List of input prompts
157
  max_length: Maximum generation length
158
  temperature: Sampling temperature
159
  top_k: Top-k sampling parameter
160
  num_samples: Number of samples per prompt
161
+
162
  Returns:
163
  List of generation results with quality metrics
164
  """
165
  self.model.eval()
166
  results = []
167
+
168
  print(f"✍️ Evaluating text generation on {len(prompts)} prompts...")
169
+
170
  with torch.no_grad():
171
  for prompt in prompts:
172
  prompt_results = []
173
+
174
  for sample_idx in range(num_samples):
175
  # Tokenize prompt
176
  input_ids = self.tokenizer.encode(prompt)
177
  input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.device)
178
+
179
  start_time = time.time()
180
+
181
  # Generate text
182
  output = self.model.generate(
183
  input_tensor,
184
  max_new_tokens=max_length,
185
  temperature=temperature,
186
+ top_k=top_k,
187
  )
188
+
189
  generation_time = time.time() - start_time
190
+
191
  # Decode output
192
  generated_ids = output[0].tolist()
193
  full_text = self.tokenizer.decode(generated_ids)
194
+ generated_text = self.tokenizer.decode(generated_ids[len(input_ids) :])
195
+
196
  # Calculate quality metrics
197
  quality_metrics = self._assess_generation_quality(generated_text)
198
+
199
+ prompt_results.append(
200
+ {
201
+ "prompt": prompt,
202
+ "generated_text": generated_text,
203
+ "full_text": full_text,
204
+ "generation_time": generation_time,
205
+ "tokens_generated": len(generated_ids) - len(input_ids),
206
+ "tokens_per_second": (len(generated_ids) - len(input_ids))
207
+ / generation_time,
208
+ "quality_metrics": quality_metrics,
209
+ }
210
+ )
211
+
212
  results.extend(prompt_results)
213
+
214
  return results
215
+
216
  def _assess_generation_quality(self, text: str) -> Dict[str, float]:
217
  """
218
  Assess basic quality metrics for generated text.
219
+
220
  Args:
221
  text: Generated text to assess
222
+
223
  Returns:
224
  Dictionary of quality metrics
225
  """
226
  if not text.strip():
227
  return {
228
+ "length": 0,
229
+ "avg_word_length": 0,
230
+ "repetition_rate": 1.0,
231
+ "coherence_score": 0.0,
232
  }
233
+
234
  words = text.split()
235
+
236
  # Basic metrics
237
  length = len(words)
238
  avg_word_length = sum(len(word) for word in words) / len(words) if words else 0
239
+
240
  # Repetition rate (simple n-gram repetition)
241
+ bigrams = [f"{words[i]} {words[i+1]}" for i in range(len(words) - 1)]
242
  unique_bigrams = len(set(bigrams))
243
  repetition_rate = 1 - (unique_bigrams / len(bigrams) if bigrams else 0)
244
+
245
  # Simple coherence score (based on sentence structure)
246
+ sentences = text.split(".")
247
  valid_sentences = [s for s in sentences if len(s.strip().split()) > 3]
248
  coherence_score = len(valid_sentences) / len(sentences) if sentences else 0
249
+
250
  return {
251
+ "length": length,
252
+ "avg_word_length": avg_word_length,
253
+ "repetition_rate": repetition_rate,
254
+ "coherence_score": coherence_score,
255
  }
256
+
257
  def evaluate_downstream_tasks(self) -> Dict[str, Any]:
258
  """
259
  Evaluate model performance on downstream tasks.
260
+
261
  This function implements basic downstream task evaluation including:
262
  - Reading comprehension (simplified SQUAD-style)
263
  - Sentiment analysis (few-shot)
264
  - Common sense reasoning
265
+
266
  Returns:
267
  Dictionary of downstream task results
268
  """
269
  results = {}
270
+
271
  # 1. Reading Comprehension (Simplified SQUAD-style)
272
+ results["reading_comprehension"] = self._evaluate_reading_comprehension()
273
+
274
  # 2. Sentiment Analysis (Few-shot learning)
275
+ results["sentiment_analysis"] = self._evaluate_sentiment_analysis()
276
+
277
  # 3. Common Sense Reasoning
278
+ results["reasoning"] = self._evaluate_reasoning()
279
+
280
  # 4. Text Completion Quality
281
+ results["text_completion"] = self._evaluate_text_completion()
282
+
283
  return results
284
+
285
  def _evaluate_reading_comprehension(self) -> Dict[str, Any]:
286
  """Simplified reading comprehension evaluation."""
287
  # Sample reading comprehension tasks
288
  tasks = [
289
  {
290
+ "context": "The Eiffel Tower is a wrought-iron lattice tower on the Champ de Mars in Paris, France. It is named after the engineer Gustave Eiffel, whose company designed and built the tower.",
291
+ "question": "Who is the Eiffel Tower named after?",
292
+ "expected": "Gustave Eiffel",
293
  },
294
  {
295
+ "context": "Python is a high-level programming language. It was created by Guido van Rossum and first released in 1991.",
296
+ "question": "When was Python first released?",
297
+ "expected": "1991",
298
  },
299
  {
300
+ "context": "Machine learning is a subset of artificial intelligence that enables computers to learn without being explicitly programmed.",
301
+ "question": "What is machine learning a subset of?",
302
+ "expected": "artificial intelligence",
303
+ },
304
  ]
305
+
306
  correct = 0
307
  total = len(tasks)
308
+
309
  for task in tasks:
310
  prompt = f"Context: {task['context']}\nQuestion: {task['question']}\nAnswer:"
311
+
312
  # Generate answer
313
  input_ids = self.tokenizer.encode(prompt)
314
  input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.device)
315
+
316
  with torch.no_grad():
317
  output = self.model.generate(input_tensor, max_new_tokens=20, temperature=0.1)
318
+
319
  generated_ids = output[0].tolist()
320
+ answer = self.tokenizer.decode(generated_ids[len(input_ids) :]).strip().lower()
321
+
322
  # Simple substring matching
323
+ if task["expected"].lower() in answer:
324
  correct += 1
325
+
326
  return {
327
+ "accuracy": correct / total,
328
+ "correct": correct,
329
+ "total": total,
330
+ "score": correct / total,
331
  }
332
+
333
  def _evaluate_sentiment_analysis(self) -> Dict[str, Any]:
334
  """Few-shot sentiment analysis evaluation."""
335
  # Few-shot examples
336
  examples = "Examples:\nText: 'I love this movie!' Sentiment: Positive\nText: 'This is terrible.' Sentiment: Negative\nText: 'It was okay.' Sentiment: Neutral\n\n"
337
+
338
  # Test cases
339
  test_cases = [
340
+ {"text": "This is amazing!", "expected": "positive"},
341
+ {"text": "I hate this.", "expected": "negative"},
342
+ {"text": "This is wonderful.", "expected": "positive"},
343
+ {"text": "This is awful.", "expected": "negative"},
344
+ {"text": "It was fine.", "expected": "neutral"},
345
  ]
346
+
347
  correct = 0
348
  total = len(test_cases)
349
+
350
  for case in test_cases:
351
  prompt = f"{examples}Text: '{case['text']}' Sentiment:"
352
+
353
  # Generate sentiment
354
  input_ids = self.tokenizer.encode(prompt)
355
  input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.device)
356
+
357
  with torch.no_grad():
358
  output = self.model.generate(input_tensor, max_new_tokens=5, temperature=0.1)
359
+
360
  generated_ids = output[0].tolist()
361
+ sentiment = self.tokenizer.decode(generated_ids[len(input_ids) :]).strip().lower()
362
+
363
  # Check if expected sentiment is in the generated response
364
+ if case["expected"] in sentiment:
365
  correct += 1
366
+
367
  return {
368
+ "accuracy": correct / total,
369
+ "correct": correct,
370
+ "total": total,
371
+ "score": correct / total,
372
  }
373
+
374
  def _evaluate_reasoning(self) -> Dict[str, Any]:
375
  """Simple reasoning evaluation."""
376
  # Basic reasoning tasks
377
  tasks = [
378
  {
379
+ "question": "If all birds can fly and a penguin is a bird, can a penguin fly?",
380
+ "expected": "no", # This tests if model knows real-world facts
381
  },
382
  {
383
+ "question": "If it is raining outside, should you take an umbrella?",
384
+ "expected": "yes",
385
  },
386
+ {"question": "What comes after Monday?", "expected": "tuesday"},
387
+ {"question": "Is the sun larger than the earth?", "expected": "yes"},
 
 
 
 
 
 
388
  ]
389
+
390
  correct = 0
391
  total = len(tasks)
392
+
393
  for task in tasks:
394
  prompt = f"Question: {task['question']}\nAnswer:"
395
+
396
  # Generate answer
397
  input_ids = self.tokenizer.encode(prompt)
398
  input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.device)
399
+
400
  with torch.no_grad():
401
  output = self.model.generate(input_tensor, max_new_tokens=10, temperature=0.1)
402
+
403
  generated_ids = output[0].tolist()
404
+ answer = self.tokenizer.decode(generated_ids[len(input_ids) :]).strip().lower()
405
+
406
  # Check if expected answer is in the response
407
+ if task["expected"] in answer:
408
  correct += 1
409
+
410
  return {
411
+ "accuracy": correct / total,
412
+ "correct": correct,
413
+ "total": total,
414
+ "score": correct / total,
415
  }
416
+
417
  def _evaluate_text_completion(self) -> Dict[str, Any]:
418
  """Evaluate text completion quality."""
419
  # Common phrases that should be completed predictably
420
  completions = [
421
+ {"prompt": "The capital of France is", "expected_word": "paris"},
422
+ {"prompt": "Two plus two equals", "expected_word": "four"},
423
+ {"prompt": "The largest planet in our solar system is", "expected_word": "jupiter"},
424
+ {"prompt": "Water boils at", "expected_word": "100"},
425
  ]
426
+
427
  correct = 0
428
  total = len(completions)
429
+
430
  for completion in completions:
431
  # Generate completion
432
+ input_ids = self.tokenizer.encode(completion["prompt"])
433
  input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.device)
434
+
435
  with torch.no_grad():
436
  output = self.model.generate(input_tensor, max_new_tokens=5, temperature=0.1)
437
+
438
  generated_ids = output[0].tolist()
439
+ generated_text = self.tokenizer.decode(generated_ids[len(input_ids) :]).strip().lower()
440
+
441
  # Check if expected word appears in completion
442
+ if completion["expected_word"] in generated_text:
443
  correct += 1
444
+
445
  return {
446
+ "accuracy": correct / total,
447
+ "correct": correct,
448
+ "total": total,
449
+ "score": correct / total,
450
  }
451
+
452
  def run_comprehensive_evaluation(
453
+ self, eval_data_path: str, metrics: List[str] = None, generation_prompts: List[str] = None
 
 
 
454
  ) -> Dict[str, Any]:
455
  """
456
  Run comprehensive model evaluation.
457
+
458
  Args:
459
  eval_data_path: Path to evaluation text file
460
  metrics: List of metrics to compute
461
  generation_prompts: Prompts for text generation evaluation
462
+
463
  Returns:
464
  Complete evaluation results
465
  """
466
  if metrics is None:
467
+ metrics = ["perplexity", "loss", "generation"]
468
+
469
  if generation_prompts is None:
470
  generation_prompts = [
471
  "The history of artificial intelligence",
472
  "Machine learning algorithms",
473
  "The future of technology",
474
  "In a world where",
475
+ "Scientists have discovered",
476
  ]
477
+
478
  results = {
479
+ "model_info": {
480
+ "parameters": self.model.get_num_params(),
481
+ "device": self.device,
482
+ "vocab_size": self.tokenizer.vocab_size(),
483
  },
484
+ "evaluation_timestamp": time.time(),
485
  }
486
+
487
  # Load evaluation data
488
  print(f"πŸ“‚ Loading evaluation data from {eval_data_path}")
489
  if os.path.exists(eval_data_path):
490
+ with open(eval_data_path, "r", encoding="utf-8") as f:
491
  eval_texts = [line.strip() for line in f if line.strip()]
492
  else:
493
  print(f"⚠️ Evaluation file not found, using sample texts")
 
496
  "Machine learning algorithms can learn patterns from data automatically.",
497
  "Natural language processing helps computers understand human language.",
498
  "Deep learning uses neural networks with multiple layers for complex tasks.",
499
+ "The development of large language models has transformed AI applications.",
500
  ]
501
+
502
  # Intrinsic evaluation
503
+ if "perplexity" in metrics or "loss" in metrics:
504
  perplexity_results = self.evaluate_perplexity(eval_texts)
505
+ results["intrinsic_evaluation"] = perplexity_results
506
+
507
  # Text generation evaluation
508
+ if "generation" in metrics:
509
  generation_results = self.evaluate_text_generation(generation_prompts)
510
+ results["generation_evaluation"] = {
511
+ "results": generation_results,
512
+ "summary": self._summarize_generation_results(generation_results),
513
  }
514
+
515
  # Downstream tasks (placeholder)
516
+ results["downstream_evaluation"] = self.evaluate_downstream_tasks()
517
+
518
  # Overall quality assessment
519
+ results["quality_assessment"] = self._assess_overall_quality(results)
520
+
521
  return results
522
+
523
  def _summarize_generation_results(self, results: List[Dict[str, Any]]) -> Dict[str, float]:
524
  """Summarize text generation results."""
525
  if not results:
526
  return {}
527
+
528
+ total_time = sum(r["generation_time"] for r in results)
529
+ total_tokens = sum(r["tokens_generated"] for r in results)
530
+
531
+ quality_metrics = [r["quality_metrics"] for r in results]
532
+
533
  return {
534
+ "avg_generation_time": total_time / len(results),
535
+ "avg_tokens_per_second": total_tokens / total_time if total_time > 0 else 0,
536
+ "avg_length": sum(q["length"] for q in quality_metrics) / len(quality_metrics),
537
+ "avg_repetition_rate": sum(q["repetition_rate"] for q in quality_metrics)
538
+ / len(quality_metrics),
539
+ "avg_coherence_score": sum(q["coherence_score"] for q in quality_metrics)
540
+ / len(quality_metrics),
541
  }
542
+
543
  def _assess_overall_quality(self, results: Dict[str, Any]) -> Dict[str, Any]:
544
  """Assess overall model quality based on evaluation results."""
545
+ assessment = {"quality_level": "unknown", "recommendations": []}
546
+
 
 
 
547
  # Check intrinsic metrics
548
+ if "intrinsic_evaluation" in results:
549
+ perplexity = results["intrinsic_evaluation"].get("perplexity", float("inf"))
550
+
551
  if perplexity < 12:
552
+ assessment["quality_level"] = "good"
553
+ assessment["recommendations"].append("Model shows good perplexity scores")
554
  elif perplexity < 50:
555
+ assessment["quality_level"] = "fair"
556
+ assessment["recommendations"].append(
557
+ "Model shows fair performance, could benefit from more training"
558
+ )
559
  else:
560
+ assessment["quality_level"] = "poor"
561
+ assessment["recommendations"].append(
562
+ "Model needs significant more training or data improvements"
563
+ )
564
+
565
  # Check generation quality
566
+ if "generation_evaluation" in results:
567
+ summary = results["generation_evaluation"].get("summary", {})
568
+ repetition_rate = summary.get("avg_repetition_rate", 1.0)
569
+ coherence_score = summary.get("avg_coherence_score", 0.0)
570
+
571
  if repetition_rate > 0.7:
572
+ assessment["recommendations"].append(
573
+ "High repetition rate - consider training longer or adjusting data"
574
+ )
575
  if coherence_score < 0.3:
576
+ assessment["recommendations"].append(
577
+ "Low coherence - model may need more training steps"
578
+ )
579
+
580
  return assessment
581
 
582
 
583
  def load_model_from_directory(model_dir: str, device: str = "cpu") -> Tuple[GPTModel, str]:
584
  """
585
  Load model from directory containing checkpoints.
586
+
587
  Args:
588
  model_dir: Directory containing model files
589
  device: Device to load model on
590
+
591
  Returns:
592
  Tuple of (model, tokenizer_path)
593
  """
594
  model_dir = Path(model_dir)
595
+
596
  # Find best model checkpoint
597
  best_model_path = model_dir / "best_model.pt"
598
  if not best_model_path.exists():
 
600
  checkpoints = list(model_dir.glob("checkpoint_step_*.pt"))
601
  if not checkpoints:
602
  raise FileNotFoundError(f"No model checkpoints found in {model_dir}")
603
+
604
  # Get latest checkpoint
605
+ latest_checkpoint = max(checkpoints, key=lambda p: int(p.stem.split("_")[-1]))
606
  best_model_path = latest_checkpoint
607
+
608
  print(f"πŸ“‚ Loading model from {best_model_path}")
609
+
610
  # Load checkpoint
611
  checkpoint = torch.load(best_model_path, map_location=device)
612
+
613
  # Determine model size from config
614
+ config = checkpoint.get("config", {})
615
+ n_layer = config.get("n_layer", 12)
616
+
617
  if n_layer <= 6:
618
  model_size = "small"
619
  elif n_layer <= 12:
620
  model_size = "medium"
621
  else:
622
  model_size = "large"
623
+
624
  # Create and load model
625
  model = create_model(model_size)
626
+ model.load_state_dict(checkpoint["model_state_dict"])
627
+
628
  print(f"βœ… Model loaded successfully ({model_size}, {model.get_num_params():,} parameters)")
629
+
630
  # Find tokenizer
631
  tokenizer_path = model_dir.parent / "tokenizer" / "tokenizer.model"
632
  if not tokenizer_path.exists():
633
  tokenizer_path = Path("data/tokenizer/tokenizer.model")
634
+
635
  if not tokenizer_path.exists():
636
  raise FileNotFoundError(f"Tokenizer not found at {tokenizer_path}")
637
+
638
  return model, str(tokenizer_path)
639
 
640
 
 
655
  --model_dir models/small-extended-4k \\
656
  --metrics perplexity,generation \\
657
  --output results.json
658
+ """,
659
  )
660
+
661
+ parser.add_argument("--model_dir", required=True, help="Directory containing trained model")
662
+
 
 
 
 
663
  parser.add_argument(
664
+ "--eval_data", help="Path to evaluation text file (default: use sample texts)"
 
665
  )
666
+
667
  parser.add_argument(
668
  "--metrics",
669
  default="perplexity,loss,generation",
670
+ help="Comma-separated list of metrics to evaluate (default: perplexity,loss,generation)",
671
  )
672
+
673
+ parser.add_argument("--output", help="Output JSON file for results (default: print to console)")
674
+
 
 
 
675
  parser.add_argument(
676
  "--device",
677
  choices=["cpu", "cuda", "auto"],
678
  default="auto",
679
+ help="Device for evaluation (default: auto)",
680
  )
681
+
682
  parser.add_argument(
683
+ "--generation_prompts", help="File containing prompts for text generation evaluation"
 
684
  )
685
+
686
  args = parser.parse_args()
687
+
688
  print("πŸ“Š OpenLLM Model Evaluation")
689
  print("=" * 50)
690
+
691
  # Determine device
692
  if args.device == "auto":
693
  device = "cuda" if torch.cuda.is_available() else "cpu"
694
  else:
695
  device = args.device
696
+
697
  print(f"Using device: {device}")
698
+
699
  try:
700
  # Load model
701
  model, tokenizer_path = load_model_from_directory(args.model_dir, device)
702
+
703
  # Create evaluator
704
  evaluator = ModelEvaluator(model, tokenizer_path, device)
705
+
706
  # Parse metrics
707
+ metrics = [m.strip() for m in args.metrics.split(",")]
708
+
709
  # Load generation prompts if specified
710
  generation_prompts = None
711
  if args.generation_prompts and os.path.exists(args.generation_prompts):
712
+ with open(args.generation_prompts, "r", encoding="utf-8") as f:
713
  generation_prompts = [line.strip() for line in f if line.strip()]
714
+
715
  # Run evaluation
716
  eval_data_path = args.eval_data or "data/clean/training_data.txt"
717
  results = evaluator.run_comprehensive_evaluation(
718
  eval_data_path, metrics, generation_prompts
719
  )
720
+
721
  # Output results
722
  if args.output:
723
+ with open(args.output, "w", encoding="utf-8") as f:
724
  json.dump(results, f, indent=2)
725
  print(f"\nπŸ’Ύ Results saved to {args.output}")
726
  else:
727
  print(f"\nπŸ“Š Evaluation Results:")
728
  print("=" * 50)
729
+
730
  # Print key metrics
731
+ if "intrinsic_evaluation" in results:
732
+ intrinsic = results["intrinsic_evaluation"]
733
  print(f"πŸ“ˆ Intrinsic Metrics:")
734
  print(f" Loss: {intrinsic['loss']:.4f}")
735
  print(f" Perplexity: {intrinsic['perplexity']:.2f}")
736
  print(f" Sequences evaluated: {intrinsic['num_sequences']:,}")
737
+
738
+ if "generation_evaluation" in results:
739
+ gen_summary = results["generation_evaluation"]["summary"]
740
  print(f"\n✍️ Generation Quality:")
741
+ print(
742
+ f" Avg generation speed: {gen_summary['avg_tokens_per_second']:.1f} tokens/sec"
743
+ )
744
  print(f" Avg text length: {gen_summary['avg_length']:.1f} words")
745
  print(f" Repetition rate: {gen_summary['avg_repetition_rate']:.3f}")
746
  print(f" Coherence score: {gen_summary['avg_coherence_score']:.3f}")
747
+
748
  # Quality assessment
749
+ if "quality_assessment" in results:
750
+ assessment = results["quality_assessment"]
751
  print(f"\n🎯 Overall Assessment:")
752
  print(f" Quality Level: {assessment['quality_level'].upper()}")
753
+ for rec in assessment["recommendations"]:
754
  print(f" β€’ {rec}")
755
+
756
  print(f"\nπŸŽ‰ Evaluation completed successfully!")
757
+
758
  except Exception as e:
759
  print(f"\n❌ Evaluation failed: {e}")
760
  import traceback
761
+
762
  traceback.print_exc()
763
  return False
764
+
765
  return True
766
 
767
 
768
  if __name__ == "__main__":
769
+ main()
training/model.py CHANGED
@@ -18,7 +18,7 @@ language modeling (next-token prediction).
18
 
19
  ARCHITECTURE OVERVIEW:
20
  - Token Embedding: Maps token IDs to dense vectors
21
- - Positional Embedding: Adds position information to token embeddings
22
  - Transformer Blocks: Stack of multi-head attention + feed-forward layers
23
  - Layer Normalization: Pre-norm placement for training stability
24
  - Output Head: Linear projection to vocabulary for next-token prediction
@@ -32,16 +32,16 @@ FEATURES:
32
 
33
  Usage:
34
  from model import GPTConfig, GPTModel
35
-
36
  config = GPTConfig(vocab_size=32000, n_layer=12, n_head=12, n_embd=768)
37
  model = GPTModel(config)
38
-
39
  # Forward pass
40
  logits = model(input_ids) # Shape: (batch_size, seq_len, vocab_size)
41
 
42
  Hardware Requirements:
43
  - Small Model (25M params): 4-8GB RAM, CPU/integrated GPU
44
- - Medium Model (117M params): 8-16GB RAM, dedicated GPU recommended
45
  - Large Model (350M params): 16GB+ RAM, high-end GPU required
46
 
47
  Author: Louis Chua Bean Chong
@@ -60,43 +60,43 @@ from typing import Optional, Tuple
60
  class GPTConfig:
61
  """
62
  Configuration class for GPT model hyperparameters.
63
-
64
  This class defines all the architectural parameters needed to instantiate
65
  a GPT model. Use the provided class methods to get pre-configured setups
66
  for different model sizes.
67
  """
68
-
69
  # Model architecture
70
- vocab_size: int = 32000 # Vocabulary size (from tokenizer)
71
- n_layer: int = 12 # Number of transformer layers
72
- n_head: int = 12 # Number of attention heads
73
- n_embd: int = 768 # Embedding dimension
74
-
75
  # Sequence and context
76
- block_size: int = 1024 # Maximum sequence length
77
-
78
  # Training hyperparameters
79
- dropout: float = 0.1 # Dropout probability
80
- bias: bool = True # Use bias in linear layers
81
-
82
  # Model size identifier
83
- model_name: str = "gpt-medium" # Human-readable model identifier
84
-
85
  @classmethod
86
- def small(cls) -> 'GPTConfig':
87
  """Small model configuration (~25M parameters) - Good for CPU training"""
88
  return cls(
89
  vocab_size=32000,
90
  n_layer=6,
91
- n_head=8,
92
  n_embd=512,
93
  block_size=1024,
94
  dropout=0.1,
95
- model_name="gpt-small"
96
  )
97
-
98
- @classmethod
99
- def medium(cls) -> 'GPTConfig':
100
  """Medium model configuration (~117M parameters) - Balanced performance"""
101
  return cls(
102
  vocab_size=32000,
@@ -105,11 +105,11 @@ class GPTConfig:
105
  n_embd=768,
106
  block_size=2048,
107
  dropout=0.1,
108
- model_name="gpt-medium"
109
  )
110
-
111
  @classmethod
112
- def large(cls) -> 'GPTConfig':
113
  """Large model configuration (~350M parameters) - High performance"""
114
  return cls(
115
  vocab_size=32000,
@@ -118,29 +118,29 @@ class GPTConfig:
118
  n_embd=1024,
119
  block_size=2048,
120
  dropout=0.1,
121
- model_name="gpt-large"
122
  )
123
-
124
  def estimate_parameters(self) -> int:
125
  """
126
  Estimate the total number of trainable parameters.
127
-
128
  Returns:
129
  int: Estimated parameter count
130
  """
131
  # Token embeddings
132
  token_emb = self.vocab_size * self.n_embd
133
-
134
- # Position embeddings
135
  pos_emb = self.block_size * self.n_embd
136
-
137
  # Transformer layers
138
  # Each layer: attention (4 * n_embd^2) + mlp (8 * n_embd^2) + layer_norms
139
  layer_params = self.n_layer * (12 * self.n_embd**2 + 4 * self.n_embd)
140
-
141
  # Output head
142
  output_head = self.vocab_size * self.n_embd
143
-
144
  total = token_emb + pos_emb + layer_params + output_head
145
  return total
146
 
@@ -148,75 +148,78 @@ class GPTConfig:
148
  class CausalSelfAttention(nn.Module):
149
  """
150
  Multi-head causal self-attention mechanism.
151
-
152
  This implements the core attention mechanism of the transformer, with causal
153
  masking to ensure autoregressive behavior (tokens can only attend to previous
154
  tokens, not future ones).
155
  """
156
-
157
  def __init__(self, config: GPTConfig):
158
  super().__init__()
159
- assert config.n_embd % config.n_head == 0, "Embedding dim must be divisible by number of heads"
160
-
 
 
161
  self.config = config
162
  self.n_head = config.n_head
163
  self.n_embd = config.n_embd
164
  self.head_dim = self.n_embd // self.n_head
165
-
166
  # Key, query, value projections for all heads (batched)
167
  self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
168
-
169
  # Output projection
170
  self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
171
-
172
  # Dropout
173
  self.attn_dropout = nn.Dropout(config.dropout)
174
  self.resid_dropout = nn.Dropout(config.dropout)
175
-
176
  # Causal mask - lower triangular matrix
177
  self.register_buffer(
178
  "bias",
179
- torch.tril(torch.ones(config.block_size, config.block_size))
180
- .view(1, 1, config.block_size, config.block_size)
 
181
  )
182
-
183
  def forward(self, x: torch.Tensor) -> torch.Tensor:
184
  """
185
  Forward pass of causal self-attention.
186
-
187
  This method implements the scaled dot-product attention mechanism with causal masking.
188
  The attention mechanism allows each token to attend to all previous tokens in the sequence,
189
  but not to future tokens, maintaining the autoregressive property essential for language modeling.
190
-
191
  Mathematical formulation:
192
  Attention(Q, K, V) = softmax(QK^T / sqrt(d_k))V
193
  where Q, K, V are query, key, value matrices derived from input x
194
-
195
  Implementation details:
196
  - Uses batch matrix multiplication for efficiency
197
  - Applies causal mask to prevent future token attention
198
  - Implements multi-head attention by reshaping and parallel processing
199
  - Applies dropout for regularization during training
200
-
201
  Args:
202
  x: Input tensor of shape (batch_size, seq_len, n_embd)
203
  Contains embedded token representations from previous layer
204
-
205
  Returns:
206
  torch.Tensor: Output tensor of shape (batch_size, seq_len, n_embd)
207
  """
208
  # Extract tensor dimensions for clear variable naming and validation
209
  # B = batch size (number of sequences processed in parallel)
210
- # T = sequence length (number of tokens in each sequence)
211
  # C = embedding dimensionality (n_embd from config)
212
  B, T, C = x.size()
213
-
214
  # Generate query, key, and value projections for all attention heads
215
  # The c_attn linear layer outputs 3 * n_embd features, which we split into Q, K, V
216
  # This batched approach is more efficient than separate linear layers
217
  # Input shape: (B, T, C) -> Output shape: (B, T, 3*C) -> Split to 3x (B, T, C)
218
  q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
219
-
220
  # Reshape tensors for multi-head attention computation
221
  # Transform from (B, T, C) to (B, nh, T, hs) where:
222
  # - nh = number of heads (self.n_head)
@@ -225,41 +228,41 @@ class CausalSelfAttention(nn.Module):
225
  q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
226
  k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
227
  v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
228
-
229
  # Compute scaled dot-product attention scores
230
  # Matrix multiplication: Q @ K^T gives attention affinities between all token pairs
231
  # Scaling by 1/sqrt(head_dim) prevents softmax saturation for large embedding dimensions
232
  # Shape: (B, nh, T, hs) @ (B, nh, hs, T) -> (B, nh, T, T)
233
  # The resulting (T, T) matrix represents attention weights from each token to every other token
234
  att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))
235
-
236
  # Apply causal masking to enforce autoregressive property
237
  # The causal mask ensures that token i can only attend to tokens j where j <= i
238
  # This prevents the model from "cheating" by looking at future tokens during training
239
  # We use -inf for masked positions so they become 0 after softmax
240
- att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
241
-
242
  # Convert attention scores to probabilities using softmax
243
  # Each row of the attention matrix now sums to 1, representing a probability distribution
244
  # over which tokens to attend to for each query position
245
  att = F.softmax(att, dim=-1)
246
-
247
  # Apply dropout to attention weights for regularization
248
  # This randomly zeros some attention connections during training to prevent overfitting
249
  att = self.attn_dropout(att)
250
-
251
  # Apply attention weights to value vectors
252
  # This weighted combination produces the actual output of the attention mechanism
253
  # Shape: (B, nh, T, T) @ (B, nh, T, hs) -> (B, nh, T, hs)
254
  # Each output position is a weighted sum of all value vectors, with weights from attention
255
  y = att @ v
256
-
257
  # Concatenate multi-head outputs back to original embedding dimension
258
  # Transform from (B, nh, T, hs) back to (B, T, C) where C = nh * hs
259
  # The transpose moves head dimension back, and contiguous() ensures memory layout efficiency
260
  # This combines information from all attention heads into a single representation
261
  y = y.transpose(1, 2).contiguous().view(B, T, C)
262
-
263
  # Apply final output projection and residual dropout
264
  # The output projection allows the model to learn how to best combine multi-head information
265
  # Residual dropout provides additional regularization before the residual connection
@@ -270,62 +273,62 @@ class CausalSelfAttention(nn.Module):
270
  class MLP(nn.Module):
271
  """
272
  Multi-Layer Perceptron (Feed-Forward Network) for Transformer.
273
-
274
  This implements the position-wise feed-forward network that appears in each transformer layer.
275
  The MLP provides additional non-linear transformation capacity beyond what attention provides.
276
-
277
  Architecture:
278
  Input -> Linear(n_embd -> 4*n_embd) -> GELU -> Linear(4*n_embd -> n_embd) -> Dropout -> Output
279
-
280
  Design rationale:
281
  - 4x expansion is standard in transformers (from "Attention Is All You Need")
282
  - GELU activation provides smoother gradients than ReLU for language modeling
283
  - Dropout prevents overfitting in the feed-forward layers
284
  - Two linear layers allow complex non-linear transformations of attention outputs
285
-
286
  Parameters:
287
  - First linear layer: n_embd * 4*n_embd parameters (expansion)
288
  - Second linear layer: 4*n_embd * n_embd parameters (projection back)
289
  - Total: 8 * n_embd^2 parameters (significant portion of model size)
290
  """
291
-
292
  def __init__(self, config: GPTConfig):
293
  super().__init__()
294
-
295
  # First linear layer: expand embedding dimension by 4x
296
  # This expansion gives the network more representational capacity
297
  # The 4x factor is a standard choice that balances capacity vs efficiency
298
  self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
299
-
300
  # GELU (Gaussian Error Linear Unit) activation function
301
  # GELU provides smoother gradients compared to ReLU and works better for language modeling
302
  # It's approximately: GELU(x) = x * Ξ¦(x) where Ξ¦ is the CDF of standard normal distribution
303
  self.gelu = nn.GELU()
304
-
305
  # Second linear layer: project back to original embedding dimension
306
  # This projection allows the network to combine information from the expanded representation
307
  self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
308
-
309
  # Dropout for regularization in the feed-forward network
310
  # Applied after the final projection to prevent overfitting
311
  self.dropout = nn.Dropout(config.dropout)
312
-
313
  def forward(self, x: torch.Tensor) -> torch.Tensor:
314
  """
315
  Forward pass of the feed-forward network.
316
-
317
  This method applies a two-layer MLP with GELU activation to transform
318
  the attention outputs. The MLP operates independently on each position
319
  in the sequence, providing position-wise non-linear transformations.
320
-
321
  Mathematical operation:
322
  MLP(x) = Dropout(Linearβ‚‚(GELU(Linear₁(x))))
323
  where Linear₁: R^n_embd -> R^4*n_embd and Linearβ‚‚: R^4*n_embd -> R^n_embd
324
-
325
  Args:
326
  x: Input tensor of shape (batch_size, seq_len, n_embd)
327
  Contains attended representations from the attention layer
328
-
329
  Returns:
330
  torch.Tensor: Output tensor of shape (batch_size, seq_len, n_embd)
331
  Contains transformed representations ready for residual connection
@@ -334,112 +337,116 @@ class MLP(nn.Module):
334
  # This expansion provides the network with a higher-dimensional space for computation
335
  # Shape: (batch_size, seq_len, n_embd) -> (batch_size, seq_len, 4*n_embd)
336
  x = self.c_fc(x)
337
-
338
  # Apply GELU activation function for non-linearity
339
  # GELU is smoother than ReLU and provides better gradients for language modeling
340
  # It introduces non-linearity while maintaining differentiability everywhere
341
  x = self.gelu(x)
342
-
343
  # Second linear transformation: project back to original n_embd dimensions
344
  # This projection combines information from the expanded representation
345
  # Shape: (batch_size, seq_len, 4*n_embd) -> (batch_size, seq_len, n_embd)
346
  x = self.c_proj(x)
347
-
348
  # Apply dropout for regularization before residual connection
349
  # Dropout randomly zeros some neurons during training to prevent overfitting
350
  # This is particularly important in the feed-forward layers which have many parameters
351
  x = self.dropout(x)
352
-
353
  return x
354
 
355
 
356
  class Block(nn.Module):
357
  """
358
  Single Transformer block.
359
-
360
  Consists of:
361
  1. Layer normalization
362
- 2. Multi-head causal self-attention
363
  3. Residual connection
364
  4. Layer normalization
365
  5. MLP (feed-forward network)
366
  6. Residual connection
367
-
368
  Uses pre-norm architecture for better training stability.
369
  """
370
-
371
  def __init__(self, config: GPTConfig):
372
  super().__init__()
373
  self.ln_1 = nn.LayerNorm(config.n_embd)
374
  self.attn = CausalSelfAttention(config)
375
  self.ln_2 = nn.LayerNorm(config.n_embd)
376
  self.mlp = MLP(config)
377
-
378
  def forward(self, x: torch.Tensor) -> torch.Tensor:
379
  """
380
  Forward pass of transformer block.
381
-
382
  Args:
383
  x: Input tensor of shape (batch_size, seq_len, n_embd)
384
-
385
  Returns:
386
  torch.Tensor: Output tensor of shape (batch_size, seq_len, n_embd)
387
  """
388
  # Pre-norm attention with residual connection
389
  x = x + self.attn(self.ln_1(x))
390
-
391
- # Pre-norm MLP with residual connection
392
  x = x + self.mlp(self.ln_2(x))
393
-
394
  return x
395
 
396
 
397
  class GPTModel(nn.Module):
398
  """
399
  Complete GPT Language Model.
400
-
401
  This is the main model class that combines all components:
402
  - Token and positional embeddings
403
  - Stack of transformer blocks
404
  - Final layer normalization
405
  - Language modeling head
406
-
407
  The model can be used for:
408
  - Training from scratch on text data
409
  - Fine-tuning on downstream tasks
410
  - Text generation (inference)
411
  """
412
-
413
  def __init__(self, config: GPTConfig):
414
  super().__init__()
415
  assert config.vocab_size is not None, "vocab_size must be specified"
416
  assert config.block_size is not None, "block_size must be specified"
417
-
418
  self.config = config
419
-
420
  # Embeddings
421
- self.transformer = nn.ModuleDict(dict(
422
- wte = nn.Embedding(config.vocab_size, config.n_embd), # Token embeddings
423
- wpe = nn.Embedding(config.block_size, config.n_embd), # Position embeddings
424
- drop = nn.Dropout(config.dropout),
425
- h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), # Transformer blocks
426
- ln_f = nn.LayerNorm(config.n_embd), # Final layer norm
427
- ))
428
-
 
 
 
 
429
  # Language modeling head (maps hidden states to vocabulary)
430
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
431
-
432
  # Tie weights between token embeddings and output head (common practice)
433
  self.transformer.wte.weight = self.lm_head.weight
434
-
435
  # Initialize weights
436
  self.apply(self._init_weights)
437
-
438
  # Report parameter count
439
  print(f"Model initialized: {self.config.model_name}")
440
  print(f"Parameters: {self.get_num_params():,}")
441
  print(f"Estimated: {self.config.estimate_parameters():,}")
442
-
443
  def _init_weights(self, module):
444
  """Initialize model weights using standard practices."""
445
  if isinstance(module, nn.Linear):
@@ -448,14 +455,14 @@ class GPTModel(nn.Module):
448
  torch.nn.init.zeros_(module.bias)
449
  elif isinstance(module, nn.Embedding):
450
  torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
451
-
452
  def get_num_params(self, non_embedding: bool = False) -> int:
453
  """
454
  Count the number of parameters in the model.
455
-
456
  Args:
457
  non_embedding: If True, subtract embedding parameters
458
-
459
  Returns:
460
  int: Number of parameters
461
  """
@@ -464,19 +471,17 @@ class GPTModel(nn.Module):
464
  n_params -= self.transformer.wpe.weight.numel()
465
  n_params -= self.transformer.wte.weight.numel()
466
  return n_params
467
-
468
  def forward(
469
- self,
470
- idx: torch.Tensor,
471
- targets: Optional[torch.Tensor] = None
472
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
473
  """
474
  Forward pass of the GPT model.
475
-
476
  Args:
477
  idx: Input token indices of shape (batch_size, seq_len)
478
  targets: Optional target tokens for loss calculation (batch_size, seq_len)
479
-
480
  Returns:
481
  Tuple containing:
482
  - logits: Output logits of shape (batch_size, seq_len, vocab_size)
@@ -484,53 +489,57 @@ class GPTModel(nn.Module):
484
  """
485
  device = idx.device
486
  b, t = idx.size()
487
- assert t <= self.config.block_size, f"Sequence length {t} exceeds block size {self.config.block_size}"
488
-
 
 
489
  # Token embeddings
490
  tok_emb = self.transformer.wte(idx) # (b, t, n_embd)
491
-
492
  # Position embeddings
493
  pos = torch.arange(0, t, dtype=torch.long, device=device) # (t,)
494
  pos_emb = self.transformer.wpe(pos) # (t, n_embd)
495
-
496
  # Combine embeddings and apply dropout
497
  x = self.transformer.drop(tok_emb + pos_emb)
498
-
499
  # Pass through transformer blocks
500
  for block in self.transformer.h:
501
  x = block(x)
502
-
503
  # Final layer normalization
504
  x = self.transformer.ln_f(x)
505
-
506
  # Language modeling head
507
  if targets is not None:
508
  # If we have targets, compute loss
509
  logits = self.lm_head(x)
510
- loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
 
 
511
  else:
512
  # If no targets, only compute logits for the last token (more efficient for generation)
513
  logits = self.lm_head(x[:, [-1], :]) # Note: using list [-1] to preserve the time dim
514
  loss = None
515
-
516
  return logits, loss
517
-
518
  def generate(
519
- self,
520
- idx: torch.Tensor,
521
  max_new_tokens: int = 100,
522
  temperature: float = 1.0,
523
- top_k: Optional[int] = None
524
  ) -> torch.Tensor:
525
  """
526
  Generate new tokens autoregressively.
527
-
528
  Args:
529
  idx: Starting token indices (batch_size, seq_len)
530
  max_new_tokens: Maximum number of new tokens to generate
531
  temperature: Sampling temperature (higher = more random)
532
  top_k: If set, only sample from top-k most likely tokens
533
-
534
  Returns:
535
  torch.Tensor: Generated sequence (batch_size, seq_len + max_new_tokens)
536
  """
@@ -538,57 +547,61 @@ class GPTModel(nn.Module):
538
  with torch.no_grad():
539
  for _ in range(max_new_tokens):
540
  # Crop sequence if it exceeds block size
541
- idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
542
-
 
 
 
 
543
  # Forward pass
544
  logits, _ = self(idx_cond)
545
-
546
  # Get logits for the last token and apply temperature
547
  logits = logits[:, -1, :] / temperature
548
-
549
  # Optionally crop to top-k most likely tokens
550
  if top_k is not None:
551
  v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
552
- logits[logits < v[:, [-1]]] = -float('Inf')
553
-
554
  # Apply softmax and sample
555
  probs = F.softmax(logits, dim=-1)
556
  idx_next = torch.multinomial(probs, num_samples=1)
557
-
558
  # Append to sequence
559
  idx = torch.cat((idx, idx_next), dim=1)
560
-
561
  self.train() # Return to training mode
562
  return idx
563
-
564
  def estimate_memory_usage(self, batch_size: int = 1, seq_len: int = None) -> dict:
565
  """
566
  Estimate memory usage for training and inference.
567
-
568
  Args:
569
  batch_size: Batch size for estimation
570
  seq_len: Sequence length (defaults to block_size)
571
-
572
  Returns:
573
  dict: Memory usage estimates in MB
574
  """
575
  if seq_len is None:
576
  seq_len = self.config.block_size
577
-
578
  # Model parameters (weights)
579
  param_memory = self.get_num_params() * 4 / (1024**2) # 4 bytes per float32
580
-
581
  # Activations (rough estimate)
582
  activation_memory = (
583
  batch_size * seq_len * self.config.n_embd * self.config.n_layer * 8 # Rough estimate
584
  ) / (1024**2)
585
-
586
  # Gradients (same size as parameters during training)
587
  gradient_memory = param_memory
588
-
589
  return {
590
  "parameters_mb": param_memory,
591
- "activations_mb": activation_memory,
592
  "gradients_mb": gradient_memory,
593
  "total_training_mb": param_memory + activation_memory + gradient_memory,
594
  "total_inference_mb": param_memory + activation_memory * 0.5, # No gradients needed
@@ -598,25 +611,25 @@ class GPTModel(nn.Module):
598
  def create_model(model_size: str = "medium") -> GPTModel:
599
  """
600
  Factory function to create a GPT model with predefined configurations.
601
-
602
  Args:
603
  model_size: Size of model to create ("small", "medium", "large")
604
-
605
  Returns:
606
  GPTModel: Initialized model
607
  """
608
  configs = {
609
  "small": GPTConfig.small(),
610
- "medium": GPTConfig.medium(),
611
  "large": GPTConfig.large(),
612
  }
613
-
614
  if model_size not in configs:
615
  raise ValueError(f"Unknown model size: {model_size}. Choose from {list(configs.keys())}")
616
-
617
  config = configs[model_size]
618
  model = GPTModel(config)
619
-
620
  return model
621
 
622
 
@@ -624,18 +637,20 @@ if __name__ == "__main__":
624
  # Example usage
625
  print("🧠 GPT Model Architecture")
626
  print("=" * 50)
627
-
628
  # Create models of different sizes
629
  for size in ["small", "medium", "large"]:
630
  print(f"\n{size.upper()} MODEL:")
631
  model = create_model(size)
632
-
633
  # Show memory estimates
634
  memory = model.estimate_memory_usage(batch_size=4, seq_len=512)
635
- print(f"Memory (4 batch, 512 seq): {memory['total_training_mb']:.1f}MB training, {memory['total_inference_mb']:.1f}MB inference")
636
-
 
 
637
  # Test forward pass
638
  x = torch.randint(0, 32000, (2, 64)) # Batch size 2, sequence length 64
639
  with torch.no_grad():
640
  logits, _ = model(x)
641
- print(f"Test forward pass: {x.shape} -> {logits.shape} βœ“")
 
18
 
19
  ARCHITECTURE OVERVIEW:
20
  - Token Embedding: Maps token IDs to dense vectors
21
+ - Positional Embedding: Adds position information to token embeddings
22
  - Transformer Blocks: Stack of multi-head attention + feed-forward layers
23
  - Layer Normalization: Pre-norm placement for training stability
24
  - Output Head: Linear projection to vocabulary for next-token prediction
 
32
 
33
  Usage:
34
  from model import GPTConfig, GPTModel
35
+
36
  config = GPTConfig(vocab_size=32000, n_layer=12, n_head=12, n_embd=768)
37
  model = GPTModel(config)
38
+
39
  # Forward pass
40
  logits = model(input_ids) # Shape: (batch_size, seq_len, vocab_size)
41
 
42
  Hardware Requirements:
43
  - Small Model (25M params): 4-8GB RAM, CPU/integrated GPU
44
+ - Medium Model (117M params): 8-16GB RAM, dedicated GPU recommended
45
  - Large Model (350M params): 16GB+ RAM, high-end GPU required
46
 
47
  Author: Louis Chua Bean Chong
 
60
  class GPTConfig:
61
  """
62
  Configuration class for GPT model hyperparameters.
63
+
64
  This class defines all the architectural parameters needed to instantiate
65
  a GPT model. Use the provided class methods to get pre-configured setups
66
  for different model sizes.
67
  """
68
+
69
  # Model architecture
70
+ vocab_size: int = 32000 # Vocabulary size (from tokenizer)
71
+ n_layer: int = 12 # Number of transformer layers
72
+ n_head: int = 12 # Number of attention heads
73
+ n_embd: int = 768 # Embedding dimension
74
+
75
  # Sequence and context
76
+ block_size: int = 1024 # Maximum sequence length
77
+
78
  # Training hyperparameters
79
+ dropout: float = 0.1 # Dropout probability
80
+ bias: bool = True # Use bias in linear layers
81
+
82
  # Model size identifier
83
+ model_name: str = "gpt-medium" # Human-readable model identifier
84
+
85
  @classmethod
86
+ def small(cls) -> "GPTConfig":
87
  """Small model configuration (~25M parameters) - Good for CPU training"""
88
  return cls(
89
  vocab_size=32000,
90
  n_layer=6,
91
+ n_head=8,
92
  n_embd=512,
93
  block_size=1024,
94
  dropout=0.1,
95
+ model_name="gpt-small",
96
  )
97
+
98
+ @classmethod
99
+ def medium(cls) -> "GPTConfig":
100
  """Medium model configuration (~117M parameters) - Balanced performance"""
101
  return cls(
102
  vocab_size=32000,
 
105
  n_embd=768,
106
  block_size=2048,
107
  dropout=0.1,
108
+ model_name="gpt-medium",
109
  )
110
+
111
  @classmethod
112
+ def large(cls) -> "GPTConfig":
113
  """Large model configuration (~350M parameters) - High performance"""
114
  return cls(
115
  vocab_size=32000,
 
118
  n_embd=1024,
119
  block_size=2048,
120
  dropout=0.1,
121
+ model_name="gpt-large",
122
  )
123
+
124
  def estimate_parameters(self) -> int:
125
  """
126
  Estimate the total number of trainable parameters.
127
+
128
  Returns:
129
  int: Estimated parameter count
130
  """
131
  # Token embeddings
132
  token_emb = self.vocab_size * self.n_embd
133
+
134
+ # Position embeddings
135
  pos_emb = self.block_size * self.n_embd
136
+
137
  # Transformer layers
138
  # Each layer: attention (4 * n_embd^2) + mlp (8 * n_embd^2) + layer_norms
139
  layer_params = self.n_layer * (12 * self.n_embd**2 + 4 * self.n_embd)
140
+
141
  # Output head
142
  output_head = self.vocab_size * self.n_embd
143
+
144
  total = token_emb + pos_emb + layer_params + output_head
145
  return total
146
 
 
148
  class CausalSelfAttention(nn.Module):
149
  """
150
  Multi-head causal self-attention mechanism.
151
+
152
  This implements the core attention mechanism of the transformer, with causal
153
  masking to ensure autoregressive behavior (tokens can only attend to previous
154
  tokens, not future ones).
155
  """
156
+
157
  def __init__(self, config: GPTConfig):
158
  super().__init__()
159
+ assert (
160
+ config.n_embd % config.n_head == 0
161
+ ), "Embedding dim must be divisible by number of heads"
162
+
163
  self.config = config
164
  self.n_head = config.n_head
165
  self.n_embd = config.n_embd
166
  self.head_dim = self.n_embd // self.n_head
167
+
168
  # Key, query, value projections for all heads (batched)
169
  self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
170
+
171
  # Output projection
172
  self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
173
+
174
  # Dropout
175
  self.attn_dropout = nn.Dropout(config.dropout)
176
  self.resid_dropout = nn.Dropout(config.dropout)
177
+
178
  # Causal mask - lower triangular matrix
179
  self.register_buffer(
180
  "bias",
181
+ torch.tril(torch.ones(config.block_size, config.block_size)).view(
182
+ 1, 1, config.block_size, config.block_size
183
+ ),
184
  )
185
+
186
  def forward(self, x: torch.Tensor) -> torch.Tensor:
187
  """
188
  Forward pass of causal self-attention.
189
+
190
  This method implements the scaled dot-product attention mechanism with causal masking.
191
  The attention mechanism allows each token to attend to all previous tokens in the sequence,
192
  but not to future tokens, maintaining the autoregressive property essential for language modeling.
193
+
194
  Mathematical formulation:
195
  Attention(Q, K, V) = softmax(QK^T / sqrt(d_k))V
196
  where Q, K, V are query, key, value matrices derived from input x
197
+
198
  Implementation details:
199
  - Uses batch matrix multiplication for efficiency
200
  - Applies causal mask to prevent future token attention
201
  - Implements multi-head attention by reshaping and parallel processing
202
  - Applies dropout for regularization during training
203
+
204
  Args:
205
  x: Input tensor of shape (batch_size, seq_len, n_embd)
206
  Contains embedded token representations from previous layer
207
+
208
  Returns:
209
  torch.Tensor: Output tensor of shape (batch_size, seq_len, n_embd)
210
  """
211
  # Extract tensor dimensions for clear variable naming and validation
212
  # B = batch size (number of sequences processed in parallel)
213
+ # T = sequence length (number of tokens in each sequence)
214
  # C = embedding dimensionality (n_embd from config)
215
  B, T, C = x.size()
216
+
217
  # Generate query, key, and value projections for all attention heads
218
  # The c_attn linear layer outputs 3 * n_embd features, which we split into Q, K, V
219
  # This batched approach is more efficient than separate linear layers
220
  # Input shape: (B, T, C) -> Output shape: (B, T, 3*C) -> Split to 3x (B, T, C)
221
  q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
222
+
223
  # Reshape tensors for multi-head attention computation
224
  # Transform from (B, T, C) to (B, nh, T, hs) where:
225
  # - nh = number of heads (self.n_head)
 
228
  q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
229
  k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
230
  v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
231
+
232
  # Compute scaled dot-product attention scores
233
  # Matrix multiplication: Q @ K^T gives attention affinities between all token pairs
234
  # Scaling by 1/sqrt(head_dim) prevents softmax saturation for large embedding dimensions
235
  # Shape: (B, nh, T, hs) @ (B, nh, hs, T) -> (B, nh, T, T)
236
  # The resulting (T, T) matrix represents attention weights from each token to every other token
237
  att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))
238
+
239
  # Apply causal masking to enforce autoregressive property
240
  # The causal mask ensures that token i can only attend to tokens j where j <= i
241
  # This prevents the model from "cheating" by looking at future tokens during training
242
  # We use -inf for masked positions so they become 0 after softmax
243
+ att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
244
+
245
  # Convert attention scores to probabilities using softmax
246
  # Each row of the attention matrix now sums to 1, representing a probability distribution
247
  # over which tokens to attend to for each query position
248
  att = F.softmax(att, dim=-1)
249
+
250
  # Apply dropout to attention weights for regularization
251
  # This randomly zeros some attention connections during training to prevent overfitting
252
  att = self.attn_dropout(att)
253
+
254
  # Apply attention weights to value vectors
255
  # This weighted combination produces the actual output of the attention mechanism
256
  # Shape: (B, nh, T, T) @ (B, nh, T, hs) -> (B, nh, T, hs)
257
  # Each output position is a weighted sum of all value vectors, with weights from attention
258
  y = att @ v
259
+
260
  # Concatenate multi-head outputs back to original embedding dimension
261
  # Transform from (B, nh, T, hs) back to (B, T, C) where C = nh * hs
262
  # The transpose moves head dimension back, and contiguous() ensures memory layout efficiency
263
  # This combines information from all attention heads into a single representation
264
  y = y.transpose(1, 2).contiguous().view(B, T, C)
265
+
266
  # Apply final output projection and residual dropout
267
  # The output projection allows the model to learn how to best combine multi-head information
268
  # Residual dropout provides additional regularization before the residual connection
 
273
  class MLP(nn.Module):
274
  """
275
  Multi-Layer Perceptron (Feed-Forward Network) for Transformer.
276
+
277
  This implements the position-wise feed-forward network that appears in each transformer layer.
278
  The MLP provides additional non-linear transformation capacity beyond what attention provides.
279
+
280
  Architecture:
281
  Input -> Linear(n_embd -> 4*n_embd) -> GELU -> Linear(4*n_embd -> n_embd) -> Dropout -> Output
282
+
283
  Design rationale:
284
  - 4x expansion is standard in transformers (from "Attention Is All You Need")
285
  - GELU activation provides smoother gradients than ReLU for language modeling
286
  - Dropout prevents overfitting in the feed-forward layers
287
  - Two linear layers allow complex non-linear transformations of attention outputs
288
+
289
  Parameters:
290
  - First linear layer: n_embd * 4*n_embd parameters (expansion)
291
  - Second linear layer: 4*n_embd * n_embd parameters (projection back)
292
  - Total: 8 * n_embd^2 parameters (significant portion of model size)
293
  """
294
+
295
  def __init__(self, config: GPTConfig):
296
  super().__init__()
297
+
298
  # First linear layer: expand embedding dimension by 4x
299
  # This expansion gives the network more representational capacity
300
  # The 4x factor is a standard choice that balances capacity vs efficiency
301
  self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
302
+
303
  # GELU (Gaussian Error Linear Unit) activation function
304
  # GELU provides smoother gradients compared to ReLU and works better for language modeling
305
  # It's approximately: GELU(x) = x * Ξ¦(x) where Ξ¦ is the CDF of standard normal distribution
306
  self.gelu = nn.GELU()
307
+
308
  # Second linear layer: project back to original embedding dimension
309
  # This projection allows the network to combine information from the expanded representation
310
  self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
311
+
312
  # Dropout for regularization in the feed-forward network
313
  # Applied after the final projection to prevent overfitting
314
  self.dropout = nn.Dropout(config.dropout)
315
+
316
  def forward(self, x: torch.Tensor) -> torch.Tensor:
317
  """
318
  Forward pass of the feed-forward network.
319
+
320
  This method applies a two-layer MLP with GELU activation to transform
321
  the attention outputs. The MLP operates independently on each position
322
  in the sequence, providing position-wise non-linear transformations.
323
+
324
  Mathematical operation:
325
  MLP(x) = Dropout(Linearβ‚‚(GELU(Linear₁(x))))
326
  where Linear₁: R^n_embd -> R^4*n_embd and Linearβ‚‚: R^4*n_embd -> R^n_embd
327
+
328
  Args:
329
  x: Input tensor of shape (batch_size, seq_len, n_embd)
330
  Contains attended representations from the attention layer
331
+
332
  Returns:
333
  torch.Tensor: Output tensor of shape (batch_size, seq_len, n_embd)
334
  Contains transformed representations ready for residual connection
 
337
  # This expansion provides the network with a higher-dimensional space for computation
338
  # Shape: (batch_size, seq_len, n_embd) -> (batch_size, seq_len, 4*n_embd)
339
  x = self.c_fc(x)
340
+
341
  # Apply GELU activation function for non-linearity
342
  # GELU is smoother than ReLU and provides better gradients for language modeling
343
  # It introduces non-linearity while maintaining differentiability everywhere
344
  x = self.gelu(x)
345
+
346
  # Second linear transformation: project back to original n_embd dimensions
347
  # This projection combines information from the expanded representation
348
  # Shape: (batch_size, seq_len, 4*n_embd) -> (batch_size, seq_len, n_embd)
349
  x = self.c_proj(x)
350
+
351
  # Apply dropout for regularization before residual connection
352
  # Dropout randomly zeros some neurons during training to prevent overfitting
353
  # This is particularly important in the feed-forward layers which have many parameters
354
  x = self.dropout(x)
355
+
356
  return x
357
 
358
 
359
  class Block(nn.Module):
360
  """
361
  Single Transformer block.
362
+
363
  Consists of:
364
  1. Layer normalization
365
+ 2. Multi-head causal self-attention
366
  3. Residual connection
367
  4. Layer normalization
368
  5. MLP (feed-forward network)
369
  6. Residual connection
370
+
371
  Uses pre-norm architecture for better training stability.
372
  """
373
+
374
  def __init__(self, config: GPTConfig):
375
  super().__init__()
376
  self.ln_1 = nn.LayerNorm(config.n_embd)
377
  self.attn = CausalSelfAttention(config)
378
  self.ln_2 = nn.LayerNorm(config.n_embd)
379
  self.mlp = MLP(config)
380
+
381
  def forward(self, x: torch.Tensor) -> torch.Tensor:
382
  """
383
  Forward pass of transformer block.
384
+
385
  Args:
386
  x: Input tensor of shape (batch_size, seq_len, n_embd)
387
+
388
  Returns:
389
  torch.Tensor: Output tensor of shape (batch_size, seq_len, n_embd)
390
  """
391
  # Pre-norm attention with residual connection
392
  x = x + self.attn(self.ln_1(x))
393
+
394
+ # Pre-norm MLP with residual connection
395
  x = x + self.mlp(self.ln_2(x))
396
+
397
  return x
398
 
399
 
400
  class GPTModel(nn.Module):
401
  """
402
  Complete GPT Language Model.
403
+
404
  This is the main model class that combines all components:
405
  - Token and positional embeddings
406
  - Stack of transformer blocks
407
  - Final layer normalization
408
  - Language modeling head
409
+
410
  The model can be used for:
411
  - Training from scratch on text data
412
  - Fine-tuning on downstream tasks
413
  - Text generation (inference)
414
  """
415
+
416
  def __init__(self, config: GPTConfig):
417
  super().__init__()
418
  assert config.vocab_size is not None, "vocab_size must be specified"
419
  assert config.block_size is not None, "block_size must be specified"
420
+
421
  self.config = config
422
+
423
  # Embeddings
424
+ self.transformer = nn.ModuleDict(
425
+ dict(
426
+ wte=nn.Embedding(config.vocab_size, config.n_embd), # Token embeddings
427
+ wpe=nn.Embedding(config.block_size, config.n_embd), # Position embeddings
428
+ drop=nn.Dropout(config.dropout),
429
+ h=nn.ModuleList(
430
+ [Block(config) for _ in range(config.n_layer)]
431
+ ), # Transformer blocks
432
+ ln_f=nn.LayerNorm(config.n_embd), # Final layer norm
433
+ )
434
+ )
435
+
436
  # Language modeling head (maps hidden states to vocabulary)
437
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
438
+
439
  # Tie weights between token embeddings and output head (common practice)
440
  self.transformer.wte.weight = self.lm_head.weight
441
+
442
  # Initialize weights
443
  self.apply(self._init_weights)
444
+
445
  # Report parameter count
446
  print(f"Model initialized: {self.config.model_name}")
447
  print(f"Parameters: {self.get_num_params():,}")
448
  print(f"Estimated: {self.config.estimate_parameters():,}")
449
+
450
  def _init_weights(self, module):
451
  """Initialize model weights using standard practices."""
452
  if isinstance(module, nn.Linear):
 
455
  torch.nn.init.zeros_(module.bias)
456
  elif isinstance(module, nn.Embedding):
457
  torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
458
+
459
  def get_num_params(self, non_embedding: bool = False) -> int:
460
  """
461
  Count the number of parameters in the model.
462
+
463
  Args:
464
  non_embedding: If True, subtract embedding parameters
465
+
466
  Returns:
467
  int: Number of parameters
468
  """
 
471
  n_params -= self.transformer.wpe.weight.numel()
472
  n_params -= self.transformer.wte.weight.numel()
473
  return n_params
474
+
475
  def forward(
476
+ self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None
 
 
477
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
478
  """
479
  Forward pass of the GPT model.
480
+
481
  Args:
482
  idx: Input token indices of shape (batch_size, seq_len)
483
  targets: Optional target tokens for loss calculation (batch_size, seq_len)
484
+
485
  Returns:
486
  Tuple containing:
487
  - logits: Output logits of shape (batch_size, seq_len, vocab_size)
 
489
  """
490
  device = idx.device
491
  b, t = idx.size()
492
+ assert (
493
+ t <= self.config.block_size
494
+ ), f"Sequence length {t} exceeds block size {self.config.block_size}"
495
+
496
  # Token embeddings
497
  tok_emb = self.transformer.wte(idx) # (b, t, n_embd)
498
+
499
  # Position embeddings
500
  pos = torch.arange(0, t, dtype=torch.long, device=device) # (t,)
501
  pos_emb = self.transformer.wpe(pos) # (t, n_embd)
502
+
503
  # Combine embeddings and apply dropout
504
  x = self.transformer.drop(tok_emb + pos_emb)
505
+
506
  # Pass through transformer blocks
507
  for block in self.transformer.h:
508
  x = block(x)
509
+
510
  # Final layer normalization
511
  x = self.transformer.ln_f(x)
512
+
513
  # Language modeling head
514
  if targets is not None:
515
  # If we have targets, compute loss
516
  logits = self.lm_head(x)
517
+ loss = F.cross_entropy(
518
+ logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1
519
+ )
520
  else:
521
  # If no targets, only compute logits for the last token (more efficient for generation)
522
  logits = self.lm_head(x[:, [-1], :]) # Note: using list [-1] to preserve the time dim
523
  loss = None
524
+
525
  return logits, loss
526
+
527
  def generate(
528
+ self,
529
+ idx: torch.Tensor,
530
  max_new_tokens: int = 100,
531
  temperature: float = 1.0,
532
+ top_k: Optional[int] = None,
533
  ) -> torch.Tensor:
534
  """
535
  Generate new tokens autoregressively.
536
+
537
  Args:
538
  idx: Starting token indices (batch_size, seq_len)
539
  max_new_tokens: Maximum number of new tokens to generate
540
  temperature: Sampling temperature (higher = more random)
541
  top_k: If set, only sample from top-k most likely tokens
542
+
543
  Returns:
544
  torch.Tensor: Generated sequence (batch_size, seq_len + max_new_tokens)
545
  """
 
547
  with torch.no_grad():
548
  for _ in range(max_new_tokens):
549
  # Crop sequence if it exceeds block size
550
+ idx_cond = (
551
+ idx
552
+ if idx.size(1) <= self.config.block_size
553
+ else idx[:, -self.config.block_size :]
554
+ )
555
+
556
  # Forward pass
557
  logits, _ = self(idx_cond)
558
+
559
  # Get logits for the last token and apply temperature
560
  logits = logits[:, -1, :] / temperature
561
+
562
  # Optionally crop to top-k most likely tokens
563
  if top_k is not None:
564
  v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
565
+ logits[logits < v[:, [-1]]] = -float("Inf")
566
+
567
  # Apply softmax and sample
568
  probs = F.softmax(logits, dim=-1)
569
  idx_next = torch.multinomial(probs, num_samples=1)
570
+
571
  # Append to sequence
572
  idx = torch.cat((idx, idx_next), dim=1)
573
+
574
  self.train() # Return to training mode
575
  return idx
576
+
577
  def estimate_memory_usage(self, batch_size: int = 1, seq_len: int = None) -> dict:
578
  """
579
  Estimate memory usage for training and inference.
580
+
581
  Args:
582
  batch_size: Batch size for estimation
583
  seq_len: Sequence length (defaults to block_size)
584
+
585
  Returns:
586
  dict: Memory usage estimates in MB
587
  """
588
  if seq_len is None:
589
  seq_len = self.config.block_size
590
+
591
  # Model parameters (weights)
592
  param_memory = self.get_num_params() * 4 / (1024**2) # 4 bytes per float32
593
+
594
  # Activations (rough estimate)
595
  activation_memory = (
596
  batch_size * seq_len * self.config.n_embd * self.config.n_layer * 8 # Rough estimate
597
  ) / (1024**2)
598
+
599
  # Gradients (same size as parameters during training)
600
  gradient_memory = param_memory
601
+
602
  return {
603
  "parameters_mb": param_memory,
604
+ "activations_mb": activation_memory,
605
  "gradients_mb": gradient_memory,
606
  "total_training_mb": param_memory + activation_memory + gradient_memory,
607
  "total_inference_mb": param_memory + activation_memory * 0.5, # No gradients needed
 
611
  def create_model(model_size: str = "medium") -> GPTModel:
612
  """
613
  Factory function to create a GPT model with predefined configurations.
614
+
615
  Args:
616
  model_size: Size of model to create ("small", "medium", "large")
617
+
618
  Returns:
619
  GPTModel: Initialized model
620
  """
621
  configs = {
622
  "small": GPTConfig.small(),
623
+ "medium": GPTConfig.medium(),
624
  "large": GPTConfig.large(),
625
  }
626
+
627
  if model_size not in configs:
628
  raise ValueError(f"Unknown model size: {model_size}. Choose from {list(configs.keys())}")
629
+
630
  config = configs[model_size]
631
  model = GPTModel(config)
632
+
633
  return model
634
 
635
 
 
637
  # Example usage
638
  print("🧠 GPT Model Architecture")
639
  print("=" * 50)
640
+
641
  # Create models of different sizes
642
  for size in ["small", "medium", "large"]:
643
  print(f"\n{size.upper()} MODEL:")
644
  model = create_model(size)
645
+
646
  # Show memory estimates
647
  memory = model.estimate_memory_usage(batch_size=4, seq_len=512)
648
+ print(
649
+ f"Memory (4 batch, 512 seq): {memory['total_training_mb']:.1f}MB training, {memory['total_inference_mb']:.1f}MB inference"
650
+ )
651
+
652
  # Test forward pass
653
  x = torch.randint(0, 32000, (2, 64)) # Batch size 2, sequence length 64
654
  with torch.no_grad():
655
  logits, _ = model(x)
656
+ print(f"Test forward pass: {x.shape} -> {logits.shape} βœ“")
training/train_model.py CHANGED
@@ -69,6 +69,7 @@ try:
69
  from data_loader import TextDataLoader
70
  except ImportError:
71
  import sys
 
72
  sys.path.append(os.path.dirname(__file__))
73
  from model import GPTModel, GPTConfig, create_model
74
  from data_loader import TextDataLoader
@@ -77,11 +78,11 @@ except ImportError:
77
  class ModelTrainer:
78
  """
79
  Comprehensive trainer for GPT-style language models.
80
-
81
  Handles the complete training pipeline including data loading, optimization,
82
  checkpointing, and progress monitoring.
83
  """
84
-
85
  def __init__(
86
  self,
87
  model: GPTModel,
@@ -96,11 +97,11 @@ class ModelTrainer:
96
  gradient_clipping: float = 1.0,
97
  save_every: int = 1000,
98
  eval_every: int = 500,
99
- log_every: int = 100
100
  ):
101
  """
102
  Initialize the model trainer.
103
-
104
  Args:
105
  model: GPT model to train
106
  data_loader: Data loader for training data
@@ -120,7 +121,7 @@ class ModelTrainer:
120
  self.data_loader = data_loader
121
  self.output_dir = Path(output_dir)
122
  self.device = device
123
-
124
  # Training hyperparameters
125
  self.learning_rate = learning_rate
126
  self.weight_decay = weight_decay
@@ -128,29 +129,29 @@ class ModelTrainer:
128
  self.max_steps = max_steps
129
  self.gradient_accumulation_steps = gradient_accumulation_steps
130
  self.gradient_clipping = gradient_clipping
131
-
132
  # Logging and saving
133
  self.save_every = save_every
134
  self.eval_every = eval_every
135
  self.log_every = log_every
136
-
137
  # Create output directory
138
  self.output_dir.mkdir(parents=True, exist_ok=True)
139
-
140
  # Initialize optimizer and scheduler
141
  self.optimizer = self._create_optimizer()
142
  self.scheduler = self._create_scheduler()
143
-
144
  # Training state
145
  self.step = 0
146
  self.epoch = 0
147
- self.best_loss = float('inf')
148
  self.training_log = []
149
-
150
  # Performance tracking
151
  self.start_time = None
152
  self.step_times = []
153
-
154
  print(f"πŸš€ ModelTrainer initialized")
155
  print(f" Device: {device}")
156
  print(f" Model parameters: {model.get_num_params():,}")
@@ -158,38 +159,38 @@ class ModelTrainer:
158
  print(f" Max steps: {max_steps:,}")
159
  print(f" Gradient accumulation: {gradient_accumulation_steps}")
160
  print(f" Output directory: {output_dir}")
161
-
162
  def _create_optimizer(self) -> optim.Optimizer:
163
  """Create AdamW optimizer with weight decay."""
164
  # Separate parameters for weight decay
165
  decay_params = []
166
  no_decay_params = []
167
-
168
  for name, param in self.model.named_parameters():
169
  if not param.requires_grad:
170
  continue
171
-
172
  # Don't apply weight decay to biases and layer norm parameters
173
- if len(param.shape) == 1 or name.endswith('.bias'):
174
  no_decay_params.append(param)
175
  else:
176
  decay_params.append(param)
177
-
178
  param_groups = [
179
- {'params': decay_params, 'weight_decay': self.weight_decay},
180
- {'params': no_decay_params, 'weight_decay': 0.0}
181
  ]
182
-
183
  # Use AdamW with lower memory usage for CPU
184
  optimizer = optim.AdamW(
185
  param_groups,
186
  lr=self.learning_rate,
187
  betas=(0.9, 0.95), # Slightly different from default for LLM training
188
- eps=1e-8
189
  )
190
-
191
  return optimizer
192
-
193
  def _create_scheduler(self) -> torch.optim.lr_scheduler._LRScheduler:
194
  """Create learning rate scheduler with warmup and cosine decay."""
195
  if self.warmup_steps > 0:
@@ -201,59 +202,59 @@ class ModelTrainer:
201
  self.max_steps = max_steps
202
  self.min_lr_factor = min_lr_factor
203
  super().__init__(optimizer)
204
-
205
  def get_lr(self):
206
  if self.last_epoch < self.warmup_steps:
207
- # Linear warmup
208
  factor = self.last_epoch / self.warmup_steps
209
  return [base_lr * (0.01 + 0.99 * factor) for base_lr in self.base_lrs]
210
  else:
211
  # Cosine decay
212
- progress = (self.last_epoch - self.warmup_steps) / (self.max_steps - self.warmup_steps)
 
 
213
  progress = min(progress, 1.0) # Clamp to 1.0
214
  factor = 0.5 * (1 + math.cos(math.pi * progress))
215
  factor = self.min_lr_factor + (1 - self.min_lr_factor) * factor
216
  return [base_lr * factor for base_lr in self.base_lrs]
217
-
218
  scheduler = WarmupCosineScheduler(
219
  self.optimizer,
220
  warmup_steps=self.warmup_steps,
221
  max_steps=self.max_steps,
222
- min_lr_factor=0.1
223
  )
224
  else:
225
  # Just cosine decay - this should not trigger warnings
226
  scheduler = CosineAnnealingLR(
227
- self.optimizer,
228
- T_max=self.max_steps,
229
- eta_min=self.learning_rate * 0.1
230
  )
231
-
232
  return scheduler
233
-
234
  def _calculate_loss(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
235
  """
236
  Calculate cross-entropy loss for autoregressive language modeling.
237
-
238
  This method computes the standard cross-entropy loss used in language model training.
239
  The loss measures how well the model predicts the next token in the sequence.
240
-
241
  Mathematical formulation:
242
  Loss = -βˆ‘ log(P(target_token | context))
243
  where P is the softmax probability distribution over vocabulary
244
-
245
  Implementation details:
246
  - Reshapes 3D tensors to 2D for efficient computation
247
  - Uses PyTorch's optimized cross_entropy function
248
  - Handles padding tokens by ignoring them in loss calculation
249
  - Computes mean loss across all valid positions
250
-
251
  Why cross-entropy for language modeling:
252
  - Natural choice for multi-class classification (next token prediction)
253
  - Provides strong gradient signal for correct token probabilities
254
  - Mathematically equivalent to minimizing negative log-likelihood
255
  - Well-studied optimization properties for neural language models
256
-
257
  Args:
258
  logits: Raw model predictions of shape (batch_size, seq_len, vocab_size)
259
  Contains unnormalized scores for each token in vocabulary
@@ -261,7 +262,7 @@ class ModelTrainer:
261
  targets: Ground truth next tokens of shape (batch_size, seq_len)
262
  Contains token IDs representing the true next tokens
263
  Should be input sequence shifted by one position
264
-
265
  Returns:
266
  torch.Tensor: Scalar loss value representing prediction error
267
  Lower values indicate better next-token prediction accuracy
@@ -271,123 +272,126 @@ class ModelTrainer:
271
  # where each row represents one prediction over the entire vocabulary
272
  logits = logits.view(-1, logits.size(-1)) # (batch_size * seq_len, vocab_size)
273
  targets = targets.view(-1) # (batch_size * seq_len,)
274
-
275
  # Calculate cross-entropy loss with proper handling of special tokens
276
  # ignore_index=-1 excludes padding tokens from loss calculation
277
  # This prevents the model from learning to predict padding, which would skew training
278
  # The function internally applies softmax to logits and computes negative log-likelihood
279
  loss = nn.functional.cross_entropy(logits, targets, ignore_index=-1)
280
-
281
  # Return scalar loss for backpropagation
282
  # This loss will be used to compute gradients via automatic differentiation
283
  return loss
284
-
285
  def _get_memory_usage(self) -> Dict[str, float]:
286
  """Get current memory usage statistics."""
287
  memory_stats = {}
288
-
289
- if torch.cuda.is_available() and self.device.startswith('cuda'):
290
- memory_stats['gpu_allocated_mb'] = torch.cuda.memory_allocated() / (1024**2)
291
- memory_stats['gpu_cached_mb'] = torch.cuda.memory_reserved() / (1024**2)
292
-
293
  # Estimate CPU memory (approximate)
294
  import psutil
 
295
  process = psutil.Process()
296
- memory_stats['cpu_memory_mb'] = process.memory_info().rss / (1024**2)
297
-
298
  return memory_stats
299
-
300
  def _log_step(self, step: int, loss: float, lr: float, step_time: float) -> None:
301
  """Log training progress for a single step."""
302
  perplexity = math.exp(min(loss, 10)) # Cap at exp(10) to avoid overflow
303
-
304
  # Calculate tokens per second
305
  tokens_per_batch = self.data_loader.batch_size * self.data_loader.seq_len
306
  tokens_per_second = tokens_per_batch / step_time if step_time > 0 else 0
307
-
308
  # Get memory usage
309
  memory_stats = self._get_memory_usage()
310
-
311
  # Create log entry
312
  log_entry = {
313
- 'step': step,
314
- 'loss': loss,
315
- 'perplexity': perplexity,
316
- 'learning_rate': lr,
317
- 'step_time': step_time,
318
- 'tokens_per_second': tokens_per_second,
319
- 'memory_mb': memory_stats.get('cpu_memory_mb', 0)
320
  }
321
-
322
  self.training_log.append(log_entry)
323
-
324
  # Print progress
325
  elapsed_time = time.time() - self.start_time if self.start_time else 0
326
  eta_seconds = (self.max_steps - step) * step_time if step_time > 0 else 0
327
  eta_hours = eta_seconds / 3600
328
-
329
- print(f"Step {step:,}/{self.max_steps:,} | "
330
- f"Loss: {loss:.4f} | "
331
- f"PPL: {perplexity:.2f} | "
332
- f"LR: {lr:.2e} | "
333
- f"Time: {step_time:.2f}s | "
334
- f"Tokens/s: {tokens_per_second:.1f} | "
335
- f"Memory: {memory_stats.get('cpu_memory_mb', 0):.0f}MB | "
336
- f"ETA: {eta_hours:.1f}h")
337
-
 
 
338
  def _save_checkpoint(self, step: int, is_best: bool = False) -> None:
339
  """Save model checkpoint."""
340
  checkpoint = {
341
- 'step': step,
342
- 'epoch': self.epoch,
343
- 'model_state_dict': self.model.state_dict(),
344
- 'optimizer_state_dict': self.optimizer.state_dict(),
345
- 'scheduler_state_dict': self.scheduler.state_dict(),
346
- 'best_loss': self.best_loss,
347
- 'training_log': self.training_log,
348
- 'config': self.model.config.__dict__
349
  }
350
-
351
  # Save latest checkpoint
352
  checkpoint_path = self.output_dir / f"checkpoint_step_{step}.pt"
353
  torch.save(checkpoint, checkpoint_path)
354
-
355
  # Save best checkpoint
356
  if is_best:
357
  best_path = self.output_dir / "best_model.pt"
358
  torch.save(checkpoint, best_path)
359
  print(f"πŸ’Ύ New best model saved: {best_path}")
360
-
361
  # Save training log
362
  log_path = self.output_dir / "training_log.json"
363
- with open(log_path, 'w') as f:
364
  json.dump(self.training_log, f, indent=2)
365
-
366
  print(f"πŸ’Ύ Checkpoint saved: {checkpoint_path}")
367
-
368
  def _load_checkpoint(self, checkpoint_path: str) -> None:
369
  """Load model checkpoint to resume training."""
370
  if not os.path.exists(checkpoint_path):
371
  print(f"⚠️ Checkpoint not found: {checkpoint_path}")
372
  return
373
-
374
  print(f"πŸ“‚ Loading checkpoint: {checkpoint_path}")
375
-
376
  checkpoint = torch.load(checkpoint_path, map_location=self.device)
377
-
378
- self.model.load_state_dict(checkpoint['model_state_dict'])
379
- self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
380
- self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
381
-
382
- self.step = checkpoint['step']
383
- self.epoch = checkpoint['epoch']
384
- self.best_loss = checkpoint['best_loss']
385
- self.training_log = checkpoint.get('training_log', [])
386
-
387
  print(f"βœ“ Checkpoint loaded successfully")
388
  print(f" Resuming from step: {self.step:,}")
389
  print(f" Best loss so far: {self.best_loss:.4f}")
390
-
391
  def train(self) -> None:
392
  """Main training loop."""
393
  print(f"\nπŸš€ Starting training...")
@@ -396,85 +400,85 @@ class ModelTrainer:
396
  print(f" Device: {self.device}")
397
  print(f" Max steps: {self.max_steps:,}")
398
  print("=" * 80)
399
-
400
  self.model.train()
401
  self.start_time = time.time()
402
-
403
  # Initialize gradient accumulation
404
  accumulated_loss = 0.0
405
  self.optimizer.zero_grad()
406
-
407
  for batch_idx, (input_ids, target_ids) in enumerate(self.data_loader):
408
  if self.step >= self.max_steps:
409
  break
410
-
411
  step_start_time = time.time()
412
-
413
  # Move batch to device
414
  input_ids = input_ids.to(self.device)
415
  target_ids = target_ids.to(self.device)
416
-
417
  # Forward pass (model computes loss internally when targets provided)
418
  logits, loss = self.model(input_ids, target_ids)
419
-
420
  # Scale loss for gradient accumulation
421
  loss = loss / self.gradient_accumulation_steps
422
  accumulated_loss += loss.item()
423
-
424
  # Backward pass
425
  loss.backward()
426
-
427
  # Update weights every gradient_accumulation_steps
428
  if (batch_idx + 1) % self.gradient_accumulation_steps == 0:
429
  # Clip gradients
430
  if self.gradient_clipping > 0:
431
  torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.gradient_clipping)
432
-
433
  # Update parameters
434
  self.optimizer.step()
435
  self.scheduler.step()
436
  self.optimizer.zero_grad()
437
-
438
  # Update step count
439
  self.step += 1
440
  step_time = time.time() - step_start_time
441
  self.step_times.append(step_time)
442
-
443
  # Get current learning rate
444
  current_lr = self.scheduler.get_last_lr()[0]
445
-
446
  # Log progress
447
  if self.step % self.log_every == 0:
448
  avg_loss = accumulated_loss
449
  self._log_step(self.step, avg_loss, current_lr, step_time)
450
-
451
  # Save checkpoint
452
  if self.step % self.save_every == 0:
453
  is_best = accumulated_loss < self.best_loss
454
  if is_best:
455
  self.best_loss = accumulated_loss
456
-
457
  self._save_checkpoint(self.step, is_best)
458
-
459
  # Clean up memory periodically
460
  if self.step % 100 == 0:
461
  gc.collect()
462
-
463
  # Reset accumulated loss
464
  accumulated_loss = 0.0
465
-
466
  # Check if training complete
467
  if self.step >= self.max_steps:
468
  break
469
-
470
  # Final checkpoint
471
  print(f"\nπŸŽ‰ Training completed!")
472
  self._save_checkpoint(self.step, is_best=True)
473
-
474
  # Training summary
475
  total_time = time.time() - self.start_time
476
  avg_step_time = sum(self.step_times) / len(self.step_times) if self.step_times else 0
477
-
478
  print(f"\nπŸ“Š Training Summary:")
479
  print(f" Steps completed: {self.step:,}")
480
  print(f" Total time: {total_time/3600:.2f} hours")
@@ -504,130 +508,105 @@ Examples:
504
  --batch-size 2 \\
505
  --max-steps 50000 \\
506
  --output-dir models/my-medium-model
507
- """
508
  )
509
-
510
  # Model and data arguments
511
  parser.add_argument(
512
  "--model-size",
513
  choices=["small", "medium", "large"],
514
  default="small",
515
- help="Model size to train (default: small)"
516
  )
517
-
518
  parser.add_argument(
519
  "--data-file",
520
  default="data/clean/training_data.txt",
521
- help="Path to training text file (default: data/clean/training_data.txt)"
522
  )
523
-
524
  parser.add_argument(
525
  "--tokenizer-dir",
526
  default="data/tokenizer/",
527
- help="Path to tokenizer directory (default: data/tokenizer/)"
528
  )
529
-
530
  parser.add_argument(
531
- "--output-dir",
532
- required=True,
533
- help="Output directory for model checkpoints"
534
  )
535
-
536
  # Training hyperparameters
537
  parser.add_argument(
538
- "--seq-len",
539
- type=int,
540
- default=512,
541
- help="Sequence length for training (default: 512)"
542
- )
543
-
544
- parser.add_argument(
545
- "--batch-size",
546
- type=int,
547
- default=4,
548
- help="Batch size (default: 4)"
549
  )
550
-
 
 
551
  parser.add_argument(
552
- "--learning-rate",
553
- type=float,
554
- default=3e-4,
555
- help="Learning rate (default: 3e-4)"
556
  )
557
-
558
  parser.add_argument(
559
- "--max-steps",
560
- type=int,
561
- default=10000,
562
- help="Maximum training steps (default: 10000)"
563
  )
564
-
565
  parser.add_argument(
566
- "--warmup-steps",
567
- type=int,
568
- default=1000,
569
- help="Warmup steps (default: 1000)"
570
  )
571
-
572
  parser.add_argument(
573
  "--gradient-accumulation-steps",
574
  type=int,
575
  default=4,
576
- help="Gradient accumulation steps (default: 4)"
577
  )
578
-
579
  parser.add_argument(
580
  "--device",
581
  choices=["cpu", "cuda", "auto"],
582
  default="auto",
583
- help="Training device (default: auto)"
584
- )
585
-
586
- parser.add_argument(
587
- "--resume",
588
- help="Path to checkpoint to resume training from"
589
  )
590
-
 
 
591
  parser.add_argument(
592
- "--save-every",
593
- type=int,
594
- default=1000,
595
- help="Save checkpoint every N steps (default: 1000)"
596
  )
597
-
598
  args = parser.parse_args()
599
-
600
  print("πŸš€ OpenLLM Model Training")
601
  print("=" * 60)
602
-
603
  # Determine device
604
  if args.device == "auto":
605
  device = "cuda" if torch.cuda.is_available() else "cpu"
606
  else:
607
  device = args.device
608
-
609
  print(f"Using device: {device}")
610
-
611
  try:
612
  # Create model
613
  print(f"\nπŸ—οΈ Creating {args.model_size} model...")
614
  model = create_model(args.model_size)
615
-
616
  # Create data loader
617
  print(f"\nπŸ“Š Setting up data loader...")
618
  tokenizer_path = os.path.join(args.tokenizer_dir, "tokenizer.model")
619
-
620
  data_loader = TextDataLoader(
621
  data_file=args.data_file,
622
  tokenizer_path=tokenizer_path,
623
  seq_len=args.seq_len,
624
  batch_size=args.batch_size,
625
- shuffle=True
626
  )
627
-
628
  # Get data statistics
629
  data_stats = data_loader.get_data_stats()
630
-
631
  # Create trainer
632
  print(f"\n🎯 Setting up trainer...")
633
  trainer = ModelTrainer(
@@ -639,26 +618,27 @@ Examples:
639
  max_steps=args.max_steps,
640
  warmup_steps=args.warmup_steps,
641
  gradient_accumulation_steps=args.gradient_accumulation_steps,
642
- save_every=args.save_every
643
  )
644
-
645
  # Resume from checkpoint if specified
646
  if args.resume:
647
  trainer._load_checkpoint(args.resume)
648
-
649
  # Start training
650
  trainer.train()
651
-
652
  print(f"\nπŸŽ‰ Training completed successfully!")
653
-
654
  except Exception as e:
655
  print(f"\n❌ Training failed: {e}")
656
  import traceback
 
657
  traceback.print_exc()
658
  return False
659
-
660
  return True
661
 
662
 
663
  if __name__ == "__main__":
664
- main()
 
69
  from data_loader import TextDataLoader
70
  except ImportError:
71
  import sys
72
+
73
  sys.path.append(os.path.dirname(__file__))
74
  from model import GPTModel, GPTConfig, create_model
75
  from data_loader import TextDataLoader
 
78
  class ModelTrainer:
79
  """
80
  Comprehensive trainer for GPT-style language models.
81
+
82
  Handles the complete training pipeline including data loading, optimization,
83
  checkpointing, and progress monitoring.
84
  """
85
+
86
  def __init__(
87
  self,
88
  model: GPTModel,
 
97
  gradient_clipping: float = 1.0,
98
  save_every: int = 1000,
99
  eval_every: int = 500,
100
+ log_every: int = 100,
101
  ):
102
  """
103
  Initialize the model trainer.
104
+
105
  Args:
106
  model: GPT model to train
107
  data_loader: Data loader for training data
 
121
  self.data_loader = data_loader
122
  self.output_dir = Path(output_dir)
123
  self.device = device
124
+
125
  # Training hyperparameters
126
  self.learning_rate = learning_rate
127
  self.weight_decay = weight_decay
 
129
  self.max_steps = max_steps
130
  self.gradient_accumulation_steps = gradient_accumulation_steps
131
  self.gradient_clipping = gradient_clipping
132
+
133
  # Logging and saving
134
  self.save_every = save_every
135
  self.eval_every = eval_every
136
  self.log_every = log_every
137
+
138
  # Create output directory
139
  self.output_dir.mkdir(parents=True, exist_ok=True)
140
+
141
  # Initialize optimizer and scheduler
142
  self.optimizer = self._create_optimizer()
143
  self.scheduler = self._create_scheduler()
144
+
145
  # Training state
146
  self.step = 0
147
  self.epoch = 0
148
+ self.best_loss = float("inf")
149
  self.training_log = []
150
+
151
  # Performance tracking
152
  self.start_time = None
153
  self.step_times = []
154
+
155
  print(f"πŸš€ ModelTrainer initialized")
156
  print(f" Device: {device}")
157
  print(f" Model parameters: {model.get_num_params():,}")
 
159
  print(f" Max steps: {max_steps:,}")
160
  print(f" Gradient accumulation: {gradient_accumulation_steps}")
161
  print(f" Output directory: {output_dir}")
162
+
163
  def _create_optimizer(self) -> optim.Optimizer:
164
  """Create AdamW optimizer with weight decay."""
165
  # Separate parameters for weight decay
166
  decay_params = []
167
  no_decay_params = []
168
+
169
  for name, param in self.model.named_parameters():
170
  if not param.requires_grad:
171
  continue
172
+
173
  # Don't apply weight decay to biases and layer norm parameters
174
+ if len(param.shape) == 1 or name.endswith(".bias"):
175
  no_decay_params.append(param)
176
  else:
177
  decay_params.append(param)
178
+
179
  param_groups = [
180
+ {"params": decay_params, "weight_decay": self.weight_decay},
181
+ {"params": no_decay_params, "weight_decay": 0.0},
182
  ]
183
+
184
  # Use AdamW with lower memory usage for CPU
185
  optimizer = optim.AdamW(
186
  param_groups,
187
  lr=self.learning_rate,
188
  betas=(0.9, 0.95), # Slightly different from default for LLM training
189
+ eps=1e-8,
190
  )
191
+
192
  return optimizer
193
+
194
  def _create_scheduler(self) -> torch.optim.lr_scheduler._LRScheduler:
195
  """Create learning rate scheduler with warmup and cosine decay."""
196
  if self.warmup_steps > 0:
 
202
  self.max_steps = max_steps
203
  self.min_lr_factor = min_lr_factor
204
  super().__init__(optimizer)
205
+
206
  def get_lr(self):
207
  if self.last_epoch < self.warmup_steps:
208
+ # Linear warmup
209
  factor = self.last_epoch / self.warmup_steps
210
  return [base_lr * (0.01 + 0.99 * factor) for base_lr in self.base_lrs]
211
  else:
212
  # Cosine decay
213
+ progress = (self.last_epoch - self.warmup_steps) / (
214
+ self.max_steps - self.warmup_steps
215
+ )
216
  progress = min(progress, 1.0) # Clamp to 1.0
217
  factor = 0.5 * (1 + math.cos(math.pi * progress))
218
  factor = self.min_lr_factor + (1 - self.min_lr_factor) * factor
219
  return [base_lr * factor for base_lr in self.base_lrs]
220
+
221
  scheduler = WarmupCosineScheduler(
222
  self.optimizer,
223
  warmup_steps=self.warmup_steps,
224
  max_steps=self.max_steps,
225
+ min_lr_factor=0.1,
226
  )
227
  else:
228
  # Just cosine decay - this should not trigger warnings
229
  scheduler = CosineAnnealingLR(
230
+ self.optimizer, T_max=self.max_steps, eta_min=self.learning_rate * 0.1
 
 
231
  )
232
+
233
  return scheduler
234
+
235
  def _calculate_loss(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
236
  """
237
  Calculate cross-entropy loss for autoregressive language modeling.
238
+
239
  This method computes the standard cross-entropy loss used in language model training.
240
  The loss measures how well the model predicts the next token in the sequence.
241
+
242
  Mathematical formulation:
243
  Loss = -βˆ‘ log(P(target_token | context))
244
  where P is the softmax probability distribution over vocabulary
245
+
246
  Implementation details:
247
  - Reshapes 3D tensors to 2D for efficient computation
248
  - Uses PyTorch's optimized cross_entropy function
249
  - Handles padding tokens by ignoring them in loss calculation
250
  - Computes mean loss across all valid positions
251
+
252
  Why cross-entropy for language modeling:
253
  - Natural choice for multi-class classification (next token prediction)
254
  - Provides strong gradient signal for correct token probabilities
255
  - Mathematically equivalent to minimizing negative log-likelihood
256
  - Well-studied optimization properties for neural language models
257
+
258
  Args:
259
  logits: Raw model predictions of shape (batch_size, seq_len, vocab_size)
260
  Contains unnormalized scores for each token in vocabulary
 
262
  targets: Ground truth next tokens of shape (batch_size, seq_len)
263
  Contains token IDs representing the true next tokens
264
  Should be input sequence shifted by one position
265
+
266
  Returns:
267
  torch.Tensor: Scalar loss value representing prediction error
268
  Lower values indicate better next-token prediction accuracy
 
272
  # where each row represents one prediction over the entire vocabulary
273
  logits = logits.view(-1, logits.size(-1)) # (batch_size * seq_len, vocab_size)
274
  targets = targets.view(-1) # (batch_size * seq_len,)
275
+
276
  # Calculate cross-entropy loss with proper handling of special tokens
277
  # ignore_index=-1 excludes padding tokens from loss calculation
278
  # This prevents the model from learning to predict padding, which would skew training
279
  # The function internally applies softmax to logits and computes negative log-likelihood
280
  loss = nn.functional.cross_entropy(logits, targets, ignore_index=-1)
281
+
282
  # Return scalar loss for backpropagation
283
  # This loss will be used to compute gradients via automatic differentiation
284
  return loss
285
+
286
  def _get_memory_usage(self) -> Dict[str, float]:
287
  """Get current memory usage statistics."""
288
  memory_stats = {}
289
+
290
+ if torch.cuda.is_available() and self.device.startswith("cuda"):
291
+ memory_stats["gpu_allocated_mb"] = torch.cuda.memory_allocated() / (1024**2)
292
+ memory_stats["gpu_cached_mb"] = torch.cuda.memory_reserved() / (1024**2)
293
+
294
  # Estimate CPU memory (approximate)
295
  import psutil
296
+
297
  process = psutil.Process()
298
+ memory_stats["cpu_memory_mb"] = process.memory_info().rss / (1024**2)
299
+
300
  return memory_stats
301
+
302
  def _log_step(self, step: int, loss: float, lr: float, step_time: float) -> None:
303
  """Log training progress for a single step."""
304
  perplexity = math.exp(min(loss, 10)) # Cap at exp(10) to avoid overflow
305
+
306
  # Calculate tokens per second
307
  tokens_per_batch = self.data_loader.batch_size * self.data_loader.seq_len
308
  tokens_per_second = tokens_per_batch / step_time if step_time > 0 else 0
309
+
310
  # Get memory usage
311
  memory_stats = self._get_memory_usage()
312
+
313
  # Create log entry
314
  log_entry = {
315
+ "step": step,
316
+ "loss": loss,
317
+ "perplexity": perplexity,
318
+ "learning_rate": lr,
319
+ "step_time": step_time,
320
+ "tokens_per_second": tokens_per_second,
321
+ "memory_mb": memory_stats.get("cpu_memory_mb", 0),
322
  }
323
+
324
  self.training_log.append(log_entry)
325
+
326
  # Print progress
327
  elapsed_time = time.time() - self.start_time if self.start_time else 0
328
  eta_seconds = (self.max_steps - step) * step_time if step_time > 0 else 0
329
  eta_hours = eta_seconds / 3600
330
+
331
+ print(
332
+ f"Step {step:,}/{self.max_steps:,} | "
333
+ f"Loss: {loss:.4f} | "
334
+ f"PPL: {perplexity:.2f} | "
335
+ f"LR: {lr:.2e} | "
336
+ f"Time: {step_time:.2f}s | "
337
+ f"Tokens/s: {tokens_per_second:.1f} | "
338
+ f"Memory: {memory_stats.get('cpu_memory_mb', 0):.0f}MB | "
339
+ f"ETA: {eta_hours:.1f}h"
340
+ )
341
+
342
  def _save_checkpoint(self, step: int, is_best: bool = False) -> None:
343
  """Save model checkpoint."""
344
  checkpoint = {
345
+ "step": step,
346
+ "epoch": self.epoch,
347
+ "model_state_dict": self.model.state_dict(),
348
+ "optimizer_state_dict": self.optimizer.state_dict(),
349
+ "scheduler_state_dict": self.scheduler.state_dict(),
350
+ "best_loss": self.best_loss,
351
+ "training_log": self.training_log,
352
+ "config": self.model.config.__dict__,
353
  }
354
+
355
  # Save latest checkpoint
356
  checkpoint_path = self.output_dir / f"checkpoint_step_{step}.pt"
357
  torch.save(checkpoint, checkpoint_path)
358
+
359
  # Save best checkpoint
360
  if is_best:
361
  best_path = self.output_dir / "best_model.pt"
362
  torch.save(checkpoint, best_path)
363
  print(f"πŸ’Ύ New best model saved: {best_path}")
364
+
365
  # Save training log
366
  log_path = self.output_dir / "training_log.json"
367
+ with open(log_path, "w") as f:
368
  json.dump(self.training_log, f, indent=2)
369
+
370
  print(f"πŸ’Ύ Checkpoint saved: {checkpoint_path}")
371
+
372
  def _load_checkpoint(self, checkpoint_path: str) -> None:
373
  """Load model checkpoint to resume training."""
374
  if not os.path.exists(checkpoint_path):
375
  print(f"⚠️ Checkpoint not found: {checkpoint_path}")
376
  return
377
+
378
  print(f"πŸ“‚ Loading checkpoint: {checkpoint_path}")
379
+
380
  checkpoint = torch.load(checkpoint_path, map_location=self.device)
381
+
382
+ self.model.load_state_dict(checkpoint["model_state_dict"])
383
+ self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
384
+ self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
385
+
386
+ self.step = checkpoint["step"]
387
+ self.epoch = checkpoint["epoch"]
388
+ self.best_loss = checkpoint["best_loss"]
389
+ self.training_log = checkpoint.get("training_log", [])
390
+
391
  print(f"βœ“ Checkpoint loaded successfully")
392
  print(f" Resuming from step: {self.step:,}")
393
  print(f" Best loss so far: {self.best_loss:.4f}")
394
+
395
  def train(self) -> None:
396
  """Main training loop."""
397
  print(f"\nπŸš€ Starting training...")
 
400
  print(f" Device: {self.device}")
401
  print(f" Max steps: {self.max_steps:,}")
402
  print("=" * 80)
403
+
404
  self.model.train()
405
  self.start_time = time.time()
406
+
407
  # Initialize gradient accumulation
408
  accumulated_loss = 0.0
409
  self.optimizer.zero_grad()
410
+
411
  for batch_idx, (input_ids, target_ids) in enumerate(self.data_loader):
412
  if self.step >= self.max_steps:
413
  break
414
+
415
  step_start_time = time.time()
416
+
417
  # Move batch to device
418
  input_ids = input_ids.to(self.device)
419
  target_ids = target_ids.to(self.device)
420
+
421
  # Forward pass (model computes loss internally when targets provided)
422
  logits, loss = self.model(input_ids, target_ids)
423
+
424
  # Scale loss for gradient accumulation
425
  loss = loss / self.gradient_accumulation_steps
426
  accumulated_loss += loss.item()
427
+
428
  # Backward pass
429
  loss.backward()
430
+
431
  # Update weights every gradient_accumulation_steps
432
  if (batch_idx + 1) % self.gradient_accumulation_steps == 0:
433
  # Clip gradients
434
  if self.gradient_clipping > 0:
435
  torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.gradient_clipping)
436
+
437
  # Update parameters
438
  self.optimizer.step()
439
  self.scheduler.step()
440
  self.optimizer.zero_grad()
441
+
442
  # Update step count
443
  self.step += 1
444
  step_time = time.time() - step_start_time
445
  self.step_times.append(step_time)
446
+
447
  # Get current learning rate
448
  current_lr = self.scheduler.get_last_lr()[0]
449
+
450
  # Log progress
451
  if self.step % self.log_every == 0:
452
  avg_loss = accumulated_loss
453
  self._log_step(self.step, avg_loss, current_lr, step_time)
454
+
455
  # Save checkpoint
456
  if self.step % self.save_every == 0:
457
  is_best = accumulated_loss < self.best_loss
458
  if is_best:
459
  self.best_loss = accumulated_loss
460
+
461
  self._save_checkpoint(self.step, is_best)
462
+
463
  # Clean up memory periodically
464
  if self.step % 100 == 0:
465
  gc.collect()
466
+
467
  # Reset accumulated loss
468
  accumulated_loss = 0.0
469
+
470
  # Check if training complete
471
  if self.step >= self.max_steps:
472
  break
473
+
474
  # Final checkpoint
475
  print(f"\nπŸŽ‰ Training completed!")
476
  self._save_checkpoint(self.step, is_best=True)
477
+
478
  # Training summary
479
  total_time = time.time() - self.start_time
480
  avg_step_time = sum(self.step_times) / len(self.step_times) if self.step_times else 0
481
+
482
  print(f"\nπŸ“Š Training Summary:")
483
  print(f" Steps completed: {self.step:,}")
484
  print(f" Total time: {total_time/3600:.2f} hours")
 
508
  --batch-size 2 \\
509
  --max-steps 50000 \\
510
  --output-dir models/my-medium-model
511
+ """,
512
  )
513
+
514
  # Model and data arguments
515
  parser.add_argument(
516
  "--model-size",
517
  choices=["small", "medium", "large"],
518
  default="small",
519
+ help="Model size to train (default: small)",
520
  )
521
+
522
  parser.add_argument(
523
  "--data-file",
524
  default="data/clean/training_data.txt",
525
+ help="Path to training text file (default: data/clean/training_data.txt)",
526
  )
527
+
528
  parser.add_argument(
529
  "--tokenizer-dir",
530
  default="data/tokenizer/",
531
+ help="Path to tokenizer directory (default: data/tokenizer/)",
532
  )
533
+
534
  parser.add_argument(
535
+ "--output-dir", required=True, help="Output directory for model checkpoints"
 
 
536
  )
537
+
538
  # Training hyperparameters
539
  parser.add_argument(
540
+ "--seq-len", type=int, default=512, help="Sequence length for training (default: 512)"
 
 
 
 
 
 
 
 
 
 
541
  )
542
+
543
+ parser.add_argument("--batch-size", type=int, default=4, help="Batch size (default: 4)")
544
+
545
  parser.add_argument(
546
+ "--learning-rate", type=float, default=3e-4, help="Learning rate (default: 3e-4)"
 
 
 
547
  )
548
+
549
  parser.add_argument(
550
+ "--max-steps", type=int, default=10000, help="Maximum training steps (default: 10000)"
 
 
 
551
  )
552
+
553
  parser.add_argument(
554
+ "--warmup-steps", type=int, default=1000, help="Warmup steps (default: 1000)"
 
 
 
555
  )
556
+
557
  parser.add_argument(
558
  "--gradient-accumulation-steps",
559
  type=int,
560
  default=4,
561
+ help="Gradient accumulation steps (default: 4)",
562
  )
563
+
564
  parser.add_argument(
565
  "--device",
566
  choices=["cpu", "cuda", "auto"],
567
  default="auto",
568
+ help="Training device (default: auto)",
 
 
 
 
 
569
  )
570
+
571
+ parser.add_argument("--resume", help="Path to checkpoint to resume training from")
572
+
573
  parser.add_argument(
574
+ "--save-every", type=int, default=1000, help="Save checkpoint every N steps (default: 1000)"
 
 
 
575
  )
576
+
577
  args = parser.parse_args()
578
+
579
  print("πŸš€ OpenLLM Model Training")
580
  print("=" * 60)
581
+
582
  # Determine device
583
  if args.device == "auto":
584
  device = "cuda" if torch.cuda.is_available() else "cpu"
585
  else:
586
  device = args.device
587
+
588
  print(f"Using device: {device}")
589
+
590
  try:
591
  # Create model
592
  print(f"\nπŸ—οΈ Creating {args.model_size} model...")
593
  model = create_model(args.model_size)
594
+
595
  # Create data loader
596
  print(f"\nπŸ“Š Setting up data loader...")
597
  tokenizer_path = os.path.join(args.tokenizer_dir, "tokenizer.model")
598
+
599
  data_loader = TextDataLoader(
600
  data_file=args.data_file,
601
  tokenizer_path=tokenizer_path,
602
  seq_len=args.seq_len,
603
  batch_size=args.batch_size,
604
+ shuffle=True,
605
  )
606
+
607
  # Get data statistics
608
  data_stats = data_loader.get_data_stats()
609
+
610
  # Create trainer
611
  print(f"\n🎯 Setting up trainer...")
612
  trainer = ModelTrainer(
 
618
  max_steps=args.max_steps,
619
  warmup_steps=args.warmup_steps,
620
  gradient_accumulation_steps=args.gradient_accumulation_steps,
621
+ save_every=args.save_every,
622
  )
623
+
624
  # Resume from checkpoint if specified
625
  if args.resume:
626
  trainer._load_checkpoint(args.resume)
627
+
628
  # Start training
629
  trainer.train()
630
+
631
  print(f"\nπŸŽ‰ Training completed successfully!")
632
+
633
  except Exception as e:
634
  print(f"\n❌ Training failed: {e}")
635
  import traceback
636
+
637
  traceback.print_exc()
638
  return False
639
+
640
  return True
641
 
642
 
643
  if __name__ == "__main__":
644
+ main()
training/train_tokenizer.py CHANGED
@@ -76,46 +76,48 @@ except ImportError:
76
  def validate_input_file(input_path: str) -> None:
77
  """
78
  Validate that the input training file exists and is readable.
79
-
80
  Args:
81
  input_path (str): Path to the training text file
82
-
83
  Raises:
84
  FileNotFoundError: If input file doesn't exist
85
  ValueError: If input file is empty or unreadable
86
  """
87
  if not os.path.exists(input_path):
88
  raise FileNotFoundError(f"Training data file not found: {input_path}")
89
-
90
  # Check file size and readability
91
  file_size = os.path.getsize(input_path)
92
  if file_size == 0:
93
  raise ValueError(f"Training data file is empty: {input_path}")
94
-
95
  # Test that we can read the file
96
  try:
97
- with open(input_path, 'r', encoding='utf-8') as f:
98
  first_line = f.readline()
99
  if not first_line.strip():
100
- raise ValueError(f"Training data file appears to be empty or contains only whitespace")
 
 
101
  except UnicodeDecodeError as e:
102
  raise ValueError(f"Cannot read training data file as UTF-8: {e}")
103
-
104
  print(f"βœ“ Input file validated: {input_path} ({file_size:,} bytes)")
105
 
106
 
107
  def count_training_sentences(input_path: str) -> int:
108
  """
109
  Count the number of training sentences/lines in the input file.
110
-
111
  Args:
112
  input_path (str): Path to the training text file
113
-
114
  Returns:
115
  int: Number of lines in the file
116
  """
117
  print("Counting training sentences...")
118
- with open(input_path, 'r', encoding='utf-8') as f:
119
  count = sum(1 for line in f if line.strip())
120
  print(f"βœ“ Found {count:,} training sentences")
121
  return count
@@ -133,7 +135,7 @@ def train_sentencepiece_tokenizer(
133
  ) -> Dict[str, Any]:
134
  """
135
  Train a SentencePiece tokenizer with the specified parameters.
136
-
137
  Args:
138
  input_path (str): Path to training text file
139
  output_dir (str): Directory to save tokenizer files
@@ -143,16 +145,16 @@ def train_sentencepiece_tokenizer(
143
  max_sentence_length (int): Maximum sentence length in characters
144
  input_sentence_size (int): Maximum number of sentences to use for training
145
  shuffle_input_sentence (bool): Whether to shuffle input sentences
146
-
147
  Returns:
148
  Dict[str, Any]: Training statistics and configuration
149
  """
150
  # Ensure output directory exists
151
  os.makedirs(output_dir, exist_ok=True)
152
-
153
  # Define output paths
154
  model_prefix = os.path.join(output_dir, "tokenizer")
155
-
156
  # SentencePiece training parameters
157
  train_params = [
158
  f"--input={input_path}",
@@ -163,30 +165,28 @@ def train_sentencepiece_tokenizer(
163
  f"--max_sentence_length={max_sentence_length}",
164
  f"--input_sentence_size={input_sentence_size}",
165
  f"--shuffle_input_sentence={shuffle_input_sentence}",
166
-
167
  # Special tokens for language modeling
168
- "--pad_id=0", # Padding token
169
- "--unk_id=1", # Unknown token
170
- "--bos_id=2", # Beginning of sequence
171
- "--eos_id=3", # End of sequence
172
-
173
  # Additional useful parameters
174
- "--split_by_unicode_script=true", # Better handling of mixed scripts
175
- "--split_by_whitespace=true", # Split on whitespace
176
- "--remove_extra_whitespaces=true", # Clean up whitespace
177
- "--normalization_rule_name=identity", # Keep original text as-is
178
  ]
179
-
180
  print(f"\nTraining SentencePiece tokenizer...")
181
  print(f" Algorithm: {model_type.upper()}")
182
  print(f" Vocabulary size: {vocab_size:,}")
183
  print(f" Character coverage: {character_coverage}")
184
  print(f" Output directory: {output_dir}")
185
  print(f" Model files: {model_prefix}.model, {model_prefix}.vocab")
186
-
187
  # Record training start time
188
  start_time = time.time()
189
-
190
  # Train the tokenizer
191
  try:
192
  spm.SentencePieceTrainer.train(" ".join(train_params))
@@ -194,19 +194,19 @@ def train_sentencepiece_tokenizer(
194
  print(f"βœ“ Tokenizer training completed in {training_time:.1f} seconds")
195
  except Exception as e:
196
  raise RuntimeError(f"SentencePiece training failed: {e}")
197
-
198
  # Verify output files were created
199
  model_file = f"{model_prefix}.model"
200
  vocab_file = f"{model_prefix}.vocab"
201
-
202
  if not os.path.exists(model_file):
203
  raise RuntimeError(f"Expected model file not created: {model_file}")
204
  if not os.path.exists(vocab_file):
205
  raise RuntimeError(f"Expected vocab file not created: {vocab_file}")
206
-
207
  print(f"βœ“ Model file created: {model_file} ({os.path.getsize(model_file):,} bytes)")
208
  print(f"βœ“ Vocab file created: {vocab_file} ({os.path.getsize(vocab_file):,} bytes)")
209
-
210
  # Return training configuration and statistics
211
  config = {
212
  "model_type": model_type,
@@ -219,24 +219,24 @@ def train_sentencepiece_tokenizer(
219
  "model_file": model_file,
220
  "vocab_file": vocab_file,
221
  }
222
-
223
  return config
224
 
225
 
226
  def test_tokenizer(model_path: str, test_sentences: list = None) -> None:
227
  """
228
  Test the trained tokenizer on sample sentences to verify it works correctly.
229
-
230
  Args:
231
  model_path (str): Path to the trained .model file
232
  test_sentences (list): Optional list of test sentences
233
  """
234
  print(f"\nTesting trained tokenizer...")
235
-
236
  # Load the trained tokenizer
237
  sp = spm.SentencePieceProcessor()
238
  sp.load(model_path)
239
-
240
  # Default test sentences if none provided
241
  if test_sentences is None:
242
  test_sentences = [
@@ -245,63 +245,65 @@ def test_tokenizer(model_path: str, test_sentences: list = None) -> None:
245
  "Machine learning and artificial intelligence are transforming technology.",
246
  "SentencePiece tokenization works well for language models.",
247
  ]
248
-
249
  print(f"Vocabulary size: {sp.vocab_size():,}")
250
- print(f"Special tokens: PAD={sp.pad_id()}, UNK={sp.unk_id()}, BOS={sp.bos_id()}, EOS={sp.eos_id()}")
251
-
 
 
252
  print("\nTokenization examples:")
253
  for i, sentence in enumerate(test_sentences, 1):
254
  # Encode to token IDs and pieces
255
  token_ids = sp.encode(sentence)
256
  token_pieces = sp.encode(sentence, out_type=str)
257
-
258
  print(f"\n{i}. Input: {sentence}")
259
  print(f" Tokens ({len(token_pieces)}): {token_pieces}")
260
  print(f" IDs: {token_ids[:10]}{'...' if len(token_ids) > 10 else ''}")
261
-
262
  # Test decoding
263
  decoded = sp.decode(token_ids)
264
  print(f" Decoded: {decoded}")
265
-
266
  # Verify round-trip encoding/decoding
267
  if decoded.strip() != sentence.strip():
268
  print(f" ⚠️ Warning: Decode mismatch!")
269
-
270
  print("βœ“ Tokenizer testing completed")
271
 
272
 
273
  def save_huggingface_config(output_dir: str, config: Dict[str, Any]) -> None:
274
  """
275
  Save a Hugging Face compatible tokenizer configuration file.
276
-
277
  Args:
278
  output_dir (str): Directory containing the tokenizer files
279
  config (Dict[str, Any]): Tokenizer configuration
280
  """
281
  # Create Hugging Face tokenizer config
282
  hf_config = {
283
- "tokenizer_class": "SentencePieceTokenizer",
284
  "model_type": config["model_type"],
285
  "vocab_size": config["vocab_size"],
286
  "model_file": "tokenizer.model",
287
  "special_tokens": {
288
  "pad_token": "<pad>",
289
- "unk_token": "<unk>",
290
  "bos_token": "<s>",
291
  "eos_token": "</s>",
292
  },
293
  "special_token_ids": {
294
  "pad_token_id": 0,
295
  "unk_token_id": 1,
296
- "bos_token_id": 2,
297
  "eos_token_id": 3,
298
- }
299
  }
300
-
301
  config_path = os.path.join(output_dir, "tokenizer_config.json")
302
- with open(config_path, 'w', encoding='utf-8') as f:
303
  json.dump(hf_config, f, indent=2, ensure_ascii=False)
304
-
305
  print(f"βœ“ Hugging Face config saved: {config_path}")
306
 
307
 
@@ -322,69 +324,67 @@ Examples:
322
  --model_type bpe \\
323
  --output_dir data/tokenizer/ \\
324
  --character_coverage 0.9995
325
- """
326
  )
327
-
328
  # Required arguments
329
  parser.add_argument(
330
- "--input",
331
- required=True,
332
- help="Path to training text file (e.g., data/clean/training_data.txt)"
333
  )
334
-
335
  # Optional arguments with sensible defaults
336
  parser.add_argument(
337
- "--vocab_size",
338
- type=int,
339
  default=32000,
340
- help="Vocabulary size (default: 32000, recommended: 8k-64k)"
341
  )
342
-
343
  parser.add_argument(
344
- "--model_type",
345
- choices=["bpe", "unigram"],
346
  default="bpe",
347
- help="Tokenization algorithm (default: bpe)"
348
  )
349
-
350
  parser.add_argument(
351
- "--output_dir",
352
  default="data/tokenizer/",
353
- help="Output directory for tokenizer files (default: data/tokenizer/)"
354
  )
355
-
356
  parser.add_argument(
357
- "--character_coverage",
358
- type=float,
359
  default=0.9995,
360
- help="Character coverage (default: 0.9995 for English)"
361
  )
362
-
363
  parser.add_argument(
364
- "--max_sentence_length",
365
- type=int,
366
  default=4192,
367
- help="Maximum sentence length in characters (default: 4192)"
368
  )
369
-
370
  parser.add_argument(
371
- "--no_test",
372
- action="store_true",
373
- help="Skip tokenizer testing after training"
374
  )
375
-
376
  args = parser.parse_args()
377
-
378
  print("πŸ”€ SentencePiece Tokenizer Training")
379
  print("=" * 50)
380
-
381
  try:
382
  # Step 1: Validate input file
383
  validate_input_file(args.input)
384
-
385
  # Step 2: Count training data
386
  sentence_count = count_training_sentences(args.input)
387
-
388
  # Step 3: Train tokenizer
389
  config = train_sentencepiece_tokenizer(
390
  input_path=args.input,
@@ -394,36 +394,36 @@ Examples:
394
  character_coverage=args.character_coverage,
395
  max_sentence_length=args.max_sentence_length,
396
  )
397
-
398
  # Step 4: Save Hugging Face compatible config
399
  save_huggingface_config(args.output_dir, config)
400
-
401
  # Step 5: Test tokenizer (unless skipped)
402
  if not args.no_test:
403
  model_path = os.path.join(args.output_dir, "tokenizer.model")
404
  test_tokenizer(model_path)
405
-
406
  # Step 6: Print summary
407
  print(f"\nπŸŽ‰ Tokenizer training completed successfully!")
408
  print(f"πŸ“ Output directory: {args.output_dir}")
409
  print(f"πŸ“Š Vocabulary size: {config['vocab_size']:,}")
410
  print(f"⏱️ Training time: {config['training_time_seconds']:.1f}s")
411
  print(f"πŸ“„ Training sentences: {sentence_count:,}")
412
-
413
  print(f"\nFiles created:")
414
  print(f" β€’ {config['model_file']} - SentencePiece model")
415
- print(f" β€’ {config['vocab_file']} - Vocabulary file")
416
  print(f" β€’ {os.path.join(args.output_dir, 'tokenizer_config.json')} - Hugging Face config")
417
-
418
  print(f"\nTo use this tokenizer in your language model:")
419
  print(f" import sentencepiece as spm")
420
  print(f" sp = spm.SentencePieceProcessor()")
421
  print(f" sp.load('{config['model_file']}')")
422
-
423
  except Exception as e:
424
  print(f"\n❌ Error: {e}")
425
  exit(1)
426
 
427
 
428
  if __name__ == "__main__":
429
- main()
 
76
  def validate_input_file(input_path: str) -> None:
77
  """
78
  Validate that the input training file exists and is readable.
79
+
80
  Args:
81
  input_path (str): Path to the training text file
82
+
83
  Raises:
84
  FileNotFoundError: If input file doesn't exist
85
  ValueError: If input file is empty or unreadable
86
  """
87
  if not os.path.exists(input_path):
88
  raise FileNotFoundError(f"Training data file not found: {input_path}")
89
+
90
  # Check file size and readability
91
  file_size = os.path.getsize(input_path)
92
  if file_size == 0:
93
  raise ValueError(f"Training data file is empty: {input_path}")
94
+
95
  # Test that we can read the file
96
  try:
97
+ with open(input_path, "r", encoding="utf-8") as f:
98
  first_line = f.readline()
99
  if not first_line.strip():
100
+ raise ValueError(
101
+ f"Training data file appears to be empty or contains only whitespace"
102
+ )
103
  except UnicodeDecodeError as e:
104
  raise ValueError(f"Cannot read training data file as UTF-8: {e}")
105
+
106
  print(f"βœ“ Input file validated: {input_path} ({file_size:,} bytes)")
107
 
108
 
109
  def count_training_sentences(input_path: str) -> int:
110
  """
111
  Count the number of training sentences/lines in the input file.
112
+
113
  Args:
114
  input_path (str): Path to the training text file
115
+
116
  Returns:
117
  int: Number of lines in the file
118
  """
119
  print("Counting training sentences...")
120
+ with open(input_path, "r", encoding="utf-8") as f:
121
  count = sum(1 for line in f if line.strip())
122
  print(f"βœ“ Found {count:,} training sentences")
123
  return count
 
135
  ) -> Dict[str, Any]:
136
  """
137
  Train a SentencePiece tokenizer with the specified parameters.
138
+
139
  Args:
140
  input_path (str): Path to training text file
141
  output_dir (str): Directory to save tokenizer files
 
145
  max_sentence_length (int): Maximum sentence length in characters
146
  input_sentence_size (int): Maximum number of sentences to use for training
147
  shuffle_input_sentence (bool): Whether to shuffle input sentences
148
+
149
  Returns:
150
  Dict[str, Any]: Training statistics and configuration
151
  """
152
  # Ensure output directory exists
153
  os.makedirs(output_dir, exist_ok=True)
154
+
155
  # Define output paths
156
  model_prefix = os.path.join(output_dir, "tokenizer")
157
+
158
  # SentencePiece training parameters
159
  train_params = [
160
  f"--input={input_path}",
 
165
  f"--max_sentence_length={max_sentence_length}",
166
  f"--input_sentence_size={input_sentence_size}",
167
  f"--shuffle_input_sentence={shuffle_input_sentence}",
 
168
  # Special tokens for language modeling
169
+ "--pad_id=0", # Padding token
170
+ "--unk_id=1", # Unknown token
171
+ "--bos_id=2", # Beginning of sequence
172
+ "--eos_id=3", # End of sequence
 
173
  # Additional useful parameters
174
+ "--split_by_unicode_script=true", # Better handling of mixed scripts
175
+ "--split_by_whitespace=true", # Split on whitespace
176
+ "--remove_extra_whitespaces=true", # Clean up whitespace
177
+ "--normalization_rule_name=identity", # Keep original text as-is
178
  ]
179
+
180
  print(f"\nTraining SentencePiece tokenizer...")
181
  print(f" Algorithm: {model_type.upper()}")
182
  print(f" Vocabulary size: {vocab_size:,}")
183
  print(f" Character coverage: {character_coverage}")
184
  print(f" Output directory: {output_dir}")
185
  print(f" Model files: {model_prefix}.model, {model_prefix}.vocab")
186
+
187
  # Record training start time
188
  start_time = time.time()
189
+
190
  # Train the tokenizer
191
  try:
192
  spm.SentencePieceTrainer.train(" ".join(train_params))
 
194
  print(f"βœ“ Tokenizer training completed in {training_time:.1f} seconds")
195
  except Exception as e:
196
  raise RuntimeError(f"SentencePiece training failed: {e}")
197
+
198
  # Verify output files were created
199
  model_file = f"{model_prefix}.model"
200
  vocab_file = f"{model_prefix}.vocab"
201
+
202
  if not os.path.exists(model_file):
203
  raise RuntimeError(f"Expected model file not created: {model_file}")
204
  if not os.path.exists(vocab_file):
205
  raise RuntimeError(f"Expected vocab file not created: {vocab_file}")
206
+
207
  print(f"βœ“ Model file created: {model_file} ({os.path.getsize(model_file):,} bytes)")
208
  print(f"βœ“ Vocab file created: {vocab_file} ({os.path.getsize(vocab_file):,} bytes)")
209
+
210
  # Return training configuration and statistics
211
  config = {
212
  "model_type": model_type,
 
219
  "model_file": model_file,
220
  "vocab_file": vocab_file,
221
  }
222
+
223
  return config
224
 
225
 
226
  def test_tokenizer(model_path: str, test_sentences: list = None) -> None:
227
  """
228
  Test the trained tokenizer on sample sentences to verify it works correctly.
229
+
230
  Args:
231
  model_path (str): Path to the trained .model file
232
  test_sentences (list): Optional list of test sentences
233
  """
234
  print(f"\nTesting trained tokenizer...")
235
+
236
  # Load the trained tokenizer
237
  sp = spm.SentencePieceProcessor()
238
  sp.load(model_path)
239
+
240
  # Default test sentences if none provided
241
  if test_sentences is None:
242
  test_sentences = [
 
245
  "Machine learning and artificial intelligence are transforming technology.",
246
  "SentencePiece tokenization works well for language models.",
247
  ]
248
+
249
  print(f"Vocabulary size: {sp.vocab_size():,}")
250
+ print(
251
+ f"Special tokens: PAD={sp.pad_id()}, UNK={sp.unk_id()}, BOS={sp.bos_id()}, EOS={sp.eos_id()}"
252
+ )
253
+
254
  print("\nTokenization examples:")
255
  for i, sentence in enumerate(test_sentences, 1):
256
  # Encode to token IDs and pieces
257
  token_ids = sp.encode(sentence)
258
  token_pieces = sp.encode(sentence, out_type=str)
259
+
260
  print(f"\n{i}. Input: {sentence}")
261
  print(f" Tokens ({len(token_pieces)}): {token_pieces}")
262
  print(f" IDs: {token_ids[:10]}{'...' if len(token_ids) > 10 else ''}")
263
+
264
  # Test decoding
265
  decoded = sp.decode(token_ids)
266
  print(f" Decoded: {decoded}")
267
+
268
  # Verify round-trip encoding/decoding
269
  if decoded.strip() != sentence.strip():
270
  print(f" ⚠️ Warning: Decode mismatch!")
271
+
272
  print("βœ“ Tokenizer testing completed")
273
 
274
 
275
  def save_huggingface_config(output_dir: str, config: Dict[str, Any]) -> None:
276
  """
277
  Save a Hugging Face compatible tokenizer configuration file.
278
+
279
  Args:
280
  output_dir (str): Directory containing the tokenizer files
281
  config (Dict[str, Any]): Tokenizer configuration
282
  """
283
  # Create Hugging Face tokenizer config
284
  hf_config = {
285
+ "tokenizer_class": "SentencePieceTokenizer",
286
  "model_type": config["model_type"],
287
  "vocab_size": config["vocab_size"],
288
  "model_file": "tokenizer.model",
289
  "special_tokens": {
290
  "pad_token": "<pad>",
291
+ "unk_token": "<unk>",
292
  "bos_token": "<s>",
293
  "eos_token": "</s>",
294
  },
295
  "special_token_ids": {
296
  "pad_token_id": 0,
297
  "unk_token_id": 1,
298
+ "bos_token_id": 2,
299
  "eos_token_id": 3,
300
+ },
301
  }
302
+
303
  config_path = os.path.join(output_dir, "tokenizer_config.json")
304
+ with open(config_path, "w", encoding="utf-8") as f:
305
  json.dump(hf_config, f, indent=2, ensure_ascii=False)
306
+
307
  print(f"βœ“ Hugging Face config saved: {config_path}")
308
 
309
 
 
324
  --model_type bpe \\
325
  --output_dir data/tokenizer/ \\
326
  --character_coverage 0.9995
327
+ """,
328
  )
329
+
330
  # Required arguments
331
  parser.add_argument(
332
+ "--input",
333
+ required=True,
334
+ help="Path to training text file (e.g., data/clean/training_data.txt)",
335
  )
336
+
337
  # Optional arguments with sensible defaults
338
  parser.add_argument(
339
+ "--vocab_size",
340
+ type=int,
341
  default=32000,
342
+ help="Vocabulary size (default: 32000, recommended: 8k-64k)",
343
  )
344
+
345
  parser.add_argument(
346
+ "--model_type",
347
+ choices=["bpe", "unigram"],
348
  default="bpe",
349
+ help="Tokenization algorithm (default: bpe)",
350
  )
351
+
352
  parser.add_argument(
353
+ "--output_dir",
354
  default="data/tokenizer/",
355
+ help="Output directory for tokenizer files (default: data/tokenizer/)",
356
  )
357
+
358
  parser.add_argument(
359
+ "--character_coverage",
360
+ type=float,
361
  default=0.9995,
362
+ help="Character coverage (default: 0.9995 for English)",
363
  )
364
+
365
  parser.add_argument(
366
+ "--max_sentence_length",
367
+ type=int,
368
  default=4192,
369
+ help="Maximum sentence length in characters (default: 4192)",
370
  )
371
+
372
  parser.add_argument(
373
+ "--no_test", action="store_true", help="Skip tokenizer testing after training"
 
 
374
  )
375
+
376
  args = parser.parse_args()
377
+
378
  print("πŸ”€ SentencePiece Tokenizer Training")
379
  print("=" * 50)
380
+
381
  try:
382
  # Step 1: Validate input file
383
  validate_input_file(args.input)
384
+
385
  # Step 2: Count training data
386
  sentence_count = count_training_sentences(args.input)
387
+
388
  # Step 3: Train tokenizer
389
  config = train_sentencepiece_tokenizer(
390
  input_path=args.input,
 
394
  character_coverage=args.character_coverage,
395
  max_sentence_length=args.max_sentence_length,
396
  )
397
+
398
  # Step 4: Save Hugging Face compatible config
399
  save_huggingface_config(args.output_dir, config)
400
+
401
  # Step 5: Test tokenizer (unless skipped)
402
  if not args.no_test:
403
  model_path = os.path.join(args.output_dir, "tokenizer.model")
404
  test_tokenizer(model_path)
405
+
406
  # Step 6: Print summary
407
  print(f"\nπŸŽ‰ Tokenizer training completed successfully!")
408
  print(f"πŸ“ Output directory: {args.output_dir}")
409
  print(f"πŸ“Š Vocabulary size: {config['vocab_size']:,}")
410
  print(f"⏱️ Training time: {config['training_time_seconds']:.1f}s")
411
  print(f"πŸ“„ Training sentences: {sentence_count:,}")
412
+
413
  print(f"\nFiles created:")
414
  print(f" β€’ {config['model_file']} - SentencePiece model")
415
+ print(f" β€’ {config['vocab_file']} - Vocabulary file")
416
  print(f" β€’ {os.path.join(args.output_dir, 'tokenizer_config.json')} - Hugging Face config")
417
+
418
  print(f"\nTo use this tokenizer in your language model:")
419
  print(f" import sentencepiece as spm")
420
  print(f" sp = spm.SentencePieceProcessor()")
421
  print(f" sp.load('{config['model_file']}')")
422
+
423
  except Exception as e:
424
  print(f"\n❌ Error: {e}")
425
  exit(1)
426
 
427
 
428
  if __name__ == "__main__":
429
+ main()