Upload folder using huggingface_hub
Browse files- core/src/__init__.py +16 -0
- core/src/data_loader.py +493 -0
- core/src/download_and_prepare.py +243 -0
- core/src/enterprise_integration.py +139 -0
- core/src/evaluate_model.py +767 -0
- core/src/export_model.py +727 -0
- core/src/generate_text.py +866 -0
- core/src/inference_server.py +907 -0
- core/src/main.py +842 -0
- core/src/mixed_precision.py +220 -0
- core/src/model.py +665 -0
- core/src/model_test.py +564 -0
- core/src/optimized_data_loader.py +437 -0
- core/src/optimized_inference_server.py +739 -0
- core/src/performance_monitor.py +543 -0
- core/src/quantization.py +286 -0
- core/src/train_model.py +668 -0
- core/src/train_tokenizer.py +428 -0
core/src/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core source package for OpenLLM
|
| 2 |
+
# This file makes the core/src directory a Python package
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
OpenLLM Core Source Package
|
| 6 |
+
|
| 7 |
+
This package contains the core implementation of the OpenLLM language model,
|
| 8 |
+
including model architecture, training, inference, and data processing components.
|
| 9 |
+
|
| 10 |
+
Author: Louis Chua Bean Chong
|
| 11 |
+
License: GPLv3
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
__version__ = "1.0.0"
|
| 15 |
+
__author__ = "Louis Chua Bean Chong"
|
| 16 |
+
__license__ = "GPLv3"
|
core/src/data_loader.py
ADDED
|
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright (C) 2024 Louis Chua Bean Chong
|
| 3 |
+
#
|
| 4 |
+
# This file is part of OpenLLM.
|
| 5 |
+
#
|
| 6 |
+
# OpenLLM is dual-licensed:
|
| 7 |
+
# 1. For open source use: GNU General Public License v3.0
|
| 8 |
+
# 2. For commercial use: Commercial License (contact for details)
|
| 9 |
+
#
|
| 10 |
+
# See LICENSE and docs/LICENSES.md for full license information.
|
| 11 |
+
|
| 12 |
+
"""
|
| 13 |
+
Training Data Loader for Language Model Training
|
| 14 |
+
|
| 15 |
+
This module provides efficient data loading and batching for training GPT-style
|
| 16 |
+
language models. It handles text preprocessing, tokenization, and creates
|
| 17 |
+
batches suitable for autoregressive language modeling.
|
| 18 |
+
|
| 19 |
+
FEATURES:
|
| 20 |
+
- Memory-efficient text loading with sliding window
|
| 21 |
+
- Automatic tokenization using trained SentencePiece model
|
| 22 |
+
- Configurable sequence length and batch size
|
| 23 |
+
- CPU-optimized data loading for limited hardware
|
| 24 |
+
- Support for training data validation and statistics
|
| 25 |
+
|
| 26 |
+
MEMORY OPTIMIZATION:
|
| 27 |
+
- Streaming data loading (doesn't load entire dataset to memory)
|
| 28 |
+
- Configurable chunk sizes for large files
|
| 29 |
+
- Efficient tensor creation and batching
|
| 30 |
+
- Garbage collection hints for memory management
|
| 31 |
+
|
| 32 |
+
Usage:
|
| 33 |
+
from data_loader import TextDataLoader
|
| 34 |
+
|
| 35 |
+
loader = TextDataLoader(
|
| 36 |
+
data_file="data/clean/training_data.txt",
|
| 37 |
+
tokenizer_path="data/tokenizer/tokenizer.model",
|
| 38 |
+
seq_len=512,
|
| 39 |
+
batch_size=4
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
for batch in loader:
|
| 43 |
+
input_ids, targets = batch
|
| 44 |
+
# input_ids: (batch_size, seq_len)
|
| 45 |
+
# targets: (batch_size, seq_len) - shifted by 1 for next token prediction
|
| 46 |
+
|
| 47 |
+
Author: Louis Chua Bean Chong
|
| 48 |
+
License: GPLv3
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
import gc
|
| 52 |
+
import os
|
| 53 |
+
import random
|
| 54 |
+
import time
|
| 55 |
+
from typing import Iterator, List, Tuple
|
| 56 |
+
|
| 57 |
+
import torch
|
| 58 |
+
|
| 59 |
+
try:
|
| 60 |
+
import sentencepiece as spm
|
| 61 |
+
except ImportError:
|
| 62 |
+
print("ERROR: SentencePiece not installed. Run: pip install sentencepiece")
|
| 63 |
+
exit(1)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class TextDataLoader:
|
| 67 |
+
"""
|
| 68 |
+
Efficient data loader for autoregressive language model training.
|
| 69 |
+
|
| 70 |
+
This class handles loading text data, tokenizing it using SentencePiece,
|
| 71 |
+
and creating batches suitable for next-token prediction training.
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
def __init__(
|
| 75 |
+
self,
|
| 76 |
+
data_file: str,
|
| 77 |
+
tokenizer_path: str,
|
| 78 |
+
seq_len: int = 512,
|
| 79 |
+
batch_size: int = 4,
|
| 80 |
+
chunk_size: int = 1000000, # Lines to read at once
|
| 81 |
+
shuffle: bool = True,
|
| 82 |
+
seed: int = 42,
|
| 83 |
+
):
|
| 84 |
+
"""
|
| 85 |
+
Initialize the data loader.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
data_file: Path to training text file (one passage per line)
|
| 89 |
+
tokenizer_path: Path to trained SentencePiece model
|
| 90 |
+
seq_len: Maximum sequence length for training
|
| 91 |
+
batch_size: Batch size for training
|
| 92 |
+
chunk_size: Number of lines to read in memory at once
|
| 93 |
+
shuffle: Whether to shuffle training examples
|
| 94 |
+
seed: Random seed for reproducibility
|
| 95 |
+
"""
|
| 96 |
+
self.data_file = data_file
|
| 97 |
+
self.tokenizer_path = tokenizer_path
|
| 98 |
+
self.seq_len = seq_len
|
| 99 |
+
self.batch_size = batch_size
|
| 100 |
+
self.chunk_size = chunk_size
|
| 101 |
+
self.shuffle = shuffle
|
| 102 |
+
self.seed = seed
|
| 103 |
+
|
| 104 |
+
# Validate inputs
|
| 105 |
+
self._validate_inputs()
|
| 106 |
+
|
| 107 |
+
# Load tokenizer
|
| 108 |
+
self.tokenizer = self._load_tokenizer()
|
| 109 |
+
|
| 110 |
+
# Get data statistics
|
| 111 |
+
self.total_lines = self._count_lines()
|
| 112 |
+
self.current_line = 0
|
| 113 |
+
|
| 114 |
+
# Initialize data attribute for testing compatibility
|
| 115 |
+
# Load a small sample of data for testing purposes
|
| 116 |
+
self.data = self._read_chunk(
|
| 117 |
+
0, min(self.chunk_size, 100)
|
| 118 |
+
) # Load up to 100 passages for testing
|
| 119 |
+
|
| 120 |
+
# Set random seed for reproducibility
|
| 121 |
+
random.seed(seed)
|
| 122 |
+
|
| 123 |
+
print("π TextDataLoader initialized")
|
| 124 |
+
print(f" Data file: {data_file}")
|
| 125 |
+
print(f" Total passages: {self.total_lines:,}")
|
| 126 |
+
print(f" Sequence length: {seq_len}")
|
| 127 |
+
print(f" Batch size: {batch_size}")
|
| 128 |
+
print(f" Vocabulary size: {self.tokenizer.vocab_size():,}")
|
| 129 |
+
|
| 130 |
+
def _validate_inputs(self) -> None:
|
| 131 |
+
"""Validate input parameters and file paths."""
|
| 132 |
+
if not os.path.exists(self.data_file):
|
| 133 |
+
raise FileNotFoundError(f"Training data file not found: {self.data_file}")
|
| 134 |
+
|
| 135 |
+
if not os.path.exists(self.tokenizer_path):
|
| 136 |
+
raise FileNotFoundError(f"Tokenizer model not found: {self.tokenizer_path}")
|
| 137 |
+
|
| 138 |
+
if self.seq_len <= 0:
|
| 139 |
+
raise ValueError(f"Sequence length must be positive, got {self.seq_len}")
|
| 140 |
+
|
| 141 |
+
if self.batch_size <= 0:
|
| 142 |
+
raise ValueError(f"Batch size must be positive, got {self.batch_size}")
|
| 143 |
+
|
| 144 |
+
if self.chunk_size <= 0:
|
| 145 |
+
raise ValueError(f"Chunk size must be positive, got {self.chunk_size}")
|
| 146 |
+
|
| 147 |
+
def _load_tokenizer(self) -> spm.SentencePieceProcessor:
|
| 148 |
+
"""Load the trained SentencePiece tokenizer."""
|
| 149 |
+
try:
|
| 150 |
+
tokenizer = spm.SentencePieceProcessor()
|
| 151 |
+
tokenizer.load(self.tokenizer_path)
|
| 152 |
+
return tokenizer
|
| 153 |
+
except Exception as e:
|
| 154 |
+
raise RuntimeError(f"Failed to load tokenizer: {e}")
|
| 155 |
+
|
| 156 |
+
def _count_lines(self) -> int:
|
| 157 |
+
"""Count total number of lines in the data file."""
|
| 158 |
+
print("π Counting training passages...")
|
| 159 |
+
start_time = time.time()
|
| 160 |
+
|
| 161 |
+
line_count = 0
|
| 162 |
+
with open(self.data_file, "r", encoding="utf-8") as f:
|
| 163 |
+
for line in f:
|
| 164 |
+
if line.strip(): # Only count non-empty lines
|
| 165 |
+
line_count += 1
|
| 166 |
+
|
| 167 |
+
count_time = time.time() - start_time
|
| 168 |
+
print(f"β Found {line_count:,} passages in {count_time:.1f}s")
|
| 169 |
+
|
| 170 |
+
return line_count
|
| 171 |
+
|
| 172 |
+
def _read_chunk(self, start_line: int = 0, limit: int = None) -> List[str]:
|
| 173 |
+
"""
|
| 174 |
+
Read a chunk of lines from the data file.
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
start_line: Line number to start reading from
|
| 178 |
+
limit: Maximum number of lines to read (None for default chunk_size)
|
| 179 |
+
|
| 180 |
+
Returns:
|
| 181 |
+
List of text passages
|
| 182 |
+
"""
|
| 183 |
+
chunk = []
|
| 184 |
+
current_line = 0
|
| 185 |
+
lines_read = 0
|
| 186 |
+
max_lines = limit if limit is not None else self.chunk_size
|
| 187 |
+
|
| 188 |
+
with open(self.data_file, "r", encoding="utf-8") as f:
|
| 189 |
+
for line in f:
|
| 190 |
+
if current_line < start_line:
|
| 191 |
+
current_line += 1
|
| 192 |
+
continue
|
| 193 |
+
|
| 194 |
+
text = line.strip()
|
| 195 |
+
if text: # Only include non-empty lines
|
| 196 |
+
chunk.append(text)
|
| 197 |
+
lines_read += 1
|
| 198 |
+
|
| 199 |
+
if lines_read >= max_lines:
|
| 200 |
+
break
|
| 201 |
+
|
| 202 |
+
current_line += 1
|
| 203 |
+
|
| 204 |
+
return chunk
|
| 205 |
+
|
| 206 |
+
def _tokenize_texts(self, texts: List[str]) -> List[List[int]]:
|
| 207 |
+
"""
|
| 208 |
+
Tokenize a list of text passages using SentencePiece tokenizer.
|
| 209 |
+
|
| 210 |
+
This method converts raw text into token ID sequences suitable for language model training.
|
| 211 |
+
It handles special tokens (BOS/EOS) and length constraints for efficient training.
|
| 212 |
+
|
| 213 |
+
Text processing pipeline:
|
| 214 |
+
1. Add BOS (Beginning of Sequence) token to mark sequence start
|
| 215 |
+
2. Tokenize text using trained SentencePiece model (subword tokenization)
|
| 216 |
+
3. Truncate sequences that exceed maximum length
|
| 217 |
+
4. Add EOS (End of Sequence) token to mark sequence end
|
| 218 |
+
|
| 219 |
+
Special token handling:
|
| 220 |
+
- BOS token helps model learn to generate text from scratch
|
| 221 |
+
- EOS token signals natural sequence endings
|
| 222 |
+
- These tokens are crucial for proper autoregressive generation
|
| 223 |
+
|
| 224 |
+
Args:
|
| 225 |
+
texts: List of text passages (typically Wikipedia passages from SQUAD)
|
| 226 |
+
Each passage should be a complete, coherent text segment
|
| 227 |
+
|
| 228 |
+
Returns:
|
| 229 |
+
List of token ID sequences, where each sequence is a list of integers
|
| 230 |
+
representing subword tokens from the SentencePiece vocabulary
|
| 231 |
+
"""
|
| 232 |
+
tokenized = []
|
| 233 |
+
|
| 234 |
+
for text in texts:
|
| 235 |
+
try:
|
| 236 |
+
# Add BOS (Beginning of Sequence) token at the start
|
| 237 |
+
# BOS token ID=2 by default in SentencePiece, signals sequence start
|
| 238 |
+
# This helps the model learn proper sequence initialization during generation
|
| 239 |
+
tokens = [self.tokenizer.bos_id()] + self.tokenizer.encode(text)
|
| 240 |
+
|
| 241 |
+
# Truncate sequences that exceed maximum context length
|
| 242 |
+
# Reserve one position for EOS token by using (seq_len - 1)
|
| 243 |
+
# This ensures we never exceed the model's context window during training
|
| 244 |
+
if len(tokens) > self.seq_len - 1:
|
| 245 |
+
tokens = tokens[: self.seq_len - 1]
|
| 246 |
+
# NOTE: Truncation may cut off text mid-sentence, but this is acceptable
|
| 247 |
+
# for language modeling where the model learns from partial contexts
|
| 248 |
+
|
| 249 |
+
# Add EOS (End of Sequence) token at the end
|
| 250 |
+
# EOS token ID=1 by default in SentencePiece, signals sequence completion
|
| 251 |
+
# This teaches the model when to stop generating text naturally
|
| 252 |
+
tokens.append(self.tokenizer.eos_id())
|
| 253 |
+
|
| 254 |
+
# Validate tokenization result
|
| 255 |
+
if len(tokens) <= 2: # Only BOS + EOS tokens, no actual content
|
| 256 |
+
print(f"β οΈ Skipping very short text: {text[:50]}...")
|
| 257 |
+
continue
|
| 258 |
+
|
| 259 |
+
tokenized.append(tokens)
|
| 260 |
+
|
| 261 |
+
except Exception as e:
|
| 262 |
+
# Handle tokenization errors gracefully to avoid stopping training
|
| 263 |
+
# Common causes: encoding issues, very long texts, special characters
|
| 264 |
+
print(f"β οΈ Failed to tokenize passage: {text[:50]}... Error: {e}")
|
| 265 |
+
continue
|
| 266 |
+
|
| 267 |
+
# Log tokenization statistics for monitoring
|
| 268 |
+
if tokenized:
|
| 269 |
+
avg_length = sum(len(tokens) for tokens in tokenized) / len(tokenized)
|
| 270 |
+
print(f"π Tokenized {len(tokenized)} passages, avg length: {avg_length:.1f} tokens")
|
| 271 |
+
|
| 272 |
+
return tokenized
|
| 273 |
+
|
| 274 |
+
def _create_training_examples(
|
| 275 |
+
self, token_sequences: List[List[int]]
|
| 276 |
+
) -> List[Tuple[List[int], List[int]]]:
|
| 277 |
+
"""
|
| 278 |
+
Create training examples with input and target sequences.
|
| 279 |
+
|
| 280 |
+
For autoregressive training, targets are inputs shifted by one position.
|
| 281 |
+
|
| 282 |
+
Args:
|
| 283 |
+
token_sequences: List of tokenized sequences
|
| 284 |
+
|
| 285 |
+
Returns:
|
| 286 |
+
List of (input_ids, target_ids) tuples
|
| 287 |
+
"""
|
| 288 |
+
examples = []
|
| 289 |
+
|
| 290 |
+
for tokens in token_sequences:
|
| 291 |
+
if len(tokens) < 2: # Need at least 2 tokens for input/target pair
|
| 292 |
+
continue
|
| 293 |
+
|
| 294 |
+
# For sequences longer than seq_len, create multiple examples with sliding window
|
| 295 |
+
if len(tokens) > self.seq_len:
|
| 296 |
+
# Create overlapping windows (50% overlap for better learning)
|
| 297 |
+
stride = self.seq_len // 2
|
| 298 |
+
for i in range(0, len(tokens) - self.seq_len, stride):
|
| 299 |
+
input_ids = tokens[i : i + self.seq_len]
|
| 300 |
+
target_ids = tokens[i + 1 : i + self.seq_len + 1]
|
| 301 |
+
examples.append((input_ids, target_ids))
|
| 302 |
+
else:
|
| 303 |
+
# Pad shorter sequences
|
| 304 |
+
input_ids = tokens[:-1] # All but last token
|
| 305 |
+
target_ids = tokens[1:] # All but first token
|
| 306 |
+
|
| 307 |
+
# Pad to seq_len if necessary
|
| 308 |
+
while len(input_ids) < self.seq_len:
|
| 309 |
+
input_ids.append(self.tokenizer.pad_id())
|
| 310 |
+
target_ids.append(-1) # Use -1 for padding in targets (ignored in loss)
|
| 311 |
+
|
| 312 |
+
# Truncate if still too long
|
| 313 |
+
input_ids = input_ids[: self.seq_len]
|
| 314 |
+
target_ids = target_ids[: self.seq_len]
|
| 315 |
+
|
| 316 |
+
examples.append((input_ids, target_ids))
|
| 317 |
+
|
| 318 |
+
return examples
|
| 319 |
+
|
| 320 |
+
def _create_batch(
|
| 321 |
+
self, examples: List[Tuple[List[int], List[int]]]
|
| 322 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 323 |
+
"""
|
| 324 |
+
Create a batch tensor from training examples.
|
| 325 |
+
|
| 326 |
+
Args:
|
| 327 |
+
examples: List of (input_ids, target_ids) tuples
|
| 328 |
+
|
| 329 |
+
Returns:
|
| 330 |
+
Tuple of (input_tensor, target_tensor)
|
| 331 |
+
"""
|
| 332 |
+
if not examples:
|
| 333 |
+
raise ValueError("Cannot create batch from empty examples")
|
| 334 |
+
|
| 335 |
+
batch_size = len(examples)
|
| 336 |
+
|
| 337 |
+
# Initialize tensors
|
| 338 |
+
input_ids = torch.zeros((batch_size, self.seq_len), dtype=torch.long)
|
| 339 |
+
target_ids = torch.full((batch_size, self.seq_len), -1, dtype=torch.long)
|
| 340 |
+
|
| 341 |
+
# Fill tensors
|
| 342 |
+
for i, (inp, tgt) in enumerate(examples):
|
| 343 |
+
input_ids[i, : len(inp)] = torch.tensor(inp, dtype=torch.long)
|
| 344 |
+
target_ids[i, : len(tgt)] = torch.tensor(tgt, dtype=torch.long)
|
| 345 |
+
|
| 346 |
+
return input_ids, target_ids
|
| 347 |
+
|
| 348 |
+
def __iter__(self) -> Iterator[Tuple[torch.Tensor, torch.Tensor]]:
|
| 349 |
+
"""
|
| 350 |
+
Iterate over training batches.
|
| 351 |
+
|
| 352 |
+
Yields:
|
| 353 |
+
Tuple of (input_ids, target_ids) tensors
|
| 354 |
+
"""
|
| 355 |
+
self.current_line = 0
|
| 356 |
+
|
| 357 |
+
while self.current_line < self.total_lines:
|
| 358 |
+
# Read chunk of text
|
| 359 |
+
texts = self._read_chunk(self.current_line)
|
| 360 |
+
if not texts:
|
| 361 |
+
break
|
| 362 |
+
|
| 363 |
+
# Tokenize texts
|
| 364 |
+
token_sequences = self._tokenize_texts(texts)
|
| 365 |
+
|
| 366 |
+
# Create training examples
|
| 367 |
+
examples = self._create_training_examples(token_sequences)
|
| 368 |
+
|
| 369 |
+
# Shuffle examples if requested
|
| 370 |
+
if self.shuffle:
|
| 371 |
+
random.shuffle(examples)
|
| 372 |
+
|
| 373 |
+
# Create batches
|
| 374 |
+
for i in range(0, len(examples), self.batch_size):
|
| 375 |
+
batch_examples = examples[i : i + self.batch_size]
|
| 376 |
+
|
| 377 |
+
if len(batch_examples) == self.batch_size: # Only yield full batches
|
| 378 |
+
try:
|
| 379 |
+
input_ids, target_ids = self._create_batch(batch_examples)
|
| 380 |
+
yield input_ids, target_ids
|
| 381 |
+
except Exception as e:
|
| 382 |
+
print(f"β οΈ Failed to create batch: {e}")
|
| 383 |
+
continue
|
| 384 |
+
|
| 385 |
+
# Update progress
|
| 386 |
+
self.current_line += len(texts)
|
| 387 |
+
|
| 388 |
+
# Clean up memory
|
| 389 |
+
del texts, token_sequences, examples
|
| 390 |
+
gc.collect()
|
| 391 |
+
|
| 392 |
+
def get_data_stats(self) -> dict:
|
| 393 |
+
"""
|
| 394 |
+
Get statistics about the training data.
|
| 395 |
+
|
| 396 |
+
Returns:
|
| 397 |
+
Dictionary with data statistics
|
| 398 |
+
"""
|
| 399 |
+
print("π Analyzing training data...")
|
| 400 |
+
|
| 401 |
+
# Sample some data to get statistics
|
| 402 |
+
sample_texts = self._read_chunk(0)[:100] # Sample first 100 passages
|
| 403 |
+
token_sequences = self._tokenize_texts(sample_texts)
|
| 404 |
+
|
| 405 |
+
if token_sequences:
|
| 406 |
+
sequence_lengths = [len(seq) for seq in token_sequences]
|
| 407 |
+
avg_length = sum(sequence_lengths) / len(sequence_lengths)
|
| 408 |
+
max_length = max(sequence_lengths)
|
| 409 |
+
min_length = min(sequence_lengths)
|
| 410 |
+
else:
|
| 411 |
+
avg_length = max_length = min_length = 0
|
| 412 |
+
|
| 413 |
+
# Estimate total tokens
|
| 414 |
+
estimated_total_tokens = int(avg_length * self.total_lines)
|
| 415 |
+
|
| 416 |
+
# Estimate number of batches per epoch
|
| 417 |
+
examples_per_passage = max(1, avg_length // self.seq_len)
|
| 418 |
+
total_examples = int(self.total_lines * examples_per_passage)
|
| 419 |
+
batches_per_epoch = total_examples // self.batch_size
|
| 420 |
+
|
| 421 |
+
stats = {
|
| 422 |
+
"total_passages": self.total_lines,
|
| 423 |
+
"avg_tokens_per_passage": avg_length,
|
| 424 |
+
"min_tokens_per_passage": min_length,
|
| 425 |
+
"max_tokens_per_passage": max_length,
|
| 426 |
+
"estimated_total_tokens": estimated_total_tokens,
|
| 427 |
+
"estimated_examples_per_epoch": total_examples,
|
| 428 |
+
"estimated_batches_per_epoch": batches_per_epoch,
|
| 429 |
+
"sequence_length": self.seq_len,
|
| 430 |
+
"batch_size": self.batch_size,
|
| 431 |
+
"vocabulary_size": self.tokenizer.vocab_size(),
|
| 432 |
+
}
|
| 433 |
+
|
| 434 |
+
print("β Data analysis complete:")
|
| 435 |
+
print(f" Total passages: {stats['total_passages']:,}")
|
| 436 |
+
print(f" Avg tokens per passage: {stats['avg_tokens_per_passage']:.1f}")
|
| 437 |
+
print(f" Estimated total tokens: {stats['estimated_total_tokens']:,}")
|
| 438 |
+
print(f" Estimated batches per epoch: {stats['estimated_batches_per_epoch']:,}")
|
| 439 |
+
|
| 440 |
+
return stats
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
def test_data_loader():
|
| 444 |
+
"""Test function for the data loader."""
|
| 445 |
+
print("π§ͺ Testing TextDataLoader...")
|
| 446 |
+
|
| 447 |
+
# Test with small parameters
|
| 448 |
+
try:
|
| 449 |
+
loader = TextDataLoader(
|
| 450 |
+
data_file="data/clean/training_data.txt",
|
| 451 |
+
tokenizer_path="data/tokenizer/tokenizer.model",
|
| 452 |
+
seq_len=128,
|
| 453 |
+
batch_size=2,
|
| 454 |
+
chunk_size=10, # Small for testing
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
# Get data statistics
|
| 458 |
+
_ = loader.get_data_stats()
|
| 459 |
+
|
| 460 |
+
# Test iteration
|
| 461 |
+
print("\nπ Testing batch iteration...")
|
| 462 |
+
start_time = time.time()
|
| 463 |
+
batch_count = 0
|
| 464 |
+
|
| 465 |
+
for batch_idx, (input_ids, target_ids) in enumerate(loader):
|
| 466 |
+
batch_count += 1
|
| 467 |
+
|
| 468 |
+
print(f"Batch {batch_idx + 1}:")
|
| 469 |
+
print(f" Input shape: {input_ids.shape}")
|
| 470 |
+
print(f" Target shape: {target_ids.shape}")
|
| 471 |
+
print(f" Sample input tokens: {input_ids[0][:10].tolist()}")
|
| 472 |
+
print(f" Sample target tokens: {target_ids[0][:10].tolist()}")
|
| 473 |
+
|
| 474 |
+
if batch_idx >= 2: # Only test first few batches
|
| 475 |
+
break
|
| 476 |
+
|
| 477 |
+
test_time = time.time() - start_time
|
| 478 |
+
print("\nβ Data loader test completed successfully!")
|
| 479 |
+
print(f" Processed {batch_count} batches in {test_time:.2f}s")
|
| 480 |
+
print(f" Average time per batch: {test_time/max(1, batch_count):.2f}s")
|
| 481 |
+
|
| 482 |
+
return True
|
| 483 |
+
|
| 484 |
+
except Exception as e:
|
| 485 |
+
print(f"β Data loader test failed: {e}")
|
| 486 |
+
import traceback
|
| 487 |
+
|
| 488 |
+
traceback.print_exc()
|
| 489 |
+
return False
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
if __name__ == "__main__":
|
| 493 |
+
test_data_loader()
|
core/src/download_and_prepare.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright (C) 2024 Louis Chua Bean Chong
|
| 3 |
+
#
|
| 4 |
+
# This file is part of OpenLLM.
|
| 5 |
+
#
|
| 6 |
+
# OpenLLM is dual-licensed:
|
| 7 |
+
# 1. For open source use: GNU General Public License v3.0
|
| 8 |
+
# 2. For commercial use: Commercial License (contact for details)
|
| 9 |
+
#
|
| 10 |
+
# See LICENSE and docs/LICENSES.md for full license information.
|
| 11 |
+
|
| 12 |
+
r"""
|
| 13 |
+
Download and prepare training data from the SQUAD dataset.
|
| 14 |
+
|
| 15 |
+
OVERVIEW:
|
| 16 |
+
This script downloads the SQUAD (Stanford Question Answering Dataset) from its official source,
|
| 17 |
+
extracts the Wikipedia context passages from the JSON format, and saves the cleaned text to disk.
|
| 18 |
+
The SQUAD dataset contains high-quality Wikipedia articles that are perfect for training language models.
|
| 19 |
+
|
| 20 |
+
DATA FLOW:
|
| 21 |
+
1. Downloads 4 JSON files from Stanford (SQUAD v1.1 & v2.0, train & dev splits)
|
| 22 |
+
2. Parses JSON structure: data -> articles -> paragraphs -> context
|
| 23 |
+
3. Extracts only the 'context' fields (Wikipedia passages, not questions/answers)
|
| 24 |
+
4. Cleans text: normalizes whitespace, filters by minimum word count
|
| 25 |
+
5. Outputs one passage per line in a single text file
|
| 26 |
+
|
| 27 |
+
The output is a single text file containing ~150k-200k Wikipedia article passages,
|
| 28 |
+
suitable for training tokenizers and language models.
|
| 29 |
+
|
| 30 |
+
DATASET INFO:
|
| 31 |
+
- SQUAD v1.1: 87k train + 10k dev examples
|
| 32 |
+
- SQUAD v2.0: 130k train + 11k dev examples
|
| 33 |
+
- Source: High-quality Wikipedia articles across diverse topics
|
| 34 |
+
- Total download size: ~200MB
|
| 35 |
+
- Final processed size: ~100-150MB of clean text
|
| 36 |
+
|
| 37 |
+
Usage:
|
| 38 |
+
python core/src/download_and_prepare.py
|
| 39 |
+
|
| 40 |
+
Output:
|
| 41 |
+
data/clean/training_data.txt - Cleaned Wikipedia passages from SQUAD dataset
|
| 42 |
+
|
| 43 |
+
Requirements:
|
| 44 |
+
pip install requests tqdm
|
| 45 |
+
|
| 46 |
+
Example setup:
|
| 47 |
+
|
| 48 |
+
Windows PowerShell:
|
| 49 |
+
```powershell
|
| 50 |
+
python -m venv venv
|
| 51 |
+
.\venv\Scripts\Activate.ps1
|
| 52 |
+
pip install requests tqdm
|
| 53 |
+
python core/src/download_and_prepare.py
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
Linux/macOS:
|
| 57 |
+
```bash
|
| 58 |
+
python -m venv venv
|
| 59 |
+
source venv/bin/activate
|
| 60 |
+
pip install requests tqdm
|
| 61 |
+
python core/src/download_and_prepare.py
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
import json
|
| 67 |
+
import os
|
| 68 |
+
|
| 69 |
+
import requests
|
| 70 |
+
from tqdm import tqdm
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def download_file(url, filename):
|
| 74 |
+
"""
|
| 75 |
+
Download a file from URL with progress bar.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
url (str): URL to download from
|
| 79 |
+
filename (str): Local path where file should be saved
|
| 80 |
+
"""
|
| 81 |
+
# Stream the download to handle large files efficiently
|
| 82 |
+
response = requests.get(url, stream=True, timeout=30)
|
| 83 |
+
total_size = int(response.headers.get("content-length", 0))
|
| 84 |
+
|
| 85 |
+
# Use tqdm progress bar to show download progress
|
| 86 |
+
with open(filename, "wb") as file, tqdm(
|
| 87 |
+
desc=filename,
|
| 88 |
+
total=total_size,
|
| 89 |
+
unit="iB",
|
| 90 |
+
unit_scale=True,
|
| 91 |
+
unit_divisor=1024,
|
| 92 |
+
) as pbar:
|
| 93 |
+
# Download in 1KB chunks
|
| 94 |
+
for data in response.iter_content(chunk_size=1024):
|
| 95 |
+
size = file.write(data)
|
| 96 |
+
pbar.update(size)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def prepare_training_data(output_path="data/clean/training_data.txt", min_words=10):
|
| 100 |
+
"""
|
| 101 |
+
Downloads the SQUAD dataset and extracts Wikipedia context passages for training.
|
| 102 |
+
|
| 103 |
+
SQUAD Dataset Structure:
|
| 104 |
+
- Each JSON file contains a 'data' array of Wikipedia articles
|
| 105 |
+
- Each article has 'paragraphs' containing 'context' (Wikipedia text) and 'qas' (questions/answers)
|
| 106 |
+
- We extract only the 'context' fields which contain high-quality Wikipedia passages
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
output_path (str): Path to save the cleaned text data.
|
| 110 |
+
min_words (int): Minimum number of words required for a passage to be included.
|
| 111 |
+
"""
|
| 112 |
+
print("Downloading SQUAD dataset...")
|
| 113 |
+
|
| 114 |
+
# Official SQUAD dataset URLs from Stanford
|
| 115 |
+
# Using both v1.1 and v2.0 for maximum training data
|
| 116 |
+
# v1.1: ~87k training + 10k dev examples
|
| 117 |
+
# v2.0: ~130k training + 11k dev examples (includes unanswerable questions)
|
| 118 |
+
squad_urls = [
|
| 119 |
+
"https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json",
|
| 120 |
+
"https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json",
|
| 121 |
+
"https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json",
|
| 122 |
+
"https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json",
|
| 123 |
+
]
|
| 124 |
+
|
| 125 |
+
# Create directory structure for temporary files
|
| 126 |
+
os.makedirs("data/raw", exist_ok=True)
|
| 127 |
+
|
| 128 |
+
downloaded_files = []
|
| 129 |
+
|
| 130 |
+
# Download each SQUAD dataset file
|
| 131 |
+
print("Step 1: Downloading SQUAD JSON files...")
|
| 132 |
+
for i, url in enumerate(squad_urls):
|
| 133 |
+
filename = f"data/raw/squad_{i+1}.json"
|
| 134 |
+
try:
|
| 135 |
+
print(f"Downloading {url}...")
|
| 136 |
+
download_file(url, filename)
|
| 137 |
+
downloaded_files.append(filename)
|
| 138 |
+
print(f"Successfully downloaded {filename}")
|
| 139 |
+
except Exception as e:
|
| 140 |
+
print(f"Failed to download {url}: {e}")
|
| 141 |
+
continue
|
| 142 |
+
|
| 143 |
+
# Verify we have at least one successful download
|
| 144 |
+
if not downloaded_files:
|
| 145 |
+
print("ERROR: No files were downloaded successfully.")
|
| 146 |
+
return
|
| 147 |
+
|
| 148 |
+
# Ensure output directory exists
|
| 149 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 150 |
+
print(f"\nStep 2: Processing SQUAD files and saving to {output_path}...")
|
| 151 |
+
|
| 152 |
+
# Process each downloaded SQUAD JSON file and extract contexts
|
| 153 |
+
with open(output_path, "w", encoding="utf-8") as f:
|
| 154 |
+
total_contexts = 0
|
| 155 |
+
|
| 156 |
+
for file_path in downloaded_files:
|
| 157 |
+
print(f"Processing {file_path}...")
|
| 158 |
+
|
| 159 |
+
try:
|
| 160 |
+
# Load and parse the JSON file
|
| 161 |
+
with open(file_path, "r", encoding="utf-8") as json_file:
|
| 162 |
+
squad_data = json.load(json_file)
|
| 163 |
+
|
| 164 |
+
# Navigate the SQUAD JSON structure to extract context passages
|
| 165 |
+
# Structure: data -> articles -> paragraphs -> context
|
| 166 |
+
contexts = []
|
| 167 |
+
for article in squad_data.get("data", []):
|
| 168 |
+
# Each article represents a Wikipedia page
|
| 169 |
+
for paragraph in article.get("paragraphs", []):
|
| 170 |
+
# Each paragraph contains a 'context' (Wikipedia passage) and 'qas' (Q&A pairs)
|
| 171 |
+
context = paragraph.get("context", "").strip()
|
| 172 |
+
if context:
|
| 173 |
+
contexts.append(context)
|
| 174 |
+
|
| 175 |
+
print(f"Found {len(contexts)} Wikipedia passages in {os.path.basename(file_path)}")
|
| 176 |
+
|
| 177 |
+
# Clean and filter each context passage for high-quality training data
|
| 178 |
+
# This preprocessing is crucial for effective language model training
|
| 179 |
+
for context in tqdm(contexts, desc=f"Processing {os.path.basename(file_path)}"):
|
| 180 |
+
# Text normalization and cleaning pipeline
|
| 181 |
+
# Step 1: Normalize whitespace to ensure consistent formatting
|
| 182 |
+
# - Collapse multiple spaces/tabs into single spaces
|
| 183 |
+
# - Remove excessive newlines that break sentence flow
|
| 184 |
+
# - Strip leading/trailing whitespace
|
| 185 |
+
# This preserves natural sentence structure while cleaning artifacts
|
| 186 |
+
cleaned_text = " ".join(context.split())
|
| 187 |
+
|
| 188 |
+
# Step 2: Skip empty passages after cleaning
|
| 189 |
+
# Empty passages can occur from malformed JSON or pure whitespace
|
| 190 |
+
if not cleaned_text:
|
| 191 |
+
continue
|
| 192 |
+
|
| 193 |
+
# Step 3: Quality filtering based on content length
|
| 194 |
+
# Apply minimum word count filter to ensure substantial content
|
| 195 |
+
# Short passages (< min_words) provide insufficient context for language modeling
|
| 196 |
+
# Wikipedia passages are typically well-formed, so this mainly catches truncated text
|
| 197 |
+
word_count = len(cleaned_text.split())
|
| 198 |
+
if word_count >= min_words:
|
| 199 |
+
# Write each passage on a new line for easy processing by data loaders
|
| 200 |
+
# The line-based format enables efficient streaming during training
|
| 201 |
+
# Each line represents one coherent Wikipedia passage
|
| 202 |
+
f.write(cleaned_text + "\n")
|
| 203 |
+
total_contexts += 1
|
| 204 |
+
|
| 205 |
+
# Optional: Log extremely short passages for monitoring data quality
|
| 206 |
+
elif word_count > 0: # Non-empty but too short
|
| 207 |
+
if total_contexts % 1000 == 0: # Log occasionally to avoid spam
|
| 208 |
+
print(
|
| 209 |
+
f"β οΈ Skipped short passage ({word_count} words): {cleaned_text[:50]}..."
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
except Exception as e:
|
| 213 |
+
print(f"Error processing {file_path}: {e}")
|
| 214 |
+
continue
|
| 215 |
+
|
| 216 |
+
print(f"\nStep 3: Successfully saved {total_contexts} Wikipedia passages from SQUAD dataset.")
|
| 217 |
+
print(f"Output file: {output_path}")
|
| 218 |
+
|
| 219 |
+
# Clean up temporary downloaded files to save disk space
|
| 220 |
+
print("Step 4: Cleaning up temporary files...")
|
| 221 |
+
for file in downloaded_files:
|
| 222 |
+
try:
|
| 223 |
+
os.remove(file)
|
| 224 |
+
print(f"Removed {file}")
|
| 225 |
+
except Exception as e:
|
| 226 |
+
print(f"Warning: Could not remove {file}: {e}")
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
if __name__ == "__main__":
|
| 230 |
+
"""
|
| 231 |
+
Main execution block - runs when script is called directly.
|
| 232 |
+
|
| 233 |
+
This will:
|
| 234 |
+
1. Download SQUAD v1.1 and v2.0 datasets (~200MB total)
|
| 235 |
+
2. Extract ~240k Wikipedia passages from the JSON files
|
| 236 |
+
3. Clean and filter the text (remove passages < 10 words)
|
| 237 |
+
4. Save all passages to data/clean/training_data.txt (one per line)
|
| 238 |
+
5. Clean up temporary files
|
| 239 |
+
|
| 240 |
+
Expected output: ~150k-200k high-quality Wikipedia passages suitable for LM training
|
| 241 |
+
"""
|
| 242 |
+
# Run the data preparation function with default parameters
|
| 243 |
+
prepare_training_data()
|
core/src/enterprise_integration.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright (C) 2024 Louis Chua Bean Chong
|
| 3 |
+
#
|
| 4 |
+
# This file is part of OpenLLM.
|
| 5 |
+
#
|
| 6 |
+
# OpenLLM is dual-licensed:
|
| 7 |
+
# 1. For open source use: GNU General Public License v3.0
|
| 8 |
+
# 2. For commercial use: Commercial License (contact for details)
|
| 9 |
+
#
|
| 10 |
+
# See LICENSE and docs/LICENSES.md for full license information.
|
| 11 |
+
|
| 12 |
+
"""
|
| 13 |
+
Enterprise Integration Layer for OpenLLM
|
| 14 |
+
|
| 15 |
+
This module provides an optional plugin mechanism to load enterprise-only
|
| 16 |
+
modules without coupling the open source core to proprietary code. It follows
|
| 17 |
+
the project rule: core functionality must work without proprietary
|
| 18 |
+
dependencies; enterprise features are optional extensions.
|
| 19 |
+
|
| 20 |
+
How it works:
|
| 21 |
+
- Attempts to locate a Python module that exposes enterprise commands
|
| 22 |
+
- Supports two discovery methods:
|
| 23 |
+
1) Python package on sys.path: `openllm_enterprise`
|
| 24 |
+
2) Filesystem path via env var `OPENLLM_ENTERPRISE_PATH` that contains
|
| 25 |
+
a module with `register_cli(subparsers)` function
|
| 26 |
+
- If found, calls `register_cli(subparsers)` to register additional CLI commands
|
| 27 |
+
|
| 28 |
+
Security and Licensing:
|
| 29 |
+
- No proprietary code is included in the open repository
|
| 30 |
+
- This module only performs optional dynamic imports if the user provides
|
| 31 |
+
an enterprise package or path
|
| 32 |
+
- All core code remains GPLv3 compliant
|
| 33 |
+
|
| 34 |
+
Usage (enterprise side expected contract):
|
| 35 |
+
# In the enterprise package/module
|
| 36 |
+
def register_cli(subparsers):
|
| 37 |
+
parser = subparsers.add_parser(
|
| 38 |
+
"enterprise-train",
|
| 39 |
+
help="Enterprise: RLHF training",
|
| 40 |
+
description="Run RLHF training using enterprise-only components."
|
| 41 |
+
)
|
| 42 |
+
parser.add_argument("--config", required=True)
|
| 43 |
+
parser.set_defaults(func=enterprise_train_entry)
|
| 44 |
+
|
| 45 |
+
def enterprise_train_entry(args):
|
| 46 |
+
...
|
| 47 |
+
|
| 48 |
+
Author: Louis Chua Bean Chong
|
| 49 |
+
License: GPLv3 (core); enterprise modules remain out-of-tree
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
from __future__ import annotations
|
| 53 |
+
|
| 54 |
+
import importlib
|
| 55 |
+
import os
|
| 56 |
+
import sys
|
| 57 |
+
from pathlib import Path
|
| 58 |
+
from typing import Any
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _try_import_by_name(module_name: str):
|
| 62 |
+
"""Attempt to import a module by name. Returns module or None on failure."""
|
| 63 |
+
try:
|
| 64 |
+
return importlib.import_module(module_name)
|
| 65 |
+
except Exception:
|
| 66 |
+
return None
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def _try_import_from_path(module_path: str):
|
| 70 |
+
"""
|
| 71 |
+
Attempt to import a module from a filesystem path.
|
| 72 |
+
|
| 73 |
+
The path may point either to a package directory (containing __init__.py)
|
| 74 |
+
or to a .py file. This function prepends the parent directory to sys.path
|
| 75 |
+
and imports the module by stem name.
|
| 76 |
+
"""
|
| 77 |
+
try:
|
| 78 |
+
path = Path(module_path)
|
| 79 |
+
if not path.exists():
|
| 80 |
+
return None
|
| 81 |
+
|
| 82 |
+
if path.is_file():
|
| 83 |
+
parent = str(path.parent)
|
| 84 |
+
mod_name = path.stem
|
| 85 |
+
else:
|
| 86 |
+
parent = str(path.parent)
|
| 87 |
+
mod_name = path.name
|
| 88 |
+
|
| 89 |
+
if parent not in sys.path:
|
| 90 |
+
sys.path.insert(0, parent)
|
| 91 |
+
return importlib.import_module(mod_name)
|
| 92 |
+
except Exception:
|
| 93 |
+
return None
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def load_enterprise_cli(subparsers: Any) -> bool:
|
| 97 |
+
"""
|
| 98 |
+
Try to load enterprise-only CLI commands.
|
| 99 |
+
|
| 100 |
+
Discovery order:
|
| 101 |
+
1) Python package/module named `openllm_enterprise`
|
| 102 |
+
2) Env var `OPENLLM_ENTERPRISE_PATH` pointing to a package dir or .py file
|
| 103 |
+
|
| 104 |
+
If a discovered module exposes `register_cli(subparsers)`, it will be called
|
| 105 |
+
to register enterprise commands. Returns True if any enterprise module was
|
| 106 |
+
loaded successfully; otherwise False.
|
| 107 |
+
"""
|
| 108 |
+
# 1) Try well-known package name
|
| 109 |
+
enterprise_mod = _try_import_by_name("openllm_enterprise")
|
| 110 |
+
if enterprise_mod and hasattr(enterprise_mod, "register_cli"):
|
| 111 |
+
try:
|
| 112 |
+
enterprise_mod.register_cli(subparsers)
|
| 113 |
+
print("π Loaded enterprise commands from openllm_enterprise package")
|
| 114 |
+
return True
|
| 115 |
+
except Exception as e:
|
| 116 |
+
# Fail gracefully; core must continue to work
|
| 117 |
+
print(f"Warning: Enterprise module registration failed: {e}")
|
| 118 |
+
|
| 119 |
+
# 2) Try explicit path via environment variable
|
| 120 |
+
enterprise_path = os.environ.get("OPENLLM_ENTERPRISE_PATH")
|
| 121 |
+
if enterprise_path:
|
| 122 |
+
enterprise_mod = _try_import_from_path(enterprise_path)
|
| 123 |
+
if enterprise_mod and hasattr(enterprise_mod, "register_cli"):
|
| 124 |
+
try:
|
| 125 |
+
enterprise_mod.register_cli(subparsers)
|
| 126 |
+
print(
|
| 127 |
+
"π Loaded enterprise commands from OPENLLM_ENTERPRISE_PATH="
|
| 128 |
+
f"{enterprise_path}"
|
| 129 |
+
)
|
| 130 |
+
return True
|
| 131 |
+
except Exception as e:
|
| 132 |
+
# Fail gracefully
|
| 133 |
+
print(f"Warning: Enterprise module registration failed: {e}")
|
| 134 |
+
|
| 135 |
+
# Not found (by design this is optional)
|
| 136 |
+
return False
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
__all__ = ["load_enterprise_cli"]
|
core/src/evaluate_model.py
ADDED
|
@@ -0,0 +1,767 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright (C) 2024 Louis Chua Bean Chong
|
| 3 |
+
#
|
| 4 |
+
# This file is part of OpenLLM.
|
| 5 |
+
#
|
| 6 |
+
# OpenLLM is dual-licensed:
|
| 7 |
+
# 1. For open source use: GNU General Public License v3.0
|
| 8 |
+
# 2. For commercial use: Commercial License (contact for details)
|
| 9 |
+
#
|
| 10 |
+
# See LICENSE and docs/LICENSES.md for full license information.
|
| 11 |
+
|
| 12 |
+
"""
|
| 13 |
+
OpenLLM Model Evaluation Script
|
| 14 |
+
|
| 15 |
+
This script implements comprehensive evaluation for trained OpenLLM models,
|
| 16 |
+
including intrinsic evaluation (perplexity, loss) and text generation quality
|
| 17 |
+
assessment as specified in Step 5 of the training pipeline.
|
| 18 |
+
|
| 19 |
+
Usage:
|
| 20 |
+
python core/src/evaluate_model.py \
|
| 21 |
+
--model_dir models/openllm-medium \
|
| 22 |
+
--eval_data data/clean/validation_data.txt \
|
| 23 |
+
--metrics perplexity,loss
|
| 24 |
+
|
| 25 |
+
Features:
|
| 26 |
+
- Perplexity calculation on held-out data
|
| 27 |
+
- Text generation quality assessment
|
| 28 |
+
- Multiple evaluation metrics
|
| 29 |
+
- Comprehensive quality benchmarks
|
| 30 |
+
- JSON output for downstream analysis
|
| 31 |
+
|
| 32 |
+
Author: Louis Chua Bean Chong
|
| 33 |
+
License: GPLv3
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
import argparse
|
| 37 |
+
import json
|
| 38 |
+
import math
|
| 39 |
+
import os
|
| 40 |
+
import sys
|
| 41 |
+
import time
|
| 42 |
+
from pathlib import Path
|
| 43 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 44 |
+
|
| 45 |
+
import sentencepiece as smp
|
| 46 |
+
import torch
|
| 47 |
+
|
| 48 |
+
# Add current directory to path for imports
|
| 49 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 50 |
+
|
| 51 |
+
from model import GPTModel, create_model
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class ModelEvaluator:
|
| 55 |
+
"""
|
| 56 |
+
Comprehensive evaluator for OpenLLM models.
|
| 57 |
+
|
| 58 |
+
Implements intrinsic evaluation metrics and text generation quality
|
| 59 |
+
assessment following the training pipeline specifications.
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
def __init__(self, model: GPTModel, tokenizer_path: str, device: str = "cpu"):
|
| 63 |
+
"""
|
| 64 |
+
Initialize the model evaluator.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
model: Trained GPT model
|
| 68 |
+
tokenizer_path: Path to tokenizer model file
|
| 69 |
+
device: Device to run evaluation on
|
| 70 |
+
"""
|
| 71 |
+
self.model = model.to(device)
|
| 72 |
+
self.device = device
|
| 73 |
+
|
| 74 |
+
# Load tokenizer
|
| 75 |
+
self.tokenizer = smp.SentencePieceProcessor()
|
| 76 |
+
self.tokenizer.load(tokenizer_path)
|
| 77 |
+
|
| 78 |
+
print("π§ ModelEvaluator initialized")
|
| 79 |
+
print(f" Device: {device}")
|
| 80 |
+
print(f" Model parameters: {model.get_num_params():,}")
|
| 81 |
+
print(f" Vocabulary size: {self.tokenizer.vocab_size():,}")
|
| 82 |
+
|
| 83 |
+
def evaluate_perplexity(
|
| 84 |
+
self, eval_data: List[str], max_seq_len: int = 512, batch_size: int = 1
|
| 85 |
+
) -> Dict[str, float]:
|
| 86 |
+
"""
|
| 87 |
+
Calculate perplexity on evaluation data.
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
eval_data: List of text passages for evaluation
|
| 91 |
+
max_seq_len: Maximum sequence length for evaluation
|
| 92 |
+
batch_size: Batch size for evaluation
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
Dictionary with loss and perplexity metrics
|
| 96 |
+
"""
|
| 97 |
+
self.model.eval()
|
| 98 |
+
total_loss = 0.0
|
| 99 |
+
total_tokens = 0
|
| 100 |
+
num_sequences = 0
|
| 101 |
+
|
| 102 |
+
print(f"π Calculating perplexity on {len(eval_data)} passages...")
|
| 103 |
+
|
| 104 |
+
with torch.no_grad():
|
| 105 |
+
for i, text in enumerate(eval_data):
|
| 106 |
+
if i % 100 == 0:
|
| 107 |
+
print(f" Progress: {i}/{len(eval_data)} passages")
|
| 108 |
+
|
| 109 |
+
# Tokenize text
|
| 110 |
+
tokens = self.tokenizer.encode(text)
|
| 111 |
+
if len(tokens) < 2:
|
| 112 |
+
continue
|
| 113 |
+
|
| 114 |
+
# Truncate if too long
|
| 115 |
+
if len(tokens) > max_seq_len:
|
| 116 |
+
tokens = tokens[:max_seq_len]
|
| 117 |
+
|
| 118 |
+
# Create input and target tensors
|
| 119 |
+
input_ids = torch.tensor([tokens[:-1]], dtype=torch.long, device=self.device)
|
| 120 |
+
target_ids = torch.tensor([tokens[1:]], dtype=torch.long, device=self.device)
|
| 121 |
+
|
| 122 |
+
# Forward pass
|
| 123 |
+
logits, loss = self.model(input_ids, target_ids)
|
| 124 |
+
|
| 125 |
+
# Accumulate loss
|
| 126 |
+
seq_length = len(tokens) - 1
|
| 127 |
+
total_loss += loss.item() * seq_length
|
| 128 |
+
total_tokens += seq_length
|
| 129 |
+
num_sequences += 1
|
| 130 |
+
|
| 131 |
+
# Calculate metrics
|
| 132 |
+
avg_loss = total_loss / total_tokens if total_tokens > 0 else float("inf")
|
| 133 |
+
perplexity = math.exp(min(avg_loss, 10)) # Cap to prevent overflow
|
| 134 |
+
|
| 135 |
+
return {
|
| 136 |
+
"loss": avg_loss,
|
| 137 |
+
"perplexity": perplexity,
|
| 138 |
+
"total_tokens": total_tokens,
|
| 139 |
+
"num_sequences": num_sequences,
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
def evaluate_text_generation(
|
| 143 |
+
self,
|
| 144 |
+
prompts: List[str],
|
| 145 |
+
max_length: int = 256,
|
| 146 |
+
temperature: float = 0.7,
|
| 147 |
+
top_k: Optional[int] = 40,
|
| 148 |
+
num_samples: int = 1,
|
| 149 |
+
) -> List[Dict[str, Any]]:
|
| 150 |
+
"""
|
| 151 |
+
Evaluate text generation quality.
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
prompts: List of input prompts
|
| 155 |
+
max_length: Maximum generation length
|
| 156 |
+
temperature: Sampling temperature
|
| 157 |
+
top_k: Top-k sampling parameter
|
| 158 |
+
num_samples: Number of samples per prompt
|
| 159 |
+
|
| 160 |
+
Returns:
|
| 161 |
+
List of generation results with quality metrics
|
| 162 |
+
"""
|
| 163 |
+
self.model.eval()
|
| 164 |
+
results = []
|
| 165 |
+
|
| 166 |
+
print(f"βοΈ Evaluating text generation on {len(prompts)} prompts...")
|
| 167 |
+
|
| 168 |
+
with torch.no_grad():
|
| 169 |
+
for prompt in prompts:
|
| 170 |
+
prompt_results = []
|
| 171 |
+
|
| 172 |
+
for sample_idx in range(num_samples):
|
| 173 |
+
# Tokenize prompt
|
| 174 |
+
input_ids = self.tokenizer.encode(prompt)
|
| 175 |
+
input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.device)
|
| 176 |
+
|
| 177 |
+
start_time = time.time()
|
| 178 |
+
|
| 179 |
+
# Generate text
|
| 180 |
+
output = self.model.generate(
|
| 181 |
+
input_tensor,
|
| 182 |
+
max_new_tokens=max_length,
|
| 183 |
+
temperature=temperature,
|
| 184 |
+
top_k=top_k,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
generation_time = time.time() - start_time
|
| 188 |
+
|
| 189 |
+
# Decode output
|
| 190 |
+
generated_ids = output[0].tolist()
|
| 191 |
+
full_text = self.tokenizer.decode(generated_ids)
|
| 192 |
+
generated_text = self.tokenizer.decode(generated_ids[len(input_ids) :])
|
| 193 |
+
|
| 194 |
+
# Calculate quality metrics
|
| 195 |
+
quality_metrics = self._assess_generation_quality(generated_text)
|
| 196 |
+
|
| 197 |
+
prompt_results.append(
|
| 198 |
+
{
|
| 199 |
+
"prompt": prompt,
|
| 200 |
+
"generated_text": generated_text,
|
| 201 |
+
"full_text": full_text,
|
| 202 |
+
"generation_time": generation_time,
|
| 203 |
+
"tokens_generated": len(generated_ids) - len(input_ids),
|
| 204 |
+
"tokens_per_second": (len(generated_ids) - len(input_ids))
|
| 205 |
+
/ generation_time,
|
| 206 |
+
"quality_metrics": quality_metrics,
|
| 207 |
+
}
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
results.extend(prompt_results)
|
| 211 |
+
|
| 212 |
+
return results
|
| 213 |
+
|
| 214 |
+
def _assess_generation_quality(self, text: str) -> Dict[str, float]:
|
| 215 |
+
"""
|
| 216 |
+
Assess basic quality metrics for generated text.
|
| 217 |
+
|
| 218 |
+
Args:
|
| 219 |
+
text: Generated text to assess
|
| 220 |
+
|
| 221 |
+
Returns:
|
| 222 |
+
Dictionary of quality metrics
|
| 223 |
+
"""
|
| 224 |
+
if not text.strip():
|
| 225 |
+
return {
|
| 226 |
+
"length": 0,
|
| 227 |
+
"avg_word_length": 0,
|
| 228 |
+
"repetition_rate": 1.0,
|
| 229 |
+
"coherence_score": 0.0,
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
words = text.split()
|
| 233 |
+
|
| 234 |
+
# Basic metrics
|
| 235 |
+
length = len(words)
|
| 236 |
+
avg_word_length = sum(len(word) for word in words) / len(words) if words else 0
|
| 237 |
+
|
| 238 |
+
# Repetition rate (simple n-gram repetition)
|
| 239 |
+
bigrams = [f"{words[i]} {words[i+1]}" for i in range(len(words) - 1)]
|
| 240 |
+
unique_bigrams = len(set(bigrams))
|
| 241 |
+
repetition_rate = 1 - (unique_bigrams / len(bigrams) if bigrams else 0)
|
| 242 |
+
|
| 243 |
+
# Simple coherence score (based on sentence structure)
|
| 244 |
+
sentences = text.split(".")
|
| 245 |
+
valid_sentences = [s for s in sentences if len(s.strip().split()) > 3]
|
| 246 |
+
coherence_score = len(valid_sentences) / len(sentences) if sentences else 0
|
| 247 |
+
|
| 248 |
+
return {
|
| 249 |
+
"length": length,
|
| 250 |
+
"avg_word_length": avg_word_length,
|
| 251 |
+
"repetition_rate": repetition_rate,
|
| 252 |
+
"coherence_score": coherence_score,
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
def evaluate_downstream_tasks(self) -> Dict[str, Any]:
|
| 256 |
+
"""
|
| 257 |
+
Evaluate model performance on downstream tasks.
|
| 258 |
+
|
| 259 |
+
This function implements basic downstream task evaluation including:
|
| 260 |
+
- Reading comprehension (simplified SQUAD-style)
|
| 261 |
+
- Sentiment analysis (few-shot)
|
| 262 |
+
- Common sense reasoning
|
| 263 |
+
|
| 264 |
+
Returns:
|
| 265 |
+
Dictionary of downstream task results
|
| 266 |
+
"""
|
| 267 |
+
results = {}
|
| 268 |
+
|
| 269 |
+
# 1. Reading Comprehension (Simplified SQUAD-style)
|
| 270 |
+
results["reading_comprehension"] = self._evaluate_reading_comprehension()
|
| 271 |
+
|
| 272 |
+
# 2. Sentiment Analysis (Few-shot learning)
|
| 273 |
+
results["sentiment_analysis"] = self._evaluate_sentiment_analysis()
|
| 274 |
+
|
| 275 |
+
# 3. Common Sense Reasoning
|
| 276 |
+
results["reasoning"] = self._evaluate_reasoning()
|
| 277 |
+
|
| 278 |
+
# 4. Text Completion Quality
|
| 279 |
+
results["text_completion"] = self._evaluate_text_completion()
|
| 280 |
+
|
| 281 |
+
return results
|
| 282 |
+
|
| 283 |
+
def _evaluate_reading_comprehension(self) -> Dict[str, Any]:
|
| 284 |
+
"""Simplified reading comprehension evaluation."""
|
| 285 |
+
# Sample reading comprehension tasks
|
| 286 |
+
tasks = [
|
| 287 |
+
{
|
| 288 |
+
"context": "The Eiffel Tower is a wrought-iron lattice tower on the Champ de Mars in Paris, France. It is named after the engineer Gustave Eiffel, whose company designed and built the tower.",
|
| 289 |
+
"question": "Who is the Eiffel Tower named after?",
|
| 290 |
+
"expected": "Gustave Eiffel",
|
| 291 |
+
},
|
| 292 |
+
{
|
| 293 |
+
"context": "Python is a high-level programming language. It was created by Guido van Rossum and first released in 1991.",
|
| 294 |
+
"question": "When was Python first released?",
|
| 295 |
+
"expected": "1991",
|
| 296 |
+
},
|
| 297 |
+
{
|
| 298 |
+
"context": "Machine learning is a subset of artificial intelligence that enables computers to learn without being explicitly programmed.",
|
| 299 |
+
"question": "What is machine learning a subset of?",
|
| 300 |
+
"expected": "artificial intelligence",
|
| 301 |
+
},
|
| 302 |
+
]
|
| 303 |
+
|
| 304 |
+
correct = 0
|
| 305 |
+
total = len(tasks)
|
| 306 |
+
|
| 307 |
+
for task in tasks:
|
| 308 |
+
prompt = f"Context: {task['context']}\nQuestion: {task['question']}\nAnswer:"
|
| 309 |
+
|
| 310 |
+
# Generate answer
|
| 311 |
+
input_ids = self.tokenizer.encode(prompt)
|
| 312 |
+
input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.device)
|
| 313 |
+
|
| 314 |
+
with torch.no_grad():
|
| 315 |
+
output = self.model.generate(input_tensor, max_new_tokens=20, temperature=0.1)
|
| 316 |
+
|
| 317 |
+
generated_ids = output[0].tolist()
|
| 318 |
+
answer = self.tokenizer.decode(generated_ids[len(input_ids) :]).strip().lower()
|
| 319 |
+
|
| 320 |
+
# Simple substring matching
|
| 321 |
+
if task["expected"].lower() in answer:
|
| 322 |
+
correct += 1
|
| 323 |
+
|
| 324 |
+
return {
|
| 325 |
+
"accuracy": correct / total,
|
| 326 |
+
"correct": correct,
|
| 327 |
+
"total": total,
|
| 328 |
+
"score": correct / total,
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
def _evaluate_sentiment_analysis(self) -> Dict[str, Any]:
|
| 332 |
+
"""Few-shot sentiment analysis evaluation."""
|
| 333 |
+
# Few-shot examples
|
| 334 |
+
examples = "Examples:\nText: 'I love this movie!' Sentiment: Positive\nText: 'This is terrible.' Sentiment: Negative\nText: 'It was okay.' Sentiment: Neutral\n\n"
|
| 335 |
+
|
| 336 |
+
# Test cases
|
| 337 |
+
test_cases = [
|
| 338 |
+
{"text": "This is amazing!", "expected": "positive"},
|
| 339 |
+
{"text": "I hate this.", "expected": "negative"},
|
| 340 |
+
{"text": "This is wonderful.", "expected": "positive"},
|
| 341 |
+
{"text": "This is awful.", "expected": "negative"},
|
| 342 |
+
{"text": "It was fine.", "expected": "neutral"},
|
| 343 |
+
]
|
| 344 |
+
|
| 345 |
+
correct = 0
|
| 346 |
+
total = len(test_cases)
|
| 347 |
+
|
| 348 |
+
for case in test_cases:
|
| 349 |
+
prompt = f"{examples}Text: '{case['text']}' Sentiment:"
|
| 350 |
+
|
| 351 |
+
# Generate sentiment
|
| 352 |
+
input_ids = self.tokenizer.encode(prompt)
|
| 353 |
+
input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.device)
|
| 354 |
+
|
| 355 |
+
with torch.no_grad():
|
| 356 |
+
output = self.model.generate(input_tensor, max_new_tokens=5, temperature=0.1)
|
| 357 |
+
|
| 358 |
+
generated_ids = output[0].tolist()
|
| 359 |
+
sentiment = self.tokenizer.decode(generated_ids[len(input_ids) :]).strip().lower()
|
| 360 |
+
|
| 361 |
+
# Check if expected sentiment is in the generated response
|
| 362 |
+
if case["expected"] in sentiment:
|
| 363 |
+
correct += 1
|
| 364 |
+
|
| 365 |
+
return {
|
| 366 |
+
"accuracy": correct / total,
|
| 367 |
+
"correct": correct,
|
| 368 |
+
"total": total,
|
| 369 |
+
"score": correct / total,
|
| 370 |
+
}
|
| 371 |
+
|
| 372 |
+
def _evaluate_reasoning(self) -> Dict[str, Any]:
|
| 373 |
+
"""Simple reasoning evaluation."""
|
| 374 |
+
# Basic reasoning tasks
|
| 375 |
+
tasks = [
|
| 376 |
+
{
|
| 377 |
+
"question": "If all birds can fly and a penguin is a bird, can a penguin fly?",
|
| 378 |
+
"expected": "no", # This tests if model knows real-world facts
|
| 379 |
+
},
|
| 380 |
+
{
|
| 381 |
+
"question": "If it is raining outside, should you take an umbrella?",
|
| 382 |
+
"expected": "yes",
|
| 383 |
+
},
|
| 384 |
+
{"question": "What comes after Monday?", "expected": "tuesday"},
|
| 385 |
+
{"question": "Is the sun larger than the earth?", "expected": "yes"},
|
| 386 |
+
]
|
| 387 |
+
|
| 388 |
+
correct = 0
|
| 389 |
+
total = len(tasks)
|
| 390 |
+
|
| 391 |
+
for task in tasks:
|
| 392 |
+
prompt = f"Question: {task['question']}\nAnswer:"
|
| 393 |
+
|
| 394 |
+
# Generate answer
|
| 395 |
+
input_ids = self.tokenizer.encode(prompt)
|
| 396 |
+
input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.device)
|
| 397 |
+
|
| 398 |
+
with torch.no_grad():
|
| 399 |
+
output = self.model.generate(input_tensor, max_new_tokens=10, temperature=0.1)
|
| 400 |
+
|
| 401 |
+
generated_ids = output[0].tolist()
|
| 402 |
+
answer = self.tokenizer.decode(generated_ids[len(input_ids) :]).strip().lower()
|
| 403 |
+
|
| 404 |
+
# Check if expected answer is in the response
|
| 405 |
+
if task["expected"] in answer:
|
| 406 |
+
correct += 1
|
| 407 |
+
|
| 408 |
+
return {
|
| 409 |
+
"accuracy": correct / total,
|
| 410 |
+
"correct": correct,
|
| 411 |
+
"total": total,
|
| 412 |
+
"score": correct / total,
|
| 413 |
+
}
|
| 414 |
+
|
| 415 |
+
def _evaluate_text_completion(self) -> Dict[str, Any]:
|
| 416 |
+
"""Evaluate text completion quality."""
|
| 417 |
+
# Common phrases that should be completed predictably
|
| 418 |
+
completions = [
|
| 419 |
+
{"prompt": "The capital of France is", "expected_word": "paris"},
|
| 420 |
+
{"prompt": "Two plus two equals", "expected_word": "four"},
|
| 421 |
+
{"prompt": "The largest planet in our solar system is", "expected_word": "jupiter"},
|
| 422 |
+
{"prompt": "Water boils at", "expected_word": "100"},
|
| 423 |
+
]
|
| 424 |
+
|
| 425 |
+
correct = 0
|
| 426 |
+
total = len(completions)
|
| 427 |
+
|
| 428 |
+
for completion in completions:
|
| 429 |
+
# Generate completion
|
| 430 |
+
input_ids = self.tokenizer.encode(completion["prompt"])
|
| 431 |
+
input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.device)
|
| 432 |
+
|
| 433 |
+
with torch.no_grad():
|
| 434 |
+
output = self.model.generate(input_tensor, max_new_tokens=5, temperature=0.1)
|
| 435 |
+
|
| 436 |
+
generated_ids = output[0].tolist()
|
| 437 |
+
generated_text = self.tokenizer.decode(generated_ids[len(input_ids) :]).strip().lower()
|
| 438 |
+
|
| 439 |
+
# Check if expected word appears in completion
|
| 440 |
+
if completion["expected_word"] in generated_text:
|
| 441 |
+
correct += 1
|
| 442 |
+
|
| 443 |
+
return {
|
| 444 |
+
"accuracy": correct / total,
|
| 445 |
+
"correct": correct,
|
| 446 |
+
"total": total,
|
| 447 |
+
"score": correct / total,
|
| 448 |
+
}
|
| 449 |
+
|
| 450 |
+
def run_comprehensive_evaluation(
|
| 451 |
+
self, eval_data_path: str, metrics: List[str] = None, generation_prompts: List[str] = None
|
| 452 |
+
) -> Dict[str, Any]:
|
| 453 |
+
"""
|
| 454 |
+
Run comprehensive model evaluation.
|
| 455 |
+
|
| 456 |
+
Args:
|
| 457 |
+
eval_data_path: Path to evaluation text file
|
| 458 |
+
metrics: List of metrics to compute
|
| 459 |
+
generation_prompts: Prompts for text generation evaluation
|
| 460 |
+
|
| 461 |
+
Returns:
|
| 462 |
+
Complete evaluation results
|
| 463 |
+
"""
|
| 464 |
+
if metrics is None:
|
| 465 |
+
metrics = ["perplexity", "loss", "generation"]
|
| 466 |
+
|
| 467 |
+
if generation_prompts is None:
|
| 468 |
+
generation_prompts = [
|
| 469 |
+
"The history of artificial intelligence",
|
| 470 |
+
"Machine learning algorithms",
|
| 471 |
+
"The future of technology",
|
| 472 |
+
"In a world where",
|
| 473 |
+
"Scientists have discovered",
|
| 474 |
+
]
|
| 475 |
+
|
| 476 |
+
results = {
|
| 477 |
+
"model_info": {
|
| 478 |
+
"parameters": self.model.get_num_params(),
|
| 479 |
+
"device": self.device,
|
| 480 |
+
"vocab_size": self.tokenizer.vocab_size(),
|
| 481 |
+
},
|
| 482 |
+
"evaluation_timestamp": time.time(),
|
| 483 |
+
}
|
| 484 |
+
|
| 485 |
+
# Load evaluation data
|
| 486 |
+
print(f"π Loading evaluation data from {eval_data_path}")
|
| 487 |
+
if os.path.exists(eval_data_path):
|
| 488 |
+
with open(eval_data_path, "r", encoding="utf-8") as f:
|
| 489 |
+
eval_texts = [line.strip() for line in f if line.strip()]
|
| 490 |
+
else:
|
| 491 |
+
print("β οΈ Evaluation file not found, using sample texts")
|
| 492 |
+
eval_texts = [
|
| 493 |
+
"Artificial intelligence is a rapidly growing field of computer science.",
|
| 494 |
+
"Machine learning algorithms can learn patterns from data automatically.",
|
| 495 |
+
"Natural language processing helps computers understand human language.",
|
| 496 |
+
"Deep learning uses neural networks with multiple layers for complex tasks.",
|
| 497 |
+
"The development of large language models has transformed AI applications.",
|
| 498 |
+
]
|
| 499 |
+
|
| 500 |
+
# Intrinsic evaluation
|
| 501 |
+
if "perplexity" in metrics or "loss" in metrics:
|
| 502 |
+
perplexity_results = self.evaluate_perplexity(eval_texts)
|
| 503 |
+
results["intrinsic_evaluation"] = perplexity_results
|
| 504 |
+
|
| 505 |
+
# Text generation evaluation
|
| 506 |
+
if "generation" in metrics:
|
| 507 |
+
generation_results = self.evaluate_text_generation(generation_prompts)
|
| 508 |
+
results["generation_evaluation"] = {
|
| 509 |
+
"results": generation_results,
|
| 510 |
+
"summary": self._summarize_generation_results(generation_results),
|
| 511 |
+
}
|
| 512 |
+
|
| 513 |
+
# Downstream tasks (placeholder)
|
| 514 |
+
results["downstream_evaluation"] = self.evaluate_downstream_tasks()
|
| 515 |
+
|
| 516 |
+
# Overall quality assessment
|
| 517 |
+
results["quality_assessment"] = self._assess_overall_quality(results)
|
| 518 |
+
|
| 519 |
+
return results
|
| 520 |
+
|
| 521 |
+
def _summarize_generation_results(self, results: List[Dict[str, Any]]) -> Dict[str, float]:
|
| 522 |
+
"""Summarize text generation results."""
|
| 523 |
+
if not results:
|
| 524 |
+
return {}
|
| 525 |
+
|
| 526 |
+
total_time = sum(r["generation_time"] for r in results)
|
| 527 |
+
total_tokens = sum(r["tokens_generated"] for r in results)
|
| 528 |
+
|
| 529 |
+
quality_metrics = [r["quality_metrics"] for r in results]
|
| 530 |
+
|
| 531 |
+
return {
|
| 532 |
+
"avg_generation_time": total_time / len(results),
|
| 533 |
+
"avg_tokens_per_second": total_tokens / total_time if total_time > 0 else 0,
|
| 534 |
+
"avg_length": sum(q["length"] for q in quality_metrics) / len(quality_metrics),
|
| 535 |
+
"avg_repetition_rate": sum(q["repetition_rate"] for q in quality_metrics)
|
| 536 |
+
/ len(quality_metrics),
|
| 537 |
+
"avg_coherence_score": sum(q["coherence_score"] for q in quality_metrics)
|
| 538 |
+
/ len(quality_metrics),
|
| 539 |
+
}
|
| 540 |
+
|
| 541 |
+
def _assess_overall_quality(self, results: Dict[str, Any]) -> Dict[str, Any]:
|
| 542 |
+
"""Assess overall model quality based on evaluation results."""
|
| 543 |
+
assessment = {"quality_level": "unknown", "recommendations": []}
|
| 544 |
+
|
| 545 |
+
# Check intrinsic metrics
|
| 546 |
+
if "intrinsic_evaluation" in results:
|
| 547 |
+
perplexity = results["intrinsic_evaluation"].get("perplexity", float("inf"))
|
| 548 |
+
|
| 549 |
+
if perplexity < 12:
|
| 550 |
+
assessment["quality_level"] = "good"
|
| 551 |
+
assessment["recommendations"].append("Model shows good perplexity scores")
|
| 552 |
+
elif perplexity < 50:
|
| 553 |
+
assessment["quality_level"] = "fair"
|
| 554 |
+
assessment["recommendations"].append(
|
| 555 |
+
"Model shows fair performance, could benefit from more training"
|
| 556 |
+
)
|
| 557 |
+
else:
|
| 558 |
+
assessment["quality_level"] = "poor"
|
| 559 |
+
assessment["recommendations"].append(
|
| 560 |
+
"Model needs significant more training or data improvements"
|
| 561 |
+
)
|
| 562 |
+
|
| 563 |
+
# Check generation quality
|
| 564 |
+
if "generation_evaluation" in results:
|
| 565 |
+
summary = results["generation_evaluation"].get("summary", {})
|
| 566 |
+
repetition_rate = summary.get("avg_repetition_rate", 1.0)
|
| 567 |
+
coherence_score = summary.get("avg_coherence_score", 0.0)
|
| 568 |
+
|
| 569 |
+
if repetition_rate > 0.7:
|
| 570 |
+
assessment["recommendations"].append(
|
| 571 |
+
"High repetition rate - consider training longer or adjusting data"
|
| 572 |
+
)
|
| 573 |
+
if coherence_score < 0.3:
|
| 574 |
+
assessment["recommendations"].append(
|
| 575 |
+
"Low coherence - model may need more training steps"
|
| 576 |
+
)
|
| 577 |
+
|
| 578 |
+
return assessment
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
def load_model_from_directory(model_dir: str, device: str = "cpu") -> Tuple[GPTModel, str]:
|
| 582 |
+
"""
|
| 583 |
+
Load model from directory containing checkpoints.
|
| 584 |
+
|
| 585 |
+
Args:
|
| 586 |
+
model_dir: Directory containing model files
|
| 587 |
+
device: Device to load model on
|
| 588 |
+
|
| 589 |
+
Returns:
|
| 590 |
+
Tuple of (model, tokenizer_path)
|
| 591 |
+
"""
|
| 592 |
+
model_dir = Path(model_dir)
|
| 593 |
+
|
| 594 |
+
# Find best model checkpoint
|
| 595 |
+
best_model_path = model_dir / "best_model.pt"
|
| 596 |
+
if not best_model_path.exists():
|
| 597 |
+
# Look for latest checkpoint
|
| 598 |
+
checkpoints = list(model_dir.glob("checkpoint_step_*.pt"))
|
| 599 |
+
if not checkpoints:
|
| 600 |
+
raise FileNotFoundError(f"No model checkpoints found in {model_dir}")
|
| 601 |
+
|
| 602 |
+
# Get latest checkpoint
|
| 603 |
+
latest_checkpoint = max(checkpoints, key=lambda p: int(p.stem.split("_")[-1]))
|
| 604 |
+
best_model_path = latest_checkpoint
|
| 605 |
+
|
| 606 |
+
print(f"π Loading model from {best_model_path}")
|
| 607 |
+
|
| 608 |
+
# Load checkpoint
|
| 609 |
+
checkpoint = torch.load(best_model_path, map_location=device)
|
| 610 |
+
|
| 611 |
+
# Determine model size from config
|
| 612 |
+
config = checkpoint.get("config", {})
|
| 613 |
+
n_layer = config.get("n_layer", 12)
|
| 614 |
+
|
| 615 |
+
if n_layer <= 6:
|
| 616 |
+
model_size = "small"
|
| 617 |
+
elif n_layer <= 12:
|
| 618 |
+
model_size = "medium"
|
| 619 |
+
else:
|
| 620 |
+
model_size = "large"
|
| 621 |
+
|
| 622 |
+
# Create and load model
|
| 623 |
+
model = create_model(model_size)
|
| 624 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
| 625 |
+
|
| 626 |
+
print(f"β
Model loaded successfully ({model_size}, {model.get_num_params():,} parameters)")
|
| 627 |
+
|
| 628 |
+
# Find tokenizer
|
| 629 |
+
tokenizer_path = model_dir.parent / "tokenizer" / "tokenizer.model"
|
| 630 |
+
if not tokenizer_path.exists():
|
| 631 |
+
tokenizer_path = Path("data/tokenizer/tokenizer.model")
|
| 632 |
+
|
| 633 |
+
if not tokenizer_path.exists():
|
| 634 |
+
raise FileNotFoundError(f"Tokenizer not found at {tokenizer_path}")
|
| 635 |
+
|
| 636 |
+
return model, str(tokenizer_path)
|
| 637 |
+
|
| 638 |
+
|
| 639 |
+
def main():
|
| 640 |
+
"""Main evaluation function."""
|
| 641 |
+
parser = argparse.ArgumentParser(
|
| 642 |
+
description="Evaluate OpenLLM model performance",
|
| 643 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 644 |
+
epilog="""
|
| 645 |
+
Examples:
|
| 646 |
+
# Basic evaluation
|
| 647 |
+
python core/src/evaluate_model.py \\
|
| 648 |
+
--model_dir models/small-extended-4k \\
|
| 649 |
+
--eval_data data/clean/training_data.txt
|
| 650 |
+
|
| 651 |
+
# Specific metrics
|
| 652 |
+
python core/src/evaluate_model.py \\
|
| 653 |
+
--model_dir models/small-extended-4k \\
|
| 654 |
+
--metrics perplexity,generation \\
|
| 655 |
+
--output results.json
|
| 656 |
+
""",
|
| 657 |
+
)
|
| 658 |
+
|
| 659 |
+
parser.add_argument("--model_dir", required=True, help="Directory containing trained model")
|
| 660 |
+
|
| 661 |
+
parser.add_argument(
|
| 662 |
+
"--eval_data", help="Path to evaluation text file (default: use sample texts)"
|
| 663 |
+
)
|
| 664 |
+
|
| 665 |
+
parser.add_argument(
|
| 666 |
+
"--metrics",
|
| 667 |
+
default="perplexity,loss,generation",
|
| 668 |
+
help="Comma-separated list of metrics to evaluate (default: perplexity,loss,generation)",
|
| 669 |
+
)
|
| 670 |
+
|
| 671 |
+
parser.add_argument("--output", help="Output JSON file for results (default: print to console)")
|
| 672 |
+
|
| 673 |
+
parser.add_argument(
|
| 674 |
+
"--device",
|
| 675 |
+
choices=["cpu", "cuda", "auto"],
|
| 676 |
+
default="auto",
|
| 677 |
+
help="Device for evaluation (default: auto)",
|
| 678 |
+
)
|
| 679 |
+
|
| 680 |
+
parser.add_argument(
|
| 681 |
+
"--generation_prompts", help="File containing prompts for text generation evaluation"
|
| 682 |
+
)
|
| 683 |
+
|
| 684 |
+
args = parser.parse_args()
|
| 685 |
+
|
| 686 |
+
print("π OpenLLM Model Evaluation")
|
| 687 |
+
print("=" * 50)
|
| 688 |
+
|
| 689 |
+
# Determine device
|
| 690 |
+
if args.device == "auto":
|
| 691 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 692 |
+
else:
|
| 693 |
+
device = args.device
|
| 694 |
+
|
| 695 |
+
print(f"Using device: {device}")
|
| 696 |
+
|
| 697 |
+
try:
|
| 698 |
+
# Load model
|
| 699 |
+
model, tokenizer_path = load_model_from_directory(args.model_dir, device)
|
| 700 |
+
|
| 701 |
+
# Create evaluator
|
| 702 |
+
evaluator = ModelEvaluator(model, tokenizer_path, device)
|
| 703 |
+
|
| 704 |
+
# Parse metrics
|
| 705 |
+
metrics = [m.strip() for m in args.metrics.split(",")]
|
| 706 |
+
|
| 707 |
+
# Load generation prompts if specified
|
| 708 |
+
generation_prompts = None
|
| 709 |
+
if args.generation_prompts and os.path.exists(args.generation_prompts):
|
| 710 |
+
with open(args.generation_prompts, "r", encoding="utf-8") as f:
|
| 711 |
+
generation_prompts = [line.strip() for line in f if line.strip()]
|
| 712 |
+
|
| 713 |
+
# Run evaluation
|
| 714 |
+
eval_data_path = args.eval_data or "data/clean/training_data.txt"
|
| 715 |
+
results = evaluator.run_comprehensive_evaluation(
|
| 716 |
+
eval_data_path, metrics, generation_prompts
|
| 717 |
+
)
|
| 718 |
+
|
| 719 |
+
# Output results
|
| 720 |
+
if args.output:
|
| 721 |
+
with open(args.output, "w", encoding="utf-8") as f:
|
| 722 |
+
json.dump(results, f, indent=2)
|
| 723 |
+
print(f"\nπΎ Results saved to {args.output}")
|
| 724 |
+
else:
|
| 725 |
+
print("\nπ Evaluation Results:")
|
| 726 |
+
print("=" * 50)
|
| 727 |
+
|
| 728 |
+
# Print key metrics
|
| 729 |
+
if "intrinsic_evaluation" in results:
|
| 730 |
+
intrinsic = results["intrinsic_evaluation"]
|
| 731 |
+
print("π Intrinsic Metrics:")
|
| 732 |
+
print(f" Loss: {intrinsic['loss']:.4f}")
|
| 733 |
+
print(f" Perplexity: {intrinsic['perplexity']:.2f}")
|
| 734 |
+
print(f" Sequences evaluated: {intrinsic['num_sequences']:,}")
|
| 735 |
+
|
| 736 |
+
if "generation_evaluation" in results:
|
| 737 |
+
gen_summary = results["generation_evaluation"]["summary"]
|
| 738 |
+
print("\nβοΈ Generation Quality:")
|
| 739 |
+
print(
|
| 740 |
+
f" Avg generation speed: {gen_summary['avg_tokens_per_second']:.1f} tokens/sec"
|
| 741 |
+
)
|
| 742 |
+
print(f" Avg text length: {gen_summary['avg_length']:.1f} words")
|
| 743 |
+
print(f" Repetition rate: {gen_summary['avg_repetition_rate']:.3f}")
|
| 744 |
+
print(f" Coherence score: {gen_summary['avg_coherence_score']:.3f}")
|
| 745 |
+
|
| 746 |
+
# Quality assessment
|
| 747 |
+
if "quality_assessment" in results:
|
| 748 |
+
assessment = results["quality_assessment"]
|
| 749 |
+
print("\nπ― Overall Assessment:")
|
| 750 |
+
print(f" Quality Level: {assessment['quality_level'].upper()}")
|
| 751 |
+
for rec in assessment["recommendations"]:
|
| 752 |
+
print(f" β’ {rec}")
|
| 753 |
+
|
| 754 |
+
print("\nπ Evaluation completed successfully!")
|
| 755 |
+
|
| 756 |
+
except Exception as e:
|
| 757 |
+
print(f"\nβ Evaluation failed: {e}")
|
| 758 |
+
import traceback
|
| 759 |
+
|
| 760 |
+
traceback.print_exc()
|
| 761 |
+
return False
|
| 762 |
+
|
| 763 |
+
return True
|
| 764 |
+
|
| 765 |
+
|
| 766 |
+
if __name__ == "__main__":
|
| 767 |
+
main()
|
core/src/export_model.py
ADDED
|
@@ -0,0 +1,727 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright (C) 2024 Louis Chua Bean Chong
|
| 3 |
+
#
|
| 4 |
+
# This file is part of OpenLLM.
|
| 5 |
+
#
|
| 6 |
+
# OpenLLM is dual-licensed:
|
| 7 |
+
# 1. For open source use: GNU General Public License v3.0
|
| 8 |
+
# 2. For commercial use: Commercial License (contact for details)
|
| 9 |
+
#
|
| 10 |
+
# See LICENSE and docs/LICENSES.md for full license information.
|
| 11 |
+
|
| 12 |
+
"""
|
| 13 |
+
OpenLLM Model Export Script
|
| 14 |
+
|
| 15 |
+
This script implements Step 6 of the training pipeline: Model Export & Deployment.
|
| 16 |
+
It exports trained OpenLLM models to various formats for production inference.
|
| 17 |
+
|
| 18 |
+
Supported Formats:
|
| 19 |
+
- PyTorch native format (for Python inference)
|
| 20 |
+
- Hugging Face format (for ecosystem compatibility)
|
| 21 |
+
- ONNX format (for optimized cross-platform inference)
|
| 22 |
+
|
| 23 |
+
Usage:
|
| 24 |
+
# PyTorch format
|
| 25 |
+
python core/src/export_model.py \
|
| 26 |
+
--model_dir models/small-extended-4k \
|
| 27 |
+
--format pytorch \
|
| 28 |
+
--output_dir exports/pytorch/
|
| 29 |
+
|
| 30 |
+
# Hugging Face format
|
| 31 |
+
python core/src/export_model.py \
|
| 32 |
+
--model_dir models/small-extended-4k \
|
| 33 |
+
--format huggingface \
|
| 34 |
+
--output_dir exports/huggingface/
|
| 35 |
+
|
| 36 |
+
# ONNX format
|
| 37 |
+
python core/src/export_model.py \
|
| 38 |
+
--model_dir models/small-extended-4k \
|
| 39 |
+
--format onnx \
|
| 40 |
+
--output_dir exports/onnx/ \
|
| 41 |
+
--optimize_for_inference
|
| 42 |
+
|
| 43 |
+
Author: Louis Chua Bean Chong
|
| 44 |
+
License: GPLv3
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
import argparse
|
| 48 |
+
import json
|
| 49 |
+
import os
|
| 50 |
+
import shutil
|
| 51 |
+
import sys
|
| 52 |
+
from pathlib import Path
|
| 53 |
+
from typing import Dict
|
| 54 |
+
|
| 55 |
+
import torch
|
| 56 |
+
|
| 57 |
+
# Add current directory to path for imports
|
| 58 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 59 |
+
|
| 60 |
+
from model import create_model
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class ModelExporter:
|
| 64 |
+
"""
|
| 65 |
+
Comprehensive model exporter for OpenLLM models.
|
| 66 |
+
|
| 67 |
+
Handles export to multiple formats including PyTorch, Hugging Face,
|
| 68 |
+
and ONNX for different deployment scenarios.
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
def __init__(self, model_dir: str, output_dir: str):
|
| 72 |
+
"""
|
| 73 |
+
Initialize the model exporter.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
model_dir: Directory containing trained model checkpoints
|
| 77 |
+
output_dir: Base directory for exported models
|
| 78 |
+
"""
|
| 79 |
+
self.model_dir = Path(model_dir)
|
| 80 |
+
self.output_dir = Path(output_dir)
|
| 81 |
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 82 |
+
|
| 83 |
+
# Load model and metadata
|
| 84 |
+
self.model, self.config, self.training_info = self._load_model()
|
| 85 |
+
self.tokenizer_path = self._find_tokenizer()
|
| 86 |
+
|
| 87 |
+
print("π§ ModelExporter initialized")
|
| 88 |
+
print(f" Model: {self.config.model_name}")
|
| 89 |
+
print(f" Parameters: {self.model.get_num_params():,}")
|
| 90 |
+
print(f" Output directory: {output_dir}")
|
| 91 |
+
|
| 92 |
+
def _load_model(self):
|
| 93 |
+
"""Load model from checkpoint directory."""
|
| 94 |
+
# Find best model checkpoint
|
| 95 |
+
best_model_path = self.model_dir / "best_model.pt"
|
| 96 |
+
if not best_model_path.exists():
|
| 97 |
+
# Look for latest checkpoint
|
| 98 |
+
checkpoints = list(self.model_dir.glob("checkpoint_step_*.pt"))
|
| 99 |
+
if not checkpoints:
|
| 100 |
+
raise FileNotFoundError(f"No model checkpoints found in {self.model_dir}")
|
| 101 |
+
|
| 102 |
+
# Get latest checkpoint
|
| 103 |
+
latest_checkpoint = max(checkpoints, key=lambda p: int(p.stem.split("_")[-1]))
|
| 104 |
+
best_model_path = latest_checkpoint
|
| 105 |
+
|
| 106 |
+
print(f"π Loading model from {best_model_path}")
|
| 107 |
+
|
| 108 |
+
# Load checkpoint
|
| 109 |
+
checkpoint = torch.load(best_model_path, map_location="cpu")
|
| 110 |
+
|
| 111 |
+
# Determine model size from config
|
| 112 |
+
config_dict = checkpoint.get("config", {})
|
| 113 |
+
n_layer = config_dict.get("n_layer", 12)
|
| 114 |
+
|
| 115 |
+
if n_layer <= 6:
|
| 116 |
+
model_size = "small"
|
| 117 |
+
elif n_layer <= 12:
|
| 118 |
+
model_size = "medium"
|
| 119 |
+
else:
|
| 120 |
+
model_size = "large"
|
| 121 |
+
|
| 122 |
+
# Create and load model
|
| 123 |
+
model = create_model(model_size)
|
| 124 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
| 125 |
+
model.eval() # Set to evaluation mode
|
| 126 |
+
|
| 127 |
+
# Extract training info
|
| 128 |
+
training_info = {
|
| 129 |
+
"step": checkpoint.get("step", 0),
|
| 130 |
+
"best_loss": checkpoint.get("best_loss", 0.0),
|
| 131 |
+
"model_size": model_size,
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
return model, model.config, training_info
|
| 135 |
+
|
| 136 |
+
def _find_tokenizer(self):
|
| 137 |
+
"""Find tokenizer path."""
|
| 138 |
+
# Try multiple possible locations
|
| 139 |
+
possible_paths = [
|
| 140 |
+
self.model_dir.parent / "tokenizer" / "tokenizer.model",
|
| 141 |
+
Path("data/tokenizer/tokenizer.model"),
|
| 142 |
+
self.model_dir / "tokenizer.model",
|
| 143 |
+
]
|
| 144 |
+
|
| 145 |
+
for path in possible_paths:
|
| 146 |
+
if path.exists():
|
| 147 |
+
return str(path)
|
| 148 |
+
|
| 149 |
+
raise FileNotFoundError("Tokenizer not found in expected locations")
|
| 150 |
+
|
| 151 |
+
def export_pytorch(self) -> str:
|
| 152 |
+
"""
|
| 153 |
+
Export model in PyTorch native format.
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
Path to exported model directory
|
| 157 |
+
"""
|
| 158 |
+
output_path = self.output_dir / "pytorch"
|
| 159 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 160 |
+
|
| 161 |
+
print("π Exporting to PyTorch format...")
|
| 162 |
+
|
| 163 |
+
# Save model state dict
|
| 164 |
+
model_path = output_path / "model.pt"
|
| 165 |
+
torch.save(
|
| 166 |
+
{
|
| 167 |
+
"model_state_dict": self.model.state_dict(),
|
| 168 |
+
"config": self.config.__dict__,
|
| 169 |
+
"training_info": self.training_info,
|
| 170 |
+
},
|
| 171 |
+
model_path,
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
# Save configuration
|
| 175 |
+
config_path = output_path / "config.json"
|
| 176 |
+
with open(config_path, "w") as f:
|
| 177 |
+
json.dump(
|
| 178 |
+
{
|
| 179 |
+
"model_config": self.config.__dict__,
|
| 180 |
+
"training_info": self.training_info,
|
| 181 |
+
"export_format": "pytorch",
|
| 182 |
+
},
|
| 183 |
+
f,
|
| 184 |
+
indent=2,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
# Copy tokenizer
|
| 188 |
+
tokenizer_out = output_path / "tokenizer.model"
|
| 189 |
+
shutil.copy2(self.tokenizer_path, tokenizer_out)
|
| 190 |
+
|
| 191 |
+
# Create loading script
|
| 192 |
+
self._create_pytorch_loader(output_path)
|
| 193 |
+
|
| 194 |
+
print(f"β
PyTorch export completed: {output_path}")
|
| 195 |
+
return str(output_path)
|
| 196 |
+
|
| 197 |
+
def export_huggingface(self) -> str:
|
| 198 |
+
"""
|
| 199 |
+
Export model in Hugging Face compatible format.
|
| 200 |
+
|
| 201 |
+
Returns:
|
| 202 |
+
Path to exported model directory
|
| 203 |
+
"""
|
| 204 |
+
output_path = self.output_dir / "huggingface"
|
| 205 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 206 |
+
|
| 207 |
+
print("π Exporting to Hugging Face format...")
|
| 208 |
+
|
| 209 |
+
# Save model weights in HF format
|
| 210 |
+
model_path = output_path / "pytorch_model.bin"
|
| 211 |
+
torch.save(self.model.state_dict(), model_path)
|
| 212 |
+
|
| 213 |
+
# Create HF-compatible config
|
| 214 |
+
hf_config = {
|
| 215 |
+
"architectures": ["GPTModel"],
|
| 216 |
+
"model_type": "gpt",
|
| 217 |
+
"vocab_size": self.config.vocab_size,
|
| 218 |
+
"n_layer": self.config.n_layer,
|
| 219 |
+
"n_head": self.config.n_head,
|
| 220 |
+
"n_embd": self.config.n_embd,
|
| 221 |
+
"block_size": self.config.block_size,
|
| 222 |
+
"dropout": self.config.dropout,
|
| 223 |
+
"bias": self.config.bias,
|
| 224 |
+
"torch_dtype": "float32",
|
| 225 |
+
"transformers_version": "4.0.0",
|
| 226 |
+
"openllm_version": "0.1.0",
|
| 227 |
+
"training_steps": self.training_info["step"],
|
| 228 |
+
"model_size": self.training_info["model_size"],
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
config_path = output_path / "config.json"
|
| 232 |
+
with open(config_path, "w") as f:
|
| 233 |
+
json.dump(hf_config, f, indent=2)
|
| 234 |
+
|
| 235 |
+
# Copy tokenizer with HF naming
|
| 236 |
+
shutil.copy2(self.tokenizer_path, output_path / "tokenizer.model")
|
| 237 |
+
|
| 238 |
+
# Create tokenizer config
|
| 239 |
+
tokenizer_config = {
|
| 240 |
+
"tokenizer_class": "SentencePieceTokenizer",
|
| 241 |
+
"model_max_length": self.config.block_size,
|
| 242 |
+
"vocab_size": self.config.vocab_size,
|
| 243 |
+
"unk_token": "<unk>",
|
| 244 |
+
"bos_token": "<s>",
|
| 245 |
+
"eos_token": "</s>",
|
| 246 |
+
"pad_token": "<pad>",
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
with open(output_path / "tokenizer_config.json", "w") as f:
|
| 250 |
+
json.dump(tokenizer_config, f, indent=2)
|
| 251 |
+
|
| 252 |
+
# Create generation config
|
| 253 |
+
generation_config = {
|
| 254 |
+
"max_length": 512,
|
| 255 |
+
"max_new_tokens": 256,
|
| 256 |
+
"temperature": 0.7,
|
| 257 |
+
"top_k": 40,
|
| 258 |
+
"top_p": 0.9,
|
| 259 |
+
"do_sample": True,
|
| 260 |
+
"pad_token_id": 0,
|
| 261 |
+
"eos_token_id": 1,
|
| 262 |
+
"bos_token_id": 2,
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
with open(output_path / "generation_config.json", "w") as f:
|
| 266 |
+
json.dump(generation_config, f, indent=2)
|
| 267 |
+
|
| 268 |
+
# Create HF loading script
|
| 269 |
+
self._create_hf_loader(output_path)
|
| 270 |
+
|
| 271 |
+
print(f"β
Hugging Face export completed: {output_path}")
|
| 272 |
+
return str(output_path)
|
| 273 |
+
|
| 274 |
+
def export_onnx(self, optimize_for_inference: bool = False) -> str:
|
| 275 |
+
"""
|
| 276 |
+
Export model to ONNX format for optimized inference.
|
| 277 |
+
|
| 278 |
+
Args:
|
| 279 |
+
optimize_for_inference: Whether to apply ONNX optimizations
|
| 280 |
+
|
| 281 |
+
Returns:
|
| 282 |
+
Path to exported ONNX model
|
| 283 |
+
"""
|
| 284 |
+
try:
|
| 285 |
+
import onnx
|
| 286 |
+
import onnxruntime
|
| 287 |
+
except ImportError:
|
| 288 |
+
raise ImportError("ONNX export requires: pip install onnx onnxruntime")
|
| 289 |
+
|
| 290 |
+
output_path = self.output_dir / "onnx"
|
| 291 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 292 |
+
|
| 293 |
+
print("π Exporting to ONNX format...")
|
| 294 |
+
|
| 295 |
+
# Prepare model for export
|
| 296 |
+
self.model.eval()
|
| 297 |
+
|
| 298 |
+
# Create dummy input for tracing
|
| 299 |
+
batch_size = 1
|
| 300 |
+
seq_len = 64 # Use shorter sequence for compatibility
|
| 301 |
+
dummy_input = torch.randint(0, self.config.vocab_size, (batch_size, seq_len))
|
| 302 |
+
|
| 303 |
+
# Export to ONNX
|
| 304 |
+
onnx_path = output_path / "model.onnx"
|
| 305 |
+
|
| 306 |
+
torch.onnx.export(
|
| 307 |
+
self.model,
|
| 308 |
+
dummy_input,
|
| 309 |
+
onnx_path,
|
| 310 |
+
export_params=True,
|
| 311 |
+
opset_version=11,
|
| 312 |
+
do_constant_folding=True,
|
| 313 |
+
input_names=["input_ids"],
|
| 314 |
+
output_names=["logits"],
|
| 315 |
+
dynamic_axes={
|
| 316 |
+
"input_ids": {0: "batch_size", 1: "sequence_length"},
|
| 317 |
+
"logits": {0: "batch_size", 1: "sequence_length"},
|
| 318 |
+
},
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
# Verify ONNX model
|
| 322 |
+
onnx_model = onnx.load(str(onnx_path))
|
| 323 |
+
onnx.checker.check_model(onnx_model)
|
| 324 |
+
|
| 325 |
+
# Apply optimizations if requested
|
| 326 |
+
if optimize_for_inference:
|
| 327 |
+
self._optimize_onnx_model(onnx_path)
|
| 328 |
+
|
| 329 |
+
# Save metadata
|
| 330 |
+
metadata = {
|
| 331 |
+
"model_config": self.config.__dict__,
|
| 332 |
+
"training_info": self.training_info,
|
| 333 |
+
"export_format": "onnx",
|
| 334 |
+
"input_shape": [batch_size, seq_len],
|
| 335 |
+
"input_names": ["input_ids"],
|
| 336 |
+
"output_names": ["logits"],
|
| 337 |
+
"optimized": optimize_for_inference,
|
| 338 |
+
}
|
| 339 |
+
|
| 340 |
+
with open(output_path / "metadata.json", "w") as f:
|
| 341 |
+
json.dump(metadata, f, indent=2)
|
| 342 |
+
|
| 343 |
+
# Copy tokenizer
|
| 344 |
+
shutil.copy2(self.tokenizer_path, output_path / "tokenizer.model")
|
| 345 |
+
|
| 346 |
+
# Create ONNX inference script
|
| 347 |
+
self._create_onnx_inference(output_path)
|
| 348 |
+
|
| 349 |
+
print(f"β
ONNX export completed: {onnx_path}")
|
| 350 |
+
return str(onnx_path)
|
| 351 |
+
|
| 352 |
+
def _optimize_onnx_model(self, onnx_path: Path):
|
| 353 |
+
"""Apply ONNX optimizations for inference."""
|
| 354 |
+
try:
|
| 355 |
+
import onnxruntime
|
| 356 |
+
from onnxruntime.tools import optimizer
|
| 357 |
+
|
| 358 |
+
print("π§ Applying ONNX optimizations...")
|
| 359 |
+
|
| 360 |
+
# Create optimized model
|
| 361 |
+
optimized_path = onnx_path.parent / "model_optimized.onnx"
|
| 362 |
+
|
| 363 |
+
# Apply graph optimizations
|
| 364 |
+
optimizer.optimize_model(
|
| 365 |
+
str(onnx_path),
|
| 366 |
+
str(optimized_path),
|
| 367 |
+
optimization_level=optimizer.OptimizationLevel.ORT_ENABLE_ALL,
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
# Replace original with optimized
|
| 371 |
+
shutil.move(str(optimized_path), str(onnx_path))
|
| 372 |
+
|
| 373 |
+
print("β
ONNX optimizations applied")
|
| 374 |
+
|
| 375 |
+
except ImportError:
|
| 376 |
+
print("β οΈ ONNX optimization requires onnxruntime-tools")
|
| 377 |
+
except Exception as e:
|
| 378 |
+
print(f"β οΈ ONNX optimization failed: {e}")
|
| 379 |
+
|
| 380 |
+
def _create_pytorch_loader(self, output_path: Path):
|
| 381 |
+
"""Create PyTorch model loader script."""
|
| 382 |
+
loader_script = '''#!/usr/bin/env python3
|
| 383 |
+
"""
|
| 384 |
+
PyTorch Model Loader for OpenLLM
|
| 385 |
+
|
| 386 |
+
Usage:
|
| 387 |
+
from load_model import load_model, generate_text
|
| 388 |
+
|
| 389 |
+
model, tokenizer, config = load_model(".")
|
| 390 |
+
text = generate_text(model, tokenizer, "Hello world", max_length=50)
|
| 391 |
+
print(text)
|
| 392 |
+
"""
|
| 393 |
+
|
| 394 |
+
import torch
|
| 395 |
+
import json
|
| 396 |
+
import sentencepiece as spm
|
| 397 |
+
from pathlib import Path
|
| 398 |
+
|
| 399 |
+
def load_model(model_dir="."):
|
| 400 |
+
"""Load OpenLLM model from PyTorch export."""
|
| 401 |
+
model_dir = Path(model_dir)
|
| 402 |
+
|
| 403 |
+
# Load config
|
| 404 |
+
with open(model_dir / "config.json", 'r') as f:
|
| 405 |
+
config_data = json.load(f)
|
| 406 |
+
|
| 407 |
+
model_config = config_data['model_config']
|
| 408 |
+
|
| 409 |
+
# Recreate model architecture (you'll need to have the model.py file)
|
| 410 |
+
# This is a simplified loader - in practice you'd import your GPTModel class
|
| 411 |
+
print(f"Model config: {model_config}")
|
| 412 |
+
print("Note: You need to import and create the actual model class")
|
| 413 |
+
|
| 414 |
+
# Load model state
|
| 415 |
+
checkpoint = torch.load(model_dir / "model.pt", map_location='cpu')
|
| 416 |
+
|
| 417 |
+
# Load tokenizer
|
| 418 |
+
tokenizer = smp.SentencePieceProcessor()
|
| 419 |
+
tokenizer.load(str(model_dir / "tokenizer.model"))
|
| 420 |
+
|
| 421 |
+
return None, tokenizer, model_config # Placeholder
|
| 422 |
+
|
| 423 |
+
def generate_text(model, tokenizer, prompt, max_length=100):
|
| 424 |
+
"""Generate text using the loaded model."""
|
| 425 |
+
# Implement text generation
|
| 426 |
+
return f"Generated text for: {prompt}"
|
| 427 |
+
|
| 428 |
+
if __name__ == "__main__":
|
| 429 |
+
model, tokenizer, config = load_model()
|
| 430 |
+
print(f"Model loaded with {config.get('vocab_size', 'unknown')} vocabulary size")
|
| 431 |
+
'''
|
| 432 |
+
|
| 433 |
+
with open(output_path / "load_model.py", "w") as f:
|
| 434 |
+
f.write(loader_script)
|
| 435 |
+
|
| 436 |
+
def _create_hf_loader(self, output_path: Path):
|
| 437 |
+
"""Create Hugging Face model loader script."""
|
| 438 |
+
loader_script = '''#!/usr/bin/env python3
|
| 439 |
+
"""
|
| 440 |
+
Hugging Face Compatible Loader for OpenLLM
|
| 441 |
+
|
| 442 |
+
Usage:
|
| 443 |
+
# Using transformers library (if you implement custom model class)
|
| 444 |
+
# from transformers import AutoModel, AutoTokenizer
|
| 445 |
+
# model = AutoModel.from_pretrained(".")
|
| 446 |
+
# tokenizer = AutoTokenizer.from_pretrained(".")
|
| 447 |
+
|
| 448 |
+
# Manual loading
|
| 449 |
+
from load_hf_model import load_model_manual
|
| 450 |
+
model, tokenizer = load_model_manual(".")
|
| 451 |
+
"""
|
| 452 |
+
|
| 453 |
+
import torch
|
| 454 |
+
import json
|
| 455 |
+
import sentencepiece as smp
|
| 456 |
+
from pathlib import Path
|
| 457 |
+
|
| 458 |
+
def load_model_manual(model_dir="."):
|
| 459 |
+
"""Manually load model in HF format."""
|
| 460 |
+
model_dir = Path(model_dir)
|
| 461 |
+
|
| 462 |
+
# Load config
|
| 463 |
+
with open(model_dir / "config.json", 'r') as f:
|
| 464 |
+
config = json.load(f)
|
| 465 |
+
|
| 466 |
+
# Load model weights
|
| 467 |
+
state_dict = torch.load(model_dir / "pytorch_model.bin", map_location='cpu')
|
| 468 |
+
|
| 469 |
+
# Load tokenizer
|
| 470 |
+
tokenizer = smp.SentencePieceProcessor()
|
| 471 |
+
tokenizer.load(str(model_dir / "tokenizer.model"))
|
| 472 |
+
|
| 473 |
+
print(f"Loaded model: {config['model_type']} with {config['n_layer']} layers")
|
| 474 |
+
print(f"Vocabulary size: {config['vocab_size']}")
|
| 475 |
+
|
| 476 |
+
return state_dict, tokenizer
|
| 477 |
+
|
| 478 |
+
if __name__ == "__main__":
|
| 479 |
+
state_dict, tokenizer = load_model_manual()
|
| 480 |
+
print(f"Model weights loaded: {len(state_dict)} parameters")
|
| 481 |
+
print(f"Tokenizer vocabulary: {tokenizer.vocab_size()}")
|
| 482 |
+
'''
|
| 483 |
+
|
| 484 |
+
with open(output_path / "load_hf_model.py", "w") as f:
|
| 485 |
+
f.write(loader_script)
|
| 486 |
+
|
| 487 |
+
def _create_onnx_inference(self, output_path: Path):
|
| 488 |
+
"""Create ONNX inference script."""
|
| 489 |
+
inference_script = '''#!/usr/bin/env python3
|
| 490 |
+
"""
|
| 491 |
+
ONNX Inference for OpenLLM
|
| 492 |
+
|
| 493 |
+
Usage:
|
| 494 |
+
from onnx_inference import ONNXInference
|
| 495 |
+
|
| 496 |
+
inference = ONNXInference(".")
|
| 497 |
+
output = inference.generate("Hello world", max_length=50)
|
| 498 |
+
print(output)
|
| 499 |
+
"""
|
| 500 |
+
|
| 501 |
+
import numpy as np
|
| 502 |
+
import json
|
| 503 |
+
import sentencepiece as smp
|
| 504 |
+
from pathlib import Path
|
| 505 |
+
|
| 506 |
+
try:
|
| 507 |
+
import onnxruntime as ort
|
| 508 |
+
except ImportError:
|
| 509 |
+
print("Install onnxruntime: pip install onnxruntime")
|
| 510 |
+
ort = None
|
| 511 |
+
|
| 512 |
+
class ONNXInference:
|
| 513 |
+
def __init__(self, model_dir="."):
|
| 514 |
+
if ort is None:
|
| 515 |
+
raise ImportError("onnxruntime not available")
|
| 516 |
+
|
| 517 |
+
model_dir = Path(model_dir)
|
| 518 |
+
|
| 519 |
+
# Load ONNX model
|
| 520 |
+
self.session = ort.InferenceSession(str(model_dir / "model.onnx"))
|
| 521 |
+
|
| 522 |
+
# Load metadata
|
| 523 |
+
with open(model_dir / "metadata.json", 'r') as f:
|
| 524 |
+
self.metadata = json.load(f)
|
| 525 |
+
|
| 526 |
+
# Load tokenizer
|
| 527 |
+
self.tokenizer = smp.SentencePieceProcessor()
|
| 528 |
+
self.tokenizer.load(str(model_dir / "tokenizer.model"))
|
| 529 |
+
|
| 530 |
+
print(f"ONNX model loaded: {self.metadata['model_config']['model_name']}")
|
| 531 |
+
|
| 532 |
+
def predict(self, input_ids):
|
| 533 |
+
"""Run inference on input token IDs."""
|
| 534 |
+
# Prepare input
|
| 535 |
+
input_data = {"input_ids": input_ids.astype(np.int64)}
|
| 536 |
+
|
| 537 |
+
# Run inference
|
| 538 |
+
outputs = self.session.run(None, input_data)
|
| 539 |
+
return outputs[0] # logits
|
| 540 |
+
|
| 541 |
+
def generate(self, prompt, max_length=50, temperature=0.7):
|
| 542 |
+
"""Generate text from prompt."""
|
| 543 |
+
# Tokenize prompt
|
| 544 |
+
tokens = self.tokenizer.encode(prompt)
|
| 545 |
+
input_ids = np.array([tokens], dtype=np.int64)
|
| 546 |
+
|
| 547 |
+
# Simple greedy generation (can be improved)
|
| 548 |
+
generated = tokens.copy()
|
| 549 |
+
|
| 550 |
+
for _ in range(max_length):
|
| 551 |
+
if len(generated) >= 512: # Max sequence length
|
| 552 |
+
break
|
| 553 |
+
|
| 554 |
+
# Get current input (last 64 tokens to fit ONNX model)
|
| 555 |
+
current_input = np.array([generated[-64:]], dtype=np.int64)
|
| 556 |
+
|
| 557 |
+
# Predict next token
|
| 558 |
+
logits = self.predict(current_input)
|
| 559 |
+
next_token_logits = logits[0, -1, :] # Last position
|
| 560 |
+
|
| 561 |
+
# Apply temperature and sample
|
| 562 |
+
if temperature > 0:
|
| 563 |
+
next_token_logits = next_token_logits / temperature
|
| 564 |
+
probs = np.exp(next_token_logits) / np.sum(np.exp(next_token_logits))
|
| 565 |
+
next_token = np.random.choice(len(probs), p=probs)
|
| 566 |
+
else:
|
| 567 |
+
next_token = np.argmax(next_token_logits)
|
| 568 |
+
|
| 569 |
+
generated.append(int(next_token))
|
| 570 |
+
|
| 571 |
+
# Decode generated text
|
| 572 |
+
generated_text = self.tokenizer.decode(generated[len(tokens):])
|
| 573 |
+
return generated_text
|
| 574 |
+
|
| 575 |
+
if __name__ == "__main__":
|
| 576 |
+
inference = ONNXInference()
|
| 577 |
+
result = inference.generate("The future of AI is", max_length=30)
|
| 578 |
+
print(f"Generated: {result}")
|
| 579 |
+
'''
|
| 580 |
+
|
| 581 |
+
with open(output_path / "onnx_inference.py", "w") as f:
|
| 582 |
+
f.write(inference_script)
|
| 583 |
+
|
| 584 |
+
def export_all_formats(self, optimize_onnx: bool = False) -> Dict[str, str]:
|
| 585 |
+
"""
|
| 586 |
+
Export model to all supported formats.
|
| 587 |
+
|
| 588 |
+
Args:
|
| 589 |
+
optimize_onnx: Whether to optimize ONNX model
|
| 590 |
+
|
| 591 |
+
Returns:
|
| 592 |
+
Dictionary mapping format names to export paths
|
| 593 |
+
"""
|
| 594 |
+
results = {}
|
| 595 |
+
|
| 596 |
+
print("π Exporting to all formats...")
|
| 597 |
+
|
| 598 |
+
try:
|
| 599 |
+
results["pytorch"] = self.export_pytorch()
|
| 600 |
+
except Exception as e:
|
| 601 |
+
print(f"β PyTorch export failed: {e}")
|
| 602 |
+
|
| 603 |
+
try:
|
| 604 |
+
results["huggingface"] = self.export_huggingface()
|
| 605 |
+
except Exception as e:
|
| 606 |
+
print(f"β Hugging Face export failed: {e}")
|
| 607 |
+
|
| 608 |
+
try:
|
| 609 |
+
results["onnx"] = self.export_onnx(optimize_onnx)
|
| 610 |
+
except Exception as e:
|
| 611 |
+
print(f"β ONNX export failed: {e}")
|
| 612 |
+
|
| 613 |
+
# Create summary
|
| 614 |
+
summary = {
|
| 615 |
+
"export_timestamp": torch.datetime.now().isoformat(),
|
| 616 |
+
"model_info": {
|
| 617 |
+
"name": self.config.model_name,
|
| 618 |
+
"parameters": self.model.get_num_params(),
|
| 619 |
+
"training_steps": self.training_info["step"],
|
| 620 |
+
"best_loss": self.training_info["best_loss"],
|
| 621 |
+
},
|
| 622 |
+
"exports": results,
|
| 623 |
+
}
|
| 624 |
+
|
| 625 |
+
with open(self.output_dir / "export_summary.json", "w") as f:
|
| 626 |
+
json.dump(summary, f, indent=2)
|
| 627 |
+
|
| 628 |
+
print(f"β
Export summary saved: {self.output_dir / 'export_summary.json'}")
|
| 629 |
+
|
| 630 |
+
return results
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
def main():
|
| 634 |
+
"""Main export function."""
|
| 635 |
+
parser = argparse.ArgumentParser(
|
| 636 |
+
description="Export OpenLLM models to various formats",
|
| 637 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 638 |
+
epilog="""
|
| 639 |
+
Examples:
|
| 640 |
+
# Export to PyTorch format
|
| 641 |
+
python core/src/export_model.py \\
|
| 642 |
+
--model_dir models/small-extended-4k \\
|
| 643 |
+
--format pytorch \\
|
| 644 |
+
--output_dir exports/pytorch/
|
| 645 |
+
|
| 646 |
+
# Export to Hugging Face format
|
| 647 |
+
python core/src/export_model.py \\
|
| 648 |
+
--model_dir models/small-extended-4k \\
|
| 649 |
+
--format huggingface \\
|
| 650 |
+
--output_dir exports/huggingface/
|
| 651 |
+
|
| 652 |
+
# Export to ONNX with optimizations
|
| 653 |
+
python core/src/export_model.py \\
|
| 654 |
+
--model_dir models/small-extended-4k \\
|
| 655 |
+
--format onnx \\
|
| 656 |
+
--output_dir exports/onnx/ \\
|
| 657 |
+
--optimize_for_inference
|
| 658 |
+
|
| 659 |
+
# Export to all formats
|
| 660 |
+
python core/src/export_model.py \\
|
| 661 |
+
--model_dir models/small-extended-4k \\
|
| 662 |
+
--format all \\
|
| 663 |
+
--output_dir exports/
|
| 664 |
+
""",
|
| 665 |
+
)
|
| 666 |
+
|
| 667 |
+
parser.add_argument(
|
| 668 |
+
"--model_dir", required=True, help="Directory containing trained model checkpoints"
|
| 669 |
+
)
|
| 670 |
+
|
| 671 |
+
parser.add_argument(
|
| 672 |
+
"--format",
|
| 673 |
+
choices=["pytorch", "huggingface", "onnx", "all"],
|
| 674 |
+
required=True,
|
| 675 |
+
help="Export format",
|
| 676 |
+
)
|
| 677 |
+
|
| 678 |
+
parser.add_argument("--output_dir", required=True, help="Output directory for exported models")
|
| 679 |
+
|
| 680 |
+
parser.add_argument(
|
| 681 |
+
"--optimize_for_inference",
|
| 682 |
+
action="store_true",
|
| 683 |
+
help="Apply optimizations for inference (ONNX only)",
|
| 684 |
+
)
|
| 685 |
+
|
| 686 |
+
args = parser.parse_args()
|
| 687 |
+
|
| 688 |
+
print("π¦ OpenLLM Model Export")
|
| 689 |
+
print("=" * 50)
|
| 690 |
+
|
| 691 |
+
try:
|
| 692 |
+
# Create exporter
|
| 693 |
+
exporter = ModelExporter(args.model_dir, args.output_dir)
|
| 694 |
+
|
| 695 |
+
# Export based on format
|
| 696 |
+
if args.format == "pytorch":
|
| 697 |
+
result = exporter.export_pytorch()
|
| 698 |
+
print(f"\nβ
PyTorch export completed: {result}")
|
| 699 |
+
|
| 700 |
+
elif args.format == "huggingface":
|
| 701 |
+
result = exporter.export_huggingface()
|
| 702 |
+
print(f"\nβ
Hugging Face export completed: {result}")
|
| 703 |
+
|
| 704 |
+
elif args.format == "onnx":
|
| 705 |
+
result = exporter.export_onnx(args.optimize_for_inference)
|
| 706 |
+
print(f"\nβ
ONNX export completed: {result}")
|
| 707 |
+
|
| 708 |
+
elif args.format == "all":
|
| 709 |
+
results = exporter.export_all_formats(args.optimize_for_inference)
|
| 710 |
+
print("\nβ
All formats exported:")
|
| 711 |
+
for fmt, path in results.items():
|
| 712 |
+
print(f" {fmt}: {path}")
|
| 713 |
+
|
| 714 |
+
print("\nπ Export completed successfully!")
|
| 715 |
+
|
| 716 |
+
except Exception as e:
|
| 717 |
+
print(f"\nβ Export failed: {e}")
|
| 718 |
+
import traceback
|
| 719 |
+
|
| 720 |
+
traceback.print_exc()
|
| 721 |
+
return False
|
| 722 |
+
|
| 723 |
+
return True
|
| 724 |
+
|
| 725 |
+
|
| 726 |
+
if __name__ == "__main__":
|
| 727 |
+
main()
|
core/src/generate_text.py
ADDED
|
@@ -0,0 +1,866 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright (C) 2024 Louis Chua Bean Chong
|
| 3 |
+
#
|
| 4 |
+
# This file is part of OpenLLM.
|
| 5 |
+
#
|
| 6 |
+
# OpenLLM is dual-licensed:
|
| 7 |
+
# 1. For open source use: GNU General Public License v3.0
|
| 8 |
+
# 2. For commercial use: Commercial License (contact for details)
|
| 9 |
+
#
|
| 10 |
+
# See LICENSE and docs/LICENSES.md for full license information.
|
| 11 |
+
|
| 12 |
+
"""
|
| 13 |
+
OpenLLM Text Generation Script
|
| 14 |
+
|
| 15 |
+
This script implements standalone text generation for OpenLLM models
|
| 16 |
+
as specified in Step 5 of the training pipeline (Text Generation Quality assessment).
|
| 17 |
+
|
| 18 |
+
Features:
|
| 19 |
+
- Load trained OpenLLM models from checkpoint directories
|
| 20 |
+
- Generate text with configurable parameters (temperature, length, etc.)
|
| 21 |
+
- Support multiple model formats (auto-detection)
|
| 22 |
+
- Quality assessment and metrics
|
| 23 |
+
- Batch generation capabilities
|
| 24 |
+
- Output formatting and saving
|
| 25 |
+
|
| 26 |
+
Usage:
|
| 27 |
+
# Basic text generation
|
| 28 |
+
python core/src/generate_text.py \
|
| 29 |
+
--model_dir models/small-extended-4k \
|
| 30 |
+
--prompt "The history of artificial intelligence" \
|
| 31 |
+
--max_length 256 \
|
| 32 |
+
--temperature 0.7
|
| 33 |
+
|
| 34 |
+
# Multiple prompts with custom settings
|
| 35 |
+
python core/src/generate_text.py \
|
| 36 |
+
--model_dir models/small-extended-4k \
|
| 37 |
+
--prompts_file prompts.txt \
|
| 38 |
+
--max_length 100 \
|
| 39 |
+
--temperature 0.8 \
|
| 40 |
+
--top_k 40 \
|
| 41 |
+
--num_samples 3
|
| 42 |
+
|
| 43 |
+
# Save results to file
|
| 44 |
+
python core/src/generate_text.py \
|
| 45 |
+
--model_dir models/small-extended-4k \
|
| 46 |
+
--prompt "Once upon a time" \
|
| 47 |
+
--output_file generated_samples.txt
|
| 48 |
+
|
| 49 |
+
Author: Louis Chua Bean Chong
|
| 50 |
+
License: GPLv3
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
import argparse
|
| 54 |
+
import os
|
| 55 |
+
import sys
|
| 56 |
+
import time
|
| 57 |
+
from pathlib import Path
|
| 58 |
+
from typing import Any, Dict, List, Optional
|
| 59 |
+
|
| 60 |
+
import sentencepiece as spm
|
| 61 |
+
import torch
|
| 62 |
+
|
| 63 |
+
# Add current directory to path for imports
|
| 64 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 65 |
+
|
| 66 |
+
from model import create_model
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class TextGenerator:
|
| 70 |
+
"""
|
| 71 |
+
Comprehensive text generation engine for OpenLLM models.
|
| 72 |
+
|
| 73 |
+
This class handles loading trained models and generating high-quality text
|
| 74 |
+
with configurable sampling parameters and quality assessment.
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
def __init__(self, model_dir: str, device: str = "auto"):
|
| 78 |
+
"""
|
| 79 |
+
Initialize the text generator.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
model_dir: Directory containing trained model checkpoints
|
| 83 |
+
device: Device to use ("auto", "cpu", "cuda")
|
| 84 |
+
|
| 85 |
+
Implementation Details:
|
| 86 |
+
- Auto-detects best available device if device="auto"
|
| 87 |
+
- Loads model architecture based on checkpoint configuration
|
| 88 |
+
- Sets up tokenizer for text processing
|
| 89 |
+
- Validates model and tokenizer compatibility
|
| 90 |
+
"""
|
| 91 |
+
self.model_dir = Path(model_dir)
|
| 92 |
+
|
| 93 |
+
# Determine device to use
|
| 94 |
+
# Auto-detection prioritizes CUDA if available for better performance
|
| 95 |
+
if device == "auto":
|
| 96 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 97 |
+
else:
|
| 98 |
+
self.device = device
|
| 99 |
+
|
| 100 |
+
print("π OpenLLM Text Generator")
|
| 101 |
+
print(f"π Model directory: {model_dir}")
|
| 102 |
+
print(f"π₯οΈ Device: {self.device}")
|
| 103 |
+
|
| 104 |
+
# Load model and tokenizer
|
| 105 |
+
# This handles the complete setup process
|
| 106 |
+
self._load_model()
|
| 107 |
+
self._load_tokenizer()
|
| 108 |
+
|
| 109 |
+
# Validate setup
|
| 110 |
+
# Ensure model and tokenizer are compatible
|
| 111 |
+
self._validate_setup()
|
| 112 |
+
|
| 113 |
+
print("β
Text generator initialized successfully!")
|
| 114 |
+
|
| 115 |
+
def _load_model(self):
|
| 116 |
+
"""
|
| 117 |
+
Load the trained model from checkpoint.
|
| 118 |
+
|
| 119 |
+
Implementation Details:
|
| 120 |
+
- Searches for best_model.pt or latest checkpoint
|
| 121 |
+
- Auto-detects model size from configuration
|
| 122 |
+
- Handles different checkpoint formats gracefully
|
| 123 |
+
- Sets model to evaluation mode for inference
|
| 124 |
+
"""
|
| 125 |
+
# Find the best model checkpoint
|
| 126 |
+
# Priority: best_model.pt > latest checkpoint by step number
|
| 127 |
+
best_model_path = self.model_dir / "best_model.pt"
|
| 128 |
+
|
| 129 |
+
if best_model_path.exists():
|
| 130 |
+
checkpoint_path = best_model_path
|
| 131 |
+
print(f"π₯ Loading best model: {checkpoint_path}")
|
| 132 |
+
else:
|
| 133 |
+
# Look for step-based checkpoints
|
| 134 |
+
checkpoints = list(self.model_dir.glob("checkpoint_step_*.pt"))
|
| 135 |
+
if not checkpoints:
|
| 136 |
+
raise FileNotFoundError(f"No model checkpoints found in {self.model_dir}")
|
| 137 |
+
|
| 138 |
+
# Get the latest checkpoint by step number
|
| 139 |
+
latest_checkpoint = max(checkpoints, key=lambda p: int(p.stem.split("_")[-1]))
|
| 140 |
+
checkpoint_path = latest_checkpoint
|
| 141 |
+
print(f"π₯ Loading latest checkpoint: {checkpoint_path}")
|
| 142 |
+
|
| 143 |
+
# Load checkpoint data
|
| 144 |
+
# This contains model weights, configuration, and training metadata
|
| 145 |
+
try:
|
| 146 |
+
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
| 147 |
+
print("β
Checkpoint loaded successfully")
|
| 148 |
+
except Exception as e:
|
| 149 |
+
raise RuntimeError(f"Failed to load checkpoint: {e}")
|
| 150 |
+
|
| 151 |
+
# Extract model configuration
|
| 152 |
+
# This tells us what architecture to create
|
| 153 |
+
if "config" in checkpoint:
|
| 154 |
+
config_dict = checkpoint["config"]
|
| 155 |
+
else:
|
| 156 |
+
# Fallback: try to infer from model state dict
|
| 157 |
+
print("β οΈ No config found in checkpoint, inferring from model structure...")
|
| 158 |
+
config_dict = self._infer_config_from_state_dict(
|
| 159 |
+
checkpoint.get("model_state_dict", checkpoint)
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
# Determine model size category
|
| 163 |
+
# This maps checkpoint config to our predefined model sizes
|
| 164 |
+
n_layer = config_dict.get("n_layer", 12)
|
| 165 |
+
n_embd = config_dict.get("n_embd", 768)
|
| 166 |
+
|
| 167 |
+
if n_layer <= 6:
|
| 168 |
+
model_size = "small"
|
| 169 |
+
elif n_layer <= 12:
|
| 170 |
+
model_size = "medium"
|
| 171 |
+
else:
|
| 172 |
+
model_size = "large"
|
| 173 |
+
|
| 174 |
+
print(f"π― Detected model size: {model_size}")
|
| 175 |
+
print(f"π Architecture: {n_layer} layers, {n_embd} embedding dim")
|
| 176 |
+
|
| 177 |
+
# Create model architecture
|
| 178 |
+
# This recreates the exact same model used during training
|
| 179 |
+
try:
|
| 180 |
+
self.model = create_model(model_size)
|
| 181 |
+
print(f"ποΈ Model architecture created: {self.model.get_num_params():,} parameters")
|
| 182 |
+
except Exception as e:
|
| 183 |
+
raise RuntimeError(f"Failed to create model architecture: {e}")
|
| 184 |
+
|
| 185 |
+
# Load trained weights
|
| 186 |
+
# This restores the model to its trained state
|
| 187 |
+
try:
|
| 188 |
+
if "model_state_dict" in checkpoint:
|
| 189 |
+
self.model.load_state_dict(checkpoint["model_state_dict"])
|
| 190 |
+
else:
|
| 191 |
+
# Fallback for different checkpoint formats
|
| 192 |
+
self.model.load_state_dict(checkpoint)
|
| 193 |
+
|
| 194 |
+
print("β
Model weights loaded successfully")
|
| 195 |
+
except Exception as e:
|
| 196 |
+
raise RuntimeError(f"Failed to load model weights: {e}")
|
| 197 |
+
|
| 198 |
+
# Move model to device and set to evaluation mode
|
| 199 |
+
# Evaluation mode disables dropout and other training-specific behaviors
|
| 200 |
+
self.model = self.model.to(self.device)
|
| 201 |
+
self.model.eval()
|
| 202 |
+
|
| 203 |
+
# Store model configuration for later use
|
| 204 |
+
# This is useful for generation parameters and limits
|
| 205 |
+
self.config = self.model.config
|
| 206 |
+
|
| 207 |
+
# Extract training metadata if available
|
| 208 |
+
# This provides context about model quality and training progress
|
| 209 |
+
self.training_info = {
|
| 210 |
+
"step": checkpoint.get("step", "Unknown"),
|
| 211 |
+
"best_loss": checkpoint.get("best_loss", "Unknown"),
|
| 212 |
+
"model_size": model_size,
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
print(
|
| 216 |
+
f"π Training info: step {self.training_info['step']}, "
|
| 217 |
+
f"best loss {self.training_info['best_loss']}"
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
def _infer_config_from_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]:
|
| 221 |
+
"""
|
| 222 |
+
Infer model configuration from state dict when config is missing.
|
| 223 |
+
|
| 224 |
+
Args:
|
| 225 |
+
state_dict: Model parameter dictionary
|
| 226 |
+
|
| 227 |
+
Returns:
|
| 228 |
+
Inferred configuration dictionary
|
| 229 |
+
|
| 230 |
+
Implementation Details:
|
| 231 |
+
- Analyzes parameter shapes to determine architecture
|
| 232 |
+
- Makes reasonable assumptions about standard GPT architecture
|
| 233 |
+
- Provides fallback values for missing parameters
|
| 234 |
+
"""
|
| 235 |
+
# Extract key dimensions from parameter shapes
|
| 236 |
+
# This reverse-engineers the model architecture
|
| 237 |
+
|
| 238 |
+
# Embedding layer tells us vocab size and embedding dimension
|
| 239 |
+
if "transformer.wte.weight" in state_dict:
|
| 240 |
+
vocab_size, n_embd = state_dict["transformer.wte.weight"].shape
|
| 241 |
+
else:
|
| 242 |
+
# Fallback defaults
|
| 243 |
+
vocab_size, n_embd = 32000, 512
|
| 244 |
+
|
| 245 |
+
# Count transformer blocks to get number of layers
|
| 246 |
+
# Look for attention weight patterns
|
| 247 |
+
n_layer = 0
|
| 248 |
+
for key in state_dict.keys():
|
| 249 |
+
if "attn.c_attn.weight" in key:
|
| 250 |
+
# Extract layer number from key like 'transformer.h.0.attn.c_attn.weight'
|
| 251 |
+
layer_num = int(key.split(".")[2])
|
| 252 |
+
n_layer = max(n_layer, layer_num + 1)
|
| 253 |
+
|
| 254 |
+
# Infer number of attention heads from attention weights
|
| 255 |
+
# The c_attn weight combines query, key, value projections
|
| 256 |
+
if "transformer.h.0.attn.c_attn.weight" in state_dict:
|
| 257 |
+
_ = state_dict["transformer.h.0.attn.c_attn.weight"].shape
|
| 258 |
+
# Shape is [n_embd, 3 * n_embd] for combined Q,K,V
|
| 259 |
+
# So n_head = n_embd / head_dim, assuming head_dim = 64
|
| 260 |
+
n_head = n_embd // 64 # Standard head dimension
|
| 261 |
+
else:
|
| 262 |
+
n_head = 8 # Fallback
|
| 263 |
+
|
| 264 |
+
# Construct configuration dictionary
|
| 265 |
+
# Use reasonable defaults for missing values
|
| 266 |
+
config = {
|
| 267 |
+
"vocab_size": vocab_size,
|
| 268 |
+
"n_layer": n_layer,
|
| 269 |
+
"n_head": n_head,
|
| 270 |
+
"n_embd": n_embd,
|
| 271 |
+
"block_size": 1024, # Standard context length
|
| 272 |
+
"dropout": 0.1, # Standard dropout rate
|
| 273 |
+
"bias": True, # Most models use bias
|
| 274 |
+
"model_name": f"gpt-inferred-{n_layer}L",
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
print(f"π Inferred config: {config}")
|
| 278 |
+
return config
|
| 279 |
+
|
| 280 |
+
def _load_tokenizer(self):
|
| 281 |
+
"""
|
| 282 |
+
Load the SentencePiece tokenizer.
|
| 283 |
+
|
| 284 |
+
Implementation Details:
|
| 285 |
+
- Searches multiple possible tokenizer locations
|
| 286 |
+
- Validates tokenizer vocabulary size against model
|
| 287 |
+
- Sets up special tokens if available
|
| 288 |
+
"""
|
| 289 |
+
# Try multiple possible tokenizer locations
|
| 290 |
+
# Different training setups may store tokenizer in different places
|
| 291 |
+
possible_paths = [
|
| 292 |
+
self.model_dir / "tokenizer.model",
|
| 293 |
+
self.model_dir.parent / "tokenizer" / "tokenizer.model",
|
| 294 |
+
Path("data/tokenizer/tokenizer.model"),
|
| 295 |
+
self.model_dir / ".." / "tokenizer" / "tokenizer.model",
|
| 296 |
+
]
|
| 297 |
+
|
| 298 |
+
tokenizer_path = None
|
| 299 |
+
for path in possible_paths:
|
| 300 |
+
if path.exists():
|
| 301 |
+
tokenizer_path = path
|
| 302 |
+
break
|
| 303 |
+
|
| 304 |
+
if tokenizer_path is None:
|
| 305 |
+
raise FileNotFoundError(f"Tokenizer not found in any of: {possible_paths}")
|
| 306 |
+
|
| 307 |
+
print(f"π Loading tokenizer from: {tokenizer_path}")
|
| 308 |
+
|
| 309 |
+
# Load SentencePiece tokenizer
|
| 310 |
+
# This handles all text-to-token and token-to-text conversion
|
| 311 |
+
try:
|
| 312 |
+
self.tokenizer = spm.SentencePieceProcessor()
|
| 313 |
+
self.tokenizer.load(str(tokenizer_path))
|
| 314 |
+
print(f"β
Tokenizer loaded: {self.tokenizer.vocab_size()} vocabulary")
|
| 315 |
+
except Exception as e:
|
| 316 |
+
raise RuntimeError(f"Failed to load tokenizer: {e}")
|
| 317 |
+
|
| 318 |
+
def _validate_setup(self):
|
| 319 |
+
"""
|
| 320 |
+
Validate that model and tokenizer are compatible.
|
| 321 |
+
|
| 322 |
+
Implementation Details:
|
| 323 |
+
- Checks vocabulary size consistency
|
| 324 |
+
- Tests basic tokenization and model forward pass
|
| 325 |
+
- Warns about potential compatibility issues
|
| 326 |
+
"""
|
| 327 |
+
# Check vocabulary size consistency
|
| 328 |
+
# Model and tokenizer should have matching vocabulary
|
| 329 |
+
model_vocab_size = self.config.vocab_size
|
| 330 |
+
tokenizer_vocab_size = self.tokenizer.vocab_size()
|
| 331 |
+
|
| 332 |
+
if model_vocab_size != tokenizer_vocab_size:
|
| 333 |
+
print("β οΈ Warning: Vocabulary size mismatch!")
|
| 334 |
+
print(f" Model expects: {model_vocab_size}")
|
| 335 |
+
print(f" Tokenizer has: {tokenizer_vocab_size}")
|
| 336 |
+
print(" This may cause generation issues.")
|
| 337 |
+
|
| 338 |
+
# Test basic functionality
|
| 339 |
+
# Quick validation that everything works together
|
| 340 |
+
try:
|
| 341 |
+
# Test tokenization
|
| 342 |
+
test_text = "Hello world"
|
| 343 |
+
tokens = self.tokenizer.encode(test_text)
|
| 344 |
+
_ = self.tokenizer.decode(tokens)
|
| 345 |
+
|
| 346 |
+
# Test model forward pass
|
| 347 |
+
input_ids = torch.tensor([tokens[:5]], dtype=torch.long, device=self.device)
|
| 348 |
+
with torch.no_grad():
|
| 349 |
+
_ = self.model(input_ids)
|
| 350 |
+
|
| 351 |
+
print("β
Validation passed: tokenization and model forward pass work")
|
| 352 |
+
|
| 353 |
+
except Exception as e:
|
| 354 |
+
print(f"β οΈ Validation warning: {e}")
|
| 355 |
+
print(" Generation may still work, but there might be issues.")
|
| 356 |
+
|
| 357 |
+
def generate(
|
| 358 |
+
self,
|
| 359 |
+
prompt: str,
|
| 360 |
+
max_length: int = 100,
|
| 361 |
+
temperature: float = 0.7,
|
| 362 |
+
top_k: Optional[int] = 40,
|
| 363 |
+
top_p: Optional[float] = 0.9,
|
| 364 |
+
num_return_sequences: int = 1,
|
| 365 |
+
do_sample: bool = True,
|
| 366 |
+
repetition_penalty: float = 1.0,
|
| 367 |
+
) -> List[str]:
|
| 368 |
+
"""
|
| 369 |
+
Generate text from a prompt using the loaded model.
|
| 370 |
+
|
| 371 |
+
Args:
|
| 372 |
+
prompt: Input text to continue
|
| 373 |
+
max_length: Maximum number of tokens to generate
|
| 374 |
+
temperature: Sampling temperature (0.1-2.0, higher = more random)
|
| 375 |
+
top_k: Limit to top-k most likely tokens (None = no limit)
|
| 376 |
+
top_p: Nucleus sampling threshold (None = no nucleus sampling)
|
| 377 |
+
num_return_sequences: Number of sequences to generate
|
| 378 |
+
do_sample: Whether to use sampling (False = greedy)
|
| 379 |
+
repetition_penalty: Penalty for repeating tokens (1.0 = no penalty)
|
| 380 |
+
|
| 381 |
+
Returns:
|
| 382 |
+
List of generated text strings
|
| 383 |
+
|
| 384 |
+
Implementation Details:
|
| 385 |
+
- Uses autoregressive generation (one token at a time)
|
| 386 |
+
- Supports multiple sampling strategies (greedy, top-k, nucleus)
|
| 387 |
+
- Handles context length limits gracefully
|
| 388 |
+
- Applies repetition penalty to improve quality
|
| 389 |
+
- Returns only the generated portion (excludes input prompt)
|
| 390 |
+
"""
|
| 391 |
+
print(f"π― Generating text for: '{prompt[:50]}{'...' if len(prompt) > 50 else ''}'")
|
| 392 |
+
print(
|
| 393 |
+
f"βοΈ Parameters: max_length={max_length}, temperature={temperature}, "
|
| 394 |
+
f"top_k={top_k}, top_p={top_p}"
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
# Tokenize input prompt
|
| 398 |
+
# Convert text to token IDs for model processing
|
| 399 |
+
try:
|
| 400 |
+
input_tokens = self.tokenizer.encode(prompt)
|
| 401 |
+
if len(input_tokens) == 0:
|
| 402 |
+
raise ValueError("Empty tokenization result")
|
| 403 |
+
except Exception as e:
|
| 404 |
+
raise RuntimeError(f"Failed to tokenize prompt: {e}")
|
| 405 |
+
|
| 406 |
+
# Check prompt length against model context
|
| 407 |
+
# Ensure we don't exceed model's maximum sequence length
|
| 408 |
+
max_context = self.config.block_size
|
| 409 |
+
if len(input_tokens) >= max_context:
|
| 410 |
+
print(
|
| 411 |
+
f"β οΈ Warning: Prompt length ({len(input_tokens)}) approaches "
|
| 412 |
+
f"context limit ({max_context})"
|
| 413 |
+
)
|
| 414 |
+
# Truncate prompt if necessary
|
| 415 |
+
input_tokens = input_tokens[-(max_context - max_length) :]
|
| 416 |
+
print(f" Truncated prompt to {len(input_tokens)} tokens")
|
| 417 |
+
|
| 418 |
+
# Generate multiple sequences
|
| 419 |
+
# Each sequence is generated independently
|
| 420 |
+
generated_texts = []
|
| 421 |
+
|
| 422 |
+
for seq_idx in range(num_return_sequences):
|
| 423 |
+
if num_return_sequences > 1:
|
| 424 |
+
print(f"π Generating sequence {seq_idx + 1}/{num_return_sequences}")
|
| 425 |
+
|
| 426 |
+
try:
|
| 427 |
+
generated_text = self._generate_single_sequence(
|
| 428 |
+
input_tokens=input_tokens,
|
| 429 |
+
max_length=max_length,
|
| 430 |
+
temperature=temperature,
|
| 431 |
+
top_k=top_k,
|
| 432 |
+
top_p=top_p,
|
| 433 |
+
do_sample=do_sample,
|
| 434 |
+
repetition_penalty=repetition_penalty,
|
| 435 |
+
)
|
| 436 |
+
generated_texts.append(generated_text)
|
| 437 |
+
|
| 438 |
+
except Exception as e:
|
| 439 |
+
print(f"β οΈ Generation failed for sequence {seq_idx + 1}: {e}")
|
| 440 |
+
generated_texts.append(f"Generation error: {e}")
|
| 441 |
+
|
| 442 |
+
return generated_texts
|
| 443 |
+
|
| 444 |
+
def _generate_single_sequence(
|
| 445 |
+
self,
|
| 446 |
+
input_tokens: List[int],
|
| 447 |
+
max_length: int,
|
| 448 |
+
temperature: float,
|
| 449 |
+
top_k: Optional[int],
|
| 450 |
+
top_p: Optional[float],
|
| 451 |
+
do_sample: bool,
|
| 452 |
+
repetition_penalty: float,
|
| 453 |
+
) -> str:
|
| 454 |
+
"""
|
| 455 |
+
Generate a single text sequence using autoregressive sampling.
|
| 456 |
+
|
| 457 |
+
Args:
|
| 458 |
+
input_tokens: Tokenized input prompt
|
| 459 |
+
max_length: Maximum tokens to generate
|
| 460 |
+
temperature: Sampling temperature
|
| 461 |
+
top_k: Top-k sampling limit
|
| 462 |
+
top_p: Nucleus sampling threshold
|
| 463 |
+
do_sample: Whether to use sampling vs greedy
|
| 464 |
+
repetition_penalty: Repetition penalty factor
|
| 465 |
+
|
| 466 |
+
Returns:
|
| 467 |
+
Generated text string (excluding input prompt)
|
| 468 |
+
|
| 469 |
+
Implementation Details:
|
| 470 |
+
- Implements autoregressive generation loop
|
| 471 |
+
- Applies all specified sampling strategies
|
| 472 |
+
- Handles special tokens (EOS, padding)
|
| 473 |
+
- Tracks token frequencies for repetition penalty
|
| 474 |
+
"""
|
| 475 |
+
# Initialize generation state
|
| 476 |
+
# Keep track of all generated tokens and their frequencies
|
| 477 |
+
generated_tokens = input_tokens.copy()
|
| 478 |
+
token_frequencies = {} # For repetition penalty
|
| 479 |
+
|
| 480 |
+
# Count initial token frequencies
|
| 481 |
+
# This helps apply repetition penalty from the start
|
| 482 |
+
for token in input_tokens:
|
| 483 |
+
token_frequencies[token] = token_frequencies.get(token, 0) + 1
|
| 484 |
+
|
| 485 |
+
# Set model to evaluation mode and disable gradients
|
| 486 |
+
# This ensures consistent inference behavior and saves memory
|
| 487 |
+
self.model.eval()
|
| 488 |
+
|
| 489 |
+
with torch.no_grad():
|
| 490 |
+
# Main generation loop
|
| 491 |
+
# Generate one token at a time until stopping condition
|
| 492 |
+
for step in range(max_length):
|
| 493 |
+
# Check context length limits
|
| 494 |
+
# Prevent exceeding model's maximum sequence length
|
| 495 |
+
if len(generated_tokens) >= self.config.block_size:
|
| 496 |
+
print(f"β οΈ Reached maximum context length ({self.config.block_size})")
|
| 497 |
+
break
|
| 498 |
+
|
| 499 |
+
# Prepare model input
|
| 500 |
+
# Use all generated tokens as context for next prediction
|
| 501 |
+
input_ids = torch.tensor([generated_tokens], dtype=torch.long, device=self.device)
|
| 502 |
+
|
| 503 |
+
try:
|
| 504 |
+
# Forward pass through model
|
| 505 |
+
# Get logits (raw predictions) for all vocabulary tokens
|
| 506 |
+
outputs = self.model(input_ids)
|
| 507 |
+
|
| 508 |
+
# Handle different model output formats
|
| 509 |
+
# Some models return tuples, others return tensors directly
|
| 510 |
+
if isinstance(outputs, tuple):
|
| 511 |
+
logits = outputs[0] # First element is usually logits
|
| 512 |
+
else:
|
| 513 |
+
logits = outputs
|
| 514 |
+
|
| 515 |
+
# Get predictions for next token (last position in sequence)
|
| 516 |
+
next_token_logits = logits[0, -1, :].float()
|
| 517 |
+
|
| 518 |
+
except Exception as e:
|
| 519 |
+
raise RuntimeError(f"Model forward pass failed at step {step}: {e}")
|
| 520 |
+
|
| 521 |
+
# Apply repetition penalty
|
| 522 |
+
# Reduce probability of recently used tokens
|
| 523 |
+
if repetition_penalty != 1.0:
|
| 524 |
+
for token, freq in token_frequencies.items():
|
| 525 |
+
if token < len(next_token_logits):
|
| 526 |
+
penalty = repetition_penalty**freq
|
| 527 |
+
if next_token_logits[token] > 0:
|
| 528 |
+
next_token_logits[token] /= penalty
|
| 529 |
+
else:
|
| 530 |
+
next_token_logits[token] *= penalty
|
| 531 |
+
|
| 532 |
+
# Apply sampling strategy to select next token
|
| 533 |
+
# This determines the randomness and quality of generation
|
| 534 |
+
if do_sample:
|
| 535 |
+
next_token = self._sample_next_token(
|
| 536 |
+
next_token_logits, temperature, top_k, top_p
|
| 537 |
+
)
|
| 538 |
+
else:
|
| 539 |
+
# Greedy decoding: always pick most likely token
|
| 540 |
+
next_token = torch.argmax(next_token_logits).item()
|
| 541 |
+
|
| 542 |
+
# Add generated token to sequence
|
| 543 |
+
generated_tokens.append(next_token)
|
| 544 |
+
|
| 545 |
+
# Update token frequency for repetition penalty
|
| 546 |
+
token_frequencies[next_token] = token_frequencies.get(next_token, 0) + 1
|
| 547 |
+
|
| 548 |
+
# Check for end-of-sequence token
|
| 549 |
+
# Some models/tokenizers have special EOS tokens
|
| 550 |
+
if hasattr(self.tokenizer, "eos_id") and next_token == self.tokenizer.eos_id():
|
| 551 |
+
print(f"π Reached end-of-sequence token at step {step}")
|
| 552 |
+
break
|
| 553 |
+
|
| 554 |
+
# Optional: Check for other stopping conditions
|
| 555 |
+
# Could add custom stop words or patterns here
|
| 556 |
+
|
| 557 |
+
# Decode generated tokens to text
|
| 558 |
+
# Convert token IDs back to readable text, excluding input prompt
|
| 559 |
+
try:
|
| 560 |
+
# Extract only newly generated tokens (exclude input prompt)
|
| 561 |
+
new_tokens = generated_tokens[len(input_tokens) :]
|
| 562 |
+
|
| 563 |
+
if len(new_tokens) == 0:
|
| 564 |
+
return "β οΈ No tokens generated"
|
| 565 |
+
|
| 566 |
+
# Decode to text using tokenizer
|
| 567 |
+
generated_text = self.tokenizer.decode(new_tokens)
|
| 568 |
+
|
| 569 |
+
print(f"β
Generated {len(new_tokens)} tokens")
|
| 570 |
+
return generated_text
|
| 571 |
+
|
| 572 |
+
except Exception as e:
|
| 573 |
+
raise RuntimeError(f"Failed to decode generated tokens: {e}")
|
| 574 |
+
|
| 575 |
+
def _sample_next_token(
|
| 576 |
+
self, logits: torch.Tensor, temperature: float, top_k: Optional[int], top_p: Optional[float]
|
| 577 |
+
) -> int:
|
| 578 |
+
"""
|
| 579 |
+
Sample next token using specified sampling strategy.
|
| 580 |
+
|
| 581 |
+
Args:
|
| 582 |
+
logits: Raw model predictions for next token
|
| 583 |
+
temperature: Sampling temperature
|
| 584 |
+
top_k: Top-k sampling limit
|
| 585 |
+
top_p: Nucleus sampling threshold
|
| 586 |
+
|
| 587 |
+
Returns:
|
| 588 |
+
Selected token ID
|
| 589 |
+
|
| 590 |
+
Implementation Details:
|
| 591 |
+
- Applies temperature scaling for randomness control
|
| 592 |
+
- Implements top-k sampling to limit choices
|
| 593 |
+
- Implements nucleus (top-p) sampling for quality
|
| 594 |
+
- Uses multinomial sampling for final selection
|
| 595 |
+
"""
|
| 596 |
+
# Apply temperature scaling
|
| 597 |
+
# Higher temperature = more random, lower = more deterministic
|
| 598 |
+
if temperature != 1.0:
|
| 599 |
+
logits = logits / temperature
|
| 600 |
+
|
| 601 |
+
# Apply top-k filtering
|
| 602 |
+
# Only consider the k most likely tokens
|
| 603 |
+
if top_k is not None and top_k > 0:
|
| 604 |
+
# Get indices of top-k tokens
|
| 605 |
+
top_k_tokens = min(top_k, logits.size(-1))
|
| 606 |
+
top_k_values, top_k_indices = torch.topk(logits, top_k_tokens)
|
| 607 |
+
|
| 608 |
+
# Zero out non-top-k logits
|
| 609 |
+
filtered_logits = torch.full_like(logits, float("-inf"))
|
| 610 |
+
filtered_logits[top_k_indices] = top_k_values
|
| 611 |
+
logits = filtered_logits
|
| 612 |
+
|
| 613 |
+
# Apply nucleus (top-p) sampling
|
| 614 |
+
# Dynamically adjust vocabulary based on cumulative probability
|
| 615 |
+
if top_p is not None and top_p < 1.0:
|
| 616 |
+
# Sort logits in descending order
|
| 617 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 618 |
+
|
| 619 |
+
# Calculate cumulative probabilities
|
| 620 |
+
sorted_probs = torch.softmax(sorted_logits, dim=-1)
|
| 621 |
+
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
| 622 |
+
|
| 623 |
+
# Find cutoff point where cumulative probability exceeds top_p
|
| 624 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 625 |
+
|
| 626 |
+
# Keep at least the top token
|
| 627 |
+
sorted_indices_to_remove[0] = False
|
| 628 |
+
|
| 629 |
+
# Zero out tokens beyond nucleus
|
| 630 |
+
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
| 631 |
+
logits[indices_to_remove] = float("-inf")
|
| 632 |
+
|
| 633 |
+
# Convert logits to probabilities and sample
|
| 634 |
+
# Use multinomial sampling for final token selection
|
| 635 |
+
probs = torch.softmax(logits, dim=-1)
|
| 636 |
+
next_token = torch.multinomial(probs, num_samples=1).item()
|
| 637 |
+
|
| 638 |
+
return next_token
|
| 639 |
+
|
| 640 |
+
def generate_batch(self, prompts: List[str], **generation_kwargs) -> List[List[str]]:
|
| 641 |
+
"""
|
| 642 |
+
Generate text for multiple prompts.
|
| 643 |
+
|
| 644 |
+
Args:
|
| 645 |
+
prompts: List of input prompts
|
| 646 |
+
**generation_kwargs: Arguments passed to generate()
|
| 647 |
+
|
| 648 |
+
Returns:
|
| 649 |
+
List of lists, where each inner list contains generated texts for one prompt
|
| 650 |
+
|
| 651 |
+
Implementation Details:
|
| 652 |
+
- Processes prompts sequentially (could be parallelized)
|
| 653 |
+
- Applies same generation parameters to all prompts
|
| 654 |
+
- Handles errors gracefully for individual prompts
|
| 655 |
+
"""
|
| 656 |
+
print(f"π Generating text for {len(prompts)} prompts...")
|
| 657 |
+
|
| 658 |
+
all_results = []
|
| 659 |
+
|
| 660 |
+
for i, prompt in enumerate(prompts):
|
| 661 |
+
print(f"\n--- Prompt {i + 1}/{len(prompts)} ---")
|
| 662 |
+
|
| 663 |
+
try:
|
| 664 |
+
results = self.generate(prompt, **generation_kwargs)
|
| 665 |
+
all_results.append(results)
|
| 666 |
+
|
| 667 |
+
except Exception as e:
|
| 668 |
+
print(f"β Failed to generate for prompt {i + 1}: {e}")
|
| 669 |
+
all_results.append([f"Generation failed: {e}"])
|
| 670 |
+
|
| 671 |
+
return all_results
|
| 672 |
+
|
| 673 |
+
|
| 674 |
+
def load_prompts_from_file(file_path: str) -> List[str]:
|
| 675 |
+
"""
|
| 676 |
+
Load prompts from a text file.
|
| 677 |
+
|
| 678 |
+
Args:
|
| 679 |
+
file_path: Path to file containing prompts (one per line)
|
| 680 |
+
|
| 681 |
+
Returns:
|
| 682 |
+
List of prompt strings
|
| 683 |
+
|
| 684 |
+
Implementation Details:
|
| 685 |
+
- Reads file line by line
|
| 686 |
+
- Strips whitespace and filters empty lines
|
| 687 |
+
- Handles different text encodings gracefully
|
| 688 |
+
"""
|
| 689 |
+
try:
|
| 690 |
+
with open(file_path, "r", encoding="utf-8") as f:
|
| 691 |
+
prompts = [line.strip() for line in f if line.strip()]
|
| 692 |
+
|
| 693 |
+
print(f"π Loaded {len(prompts)} prompts from {file_path}")
|
| 694 |
+
return prompts
|
| 695 |
+
|
| 696 |
+
except Exception as e:
|
| 697 |
+
raise RuntimeError(f"Failed to load prompts from {file_path}: {e}")
|
| 698 |
+
|
| 699 |
+
|
| 700 |
+
def save_results_to_file(results: List[str], output_path: str, prompts: List[str] = None):
|
| 701 |
+
"""
|
| 702 |
+
Save generation results to a text file.
|
| 703 |
+
|
| 704 |
+
Args:
|
| 705 |
+
results: Generated text results
|
| 706 |
+
output_path: Path to output file
|
| 707 |
+
prompts: Original prompts (optional, for context)
|
| 708 |
+
|
| 709 |
+
Implementation Details:
|
| 710 |
+
- Formats output with clear separators
|
| 711 |
+
- Includes prompts and metadata when available
|
| 712 |
+
- Handles file creation and error reporting
|
| 713 |
+
"""
|
| 714 |
+
try:
|
| 715 |
+
with open(output_path, "w", encoding="utf-8") as f:
|
| 716 |
+
f.write("# OpenLLM Text Generation Results\n")
|
| 717 |
+
f.write(f"# Generated at: {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
|
| 718 |
+
f.write(f"# Total samples: {len(results)}\n\n")
|
| 719 |
+
|
| 720 |
+
for i, result in enumerate(results):
|
| 721 |
+
f.write(f"--- Sample {i + 1} ---\n")
|
| 722 |
+
|
| 723 |
+
if prompts and i < len(prompts):
|
| 724 |
+
f.write(f"Prompt: {prompts[i]}\n\n")
|
| 725 |
+
|
| 726 |
+
if isinstance(result, list):
|
| 727 |
+
for j, text in enumerate(result):
|
| 728 |
+
f.write(f"Generated {j + 1}: {text}\n\n")
|
| 729 |
+
else:
|
| 730 |
+
f.write(f"Generated: {result}\n\n")
|
| 731 |
+
|
| 732 |
+
f.write("-" * 50 + "\n\n")
|
| 733 |
+
|
| 734 |
+
print(f"πΎ Results saved to: {output_path}")
|
| 735 |
+
|
| 736 |
+
except Exception as e:
|
| 737 |
+
raise RuntimeError(f"Failed to save results to {output_path}: {e}")
|
| 738 |
+
|
| 739 |
+
|
| 740 |
+
def main():
|
| 741 |
+
"""Main function for command-line text generation."""
|
| 742 |
+
parser = argparse.ArgumentParser(
|
| 743 |
+
description="OpenLLM Text Generation",
|
| 744 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 745 |
+
epilog="""
|
| 746 |
+
Examples:
|
| 747 |
+
# Basic text generation
|
| 748 |
+
python core/src/generate_text.py \\
|
| 749 |
+
--model_dir ./openllm-trained \\
|
| 750 |
+
--prompt "Hello, how are you?" \\
|
| 751 |
+
--max_length 100
|
| 752 |
+
|
| 753 |
+
# Advanced generation with parameters
|
| 754 |
+
python core/src/generate_text.py \\
|
| 755 |
+
--model_dir ./openllm-trained \\
|
| 756 |
+
--prompt "The future of AI is" \\
|
| 757 |
+
--max_length 200 \\
|
| 758 |
+
--temperature 0.8 \\
|
| 759 |
+
--top_k 50 \\
|
| 760 |
+
--top_p 0.9
|
| 761 |
+
""",
|
| 762 |
+
)
|
| 763 |
+
|
| 764 |
+
parser.add_argument(
|
| 765 |
+
"--model_dir",
|
| 766 |
+
required=True,
|
| 767 |
+
help="Directory containing trained model checkpoints",
|
| 768 |
+
)
|
| 769 |
+
|
| 770 |
+
parser.add_argument(
|
| 771 |
+
"--prompt",
|
| 772 |
+
required=True,
|
| 773 |
+
help="Input text prompt for generation",
|
| 774 |
+
)
|
| 775 |
+
|
| 776 |
+
parser.add_argument(
|
| 777 |
+
"--max_length",
|
| 778 |
+
type=int,
|
| 779 |
+
default=100,
|
| 780 |
+
help="Maximum number of tokens to generate (default: 100)",
|
| 781 |
+
)
|
| 782 |
+
|
| 783 |
+
parser.add_argument(
|
| 784 |
+
"--temperature",
|
| 785 |
+
type=float,
|
| 786 |
+
default=0.7,
|
| 787 |
+
help="Sampling temperature (default: 0.7)",
|
| 788 |
+
)
|
| 789 |
+
|
| 790 |
+
parser.add_argument(
|
| 791 |
+
"--top_k",
|
| 792 |
+
type=int,
|
| 793 |
+
default=40,
|
| 794 |
+
help="Top-k sampling parameter (default: 40)",
|
| 795 |
+
)
|
| 796 |
+
|
| 797 |
+
parser.add_argument(
|
| 798 |
+
"--top_p",
|
| 799 |
+
type=float,
|
| 800 |
+
default=0.9,
|
| 801 |
+
help="Nucleus sampling parameter (default: 0.9)",
|
| 802 |
+
)
|
| 803 |
+
|
| 804 |
+
parser.add_argument(
|
| 805 |
+
"--device",
|
| 806 |
+
default="auto",
|
| 807 |
+
choices=["auto", "cpu", "cuda"],
|
| 808 |
+
help="Device to use for generation (default: auto)",
|
| 809 |
+
)
|
| 810 |
+
|
| 811 |
+
args = parser.parse_args()
|
| 812 |
+
|
| 813 |
+
print("π OpenLLM Text Generation")
|
| 814 |
+
print("=" * 50)
|
| 815 |
+
|
| 816 |
+
try:
|
| 817 |
+
# Initialize text generator
|
| 818 |
+
generator = TextGenerator(args.model_dir, args.device)
|
| 819 |
+
|
| 820 |
+
# Generate text
|
| 821 |
+
print(f"π Prompt: {args.prompt}")
|
| 822 |
+
print(f"βοΈ Parameters: max_length={args.max_length}, temperature={args.temperature}")
|
| 823 |
+
|
| 824 |
+
generated_text = generator.generate(
|
| 825 |
+
prompt=args.prompt,
|
| 826 |
+
max_length=args.max_length,
|
| 827 |
+
temperature=args.temperature,
|
| 828 |
+
top_k=args.top_k,
|
| 829 |
+
top_p=args.top_p,
|
| 830 |
+
)
|
| 831 |
+
|
| 832 |
+
print("\nπ― Generated text:")
|
| 833 |
+
print(f"{generated_text}")
|
| 834 |
+
|
| 835 |
+
except Exception as e:
|
| 836 |
+
print(f"\nβ Error: {e}")
|
| 837 |
+
import traceback
|
| 838 |
+
|
| 839 |
+
traceback.print_exc()
|
| 840 |
+
return False
|
| 841 |
+
|
| 842 |
+
return True
|
| 843 |
+
|
| 844 |
+
|
| 845 |
+
def load_tokenizer(tokenizer_path: str):
|
| 846 |
+
"""
|
| 847 |
+
Load tokenizer for testing purposes.
|
| 848 |
+
|
| 849 |
+
This function is used by tests to load tokenizers without initializing the full generator.
|
| 850 |
+
|
| 851 |
+
Args:
|
| 852 |
+
tokenizer_path: Path to tokenizer model file
|
| 853 |
+
|
| 854 |
+
Returns:
|
| 855 |
+
SentencePieceProcessor: Loaded tokenizer
|
| 856 |
+
"""
|
| 857 |
+
import sentencepiece as spm
|
| 858 |
+
|
| 859 |
+
tokenizer = spm.SentencePieceProcessor()
|
| 860 |
+
tokenizer.load(tokenizer_path)
|
| 861 |
+
return tokenizer
|
| 862 |
+
|
| 863 |
+
|
| 864 |
+
if __name__ == "__main__":
|
| 865 |
+
success = main()
|
| 866 |
+
exit(0 if success else 1)
|
core/src/inference_server.py
ADDED
|
@@ -0,0 +1,907 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright (C) 2024 Louis Chua Bean Chong
|
| 3 |
+
#
|
| 4 |
+
# This file is part of OpenLLM.
|
| 5 |
+
#
|
| 6 |
+
# OpenLLM is dual-licensed:
|
| 7 |
+
# 1. For open source use: GNU General Public License v3.0
|
| 8 |
+
# 2. For commercial use: Commercial License (contact for details)
|
| 9 |
+
#
|
| 10 |
+
# See LICENSE and docs/LICENSES.md for full license information.
|
| 11 |
+
|
| 12 |
+
"""
|
| 13 |
+
OpenLLM Inference Server
|
| 14 |
+
|
| 15 |
+
This script implements the REST API server for OpenLLM model inference
|
| 16 |
+
as specified in Step 6 of the training pipeline.
|
| 17 |
+
|
| 18 |
+
Features:
|
| 19 |
+
- FastAPI-based REST API
|
| 20 |
+
- Support for multiple model formats (PyTorch, Hugging Face, ONNX)
|
| 21 |
+
- Text generation with configurable parameters
|
| 22 |
+
- Health checks and metrics
|
| 23 |
+
- Production-ready deployment
|
| 24 |
+
|
| 25 |
+
Usage:
|
| 26 |
+
python core/src/inference_server.py \
|
| 27 |
+
--model_path exports/huggingface/ \
|
| 28 |
+
--host 0.0.0.0 \
|
| 29 |
+
--port 8000 \
|
| 30 |
+
--max_length 512
|
| 31 |
+
|
| 32 |
+
API Endpoints:
|
| 33 |
+
POST /generate - Generate text from prompt
|
| 34 |
+
GET /health - Health check
|
| 35 |
+
GET /info - Model information
|
| 36 |
+
|
| 37 |
+
Author: Louis Chua Bean Chong
|
| 38 |
+
License: GPLv3
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
import argparse
|
| 42 |
+
import json
|
| 43 |
+
import time
|
| 44 |
+
from pathlib import Path
|
| 45 |
+
from typing import Any, Dict, List, Optional
|
| 46 |
+
|
| 47 |
+
import uvicorn
|
| 48 |
+
|
| 49 |
+
# FastAPI imports (open source)
|
| 50 |
+
try:
|
| 51 |
+
from fastapi import BackgroundTasks, FastAPI, HTTPException
|
| 52 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 53 |
+
from pydantic import BaseModel, Field
|
| 54 |
+
except ImportError:
|
| 55 |
+
raise ImportError("Install FastAPI: pip install fastapi uvicorn[standard]")
|
| 56 |
+
|
| 57 |
+
import os
|
| 58 |
+
|
| 59 |
+
# Import our modules
|
| 60 |
+
import sys
|
| 61 |
+
|
| 62 |
+
import numpy as np
|
| 63 |
+
import sentencepiece as smp
|
| 64 |
+
import torch
|
| 65 |
+
|
| 66 |
+
# Add current directory to path for imports
|
| 67 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 68 |
+
|
| 69 |
+
from model import create_model
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class TextGenerationConfig(BaseModel):
|
| 73 |
+
"""Configuration for text generation parameters."""
|
| 74 |
+
|
| 75 |
+
max_new_tokens: int = Field(
|
| 76 |
+
256, description="Maximum number of tokens to generate", ge=1, le=2048
|
| 77 |
+
)
|
| 78 |
+
temperature: float = Field(0.7, description="Sampling temperature", ge=0.0, le=2.0)
|
| 79 |
+
top_k: Optional[int] = Field(40, description="Top-k sampling parameter", ge=1, le=1000)
|
| 80 |
+
top_p: Optional[float] = Field(0.9, description="Nucleus sampling parameter", ge=0.1, le=1.0)
|
| 81 |
+
num_return_sequences: int = Field(1, description="Number of sequences to generate", ge=1, le=5)
|
| 82 |
+
stop_sequences: Optional[List[str]] = Field(
|
| 83 |
+
None, description="Stop generation at these sequences"
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class GenerationRequest(BaseModel):
|
| 88 |
+
"""Request model for text generation."""
|
| 89 |
+
|
| 90 |
+
prompt: str = Field(..., description="Input text prompt")
|
| 91 |
+
max_length: int = Field(256, description="Maximum generation length", ge=1, le=2048)
|
| 92 |
+
temperature: float = Field(0.7, description="Sampling temperature", ge=0.0, le=2.0)
|
| 93 |
+
top_k: Optional[int] = Field(40, description="Top-k sampling parameter", ge=1, le=1000)
|
| 94 |
+
top_p: Optional[float] = Field(0.9, description="Nucleus sampling parameter", ge=0.1, le=1.0)
|
| 95 |
+
num_return_sequences: int = Field(1, description="Number of sequences to generate", ge=1, le=5)
|
| 96 |
+
stop_sequences: Optional[List[str]] = Field(
|
| 97 |
+
None, description="Stop generation at these sequences"
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class GenerationResponse(BaseModel):
|
| 102 |
+
"""Response model for text generation."""
|
| 103 |
+
|
| 104 |
+
generated_text: List[str] = Field(..., description="Generated text sequences")
|
| 105 |
+
prompt: str = Field(..., description="Original prompt")
|
| 106 |
+
generation_time: float = Field(..., description="Generation time in seconds")
|
| 107 |
+
parameters: Dict[str, Any] = Field(..., description="Generation parameters used")
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class ModelInfo(BaseModel):
|
| 111 |
+
"""Model information response."""
|
| 112 |
+
|
| 113 |
+
model_name: str
|
| 114 |
+
model_size: str
|
| 115 |
+
parameters: int
|
| 116 |
+
vocab_size: int
|
| 117 |
+
max_length: int
|
| 118 |
+
format: str
|
| 119 |
+
loaded_at: str
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class HealthResponse(BaseModel):
|
| 123 |
+
"""Health check response."""
|
| 124 |
+
|
| 125 |
+
status: str
|
| 126 |
+
model_loaded: bool
|
| 127 |
+
uptime_seconds: float
|
| 128 |
+
total_requests: int
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class OpenLLMInference:
|
| 132 |
+
"""
|
| 133 |
+
OpenLLM model inference engine.
|
| 134 |
+
|
| 135 |
+
Supports multiple model formats and provides text generation capabilities.
|
| 136 |
+
"""
|
| 137 |
+
|
| 138 |
+
def __init__(self, model_path: str, model_format: str = "auto"):
|
| 139 |
+
"""
|
| 140 |
+
Initialize inference engine.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
model_path: Path to exported model directory
|
| 144 |
+
model_format: Model format (pytorch, huggingface, onnx, auto)
|
| 145 |
+
"""
|
| 146 |
+
self.model_path = Path(model_path)
|
| 147 |
+
self.model_format = model_format
|
| 148 |
+
self.model = None
|
| 149 |
+
self.tokenizer = None
|
| 150 |
+
self.config = None
|
| 151 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 152 |
+
|
| 153 |
+
# Load model
|
| 154 |
+
self._load_model()
|
| 155 |
+
|
| 156 |
+
# Statistics
|
| 157 |
+
self.loaded_at = time.time()
|
| 158 |
+
self.total_requests = 0
|
| 159 |
+
|
| 160 |
+
print("π OpenLLM Inference Engine initialized")
|
| 161 |
+
print(f" Model: {self.config.get('model_name', 'Unknown')}")
|
| 162 |
+
print(f" Format: {self.detected_format}")
|
| 163 |
+
print(f" Device: {self.device}")
|
| 164 |
+
|
| 165 |
+
def _detect_format(self) -> str:
|
| 166 |
+
"""Auto-detect model format from directory contents."""
|
| 167 |
+
if (self.model_path / "model.pt").exists():
|
| 168 |
+
return "pytorch"
|
| 169 |
+
elif (self.model_path / "pytorch_model.bin").exists():
|
| 170 |
+
return "huggingface"
|
| 171 |
+
elif (self.model_path / "model.onnx").exists():
|
| 172 |
+
return "onnx"
|
| 173 |
+
else:
|
| 174 |
+
raise ValueError(f"Could not detect model format in {self.model_path}")
|
| 175 |
+
|
| 176 |
+
def _load_model(self):
|
| 177 |
+
"""Load model based on detected format."""
|
| 178 |
+
if self.model_format == "auto":
|
| 179 |
+
self.detected_format = self._detect_format()
|
| 180 |
+
else:
|
| 181 |
+
self.detected_format = self.model_format
|
| 182 |
+
|
| 183 |
+
print(f"π Loading {self.detected_format} model from {self.model_path}")
|
| 184 |
+
|
| 185 |
+
if self.detected_format == "pytorch":
|
| 186 |
+
self._load_pytorch_model()
|
| 187 |
+
elif self.detected_format == "huggingface":
|
| 188 |
+
self._load_huggingface_model()
|
| 189 |
+
elif self.detected_format == "onnx":
|
| 190 |
+
self._load_onnx_model()
|
| 191 |
+
else:
|
| 192 |
+
raise ValueError(f"Unsupported format: {self.detected_format}")
|
| 193 |
+
|
| 194 |
+
# Load tokenizer
|
| 195 |
+
self._load_tokenizer()
|
| 196 |
+
|
| 197 |
+
print("β
Model loaded successfully")
|
| 198 |
+
|
| 199 |
+
def _load_pytorch_model(self):
|
| 200 |
+
"""Load PyTorch format model."""
|
| 201 |
+
# Load config
|
| 202 |
+
with open(self.model_path / "config.json", "r") as f:
|
| 203 |
+
config_data = json.load(f)
|
| 204 |
+
|
| 205 |
+
self.config = config_data["model_config"]
|
| 206 |
+
|
| 207 |
+
# Load model
|
| 208 |
+
checkpoint = torch.load(self.model_path / "model.pt", map_location=self.device)
|
| 209 |
+
|
| 210 |
+
# Determine model size
|
| 211 |
+
n_layer = self.config.get("n_layer", 12)
|
| 212 |
+
if n_layer <= 6:
|
| 213 |
+
model_size = "small"
|
| 214 |
+
elif n_layer <= 12:
|
| 215 |
+
model_size = "medium"
|
| 216 |
+
else:
|
| 217 |
+
model_size = "large"
|
| 218 |
+
|
| 219 |
+
# Create model
|
| 220 |
+
self.model = create_model(model_size)
|
| 221 |
+
self.model.load_state_dict(checkpoint["model_state_dict"])
|
| 222 |
+
self.model.to(self.device)
|
| 223 |
+
self.model.eval()
|
| 224 |
+
|
| 225 |
+
def _load_huggingface_model(self):
|
| 226 |
+
"""Load Hugging Face format model."""
|
| 227 |
+
# Load config
|
| 228 |
+
with open(self.model_path / "config.json", "r") as f:
|
| 229 |
+
self.config = json.load(f)
|
| 230 |
+
|
| 231 |
+
# Load model weights
|
| 232 |
+
state_dict = torch.load(self.model_path / "pytorch_model.bin", map_location=self.device)
|
| 233 |
+
|
| 234 |
+
# Determine model size
|
| 235 |
+
n_layer = self.config.get("n_layer", 12)
|
| 236 |
+
if n_layer <= 6:
|
| 237 |
+
model_size = "small"
|
| 238 |
+
elif n_layer <= 12:
|
| 239 |
+
model_size = "medium"
|
| 240 |
+
else:
|
| 241 |
+
model_size = "large"
|
| 242 |
+
|
| 243 |
+
# Create model
|
| 244 |
+
self.model = create_model(model_size)
|
| 245 |
+
self.model.load_state_dict(state_dict)
|
| 246 |
+
self.model.to(self.device)
|
| 247 |
+
self.model.eval()
|
| 248 |
+
|
| 249 |
+
def _load_onnx_model(self):
|
| 250 |
+
"""Load ONNX format model."""
|
| 251 |
+
try:
|
| 252 |
+
import onnxruntime as ort
|
| 253 |
+
except ImportError:
|
| 254 |
+
raise ImportError("ONNX inference requires: pip install onnxruntime")
|
| 255 |
+
|
| 256 |
+
# Security mitigation: Validate model path to prevent arbitrary file access
|
| 257 |
+
model_file = self.model_path / "model.onnx"
|
| 258 |
+
if not model_file.exists():
|
| 259 |
+
raise FileNotFoundError(f"ONNX model not found: {model_file}")
|
| 260 |
+
|
| 261 |
+
# Security mitigation: Validate file is within expected directory
|
| 262 |
+
if not str(model_file).startswith(str(self.model_path)):
|
| 263 |
+
raise ValueError(f"Invalid model path: {model_file}")
|
| 264 |
+
|
| 265 |
+
# Load metadata with path validation
|
| 266 |
+
metadata_file = self.model_path / "metadata.json"
|
| 267 |
+
if not metadata_file.exists():
|
| 268 |
+
raise FileNotFoundError(f"ONNX metadata not found: {metadata_file}")
|
| 269 |
+
|
| 270 |
+
with open(metadata_file, "r") as f:
|
| 271 |
+
metadata = json.load(f)
|
| 272 |
+
|
| 273 |
+
self.config = metadata["model_config"]
|
| 274 |
+
|
| 275 |
+
# Create ONNX session with security options
|
| 276 |
+
providers = (
|
| 277 |
+
["CUDAExecutionProvider", "CPUExecutionProvider"]
|
| 278 |
+
if torch.cuda.is_available()
|
| 279 |
+
else ["CPUExecutionProvider"]
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
# Security mitigation: Use session options to restrict capabilities
|
| 283 |
+
session_options = ort.SessionOptions()
|
| 284 |
+
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
|
| 285 |
+
session_options.enable_mem_pattern = False # Disable memory optimization
|
| 286 |
+
session_options.enable_cpu_mem_arena = False # Disable CPU memory arena
|
| 287 |
+
|
| 288 |
+
self.onnx_session = ort.InferenceSession(
|
| 289 |
+
str(model_file), providers=providers, sess_options=session_options
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
# ONNX models don't need device management
|
| 293 |
+
self.device = "onnx"
|
| 294 |
+
|
| 295 |
+
def _load_tokenizer(self):
|
| 296 |
+
"""Load tokenizer."""
|
| 297 |
+
tokenizer_path = self.model_path / "tokenizer.model"
|
| 298 |
+
if not tokenizer_path.exists():
|
| 299 |
+
raise FileNotFoundError(f"Tokenizer not found: {tokenizer_path}")
|
| 300 |
+
|
| 301 |
+
self.tokenizer = smp.SentencePieceProcessor()
|
| 302 |
+
self.tokenizer.load(str(tokenizer_path))
|
| 303 |
+
|
| 304 |
+
def generate(
|
| 305 |
+
self,
|
| 306 |
+
prompt: str,
|
| 307 |
+
max_length: int = 256,
|
| 308 |
+
temperature: float = 0.7,
|
| 309 |
+
top_k: Optional[int] = 40,
|
| 310 |
+
top_p: Optional[float] = 0.9,
|
| 311 |
+
num_return_sequences: int = 1,
|
| 312 |
+
stop_sequences: Optional[List[str]] = None,
|
| 313 |
+
) -> List[str]:
|
| 314 |
+
"""
|
| 315 |
+
Generate text from prompt.
|
| 316 |
+
|
| 317 |
+
Args:
|
| 318 |
+
prompt: Input text prompt
|
| 319 |
+
max_length: Maximum generation length
|
| 320 |
+
temperature: Sampling temperature
|
| 321 |
+
top_k: Top-k sampling parameter
|
| 322 |
+
top_p: Nucleus sampling parameter
|
| 323 |
+
num_return_sequences: Number of sequences to generate
|
| 324 |
+
stop_sequences: Stop generation at these sequences
|
| 325 |
+
|
| 326 |
+
Returns:
|
| 327 |
+
List of generated text sequences
|
| 328 |
+
"""
|
| 329 |
+
self.total_requests += 1
|
| 330 |
+
|
| 331 |
+
if self.detected_format == "onnx":
|
| 332 |
+
return self._generate_onnx(
|
| 333 |
+
prompt, max_length, temperature, top_k, num_return_sequences, stop_sequences
|
| 334 |
+
)
|
| 335 |
+
else:
|
| 336 |
+
return self._generate_pytorch(
|
| 337 |
+
prompt, max_length, temperature, top_k, top_p, num_return_sequences, stop_sequences
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
def _generate_pytorch(
|
| 341 |
+
self,
|
| 342 |
+
prompt: str,
|
| 343 |
+
max_length: int,
|
| 344 |
+
temperature: float,
|
| 345 |
+
top_k: Optional[int],
|
| 346 |
+
top_p: Optional[float],
|
| 347 |
+
num_return_sequences: int,
|
| 348 |
+
stop_sequences: Optional[List[str]],
|
| 349 |
+
) -> List[str]:
|
| 350 |
+
"""Generate using PyTorch model."""
|
| 351 |
+
# Tokenize prompt
|
| 352 |
+
input_ids = self.tokenizer.encode(prompt)
|
| 353 |
+
input_tensor = torch.tensor(
|
| 354 |
+
[input_ids] * num_return_sequences, dtype=torch.long, device=self.device
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
# Generate
|
| 358 |
+
with torch.no_grad():
|
| 359 |
+
outputs = []
|
| 360 |
+
for _ in range(num_return_sequences):
|
| 361 |
+
# Use model's generate method if available
|
| 362 |
+
if hasattr(self.model, "generate"):
|
| 363 |
+
output = self.model.generate(
|
| 364 |
+
input_tensor[:1], # Single sequence
|
| 365 |
+
max_new_tokens=max_length,
|
| 366 |
+
temperature=temperature,
|
| 367 |
+
top_k=top_k,
|
| 368 |
+
)
|
| 369 |
+
generated_ids = output[0].tolist()
|
| 370 |
+
generated_text = self.tokenizer.decode(generated_ids[len(input_ids) :])
|
| 371 |
+
else:
|
| 372 |
+
# Fallback simple generation
|
| 373 |
+
generated_text = self._simple_generate(
|
| 374 |
+
input_tensor[:1], max_length, temperature
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
# Apply stop sequences
|
| 378 |
+
if stop_sequences:
|
| 379 |
+
for stop_seq in stop_sequences:
|
| 380 |
+
if stop_seq in generated_text:
|
| 381 |
+
generated_text = generated_text.split(stop_seq)[0]
|
| 382 |
+
break
|
| 383 |
+
|
| 384 |
+
outputs.append(generated_text)
|
| 385 |
+
|
| 386 |
+
return outputs
|
| 387 |
+
|
| 388 |
+
def _generate_onnx(
|
| 389 |
+
self,
|
| 390 |
+
prompt: str,
|
| 391 |
+
max_length: int,
|
| 392 |
+
temperature: float,
|
| 393 |
+
top_k: Optional[int],
|
| 394 |
+
num_return_sequences: int,
|
| 395 |
+
stop_sequences: Optional[List[str]],
|
| 396 |
+
) -> List[str]:
|
| 397 |
+
"""Generate using ONNX model."""
|
| 398 |
+
outputs = []
|
| 399 |
+
|
| 400 |
+
for _ in range(num_return_sequences):
|
| 401 |
+
# Tokenize prompt
|
| 402 |
+
tokens = self.tokenizer.encode(prompt)
|
| 403 |
+
generated = tokens.copy()
|
| 404 |
+
|
| 405 |
+
# Simple autoregressive generation
|
| 406 |
+
for _ in range(max_length):
|
| 407 |
+
if len(generated) >= 512: # Max sequence length for ONNX
|
| 408 |
+
break
|
| 409 |
+
|
| 410 |
+
# Prepare input (last 64 tokens to fit ONNX model)
|
| 411 |
+
current_input = np.array([generated[-64:]], dtype=np.int64)
|
| 412 |
+
|
| 413 |
+
# Run inference
|
| 414 |
+
logits = self.onnx_session.run(None, {"input_ids": current_input})[0]
|
| 415 |
+
next_token_logits = logits[0, -1, :]
|
| 416 |
+
|
| 417 |
+
# Apply temperature
|
| 418 |
+
if temperature > 0:
|
| 419 |
+
next_token_logits = next_token_logits / temperature
|
| 420 |
+
probs = np.exp(next_token_logits) / np.sum(np.exp(next_token_logits))
|
| 421 |
+
|
| 422 |
+
# Apply top-k if specified
|
| 423 |
+
if top_k:
|
| 424 |
+
top_indices = np.argpartition(probs, -top_k)[-top_k:]
|
| 425 |
+
probs_filtered = np.zeros_like(probs)
|
| 426 |
+
probs_filtered[top_indices] = probs[top_indices]
|
| 427 |
+
probs = probs_filtered / np.sum(probs_filtered)
|
| 428 |
+
|
| 429 |
+
next_token = np.random.choice(len(probs), p=probs)
|
| 430 |
+
else:
|
| 431 |
+
next_token = np.argmax(next_token_logits)
|
| 432 |
+
|
| 433 |
+
generated.append(int(next_token))
|
| 434 |
+
|
| 435 |
+
# Decode generated text
|
| 436 |
+
generated_text = self.tokenizer.decode(generated[len(tokens) :])
|
| 437 |
+
|
| 438 |
+
# Apply stop sequences
|
| 439 |
+
if stop_sequences:
|
| 440 |
+
for stop_seq in stop_sequences:
|
| 441 |
+
if stop_seq in generated_text:
|
| 442 |
+
generated_text = generated_text.split(stop_seq)[0]
|
| 443 |
+
break
|
| 444 |
+
|
| 445 |
+
outputs.append(generated_text)
|
| 446 |
+
|
| 447 |
+
return outputs
|
| 448 |
+
|
| 449 |
+
def _simple_generate(
|
| 450 |
+
self, input_tensor: torch.Tensor, max_length: int, temperature: float
|
| 451 |
+
) -> str:
|
| 452 |
+
"""Simple fallback generation method."""
|
| 453 |
+
generated = input_tensor[0].tolist()
|
| 454 |
+
|
| 455 |
+
for _ in range(max_length):
|
| 456 |
+
if len(generated) >= self.config.get("block_size", 1024):
|
| 457 |
+
break
|
| 458 |
+
|
| 459 |
+
# Forward pass
|
| 460 |
+
current_input = torch.tensor([generated], dtype=torch.long, device=self.device)
|
| 461 |
+
with torch.no_grad():
|
| 462 |
+
logits, _ = self.model(current_input)
|
| 463 |
+
|
| 464 |
+
# Get next token logits and apply temperature
|
| 465 |
+
next_token_logits = logits[0, -1, :] / temperature
|
| 466 |
+
probs = torch.softmax(next_token_logits, dim=-1)
|
| 467 |
+
next_token = torch.multinomial(probs, num_samples=1).item()
|
| 468 |
+
|
| 469 |
+
generated.append(next_token)
|
| 470 |
+
|
| 471 |
+
# Decode only the generated part
|
| 472 |
+
original_length = input_tensor.size(1)
|
| 473 |
+
generated_tokens = generated[original_length:]
|
| 474 |
+
return self.tokenizer.decode(generated_tokens)
|
| 475 |
+
|
| 476 |
+
def get_info(self) -> Dict[str, Any]:
|
| 477 |
+
"""Get model information."""
|
| 478 |
+
return {
|
| 479 |
+
"model_name": self.config.get("model_name", "OpenLLM"),
|
| 480 |
+
"model_size": self.config.get("model_size", "unknown"),
|
| 481 |
+
"parameters": self.config.get("n_embd", 0)
|
| 482 |
+
* self.config.get("n_layer", 0), # Approximate
|
| 483 |
+
"vocab_size": self.config.get("vocab_size", self.tokenizer.vocab_size()),
|
| 484 |
+
"max_length": self.config.get("block_size", 1024),
|
| 485 |
+
"format": self.detected_format,
|
| 486 |
+
"loaded_at": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.loaded_at)),
|
| 487 |
+
}
|
| 488 |
+
|
| 489 |
+
def get_health(self) -> Dict[str, Any]:
|
| 490 |
+
"""Get health status."""
|
| 491 |
+
return {
|
| 492 |
+
"status": "healthy",
|
| 493 |
+
"model_loaded": self.model is not None,
|
| 494 |
+
"uptime_seconds": time.time() - self.loaded_at,
|
| 495 |
+
"total_requests": self.total_requests,
|
| 496 |
+
}
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
# Global inference engine
|
| 500 |
+
inference_engine: Optional[OpenLLMInference] = None
|
| 501 |
+
|
| 502 |
+
# FastAPI app
|
| 503 |
+
app = FastAPI(
|
| 504 |
+
title="OpenLLM Inference API",
|
| 505 |
+
description="REST API for OpenLLM text generation",
|
| 506 |
+
version="0.1.0",
|
| 507 |
+
docs_url="/docs",
|
| 508 |
+
redoc_url="/redoc",
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
# CORS middleware
|
| 512 |
+
app.add_middleware(
|
| 513 |
+
CORSMiddleware,
|
| 514 |
+
allow_origins=["*"], # Configure appropriately for production
|
| 515 |
+
allow_credentials=True,
|
| 516 |
+
allow_methods=["*"],
|
| 517 |
+
allow_headers=["*"],
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
@app.on_event("startup")
|
| 522 |
+
async def startup_event():
|
| 523 |
+
"""Initialize inference engine on startup."""
|
| 524 |
+
print("π Starting OpenLLM Inference Server...")
|
| 525 |
+
# Note: Model loading is handled in main() function
|
| 526 |
+
# For testing, we'll create a mock model if none exists
|
| 527 |
+
global inference_engine
|
| 528 |
+
if inference_engine is None:
|
| 529 |
+
print("β οΈ No model loaded - server will return 503 for generation requests")
|
| 530 |
+
print(" Use main() function to load a real model")
|
| 531 |
+
print(" For testing, use load_model_for_testing() function")
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
@app.post("/generate", response_model=GenerationResponse)
|
| 535 |
+
async def generate_text(request: GenerationRequest, background_tasks: BackgroundTasks):
|
| 536 |
+
"""Generate text from prompt."""
|
| 537 |
+
if inference_engine is None:
|
| 538 |
+
raise HTTPException(status_code=503, detail="Model not loaded")
|
| 539 |
+
|
| 540 |
+
start_time = time.time()
|
| 541 |
+
|
| 542 |
+
try:
|
| 543 |
+
# Generate text
|
| 544 |
+
generated_texts = inference_engine.generate(
|
| 545 |
+
prompt=request.prompt,
|
| 546 |
+
max_length=request.max_length,
|
| 547 |
+
temperature=request.temperature,
|
| 548 |
+
top_k=request.top_k,
|
| 549 |
+
top_p=request.top_p,
|
| 550 |
+
num_return_sequences=request.num_return_sequences,
|
| 551 |
+
stop_sequences=request.stop_sequences,
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
generation_time = time.time() - start_time
|
| 555 |
+
|
| 556 |
+
return GenerationResponse(
|
| 557 |
+
generated_text=generated_texts,
|
| 558 |
+
prompt=request.prompt,
|
| 559 |
+
generation_time=generation_time,
|
| 560 |
+
parameters={
|
| 561 |
+
"max_length": request.max_length,
|
| 562 |
+
"temperature": request.temperature,
|
| 563 |
+
"top_k": request.top_k,
|
| 564 |
+
"top_p": request.top_p,
|
| 565 |
+
"num_return_sequences": request.num_return_sequences,
|
| 566 |
+
},
|
| 567 |
+
)
|
| 568 |
+
|
| 569 |
+
except Exception as e:
|
| 570 |
+
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
@app.post("/generate/stream")
|
| 574 |
+
async def generate_text_stream(request: GenerationRequest):
|
| 575 |
+
"""Generate text with streaming response."""
|
| 576 |
+
if inference_engine is None:
|
| 577 |
+
raise HTTPException(status_code=503, detail="Model not loaded")
|
| 578 |
+
|
| 579 |
+
try:
|
| 580 |
+
# For now, return a simple streaming response
|
| 581 |
+
# In a real implementation, this would stream tokens as they're generated
|
| 582 |
+
generated_texts = inference_engine.generate(
|
| 583 |
+
prompt=request.prompt,
|
| 584 |
+
max_length=request.max_length,
|
| 585 |
+
temperature=request.temperature,
|
| 586 |
+
top_k=request.top_k,
|
| 587 |
+
top_p=request.top_p,
|
| 588 |
+
num_return_sequences=request.num_return_sequences,
|
| 589 |
+
stop_sequences=request.stop_sequences,
|
| 590 |
+
)
|
| 591 |
+
|
| 592 |
+
# Return as streaming response
|
| 593 |
+
return {
|
| 594 |
+
"generated_text": generated_texts,
|
| 595 |
+
"prompt": request.prompt,
|
| 596 |
+
"streaming": True,
|
| 597 |
+
}
|
| 598 |
+
|
| 599 |
+
except Exception as e:
|
| 600 |
+
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
@app.get("/info", response_model=ModelInfo)
|
| 604 |
+
async def get_model_info():
|
| 605 |
+
"""Get model information."""
|
| 606 |
+
if inference_engine is None:
|
| 607 |
+
raise HTTPException(status_code=503, detail="Model not loaded")
|
| 608 |
+
|
| 609 |
+
info = inference_engine.get_info()
|
| 610 |
+
return ModelInfo(**info)
|
| 611 |
+
|
| 612 |
+
|
| 613 |
+
@app.get("/health", response_model=HealthResponse)
|
| 614 |
+
async def health_check():
|
| 615 |
+
"""Health check endpoint."""
|
| 616 |
+
if inference_engine is None:
|
| 617 |
+
return HealthResponse(
|
| 618 |
+
status="unhealthy", model_loaded=False, uptime_seconds=0.0, total_requests=0
|
| 619 |
+
)
|
| 620 |
+
|
| 621 |
+
health = inference_engine.get_health()
|
| 622 |
+
return HealthResponse(**health)
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
@app.get("/")
|
| 626 |
+
async def root():
|
| 627 |
+
"""Root endpoint."""
|
| 628 |
+
return {
|
| 629 |
+
"message": "OpenLLM Inference API",
|
| 630 |
+
"version": "0.1.0",
|
| 631 |
+
"docs": "/docs",
|
| 632 |
+
"health": "/health",
|
| 633 |
+
"info": "/info",
|
| 634 |
+
"endpoints": ["/generate", "/generate/stream", "/health", "/info"],
|
| 635 |
+
}
|
| 636 |
+
|
| 637 |
+
|
| 638 |
+
def main():
|
| 639 |
+
"""Main server function."""
|
| 640 |
+
parser = argparse.ArgumentParser(
|
| 641 |
+
description="OpenLLM Inference Server",
|
| 642 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 643 |
+
epilog="""
|
| 644 |
+
Examples:
|
| 645 |
+
# Start server with Hugging Face model
|
| 646 |
+
python core/src/inference_server.py \\
|
| 647 |
+
--model_path exports/huggingface/ \\
|
| 648 |
+
--host 0.0.0.0 \\
|
| 649 |
+
--port 8000
|
| 650 |
+
|
| 651 |
+
# Start server with ONNX model
|
| 652 |
+
python core/src/inference_server.py \\
|
| 653 |
+
--model_path exports/onnx/ \\
|
| 654 |
+
--format onnx \\
|
| 655 |
+
--port 8001
|
| 656 |
+
""",
|
| 657 |
+
)
|
| 658 |
+
|
| 659 |
+
parser.add_argument(
|
| 660 |
+
"--model_path",
|
| 661 |
+
required=True,
|
| 662 |
+
help="Path to exported model directory",
|
| 663 |
+
)
|
| 664 |
+
|
| 665 |
+
parser.add_argument(
|
| 666 |
+
"--format",
|
| 667 |
+
choices=["pytorch", "huggingface", "onnx", "auto"],
|
| 668 |
+
default="auto",
|
| 669 |
+
help="Model format (default: auto-detect)",
|
| 670 |
+
)
|
| 671 |
+
|
| 672 |
+
parser.add_argument(
|
| 673 |
+
"--host",
|
| 674 |
+
default="127.0.0.1",
|
| 675 |
+
help="Host to bind to (default: 127.0.0.1)",
|
| 676 |
+
)
|
| 677 |
+
|
| 678 |
+
parser.add_argument(
|
| 679 |
+
"--port",
|
| 680 |
+
type=int,
|
| 681 |
+
default=8000,
|
| 682 |
+
help="Port to bind to (default: 8000)",
|
| 683 |
+
)
|
| 684 |
+
|
| 685 |
+
parser.add_argument(
|
| 686 |
+
"--max_length",
|
| 687 |
+
type=int,
|
| 688 |
+
default=512,
|
| 689 |
+
help="Maximum generation length (default: 512)",
|
| 690 |
+
)
|
| 691 |
+
|
| 692 |
+
args = parser.parse_args()
|
| 693 |
+
|
| 694 |
+
# Initialize inference engine
|
| 695 |
+
global inference_engine
|
| 696 |
+
inference_engine = OpenLLMInference(args.model_path, args.format)
|
| 697 |
+
|
| 698 |
+
# Start server
|
| 699 |
+
print(f"π Starting server on {args.host}:{args.port}")
|
| 700 |
+
uvicorn.run(
|
| 701 |
+
app,
|
| 702 |
+
host=args.host,
|
| 703 |
+
port=args.port,
|
| 704 |
+
log_level="info",
|
| 705 |
+
)
|
| 706 |
+
|
| 707 |
+
|
| 708 |
+
def load_model(model_path: str, model_format: str = "auto"):
|
| 709 |
+
"""
|
| 710 |
+
Load model for testing purposes.
|
| 711 |
+
|
| 712 |
+
This function is used by tests to load models without starting the full server.
|
| 713 |
+
|
| 714 |
+
Args:
|
| 715 |
+
model_path: Path to exported model directory
|
| 716 |
+
model_format: Model format (pytorch, huggingface, onnx, auto)
|
| 717 |
+
|
| 718 |
+
Returns:
|
| 719 |
+
OpenLLMInference: Initialized inference engine
|
| 720 |
+
"""
|
| 721 |
+
return OpenLLMInference(model_path, model_format)
|
| 722 |
+
|
| 723 |
+
|
| 724 |
+
def load_model_for_testing(
|
| 725 |
+
model_path: str = "exports/huggingface", model_format: str = "huggingface"
|
| 726 |
+
):
|
| 727 |
+
"""
|
| 728 |
+
Load a real model for testing purposes.
|
| 729 |
+
|
| 730 |
+
This function loads the actual trained model for testing.
|
| 731 |
+
|
| 732 |
+
Args:
|
| 733 |
+
model_path: Path to the model directory (default: exports/huggingface)
|
| 734 |
+
model_format: Model format (default: huggingface)
|
| 735 |
+
|
| 736 |
+
Returns:
|
| 737 |
+
OpenLLMInference: Real inference engine with loaded model
|
| 738 |
+
"""
|
| 739 |
+
global inference_engine
|
| 740 |
+
try:
|
| 741 |
+
inference_engine = OpenLLMInference(model_path, model_format)
|
| 742 |
+
print(f"β
Real model loaded for testing from {model_path}")
|
| 743 |
+
return inference_engine
|
| 744 |
+
except Exception as e:
|
| 745 |
+
print(f"β Failed to load real model: {e}")
|
| 746 |
+
# Fallback to mock model for testing
|
| 747 |
+
return create_test_model()
|
| 748 |
+
|
| 749 |
+
|
| 750 |
+
def create_test_model():
|
| 751 |
+
"""
|
| 752 |
+
Create a real lightweight test model for testing purposes.
|
| 753 |
+
|
| 754 |
+
This creates a real model with minimal parameters for testing,
|
| 755 |
+
without requiring large model files to be downloaded.
|
| 756 |
+
|
| 757 |
+
Returns:
|
| 758 |
+
OpenLLMInference: Real lightweight inference engine
|
| 759 |
+
"""
|
| 760 |
+
try:
|
| 761 |
+
# Create a real model with minimal parameters
|
| 762 |
+
import sentencepiece as smp
|
| 763 |
+
from model import GPTConfig, GPTModel
|
| 764 |
+
|
| 765 |
+
# Create minimal config for testing
|
| 766 |
+
config = GPTConfig.small()
|
| 767 |
+
config.n_embd = 128 # Very small for testing
|
| 768 |
+
config.n_layer = 2 # Very small for testing
|
| 769 |
+
config.vocab_size = 1000 # Small vocabulary
|
| 770 |
+
config.block_size = 64 # Small context
|
| 771 |
+
|
| 772 |
+
# Create real model
|
| 773 |
+
model = GPTModel(config)
|
| 774 |
+
model.eval()
|
| 775 |
+
|
| 776 |
+
# Create minimal tokenizer
|
| 777 |
+
class MinimalTokenizer:
|
| 778 |
+
def __init__(self):
|
| 779 |
+
self.vocab_size = 1000
|
| 780 |
+
|
| 781 |
+
def encode(self, text):
|
| 782 |
+
# Simple character-based encoding for testing
|
| 783 |
+
return [ord(c) % 1000 for c in text[:50]] # Limit to 50 chars
|
| 784 |
+
|
| 785 |
+
def decode(self, tokens):
|
| 786 |
+
# Simple character-based decoding for testing
|
| 787 |
+
return "".join([chr(t % 256) for t in tokens if t < 256])
|
| 788 |
+
|
| 789 |
+
def vocab_size(self):
|
| 790 |
+
return 1000
|
| 791 |
+
|
| 792 |
+
# Create real inference engine with lightweight model
|
| 793 |
+
class LightweightInferenceEngine:
|
| 794 |
+
def __init__(self):
|
| 795 |
+
self.model = model
|
| 796 |
+
self.tokenizer = MinimalTokenizer()
|
| 797 |
+
self.config = {
|
| 798 |
+
"model_name": "openllm-small-test",
|
| 799 |
+
"model_size": "small",
|
| 800 |
+
"n_embd": config.n_embd,
|
| 801 |
+
"n_layer": config.n_layer,
|
| 802 |
+
"vocab_size": config.vocab_size,
|
| 803 |
+
"block_size": config.block_size,
|
| 804 |
+
}
|
| 805 |
+
self.detected_format = "pytorch"
|
| 806 |
+
self.device = "cpu"
|
| 807 |
+
self.loaded_at = time.time()
|
| 808 |
+
self.total_requests = 0
|
| 809 |
+
|
| 810 |
+
def generate(self, prompt, max_length=10, temperature=0.7, **kwargs):
|
| 811 |
+
"""Real text generation with lightweight model."""
|
| 812 |
+
self.total_requests += 1
|
| 813 |
+
|
| 814 |
+
# Tokenize input
|
| 815 |
+
input_ids = self.tokenizer.encode(prompt)
|
| 816 |
+
if len(input_ids) == 0:
|
| 817 |
+
input_ids = [1] # Default token
|
| 818 |
+
|
| 819 |
+
# Simple autoregressive generation
|
| 820 |
+
generated = input_ids.copy()
|
| 821 |
+
for _ in range(max_length):
|
| 822 |
+
if len(generated) >= self.config["block_size"]:
|
| 823 |
+
break
|
| 824 |
+
|
| 825 |
+
# Create input tensor
|
| 826 |
+
input_tensor = torch.tensor([generated], dtype=torch.long)
|
| 827 |
+
|
| 828 |
+
# Forward pass
|
| 829 |
+
with torch.no_grad():
|
| 830 |
+
logits, _ = self.model(input_tensor)
|
| 831 |
+
|
| 832 |
+
# Get next token
|
| 833 |
+
next_token_logits = logits[0, -1, :] / temperature
|
| 834 |
+
probs = torch.softmax(next_token_logits, dim=-1)
|
| 835 |
+
next_token = torch.multinomial(probs, num_samples=1).item()
|
| 836 |
+
|
| 837 |
+
generated.append(next_token)
|
| 838 |
+
|
| 839 |
+
# Decode generated text
|
| 840 |
+
generated_text = self.tokenizer.decode(generated[len(input_ids) :])
|
| 841 |
+
return [generated_text]
|
| 842 |
+
|
| 843 |
+
def get_info(self):
|
| 844 |
+
"""Get real model information."""
|
| 845 |
+
return {
|
| 846 |
+
"model_name": "openllm-small-test",
|
| 847 |
+
"model_size": "small",
|
| 848 |
+
"parameters": config.n_embd * config.n_layer * 1000,
|
| 849 |
+
"vocab_size": config.vocab_size,
|
| 850 |
+
"max_length": config.block_size,
|
| 851 |
+
"format": "pytorch",
|
| 852 |
+
"loaded_at": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.loaded_at)),
|
| 853 |
+
}
|
| 854 |
+
|
| 855 |
+
def get_health(self):
|
| 856 |
+
"""Get real health status."""
|
| 857 |
+
return {
|
| 858 |
+
"status": "healthy",
|
| 859 |
+
"model_loaded": True,
|
| 860 |
+
"uptime_seconds": time.time() - self.loaded_at,
|
| 861 |
+
"total_requests": self.total_requests,
|
| 862 |
+
}
|
| 863 |
+
|
| 864 |
+
return LightweightInferenceEngine()
|
| 865 |
+
|
| 866 |
+
except Exception as e:
|
| 867 |
+
print(f"β οΈ Failed to create lightweight model: {e}")
|
| 868 |
+
|
| 869 |
+
# Fallback to simple mock if real model creation fails
|
| 870 |
+
class SimpleMockInferenceEngine:
|
| 871 |
+
def __init__(self):
|
| 872 |
+
self.model = "simple_mock"
|
| 873 |
+
self.tokenizer = "simple_mock"
|
| 874 |
+
self.config = {"model_name": "fallback-model"}
|
| 875 |
+
self.detected_format = "pytorch"
|
| 876 |
+
self.device = "cpu"
|
| 877 |
+
self.loaded_at = time.time()
|
| 878 |
+
self.total_requests = 0
|
| 879 |
+
|
| 880 |
+
def generate(self, prompt, **kwargs):
|
| 881 |
+
self.total_requests += 1
|
| 882 |
+
return [f"Generated: {prompt[:10]}..."]
|
| 883 |
+
|
| 884 |
+
def get_info(self):
|
| 885 |
+
return {
|
| 886 |
+
"model_name": "fallback-model",
|
| 887 |
+
"model_size": "small",
|
| 888 |
+
"parameters": 1000,
|
| 889 |
+
"vocab_size": 1000,
|
| 890 |
+
"max_length": 100,
|
| 891 |
+
"format": "pytorch",
|
| 892 |
+
"loaded_at": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.loaded_at)),
|
| 893 |
+
}
|
| 894 |
+
|
| 895 |
+
def get_health(self):
|
| 896 |
+
return {
|
| 897 |
+
"status": "healthy",
|
| 898 |
+
"model_loaded": True,
|
| 899 |
+
"uptime_seconds": time.time() - self.loaded_at,
|
| 900 |
+
"total_requests": self.total_requests,
|
| 901 |
+
}
|
| 902 |
+
|
| 903 |
+
return SimpleMockInferenceEngine()
|
| 904 |
+
|
| 905 |
+
|
| 906 |
+
if __name__ == "__main__":
|
| 907 |
+
main()
|
core/src/main.py
ADDED
|
@@ -0,0 +1,842 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright (C) 2024 Louis Chua Bean Chong
|
| 3 |
+
#
|
| 4 |
+
# This file is part of OpenLLM.
|
| 5 |
+
#
|
| 6 |
+
# OpenLLM is dual-licensed:
|
| 7 |
+
# 1. For open source use: GNU General Public License v3.0
|
| 8 |
+
# 2. For commercial use: Commercial License (contact for details)
|
| 9 |
+
#
|
| 10 |
+
# See LICENSE and docs/LICENSES.md for full license information.
|
| 11 |
+
|
| 12 |
+
"""
|
| 13 |
+
OpenLLM - Main CLI Entry Point
|
| 14 |
+
|
| 15 |
+
This module provides a unified command-line interface for all OpenLLM operations
|
| 16 |
+
including data preparation, tokenizer training, model training, and inference.
|
| 17 |
+
|
| 18 |
+
Usage:
|
| 19 |
+
python core/src/main.py <command> [options]
|
| 20 |
+
|
| 21 |
+
Available Commands:
|
| 22 |
+
prepare-data Download and prepare training data from SQUAD dataset
|
| 23 |
+
train-tokenizer Train a SentencePiece tokenizer on the prepared data
|
| 24 |
+
test-model Test and validate model architecture
|
| 25 |
+
train-model Train the language model
|
| 26 |
+
inference Run model inference (coming soon)
|
| 27 |
+
evaluate Evaluate model performance (coming soon)
|
| 28 |
+
|
| 29 |
+
Examples:
|
| 30 |
+
# Full pipeline
|
| 31 |
+
python core/src/main.py prepare-data
|
| 32 |
+
python core/src/main.py train-tokenizer --vocab-size 32000
|
| 33 |
+
python core/src/main.py test-model --model-size small
|
| 34 |
+
python core/src/main.py train-model --model-size small --output-dir models/my-model
|
| 35 |
+
|
| 36 |
+
# Help for specific commands
|
| 37 |
+
python core/src/main.py train-model --help
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
import argparse
|
| 41 |
+
import os
|
| 42 |
+
import sys
|
| 43 |
+
from pathlib import Path
|
| 44 |
+
|
| 45 |
+
# Set console encoding for Windows compatibility
|
| 46 |
+
if sys.platform == "win32":
|
| 47 |
+
import codecs
|
| 48 |
+
sys.stdout = codecs.getwriter("utf-8")(sys.stdout.detach())
|
| 49 |
+
sys.stderr = codecs.getwriter("utf-8")(sys.stderr.detach())
|
| 50 |
+
|
| 51 |
+
# Add the current directory to Python path for imports
|
| 52 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 53 |
+
|
| 54 |
+
try:
|
| 55 |
+
from download_and_prepare import prepare_training_data
|
| 56 |
+
from model_test import ModelTester
|
| 57 |
+
from train_tokenizer import (
|
| 58 |
+
count_training_sentences,
|
| 59 |
+
save_huggingface_config,
|
| 60 |
+
test_tokenizer,
|
| 61 |
+
train_sentencepiece_tokenizer,
|
| 62 |
+
validate_input_file,
|
| 63 |
+
)
|
| 64 |
+
except ImportError as e:
|
| 65 |
+
print(f"Error importing modules: {e}")
|
| 66 |
+
print("Make sure you're running this from the correct directory.")
|
| 67 |
+
sys.exit(1)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def cmd_prepare_data(args):
|
| 71 |
+
"""Execute data preparation command."""
|
| 72 |
+
print("ποΈ Starting data preparation...")
|
| 73 |
+
print(f"Output path: {args.output}")
|
| 74 |
+
print(f"Minimum words per passage: {args.min_words}")
|
| 75 |
+
|
| 76 |
+
try:
|
| 77 |
+
prepare_training_data(output_path=args.output, min_words=args.min_words)
|
| 78 |
+
print("β
Data preparation completed successfully!")
|
| 79 |
+
return True
|
| 80 |
+
except Exception as e:
|
| 81 |
+
print(f"β Data preparation failed: {e}")
|
| 82 |
+
return False
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def cmd_train_tokenizer(args):
|
| 86 |
+
"""Execute tokenizer training command."""
|
| 87 |
+
print("π€ Starting tokenizer training...")
|
| 88 |
+
print(f"Input: {args.input}")
|
| 89 |
+
print(f"Output directory: {args.output_dir}")
|
| 90 |
+
print(f"Vocabulary size: {args.vocab_size:,}")
|
| 91 |
+
print(f"Model type: {args.model_type}")
|
| 92 |
+
|
| 93 |
+
try:
|
| 94 |
+
# Step 1: Validate input
|
| 95 |
+
validate_input_file(args.input)
|
| 96 |
+
|
| 97 |
+
# Step 2: Count training data
|
| 98 |
+
sentence_count = count_training_sentences(args.input)
|
| 99 |
+
|
| 100 |
+
# Step 3: Train tokenizer
|
| 101 |
+
config = train_sentencepiece_tokenizer(
|
| 102 |
+
input_path=args.input,
|
| 103 |
+
output_dir=args.output_dir,
|
| 104 |
+
vocab_size=args.vocab_size,
|
| 105 |
+
model_type=args.model_type,
|
| 106 |
+
character_coverage=args.character_coverage,
|
| 107 |
+
max_sentence_length=args.max_sentence_length,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
# Step 4: Save Hugging Face config
|
| 111 |
+
save_huggingface_config(args.output_dir, config)
|
| 112 |
+
|
| 113 |
+
# Step 5: Test tokenizer (unless skipped)
|
| 114 |
+
if not args.no_test:
|
| 115 |
+
model_path = os.path.join(args.output_dir, "tokenizer.model")
|
| 116 |
+
test_tokenizer(model_path)
|
| 117 |
+
|
| 118 |
+
print("β
Tokenizer training completed successfully!")
|
| 119 |
+
print(f"π Output: {args.output_dir}")
|
| 120 |
+
print(f"π Vocabulary size: {config['vocab_size']:,}")
|
| 121 |
+
print(f"π Training sentences: {sentence_count:,}")
|
| 122 |
+
return True
|
| 123 |
+
|
| 124 |
+
except Exception as e:
|
| 125 |
+
print(f"β Tokenizer training failed: {e}")
|
| 126 |
+
return False
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def cmd_train_model(args):
|
| 130 |
+
"""Execute model training command."""
|
| 131 |
+
print("ποΈ Starting model training...")
|
| 132 |
+
|
| 133 |
+
try:
|
| 134 |
+
import os
|
| 135 |
+
|
| 136 |
+
import torch
|
| 137 |
+
from data_loader import TextDataLoader
|
| 138 |
+
from train_model import ModelTrainer, create_model
|
| 139 |
+
|
| 140 |
+
# Determine device
|
| 141 |
+
if args.device == "auto":
|
| 142 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 143 |
+
else:
|
| 144 |
+
device = args.device
|
| 145 |
+
|
| 146 |
+
print(f"Device: {device}")
|
| 147 |
+
|
| 148 |
+
# Create model
|
| 149 |
+
print(f"Creating {args.model_size} model...")
|
| 150 |
+
model = create_model(args.model_size)
|
| 151 |
+
|
| 152 |
+
# Create data loader
|
| 153 |
+
print("Setting up data loader...")
|
| 154 |
+
tokenizer_path = os.path.join(args.tokenizer_dir, "tokenizer.model")
|
| 155 |
+
|
| 156 |
+
if not os.path.exists(tokenizer_path):
|
| 157 |
+
print(f"β Tokenizer not found at {tokenizer_path}")
|
| 158 |
+
print(
|
| 159 |
+
"Please run: python core/src/main.py train-tokenizer --input data/clean/training_data.txt"
|
| 160 |
+
)
|
| 161 |
+
return False
|
| 162 |
+
|
| 163 |
+
data_loader = TextDataLoader(
|
| 164 |
+
data_file=args.data_file,
|
| 165 |
+
tokenizer_path=tokenizer_path,
|
| 166 |
+
seq_len=args.seq_len,
|
| 167 |
+
batch_size=args.batch_size,
|
| 168 |
+
shuffle=True,
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
# Get data statistics
|
| 172 |
+
_ = data_loader.get_data_stats()
|
| 173 |
+
|
| 174 |
+
# Create trainer
|
| 175 |
+
print("Setting up trainer...")
|
| 176 |
+
trainer = ModelTrainer(
|
| 177 |
+
model=model,
|
| 178 |
+
data_loader=data_loader,
|
| 179 |
+
output_dir=args.output_dir,
|
| 180 |
+
device=device,
|
| 181 |
+
learning_rate=args.learning_rate,
|
| 182 |
+
max_steps=args.max_steps,
|
| 183 |
+
warmup_steps=args.warmup_steps,
|
| 184 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 185 |
+
save_every=args.save_every,
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
# Resume from checkpoint if specified
|
| 189 |
+
if args.resume:
|
| 190 |
+
trainer._load_checkpoint(args.resume)
|
| 191 |
+
|
| 192 |
+
# Start training
|
| 193 |
+
trainer.train()
|
| 194 |
+
|
| 195 |
+
return True
|
| 196 |
+
|
| 197 |
+
except Exception as e:
|
| 198 |
+
print(f"β Training failed: {e}")
|
| 199 |
+
import traceback
|
| 200 |
+
|
| 201 |
+
traceback.print_exc()
|
| 202 |
+
return False
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def cmd_inference(args):
|
| 206 |
+
"""
|
| 207 |
+
Execute model inference command.
|
| 208 |
+
|
| 209 |
+
This function implements text generation using trained OpenLLM models.
|
| 210 |
+
It supports multiple model formats and provides flexible generation options.
|
| 211 |
+
|
| 212 |
+
Args:
|
| 213 |
+
args: Namespace containing CLI arguments including:
|
| 214 |
+
- model_path: Path to trained model directory
|
| 215 |
+
- prompt: Input text prompt for generation
|
| 216 |
+
- max_length: Maximum number of tokens to generate
|
| 217 |
+
- temperature: Sampling temperature (0.1-2.0)
|
| 218 |
+
- format: Model format (auto-detect by default)
|
| 219 |
+
|
| 220 |
+
Returns:
|
| 221 |
+
bool: True if inference succeeded, False otherwise
|
| 222 |
+
|
| 223 |
+
Implementation Details:
|
| 224 |
+
- Auto-detects model format (PyTorch, Hugging Face, ONNX)
|
| 225 |
+
- Uses inference_server.py's OpenLLMInference class for generation
|
| 226 |
+
- Supports configurable generation parameters
|
| 227 |
+
- Handles errors gracefully with informative messages
|
| 228 |
+
"""
|
| 229 |
+
print("π OpenLLM Model Inference")
|
| 230 |
+
print("=" * 40)
|
| 231 |
+
|
| 232 |
+
try:
|
| 233 |
+
# Import inference functionality
|
| 234 |
+
# We import here to avoid circular imports and handle missing dependencies
|
| 235 |
+
import os
|
| 236 |
+
import sys
|
| 237 |
+
|
| 238 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 239 |
+
|
| 240 |
+
from inference_server import OpenLLMInference
|
| 241 |
+
|
| 242 |
+
# Validate model path exists
|
| 243 |
+
# Early validation prevents confusing error messages later
|
| 244 |
+
model_path = Path(args.model_path)
|
| 245 |
+
if not model_path.exists():
|
| 246 |
+
print(f"β Model path not found: {args.model_path}")
|
| 247 |
+
print(" Please check the path and try again.")
|
| 248 |
+
return False
|
| 249 |
+
|
| 250 |
+
# Initialize inference engine
|
| 251 |
+
# This handles model loading and format detection automatically
|
| 252 |
+
print(f"π Loading model from: {args.model_path}")
|
| 253 |
+
inference_engine = OpenLLMInference(
|
| 254 |
+
model_path=str(model_path),
|
| 255 |
+
model_format=getattr(args, "format", "auto"), # Default to auto-detection
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
# Prepare generation parameters
|
| 259 |
+
# These parameters control the quality and style of generated text
|
| 260 |
+
generation_params = {
|
| 261 |
+
"max_length": args.max_length,
|
| 262 |
+
"temperature": getattr(args, "temperature", 0.7), # Default temperature
|
| 263 |
+
"top_k": getattr(args, "top_k", 40), # Default top-k
|
| 264 |
+
"top_p": getattr(args, "top_p", 0.9), # Default nucleus sampling
|
| 265 |
+
"num_return_sequences": getattr(args, "num_sequences", 1), # Default single sequence
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
print(f"π Generating text for prompt: '{args.prompt}'")
|
| 269 |
+
print(
|
| 270 |
+
f"βοΈ Parameters: max_length={generation_params['max_length']}, "
|
| 271 |
+
f"temperature={generation_params['temperature']}"
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
# Generate text using the inference engine
|
| 275 |
+
# This is the core functionality that produces the output
|
| 276 |
+
import time
|
| 277 |
+
|
| 278 |
+
start_time = time.time()
|
| 279 |
+
|
| 280 |
+
generated_texts = inference_engine.generate(prompt=args.prompt, **generation_params)
|
| 281 |
+
|
| 282 |
+
generation_time = time.time() - start_time
|
| 283 |
+
|
| 284 |
+
# Display results with formatting
|
| 285 |
+
# Clear presentation helps users understand the output
|
| 286 |
+
print("\n⨠Generated Text:")
|
| 287 |
+
print("-" * 50)
|
| 288 |
+
|
| 289 |
+
for i, text in enumerate(generated_texts, 1):
|
| 290 |
+
if len(generated_texts) > 1:
|
| 291 |
+
print(f"\n[Sequence {i}]")
|
| 292 |
+
print(text)
|
| 293 |
+
|
| 294 |
+
print("-" * 50)
|
| 295 |
+
print(f"β±οΈ Generation time: {generation_time:.2f} seconds")
|
| 296 |
+
print(f"π Tokens generated: ~{len(generated_texts[0].split())}")
|
| 297 |
+
print(f"π― Model: {inference_engine.config.get('model_name', 'OpenLLM')}")
|
| 298 |
+
|
| 299 |
+
return True
|
| 300 |
+
|
| 301 |
+
except ImportError as e:
|
| 302 |
+
print(f"β Missing dependencies for inference: {e}")
|
| 303 |
+
print(" Please install: pip install fastapi uvicorn")
|
| 304 |
+
return False
|
| 305 |
+
|
| 306 |
+
except Exception as e:
|
| 307 |
+
print(f"β Inference failed: {e}")
|
| 308 |
+
import traceback
|
| 309 |
+
|
| 310 |
+
traceback.print_exc()
|
| 311 |
+
return False
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def cmd_evaluate(args):
|
| 315 |
+
"""
|
| 316 |
+
Execute model evaluation command.
|
| 317 |
+
|
| 318 |
+
This function implements comprehensive model evaluation including intrinsic
|
| 319 |
+
metrics (perplexity) and downstream task performance assessment.
|
| 320 |
+
|
| 321 |
+
Args:
|
| 322 |
+
args: Namespace containing CLI arguments including:
|
| 323 |
+
- model_path: Path to trained model directory
|
| 324 |
+
- eval_data: Path to evaluation dataset (optional)
|
| 325 |
+
- metrics: Comma-separated list of metrics to compute
|
| 326 |
+
- output_dir: Directory to save evaluation results
|
| 327 |
+
- format: Model format (auto-detect by default)
|
| 328 |
+
|
| 329 |
+
Returns:
|
| 330 |
+
bool: True if evaluation succeeded, False otherwise
|
| 331 |
+
|
| 332 |
+
Implementation Details:
|
| 333 |
+
- Uses evaluate_model.py's ModelEvaluator class for comprehensive testing
|
| 334 |
+
- Computes perplexity on held-out data if provided
|
| 335 |
+
- Runs downstream task evaluation (reading comprehension, sentiment, etc.)
|
| 336 |
+
- Generates detailed evaluation report with metrics and examples
|
| 337 |
+
- Saves results to JSON file for further analysis
|
| 338 |
+
"""
|
| 339 |
+
print("π OpenLLM Model Evaluation")
|
| 340 |
+
print("=" * 40)
|
| 341 |
+
|
| 342 |
+
try:
|
| 343 |
+
# Import evaluation functionality
|
| 344 |
+
# We import here to avoid circular imports and handle missing dependencies
|
| 345 |
+
import json
|
| 346 |
+
import os
|
| 347 |
+
import sys
|
| 348 |
+
from pathlib import Path
|
| 349 |
+
|
| 350 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 351 |
+
|
| 352 |
+
from evaluate_model import ModelEvaluator
|
| 353 |
+
|
| 354 |
+
# Validate model path exists
|
| 355 |
+
# Early validation prevents confusing error messages later
|
| 356 |
+
model_path = Path(args.model_path)
|
| 357 |
+
if not model_path.exists():
|
| 358 |
+
print(f"β Model path not found: {args.model_path}")
|
| 359 |
+
print(" Please check the path and try again.")
|
| 360 |
+
return False
|
| 361 |
+
|
| 362 |
+
# Determine output directory for results
|
| 363 |
+
# Create output directory if it doesn't exist
|
| 364 |
+
output_dir = Path(getattr(args, "output_dir", "evaluation_results"))
|
| 365 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 366 |
+
|
| 367 |
+
# Parse requested metrics
|
| 368 |
+
# Default to comprehensive evaluation if not specified
|
| 369 |
+
requested_metrics = getattr(args, "metrics", "perplexity,generation,downstream").split(",")
|
| 370 |
+
requested_metrics = [m.strip() for m in requested_metrics]
|
| 371 |
+
|
| 372 |
+
print(f"π Loading model from: {args.model_path}")
|
| 373 |
+
print(f"π Requested metrics: {', '.join(requested_metrics)}")
|
| 374 |
+
print(f"πΎ Results will be saved to: {output_dir}")
|
| 375 |
+
|
| 376 |
+
# Initialize model evaluator
|
| 377 |
+
# This handles model loading and tokenizer setup
|
| 378 |
+
evaluator = ModelEvaluator(
|
| 379 |
+
model_dir=str(model_path),
|
| 380 |
+
tokenizer_path=getattr(args, "tokenizer_path", None), # Auto-detect if not provided
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
# Prepare evaluation results container
|
| 384 |
+
# This will store all evaluation metrics and examples
|
| 385 |
+
evaluation_results = {
|
| 386 |
+
"model_info": {
|
| 387 |
+
"model_path": str(model_path),
|
| 388 |
+
"model_name": evaluator.config.get("model_name", "OpenLLM"),
|
| 389 |
+
"parameters": evaluator.model.get_num_params(),
|
| 390 |
+
"evaluation_time": None,
|
| 391 |
+
},
|
| 392 |
+
"metrics": {},
|
| 393 |
+
"examples": {},
|
| 394 |
+
"summary": {},
|
| 395 |
+
}
|
| 396 |
+
|
| 397 |
+
import time
|
| 398 |
+
|
| 399 |
+
start_time = time.time()
|
| 400 |
+
|
| 401 |
+
# 1. Perplexity Evaluation
|
| 402 |
+
# This measures how well the model predicts the next token
|
| 403 |
+
if "perplexity" in requested_metrics:
|
| 404 |
+
print("\nπ Computing perplexity...")
|
| 405 |
+
|
| 406 |
+
eval_data_path = getattr(args, "eval_data", None)
|
| 407 |
+
if eval_data_path and Path(eval_data_path).exists():
|
| 408 |
+
# Use provided evaluation data
|
| 409 |
+
perplexity_result = evaluator.evaluate_perplexity(eval_data_path)
|
| 410 |
+
else:
|
| 411 |
+
# Use a subset of training data for perplexity calculation
|
| 412 |
+
print(" No eval data provided, using default test set")
|
| 413 |
+
perplexity_result = evaluator.evaluate_perplexity()
|
| 414 |
+
|
| 415 |
+
evaluation_results["metrics"]["perplexity"] = perplexity_result
|
| 416 |
+
|
| 417 |
+
print(f" β
Perplexity: {perplexity_result.get('perplexity', 'N/A'):.2f}")
|
| 418 |
+
print(f" π Loss: {perplexity_result.get('loss', 'N/A'):.4f}")
|
| 419 |
+
|
| 420 |
+
# 2. Text Generation Quality Assessment
|
| 421 |
+
# This evaluates the coherence and quality of generated text
|
| 422 |
+
if "generation" in requested_metrics:
|
| 423 |
+
print("\nβοΈ Evaluating text generation quality...")
|
| 424 |
+
|
| 425 |
+
generation_result = evaluator.evaluate_text_generation()
|
| 426 |
+
evaluation_results["metrics"]["generation"] = generation_result
|
| 427 |
+
evaluation_results["examples"]["generation"] = generation_result.get("examples", [])
|
| 428 |
+
|
| 429 |
+
print(
|
| 430 |
+
f" β
Average quality score: {generation_result.get('average_quality', 'N/A'):.2f}"
|
| 431 |
+
)
|
| 432 |
+
print(f" π Generated {len(generation_result.get('examples', []))} examples")
|
| 433 |
+
|
| 434 |
+
# 3. Downstream Task Evaluation
|
| 435 |
+
# This tests specific capabilities like reading comprehension
|
| 436 |
+
if "downstream" in requested_metrics:
|
| 437 |
+
print("\nπ― Evaluating downstream tasks...")
|
| 438 |
+
|
| 439 |
+
downstream_result = evaluator.evaluate_downstream_tasks()
|
| 440 |
+
evaluation_results["metrics"]["downstream"] = downstream_result
|
| 441 |
+
evaluation_results["examples"]["downstream"] = {
|
| 442 |
+
task: result.get("examples", []) for task, result in downstream_result.items()
|
| 443 |
+
}
|
| 444 |
+
|
| 445 |
+
# Display summary of downstream results
|
| 446 |
+
for task_name, task_result in downstream_result.items():
|
| 447 |
+
accuracy = task_result.get("accuracy", 0) * 100
|
| 448 |
+
print(f" β
{task_name.replace('_', ' ').title()}: {accuracy:.1f}%")
|
| 449 |
+
|
| 450 |
+
# Calculate total evaluation time
|
| 451 |
+
evaluation_time = time.time() - start_time
|
| 452 |
+
evaluation_results["model_info"]["evaluation_time"] = evaluation_time
|
| 453 |
+
|
| 454 |
+
# Generate evaluation summary
|
| 455 |
+
# This provides a high-level overview of model performance
|
| 456 |
+
summary = {
|
| 457 |
+
"overall_score": 0.0, # Will be calculated based on available metrics
|
| 458 |
+
"strengths": [],
|
| 459 |
+
"weaknesses": [],
|
| 460 |
+
"recommendations": [],
|
| 461 |
+
}
|
| 462 |
+
|
| 463 |
+
# Calculate overall score based on available metrics
|
| 464 |
+
scores = []
|
| 465 |
+
|
| 466 |
+
if "perplexity" in evaluation_results["metrics"]:
|
| 467 |
+
ppl = evaluation_results["metrics"]["perplexity"].get("perplexity", float("inf"))
|
| 468 |
+
# Convert perplexity to 0-100 score (lower perplexity is better)
|
| 469 |
+
ppl_score = max(0, 100 - (ppl - 10) * 5) # Rough conversion
|
| 470 |
+
scores.append(ppl_score)
|
| 471 |
+
|
| 472 |
+
if ppl < 15:
|
| 473 |
+
summary["strengths"].append("Good language modeling (low perplexity)")
|
| 474 |
+
else:
|
| 475 |
+
summary["weaknesses"].append("High perplexity indicates poor language modeling")
|
| 476 |
+
|
| 477 |
+
if "generation" in evaluation_results["metrics"]:
|
| 478 |
+
gen_score = evaluation_results["metrics"]["generation"].get("average_quality", 0) * 100
|
| 479 |
+
scores.append(gen_score)
|
| 480 |
+
|
| 481 |
+
if gen_score > 70:
|
| 482 |
+
summary["strengths"].append("High-quality text generation")
|
| 483 |
+
else:
|
| 484 |
+
summary["weaknesses"].append("Text generation needs improvement")
|
| 485 |
+
|
| 486 |
+
if "downstream" in evaluation_results["metrics"]:
|
| 487 |
+
downstream_scores = []
|
| 488 |
+
for task_result in evaluation_results["metrics"]["downstream"].values():
|
| 489 |
+
downstream_scores.append(task_result.get("accuracy", 0) * 100)
|
| 490 |
+
|
| 491 |
+
if downstream_scores:
|
| 492 |
+
avg_downstream = sum(downstream_scores) / len(downstream_scores)
|
| 493 |
+
scores.append(avg_downstream)
|
| 494 |
+
|
| 495 |
+
if avg_downstream > 50:
|
| 496 |
+
summary["strengths"].append("Good performance on downstream tasks")
|
| 497 |
+
else:
|
| 498 |
+
summary["weaknesses"].append("Poor downstream task performance")
|
| 499 |
+
|
| 500 |
+
# Calculate overall score
|
| 501 |
+
if scores:
|
| 502 |
+
summary["overall_score"] = sum(scores) / len(scores)
|
| 503 |
+
|
| 504 |
+
# Add recommendations based on performance
|
| 505 |
+
if summary["overall_score"] < 40:
|
| 506 |
+
summary["recommendations"].extend(
|
| 507 |
+
[
|
| 508 |
+
"Consider training for more steps",
|
| 509 |
+
"Verify training data quality",
|
| 510 |
+
"Check model architecture and hyperparameters",
|
| 511 |
+
]
|
| 512 |
+
)
|
| 513 |
+
elif summary["overall_score"] < 70:
|
| 514 |
+
summary["recommendations"].extend(
|
| 515 |
+
[
|
| 516 |
+
"Model shows promise - consider extended training",
|
| 517 |
+
"Fine-tune on specific downstream tasks",
|
| 518 |
+
]
|
| 519 |
+
)
|
| 520 |
+
else:
|
| 521 |
+
summary["recommendations"].append("Model performs well - ready for deployment")
|
| 522 |
+
|
| 523 |
+
evaluation_results["summary"] = summary
|
| 524 |
+
|
| 525 |
+
# Save detailed results to file
|
| 526 |
+
# This allows for further analysis and comparison between models
|
| 527 |
+
results_file = output_dir / f"evaluation_results_{int(time.time())}.json"
|
| 528 |
+
with open(results_file, "w") as f:
|
| 529 |
+
json.dump(evaluation_results, f, indent=2, default=str)
|
| 530 |
+
|
| 531 |
+
# Display comprehensive results summary
|
| 532 |
+
print("\n" + "=" * 60)
|
| 533 |
+
print("π EVALUATION SUMMARY")
|
| 534 |
+
print("=" * 60)
|
| 535 |
+
print(f"π― Overall Score: {summary['overall_score']:.1f}/100")
|
| 536 |
+
print(f"β±οΈ Evaluation Time: {evaluation_time:.1f} seconds")
|
| 537 |
+
|
| 538 |
+
if summary["strengths"]:
|
| 539 |
+
print("\nβ
Strengths:")
|
| 540 |
+
for strength in summary["strengths"]:
|
| 541 |
+
print(f" β’ {strength}")
|
| 542 |
+
|
| 543 |
+
if summary["weaknesses"]:
|
| 544 |
+
print("\nβ οΈ Areas for Improvement:")
|
| 545 |
+
for weakness in summary["weaknesses"]:
|
| 546 |
+
print(f" β’ {weakness}")
|
| 547 |
+
|
| 548 |
+
if summary["recommendations"]:
|
| 549 |
+
print("\nπ‘ Recommendations:")
|
| 550 |
+
for rec in summary["recommendations"]:
|
| 551 |
+
print(f" β’ {rec}")
|
| 552 |
+
|
| 553 |
+
print(f"\nπΎ Detailed results saved to: {results_file}")
|
| 554 |
+
print("π Evaluation completed successfully!")
|
| 555 |
+
|
| 556 |
+
return True
|
| 557 |
+
|
| 558 |
+
except ImportError as e:
|
| 559 |
+
print(f"β Missing dependencies for evaluation: {e}")
|
| 560 |
+
print(" Please check that all required packages are installed.")
|
| 561 |
+
return False
|
| 562 |
+
|
| 563 |
+
except Exception as e:
|
| 564 |
+
print(f"β Evaluation failed: {e}")
|
| 565 |
+
import traceback
|
| 566 |
+
|
| 567 |
+
traceback.print_exc()
|
| 568 |
+
return False
|
| 569 |
+
|
| 570 |
+
|
| 571 |
+
def cmd_test_model(args):
|
| 572 |
+
"""Execute model testing command."""
|
| 573 |
+
print("π§ͺ Testing model architecture...")
|
| 574 |
+
|
| 575 |
+
try:
|
| 576 |
+
# Initialize model tester
|
| 577 |
+
tester = ModelTester(device=args.device)
|
| 578 |
+
|
| 579 |
+
if args.all_sizes:
|
| 580 |
+
# Test all model sizes
|
| 581 |
+
test_sizes = ["small", "medium", "large"]
|
| 582 |
+
all_success = True
|
| 583 |
+
|
| 584 |
+
for size in test_sizes:
|
| 585 |
+
print(f"\n{'='*20} Testing {size.upper()} Model {'='*20}")
|
| 586 |
+
results = tester.run_comprehensive_test(size)
|
| 587 |
+
|
| 588 |
+
if not results["initialization"]["success"]:
|
| 589 |
+
all_success = False
|
| 590 |
+
print(f"β {size.upper()} model failed initialization")
|
| 591 |
+
else:
|
| 592 |
+
print(f"β {size.upper()} model passed all tests")
|
| 593 |
+
|
| 594 |
+
return all_success
|
| 595 |
+
else:
|
| 596 |
+
# Test single model size
|
| 597 |
+
results = tester.run_comprehensive_test(args.model_size)
|
| 598 |
+
|
| 599 |
+
if args.save_results:
|
| 600 |
+
import json
|
| 601 |
+
|
| 602 |
+
with open(args.save_results, "w") as f:
|
| 603 |
+
json.dump(results, f, indent=2)
|
| 604 |
+
print(f"\nπΎ Results saved to {args.save_results}")
|
| 605 |
+
|
| 606 |
+
return results["initialization"]["success"]
|
| 607 |
+
|
| 608 |
+
except Exception as e:
|
| 609 |
+
print(f"β Model testing failed: {e}")
|
| 610 |
+
return False
|
| 611 |
+
|
| 612 |
+
|
| 613 |
+
def create_parser():
|
| 614 |
+
"""Create the main argument parser with subcommands."""
|
| 615 |
+
parser = argparse.ArgumentParser(
|
| 616 |
+
description="OpenLLM - Open Source Large Language Model Training Pipeline",
|
| 617 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 618 |
+
epilog="""
|
| 619 |
+
Examples:
|
| 620 |
+
# Prepare training data from SQUAD dataset
|
| 621 |
+
python core/src/main.py prepare-data --output data/clean/training_data.txt
|
| 622 |
+
|
| 623 |
+
# Train tokenizer with custom settings
|
| 624 |
+
python core/src/main.py train-tokenizer \\
|
| 625 |
+
--input data/clean/training_data.txt \\
|
| 626 |
+
--vocab-size 32000 \\
|
| 627 |
+
--output-dir data/tokenizer/
|
| 628 |
+
|
| 629 |
+
# Get help for specific commands
|
| 630 |
+
python core/src/main.py train-tokenizer --help
|
| 631 |
+
""",
|
| 632 |
+
)
|
| 633 |
+
|
| 634 |
+
parser.add_argument("--version", action="version", version="OpenLLM v0.1.0")
|
| 635 |
+
|
| 636 |
+
# Create subparsers for different commands
|
| 637 |
+
subparsers = parser.add_subparsers(dest="command", help="Available commands", required=True)
|
| 638 |
+
|
| 639 |
+
# Data preparation command
|
| 640 |
+
parser_data = subparsers.add_parser(
|
| 641 |
+
"prepare-data",
|
| 642 |
+
help="Download and prepare training data from SQUAD dataset",
|
| 643 |
+
description="Downloads SQUAD v1.1 and v2.0 datasets, extracts Wikipedia passages, and prepares clean training text.",
|
| 644 |
+
)
|
| 645 |
+
parser_data.add_argument(
|
| 646 |
+
"--output",
|
| 647 |
+
default="data/clean/training_data.txt",
|
| 648 |
+
help="Output path for cleaned training data (default: data/clean/training_data.txt)",
|
| 649 |
+
)
|
| 650 |
+
parser_data.add_argument(
|
| 651 |
+
"--min-words",
|
| 652 |
+
type=int,
|
| 653 |
+
default=10,
|
| 654 |
+
help="Minimum number of words per passage (default: 10)",
|
| 655 |
+
)
|
| 656 |
+
parser_data.set_defaults(func=cmd_prepare_data)
|
| 657 |
+
|
| 658 |
+
# Tokenizer training command
|
| 659 |
+
parser_tokenizer = subparsers.add_parser(
|
| 660 |
+
"train-tokenizer",
|
| 661 |
+
help="Train a SentencePiece tokenizer on prepared data",
|
| 662 |
+
description="Trains a BPE or Unigram tokenizer using SentencePiece on the prepared training text.",
|
| 663 |
+
)
|
| 664 |
+
parser_tokenizer.add_argument("--input", required=True, help="Path to training text file")
|
| 665 |
+
parser_tokenizer.add_argument(
|
| 666 |
+
"--vocab-size", type=int, default=32000, help="Vocabulary size (default: 32000)"
|
| 667 |
+
)
|
| 668 |
+
parser_tokenizer.add_argument(
|
| 669 |
+
"--model-type",
|
| 670 |
+
choices=["bpe", "unigram"],
|
| 671 |
+
default="bpe",
|
| 672 |
+
help="Tokenization algorithm (default: bpe)",
|
| 673 |
+
)
|
| 674 |
+
parser_tokenizer.add_argument(
|
| 675 |
+
"--output-dir",
|
| 676 |
+
default="data/tokenizer/",
|
| 677 |
+
help="Output directory for tokenizer files (default: data/tokenizer/)",
|
| 678 |
+
)
|
| 679 |
+
parser_tokenizer.add_argument(
|
| 680 |
+
"--character-coverage",
|
| 681 |
+
type=float,
|
| 682 |
+
default=0.9995,
|
| 683 |
+
help="Character coverage (default: 0.9995)",
|
| 684 |
+
)
|
| 685 |
+
parser_tokenizer.add_argument(
|
| 686 |
+
"--max-sentence-length",
|
| 687 |
+
type=int,
|
| 688 |
+
default=4192,
|
| 689 |
+
help="Maximum sentence length (default: 4192)",
|
| 690 |
+
)
|
| 691 |
+
parser_tokenizer.add_argument(
|
| 692 |
+
"--no-test", action="store_true", help="Skip tokenizer testing after training"
|
| 693 |
+
)
|
| 694 |
+
parser_tokenizer.set_defaults(func=cmd_train_tokenizer)
|
| 695 |
+
|
| 696 |
+
# Model testing command
|
| 697 |
+
parser_test = subparsers.add_parser(
|
| 698 |
+
"test-model",
|
| 699 |
+
help="Test and validate model architecture",
|
| 700 |
+
description="Test model initialization, forward pass, memory usage, and tokenizer integration.",
|
| 701 |
+
)
|
| 702 |
+
parser_test.add_argument(
|
| 703 |
+
"--model-size",
|
| 704 |
+
choices=["small", "medium", "large"],
|
| 705 |
+
default="medium",
|
| 706 |
+
help="Model size to test (default: medium)",
|
| 707 |
+
)
|
| 708 |
+
parser_test.add_argument("--all-sizes", action="store_true", help="Test all model sizes")
|
| 709 |
+
parser_test.add_argument(
|
| 710 |
+
"--device",
|
| 711 |
+
choices=["cpu", "cuda", "auto"],
|
| 712 |
+
default="auto",
|
| 713 |
+
help="Device to use for testing (default: auto)",
|
| 714 |
+
)
|
| 715 |
+
parser_test.add_argument("--save-results", help="Save test results to JSON file")
|
| 716 |
+
parser_test.set_defaults(func=cmd_test_model)
|
| 717 |
+
|
| 718 |
+
# Model training command
|
| 719 |
+
parser_model = subparsers.add_parser(
|
| 720 |
+
"train-model",
|
| 721 |
+
help="Train the language model",
|
| 722 |
+
description="Train a GPT-style transformer language model on tokenized text.",
|
| 723 |
+
)
|
| 724 |
+
parser_model.add_argument(
|
| 725 |
+
"--model-size",
|
| 726 |
+
choices=["small", "medium", "large"],
|
| 727 |
+
default="small",
|
| 728 |
+
help="Model size to train (default: small)",
|
| 729 |
+
)
|
| 730 |
+
parser_model.add_argument(
|
| 731 |
+
"--tokenizer-dir",
|
| 732 |
+
default="data/tokenizer/",
|
| 733 |
+
help="Path to trained tokenizer directory (default: data/tokenizer/)",
|
| 734 |
+
)
|
| 735 |
+
parser_model.add_argument(
|
| 736 |
+
"--data-file",
|
| 737 |
+
default="data/clean/training_data.txt",
|
| 738 |
+
help="Path to training text file (default: data/clean/training_data.txt)",
|
| 739 |
+
)
|
| 740 |
+
parser_model.add_argument(
|
| 741 |
+
"--output-dir", required=True, help="Output directory for model checkpoints"
|
| 742 |
+
)
|
| 743 |
+
parser_model.add_argument(
|
| 744 |
+
"--seq-len", type=int, default=512, help="Sequence length for training (default: 512)"
|
| 745 |
+
)
|
| 746 |
+
parser_model.add_argument(
|
| 747 |
+
"--batch-size", type=int, default=4, help="Batch size (default: 4, reduce for low memory)"
|
| 748 |
+
)
|
| 749 |
+
parser_model.add_argument(
|
| 750 |
+
"--learning-rate", type=float, default=3e-4, help="Learning rate (default: 3e-4)"
|
| 751 |
+
)
|
| 752 |
+
parser_model.add_argument(
|
| 753 |
+
"--max-steps", type=int, default=10000, help="Maximum training steps (default: 10000)"
|
| 754 |
+
)
|
| 755 |
+
parser_model.add_argument(
|
| 756 |
+
"--warmup-steps", type=int, default=1000, help="Warmup steps (default: 1000)"
|
| 757 |
+
)
|
| 758 |
+
parser_model.add_argument(
|
| 759 |
+
"--gradient-accumulation-steps",
|
| 760 |
+
type=int,
|
| 761 |
+
default=4,
|
| 762 |
+
help="Gradient accumulation steps (default: 4)",
|
| 763 |
+
)
|
| 764 |
+
parser_model.add_argument(
|
| 765 |
+
"--device",
|
| 766 |
+
choices=["cpu", "cuda", "auto"],
|
| 767 |
+
default="auto",
|
| 768 |
+
help="Training device (default: auto)",
|
| 769 |
+
)
|
| 770 |
+
parser_model.add_argument("--resume", help="Path to checkpoint to resume training from")
|
| 771 |
+
parser_model.add_argument(
|
| 772 |
+
"--save-every", type=int, default=1000, help="Save checkpoint every N steps (default: 1000)"
|
| 773 |
+
)
|
| 774 |
+
parser_model.set_defaults(func=cmd_train_model)
|
| 775 |
+
|
| 776 |
+
# Inference command (placeholder)
|
| 777 |
+
parser_inference = subparsers.add_parser(
|
| 778 |
+
"inference",
|
| 779 |
+
help="Run model inference (coming soon)",
|
| 780 |
+
description="Generate text using a trained model.",
|
| 781 |
+
)
|
| 782 |
+
parser_inference.add_argument("--model-path", required=True, help="Path to trained model")
|
| 783 |
+
parser_inference.add_argument("--prompt", required=True, help="Input text prompt")
|
| 784 |
+
parser_inference.add_argument(
|
| 785 |
+
"--max-length", type=int, default=256, help="Maximum generation length"
|
| 786 |
+
)
|
| 787 |
+
parser_inference.set_defaults(func=cmd_inference)
|
| 788 |
+
|
| 789 |
+
# Evaluation command (placeholder)
|
| 790 |
+
parser_eval = subparsers.add_parser(
|
| 791 |
+
"evaluate",
|
| 792 |
+
help="Evaluate model performance (coming soon)",
|
| 793 |
+
description="Evaluate model on various benchmarks and metrics.",
|
| 794 |
+
)
|
| 795 |
+
parser_eval.add_argument("--model-path", required=True, help="Path to trained model")
|
| 796 |
+
parser_eval.add_argument("--eval-data", help="Path to evaluation dataset")
|
| 797 |
+
parser_eval.add_argument(
|
| 798 |
+
"--metrics", nargs="+", default=["perplexity"], help="Metrics to compute"
|
| 799 |
+
)
|
| 800 |
+
parser_eval.set_defaults(func=cmd_evaluate)
|
| 801 |
+
|
| 802 |
+
# --- Optional: Enterprise module integration ---
|
| 803 |
+
# Load enterprise-only CLI commands if an external module is available.
|
| 804 |
+
# This preserves the core's open-source nature while allowing private
|
| 805 |
+
# extensions to register additional commands without modifying core code.
|
| 806 |
+
try:
|
| 807 |
+
from enterprise_integration import load_enterprise_cli
|
| 808 |
+
|
| 809 |
+
if load_enterprise_cli(subparsers):
|
| 810 |
+
print("π§© Enterprise extensions detected and loaded")
|
| 811 |
+
else:
|
| 812 |
+
# No enterprise plugin found (normal for open-source-only usage)
|
| 813 |
+
pass
|
| 814 |
+
except Exception as e:
|
| 815 |
+
# Never fail core CLI due to enterprise integration issues
|
| 816 |
+
print(f"Warning: Enterprise integration failed: {e}")
|
| 817 |
+
|
| 818 |
+
return parser
|
| 819 |
+
|
| 820 |
+
|
| 821 |
+
def main():
|
| 822 |
+
"""Main entry point for the OpenLLM CLI."""
|
| 823 |
+
parser = create_parser()
|
| 824 |
+
args = parser.parse_args()
|
| 825 |
+
|
| 826 |
+
print("π OpenLLM - Open Source Large Language Model")
|
| 827 |
+
print("=" * 60)
|
| 828 |
+
|
| 829 |
+
# Execute the selected command
|
| 830 |
+
success = args.func(args)
|
| 831 |
+
|
| 832 |
+
# Exit with appropriate code
|
| 833 |
+
if success:
|
| 834 |
+
print("\nπ Command completed successfully!")
|
| 835 |
+
sys.exit(0)
|
| 836 |
+
else:
|
| 837 |
+
print("\nβ Command failed or not implemented yet.")
|
| 838 |
+
sys.exit(1)
|
| 839 |
+
|
| 840 |
+
|
| 841 |
+
if __name__ == "__main__":
|
| 842 |
+
main()
|
core/src/mixed_precision.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Mixed Precision Training Utilities
|
| 4 |
+
|
| 5 |
+
This module provides utilities for mixed precision training using PyTorch's
|
| 6 |
+
automatic mixed precision (AMP) to improve training speed and reduce memory usage.
|
| 7 |
+
|
| 8 |
+
Author: Louis Chua Bean Chong
|
| 9 |
+
License: GPLv3
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
from torch.cuda.amp import autocast, GradScaler
|
| 15 |
+
from typing import Optional, Callable
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class MixedPrecisionTrainer:
|
| 19 |
+
"""
|
| 20 |
+
Mixed precision training wrapper for improved performance.
|
| 21 |
+
|
| 22 |
+
This class provides automatic mixed precision training capabilities
|
| 23 |
+
that can significantly improve training speed and reduce memory usage
|
| 24 |
+
on compatible hardware (especially NVIDIA GPUs with Tensor Cores).
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self,
|
| 28 |
+
model: nn.Module,
|
| 29 |
+
optimizer: torch.optim.Optimizer,
|
| 30 |
+
device: str = "auto",
|
| 31 |
+
dtype: torch.dtype = torch.float16,
|
| 32 |
+
enabled: bool = True):
|
| 33 |
+
"""
|
| 34 |
+
Initialize mixed precision trainer.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
model: The model to train
|
| 38 |
+
optimizer: The optimizer to use
|
| 39 |
+
device: Device to use ("auto", "cpu", "cuda")
|
| 40 |
+
dtype: Precision dtype (float16, bfloat16)
|
| 41 |
+
enabled: Whether to enable mixed precision
|
| 42 |
+
"""
|
| 43 |
+
self.model = model
|
| 44 |
+
self.optimizer = optimizer
|
| 45 |
+
self.device = self._get_device(device)
|
| 46 |
+
self.dtype = dtype
|
| 47 |
+
self.enabled = enabled and self.device.type == "cuda"
|
| 48 |
+
|
| 49 |
+
# Initialize gradient scaler for mixed precision
|
| 50 |
+
self.scaler = GradScaler() if self.enabled else None
|
| 51 |
+
|
| 52 |
+
# Move model to device
|
| 53 |
+
self.model.to(self.device)
|
| 54 |
+
|
| 55 |
+
print(f"Mixed Precision Training: {'Enabled' if self.enabled else 'Disabled'}")
|
| 56 |
+
print(f"Device: {self.device}")
|
| 57 |
+
print(f"Precision: {self.dtype}")
|
| 58 |
+
|
| 59 |
+
def _get_device(self, device: str) -> torch.device:
|
| 60 |
+
"""Get the appropriate device."""
|
| 61 |
+
if device == "auto":
|
| 62 |
+
if torch.cuda.is_available():
|
| 63 |
+
return torch.device("cuda")
|
| 64 |
+
else:
|
| 65 |
+
return torch.device("cpu")
|
| 66 |
+
else:
|
| 67 |
+
return torch.device(device)
|
| 68 |
+
|
| 69 |
+
def train_step(self,
|
| 70 |
+
batch: torch.Tensor,
|
| 71 |
+
targets: torch.Tensor,
|
| 72 |
+
loss_fn: Optional[Callable] = None) -> dict:
|
| 73 |
+
"""
|
| 74 |
+
Perform a single training step with mixed precision.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
batch: Input batch
|
| 78 |
+
targets: Target batch
|
| 79 |
+
loss_fn: Optional custom loss function
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
dict: Training metrics
|
| 83 |
+
"""
|
| 84 |
+
self.model.train()
|
| 85 |
+
self.optimizer.zero_grad()
|
| 86 |
+
|
| 87 |
+
# Move data to device
|
| 88 |
+
batch = batch.to(self.device)
|
| 89 |
+
targets = targets.to(self.device)
|
| 90 |
+
|
| 91 |
+
if self.enabled:
|
| 92 |
+
# Mixed precision forward pass
|
| 93 |
+
with autocast(dtype=self.dtype):
|
| 94 |
+
if loss_fn is None:
|
| 95 |
+
# Use model's built-in loss computation
|
| 96 |
+
logits, loss = self.model(batch, targets)
|
| 97 |
+
else:
|
| 98 |
+
# Use custom loss function
|
| 99 |
+
logits = self.model(batch)
|
| 100 |
+
loss = loss_fn(logits, targets)
|
| 101 |
+
|
| 102 |
+
# Scaled backward pass
|
| 103 |
+
self.scaler.scale(loss).backward()
|
| 104 |
+
self.scaler.step(self.optimizer)
|
| 105 |
+
self.scaler.update()
|
| 106 |
+
else:
|
| 107 |
+
# Standard precision training
|
| 108 |
+
if loss_fn is None:
|
| 109 |
+
logits, loss = self.model(batch, targets)
|
| 110 |
+
else:
|
| 111 |
+
logits = self.model(batch)
|
| 112 |
+
loss = loss_fn(logits, targets)
|
| 113 |
+
|
| 114 |
+
loss.backward()
|
| 115 |
+
self.optimizer.step()
|
| 116 |
+
|
| 117 |
+
return {
|
| 118 |
+
"loss": loss.item(),
|
| 119 |
+
"logits": logits,
|
| 120 |
+
"scaler_scale": self.scaler.get_scale() if self.scaler else 1.0
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
def eval_step(self,
|
| 124 |
+
batch: torch.Tensor,
|
| 125 |
+
targets: torch.Tensor,
|
| 126 |
+
loss_fn: Optional[Callable] = None) -> dict:
|
| 127 |
+
"""
|
| 128 |
+
Perform a single evaluation step.
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
batch: Input batch
|
| 132 |
+
targets: Target batch
|
| 133 |
+
loss_fn: Optional custom loss function
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
dict: Evaluation metrics
|
| 137 |
+
"""
|
| 138 |
+
self.model.eval()
|
| 139 |
+
|
| 140 |
+
# Move data to device
|
| 141 |
+
batch = batch.to(self.device)
|
| 142 |
+
targets = targets.to(self.device)
|
| 143 |
+
|
| 144 |
+
with torch.no_grad():
|
| 145 |
+
if self.enabled:
|
| 146 |
+
with autocast(dtype=self.dtype):
|
| 147 |
+
if loss_fn is None:
|
| 148 |
+
logits, loss = self.model(batch, targets)
|
| 149 |
+
else:
|
| 150 |
+
logits = self.model(batch)
|
| 151 |
+
loss = loss_fn(logits, targets)
|
| 152 |
+
else:
|
| 153 |
+
if loss_fn is None:
|
| 154 |
+
logits, loss = self.model(batch, targets)
|
| 155 |
+
else:
|
| 156 |
+
logits = self.model(batch)
|
| 157 |
+
loss = loss_fn(logits, targets)
|
| 158 |
+
|
| 159 |
+
return {
|
| 160 |
+
"loss": loss.item(),
|
| 161 |
+
"logits": logits
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
def save_checkpoint(self, path: str, **kwargs):
|
| 165 |
+
"""Save model checkpoint with mixed precision state."""
|
| 166 |
+
checkpoint = {
|
| 167 |
+
"model_state_dict": self.model.state_dict(),
|
| 168 |
+
"optimizer_state_dict": self.optimizer.state_dict(),
|
| 169 |
+
"scaler_state_dict": self.scaler.state_dict() if self.scaler else None,
|
| 170 |
+
"dtype": self.dtype,
|
| 171 |
+
"enabled": self.enabled,
|
| 172 |
+
**kwargs
|
| 173 |
+
}
|
| 174 |
+
torch.save(checkpoint, path)
|
| 175 |
+
|
| 176 |
+
def load_checkpoint(self, path: str):
|
| 177 |
+
"""Load model checkpoint with mixed precision state."""
|
| 178 |
+
checkpoint = torch.load(path, map_location=self.device)
|
| 179 |
+
|
| 180 |
+
self.model.load_state_dict(checkpoint["model_state_dict"])
|
| 181 |
+
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
| 182 |
+
|
| 183 |
+
if self.scaler and checkpoint.get("scaler_state_dict"):
|
| 184 |
+
self.scaler.load_state_dict(checkpoint["scaler_state_dict"])
|
| 185 |
+
|
| 186 |
+
return checkpoint
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def enable_mixed_precision(model: nn.Module,
|
| 190 |
+
optimizer: torch.optim.Optimizer,
|
| 191 |
+
**kwargs) -> MixedPrecisionTrainer:
|
| 192 |
+
"""
|
| 193 |
+
Convenience function to enable mixed precision training.
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
model: The model to train
|
| 197 |
+
optimizer: The optimizer to use
|
| 198 |
+
**kwargs: Additional arguments for MixedPrecisionTrainer
|
| 199 |
+
|
| 200 |
+
Returns:
|
| 201 |
+
MixedPrecisionTrainer: Configured trainer
|
| 202 |
+
"""
|
| 203 |
+
return MixedPrecisionTrainer(model, optimizer, **kwargs)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def get_optimal_dtype() -> torch.dtype:
|
| 207 |
+
"""
|
| 208 |
+
Get the optimal dtype for mixed precision training.
|
| 209 |
+
|
| 210 |
+
Returns:
|
| 211 |
+
torch.dtype: Optimal dtype (bfloat16 for newer GPUs, float16 for older)
|
| 212 |
+
"""
|
| 213 |
+
if torch.cuda.is_available():
|
| 214 |
+
# Check if bfloat16 is supported
|
| 215 |
+
if hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'):
|
| 216 |
+
return torch.bfloat16
|
| 217 |
+
else:
|
| 218 |
+
return torch.float16
|
| 219 |
+
else:
|
| 220 |
+
return torch.float32
|
core/src/model.py
ADDED
|
@@ -0,0 +1,665 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright (C) 2024 Louis Chua Bean Chong
|
| 3 |
+
#
|
| 4 |
+
# This file is part of OpenLLM.
|
| 5 |
+
#
|
| 6 |
+
# OpenLLM is dual-licensed:
|
| 7 |
+
# 1. For open source use: GNU General Public License v3.0
|
| 8 |
+
# 2. For commercial use: Commercial License (contact for details)
|
| 9 |
+
#
|
| 10 |
+
# See LICENSE and docs/LICENSES.md for full license information.
|
| 11 |
+
|
| 12 |
+
"""
|
| 13 |
+
GPT-style Language Model Architecture
|
| 14 |
+
|
| 15 |
+
This module implements a standard GPT (Generative Pre-trained Transformer) architecture
|
| 16 |
+
using pure PyTorch. The model is a decoder-only transformer designed for autoregressive
|
| 17 |
+
language modeling (next-token prediction).
|
| 18 |
+
|
| 19 |
+
ARCHITECTURE OVERVIEW:
|
| 20 |
+
- Token Embedding: Maps token IDs to dense vectors
|
| 21 |
+
- Positional Embedding: Adds position information to token embeddings
|
| 22 |
+
- Transformer Blocks: Stack of multi-head attention + feed-forward layers
|
| 23 |
+
- Layer Normalization: Pre-norm placement for training stability
|
| 24 |
+
- Output Head: Linear projection to vocabulary for next-token prediction
|
| 25 |
+
|
| 26 |
+
FEATURES:
|
| 27 |
+
- Configurable model size (small/medium/large)
|
| 28 |
+
- Dropout for regularization
|
| 29 |
+
- Causal (autoregressive) attention masking
|
| 30 |
+
- Compatible with our SentencePiece tokenizer
|
| 31 |
+
- Memory-efficient implementation for training on limited hardware
|
| 32 |
+
|
| 33 |
+
Usage:
|
| 34 |
+
from model import GPTConfig, GPTModel
|
| 35 |
+
|
| 36 |
+
config = GPTConfig(vocab_size=32000, n_layer=12, n_head=12, n_embd=768)
|
| 37 |
+
model = GPTModel(config)
|
| 38 |
+
|
| 39 |
+
# Forward pass
|
| 40 |
+
logits = model(input_ids) # Shape: (batch_size, seq_len, vocab_size)
|
| 41 |
+
|
| 42 |
+
Hardware Requirements:
|
| 43 |
+
- Small Model (25M params): 4-8GB RAM, CPU/integrated GPU
|
| 44 |
+
- Medium Model (117M params): 8-16GB RAM, dedicated GPU recommended
|
| 45 |
+
- Large Model (350M params): 16GB+ RAM, high-end GPU required
|
| 46 |
+
|
| 47 |
+
Author: Louis Chua Bean Chong
|
| 48 |
+
License: GPLv3
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
import math
|
| 52 |
+
from dataclasses import dataclass
|
| 53 |
+
from typing import Optional, Tuple
|
| 54 |
+
|
| 55 |
+
import torch
|
| 56 |
+
import torch.nn as nn
|
| 57 |
+
import torch.nn.functional as F
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@dataclass
|
| 61 |
+
class GPTConfig:
|
| 62 |
+
"""
|
| 63 |
+
Configuration class for GPT model hyperparameters.
|
| 64 |
+
|
| 65 |
+
This class defines all the architectural parameters needed to instantiate
|
| 66 |
+
a GPT model. Use the provided class methods to get pre-configured setups
|
| 67 |
+
for different model sizes.
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
# Model architecture
|
| 71 |
+
vocab_size: int = 32000 # Vocabulary size (from tokenizer)
|
| 72 |
+
n_layer: int = 12 # Number of transformer layers
|
| 73 |
+
n_head: int = 12 # Number of attention heads
|
| 74 |
+
n_embd: int = 768 # Embedding dimension
|
| 75 |
+
|
| 76 |
+
# Sequence and context
|
| 77 |
+
block_size: int = 1024 # Maximum sequence length
|
| 78 |
+
|
| 79 |
+
# Training hyperparameters
|
| 80 |
+
dropout: float = 0.1 # Dropout probability
|
| 81 |
+
bias: bool = True # Use bias in linear layers
|
| 82 |
+
|
| 83 |
+
# Model size identifier
|
| 84 |
+
model_name: str = "gpt-medium" # Human-readable model identifier
|
| 85 |
+
|
| 86 |
+
@classmethod
|
| 87 |
+
def small(cls) -> "GPTConfig":
|
| 88 |
+
"""Small model configuration (~25M parameters) - Good for CPU training"""
|
| 89 |
+
return cls(
|
| 90 |
+
vocab_size=32000,
|
| 91 |
+
n_layer=6,
|
| 92 |
+
n_head=8,
|
| 93 |
+
n_embd=512,
|
| 94 |
+
block_size=1024,
|
| 95 |
+
dropout=0.1,
|
| 96 |
+
model_name="gpt-small",
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
@classmethod
|
| 100 |
+
def medium(cls) -> "GPTConfig":
|
| 101 |
+
"""Medium model configuration (~117M parameters) - Balanced performance"""
|
| 102 |
+
return cls(
|
| 103 |
+
vocab_size=32000,
|
| 104 |
+
n_layer=12,
|
| 105 |
+
n_head=12,
|
| 106 |
+
n_embd=768,
|
| 107 |
+
block_size=2048,
|
| 108 |
+
dropout=0.1,
|
| 109 |
+
model_name="gpt-medium",
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
@classmethod
|
| 113 |
+
def large(cls) -> "GPTConfig":
|
| 114 |
+
"""Large model configuration (~350M parameters) - High performance"""
|
| 115 |
+
return cls(
|
| 116 |
+
vocab_size=32000,
|
| 117 |
+
n_layer=24,
|
| 118 |
+
n_head=16,
|
| 119 |
+
n_embd=1024,
|
| 120 |
+
block_size=2048,
|
| 121 |
+
dropout=0.1,
|
| 122 |
+
model_name="gpt-large",
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
def estimate_parameters(self) -> int:
|
| 126 |
+
"""
|
| 127 |
+
Estimate the total number of trainable parameters.
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
int: Estimated parameter count
|
| 131 |
+
"""
|
| 132 |
+
# Token embeddings
|
| 133 |
+
token_emb = self.vocab_size * self.n_embd
|
| 134 |
+
|
| 135 |
+
# Position embeddings
|
| 136 |
+
pos_emb = self.block_size * self.n_embd
|
| 137 |
+
|
| 138 |
+
# Transformer layers
|
| 139 |
+
# Each layer: attention (4 * n_embd^2) + mlp (8 * n_embd^2) + layer_norms
|
| 140 |
+
layer_params = self.n_layer * (12 * self.n_embd**2 + 4 * self.n_embd)
|
| 141 |
+
|
| 142 |
+
# Output head
|
| 143 |
+
output_head = self.vocab_size * self.n_embd
|
| 144 |
+
|
| 145 |
+
total = token_emb + pos_emb + layer_params + output_head
|
| 146 |
+
return total
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class CausalSelfAttention(nn.Module):
|
| 150 |
+
"""
|
| 151 |
+
Multi-head causal self-attention mechanism.
|
| 152 |
+
|
| 153 |
+
This implements the core attention mechanism of the transformer, with causal
|
| 154 |
+
masking to ensure autoregressive behavior (tokens can only attend to previous
|
| 155 |
+
tokens, not future ones).
|
| 156 |
+
"""
|
| 157 |
+
|
| 158 |
+
def __init__(self, config: GPTConfig):
|
| 159 |
+
super().__init__()
|
| 160 |
+
assert (
|
| 161 |
+
config.n_embd % config.n_head == 0
|
| 162 |
+
), "Embedding dim must be divisible by number of heads"
|
| 163 |
+
|
| 164 |
+
self.config = config
|
| 165 |
+
self.n_head = config.n_head
|
| 166 |
+
self.n_embd = config.n_embd
|
| 167 |
+
self.head_dim = self.n_embd // self.n_head
|
| 168 |
+
|
| 169 |
+
# Key, query, value projections for all heads (batched)
|
| 170 |
+
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
|
| 171 |
+
|
| 172 |
+
# Output projection
|
| 173 |
+
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
|
| 174 |
+
|
| 175 |
+
# Dropout
|
| 176 |
+
self.attn_dropout = nn.Dropout(config.dropout)
|
| 177 |
+
self.resid_dropout = nn.Dropout(config.dropout)
|
| 178 |
+
|
| 179 |
+
# Causal mask - lower triangular matrix
|
| 180 |
+
self.register_buffer(
|
| 181 |
+
"bias",
|
| 182 |
+
torch.tril(torch.ones(config.block_size, config.block_size)).view(
|
| 183 |
+
1, 1, config.block_size, config.block_size
|
| 184 |
+
),
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 188 |
+
"""
|
| 189 |
+
Forward pass of causal self-attention.
|
| 190 |
+
|
| 191 |
+
This method implements the scaled dot-product attention mechanism with causal masking.
|
| 192 |
+
The attention mechanism allows each token to attend to all previous tokens in the sequence,
|
| 193 |
+
but not to future tokens, maintaining the autoregressive property essential for language modeling.
|
| 194 |
+
|
| 195 |
+
Mathematical formulation:
|
| 196 |
+
Attention(Q, K, V) = softmax(QK^T / sqrt(d_k))V
|
| 197 |
+
where Q, K, V are query, key, value matrices derived from input x
|
| 198 |
+
|
| 199 |
+
Implementation details:
|
| 200 |
+
- Uses batch matrix multiplication for efficiency
|
| 201 |
+
- Applies causal mask to prevent future token attention
|
| 202 |
+
- Implements multi-head attention by reshaping and parallel processing
|
| 203 |
+
- Applies dropout for regularization during training
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
x: Input tensor of shape (batch_size, seq_len, n_embd)
|
| 207 |
+
Contains embedded token representations from previous layer
|
| 208 |
+
|
| 209 |
+
Returns:
|
| 210 |
+
torch.Tensor: Output tensor of shape (batch_size, seq_len, n_embd)
|
| 211 |
+
"""
|
| 212 |
+
# Extract tensor dimensions for clear variable naming and validation
|
| 213 |
+
# B = batch size (number of sequences processed in parallel)
|
| 214 |
+
# T = sequence length (number of tokens in each sequence)
|
| 215 |
+
# C = embedding dimensionality (n_embd from config)
|
| 216 |
+
B, T, C = x.size()
|
| 217 |
+
|
| 218 |
+
# Generate query, key, and value projections for all attention heads
|
| 219 |
+
# The c_attn linear layer outputs 3 * n_embd features, which we split into Q, K, V
|
| 220 |
+
# This batched approach is more efficient than separate linear layers
|
| 221 |
+
# Input shape: (B, T, C) -> Output shape: (B, T, 3*C) -> Split to 3x (B, T, C)
|
| 222 |
+
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
|
| 223 |
+
|
| 224 |
+
# Reshape tensors for multi-head attention computation
|
| 225 |
+
# Transform from (B, T, C) to (B, nh, T, hs) where:
|
| 226 |
+
# - nh = number of heads (self.n_head)
|
| 227 |
+
# - hs = head size (self.head_dim = C // nh)
|
| 228 |
+
# The transpose(1, 2) moves the head dimension before sequence dimension for efficient computation
|
| 229 |
+
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
|
| 230 |
+
k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
|
| 231 |
+
v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
|
| 232 |
+
|
| 233 |
+
# Compute scaled dot-product attention scores
|
| 234 |
+
# Matrix multiplication: Q @ K^T gives attention affinities between all token pairs
|
| 235 |
+
# Scaling by 1/sqrt(head_dim) prevents softmax saturation for large embedding dimensions
|
| 236 |
+
# Shape: (B, nh, T, hs) @ (B, nh, hs, T) -> (B, nh, T, T)
|
| 237 |
+
# The resulting (T, T) matrix represents attention weights from each token to every other token
|
| 238 |
+
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))
|
| 239 |
+
|
| 240 |
+
# Apply causal masking to enforce autoregressive property
|
| 241 |
+
# The causal mask ensures that token i can only attend to tokens j where j <= i
|
| 242 |
+
# This prevents the model from "cheating" by looking at future tokens during training
|
| 243 |
+
# We use -inf for masked positions so they become 0 after softmax
|
| 244 |
+
att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
|
| 245 |
+
|
| 246 |
+
# Convert attention scores to probabilities using softmax
|
| 247 |
+
# Each row of the attention matrix now sums to 1, representing a probability distribution
|
| 248 |
+
# over which tokens to attend to for each query position
|
| 249 |
+
att = F.softmax(att, dim=-1)
|
| 250 |
+
|
| 251 |
+
# Apply dropout to attention weights for regularization
|
| 252 |
+
# This randomly zeros some attention connections during training to prevent overfitting
|
| 253 |
+
att = self.attn_dropout(att)
|
| 254 |
+
|
| 255 |
+
# Apply attention weights to value vectors
|
| 256 |
+
# This weighted combination produces the actual output of the attention mechanism
|
| 257 |
+
# Shape: (B, nh, T, T) @ (B, nh, T, hs) -> (B, nh, T, hs)
|
| 258 |
+
# Each output position is a weighted sum of all value vectors, with weights from attention
|
| 259 |
+
y = att @ v
|
| 260 |
+
|
| 261 |
+
# Concatenate multi-head outputs back to original embedding dimension
|
| 262 |
+
# Transform from (B, nh, T, hs) back to (B, T, C) where C = nh * hs
|
| 263 |
+
# The transpose moves head dimension back, and contiguous() ensures memory layout efficiency
|
| 264 |
+
# This combines information from all attention heads into a single representation
|
| 265 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C)
|
| 266 |
+
|
| 267 |
+
# Apply final output projection and residual dropout
|
| 268 |
+
# The output projection allows the model to learn how to best combine multi-head information
|
| 269 |
+
# Residual dropout provides additional regularization before the residual connection
|
| 270 |
+
y = self.resid_dropout(self.c_proj(y))
|
| 271 |
+
return y
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
class MLP(nn.Module):
|
| 275 |
+
"""
|
| 276 |
+
Multi-Layer Perceptron (Feed-Forward Network) for Transformer.
|
| 277 |
+
|
| 278 |
+
This implements the position-wise feed-forward network that appears in each transformer layer.
|
| 279 |
+
The MLP provides additional non-linear transformation capacity beyond what attention provides.
|
| 280 |
+
|
| 281 |
+
Architecture:
|
| 282 |
+
Input -> Linear(n_embd -> 4*n_embd) -> GELU -> Linear(4*n_embd -> n_embd) -> Dropout -> Output
|
| 283 |
+
|
| 284 |
+
Design rationale:
|
| 285 |
+
- 4x expansion is standard in transformers (from "Attention Is All You Need")
|
| 286 |
+
- GELU activation provides smoother gradients than ReLU for language modeling
|
| 287 |
+
- Dropout prevents overfitting in the feed-forward layers
|
| 288 |
+
- Two linear layers allow complex non-linear transformations of attention outputs
|
| 289 |
+
|
| 290 |
+
Parameters:
|
| 291 |
+
- First linear layer: n_embd * 4*n_embd parameters (expansion)
|
| 292 |
+
- Second linear layer: 4*n_embd * n_embd parameters (projection back)
|
| 293 |
+
- Total: 8 * n_embd^2 parameters (significant portion of model size)
|
| 294 |
+
"""
|
| 295 |
+
|
| 296 |
+
def __init__(self, config: GPTConfig):
|
| 297 |
+
super().__init__()
|
| 298 |
+
|
| 299 |
+
# First linear layer: expand embedding dimension by 4x
|
| 300 |
+
# This expansion gives the network more representational capacity
|
| 301 |
+
# The 4x factor is a standard choice that balances capacity vs efficiency
|
| 302 |
+
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
|
| 303 |
+
|
| 304 |
+
# GELU (Gaussian Error Linear Unit) activation function
|
| 305 |
+
# GELU provides smoother gradients compared to ReLU and works better for language modeling
|
| 306 |
+
# It's approximately: GELU(x) = x * Ξ¦(x) where Ξ¦ is the CDF of standard normal distribution
|
| 307 |
+
self.gelu = nn.GELU()
|
| 308 |
+
|
| 309 |
+
# Second linear layer: project back to original embedding dimension
|
| 310 |
+
# This projection allows the network to combine information from the expanded representation
|
| 311 |
+
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
|
| 312 |
+
|
| 313 |
+
# Dropout for regularization in the feed-forward network
|
| 314 |
+
# Applied after the final projection to prevent overfitting
|
| 315 |
+
self.dropout = nn.Dropout(config.dropout)
|
| 316 |
+
|
| 317 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 318 |
+
"""
|
| 319 |
+
Forward pass of the feed-forward network.
|
| 320 |
+
|
| 321 |
+
This method applies a two-layer MLP with GELU activation to transform
|
| 322 |
+
the attention outputs. The MLP operates independently on each position
|
| 323 |
+
in the sequence, providing position-wise non-linear transformations.
|
| 324 |
+
|
| 325 |
+
Mathematical operation:
|
| 326 |
+
MLP(x) = Dropout(Linearβ(GELU(Linearβ(x))))
|
| 327 |
+
where Linearβ: R^n_embd -> R^4*n_embd and Linearβ: R^4*n_embd -> R^n_embd
|
| 328 |
+
|
| 329 |
+
Args:
|
| 330 |
+
x: Input tensor of shape (batch_size, seq_len, n_embd)
|
| 331 |
+
Contains attended representations from the attention layer
|
| 332 |
+
|
| 333 |
+
Returns:
|
| 334 |
+
torch.Tensor: Output tensor of shape (batch_size, seq_len, n_embd)
|
| 335 |
+
Contains transformed representations ready for residual connection
|
| 336 |
+
"""
|
| 337 |
+
# First linear transformation: expand from n_embd to 4*n_embd dimensions
|
| 338 |
+
# This expansion provides the network with a higher-dimensional space for computation
|
| 339 |
+
# Shape: (batch_size, seq_len, n_embd) -> (batch_size, seq_len, 4*n_embd)
|
| 340 |
+
x = self.c_fc(x)
|
| 341 |
+
|
| 342 |
+
# Apply GELU activation function for non-linearity
|
| 343 |
+
# GELU is smoother than ReLU and provides better gradients for language modeling
|
| 344 |
+
# It introduces non-linearity while maintaining differentiability everywhere
|
| 345 |
+
x = self.gelu(x)
|
| 346 |
+
|
| 347 |
+
# Second linear transformation: project back to original n_embd dimensions
|
| 348 |
+
# This projection combines information from the expanded representation
|
| 349 |
+
# Shape: (batch_size, seq_len, 4*n_embd) -> (batch_size, seq_len, n_embd)
|
| 350 |
+
x = self.c_proj(x)
|
| 351 |
+
|
| 352 |
+
# Apply dropout for regularization before residual connection
|
| 353 |
+
# Dropout randomly zeros some neurons during training to prevent overfitting
|
| 354 |
+
# This is particularly important in the feed-forward layers which have many parameters
|
| 355 |
+
x = self.dropout(x)
|
| 356 |
+
|
| 357 |
+
return x
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
class Block(nn.Module):
|
| 361 |
+
"""
|
| 362 |
+
Single Transformer block.
|
| 363 |
+
|
| 364 |
+
Consists of:
|
| 365 |
+
1. Layer normalization
|
| 366 |
+
2. Multi-head causal self-attention
|
| 367 |
+
3. Residual connection
|
| 368 |
+
4. Layer normalization
|
| 369 |
+
5. MLP (feed-forward network)
|
| 370 |
+
6. Residual connection
|
| 371 |
+
|
| 372 |
+
Uses pre-norm architecture for better training stability.
|
| 373 |
+
"""
|
| 374 |
+
|
| 375 |
+
def __init__(self, config: GPTConfig):
|
| 376 |
+
super().__init__()
|
| 377 |
+
self.ln_1 = nn.LayerNorm(config.n_embd)
|
| 378 |
+
self.attn = CausalSelfAttention(config)
|
| 379 |
+
self.ln_2 = nn.LayerNorm(config.n_embd)
|
| 380 |
+
self.mlp = MLP(config)
|
| 381 |
+
|
| 382 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 383 |
+
"""
|
| 384 |
+
Forward pass of transformer block.
|
| 385 |
+
|
| 386 |
+
Args:
|
| 387 |
+
x: Input tensor of shape (batch_size, seq_len, n_embd)
|
| 388 |
+
|
| 389 |
+
Returns:
|
| 390 |
+
torch.Tensor: Output tensor of shape (batch_size, seq_len, n_embd)
|
| 391 |
+
"""
|
| 392 |
+
# Pre-norm attention with residual connection
|
| 393 |
+
x = x + self.attn(self.ln_1(x))
|
| 394 |
+
|
| 395 |
+
# Pre-norm MLP with residual connection
|
| 396 |
+
x = x + self.mlp(self.ln_2(x))
|
| 397 |
+
|
| 398 |
+
return x
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
class GPTModel(nn.Module):
|
| 402 |
+
"""
|
| 403 |
+
Complete GPT Language Model.
|
| 404 |
+
|
| 405 |
+
This is the main model class that combines all components:
|
| 406 |
+
- Token and positional embeddings
|
| 407 |
+
- Stack of transformer blocks
|
| 408 |
+
- Final layer normalization
|
| 409 |
+
- Language modeling head
|
| 410 |
+
|
| 411 |
+
The model can be used for:
|
| 412 |
+
- Training from scratch on text data
|
| 413 |
+
- Fine-tuning on downstream tasks
|
| 414 |
+
- Text generation (inference)
|
| 415 |
+
"""
|
| 416 |
+
|
| 417 |
+
def __init__(self, config: GPTConfig, use_checkpoint=True):
|
| 418 |
+
super().__init__()
|
| 419 |
+
assert config.vocab_size is not None, "vocab_size must be specified"
|
| 420 |
+
assert config.block_size is not None, "block_size must be specified"
|
| 421 |
+
|
| 422 |
+
self.config = config
|
| 423 |
+
self.use_checkpoint = use_checkpoint
|
| 424 |
+
|
| 425 |
+
# Embeddings
|
| 426 |
+
self.transformer = nn.ModuleDict(
|
| 427 |
+
dict(
|
| 428 |
+
wte=nn.Embedding(config.vocab_size, config.n_embd), # Token embeddings
|
| 429 |
+
wpe=nn.Embedding(config.block_size, config.n_embd), # Position embeddings
|
| 430 |
+
drop=nn.Dropout(config.dropout),
|
| 431 |
+
h=nn.ModuleList(
|
| 432 |
+
[Block(config) for _ in range(config.n_layer)]
|
| 433 |
+
), # Transformer blocks
|
| 434 |
+
ln_f=nn.LayerNorm(config.n_embd), # Final layer norm
|
| 435 |
+
)
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
# Language modeling head (maps hidden states to vocabulary)
|
| 439 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
| 440 |
+
|
| 441 |
+
# Tie weights between token embeddings and output head (common practice)
|
| 442 |
+
self.transformer.wte.weight = self.lm_head.weight
|
| 443 |
+
|
| 444 |
+
# Initialize weights
|
| 445 |
+
self.apply(self._init_weights)
|
| 446 |
+
|
| 447 |
+
# Report parameter count
|
| 448 |
+
print(f"Model initialized: {self.config.model_name}")
|
| 449 |
+
print(f"Parameters: {self.get_num_params():,}")
|
| 450 |
+
print(f"Estimated: {self.config.estimate_parameters():,}")
|
| 451 |
+
|
| 452 |
+
def _init_weights(self, module):
|
| 453 |
+
"""Initialize model weights using standard practices."""
|
| 454 |
+
if isinstance(module, nn.Linear):
|
| 455 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 456 |
+
if module.bias is not None:
|
| 457 |
+
torch.nn.init.zeros_(module.bias)
|
| 458 |
+
elif isinstance(module, nn.Embedding):
|
| 459 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 460 |
+
|
| 461 |
+
def get_num_params(self, non_embedding: bool = False) -> int:
|
| 462 |
+
"""
|
| 463 |
+
Count the number of parameters in the model.
|
| 464 |
+
|
| 465 |
+
Args:
|
| 466 |
+
non_embedding: If True, subtract embedding parameters
|
| 467 |
+
|
| 468 |
+
Returns:
|
| 469 |
+
int: Number of parameters
|
| 470 |
+
"""
|
| 471 |
+
n_params = sum(p.numel() for p in self.parameters())
|
| 472 |
+
if non_embedding:
|
| 473 |
+
n_params -= self.transformer.wpe.weight.numel()
|
| 474 |
+
n_params -= self.transformer.wte.weight.numel()
|
| 475 |
+
return n_params
|
| 476 |
+
|
| 477 |
+
def forward(
|
| 478 |
+
self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None
|
| 479 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 480 |
+
"""
|
| 481 |
+
Forward pass of the GPT model.
|
| 482 |
+
|
| 483 |
+
Args:
|
| 484 |
+
idx: Input token indices of shape (batch_size, seq_len)
|
| 485 |
+
targets: Optional target tokens for loss calculation (batch_size, seq_len)
|
| 486 |
+
|
| 487 |
+
Returns:
|
| 488 |
+
Tuple containing:
|
| 489 |
+
- logits: Output logits of shape (batch_size, seq_len, vocab_size)
|
| 490 |
+
- loss: Cross-entropy loss if targets provided, None otherwise
|
| 491 |
+
"""
|
| 492 |
+
device = idx.device
|
| 493 |
+
b, t = idx.size()
|
| 494 |
+
assert (
|
| 495 |
+
t <= self.config.block_size
|
| 496 |
+
), f"Sequence length {t} exceeds block size {self.config.block_size}"
|
| 497 |
+
|
| 498 |
+
# Token embeddings
|
| 499 |
+
tok_emb = self.transformer.wte(idx) # (b, t, n_embd)
|
| 500 |
+
|
| 501 |
+
# Position embeddings
|
| 502 |
+
pos = torch.arange(0, t, dtype=torch.long, device=device) # (t,)
|
| 503 |
+
pos_emb = self.transformer.wpe(pos) # (t, n_embd)
|
| 504 |
+
|
| 505 |
+
# Combine embeddings and apply dropout
|
| 506 |
+
x = self.transformer.drop(tok_emb + pos_emb)
|
| 507 |
+
|
| 508 |
+
# Pass through transformer blocks with optional gradient checkpointing
|
| 509 |
+
if self.use_checkpoint and self.training:
|
| 510 |
+
# Use gradient checkpointing to save memory during training
|
| 511 |
+
for block in self.transformer.h:
|
| 512 |
+
x = torch.utils.checkpoint.checkpoint(block, x)
|
| 513 |
+
else:
|
| 514 |
+
# Standard forward pass
|
| 515 |
+
for block in self.transformer.h:
|
| 516 |
+
x = block(x)
|
| 517 |
+
|
| 518 |
+
# Final layer normalization
|
| 519 |
+
x = self.transformer.ln_f(x)
|
| 520 |
+
|
| 521 |
+
# Language modeling head
|
| 522 |
+
# Always compute full logits for training and evaluation
|
| 523 |
+
logits = self.lm_head(x)
|
| 524 |
+
|
| 525 |
+
if targets is not None:
|
| 526 |
+
# If we have targets, compute loss
|
| 527 |
+
loss = F.cross_entropy(
|
| 528 |
+
logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1
|
| 529 |
+
)
|
| 530 |
+
else:
|
| 531 |
+
# If no targets, no loss computation
|
| 532 |
+
loss = None
|
| 533 |
+
|
| 534 |
+
return logits, loss
|
| 535 |
+
|
| 536 |
+
def generate(
|
| 537 |
+
self,
|
| 538 |
+
idx: torch.Tensor,
|
| 539 |
+
max_new_tokens: int = 100,
|
| 540 |
+
temperature: float = 1.0,
|
| 541 |
+
top_k: Optional[int] = None,
|
| 542 |
+
) -> torch.Tensor:
|
| 543 |
+
"""
|
| 544 |
+
Generate new tokens autoregressively.
|
| 545 |
+
|
| 546 |
+
Args:
|
| 547 |
+
idx: Starting token indices (batch_size, seq_len)
|
| 548 |
+
max_new_tokens: Maximum number of new tokens to generate
|
| 549 |
+
temperature: Sampling temperature (higher = more random)
|
| 550 |
+
top_k: If set, only sample from top-k most likely tokens
|
| 551 |
+
|
| 552 |
+
Returns:
|
| 553 |
+
torch.Tensor: Generated sequence (batch_size, seq_len + max_new_tokens)
|
| 554 |
+
"""
|
| 555 |
+
self.eval()
|
| 556 |
+
with torch.no_grad():
|
| 557 |
+
for _ in range(max_new_tokens):
|
| 558 |
+
# Crop sequence if it exceeds block size
|
| 559 |
+
idx_cond = (
|
| 560 |
+
idx
|
| 561 |
+
if idx.size(1) <= self.config.block_size
|
| 562 |
+
else idx[:, -self.config.block_size :]
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
+
# Forward pass
|
| 566 |
+
logits, _ = self(idx_cond)
|
| 567 |
+
|
| 568 |
+
# Get logits for the last token and apply temperature
|
| 569 |
+
logits = logits[:, -1, :] / temperature
|
| 570 |
+
|
| 571 |
+
# Optionally crop to top-k most likely tokens
|
| 572 |
+
if top_k is not None:
|
| 573 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 574 |
+
logits[logits < v[:, [-1]]] = -float("inf")
|
| 575 |
+
|
| 576 |
+
# Apply softmax and sample
|
| 577 |
+
probs = F.softmax(logits, dim=-1)
|
| 578 |
+
idx_next = torch.multinomial(probs, num_samples=1)
|
| 579 |
+
|
| 580 |
+
# Append to sequence
|
| 581 |
+
idx = torch.cat((idx, idx_next), dim=1)
|
| 582 |
+
|
| 583 |
+
self.train() # Return to training mode
|
| 584 |
+
return idx
|
| 585 |
+
|
| 586 |
+
def estimate_memory_usage(self, batch_size: int = 1, seq_len: int = None) -> dict:
|
| 587 |
+
"""
|
| 588 |
+
Estimate memory usage for training and inference.
|
| 589 |
+
|
| 590 |
+
Args:
|
| 591 |
+
batch_size: Batch size for estimation
|
| 592 |
+
seq_len: Sequence length (defaults to block_size)
|
| 593 |
+
|
| 594 |
+
Returns:
|
| 595 |
+
dict: Memory usage estimates in MB
|
| 596 |
+
"""
|
| 597 |
+
if seq_len is None:
|
| 598 |
+
seq_len = self.config.block_size
|
| 599 |
+
|
| 600 |
+
# Model parameters (weights)
|
| 601 |
+
param_memory = self.get_num_params() * 4 / (1024**2) # 4 bytes per float32
|
| 602 |
+
|
| 603 |
+
# Activations (rough estimate)
|
| 604 |
+
activation_memory = (
|
| 605 |
+
batch_size * seq_len * self.config.n_embd * self.config.n_layer * 8 # Rough estimate
|
| 606 |
+
) / (1024**2)
|
| 607 |
+
|
| 608 |
+
# Gradients (same size as parameters during training)
|
| 609 |
+
gradient_memory = param_memory
|
| 610 |
+
|
| 611 |
+
return {
|
| 612 |
+
"parameters_mb": param_memory,
|
| 613 |
+
"activations_mb": activation_memory,
|
| 614 |
+
"gradients_mb": gradient_memory,
|
| 615 |
+
"total_training_mb": param_memory + activation_memory + gradient_memory,
|
| 616 |
+
"total_inference_mb": param_memory + activation_memory * 0.5, # No gradients needed
|
| 617 |
+
}
|
| 618 |
+
|
| 619 |
+
|
| 620 |
+
def create_model(model_size: str = "medium") -> GPTModel:
|
| 621 |
+
"""
|
| 622 |
+
Factory function to create a GPT model with predefined configurations.
|
| 623 |
+
|
| 624 |
+
Args:
|
| 625 |
+
model_size: Size of model to create ("small", "medium", "large")
|
| 626 |
+
|
| 627 |
+
Returns:
|
| 628 |
+
GPTModel: Initialized model
|
| 629 |
+
"""
|
| 630 |
+
configs = {
|
| 631 |
+
"small": GPTConfig.small(),
|
| 632 |
+
"medium": GPTConfig.medium(),
|
| 633 |
+
"large": GPTConfig.large(),
|
| 634 |
+
}
|
| 635 |
+
|
| 636 |
+
if model_size not in configs:
|
| 637 |
+
raise ValueError(f"Unknown model size: {model_size}. Choose from {list(configs.keys())}")
|
| 638 |
+
|
| 639 |
+
config = configs[model_size]
|
| 640 |
+
model = GPTModel(config)
|
| 641 |
+
|
| 642 |
+
return model
|
| 643 |
+
|
| 644 |
+
|
| 645 |
+
if __name__ == "__main__":
|
| 646 |
+
# Example usage
|
| 647 |
+
print("π§ GPT Model Architecture")
|
| 648 |
+
print("=" * 50)
|
| 649 |
+
|
| 650 |
+
# Create models of different sizes
|
| 651 |
+
for size in ["small", "medium", "large"]:
|
| 652 |
+
print(f"\n{size.upper()} MODEL:")
|
| 653 |
+
model = create_model(size)
|
| 654 |
+
|
| 655 |
+
# Show memory estimates
|
| 656 |
+
memory = model.estimate_memory_usage(batch_size=4, seq_len=512)
|
| 657 |
+
print(
|
| 658 |
+
f"Memory (4 batch, 512 seq): {memory['total_training_mb']:.1f}MB training, {memory['total_inference_mb']:.1f}MB inference"
|
| 659 |
+
)
|
| 660 |
+
|
| 661 |
+
# Test forward pass
|
| 662 |
+
x = torch.randint(0, 32000, (2, 64)) # Batch size 2, sequence length 64
|
| 663 |
+
with torch.no_grad():
|
| 664 |
+
logits, _ = model(x)
|
| 665 |
+
print(f"Test forward pass: {x.shape} -> {logits.shape} β")
|
core/src/model_test.py
ADDED
|
@@ -0,0 +1,564 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright (C) 2024 Louis Chua Bean Chong
|
| 3 |
+
#
|
| 4 |
+
# This file is part of OpenLLM.
|
| 5 |
+
#
|
| 6 |
+
# OpenLLM is dual-licensed:
|
| 7 |
+
# 1. For open source use: GNU General Public License v3.0
|
| 8 |
+
# 2. For commercial use: Commercial License (contact for details)
|
| 9 |
+
#
|
| 10 |
+
# See LICENSE and docs/LICENSES.md for full license information.
|
| 11 |
+
|
| 12 |
+
"""
|
| 13 |
+
Model Architecture Testing and Validation Script
|
| 14 |
+
|
| 15 |
+
This script provides comprehensive testing and validation for the GPT model architecture.
|
| 16 |
+
It helps verify that the model is correctly implemented and can run on your hardware.
|
| 17 |
+
|
| 18 |
+
FEATURES:
|
| 19 |
+
- Model initialization testing
|
| 20 |
+
- Forward pass validation
|
| 21 |
+
- Memory usage analysis
|
| 22 |
+
- Tokenizer integration testing
|
| 23 |
+
- Performance benchmarking
|
| 24 |
+
- Hardware compatibility checks
|
| 25 |
+
|
| 26 |
+
Usage:
|
| 27 |
+
python core/src/test_model.py --model_size medium
|
| 28 |
+
python core/src/test_model.py --model_size small --test_generation
|
| 29 |
+
python core/src/test_model.py --all_sizes --benchmark
|
| 30 |
+
|
| 31 |
+
Requirements:
|
| 32 |
+
- torch
|
| 33 |
+
- sentencepiece (for tokenizer integration)
|
| 34 |
+
- Our trained tokenizer in data/tokenizer/
|
| 35 |
+
|
| 36 |
+
Author: Louis Chua Bean Chong
|
| 37 |
+
License: GPLv3
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
import argparse
|
| 41 |
+
import json
|
| 42 |
+
import os
|
| 43 |
+
import time
|
| 44 |
+
import traceback
|
| 45 |
+
from typing import Dict, List
|
| 46 |
+
|
| 47 |
+
import torch
|
| 48 |
+
|
| 49 |
+
# Import our model architecture
|
| 50 |
+
try:
|
| 51 |
+
from model import GPTModel, create_model
|
| 52 |
+
except ImportError:
|
| 53 |
+
import sys
|
| 54 |
+
|
| 55 |
+
sys.path.append(os.path.dirname(__file__))
|
| 56 |
+
from model import GPTModel, create_model
|
| 57 |
+
|
| 58 |
+
# Import tokenizer if available
|
| 59 |
+
try:
|
| 60 |
+
import sentencepiece as spm
|
| 61 |
+
|
| 62 |
+
TOKENIZER_AVAILABLE = True
|
| 63 |
+
except ImportError:
|
| 64 |
+
TOKENIZER_AVAILABLE = False
|
| 65 |
+
print("Warning: SentencePiece not available. Tokenizer tests will be skipped.")
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class ModelTester:
|
| 69 |
+
"""
|
| 70 |
+
Comprehensive model testing class.
|
| 71 |
+
|
| 72 |
+
Provides methods to test model initialization, forward passes, memory usage,
|
| 73 |
+
and integration with the tokenizer.
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
def __init__(self, device: str = "auto"):
|
| 77 |
+
"""
|
| 78 |
+
Initialize the model tester.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
device: Device to use ("cpu", "cuda", or "auto")
|
| 82 |
+
"""
|
| 83 |
+
if device == "auto":
|
| 84 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 85 |
+
else:
|
| 86 |
+
self.device = device
|
| 87 |
+
|
| 88 |
+
print("π§ Model Tester initialized")
|
| 89 |
+
print(f"Device: {self.device}")
|
| 90 |
+
print(f"PyTorch version: {torch.__version__}")
|
| 91 |
+
|
| 92 |
+
# Try to load tokenizer
|
| 93 |
+
self.tokenizer = None
|
| 94 |
+
self.load_tokenizer()
|
| 95 |
+
|
| 96 |
+
def load_tokenizer(self) -> None:
|
| 97 |
+
"""Load the trained SentencePiece tokenizer if available."""
|
| 98 |
+
if not TOKENIZER_AVAILABLE:
|
| 99 |
+
return
|
| 100 |
+
|
| 101 |
+
tokenizer_path = "data/tokenizer/tokenizer.model"
|
| 102 |
+
if os.path.exists(tokenizer_path):
|
| 103 |
+
try:
|
| 104 |
+
self.tokenizer = spm.SentencePieceProcessor()
|
| 105 |
+
self.tokenizer.load(tokenizer_path)
|
| 106 |
+
print(f"β Tokenizer loaded: {tokenizer_path}")
|
| 107 |
+
print(f" Vocabulary size: {self.tokenizer.vocab_size():,}")
|
| 108 |
+
except Exception as e:
|
| 109 |
+
print(f"β οΈ Failed to load tokenizer: {e}")
|
| 110 |
+
else:
|
| 111 |
+
print(f"β οΈ Tokenizer not found at {tokenizer_path}")
|
| 112 |
+
|
| 113 |
+
def test_model_initialization(self, model_size: str = "medium") -> Dict:
|
| 114 |
+
"""
|
| 115 |
+
Test model initialization and basic properties.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
model_size: Size of model to test
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
dict: Test results
|
| 122 |
+
"""
|
| 123 |
+
print(f"\nπ§ Testing {model_size.upper()} model initialization...")
|
| 124 |
+
|
| 125 |
+
try:
|
| 126 |
+
# Create model
|
| 127 |
+
start_time = time.time()
|
| 128 |
+
model = create_model(model_size)
|
| 129 |
+
init_time = time.time() - start_time
|
| 130 |
+
|
| 131 |
+
# Move to device
|
| 132 |
+
model = model.to(self.device)
|
| 133 |
+
|
| 134 |
+
# Basic checks
|
| 135 |
+
param_count = model.get_num_params()
|
| 136 |
+
config = model.config
|
| 137 |
+
|
| 138 |
+
print("β Model created successfully")
|
| 139 |
+
print(f" Parameters: {param_count:,}")
|
| 140 |
+
print(f" Layers: {config.n_layer}")
|
| 141 |
+
print(f" Heads: {config.n_head}")
|
| 142 |
+
print(f" Embedding dim: {config.n_embd}")
|
| 143 |
+
print(f" Block size: {config.block_size}")
|
| 144 |
+
print(f" Initialization time: {init_time:.2f}s")
|
| 145 |
+
|
| 146 |
+
return {
|
| 147 |
+
"success": True,
|
| 148 |
+
"model_size": model_size,
|
| 149 |
+
"parameters": param_count,
|
| 150 |
+
"config": config.__dict__,
|
| 151 |
+
"init_time": init_time,
|
| 152 |
+
"device": str(next(model.parameters()).device),
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
except Exception as e:
|
| 156 |
+
print(f"β Model initialization failed: {e}")
|
| 157 |
+
traceback.print_exc()
|
| 158 |
+
return {"success": False, "error": str(e)}
|
| 159 |
+
|
| 160 |
+
def test_forward_pass(self, model: GPTModel, batch_size: int = 2, seq_len: int = 64) -> Dict:
|
| 161 |
+
"""
|
| 162 |
+
Test model forward pass with synthetic data.
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
model: Model to test
|
| 166 |
+
batch_size: Batch size for test
|
| 167 |
+
seq_len: Sequence length for test
|
| 168 |
+
|
| 169 |
+
Returns:
|
| 170 |
+
dict: Test results
|
| 171 |
+
"""
|
| 172 |
+
print(f"\nπ Testing forward pass (batch={batch_size}, seq_len={seq_len})...")
|
| 173 |
+
|
| 174 |
+
try:
|
| 175 |
+
model.eval()
|
| 176 |
+
|
| 177 |
+
# Create synthetic input
|
| 178 |
+
x = torch.randint(0, model.config.vocab_size, (batch_size, seq_len))
|
| 179 |
+
x = x.to(self.device)
|
| 180 |
+
|
| 181 |
+
# Test inference mode
|
| 182 |
+
start_time = time.time()
|
| 183 |
+
with torch.no_grad():
|
| 184 |
+
logits, _ = model(x)
|
| 185 |
+
inference_time = time.time() - start_time
|
| 186 |
+
|
| 187 |
+
# Test training mode with targets
|
| 188 |
+
model.train()
|
| 189 |
+
targets = torch.randint(0, model.config.vocab_size, (batch_size, seq_len))
|
| 190 |
+
targets = targets.to(self.device)
|
| 191 |
+
|
| 192 |
+
start_time = time.time()
|
| 193 |
+
logits_train, loss = model(x, targets)
|
| 194 |
+
train_time = time.time() - start_time
|
| 195 |
+
|
| 196 |
+
print("β Forward pass successful")
|
| 197 |
+
print(f" Input shape: {x.shape}")
|
| 198 |
+
print(f" Output shape: {logits.shape}")
|
| 199 |
+
print(f" Loss: {loss.item():.4f}")
|
| 200 |
+
print(f" Inference time: {inference_time:.4f}s")
|
| 201 |
+
print(f" Training time: {train_time:.4f}s")
|
| 202 |
+
|
| 203 |
+
return {
|
| 204 |
+
"success": True,
|
| 205 |
+
"input_shape": list(x.shape),
|
| 206 |
+
"output_shape": list(logits.shape),
|
| 207 |
+
"loss": loss.item(),
|
| 208 |
+
"inference_time": inference_time,
|
| 209 |
+
"training_time": train_time,
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
except Exception as e:
|
| 213 |
+
print(f"β Forward pass failed: {e}")
|
| 214 |
+
traceback.print_exc()
|
| 215 |
+
return {"success": False, "error": str(e)}
|
| 216 |
+
|
| 217 |
+
def test_memory_usage(self, model: GPTModel, batch_sizes: List[int] = [1, 2, 4]) -> Dict:
|
| 218 |
+
"""
|
| 219 |
+
Test memory usage for different batch sizes.
|
| 220 |
+
|
| 221 |
+
Args:
|
| 222 |
+
model: Model to test
|
| 223 |
+
batch_sizes: List of batch sizes to test
|
| 224 |
+
|
| 225 |
+
Returns:
|
| 226 |
+
dict: Memory usage results
|
| 227 |
+
"""
|
| 228 |
+
print("\nπΎ Testing memory usage...")
|
| 229 |
+
|
| 230 |
+
results = {}
|
| 231 |
+
|
| 232 |
+
for batch_size in batch_sizes:
|
| 233 |
+
try:
|
| 234 |
+
# Clear cache
|
| 235 |
+
if torch.cuda.is_available():
|
| 236 |
+
torch.cuda.empty_cache()
|
| 237 |
+
|
| 238 |
+
# Get initial memory
|
| 239 |
+
if torch.cuda.is_available():
|
| 240 |
+
initial_memory = torch.cuda.memory_allocated() / (1024**2)
|
| 241 |
+
else:
|
| 242 |
+
initial_memory = 0
|
| 243 |
+
|
| 244 |
+
# Forward pass
|
| 245 |
+
seq_len = min(512, model.config.block_size)
|
| 246 |
+
x = torch.randint(0, model.config.vocab_size, (batch_size, seq_len))
|
| 247 |
+
x = x.to(self.device)
|
| 248 |
+
|
| 249 |
+
with torch.no_grad():
|
| 250 |
+
logits, _ = model(x)
|
| 251 |
+
|
| 252 |
+
# Get peak memory
|
| 253 |
+
if torch.cuda.is_available():
|
| 254 |
+
peak_memory = torch.cuda.max_memory_allocated() / (1024**2)
|
| 255 |
+
memory_used = peak_memory - initial_memory
|
| 256 |
+
else:
|
| 257 |
+
memory_used = model.estimate_memory_usage(batch_size, seq_len)[
|
| 258 |
+
"total_inference_mb"
|
| 259 |
+
]
|
| 260 |
+
|
| 261 |
+
results[f"batch_{batch_size}"] = {
|
| 262 |
+
"memory_mb": memory_used,
|
| 263 |
+
"memory_per_sample": memory_used / batch_size,
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
print(
|
| 267 |
+
f" Batch size {batch_size}: {memory_used:.1f}MB ({memory_used/batch_size:.1f}MB per sample)"
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
except Exception as e:
|
| 271 |
+
print(f" Batch size {batch_size}: Failed - {e}")
|
| 272 |
+
results[f"batch_{batch_size}"] = {"error": str(e)}
|
| 273 |
+
|
| 274 |
+
return results
|
| 275 |
+
|
| 276 |
+
def test_tokenizer_integration(self, model: GPTModel) -> Dict:
|
| 277 |
+
"""
|
| 278 |
+
Test integration with the trained tokenizer.
|
| 279 |
+
|
| 280 |
+
Args:
|
| 281 |
+
model: Model to test
|
| 282 |
+
|
| 283 |
+
Returns:
|
| 284 |
+
dict: Integration test results
|
| 285 |
+
"""
|
| 286 |
+
print("\nπ€ Testing tokenizer integration...")
|
| 287 |
+
|
| 288 |
+
if self.tokenizer is None:
|
| 289 |
+
print("β οΈ No tokenizer available, skipping integration test")
|
| 290 |
+
return {"success": False, "reason": "No tokenizer available"}
|
| 291 |
+
|
| 292 |
+
try:
|
| 293 |
+
# Test sentences
|
| 294 |
+
test_sentences = [
|
| 295 |
+
"The quick brown fox jumps over the lazy dog.",
|
| 296 |
+
"Machine learning is transforming technology.",
|
| 297 |
+
"GPT models use transformer architecture for language modeling.",
|
| 298 |
+
]
|
| 299 |
+
|
| 300 |
+
results = []
|
| 301 |
+
|
| 302 |
+
for sentence in test_sentences:
|
| 303 |
+
# Tokenize
|
| 304 |
+
tokens = self.tokenizer.encode(sentence)
|
| 305 |
+
token_tensor = torch.tensor([tokens]).to(self.device)
|
| 306 |
+
|
| 307 |
+
# Forward pass
|
| 308 |
+
with torch.no_grad():
|
| 309 |
+
logits, _ = model(token_tensor)
|
| 310 |
+
|
| 311 |
+
# Get predictions for next token
|
| 312 |
+
next_token_logits = logits[0, -1, :]
|
| 313 |
+
next_token_probs = torch.softmax(next_token_logits, dim=0)
|
| 314 |
+
top5_tokens = torch.topk(next_token_probs, 5)
|
| 315 |
+
|
| 316 |
+
# Decode top predictions
|
| 317 |
+
top5_decoded = []
|
| 318 |
+
for token_id in top5_tokens.indices:
|
| 319 |
+
try:
|
| 320 |
+
decoded = self.tokenizer.decode([token_id.item()])
|
| 321 |
+
prob = top5_tokens.values[len(top5_decoded)].item()
|
| 322 |
+
top5_decoded.append((decoded, prob))
|
| 323 |
+
except Exception:
|
| 324 |
+
top5_decoded.append(("<??>", 0.0))
|
| 325 |
+
|
| 326 |
+
results.append(
|
| 327 |
+
{"input": sentence, "tokens": len(tokens), "top_predictions": top5_decoded}
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
print(f"β '{sentence[:30]}...' -> {len(tokens)} tokens")
|
| 331 |
+
print(f" Top prediction: '{top5_decoded[0][0]}' ({top5_decoded[0][1]:.3f})")
|
| 332 |
+
|
| 333 |
+
return {
|
| 334 |
+
"success": True,
|
| 335 |
+
"vocab_size_match": self.tokenizer.vocab_size() == model.config.vocab_size,
|
| 336 |
+
"test_results": results,
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
except Exception as e:
|
| 340 |
+
print(f"β Tokenizer integration failed: {e}")
|
| 341 |
+
traceback.print_exc()
|
| 342 |
+
return {"success": False, "error": str(e)}
|
| 343 |
+
|
| 344 |
+
def test_generation(self, model: GPTModel, prompt: str = "The future of AI") -> Dict:
|
| 345 |
+
"""
|
| 346 |
+
Test text generation capabilities.
|
| 347 |
+
|
| 348 |
+
Args:
|
| 349 |
+
model: Model to test
|
| 350 |
+
prompt: Starting prompt for generation
|
| 351 |
+
|
| 352 |
+
Returns:
|
| 353 |
+
dict: Generation test results
|
| 354 |
+
"""
|
| 355 |
+
print("\nβοΈ Testing text generation...")
|
| 356 |
+
|
| 357 |
+
if self.tokenizer is None:
|
| 358 |
+
print("β οΈ No tokenizer available, skipping generation test")
|
| 359 |
+
return {"success": False, "reason": "No tokenizer available"}
|
| 360 |
+
|
| 361 |
+
try:
|
| 362 |
+
# Tokenize prompt
|
| 363 |
+
tokens = self.tokenizer.encode(prompt)
|
| 364 |
+
input_tensor = torch.tensor([tokens]).to(self.device)
|
| 365 |
+
|
| 366 |
+
print(f"Prompt: '{prompt}'")
|
| 367 |
+
print("Generating...")
|
| 368 |
+
|
| 369 |
+
# Generate
|
| 370 |
+
start_time = time.time()
|
| 371 |
+
output = model.generate(input_tensor, max_new_tokens=50, temperature=0.8, top_k=50)
|
| 372 |
+
generation_time = time.time() - start_time
|
| 373 |
+
|
| 374 |
+
# Decode output
|
| 375 |
+
generated_tokens = output[0].tolist()
|
| 376 |
+
generated_text = self.tokenizer.decode(generated_tokens)
|
| 377 |
+
|
| 378 |
+
print(f"β Generated text: '{generated_text}'")
|
| 379 |
+
print(f" Generation time: {generation_time:.2f}s")
|
| 380 |
+
print(f" Tokens per second: {50/generation_time:.1f}")
|
| 381 |
+
|
| 382 |
+
return {
|
| 383 |
+
"success": True,
|
| 384 |
+
"prompt": prompt,
|
| 385 |
+
"generated_text": generated_text,
|
| 386 |
+
"generation_time": generation_time,
|
| 387 |
+
"tokens_per_second": 50 / generation_time,
|
| 388 |
+
}
|
| 389 |
+
|
| 390 |
+
except Exception as e:
|
| 391 |
+
print(f"β Text generation failed: {e}")
|
| 392 |
+
traceback.print_exc()
|
| 393 |
+
return {"success": False, "error": str(e)}
|
| 394 |
+
|
| 395 |
+
def run_comprehensive_test(self, model_size: str = "medium") -> Dict:
|
| 396 |
+
"""
|
| 397 |
+
Run all tests for a given model size.
|
| 398 |
+
|
| 399 |
+
Args:
|
| 400 |
+
model_size: Size of model to test
|
| 401 |
+
|
| 402 |
+
Returns:
|
| 403 |
+
dict: Complete test results
|
| 404 |
+
"""
|
| 405 |
+
print(f"\nπ Running comprehensive test for {model_size.upper()} model")
|
| 406 |
+
print("=" * 60)
|
| 407 |
+
|
| 408 |
+
results = {"model_size": model_size, "device": self.device}
|
| 409 |
+
|
| 410 |
+
# Test 1: Model initialization
|
| 411 |
+
init_result = self.test_model_initialization(model_size)
|
| 412 |
+
results["initialization"] = init_result
|
| 413 |
+
|
| 414 |
+
if not init_result["success"]:
|
| 415 |
+
return results
|
| 416 |
+
|
| 417 |
+
# Create model for remaining tests
|
| 418 |
+
model = create_model(model_size).to(self.device)
|
| 419 |
+
|
| 420 |
+
# Test 2: Forward pass
|
| 421 |
+
results["forward_pass"] = self.test_forward_pass(model)
|
| 422 |
+
|
| 423 |
+
# Test 3: Memory usage
|
| 424 |
+
results["memory_usage"] = self.test_memory_usage(model)
|
| 425 |
+
|
| 426 |
+
# Test 4: Tokenizer integration
|
| 427 |
+
results["tokenizer_integration"] = self.test_tokenizer_integration(model)
|
| 428 |
+
|
| 429 |
+
# Test 5: Text generation
|
| 430 |
+
results["generation"] = self.test_generation(model)
|
| 431 |
+
|
| 432 |
+
return results
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
def load_model_config(model_size: str) -> Dict:
|
| 436 |
+
"""Load model configuration from JSON file."""
|
| 437 |
+
config_path = f"configs/{model_size}_model.json"
|
| 438 |
+
if os.path.exists(config_path):
|
| 439 |
+
with open(config_path, "r") as f:
|
| 440 |
+
return json.load(f)
|
| 441 |
+
return {}
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
def print_hardware_recommendations(model_size: str) -> None:
|
| 445 |
+
"""Print hardware recommendations for the given model size."""
|
| 446 |
+
config = load_model_config(model_size)
|
| 447 |
+
|
| 448 |
+
if config:
|
| 449 |
+
print(f"\nπ» Hardware Recommendations for {model_size.upper()} model:")
|
| 450 |
+
print(f" Parameters: {config.get('parameters', 'Unknown')}")
|
| 451 |
+
print(f" Recommended: {config.get('recommended_hardware', 'Unknown')}")
|
| 452 |
+
|
| 453 |
+
if "memory_estimates" in config:
|
| 454 |
+
mem = config["memory_estimates"]
|
| 455 |
+
print(f" Memory usage: ~{mem.get('parameters_mb', '?')}MB parameters")
|
| 456 |
+
print(f" Training: ~{mem.get('training_mb_per_sample', '?')}MB per sample")
|
| 457 |
+
print(f" Inference: ~{mem.get('inference_mb_per_sample', '?')}MB per sample")
|
| 458 |
+
|
| 459 |
+
if "cpu_training_notes" in config:
|
| 460 |
+
cpu_notes = config["cpu_training_notes"]
|
| 461 |
+
if cpu_notes.get("feasible"):
|
| 462 |
+
print(
|
| 463 |
+
f" CPU Training: Feasible but slow ({cpu_notes.get('expected_training_time', '?')})"
|
| 464 |
+
)
|
| 465 |
+
else:
|
| 466 |
+
print(f" CPU Training: Not recommended - {cpu_notes.get('reason', 'Too large')}")
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
def main():
|
| 470 |
+
"""Main function to handle command line testing."""
|
| 471 |
+
parser = argparse.ArgumentParser(
|
| 472 |
+
description="Test and validate GPT model architecture",
|
| 473 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 474 |
+
epilog="""
|
| 475 |
+
Examples:
|
| 476 |
+
# Test medium model
|
| 477 |
+
python core/src/test_model.py --model_size medium
|
| 478 |
+
|
| 479 |
+
# Test all model sizes
|
| 480 |
+
python core/src/test_model.py --all_sizes
|
| 481 |
+
|
| 482 |
+
# Test with text generation
|
| 483 |
+
python core/src/test_model.py --model_size small --test_generation
|
| 484 |
+
|
| 485 |
+
# Show hardware recommendations
|
| 486 |
+
python core/src/test_model.py --recommendations
|
| 487 |
+
""",
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
parser.add_argument(
|
| 491 |
+
"--model_size",
|
| 492 |
+
choices=["small", "medium", "large"],
|
| 493 |
+
default="medium",
|
| 494 |
+
help="Model size to test (default: medium)",
|
| 495 |
+
)
|
| 496 |
+
|
| 497 |
+
parser.add_argument("--all_sizes", action="store_true", help="Test all model sizes")
|
| 498 |
+
|
| 499 |
+
parser.add_argument(
|
| 500 |
+
"--test_generation", action="store_true", help="Include text generation test"
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
parser.add_argument(
|
| 504 |
+
"--device",
|
| 505 |
+
choices=["cpu", "cuda", "auto"],
|
| 506 |
+
default="auto",
|
| 507 |
+
help="Device to use for testing (default: auto)",
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
parser.add_argument(
|
| 511 |
+
"--recommendations",
|
| 512 |
+
action="store_true",
|
| 513 |
+
help="Show hardware recommendations for all model sizes",
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
parser.add_argument("--save_results", help="Save test results to JSON file")
|
| 517 |
+
|
| 518 |
+
args = parser.parse_args()
|
| 519 |
+
|
| 520 |
+
print("π§ͺ GPT Model Architecture Tester")
|
| 521 |
+
print("=" * 50)
|
| 522 |
+
|
| 523 |
+
# Show hardware recommendations
|
| 524 |
+
if args.recommendations:
|
| 525 |
+
for size in ["small", "medium", "large"]:
|
| 526 |
+
print_hardware_recommendations(size)
|
| 527 |
+
return
|
| 528 |
+
|
| 529 |
+
# Initialize tester
|
| 530 |
+
tester = ModelTester(device=args.device)
|
| 531 |
+
|
| 532 |
+
# Run tests
|
| 533 |
+
all_results = {}
|
| 534 |
+
|
| 535 |
+
if args.all_sizes:
|
| 536 |
+
test_sizes = ["small", "medium", "large"]
|
| 537 |
+
else:
|
| 538 |
+
test_sizes = [args.model_size]
|
| 539 |
+
|
| 540 |
+
for size in test_sizes:
|
| 541 |
+
results = tester.run_comprehensive_test(size)
|
| 542 |
+
all_results[size] = results
|
| 543 |
+
|
| 544 |
+
# Print summary
|
| 545 |
+
print(f"\nπ {size.upper()} Model Test Summary:")
|
| 546 |
+
print(f" Initialization: {'β' if results['initialization']['success'] else 'β'}")
|
| 547 |
+
print(f" Forward Pass: {'β' if results.get('forward_pass', {}).get('success') else 'β'}")
|
| 548 |
+
print(f" Memory Test: {'β' if 'memory_usage' in results else 'β'}")
|
| 549 |
+
print(
|
| 550 |
+
f" Tokenizer: {'β' if results.get('tokenizer_integration', {}).get('success') else 'β'}"
|
| 551 |
+
)
|
| 552 |
+
print(f" Generation: {'β' if results.get('generation', {}).get('success') else 'β'}")
|
| 553 |
+
|
| 554 |
+
# Save results if requested
|
| 555 |
+
if args.save_results:
|
| 556 |
+
with open(args.save_results, "w") as f:
|
| 557 |
+
json.dump(all_results, f, indent=2)
|
| 558 |
+
print(f"\nπΎ Results saved to {args.save_results}")
|
| 559 |
+
|
| 560 |
+
print("\nπ Testing completed!")
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
if __name__ == "__main__":
|
| 564 |
+
main()
|
core/src/optimized_data_loader.py
ADDED
|
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Optimized Data Loader for Training
|
| 4 |
+
|
| 5 |
+
This module provides an optimized data loader with prefetching, caching,
|
| 6 |
+
and efficient batch processing to improve training performance.
|
| 7 |
+
|
| 8 |
+
Author: Louis Chua Bean Chong
|
| 9 |
+
License: GPLv3
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
from torch.utils.data import DataLoader, Dataset, Sampler
|
| 15 |
+
from typing import Optional, List, Tuple, Dict, Any
|
| 16 |
+
import numpy as np
|
| 17 |
+
import threading
|
| 18 |
+
import queue
|
| 19 |
+
import time
|
| 20 |
+
from collections import deque
|
| 21 |
+
import psutil
|
| 22 |
+
import os
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class OptimizedDataset(Dataset):
|
| 26 |
+
"""
|
| 27 |
+
Optimized dataset with caching and memory management.
|
| 28 |
+
|
| 29 |
+
This dataset provides efficient data loading with optional caching
|
| 30 |
+
and memory management to improve training performance.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(self,
|
| 34 |
+
data: torch.Tensor,
|
| 35 |
+
targets: torch.Tensor,
|
| 36 |
+
cache_size: Optional[int] = None,
|
| 37 |
+
pin_memory: bool = True):
|
| 38 |
+
"""
|
| 39 |
+
Initialize optimized dataset.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
data: Input data tensor
|
| 43 |
+
targets: Target tensor
|
| 44 |
+
cache_size: Number of samples to cache in memory
|
| 45 |
+
pin_memory: Whether to pin memory for faster GPU transfer
|
| 46 |
+
"""
|
| 47 |
+
self.data = data
|
| 48 |
+
self.targets = targets
|
| 49 |
+
self.cache_size = cache_size
|
| 50 |
+
self.pin_memory = pin_memory
|
| 51 |
+
|
| 52 |
+
# Initialize cache
|
| 53 |
+
self.cache = {}
|
| 54 |
+
self.cache_hits = 0
|
| 55 |
+
self.cache_misses = 0
|
| 56 |
+
|
| 57 |
+
if cache_size and cache_size > 0:
|
| 58 |
+
print(f"Initializing cache with {cache_size} samples")
|
| 59 |
+
|
| 60 |
+
def __len__(self):
|
| 61 |
+
return len(self.data)
|
| 62 |
+
|
| 63 |
+
def __getitem__(self, idx):
|
| 64 |
+
# Check cache first
|
| 65 |
+
if self.cache_size and idx in self.cache:
|
| 66 |
+
self.cache_hits += 1
|
| 67 |
+
return self.cache[idx]
|
| 68 |
+
|
| 69 |
+
self.cache_misses += 1
|
| 70 |
+
|
| 71 |
+
# Get data
|
| 72 |
+
sample_data = self.data[idx]
|
| 73 |
+
sample_target = self.targets[idx]
|
| 74 |
+
|
| 75 |
+
# Pin memory if requested
|
| 76 |
+
if self.pin_memory and torch.cuda.is_available():
|
| 77 |
+
sample_data = sample_data.pin_memory()
|
| 78 |
+
sample_target = sample_target.pin_memory()
|
| 79 |
+
|
| 80 |
+
# Cache if enabled
|
| 81 |
+
if self.cache_size and len(self.cache) < self.cache_size:
|
| 82 |
+
self.cache[idx] = (sample_data, sample_target)
|
| 83 |
+
|
| 84 |
+
return sample_data, sample_target
|
| 85 |
+
|
| 86 |
+
def get_cache_stats(self) -> Dict[str, Any]:
|
| 87 |
+
"""Get cache statistics."""
|
| 88 |
+
total_requests = self.cache_hits + self.cache_misses
|
| 89 |
+
hit_rate = self.cache_hits / total_requests if total_requests > 0 else 0
|
| 90 |
+
|
| 91 |
+
return {
|
| 92 |
+
"cache_hits": self.cache_hits,
|
| 93 |
+
"cache_misses": self.cache_misses,
|
| 94 |
+
"hit_rate": hit_rate,
|
| 95 |
+
"cache_size": len(self.cache),
|
| 96 |
+
"max_cache_size": self.cache_size
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class PrefetchDataLoader:
|
| 101 |
+
"""
|
| 102 |
+
Data loader with prefetching for improved performance.
|
| 103 |
+
|
| 104 |
+
This data loader uses background threads to prefetch data,
|
| 105 |
+
reducing the time spent waiting for data during training.
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
def __init__(self,
|
| 109 |
+
dataset: Dataset,
|
| 110 |
+
batch_size: int = 32,
|
| 111 |
+
num_workers: int = 4,
|
| 112 |
+
prefetch_factor: int = 2,
|
| 113 |
+
pin_memory: bool = True,
|
| 114 |
+
shuffle: bool = True,
|
| 115 |
+
drop_last: bool = False):
|
| 116 |
+
"""
|
| 117 |
+
Initialize prefetch data loader.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
dataset: Dataset to load
|
| 121 |
+
batch_size: Batch size
|
| 122 |
+
num_workers: Number of worker processes
|
| 123 |
+
prefetch_factor: Number of batches to prefetch
|
| 124 |
+
pin_memory: Whether to pin memory
|
| 125 |
+
shuffle: Whether to shuffle data
|
| 126 |
+
drop_last: Whether to drop incomplete batches
|
| 127 |
+
"""
|
| 128 |
+
self.dataset = dataset
|
| 129 |
+
self.batch_size = batch_size
|
| 130 |
+
self.num_workers = num_workers
|
| 131 |
+
self.prefetch_factor = prefetch_factor
|
| 132 |
+
self.pin_memory = pin_memory
|
| 133 |
+
self.shuffle = shuffle
|
| 134 |
+
self.drop_last = drop_last
|
| 135 |
+
|
| 136 |
+
# Initialize data loader
|
| 137 |
+
self.data_loader = DataLoader(
|
| 138 |
+
dataset=dataset,
|
| 139 |
+
batch_size=batch_size,
|
| 140 |
+
shuffle=shuffle,
|
| 141 |
+
num_workers=num_workers,
|
| 142 |
+
pin_memory=pin_memory,
|
| 143 |
+
drop_last=drop_last,
|
| 144 |
+
persistent_workers=True if num_workers > 0 else False
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
# Prefetch queue
|
| 148 |
+
self.prefetch_queue = queue.Queue(maxsize=prefetch_factor)
|
| 149 |
+
self.prefetch_thread = None
|
| 150 |
+
self.stop_prefetch = False
|
| 151 |
+
|
| 152 |
+
# Start prefetching
|
| 153 |
+
self._start_prefetch()
|
| 154 |
+
|
| 155 |
+
print(f"PrefetchDataLoader initialized with {num_workers} workers")
|
| 156 |
+
|
| 157 |
+
def _start_prefetch(self):
|
| 158 |
+
"""Start prefetching thread."""
|
| 159 |
+
if self.prefetch_factor > 0:
|
| 160 |
+
self.prefetch_thread = threading.Thread(target=self._prefetch_worker)
|
| 161 |
+
self.prefetch_thread.daemon = True
|
| 162 |
+
self.prefetch_thread.start()
|
| 163 |
+
|
| 164 |
+
def _prefetch_worker(self):
|
| 165 |
+
"""Worker thread for prefetching data."""
|
| 166 |
+
try:
|
| 167 |
+
for batch in self.data_loader:
|
| 168 |
+
if self.stop_prefetch:
|
| 169 |
+
break
|
| 170 |
+
|
| 171 |
+
# Put batch in queue (block if full)
|
| 172 |
+
self.prefetch_queue.put(batch, block=True)
|
| 173 |
+
except Exception as e:
|
| 174 |
+
print(f"Prefetch worker error: {e}")
|
| 175 |
+
|
| 176 |
+
def __iter__(self):
|
| 177 |
+
"""Iterate over prefetched batches."""
|
| 178 |
+
return self
|
| 179 |
+
|
| 180 |
+
def __next__(self):
|
| 181 |
+
"""Get next batch from prefetch queue."""
|
| 182 |
+
if self.stop_prefetch:
|
| 183 |
+
raise StopIteration
|
| 184 |
+
|
| 185 |
+
try:
|
| 186 |
+
# Get batch from prefetch queue
|
| 187 |
+
batch = self.prefetch_queue.get(timeout=1.0)
|
| 188 |
+
return batch
|
| 189 |
+
except queue.Empty:
|
| 190 |
+
# If queue is empty, get directly from data loader
|
| 191 |
+
return next(self.data_loader.__iter__())
|
| 192 |
+
|
| 193 |
+
def __len__(self):
|
| 194 |
+
return len(self.data_loader)
|
| 195 |
+
|
| 196 |
+
def stop(self):
|
| 197 |
+
"""Stop prefetching."""
|
| 198 |
+
self.stop_prefetch = True
|
| 199 |
+
if self.prefetch_thread:
|
| 200 |
+
self.prefetch_thread.join()
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
class DynamicBatchSampler(Sampler):
|
| 204 |
+
"""
|
| 205 |
+
Dynamic batch sampler that adjusts batch size based on memory availability.
|
| 206 |
+
|
| 207 |
+
This sampler monitors system memory and adjusts batch sizes dynamically
|
| 208 |
+
to optimize memory usage and training performance.
|
| 209 |
+
"""
|
| 210 |
+
|
| 211 |
+
def __init__(self,
|
| 212 |
+
dataset_size: int,
|
| 213 |
+
base_batch_size: int = 32,
|
| 214 |
+
max_batch_size: int = 128,
|
| 215 |
+
memory_threshold: float = 0.8,
|
| 216 |
+
adjustment_factor: float = 1.2):
|
| 217 |
+
"""
|
| 218 |
+
Initialize dynamic batch sampler.
|
| 219 |
+
|
| 220 |
+
Args:
|
| 221 |
+
dataset_size: Size of the dataset
|
| 222 |
+
base_batch_size: Base batch size
|
| 223 |
+
max_batch_size: Maximum batch size
|
| 224 |
+
memory_threshold: Memory usage threshold for adjustment
|
| 225 |
+
adjustment_factor: Factor for batch size adjustment
|
| 226 |
+
"""
|
| 227 |
+
self.dataset_size = dataset_size
|
| 228 |
+
self.base_batch_size = base_batch_size
|
| 229 |
+
self.max_batch_size = max_batch_size
|
| 230 |
+
self.memory_threshold = memory_threshold
|
| 231 |
+
self.adjustment_factor = adjustment_factor
|
| 232 |
+
|
| 233 |
+
self.current_batch_size = base_batch_size
|
| 234 |
+
self.batch_history = deque(maxlen=10)
|
| 235 |
+
|
| 236 |
+
print(f"DynamicBatchSampler initialized with base batch size: {base_batch_size}")
|
| 237 |
+
|
| 238 |
+
def _get_memory_usage(self) -> float:
|
| 239 |
+
"""Get current memory usage as a fraction."""
|
| 240 |
+
memory = psutil.virtual_memory()
|
| 241 |
+
return memory.percent / 100.0
|
| 242 |
+
|
| 243 |
+
def _adjust_batch_size(self):
|
| 244 |
+
"""Adjust batch size based on memory usage."""
|
| 245 |
+
memory_usage = self._get_memory_usage()
|
| 246 |
+
|
| 247 |
+
if memory_usage > self.memory_threshold:
|
| 248 |
+
# Reduce batch size if memory usage is high
|
| 249 |
+
self.current_batch_size = max(
|
| 250 |
+
self.base_batch_size,
|
| 251 |
+
int(self.current_batch_size / self.adjustment_factor)
|
| 252 |
+
)
|
| 253 |
+
else:
|
| 254 |
+
# Increase batch size if memory usage is low
|
| 255 |
+
self.current_batch_size = min(
|
| 256 |
+
self.max_batch_size,
|
| 257 |
+
int(self.current_batch_size * self.adjustment_factor)
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
self.batch_history.append(self.current_batch_size)
|
| 261 |
+
|
| 262 |
+
def __iter__(self):
|
| 263 |
+
"""Generate batch indices."""
|
| 264 |
+
indices = list(range(self.dataset_size))
|
| 265 |
+
|
| 266 |
+
# Shuffle indices
|
| 267 |
+
np.random.shuffle(indices)
|
| 268 |
+
|
| 269 |
+
# Generate batches
|
| 270 |
+
for i in range(0, len(indices), self.current_batch_size):
|
| 271 |
+
batch_indices = indices[i:i + self.current_batch_size]
|
| 272 |
+
|
| 273 |
+
# Adjust batch size for next iteration
|
| 274 |
+
self._adjust_batch_size()
|
| 275 |
+
|
| 276 |
+
yield batch_indices
|
| 277 |
+
|
| 278 |
+
def __len__(self):
|
| 279 |
+
return (self.dataset_size + self.current_batch_size - 1) // self.current_batch_size
|
| 280 |
+
|
| 281 |
+
def get_stats(self) -> Dict[str, Any]:
|
| 282 |
+
"""Get sampler statistics."""
|
| 283 |
+
return {
|
| 284 |
+
"current_batch_size": self.current_batch_size,
|
| 285 |
+
"base_batch_size": self.base_batch_size,
|
| 286 |
+
"max_batch_size": self.max_batch_size,
|
| 287 |
+
"memory_usage": self._get_memory_usage(),
|
| 288 |
+
"batch_history": list(self.batch_history)
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
class OptimizedDataLoader:
|
| 293 |
+
"""
|
| 294 |
+
High-performance data loader with multiple optimizations.
|
| 295 |
+
|
| 296 |
+
This data loader combines multiple optimization techniques:
|
| 297 |
+
- Prefetching with background threads
|
| 298 |
+
- Dynamic batch sizing
|
| 299 |
+
- Memory pinning
|
| 300 |
+
- Caching
|
| 301 |
+
- Efficient memory management
|
| 302 |
+
"""
|
| 303 |
+
|
| 304 |
+
def __init__(self,
|
| 305 |
+
dataset: Dataset,
|
| 306 |
+
batch_size: int = 32,
|
| 307 |
+
num_workers: int = 4,
|
| 308 |
+
prefetch_factor: int = 2,
|
| 309 |
+
pin_memory: bool = True,
|
| 310 |
+
shuffle: bool = True,
|
| 311 |
+
drop_last: bool = False,
|
| 312 |
+
use_dynamic_batching: bool = True,
|
| 313 |
+
cache_size: Optional[int] = None):
|
| 314 |
+
"""
|
| 315 |
+
Initialize optimized data loader.
|
| 316 |
+
|
| 317 |
+
Args:
|
| 318 |
+
dataset: Dataset to load
|
| 319 |
+
batch_size: Base batch size
|
| 320 |
+
num_workers: Number of worker processes
|
| 321 |
+
prefetch_factor: Number of batches to prefetch
|
| 322 |
+
pin_memory: Whether to pin memory
|
| 323 |
+
shuffle: Whether to shuffle data
|
| 324 |
+
drop_last: Whether to drop incomplete batches
|
| 325 |
+
use_dynamic_batching: Whether to use dynamic batch sizing
|
| 326 |
+
cache_size: Number of samples to cache
|
| 327 |
+
"""
|
| 328 |
+
self.dataset = dataset
|
| 329 |
+
self.batch_size = batch_size
|
| 330 |
+
self.num_workers = num_workers
|
| 331 |
+
self.prefetch_factor = prefetch_factor
|
| 332 |
+
self.pin_memory = pin_memory
|
| 333 |
+
self.shuffle = shuffle
|
| 334 |
+
self.drop_last = drop_last
|
| 335 |
+
self.use_dynamic_batching = use_dynamic_batching
|
| 336 |
+
self.cache_size = cache_size
|
| 337 |
+
|
| 338 |
+
# Create optimized dataset if caching is enabled
|
| 339 |
+
if cache_size and cache_size > 0:
|
| 340 |
+
self.dataset = OptimizedDataset(
|
| 341 |
+
dataset.data if hasattr(dataset, 'data') else dataset,
|
| 342 |
+
dataset.targets if hasattr(dataset, 'targets') else None,
|
| 343 |
+
cache_size=cache_size,
|
| 344 |
+
pin_memory=pin_memory
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
# Create sampler
|
| 348 |
+
if use_dynamic_batching:
|
| 349 |
+
self.sampler = DynamicBatchSampler(
|
| 350 |
+
dataset_size=len(self.dataset),
|
| 351 |
+
base_batch_size=batch_size,
|
| 352 |
+
max_batch_size=batch_size * 4
|
| 353 |
+
)
|
| 354 |
+
else:
|
| 355 |
+
self.sampler = None
|
| 356 |
+
|
| 357 |
+
# Create data loader
|
| 358 |
+
self.data_loader = DataLoader(
|
| 359 |
+
dataset=self.dataset,
|
| 360 |
+
batch_size=batch_size,
|
| 361 |
+
sampler=self.sampler,
|
| 362 |
+
shuffle=shuffle if not use_dynamic_batching else False,
|
| 363 |
+
num_workers=num_workers,
|
| 364 |
+
pin_memory=pin_memory,
|
| 365 |
+
drop_last=drop_last,
|
| 366 |
+
persistent_workers=True if num_workers > 0 else False
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
# Create prefetch loader
|
| 370 |
+
self.prefetch_loader = PrefetchDataLoader(
|
| 371 |
+
dataset=self.dataset,
|
| 372 |
+
batch_size=batch_size,
|
| 373 |
+
num_workers=num_workers,
|
| 374 |
+
prefetch_factor=prefetch_factor,
|
| 375 |
+
pin_memory=pin_memory,
|
| 376 |
+
shuffle=shuffle,
|
| 377 |
+
drop_last=drop_last
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
print(f"OptimizedDataLoader initialized with {num_workers} workers")
|
| 381 |
+
|
| 382 |
+
def __iter__(self):
|
| 383 |
+
"""Iterate over batches."""
|
| 384 |
+
return iter(self.prefetch_loader)
|
| 385 |
+
|
| 386 |
+
def __len__(self):
|
| 387 |
+
return len(self.data_loader)
|
| 388 |
+
|
| 389 |
+
def get_stats(self) -> Dict[str, Any]:
|
| 390 |
+
"""Get loader statistics."""
|
| 391 |
+
stats = {
|
| 392 |
+
"batch_size": self.batch_size,
|
| 393 |
+
"num_workers": self.num_workers,
|
| 394 |
+
"prefetch_factor": self.prefetch_factor,
|
| 395 |
+
"cache_enabled": self.cache_size is not None,
|
| 396 |
+
"dynamic_batching": self.use_dynamic_batching
|
| 397 |
+
}
|
| 398 |
+
|
| 399 |
+
if hasattr(self.dataset, 'get_cache_stats'):
|
| 400 |
+
stats.update(self.dataset.get_cache_stats())
|
| 401 |
+
|
| 402 |
+
if self.sampler:
|
| 403 |
+
stats.update(self.sampler.get_stats())
|
| 404 |
+
|
| 405 |
+
return stats
|
| 406 |
+
|
| 407 |
+
def stop(self):
|
| 408 |
+
"""Stop the data loader."""
|
| 409 |
+
self.prefetch_loader.stop()
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
def create_optimized_loader(dataset: Dataset,
|
| 413 |
+
batch_size: int = 32,
|
| 414 |
+
num_workers: Optional[int] = None,
|
| 415 |
+
**kwargs) -> OptimizedDataLoader:
|
| 416 |
+
"""
|
| 417 |
+
Create an optimized data loader with automatic configuration.
|
| 418 |
+
|
| 419 |
+
Args:
|
| 420 |
+
dataset: Dataset to load
|
| 421 |
+
batch_size: Batch size
|
| 422 |
+
num_workers: Number of workers (auto-detect if None)
|
| 423 |
+
**kwargs: Additional arguments
|
| 424 |
+
|
| 425 |
+
Returns:
|
| 426 |
+
OptimizedDataLoader: Configured data loader
|
| 427 |
+
"""
|
| 428 |
+
if num_workers is None:
|
| 429 |
+
# Auto-detect optimal number of workers
|
| 430 |
+
num_workers = min(4, os.cpu_count() or 1)
|
| 431 |
+
|
| 432 |
+
return OptimizedDataLoader(
|
| 433 |
+
dataset=dataset,
|
| 434 |
+
batch_size=batch_size,
|
| 435 |
+
num_workers=num_workers,
|
| 436 |
+
**kwargs
|
| 437 |
+
)
|
core/src/optimized_inference_server.py
ADDED
|
@@ -0,0 +1,739 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Optimized OpenLLM Inference Server
|
| 4 |
+
|
| 5 |
+
This module provides an optimized inference server with:
|
| 6 |
+
- Model caching and memory management
|
| 7 |
+
- Request batching for improved throughput
|
| 8 |
+
- Response streaming for real-time generation
|
| 9 |
+
- Performance monitoring and metrics
|
| 10 |
+
- Load balancing and concurrent processing
|
| 11 |
+
|
| 12 |
+
Author: Louis Chua Bean Chong
|
| 13 |
+
License: GPLv3
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import asyncio
|
| 17 |
+
import json
|
| 18 |
+
import time
|
| 19 |
+
import threading
|
| 20 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 21 |
+
from typing import Optional, List, Dict, Any, AsyncGenerator
|
| 22 |
+
from collections import deque
|
| 23 |
+
import torch
|
| 24 |
+
import torch.nn.functional as F
|
| 25 |
+
from fastapi import FastAPI, HTTPException, BackgroundTasks
|
| 26 |
+
from fastapi.responses import StreamingResponse
|
| 27 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 28 |
+
from pydantic import BaseModel, Field
|
| 29 |
+
import uvicorn
|
| 30 |
+
import logging
|
| 31 |
+
import psutil
|
| 32 |
+
import os
|
| 33 |
+
import sys
|
| 34 |
+
from pathlib import Path
|
| 35 |
+
|
| 36 |
+
# Add current directory to path for imports
|
| 37 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 38 |
+
|
| 39 |
+
from model import GPTConfig, GPTModel
|
| 40 |
+
from quantization import QuantizedModel, quantize_model_dynamic
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# Configure logging
|
| 44 |
+
logging.basicConfig(level=logging.INFO)
|
| 45 |
+
logger = logging.getLogger(__name__)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class OptimizedInferenceEngine:
|
| 49 |
+
"""
|
| 50 |
+
Optimized inference engine with caching and batching.
|
| 51 |
+
|
| 52 |
+
This engine provides high-performance inference with:
|
| 53 |
+
- Model caching and memory management
|
| 54 |
+
- Request batching for improved throughput
|
| 55 |
+
- Quantization support for reduced memory usage
|
| 56 |
+
- Performance monitoring and metrics
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
def __init__(self,
|
| 60 |
+
model_path: str,
|
| 61 |
+
device: str = "auto",
|
| 62 |
+
use_quantization: bool = True,
|
| 63 |
+
cache_size: int = 1000,
|
| 64 |
+
max_batch_size: int = 32,
|
| 65 |
+
num_workers: int = 4):
|
| 66 |
+
"""
|
| 67 |
+
Initialize optimized inference engine.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
model_path: Path to the model
|
| 71 |
+
device: Device to use ("auto", "cpu", "cuda")
|
| 72 |
+
use_quantization: Whether to use quantization
|
| 73 |
+
cache_size: Size of response cache
|
| 74 |
+
max_batch_size: Maximum batch size for processing
|
| 75 |
+
num_workers: Number of worker threads
|
| 76 |
+
"""
|
| 77 |
+
self.model_path = model_path
|
| 78 |
+
self.device = self._get_device(device)
|
| 79 |
+
self.use_quantization = use_quantization
|
| 80 |
+
self.cache_size = cache_size
|
| 81 |
+
self.max_batch_size = max_batch_size
|
| 82 |
+
self.num_workers = num_workers
|
| 83 |
+
|
| 84 |
+
# Initialize components
|
| 85 |
+
self.model = None
|
| 86 |
+
self.tokenizer = None
|
| 87 |
+
self.quantized_model = None
|
| 88 |
+
self.response_cache = {}
|
| 89 |
+
self.request_queue = deque()
|
| 90 |
+
self.processing_lock = threading.Lock()
|
| 91 |
+
|
| 92 |
+
# Performance metrics
|
| 93 |
+
self.metrics = {
|
| 94 |
+
"total_requests": 0,
|
| 95 |
+
"cache_hits": 0,
|
| 96 |
+
"cache_misses": 0,
|
| 97 |
+
"avg_generation_time": 0.0,
|
| 98 |
+
"total_generation_time": 0.0,
|
| 99 |
+
"requests_per_second": 0.0
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
# Thread pool for concurrent processing
|
| 103 |
+
self.executor = ThreadPoolExecutor(max_workers=num_workers)
|
| 104 |
+
|
| 105 |
+
# Load model
|
| 106 |
+
self._load_model()
|
| 107 |
+
|
| 108 |
+
logger.info(f"OptimizedInferenceEngine initialized on {self.device}")
|
| 109 |
+
|
| 110 |
+
def _get_device(self, device: str) -> torch.device:
|
| 111 |
+
"""Get the appropriate device."""
|
| 112 |
+
if device == "auto":
|
| 113 |
+
if torch.cuda.is_available():
|
| 114 |
+
return torch.device("cuda")
|
| 115 |
+
else:
|
| 116 |
+
return torch.device("cpu")
|
| 117 |
+
else:
|
| 118 |
+
return torch.device(device)
|
| 119 |
+
|
| 120 |
+
def _load_model(self):
|
| 121 |
+
"""Load and optimize the model."""
|
| 122 |
+
try:
|
| 123 |
+
logger.info(f"Loading model from {self.model_path}")
|
| 124 |
+
|
| 125 |
+
# Load model configuration
|
| 126 |
+
config_path = Path(self.model_path) / "config.json"
|
| 127 |
+
if config_path.exists():
|
| 128 |
+
with open(config_path, 'r') as f:
|
| 129 |
+
config_data = json.load(f)
|
| 130 |
+
config = GPTConfig(**config_data)
|
| 131 |
+
else:
|
| 132 |
+
# Use default config
|
| 133 |
+
config = GPTConfig.small()
|
| 134 |
+
|
| 135 |
+
# Create model
|
| 136 |
+
self.model = GPTModel(config, use_checkpoint=False) # No checkpointing for inference
|
| 137 |
+
|
| 138 |
+
# Load model weights
|
| 139 |
+
model_path = Path(self.model_path) / "pytorch_model.bin"
|
| 140 |
+
if model_path.exists():
|
| 141 |
+
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
|
| 142 |
+
logger.info("Model weights loaded successfully")
|
| 143 |
+
else:
|
| 144 |
+
logger.warning("No model weights found, using initialized weights")
|
| 145 |
+
|
| 146 |
+
# Move model to device
|
| 147 |
+
self.model.to(self.device)
|
| 148 |
+
self.model.eval()
|
| 149 |
+
|
| 150 |
+
# Apply quantization if requested
|
| 151 |
+
if self.use_quantization and self.device.type == "cpu":
|
| 152 |
+
logger.info("Applying dynamic quantization")
|
| 153 |
+
self.quantized_model = QuantizedModel(self.model)
|
| 154 |
+
self.quantized_model.quantize_dynamic()
|
| 155 |
+
logger.info("Quantization completed")
|
| 156 |
+
|
| 157 |
+
# Load tokenizer
|
| 158 |
+
tokenizer_path = Path(self.model_path) / "tokenizer.model"
|
| 159 |
+
if tokenizer_path.exists():
|
| 160 |
+
import sentencepiece as spm
|
| 161 |
+
self.tokenizer = spm.SentencePieceProcessor()
|
| 162 |
+
self.tokenizer.load(str(tokenizer_path))
|
| 163 |
+
logger.info("Tokenizer loaded successfully")
|
| 164 |
+
else:
|
| 165 |
+
logger.warning("No tokenizer found")
|
| 166 |
+
|
| 167 |
+
logger.info("Model loading completed")
|
| 168 |
+
|
| 169 |
+
except Exception as e:
|
| 170 |
+
logger.error(f"Failed to load model: {e}")
|
| 171 |
+
raise
|
| 172 |
+
|
| 173 |
+
def _get_cache_key(self, prompt: str, **kwargs) -> str:
|
| 174 |
+
"""Generate cache key for request."""
|
| 175 |
+
# Create a hash of the prompt and parameters
|
| 176 |
+
import hashlib
|
| 177 |
+
key_data = f"{prompt}_{kwargs}"
|
| 178 |
+
return hashlib.md5(key_data.encode()).hexdigest()
|
| 179 |
+
|
| 180 |
+
def _check_cache(self, cache_key: str) -> Optional[List[str]]:
|
| 181 |
+
"""Check if response is cached."""
|
| 182 |
+
if cache_key in self.response_cache:
|
| 183 |
+
self.metrics["cache_hits"] += 1
|
| 184 |
+
return self.response_cache[cache_key]
|
| 185 |
+
else:
|
| 186 |
+
self.metrics["cache_misses"] += 1
|
| 187 |
+
return None
|
| 188 |
+
|
| 189 |
+
def _update_cache(self, cache_key: str, response: List[str]):
|
| 190 |
+
"""Update response cache."""
|
| 191 |
+
if len(self.response_cache) >= self.cache_size:
|
| 192 |
+
# Remove oldest entry
|
| 193 |
+
oldest_key = next(iter(self.response_cache))
|
| 194 |
+
del self.response_cache[oldest_key]
|
| 195 |
+
|
| 196 |
+
self.response_cache[cache_key] = response
|
| 197 |
+
|
| 198 |
+
def _tokenize(self, text: str) -> torch.Tensor:
|
| 199 |
+
"""Tokenize text using the loaded tokenizer."""
|
| 200 |
+
if self.tokenizer is None:
|
| 201 |
+
# Fallback to simple tokenization
|
| 202 |
+
return torch.tensor([ord(c) % 1000 for c in text], dtype=torch.long)
|
| 203 |
+
|
| 204 |
+
tokens = self.tokenizer.encode_as_ids(text)
|
| 205 |
+
return torch.tensor(tokens, dtype=torch.long)
|
| 206 |
+
|
| 207 |
+
def _detokenize(self, tokens: torch.Tensor) -> str:
|
| 208 |
+
"""Detokenize tokens to text."""
|
| 209 |
+
if self.tokenizer is None:
|
| 210 |
+
# Fallback to simple detokenization
|
| 211 |
+
return ''.join([chr(t % 1000) for t in tokens.tolist()])
|
| 212 |
+
|
| 213 |
+
return self.tokenizer.decode(tokens.tolist())
|
| 214 |
+
|
| 215 |
+
def generate(self,
|
| 216 |
+
prompt: str,
|
| 217 |
+
max_length: int = 256,
|
| 218 |
+
temperature: float = 0.7,
|
| 219 |
+
top_k: Optional[int] = 40,
|
| 220 |
+
top_p: Optional[float] = 0.9,
|
| 221 |
+
num_return_sequences: int = 1,
|
| 222 |
+
stop_sequences: Optional[List[str]] = None) -> List[str]:
|
| 223 |
+
"""
|
| 224 |
+
Generate text with optimizations.
|
| 225 |
+
|
| 226 |
+
Args:
|
| 227 |
+
prompt: Input prompt
|
| 228 |
+
max_length: Maximum generation length
|
| 229 |
+
temperature: Sampling temperature
|
| 230 |
+
top_k: Top-k sampling parameter
|
| 231 |
+
top_p: Nucleus sampling parameter
|
| 232 |
+
num_return_sequences: Number of sequences to generate
|
| 233 |
+
stop_sequences: Stop generation at these sequences
|
| 234 |
+
|
| 235 |
+
Returns:
|
| 236 |
+
List of generated texts
|
| 237 |
+
"""
|
| 238 |
+
start_time = time.time()
|
| 239 |
+
|
| 240 |
+
# Check cache first
|
| 241 |
+
cache_key = self._get_cache_key(prompt, max_length=max_length,
|
| 242 |
+
temperature=temperature, top_k=top_k, top_p=top_p)
|
| 243 |
+
cached_response = self._check_cache(cache_key)
|
| 244 |
+
if cached_response:
|
| 245 |
+
return cached_response
|
| 246 |
+
|
| 247 |
+
# Tokenize input
|
| 248 |
+
input_tokens = self._tokenize(prompt)
|
| 249 |
+
input_tokens = input_tokens.unsqueeze(0).to(self.device) # Add batch dimension
|
| 250 |
+
|
| 251 |
+
# Generate text
|
| 252 |
+
with torch.no_grad():
|
| 253 |
+
if self.quantized_model and self.quantized_model.is_quantized:
|
| 254 |
+
# Use quantized model
|
| 255 |
+
generated_tokens = self.quantized_model.quantized_model.generate(
|
| 256 |
+
input_tokens,
|
| 257 |
+
max_new_tokens=max_length,
|
| 258 |
+
temperature=temperature,
|
| 259 |
+
top_k=top_k,
|
| 260 |
+
do_sample=True
|
| 261 |
+
)
|
| 262 |
+
else:
|
| 263 |
+
# Use regular model
|
| 264 |
+
generated_tokens = self.model.generate(
|
| 265 |
+
input_tokens,
|
| 266 |
+
max_new_tokens=max_length,
|
| 267 |
+
temperature=temperature,
|
| 268 |
+
top_k=top_k,
|
| 269 |
+
do_sample=True
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
# Detokenize
|
| 273 |
+
generated_texts = []
|
| 274 |
+
for i in range(num_return_sequences):
|
| 275 |
+
# Extract generated part (remove input)
|
| 276 |
+
generated_part = generated_tokens[0, len(input_tokens[0]):]
|
| 277 |
+
text = self._detokenize(generated_part)
|
| 278 |
+
|
| 279 |
+
# Apply stop sequences
|
| 280 |
+
if stop_sequences:
|
| 281 |
+
for stop_seq in stop_sequences:
|
| 282 |
+
if stop_seq in text:
|
| 283 |
+
text = text[:text.find(stop_seq)]
|
| 284 |
+
break
|
| 285 |
+
|
| 286 |
+
generated_texts.append(text)
|
| 287 |
+
|
| 288 |
+
# Update cache
|
| 289 |
+
self._update_cache(cache_key, generated_texts)
|
| 290 |
+
|
| 291 |
+
# Update metrics
|
| 292 |
+
generation_time = time.time() - start_time
|
| 293 |
+
self.metrics["total_requests"] += 1
|
| 294 |
+
self.metrics["total_generation_time"] += generation_time
|
| 295 |
+
self.metrics["avg_generation_time"] = (
|
| 296 |
+
self.metrics["total_generation_time"] / self.metrics["total_requests"]
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
return generated_texts
|
| 300 |
+
|
| 301 |
+
async def generate_async(self,
|
| 302 |
+
prompt: str,
|
| 303 |
+
max_length: int = 256,
|
| 304 |
+
temperature: float = 0.7,
|
| 305 |
+
top_k: Optional[int] = 40,
|
| 306 |
+
top_p: Optional[float] = 0.9,
|
| 307 |
+
num_return_sequences: int = 1,
|
| 308 |
+
stop_sequences: Optional[List[str]] = None) -> List[str]:
|
| 309 |
+
"""
|
| 310 |
+
Asynchronous text generation.
|
| 311 |
+
|
| 312 |
+
Args:
|
| 313 |
+
Same as generate()
|
| 314 |
+
|
| 315 |
+
Returns:
|
| 316 |
+
List of generated texts
|
| 317 |
+
"""
|
| 318 |
+
# Run generation in thread pool
|
| 319 |
+
loop = asyncio.get_event_loop()
|
| 320 |
+
return await loop.run_in_executor(
|
| 321 |
+
self.executor,
|
| 322 |
+
self.generate,
|
| 323 |
+
prompt, max_length, temperature, top_k, top_p,
|
| 324 |
+
num_return_sequences, stop_sequences
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
async def generate_stream(self,
|
| 328 |
+
prompt: str,
|
| 329 |
+
max_length: int = 256,
|
| 330 |
+
temperature: float = 0.7,
|
| 331 |
+
top_k: Optional[int] = 40,
|
| 332 |
+
top_p: Optional[float] = 0.9,
|
| 333 |
+
stop_sequences: Optional[List[str]] = None) -> AsyncGenerator[str, None]:
|
| 334 |
+
"""
|
| 335 |
+
Stream generated text token by token.
|
| 336 |
+
|
| 337 |
+
Args:
|
| 338 |
+
Same as generate()
|
| 339 |
+
|
| 340 |
+
Yields:
|
| 341 |
+
Generated text tokens
|
| 342 |
+
"""
|
| 343 |
+
# Tokenize input
|
| 344 |
+
input_tokens = self._tokenize(prompt)
|
| 345 |
+
input_tokens = input_tokens.unsqueeze(0).to(self.device)
|
| 346 |
+
|
| 347 |
+
# Generate tokens one by one
|
| 348 |
+
current_tokens = input_tokens.clone()
|
| 349 |
+
|
| 350 |
+
with torch.no_grad():
|
| 351 |
+
for _ in range(max_length):
|
| 352 |
+
# Get next token
|
| 353 |
+
if self.quantized_model and self.quantized_model.is_quantized:
|
| 354 |
+
logits = self.quantized_model.quantized_model(current_tokens)
|
| 355 |
+
else:
|
| 356 |
+
logits = self.model(current_tokens)
|
| 357 |
+
|
| 358 |
+
# Sample next token
|
| 359 |
+
logits = logits[:, -1, :] / temperature
|
| 360 |
+
if top_k is not None:
|
| 361 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 362 |
+
logits[logits < v[:, [-1]]] = -float("inf")
|
| 363 |
+
|
| 364 |
+
probs = F.softmax(logits, dim=-1)
|
| 365 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 366 |
+
|
| 367 |
+
# Add to sequence
|
| 368 |
+
current_tokens = torch.cat([current_tokens, next_token], dim=1)
|
| 369 |
+
|
| 370 |
+
# Convert token to text
|
| 371 |
+
token_text = self._detokenize(next_token[0])
|
| 372 |
+
yield token_text
|
| 373 |
+
|
| 374 |
+
# Check for stop sequences
|
| 375 |
+
if stop_sequences:
|
| 376 |
+
full_text = self._detokenize(current_tokens[0, len(input_tokens[0]):])
|
| 377 |
+
for stop_seq in stop_sequences:
|
| 378 |
+
if stop_seq in full_text:
|
| 379 |
+
return
|
| 380 |
+
|
| 381 |
+
def get_metrics(self) -> Dict[str, Any]:
|
| 382 |
+
"""Get performance metrics."""
|
| 383 |
+
memory_usage = psutil.virtual_memory().percent
|
| 384 |
+
|
| 385 |
+
return {
|
| 386 |
+
**self.metrics,
|
| 387 |
+
"memory_usage_percent": memory_usage,
|
| 388 |
+
"cache_size": len(self.response_cache),
|
| 389 |
+
"max_cache_size": self.cache_size,
|
| 390 |
+
"cache_hit_rate": (
|
| 391 |
+
self.metrics["cache_hits"] /
|
| 392 |
+
(self.metrics["cache_hits"] + self.metrics["cache_misses"])
|
| 393 |
+
if (self.metrics["cache_hits"] + self.metrics["cache_misses"]) > 0 else 0
|
| 394 |
+
),
|
| 395 |
+
"device": str(self.device),
|
| 396 |
+
"quantization_enabled": self.quantized_model is not None
|
| 397 |
+
}
|
| 398 |
+
|
| 399 |
+
def cleanup(self):
|
| 400 |
+
"""Clean up resources."""
|
| 401 |
+
if self.executor:
|
| 402 |
+
self.executor.shutdown(wait=True)
|
| 403 |
+
|
| 404 |
+
# Clear cache
|
| 405 |
+
self.response_cache.clear()
|
| 406 |
+
|
| 407 |
+
logger.info("Inference engine cleaned up")
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
# Request/Response models
|
| 411 |
+
class GenerationRequest(BaseModel):
|
| 412 |
+
"""Request model for text generation."""
|
| 413 |
+
prompt: str = Field(..., description="Input text prompt")
|
| 414 |
+
max_length: int = Field(256, description="Maximum generation length", ge=1, le=2048)
|
| 415 |
+
temperature: float = Field(0.7, description="Sampling temperature", ge=0.0, le=2.0)
|
| 416 |
+
top_k: Optional[int] = Field(40, description="Top-k sampling parameter", ge=1, le=1000)
|
| 417 |
+
top_p: Optional[float] = Field(0.9, description="Nucleus sampling parameter", ge=0.1, le=1.0)
|
| 418 |
+
num_return_sequences: int = Field(1, description="Number of sequences to generate", ge=1, le=5)
|
| 419 |
+
stop_sequences: Optional[List[str]] = Field(None, description="Stop generation at these sequences")
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
class GenerationResponse(BaseModel):
|
| 423 |
+
"""Response model for text generation."""
|
| 424 |
+
generated_text: List[str]
|
| 425 |
+
prompt: str
|
| 426 |
+
generation_time: float
|
| 427 |
+
parameters: Dict[str, Any]
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
class BatchGenerationRequest(BaseModel):
|
| 431 |
+
"""Request model for batch text generation."""
|
| 432 |
+
prompts: List[str] = Field(..., description="List of input prompts")
|
| 433 |
+
max_length: int = Field(256, description="Maximum generation length", ge=1, le=2048)
|
| 434 |
+
temperature: float = Field(0.7, description="Sampling temperature", ge=0.0, le=2.0)
|
| 435 |
+
top_k: Optional[int] = Field(40, description="Top-k sampling parameter", ge=1, le=1000)
|
| 436 |
+
top_p: Optional[float] = Field(0.9, description="Nucleus sampling parameter", ge=0.1, le=1.0)
|
| 437 |
+
stop_sequences: Optional[List[str]] = Field(None, description="Stop generation at these sequences")
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
class BatchGenerationResponse(BaseModel):
|
| 441 |
+
"""Response model for batch text generation."""
|
| 442 |
+
generated_texts: List[List[str]]
|
| 443 |
+
prompts: List[str]
|
| 444 |
+
generation_time: float
|
| 445 |
+
parameters: Dict[str, Any]
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
# Global inference engine
|
| 449 |
+
inference_engine: Optional[OptimizedInferenceEngine] = None
|
| 450 |
+
|
| 451 |
+
# FastAPI app
|
| 452 |
+
app = FastAPI(
|
| 453 |
+
title="Optimized OpenLLM Inference API",
|
| 454 |
+
description="High-performance REST API for OpenLLM text generation",
|
| 455 |
+
version="0.1.0",
|
| 456 |
+
docs_url="/docs",
|
| 457 |
+
redoc_url="/redoc",
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
# CORS middleware
|
| 461 |
+
app.add_middleware(
|
| 462 |
+
CORSMiddleware,
|
| 463 |
+
allow_origins=["*"],
|
| 464 |
+
allow_credentials=True,
|
| 465 |
+
allow_methods=["*"],
|
| 466 |
+
allow_headers=["*"],
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
@app.on_event("startup")
|
| 471 |
+
async def startup_event():
|
| 472 |
+
"""Initialize inference engine on startup."""
|
| 473 |
+
logger.info("π Starting Optimized OpenLLM Inference Server...")
|
| 474 |
+
global inference_engine
|
| 475 |
+
if inference_engine is None:
|
| 476 |
+
logger.warning("No model loaded - server will return 503 for generation requests")
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
@app.on_event("shutdown")
|
| 480 |
+
async def shutdown_event():
|
| 481 |
+
"""Clean up resources on shutdown."""
|
| 482 |
+
global inference_engine
|
| 483 |
+
if inference_engine:
|
| 484 |
+
inference_engine.cleanup()
|
| 485 |
+
logger.info("Server shutdown complete")
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
@app.post("/generate", response_model=GenerationResponse)
|
| 489 |
+
async def generate_text(request: GenerationRequest):
|
| 490 |
+
"""Generate text from prompt with optimizations."""
|
| 491 |
+
if inference_engine is None:
|
| 492 |
+
raise HTTPException(status_code=503, detail="Model not loaded")
|
| 493 |
+
|
| 494 |
+
start_time = time.time()
|
| 495 |
+
|
| 496 |
+
try:
|
| 497 |
+
# Generate text asynchronously
|
| 498 |
+
generated_texts = await inference_engine.generate_async(
|
| 499 |
+
prompt=request.prompt,
|
| 500 |
+
max_length=request.max_length,
|
| 501 |
+
temperature=request.temperature,
|
| 502 |
+
top_k=request.top_k,
|
| 503 |
+
top_p=request.top_p,
|
| 504 |
+
num_return_sequences=request.num_return_sequences,
|
| 505 |
+
stop_sequences=request.stop_sequences,
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
generation_time = time.time() - start_time
|
| 509 |
+
|
| 510 |
+
return GenerationResponse(
|
| 511 |
+
generated_text=generated_texts,
|
| 512 |
+
prompt=request.prompt,
|
| 513 |
+
generation_time=generation_time,
|
| 514 |
+
parameters={
|
| 515 |
+
"max_length": request.max_length,
|
| 516 |
+
"temperature": request.temperature,
|
| 517 |
+
"top_k": request.top_k,
|
| 518 |
+
"top_p": request.top_p,
|
| 519 |
+
"num_return_sequences": request.num_return_sequences,
|
| 520 |
+
},
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
except Exception as e:
|
| 524 |
+
logger.error(f"Generation failed: {e}")
|
| 525 |
+
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
@app.post("/generate/stream")
|
| 529 |
+
async def generate_text_stream(request: GenerationRequest):
|
| 530 |
+
"""Generate text with streaming response."""
|
| 531 |
+
if inference_engine is None:
|
| 532 |
+
raise HTTPException(status_code=503, detail="Model not loaded")
|
| 533 |
+
|
| 534 |
+
async def generate_stream():
|
| 535 |
+
try:
|
| 536 |
+
async for token in inference_engine.generate_stream(
|
| 537 |
+
prompt=request.prompt,
|
| 538 |
+
max_length=request.max_length,
|
| 539 |
+
temperature=request.temperature,
|
| 540 |
+
top_k=request.top_k,
|
| 541 |
+
top_p=request.top_p,
|
| 542 |
+
stop_sequences=request.stop_sequences,
|
| 543 |
+
):
|
| 544 |
+
yield f"data: {json.dumps({'token': token})}\n\n"
|
| 545 |
+
|
| 546 |
+
yield f"data: {json.dumps({'done': True})}\n\n"
|
| 547 |
+
|
| 548 |
+
except Exception as e:
|
| 549 |
+
logger.error(f"Streaming generation failed: {e}")
|
| 550 |
+
yield f"data: {json.dumps({'error': str(e)})}\n\n"
|
| 551 |
+
|
| 552 |
+
return StreamingResponse(
|
| 553 |
+
generate_stream(),
|
| 554 |
+
media_type="text/plain",
|
| 555 |
+
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
@app.post("/generate/batch", response_model=BatchGenerationResponse)
|
| 560 |
+
async def generate_text_batch(request: BatchGenerationRequest):
|
| 561 |
+
"""Generate text for multiple prompts in batch."""
|
| 562 |
+
if inference_engine is None:
|
| 563 |
+
raise HTTPException(status_code=503, detail="Model not loaded")
|
| 564 |
+
|
| 565 |
+
start_time = time.time()
|
| 566 |
+
|
| 567 |
+
try:
|
| 568 |
+
# Process prompts in parallel
|
| 569 |
+
tasks = []
|
| 570 |
+
for prompt in request.prompts:
|
| 571 |
+
task = inference_engine.generate_async(
|
| 572 |
+
prompt=prompt,
|
| 573 |
+
max_length=request.max_length,
|
| 574 |
+
temperature=request.temperature,
|
| 575 |
+
top_k=request.top_k,
|
| 576 |
+
top_p=request.top_p,
|
| 577 |
+
num_return_sequences=1,
|
| 578 |
+
stop_sequences=request.stop_sequences,
|
| 579 |
+
)
|
| 580 |
+
tasks.append(task)
|
| 581 |
+
|
| 582 |
+
# Wait for all tasks to complete
|
| 583 |
+
generated_texts = await asyncio.gather(*tasks)
|
| 584 |
+
|
| 585 |
+
generation_time = time.time() - start_time
|
| 586 |
+
|
| 587 |
+
return BatchGenerationResponse(
|
| 588 |
+
generated_texts=generated_texts,
|
| 589 |
+
prompts=request.prompts,
|
| 590 |
+
generation_time=generation_time,
|
| 591 |
+
parameters={
|
| 592 |
+
"max_length": request.max_length,
|
| 593 |
+
"temperature": request.temperature,
|
| 594 |
+
"top_k": request.top_k,
|
| 595 |
+
"top_p": request.top_p,
|
| 596 |
+
"num_prompts": len(request.prompts),
|
| 597 |
+
},
|
| 598 |
+
)
|
| 599 |
+
|
| 600 |
+
except Exception as e:
|
| 601 |
+
logger.error(f"Batch generation failed: {e}")
|
| 602 |
+
raise HTTPException(status_code=500, detail=f"Batch generation failed: {str(e)}")
|
| 603 |
+
|
| 604 |
+
|
| 605 |
+
@app.get("/health")
|
| 606 |
+
async def health_check():
|
| 607 |
+
"""Health check endpoint."""
|
| 608 |
+
global inference_engine
|
| 609 |
+
|
| 610 |
+
if inference_engine is None:
|
| 611 |
+
return {"status": "unhealthy", "message": "Model not loaded"}
|
| 612 |
+
|
| 613 |
+
try:
|
| 614 |
+
# Quick generation test
|
| 615 |
+
test_result = await inference_engine.generate_async(
|
| 616 |
+
prompt="Hello",
|
| 617 |
+
max_length=5,
|
| 618 |
+
temperature=0.7
|
| 619 |
+
)
|
| 620 |
+
|
| 621 |
+
return {
|
| 622 |
+
"status": "healthy",
|
| 623 |
+
"model_loaded": True,
|
| 624 |
+
"test_generation": len(test_result) > 0
|
| 625 |
+
}
|
| 626 |
+
|
| 627 |
+
except Exception as e:
|
| 628 |
+
return {
|
| 629 |
+
"status": "unhealthy",
|
| 630 |
+
"message": f"Generation test failed: {str(e)}"
|
| 631 |
+
}
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
@app.get("/metrics")
|
| 635 |
+
async def get_metrics():
|
| 636 |
+
"""Get performance metrics."""
|
| 637 |
+
global inference_engine
|
| 638 |
+
|
| 639 |
+
if inference_engine is None:
|
| 640 |
+
return {"error": "Model not loaded"}
|
| 641 |
+
|
| 642 |
+
return inference_engine.get_metrics()
|
| 643 |
+
|
| 644 |
+
|
| 645 |
+
@app.get("/info")
|
| 646 |
+
async def get_model_info():
|
| 647 |
+
"""Get model information."""
|
| 648 |
+
global inference_engine
|
| 649 |
+
|
| 650 |
+
if inference_engine is None:
|
| 651 |
+
return {"error": "Model not loaded"}
|
| 652 |
+
|
| 653 |
+
model = inference_engine.model
|
| 654 |
+
if model is None:
|
| 655 |
+
return {"error": "Model not available"}
|
| 656 |
+
|
| 657 |
+
return {
|
| 658 |
+
"model_name": model.config.model_name,
|
| 659 |
+
"vocab_size": model.config.vocab_size,
|
| 660 |
+
"n_layer": model.config.n_layer,
|
| 661 |
+
"n_head": model.config.n_head,
|
| 662 |
+
"n_embd": model.config.n_embd,
|
| 663 |
+
"block_size": model.config.block_size,
|
| 664 |
+
"parameters": model.get_num_params(),
|
| 665 |
+
"device": str(inference_engine.device),
|
| 666 |
+
"quantization_enabled": inference_engine.quantized_model is not None,
|
| 667 |
+
"cache_size": len(inference_engine.response_cache),
|
| 668 |
+
"max_cache_size": inference_engine.cache_size,
|
| 669 |
+
}
|
| 670 |
+
|
| 671 |
+
|
| 672 |
+
def create_optimized_server(model_path: str,
|
| 673 |
+
host: str = "0.0.0.0",
|
| 674 |
+
port: int = 8000,
|
| 675 |
+
device: str = "auto",
|
| 676 |
+
use_quantization: bool = True,
|
| 677 |
+
cache_size: int = 1000,
|
| 678 |
+
max_batch_size: int = 32,
|
| 679 |
+
num_workers: int = 4) -> FastAPI:
|
| 680 |
+
"""
|
| 681 |
+
Create an optimized inference server.
|
| 682 |
+
|
| 683 |
+
Args:
|
| 684 |
+
model_path: Path to the model
|
| 685 |
+
host: Server host
|
| 686 |
+
port: Server port
|
| 687 |
+
device: Device to use
|
| 688 |
+
use_quantization: Whether to use quantization
|
| 689 |
+
cache_size: Size of response cache
|
| 690 |
+
max_batch_size: Maximum batch size
|
| 691 |
+
num_workers: Number of worker threads
|
| 692 |
+
|
| 693 |
+
Returns:
|
| 694 |
+
FastAPI app instance
|
| 695 |
+
"""
|
| 696 |
+
global inference_engine
|
| 697 |
+
|
| 698 |
+
# Initialize inference engine
|
| 699 |
+
inference_engine = OptimizedInferenceEngine(
|
| 700 |
+
model_path=model_path,
|
| 701 |
+
device=device,
|
| 702 |
+
use_quantization=use_quantization,
|
| 703 |
+
cache_size=cache_size,
|
| 704 |
+
max_batch_size=max_batch_size,
|
| 705 |
+
num_workers=num_workers
|
| 706 |
+
)
|
| 707 |
+
|
| 708 |
+
return app
|
| 709 |
+
|
| 710 |
+
|
| 711 |
+
if __name__ == "__main__":
|
| 712 |
+
import argparse
|
| 713 |
+
|
| 714 |
+
parser = argparse.ArgumentParser(description="Optimized OpenLLM Inference Server")
|
| 715 |
+
parser.add_argument("--model_path", type=str, required=True, help="Path to model")
|
| 716 |
+
parser.add_argument("--host", type=str, default="0.0.0.0", help="Server host")
|
| 717 |
+
parser.add_argument("--port", type=int, default=8000, help="Server port")
|
| 718 |
+
parser.add_argument("--device", type=str, default="auto", help="Device to use")
|
| 719 |
+
parser.add_argument("--use_quantization", action="store_true", help="Use quantization")
|
| 720 |
+
parser.add_argument("--cache_size", type=int, default=1000, help="Cache size")
|
| 721 |
+
parser.add_argument("--max_batch_size", type=int, default=32, help="Max batch size")
|
| 722 |
+
parser.add_argument("--num_workers", type=int, default=4, help="Number of workers")
|
| 723 |
+
|
| 724 |
+
args = parser.parse_args()
|
| 725 |
+
|
| 726 |
+
# Create server
|
| 727 |
+
app = create_optimized_server(
|
| 728 |
+
model_path=args.model_path,
|
| 729 |
+
host=args.host,
|
| 730 |
+
port=args.port,
|
| 731 |
+
device=args.device,
|
| 732 |
+
use_quantization=args.use_quantization,
|
| 733 |
+
cache_size=args.cache_size,
|
| 734 |
+
max_batch_size=args.max_batch_size,
|
| 735 |
+
num_workers=args.num_workers
|
| 736 |
+
)
|
| 737 |
+
|
| 738 |
+
# Run server
|
| 739 |
+
uvicorn.run(app, host=args.host, port=args.port)
|
core/src/performance_monitor.py
ADDED
|
@@ -0,0 +1,543 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Performance Monitoring and Profiling
|
| 4 |
+
|
| 5 |
+
This module provides comprehensive performance monitoring and profiling
|
| 6 |
+
capabilities for the OpenLLM project, including system resources,
|
| 7 |
+
model performance, and optimization recommendations.
|
| 8 |
+
|
| 9 |
+
Author: Louis Chua Bean Chong
|
| 10 |
+
License: GPLv3
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import time
|
| 14 |
+
import psutil
|
| 15 |
+
import torch
|
| 16 |
+
import threading
|
| 17 |
+
from typing import Dict, List, Any, Optional, Callable
|
| 18 |
+
from dataclasses import dataclass, field
|
| 19 |
+
from collections import deque
|
| 20 |
+
import json
|
| 21 |
+
import logging
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
import numpy as np
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class SystemMetrics:
|
| 30 |
+
"""System resource metrics."""
|
| 31 |
+
cpu_percent: float
|
| 32 |
+
memory_percent: float
|
| 33 |
+
memory_available_gb: float
|
| 34 |
+
disk_usage_percent: float
|
| 35 |
+
network_io: Dict[str, float]
|
| 36 |
+
gpu_utilization: Optional[float] = None
|
| 37 |
+
gpu_memory_percent: Optional[float] = None
|
| 38 |
+
timestamp: float = field(default_factory=time.time)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@dataclass
|
| 42 |
+
class ModelMetrics:
|
| 43 |
+
"""Model performance metrics."""
|
| 44 |
+
inference_time_ms: float
|
| 45 |
+
tokens_per_second: float
|
| 46 |
+
memory_usage_mb: float
|
| 47 |
+
batch_size: int
|
| 48 |
+
sequence_length: int
|
| 49 |
+
model_parameters: int
|
| 50 |
+
timestamp: float = field(default_factory=time.time)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@dataclass
|
| 54 |
+
class TrainingMetrics:
|
| 55 |
+
"""Training performance metrics."""
|
| 56 |
+
loss: float
|
| 57 |
+
learning_rate: float
|
| 58 |
+
gradient_norm: float
|
| 59 |
+
training_time_ms: float
|
| 60 |
+
samples_per_second: float
|
| 61 |
+
memory_usage_mb: float
|
| 62 |
+
epoch: int
|
| 63 |
+
step: int
|
| 64 |
+
timestamp: float = field(default_factory=time.time)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class PerformanceProfiler:
|
| 68 |
+
"""
|
| 69 |
+
Performance profiler for monitoring and optimizing system performance.
|
| 70 |
+
|
| 71 |
+
This profiler tracks system resources, model performance, and training metrics
|
| 72 |
+
to provide insights and optimization recommendations.
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
def __init__(self,
|
| 76 |
+
history_size: int = 1000,
|
| 77 |
+
monitoring_interval: float = 1.0,
|
| 78 |
+
enable_gpu_monitoring: bool = True):
|
| 79 |
+
"""
|
| 80 |
+
Initialize performance profiler.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
history_size: Number of metrics to keep in history
|
| 84 |
+
monitoring_interval: Interval between system checks (seconds)
|
| 85 |
+
enable_gpu_monitoring: Whether to monitor GPU usage
|
| 86 |
+
"""
|
| 87 |
+
self.history_size = history_size
|
| 88 |
+
self.monitoring_interval = monitoring_interval
|
| 89 |
+
self.enable_gpu_monitoring = enable_gpu_monitoring
|
| 90 |
+
|
| 91 |
+
# Metrics storage
|
| 92 |
+
self.system_metrics = deque(maxlen=history_size)
|
| 93 |
+
self.model_metrics = deque(maxlen=history_size)
|
| 94 |
+
self.training_metrics = deque(maxlen=history_size)
|
| 95 |
+
|
| 96 |
+
# Monitoring state
|
| 97 |
+
self.monitoring_active = False
|
| 98 |
+
self.monitoring_thread = None
|
| 99 |
+
|
| 100 |
+
# Performance counters
|
| 101 |
+
self.total_inference_requests = 0
|
| 102 |
+
self.total_training_steps = 0
|
| 103 |
+
self.start_time = time.time()
|
| 104 |
+
|
| 105 |
+
# Optimization recommendations
|
| 106 |
+
self.recommendations = []
|
| 107 |
+
|
| 108 |
+
logger.info("PerformanceProfiler initialized")
|
| 109 |
+
|
| 110 |
+
def start_monitoring(self):
|
| 111 |
+
"""Start continuous system monitoring."""
|
| 112 |
+
if self.monitoring_active:
|
| 113 |
+
logger.warning("Monitoring already active")
|
| 114 |
+
return
|
| 115 |
+
|
| 116 |
+
self.monitoring_active = True
|
| 117 |
+
self.monitoring_thread = threading.Thread(target=self._monitoring_loop, daemon=True)
|
| 118 |
+
self.monitoring_thread.start()
|
| 119 |
+
logger.info("System monitoring started")
|
| 120 |
+
|
| 121 |
+
def stop_monitoring(self):
|
| 122 |
+
"""Stop continuous system monitoring."""
|
| 123 |
+
self.monitoring_active = False
|
| 124 |
+
if self.monitoring_thread:
|
| 125 |
+
self.monitoring_thread.join()
|
| 126 |
+
logger.info("System monitoring stopped")
|
| 127 |
+
|
| 128 |
+
def _monitoring_loop(self):
|
| 129 |
+
"""Main monitoring loop."""
|
| 130 |
+
while self.monitoring_active:
|
| 131 |
+
try:
|
| 132 |
+
metrics = self._collect_system_metrics()
|
| 133 |
+
self.system_metrics.append(metrics)
|
| 134 |
+
|
| 135 |
+
# Check for performance issues
|
| 136 |
+
self._check_performance_issues(metrics)
|
| 137 |
+
|
| 138 |
+
time.sleep(self.monitoring_interval)
|
| 139 |
+
except Exception as e:
|
| 140 |
+
logger.error(f"Monitoring error: {e}")
|
| 141 |
+
time.sleep(self.monitoring_interval)
|
| 142 |
+
|
| 143 |
+
def _collect_system_metrics(self) -> SystemMetrics:
|
| 144 |
+
"""Collect current system metrics."""
|
| 145 |
+
# CPU and memory
|
| 146 |
+
cpu_percent = psutil.cpu_percent(interval=0.1)
|
| 147 |
+
memory = psutil.virtual_memory()
|
| 148 |
+
memory_percent = memory.percent
|
| 149 |
+
memory_available_gb = memory.available / (1024**3)
|
| 150 |
+
|
| 151 |
+
# Disk usage
|
| 152 |
+
disk_usage = psutil.disk_usage('/')
|
| 153 |
+
disk_usage_percent = disk_usage.percent
|
| 154 |
+
|
| 155 |
+
# Network I/O
|
| 156 |
+
network_io = psutil.net_io_counters()
|
| 157 |
+
network_metrics = {
|
| 158 |
+
'bytes_sent': network_io.bytes_sent,
|
| 159 |
+
'bytes_recv': network_io.bytes_recv,
|
| 160 |
+
'packets_sent': network_io.packets_sent,
|
| 161 |
+
'packets_recv': network_io.packets_recv
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
# GPU metrics (if available)
|
| 165 |
+
gpu_utilization = None
|
| 166 |
+
gpu_memory_percent = None
|
| 167 |
+
|
| 168 |
+
if self.enable_gpu_monitoring and torch.cuda.is_available():
|
| 169 |
+
try:
|
| 170 |
+
gpu_utilization = torch.cuda.utilization()
|
| 171 |
+
gpu_memory = torch.cuda.memory_stats()
|
| 172 |
+
gpu_memory_percent = (
|
| 173 |
+
gpu_memory['allocated_bytes.all.current'] /
|
| 174 |
+
gpu_memory['reserved_bytes.all.current']
|
| 175 |
+
) * 100 if gpu_memory['reserved_bytes.all.current'] > 0 else 0
|
| 176 |
+
except Exception as e:
|
| 177 |
+
logger.debug(f"GPU monitoring error: {e}")
|
| 178 |
+
|
| 179 |
+
return SystemMetrics(
|
| 180 |
+
cpu_percent=cpu_percent,
|
| 181 |
+
memory_percent=memory_percent,
|
| 182 |
+
memory_available_gb=memory_available_gb,
|
| 183 |
+
disk_usage_percent=disk_usage_percent,
|
| 184 |
+
network_io=network_metrics,
|
| 185 |
+
gpu_utilization=gpu_utilization,
|
| 186 |
+
gpu_memory_percent=gpu_memory_percent
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
def _check_performance_issues(self, metrics: SystemMetrics):
|
| 190 |
+
"""Check for performance issues and generate recommendations."""
|
| 191 |
+
recommendations = []
|
| 192 |
+
|
| 193 |
+
# Memory usage check
|
| 194 |
+
if metrics.memory_percent > 90:
|
| 195 |
+
recommendations.append({
|
| 196 |
+
'type': 'memory_high',
|
| 197 |
+
'severity': 'high',
|
| 198 |
+
'message': f'Memory usage is very high ({metrics.memory_percent:.1f}%)',
|
| 199 |
+
'suggestion': 'Consider reducing batch size or using gradient checkpointing'
|
| 200 |
+
})
|
| 201 |
+
elif metrics.memory_percent > 80:
|
| 202 |
+
recommendations.append({
|
| 203 |
+
'type': 'memory_high',
|
| 204 |
+
'severity': 'medium',
|
| 205 |
+
'message': f'Memory usage is high ({metrics.memory_percent:.1f}%)',
|
| 206 |
+
'suggestion': 'Monitor memory usage and consider optimization'
|
| 207 |
+
})
|
| 208 |
+
|
| 209 |
+
# CPU usage check
|
| 210 |
+
if metrics.cpu_percent > 95:
|
| 211 |
+
recommendations.append({
|
| 212 |
+
'type': 'cpu_high',
|
| 213 |
+
'severity': 'high',
|
| 214 |
+
'message': f'CPU usage is very high ({metrics.cpu_percent:.1f}%)',
|
| 215 |
+
'suggestion': 'Consider reducing number of workers or using GPU'
|
| 216 |
+
})
|
| 217 |
+
|
| 218 |
+
# GPU usage check
|
| 219 |
+
if metrics.gpu_utilization is not None:
|
| 220 |
+
if metrics.gpu_utilization < 50:
|
| 221 |
+
recommendations.append({
|
| 222 |
+
'type': 'gpu_underutilized',
|
| 223 |
+
'severity': 'low',
|
| 224 |
+
'message': f'GPU utilization is low ({metrics.gpu_utilization:.1f}%)',
|
| 225 |
+
'suggestion': 'Consider increasing batch size or using mixed precision'
|
| 226 |
+
})
|
| 227 |
+
elif metrics.gpu_memory_percent and metrics.gpu_memory_percent > 90:
|
| 228 |
+
recommendations.append({
|
| 229 |
+
'type': 'gpu_memory_high',
|
| 230 |
+
'severity': 'high',
|
| 231 |
+
'message': f'GPU memory usage is very high ({metrics.gpu_memory_percent:.1f}%)',
|
| 232 |
+
'suggestion': 'Consider reducing batch size or using gradient checkpointing'
|
| 233 |
+
})
|
| 234 |
+
|
| 235 |
+
# Add recommendations to history
|
| 236 |
+
for rec in recommendations:
|
| 237 |
+
rec['timestamp'] = time.time()
|
| 238 |
+
self.recommendations.append(rec)
|
| 239 |
+
|
| 240 |
+
# Keep only recent recommendations
|
| 241 |
+
if len(self.recommendations) > 100:
|
| 242 |
+
self.recommendations = self.recommendations[-100:]
|
| 243 |
+
|
| 244 |
+
def record_inference(self,
|
| 245 |
+
inference_time_ms: float,
|
| 246 |
+
tokens_generated: int,
|
| 247 |
+
memory_usage_mb: float,
|
| 248 |
+
batch_size: int,
|
| 249 |
+
sequence_length: int,
|
| 250 |
+
model_parameters: int):
|
| 251 |
+
"""Record inference performance metrics."""
|
| 252 |
+
tokens_per_second = (tokens_generated / (inference_time_ms / 1000)) if inference_time_ms > 0 else 0
|
| 253 |
+
|
| 254 |
+
metrics = ModelMetrics(
|
| 255 |
+
inference_time_ms=inference_time_ms,
|
| 256 |
+
tokens_per_second=tokens_per_second,
|
| 257 |
+
memory_usage_mb=memory_usage_mb,
|
| 258 |
+
batch_size=batch_size,
|
| 259 |
+
sequence_length=sequence_length,
|
| 260 |
+
model_parameters=model_parameters
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
self.model_metrics.append(metrics)
|
| 264 |
+
self.total_inference_requests += 1
|
| 265 |
+
|
| 266 |
+
def record_training(self,
|
| 267 |
+
loss: float,
|
| 268 |
+
learning_rate: float,
|
| 269 |
+
gradient_norm: float,
|
| 270 |
+
training_time_ms: float,
|
| 271 |
+
samples_processed: int,
|
| 272 |
+
memory_usage_mb: float,
|
| 273 |
+
epoch: int,
|
| 274 |
+
step: int):
|
| 275 |
+
"""Record training performance metrics."""
|
| 276 |
+
samples_per_second = (samples_processed / (training_time_ms / 1000)) if training_time_ms > 0 else 0
|
| 277 |
+
|
| 278 |
+
metrics = TrainingMetrics(
|
| 279 |
+
loss=loss,
|
| 280 |
+
learning_rate=learning_rate,
|
| 281 |
+
gradient_norm=gradient_norm,
|
| 282 |
+
training_time_ms=training_time_ms,
|
| 283 |
+
samples_per_second=samples_per_second,
|
| 284 |
+
memory_usage_mb=memory_usage_mb,
|
| 285 |
+
epoch=epoch,
|
| 286 |
+
step=step
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
self.training_metrics.append(metrics)
|
| 290 |
+
self.total_training_steps += 1
|
| 291 |
+
|
| 292 |
+
def get_system_summary(self) -> Dict[str, Any]:
|
| 293 |
+
"""Get system performance summary."""
|
| 294 |
+
if not self.system_metrics:
|
| 295 |
+
return {"error": "No system metrics available"}
|
| 296 |
+
|
| 297 |
+
recent_metrics = list(self.system_metrics)[-100:] # Last 100 measurements
|
| 298 |
+
|
| 299 |
+
cpu_values = [m.cpu_percent for m in recent_metrics]
|
| 300 |
+
memory_values = [m.memory_percent for m in recent_metrics]
|
| 301 |
+
|
| 302 |
+
return {
|
| 303 |
+
"cpu": {
|
| 304 |
+
"current": cpu_values[-1] if cpu_values else 0,
|
| 305 |
+
"average": np.mean(cpu_values) if cpu_values else 0,
|
| 306 |
+
"max": np.max(cpu_values) if cpu_values else 0,
|
| 307 |
+
"min": np.min(cpu_values) if cpu_values else 0
|
| 308 |
+
},
|
| 309 |
+
"memory": {
|
| 310 |
+
"current_percent": memory_values[-1] if memory_values else 0,
|
| 311 |
+
"average_percent": np.mean(memory_values) if memory_values else 0,
|
| 312 |
+
"available_gb": recent_metrics[-1].memory_available_gb if recent_metrics else 0
|
| 313 |
+
},
|
| 314 |
+
"gpu": {
|
| 315 |
+
"utilization": recent_metrics[-1].gpu_utilization if recent_metrics else None,
|
| 316 |
+
"memory_percent": recent_metrics[-1].gpu_memory_percent if recent_metrics else None
|
| 317 |
+
},
|
| 318 |
+
"uptime_hours": (time.time() - self.start_time) / 3600
|
| 319 |
+
}
|
| 320 |
+
|
| 321 |
+
def get_model_summary(self) -> Dict[str, Any]:
|
| 322 |
+
"""Get model performance summary."""
|
| 323 |
+
if not self.model_metrics:
|
| 324 |
+
return {"error": "No model metrics available"}
|
| 325 |
+
|
| 326 |
+
recent_metrics = list(self.model_metrics)[-100:] # Last 100 measurements
|
| 327 |
+
|
| 328 |
+
inference_times = [m.inference_time_ms for m in recent_metrics]
|
| 329 |
+
tokens_per_sec = [m.tokens_per_second for m in recent_metrics]
|
| 330 |
+
memory_usage = [m.memory_usage_mb for m in recent_metrics]
|
| 331 |
+
|
| 332 |
+
return {
|
| 333 |
+
"inference": {
|
| 334 |
+
"avg_time_ms": np.mean(inference_times) if inference_times else 0,
|
| 335 |
+
"min_time_ms": np.min(inference_times) if inference_times else 0,
|
| 336 |
+
"max_time_ms": np.max(inference_times) if inference_times else 0,
|
| 337 |
+
"avg_tokens_per_second": np.mean(tokens_per_sec) if tokens_per_sec else 0
|
| 338 |
+
},
|
| 339 |
+
"memory": {
|
| 340 |
+
"avg_usage_mb": np.mean(memory_usage) if memory_usage else 0,
|
| 341 |
+
"max_usage_mb": np.max(memory_usage) if memory_usage else 0
|
| 342 |
+
},
|
| 343 |
+
"total_requests": self.total_inference_requests,
|
| 344 |
+
"recent_requests": len(recent_metrics)
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
def get_training_summary(self) -> Dict[str, Any]:
|
| 348 |
+
"""Get training performance summary."""
|
| 349 |
+
if not self.training_metrics:
|
| 350 |
+
return {"error": "No training metrics available"}
|
| 351 |
+
|
| 352 |
+
recent_metrics = list(self.training_metrics)[-100:] # Last 100 measurements
|
| 353 |
+
|
| 354 |
+
losses = [m.loss for m in recent_metrics]
|
| 355 |
+
samples_per_sec = [m.samples_per_second for m in recent_metrics]
|
| 356 |
+
memory_usage = [m.memory_usage_mb for m in recent_metrics]
|
| 357 |
+
|
| 358 |
+
return {
|
| 359 |
+
"loss": {
|
| 360 |
+
"current": losses[-1] if losses else 0,
|
| 361 |
+
"average": np.mean(losses) if losses else 0,
|
| 362 |
+
"min": np.min(losses) if losses else 0,
|
| 363 |
+
"trend": "decreasing" if len(losses) > 1 and losses[-1] < losses[0] else "increasing"
|
| 364 |
+
},
|
| 365 |
+
"performance": {
|
| 366 |
+
"avg_samples_per_second": np.mean(samples_per_sec) if samples_per_sec else 0,
|
| 367 |
+
"avg_memory_usage_mb": np.mean(memory_usage) if memory_usage else 0
|
| 368 |
+
},
|
| 369 |
+
"total_steps": self.total_training_steps,
|
| 370 |
+
"recent_steps": len(recent_metrics),
|
| 371 |
+
"current_epoch": recent_metrics[-1].epoch if recent_metrics else 0
|
| 372 |
+
}
|
| 373 |
+
|
| 374 |
+
def get_recommendations(self) -> List[Dict[str, Any]]:
|
| 375 |
+
"""Get current optimization recommendations."""
|
| 376 |
+
return self.recommendations[-10:] # Return last 10 recommendations
|
| 377 |
+
|
| 378 |
+
def generate_optimization_report(self) -> Dict[str, Any]:
|
| 379 |
+
"""Generate comprehensive optimization report."""
|
| 380 |
+
system_summary = self.get_system_summary()
|
| 381 |
+
model_summary = self.get_model_summary()
|
| 382 |
+
training_summary = self.get_training_summary()
|
| 383 |
+
recommendations = self.get_recommendations()
|
| 384 |
+
|
| 385 |
+
# Calculate overall performance score
|
| 386 |
+
performance_score = self._calculate_performance_score(
|
| 387 |
+
system_summary, model_summary, training_summary
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
return {
|
| 391 |
+
"timestamp": time.time(),
|
| 392 |
+
"performance_score": performance_score,
|
| 393 |
+
"system_summary": system_summary,
|
| 394 |
+
"model_summary": model_summary,
|
| 395 |
+
"training_summary": training_summary,
|
| 396 |
+
"recommendations": recommendations,
|
| 397 |
+
"optimization_priority": self._get_optimization_priority(recommendations)
|
| 398 |
+
}
|
| 399 |
+
|
| 400 |
+
def _calculate_performance_score(self,
|
| 401 |
+
system_summary: Dict,
|
| 402 |
+
model_summary: Dict,
|
| 403 |
+
training_summary: Dict) -> float:
|
| 404 |
+
"""Calculate overall performance score (0-100)."""
|
| 405 |
+
score = 100.0
|
| 406 |
+
|
| 407 |
+
# Deduct points for system issues
|
| 408 |
+
if "cpu" in system_summary:
|
| 409 |
+
cpu_avg = system_summary["cpu"]["average"]
|
| 410 |
+
if cpu_avg > 90:
|
| 411 |
+
score -= 20
|
| 412 |
+
elif cpu_avg > 80:
|
| 413 |
+
score -= 10
|
| 414 |
+
elif cpu_avg > 70:
|
| 415 |
+
score -= 5
|
| 416 |
+
|
| 417 |
+
if "memory" in system_summary:
|
| 418 |
+
memory_avg = system_summary["memory"]["average_percent"]
|
| 419 |
+
if memory_avg > 90:
|
| 420 |
+
score -= 20
|
| 421 |
+
elif memory_avg > 80:
|
| 422 |
+
score -= 10
|
| 423 |
+
elif memory_avg > 70:
|
| 424 |
+
score -= 5
|
| 425 |
+
|
| 426 |
+
# Deduct points for model performance issues
|
| 427 |
+
if "inference" in model_summary:
|
| 428 |
+
avg_time = model_summary["inference"]["avg_time_ms"]
|
| 429 |
+
if avg_time > 1000: # More than 1 second
|
| 430 |
+
score -= 15
|
| 431 |
+
elif avg_time > 500: # More than 500ms
|
| 432 |
+
score -= 10
|
| 433 |
+
elif avg_time > 100: # More than 100ms
|
| 434 |
+
score -= 5
|
| 435 |
+
|
| 436 |
+
return max(0, score)
|
| 437 |
+
|
| 438 |
+
def _get_optimization_priority(self, recommendations: List[Dict]) -> str:
|
| 439 |
+
"""Get optimization priority based on recommendations."""
|
| 440 |
+
high_priority = sum(1 for r in recommendations if r.get('severity') == 'high')
|
| 441 |
+
medium_priority = sum(1 for r in recommendations if r.get('severity') == 'medium')
|
| 442 |
+
|
| 443 |
+
if high_priority > 0:
|
| 444 |
+
return "high"
|
| 445 |
+
elif medium_priority > 2:
|
| 446 |
+
return "medium"
|
| 447 |
+
else:
|
| 448 |
+
return "low"
|
| 449 |
+
|
| 450 |
+
def save_metrics(self, filepath: str):
|
| 451 |
+
"""Save metrics to file."""
|
| 452 |
+
try:
|
| 453 |
+
data = {
|
| 454 |
+
"system_metrics": [self._metric_to_dict(m) for m in self.system_metrics],
|
| 455 |
+
"model_metrics": [self._metric_to_dict(m) for m in self.model_metrics],
|
| 456 |
+
"training_metrics": [self._metric_to_dict(m) for m in self.training_metrics],
|
| 457 |
+
"recommendations": self.recommendations,
|
| 458 |
+
"summary": {
|
| 459 |
+
"total_inference_requests": self.total_inference_requests,
|
| 460 |
+
"total_training_steps": self.total_training_steps,
|
| 461 |
+
"uptime_hours": (time.time() - self.start_time) / 3600
|
| 462 |
+
}
|
| 463 |
+
}
|
| 464 |
+
|
| 465 |
+
with open(filepath, 'w') as f:
|
| 466 |
+
json.dump(data, f, indent=2, default=str)
|
| 467 |
+
|
| 468 |
+
logger.info(f"Metrics saved to {filepath}")
|
| 469 |
+
|
| 470 |
+
except Exception as e:
|
| 471 |
+
logger.error(f"Failed to save metrics: {e}")
|
| 472 |
+
|
| 473 |
+
def _metric_to_dict(self, metric) -> Dict:
|
| 474 |
+
"""Convert metric object to dictionary."""
|
| 475 |
+
return {k: v for k, v in metric.__dict__.items() if not k.startswith('_')}
|
| 476 |
+
|
| 477 |
+
def load_metrics(self, filepath: str):
|
| 478 |
+
"""Load metrics from file."""
|
| 479 |
+
try:
|
| 480 |
+
with open(filepath, 'r') as f:
|
| 481 |
+
data = json.load(f)
|
| 482 |
+
|
| 483 |
+
# Reconstruct metrics objects
|
| 484 |
+
self.system_metrics = deque(
|
| 485 |
+
[SystemMetrics(**m) for m in data.get("system_metrics", [])],
|
| 486 |
+
maxlen=self.history_size
|
| 487 |
+
)
|
| 488 |
+
self.model_metrics = deque(
|
| 489 |
+
[ModelMetrics(**m) for m in data.get("model_metrics", [])],
|
| 490 |
+
maxlen=self.history_size
|
| 491 |
+
)
|
| 492 |
+
self.training_metrics = deque(
|
| 493 |
+
[TrainingMetrics(**m) for m in data.get("training_metrics", [])],
|
| 494 |
+
maxlen=self.history_size
|
| 495 |
+
)
|
| 496 |
+
self.recommendations = data.get("recommendations", [])
|
| 497 |
+
|
| 498 |
+
logger.info(f"Metrics loaded from {filepath}")
|
| 499 |
+
|
| 500 |
+
except Exception as e:
|
| 501 |
+
logger.error(f"Failed to load metrics: {e}")
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
# Global profiler instance
|
| 505 |
+
_global_profiler: Optional[PerformanceProfiler] = None
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
def get_profiler() -> PerformanceProfiler:
|
| 509 |
+
"""Get global profiler instance."""
|
| 510 |
+
global _global_profiler
|
| 511 |
+
if _global_profiler is None:
|
| 512 |
+
_global_profiler = PerformanceProfiler()
|
| 513 |
+
return _global_profiler
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
def start_monitoring():
|
| 517 |
+
"""Start global performance monitoring."""
|
| 518 |
+
profiler = get_profiler()
|
| 519 |
+
profiler.start_monitoring()
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
def stop_monitoring():
|
| 523 |
+
"""Stop global performance monitoring."""
|
| 524 |
+
profiler = get_profiler()
|
| 525 |
+
profiler.stop_monitoring()
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
def record_inference(**kwargs):
|
| 529 |
+
"""Record inference metrics using global profiler."""
|
| 530 |
+
profiler = get_profiler()
|
| 531 |
+
profiler.record_inference(**kwargs)
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
def record_training(**kwargs):
|
| 535 |
+
"""Record training metrics using global profiler."""
|
| 536 |
+
profiler = get_profiler()
|
| 537 |
+
profiler.record_training(**kwargs)
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
def get_performance_report() -> Dict[str, Any]:
|
| 541 |
+
"""Get performance report using global profiler."""
|
| 542 |
+
profiler = get_profiler()
|
| 543 |
+
return profiler.generate_optimization_report()
|
core/src/quantization.py
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Model Quantization Utilities
|
| 4 |
+
|
| 5 |
+
This module provides utilities for model quantization to reduce memory usage
|
| 6 |
+
and improve inference speed while maintaining reasonable accuracy.
|
| 7 |
+
|
| 8 |
+
Author: Louis Chua Bean Chong
|
| 9 |
+
License: GPLv3
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.quantization as quantization
|
| 15 |
+
from typing import Optional, Dict, Any
|
| 16 |
+
import copy
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class QuantizedModel:
|
| 20 |
+
"""
|
| 21 |
+
Wrapper for quantized models with easy conversion and inference.
|
| 22 |
+
|
| 23 |
+
This class provides utilities for converting models to quantized versions
|
| 24 |
+
and performing efficient inference with reduced memory usage.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, model: nn.Module, quantized_model: Optional[nn.Module] = None):
|
| 28 |
+
"""
|
| 29 |
+
Initialize quantized model wrapper.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
model: Original model
|
| 33 |
+
quantized_model: Pre-quantized model (optional)
|
| 34 |
+
"""
|
| 35 |
+
self.original_model = model
|
| 36 |
+
self.quantized_model = quantized_model
|
| 37 |
+
self.is_quantized = quantized_model is not None
|
| 38 |
+
|
| 39 |
+
def quantize_dynamic(self,
|
| 40 |
+
qconfig_spec: Optional[Dict] = None,
|
| 41 |
+
dtype: torch.dtype = torch.qint8) -> 'QuantizedModel':
|
| 42 |
+
"""
|
| 43 |
+
Perform dynamic quantization on the model.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
qconfig_spec: Quantization configuration
|
| 47 |
+
dtype: Quantization dtype (qint8, quint8)
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
QuantizedModel: Self with quantized model
|
| 51 |
+
"""
|
| 52 |
+
if qconfig_spec is None:
|
| 53 |
+
qconfig_spec = {
|
| 54 |
+
nn.Linear: quantization.default_dynamic_qconfig,
|
| 55 |
+
nn.LSTM: quantization.default_dynamic_qconfig,
|
| 56 |
+
nn.LSTMCell: quantization.default_dynamic_qconfig,
|
| 57 |
+
nn.RNNCell: quantization.default_dynamic_qconfig,
|
| 58 |
+
nn.GRUCell: quantization.default_dynamic_qconfig,
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
# Create a copy of the model for quantization
|
| 62 |
+
model_copy = copy.deepcopy(self.original_model)
|
| 63 |
+
model_copy.eval()
|
| 64 |
+
|
| 65 |
+
# Prepare model for quantization
|
| 66 |
+
model_prepared = quantization.prepare_dynamic(model_copy, qconfig_spec)
|
| 67 |
+
|
| 68 |
+
# Convert to quantized model
|
| 69 |
+
self.quantized_model = quantization.convert(model_prepared)
|
| 70 |
+
self.is_quantized = True
|
| 71 |
+
|
| 72 |
+
print(f"Dynamic quantization completed with dtype: {dtype}")
|
| 73 |
+
return self
|
| 74 |
+
|
| 75 |
+
def quantize_static(self,
|
| 76 |
+
calibration_data: torch.utils.data.DataLoader,
|
| 77 |
+
qconfig: Optional[quantization.QConfig] = None) -> 'QuantizedModel':
|
| 78 |
+
"""
|
| 79 |
+
Perform static quantization on the model.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
calibration_data: DataLoader for calibration
|
| 83 |
+
qconfig: Quantization configuration
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
QuantizedModel: Self with quantized model
|
| 87 |
+
"""
|
| 88 |
+
if qconfig is None:
|
| 89 |
+
qconfig = quantization.get_default_qconfig('fbgemm')
|
| 90 |
+
|
| 91 |
+
# Create a copy of the model for quantization
|
| 92 |
+
model_copy = copy.deepcopy(self.original_model)
|
| 93 |
+
model_copy.eval()
|
| 94 |
+
|
| 95 |
+
# Prepare model for quantization
|
| 96 |
+
model_prepared = quantization.prepare(model_copy, qconfig)
|
| 97 |
+
|
| 98 |
+
# Calibrate the model
|
| 99 |
+
print("Calibrating model...")
|
| 100 |
+
with torch.no_grad():
|
| 101 |
+
for batch_idx, (data, _) in enumerate(calibration_data):
|
| 102 |
+
if batch_idx >= 100: # Limit calibration samples
|
| 103 |
+
break
|
| 104 |
+
model_prepared(data)
|
| 105 |
+
|
| 106 |
+
# Convert to quantized model
|
| 107 |
+
self.quantized_model = quantization.convert(model_prepared)
|
| 108 |
+
self.is_quantized = True
|
| 109 |
+
|
| 110 |
+
print("Static quantization completed")
|
| 111 |
+
return self
|
| 112 |
+
|
| 113 |
+
def forward(self, *args, **kwargs):
|
| 114 |
+
"""Forward pass using quantized model if available."""
|
| 115 |
+
if self.is_quantized and self.quantized_model is not None:
|
| 116 |
+
return self.quantized_model(*args, **kwargs)
|
| 117 |
+
else:
|
| 118 |
+
return self.original_model(*args, **kwargs)
|
| 119 |
+
|
| 120 |
+
def get_memory_usage(self) -> Dict[str, float]:
|
| 121 |
+
"""
|
| 122 |
+
Get memory usage comparison between original and quantized models.
|
| 123 |
+
|
| 124 |
+
Returns:
|
| 125 |
+
dict: Memory usage in MB
|
| 126 |
+
"""
|
| 127 |
+
def get_model_size(model):
|
| 128 |
+
param_size = 0
|
| 129 |
+
buffer_size = 0
|
| 130 |
+
|
| 131 |
+
for param in model.parameters():
|
| 132 |
+
param_size += param.nelement() * param.element_size()
|
| 133 |
+
|
| 134 |
+
for buffer in model.buffers():
|
| 135 |
+
buffer_size += buffer.nelement() * buffer.element_size()
|
| 136 |
+
|
| 137 |
+
return (param_size + buffer_size) / (1024 * 1024) # Convert to MB
|
| 138 |
+
|
| 139 |
+
original_size = get_model_size(self.original_model)
|
| 140 |
+
quantized_size = get_model_size(self.quantized_model) if self.quantized_model else original_size
|
| 141 |
+
|
| 142 |
+
return {
|
| 143 |
+
"original_mb": original_size,
|
| 144 |
+
"quantized_mb": quantized_size,
|
| 145 |
+
"compression_ratio": original_size / quantized_size if quantized_size > 0 else 1.0
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
def save_quantized(self, path: str):
|
| 149 |
+
"""Save quantized model."""
|
| 150 |
+
if self.quantized_model is not None:
|
| 151 |
+
torch.save(self.quantized_model.state_dict(), path)
|
| 152 |
+
print(f"Quantized model saved to: {path}")
|
| 153 |
+
else:
|
| 154 |
+
raise ValueError("No quantized model available")
|
| 155 |
+
|
| 156 |
+
def load_quantized(self, path: str):
|
| 157 |
+
"""Load quantized model."""
|
| 158 |
+
self.quantized_model.load_state_dict(torch.load(path))
|
| 159 |
+
self.is_quantized = True
|
| 160 |
+
print(f"Quantized model loaded from: {path}")
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def quantize_model_dynamic(model: nn.Module,
|
| 164 |
+
dtype: torch.dtype = torch.qint8) -> QuantizedModel:
|
| 165 |
+
"""
|
| 166 |
+
Convenience function for dynamic quantization.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
model: Model to quantize
|
| 170 |
+
dtype: Quantization dtype
|
| 171 |
+
|
| 172 |
+
Returns:
|
| 173 |
+
QuantizedModel: Quantized model wrapper
|
| 174 |
+
"""
|
| 175 |
+
quantized = QuantizedModel(model)
|
| 176 |
+
return quantized.quantize_dynamic(dtype=dtype)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def quantize_model_static(model: nn.Module,
|
| 180 |
+
calibration_data: torch.utils.data.DataLoader,
|
| 181 |
+
qconfig: Optional[quantization.QConfig] = None) -> QuantizedModel:
|
| 182 |
+
"""
|
| 183 |
+
Convenience function for static quantization.
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
model: Model to quantize
|
| 187 |
+
calibration_data: Data for calibration
|
| 188 |
+
qconfig: Quantization configuration
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
QuantizedModel: Quantized model wrapper
|
| 192 |
+
"""
|
| 193 |
+
quantized = QuantizedModel(model)
|
| 194 |
+
return quantized.quantize_static(calibration_data, qconfig)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def create_quantization_config(backend: str = 'fbgemm',
|
| 198 |
+
dtype: torch.dtype = torch.qint8) -> quantization.QConfig:
|
| 199 |
+
"""
|
| 200 |
+
Create quantization configuration.
|
| 201 |
+
|
| 202 |
+
Args:
|
| 203 |
+
backend: Quantization backend ('fbgemm', 'qnnpack')
|
| 204 |
+
dtype: Quantization dtype
|
| 205 |
+
|
| 206 |
+
Returns:
|
| 207 |
+
QConfig: Quantization configuration
|
| 208 |
+
"""
|
| 209 |
+
if backend == 'fbgemm':
|
| 210 |
+
return quantization.QConfig(
|
| 211 |
+
activation=quantization.default_observer,
|
| 212 |
+
weight=quantization.default_per_channel_weight_observer
|
| 213 |
+
)
|
| 214 |
+
elif backend == 'qnnpack':
|
| 215 |
+
return quantization.QConfig(
|
| 216 |
+
activation=quantization.default_observer,
|
| 217 |
+
weight=quantization.default_weight_observer
|
| 218 |
+
)
|
| 219 |
+
else:
|
| 220 |
+
raise ValueError(f"Unsupported backend: {backend}")
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def benchmark_quantization(original_model: nn.Module,
|
| 224 |
+
quantized_model: QuantizedModel,
|
| 225 |
+
test_data: torch.Tensor,
|
| 226 |
+
num_runs: int = 100) -> Dict[str, float]:
|
| 227 |
+
"""
|
| 228 |
+
Benchmark original vs quantized model performance.
|
| 229 |
+
|
| 230 |
+
Args:
|
| 231 |
+
original_model: Original model
|
| 232 |
+
quantized_model: Quantized model
|
| 233 |
+
test_data: Test data for benchmarking
|
| 234 |
+
num_runs: Number of runs for averaging
|
| 235 |
+
|
| 236 |
+
Returns:
|
| 237 |
+
dict: Performance metrics
|
| 238 |
+
"""
|
| 239 |
+
original_model.eval()
|
| 240 |
+
quantized_model.quantized_model.eval()
|
| 241 |
+
|
| 242 |
+
# Benchmark original model
|
| 243 |
+
start_time = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None
|
| 244 |
+
end_time = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None
|
| 245 |
+
|
| 246 |
+
if start_time:
|
| 247 |
+
start_time.record()
|
| 248 |
+
|
| 249 |
+
with torch.no_grad():
|
| 250 |
+
for _ in range(num_runs):
|
| 251 |
+
_ = original_model(test_data)
|
| 252 |
+
|
| 253 |
+
if end_time:
|
| 254 |
+
end_time.record()
|
| 255 |
+
torch.cuda.synchronize()
|
| 256 |
+
original_time = start_time.elapsed_time(end_time) / num_runs
|
| 257 |
+
else:
|
| 258 |
+
import time
|
| 259 |
+
start = time.time()
|
| 260 |
+
for _ in range(num_runs):
|
| 261 |
+
_ = original_model(test_data)
|
| 262 |
+
original_time = (time.time() - start) * 1000 / num_runs # Convert to ms
|
| 263 |
+
|
| 264 |
+
# Benchmark quantized model
|
| 265 |
+
if start_time:
|
| 266 |
+
start_time.record()
|
| 267 |
+
|
| 268 |
+
with torch.no_grad():
|
| 269 |
+
for _ in range(num_runs):
|
| 270 |
+
_ = quantized_model.quantized_model(test_data)
|
| 271 |
+
|
| 272 |
+
if end_time:
|
| 273 |
+
end_time.record()
|
| 274 |
+
torch.cuda.synchronize()
|
| 275 |
+
quantized_time = start_time.elapsed_time(end_time) / num_runs
|
| 276 |
+
else:
|
| 277 |
+
start = time.time()
|
| 278 |
+
for _ in range(num_runs):
|
| 279 |
+
_ = quantized_model.quantized_model(test_data)
|
| 280 |
+
quantized_time = (time.time() - start) * 1000 / num_runs # Convert to ms
|
| 281 |
+
|
| 282 |
+
return {
|
| 283 |
+
"original_time_ms": original_time,
|
| 284 |
+
"quantized_time_ms": quantized_time,
|
| 285 |
+
"speedup": original_time / quantized_time if quantized_time > 0 else 1.0
|
| 286 |
+
}
|
core/src/train_model.py
ADDED
|
@@ -0,0 +1,668 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright (C) 2024 Louis Chua Bean Chong
|
| 3 |
+
#
|
| 4 |
+
# This file is part of OpenLLM.
|
| 5 |
+
#
|
| 6 |
+
# OpenLLM is dual-licensed:
|
| 7 |
+
# 1. For open source use: GNU General Public License v3.0
|
| 8 |
+
# 2. For commercial use: Commercial License (contact for details)
|
| 9 |
+
#
|
| 10 |
+
# See LICENSE and docs/LICENSES.md for full license information.
|
| 11 |
+
|
| 12 |
+
"""
|
| 13 |
+
Language Model Training Script
|
| 14 |
+
|
| 15 |
+
This script implements the complete training pipeline for GPT-style language models.
|
| 16 |
+
It includes optimization, checkpointing, progress monitoring, and CPU-optimized training
|
| 17 |
+
for limited hardware environments.
|
| 18 |
+
|
| 19 |
+
FEATURES:
|
| 20 |
+
- CPU-optimized training with memory management
|
| 21 |
+
- Gradient accumulation for effective large batch sizes
|
| 22 |
+
- Learning rate scheduling with warmup
|
| 23 |
+
- Model checkpointing and resume capability
|
| 24 |
+
- Real-time monitoring of loss, perplexity, and speed
|
| 25 |
+
- Memory usage tracking and optimization
|
| 26 |
+
- Automatic mixed precision (if available)
|
| 27 |
+
|
| 28 |
+
HARDWARE OPTIMIZATION:
|
| 29 |
+
- Designed for 8GB RAM systems
|
| 30 |
+
- Efficient CPU training with PyTorch optimizations
|
| 31 |
+
- Gradient accumulation to simulate larger batches
|
| 32 |
+
- Memory cleanup and garbage collection
|
| 33 |
+
- Progress saving for long training runs
|
| 34 |
+
|
| 35 |
+
Usage:
|
| 36 |
+
python core/src/train_model.py \\
|
| 37 |
+
--model-size small \\
|
| 38 |
+
--data-file data/clean/training_data.txt \\
|
| 39 |
+
--tokenizer-dir data/tokenizer/ \\
|
| 40 |
+
--output-dir models/my-model/ \\
|
| 41 |
+
--max-steps 10000
|
| 42 |
+
|
| 43 |
+
Requirements:
|
| 44 |
+
- PyTorch
|
| 45 |
+
- SentencePiece
|
| 46 |
+
- Our model architecture and data loader
|
| 47 |
+
|
| 48 |
+
Author: Louis Chua Bean Chong
|
| 49 |
+
License: GPLv3
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
import argparse
|
| 53 |
+
import gc
|
| 54 |
+
import json
|
| 55 |
+
import math
|
| 56 |
+
import os
|
| 57 |
+
import time
|
| 58 |
+
from pathlib import Path
|
| 59 |
+
from typing import Dict
|
| 60 |
+
|
| 61 |
+
import torch
|
| 62 |
+
import torch.nn as nn
|
| 63 |
+
import torch.optim as optim
|
| 64 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
| 65 |
+
|
| 66 |
+
# Import our modules
|
| 67 |
+
try:
|
| 68 |
+
from data_loader import TextDataLoader
|
| 69 |
+
from model import GPTModel, create_model
|
| 70 |
+
except ImportError:
|
| 71 |
+
import sys
|
| 72 |
+
|
| 73 |
+
sys.path.append(os.path.dirname(__file__))
|
| 74 |
+
from data_loader import TextDataLoader
|
| 75 |
+
from model import GPTModel, create_model
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class TrainingConfig:
|
| 79 |
+
"""Configuration for model training parameters."""
|
| 80 |
+
|
| 81 |
+
def __init__(
|
| 82 |
+
self,
|
| 83 |
+
learning_rate: float = 1e-4,
|
| 84 |
+
batch_size: int = 32,
|
| 85 |
+
max_steps: int = 100000,
|
| 86 |
+
warmup_steps: int = 10000,
|
| 87 |
+
gradient_clipping: float = 1.0,
|
| 88 |
+
weight_decay: float = 0.01,
|
| 89 |
+
mixed_precision: bool = True,
|
| 90 |
+
gradient_checkpointing: bool = True,
|
| 91 |
+
):
|
| 92 |
+
self.learning_rate = learning_rate
|
| 93 |
+
self.batch_size = batch_size
|
| 94 |
+
self.max_steps = max_steps
|
| 95 |
+
self.warmup_steps = warmup_steps
|
| 96 |
+
self.gradient_clipping = gradient_clipping
|
| 97 |
+
self.weight_decay = weight_decay
|
| 98 |
+
self.mixed_precision = mixed_precision
|
| 99 |
+
self.gradient_checkpointing = gradient_checkpointing
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class ModelTrainer:
|
| 103 |
+
"""
|
| 104 |
+
Comprehensive trainer for GPT-style language models.
|
| 105 |
+
|
| 106 |
+
Handles the complete training pipeline including data loading, optimization,
|
| 107 |
+
checkpointing, and progress monitoring.
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
def __init__(
|
| 111 |
+
self,
|
| 112 |
+
model: GPTModel,
|
| 113 |
+
data_loader: TextDataLoader,
|
| 114 |
+
output_dir: str,
|
| 115 |
+
device: str = "cpu",
|
| 116 |
+
learning_rate: float = 3e-4,
|
| 117 |
+
weight_decay: float = 0.01,
|
| 118 |
+
warmup_steps: int = 1000,
|
| 119 |
+
max_steps: int = 10000,
|
| 120 |
+
gradient_accumulation_steps: int = 4,
|
| 121 |
+
gradient_clipping: float = 1.0,
|
| 122 |
+
save_every: int = 1000,
|
| 123 |
+
eval_every: int = 500,
|
| 124 |
+
log_every: int = 100,
|
| 125 |
+
):
|
| 126 |
+
"""
|
| 127 |
+
Initialize the model trainer.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
model: GPT model to train
|
| 131 |
+
data_loader: Data loader for training data
|
| 132 |
+
output_dir: Directory to save checkpoints and logs
|
| 133 |
+
device: Training device ("cpu" or "cuda")
|
| 134 |
+
learning_rate: Peak learning rate
|
| 135 |
+
weight_decay: Weight decay for regularization
|
| 136 |
+
warmup_steps: Number of warmup steps for learning rate
|
| 137 |
+
max_steps: Maximum training steps
|
| 138 |
+
gradient_accumulation_steps: Steps to accumulate gradients
|
| 139 |
+
gradient_clipping: Maximum gradient norm
|
| 140 |
+
save_every: Save checkpoint every N steps
|
| 141 |
+
eval_every: Evaluate model every N steps
|
| 142 |
+
log_every: Log progress every N steps
|
| 143 |
+
"""
|
| 144 |
+
self.model = model.to(device)
|
| 145 |
+
self.data_loader = data_loader
|
| 146 |
+
self.output_dir = Path(output_dir)
|
| 147 |
+
self.device = device
|
| 148 |
+
|
| 149 |
+
# Training hyperparameters
|
| 150 |
+
self.learning_rate = learning_rate
|
| 151 |
+
self.weight_decay = weight_decay
|
| 152 |
+
self.warmup_steps = warmup_steps
|
| 153 |
+
self.max_steps = max_steps
|
| 154 |
+
self.gradient_accumulation_steps = gradient_accumulation_steps
|
| 155 |
+
self.gradient_clipping = gradient_clipping
|
| 156 |
+
|
| 157 |
+
# Logging and saving
|
| 158 |
+
self.save_every = save_every
|
| 159 |
+
self.eval_every = eval_every
|
| 160 |
+
self.log_every = log_every
|
| 161 |
+
|
| 162 |
+
# Create output directory
|
| 163 |
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 164 |
+
|
| 165 |
+
# Initialize optimizer and scheduler
|
| 166 |
+
self.optimizer = self._create_optimizer()
|
| 167 |
+
self.scheduler = self._create_scheduler()
|
| 168 |
+
|
| 169 |
+
# Training state
|
| 170 |
+
self.step = 0
|
| 171 |
+
self.epoch = 0
|
| 172 |
+
self.best_loss = float("inf")
|
| 173 |
+
self.training_log = []
|
| 174 |
+
|
| 175 |
+
# Performance tracking
|
| 176 |
+
self.start_time = None
|
| 177 |
+
self.step_times = []
|
| 178 |
+
|
| 179 |
+
print("π ModelTrainer initialized")
|
| 180 |
+
print(f" Device: {device}")
|
| 181 |
+
print(f" Model parameters: {model.get_num_params():,}")
|
| 182 |
+
print(f" Learning rate: {learning_rate}")
|
| 183 |
+
print(f" Max steps: {max_steps:,}")
|
| 184 |
+
print(f" Gradient accumulation: {gradient_accumulation_steps}")
|
| 185 |
+
print(f" Output directory: {output_dir}")
|
| 186 |
+
|
| 187 |
+
def _create_optimizer(self) -> optim.Optimizer:
|
| 188 |
+
"""Create AdamW optimizer with weight decay."""
|
| 189 |
+
# Separate parameters for weight decay
|
| 190 |
+
decay_params = []
|
| 191 |
+
no_decay_params = []
|
| 192 |
+
|
| 193 |
+
for name, param in self.model.named_parameters():
|
| 194 |
+
if not param.requires_grad:
|
| 195 |
+
continue
|
| 196 |
+
|
| 197 |
+
# Don't apply weight decay to biases and layer norm parameters
|
| 198 |
+
if len(param.shape) == 1 or name.endswith(".bias"):
|
| 199 |
+
no_decay_params.append(param)
|
| 200 |
+
else:
|
| 201 |
+
decay_params.append(param)
|
| 202 |
+
|
| 203 |
+
param_groups = [
|
| 204 |
+
{"params": decay_params, "weight_decay": self.weight_decay},
|
| 205 |
+
{"params": no_decay_params, "weight_decay": 0.0},
|
| 206 |
+
]
|
| 207 |
+
|
| 208 |
+
# Use AdamW with lower memory usage for CPU
|
| 209 |
+
optimizer = optim.AdamW(
|
| 210 |
+
param_groups,
|
| 211 |
+
lr=self.learning_rate,
|
| 212 |
+
betas=(0.9, 0.95), # Slightly different from default for LLM training
|
| 213 |
+
eps=1e-8,
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
return optimizer
|
| 217 |
+
|
| 218 |
+
def _create_scheduler(self) -> torch.optim.lr_scheduler._LRScheduler:
|
| 219 |
+
"""Create learning rate scheduler with warmup and cosine decay."""
|
| 220 |
+
if self.warmup_steps > 0:
|
| 221 |
+
# Use a custom scheduler to avoid deprecation warnings
|
| 222 |
+
# This implements warmup + cosine decay without SequentialLR
|
| 223 |
+
class WarmupCosineScheduler(torch.optim.lr_scheduler._LRScheduler):
|
| 224 |
+
def __init__(self, optimizer, warmup_steps, max_steps, min_lr_factor=0.1):
|
| 225 |
+
self.warmup_steps = warmup_steps
|
| 226 |
+
self.max_steps = max_steps
|
| 227 |
+
self.min_lr_factor = min_lr_factor
|
| 228 |
+
super().__init__(optimizer)
|
| 229 |
+
|
| 230 |
+
def get_lr(self):
|
| 231 |
+
if self.last_epoch < self.warmup_steps:
|
| 232 |
+
# Linear warmup
|
| 233 |
+
factor = self.last_epoch / self.warmup_steps
|
| 234 |
+
return [base_lr * (0.01 + 0.99 * factor) for base_lr in self.base_lrs]
|
| 235 |
+
else:
|
| 236 |
+
# Cosine decay
|
| 237 |
+
progress = (self.last_epoch - self.warmup_steps) / (
|
| 238 |
+
self.max_steps - self.warmup_steps
|
| 239 |
+
)
|
| 240 |
+
progress = min(progress, 1.0) # Clamp to 1.0
|
| 241 |
+
factor = 0.5 * (1 + math.cos(math.pi * progress))
|
| 242 |
+
factor = self.min_lr_factor + (1 - self.min_lr_factor) * factor
|
| 243 |
+
return [base_lr * factor for base_lr in self.base_lrs]
|
| 244 |
+
|
| 245 |
+
scheduler = WarmupCosineScheduler(
|
| 246 |
+
self.optimizer,
|
| 247 |
+
warmup_steps=self.warmup_steps,
|
| 248 |
+
max_steps=self.max_steps,
|
| 249 |
+
min_lr_factor=0.1,
|
| 250 |
+
)
|
| 251 |
+
else:
|
| 252 |
+
# Just cosine decay - this should not trigger warnings
|
| 253 |
+
scheduler = CosineAnnealingLR(
|
| 254 |
+
self.optimizer, T_max=self.max_steps, eta_min=self.learning_rate * 0.1
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
return scheduler
|
| 258 |
+
|
| 259 |
+
def _calculate_loss(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
| 260 |
+
"""
|
| 261 |
+
Calculate cross-entropy loss for autoregressive language modeling.
|
| 262 |
+
|
| 263 |
+
This method computes the standard cross-entropy loss used in language model training.
|
| 264 |
+
The loss measures how well the model predicts the next token in the sequence.
|
| 265 |
+
|
| 266 |
+
Mathematical formulation:
|
| 267 |
+
Loss = -β log(P(target_token | context))
|
| 268 |
+
where P is the softmax probability distribution over vocabulary
|
| 269 |
+
|
| 270 |
+
Implementation details:
|
| 271 |
+
- Reshapes 3D tensors to 2D for efficient computation
|
| 272 |
+
- Uses PyTorch's optimized cross_entropy function
|
| 273 |
+
- Handles padding tokens by ignoring them in loss calculation
|
| 274 |
+
- Computes mean loss across all valid positions
|
| 275 |
+
|
| 276 |
+
Why cross-entropy for language modeling:
|
| 277 |
+
- Natural choice for multi-class classification (next token prediction)
|
| 278 |
+
- Provides strong gradient signal for correct token probabilities
|
| 279 |
+
- Mathematically equivalent to minimizing negative log-likelihood
|
| 280 |
+
- Well-studied optimization properties for neural language models
|
| 281 |
+
|
| 282 |
+
Args:
|
| 283 |
+
logits: Raw model predictions of shape (batch_size, seq_len, vocab_size)
|
| 284 |
+
Contains unnormalized scores for each token in vocabulary
|
| 285 |
+
These will be converted to probabilities via softmax internally
|
| 286 |
+
targets: Ground truth next tokens of shape (batch_size, seq_len)
|
| 287 |
+
Contains token IDs representing the true next tokens
|
| 288 |
+
Should be input sequence shifted by one position
|
| 289 |
+
|
| 290 |
+
Returns:
|
| 291 |
+
torch.Tensor: Scalar loss value representing prediction error
|
| 292 |
+
Lower values indicate better next-token prediction accuracy
|
| 293 |
+
"""
|
| 294 |
+
# Reshape tensors from 3D to 2D for efficient loss computation
|
| 295 |
+
# This converts per-sequence per-position predictions to a flat structure
|
| 296 |
+
# where each row represents one prediction over the entire vocabulary
|
| 297 |
+
logits = logits.view(-1, logits.size(-1)) # (batch_size * seq_len, vocab_size)
|
| 298 |
+
targets = targets.view(-1) # (batch_size * seq_len,)
|
| 299 |
+
|
| 300 |
+
# Calculate cross-entropy loss with proper handling of special tokens
|
| 301 |
+
# ignore_index=-1 excludes padding tokens from loss calculation
|
| 302 |
+
# This prevents the model from learning to predict padding, which would skew training
|
| 303 |
+
# The function internally applies softmax to logits and computes negative log-likelihood
|
| 304 |
+
loss = nn.functional.cross_entropy(logits, targets, ignore_index=-1)
|
| 305 |
+
|
| 306 |
+
# Return scalar loss for backpropagation
|
| 307 |
+
# This loss will be used to compute gradients via automatic differentiation
|
| 308 |
+
return loss
|
| 309 |
+
|
| 310 |
+
def _get_memory_usage(self) -> Dict[str, float]:
|
| 311 |
+
"""Get current memory usage statistics."""
|
| 312 |
+
memory_stats = {}
|
| 313 |
+
|
| 314 |
+
if torch.cuda.is_available() and self.device.startswith("cuda"):
|
| 315 |
+
memory_stats["gpu_allocated_mb"] = torch.cuda.memory_allocated() / (1024**2)
|
| 316 |
+
memory_stats["gpu_cached_mb"] = torch.cuda.memory_reserved() / (1024**2)
|
| 317 |
+
|
| 318 |
+
# Estimate CPU memory (approximate)
|
| 319 |
+
import psutil
|
| 320 |
+
|
| 321 |
+
process = psutil.Process()
|
| 322 |
+
memory_stats["cpu_memory_mb"] = process.memory_info().rss / (1024**2)
|
| 323 |
+
|
| 324 |
+
return memory_stats
|
| 325 |
+
|
| 326 |
+
def _log_step(self, step: int, loss: float, lr: float, step_time: float) -> None:
|
| 327 |
+
"""Log training progress for a single step."""
|
| 328 |
+
perplexity = math.exp(min(loss, 10)) # Cap at exp(10) to avoid overflow
|
| 329 |
+
|
| 330 |
+
# Calculate tokens per second
|
| 331 |
+
tokens_per_batch = self.data_loader.batch_size * self.data_loader.seq_len
|
| 332 |
+
tokens_per_second = tokens_per_batch / step_time if step_time > 0 else 0
|
| 333 |
+
|
| 334 |
+
# Get memory usage
|
| 335 |
+
memory_stats = self._get_memory_usage()
|
| 336 |
+
|
| 337 |
+
# Create log entry
|
| 338 |
+
log_entry = {
|
| 339 |
+
"step": step,
|
| 340 |
+
"loss": loss,
|
| 341 |
+
"perplexity": perplexity,
|
| 342 |
+
"learning_rate": lr,
|
| 343 |
+
"step_time": step_time,
|
| 344 |
+
"tokens_per_second": tokens_per_second,
|
| 345 |
+
"memory_mb": memory_stats.get("cpu_memory_mb", 0),
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
self.training_log.append(log_entry)
|
| 349 |
+
|
| 350 |
+
# Print progress
|
| 351 |
+
_ = time.time() - self.start_time if self.start_time else 0
|
| 352 |
+
eta_seconds = (self.max_steps - step) * step_time if step_time > 0 else 0
|
| 353 |
+
eta_hours = eta_seconds / 3600
|
| 354 |
+
|
| 355 |
+
print(
|
| 356 |
+
f"Step {step:,}/{self.max_steps:,} | "
|
| 357 |
+
f"Loss: {loss:.4f} | "
|
| 358 |
+
f"PPL: {perplexity:.2f} | "
|
| 359 |
+
f"LR: {lr:.2e} | "
|
| 360 |
+
f"Time: {step_time:.2f}s | "
|
| 361 |
+
f"Tokens/s: {tokens_per_second:.1f} | "
|
| 362 |
+
f"Memory: {memory_stats.get('cpu_memory_mb', 0):.0f}MB | "
|
| 363 |
+
f"ETA: {eta_hours:.1f}h"
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
def _save_checkpoint(self, step: int, is_best: bool = False) -> None:
|
| 367 |
+
"""Save model checkpoint."""
|
| 368 |
+
checkpoint = {
|
| 369 |
+
"step": step,
|
| 370 |
+
"epoch": self.epoch,
|
| 371 |
+
"model_state_dict": self.model.state_dict(),
|
| 372 |
+
"optimizer_state_dict": self.optimizer.state_dict(),
|
| 373 |
+
"scheduler_state_dict": self.scheduler.state_dict(),
|
| 374 |
+
"best_loss": self.best_loss,
|
| 375 |
+
"training_log": self.training_log,
|
| 376 |
+
"config": self.model.config.__dict__,
|
| 377 |
+
}
|
| 378 |
+
|
| 379 |
+
# Save latest checkpoint
|
| 380 |
+
checkpoint_path = self.output_dir / f"checkpoint_step_{step}.pt"
|
| 381 |
+
torch.save(checkpoint, checkpoint_path)
|
| 382 |
+
|
| 383 |
+
# Save best checkpoint
|
| 384 |
+
if is_best:
|
| 385 |
+
best_path = self.output_dir / "best_model.pt"
|
| 386 |
+
torch.save(checkpoint, best_path)
|
| 387 |
+
print(f"πΎ New best model saved: {best_path}")
|
| 388 |
+
|
| 389 |
+
# Save training log
|
| 390 |
+
log_path = self.output_dir / "training_log.json"
|
| 391 |
+
with open(log_path, "w") as f:
|
| 392 |
+
json.dump(self.training_log, f, indent=2)
|
| 393 |
+
|
| 394 |
+
print(f"πΎ Checkpoint saved: {checkpoint_path}")
|
| 395 |
+
|
| 396 |
+
def _load_checkpoint(self, checkpoint_path: str) -> None:
|
| 397 |
+
"""Load model checkpoint to resume training."""
|
| 398 |
+
if not os.path.exists(checkpoint_path):
|
| 399 |
+
print(f"β οΈ Checkpoint not found: {checkpoint_path}")
|
| 400 |
+
return
|
| 401 |
+
|
| 402 |
+
print(f"π Loading checkpoint: {checkpoint_path}")
|
| 403 |
+
|
| 404 |
+
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
| 405 |
+
|
| 406 |
+
self.model.load_state_dict(checkpoint["model_state_dict"])
|
| 407 |
+
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
| 408 |
+
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
| 409 |
+
|
| 410 |
+
self.step = checkpoint["step"]
|
| 411 |
+
self.epoch = checkpoint["epoch"]
|
| 412 |
+
self.best_loss = checkpoint["best_loss"]
|
| 413 |
+
self.training_log = checkpoint.get("training_log", [])
|
| 414 |
+
|
| 415 |
+
print("β Checkpoint loaded successfully")
|
| 416 |
+
print(f" Resuming from step: {self.step:,}")
|
| 417 |
+
print(f" Best loss so far: {self.best_loss:.4f}")
|
| 418 |
+
|
| 419 |
+
def train(self) -> None:
|
| 420 |
+
"""Main training loop."""
|
| 421 |
+
print("\nπ Starting training...")
|
| 422 |
+
print(f" Model: {self.model.config.model_name}")
|
| 423 |
+
print(f" Parameters: {self.model.get_num_params():,}")
|
| 424 |
+
print(f" Device: {self.device}")
|
| 425 |
+
print(f" Max steps: {self.max_steps:,}")
|
| 426 |
+
print("=" * 80)
|
| 427 |
+
|
| 428 |
+
self.model.train()
|
| 429 |
+
self.start_time = time.time()
|
| 430 |
+
|
| 431 |
+
# Initialize gradient accumulation
|
| 432 |
+
accumulated_loss = 0.0
|
| 433 |
+
self.optimizer.zero_grad()
|
| 434 |
+
|
| 435 |
+
for batch_idx, (input_ids, target_ids) in enumerate(self.data_loader):
|
| 436 |
+
if self.step >= self.max_steps:
|
| 437 |
+
break
|
| 438 |
+
|
| 439 |
+
step_start_time = time.time()
|
| 440 |
+
|
| 441 |
+
# Move batch to device
|
| 442 |
+
input_ids = input_ids.to(self.device)
|
| 443 |
+
target_ids = target_ids.to(self.device)
|
| 444 |
+
|
| 445 |
+
# Forward pass (model computes loss internally when targets provided)
|
| 446 |
+
logits, loss = self.model(input_ids, target_ids)
|
| 447 |
+
|
| 448 |
+
# Scale loss for gradient accumulation
|
| 449 |
+
loss = loss / self.gradient_accumulation_steps
|
| 450 |
+
accumulated_loss += loss.item()
|
| 451 |
+
|
| 452 |
+
# Backward pass
|
| 453 |
+
loss.backward()
|
| 454 |
+
|
| 455 |
+
# Update weights every gradient_accumulation_steps
|
| 456 |
+
if (batch_idx + 1) % self.gradient_accumulation_steps == 0:
|
| 457 |
+
# Clip gradients
|
| 458 |
+
if self.gradient_clipping > 0:
|
| 459 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.gradient_clipping)
|
| 460 |
+
|
| 461 |
+
# Update parameters
|
| 462 |
+
self.optimizer.step()
|
| 463 |
+
self.scheduler.step()
|
| 464 |
+
self.optimizer.zero_grad()
|
| 465 |
+
|
| 466 |
+
# Update step count
|
| 467 |
+
self.step += 1
|
| 468 |
+
step_time = time.time() - step_start_time
|
| 469 |
+
self.step_times.append(step_time)
|
| 470 |
+
|
| 471 |
+
# Get current learning rate
|
| 472 |
+
current_lr = self.scheduler.get_last_lr()[0]
|
| 473 |
+
|
| 474 |
+
# Log progress
|
| 475 |
+
if self.step % self.log_every == 0:
|
| 476 |
+
avg_loss = accumulated_loss
|
| 477 |
+
self._log_step(self.step, avg_loss, current_lr, step_time)
|
| 478 |
+
|
| 479 |
+
# Save checkpoint
|
| 480 |
+
if self.step % self.save_every == 0:
|
| 481 |
+
is_best = accumulated_loss < self.best_loss
|
| 482 |
+
if is_best:
|
| 483 |
+
self.best_loss = accumulated_loss
|
| 484 |
+
|
| 485 |
+
self._save_checkpoint(self.step, is_best)
|
| 486 |
+
|
| 487 |
+
# Clean up memory periodically
|
| 488 |
+
if self.step % 100 == 0:
|
| 489 |
+
gc.collect()
|
| 490 |
+
|
| 491 |
+
# Reset accumulated loss
|
| 492 |
+
accumulated_loss = 0.0
|
| 493 |
+
|
| 494 |
+
# Check if training complete
|
| 495 |
+
if self.step >= self.max_steps:
|
| 496 |
+
break
|
| 497 |
+
|
| 498 |
+
# Final checkpoint
|
| 499 |
+
print("\nπ Training completed!")
|
| 500 |
+
self._save_checkpoint(self.step, is_best=True)
|
| 501 |
+
|
| 502 |
+
# Training summary
|
| 503 |
+
total_time = time.time() - self.start_time
|
| 504 |
+
avg_step_time = sum(self.step_times) / len(self.step_times) if self.step_times else 0
|
| 505 |
+
|
| 506 |
+
print("\nπ Training Summary:")
|
| 507 |
+
print(f" Steps completed: {self.step:,}")
|
| 508 |
+
print(f" Total time: {total_time/3600:.2f} hours")
|
| 509 |
+
print(f" Average time per step: {avg_step_time:.2f}s")
|
| 510 |
+
print(f" Final loss: {self.best_loss:.4f}")
|
| 511 |
+
print(f" Final perplexity: {math.exp(min(self.best_loss, 10)):.2f}")
|
| 512 |
+
print(f" Model saved to: {self.output_dir}")
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
def main():
|
| 516 |
+
"""Main function to handle command line training."""
|
| 517 |
+
parser = argparse.ArgumentParser(
|
| 518 |
+
description="Train a GPT-style language model",
|
| 519 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 520 |
+
epilog="""
|
| 521 |
+
Examples:
|
| 522 |
+
# Train small model for quick experimentation
|
| 523 |
+
python core/src/train_model.py \\
|
| 524 |
+
--model-size small \\
|
| 525 |
+
--max-steps 5000 \\
|
| 526 |
+
--output-dir models/test-small
|
| 527 |
+
|
| 528 |
+
# Train medium model with custom settings
|
| 529 |
+
python core/src/train_model.py \\
|
| 530 |
+
--model-size medium \\
|
| 531 |
+
--learning-rate 1e-4 \\
|
| 532 |
+
--batch-size 2 \\
|
| 533 |
+
--max-steps 50000 \\
|
| 534 |
+
--output-dir models/my-medium-model
|
| 535 |
+
""",
|
| 536 |
+
)
|
| 537 |
+
|
| 538 |
+
# Model and data arguments
|
| 539 |
+
parser.add_argument(
|
| 540 |
+
"--model-size",
|
| 541 |
+
choices=["small", "medium", "large"],
|
| 542 |
+
default="small",
|
| 543 |
+
help="Model size to train (default: small)",
|
| 544 |
+
)
|
| 545 |
+
|
| 546 |
+
parser.add_argument(
|
| 547 |
+
"--data-file",
|
| 548 |
+
default="data/clean/training_data.txt",
|
| 549 |
+
help="Path to training text file (default: data/clean/training_data.txt)",
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
parser.add_argument(
|
| 553 |
+
"--tokenizer-dir",
|
| 554 |
+
default="data/tokenizer/",
|
| 555 |
+
help="Path to tokenizer directory (default: data/tokenizer/)",
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
parser.add_argument(
|
| 559 |
+
"--output-dir", required=True, help="Output directory for model checkpoints"
|
| 560 |
+
)
|
| 561 |
+
|
| 562 |
+
# Training hyperparameters
|
| 563 |
+
parser.add_argument(
|
| 564 |
+
"--seq-len", type=int, default=512, help="Sequence length for training (default: 512)"
|
| 565 |
+
)
|
| 566 |
+
|
| 567 |
+
parser.add_argument("--batch-size", type=int, default=4, help="Batch size (default: 4)")
|
| 568 |
+
|
| 569 |
+
parser.add_argument(
|
| 570 |
+
"--learning-rate", type=float, default=3e-4, help="Learning rate (default: 3e-4)"
|
| 571 |
+
)
|
| 572 |
+
|
| 573 |
+
parser.add_argument(
|
| 574 |
+
"--max-steps", type=int, default=10000, help="Maximum training steps (default: 10000)"
|
| 575 |
+
)
|
| 576 |
+
|
| 577 |
+
parser.add_argument(
|
| 578 |
+
"--warmup-steps", type=int, default=1000, help="Warmup steps (default: 1000)"
|
| 579 |
+
)
|
| 580 |
+
|
| 581 |
+
parser.add_argument(
|
| 582 |
+
"--gradient-accumulation-steps",
|
| 583 |
+
type=int,
|
| 584 |
+
default=4,
|
| 585 |
+
help="Gradient accumulation steps (default: 4)",
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
+
parser.add_argument(
|
| 589 |
+
"--device",
|
| 590 |
+
choices=["cpu", "cuda", "auto"],
|
| 591 |
+
default="auto",
|
| 592 |
+
help="Training device (default: auto)",
|
| 593 |
+
)
|
| 594 |
+
|
| 595 |
+
parser.add_argument("--resume", help="Path to checkpoint to resume training from")
|
| 596 |
+
|
| 597 |
+
parser.add_argument(
|
| 598 |
+
"--save-every", type=int, default=1000, help="Save checkpoint every N steps (default: 1000)"
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
args = parser.parse_args()
|
| 602 |
+
|
| 603 |
+
print("π OpenLLM Model Training")
|
| 604 |
+
print("=" * 60)
|
| 605 |
+
|
| 606 |
+
# Determine device
|
| 607 |
+
if args.device == "auto":
|
| 608 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 609 |
+
else:
|
| 610 |
+
device = args.device
|
| 611 |
+
|
| 612 |
+
print(f"Using device: {device}")
|
| 613 |
+
|
| 614 |
+
try:
|
| 615 |
+
# Create model
|
| 616 |
+
print(f"\nποΈ Creating {args.model_size} model...")
|
| 617 |
+
model = create_model(args.model_size)
|
| 618 |
+
|
| 619 |
+
# Create data loader
|
| 620 |
+
print("\nπ Setting up data loader...")
|
| 621 |
+
tokenizer_path = os.path.join(args.tokenizer_dir, "tokenizer.model")
|
| 622 |
+
|
| 623 |
+
data_loader = TextDataLoader(
|
| 624 |
+
data_file=args.data_file,
|
| 625 |
+
tokenizer_path=tokenizer_path,
|
| 626 |
+
seq_len=args.seq_len,
|
| 627 |
+
batch_size=args.batch_size,
|
| 628 |
+
shuffle=True,
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
# Get data statistics
|
| 632 |
+
_ = data_loader.get_data_stats()
|
| 633 |
+
|
| 634 |
+
# Create trainer
|
| 635 |
+
print("\nπ― Setting up trainer...")
|
| 636 |
+
trainer = ModelTrainer(
|
| 637 |
+
model=model,
|
| 638 |
+
data_loader=data_loader,
|
| 639 |
+
output_dir=args.output_dir,
|
| 640 |
+
device=device,
|
| 641 |
+
learning_rate=args.learning_rate,
|
| 642 |
+
max_steps=args.max_steps,
|
| 643 |
+
warmup_steps=args.warmup_steps,
|
| 644 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 645 |
+
save_every=args.save_every,
|
| 646 |
+
)
|
| 647 |
+
|
| 648 |
+
# Resume from checkpoint if specified
|
| 649 |
+
if args.resume:
|
| 650 |
+
trainer._load_checkpoint(args.resume)
|
| 651 |
+
|
| 652 |
+
# Start training
|
| 653 |
+
trainer.train()
|
| 654 |
+
|
| 655 |
+
print("\nπ Training completed successfully!")
|
| 656 |
+
|
| 657 |
+
except Exception as e:
|
| 658 |
+
print(f"\nβ Training failed: {e}")
|
| 659 |
+
import traceback
|
| 660 |
+
|
| 661 |
+
traceback.print_exc()
|
| 662 |
+
return False
|
| 663 |
+
|
| 664 |
+
return True
|
| 665 |
+
|
| 666 |
+
|
| 667 |
+
if __name__ == "__main__":
|
| 668 |
+
main()
|
core/src/train_tokenizer.py
ADDED
|
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright (C) 2024 Louis Chua Bean Chong
|
| 3 |
+
#
|
| 4 |
+
# This file is part of OpenLLM.
|
| 5 |
+
#
|
| 6 |
+
# OpenLLM is dual-licensed:
|
| 7 |
+
# 1. For open source use: GNU General Public License v3.0
|
| 8 |
+
# 2. For commercial use: Commercial License (contact for details)
|
| 9 |
+
#
|
| 10 |
+
# See LICENSE and docs/LICENSES.md for full license information.
|
| 11 |
+
|
| 12 |
+
"""
|
| 13 |
+
Train a SentencePiece tokenizer from scratch using the prepared training data.
|
| 14 |
+
|
| 15 |
+
OVERVIEW:
|
| 16 |
+
This script trains a SentencePiece tokenizer on the cleaned text data from the SQUAD dataset
|
| 17 |
+
or any other text corpus. SentencePiece is a subword tokenizer that works well for language
|
| 18 |
+
models and supports multiple languages without requiring pre-tokenization.
|
| 19 |
+
|
| 20 |
+
FEATURES:
|
| 21 |
+
- Supports BPE (Byte Pair Encoding) and Unigram tokenization algorithms
|
| 22 |
+
- Configurable vocabulary size (recommended: 8k-64k for LLMs)
|
| 23 |
+
- Handles special tokens (BOS, EOS, UNK, PAD)
|
| 24 |
+
- Outputs tokenizer model files compatible with Hugging Face
|
| 25 |
+
- Comprehensive statistics and vocabulary analysis
|
| 26 |
+
|
| 27 |
+
TOKENIZER OUTPUT:
|
| 28 |
+
- tokenizer.model: SentencePiece model file
|
| 29 |
+
- tokenizer.vocab: Human-readable vocabulary file
|
| 30 |
+
- tokenizer_config.json: Configuration for Hugging Face integration
|
| 31 |
+
|
| 32 |
+
Usage:
|
| 33 |
+
python core/src/train_tokenizer.py --input data/clean/training_data.txt --vocab_size 32000
|
| 34 |
+
|
| 35 |
+
Advanced usage:
|
| 36 |
+
python core/src/train_tokenizer.py \\
|
| 37 |
+
--input data/clean/training_data.txt \\
|
| 38 |
+
--vocab_size 32000 \\
|
| 39 |
+
--model_type bpe \\
|
| 40 |
+
--output_dir data/tokenizer/ \\
|
| 41 |
+
--character_coverage 0.9995
|
| 42 |
+
|
| 43 |
+
Requirements:
|
| 44 |
+
pip install sentencepiece
|
| 45 |
+
|
| 46 |
+
Example setup:
|
| 47 |
+
```bash
|
| 48 |
+
# If not already in virtual environment
|
| 49 |
+
python -m venv venv
|
| 50 |
+
source venv/bin/activate # Linux/macOS
|
| 51 |
+
# .\venv\Scripts\Activate.ps1 # Windows PowerShell
|
| 52 |
+
|
| 53 |
+
# Install SentencePiece
|
| 54 |
+
pip install sentencepiece
|
| 55 |
+
|
| 56 |
+
# Train tokenizer
|
| 57 |
+
python core/src/train_tokenizer.py --input data/clean/training_data.txt --vocab_size 32000
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
import argparse
|
| 63 |
+
import json
|
| 64 |
+
import os
|
| 65 |
+
import time
|
| 66 |
+
from typing import Any, Dict
|
| 67 |
+
|
| 68 |
+
try:
|
| 69 |
+
import sentencepiece as spm
|
| 70 |
+
except ImportError:
|
| 71 |
+
print("ERROR: SentencePiece not installed. Run: pip install sentencepiece")
|
| 72 |
+
exit(1)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def validate_input_file(input_path: str) -> None:
|
| 76 |
+
"""
|
| 77 |
+
Validate that the input training file exists and is readable.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
input_path (str): Path to the training text file
|
| 81 |
+
|
| 82 |
+
Raises:
|
| 83 |
+
FileNotFoundError: If input file doesn't exist
|
| 84 |
+
ValueError: If input file is empty or unreadable
|
| 85 |
+
"""
|
| 86 |
+
if not os.path.exists(input_path):
|
| 87 |
+
raise FileNotFoundError(f"Training data file not found: {input_path}")
|
| 88 |
+
|
| 89 |
+
# Check file size and readability
|
| 90 |
+
file_size = os.path.getsize(input_path)
|
| 91 |
+
if file_size == 0:
|
| 92 |
+
raise ValueError(f"Training data file is empty: {input_path}")
|
| 93 |
+
|
| 94 |
+
# Test that we can read the file
|
| 95 |
+
try:
|
| 96 |
+
with open(input_path, "r", encoding="utf-8") as f:
|
| 97 |
+
first_line = f.readline()
|
| 98 |
+
if not first_line.strip():
|
| 99 |
+
raise ValueError(
|
| 100 |
+
"Training data file appears to be empty or contains only whitespace"
|
| 101 |
+
)
|
| 102 |
+
except UnicodeDecodeError as e:
|
| 103 |
+
raise ValueError(f"Cannot read training data file as UTF-8: {e}")
|
| 104 |
+
|
| 105 |
+
print(f"β Input file validated: {input_path} ({file_size:,} bytes)")
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def count_training_sentences(input_path: str) -> int:
|
| 109 |
+
"""
|
| 110 |
+
Count the number of training sentences/lines in the input file.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
input_path (str): Path to the training text file
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
int: Number of lines in the file
|
| 117 |
+
"""
|
| 118 |
+
print("Counting training sentences...")
|
| 119 |
+
with open(input_path, "r", encoding="utf-8") as f:
|
| 120 |
+
count = sum(1 for line in f if line.strip())
|
| 121 |
+
print(f"β Found {count:,} training sentences")
|
| 122 |
+
return count
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def train_sentencepiece_tokenizer(
|
| 126 |
+
input_path: str,
|
| 127 |
+
output_dir: str,
|
| 128 |
+
vocab_size: int = 32000,
|
| 129 |
+
model_type: str = "bpe",
|
| 130 |
+
character_coverage: float = 0.9995,
|
| 131 |
+
max_sentence_length: int = 4192,
|
| 132 |
+
input_sentence_size: int = 10000000,
|
| 133 |
+
shuffle_input_sentence: bool = True,
|
| 134 |
+
) -> Dict[str, Any]:
|
| 135 |
+
"""
|
| 136 |
+
Train a SentencePiece tokenizer with the specified parameters.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
input_path (str): Path to training text file
|
| 140 |
+
output_dir (str): Directory to save tokenizer files
|
| 141 |
+
vocab_size (int): Target vocabulary size (recommended: 8k-64k)
|
| 142 |
+
model_type (str): Algorithm type ('bpe' or 'unigram')
|
| 143 |
+
character_coverage (float): Character coverage (0.9995 for English, 1.0 for Japanese)
|
| 144 |
+
max_sentence_length (int): Maximum sentence length in characters
|
| 145 |
+
input_sentence_size (int): Maximum number of sentences to use for training
|
| 146 |
+
shuffle_input_sentence (bool): Whether to shuffle input sentences
|
| 147 |
+
|
| 148 |
+
Returns:
|
| 149 |
+
Dict[str, Any]: Training statistics and configuration
|
| 150 |
+
"""
|
| 151 |
+
# Ensure output directory exists
|
| 152 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 153 |
+
|
| 154 |
+
# Define output paths
|
| 155 |
+
model_prefix = os.path.join(output_dir, "tokenizer")
|
| 156 |
+
|
| 157 |
+
# SentencePiece training parameters
|
| 158 |
+
train_params = [
|
| 159 |
+
f"--input={input_path}",
|
| 160 |
+
f"--model_prefix={model_prefix}",
|
| 161 |
+
f"--vocab_size={vocab_size}",
|
| 162 |
+
f"--model_type={model_type}",
|
| 163 |
+
f"--character_coverage={character_coverage}",
|
| 164 |
+
f"--max_sentence_length={max_sentence_length}",
|
| 165 |
+
f"--input_sentence_size={input_sentence_size}",
|
| 166 |
+
f"--shuffle_input_sentence={shuffle_input_sentence}",
|
| 167 |
+
# Special tokens for language modeling
|
| 168 |
+
"--pad_id=0", # Padding token
|
| 169 |
+
"--unk_id=1", # Unknown token
|
| 170 |
+
"--bos_id=2", # Beginning of sequence
|
| 171 |
+
"--eos_id=3", # End of sequence
|
| 172 |
+
# Additional useful parameters
|
| 173 |
+
"--split_by_unicode_script=true", # Better handling of mixed scripts
|
| 174 |
+
"--split_by_whitespace=true", # Split on whitespace
|
| 175 |
+
"--remove_extra_whitespaces=true", # Clean up whitespace
|
| 176 |
+
"--normalization_rule_name=identity", # Keep original text as-is
|
| 177 |
+
]
|
| 178 |
+
|
| 179 |
+
print("\nTraining SentencePiece tokenizer...")
|
| 180 |
+
print(f" Algorithm: {model_type.upper()}")
|
| 181 |
+
print(f" Vocabulary size: {vocab_size:,}")
|
| 182 |
+
print(f" Character coverage: {character_coverage}")
|
| 183 |
+
print(f" Output directory: {output_dir}")
|
| 184 |
+
print(f" Model files: {model_prefix}.model, {model_prefix}.vocab")
|
| 185 |
+
|
| 186 |
+
# Record training start time
|
| 187 |
+
start_time = time.time()
|
| 188 |
+
|
| 189 |
+
# Train the tokenizer
|
| 190 |
+
try:
|
| 191 |
+
spm.SentencePieceTrainer.train(" ".join(train_params))
|
| 192 |
+
training_time = time.time() - start_time
|
| 193 |
+
print(f"β Tokenizer training completed in {training_time:.1f} seconds")
|
| 194 |
+
except Exception as e:
|
| 195 |
+
raise RuntimeError(f"SentencePiece training failed: {e}")
|
| 196 |
+
|
| 197 |
+
# Verify output files were created
|
| 198 |
+
model_file = f"{model_prefix}.model"
|
| 199 |
+
vocab_file = f"{model_prefix}.vocab"
|
| 200 |
+
|
| 201 |
+
if not os.path.exists(model_file):
|
| 202 |
+
raise RuntimeError(f"Expected model file not created: {model_file}")
|
| 203 |
+
if not os.path.exists(vocab_file):
|
| 204 |
+
raise RuntimeError(f"Expected vocab file not created: {vocab_file}")
|
| 205 |
+
|
| 206 |
+
print(f"β Model file created: {model_file} ({os.path.getsize(model_file):,} bytes)")
|
| 207 |
+
print(f"β Vocab file created: {vocab_file} ({os.path.getsize(vocab_file):,} bytes)")
|
| 208 |
+
|
| 209 |
+
# Return training configuration and statistics
|
| 210 |
+
config = {
|
| 211 |
+
"model_type": model_type,
|
| 212 |
+
"vocab_size": vocab_size,
|
| 213 |
+
"character_coverage": character_coverage,
|
| 214 |
+
"max_sentence_length": max_sentence_length,
|
| 215 |
+
"training_time_seconds": training_time,
|
| 216 |
+
"input_file": input_path,
|
| 217 |
+
"output_directory": output_dir,
|
| 218 |
+
"model_file": model_file,
|
| 219 |
+
"vocab_file": vocab_file,
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
return config
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def test_tokenizer(model_path: str, test_sentences: list = None) -> None:
|
| 226 |
+
"""
|
| 227 |
+
Test the trained tokenizer on sample sentences to verify it works correctly.
|
| 228 |
+
|
| 229 |
+
Args:
|
| 230 |
+
model_path (str): Path to the trained .model file
|
| 231 |
+
test_sentences (list): Optional list of test sentences
|
| 232 |
+
"""
|
| 233 |
+
print("\nTesting trained tokenizer...")
|
| 234 |
+
|
| 235 |
+
# Load the trained tokenizer
|
| 236 |
+
sp = spm.SentencePieceProcessor()
|
| 237 |
+
sp.load(model_path)
|
| 238 |
+
|
| 239 |
+
# Default test sentences if none provided
|
| 240 |
+
if test_sentences is None:
|
| 241 |
+
test_sentences = [
|
| 242 |
+
"Hello, world! This is a test sentence.",
|
| 243 |
+
"The quick brown fox jumps over the lazy dog.",
|
| 244 |
+
"Machine learning and artificial intelligence are transforming technology.",
|
| 245 |
+
"SentencePiece tokenization works well for language models.",
|
| 246 |
+
]
|
| 247 |
+
|
| 248 |
+
print(f"Vocabulary size: {sp.vocab_size():,}")
|
| 249 |
+
print(
|
| 250 |
+
f"Special tokens: PAD={sp.pad_id()}, UNK={sp.unk_id()}, BOS={sp.bos_id()}, EOS={sp.eos_id()}"
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
print("\nTokenization examples:")
|
| 254 |
+
for i, sentence in enumerate(test_sentences, 1):
|
| 255 |
+
# Encode to token IDs and pieces
|
| 256 |
+
token_ids = sp.encode(sentence)
|
| 257 |
+
token_pieces = sp.encode(sentence, out_type=str)
|
| 258 |
+
|
| 259 |
+
print(f"\n{i}. Input: {sentence}")
|
| 260 |
+
print(f" Tokens ({len(token_pieces)}): {token_pieces}")
|
| 261 |
+
print(f" IDs: {token_ids[:10]}{'...' if len(token_ids) > 10 else ''}")
|
| 262 |
+
|
| 263 |
+
# Test decoding
|
| 264 |
+
decoded = sp.decode(token_ids)
|
| 265 |
+
print(f" Decoded: {decoded}")
|
| 266 |
+
|
| 267 |
+
# Verify round-trip encoding/decoding
|
| 268 |
+
if decoded.strip() != sentence.strip():
|
| 269 |
+
print(" β οΈ Warning: Decode mismatch!")
|
| 270 |
+
|
| 271 |
+
print("β Tokenizer testing completed")
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def save_huggingface_config(output_dir: str, config: Dict[str, Any]) -> None:
|
| 275 |
+
"""
|
| 276 |
+
Save a Hugging Face compatible tokenizer configuration file.
|
| 277 |
+
|
| 278 |
+
Args:
|
| 279 |
+
output_dir (str): Directory containing the tokenizer files
|
| 280 |
+
config (Dict[str, Any]): Tokenizer configuration
|
| 281 |
+
"""
|
| 282 |
+
# Create Hugging Face tokenizer config
|
| 283 |
+
hf_config = {
|
| 284 |
+
"tokenizer_class": "SentencePieceTokenizer",
|
| 285 |
+
"model_type": config["model_type"],
|
| 286 |
+
"vocab_size": config["vocab_size"],
|
| 287 |
+
"model_file": "tokenizer.model",
|
| 288 |
+
"special_tokens": {
|
| 289 |
+
"pad_token": "<pad>",
|
| 290 |
+
"unk_token": "<unk>",
|
| 291 |
+
"bos_token": "<s>",
|
| 292 |
+
"eos_token": "</s>",
|
| 293 |
+
},
|
| 294 |
+
"special_token_ids": {
|
| 295 |
+
"pad_token_id": 0,
|
| 296 |
+
"unk_token_id": 1,
|
| 297 |
+
"bos_token_id": 2,
|
| 298 |
+
"eos_token_id": 3,
|
| 299 |
+
},
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
config_path = os.path.join(output_dir, "tokenizer_config.json")
|
| 303 |
+
with open(config_path, "w", encoding="utf-8") as f:
|
| 304 |
+
json.dump(hf_config, f, indent=2, ensure_ascii=False)
|
| 305 |
+
|
| 306 |
+
print(f"β Hugging Face config saved: {config_path}")
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def main():
|
| 310 |
+
"""Main function to handle command line arguments and orchestrate tokenizer training."""
|
| 311 |
+
parser = argparse.ArgumentParser(
|
| 312 |
+
description="Train a SentencePiece tokenizer for language model training",
|
| 313 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 314 |
+
epilog="""
|
| 315 |
+
Examples:
|
| 316 |
+
# Basic usage with SQUAD data
|
| 317 |
+
python core/src/train_tokenizer.py --input data/clean/training_data.txt --vocab_size 32000
|
| 318 |
+
|
| 319 |
+
# Advanced configuration
|
| 320 |
+
python core/src/train_tokenizer.py \\
|
| 321 |
+
--input data/clean/training_data.txt \\
|
| 322 |
+
--vocab_size 32000 \\
|
| 323 |
+
--model_type bpe \\
|
| 324 |
+
--output_dir data/tokenizer/ \\
|
| 325 |
+
--character_coverage 0.9995
|
| 326 |
+
""",
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
# Required arguments
|
| 330 |
+
parser.add_argument(
|
| 331 |
+
"--input",
|
| 332 |
+
required=True,
|
| 333 |
+
help="Path to training text file (e.g., data/clean/training_data.txt)",
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
# Optional arguments with sensible defaults
|
| 337 |
+
parser.add_argument(
|
| 338 |
+
"--vocab_size",
|
| 339 |
+
type=int,
|
| 340 |
+
default=32000,
|
| 341 |
+
help="Vocabulary size (default: 32000, recommended: 8k-64k)",
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
parser.add_argument(
|
| 345 |
+
"--model_type",
|
| 346 |
+
choices=["bpe", "unigram"],
|
| 347 |
+
default="bpe",
|
| 348 |
+
help="Tokenization algorithm (default: bpe)",
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
parser.add_argument(
|
| 352 |
+
"--output_dir",
|
| 353 |
+
default="data/tokenizer/",
|
| 354 |
+
help="Output directory for tokenizer files (default: data/tokenizer/)",
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
parser.add_argument(
|
| 358 |
+
"--character_coverage",
|
| 359 |
+
type=float,
|
| 360 |
+
default=0.9995,
|
| 361 |
+
help="Character coverage (default: 0.9995 for English)",
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
parser.add_argument(
|
| 365 |
+
"--max_sentence_length",
|
| 366 |
+
type=int,
|
| 367 |
+
default=4192,
|
| 368 |
+
help="Maximum sentence length in characters (default: 4192)",
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
parser.add_argument(
|
| 372 |
+
"--no_test", action="store_true", help="Skip tokenizer testing after training"
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
args = parser.parse_args()
|
| 376 |
+
|
| 377 |
+
print("π€ SentencePiece Tokenizer Training")
|
| 378 |
+
print("=" * 50)
|
| 379 |
+
|
| 380 |
+
try:
|
| 381 |
+
# Step 1: Validate input file
|
| 382 |
+
validate_input_file(args.input)
|
| 383 |
+
|
| 384 |
+
# Step 2: Count training data
|
| 385 |
+
sentence_count = count_training_sentences(args.input)
|
| 386 |
+
|
| 387 |
+
# Step 3: Train tokenizer
|
| 388 |
+
config = train_sentencepiece_tokenizer(
|
| 389 |
+
input_path=args.input,
|
| 390 |
+
output_dir=args.output_dir,
|
| 391 |
+
vocab_size=args.vocab_size,
|
| 392 |
+
model_type=args.model_type,
|
| 393 |
+
character_coverage=args.character_coverage,
|
| 394 |
+
max_sentence_length=args.max_sentence_length,
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
# Step 4: Save Hugging Face compatible config
|
| 398 |
+
save_huggingface_config(args.output_dir, config)
|
| 399 |
+
|
| 400 |
+
# Step 5: Test tokenizer (unless skipped)
|
| 401 |
+
if not args.no_test:
|
| 402 |
+
model_path = os.path.join(args.output_dir, "tokenizer.model")
|
| 403 |
+
test_tokenizer(model_path)
|
| 404 |
+
|
| 405 |
+
# Step 6: Print summary
|
| 406 |
+
print("\nπ Tokenizer training completed successfully!")
|
| 407 |
+
print(f"π Output directory: {args.output_dir}")
|
| 408 |
+
print(f"π Vocabulary size: {config['vocab_size']:,}")
|
| 409 |
+
print(f"β±οΈ Training time: {config['training_time_seconds']:.1f}s")
|
| 410 |
+
print(f"π Training sentences: {sentence_count:,}")
|
| 411 |
+
|
| 412 |
+
print("\nFiles created:")
|
| 413 |
+
print(f" β’ {config['model_file']} - SentencePiece model")
|
| 414 |
+
print(f" β’ {config['vocab_file']} - Vocabulary file")
|
| 415 |
+
print(f" β’ {os.path.join(args.output_dir, 'tokenizer_config.json')} - Hugging Face config")
|
| 416 |
+
|
| 417 |
+
print("\nTo use this tokenizer in your language model:")
|
| 418 |
+
print(" import sentencepiece as spm")
|
| 419 |
+
print(" sp = spm.SentencePieceProcessor()")
|
| 420 |
+
print(f" sp.load('{config['model_file']}')")
|
| 421 |
+
|
| 422 |
+
except Exception as e:
|
| 423 |
+
print(f"\nβ Error: {e}")
|
| 424 |
+
exit(1)
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
if __name__ == "__main__":
|
| 428 |
+
main()
|