Commit ·
c27df58
1
Parent(s): e3ef0ba
Remove pycache, add gitignore
Browse files- .gitignore +5 -0
- README.md +2 -4
- src/__init__.py +2 -0
- src/data/__init__.py +31 -0
- src/data/dataloader.py +300 -0
- src/data/dataset.py +434 -0
- src/data/tokenizer.py +300 -0
- src/export/__init__.py +1 -0
- src/inference/__init__.py +1 -0
- src/model/__init__.py +24 -0
- src/model/attention.py +172 -0
- src/model/config.py +116 -0
- src/model/decoder.py +106 -0
- src/model/ffn.py +67 -0
- src/model/kv_cache.py +127 -0
- src/model/normalization.py +50 -0
- src/model/rope.py +110 -0
- src/model/transformer.py +323 -0
- src/training/__init__.py +16 -0
- src/training/loss.py +123 -0
- src/training/optimizer.py +197 -0
- src/training/trainer.py +511 -0
.gitignore
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.py[cod]
|
| 3 |
+
*$py.class
|
| 4 |
+
*.so
|
| 5 |
+
.DS_Store
|
README.md
CHANGED
|
@@ -111,10 +111,6 @@ Training Time: ~40 minutes
|
|
| 111 |
|
| 112 |
```bash
|
| 113 |
pip install torch tokenizers huggingface_hub
|
| 114 |
-
|
| 115 |
-
# Clone model architecture code
|
| 116 |
-
git clone https://github.com/nameissakthi/slm-qualcomm
|
| 117 |
-
cd slm-qualcomm
|
| 118 |
```
|
| 119 |
|
| 120 |
### Download Model
|
|
@@ -132,6 +128,8 @@ tokenizer_path = hf_hub_download(repo_id="nameissakthi/PebbleLM-117M-Chat", file
|
|
| 132 |
```python
|
| 133 |
import torch
|
| 134 |
from tokenizers import Tokenizer
|
|
|
|
|
|
|
| 135 |
from src.model.transformer import SLMForCausalLM
|
| 136 |
from src.model.config import SLMConfig
|
| 137 |
|
|
|
|
| 111 |
|
| 112 |
```bash
|
| 113 |
pip install torch tokenizers huggingface_hub
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
```
|
| 115 |
|
| 116 |
### Download Model
|
|
|
|
| 128 |
```python
|
| 129 |
import torch
|
| 130 |
from tokenizers import Tokenizer
|
| 131 |
+
|
| 132 |
+
# Model architecture is included in this repo
|
| 133 |
from src.model.transformer import SLMForCausalLM
|
| 134 |
from src.model.config import SLMConfig
|
| 135 |
|
src/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SLM Qualcomm - Conversational Small Language Model
|
| 2 |
+
__version__ = "1.0.0"
|
src/data/__init__.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Data loading and tokenizer components
|
| 2 |
+
|
| 3 |
+
from .tokenizer import SLMTokenizer
|
| 4 |
+
from .dataset import (
|
| 5 |
+
ConversationalDataset,
|
| 6 |
+
StreamingTextDataset,
|
| 7 |
+
PackedDataset,
|
| 8 |
+
create_train_val_split,
|
| 9 |
+
load_jsonl,
|
| 10 |
+
save_jsonl,
|
| 11 |
+
)
|
| 12 |
+
from .dataloader import (
|
| 13 |
+
DataModule,
|
| 14 |
+
StreamingDataModule,
|
| 15 |
+
create_dataloader,
|
| 16 |
+
estimate_dataset_tokens,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
__all__ = [
|
| 20 |
+
"SLMTokenizer",
|
| 21 |
+
"ConversationalDataset",
|
| 22 |
+
"StreamingTextDataset",
|
| 23 |
+
"PackedDataset",
|
| 24 |
+
"create_train_val_split",
|
| 25 |
+
"load_jsonl",
|
| 26 |
+
"save_jsonl",
|
| 27 |
+
"DataModule",
|
| 28 |
+
"StreamingDataModule",
|
| 29 |
+
"create_dataloader",
|
| 30 |
+
"estimate_dataset_tokens",
|
| 31 |
+
]
|
src/data/dataloader.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
DataLoader utilities for SLM training.
|
| 3 |
+
|
| 4 |
+
Provides efficient batching and data loading for training.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
from typing import Dict, Optional, List
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torch.utils.data import DataLoader, Dataset, DistributedSampler
|
| 12 |
+
|
| 13 |
+
from .dataset import ConversationalDataset, StreamingTextDataset, PackedDataset
|
| 14 |
+
from .tokenizer import SLMTokenizer
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def create_dataloader(
|
| 18 |
+
dataset: Dataset,
|
| 19 |
+
batch_size: int,
|
| 20 |
+
shuffle: bool = True,
|
| 21 |
+
num_workers: int = 4,
|
| 22 |
+
pin_memory: bool = None, # Auto-detect based on device
|
| 23 |
+
drop_last: bool = True,
|
| 24 |
+
distributed: bool = False,
|
| 25 |
+
world_size: int = 1,
|
| 26 |
+
rank: int = 0,
|
| 27 |
+
) -> DataLoader:
|
| 28 |
+
"""Create a DataLoader with optimal settings.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
dataset: The dataset to load from
|
| 32 |
+
batch_size: Batch size per device
|
| 33 |
+
shuffle: Whether to shuffle data
|
| 34 |
+
num_workers: Number of data loading workers
|
| 35 |
+
pin_memory: Pin memory for faster GPU transfer
|
| 36 |
+
drop_last: Drop last incomplete batch
|
| 37 |
+
distributed: Whether using distributed training
|
| 38 |
+
world_size: Number of distributed processes
|
| 39 |
+
rank: Current process rank
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
Configured DataLoader
|
| 43 |
+
"""
|
| 44 |
+
sampler = None
|
| 45 |
+
if distributed:
|
| 46 |
+
sampler = DistributedSampler(
|
| 47 |
+
dataset,
|
| 48 |
+
num_replicas=world_size,
|
| 49 |
+
rank=rank,
|
| 50 |
+
shuffle=shuffle,
|
| 51 |
+
)
|
| 52 |
+
shuffle = False # Sampler handles shuffling
|
| 53 |
+
|
| 54 |
+
# Auto-detect pin_memory: disable for MPS (not supported)
|
| 55 |
+
if pin_memory is None:
|
| 56 |
+
import torch
|
| 57 |
+
pin_memory = torch.cuda.is_available() # Only True for CUDA
|
| 58 |
+
|
| 59 |
+
return DataLoader(
|
| 60 |
+
dataset,
|
| 61 |
+
batch_size=batch_size,
|
| 62 |
+
shuffle=shuffle if sampler is None else False,
|
| 63 |
+
sampler=sampler,
|
| 64 |
+
num_workers=num_workers,
|
| 65 |
+
pin_memory=pin_memory,
|
| 66 |
+
drop_last=drop_last,
|
| 67 |
+
collate_fn=default_collate_fn,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def default_collate_fn(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
|
| 72 |
+
"""Collate function for batching samples.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
batch: List of sample dictionaries
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
Batched dictionary with stacked tensors
|
| 79 |
+
"""
|
| 80 |
+
return {
|
| 81 |
+
"input_ids": torch.stack([s["input_ids"] for s in batch]),
|
| 82 |
+
"attention_mask": torch.stack([s["attention_mask"] for s in batch]),
|
| 83 |
+
"labels": torch.stack([s["labels"] for s in batch]),
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class DataModule:
|
| 88 |
+
"""Data module for managing train/val dataloaders.
|
| 89 |
+
|
| 90 |
+
Provides a unified interface for data loading during training.
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
def __init__(
|
| 94 |
+
self,
|
| 95 |
+
data_dir: str,
|
| 96 |
+
tokenizer_path: str,
|
| 97 |
+
max_length: int = 1024,
|
| 98 |
+
batch_size: int = 32,
|
| 99 |
+
num_workers: int = 4,
|
| 100 |
+
val_batch_size: Optional[int] = None,
|
| 101 |
+
):
|
| 102 |
+
"""Initialize data module.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
data_dir: Directory containing processed data
|
| 106 |
+
tokenizer_path: Path to tokenizer.json
|
| 107 |
+
max_length: Maximum sequence length
|
| 108 |
+
batch_size: Training batch size
|
| 109 |
+
num_workers: Number of data loading workers
|
| 110 |
+
val_batch_size: Validation batch size (defaults to batch_size)
|
| 111 |
+
"""
|
| 112 |
+
self.data_dir = data_dir
|
| 113 |
+
self.max_length = max_length
|
| 114 |
+
self.batch_size = batch_size
|
| 115 |
+
self.val_batch_size = val_batch_size or batch_size
|
| 116 |
+
self.num_workers = num_workers
|
| 117 |
+
|
| 118 |
+
# Load tokenizer
|
| 119 |
+
self.tokenizer = SLMTokenizer.from_file(tokenizer_path)
|
| 120 |
+
|
| 121 |
+
# Datasets (created on first access)
|
| 122 |
+
self._train_dataset = None
|
| 123 |
+
self._val_dataset = None
|
| 124 |
+
|
| 125 |
+
@property
|
| 126 |
+
def train_dataset(self) -> Dataset:
|
| 127 |
+
"""Get or create training dataset."""
|
| 128 |
+
if self._train_dataset is None:
|
| 129 |
+
self._train_dataset = ConversationalDataset(
|
| 130 |
+
data_path=self.data_dir,
|
| 131 |
+
tokenizer=self.tokenizer,
|
| 132 |
+
max_length=self.max_length,
|
| 133 |
+
split="train",
|
| 134 |
+
)
|
| 135 |
+
return self._train_dataset
|
| 136 |
+
|
| 137 |
+
@property
|
| 138 |
+
def val_dataset(self) -> Dataset:
|
| 139 |
+
"""Get or create validation dataset."""
|
| 140 |
+
if self._val_dataset is None:
|
| 141 |
+
self._val_dataset = ConversationalDataset(
|
| 142 |
+
data_path=self.data_dir,
|
| 143 |
+
tokenizer=self.tokenizer,
|
| 144 |
+
max_length=self.max_length,
|
| 145 |
+
split="val",
|
| 146 |
+
)
|
| 147 |
+
return self._val_dataset
|
| 148 |
+
|
| 149 |
+
def train_dataloader(
|
| 150 |
+
self,
|
| 151 |
+
distributed: bool = False,
|
| 152 |
+
world_size: int = 1,
|
| 153 |
+
rank: int = 0,
|
| 154 |
+
) -> DataLoader:
|
| 155 |
+
"""Get training dataloader."""
|
| 156 |
+
return create_dataloader(
|
| 157 |
+
self.train_dataset,
|
| 158 |
+
batch_size=self.batch_size,
|
| 159 |
+
shuffle=True,
|
| 160 |
+
num_workers=self.num_workers,
|
| 161 |
+
drop_last=True,
|
| 162 |
+
distributed=distributed,
|
| 163 |
+
world_size=world_size,
|
| 164 |
+
rank=rank,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
def val_dataloader(self) -> DataLoader:
|
| 168 |
+
"""Get validation dataloader."""
|
| 169 |
+
return create_dataloader(
|
| 170 |
+
self.val_dataset,
|
| 171 |
+
batch_size=self.val_batch_size,
|
| 172 |
+
shuffle=False,
|
| 173 |
+
num_workers=self.num_workers,
|
| 174 |
+
drop_last=False,
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class StreamingDataModule:
|
| 179 |
+
"""Data module for streaming large datasets.
|
| 180 |
+
|
| 181 |
+
Memory-efficient loading for large text corpora.
|
| 182 |
+
"""
|
| 183 |
+
|
| 184 |
+
def __init__(
|
| 185 |
+
self,
|
| 186 |
+
data_files: List[str],
|
| 187 |
+
tokenizer_path: str,
|
| 188 |
+
max_length: int = 1024,
|
| 189 |
+
batch_size: int = 32,
|
| 190 |
+
num_workers: int = 4,
|
| 191 |
+
):
|
| 192 |
+
"""Initialize streaming data module.
|
| 193 |
+
|
| 194 |
+
Args:
|
| 195 |
+
data_files: List of text file paths
|
| 196 |
+
tokenizer_path: Path to tokenizer.json
|
| 197 |
+
max_length: Maximum sequence length
|
| 198 |
+
batch_size: Batch size
|
| 199 |
+
num_workers: Number of data loading workers
|
| 200 |
+
"""
|
| 201 |
+
self.data_files = data_files
|
| 202 |
+
self.max_length = max_length
|
| 203 |
+
self.batch_size = batch_size
|
| 204 |
+
self.num_workers = num_workers
|
| 205 |
+
|
| 206 |
+
# Load tokenizer
|
| 207 |
+
self.tokenizer = SLMTokenizer.from_file(tokenizer_path)
|
| 208 |
+
|
| 209 |
+
def train_dataloader(self) -> DataLoader:
|
| 210 |
+
"""Get training dataloader for streaming data."""
|
| 211 |
+
dataset = StreamingTextDataset(
|
| 212 |
+
data_files=self.data_files,
|
| 213 |
+
tokenizer=self.tokenizer,
|
| 214 |
+
max_length=self.max_length,
|
| 215 |
+
shuffle=True,
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
return DataLoader(
|
| 219 |
+
dataset,
|
| 220 |
+
batch_size=self.batch_size,
|
| 221 |
+
num_workers=self.num_workers,
|
| 222 |
+
pin_memory=True,
|
| 223 |
+
collate_fn=default_collate_fn,
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def estimate_dataset_tokens(data_dir: str, tokenizer_path: str) -> Dict[str, int]:
|
| 228 |
+
"""Estimate total tokens in a dataset.
|
| 229 |
+
|
| 230 |
+
Args:
|
| 231 |
+
data_dir: Directory containing data files
|
| 232 |
+
tokenizer_path: Path to tokenizer
|
| 233 |
+
|
| 234 |
+
Returns:
|
| 235 |
+
Dictionary with token counts
|
| 236 |
+
"""
|
| 237 |
+
import json
|
| 238 |
+
from pathlib import Path
|
| 239 |
+
|
| 240 |
+
tokenizer = SLMTokenizer.from_file(tokenizer_path)
|
| 241 |
+
|
| 242 |
+
total_tokens = 0
|
| 243 |
+
total_samples = 0
|
| 244 |
+
|
| 245 |
+
for file_path in Path(data_dir).glob("*.json*"):
|
| 246 |
+
with open(file_path, "r") as f:
|
| 247 |
+
if file_path.suffix == ".jsonl":
|
| 248 |
+
samples = [json.loads(line) for line in f if line.strip()]
|
| 249 |
+
else:
|
| 250 |
+
samples = json.load(f)
|
| 251 |
+
if not isinstance(samples, list):
|
| 252 |
+
samples = [samples]
|
| 253 |
+
|
| 254 |
+
for sample in samples:
|
| 255 |
+
if "user" in sample and "assistant" in sample:
|
| 256 |
+
tokens = tokenizer.encode_conversation(
|
| 257 |
+
sample["user"], sample["assistant"]
|
| 258 |
+
)
|
| 259 |
+
elif "text" in sample:
|
| 260 |
+
tokens = tokenizer.encode(sample["text"])
|
| 261 |
+
else:
|
| 262 |
+
continue
|
| 263 |
+
|
| 264 |
+
total_tokens += len(tokens)
|
| 265 |
+
total_samples += 1
|
| 266 |
+
|
| 267 |
+
return {
|
| 268 |
+
"total_tokens": total_tokens,
|
| 269 |
+
"total_samples": total_samples,
|
| 270 |
+
"avg_tokens_per_sample": total_tokens / max(total_samples, 1),
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def get_dataloader_stats(dataloader: DataLoader) -> Dict[str, float]:
|
| 275 |
+
"""Get statistics from a dataloader.
|
| 276 |
+
|
| 277 |
+
Args:
|
| 278 |
+
dataloader: The dataloader to analyze
|
| 279 |
+
|
| 280 |
+
Returns:
|
| 281 |
+
Dictionary with statistics
|
| 282 |
+
"""
|
| 283 |
+
total_batches = 0
|
| 284 |
+
total_tokens = 0
|
| 285 |
+
total_non_pad_tokens = 0
|
| 286 |
+
|
| 287 |
+
for batch in dataloader:
|
| 288 |
+
total_batches += 1
|
| 289 |
+
total_tokens += batch["input_ids"].numel()
|
| 290 |
+
total_non_pad_tokens += batch["attention_mask"].sum().item()
|
| 291 |
+
|
| 292 |
+
# Only sample first 100 batches
|
| 293 |
+
if total_batches >= 100:
|
| 294 |
+
break
|
| 295 |
+
|
| 296 |
+
return {
|
| 297 |
+
"batches_sampled": total_batches,
|
| 298 |
+
"tokens_per_batch": total_tokens / max(total_batches, 1),
|
| 299 |
+
"non_pad_ratio": total_non_pad_tokens / max(total_tokens, 1),
|
| 300 |
+
}
|
src/data/dataset.py
ADDED
|
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Dataset classes for SLM training.
|
| 3 |
+
|
| 4 |
+
Handles loading, preprocessing, and tokenization of conversational data.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import json
|
| 9 |
+
import random
|
| 10 |
+
from typing import List, Dict, Optional, Iterator, Tuple
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
from torch.utils.data import Dataset, IterableDataset
|
| 15 |
+
|
| 16 |
+
from .tokenizer import SLMTokenizer
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class ConversationalDataset(Dataset):
|
| 20 |
+
"""Dataset for conversational/instruction-following data.
|
| 21 |
+
|
| 22 |
+
Loads pre-tokenized data from disk for efficient training.
|
| 23 |
+
Format: Each sample is a tokenized conversation with user/assistant turns.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
data_path: str,
|
| 29 |
+
tokenizer: SLMTokenizer,
|
| 30 |
+
max_length: int = 1024,
|
| 31 |
+
split: str = "train",
|
| 32 |
+
):
|
| 33 |
+
"""Initialize the dataset.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
data_path: Path to the processed data directory
|
| 37 |
+
tokenizer: Tokenizer instance
|
| 38 |
+
max_length: Maximum sequence length
|
| 39 |
+
split: "train" or "val"
|
| 40 |
+
"""
|
| 41 |
+
self.tokenizer = tokenizer
|
| 42 |
+
self.max_length = max_length
|
| 43 |
+
self.split = split
|
| 44 |
+
|
| 45 |
+
# Load data
|
| 46 |
+
self.samples = self._load_data(data_path)
|
| 47 |
+
print(f"Loaded {len(self.samples)} samples for {split} split")
|
| 48 |
+
|
| 49 |
+
def _load_data(self, data_path: str) -> List[Dict]:
|
| 50 |
+
"""Load data from JSON or JSONL files."""
|
| 51 |
+
samples = []
|
| 52 |
+
|
| 53 |
+
# Check for split-specific JSONL file first (preferred for large datasets)
|
| 54 |
+
split_jsonl = os.path.join(data_path, f"{self.split}.jsonl")
|
| 55 |
+
if os.path.exists(split_jsonl):
|
| 56 |
+
with open(split_jsonl, "r", encoding="utf-8") as f:
|
| 57 |
+
for line in f:
|
| 58 |
+
line = line.strip()
|
| 59 |
+
if line:
|
| 60 |
+
samples.append(json.loads(line))
|
| 61 |
+
return samples
|
| 62 |
+
|
| 63 |
+
# Check for split-specific JSON file
|
| 64 |
+
split_file = os.path.join(data_path, f"{self.split}.json")
|
| 65 |
+
if os.path.exists(split_file):
|
| 66 |
+
with open(split_file, "r", encoding="utf-8") as f:
|
| 67 |
+
# Try JSONL format first (one JSON per line)
|
| 68 |
+
content = f.read()
|
| 69 |
+
f.seek(0)
|
| 70 |
+
try:
|
| 71 |
+
# Try loading as single JSON array
|
| 72 |
+
samples = json.loads(content)
|
| 73 |
+
if isinstance(samples, list):
|
| 74 |
+
return samples
|
| 75 |
+
except json.JSONDecodeError:
|
| 76 |
+
pass
|
| 77 |
+
|
| 78 |
+
# Load as JSONL (one JSON per line)
|
| 79 |
+
for line in f:
|
| 80 |
+
line = line.strip()
|
| 81 |
+
if line:
|
| 82 |
+
samples.append(json.loads(line))
|
| 83 |
+
return samples
|
| 84 |
+
|
| 85 |
+
# Check for combined file with splits
|
| 86 |
+
combined_file = os.path.join(data_path, "data.json")
|
| 87 |
+
if os.path.exists(combined_file):
|
| 88 |
+
with open(combined_file, "r") as f:
|
| 89 |
+
all_data = json.load(f)
|
| 90 |
+
if isinstance(all_data, dict) and self.split in all_data:
|
| 91 |
+
return all_data[self.split]
|
| 92 |
+
return all_data
|
| 93 |
+
|
| 94 |
+
# Load all .json and .jsonl files in directory
|
| 95 |
+
for ext in ["*.jsonl", "*.json"]:
|
| 96 |
+
for file in sorted(Path(data_path).glob(ext)):
|
| 97 |
+
with open(file, "r", encoding="utf-8") as f:
|
| 98 |
+
if file.suffix == ".jsonl":
|
| 99 |
+
for line in f:
|
| 100 |
+
line = line.strip()
|
| 101 |
+
if line:
|
| 102 |
+
samples.append(json.loads(line))
|
| 103 |
+
else:
|
| 104 |
+
data = json.load(f)
|
| 105 |
+
if isinstance(data, list):
|
| 106 |
+
samples.extend(data)
|
| 107 |
+
else:
|
| 108 |
+
samples.append(data)
|
| 109 |
+
|
| 110 |
+
return samples
|
| 111 |
+
|
| 112 |
+
def __len__(self) -> int:
|
| 113 |
+
return len(self.samples)
|
| 114 |
+
|
| 115 |
+
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
| 116 |
+
"""Get a single sample.
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
Dictionary with:
|
| 120 |
+
- input_ids: Token IDs for the full sequence
|
| 121 |
+
- attention_mask: 1 for real tokens, 0 for padding
|
| 122 |
+
- labels: Same as input_ids but with -100 for padding (for loss)
|
| 123 |
+
"""
|
| 124 |
+
sample = self.samples[idx]
|
| 125 |
+
|
| 126 |
+
# Handle different data formats
|
| 127 |
+
if "input_ids" in sample:
|
| 128 |
+
# Pre-tokenized data
|
| 129 |
+
input_ids = sample["input_ids"]
|
| 130 |
+
elif "user" in sample and "assistant" in sample:
|
| 131 |
+
# Raw conversation format
|
| 132 |
+
input_ids = self.tokenizer.encode_conversation(
|
| 133 |
+
user_message=sample["user"],
|
| 134 |
+
assistant_message=sample["assistant"],
|
| 135 |
+
max_length=self.max_length,
|
| 136 |
+
)
|
| 137 |
+
elif "text" in sample:
|
| 138 |
+
# Raw text format
|
| 139 |
+
input_ids = self.tokenizer.encode(
|
| 140 |
+
sample["text"],
|
| 141 |
+
add_special_tokens=True,
|
| 142 |
+
max_length=self.max_length,
|
| 143 |
+
truncation=True,
|
| 144 |
+
)
|
| 145 |
+
elif "question" in sample and "answer" in sample:
|
| 146 |
+
# Q&A format
|
| 147 |
+
input_ids = self.tokenizer.encode_conversation(
|
| 148 |
+
user_message=sample["question"],
|
| 149 |
+
assistant_message=sample["answer"],
|
| 150 |
+
max_length=self.max_length,
|
| 151 |
+
)
|
| 152 |
+
else:
|
| 153 |
+
raise ValueError(f"Unknown sample format: {list(sample.keys())}")
|
| 154 |
+
|
| 155 |
+
# Pad or truncate
|
| 156 |
+
if len(input_ids) > self.max_length:
|
| 157 |
+
input_ids = input_ids[:self.max_length]
|
| 158 |
+
# Ensure EOS at the end
|
| 159 |
+
if input_ids[-1] != self.tokenizer.eos_token_id:
|
| 160 |
+
input_ids[-1] = self.tokenizer.eos_token_id
|
| 161 |
+
|
| 162 |
+
# Create attention mask (before padding)
|
| 163 |
+
attention_mask = [1] * len(input_ids)
|
| 164 |
+
|
| 165 |
+
# Pad if needed
|
| 166 |
+
padding_length = self.max_length - len(input_ids)
|
| 167 |
+
if padding_length > 0:
|
| 168 |
+
input_ids = input_ids + [self.tokenizer.pad_token_id] * padding_length
|
| 169 |
+
attention_mask = attention_mask + [0] * padding_length
|
| 170 |
+
|
| 171 |
+
# Labels for language modeling (shift happens in loss function)
|
| 172 |
+
# Use -100 for padding tokens so they're ignored in loss
|
| 173 |
+
labels = [
|
| 174 |
+
id if mask == 1 else -100
|
| 175 |
+
for id, mask in zip(input_ids, attention_mask)
|
| 176 |
+
]
|
| 177 |
+
|
| 178 |
+
return {
|
| 179 |
+
"input_ids": torch.tensor(input_ids, dtype=torch.long),
|
| 180 |
+
"attention_mask": torch.tensor(attention_mask, dtype=torch.long),
|
| 181 |
+
"labels": torch.tensor(labels, dtype=torch.long),
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class StreamingTextDataset(IterableDataset):
|
| 186 |
+
"""Streaming dataset for large text files.
|
| 187 |
+
|
| 188 |
+
Memory-efficient dataset that streams data from disk.
|
| 189 |
+
Useful for training on large text corpora.
|
| 190 |
+
"""
|
| 191 |
+
|
| 192 |
+
def __init__(
|
| 193 |
+
self,
|
| 194 |
+
data_files: List[str],
|
| 195 |
+
tokenizer: SLMTokenizer,
|
| 196 |
+
max_length: int = 1024,
|
| 197 |
+
shuffle: bool = True,
|
| 198 |
+
seed: int = 42,
|
| 199 |
+
):
|
| 200 |
+
"""Initialize streaming dataset.
|
| 201 |
+
|
| 202 |
+
Args:
|
| 203 |
+
data_files: List of text file paths
|
| 204 |
+
tokenizer: Tokenizer instance
|
| 205 |
+
max_length: Maximum sequence length
|
| 206 |
+
shuffle: Whether to shuffle files and lines
|
| 207 |
+
seed: Random seed for shuffling
|
| 208 |
+
"""
|
| 209 |
+
self.data_files = data_files
|
| 210 |
+
self.tokenizer = tokenizer
|
| 211 |
+
self.max_length = max_length
|
| 212 |
+
self.shuffle = shuffle
|
| 213 |
+
self.seed = seed
|
| 214 |
+
|
| 215 |
+
# Verify files exist
|
| 216 |
+
for f in data_files:
|
| 217 |
+
if not os.path.exists(f):
|
| 218 |
+
raise FileNotFoundError(f"Data file not found: {f}")
|
| 219 |
+
|
| 220 |
+
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
|
| 221 |
+
"""Iterate over all samples in all files."""
|
| 222 |
+
worker_info = torch.utils.data.get_worker_info()
|
| 223 |
+
|
| 224 |
+
# Handle multi-worker data loading
|
| 225 |
+
if worker_info is None:
|
| 226 |
+
files_to_process = self.data_files
|
| 227 |
+
else:
|
| 228 |
+
# Split files among workers
|
| 229 |
+
per_worker = len(self.data_files) // worker_info.num_workers
|
| 230 |
+
worker_id = worker_info.id
|
| 231 |
+
start = worker_id * per_worker
|
| 232 |
+
end = start + per_worker if worker_id < worker_info.num_workers - 1 else len(self.data_files)
|
| 233 |
+
files_to_process = self.data_files[start:end]
|
| 234 |
+
|
| 235 |
+
# Shuffle files if needed
|
| 236 |
+
if self.shuffle:
|
| 237 |
+
rng = random.Random(self.seed)
|
| 238 |
+
files_to_process = list(files_to_process)
|
| 239 |
+
rng.shuffle(files_to_process)
|
| 240 |
+
|
| 241 |
+
# Buffer for accumulating text
|
| 242 |
+
buffer = []
|
| 243 |
+
buffer_tokens = 0
|
| 244 |
+
|
| 245 |
+
for file_path in files_to_process:
|
| 246 |
+
with open(file_path, "r", encoding="utf-8") as f:
|
| 247 |
+
for line in f:
|
| 248 |
+
line = line.strip()
|
| 249 |
+
if not line:
|
| 250 |
+
continue
|
| 251 |
+
|
| 252 |
+
# Try to parse as JSON (for conversational data)
|
| 253 |
+
try:
|
| 254 |
+
data = json.loads(line)
|
| 255 |
+
if "user" in data and "assistant" in data:
|
| 256 |
+
tokens = self.tokenizer.encode_conversation(
|
| 257 |
+
data["user"], data["assistant"]
|
| 258 |
+
)
|
| 259 |
+
elif "text" in data:
|
| 260 |
+
tokens = self.tokenizer.encode(
|
| 261 |
+
data["text"], add_special_tokens=True
|
| 262 |
+
)
|
| 263 |
+
else:
|
| 264 |
+
tokens = self.tokenizer.encode(
|
| 265 |
+
line, add_special_tokens=True
|
| 266 |
+
)
|
| 267 |
+
except json.JSONDecodeError:
|
| 268 |
+
# Plain text line
|
| 269 |
+
tokens = self.tokenizer.encode(
|
| 270 |
+
line, add_special_tokens=True
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
buffer.extend(tokens)
|
| 274 |
+
|
| 275 |
+
# Yield chunks of max_length
|
| 276 |
+
while len(buffer) >= self.max_length:
|
| 277 |
+
chunk = buffer[:self.max_length]
|
| 278 |
+
buffer = buffer[self.max_length:]
|
| 279 |
+
|
| 280 |
+
yield self._create_sample(chunk)
|
| 281 |
+
|
| 282 |
+
# Handle remaining buffer (pad to max_length)
|
| 283 |
+
if len(buffer) > 0:
|
| 284 |
+
yield self._create_sample(buffer)
|
| 285 |
+
|
| 286 |
+
def _create_sample(self, tokens: List[int]) -> Dict[str, torch.Tensor]:
|
| 287 |
+
"""Create a training sample from tokens."""
|
| 288 |
+
input_ids = tokens[:self.max_length]
|
| 289 |
+
|
| 290 |
+
# Pad if needed
|
| 291 |
+
attention_mask = [1] * len(input_ids)
|
| 292 |
+
padding_length = self.max_length - len(input_ids)
|
| 293 |
+
if padding_length > 0:
|
| 294 |
+
input_ids = input_ids + [self.tokenizer.pad_token_id] * padding_length
|
| 295 |
+
attention_mask = attention_mask + [0] * padding_length
|
| 296 |
+
|
| 297 |
+
labels = [
|
| 298 |
+
id if mask == 1 else -100
|
| 299 |
+
for id, mask in zip(input_ids, attention_mask)
|
| 300 |
+
]
|
| 301 |
+
|
| 302 |
+
return {
|
| 303 |
+
"input_ids": torch.tensor(input_ids, dtype=torch.long),
|
| 304 |
+
"attention_mask": torch.tensor(attention_mask, dtype=torch.long),
|
| 305 |
+
"labels": torch.tensor(labels, dtype=torch.long),
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
class PackedDataset(Dataset):
|
| 310 |
+
"""Dataset that packs multiple short sequences into one.
|
| 311 |
+
|
| 312 |
+
Efficient for training when samples are shorter than max_length.
|
| 313 |
+
Concatenates samples with separator tokens to fill sequences.
|
| 314 |
+
"""
|
| 315 |
+
|
| 316 |
+
def __init__(
|
| 317 |
+
self,
|
| 318 |
+
samples: List[Dict],
|
| 319 |
+
tokenizer: SLMTokenizer,
|
| 320 |
+
max_length: int = 1024,
|
| 321 |
+
):
|
| 322 |
+
"""Initialize packed dataset.
|
| 323 |
+
|
| 324 |
+
Args:
|
| 325 |
+
samples: List of samples with "user" and "assistant" keys
|
| 326 |
+
tokenizer: Tokenizer instance
|
| 327 |
+
max_length: Maximum sequence length
|
| 328 |
+
"""
|
| 329 |
+
self.tokenizer = tokenizer
|
| 330 |
+
self.max_length = max_length
|
| 331 |
+
|
| 332 |
+
# Pack sequences
|
| 333 |
+
self.packed_samples = self._pack_sequences(samples)
|
| 334 |
+
print(f"Packed {len(samples)} samples into {len(self.packed_samples)} sequences")
|
| 335 |
+
|
| 336 |
+
def _pack_sequences(self, samples: List[Dict]) -> List[List[int]]:
|
| 337 |
+
"""Pack short sequences together."""
|
| 338 |
+
packed = []
|
| 339 |
+
current_sequence = []
|
| 340 |
+
|
| 341 |
+
for sample in samples:
|
| 342 |
+
# Tokenize
|
| 343 |
+
if "user" in sample and "assistant" in sample:
|
| 344 |
+
tokens = self.tokenizer.encode_conversation(
|
| 345 |
+
sample["user"], sample["assistant"]
|
| 346 |
+
)
|
| 347 |
+
elif "text" in sample:
|
| 348 |
+
tokens = self.tokenizer.encode(sample["text"], add_special_tokens=True)
|
| 349 |
+
else:
|
| 350 |
+
continue
|
| 351 |
+
|
| 352 |
+
# Check if we can add to current sequence
|
| 353 |
+
if len(current_sequence) + len(tokens) <= self.max_length:
|
| 354 |
+
current_sequence.extend(tokens)
|
| 355 |
+
else:
|
| 356 |
+
# Save current and start new
|
| 357 |
+
if current_sequence:
|
| 358 |
+
packed.append(current_sequence)
|
| 359 |
+
current_sequence = tokens[:self.max_length]
|
| 360 |
+
|
| 361 |
+
# Don't forget the last sequence
|
| 362 |
+
if current_sequence:
|
| 363 |
+
packed.append(current_sequence)
|
| 364 |
+
|
| 365 |
+
return packed
|
| 366 |
+
|
| 367 |
+
def __len__(self) -> int:
|
| 368 |
+
return len(self.packed_samples)
|
| 369 |
+
|
| 370 |
+
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
| 371 |
+
"""Get a packed sample."""
|
| 372 |
+
tokens = self.packed_samples[idx]
|
| 373 |
+
|
| 374 |
+
# Pad if needed
|
| 375 |
+
attention_mask = [1] * len(tokens)
|
| 376 |
+
padding_length = self.max_length - len(tokens)
|
| 377 |
+
if padding_length > 0:
|
| 378 |
+
tokens = tokens + [self.tokenizer.pad_token_id] * padding_length
|
| 379 |
+
attention_mask = attention_mask + [0] * padding_length
|
| 380 |
+
|
| 381 |
+
labels = [
|
| 382 |
+
id if mask == 1 else -100
|
| 383 |
+
for id, mask in zip(tokens, attention_mask)
|
| 384 |
+
]
|
| 385 |
+
|
| 386 |
+
return {
|
| 387 |
+
"input_ids": torch.tensor(tokens, dtype=torch.long),
|
| 388 |
+
"attention_mask": torch.tensor(attention_mask, dtype=torch.long),
|
| 389 |
+
"labels": torch.tensor(labels, dtype=torch.long),
|
| 390 |
+
}
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def create_train_val_split(
|
| 394 |
+
samples: List[Dict],
|
| 395 |
+
val_ratio: float = 0.01,
|
| 396 |
+
seed: int = 42,
|
| 397 |
+
) -> Tuple[List[Dict], List[Dict]]:
|
| 398 |
+
"""Split samples into train and validation sets.
|
| 399 |
+
|
| 400 |
+
Args:
|
| 401 |
+
samples: List of all samples
|
| 402 |
+
val_ratio: Ratio for validation set
|
| 403 |
+
seed: Random seed
|
| 404 |
+
|
| 405 |
+
Returns:
|
| 406 |
+
Tuple of (train_samples, val_samples)
|
| 407 |
+
"""
|
| 408 |
+
random.seed(seed)
|
| 409 |
+
shuffled = list(samples)
|
| 410 |
+
random.shuffle(shuffled)
|
| 411 |
+
|
| 412 |
+
val_size = int(len(shuffled) * val_ratio)
|
| 413 |
+
val_samples = shuffled[:val_size]
|
| 414 |
+
train_samples = shuffled[val_size:]
|
| 415 |
+
|
| 416 |
+
return train_samples, val_samples
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
def load_jsonl(file_path: str) -> List[Dict]:
|
| 420 |
+
"""Load data from a JSONL file."""
|
| 421 |
+
samples = []
|
| 422 |
+
with open(file_path, "r", encoding="utf-8") as f:
|
| 423 |
+
for line in f:
|
| 424 |
+
line = line.strip()
|
| 425 |
+
if line:
|
| 426 |
+
samples.append(json.loads(line))
|
| 427 |
+
return samples
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
def save_jsonl(samples: List[Dict], file_path: str):
|
| 431 |
+
"""Save data to a JSONL file."""
|
| 432 |
+
with open(file_path, "w", encoding="utf-8") as f:
|
| 433 |
+
for sample in samples:
|
| 434 |
+
f.write(json.dumps(sample) + "\n")
|
src/data/tokenizer.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Custom BPE Tokenizer for SLM v1.
|
| 3 |
+
16,384 vocabulary size optimized for conversational use.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import json
|
| 8 |
+
from typing import List, Optional, Union
|
| 9 |
+
from tokenizers import Tokenizer, models, trainers, pre_tokenizers, processors, decoders
|
| 10 |
+
from tokenizers.normalizers import NFKC, Lowercase, Sequence
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class SLMTokenizer:
|
| 14 |
+
"""Custom BPE tokenizer for the SLM model.
|
| 15 |
+
|
| 16 |
+
Features:
|
| 17 |
+
- 16,384 token vocabulary (memory efficient)
|
| 18 |
+
- Special tokens for conversation format
|
| 19 |
+
- Compatible with the model's embedding layer
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
# Special tokens
|
| 23 |
+
PAD_TOKEN = "<|pad|>"
|
| 24 |
+
BOS_TOKEN = "<|bos|>"
|
| 25 |
+
EOS_TOKEN = "<|eos|>"
|
| 26 |
+
UNK_TOKEN = "<|unk|>"
|
| 27 |
+
USER_TOKEN = "<|user|>"
|
| 28 |
+
ASSISTANT_TOKEN = "<|assistant|>"
|
| 29 |
+
|
| 30 |
+
SPECIAL_TOKENS = [PAD_TOKEN, BOS_TOKEN, EOS_TOKEN, UNK_TOKEN, USER_TOKEN, ASSISTANT_TOKEN]
|
| 31 |
+
|
| 32 |
+
def __init__(self, tokenizer: Optional[Tokenizer] = None):
|
| 33 |
+
"""Initialize tokenizer.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
tokenizer: Pre-trained HuggingFace tokenizer object
|
| 37 |
+
"""
|
| 38 |
+
self.tokenizer = tokenizer
|
| 39 |
+
self._setup_special_token_ids()
|
| 40 |
+
|
| 41 |
+
def _setup_special_token_ids(self):
|
| 42 |
+
"""Setup special token IDs for easy access."""
|
| 43 |
+
if self.tokenizer is not None:
|
| 44 |
+
self.pad_token_id = self.tokenizer.token_to_id(self.PAD_TOKEN)
|
| 45 |
+
self.bos_token_id = self.tokenizer.token_to_id(self.BOS_TOKEN)
|
| 46 |
+
self.eos_token_id = self.tokenizer.token_to_id(self.EOS_TOKEN)
|
| 47 |
+
self.unk_token_id = self.tokenizer.token_to_id(self.UNK_TOKEN)
|
| 48 |
+
self.user_token_id = self.tokenizer.token_to_id(self.USER_TOKEN)
|
| 49 |
+
self.assistant_token_id = self.tokenizer.token_to_id(self.ASSISTANT_TOKEN)
|
| 50 |
+
|
| 51 |
+
@classmethod
|
| 52 |
+
def train(
|
| 53 |
+
cls,
|
| 54 |
+
files: List[str],
|
| 55 |
+
vocab_size: int = 16384,
|
| 56 |
+
min_frequency: int = 2,
|
| 57 |
+
save_path: Optional[str] = None,
|
| 58 |
+
) -> "SLMTokenizer":
|
| 59 |
+
"""Train a new BPE tokenizer on the given files.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
files: List of text file paths to train on
|
| 63 |
+
vocab_size: Size of vocabulary (default 16,384)
|
| 64 |
+
min_frequency: Minimum token frequency to include
|
| 65 |
+
save_path: Optional path to save the trained tokenizer
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
Trained SLMTokenizer instance
|
| 69 |
+
"""
|
| 70 |
+
print(f"Training BPE tokenizer with vocab_size={vocab_size}...")
|
| 71 |
+
print(f"Training files: {files}")
|
| 72 |
+
|
| 73 |
+
# Initialize a BPE tokenizer
|
| 74 |
+
tokenizer = Tokenizer(models.BPE(unk_token=cls.UNK_TOKEN))
|
| 75 |
+
|
| 76 |
+
# Set up normalizer (optional - keeps text mostly as-is)
|
| 77 |
+
# We use NFKC normalization to standardize unicode
|
| 78 |
+
tokenizer.normalizer = NFKC()
|
| 79 |
+
|
| 80 |
+
# Set up pre-tokenizer (splits on whitespace and punctuation)
|
| 81 |
+
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
|
| 82 |
+
|
| 83 |
+
# Set up decoder
|
| 84 |
+
tokenizer.decoder = decoders.ByteLevel()
|
| 85 |
+
|
| 86 |
+
# Set up trainer
|
| 87 |
+
trainer = trainers.BpeTrainer(
|
| 88 |
+
vocab_size=vocab_size,
|
| 89 |
+
min_frequency=min_frequency,
|
| 90 |
+
special_tokens=cls.SPECIAL_TOKENS,
|
| 91 |
+
show_progress=True,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
# Train the tokenizer
|
| 95 |
+
tokenizer.train(files, trainer)
|
| 96 |
+
|
| 97 |
+
# Set up post-processor for adding special tokens
|
| 98 |
+
tokenizer.post_processor = processors.TemplateProcessing(
|
| 99 |
+
single=f"{cls.BOS_TOKEN} $A {cls.EOS_TOKEN}",
|
| 100 |
+
pair=f"{cls.BOS_TOKEN} $A {cls.EOS_TOKEN} {cls.BOS_TOKEN} $B {cls.EOS_TOKEN}",
|
| 101 |
+
special_tokens=[
|
| 102 |
+
(cls.BOS_TOKEN, tokenizer.token_to_id(cls.BOS_TOKEN)),
|
| 103 |
+
(cls.EOS_TOKEN, tokenizer.token_to_id(cls.EOS_TOKEN)),
|
| 104 |
+
],
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
print(f"Tokenizer trained! Vocabulary size: {tokenizer.get_vocab_size()}")
|
| 108 |
+
|
| 109 |
+
# Create instance
|
| 110 |
+
instance = cls(tokenizer)
|
| 111 |
+
|
| 112 |
+
# Save if path provided
|
| 113 |
+
if save_path:
|
| 114 |
+
instance.save(save_path)
|
| 115 |
+
|
| 116 |
+
return instance
|
| 117 |
+
|
| 118 |
+
@classmethod
|
| 119 |
+
def from_file(cls, path: str) -> "SLMTokenizer":
|
| 120 |
+
"""Load a tokenizer from a saved file.
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
path: Path to the tokenizer.json file
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
Loaded SLMTokenizer instance
|
| 127 |
+
"""
|
| 128 |
+
tokenizer = Tokenizer.from_file(path)
|
| 129 |
+
return cls(tokenizer)
|
| 130 |
+
|
| 131 |
+
def save(self, path: str):
|
| 132 |
+
"""Save the tokenizer to a file.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
path: Path to save the tokenizer (directory or file)
|
| 136 |
+
"""
|
| 137 |
+
if os.path.isdir(path):
|
| 138 |
+
save_path = os.path.join(path, "tokenizer.json")
|
| 139 |
+
else:
|
| 140 |
+
save_path = path
|
| 141 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
| 142 |
+
|
| 143 |
+
self.tokenizer.save(save_path)
|
| 144 |
+
print(f"Tokenizer saved to: {save_path}")
|
| 145 |
+
|
| 146 |
+
# Also save config
|
| 147 |
+
config_path = save_path.replace("tokenizer.json", "tokenizer_config.json")
|
| 148 |
+
config = {
|
| 149 |
+
"vocab_size": self.vocab_size,
|
| 150 |
+
"pad_token": self.PAD_TOKEN,
|
| 151 |
+
"bos_token": self.BOS_TOKEN,
|
| 152 |
+
"eos_token": self.EOS_TOKEN,
|
| 153 |
+
"unk_token": self.UNK_TOKEN,
|
| 154 |
+
"user_token": self.USER_TOKEN,
|
| 155 |
+
"assistant_token": self.ASSISTANT_TOKEN,
|
| 156 |
+
}
|
| 157 |
+
with open(config_path, "w") as f:
|
| 158 |
+
json.dump(config, f, indent=2)
|
| 159 |
+
print(f"Tokenizer config saved to: {config_path}")
|
| 160 |
+
|
| 161 |
+
def encode(
|
| 162 |
+
self,
|
| 163 |
+
text: str,
|
| 164 |
+
add_special_tokens: bool = True,
|
| 165 |
+
max_length: Optional[int] = None,
|
| 166 |
+
padding: bool = False,
|
| 167 |
+
truncation: bool = False,
|
| 168 |
+
) -> List[int]:
|
| 169 |
+
"""Encode text to token IDs.
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
text: Input text string
|
| 173 |
+
add_special_tokens: Whether to add BOS/EOS tokens
|
| 174 |
+
max_length: Maximum sequence length
|
| 175 |
+
padding: Whether to pad to max_length
|
| 176 |
+
truncation: Whether to truncate to max_length
|
| 177 |
+
|
| 178 |
+
Returns:
|
| 179 |
+
List of token IDs
|
| 180 |
+
"""
|
| 181 |
+
# Encode
|
| 182 |
+
if add_special_tokens:
|
| 183 |
+
encoding = self.tokenizer.encode(text)
|
| 184 |
+
else:
|
| 185 |
+
encoding = self.tokenizer.encode(text, add_special_tokens=False)
|
| 186 |
+
|
| 187 |
+
ids = encoding.ids
|
| 188 |
+
|
| 189 |
+
# Truncation
|
| 190 |
+
if truncation and max_length and len(ids) > max_length:
|
| 191 |
+
ids = ids[:max_length]
|
| 192 |
+
# Ensure EOS at end if we had special tokens
|
| 193 |
+
if add_special_tokens and ids[-1] != self.eos_token_id:
|
| 194 |
+
ids[-1] = self.eos_token_id
|
| 195 |
+
|
| 196 |
+
# Padding
|
| 197 |
+
if padding and max_length and len(ids) < max_length:
|
| 198 |
+
ids = ids + [self.pad_token_id] * (max_length - len(ids))
|
| 199 |
+
|
| 200 |
+
return ids
|
| 201 |
+
|
| 202 |
+
def decode(self, ids: List[int], skip_special_tokens: bool = True) -> str:
|
| 203 |
+
"""Decode token IDs to text.
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
ids: List of token IDs
|
| 207 |
+
skip_special_tokens: Whether to remove special tokens
|
| 208 |
+
|
| 209 |
+
Returns:
|
| 210 |
+
Decoded text string
|
| 211 |
+
"""
|
| 212 |
+
return self.tokenizer.decode(ids, skip_special_tokens=skip_special_tokens)
|
| 213 |
+
|
| 214 |
+
def encode_conversation(
|
| 215 |
+
self,
|
| 216 |
+
user_message: str,
|
| 217 |
+
assistant_message: Optional[str] = None,
|
| 218 |
+
max_length: Optional[int] = None,
|
| 219 |
+
) -> List[int]:
|
| 220 |
+
"""Encode a conversation turn.
|
| 221 |
+
|
| 222 |
+
Format: <|bos|><|user|>message<|assistant|>response<|eos|>
|
| 223 |
+
|
| 224 |
+
Args:
|
| 225 |
+
user_message: The user's message
|
| 226 |
+
assistant_message: Optional assistant response
|
| 227 |
+
max_length: Maximum sequence length
|
| 228 |
+
|
| 229 |
+
Returns:
|
| 230 |
+
List of token IDs
|
| 231 |
+
"""
|
| 232 |
+
# Build conversation string
|
| 233 |
+
if assistant_message:
|
| 234 |
+
text = f"{self.USER_TOKEN}{user_message}{self.ASSISTANT_TOKEN}{assistant_message}"
|
| 235 |
+
else:
|
| 236 |
+
# For inference - no response yet
|
| 237 |
+
text = f"{self.USER_TOKEN}{user_message}{self.ASSISTANT_TOKEN}"
|
| 238 |
+
|
| 239 |
+
return self.encode(text, add_special_tokens=True, max_length=max_length, truncation=True)
|
| 240 |
+
|
| 241 |
+
@property
|
| 242 |
+
def vocab_size(self) -> int:
|
| 243 |
+
"""Get vocabulary size."""
|
| 244 |
+
return self.tokenizer.get_vocab_size()
|
| 245 |
+
|
| 246 |
+
def get_vocab(self) -> dict:
|
| 247 |
+
"""Get the vocabulary as a dictionary."""
|
| 248 |
+
return self.tokenizer.get_vocab()
|
| 249 |
+
|
| 250 |
+
def __len__(self) -> int:
|
| 251 |
+
"""Return vocabulary size."""
|
| 252 |
+
return self.vocab_size
|
| 253 |
+
|
| 254 |
+
def __call__(
|
| 255 |
+
self,
|
| 256 |
+
text: Union[str, List[str]],
|
| 257 |
+
max_length: Optional[int] = None,
|
| 258 |
+
padding: bool = False,
|
| 259 |
+
truncation: bool = False,
|
| 260 |
+
return_tensors: Optional[str] = None,
|
| 261 |
+
) -> dict:
|
| 262 |
+
"""Tokenize text (HuggingFace-style interface).
|
| 263 |
+
|
| 264 |
+
Args:
|
| 265 |
+
text: Input text or list of texts
|
| 266 |
+
max_length: Maximum sequence length
|
| 267 |
+
padding: Whether to pad sequences
|
| 268 |
+
truncation: Whether to truncate sequences
|
| 269 |
+
return_tensors: If "pt", return PyTorch tensors
|
| 270 |
+
|
| 271 |
+
Returns:
|
| 272 |
+
Dictionary with input_ids and attention_mask
|
| 273 |
+
"""
|
| 274 |
+
if isinstance(text, str):
|
| 275 |
+
text = [text]
|
| 276 |
+
|
| 277 |
+
all_ids = []
|
| 278 |
+
for t in text:
|
| 279 |
+
ids = self.encode(
|
| 280 |
+
t,
|
| 281 |
+
max_length=max_length,
|
| 282 |
+
padding=padding,
|
| 283 |
+
truncation=truncation,
|
| 284 |
+
)
|
| 285 |
+
all_ids.append(ids)
|
| 286 |
+
|
| 287 |
+
# Create attention mask (1 for real tokens, 0 for padding)
|
| 288 |
+
attention_mask = [[1 if id != self.pad_token_id else 0 for id in ids] for ids in all_ids]
|
| 289 |
+
|
| 290 |
+
result = {
|
| 291 |
+
"input_ids": all_ids,
|
| 292 |
+
"attention_mask": attention_mask,
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
if return_tensors == "pt":
|
| 296 |
+
import torch
|
| 297 |
+
result["input_ids"] = torch.tensor(all_ids)
|
| 298 |
+
result["attention_mask"] = torch.tensor(attention_mask)
|
| 299 |
+
|
| 300 |
+
return result
|
src/export/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# ONNX export components
|
src/inference/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Inference and generation components
|
src/model/__init__.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SLM Model Components."""
|
| 2 |
+
|
| 3 |
+
from .config import SLMConfig
|
| 4 |
+
from .transformer import SLMForCausalLM, SLMModel, SLMOutput
|
| 5 |
+
from .kv_cache import KVCache
|
| 6 |
+
from .normalization import RMSNorm
|
| 7 |
+
from .rope import RotaryEmbedding
|
| 8 |
+
from .attention import MultiHeadAttention, create_causal_mask
|
| 9 |
+
from .ffn import FeedForward
|
| 10 |
+
from .decoder import DecoderBlock
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"SLMConfig",
|
| 14 |
+
"SLMForCausalLM",
|
| 15 |
+
"SLMModel",
|
| 16 |
+
"SLMOutput",
|
| 17 |
+
"KVCache",
|
| 18 |
+
"RMSNorm",
|
| 19 |
+
"RotaryEmbedding",
|
| 20 |
+
"MultiHeadAttention",
|
| 21 |
+
"create_causal_mask",
|
| 22 |
+
"FeedForward",
|
| 23 |
+
"DecoderBlock",
|
| 24 |
+
]
|
src/model/attention.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Multi-Head Attention with explicit KV cache for SLM.
|
| 3 |
+
Qualcomm-safe: No FlashAttention, no fused ops, clean ONNX export.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from typing import Optional, Tuple
|
| 10 |
+
|
| 11 |
+
from .config import SLMConfig
|
| 12 |
+
from .rope import RotaryEmbedding
|
| 13 |
+
from .kv_cache import KVCache
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class MultiHeadAttention(nn.Module):
|
| 17 |
+
"""Multi-Head Self-Attention with RoPE and explicit KV cache.
|
| 18 |
+
|
| 19 |
+
Design choices for Qualcomm compatibility:
|
| 20 |
+
- Standard attention (no FlashAttention)
|
| 21 |
+
- No grouped/multi-query attention (simpler, v1.1 will add GQA)
|
| 22 |
+
- Explicit KV cache management
|
| 23 |
+
- Clean tensor operations for ONNX export
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, config: SLMConfig, layer_idx: int):
|
| 27 |
+
"""Initialize attention layer.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
config: Model configuration
|
| 31 |
+
layer_idx: Index of this layer (for KV cache)
|
| 32 |
+
"""
|
| 33 |
+
super().__init__()
|
| 34 |
+
self.config = config
|
| 35 |
+
self.layer_idx = layer_idx
|
| 36 |
+
|
| 37 |
+
self.hidden_size = config.hidden_size
|
| 38 |
+
self.num_heads = config.num_heads
|
| 39 |
+
self.head_dim = config.head_dim
|
| 40 |
+
self.dropout = config.attention_dropout
|
| 41 |
+
|
| 42 |
+
# Q, K, V projections
|
| 43 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 44 |
+
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 45 |
+
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 46 |
+
|
| 47 |
+
# Output projection
|
| 48 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
| 49 |
+
|
| 50 |
+
# Rotary embeddings
|
| 51 |
+
self.rotary_emb = RotaryEmbedding(
|
| 52 |
+
dim=self.head_dim,
|
| 53 |
+
max_position_embeddings=config.max_position_embeddings,
|
| 54 |
+
base=config.rope_theta,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
def forward(
|
| 58 |
+
self,
|
| 59 |
+
hidden_states: torch.Tensor,
|
| 60 |
+
position_ids: torch.Tensor,
|
| 61 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 62 |
+
kv_cache: Optional[KVCache] = None,
|
| 63 |
+
use_cache: bool = False,
|
| 64 |
+
) -> Tuple[torch.Tensor, Optional[KVCache]]:
|
| 65 |
+
"""Forward pass for attention.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
hidden_states: Input tensor [batch, seq_len, hidden_size]
|
| 69 |
+
position_ids: Position indices [batch, seq_len]
|
| 70 |
+
attention_mask: Causal mask [batch, 1, seq_len, kv_seq_len]
|
| 71 |
+
kv_cache: Optional KV cache for inference
|
| 72 |
+
use_cache: Whether to use/update KV cache
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
Tuple of (output, kv_cache)
|
| 76 |
+
"""
|
| 77 |
+
batch_size, seq_len, _ = hidden_states.shape
|
| 78 |
+
|
| 79 |
+
# Project to Q, K, V
|
| 80 |
+
query = self.q_proj(hidden_states)
|
| 81 |
+
key = self.k_proj(hidden_states)
|
| 82 |
+
value = self.v_proj(hidden_states)
|
| 83 |
+
|
| 84 |
+
# Reshape: [batch, seq, hidden] -> [batch, seq, heads, head_dim]
|
| 85 |
+
query = query.view(batch_size, seq_len, self.num_heads, self.head_dim)
|
| 86 |
+
key = key.view(batch_size, seq_len, self.num_heads, self.head_dim)
|
| 87 |
+
value = value.view(batch_size, seq_len, self.num_heads, self.head_dim)
|
| 88 |
+
|
| 89 |
+
# Transpose for attention: [batch, heads, seq, head_dim]
|
| 90 |
+
query = query.transpose(1, 2)
|
| 91 |
+
key = key.transpose(1, 2)
|
| 92 |
+
value = value.transpose(1, 2)
|
| 93 |
+
|
| 94 |
+
# Apply rotary embeddings to Q and K
|
| 95 |
+
query, key = self.rotary_emb(query, key, position_ids)
|
| 96 |
+
|
| 97 |
+
# Handle KV cache
|
| 98 |
+
if use_cache and kv_cache is not None:
|
| 99 |
+
# Get the position to write to cache
|
| 100 |
+
cache_position = position_ids[0, 0].item()
|
| 101 |
+
|
| 102 |
+
# Update cache and get full K, V
|
| 103 |
+
key, value = kv_cache.update(
|
| 104 |
+
layer_idx=self.layer_idx,
|
| 105 |
+
key=key,
|
| 106 |
+
value=value,
|
| 107 |
+
position=cache_position,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
# Compute attention scores
|
| 111 |
+
# [batch, heads, seq, head_dim] @ [batch, heads, head_dim, kv_seq]
|
| 112 |
+
# -> [batch, heads, seq, kv_seq]
|
| 113 |
+
scale = 1.0 / (self.head_dim ** 0.5)
|
| 114 |
+
attn_weights = torch.matmul(query, key.transpose(-2, -1)) * scale
|
| 115 |
+
|
| 116 |
+
# Apply causal mask
|
| 117 |
+
if attention_mask is not None:
|
| 118 |
+
attn_weights = attn_weights + attention_mask
|
| 119 |
+
|
| 120 |
+
# Softmax and dropout
|
| 121 |
+
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
| 122 |
+
|
| 123 |
+
if self.training and self.dropout > 0:
|
| 124 |
+
attn_weights = F.dropout(attn_weights, p=self.dropout)
|
| 125 |
+
|
| 126 |
+
# Apply attention to values
|
| 127 |
+
# [batch, heads, seq, kv_seq] @ [batch, heads, kv_seq, head_dim]
|
| 128 |
+
# -> [batch, heads, seq, head_dim]
|
| 129 |
+
attn_output = torch.matmul(attn_weights, value)
|
| 130 |
+
|
| 131 |
+
# Reshape back: [batch, heads, seq, head_dim] -> [batch, seq, hidden]
|
| 132 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 133 |
+
attn_output = attn_output.view(batch_size, seq_len, self.hidden_size)
|
| 134 |
+
|
| 135 |
+
# Output projection
|
| 136 |
+
output = self.o_proj(attn_output)
|
| 137 |
+
|
| 138 |
+
return output, kv_cache
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def create_causal_mask(
|
| 142 |
+
seq_len: int,
|
| 143 |
+
kv_seq_len: int,
|
| 144 |
+
dtype: torch.dtype,
|
| 145 |
+
device: torch.device,
|
| 146 |
+
) -> torch.Tensor:
|
| 147 |
+
"""Create a causal attention mask.
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
seq_len: Query sequence length
|
| 151 |
+
kv_seq_len: Key/value sequence length
|
| 152 |
+
dtype: Data type for mask
|
| 153 |
+
device: Device for mask
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
Causal mask tensor [1, 1, seq_len, kv_seq_len]
|
| 157 |
+
"""
|
| 158 |
+
# Create lower triangular mask
|
| 159 |
+
mask = torch.full((seq_len, kv_seq_len), float("-inf"), dtype=dtype, device=device)
|
| 160 |
+
|
| 161 |
+
# For decode (seq_len=1), we can attend to all previous tokens
|
| 162 |
+
if seq_len == 1:
|
| 163 |
+
mask = torch.zeros((seq_len, kv_seq_len), dtype=dtype, device=device)
|
| 164 |
+
else:
|
| 165 |
+
# For prefill, create standard causal mask
|
| 166 |
+
# Position i can attend to positions 0..i
|
| 167 |
+
for i in range(seq_len):
|
| 168 |
+
# Offset for KV cache
|
| 169 |
+
offset = kv_seq_len - seq_len
|
| 170 |
+
mask[i, : offset + i + 1] = 0.0
|
| 171 |
+
|
| 172 |
+
return mask.unsqueeze(0).unsqueeze(0)
|
src/model/config.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Model configuration for SLM v1.
|
| 3 |
+
Defines all hyperparameters based on architecture specification.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Optional
|
| 8 |
+
import yaml
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class SLMConfig:
|
| 13 |
+
"""Configuration class for the SLM model.
|
| 14 |
+
|
| 15 |
+
Architecture: 120M parameter decoder-only transformer
|
| 16 |
+
- 8 layers, 1024 hidden size, 16 attention heads
|
| 17 |
+
- RMSNorm (pre-norm), GELU FFN, RoPE positions
|
| 18 |
+
- Explicit KV cache for efficient inference
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
# Model architecture
|
| 22 |
+
vocab_size: int = 16384
|
| 23 |
+
hidden_size: int = 1024
|
| 24 |
+
num_layers: int = 8
|
| 25 |
+
num_heads: int = 16
|
| 26 |
+
head_dim: int = 64
|
| 27 |
+
intermediate_size: int = 4096 # 4 * hidden_size
|
| 28 |
+
|
| 29 |
+
# Position encoding
|
| 30 |
+
max_position_embeddings: int = 1024
|
| 31 |
+
rope_theta: float = 10000.0
|
| 32 |
+
|
| 33 |
+
# Normalization
|
| 34 |
+
rms_norm_eps: float = 1e-6
|
| 35 |
+
|
| 36 |
+
# Embeddings
|
| 37 |
+
tie_word_embeddings: bool = True
|
| 38 |
+
|
| 39 |
+
# Dropout (disabled for inference, optional for training)
|
| 40 |
+
dropout: float = 0.0
|
| 41 |
+
attention_dropout: float = 0.0
|
| 42 |
+
|
| 43 |
+
# Precision
|
| 44 |
+
torch_dtype: str = "float16"
|
| 45 |
+
|
| 46 |
+
def __post_init__(self):
|
| 47 |
+
"""Validate configuration after initialization."""
|
| 48 |
+
assert self.hidden_size % self.num_heads == 0, \
|
| 49 |
+
f"hidden_size ({self.hidden_size}) must be divisible by num_heads ({self.num_heads})"
|
| 50 |
+
assert self.head_dim == self.hidden_size // self.num_heads, \
|
| 51 |
+
f"head_dim ({self.head_dim}) must equal hidden_size // num_heads ({self.hidden_size // self.num_heads})"
|
| 52 |
+
|
| 53 |
+
@classmethod
|
| 54 |
+
def from_yaml(cls, path: str) -> "SLMConfig":
|
| 55 |
+
"""Load configuration from YAML file."""
|
| 56 |
+
with open(path, "r") as f:
|
| 57 |
+
config_dict = yaml.safe_load(f)
|
| 58 |
+
|
| 59 |
+
model_config = config_dict.get("model", {})
|
| 60 |
+
return cls(**model_config)
|
| 61 |
+
|
| 62 |
+
def to_dict(self) -> dict:
|
| 63 |
+
"""Convert configuration to dictionary."""
|
| 64 |
+
return {
|
| 65 |
+
"vocab_size": self.vocab_size,
|
| 66 |
+
"hidden_size": self.hidden_size,
|
| 67 |
+
"num_layers": self.num_layers,
|
| 68 |
+
"num_heads": self.num_heads,
|
| 69 |
+
"head_dim": self.head_dim,
|
| 70 |
+
"intermediate_size": self.intermediate_size,
|
| 71 |
+
"max_position_embeddings": self.max_position_embeddings,
|
| 72 |
+
"rope_theta": self.rope_theta,
|
| 73 |
+
"rms_norm_eps": self.rms_norm_eps,
|
| 74 |
+
"tie_word_embeddings": self.tie_word_embeddings,
|
| 75 |
+
"dropout": self.dropout,
|
| 76 |
+
"attention_dropout": self.attention_dropout,
|
| 77 |
+
"torch_dtype": self.torch_dtype,
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
@property
|
| 81 |
+
def num_parameters(self) -> int:
|
| 82 |
+
"""Estimate total number of parameters."""
|
| 83 |
+
# Embedding: vocab_size * hidden_size
|
| 84 |
+
embedding_params = self.vocab_size * self.hidden_size
|
| 85 |
+
|
| 86 |
+
# Per layer:
|
| 87 |
+
# - Attention: 4 * hidden_size^2 (Q, K, V, O projections)
|
| 88 |
+
# - FFN: 2 * hidden_size * intermediate_size
|
| 89 |
+
# - Norms: 2 * hidden_size
|
| 90 |
+
attention_params = 4 * self.hidden_size * self.hidden_size
|
| 91 |
+
ffn_params = 2 * self.hidden_size * self.intermediate_size
|
| 92 |
+
norm_params = 2 * self.hidden_size
|
| 93 |
+
|
| 94 |
+
layer_params = attention_params + ffn_params + norm_params
|
| 95 |
+
total_layer_params = self.num_layers * layer_params
|
| 96 |
+
|
| 97 |
+
# Output head (tied with embedding if enabled)
|
| 98 |
+
output_params = 0 if self.tie_word_embeddings else self.vocab_size * self.hidden_size
|
| 99 |
+
|
| 100 |
+
# Final norm
|
| 101 |
+
final_norm_params = self.hidden_size
|
| 102 |
+
|
| 103 |
+
return embedding_params + total_layer_params + output_params + final_norm_params
|
| 104 |
+
|
| 105 |
+
def __repr__(self) -> str:
|
| 106 |
+
params_m = self.num_parameters / 1e6
|
| 107 |
+
return (
|
| 108 |
+
f"SLMConfig(\n"
|
| 109 |
+
f" vocab_size={self.vocab_size},\n"
|
| 110 |
+
f" hidden_size={self.hidden_size},\n"
|
| 111 |
+
f" num_layers={self.num_layers},\n"
|
| 112 |
+
f" num_heads={self.num_heads},\n"
|
| 113 |
+
f" max_position_embeddings={self.max_position_embeddings},\n"
|
| 114 |
+
f" estimated_params={params_m:.1f}M\n"
|
| 115 |
+
f")"
|
| 116 |
+
)
|
src/model/decoder.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Decoder Block for SLM.
|
| 3 |
+
Pre-norm architecture with residual connections.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from typing import Optional, Tuple
|
| 9 |
+
|
| 10 |
+
from .config import SLMConfig
|
| 11 |
+
from .normalization import RMSNorm
|
| 12 |
+
from .attention import MultiHeadAttention
|
| 13 |
+
from .ffn import FeedForward
|
| 14 |
+
from .kv_cache import KVCache
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class DecoderBlock(nn.Module):
|
| 18 |
+
"""Single decoder block with pre-norm architecture.
|
| 19 |
+
|
| 20 |
+
Structure (Pre-Norm):
|
| 21 |
+
```
|
| 22 |
+
x
|
| 23 |
+
├─ RMSNorm
|
| 24 |
+
├─ Multi-Head Attention
|
| 25 |
+
├─ Residual Add
|
| 26 |
+
├─ RMSNorm
|
| 27 |
+
├─ Feed-Forward Network
|
| 28 |
+
└─ Residual Add
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
Why pre-norm:
|
| 32 |
+
- More stable gradients in FP16 training
|
| 33 |
+
- Better quantization behavior
|
| 34 |
+
- Easier ONNX export (no layer-crossing dependencies)
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(self, config: SLMConfig, layer_idx: int):
|
| 38 |
+
"""Initialize decoder block.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
config: Model configuration
|
| 42 |
+
layer_idx: Index of this layer
|
| 43 |
+
"""
|
| 44 |
+
super().__init__()
|
| 45 |
+
self.config = config
|
| 46 |
+
self.layer_idx = layer_idx
|
| 47 |
+
|
| 48 |
+
# Pre-attention norm
|
| 49 |
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 50 |
+
|
| 51 |
+
# Self-attention
|
| 52 |
+
self.self_attn = MultiHeadAttention(config, layer_idx)
|
| 53 |
+
|
| 54 |
+
# Pre-FFN norm
|
| 55 |
+
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 56 |
+
|
| 57 |
+
# Feed-forward network
|
| 58 |
+
self.mlp = FeedForward(config)
|
| 59 |
+
|
| 60 |
+
def forward(
|
| 61 |
+
self,
|
| 62 |
+
hidden_states: torch.Tensor,
|
| 63 |
+
position_ids: torch.Tensor,
|
| 64 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 65 |
+
kv_cache: Optional[KVCache] = None,
|
| 66 |
+
use_cache: bool = False,
|
| 67 |
+
) -> Tuple[torch.Tensor, Optional[KVCache]]:
|
| 68 |
+
"""Forward pass through decoder block.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
hidden_states: Input tensor [batch, seq, hidden_size]
|
| 72 |
+
position_ids: Position indices [batch, seq]
|
| 73 |
+
attention_mask: Causal attention mask
|
| 74 |
+
kv_cache: Optional KV cache
|
| 75 |
+
use_cache: Whether to use/update cache
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
Tuple of (output, kv_cache)
|
| 79 |
+
"""
|
| 80 |
+
# Store residual
|
| 81 |
+
residual = hidden_states
|
| 82 |
+
|
| 83 |
+
# Pre-norm -> Attention
|
| 84 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 85 |
+
hidden_states, kv_cache = self.self_attn(
|
| 86 |
+
hidden_states=hidden_states,
|
| 87 |
+
position_ids=position_ids,
|
| 88 |
+
attention_mask=attention_mask,
|
| 89 |
+
kv_cache=kv_cache,
|
| 90 |
+
use_cache=use_cache,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# Residual connection
|
| 94 |
+
hidden_states = residual + hidden_states
|
| 95 |
+
|
| 96 |
+
# Store residual
|
| 97 |
+
residual = hidden_states
|
| 98 |
+
|
| 99 |
+
# Pre-norm -> FFN
|
| 100 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 101 |
+
hidden_states = self.mlp(hidden_states)
|
| 102 |
+
|
| 103 |
+
# Residual connection
|
| 104 |
+
hidden_states = residual + hidden_states
|
| 105 |
+
|
| 106 |
+
return hidden_states, kv_cache
|
src/model/ffn.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Feed-Forward Network for SLM.
|
| 3 |
+
Uses GELU activation (not SwiGLU) for better INT8 quantization.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
from .config import SLMConfig
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class FeedForward(nn.Module):
|
| 14 |
+
"""Feed-Forward Network with GELU activation.
|
| 15 |
+
|
| 16 |
+
Architecture: Linear -> GELU -> Linear
|
| 17 |
+
- Input: [batch, seq, hidden_size=768]
|
| 18 |
+
- Hidden: [batch, seq, intermediate_size=3072]
|
| 19 |
+
- Output: [batch, seq, hidden_size=768]
|
| 20 |
+
|
| 21 |
+
Why GELU over SwiGLU:
|
| 22 |
+
- Fewer operations (2 matmuls vs 3)
|
| 23 |
+
- Better INT8 quantization behavior
|
| 24 |
+
- Full QNN support without decomposition
|
| 25 |
+
- SwiGLU benefits mainly appear at >1B parameters
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(self, config: SLMConfig):
|
| 29 |
+
"""Initialize FFN.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
config: Model configuration
|
| 33 |
+
"""
|
| 34 |
+
super().__init__()
|
| 35 |
+
|
| 36 |
+
self.hidden_size = config.hidden_size
|
| 37 |
+
self.intermediate_size = config.intermediate_size
|
| 38 |
+
|
| 39 |
+
# Up projection: hidden -> intermediate
|
| 40 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 41 |
+
|
| 42 |
+
# Down projection: intermediate -> hidden
|
| 43 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 44 |
+
|
| 45 |
+
self.dropout = config.dropout
|
| 46 |
+
|
| 47 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 48 |
+
"""Forward pass through FFN.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
x: Input tensor [batch, seq, hidden_size]
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
Output tensor [batch, seq, hidden_size]
|
| 55 |
+
"""
|
| 56 |
+
# Up project and apply GELU
|
| 57 |
+
hidden = self.up_proj(x)
|
| 58 |
+
hidden = F.gelu(hidden, approximate="tanh")
|
| 59 |
+
|
| 60 |
+
# Down project
|
| 61 |
+
output = self.down_proj(hidden)
|
| 62 |
+
|
| 63 |
+
# Apply dropout during training
|
| 64 |
+
if self.training and self.dropout > 0:
|
| 65 |
+
output = F.dropout(output, p=self.dropout)
|
| 66 |
+
|
| 67 |
+
return output
|
src/model/kv_cache.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Explicit KV Cache management for efficient inference.
|
| 3 |
+
This is critical for Qualcomm deployment and agent control loops.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from typing import Optional, Tuple
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class KVCache:
|
| 13 |
+
"""Key-Value cache for transformer inference.
|
| 14 |
+
|
| 15 |
+
Layout: [num_layers, batch_size, num_heads, max_seq_len, head_dim]
|
| 16 |
+
|
| 17 |
+
This explicit cache enables:
|
| 18 |
+
- Efficient autoregressive decoding
|
| 19 |
+
- Cache offloading for memory management
|
| 20 |
+
- Sliding window attention (future)
|
| 21 |
+
- Agent control loops with cache manipulation
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
key_cache: torch.Tensor # [num_layers, batch, heads, max_len, head_dim]
|
| 25 |
+
value_cache: torch.Tensor # [num_layers, batch, heads, max_len, head_dim]
|
| 26 |
+
seq_len: int # Current sequence length in cache
|
| 27 |
+
|
| 28 |
+
@classmethod
|
| 29 |
+
def create(
|
| 30 |
+
cls,
|
| 31 |
+
num_layers: int,
|
| 32 |
+
batch_size: int,
|
| 33 |
+
num_heads: int,
|
| 34 |
+
max_seq_len: int,
|
| 35 |
+
head_dim: int,
|
| 36 |
+
dtype: torch.dtype = torch.float16,
|
| 37 |
+
device: torch.device = None,
|
| 38 |
+
) -> "KVCache":
|
| 39 |
+
"""Create an empty KV cache.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
num_layers: Number of transformer layers
|
| 43 |
+
batch_size: Batch size
|
| 44 |
+
num_heads: Number of attention heads
|
| 45 |
+
max_seq_len: Maximum sequence length
|
| 46 |
+
head_dim: Dimension per attention head
|
| 47 |
+
dtype: Data type for cache tensors
|
| 48 |
+
device: Device to create cache on
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
Initialized KVCache with zero tensors
|
| 52 |
+
"""
|
| 53 |
+
shape = (num_layers, batch_size, num_heads, max_seq_len, head_dim)
|
| 54 |
+
|
| 55 |
+
key_cache = torch.zeros(shape, dtype=dtype, device=device)
|
| 56 |
+
value_cache = torch.zeros(shape, dtype=dtype, device=device)
|
| 57 |
+
|
| 58 |
+
return cls(key_cache=key_cache, value_cache=value_cache, seq_len=0)
|
| 59 |
+
|
| 60 |
+
def update(
|
| 61 |
+
self,
|
| 62 |
+
layer_idx: int,
|
| 63 |
+
key: torch.Tensor,
|
| 64 |
+
value: torch.Tensor,
|
| 65 |
+
position: int,
|
| 66 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 67 |
+
"""Update cache for a specific layer and return full K, V.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
layer_idx: Index of the transformer layer
|
| 71 |
+
key: New key tensor [batch, heads, seq_len, head_dim]
|
| 72 |
+
value: New value tensor [batch, heads, seq_len, head_dim]
|
| 73 |
+
position: Starting position for the new tokens
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
Tuple of (full_key, full_value) including cached values
|
| 77 |
+
"""
|
| 78 |
+
seq_len = key.shape[2]
|
| 79 |
+
end_pos = position + seq_len
|
| 80 |
+
|
| 81 |
+
# Store new keys and values
|
| 82 |
+
self.key_cache[layer_idx, :, :, position:end_pos, :] = key
|
| 83 |
+
self.value_cache[layer_idx, :, :, position:end_pos, :] = value
|
| 84 |
+
|
| 85 |
+
# Update sequence length
|
| 86 |
+
self.seq_len = max(self.seq_len, end_pos)
|
| 87 |
+
|
| 88 |
+
# Return full K, V up to current position
|
| 89 |
+
return (
|
| 90 |
+
self.key_cache[layer_idx, :, :, :end_pos, :],
|
| 91 |
+
self.value_cache[layer_idx, :, :, :end_pos, :],
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
def get(
|
| 95 |
+
self,
|
| 96 |
+
layer_idx: int,
|
| 97 |
+
end_pos: Optional[int] = None,
|
| 98 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 99 |
+
"""Get cached K, V for a specific layer.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
layer_idx: Index of the transformer layer
|
| 103 |
+
end_pos: End position (defaults to current seq_len)
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
Tuple of (key, value) tensors
|
| 107 |
+
"""
|
| 108 |
+
if end_pos is None:
|
| 109 |
+
end_pos = self.seq_len
|
| 110 |
+
|
| 111 |
+
return (
|
| 112 |
+
self.key_cache[layer_idx, :, :, :end_pos, :],
|
| 113 |
+
self.value_cache[layer_idx, :, :, :end_pos, :],
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
def reset(self):
|
| 117 |
+
"""Reset the cache to empty state."""
|
| 118 |
+
self.key_cache.zero_()
|
| 119 |
+
self.value_cache.zero_()
|
| 120 |
+
self.seq_len = 0
|
| 121 |
+
|
| 122 |
+
@property
|
| 123 |
+
def memory_usage_mb(self) -> float:
|
| 124 |
+
"""Calculate memory usage in megabytes."""
|
| 125 |
+
total_bytes = self.key_cache.numel() * self.key_cache.element_size()
|
| 126 |
+
total_bytes += self.value_cache.numel() * self.value_cache.element_size()
|
| 127 |
+
return total_bytes / (1024 * 1024)
|
src/model/normalization.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
RMSNorm implementation for SLM.
|
| 3 |
+
Pre-norm architecture for stable FP16 training and better quantization.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class RMSNorm(nn.Module):
|
| 11 |
+
"""Root Mean Square Layer Normalization.
|
| 12 |
+
|
| 13 |
+
RMSNorm is computationally simpler than LayerNorm as it doesn't
|
| 14 |
+
compute mean statistics. This makes it:
|
| 15 |
+
- Faster to compute
|
| 16 |
+
- More stable in FP16
|
| 17 |
+
- Better for quantization
|
| 18 |
+
|
| 19 |
+
Reference: https://arxiv.org/abs/1910.07467
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, hidden_size: int, eps: float = 1e-6):
|
| 23 |
+
"""Initialize RMSNorm.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
hidden_size: The size of the hidden dimension
|
| 27 |
+
eps: Small constant for numerical stability
|
| 28 |
+
"""
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 31 |
+
self.eps = eps
|
| 32 |
+
|
| 33 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 34 |
+
"""Apply RMS normalization.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
x: Input tensor of shape [..., hidden_size]
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
Normalized tensor of same shape
|
| 41 |
+
"""
|
| 42 |
+
# Compute RMS: sqrt(mean(x^2))
|
| 43 |
+
# Use float32 for numerical stability, then cast back
|
| 44 |
+
input_dtype = x.dtype
|
| 45 |
+
x = x.float()
|
| 46 |
+
|
| 47 |
+
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
| 48 |
+
x = x * torch.rsqrt(variance + self.eps)
|
| 49 |
+
|
| 50 |
+
return (self.weight * x).to(input_dtype)
|
src/model/rope.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Rotary Position Embedding (RoPE) implementation.
|
| 3 |
+
Applied to Q and K only, with fixed base (no dynamic scaling).
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from typing import Tuple
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class RotaryEmbedding(nn.Module):
|
| 12 |
+
"""Rotary Position Embedding (RoPE).
|
| 13 |
+
|
| 14 |
+
RoPE encodes position information by rotating the query and key vectors.
|
| 15 |
+
Key properties:
|
| 16 |
+
- Parameter-free (no learnable embeddings)
|
| 17 |
+
- Naturally encodes relative positions
|
| 18 |
+
- Extrapolates well to longer sequences
|
| 19 |
+
|
| 20 |
+
Reference: https://arxiv.org/abs/2104.09864
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
dim: int,
|
| 26 |
+
max_position_embeddings: int = 1024,
|
| 27 |
+
base: float = 10000.0,
|
| 28 |
+
):
|
| 29 |
+
"""Initialize RoPE.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
dim: Dimension of the rotary embedding (usually head_dim)
|
| 33 |
+
max_position_embeddings: Maximum sequence length
|
| 34 |
+
base: Base for the frequency computation
|
| 35 |
+
"""
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.dim = dim
|
| 38 |
+
self.max_position_embeddings = max_position_embeddings
|
| 39 |
+
self.base = base
|
| 40 |
+
|
| 41 |
+
# Precompute inverse frequencies
|
| 42 |
+
inv_freq = 1.0 / (
|
| 43 |
+
self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)
|
| 44 |
+
)
|
| 45 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 46 |
+
|
| 47 |
+
# Precompute cos and sin for all positions
|
| 48 |
+
self._set_cos_sin_cache(max_position_embeddings)
|
| 49 |
+
|
| 50 |
+
def _set_cos_sin_cache(self, seq_len: int):
|
| 51 |
+
"""Precompute cos and sin values for positions."""
|
| 52 |
+
self.max_seq_len_cached = seq_len
|
| 53 |
+
t = torch.arange(seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
| 54 |
+
|
| 55 |
+
# Outer product: [seq_len] x [dim/2] -> [seq_len, dim/2]
|
| 56 |
+
freqs = torch.outer(t, self.inv_freq)
|
| 57 |
+
|
| 58 |
+
# Concatenate to get [seq_len, dim]
|
| 59 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 60 |
+
|
| 61 |
+
self.register_buffer("cos_cached", emb.cos(), persistent=False)
|
| 62 |
+
self.register_buffer("sin_cached", emb.sin(), persistent=False)
|
| 63 |
+
|
| 64 |
+
def forward(
|
| 65 |
+
self,
|
| 66 |
+
q: torch.Tensor,
|
| 67 |
+
k: torch.Tensor,
|
| 68 |
+
position_ids: torch.Tensor,
|
| 69 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 70 |
+
"""Apply rotary embeddings to query and key tensors.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
q: Query tensor of shape [batch, num_heads, seq_len, head_dim]
|
| 74 |
+
k: Key tensor of shape [batch, num_heads, seq_len, head_dim]
|
| 75 |
+
position_ids: Position indices of shape [batch, seq_len]
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
Tuple of (rotated_q, rotated_k) with same shapes as inputs
|
| 79 |
+
"""
|
| 80 |
+
seq_len = position_ids.max() + 1
|
| 81 |
+
|
| 82 |
+
# Extend cache if needed
|
| 83 |
+
if seq_len > self.max_seq_len_cached:
|
| 84 |
+
self._set_cos_sin_cache(seq_len)
|
| 85 |
+
|
| 86 |
+
# Get cos and sin for the positions
|
| 87 |
+
# Shape: [batch, seq_len, dim]
|
| 88 |
+
cos = self.cos_cached[position_ids]
|
| 89 |
+
sin = self.sin_cached[position_ids]
|
| 90 |
+
|
| 91 |
+
# Add head dimension: [batch, 1, seq_len, dim]
|
| 92 |
+
cos = cos.unsqueeze(1)
|
| 93 |
+
sin = sin.unsqueeze(1)
|
| 94 |
+
|
| 95 |
+
# Apply rotation
|
| 96 |
+
q_embed = (q * cos) + (self._rotate_half(q) * sin)
|
| 97 |
+
k_embed = (k * cos) + (self._rotate_half(k) * sin)
|
| 98 |
+
|
| 99 |
+
return q_embed, k_embed
|
| 100 |
+
|
| 101 |
+
@staticmethod
|
| 102 |
+
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
|
| 103 |
+
"""Rotate half the hidden dims of the input.
|
| 104 |
+
|
| 105 |
+
Splits the input into two halves and rotates:
|
| 106 |
+
[x1, x2, x3, x4] -> [-x3, -x4, x1, x2]
|
| 107 |
+
"""
|
| 108 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 109 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 110 |
+
return torch.cat((-x2, x1), dim=-1)
|
src/model/transformer.py
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Full Transformer model for SLM.
|
| 3 |
+
Implements the mandatory prefill/decode API for Qualcomm deployment.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from typing import Optional, Tuple, Union
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
|
| 11 |
+
from .config import SLMConfig
|
| 12 |
+
from .normalization import RMSNorm
|
| 13 |
+
from .decoder import DecoderBlock
|
| 14 |
+
from .attention import create_causal_mask
|
| 15 |
+
from .kv_cache import KVCache
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class SLMOutput:
|
| 20 |
+
"""Output from SLM forward pass."""
|
| 21 |
+
|
| 22 |
+
logits: torch.Tensor # [batch, seq, vocab_size]
|
| 23 |
+
kv_cache: Optional[KVCache] = None
|
| 24 |
+
hidden_states: Optional[torch.Tensor] = None
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class SLMModel(nn.Module):
|
| 28 |
+
"""Core transformer model (without LM head).
|
| 29 |
+
|
| 30 |
+
This is the decoder stack:
|
| 31 |
+
- Token embedding
|
| 32 |
+
- N decoder blocks
|
| 33 |
+
- Final RMSNorm
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(self, config: SLMConfig):
|
| 37 |
+
"""Initialize transformer model.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
config: Model configuration
|
| 41 |
+
"""
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.config = config
|
| 44 |
+
|
| 45 |
+
# Token embeddings
|
| 46 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
| 47 |
+
|
| 48 |
+
# Decoder layers
|
| 49 |
+
self.layers = nn.ModuleList([
|
| 50 |
+
DecoderBlock(config, layer_idx=i)
|
| 51 |
+
for i in range(config.num_layers)
|
| 52 |
+
])
|
| 53 |
+
|
| 54 |
+
# Final normalization
|
| 55 |
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 56 |
+
|
| 57 |
+
def forward(
|
| 58 |
+
self,
|
| 59 |
+
input_ids: torch.Tensor,
|
| 60 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 61 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 62 |
+
kv_cache: Optional[KVCache] = None,
|
| 63 |
+
use_cache: bool = False,
|
| 64 |
+
) -> Tuple[torch.Tensor, Optional[KVCache]]:
|
| 65 |
+
"""Forward pass through transformer.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
input_ids: Token IDs [batch, seq]
|
| 69 |
+
position_ids: Position indices [batch, seq]
|
| 70 |
+
attention_mask: Causal mask
|
| 71 |
+
kv_cache: Optional KV cache
|
| 72 |
+
use_cache: Whether to use/update cache
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
Tuple of (hidden_states, kv_cache)
|
| 76 |
+
"""
|
| 77 |
+
batch_size, seq_len = input_ids.shape
|
| 78 |
+
|
| 79 |
+
# Create position IDs if not provided
|
| 80 |
+
if position_ids is None:
|
| 81 |
+
if kv_cache is not None and kv_cache.seq_len > 0:
|
| 82 |
+
# For decode: position is the current cache length
|
| 83 |
+
position_ids = torch.arange(
|
| 84 |
+
kv_cache.seq_len, kv_cache.seq_len + seq_len,
|
| 85 |
+
device=input_ids.device
|
| 86 |
+
).unsqueeze(0).expand(batch_size, -1)
|
| 87 |
+
else:
|
| 88 |
+
# For prefill: positions are 0..seq_len-1
|
| 89 |
+
position_ids = torch.arange(
|
| 90 |
+
seq_len, device=input_ids.device
|
| 91 |
+
).unsqueeze(0).expand(batch_size, -1)
|
| 92 |
+
|
| 93 |
+
# Create attention mask if not provided
|
| 94 |
+
if attention_mask is None:
|
| 95 |
+
kv_seq_len = seq_len
|
| 96 |
+
if kv_cache is not None and kv_cache.seq_len > 0:
|
| 97 |
+
kv_seq_len = kv_cache.seq_len + seq_len
|
| 98 |
+
|
| 99 |
+
attention_mask = create_causal_mask(
|
| 100 |
+
seq_len=seq_len,
|
| 101 |
+
kv_seq_len=kv_seq_len,
|
| 102 |
+
dtype=self.embed_tokens.weight.dtype,
|
| 103 |
+
device=input_ids.device,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
# Token embeddings
|
| 107 |
+
hidden_states = self.embed_tokens(input_ids)
|
| 108 |
+
|
| 109 |
+
# Pass through decoder layers
|
| 110 |
+
for layer in self.layers:
|
| 111 |
+
hidden_states, kv_cache = layer(
|
| 112 |
+
hidden_states=hidden_states,
|
| 113 |
+
position_ids=position_ids,
|
| 114 |
+
attention_mask=attention_mask,
|
| 115 |
+
kv_cache=kv_cache,
|
| 116 |
+
use_cache=use_cache,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# Final normalization
|
| 120 |
+
hidden_states = self.norm(hidden_states)
|
| 121 |
+
|
| 122 |
+
return hidden_states, kv_cache
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class SLMForCausalLM(nn.Module):
|
| 126 |
+
"""SLM with language modeling head.
|
| 127 |
+
|
| 128 |
+
This is the full model with:
|
| 129 |
+
- Transformer backbone
|
| 130 |
+
- LM head (tied with embeddings)
|
| 131 |
+
- Prefill/Decode API for Qualcomm deployment
|
| 132 |
+
"""
|
| 133 |
+
|
| 134 |
+
def __init__(self, config: SLMConfig):
|
| 135 |
+
"""Initialize causal LM.
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
config: Model configuration
|
| 139 |
+
"""
|
| 140 |
+
super().__init__()
|
| 141 |
+
self.config = config
|
| 142 |
+
|
| 143 |
+
# Transformer backbone
|
| 144 |
+
self.model = SLMModel(config)
|
| 145 |
+
|
| 146 |
+
# LM head
|
| 147 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 148 |
+
|
| 149 |
+
# Tie weights if configured
|
| 150 |
+
if config.tie_word_embeddings:
|
| 151 |
+
self.lm_head.weight = self.model.embed_tokens.weight
|
| 152 |
+
|
| 153 |
+
# Initialize weights
|
| 154 |
+
self.apply(self._init_weights)
|
| 155 |
+
|
| 156 |
+
def _init_weights(self, module: nn.Module):
|
| 157 |
+
"""Initialize model weights."""
|
| 158 |
+
std = 0.02
|
| 159 |
+
if isinstance(module, nn.Linear):
|
| 160 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 161 |
+
if module.bias is not None:
|
| 162 |
+
module.bias.data.zero_()
|
| 163 |
+
elif isinstance(module, nn.Embedding):
|
| 164 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 165 |
+
|
| 166 |
+
def forward(
|
| 167 |
+
self,
|
| 168 |
+
input_ids: torch.Tensor,
|
| 169 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 170 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 171 |
+
kv_cache: Optional[KVCache] = None,
|
| 172 |
+
use_cache: bool = False,
|
| 173 |
+
labels: Optional[torch.Tensor] = None,
|
| 174 |
+
) -> SLMOutput:
|
| 175 |
+
"""Forward pass for causal LM.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
input_ids: Token IDs [batch, seq]
|
| 179 |
+
position_ids: Position indices [batch, seq]
|
| 180 |
+
attention_mask: Causal mask
|
| 181 |
+
kv_cache: Optional KV cache
|
| 182 |
+
use_cache: Whether to use/update cache
|
| 183 |
+
labels: Optional labels for loss computation
|
| 184 |
+
|
| 185 |
+
Returns:
|
| 186 |
+
SLMOutput with logits and optional loss
|
| 187 |
+
"""
|
| 188 |
+
# Get hidden states from transformer
|
| 189 |
+
hidden_states, kv_cache = self.model(
|
| 190 |
+
input_ids=input_ids,
|
| 191 |
+
position_ids=position_ids,
|
| 192 |
+
attention_mask=attention_mask,
|
| 193 |
+
kv_cache=kv_cache,
|
| 194 |
+
use_cache=use_cache,
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
# Compute logits
|
| 198 |
+
logits = self.lm_head(hidden_states)
|
| 199 |
+
|
| 200 |
+
return SLMOutput(
|
| 201 |
+
logits=logits,
|
| 202 |
+
kv_cache=kv_cache,
|
| 203 |
+
hidden_states=hidden_states,
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
# =========================================================================
|
| 207 |
+
# MANDATORY KV CACHE API (from architecture.txt)
|
| 208 |
+
# =========================================================================
|
| 209 |
+
|
| 210 |
+
def prefill(
|
| 211 |
+
self,
|
| 212 |
+
input_ids: torch.Tensor,
|
| 213 |
+
kv_cache: Optional[KVCache] = None,
|
| 214 |
+
) -> Tuple[torch.Tensor, KVCache]:
|
| 215 |
+
"""Prefill: Process full prompt and populate KV cache.
|
| 216 |
+
|
| 217 |
+
This is Graph 1 for Qualcomm deployment.
|
| 218 |
+
|
| 219 |
+
Args:
|
| 220 |
+
input_ids: Token IDs [batch, seq]
|
| 221 |
+
kv_cache: Empty or existing KV cache
|
| 222 |
+
|
| 223 |
+
Returns:
|
| 224 |
+
Tuple of (logits [batch, seq, vocab], populated_kv_cache)
|
| 225 |
+
"""
|
| 226 |
+
batch_size = input_ids.shape[0]
|
| 227 |
+
|
| 228 |
+
# Create empty cache if not provided
|
| 229 |
+
if kv_cache is None:
|
| 230 |
+
kv_cache = KVCache.create(
|
| 231 |
+
num_layers=self.config.num_layers,
|
| 232 |
+
batch_size=batch_size,
|
| 233 |
+
num_heads=self.config.num_heads,
|
| 234 |
+
max_seq_len=self.config.max_position_embeddings,
|
| 235 |
+
head_dim=self.config.head_dim,
|
| 236 |
+
dtype=self.model.embed_tokens.weight.dtype,
|
| 237 |
+
device=input_ids.device,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
# Forward pass with cache
|
| 241 |
+
output = self.forward(
|
| 242 |
+
input_ids=input_ids,
|
| 243 |
+
kv_cache=kv_cache,
|
| 244 |
+
use_cache=True,
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
return output.logits, output.kv_cache
|
| 248 |
+
|
| 249 |
+
def decode(
|
| 250 |
+
self,
|
| 251 |
+
input_id: torch.Tensor,
|
| 252 |
+
kv_cache: KVCache,
|
| 253 |
+
position: Optional[int] = None,
|
| 254 |
+
) -> Tuple[torch.Tensor, KVCache]:
|
| 255 |
+
"""Decode: Generate single token using KV cache.
|
| 256 |
+
|
| 257 |
+
This is Graph 2 for Qualcomm deployment.
|
| 258 |
+
|
| 259 |
+
Args:
|
| 260 |
+
input_id: Single token ID [batch, 1]
|
| 261 |
+
kv_cache: Populated KV cache from prefill or previous decode
|
| 262 |
+
position: Position index (defaults to cache.seq_len)
|
| 263 |
+
|
| 264 |
+
Returns:
|
| 265 |
+
Tuple of (logits [batch, 1, vocab], updated_kv_cache)
|
| 266 |
+
"""
|
| 267 |
+
batch_size = input_id.shape[0]
|
| 268 |
+
|
| 269 |
+
# Get position from cache if not provided
|
| 270 |
+
if position is None:
|
| 271 |
+
position = kv_cache.seq_len
|
| 272 |
+
|
| 273 |
+
# Create position IDs
|
| 274 |
+
position_ids = torch.tensor(
|
| 275 |
+
[[position]], device=input_id.device
|
| 276 |
+
).expand(batch_size, -1)
|
| 277 |
+
|
| 278 |
+
# Forward pass with cache
|
| 279 |
+
output = self.forward(
|
| 280 |
+
input_ids=input_id,
|
| 281 |
+
position_ids=position_ids,
|
| 282 |
+
kv_cache=kv_cache,
|
| 283 |
+
use_cache=True,
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
return output.logits, output.kv_cache
|
| 287 |
+
|
| 288 |
+
def create_empty_cache(
|
| 289 |
+
self,
|
| 290 |
+
batch_size: int = 1,
|
| 291 |
+
device: torch.device = None,
|
| 292 |
+
) -> KVCache:
|
| 293 |
+
"""Create an empty KV cache for inference.
|
| 294 |
+
|
| 295 |
+
Args:
|
| 296 |
+
batch_size: Batch size
|
| 297 |
+
device: Device for cache tensors
|
| 298 |
+
|
| 299 |
+
Returns:
|
| 300 |
+
Empty KVCache ready for prefill
|
| 301 |
+
"""
|
| 302 |
+
if device is None:
|
| 303 |
+
device = self.model.embed_tokens.weight.device
|
| 304 |
+
|
| 305 |
+
return KVCache.create(
|
| 306 |
+
num_layers=self.config.num_layers,
|
| 307 |
+
batch_size=batch_size,
|
| 308 |
+
num_heads=self.config.num_heads,
|
| 309 |
+
max_seq_len=self.config.max_position_embeddings,
|
| 310 |
+
head_dim=self.config.head_dim,
|
| 311 |
+
dtype=self.model.embed_tokens.weight.dtype,
|
| 312 |
+
device=device,
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
@property
|
| 316 |
+
def num_parameters(self) -> int:
|
| 317 |
+
"""Count total trainable parameters."""
|
| 318 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 319 |
+
|
| 320 |
+
@property
|
| 321 |
+
def device(self) -> torch.device:
|
| 322 |
+
"""Get model device."""
|
| 323 |
+
return self.model.embed_tokens.weight.device
|
src/training/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Training components
|
| 2 |
+
|
| 3 |
+
from .loss import LanguageModelingLoss, compute_perplexity, compute_accuracy
|
| 4 |
+
from .optimizer import create_optimizer, create_scheduler, clip_grad_norm
|
| 5 |
+
from .trainer import Trainer, TrainingConfig
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
"LanguageModelingLoss",
|
| 9 |
+
"compute_perplexity",
|
| 10 |
+
"compute_accuracy",
|
| 11 |
+
"create_optimizer",
|
| 12 |
+
"create_scheduler",
|
| 13 |
+
"clip_grad_norm",
|
| 14 |
+
"Trainer",
|
| 15 |
+
"TrainingConfig",
|
| 16 |
+
]
|
src/training/loss.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Loss functions for SLM training.
|
| 3 |
+
|
| 4 |
+
Cross-entropy loss with optional label smoothing.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from typing import Optional
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class LanguageModelingLoss(nn.Module):
|
| 14 |
+
"""Cross-entropy loss for language modeling.
|
| 15 |
+
|
| 16 |
+
Handles:
|
| 17 |
+
- Automatic shifting of labels
|
| 18 |
+
- Ignoring padding tokens (-100)
|
| 19 |
+
- Optional label smoothing
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
vocab_size: int,
|
| 25 |
+
label_smoothing: float = 0.0,
|
| 26 |
+
ignore_index: int = -100,
|
| 27 |
+
):
|
| 28 |
+
"""Initialize loss function.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
vocab_size: Size of vocabulary
|
| 32 |
+
label_smoothing: Label smoothing factor (0.0 = no smoothing)
|
| 33 |
+
ignore_index: Index to ignore in loss calculation (padding)
|
| 34 |
+
"""
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.vocab_size = vocab_size
|
| 37 |
+
self.label_smoothing = label_smoothing
|
| 38 |
+
self.ignore_index = ignore_index
|
| 39 |
+
|
| 40 |
+
self.ce_loss = nn.CrossEntropyLoss(
|
| 41 |
+
ignore_index=ignore_index,
|
| 42 |
+
label_smoothing=label_smoothing,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
def forward(
|
| 46 |
+
self,
|
| 47 |
+
logits: torch.Tensor,
|
| 48 |
+
labels: torch.Tensor,
|
| 49 |
+
shift_labels: bool = True,
|
| 50 |
+
) -> torch.Tensor:
|
| 51 |
+
"""Compute loss.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
logits: Model output logits [batch_size, seq_len, vocab_size]
|
| 55 |
+
labels: Target token IDs [batch_size, seq_len]
|
| 56 |
+
shift_labels: Whether to shift labels (for autoregressive LM)
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
Scalar loss tensor
|
| 60 |
+
"""
|
| 61 |
+
if shift_labels:
|
| 62 |
+
# Shift so we predict next token
|
| 63 |
+
# logits: predict tokens 1..n
|
| 64 |
+
# labels: actual tokens 1..n
|
| 65 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 66 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 67 |
+
else:
|
| 68 |
+
shift_logits = logits
|
| 69 |
+
shift_labels = labels
|
| 70 |
+
|
| 71 |
+
# Flatten for cross-entropy
|
| 72 |
+
# [batch * seq_len, vocab_size]
|
| 73 |
+
flat_logits = shift_logits.view(-1, self.vocab_size)
|
| 74 |
+
# [batch * seq_len]
|
| 75 |
+
flat_labels = shift_labels.view(-1)
|
| 76 |
+
|
| 77 |
+
loss = self.ce_loss(flat_logits, flat_labels)
|
| 78 |
+
|
| 79 |
+
return loss
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def compute_perplexity(loss: torch.Tensor) -> torch.Tensor:
|
| 83 |
+
"""Compute perplexity from cross-entropy loss.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
loss: Cross-entropy loss value
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
Perplexity (exp of loss)
|
| 90 |
+
"""
|
| 91 |
+
return torch.exp(loss)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def compute_accuracy(
|
| 95 |
+
logits: torch.Tensor,
|
| 96 |
+
labels: torch.Tensor,
|
| 97 |
+
ignore_index: int = -100,
|
| 98 |
+
) -> torch.Tensor:
|
| 99 |
+
"""Compute token prediction accuracy.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
logits: Model output logits [batch_size, seq_len, vocab_size]
|
| 103 |
+
labels: Target token IDs [batch_size, seq_len]
|
| 104 |
+
ignore_index: Index to ignore in accuracy calculation
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
Accuracy as a scalar tensor
|
| 108 |
+
"""
|
| 109 |
+
# Shift for autoregressive prediction
|
| 110 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 111 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 112 |
+
|
| 113 |
+
# Get predictions
|
| 114 |
+
predictions = shift_logits.argmax(dim=-1)
|
| 115 |
+
|
| 116 |
+
# Mask for valid positions
|
| 117 |
+
mask = shift_labels != ignore_index
|
| 118 |
+
|
| 119 |
+
# Compute accuracy on valid positions
|
| 120 |
+
correct = (predictions == shift_labels) & mask
|
| 121 |
+
accuracy = correct.sum().float() / mask.sum().float()
|
| 122 |
+
|
| 123 |
+
return accuracy
|
src/training/optimizer.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Optimizer and learning rate scheduler for SLM training.
|
| 3 |
+
|
| 4 |
+
Uses AdamW with cosine annealing and warmup.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
from typing import Optional, Tuple, List
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torch.optim import AdamW
|
| 12 |
+
from torch.optim.lr_scheduler import LambdaLR
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def create_optimizer(
|
| 16 |
+
model: torch.nn.Module,
|
| 17 |
+
learning_rate: float = 3e-4,
|
| 18 |
+
weight_decay: float = 0.1,
|
| 19 |
+
betas: Tuple[float, float] = (0.9, 0.95),
|
| 20 |
+
eps: float = 1e-8,
|
| 21 |
+
) -> AdamW:
|
| 22 |
+
"""Create AdamW optimizer with weight decay.
|
| 23 |
+
|
| 24 |
+
Applies weight decay only to 2D parameters (weights, not biases/norms).
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
model: The model to optimize
|
| 28 |
+
learning_rate: Base learning rate
|
| 29 |
+
weight_decay: Weight decay coefficient
|
| 30 |
+
betas: Adam beta parameters
|
| 31 |
+
eps: Adam epsilon for numerical stability
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
Configured AdamW optimizer
|
| 35 |
+
"""
|
| 36 |
+
# Separate parameters into decay and no-decay groups
|
| 37 |
+
decay_params = []
|
| 38 |
+
no_decay_params = []
|
| 39 |
+
|
| 40 |
+
for name, param in model.named_parameters():
|
| 41 |
+
if not param.requires_grad:
|
| 42 |
+
continue
|
| 43 |
+
|
| 44 |
+
# No weight decay for:
|
| 45 |
+
# - 1D parameters (biases, layer norms)
|
| 46 |
+
# - Embedding layers
|
| 47 |
+
if param.dim() == 1 or "embedding" in name.lower():
|
| 48 |
+
no_decay_params.append(param)
|
| 49 |
+
else:
|
| 50 |
+
decay_params.append(param)
|
| 51 |
+
|
| 52 |
+
param_groups = [
|
| 53 |
+
{"params": decay_params, "weight_decay": weight_decay},
|
| 54 |
+
{"params": no_decay_params, "weight_decay": 0.0},
|
| 55 |
+
]
|
| 56 |
+
|
| 57 |
+
optimizer = AdamW(
|
| 58 |
+
param_groups,
|
| 59 |
+
lr=learning_rate,
|
| 60 |
+
betas=betas,
|
| 61 |
+
eps=eps,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
return optimizer
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def create_scheduler(
|
| 68 |
+
optimizer: torch.optim.Optimizer,
|
| 69 |
+
num_training_steps: int,
|
| 70 |
+
warmup_ratio: float = 0.1,
|
| 71 |
+
min_lr_ratio: float = 0.1,
|
| 72 |
+
scheduler_type: str = "cosine",
|
| 73 |
+
) -> LambdaLR:
|
| 74 |
+
"""Create learning rate scheduler.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
optimizer: The optimizer to schedule
|
| 78 |
+
num_training_steps: Total number of training steps
|
| 79 |
+
warmup_ratio: Ratio of warmup steps (e.g., 0.1 = 10%)
|
| 80 |
+
min_lr_ratio: Minimum LR as ratio of max (e.g., 0.1 = 10% of peak LR)
|
| 81 |
+
scheduler_type: Type of scheduler ("cosine", "linear", "constant")
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
LambdaLR scheduler
|
| 85 |
+
"""
|
| 86 |
+
num_warmup_steps = int(num_training_steps * warmup_ratio)
|
| 87 |
+
|
| 88 |
+
if scheduler_type == "cosine":
|
| 89 |
+
def lr_lambda(current_step: int) -> float:
|
| 90 |
+
# Warmup phase
|
| 91 |
+
if current_step < num_warmup_steps:
|
| 92 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
| 93 |
+
|
| 94 |
+
# Cosine annealing phase
|
| 95 |
+
progress = float(current_step - num_warmup_steps) / float(
|
| 96 |
+
max(1, num_training_steps - num_warmup_steps)
|
| 97 |
+
)
|
| 98 |
+
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
|
| 99 |
+
|
| 100 |
+
# Scale between min_lr_ratio and 1.0
|
| 101 |
+
return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_decay
|
| 102 |
+
|
| 103 |
+
elif scheduler_type == "linear":
|
| 104 |
+
def lr_lambda(current_step: int) -> float:
|
| 105 |
+
if current_step < num_warmup_steps:
|
| 106 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
| 107 |
+
|
| 108 |
+
progress = float(current_step - num_warmup_steps) / float(
|
| 109 |
+
max(1, num_training_steps - num_warmup_steps)
|
| 110 |
+
)
|
| 111 |
+
return max(min_lr_ratio, 1.0 - progress * (1.0 - min_lr_ratio))
|
| 112 |
+
|
| 113 |
+
elif scheduler_type == "constant":
|
| 114 |
+
def lr_lambda(current_step: int) -> float:
|
| 115 |
+
if current_step < num_warmup_steps:
|
| 116 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
| 117 |
+
return 1.0
|
| 118 |
+
|
| 119 |
+
else:
|
| 120 |
+
raise ValueError(f"Unknown scheduler type: {scheduler_type}")
|
| 121 |
+
|
| 122 |
+
return LambdaLR(optimizer, lr_lambda)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def get_parameter_count(model: torch.nn.Module) -> dict:
|
| 126 |
+
"""Get detailed parameter count for a model.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
model: The model to analyze
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
Dictionary with parameter counts
|
| 133 |
+
"""
|
| 134 |
+
total_params = 0
|
| 135 |
+
trainable_params = 0
|
| 136 |
+
embedding_params = 0
|
| 137 |
+
|
| 138 |
+
for name, param in model.named_parameters():
|
| 139 |
+
num_params = param.numel()
|
| 140 |
+
total_params += num_params
|
| 141 |
+
|
| 142 |
+
if param.requires_grad:
|
| 143 |
+
trainable_params += num_params
|
| 144 |
+
|
| 145 |
+
if "embedding" in name.lower():
|
| 146 |
+
embedding_params += num_params
|
| 147 |
+
|
| 148 |
+
return {
|
| 149 |
+
"total": total_params,
|
| 150 |
+
"trainable": trainable_params,
|
| 151 |
+
"embedding": embedding_params,
|
| 152 |
+
"non_embedding": total_params - embedding_params,
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def get_optimizer_state(optimizer: torch.optim.Optimizer) -> dict:
|
| 157 |
+
"""Get optimizer state statistics.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
optimizer: The optimizer to analyze
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
Dictionary with optimizer state info
|
| 164 |
+
"""
|
| 165 |
+
num_params = sum(
|
| 166 |
+
sum(p.numel() for p in group["params"])
|
| 167 |
+
for group in optimizer.param_groups
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
current_lrs = [group["lr"] for group in optimizer.param_groups]
|
| 171 |
+
|
| 172 |
+
return {
|
| 173 |
+
"num_param_groups": len(optimizer.param_groups),
|
| 174 |
+
"total_params": num_params,
|
| 175 |
+
"learning_rates": current_lrs,
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def clip_grad_norm(
|
| 180 |
+
model: torch.nn.Module,
|
| 181 |
+
max_norm: float = 1.0,
|
| 182 |
+
) -> float:
|
| 183 |
+
"""Clip gradient norm and return the norm value.
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
model: The model with gradients
|
| 187 |
+
max_norm: Maximum gradient norm
|
| 188 |
+
|
| 189 |
+
Returns:
|
| 190 |
+
The gradient norm before clipping
|
| 191 |
+
"""
|
| 192 |
+
parameters = [p for p in model.parameters() if p.grad is not None]
|
| 193 |
+
if len(parameters) == 0:
|
| 194 |
+
return 0.0
|
| 195 |
+
|
| 196 |
+
total_norm = torch.nn.utils.clip_grad_norm_(parameters, max_norm)
|
| 197 |
+
return total_norm.item()
|
src/training/trainer.py
ADDED
|
@@ -0,0 +1,511 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Training loop for SLM.
|
| 3 |
+
|
| 4 |
+
Handles the complete training process including:
|
| 5 |
+
- Mixed precision training
|
| 6 |
+
- Gradient accumulation
|
| 7 |
+
- Checkpointing
|
| 8 |
+
- Logging
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
import time
|
| 13 |
+
import json
|
| 14 |
+
from dataclasses import dataclass, asdict
|
| 15 |
+
from typing import Optional, Dict, Any, Callable
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
from torch.utils.data import DataLoader
|
| 21 |
+
from torch.cuda.amp import autocast, GradScaler
|
| 22 |
+
from tqdm import tqdm
|
| 23 |
+
|
| 24 |
+
from .loss import LanguageModelingLoss, compute_perplexity, compute_accuracy
|
| 25 |
+
from .optimizer import create_optimizer, create_scheduler, clip_grad_norm
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class TrainingConfig:
|
| 30 |
+
"""Configuration for training."""
|
| 31 |
+
|
| 32 |
+
# Optimization
|
| 33 |
+
learning_rate: float = 3e-4
|
| 34 |
+
weight_decay: float = 0.1
|
| 35 |
+
warmup_ratio: float = 0.1
|
| 36 |
+
min_lr_ratio: float = 0.1
|
| 37 |
+
max_grad_norm: float = 1.0
|
| 38 |
+
label_smoothing: float = 0.0
|
| 39 |
+
|
| 40 |
+
# Training
|
| 41 |
+
num_epochs: int = 5
|
| 42 |
+
gradient_accumulation_steps: int = 4
|
| 43 |
+
fp16: bool = True
|
| 44 |
+
|
| 45 |
+
# Checkpointing
|
| 46 |
+
checkpoint_dir: str = "checkpoints"
|
| 47 |
+
save_steps: int = 1000
|
| 48 |
+
save_total_limit: int = 3
|
| 49 |
+
|
| 50 |
+
# Evaluation
|
| 51 |
+
eval_steps: int = 500
|
| 52 |
+
logging_steps: int = 10
|
| 53 |
+
|
| 54 |
+
# Early stopping
|
| 55 |
+
early_stopping_patience: int = 5 # Stop after N evals without improvement
|
| 56 |
+
early_stopping_threshold: float = 0.01 # Minimum improvement to reset patience
|
| 57 |
+
|
| 58 |
+
# Device
|
| 59 |
+
device: str = "auto"
|
| 60 |
+
|
| 61 |
+
# Compile model (torch.compile)
|
| 62 |
+
compile_model: bool = False
|
| 63 |
+
|
| 64 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 65 |
+
return asdict(self)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class Trainer:
|
| 69 |
+
"""Training loop for SLM model."""
|
| 70 |
+
|
| 71 |
+
def __init__(
|
| 72 |
+
self,
|
| 73 |
+
model: nn.Module,
|
| 74 |
+
config: TrainingConfig,
|
| 75 |
+
train_dataloader: DataLoader,
|
| 76 |
+
val_dataloader: Optional[DataLoader] = None,
|
| 77 |
+
wandb_project: Optional[str] = None,
|
| 78 |
+
):
|
| 79 |
+
"""Initialize trainer.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
model: The model to train
|
| 83 |
+
config: Training configuration
|
| 84 |
+
train_dataloader: Training data loader
|
| 85 |
+
val_dataloader: Optional validation data loader
|
| 86 |
+
wandb_project: Optional W&B project name for logging
|
| 87 |
+
"""
|
| 88 |
+
self.config = config
|
| 89 |
+
self.train_dataloader = train_dataloader
|
| 90 |
+
self.val_dataloader = val_dataloader
|
| 91 |
+
|
| 92 |
+
# Setup device
|
| 93 |
+
if config.device == "auto":
|
| 94 |
+
if torch.cuda.is_available():
|
| 95 |
+
self.device = torch.device("cuda")
|
| 96 |
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
| 97 |
+
self.device = torch.device("mps")
|
| 98 |
+
else:
|
| 99 |
+
self.device = torch.device("cpu")
|
| 100 |
+
else:
|
| 101 |
+
self.device = torch.device(config.device)
|
| 102 |
+
|
| 103 |
+
print(f"Training on device: {self.device}")
|
| 104 |
+
|
| 105 |
+
# Move model to device
|
| 106 |
+
self.model = model.to(self.device)
|
| 107 |
+
|
| 108 |
+
# Get vocab size from model
|
| 109 |
+
if hasattr(model, "config"):
|
| 110 |
+
self.vocab_size = model.config.vocab_size
|
| 111 |
+
else:
|
| 112 |
+
self.vocab_size = model.embed_tokens.num_embeddings
|
| 113 |
+
|
| 114 |
+
# Setup loss function
|
| 115 |
+
self.loss_fn = LanguageModelingLoss(
|
| 116 |
+
vocab_size=self.vocab_size,
|
| 117 |
+
label_smoothing=config.label_smoothing,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# Calculate total steps
|
| 121 |
+
self.steps_per_epoch = len(train_dataloader)
|
| 122 |
+
self.total_steps = self.steps_per_epoch * config.num_epochs
|
| 123 |
+
self.total_steps = self.total_steps // config.gradient_accumulation_steps
|
| 124 |
+
|
| 125 |
+
# Setup optimizer and scheduler
|
| 126 |
+
self.optimizer = create_optimizer(
|
| 127 |
+
model,
|
| 128 |
+
learning_rate=config.learning_rate,
|
| 129 |
+
weight_decay=config.weight_decay,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
self.scheduler = create_scheduler(
|
| 133 |
+
self.optimizer,
|
| 134 |
+
num_training_steps=self.total_steps,
|
| 135 |
+
warmup_ratio=config.warmup_ratio,
|
| 136 |
+
min_lr_ratio=config.min_lr_ratio,
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
# Setup mixed precision
|
| 140 |
+
self.use_amp = config.fp16 and self.device.type == "cuda"
|
| 141 |
+
self.scaler = GradScaler() if self.use_amp else None
|
| 142 |
+
|
| 143 |
+
# Tracking
|
| 144 |
+
self.global_step = 0
|
| 145 |
+
self.epoch = 0
|
| 146 |
+
self.best_val_loss = float("inf")
|
| 147 |
+
|
| 148 |
+
# Early stopping tracking
|
| 149 |
+
self.early_stopping_counter = 0
|
| 150 |
+
self.should_stop = False
|
| 151 |
+
|
| 152 |
+
# Checkpoint directory
|
| 153 |
+
os.makedirs(config.checkpoint_dir, exist_ok=True)
|
| 154 |
+
|
| 155 |
+
# W&B logging
|
| 156 |
+
self.wandb = None
|
| 157 |
+
if wandb_project:
|
| 158 |
+
try:
|
| 159 |
+
import wandb
|
| 160 |
+
wandb.init(project=wandb_project, config=config.to_dict())
|
| 161 |
+
self.wandb = wandb
|
| 162 |
+
except ImportError:
|
| 163 |
+
print("wandb not installed, skipping logging")
|
| 164 |
+
|
| 165 |
+
def train(self) -> Dict[str, Any]:
|
| 166 |
+
"""Run the full training loop.
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
Dictionary with training results
|
| 170 |
+
"""
|
| 171 |
+
print(f"\n{'='*60}")
|
| 172 |
+
print("STARTING TRAINING")
|
| 173 |
+
print(f"{'='*60}")
|
| 174 |
+
print(f"Total epochs: {self.config.num_epochs}")
|
| 175 |
+
print(f"Steps per epoch: {self.steps_per_epoch}")
|
| 176 |
+
print(f"Total optimization steps: {self.total_steps}")
|
| 177 |
+
print(f"Gradient accumulation: {self.config.gradient_accumulation_steps}")
|
| 178 |
+
print(f"Mixed precision: {self.use_amp}")
|
| 179 |
+
if self.config.early_stopping_patience > 0:
|
| 180 |
+
print(f"Early stopping: patience={self.config.early_stopping_patience}")
|
| 181 |
+
print(f"{'='*60}\n")
|
| 182 |
+
|
| 183 |
+
training_start = time.time()
|
| 184 |
+
|
| 185 |
+
# FIX: Start from loaded epoch (for resume), not always from 0
|
| 186 |
+
start_epoch = self.epoch
|
| 187 |
+
if start_epoch > 0:
|
| 188 |
+
print(f"Resuming from epoch {start_epoch + 1}")
|
| 189 |
+
|
| 190 |
+
for epoch in range(start_epoch, self.config.num_epochs):
|
| 191 |
+
self.epoch = epoch
|
| 192 |
+
epoch_loss = self._train_epoch()
|
| 193 |
+
|
| 194 |
+
print(f"\nEpoch {epoch + 1}/{self.config.num_epochs} - Loss: {epoch_loss:.4f}")
|
| 195 |
+
|
| 196 |
+
# Validation
|
| 197 |
+
if self.val_dataloader is not None:
|
| 198 |
+
val_metrics = self.evaluate()
|
| 199 |
+
print(f"Validation - Loss: {val_metrics['loss']:.4f}, PPL: {val_metrics['perplexity']:.2f}")
|
| 200 |
+
|
| 201 |
+
# Early stopping check
|
| 202 |
+
if val_metrics["loss"] < self.best_val_loss - self.config.early_stopping_threshold:
|
| 203 |
+
self.best_val_loss = val_metrics["loss"]
|
| 204 |
+
self.early_stopping_counter = 0
|
| 205 |
+
self.save_checkpoint("best")
|
| 206 |
+
print(f" New best model saved!")
|
| 207 |
+
else:
|
| 208 |
+
self.early_stopping_counter += 1
|
| 209 |
+
print(f" No improvement. Early stopping: {self.early_stopping_counter}/{self.config.early_stopping_patience}")
|
| 210 |
+
|
| 211 |
+
if self.config.early_stopping_patience > 0 and self.early_stopping_counter >= self.config.early_stopping_patience:
|
| 212 |
+
print(f"\nEarly stopping triggered after {self.early_stopping_counter} evaluations without improvement.")
|
| 213 |
+
self.should_stop = True
|
| 214 |
+
|
| 215 |
+
# Save epoch checkpoint
|
| 216 |
+
self.save_checkpoint(f"epoch_{epoch + 1}")
|
| 217 |
+
|
| 218 |
+
# Check early stopping
|
| 219 |
+
if self.should_stop:
|
| 220 |
+
print("Stopping training early.")
|
| 221 |
+
break
|
| 222 |
+
|
| 223 |
+
training_time = time.time() - training_start
|
| 224 |
+
print(f"\n{'='*60}")
|
| 225 |
+
print(f"TRAINING COMPLETE")
|
| 226 |
+
print(f"Total time: {training_time / 3600:.2f} hours")
|
| 227 |
+
print(f"Best validation loss: {self.best_val_loss:.4f}")
|
| 228 |
+
if self.should_stop:
|
| 229 |
+
print(f"Stopped early at epoch {self.epoch + 1}")
|
| 230 |
+
print(f"{'='*60}")
|
| 231 |
+
|
| 232 |
+
return {
|
| 233 |
+
"total_steps": self.global_step,
|
| 234 |
+
"training_time": training_time,
|
| 235 |
+
"best_val_loss": self.best_val_loss,
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
def _train_epoch(self) -> float:
|
| 239 |
+
"""Train for one epoch.
|
| 240 |
+
|
| 241 |
+
Returns:
|
| 242 |
+
Average training loss for the epoch
|
| 243 |
+
"""
|
| 244 |
+
self.model.train()
|
| 245 |
+
total_loss = 0.0
|
| 246 |
+
num_batches = 0
|
| 247 |
+
accumulated_loss = 0.0
|
| 248 |
+
num_accumulated_batches = 0 # FIX: Track actual number of batches for correct averaging
|
| 249 |
+
|
| 250 |
+
# Create progress bar
|
| 251 |
+
pbar = tqdm(
|
| 252 |
+
enumerate(self.train_dataloader),
|
| 253 |
+
total=len(self.train_dataloader),
|
| 254 |
+
desc=f"Epoch {self.epoch + 1}",
|
| 255 |
+
ncols=100,
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
for step, batch in pbar:
|
| 259 |
+
# Move batch to device
|
| 260 |
+
input_ids = batch["input_ids"].to(self.device)
|
| 261 |
+
labels = batch["labels"].to(self.device)
|
| 262 |
+
# Note: attention_mask from dataloader is padding mask (1/0)
|
| 263 |
+
# The model creates its own causal mask internally
|
| 264 |
+
# We handle padding via -100 labels in the loss function
|
| 265 |
+
|
| 266 |
+
# Forward pass with optional mixed precision
|
| 267 |
+
with autocast(enabled=self.use_amp):
|
| 268 |
+
outputs = self.model(input_ids)
|
| 269 |
+
# Handle different output types (tensor, tuple, or dataclass)
|
| 270 |
+
if isinstance(outputs, torch.Tensor):
|
| 271 |
+
logits = outputs
|
| 272 |
+
elif hasattr(outputs, 'logits'):
|
| 273 |
+
logits = outputs.logits
|
| 274 |
+
else:
|
| 275 |
+
logits = outputs[0]
|
| 276 |
+
loss = self.loss_fn(logits, labels)
|
| 277 |
+
loss = loss / self.config.gradient_accumulation_steps
|
| 278 |
+
|
| 279 |
+
# Backward pass
|
| 280 |
+
if self.use_amp:
|
| 281 |
+
self.scaler.scale(loss).backward()
|
| 282 |
+
else:
|
| 283 |
+
loss.backward()
|
| 284 |
+
|
| 285 |
+
# FIX: Track unscaled loss correctly
|
| 286 |
+
unscaled_loss = loss.item() * self.config.gradient_accumulation_steps
|
| 287 |
+
accumulated_loss += unscaled_loss
|
| 288 |
+
num_accumulated_batches += 1
|
| 289 |
+
total_loss += unscaled_loss
|
| 290 |
+
num_batches += 1
|
| 291 |
+
|
| 292 |
+
# Gradient accumulation
|
| 293 |
+
if (step + 1) % self.config.gradient_accumulation_steps == 0:
|
| 294 |
+
# Gradient clipping
|
| 295 |
+
if self.use_amp:
|
| 296 |
+
self.scaler.unscale_(self.optimizer)
|
| 297 |
+
|
| 298 |
+
grad_norm = clip_grad_norm(self.model, self.config.max_grad_norm)
|
| 299 |
+
|
| 300 |
+
# Optimizer step
|
| 301 |
+
if self.use_amp:
|
| 302 |
+
self.scaler.step(self.optimizer)
|
| 303 |
+
self.scaler.update()
|
| 304 |
+
else:
|
| 305 |
+
self.optimizer.step()
|
| 306 |
+
|
| 307 |
+
self.scheduler.step()
|
| 308 |
+
self.optimizer.zero_grad()
|
| 309 |
+
|
| 310 |
+
self.global_step += 1
|
| 311 |
+
|
| 312 |
+
# Logging
|
| 313 |
+
if self.global_step % self.config.logging_steps == 0:
|
| 314 |
+
# FIX: Divide by actual number of accumulated batches, not gradient_accumulation_steps
|
| 315 |
+
avg_loss = accumulated_loss / max(num_accumulated_batches, 1)
|
| 316 |
+
lr = self.scheduler.get_last_lr()[0]
|
| 317 |
+
|
| 318 |
+
# Update progress bar
|
| 319 |
+
pbar.set_postfix({
|
| 320 |
+
'loss': f'{avg_loss:.4f}',
|
| 321 |
+
'lr': f'{lr:.2e}',
|
| 322 |
+
'step': f'{self.global_step}/{self.total_steps}'
|
| 323 |
+
})
|
| 324 |
+
|
| 325 |
+
tqdm.write(
|
| 326 |
+
f"Step {self.global_step}/{self.total_steps} | "
|
| 327 |
+
f"Loss: {avg_loss:.4f} | "
|
| 328 |
+
f"LR: {lr:.2e} | "
|
| 329 |
+
f"Grad: {grad_norm:.2f}"
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
if self.wandb:
|
| 333 |
+
self.wandb.log({
|
| 334 |
+
"train/loss": avg_loss,
|
| 335 |
+
"train/learning_rate": lr,
|
| 336 |
+
"train/grad_norm": grad_norm,
|
| 337 |
+
"train/epoch": self.epoch,
|
| 338 |
+
}, step=self.global_step)
|
| 339 |
+
|
| 340 |
+
# Reset accumulators
|
| 341 |
+
accumulated_loss = 0.0
|
| 342 |
+
num_accumulated_batches = 0
|
| 343 |
+
|
| 344 |
+
# Evaluation
|
| 345 |
+
if self.config.eval_steps > 0 and self.global_step % self.config.eval_steps == 0:
|
| 346 |
+
if self.val_dataloader is not None:
|
| 347 |
+
val_metrics = self.evaluate()
|
| 348 |
+
print(f" Eval - Loss: {val_metrics['loss']:.4f}, PPL: {val_metrics['perplexity']:.2f}")
|
| 349 |
+
|
| 350 |
+
if self.wandb:
|
| 351 |
+
self.wandb.log({
|
| 352 |
+
"eval/loss": val_metrics["loss"],
|
| 353 |
+
"eval/perplexity": val_metrics["perplexity"],
|
| 354 |
+
}, step=self.global_step)
|
| 355 |
+
|
| 356 |
+
# Early stopping check during training
|
| 357 |
+
if val_metrics["loss"] < self.best_val_loss - self.config.early_stopping_threshold:
|
| 358 |
+
self.best_val_loss = val_metrics["loss"]
|
| 359 |
+
self.early_stopping_counter = 0
|
| 360 |
+
self.save_checkpoint("best")
|
| 361 |
+
print(f" New best model! Loss: {self.best_val_loss:.4f}")
|
| 362 |
+
else:
|
| 363 |
+
self.early_stopping_counter += 1
|
| 364 |
+
if self.config.early_stopping_patience > 0:
|
| 365 |
+
print(f" No improvement ({self.early_stopping_counter}/{self.config.early_stopping_patience})")
|
| 366 |
+
if self.early_stopping_counter >= self.config.early_stopping_patience:
|
| 367 |
+
print(f"\n Early stopping triggered!")
|
| 368 |
+
self.should_stop = True
|
| 369 |
+
break # Exit the training loop
|
| 370 |
+
|
| 371 |
+
# Checkpointing
|
| 372 |
+
if self.config.save_steps > 0 and self.global_step % self.config.save_steps == 0:
|
| 373 |
+
self.save_checkpoint(f"step_{self.global_step}")
|
| 374 |
+
|
| 375 |
+
# Check if early stopping was triggered
|
| 376 |
+
if self.should_stop:
|
| 377 |
+
break
|
| 378 |
+
|
| 379 |
+
return total_loss / max(num_batches, 1)
|
| 380 |
+
|
| 381 |
+
@torch.no_grad()
|
| 382 |
+
def evaluate(self) -> Dict[str, float]:
|
| 383 |
+
"""Evaluate the model on validation data.
|
| 384 |
+
|
| 385 |
+
Returns:
|
| 386 |
+
Dictionary with evaluation metrics
|
| 387 |
+
"""
|
| 388 |
+
self.model.eval()
|
| 389 |
+
total_loss = 0.0
|
| 390 |
+
total_accuracy = 0.0
|
| 391 |
+
num_batches = 0
|
| 392 |
+
|
| 393 |
+
for batch in self.val_dataloader:
|
| 394 |
+
input_ids = batch["input_ids"].to(self.device)
|
| 395 |
+
labels = batch["labels"].to(self.device)
|
| 396 |
+
|
| 397 |
+
with autocast(enabled=self.use_amp):
|
| 398 |
+
outputs = self.model(input_ids)
|
| 399 |
+
# Handle different output types (tensor, tuple, or dataclass)
|
| 400 |
+
if isinstance(outputs, torch.Tensor):
|
| 401 |
+
logits = outputs
|
| 402 |
+
elif hasattr(outputs, 'logits'):
|
| 403 |
+
logits = outputs.logits
|
| 404 |
+
else:
|
| 405 |
+
logits = outputs[0]
|
| 406 |
+
loss = self.loss_fn(logits, labels)
|
| 407 |
+
|
| 408 |
+
total_loss += loss.item()
|
| 409 |
+
total_accuracy += compute_accuracy(logits, labels).item()
|
| 410 |
+
num_batches += 1
|
| 411 |
+
|
| 412 |
+
self.model.train()
|
| 413 |
+
|
| 414 |
+
avg_loss = total_loss / max(num_batches, 1)
|
| 415 |
+
avg_accuracy = total_accuracy / max(num_batches, 1)
|
| 416 |
+
|
| 417 |
+
return {
|
| 418 |
+
"loss": avg_loss,
|
| 419 |
+
"perplexity": compute_perplexity(torch.tensor(avg_loss)).item(),
|
| 420 |
+
"accuracy": avg_accuracy,
|
| 421 |
+
}
|
| 422 |
+
|
| 423 |
+
def save_checkpoint(self, name: str):
|
| 424 |
+
"""Save a checkpoint.
|
| 425 |
+
|
| 426 |
+
Args:
|
| 427 |
+
name: Checkpoint name (e.g., "best", "epoch_1", "step_1000")
|
| 428 |
+
"""
|
| 429 |
+
checkpoint_path = os.path.join(self.config.checkpoint_dir, name)
|
| 430 |
+
os.makedirs(checkpoint_path, exist_ok=True)
|
| 431 |
+
|
| 432 |
+
# Save model
|
| 433 |
+
model_path = os.path.join(checkpoint_path, "model.pt")
|
| 434 |
+
torch.save(self.model.state_dict(), model_path)
|
| 435 |
+
|
| 436 |
+
# Save optimizer and scheduler
|
| 437 |
+
optimizer_path = os.path.join(checkpoint_path, "optimizer.pt")
|
| 438 |
+
torch.save({
|
| 439 |
+
"optimizer": self.optimizer.state_dict(),
|
| 440 |
+
"scheduler": self.scheduler.state_dict(),
|
| 441 |
+
"global_step": self.global_step,
|
| 442 |
+
"epoch": self.epoch,
|
| 443 |
+
"best_val_loss": self.best_val_loss,
|
| 444 |
+
"early_stopping_counter": self.early_stopping_counter,
|
| 445 |
+
}, optimizer_path)
|
| 446 |
+
|
| 447 |
+
# Save config
|
| 448 |
+
config_path = os.path.join(checkpoint_path, "config.json")
|
| 449 |
+
with open(config_path, "w") as f:
|
| 450 |
+
json.dump(self.config.to_dict(), f, indent=2)
|
| 451 |
+
|
| 452 |
+
print(f"Saved checkpoint: {checkpoint_path}")
|
| 453 |
+
|
| 454 |
+
# Cleanup old checkpoints
|
| 455 |
+
self._cleanup_checkpoints()
|
| 456 |
+
|
| 457 |
+
def load_checkpoint(self, checkpoint_path: str):
|
| 458 |
+
"""Load a checkpoint.
|
| 459 |
+
|
| 460 |
+
Args:
|
| 461 |
+
checkpoint_path: Path to checkpoint directory
|
| 462 |
+
"""
|
| 463 |
+
# Load model
|
| 464 |
+
model_path = os.path.join(checkpoint_path, "model.pt")
|
| 465 |
+
state_dict = torch.load(model_path, map_location=self.device)
|
| 466 |
+
|
| 467 |
+
# FIX: Handle torch.compile prefix (_orig_mod.) if present
|
| 468 |
+
if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
|
| 469 |
+
print(" Detected compiled model checkpoint, removing _orig_mod. prefix...")
|
| 470 |
+
state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
|
| 471 |
+
|
| 472 |
+
self.model.load_state_dict(state_dict)
|
| 473 |
+
|
| 474 |
+
# Load optimizer and scheduler
|
| 475 |
+
optimizer_path = os.path.join(checkpoint_path, "optimizer.pt")
|
| 476 |
+
if os.path.exists(optimizer_path):
|
| 477 |
+
state = torch.load(optimizer_path, map_location=self.device)
|
| 478 |
+
self.optimizer.load_state_dict(state["optimizer"])
|
| 479 |
+
self.scheduler.load_state_dict(state["scheduler"])
|
| 480 |
+
self.global_step = state["global_step"]
|
| 481 |
+
self.epoch = state["epoch"]
|
| 482 |
+
self.best_val_loss = state.get("best_val_loss", float("inf"))
|
| 483 |
+
self.early_stopping_counter = state.get("early_stopping_counter", 0)
|
| 484 |
+
|
| 485 |
+
# FIX: Increment epoch to start from next epoch (we saved after completing this epoch)
|
| 486 |
+
# Only if checkpoint was saved at end of epoch (epoch_* checkpoints)
|
| 487 |
+
if "epoch_" in checkpoint_path:
|
| 488 |
+
self.epoch += 1
|
| 489 |
+
print(f" Checkpoint was end-of-epoch, will start from epoch {self.epoch + 1}")
|
| 490 |
+
|
| 491 |
+
print(f"Loaded checkpoint: {checkpoint_path}")
|
| 492 |
+
print(f" Resuming from step {self.global_step}, epoch {self.epoch}")
|
| 493 |
+
print(f" Best val loss so far: {self.best_val_loss:.4f}")
|
| 494 |
+
|
| 495 |
+
def _cleanup_checkpoints(self):
|
| 496 |
+
"""Remove old checkpoints to save disk space."""
|
| 497 |
+
if self.config.save_total_limit <= 0:
|
| 498 |
+
return
|
| 499 |
+
|
| 500 |
+
checkpoint_dir = Path(self.config.checkpoint_dir)
|
| 501 |
+
checkpoints = sorted(
|
| 502 |
+
[d for d in checkpoint_dir.iterdir() if d.is_dir() and d.name.startswith("step_")],
|
| 503 |
+
key=lambda x: int(x.name.split("_")[1]),
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
# Keep only the most recent checkpoints (plus "best" and "epoch_*")
|
| 507 |
+
while len(checkpoints) > self.config.save_total_limit:
|
| 508 |
+
old_checkpoint = checkpoints.pop(0)
|
| 509 |
+
print(f"Removing old checkpoint: {old_checkpoint}")
|
| 510 |
+
import shutil
|
| 511 |
+
shutil.rmtree(old_checkpoint)
|