karthick
commited on
Commit
·
fb67af8
1
Parent(s):
d99ca15
Upload TinyStories 24.5M model - article generation success
Browse files- src/__init__.py +3 -0
- src/__pycache__/__init__.cpython-313.pyc +0 -0
- src/data/__init__.py +15 -0
- src/data/__pycache__/__init__.cpython-313.pyc +0 -0
- src/data/__pycache__/dataset.cpython-313.pyc +0 -0
- src/data/__pycache__/quality_checker.cpython-313.pyc +0 -0
- src/data/__pycache__/tokenizer.cpython-313.pyc +0 -0
- src/data/dataset.py +302 -0
- src/data/quality_checker.py +343 -0
- src/data/tokenizer.py +272 -0
- src/model/__init__.py +20 -0
- src/model/__pycache__/__init__.cpython-313.pyc +0 -0
- src/model/__pycache__/attention.cpython-313.pyc +0 -0
- src/model/__pycache__/rmsnorm.cpython-313.pyc +0 -0
- src/model/__pycache__/rope.cpython-313.pyc +0 -0
- src/model/__pycache__/swiglu.cpython-313.pyc +0 -0
- src/model/__pycache__/transformer_block.cpython-313.pyc +0 -0
- src/model/attention.py +301 -0
- src/model/rmsnorm.py +185 -0
- src/model/rope.py +217 -0
- src/model/swiglu.py +231 -0
- src/model/transformer_block.py +454 -0
src/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""TinyStories Language Model - 24.5M Parameters"""
|
| 2 |
+
|
| 3 |
+
__version__ = "1.0.0"
|
src/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (265 Bytes). View file
|
|
|
src/data/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Data processing modules for TinyStories training."""
|
| 2 |
+
|
| 3 |
+
from .tokenizer import load_tokenizer, train_tokenizer, test_tokenizer
|
| 4 |
+
from .dataset import TinyStoriesDataset, create_dataloaders
|
| 5 |
+
from .quality_checker import check_dataset_quality, DataQualityChecker
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
'load_tokenizer',
|
| 9 |
+
'train_tokenizer',
|
| 10 |
+
'test_tokenizer',
|
| 11 |
+
'TinyStoriesDataset',
|
| 12 |
+
'create_dataloaders',
|
| 13 |
+
'check_dataset_quality',
|
| 14 |
+
'DataQualityChecker',
|
| 15 |
+
]
|
src/data/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (584 Bytes). View file
|
|
|
src/data/__pycache__/dataset.cpython-313.pyc
ADDED
|
Binary file (11.6 kB). View file
|
|
|
src/data/__pycache__/quality_checker.cpython-313.pyc
ADDED
|
Binary file (17.8 kB). View file
|
|
|
src/data/__pycache__/tokenizer.cpython-313.pyc
ADDED
|
Binary file (11.1 kB). View file
|
|
|
src/data/dataset.py
ADDED
|
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Dataset and DataLoader utilities for TinyStories training.
|
| 2 |
+
|
| 3 |
+
This module provides:
|
| 4 |
+
1. TinyStoriesDataset class for loading and processing TinyStories
|
| 5 |
+
2. create_dataloaders function for creating train/val DataLoaders
|
| 6 |
+
3. Sequence packing for efficient training
|
| 7 |
+
|
| 8 |
+
TinyStories is a synthetic dataset of short stories generated by GPT-3.5/4
|
| 9 |
+
using a limited vocabulary suitable for children. Perfect for fast training
|
| 10 |
+
and testing language models.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
from torch.utils.data import Dataset, DataLoader
|
| 15 |
+
from datasets import load_dataset
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
import pickle
|
| 18 |
+
import logging
|
| 19 |
+
from typing import Dict, List, Tuple, Optional
|
| 20 |
+
from tqdm import tqdm
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class TinyStoriesDataset(Dataset):
|
| 26 |
+
"""TinyStories dataset with sequence packing for efficient training.
|
| 27 |
+
|
| 28 |
+
TinyStories is a synthetic dataset of short stories generated by GPT-3.5/4
|
| 29 |
+
using a limited vocabulary suitable for children. The dataset contains
|
| 30 |
+
~2.1M stories and is excellent for:
|
| 31 |
+
- Fast training (only ~1GB)
|
| 32 |
+
- Clean, well-formed English
|
| 33 |
+
- Testing model architecture
|
| 34 |
+
- Educational purposes
|
| 35 |
+
|
| 36 |
+
This dataset:
|
| 37 |
+
1. Loads TinyStories from HuggingFace datasets
|
| 38 |
+
2. Tokenizes the text
|
| 39 |
+
3. Packs sequences to max_seq_len for efficiency
|
| 40 |
+
4. Caches processed data for fast subsequent loading
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
tokenizer,
|
| 46 |
+
split: str = "train",
|
| 47 |
+
max_seq_len: int = 512,
|
| 48 |
+
cache_dir: Optional[str] = None,
|
| 49 |
+
):
|
| 50 |
+
"""Initialize TinyStories dataset.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
tokenizer: Tokenizer instance (must have encode method)
|
| 54 |
+
split: Dataset split ("train" or "validation")
|
| 55 |
+
max_seq_len: Maximum sequence length (default: 512, matches official paper)
|
| 56 |
+
cache_dir: Directory for caching processed data
|
| 57 |
+
"""
|
| 58 |
+
self.tokenizer = tokenizer
|
| 59 |
+
self.split = split
|
| 60 |
+
self.max_seq_len = max_seq_len
|
| 61 |
+
self.cache_dir = Path(cache_dir) if cache_dir else Path("./data/cache")
|
| 62 |
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
| 63 |
+
|
| 64 |
+
# Cache file path
|
| 65 |
+
cache_file = self.cache_dir / f"tinystories_{split}_{max_seq_len}.pkl"
|
| 66 |
+
|
| 67 |
+
# Try to load from cache
|
| 68 |
+
if cache_file.exists():
|
| 69 |
+
logger.info(f"Loading cached dataset from {cache_file}")
|
| 70 |
+
with open(cache_file, "rb") as f:
|
| 71 |
+
cache_data = pickle.load(f)
|
| 72 |
+
self.input_ids = cache_data["input_ids"]
|
| 73 |
+
self.labels = cache_data["labels"]
|
| 74 |
+
logger.info(f"Loaded {len(self.input_ids)} sequences from cache")
|
| 75 |
+
else:
|
| 76 |
+
# Process dataset
|
| 77 |
+
logger.info(f"Processing TinyStories {split} split...")
|
| 78 |
+
self.input_ids, self.labels = self._process_dataset()
|
| 79 |
+
|
| 80 |
+
# Save to cache
|
| 81 |
+
logger.info(f"Saving processed dataset to {cache_file}")
|
| 82 |
+
cache_data = {
|
| 83 |
+
"input_ids": self.input_ids,
|
| 84 |
+
"labels": self.labels,
|
| 85 |
+
}
|
| 86 |
+
with open(cache_file, "wb") as f:
|
| 87 |
+
pickle.dump(cache_data, f)
|
| 88 |
+
|
| 89 |
+
logger.info(f"Dataset ready: {len(self.input_ids)} sequences")
|
| 90 |
+
|
| 91 |
+
def _process_dataset(self) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
| 92 |
+
"""Process TinyStories dataset into packed sequences.
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
Tuple of (input_ids, labels) lists
|
| 96 |
+
"""
|
| 97 |
+
# Load dataset
|
| 98 |
+
dataset = load_dataset(
|
| 99 |
+
"roneneldan/TinyStories",
|
| 100 |
+
split=self.split,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# Tokenize all text
|
| 104 |
+
logger.info("Tokenizing dataset...")
|
| 105 |
+
all_token_ids = []
|
| 106 |
+
|
| 107 |
+
for example in tqdm(dataset, desc="Tokenizing"):
|
| 108 |
+
text = example["text"].strip()
|
| 109 |
+
if len(text) > 0: # Skip empty stories
|
| 110 |
+
# Encode text
|
| 111 |
+
if hasattr(self.tokenizer, 'encode'):
|
| 112 |
+
token_ids = self.tokenizer.encode(text, add_special_tokens=False)
|
| 113 |
+
else:
|
| 114 |
+
# Fallback for tokenizers.Tokenizer
|
| 115 |
+
token_ids = self.tokenizer.tokenizer.encode(text).ids
|
| 116 |
+
|
| 117 |
+
all_token_ids.extend(token_ids)
|
| 118 |
+
|
| 119 |
+
logger.info(f"Total tokens: {len(all_token_ids):,}")
|
| 120 |
+
|
| 121 |
+
# Pack into sequences
|
| 122 |
+
logger.info("Packing sequences...")
|
| 123 |
+
input_ids_list = []
|
| 124 |
+
labels_list = []
|
| 125 |
+
|
| 126 |
+
# Pack sequences with stride to maximize data usage
|
| 127 |
+
for i in range(0, len(all_token_ids) - 1, self.max_seq_len):
|
| 128 |
+
# Get sequence
|
| 129 |
+
seq = all_token_ids[i : i + self.max_seq_len]
|
| 130 |
+
|
| 131 |
+
# Skip if too short
|
| 132 |
+
if len(seq) < 2:
|
| 133 |
+
continue
|
| 134 |
+
|
| 135 |
+
# Create input_ids and labels
|
| 136 |
+
# input_ids: [0, 1, 2, ..., n-1]
|
| 137 |
+
# labels: [1, 2, 3, ..., n]
|
| 138 |
+
input_ids = torch.tensor(seq[:-1], dtype=torch.long)
|
| 139 |
+
labels = torch.tensor(seq[1:], dtype=torch.long)
|
| 140 |
+
|
| 141 |
+
# Pad if necessary
|
| 142 |
+
if len(input_ids) < self.max_seq_len:
|
| 143 |
+
pad_len = self.max_seq_len - len(input_ids)
|
| 144 |
+
input_ids = torch.cat([
|
| 145 |
+
input_ids,
|
| 146 |
+
torch.full((pad_len,), self.tokenizer.pad_token_id, dtype=torch.long)
|
| 147 |
+
])
|
| 148 |
+
labels = torch.cat([
|
| 149 |
+
labels,
|
| 150 |
+
torch.full((pad_len,), -100, dtype=torch.long) # -100 is ignored in loss
|
| 151 |
+
])
|
| 152 |
+
|
| 153 |
+
input_ids_list.append(input_ids)
|
| 154 |
+
labels_list.append(labels)
|
| 155 |
+
|
| 156 |
+
logger.info(f"Created {len(input_ids_list)} packed sequences")
|
| 157 |
+
|
| 158 |
+
return input_ids_list, labels_list
|
| 159 |
+
|
| 160 |
+
def __len__(self) -> int:
|
| 161 |
+
"""Return number of sequences."""
|
| 162 |
+
return len(self.input_ids)
|
| 163 |
+
|
| 164 |
+
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
| 165 |
+
"""Get a single sequence.
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
idx: Sequence index
|
| 169 |
+
|
| 170 |
+
Returns:
|
| 171 |
+
Dictionary with 'input_ids' and 'labels'
|
| 172 |
+
"""
|
| 173 |
+
return {
|
| 174 |
+
"input_ids": self.input_ids[idx],
|
| 175 |
+
"labels": self.labels[idx],
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def collate_fn(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
|
| 180 |
+
"""Collate function for DataLoader.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
batch: List of dictionaries with 'input_ids' and 'labels'
|
| 184 |
+
|
| 185 |
+
Returns:
|
| 186 |
+
Batched dictionary
|
| 187 |
+
"""
|
| 188 |
+
input_ids = torch.stack([item["input_ids"] for item in batch])
|
| 189 |
+
labels = torch.stack([item["labels"] for item in batch])
|
| 190 |
+
|
| 191 |
+
return {
|
| 192 |
+
"input_ids": input_ids,
|
| 193 |
+
"labels": labels,
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def create_dataloaders(
|
| 198 |
+
tokenizer,
|
| 199 |
+
batch_size: int,
|
| 200 |
+
max_seq_len: int,
|
| 201 |
+
cache_dir: str,
|
| 202 |
+
dataset_name: str = "tinystories",
|
| 203 |
+
num_workers: int = 0,
|
| 204 |
+
pin_memory: bool = True,
|
| 205 |
+
drop_last: bool = True,
|
| 206 |
+
) -> Tuple[DataLoader, DataLoader]:
|
| 207 |
+
"""Create train and validation DataLoaders for TinyStories.
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
tokenizer: Tokenizer instance
|
| 211 |
+
batch_size: Batch size per device
|
| 212 |
+
max_seq_len: Maximum sequence length (512 recommended for TinyStories)
|
| 213 |
+
cache_dir: Directory for caching processed data
|
| 214 |
+
dataset_name: Dataset to use (default: "tinystories")
|
| 215 |
+
num_workers: Number of data loading workers (use 0 for Windows)
|
| 216 |
+
pin_memory: Whether to pin memory for faster GPU transfer
|
| 217 |
+
drop_last: Whether to drop last incomplete batch
|
| 218 |
+
|
| 219 |
+
Returns:
|
| 220 |
+
Tuple of (train_loader, val_loader)
|
| 221 |
+
"""
|
| 222 |
+
logger.info("Using TinyStories dataset")
|
| 223 |
+
|
| 224 |
+
logger.info("Creating train dataset...")
|
| 225 |
+
train_dataset = TinyStoriesDataset(
|
| 226 |
+
tokenizer=tokenizer,
|
| 227 |
+
split="train",
|
| 228 |
+
max_seq_len=max_seq_len,
|
| 229 |
+
cache_dir=cache_dir,
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
logger.info("Creating validation dataset...")
|
| 233 |
+
val_dataset = TinyStoriesDataset(
|
| 234 |
+
tokenizer=tokenizer,
|
| 235 |
+
split="validation",
|
| 236 |
+
max_seq_len=max_seq_len,
|
| 237 |
+
cache_dir=cache_dir,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
# Create DataLoaders
|
| 241 |
+
train_loader = DataLoader(
|
| 242 |
+
train_dataset,
|
| 243 |
+
batch_size=batch_size,
|
| 244 |
+
shuffle=True,
|
| 245 |
+
num_workers=num_workers,
|
| 246 |
+
pin_memory=pin_memory,
|
| 247 |
+
drop_last=drop_last,
|
| 248 |
+
collate_fn=collate_fn,
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
val_loader = DataLoader(
|
| 252 |
+
val_dataset,
|
| 253 |
+
batch_size=batch_size,
|
| 254 |
+
shuffle=False,
|
| 255 |
+
num_workers=num_workers,
|
| 256 |
+
pin_memory=pin_memory,
|
| 257 |
+
drop_last=False,
|
| 258 |
+
collate_fn=collate_fn,
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
logger.info(f"Train batches: {len(train_loader)}")
|
| 262 |
+
logger.info(f"Validation batches: {len(val_loader)}")
|
| 263 |
+
|
| 264 |
+
return train_loader, val_loader
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
# Test the dataset
|
| 268 |
+
if __name__ == "__main__":
|
| 269 |
+
from .tokenizer import load_tokenizer
|
| 270 |
+
|
| 271 |
+
print("Testing TinyStoriesDataset...")
|
| 272 |
+
|
| 273 |
+
# Load tokenizer (assumes it exists)
|
| 274 |
+
tokenizer_path = "./tokenizer/wikimini_32k"
|
| 275 |
+
if Path(tokenizer_path).exists():
|
| 276 |
+
tokenizer = load_tokenizer(tokenizer_path)
|
| 277 |
+
|
| 278 |
+
# Create small dataset for testing
|
| 279 |
+
dataset = TinyStoriesDataset(
|
| 280 |
+
tokenizer=tokenizer,
|
| 281 |
+
split="validation", # Use smaller split for testing
|
| 282 |
+
max_seq_len=128,
|
| 283 |
+
cache_dir="./data/cache_test",
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
print(f"\nDataset size: {len(dataset)}")
|
| 287 |
+
print(f"Sample batch:")
|
| 288 |
+
sample = dataset[0]
|
| 289 |
+
print(f" Input IDs shape: {sample['input_ids'].shape}")
|
| 290 |
+
print(f" Labels shape: {sample['labels'].shape}")
|
| 291 |
+
print(f" First 10 input IDs: {sample['input_ids'][:10]}")
|
| 292 |
+
print(f" First 10 labels: {sample['labels'][:10]}")
|
| 293 |
+
|
| 294 |
+
# Test DataLoader
|
| 295 |
+
loader = DataLoader(dataset, batch_size=4, collate_fn=collate_fn)
|
| 296 |
+
batch = next(iter(loader))
|
| 297 |
+
print(f"\nDataLoader batch:")
|
| 298 |
+
print(f" Input IDs shape: {batch['input_ids'].shape}")
|
| 299 |
+
print(f" Labels shape: {batch['labels'].shape}")
|
| 300 |
+
else:
|
| 301 |
+
print(f"Tokenizer not found at {tokenizer_path}")
|
| 302 |
+
print("Please train tokenizer first: python scripts/train_tokenizer.py")
|
src/data/quality_checker.py
ADDED
|
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Data Quality Checker for training datasets.
|
| 2 |
+
|
| 3 |
+
This module provides tools to validate dataset quality before training:
|
| 4 |
+
- Detects artifacts (HTML tags, URLs, special tokens)
|
| 5 |
+
- Checks for malformed text
|
| 6 |
+
- Validates text statistics
|
| 7 |
+
- Reports quality issues
|
| 8 |
+
|
| 9 |
+
Prevents training on corrupted or low-quality data.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import re
|
| 13 |
+
import logging
|
| 14 |
+
from typing import Dict, List, Tuple, Optional
|
| 15 |
+
from collections import Counter
|
| 16 |
+
from datasets import load_dataset
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class DataQualityChecker:
|
| 23 |
+
"""Check dataset quality before training."""
|
| 24 |
+
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
dataset_name: str,
|
| 28 |
+
split: str = "train",
|
| 29 |
+
sample_size: Optional[int] = 10000,
|
| 30 |
+
strict: bool = False,
|
| 31 |
+
):
|
| 32 |
+
"""Initialize quality checker.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
dataset_name: Name of dataset (e.g., "roneneldan/TinyStories")
|
| 36 |
+
split: Dataset split to check ("train" or "validation")
|
| 37 |
+
sample_size: Number of samples to check (None for all)
|
| 38 |
+
strict: If True, raise errors on issues; if False, only warn
|
| 39 |
+
"""
|
| 40 |
+
self.dataset_name = dataset_name
|
| 41 |
+
self.split = split
|
| 42 |
+
self.sample_size = sample_size
|
| 43 |
+
self.strict = strict
|
| 44 |
+
|
| 45 |
+
# Quality metrics
|
| 46 |
+
self.issues: Dict[str, List[Tuple[int, str]]] = {
|
| 47 |
+
"html_tags": [],
|
| 48 |
+
"urls": [],
|
| 49 |
+
"emails": [],
|
| 50 |
+
"excessive_punctuation": [],
|
| 51 |
+
"malformed_unicode": [],
|
| 52 |
+
"empty_text": [],
|
| 53 |
+
"extremely_short": [],
|
| 54 |
+
"extremely_long": [],
|
| 55 |
+
"suspicious_patterns": [],
|
| 56 |
+
"special_tokens": [],
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
self.stats = {
|
| 60 |
+
"total_samples": 0,
|
| 61 |
+
"total_chars": 0,
|
| 62 |
+
"total_words": 0,
|
| 63 |
+
"avg_length": 0,
|
| 64 |
+
"vocabulary_size": 0,
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
def check_quality(self) -> Dict:
|
| 68 |
+
"""Run all quality checks and return results.
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
Dictionary with quality report and pass/fail status
|
| 72 |
+
"""
|
| 73 |
+
logger.info(f"Loading dataset {self.dataset_name} ({self.split} split)...")
|
| 74 |
+
|
| 75 |
+
# Load dataset
|
| 76 |
+
if "tinystories" in self.dataset_name.lower():
|
| 77 |
+
dataset = load_dataset("roneneldan/TinyStories", split=self.split)
|
| 78 |
+
elif "wikitext" in self.dataset_name.lower():
|
| 79 |
+
dataset = load_dataset("wikitext", "wikitext-103-raw-v1", split=self.split, trust_remote_code=True)
|
| 80 |
+
else:
|
| 81 |
+
dataset = load_dataset(self.dataset_name, split=self.split)
|
| 82 |
+
|
| 83 |
+
# Limit sample size if requested
|
| 84 |
+
if self.sample_size and len(dataset) > self.sample_size:
|
| 85 |
+
logger.info(f"Sampling {self.sample_size} examples from {len(dataset)} total")
|
| 86 |
+
indices = range(0, len(dataset), len(dataset) // self.sample_size)
|
| 87 |
+
dataset = dataset.select(list(indices)[:self.sample_size])
|
| 88 |
+
|
| 89 |
+
logger.info(f"Checking quality of {len(dataset)} examples...")
|
| 90 |
+
|
| 91 |
+
# Run checks
|
| 92 |
+
vocabulary = set()
|
| 93 |
+
|
| 94 |
+
for idx, example in enumerate(tqdm(dataset, desc="Quality Check")):
|
| 95 |
+
text = example.get("text", "")
|
| 96 |
+
|
| 97 |
+
# Update stats
|
| 98 |
+
self.stats["total_samples"] += 1
|
| 99 |
+
self.stats["total_chars"] += len(text)
|
| 100 |
+
words = text.split()
|
| 101 |
+
self.stats["total_words"] += len(words)
|
| 102 |
+
vocabulary.update(words)
|
| 103 |
+
|
| 104 |
+
# Run individual checks
|
| 105 |
+
self._check_html_tags(idx, text)
|
| 106 |
+
self._check_urls(idx, text)
|
| 107 |
+
self._check_emails(idx, text)
|
| 108 |
+
self._check_excessive_punctuation(idx, text)
|
| 109 |
+
self._check_malformed_unicode(idx, text)
|
| 110 |
+
self._check_empty_text(idx, text)
|
| 111 |
+
self._check_length_extremes(idx, text)
|
| 112 |
+
self._check_suspicious_patterns(idx, text)
|
| 113 |
+
self._check_special_tokens(idx, text)
|
| 114 |
+
|
| 115 |
+
# Calculate final stats
|
| 116 |
+
if self.stats["total_samples"] > 0:
|
| 117 |
+
self.stats["avg_length"] = self.stats["total_chars"] / self.stats["total_samples"]
|
| 118 |
+
self.stats["avg_words"] = self.stats["total_words"] / self.stats["total_samples"]
|
| 119 |
+
self.stats["vocabulary_size"] = len(vocabulary)
|
| 120 |
+
|
| 121 |
+
# Generate report
|
| 122 |
+
report = self._generate_report()
|
| 123 |
+
|
| 124 |
+
return report
|
| 125 |
+
|
| 126 |
+
def _check_html_tags(self, idx: int, text: str):
|
| 127 |
+
"""Check for HTML tags."""
|
| 128 |
+
html_pattern = r'<[^>]+>'
|
| 129 |
+
if re.search(html_pattern, text):
|
| 130 |
+
self.issues["html_tags"].append((idx, text[:100]))
|
| 131 |
+
|
| 132 |
+
def _check_urls(self, idx: int, text: str):
|
| 133 |
+
"""Check for URLs."""
|
| 134 |
+
url_pattern = r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+'
|
| 135 |
+
if re.search(url_pattern, text):
|
| 136 |
+
self.issues["urls"].append((idx, text[:100]))
|
| 137 |
+
|
| 138 |
+
def _check_emails(self, idx: int, text: str):
|
| 139 |
+
"""Check for email addresses."""
|
| 140 |
+
email_pattern = r'[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}'
|
| 141 |
+
if re.search(email_pattern, text):
|
| 142 |
+
self.issues["emails"].append((idx, text[:100]))
|
| 143 |
+
|
| 144 |
+
def _check_excessive_punctuation(self, idx: int, text: str):
|
| 145 |
+
"""Check for excessive punctuation (possible artifacts)."""
|
| 146 |
+
# More than 5 consecutive punctuation marks
|
| 147 |
+
if re.search(r'[!?.,;:]{5,}', text):
|
| 148 |
+
self.issues["excessive_punctuation"].append((idx, text[:100]))
|
| 149 |
+
|
| 150 |
+
# More than 20% punctuation
|
| 151 |
+
if len(text) > 0:
|
| 152 |
+
punct_count = sum(1 for c in text if c in '!?.,;:')
|
| 153 |
+
if punct_count / len(text) > 0.2:
|
| 154 |
+
self.issues["excessive_punctuation"].append((idx, text[:100]))
|
| 155 |
+
|
| 156 |
+
def _check_malformed_unicode(self, idx: int, text: str):
|
| 157 |
+
"""Check for malformed Unicode characters."""
|
| 158 |
+
# Look for replacement characters or control characters
|
| 159 |
+
if '�' in text or '\ufffd' in text:
|
| 160 |
+
self.issues["malformed_unicode"].append((idx, text[:100]))
|
| 161 |
+
|
| 162 |
+
# Control characters (excluding whitespace)
|
| 163 |
+
if re.search(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]', text):
|
| 164 |
+
self.issues["malformed_unicode"].append((idx, text[:100]))
|
| 165 |
+
|
| 166 |
+
def _check_empty_text(self, idx: int, text: str):
|
| 167 |
+
"""Check for empty or whitespace-only text."""
|
| 168 |
+
if not text or not text.strip():
|
| 169 |
+
self.issues["empty_text"].append((idx, text))
|
| 170 |
+
|
| 171 |
+
def _check_length_extremes(self, idx: int, text: str):
|
| 172 |
+
"""Check for extremely short or long text."""
|
| 173 |
+
if len(text.strip()) < 10:
|
| 174 |
+
self.issues["extremely_short"].append((idx, text))
|
| 175 |
+
elif len(text) > 50000: # Suspiciously long
|
| 176 |
+
self.issues["extremely_long"].append((idx, text[:100]))
|
| 177 |
+
|
| 178 |
+
def _check_suspicious_patterns(self, idx: int, text: str):
|
| 179 |
+
"""Check for suspicious patterns."""
|
| 180 |
+
# Repeated characters (e.g., "aaaaaa" more than 10 times)
|
| 181 |
+
if re.search(r'(.)\1{10,}', text):
|
| 182 |
+
self.issues["suspicious_patterns"].append((idx, text[:100]))
|
| 183 |
+
|
| 184 |
+
# Excessive whitespace
|
| 185 |
+
if re.search(r'\s{10,}', text):
|
| 186 |
+
self.issues["suspicious_patterns"].append((idx, text[:100]))
|
| 187 |
+
|
| 188 |
+
def _check_special_tokens(self, idx: int, text: str):
|
| 189 |
+
"""Check for special tokens that shouldn't be in raw text."""
|
| 190 |
+
# Common tokenizer special tokens
|
| 191 |
+
special_tokens = ['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]', '<|endoftext|>', '<pad>', '<unk>']
|
| 192 |
+
for token in special_tokens:
|
| 193 |
+
if token in text:
|
| 194 |
+
self.issues["special_tokens"].append((idx, text[:100]))
|
| 195 |
+
break
|
| 196 |
+
|
| 197 |
+
def _generate_report(self) -> Dict:
|
| 198 |
+
"""Generate quality report.
|
| 199 |
+
|
| 200 |
+
Returns:
|
| 201 |
+
Dictionary with quality metrics and pass/fail status
|
| 202 |
+
"""
|
| 203 |
+
total_issues = sum(len(issues) for issues in self.issues.values())
|
| 204 |
+
issue_percentage = (total_issues / self.stats["total_samples"] * 100) if self.stats["total_samples"] > 0 else 0
|
| 205 |
+
|
| 206 |
+
# Determine quality level
|
| 207 |
+
if issue_percentage == 0:
|
| 208 |
+
quality_level = "EXCELLENT"
|
| 209 |
+
passed = True
|
| 210 |
+
elif issue_percentage < 1:
|
| 211 |
+
quality_level = "GOOD"
|
| 212 |
+
passed = True
|
| 213 |
+
elif issue_percentage < 5:
|
| 214 |
+
quality_level = "ACCEPTABLE"
|
| 215 |
+
passed = not self.strict
|
| 216 |
+
elif issue_percentage < 10:
|
| 217 |
+
quality_level = "POOR"
|
| 218 |
+
passed = False
|
| 219 |
+
else:
|
| 220 |
+
quality_level = "CRITICAL"
|
| 221 |
+
passed = False
|
| 222 |
+
|
| 223 |
+
report = {
|
| 224 |
+
"dataset": self.dataset_name,
|
| 225 |
+
"split": self.split,
|
| 226 |
+
"quality_level": quality_level,
|
| 227 |
+
"passed": passed,
|
| 228 |
+
"stats": self.stats,
|
| 229 |
+
"issues": {
|
| 230 |
+
key: {
|
| 231 |
+
"count": len(value),
|
| 232 |
+
"percentage": (len(value) / self.stats["total_samples"] * 100) if self.stats["total_samples"] > 0 else 0,
|
| 233 |
+
"samples": value[:3] # First 3 examples
|
| 234 |
+
}
|
| 235 |
+
for key, value in self.issues.items() if len(value) > 0
|
| 236 |
+
},
|
| 237 |
+
"total_issues": total_issues,
|
| 238 |
+
"issue_percentage": issue_percentage,
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
return report
|
| 242 |
+
|
| 243 |
+
def print_report(self, report: Dict):
|
| 244 |
+
"""Print formatted quality report.
|
| 245 |
+
|
| 246 |
+
Args:
|
| 247 |
+
report: Report dictionary from check_quality()
|
| 248 |
+
"""
|
| 249 |
+
logger.info("\n" + "=" * 70)
|
| 250 |
+
logger.info("DATA QUALITY REPORT")
|
| 251 |
+
logger.info("=" * 70)
|
| 252 |
+
logger.info(f"Dataset: {report['dataset']} ({report['split']} split)")
|
| 253 |
+
logger.info(f"Quality Level: {report['quality_level']}")
|
| 254 |
+
logger.info(f"Status: {'✅ PASSED' if report['passed'] else '❌ FAILED'}")
|
| 255 |
+
logger.info("")
|
| 256 |
+
|
| 257 |
+
# Statistics
|
| 258 |
+
logger.info("Statistics:")
|
| 259 |
+
logger.info(f" Total Samples: {report['stats']['total_samples']:,}")
|
| 260 |
+
logger.info(f" Avg Length: {report['stats']['avg_length']:.1f} chars")
|
| 261 |
+
logger.info(f" Avg Words: {report['stats'].get('avg_words', 0):.1f} words")
|
| 262 |
+
logger.info(f" Vocabulary Size: {report['stats']['vocabulary_size']:,}")
|
| 263 |
+
logger.info("")
|
| 264 |
+
|
| 265 |
+
# Issues
|
| 266 |
+
if report['issues']:
|
| 267 |
+
logger.warning(f"Found {report['total_issues']} issues ({report['issue_percentage']:.2f}% of samples)")
|
| 268 |
+
logger.warning("")
|
| 269 |
+
for issue_type, details in report['issues'].items():
|
| 270 |
+
logger.warning(f" {issue_type.replace('_', ' ').title()}:")
|
| 271 |
+
logger.warning(f" Count: {details['count']} ({details['percentage']:.2f}%)")
|
| 272 |
+
if details['samples']:
|
| 273 |
+
logger.warning(f" Example: {details['samples'][0][1][:80]}...")
|
| 274 |
+
logger.warning("")
|
| 275 |
+
else:
|
| 276 |
+
logger.info("✅ No quality issues found!")
|
| 277 |
+
|
| 278 |
+
logger.info("=" * 70)
|
| 279 |
+
|
| 280 |
+
# Recommendations
|
| 281 |
+
if not report['passed']:
|
| 282 |
+
logger.error("\n⚠️ DATA HAS QUALITY ISSUES - Training not recommended!")
|
| 283 |
+
logger.error("Recommendations:")
|
| 284 |
+
if report['issues'].get('html_tags'):
|
| 285 |
+
logger.error(" - Remove HTML tags from text")
|
| 286 |
+
if report['issues'].get('urls'):
|
| 287 |
+
logger.error(" - Remove or mask URLs")
|
| 288 |
+
if report['issues'].get('malformed_unicode'):
|
| 289 |
+
logger.error(" - Fix Unicode encoding issues")
|
| 290 |
+
if report['issues'].get('empty_text'):
|
| 291 |
+
logger.error(" - Remove empty samples")
|
| 292 |
+
logger.error("")
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def check_dataset_quality(
|
| 296 |
+
dataset_name: str,
|
| 297 |
+
split: str = "train",
|
| 298 |
+
sample_size: Optional[int] = 10000,
|
| 299 |
+
strict: bool = False,
|
| 300 |
+
) -> bool:
|
| 301 |
+
"""Quick function to check dataset quality.
|
| 302 |
+
|
| 303 |
+
Args:
|
| 304 |
+
dataset_name: Dataset name or HuggingFace ID
|
| 305 |
+
split: Split to check
|
| 306 |
+
sample_size: Number of samples to check (None for all)
|
| 307 |
+
strict: If True, fail on any issues
|
| 308 |
+
|
| 309 |
+
Returns:
|
| 310 |
+
True if quality is acceptable, False otherwise
|
| 311 |
+
"""
|
| 312 |
+
checker = DataQualityChecker(
|
| 313 |
+
dataset_name=dataset_name,
|
| 314 |
+
split=split,
|
| 315 |
+
sample_size=sample_size,
|
| 316 |
+
strict=strict,
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
report = checker.check_quality()
|
| 320 |
+
checker.print_report(report)
|
| 321 |
+
|
| 322 |
+
return report["passed"]
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
if __name__ == "__main__":
|
| 326 |
+
import argparse
|
| 327 |
+
|
| 328 |
+
parser = argparse.ArgumentParser(description="Check dataset quality")
|
| 329 |
+
parser.add_argument("--dataset", type=str, required=True, help="Dataset name")
|
| 330 |
+
parser.add_argument("--split", type=str, default="train", help="Dataset split")
|
| 331 |
+
parser.add_argument("--sample-size", type=int, default=10000, help="Number of samples to check")
|
| 332 |
+
parser.add_argument("--strict", action="store_true", help="Fail on any issues")
|
| 333 |
+
|
| 334 |
+
args = parser.parse_args()
|
| 335 |
+
|
| 336 |
+
passed = check_dataset_quality(
|
| 337 |
+
dataset_name=args.dataset,
|
| 338 |
+
split=args.split,
|
| 339 |
+
sample_size=args.sample_size,
|
| 340 |
+
strict=args.strict,
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
exit(0 if passed else 1)
|
src/data/tokenizer.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tokenizer training and loading utilities for WikiMini model.
|
| 2 |
+
|
| 3 |
+
This module provides functions to:
|
| 4 |
+
1. Train a BPE tokenizer on WikiText-103
|
| 5 |
+
2. Load a trained tokenizer from disk
|
| 6 |
+
3. Test tokenizer functionality
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Optional, List
|
| 12 |
+
from tokenizers import Tokenizer, models, trainers, pre_tokenizers, decoders, processors
|
| 13 |
+
from datasets import load_dataset
|
| 14 |
+
import logging
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def train_tokenizer(
|
| 20 |
+
vocab_size: int = 32000,
|
| 21 |
+
min_frequency: int = 2,
|
| 22 |
+
output_dir: str = "./tokenizer/wikimini_32k",
|
| 23 |
+
show_progress: bool = True,
|
| 24 |
+
) -> Tokenizer:
|
| 25 |
+
"""Train a BPE tokenizer on WikiText-103 dataset.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
vocab_size: Size of the vocabulary
|
| 29 |
+
min_frequency: Minimum frequency for tokens
|
| 30 |
+
output_dir: Directory to save the trained tokenizer
|
| 31 |
+
show_progress: Whether to show progress during training
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
Trained tokenizer
|
| 35 |
+
"""
|
| 36 |
+
logger.info(f"Training BPE tokenizer with vocab_size={vocab_size}")
|
| 37 |
+
|
| 38 |
+
# Initialize BPE tokenizer
|
| 39 |
+
tokenizer = Tokenizer(models.BPE(unk_token="<unk>"))
|
| 40 |
+
|
| 41 |
+
# Pre-tokenization (split on whitespace and punctuation)
|
| 42 |
+
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
|
| 43 |
+
|
| 44 |
+
# Decoder
|
| 45 |
+
tokenizer.decoder = decoders.ByteLevel()
|
| 46 |
+
|
| 47 |
+
# Configure trainer
|
| 48 |
+
special_tokens = [
|
| 49 |
+
"<unk>", # Unknown token
|
| 50 |
+
"<s>", # Begin of sentence
|
| 51 |
+
"</s>", # End of sentence
|
| 52 |
+
"<pad>", # Padding token
|
| 53 |
+
]
|
| 54 |
+
|
| 55 |
+
trainer = trainers.BpeTrainer(
|
| 56 |
+
vocab_size=vocab_size,
|
| 57 |
+
min_frequency=min_frequency,
|
| 58 |
+
special_tokens=special_tokens,
|
| 59 |
+
show_progress=show_progress,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
# Load WikiText-103 dataset
|
| 63 |
+
logger.info("Loading WikiText-103 dataset...")
|
| 64 |
+
dataset = load_dataset("wikitext", "wikitext-103-raw-v1", split="train")
|
| 65 |
+
|
| 66 |
+
# Create iterator for training
|
| 67 |
+
def batch_iterator(batch_size: int = 1000):
|
| 68 |
+
"""Yield batches of text for training."""
|
| 69 |
+
for i in range(0, len(dataset), batch_size):
|
| 70 |
+
batch = dataset[i : i + batch_size]
|
| 71 |
+
yield batch["text"]
|
| 72 |
+
|
| 73 |
+
# Train tokenizer
|
| 74 |
+
logger.info("Training tokenizer...")
|
| 75 |
+
tokenizer.train_from_iterator(batch_iterator(), trainer=trainer)
|
| 76 |
+
|
| 77 |
+
# Add post-processor for special tokens
|
| 78 |
+
tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
|
| 79 |
+
|
| 80 |
+
# Enable padding
|
| 81 |
+
tokenizer.enable_padding(
|
| 82 |
+
pad_id=tokenizer.token_to_id("<pad>"),
|
| 83 |
+
pad_token="<pad>",
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
# Enable truncation
|
| 87 |
+
tokenizer.enable_truncation(max_length=2048)
|
| 88 |
+
|
| 89 |
+
# Save tokenizer
|
| 90 |
+
output_path = Path(output_dir)
|
| 91 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 92 |
+
|
| 93 |
+
tokenizer_file = output_path / "tokenizer.json"
|
| 94 |
+
tokenizer.save(str(tokenizer_file))
|
| 95 |
+
logger.info(f"Tokenizer saved to {tokenizer_file}")
|
| 96 |
+
|
| 97 |
+
# Save config
|
| 98 |
+
config = {
|
| 99 |
+
"vocab_size": vocab_size,
|
| 100 |
+
"model_type": "BPE",
|
| 101 |
+
"unk_token": "<unk>",
|
| 102 |
+
"bos_token": "<s>",
|
| 103 |
+
"eos_token": "</s>",
|
| 104 |
+
"pad_token": "<pad>",
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
import json
|
| 108 |
+
config_file = output_path / "config.json"
|
| 109 |
+
with open(config_file, 'w') as f:
|
| 110 |
+
json.dump(config, f, indent=2)
|
| 111 |
+
logger.info(f"Config saved to {config_file}")
|
| 112 |
+
|
| 113 |
+
return tokenizer
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def load_tokenizer(tokenizer_path: str, return_wrapper: bool = True):
|
| 117 |
+
"""Load a trained tokenizer from disk.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
tokenizer_path: Path to the tokenizer directory or file
|
| 121 |
+
return_wrapper: If True, returns TokenizerWrapper (default), else raw Tokenizer
|
| 122 |
+
|
| 123 |
+
Returns:
|
| 124 |
+
Loaded tokenizer (wrapped by default for compatibility)
|
| 125 |
+
"""
|
| 126 |
+
tokenizer_path = Path(tokenizer_path)
|
| 127 |
+
|
| 128 |
+
# Handle both directory and file paths
|
| 129 |
+
if tokenizer_path.is_dir():
|
| 130 |
+
tokenizer_file = tokenizer_path / "tokenizer.json"
|
| 131 |
+
else:
|
| 132 |
+
tokenizer_file = tokenizer_path
|
| 133 |
+
|
| 134 |
+
if not tokenizer_file.exists():
|
| 135 |
+
raise FileNotFoundError(f"Tokenizer file not found: {tokenizer_file}")
|
| 136 |
+
|
| 137 |
+
logger.info(f"Loading tokenizer from {tokenizer_file}")
|
| 138 |
+
tokenizer = Tokenizer.from_file(str(tokenizer_file))
|
| 139 |
+
|
| 140 |
+
# Return wrapped version for easier use (supports len(), etc.)
|
| 141 |
+
if return_wrapper:
|
| 142 |
+
return TokenizerWrapper(tokenizer)
|
| 143 |
+
|
| 144 |
+
return tokenizer
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def test_tokenizer(tokenizer: Tokenizer) -> None:
|
| 148 |
+
"""Test tokenizer with sample text.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
tokenizer: Tokenizer to test
|
| 152 |
+
"""
|
| 153 |
+
print("\n" + "="*70)
|
| 154 |
+
print(" "*25 + "Tokenizer Test")
|
| 155 |
+
print("="*70)
|
| 156 |
+
|
| 157 |
+
# Get vocab info
|
| 158 |
+
vocab_size = tokenizer.get_vocab_size()
|
| 159 |
+
print(f"\nVocabulary size: {vocab_size:,}")
|
| 160 |
+
|
| 161 |
+
# Test special tokens
|
| 162 |
+
print("\nSpecial tokens:")
|
| 163 |
+
special_tokens = ["<unk>", "<s>", "</s>", "<pad>"]
|
| 164 |
+
for token in special_tokens:
|
| 165 |
+
token_id = tokenizer.token_to_id(token)
|
| 166 |
+
print(f" {token:8s} -> ID {token_id}")
|
| 167 |
+
|
| 168 |
+
# Test encoding/decoding
|
| 169 |
+
test_texts = [
|
| 170 |
+
"The quick brown fox jumps over the lazy dog.",
|
| 171 |
+
"Machine learning is a subset of artificial intelligence.",
|
| 172 |
+
"WikiText-103 is a large-scale language modeling benchmark.",
|
| 173 |
+
]
|
| 174 |
+
|
| 175 |
+
print("\nEncoding/Decoding tests:")
|
| 176 |
+
print("-" * 70)
|
| 177 |
+
|
| 178 |
+
for i, text in enumerate(test_texts, 1):
|
| 179 |
+
# Encode
|
| 180 |
+
encoding = tokenizer.encode(text)
|
| 181 |
+
tokens = encoding.tokens
|
| 182 |
+
ids = encoding.ids
|
| 183 |
+
|
| 184 |
+
# Decode
|
| 185 |
+
decoded = tokenizer.decode(ids)
|
| 186 |
+
|
| 187 |
+
print(f"\nTest {i}:")
|
| 188 |
+
print(f" Original: {text}")
|
| 189 |
+
print(f" Tokens: {len(tokens)}")
|
| 190 |
+
print(f" IDs: {ids[:10]}..." if len(ids) > 10 else f" IDs: {ids}")
|
| 191 |
+
print(f" Decoded: {decoded}")
|
| 192 |
+
|
| 193 |
+
# Check round-trip
|
| 194 |
+
if decoded.strip() == text.strip():
|
| 195 |
+
print(" ✅ Round-trip successful")
|
| 196 |
+
else:
|
| 197 |
+
print(" ⚠️ Round-trip differs slightly (common with BPE)")
|
| 198 |
+
|
| 199 |
+
# Test batch encoding
|
| 200 |
+
print("\n\nBatch encoding test:")
|
| 201 |
+
print("-" * 70)
|
| 202 |
+
encodings = tokenizer.encode_batch(test_texts)
|
| 203 |
+
print(f" Batch size: {len(encodings)}")
|
| 204 |
+
print(f" Token counts: {[len(enc.ids) for enc in encodings]}")
|
| 205 |
+
|
| 206 |
+
print("\n" + "="*70)
|
| 207 |
+
print(" "*25 + "✅ Test Complete")
|
| 208 |
+
print("="*70 + "\n")
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
# Wrapper class for compatibility with HuggingFace-style interface
|
| 212 |
+
class TokenizerWrapper:
|
| 213 |
+
"""Wrapper to make tokenizers.Tokenizer compatible with expected interface."""
|
| 214 |
+
|
| 215 |
+
def __init__(self, tokenizer: Tokenizer):
|
| 216 |
+
self.tokenizer = tokenizer
|
| 217 |
+
self._vocab_size = tokenizer.get_vocab_size()
|
| 218 |
+
|
| 219 |
+
# Get special token IDs - support multiple formats
|
| 220 |
+
# Try standard format first, then TinyStories custom format
|
| 221 |
+
self.pad_token_id = (
|
| 222 |
+
tokenizer.token_to_id("<pad>") or
|
| 223 |
+
tokenizer.token_to_id("<|padding|>") or
|
| 224 |
+
0 # Fallback to 0 if not found
|
| 225 |
+
)
|
| 226 |
+
self.bos_token_id = (
|
| 227 |
+
tokenizer.token_to_id("<s>") or
|
| 228 |
+
tokenizer.token_to_id("<|startoftext|>")
|
| 229 |
+
)
|
| 230 |
+
self.eos_token_id = (
|
| 231 |
+
tokenizer.token_to_id("</s>") or
|
| 232 |
+
tokenizer.token_to_id("<|endoftext|>")
|
| 233 |
+
)
|
| 234 |
+
self.unk_token_id = tokenizer.token_to_id("<unk>")
|
| 235 |
+
|
| 236 |
+
def __call__(self, text, **kwargs):
|
| 237 |
+
"""Encode text (callable interface)."""
|
| 238 |
+
if isinstance(text, str):
|
| 239 |
+
return self.tokenizer.encode(text).ids
|
| 240 |
+
elif isinstance(text, list):
|
| 241 |
+
return [self.tokenizer.encode(t).ids for t in text]
|
| 242 |
+
|
| 243 |
+
def encode(self, text, add_special_tokens=True):
|
| 244 |
+
"""Encode text to token IDs."""
|
| 245 |
+
encoding = self.tokenizer.encode(text)
|
| 246 |
+
return encoding.ids
|
| 247 |
+
|
| 248 |
+
def decode(self, token_ids, skip_special_tokens=True):
|
| 249 |
+
"""Decode token IDs to text."""
|
| 250 |
+
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
|
| 251 |
+
|
| 252 |
+
def __len__(self):
|
| 253 |
+
"""Return vocabulary size."""
|
| 254 |
+
return self._vocab_size
|
| 255 |
+
|
| 256 |
+
@property
|
| 257 |
+
def vocab_size(self):
|
| 258 |
+
"""Vocabulary size property."""
|
| 259 |
+
return self._vocab_size
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def create_tokenizer_wrapper(tokenizer_path: str) -> TokenizerWrapper:
|
| 263 |
+
"""Create a wrapped tokenizer for easier use.
|
| 264 |
+
|
| 265 |
+
Args:
|
| 266 |
+
tokenizer_path: Path to tokenizer directory or file
|
| 267 |
+
|
| 268 |
+
Returns:
|
| 269 |
+
TokenizerWrapper instance
|
| 270 |
+
"""
|
| 271 |
+
tokenizer = load_tokenizer(tokenizer_path, return_wrapper=False)
|
| 272 |
+
return TokenizerWrapper(tokenizer)
|
src/model/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Model components for WikiMini 95M."""
|
| 2 |
+
|
| 3 |
+
from .rmsnorm import RMSNorm, RMSNormOptimized
|
| 4 |
+
from .rope import RotaryPositionEmbeddings, RotaryPositionEmbeddingsComplex
|
| 5 |
+
from .swiglu import SwiGLU, SwiGLUParallel, GeGLU
|
| 6 |
+
from .attention import MultiHeadAttention
|
| 7 |
+
from .transformer_block import TransformerBlock, WikiMiniModel
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
"RMSNorm",
|
| 11 |
+
"RMSNormOptimized",
|
| 12 |
+
"RotaryPositionEmbeddings",
|
| 13 |
+
"RotaryPositionEmbeddingsComplex",
|
| 14 |
+
"SwiGLU",
|
| 15 |
+
"SwiGLUParallel",
|
| 16 |
+
"GeGLU",
|
| 17 |
+
"MultiHeadAttention",
|
| 18 |
+
"TransformerBlock",
|
| 19 |
+
"WikiMiniModel",
|
| 20 |
+
]
|
src/model/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (692 Bytes). View file
|
|
|
src/model/__pycache__/attention.cpython-313.pyc
ADDED
|
Binary file (12.3 kB). View file
|
|
|
src/model/__pycache__/rmsnorm.cpython-313.pyc
ADDED
|
Binary file (7.81 kB). View file
|
|
|
src/model/__pycache__/rope.cpython-313.pyc
ADDED
|
Binary file (9.95 kB). View file
|
|
|
src/model/__pycache__/swiglu.cpython-313.pyc
ADDED
|
Binary file (8.97 kB). View file
|
|
|
src/model/__pycache__/transformer_block.cpython-313.pyc
ADDED
|
Binary file (18.7 kB). View file
|
|
|
src/model/attention.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Multi-Head Attention with RoPE integration and memory optimizations.
|
| 2 |
+
|
| 3 |
+
Critical implementation details:
|
| 4 |
+
1. Apply RoPE only to Q and K, never to V
|
| 5 |
+
2. Use SDPA for Flash Attention 2 support
|
| 6 |
+
3. Pre-normalization architecture
|
| 7 |
+
4. Memory-efficient implementation
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
import math
|
| 14 |
+
from typing import Optional, Tuple
|
| 15 |
+
from .rope import RotaryPositionEmbeddings
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class MultiHeadAttention(nn.Module):
|
| 19 |
+
"""Multi-Head Attention with RoPE and Flash Attention support.
|
| 20 |
+
|
| 21 |
+
This implementation:
|
| 22 |
+
- Uses Rotary Position Embeddings (RoPE) on Q and K only
|
| 23 |
+
- Supports Flash Attention 2 via torch.nn.functional.scaled_dot_product_attention
|
| 24 |
+
- Uses no bias terms (modern approach)
|
| 25 |
+
- Includes proper causal masking
|
| 26 |
+
- Memory-efficient implementation
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
d_model: int = 768,
|
| 32 |
+
n_heads: int = 12,
|
| 33 |
+
dropout: float = 0.1,
|
| 34 |
+
max_seq_len: int = 2048,
|
| 35 |
+
rope_base: int = 10000,
|
| 36 |
+
rope_percentage: float = 0.5,
|
| 37 |
+
use_flash_attention: bool = True,
|
| 38 |
+
):
|
| 39 |
+
super().__init__()
|
| 40 |
+
|
| 41 |
+
assert d_model % n_heads == 0, f"d_model ({d_model}) must be divisible by n_heads ({n_heads})"
|
| 42 |
+
|
| 43 |
+
self.d_model = d_model
|
| 44 |
+
self.n_heads = n_heads
|
| 45 |
+
self.head_dim = d_model // n_heads
|
| 46 |
+
|
| 47 |
+
# Windows Flash Attention: Test with PyTorch 2.10+ nightly
|
| 48 |
+
# Older versions had freezing issues, but newer versions may work
|
| 49 |
+
import sys
|
| 50 |
+
import logging
|
| 51 |
+
logger = logging.getLogger(__name__)
|
| 52 |
+
|
| 53 |
+
if sys.platform == 'win32' and use_flash_attention:
|
| 54 |
+
# Allow Flash Attention on Windows with PyTorch 2.10+
|
| 55 |
+
# If freezing occurs, set use_flash_attention: false in config
|
| 56 |
+
self.use_flash_attention = use_flash_attention
|
| 57 |
+
logger.info("[Windows] Attempting Flash Attention with PyTorch 2.10+ - if freezing occurs, disable in config")
|
| 58 |
+
elif sys.platform == 'win32':
|
| 59 |
+
self.use_flash_attention = False
|
| 60 |
+
logger.info("[Windows] Flash Attention disabled - using manual attention")
|
| 61 |
+
else:
|
| 62 |
+
self.use_flash_attention = use_flash_attention
|
| 63 |
+
|
| 64 |
+
self.dropout = dropout
|
| 65 |
+
self.scale = 1.0 / math.sqrt(self.head_dim)
|
| 66 |
+
|
| 67 |
+
# Q, K, V projections (no bias)
|
| 68 |
+
self.q_proj = nn.Linear(d_model, d_model, bias=False)
|
| 69 |
+
self.k_proj = nn.Linear(d_model, d_model, bias=False)
|
| 70 |
+
self.v_proj = nn.Linear(d_model, d_model, bias=False)
|
| 71 |
+
self.o_proj = nn.Linear(d_model, d_model, bias=False)
|
| 72 |
+
|
| 73 |
+
# RoPE for positional encoding
|
| 74 |
+
# Apply to only part of head dimensions (typically 50%)
|
| 75 |
+
rope_dim = int(self.head_dim * rope_percentage)
|
| 76 |
+
self.rope_dim = rope_dim
|
| 77 |
+
self.rope = RotaryPositionEmbeddings(
|
| 78 |
+
head_dim=rope_dim,
|
| 79 |
+
max_seq_len=max_seq_len,
|
| 80 |
+
base=rope_base
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
# Dropout
|
| 84 |
+
self.attn_dropout = nn.Dropout(dropout)
|
| 85 |
+
self.resid_dropout = nn.Dropout(dropout)
|
| 86 |
+
|
| 87 |
+
# Pre-allocate causal mask more efficiently
|
| 88 |
+
# We'll create it on-demand based on sequence length
|
| 89 |
+
self.register_buffer('cached_mask', None, persistent=False)
|
| 90 |
+
self.register_buffer('cached_mask_size', torch.tensor(0), persistent=False)
|
| 91 |
+
|
| 92 |
+
def _get_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
|
| 93 |
+
"""Get or create causal mask for the given sequence length.
|
| 94 |
+
|
| 95 |
+
CRITICAL: Always returns mask on the specified device to prevent CPU OOM errors.
|
| 96 |
+
"""
|
| 97 |
+
if self.cached_mask is None or self.cached_mask_size < seq_len:
|
| 98 |
+
# Create a new mask directly on the target device
|
| 99 |
+
mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1)
|
| 100 |
+
mask = mask.masked_fill(mask == 1, float('-inf'))
|
| 101 |
+
self.cached_mask = mask
|
| 102 |
+
self.cached_mask_size = torch.tensor(seq_len)
|
| 103 |
+
|
| 104 |
+
# CRITICAL: Ensure the returned mask is on the correct device
|
| 105 |
+
# This prevents CPU OOM when broadcasting during attn_scores + causal_mask
|
| 106 |
+
return self.cached_mask[:seq_len, :seq_len].to(device)
|
| 107 |
+
|
| 108 |
+
def _apply_rope(
|
| 109 |
+
self,
|
| 110 |
+
q: torch.Tensor,
|
| 111 |
+
k: torch.Tensor,
|
| 112 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 113 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 114 |
+
"""Apply RoPE to partial dimensions of Q and K.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
q: Query tensor [batch, seq_len, n_heads, head_dim]
|
| 118 |
+
k: Key tensor [batch, seq_len, n_heads, head_dim]
|
| 119 |
+
position_ids: Optional custom position IDs
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
Rotated Q and K tensors
|
| 123 |
+
"""
|
| 124 |
+
# Split into RoPE and pass-through dimensions
|
| 125 |
+
if self.rope_dim > 0:
|
| 126 |
+
q_rope, q_pass = q[..., :self.rope_dim], q[..., self.rope_dim:]
|
| 127 |
+
k_rope, k_pass = k[..., :self.rope_dim], k[..., self.rope_dim:]
|
| 128 |
+
|
| 129 |
+
# Apply RoPE to the first part
|
| 130 |
+
q_rope, k_rope = self.rope(q_rope, k_rope, position_ids)
|
| 131 |
+
|
| 132 |
+
# Concatenate back
|
| 133 |
+
q = torch.cat([q_rope, q_pass], dim=-1)
|
| 134 |
+
k = torch.cat([k_rope, k_pass], dim=-1)
|
| 135 |
+
|
| 136 |
+
return q, k
|
| 137 |
+
|
| 138 |
+
def forward(
|
| 139 |
+
self,
|
| 140 |
+
x: torch.Tensor,
|
| 141 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 142 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 143 |
+
use_cache: bool = False,
|
| 144 |
+
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 145 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
| 146 |
+
"""Forward pass of multi-head attention.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
x: Input tensor [batch, seq_len, d_model]
|
| 150 |
+
attention_mask: Optional attention mask
|
| 151 |
+
position_ids: Optional position IDs for RoPE
|
| 152 |
+
use_cache: Whether to return KV cache for inference
|
| 153 |
+
past_kv: Past key-value cache for inference
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
Output tensor and optional KV cache
|
| 157 |
+
"""
|
| 158 |
+
batch_size, seq_len, _ = x.size()
|
| 159 |
+
|
| 160 |
+
# Project to Q, K, V
|
| 161 |
+
q = self.q_proj(x) # [batch, seq_len, d_model]
|
| 162 |
+
k = self.k_proj(x) # [batch, seq_len, d_model]
|
| 163 |
+
v = self.v_proj(x) # [batch, seq_len, d_model]
|
| 164 |
+
|
| 165 |
+
# Reshape for multi-head attention
|
| 166 |
+
# [batch, seq_len, d_model] -> [batch, seq_len, n_heads, head_dim]
|
| 167 |
+
q = q.view(batch_size, seq_len, self.n_heads, self.head_dim)
|
| 168 |
+
k = k.view(batch_size, seq_len, self.n_heads, self.head_dim)
|
| 169 |
+
v = v.view(batch_size, seq_len, self.n_heads, self.head_dim)
|
| 170 |
+
|
| 171 |
+
# Apply RoPE to Q and K only (not V!)
|
| 172 |
+
q, k = self._apply_rope(q, k, position_ids)
|
| 173 |
+
|
| 174 |
+
# Handle KV cache for inference
|
| 175 |
+
if use_cache and past_kv is not None:
|
| 176 |
+
past_k, past_v = past_kv
|
| 177 |
+
k = torch.cat([past_k, k], dim=1)
|
| 178 |
+
v = torch.cat([past_v, v], dim=1)
|
| 179 |
+
|
| 180 |
+
kv_cache = (k, v) if use_cache else None
|
| 181 |
+
|
| 182 |
+
# Transpose for attention computation
|
| 183 |
+
# [batch, seq_len, n_heads, head_dim] -> [batch, n_heads, seq_len, head_dim]
|
| 184 |
+
q = q.transpose(1, 2).contiguous()
|
| 185 |
+
k = k.transpose(1, 2).contiguous()
|
| 186 |
+
v = v.transpose(1, 2).contiguous()
|
| 187 |
+
|
| 188 |
+
# Use Flash Attention 2 via SDPA when available
|
| 189 |
+
# This is MUCH more memory efficient than manual attention
|
| 190 |
+
if self.use_flash_attention and hasattr(F, 'scaled_dot_product_attention'):
|
| 191 |
+
# Flash Attention 2 is automatically used when available
|
| 192 |
+
# It handles the causal mask internally when is_causal=True
|
| 193 |
+
# NOTE: Windows compatibility - skip context manager to avoid freezing
|
| 194 |
+
import sys
|
| 195 |
+
if sys.platform == 'win32':
|
| 196 |
+
# On Windows, use SDPA without explicit kernel selection
|
| 197 |
+
attn_output = F.scaled_dot_product_attention(
|
| 198 |
+
q, k, v,
|
| 199 |
+
attn_mask=attention_mask,
|
| 200 |
+
dropout_p=self.dropout if self.training else 0.0,
|
| 201 |
+
is_causal=True if attention_mask is None else False,
|
| 202 |
+
scale=self.scale,
|
| 203 |
+
)
|
| 204 |
+
else:
|
| 205 |
+
# On Linux, use explicit kernel selection for best performance
|
| 206 |
+
with torch.backends.cuda.sdp_kernel(
|
| 207 |
+
enable_flash=True, # Use Flash Attention when possible
|
| 208 |
+
enable_math=True, # Fallback to math implementation
|
| 209 |
+
enable_mem_efficient=True # Use memory-efficient attention
|
| 210 |
+
):
|
| 211 |
+
attn_output = F.scaled_dot_product_attention(
|
| 212 |
+
q, k, v,
|
| 213 |
+
attn_mask=attention_mask,
|
| 214 |
+
dropout_p=self.dropout if self.training else 0.0,
|
| 215 |
+
is_causal=True if attention_mask is None else False,
|
| 216 |
+
scale=self.scale,
|
| 217 |
+
)
|
| 218 |
+
else:
|
| 219 |
+
# Manual attention computation (fallback)
|
| 220 |
+
# This is memory-intensive and should only be used for small sequences
|
| 221 |
+
attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
|
| 222 |
+
|
| 223 |
+
# Apply causal mask
|
| 224 |
+
if attention_mask is None:
|
| 225 |
+
causal_mask = self._get_causal_mask(seq_len, x.device)
|
| 226 |
+
# Expand mask for batch and heads
|
| 227 |
+
causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
|
| 228 |
+
attn_scores = attn_scores + causal_mask
|
| 229 |
+
else:
|
| 230 |
+
attn_scores = attn_scores + attention_mask
|
| 231 |
+
|
| 232 |
+
# Apply softmax
|
| 233 |
+
attn_weights = F.softmax(attn_scores, dim=-1, dtype=torch.float32).to(q.dtype)
|
| 234 |
+
attn_weights = self.attn_dropout(attn_weights)
|
| 235 |
+
|
| 236 |
+
# Compute output
|
| 237 |
+
attn_output = torch.matmul(attn_weights, v)
|
| 238 |
+
|
| 239 |
+
# Reshape back
|
| 240 |
+
# [batch, n_heads, seq_len, head_dim] -> [batch, seq_len, d_model]
|
| 241 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 242 |
+
attn_output = attn_output.view(batch_size, seq_len, self.d_model)
|
| 243 |
+
|
| 244 |
+
# Output projection
|
| 245 |
+
output = self.o_proj(attn_output)
|
| 246 |
+
output = self.resid_dropout(output)
|
| 247 |
+
|
| 248 |
+
return output, kv_cache
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
# Test the attention implementation
|
| 252 |
+
def test_attention():
|
| 253 |
+
"""Test multi-head attention with various configurations."""
|
| 254 |
+
print("Testing Multi-Head Attention...")
|
| 255 |
+
|
| 256 |
+
# Test configuration
|
| 257 |
+
batch_size = 2
|
| 258 |
+
seq_len = 128
|
| 259 |
+
d_model = 768
|
| 260 |
+
n_heads = 12
|
| 261 |
+
|
| 262 |
+
# Create attention module
|
| 263 |
+
attention = MultiHeadAttention(
|
| 264 |
+
d_model=d_model,
|
| 265 |
+
n_heads=n_heads,
|
| 266 |
+
dropout=0.1,
|
| 267 |
+
max_seq_len=2048,
|
| 268 |
+
rope_percentage=0.5,
|
| 269 |
+
use_flash_attention=True, # Enable Flash Attention
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
# Move to GPU if available
|
| 273 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 274 |
+
attention = attention.to(device)
|
| 275 |
+
attention.eval() # Set to eval mode for testing
|
| 276 |
+
|
| 277 |
+
# Create dummy input
|
| 278 |
+
x = torch.randn(batch_size, seq_len, d_model, device=device, dtype=torch.bfloat16)
|
| 279 |
+
|
| 280 |
+
# Forward pass
|
| 281 |
+
with torch.no_grad():
|
| 282 |
+
output, _ = attention(x)
|
| 283 |
+
|
| 284 |
+
# Check output shape
|
| 285 |
+
assert output.shape == (batch_size, seq_len, d_model), \
|
| 286 |
+
f"Expected shape {(batch_size, seq_len, d_model)}, got {output.shape}"
|
| 287 |
+
|
| 288 |
+
# Check for NaN
|
| 289 |
+
assert not torch.isnan(output).any(), "Output contains NaN values!"
|
| 290 |
+
|
| 291 |
+
print("✓ Multi-Head Attention test passed!")
|
| 292 |
+
print(f" Input shape: {x.shape}")
|
| 293 |
+
print(f" Output shape: {output.shape}")
|
| 294 |
+
print(f" Device: {device}")
|
| 295 |
+
print(f" Memory allocated: {torch.cuda.memory_allocated(device) / 1024**3:.2f} GB")
|
| 296 |
+
|
| 297 |
+
return True
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
if __name__ == "__main__":
|
| 301 |
+
test_attention()
|
src/model/rmsnorm.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Root Mean Square Layer Normalization (RMSNorm) implementation.
|
| 2 |
+
|
| 3 |
+
Critical implementation details:
|
| 4 |
+
1. Use multiplication with rsqrt, NOT division
|
| 5 |
+
2. No mean subtraction (unlike LayerNorm)
|
| 6 |
+
3. Compute in FP32 for numerical stability even when using BF16/FP16
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from typing import Optional
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class RMSNorm(nn.Module):
|
| 15 |
+
"""Root Mean Square Layer Normalization.
|
| 16 |
+
|
| 17 |
+
RMSNorm is a simplification of LayerNorm that removes the mean subtraction
|
| 18 |
+
and only performs re-scaling via root mean square.
|
| 19 |
+
|
| 20 |
+
Based on the paper: 'Root Mean Square Layer Normalization'
|
| 21 |
+
https://arxiv.org/abs/1910.07467
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, hidden_size: int, eps: float = 1e-6):
|
| 25 |
+
"""
|
| 26 |
+
Args:
|
| 27 |
+
hidden_size: Size of the hidden dimension
|
| 28 |
+
eps: Small constant for numerical stability (1e-6 for BF16, 1e-5 for FP16)
|
| 29 |
+
"""
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.hidden_size = hidden_size
|
| 32 |
+
# CRITICAL FIX: Ensure eps is stored as float, not string
|
| 33 |
+
self.eps = float(eps) if isinstance(eps, str) else eps
|
| 34 |
+
|
| 35 |
+
# Learnable scale parameter (gamma)
|
| 36 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 37 |
+
|
| 38 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 39 |
+
"""Apply RMSNorm to input tensor.
|
| 40 |
+
|
| 41 |
+
CRITICAL BUG TO AVOID:
|
| 42 |
+
The most common bug is using division with torch.rsqrt:
|
| 43 |
+
WRONG: x / torch.rsqrt(variance + eps) # This is x * sqrt(variance)
|
| 44 |
+
RIGHT: x * torch.rsqrt(variance + eps) # This is x / sqrt(variance)
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
x: Input tensor of shape [..., hidden_size]
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
Normalized tensor of same shape as input
|
| 51 |
+
"""
|
| 52 |
+
# Store original dtype (for mixed precision training)
|
| 53 |
+
input_dtype = x.dtype
|
| 54 |
+
|
| 55 |
+
# CRITICAL: Compute in float32 for numerical stability
|
| 56 |
+
x_float32 = x.float()
|
| 57 |
+
|
| 58 |
+
# Compute RMS (root mean square)
|
| 59 |
+
# RMS = sqrt(mean(x^2))
|
| 60 |
+
variance = x_float32.pow(2).mean(dim=-1, keepdim=True)
|
| 61 |
+
|
| 62 |
+
# CRITICAL: Use rsqrt (reciprocal square root) with multiplication
|
| 63 |
+
# rsqrt(x) = 1/sqrt(x), so x * rsqrt(variance) = x / sqrt(variance)
|
| 64 |
+
# PERFORMANCE FIX: PyTorch automatically broadcasts scalars, no need for tensor()
|
| 65 |
+
x_normalized = x_float32 * torch.rsqrt(variance + self.eps)
|
| 66 |
+
|
| 67 |
+
# Apply learned scale and cast back to original dtype
|
| 68 |
+
return self.weight * x_normalized.to(input_dtype)
|
| 69 |
+
|
| 70 |
+
def extra_repr(self) -> str:
|
| 71 |
+
return f'hidden_size={self.hidden_size}, eps={self.eps}'
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class RMSNormOptimized(nn.Module):
|
| 75 |
+
"""Optimized RMSNorm with optional fused operations.
|
| 76 |
+
|
| 77 |
+
This version includes optimizations for better performance:
|
| 78 |
+
1. Option for in-place operations
|
| 79 |
+
2. Support for sequence parallelism
|
| 80 |
+
3. Optional residual connection fusion
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
def __init__(
|
| 84 |
+
self,
|
| 85 |
+
hidden_size: int,
|
| 86 |
+
eps: float = 1e-6,
|
| 87 |
+
elementwise_affine: bool = True,
|
| 88 |
+
memory_efficient: bool = False,
|
| 89 |
+
):
|
| 90 |
+
super().__init__()
|
| 91 |
+
self.hidden_size = hidden_size
|
| 92 |
+
# CRITICAL FIX: Ensure eps is stored as float, not string
|
| 93 |
+
self.eps = float(eps) if isinstance(eps, str) else eps
|
| 94 |
+
self.elementwise_affine = elementwise_affine
|
| 95 |
+
self.memory_efficient = memory_efficient
|
| 96 |
+
|
| 97 |
+
if self.elementwise_affine:
|
| 98 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 99 |
+
else:
|
| 100 |
+
self.register_parameter('weight', None)
|
| 101 |
+
|
| 102 |
+
def forward(
|
| 103 |
+
self,
|
| 104 |
+
x: torch.Tensor,
|
| 105 |
+
residual: Optional[torch.Tensor] = None,
|
| 106 |
+
) -> torch.Tensor:
|
| 107 |
+
"""Apply RMSNorm with optional residual connection.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
x: Input tensor
|
| 111 |
+
residual: Optional residual to add before normalization
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
Normalized tensor (and residual if provided)
|
| 115 |
+
"""
|
| 116 |
+
# Add residual if provided (pre-norm architecture)
|
| 117 |
+
if residual is not None:
|
| 118 |
+
x = x + residual
|
| 119 |
+
residual = x # Save for skip connection
|
| 120 |
+
|
| 121 |
+
# Original dtype for mixed precision
|
| 122 |
+
input_dtype = x.dtype
|
| 123 |
+
|
| 124 |
+
# Compute in FP32
|
| 125 |
+
if self.memory_efficient:
|
| 126 |
+
# In-place operations to save memory
|
| 127 |
+
x = x.float()
|
| 128 |
+
variance = x.pow_(2).mean(dim=-1, keepdim=True)
|
| 129 |
+
# PERFORMANCE FIX: Use scalar directly
|
| 130 |
+
x.mul_(torch.rsqrt(variance + self.eps))
|
| 131 |
+
else:
|
| 132 |
+
# Standard computation
|
| 133 |
+
x_float32 = x.float()
|
| 134 |
+
variance = x_float32.pow(2).mean(dim=-1, keepdim=True)
|
| 135 |
+
# PERFORMANCE FIX: Use scalar directly
|
| 136 |
+
x = x_float32 * torch.rsqrt(variance + self.eps)
|
| 137 |
+
|
| 138 |
+
# Apply weight and cast back
|
| 139 |
+
if self.elementwise_affine:
|
| 140 |
+
x = self.weight * x
|
| 141 |
+
|
| 142 |
+
x = x.to(input_dtype)
|
| 143 |
+
|
| 144 |
+
if residual is not None:
|
| 145 |
+
return x, residual
|
| 146 |
+
return x
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def rmsnorm_func(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
|
| 150 |
+
"""Functional version of RMSNorm for use in torch.compile or custom kernels.
|
| 151 |
+
|
| 152 |
+
This can be used with torch.compile for better optimization.
|
| 153 |
+
"""
|
| 154 |
+
input_dtype = x.dtype
|
| 155 |
+
x = x.float()
|
| 156 |
+
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
| 157 |
+
# Ensure eps is properly handled
|
| 158 |
+
eps_val = float(eps) if isinstance(eps, str) else eps
|
| 159 |
+
x = x * torch.rsqrt(variance + eps_val)
|
| 160 |
+
return (weight * x).to(input_dtype)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
# Comparison with LayerNorm for reference
|
| 164 |
+
def compare_normalization():
|
| 165 |
+
"""Compare RMSNorm with LayerNorm to understand the differences."""
|
| 166 |
+
import torch.nn as nn
|
| 167 |
+
|
| 168 |
+
batch_size, seq_len, hidden = 2, 10, 768
|
| 169 |
+
x = torch.randn(batch_size, seq_len, hidden)
|
| 170 |
+
|
| 171 |
+
# LayerNorm: normalizes by mean and variance
|
| 172 |
+
layer_norm = nn.LayerNorm(hidden)
|
| 173 |
+
ln_out = layer_norm(x)
|
| 174 |
+
|
| 175 |
+
# RMSNorm: normalizes by RMS only (no mean subtraction)
|
| 176 |
+
rms_norm = RMSNorm(hidden)
|
| 177 |
+
rms_out = rms_norm(x)
|
| 178 |
+
|
| 179 |
+
print(f"Input shape: {x.shape}")
|
| 180 |
+
print(f"LayerNorm output shape: {ln_out.shape}")
|
| 181 |
+
print(f"RMSNorm output shape: {rms_out.shape}")
|
| 182 |
+
print(f"Mean difference: {(ln_out - rms_out).abs().mean().item():.6f}")
|
| 183 |
+
|
| 184 |
+
# RMSNorm is 15-20% faster due to simpler computation
|
| 185 |
+
return ln_out, rms_out
|
src/model/rope.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Rotary Position Embeddings (RoPE) implementation.
|
| 2 |
+
|
| 3 |
+
Critical implementation details:
|
| 4 |
+
1. Apply RoPE only to Q and K, never to V
|
| 5 |
+
2. Use head_dim, not full model dimension
|
| 6 |
+
3. Ensure proper dimension pairing for rotation
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import math
|
| 12 |
+
from typing import Optional, Tuple
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class RotaryPositionEmbeddings(nn.Module):
|
| 16 |
+
"""Rotary Position Embeddings (RoPE) for transformer models.
|
| 17 |
+
|
| 18 |
+
Based on the paper: 'RoFormer: Enhanced Transformer with Rotary Position Embedding'
|
| 19 |
+
https://arxiv.org/abs/2104.09864
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
head_dim: int,
|
| 25 |
+
max_seq_len: int = 2048,
|
| 26 |
+
base: int = 10000,
|
| 27 |
+
device: Optional[torch.device] = None,
|
| 28 |
+
):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.head_dim = head_dim
|
| 31 |
+
self.max_seq_len = max_seq_len
|
| 32 |
+
self.base = base
|
| 33 |
+
|
| 34 |
+
# CRITICAL: head_dim must be even for proper pairing
|
| 35 |
+
assert head_dim % 2 == 0, f"head_dim must be even, got {head_dim}"
|
| 36 |
+
|
| 37 |
+
# Precompute frequencies
|
| 38 |
+
self._precompute_freqs(device)
|
| 39 |
+
|
| 40 |
+
def _precompute_freqs(self, device: Optional[torch.device] = None):
|
| 41 |
+
"""Precompute the frequency tensor for RoPE."""
|
| 42 |
+
# Calculate theta frequencies
|
| 43 |
+
# theta_i = base^(-2i/d) for i in [0, 1, ..., d/2-1]
|
| 44 |
+
theta = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim))
|
| 45 |
+
|
| 46 |
+
# Create position indices
|
| 47 |
+
positions = torch.arange(self.max_seq_len).float()
|
| 48 |
+
|
| 49 |
+
# Compute outer product: [seq_len, head_dim/2]
|
| 50 |
+
freqs = torch.einsum('i,j->ij', positions, theta)
|
| 51 |
+
|
| 52 |
+
# Convert to cos and sin for rotation
|
| 53 |
+
freqs_cos = torch.cos(freqs) # [seq_len, head_dim/2]
|
| 54 |
+
freqs_sin = torch.sin(freqs) # [seq_len, head_dim/2]
|
| 55 |
+
|
| 56 |
+
# Duplicate for full dimension coverage
|
| 57 |
+
# [seq_len, head_dim/2] -> [seq_len, head_dim]
|
| 58 |
+
freqs_cos = torch.cat([freqs_cos, freqs_cos], dim=-1)
|
| 59 |
+
freqs_sin = torch.cat([freqs_sin, freqs_sin], dim=-1)
|
| 60 |
+
|
| 61 |
+
# Register as buffers (not trainable, moves with model to device)
|
| 62 |
+
self.register_buffer('freqs_cos', freqs_cos, persistent=False)
|
| 63 |
+
self.register_buffer('freqs_sin', freqs_sin, persistent=False)
|
| 64 |
+
|
| 65 |
+
def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
|
| 66 |
+
"""Rotate half the hidden dims of the input.
|
| 67 |
+
|
| 68 |
+
CRITICAL: This is the most common bug - incorrect dimension pairing.
|
| 69 |
+
For input [1, 2, 3, 4], output should be [-3, -4, 1, 2].
|
| 70 |
+
"""
|
| 71 |
+
x1 = x[..., :x.shape[-1] // 2]
|
| 72 |
+
x2 = x[..., x.shape[-1] // 2:]
|
| 73 |
+
return torch.cat([-x2, x1], dim=-1)
|
| 74 |
+
|
| 75 |
+
def apply_rotary_pos_emb(
|
| 76 |
+
self,
|
| 77 |
+
q: torch.Tensor,
|
| 78 |
+
k: torch.Tensor,
|
| 79 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 80 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 81 |
+
"""Apply rotary position embeddings to query and key tensors.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
q: Query tensor of shape [batch, seq_len, num_heads, head_dim]
|
| 85 |
+
k: Key tensor of shape [batch, seq_len, num_heads, head_dim]
|
| 86 |
+
position_ids: Optional custom position IDs
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
Tuple of rotated (q, k) tensors
|
| 90 |
+
"""
|
| 91 |
+
seq_len = q.shape[1]
|
| 92 |
+
|
| 93 |
+
# Get the frequency tensors for current sequence length
|
| 94 |
+
if position_ids is not None:
|
| 95 |
+
freqs_cos = self.freqs_cos[position_ids]
|
| 96 |
+
freqs_sin = self.freqs_sin[position_ids]
|
| 97 |
+
else:
|
| 98 |
+
freqs_cos = self.freqs_cos[:seq_len]
|
| 99 |
+
freqs_sin = self.freqs_sin[:seq_len]
|
| 100 |
+
|
| 101 |
+
# Reshape for broadcasting
|
| 102 |
+
# [seq_len, head_dim] -> [1, seq_len, 1, head_dim]
|
| 103 |
+
freqs_cos = freqs_cos[None, :, None, :]
|
| 104 |
+
freqs_sin = freqs_sin[None, :, None, :]
|
| 105 |
+
|
| 106 |
+
# Apply rotation using the formula:
|
| 107 |
+
# x_rotated = x * cos + rotate_half(x) * sin
|
| 108 |
+
q_rotated = q * freqs_cos + self.rotate_half(q) * freqs_sin
|
| 109 |
+
k_rotated = k * freqs_cos + self.rotate_half(k) * freqs_sin
|
| 110 |
+
|
| 111 |
+
return q_rotated, k_rotated
|
| 112 |
+
|
| 113 |
+
def forward(
|
| 114 |
+
self,
|
| 115 |
+
q: torch.Tensor,
|
| 116 |
+
k: torch.Tensor,
|
| 117 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 118 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 119 |
+
"""Forward pass - apply RoPE to Q and K only.
|
| 120 |
+
|
| 121 |
+
CRITICAL: Never apply RoPE to V (value) tensor!
|
| 122 |
+
"""
|
| 123 |
+
return self.apply_rotary_pos_emb(q, k, position_ids)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
# Alternative implementation using complex numbers directly
|
| 127 |
+
class RotaryPositionEmbeddingsComplex(nn.Module):
|
| 128 |
+
"""Alternative RoPE implementation using complex number operations.
|
| 129 |
+
|
| 130 |
+
This can be more efficient on some hardware but requires careful handling.
|
| 131 |
+
"""
|
| 132 |
+
|
| 133 |
+
def __init__(
|
| 134 |
+
self,
|
| 135 |
+
head_dim: int,
|
| 136 |
+
max_seq_len: int = 2048,
|
| 137 |
+
base: int = 10000,
|
| 138 |
+
device: Optional[torch.device] = None,
|
| 139 |
+
):
|
| 140 |
+
super().__init__()
|
| 141 |
+
self.head_dim = head_dim
|
| 142 |
+
self.max_seq_len = max_seq_len
|
| 143 |
+
self.base = base
|
| 144 |
+
|
| 145 |
+
assert head_dim % 2 == 0, f"head_dim must be even, got {head_dim}"
|
| 146 |
+
|
| 147 |
+
# Precompute complex exponentials
|
| 148 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
|
| 149 |
+
t = torch.arange(max_seq_len, dtype=inv_freq.dtype)
|
| 150 |
+
freqs = torch.einsum('i,j->ij', t, inv_freq)
|
| 151 |
+
|
| 152 |
+
# Store as cos/sin values
|
| 153 |
+
emb = torch.cat([freqs, freqs], dim=-1)
|
| 154 |
+
self.register_buffer('cos_cached', emb.cos()[None, :, None, :])
|
| 155 |
+
self.register_buffer('sin_cached', emb.sin()[None, :, None, :])
|
| 156 |
+
|
| 157 |
+
def forward(
|
| 158 |
+
self,
|
| 159 |
+
q: torch.Tensor,
|
| 160 |
+
k: torch.Tensor,
|
| 161 |
+
seq_len: Optional[int] = None,
|
| 162 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 163 |
+
"""Apply RoPE using cached cos/sin values."""
|
| 164 |
+
if seq_len is None:
|
| 165 |
+
seq_len = q.shape[1]
|
| 166 |
+
|
| 167 |
+
# Apply rotation
|
| 168 |
+
q_embed = (q * self.cos_cached[:, :seq_len]) + \
|
| 169 |
+
(self.rotate_half(q) * self.sin_cached[:, :seq_len])
|
| 170 |
+
k_embed = (k * self.cos_cached[:, :seq_len]) + \
|
| 171 |
+
(self.rotate_half(k) * self.sin_cached[:, :seq_len])
|
| 172 |
+
|
| 173 |
+
return q_embed, k_embed
|
| 174 |
+
|
| 175 |
+
def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
|
| 176 |
+
"""Rotate half the hidden dims."""
|
| 177 |
+
x1, x2 = x.chunk(2, dim=-1)
|
| 178 |
+
return torch.cat([-x2, x1], dim=-1)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
# Test function for RoPE
|
| 182 |
+
def test_rope():
|
| 183 |
+
"""Test RoPE implementation."""
|
| 184 |
+
print("Testing RoPE implementation...")
|
| 185 |
+
|
| 186 |
+
batch_size = 2
|
| 187 |
+
seq_len = 128
|
| 188 |
+
n_heads = 12
|
| 189 |
+
head_dim = 64
|
| 190 |
+
|
| 191 |
+
# Create RoPE module
|
| 192 |
+
rope = RotaryPositionEmbeddings(head_dim=head_dim, max_seq_len=2048)
|
| 193 |
+
|
| 194 |
+
# Create dummy Q and K tensors
|
| 195 |
+
q = torch.randn(batch_size, seq_len, n_heads, head_dim)
|
| 196 |
+
k = torch.randn(batch_size, seq_len, n_heads, head_dim)
|
| 197 |
+
|
| 198 |
+
# Apply RoPE
|
| 199 |
+
q_rot, k_rot = rope(q, k)
|
| 200 |
+
|
| 201 |
+
# Check shapes
|
| 202 |
+
assert q_rot.shape == q.shape, f"Q shape mismatch: {q_rot.shape} != {q.shape}"
|
| 203 |
+
assert k_rot.shape == k.shape, f"K shape mismatch: {k_rot.shape} != {k.shape}"
|
| 204 |
+
|
| 205 |
+
# Check for NaN
|
| 206 |
+
assert not torch.isnan(q_rot).any(), "Q contains NaN after RoPE"
|
| 207 |
+
assert not torch.isnan(k_rot).any(), "K contains NaN after RoPE"
|
| 208 |
+
|
| 209 |
+
print("✓ RoPE test passed!")
|
| 210 |
+
print(f" Input shape: {q.shape}")
|
| 211 |
+
print(f" Output shape: {q_rot.shape}")
|
| 212 |
+
|
| 213 |
+
return True
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
if __name__ == "__main__":
|
| 217 |
+
test_rope()
|
src/model/swiglu.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SwiGLU (Swish-Gated Linear Unit) activation function implementation.
|
| 2 |
+
|
| 3 |
+
Critical implementation details:
|
| 4 |
+
1. Requires THREE weight matrices (gate, value, down-projection)
|
| 5 |
+
2. Hidden dimension should be adjusted to ~8/3 * d_model for parameter parity
|
| 6 |
+
3. No bias terms in modern implementations
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from typing import Optional
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class SwiGLU(nn.Module):
|
| 16 |
+
"""Swish-Gated Linear Unit activation function.
|
| 17 |
+
|
| 18 |
+
SwiGLU combines the Swish activation (SiLU) with a gating mechanism
|
| 19 |
+
for improved gradient flow in deep networks.
|
| 20 |
+
|
| 21 |
+
Based on the paper: 'GLU Variants Improve Transformer'
|
| 22 |
+
https://arxiv.org/abs/2002.05202
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
input_dim: int,
|
| 28 |
+
hidden_dim: Optional[int] = None,
|
| 29 |
+
output_dim: Optional[int] = None,
|
| 30 |
+
multiple_of: int = 256,
|
| 31 |
+
bias: bool = False,
|
| 32 |
+
):
|
| 33 |
+
"""
|
| 34 |
+
Args:
|
| 35 |
+
input_dim: Input dimension (d_model)
|
| 36 |
+
hidden_dim: Hidden dimension for FFN. If None, uses 8/3 * input_dim
|
| 37 |
+
output_dim: Output dimension. If None, uses input_dim
|
| 38 |
+
multiple_of: Round hidden_dim to nearest multiple for hardware efficiency
|
| 39 |
+
bias: Whether to use bias terms (modern LLMs use False)
|
| 40 |
+
"""
|
| 41 |
+
super().__init__()
|
| 42 |
+
|
| 43 |
+
self.input_dim = input_dim
|
| 44 |
+
self.output_dim = output_dim or input_dim
|
| 45 |
+
|
| 46 |
+
# CRITICAL: Adjust hidden dimension for parameter parity
|
| 47 |
+
# Standard FFN with ReLU/GELU uses 4 * d_model
|
| 48 |
+
# SwiGLU needs 3 matrices, so use (8/3) * d_model for same param count
|
| 49 |
+
if hidden_dim is None:
|
| 50 |
+
hidden_dim = int(8 * input_dim / 3)
|
| 51 |
+
|
| 52 |
+
# Round to nearest multiple for better hardware utilization
|
| 53 |
+
if multiple_of > 1:
|
| 54 |
+
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
| 55 |
+
|
| 56 |
+
self.hidden_dim = hidden_dim
|
| 57 |
+
|
| 58 |
+
# Three linear projections required for SwiGLU
|
| 59 |
+
self.w_gate = nn.Linear(input_dim, hidden_dim, bias=bias) # Gate projection
|
| 60 |
+
self.w_up = nn.Linear(input_dim, hidden_dim, bias=bias) # Value/up projection
|
| 61 |
+
self.w_down = nn.Linear(hidden_dim, self.output_dim, bias=bias) # Down projection
|
| 62 |
+
|
| 63 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 64 |
+
"""Apply SwiGLU activation.
|
| 65 |
+
|
| 66 |
+
Formula: SwiGLU(x) = (Swish(xW_gate) ⊗ xW_up) W_down
|
| 67 |
+
where Swish(x) = x * sigmoid(x) = SiLU(x)
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
x: Input tensor of shape [..., input_dim]
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
Output tensor of shape [..., output_dim]
|
| 74 |
+
"""
|
| 75 |
+
# Gate path with Swish/SiLU activation
|
| 76 |
+
gate = F.silu(self.w_gate(x))
|
| 77 |
+
|
| 78 |
+
# Value path (no activation)
|
| 79 |
+
value = self.w_up(x)
|
| 80 |
+
|
| 81 |
+
# Element-wise multiplication (gating)
|
| 82 |
+
hidden = gate * value
|
| 83 |
+
|
| 84 |
+
# Down projection to output dimension
|
| 85 |
+
output = self.w_down(hidden)
|
| 86 |
+
|
| 87 |
+
return output
|
| 88 |
+
|
| 89 |
+
def extra_repr(self) -> str:
|
| 90 |
+
return (
|
| 91 |
+
f'input_dim={self.input_dim}, '
|
| 92 |
+
f'hidden_dim={self.hidden_dim}, '
|
| 93 |
+
f'output_dim={self.output_dim}'
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class SwiGLUParallel(nn.Module):
|
| 98 |
+
"""Parallel version of SwiGLU that combines gate and up projections.
|
| 99 |
+
|
| 100 |
+
This is more efficient as it reduces the number of separate matmuls.
|
| 101 |
+
Used in models like LLaMA and Mistral.
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
def __init__(
|
| 105 |
+
self,
|
| 106 |
+
input_dim: int,
|
| 107 |
+
hidden_dim: Optional[int] = None,
|
| 108 |
+
output_dim: Optional[int] = None,
|
| 109 |
+
multiple_of: int = 256,
|
| 110 |
+
bias: bool = False,
|
| 111 |
+
):
|
| 112 |
+
super().__init__()
|
| 113 |
+
|
| 114 |
+
self.input_dim = input_dim
|
| 115 |
+
self.output_dim = output_dim or input_dim
|
| 116 |
+
|
| 117 |
+
if hidden_dim is None:
|
| 118 |
+
hidden_dim = int(8 * input_dim / 3)
|
| 119 |
+
|
| 120 |
+
if multiple_of > 1:
|
| 121 |
+
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
| 122 |
+
|
| 123 |
+
self.hidden_dim = hidden_dim
|
| 124 |
+
|
| 125 |
+
# Combined gate and up projection for efficiency
|
| 126 |
+
# Output shape: [batch, seq, 2 * hidden_dim]
|
| 127 |
+
self.w_gate_up = nn.Linear(input_dim, 2 * hidden_dim, bias=bias)
|
| 128 |
+
self.w_down = nn.Linear(hidden_dim, self.output_dim, bias=bias)
|
| 129 |
+
|
| 130 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 131 |
+
"""Apply SwiGLU with parallel projections."""
|
| 132 |
+
# Single matmul for both gate and up projections
|
| 133 |
+
gate_up = self.w_gate_up(x)
|
| 134 |
+
|
| 135 |
+
# Split into gate and up components
|
| 136 |
+
gate, up = gate_up.chunk(2, dim=-1)
|
| 137 |
+
|
| 138 |
+
# Apply SwiGLU
|
| 139 |
+
hidden = F.silu(gate) * up
|
| 140 |
+
output = self.w_down(hidden)
|
| 141 |
+
|
| 142 |
+
return output
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class GeGLU(nn.Module):
|
| 146 |
+
"""GELU-Gated Linear Unit - alternative to SwiGLU.
|
| 147 |
+
|
| 148 |
+
Some models use GeGLU instead of SwiGLU. The difference is using
|
| 149 |
+
GELU instead of SiLU for the gating activation.
|
| 150 |
+
"""
|
| 151 |
+
|
| 152 |
+
def __init__(
|
| 153 |
+
self,
|
| 154 |
+
input_dim: int,
|
| 155 |
+
hidden_dim: Optional[int] = None,
|
| 156 |
+
output_dim: Optional[int] = None,
|
| 157 |
+
bias: bool = False,
|
| 158 |
+
):
|
| 159 |
+
super().__init__()
|
| 160 |
+
|
| 161 |
+
self.input_dim = input_dim
|
| 162 |
+
self.output_dim = output_dim or input_dim
|
| 163 |
+
|
| 164 |
+
if hidden_dim is None:
|
| 165 |
+
hidden_dim = int(8 * input_dim / 3)
|
| 166 |
+
|
| 167 |
+
self.hidden_dim = hidden_dim
|
| 168 |
+
|
| 169 |
+
self.w_gate = nn.Linear(input_dim, hidden_dim, bias=bias)
|
| 170 |
+
self.w_up = nn.Linear(input_dim, hidden_dim, bias=bias)
|
| 171 |
+
self.w_down = nn.Linear(hidden_dim, self.output_dim, bias=bias)
|
| 172 |
+
|
| 173 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 174 |
+
"""Apply GeGLU activation."""
|
| 175 |
+
gate = F.gelu(self.w_gate(x))
|
| 176 |
+
value = self.w_up(x)
|
| 177 |
+
hidden = gate * value
|
| 178 |
+
output = self.w_down(hidden)
|
| 179 |
+
return output
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def calculate_ffn_params(d_model: int, activation: str = "swiglu") -> dict:
|
| 183 |
+
"""Calculate FFN parameters for different activation functions.
|
| 184 |
+
|
| 185 |
+
This helper ensures parameter parity across different activation types.
|
| 186 |
+
"""
|
| 187 |
+
if activation == "relu" or activation == "gelu":
|
| 188 |
+
# Standard FFN: 2 matrices
|
| 189 |
+
hidden_dim = 4 * d_model
|
| 190 |
+
num_params = 2 * d_model * hidden_dim
|
| 191 |
+
elif activation in ["swiglu", "geglu"]:
|
| 192 |
+
# Gated FFN: 3 matrices, adjust hidden dimension
|
| 193 |
+
hidden_dim = int(8 * d_model / 3)
|
| 194 |
+
# Round to multiple of 256 for hardware efficiency
|
| 195 |
+
hidden_dim = 256 * ((hidden_dim + 255) // 256)
|
| 196 |
+
num_params = d_model * hidden_dim * 2 + hidden_dim * d_model
|
| 197 |
+
else:
|
| 198 |
+
raise ValueError(f"Unknown activation: {activation}")
|
| 199 |
+
|
| 200 |
+
return {
|
| 201 |
+
"activation": activation,
|
| 202 |
+
"d_model": d_model,
|
| 203 |
+
"hidden_dim": hidden_dim,
|
| 204 |
+
"num_params": num_params,
|
| 205 |
+
"params_millions": num_params / 1e6,
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
# Example usage and parameter comparison
|
| 210 |
+
if __name__ == "__main__":
|
| 211 |
+
d_model = 768
|
| 212 |
+
|
| 213 |
+
# Compare parameter counts
|
| 214 |
+
print("FFN Parameter Comparison:")
|
| 215 |
+
for act in ["relu", "gelu", "swiglu"]:
|
| 216 |
+
params = calculate_ffn_params(d_model, act)
|
| 217 |
+
print(f"{act.upper()}:")
|
| 218 |
+
print(f" Hidden dim: {params['hidden_dim']}")
|
| 219 |
+
print(f" Parameters: {params['params_millions']:.2f}M")
|
| 220 |
+
|
| 221 |
+
# Test SwiGLU
|
| 222 |
+
batch_size, seq_len = 2, 512
|
| 223 |
+
x = torch.randn(batch_size, seq_len, d_model)
|
| 224 |
+
|
| 225 |
+
swiglu = SwiGLU(d_model)
|
| 226 |
+
output = swiglu(x)
|
| 227 |
+
|
| 228 |
+
print(f"\nSwiGLU Test:")
|
| 229 |
+
print(f"Input shape: {x.shape}")
|
| 230 |
+
print(f"Output shape: {output.shape}")
|
| 231 |
+
print(f"SwiGLU parameters: {sum(p.numel() for p in swiglu.parameters()) / 1e6:.2f}M")
|
src/model/transformer_block.py
ADDED
|
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Transformer block with pre-normalization architecture and memory optimizations.
|
| 2 |
+
|
| 3 |
+
Critical implementation details:
|
| 4 |
+
1. Pre-normalization: RMSNorm BEFORE attention and FFN
|
| 5 |
+
2. Residual connections after each sub-layer
|
| 6 |
+
3. Modern component stack: RoPE + RMSNorm + SwiGLU
|
| 7 |
+
4. Gradient checkpointing support for memory efficiency
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
from typing import Optional, Tuple, Dict, Any
|
| 13 |
+
from torch.utils.checkpoint import checkpoint
|
| 14 |
+
|
| 15 |
+
from .rmsnorm import RMSNorm
|
| 16 |
+
from .attention import MultiHeadAttention
|
| 17 |
+
from .swiglu import SwiGLU
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class TransformerBlock(nn.Module):
|
| 21 |
+
"""Single transformer block with pre-normalization.
|
| 22 |
+
|
| 23 |
+
This follows the modern architecture used in LLaMA, Mistral, etc:
|
| 24 |
+
- Pre-normalization with RMSNorm
|
| 25 |
+
- Multi-head attention with RoPE
|
| 26 |
+
- SwiGLU activation in FFN
|
| 27 |
+
- Residual connections
|
| 28 |
+
- Gradient checkpointing support
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
d_model: int = 768,
|
| 34 |
+
n_heads: int = 12,
|
| 35 |
+
d_ffn: Optional[int] = None,
|
| 36 |
+
dropout: float = 0.1,
|
| 37 |
+
max_seq_len: int = 2048,
|
| 38 |
+
rope_base: int = 10000,
|
| 39 |
+
rope_percentage: float = 0.5,
|
| 40 |
+
rms_norm_eps: float = 1e-6,
|
| 41 |
+
use_flash_attention: bool = True,
|
| 42 |
+
use_gradient_checkpointing: bool = False,
|
| 43 |
+
):
|
| 44 |
+
super().__init__()
|
| 45 |
+
|
| 46 |
+
self.d_model = d_model
|
| 47 |
+
self.n_heads = n_heads
|
| 48 |
+
self.use_gradient_checkpointing = use_gradient_checkpointing
|
| 49 |
+
|
| 50 |
+
# Pre-normalization layers
|
| 51 |
+
self.attn_norm = RMSNorm(d_model, eps=rms_norm_eps)
|
| 52 |
+
self.ffn_norm = RMSNorm(d_model, eps=rms_norm_eps)
|
| 53 |
+
|
| 54 |
+
# Multi-head attention with RoPE
|
| 55 |
+
self.attention = MultiHeadAttention(
|
| 56 |
+
d_model=d_model,
|
| 57 |
+
n_heads=n_heads,
|
| 58 |
+
dropout=dropout,
|
| 59 |
+
max_seq_len=max_seq_len,
|
| 60 |
+
rope_base=rope_base,
|
| 61 |
+
rope_percentage=rope_percentage,
|
| 62 |
+
use_flash_attention=use_flash_attention,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
# SwiGLU FFN
|
| 66 |
+
# Default hidden dimension: 8/3 * d_model for parameter parity
|
| 67 |
+
if d_ffn is None:
|
| 68 |
+
d_ffn = int(8 * d_model / 3)
|
| 69 |
+
# Round to multiple of 256 for hardware efficiency
|
| 70 |
+
d_ffn = 256 * ((d_ffn + 255) // 256)
|
| 71 |
+
|
| 72 |
+
self.ffn = SwiGLU(
|
| 73 |
+
input_dim=d_model,
|
| 74 |
+
hidden_dim=d_ffn,
|
| 75 |
+
output_dim=d_model,
|
| 76 |
+
bias=False,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
def _attention_block(
|
| 80 |
+
self,
|
| 81 |
+
x: torch.Tensor,
|
| 82 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 83 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 84 |
+
use_cache: bool = False,
|
| 85 |
+
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 86 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
| 87 |
+
"""Attention sub-block with pre-norm."""
|
| 88 |
+
# Pre-normalization
|
| 89 |
+
x_norm = self.attn_norm(x)
|
| 90 |
+
|
| 91 |
+
# Multi-head attention
|
| 92 |
+
attn_output, kv_cache = self.attention(
|
| 93 |
+
x_norm,
|
| 94 |
+
attention_mask=attention_mask,
|
| 95 |
+
position_ids=position_ids,
|
| 96 |
+
use_cache=use_cache,
|
| 97 |
+
past_kv=past_kv,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
# Residual connection
|
| 101 |
+
return attn_output, kv_cache
|
| 102 |
+
|
| 103 |
+
def _ffn_block(self, x: torch.Tensor) -> torch.Tensor:
|
| 104 |
+
"""Feed-forward sub-block with pre-norm."""
|
| 105 |
+
# Pre-normalization
|
| 106 |
+
x_norm = self.ffn_norm(x)
|
| 107 |
+
|
| 108 |
+
# Feed-forward
|
| 109 |
+
ffn_output = self.ffn(x_norm)
|
| 110 |
+
|
| 111 |
+
return ffn_output
|
| 112 |
+
|
| 113 |
+
def forward(
|
| 114 |
+
self,
|
| 115 |
+
x: torch.Tensor,
|
| 116 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 117 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 118 |
+
use_cache: bool = False,
|
| 119 |
+
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 120 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
| 121 |
+
"""Forward pass of transformer block.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
x: Input tensor [batch, seq_len, d_model]
|
| 125 |
+
attention_mask: Optional attention mask
|
| 126 |
+
position_ids: Optional position IDs for RoPE
|
| 127 |
+
use_cache: Whether to return KV cache
|
| 128 |
+
past_kv: Past key-value cache
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
Output tensor and optional KV cache
|
| 132 |
+
"""
|
| 133 |
+
# Attention block with residual
|
| 134 |
+
if self.use_gradient_checkpointing and self.training:
|
| 135 |
+
# Use gradient checkpointing to save memory during training
|
| 136 |
+
def attention_fn(x_in):
|
| 137 |
+
attn_out, _ = self._attention_block(
|
| 138 |
+
x_in,
|
| 139 |
+
attention_mask=attention_mask,
|
| 140 |
+
position_ids=position_ids,
|
| 141 |
+
use_cache=False, # Can't use cache with checkpointing
|
| 142 |
+
past_kv=None,
|
| 143 |
+
)
|
| 144 |
+
return attn_out
|
| 145 |
+
|
| 146 |
+
attn_output = checkpoint(attention_fn, x, use_reentrant=False)
|
| 147 |
+
kv_cache = None
|
| 148 |
+
else:
|
| 149 |
+
attn_output, kv_cache = self._attention_block(
|
| 150 |
+
x,
|
| 151 |
+
attention_mask=attention_mask,
|
| 152 |
+
position_ids=position_ids,
|
| 153 |
+
use_cache=use_cache,
|
| 154 |
+
past_kv=past_kv,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
# Add residual for attention
|
| 158 |
+
x = x + attn_output
|
| 159 |
+
|
| 160 |
+
# FFN block with residual
|
| 161 |
+
if self.use_gradient_checkpointing and self.training:
|
| 162 |
+
# Use gradient checkpointing for FFN as well
|
| 163 |
+
ffn_output = checkpoint(self._ffn_block, x, use_reentrant=False)
|
| 164 |
+
else:
|
| 165 |
+
ffn_output = self._ffn_block(x)
|
| 166 |
+
|
| 167 |
+
# Add residual for FFN
|
| 168 |
+
x = x + ffn_output
|
| 169 |
+
|
| 170 |
+
return x, kv_cache
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class WikiMiniModel(nn.Module):
|
| 174 |
+
"""Complete WikiMini 95M language model.
|
| 175 |
+
|
| 176 |
+
Architecture:
|
| 177 |
+
- Token embeddings with weight tying
|
| 178 |
+
- Stack of transformer blocks
|
| 179 |
+
- Final RMSNorm
|
| 180 |
+
- LM head (tied with embeddings)
|
| 181 |
+
"""
|
| 182 |
+
|
| 183 |
+
def __init__(self, config: Dict[str, Any]):
|
| 184 |
+
super().__init__()
|
| 185 |
+
|
| 186 |
+
# Extract config values with defaults
|
| 187 |
+
self.vocab_size = config.get('vocab_size', 32000)
|
| 188 |
+
self.d_model = config.get('d_model', 768)
|
| 189 |
+
self.n_layers = config.get('n_layers', 12)
|
| 190 |
+
self.n_heads = config.get('n_heads', 12)
|
| 191 |
+
self.d_ffn = config.get('d_ffn', None)
|
| 192 |
+
self.max_seq_len = config.get('max_seq_len', 2048)
|
| 193 |
+
self.dropout = config.get('dropout', 0.1)
|
| 194 |
+
self.rope_percentage = config.get('rope_percentage', 0.5)
|
| 195 |
+
self.rope_base = config.get('rope_base', 10000)
|
| 196 |
+
self.rms_norm_eps = config.get('rms_norm_eps', 1e-6)
|
| 197 |
+
self.tie_embeddings = config.get('tie_embeddings', True)
|
| 198 |
+
self.use_flash_attention = config.get('use_flash_attention', True)
|
| 199 |
+
self.use_gradient_checkpointing = config.get('gradient_checkpointing', False)
|
| 200 |
+
|
| 201 |
+
# Token embeddings
|
| 202 |
+
self.token_embedding = nn.Embedding(self.vocab_size, self.d_model)
|
| 203 |
+
|
| 204 |
+
# Transformer blocks
|
| 205 |
+
self.blocks = nn.ModuleList([
|
| 206 |
+
TransformerBlock(
|
| 207 |
+
d_model=self.d_model,
|
| 208 |
+
n_heads=self.n_heads,
|
| 209 |
+
d_ffn=self.d_ffn,
|
| 210 |
+
dropout=self.dropout,
|
| 211 |
+
max_seq_len=self.max_seq_len,
|
| 212 |
+
rope_base=self.rope_base,
|
| 213 |
+
rope_percentage=self.rope_percentage,
|
| 214 |
+
rms_norm_eps=self.rms_norm_eps,
|
| 215 |
+
use_flash_attention=self.use_flash_attention,
|
| 216 |
+
use_gradient_checkpointing=self.use_gradient_checkpointing,
|
| 217 |
+
)
|
| 218 |
+
for _ in range(self.n_layers)
|
| 219 |
+
])
|
| 220 |
+
|
| 221 |
+
# Final normalization
|
| 222 |
+
self.final_norm = RMSNorm(self.d_model, eps=self.rms_norm_eps)
|
| 223 |
+
|
| 224 |
+
# Language modeling head
|
| 225 |
+
self.lm_head = nn.Linear(self.d_model, self.vocab_size, bias=False)
|
| 226 |
+
|
| 227 |
+
# Weight tying
|
| 228 |
+
if self.tie_embeddings:
|
| 229 |
+
self.lm_head.weight = self.token_embedding.weight
|
| 230 |
+
|
| 231 |
+
# Initialize weights
|
| 232 |
+
self._init_weights()
|
| 233 |
+
|
| 234 |
+
def _init_weights(self):
|
| 235 |
+
"""Initialize weights with scaled normal distribution."""
|
| 236 |
+
for module in self.modules():
|
| 237 |
+
if isinstance(module, nn.Linear):
|
| 238 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 239 |
+
if module.bias is not None:
|
| 240 |
+
torch.nn.init.zeros_(module.bias)
|
| 241 |
+
elif isinstance(module, nn.Embedding):
|
| 242 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 243 |
+
|
| 244 |
+
def enable_gradient_checkpointing(self):
|
| 245 |
+
"""Enable gradient checkpointing for all transformer blocks."""
|
| 246 |
+
self.use_gradient_checkpointing = True
|
| 247 |
+
for block in self.blocks:
|
| 248 |
+
block.use_gradient_checkpointing = True
|
| 249 |
+
|
| 250 |
+
def disable_gradient_checkpointing(self):
|
| 251 |
+
"""Disable gradient checkpointing for all transformer blocks."""
|
| 252 |
+
self.use_gradient_checkpointing = False
|
| 253 |
+
for block in self.blocks:
|
| 254 |
+
block.use_gradient_checkpointing = False
|
| 255 |
+
|
| 256 |
+
def count_parameters(self) -> dict:
|
| 257 |
+
"""Count model parameters by component.
|
| 258 |
+
|
| 259 |
+
Returns:
|
| 260 |
+
Dictionary with parameter counts for each component
|
| 261 |
+
"""
|
| 262 |
+
# Count by component type
|
| 263 |
+
embedding_params = sum(p.numel() for p in self.token_embedding.parameters())
|
| 264 |
+
|
| 265 |
+
attention_params = 0
|
| 266 |
+
ffn_params = 0
|
| 267 |
+
norm_params = 0
|
| 268 |
+
|
| 269 |
+
for block in self.blocks:
|
| 270 |
+
# Attention parameters
|
| 271 |
+
attention_params += sum(p.numel() for p in block.attention.parameters())
|
| 272 |
+
# FFN parameters
|
| 273 |
+
ffn_params += sum(p.numel() for p in block.ffn.parameters())
|
| 274 |
+
# Norm parameters (attention + ffn norms)
|
| 275 |
+
norm_params += sum(p.numel() for p in block.attn_norm.parameters())
|
| 276 |
+
norm_params += sum(p.numel() for p in block.ffn_norm.parameters())
|
| 277 |
+
|
| 278 |
+
# Final norm
|
| 279 |
+
norm_params += sum(p.numel() for p in self.final_norm.parameters())
|
| 280 |
+
|
| 281 |
+
# LM head (only if not tied)
|
| 282 |
+
if not self.tie_embeddings:
|
| 283 |
+
lm_head_params = sum(p.numel() for p in self.lm_head.parameters())
|
| 284 |
+
else:
|
| 285 |
+
lm_head_params = 0 # Shared with embeddings
|
| 286 |
+
|
| 287 |
+
total_params = sum(p.numel() for p in self.parameters())
|
| 288 |
+
|
| 289 |
+
return {
|
| 290 |
+
'total': total_params,
|
| 291 |
+
'total_millions': total_params / 1e6,
|
| 292 |
+
'embedding': embedding_params,
|
| 293 |
+
'attention': attention_params,
|
| 294 |
+
'ffn': ffn_params,
|
| 295 |
+
'norm': norm_params,
|
| 296 |
+
'lm_head': lm_head_params,
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
def forward(
|
| 300 |
+
self,
|
| 301 |
+
input_ids: torch.Tensor,
|
| 302 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 303 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 304 |
+
labels: Optional[torch.Tensor] = None,
|
| 305 |
+
use_cache: bool = False,
|
| 306 |
+
past_key_values: Optional[list] = None,
|
| 307 |
+
) -> Dict[str, torch.Tensor]:
|
| 308 |
+
"""Forward pass of the model.
|
| 309 |
+
|
| 310 |
+
Args:
|
| 311 |
+
input_ids: Token IDs [batch, seq_len]
|
| 312 |
+
attention_mask: Optional attention mask
|
| 313 |
+
position_ids: Optional position IDs
|
| 314 |
+
labels: Optional labels for language modeling loss
|
| 315 |
+
use_cache: Whether to return KV cache
|
| 316 |
+
past_key_values: Past KV cache for inference
|
| 317 |
+
|
| 318 |
+
Returns:
|
| 319 |
+
Dictionary with 'logits' and optionally 'loss' and 'past_key_values'
|
| 320 |
+
"""
|
| 321 |
+
batch_size, seq_len = input_ids.shape
|
| 322 |
+
|
| 323 |
+
# Token embeddings
|
| 324 |
+
x = self.token_embedding(input_ids)
|
| 325 |
+
|
| 326 |
+
# Apply dropout to embeddings
|
| 327 |
+
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
|
| 328 |
+
|
| 329 |
+
# Process through transformer blocks
|
| 330 |
+
past_key_values_out = [] if use_cache else None
|
| 331 |
+
|
| 332 |
+
for i, block in enumerate(self.blocks):
|
| 333 |
+
# Get past KV for this layer if available
|
| 334 |
+
past_kv = past_key_values[i] if past_key_values is not None else None
|
| 335 |
+
|
| 336 |
+
# Process through block
|
| 337 |
+
x, kv_cache = block(
|
| 338 |
+
x,
|
| 339 |
+
attention_mask=attention_mask,
|
| 340 |
+
position_ids=position_ids,
|
| 341 |
+
use_cache=use_cache,
|
| 342 |
+
past_kv=past_kv,
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
# Store KV cache if needed
|
| 346 |
+
if use_cache:
|
| 347 |
+
past_key_values_out.append(kv_cache)
|
| 348 |
+
|
| 349 |
+
# Final normalization
|
| 350 |
+
x = self.final_norm(x)
|
| 351 |
+
|
| 352 |
+
# Language modeling head
|
| 353 |
+
logits = self.lm_head(x)
|
| 354 |
+
|
| 355 |
+
# Prepare output
|
| 356 |
+
output = {'logits': logits}
|
| 357 |
+
|
| 358 |
+
# Calculate loss if labels provided
|
| 359 |
+
if labels is not None:
|
| 360 |
+
# Shift for next-token prediction
|
| 361 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 362 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 363 |
+
|
| 364 |
+
# Flatten for cross-entropy
|
| 365 |
+
shift_logits = shift_logits.view(-1, self.vocab_size)
|
| 366 |
+
shift_labels = shift_labels.view(-1)
|
| 367 |
+
|
| 368 |
+
# Calculate cross-entropy loss
|
| 369 |
+
loss = nn.functional.cross_entropy(
|
| 370 |
+
shift_logits,
|
| 371 |
+
shift_labels,
|
| 372 |
+
ignore_index=-100, # Standard ignore index
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
output['loss'] = loss
|
| 376 |
+
|
| 377 |
+
# Add KV cache to output if requested
|
| 378 |
+
if use_cache:
|
| 379 |
+
output['past_key_values'] = past_key_values_out
|
| 380 |
+
|
| 381 |
+
return output
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
def create_model(config: Dict[str, Any]) -> WikiMiniModel:
|
| 385 |
+
"""Create a WikiMini model from configuration.
|
| 386 |
+
|
| 387 |
+
Args:
|
| 388 |
+
config: Model configuration dictionary
|
| 389 |
+
|
| 390 |
+
Returns:
|
| 391 |
+
WikiMiniModel instance
|
| 392 |
+
"""
|
| 393 |
+
return WikiMiniModel(config)
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
# Test the complete model
|
| 397 |
+
if __name__ == "__main__":
|
| 398 |
+
# Test configuration for ~95M parameters
|
| 399 |
+
config = {
|
| 400 |
+
'vocab_size': 32000,
|
| 401 |
+
'd_model': 768,
|
| 402 |
+
'n_layers': 12,
|
| 403 |
+
'n_heads': 12,
|
| 404 |
+
'd_ffn': 2048, # Adjusted for SwiGLU
|
| 405 |
+
'max_seq_len': 2048,
|
| 406 |
+
'dropout': 0.1,
|
| 407 |
+
'rope_percentage': 0.5,
|
| 408 |
+
'rope_base': 10000,
|
| 409 |
+
'rms_norm_eps': 1e-6,
|
| 410 |
+
'tie_embeddings': True,
|
| 411 |
+
'use_flash_attention': True,
|
| 412 |
+
'gradient_checkpointing': True, # Enable for memory efficiency
|
| 413 |
+
}
|
| 414 |
+
|
| 415 |
+
# Create model
|
| 416 |
+
model = WikiMiniModel(config)
|
| 417 |
+
|
| 418 |
+
# Count parameters
|
| 419 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 420 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 421 |
+
|
| 422 |
+
print(f"WikiMini Model:")
|
| 423 |
+
print(f" Total parameters: {total_params:,} ({total_params/1e6:.2f}M)")
|
| 424 |
+
print(f" Trainable parameters: {trainable_params:,} ({trainable_params/1e6:.2f}M)")
|
| 425 |
+
print(f" Layers: {model.n_layers}")
|
| 426 |
+
print(f" Hidden size: {model.d_model}")
|
| 427 |
+
print(f" Attention heads: {model.n_heads}")
|
| 428 |
+
|
| 429 |
+
# Test forward pass
|
| 430 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 431 |
+
model = model.to(device)
|
| 432 |
+
model.eval()
|
| 433 |
+
|
| 434 |
+
# Small test batch
|
| 435 |
+
batch_size = 2
|
| 436 |
+
seq_len = 128
|
| 437 |
+
input_ids = torch.randint(0, config['vocab_size'], (batch_size, seq_len), device=device)
|
| 438 |
+
|
| 439 |
+
# Enable gradient checkpointing
|
| 440 |
+
model.enable_gradient_checkpointing()
|
| 441 |
+
|
| 442 |
+
# Forward pass
|
| 443 |
+
with torch.no_grad():
|
| 444 |
+
outputs = model(input_ids=input_ids)
|
| 445 |
+
|
| 446 |
+
print(f"\nTest forward pass:")
|
| 447 |
+
print(f" Input shape: {input_ids.shape}")
|
| 448 |
+
print(f" Output logits shape: {outputs['logits'].shape}")
|
| 449 |
+
print(f" Device: {device}")
|
| 450 |
+
|
| 451 |
+
if torch.cuda.is_available():
|
| 452 |
+
print(f" Memory allocated: {torch.cuda.memory_allocated(device) / 1024**3:.2f} GB")
|
| 453 |
+
|
| 454 |
+
print("\n✓ Model test passed!")
|