karthick commited on
Commit
fb67af8
·
1 Parent(s): d99ca15

Upload TinyStories 24.5M model - article generation success

Browse files
src/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """TinyStories Language Model - 24.5M Parameters"""
2
+
3
+ __version__ = "1.0.0"
src/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (265 Bytes). View file
 
src/data/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Data processing modules for TinyStories training."""
2
+
3
+ from .tokenizer import load_tokenizer, train_tokenizer, test_tokenizer
4
+ from .dataset import TinyStoriesDataset, create_dataloaders
5
+ from .quality_checker import check_dataset_quality, DataQualityChecker
6
+
7
+ __all__ = [
8
+ 'load_tokenizer',
9
+ 'train_tokenizer',
10
+ 'test_tokenizer',
11
+ 'TinyStoriesDataset',
12
+ 'create_dataloaders',
13
+ 'check_dataset_quality',
14
+ 'DataQualityChecker',
15
+ ]
src/data/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (584 Bytes). View file
 
src/data/__pycache__/dataset.cpython-313.pyc ADDED
Binary file (11.6 kB). View file
 
src/data/__pycache__/quality_checker.cpython-313.pyc ADDED
Binary file (17.8 kB). View file
 
src/data/__pycache__/tokenizer.cpython-313.pyc ADDED
Binary file (11.1 kB). View file
 
src/data/dataset.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dataset and DataLoader utilities for TinyStories training.
2
+
3
+ This module provides:
4
+ 1. TinyStoriesDataset class for loading and processing TinyStories
5
+ 2. create_dataloaders function for creating train/val DataLoaders
6
+ 3. Sequence packing for efficient training
7
+
8
+ TinyStories is a synthetic dataset of short stories generated by GPT-3.5/4
9
+ using a limited vocabulary suitable for children. Perfect for fast training
10
+ and testing language models.
11
+ """
12
+
13
+ import torch
14
+ from torch.utils.data import Dataset, DataLoader
15
+ from datasets import load_dataset
16
+ from pathlib import Path
17
+ import pickle
18
+ import logging
19
+ from typing import Dict, List, Tuple, Optional
20
+ from tqdm import tqdm
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class TinyStoriesDataset(Dataset):
26
+ """TinyStories dataset with sequence packing for efficient training.
27
+
28
+ TinyStories is a synthetic dataset of short stories generated by GPT-3.5/4
29
+ using a limited vocabulary suitable for children. The dataset contains
30
+ ~2.1M stories and is excellent for:
31
+ - Fast training (only ~1GB)
32
+ - Clean, well-formed English
33
+ - Testing model architecture
34
+ - Educational purposes
35
+
36
+ This dataset:
37
+ 1. Loads TinyStories from HuggingFace datasets
38
+ 2. Tokenizes the text
39
+ 3. Packs sequences to max_seq_len for efficiency
40
+ 4. Caches processed data for fast subsequent loading
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ tokenizer,
46
+ split: str = "train",
47
+ max_seq_len: int = 512,
48
+ cache_dir: Optional[str] = None,
49
+ ):
50
+ """Initialize TinyStories dataset.
51
+
52
+ Args:
53
+ tokenizer: Tokenizer instance (must have encode method)
54
+ split: Dataset split ("train" or "validation")
55
+ max_seq_len: Maximum sequence length (default: 512, matches official paper)
56
+ cache_dir: Directory for caching processed data
57
+ """
58
+ self.tokenizer = tokenizer
59
+ self.split = split
60
+ self.max_seq_len = max_seq_len
61
+ self.cache_dir = Path(cache_dir) if cache_dir else Path("./data/cache")
62
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
63
+
64
+ # Cache file path
65
+ cache_file = self.cache_dir / f"tinystories_{split}_{max_seq_len}.pkl"
66
+
67
+ # Try to load from cache
68
+ if cache_file.exists():
69
+ logger.info(f"Loading cached dataset from {cache_file}")
70
+ with open(cache_file, "rb") as f:
71
+ cache_data = pickle.load(f)
72
+ self.input_ids = cache_data["input_ids"]
73
+ self.labels = cache_data["labels"]
74
+ logger.info(f"Loaded {len(self.input_ids)} sequences from cache")
75
+ else:
76
+ # Process dataset
77
+ logger.info(f"Processing TinyStories {split} split...")
78
+ self.input_ids, self.labels = self._process_dataset()
79
+
80
+ # Save to cache
81
+ logger.info(f"Saving processed dataset to {cache_file}")
82
+ cache_data = {
83
+ "input_ids": self.input_ids,
84
+ "labels": self.labels,
85
+ }
86
+ with open(cache_file, "wb") as f:
87
+ pickle.dump(cache_data, f)
88
+
89
+ logger.info(f"Dataset ready: {len(self.input_ids)} sequences")
90
+
91
+ def _process_dataset(self) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
92
+ """Process TinyStories dataset into packed sequences.
93
+
94
+ Returns:
95
+ Tuple of (input_ids, labels) lists
96
+ """
97
+ # Load dataset
98
+ dataset = load_dataset(
99
+ "roneneldan/TinyStories",
100
+ split=self.split,
101
+ )
102
+
103
+ # Tokenize all text
104
+ logger.info("Tokenizing dataset...")
105
+ all_token_ids = []
106
+
107
+ for example in tqdm(dataset, desc="Tokenizing"):
108
+ text = example["text"].strip()
109
+ if len(text) > 0: # Skip empty stories
110
+ # Encode text
111
+ if hasattr(self.tokenizer, 'encode'):
112
+ token_ids = self.tokenizer.encode(text, add_special_tokens=False)
113
+ else:
114
+ # Fallback for tokenizers.Tokenizer
115
+ token_ids = self.tokenizer.tokenizer.encode(text).ids
116
+
117
+ all_token_ids.extend(token_ids)
118
+
119
+ logger.info(f"Total tokens: {len(all_token_ids):,}")
120
+
121
+ # Pack into sequences
122
+ logger.info("Packing sequences...")
123
+ input_ids_list = []
124
+ labels_list = []
125
+
126
+ # Pack sequences with stride to maximize data usage
127
+ for i in range(0, len(all_token_ids) - 1, self.max_seq_len):
128
+ # Get sequence
129
+ seq = all_token_ids[i : i + self.max_seq_len]
130
+
131
+ # Skip if too short
132
+ if len(seq) < 2:
133
+ continue
134
+
135
+ # Create input_ids and labels
136
+ # input_ids: [0, 1, 2, ..., n-1]
137
+ # labels: [1, 2, 3, ..., n]
138
+ input_ids = torch.tensor(seq[:-1], dtype=torch.long)
139
+ labels = torch.tensor(seq[1:], dtype=torch.long)
140
+
141
+ # Pad if necessary
142
+ if len(input_ids) < self.max_seq_len:
143
+ pad_len = self.max_seq_len - len(input_ids)
144
+ input_ids = torch.cat([
145
+ input_ids,
146
+ torch.full((pad_len,), self.tokenizer.pad_token_id, dtype=torch.long)
147
+ ])
148
+ labels = torch.cat([
149
+ labels,
150
+ torch.full((pad_len,), -100, dtype=torch.long) # -100 is ignored in loss
151
+ ])
152
+
153
+ input_ids_list.append(input_ids)
154
+ labels_list.append(labels)
155
+
156
+ logger.info(f"Created {len(input_ids_list)} packed sequences")
157
+
158
+ return input_ids_list, labels_list
159
+
160
+ def __len__(self) -> int:
161
+ """Return number of sequences."""
162
+ return len(self.input_ids)
163
+
164
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
165
+ """Get a single sequence.
166
+
167
+ Args:
168
+ idx: Sequence index
169
+
170
+ Returns:
171
+ Dictionary with 'input_ids' and 'labels'
172
+ """
173
+ return {
174
+ "input_ids": self.input_ids[idx],
175
+ "labels": self.labels[idx],
176
+ }
177
+
178
+
179
+ def collate_fn(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
180
+ """Collate function for DataLoader.
181
+
182
+ Args:
183
+ batch: List of dictionaries with 'input_ids' and 'labels'
184
+
185
+ Returns:
186
+ Batched dictionary
187
+ """
188
+ input_ids = torch.stack([item["input_ids"] for item in batch])
189
+ labels = torch.stack([item["labels"] for item in batch])
190
+
191
+ return {
192
+ "input_ids": input_ids,
193
+ "labels": labels,
194
+ }
195
+
196
+
197
+ def create_dataloaders(
198
+ tokenizer,
199
+ batch_size: int,
200
+ max_seq_len: int,
201
+ cache_dir: str,
202
+ dataset_name: str = "tinystories",
203
+ num_workers: int = 0,
204
+ pin_memory: bool = True,
205
+ drop_last: bool = True,
206
+ ) -> Tuple[DataLoader, DataLoader]:
207
+ """Create train and validation DataLoaders for TinyStories.
208
+
209
+ Args:
210
+ tokenizer: Tokenizer instance
211
+ batch_size: Batch size per device
212
+ max_seq_len: Maximum sequence length (512 recommended for TinyStories)
213
+ cache_dir: Directory for caching processed data
214
+ dataset_name: Dataset to use (default: "tinystories")
215
+ num_workers: Number of data loading workers (use 0 for Windows)
216
+ pin_memory: Whether to pin memory for faster GPU transfer
217
+ drop_last: Whether to drop last incomplete batch
218
+
219
+ Returns:
220
+ Tuple of (train_loader, val_loader)
221
+ """
222
+ logger.info("Using TinyStories dataset")
223
+
224
+ logger.info("Creating train dataset...")
225
+ train_dataset = TinyStoriesDataset(
226
+ tokenizer=tokenizer,
227
+ split="train",
228
+ max_seq_len=max_seq_len,
229
+ cache_dir=cache_dir,
230
+ )
231
+
232
+ logger.info("Creating validation dataset...")
233
+ val_dataset = TinyStoriesDataset(
234
+ tokenizer=tokenizer,
235
+ split="validation",
236
+ max_seq_len=max_seq_len,
237
+ cache_dir=cache_dir,
238
+ )
239
+
240
+ # Create DataLoaders
241
+ train_loader = DataLoader(
242
+ train_dataset,
243
+ batch_size=batch_size,
244
+ shuffle=True,
245
+ num_workers=num_workers,
246
+ pin_memory=pin_memory,
247
+ drop_last=drop_last,
248
+ collate_fn=collate_fn,
249
+ )
250
+
251
+ val_loader = DataLoader(
252
+ val_dataset,
253
+ batch_size=batch_size,
254
+ shuffle=False,
255
+ num_workers=num_workers,
256
+ pin_memory=pin_memory,
257
+ drop_last=False,
258
+ collate_fn=collate_fn,
259
+ )
260
+
261
+ logger.info(f"Train batches: {len(train_loader)}")
262
+ logger.info(f"Validation batches: {len(val_loader)}")
263
+
264
+ return train_loader, val_loader
265
+
266
+
267
+ # Test the dataset
268
+ if __name__ == "__main__":
269
+ from .tokenizer import load_tokenizer
270
+
271
+ print("Testing TinyStoriesDataset...")
272
+
273
+ # Load tokenizer (assumes it exists)
274
+ tokenizer_path = "./tokenizer/wikimini_32k"
275
+ if Path(tokenizer_path).exists():
276
+ tokenizer = load_tokenizer(tokenizer_path)
277
+
278
+ # Create small dataset for testing
279
+ dataset = TinyStoriesDataset(
280
+ tokenizer=tokenizer,
281
+ split="validation", # Use smaller split for testing
282
+ max_seq_len=128,
283
+ cache_dir="./data/cache_test",
284
+ )
285
+
286
+ print(f"\nDataset size: {len(dataset)}")
287
+ print(f"Sample batch:")
288
+ sample = dataset[0]
289
+ print(f" Input IDs shape: {sample['input_ids'].shape}")
290
+ print(f" Labels shape: {sample['labels'].shape}")
291
+ print(f" First 10 input IDs: {sample['input_ids'][:10]}")
292
+ print(f" First 10 labels: {sample['labels'][:10]}")
293
+
294
+ # Test DataLoader
295
+ loader = DataLoader(dataset, batch_size=4, collate_fn=collate_fn)
296
+ batch = next(iter(loader))
297
+ print(f"\nDataLoader batch:")
298
+ print(f" Input IDs shape: {batch['input_ids'].shape}")
299
+ print(f" Labels shape: {batch['labels'].shape}")
300
+ else:
301
+ print(f"Tokenizer not found at {tokenizer_path}")
302
+ print("Please train tokenizer first: python scripts/train_tokenizer.py")
src/data/quality_checker.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Data Quality Checker for training datasets.
2
+
3
+ This module provides tools to validate dataset quality before training:
4
+ - Detects artifacts (HTML tags, URLs, special tokens)
5
+ - Checks for malformed text
6
+ - Validates text statistics
7
+ - Reports quality issues
8
+
9
+ Prevents training on corrupted or low-quality data.
10
+ """
11
+
12
+ import re
13
+ import logging
14
+ from typing import Dict, List, Tuple, Optional
15
+ from collections import Counter
16
+ from datasets import load_dataset
17
+ from tqdm import tqdm
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class DataQualityChecker:
23
+ """Check dataset quality before training."""
24
+
25
+ def __init__(
26
+ self,
27
+ dataset_name: str,
28
+ split: str = "train",
29
+ sample_size: Optional[int] = 10000,
30
+ strict: bool = False,
31
+ ):
32
+ """Initialize quality checker.
33
+
34
+ Args:
35
+ dataset_name: Name of dataset (e.g., "roneneldan/TinyStories")
36
+ split: Dataset split to check ("train" or "validation")
37
+ sample_size: Number of samples to check (None for all)
38
+ strict: If True, raise errors on issues; if False, only warn
39
+ """
40
+ self.dataset_name = dataset_name
41
+ self.split = split
42
+ self.sample_size = sample_size
43
+ self.strict = strict
44
+
45
+ # Quality metrics
46
+ self.issues: Dict[str, List[Tuple[int, str]]] = {
47
+ "html_tags": [],
48
+ "urls": [],
49
+ "emails": [],
50
+ "excessive_punctuation": [],
51
+ "malformed_unicode": [],
52
+ "empty_text": [],
53
+ "extremely_short": [],
54
+ "extremely_long": [],
55
+ "suspicious_patterns": [],
56
+ "special_tokens": [],
57
+ }
58
+
59
+ self.stats = {
60
+ "total_samples": 0,
61
+ "total_chars": 0,
62
+ "total_words": 0,
63
+ "avg_length": 0,
64
+ "vocabulary_size": 0,
65
+ }
66
+
67
+ def check_quality(self) -> Dict:
68
+ """Run all quality checks and return results.
69
+
70
+ Returns:
71
+ Dictionary with quality report and pass/fail status
72
+ """
73
+ logger.info(f"Loading dataset {self.dataset_name} ({self.split} split)...")
74
+
75
+ # Load dataset
76
+ if "tinystories" in self.dataset_name.lower():
77
+ dataset = load_dataset("roneneldan/TinyStories", split=self.split)
78
+ elif "wikitext" in self.dataset_name.lower():
79
+ dataset = load_dataset("wikitext", "wikitext-103-raw-v1", split=self.split, trust_remote_code=True)
80
+ else:
81
+ dataset = load_dataset(self.dataset_name, split=self.split)
82
+
83
+ # Limit sample size if requested
84
+ if self.sample_size and len(dataset) > self.sample_size:
85
+ logger.info(f"Sampling {self.sample_size} examples from {len(dataset)} total")
86
+ indices = range(0, len(dataset), len(dataset) // self.sample_size)
87
+ dataset = dataset.select(list(indices)[:self.sample_size])
88
+
89
+ logger.info(f"Checking quality of {len(dataset)} examples...")
90
+
91
+ # Run checks
92
+ vocabulary = set()
93
+
94
+ for idx, example in enumerate(tqdm(dataset, desc="Quality Check")):
95
+ text = example.get("text", "")
96
+
97
+ # Update stats
98
+ self.stats["total_samples"] += 1
99
+ self.stats["total_chars"] += len(text)
100
+ words = text.split()
101
+ self.stats["total_words"] += len(words)
102
+ vocabulary.update(words)
103
+
104
+ # Run individual checks
105
+ self._check_html_tags(idx, text)
106
+ self._check_urls(idx, text)
107
+ self._check_emails(idx, text)
108
+ self._check_excessive_punctuation(idx, text)
109
+ self._check_malformed_unicode(idx, text)
110
+ self._check_empty_text(idx, text)
111
+ self._check_length_extremes(idx, text)
112
+ self._check_suspicious_patterns(idx, text)
113
+ self._check_special_tokens(idx, text)
114
+
115
+ # Calculate final stats
116
+ if self.stats["total_samples"] > 0:
117
+ self.stats["avg_length"] = self.stats["total_chars"] / self.stats["total_samples"]
118
+ self.stats["avg_words"] = self.stats["total_words"] / self.stats["total_samples"]
119
+ self.stats["vocabulary_size"] = len(vocabulary)
120
+
121
+ # Generate report
122
+ report = self._generate_report()
123
+
124
+ return report
125
+
126
+ def _check_html_tags(self, idx: int, text: str):
127
+ """Check for HTML tags."""
128
+ html_pattern = r'<[^>]+>'
129
+ if re.search(html_pattern, text):
130
+ self.issues["html_tags"].append((idx, text[:100]))
131
+
132
+ def _check_urls(self, idx: int, text: str):
133
+ """Check for URLs."""
134
+ url_pattern = r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+'
135
+ if re.search(url_pattern, text):
136
+ self.issues["urls"].append((idx, text[:100]))
137
+
138
+ def _check_emails(self, idx: int, text: str):
139
+ """Check for email addresses."""
140
+ email_pattern = r'[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}'
141
+ if re.search(email_pattern, text):
142
+ self.issues["emails"].append((idx, text[:100]))
143
+
144
+ def _check_excessive_punctuation(self, idx: int, text: str):
145
+ """Check for excessive punctuation (possible artifacts)."""
146
+ # More than 5 consecutive punctuation marks
147
+ if re.search(r'[!?.,;:]{5,}', text):
148
+ self.issues["excessive_punctuation"].append((idx, text[:100]))
149
+
150
+ # More than 20% punctuation
151
+ if len(text) > 0:
152
+ punct_count = sum(1 for c in text if c in '!?.,;:')
153
+ if punct_count / len(text) > 0.2:
154
+ self.issues["excessive_punctuation"].append((idx, text[:100]))
155
+
156
+ def _check_malformed_unicode(self, idx: int, text: str):
157
+ """Check for malformed Unicode characters."""
158
+ # Look for replacement characters or control characters
159
+ if '�' in text or '\ufffd' in text:
160
+ self.issues["malformed_unicode"].append((idx, text[:100]))
161
+
162
+ # Control characters (excluding whitespace)
163
+ if re.search(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]', text):
164
+ self.issues["malformed_unicode"].append((idx, text[:100]))
165
+
166
+ def _check_empty_text(self, idx: int, text: str):
167
+ """Check for empty or whitespace-only text."""
168
+ if not text or not text.strip():
169
+ self.issues["empty_text"].append((idx, text))
170
+
171
+ def _check_length_extremes(self, idx: int, text: str):
172
+ """Check for extremely short or long text."""
173
+ if len(text.strip()) < 10:
174
+ self.issues["extremely_short"].append((idx, text))
175
+ elif len(text) > 50000: # Suspiciously long
176
+ self.issues["extremely_long"].append((idx, text[:100]))
177
+
178
+ def _check_suspicious_patterns(self, idx: int, text: str):
179
+ """Check for suspicious patterns."""
180
+ # Repeated characters (e.g., "aaaaaa" more than 10 times)
181
+ if re.search(r'(.)\1{10,}', text):
182
+ self.issues["suspicious_patterns"].append((idx, text[:100]))
183
+
184
+ # Excessive whitespace
185
+ if re.search(r'\s{10,}', text):
186
+ self.issues["suspicious_patterns"].append((idx, text[:100]))
187
+
188
+ def _check_special_tokens(self, idx: int, text: str):
189
+ """Check for special tokens that shouldn't be in raw text."""
190
+ # Common tokenizer special tokens
191
+ special_tokens = ['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]', '<|endoftext|>', '<pad>', '<unk>']
192
+ for token in special_tokens:
193
+ if token in text:
194
+ self.issues["special_tokens"].append((idx, text[:100]))
195
+ break
196
+
197
+ def _generate_report(self) -> Dict:
198
+ """Generate quality report.
199
+
200
+ Returns:
201
+ Dictionary with quality metrics and pass/fail status
202
+ """
203
+ total_issues = sum(len(issues) for issues in self.issues.values())
204
+ issue_percentage = (total_issues / self.stats["total_samples"] * 100) if self.stats["total_samples"] > 0 else 0
205
+
206
+ # Determine quality level
207
+ if issue_percentage == 0:
208
+ quality_level = "EXCELLENT"
209
+ passed = True
210
+ elif issue_percentage < 1:
211
+ quality_level = "GOOD"
212
+ passed = True
213
+ elif issue_percentage < 5:
214
+ quality_level = "ACCEPTABLE"
215
+ passed = not self.strict
216
+ elif issue_percentage < 10:
217
+ quality_level = "POOR"
218
+ passed = False
219
+ else:
220
+ quality_level = "CRITICAL"
221
+ passed = False
222
+
223
+ report = {
224
+ "dataset": self.dataset_name,
225
+ "split": self.split,
226
+ "quality_level": quality_level,
227
+ "passed": passed,
228
+ "stats": self.stats,
229
+ "issues": {
230
+ key: {
231
+ "count": len(value),
232
+ "percentage": (len(value) / self.stats["total_samples"] * 100) if self.stats["total_samples"] > 0 else 0,
233
+ "samples": value[:3] # First 3 examples
234
+ }
235
+ for key, value in self.issues.items() if len(value) > 0
236
+ },
237
+ "total_issues": total_issues,
238
+ "issue_percentage": issue_percentage,
239
+ }
240
+
241
+ return report
242
+
243
+ def print_report(self, report: Dict):
244
+ """Print formatted quality report.
245
+
246
+ Args:
247
+ report: Report dictionary from check_quality()
248
+ """
249
+ logger.info("\n" + "=" * 70)
250
+ logger.info("DATA QUALITY REPORT")
251
+ logger.info("=" * 70)
252
+ logger.info(f"Dataset: {report['dataset']} ({report['split']} split)")
253
+ logger.info(f"Quality Level: {report['quality_level']}")
254
+ logger.info(f"Status: {'✅ PASSED' if report['passed'] else '❌ FAILED'}")
255
+ logger.info("")
256
+
257
+ # Statistics
258
+ logger.info("Statistics:")
259
+ logger.info(f" Total Samples: {report['stats']['total_samples']:,}")
260
+ logger.info(f" Avg Length: {report['stats']['avg_length']:.1f} chars")
261
+ logger.info(f" Avg Words: {report['stats'].get('avg_words', 0):.1f} words")
262
+ logger.info(f" Vocabulary Size: {report['stats']['vocabulary_size']:,}")
263
+ logger.info("")
264
+
265
+ # Issues
266
+ if report['issues']:
267
+ logger.warning(f"Found {report['total_issues']} issues ({report['issue_percentage']:.2f}% of samples)")
268
+ logger.warning("")
269
+ for issue_type, details in report['issues'].items():
270
+ logger.warning(f" {issue_type.replace('_', ' ').title()}:")
271
+ logger.warning(f" Count: {details['count']} ({details['percentage']:.2f}%)")
272
+ if details['samples']:
273
+ logger.warning(f" Example: {details['samples'][0][1][:80]}...")
274
+ logger.warning("")
275
+ else:
276
+ logger.info("✅ No quality issues found!")
277
+
278
+ logger.info("=" * 70)
279
+
280
+ # Recommendations
281
+ if not report['passed']:
282
+ logger.error("\n⚠️ DATA HAS QUALITY ISSUES - Training not recommended!")
283
+ logger.error("Recommendations:")
284
+ if report['issues'].get('html_tags'):
285
+ logger.error(" - Remove HTML tags from text")
286
+ if report['issues'].get('urls'):
287
+ logger.error(" - Remove or mask URLs")
288
+ if report['issues'].get('malformed_unicode'):
289
+ logger.error(" - Fix Unicode encoding issues")
290
+ if report['issues'].get('empty_text'):
291
+ logger.error(" - Remove empty samples")
292
+ logger.error("")
293
+
294
+
295
+ def check_dataset_quality(
296
+ dataset_name: str,
297
+ split: str = "train",
298
+ sample_size: Optional[int] = 10000,
299
+ strict: bool = False,
300
+ ) -> bool:
301
+ """Quick function to check dataset quality.
302
+
303
+ Args:
304
+ dataset_name: Dataset name or HuggingFace ID
305
+ split: Split to check
306
+ sample_size: Number of samples to check (None for all)
307
+ strict: If True, fail on any issues
308
+
309
+ Returns:
310
+ True if quality is acceptable, False otherwise
311
+ """
312
+ checker = DataQualityChecker(
313
+ dataset_name=dataset_name,
314
+ split=split,
315
+ sample_size=sample_size,
316
+ strict=strict,
317
+ )
318
+
319
+ report = checker.check_quality()
320
+ checker.print_report(report)
321
+
322
+ return report["passed"]
323
+
324
+
325
+ if __name__ == "__main__":
326
+ import argparse
327
+
328
+ parser = argparse.ArgumentParser(description="Check dataset quality")
329
+ parser.add_argument("--dataset", type=str, required=True, help="Dataset name")
330
+ parser.add_argument("--split", type=str, default="train", help="Dataset split")
331
+ parser.add_argument("--sample-size", type=int, default=10000, help="Number of samples to check")
332
+ parser.add_argument("--strict", action="store_true", help="Fail on any issues")
333
+
334
+ args = parser.parse_args()
335
+
336
+ passed = check_dataset_quality(
337
+ dataset_name=args.dataset,
338
+ split=args.split,
339
+ sample_size=args.sample_size,
340
+ strict=args.strict,
341
+ )
342
+
343
+ exit(0 if passed else 1)
src/data/tokenizer.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tokenizer training and loading utilities for WikiMini model.
2
+
3
+ This module provides functions to:
4
+ 1. Train a BPE tokenizer on WikiText-103
5
+ 2. Load a trained tokenizer from disk
6
+ 3. Test tokenizer functionality
7
+ """
8
+
9
+ import os
10
+ from pathlib import Path
11
+ from typing import Optional, List
12
+ from tokenizers import Tokenizer, models, trainers, pre_tokenizers, decoders, processors
13
+ from datasets import load_dataset
14
+ import logging
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def train_tokenizer(
20
+ vocab_size: int = 32000,
21
+ min_frequency: int = 2,
22
+ output_dir: str = "./tokenizer/wikimini_32k",
23
+ show_progress: bool = True,
24
+ ) -> Tokenizer:
25
+ """Train a BPE tokenizer on WikiText-103 dataset.
26
+
27
+ Args:
28
+ vocab_size: Size of the vocabulary
29
+ min_frequency: Minimum frequency for tokens
30
+ output_dir: Directory to save the trained tokenizer
31
+ show_progress: Whether to show progress during training
32
+
33
+ Returns:
34
+ Trained tokenizer
35
+ """
36
+ logger.info(f"Training BPE tokenizer with vocab_size={vocab_size}")
37
+
38
+ # Initialize BPE tokenizer
39
+ tokenizer = Tokenizer(models.BPE(unk_token="<unk>"))
40
+
41
+ # Pre-tokenization (split on whitespace and punctuation)
42
+ tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
43
+
44
+ # Decoder
45
+ tokenizer.decoder = decoders.ByteLevel()
46
+
47
+ # Configure trainer
48
+ special_tokens = [
49
+ "<unk>", # Unknown token
50
+ "<s>", # Begin of sentence
51
+ "</s>", # End of sentence
52
+ "<pad>", # Padding token
53
+ ]
54
+
55
+ trainer = trainers.BpeTrainer(
56
+ vocab_size=vocab_size,
57
+ min_frequency=min_frequency,
58
+ special_tokens=special_tokens,
59
+ show_progress=show_progress,
60
+ )
61
+
62
+ # Load WikiText-103 dataset
63
+ logger.info("Loading WikiText-103 dataset...")
64
+ dataset = load_dataset("wikitext", "wikitext-103-raw-v1", split="train")
65
+
66
+ # Create iterator for training
67
+ def batch_iterator(batch_size: int = 1000):
68
+ """Yield batches of text for training."""
69
+ for i in range(0, len(dataset), batch_size):
70
+ batch = dataset[i : i + batch_size]
71
+ yield batch["text"]
72
+
73
+ # Train tokenizer
74
+ logger.info("Training tokenizer...")
75
+ tokenizer.train_from_iterator(batch_iterator(), trainer=trainer)
76
+
77
+ # Add post-processor for special tokens
78
+ tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
79
+
80
+ # Enable padding
81
+ tokenizer.enable_padding(
82
+ pad_id=tokenizer.token_to_id("<pad>"),
83
+ pad_token="<pad>",
84
+ )
85
+
86
+ # Enable truncation
87
+ tokenizer.enable_truncation(max_length=2048)
88
+
89
+ # Save tokenizer
90
+ output_path = Path(output_dir)
91
+ output_path.mkdir(parents=True, exist_ok=True)
92
+
93
+ tokenizer_file = output_path / "tokenizer.json"
94
+ tokenizer.save(str(tokenizer_file))
95
+ logger.info(f"Tokenizer saved to {tokenizer_file}")
96
+
97
+ # Save config
98
+ config = {
99
+ "vocab_size": vocab_size,
100
+ "model_type": "BPE",
101
+ "unk_token": "<unk>",
102
+ "bos_token": "<s>",
103
+ "eos_token": "</s>",
104
+ "pad_token": "<pad>",
105
+ }
106
+
107
+ import json
108
+ config_file = output_path / "config.json"
109
+ with open(config_file, 'w') as f:
110
+ json.dump(config, f, indent=2)
111
+ logger.info(f"Config saved to {config_file}")
112
+
113
+ return tokenizer
114
+
115
+
116
+ def load_tokenizer(tokenizer_path: str, return_wrapper: bool = True):
117
+ """Load a trained tokenizer from disk.
118
+
119
+ Args:
120
+ tokenizer_path: Path to the tokenizer directory or file
121
+ return_wrapper: If True, returns TokenizerWrapper (default), else raw Tokenizer
122
+
123
+ Returns:
124
+ Loaded tokenizer (wrapped by default for compatibility)
125
+ """
126
+ tokenizer_path = Path(tokenizer_path)
127
+
128
+ # Handle both directory and file paths
129
+ if tokenizer_path.is_dir():
130
+ tokenizer_file = tokenizer_path / "tokenizer.json"
131
+ else:
132
+ tokenizer_file = tokenizer_path
133
+
134
+ if not tokenizer_file.exists():
135
+ raise FileNotFoundError(f"Tokenizer file not found: {tokenizer_file}")
136
+
137
+ logger.info(f"Loading tokenizer from {tokenizer_file}")
138
+ tokenizer = Tokenizer.from_file(str(tokenizer_file))
139
+
140
+ # Return wrapped version for easier use (supports len(), etc.)
141
+ if return_wrapper:
142
+ return TokenizerWrapper(tokenizer)
143
+
144
+ return tokenizer
145
+
146
+
147
+ def test_tokenizer(tokenizer: Tokenizer) -> None:
148
+ """Test tokenizer with sample text.
149
+
150
+ Args:
151
+ tokenizer: Tokenizer to test
152
+ """
153
+ print("\n" + "="*70)
154
+ print(" "*25 + "Tokenizer Test")
155
+ print("="*70)
156
+
157
+ # Get vocab info
158
+ vocab_size = tokenizer.get_vocab_size()
159
+ print(f"\nVocabulary size: {vocab_size:,}")
160
+
161
+ # Test special tokens
162
+ print("\nSpecial tokens:")
163
+ special_tokens = ["<unk>", "<s>", "</s>", "<pad>"]
164
+ for token in special_tokens:
165
+ token_id = tokenizer.token_to_id(token)
166
+ print(f" {token:8s} -> ID {token_id}")
167
+
168
+ # Test encoding/decoding
169
+ test_texts = [
170
+ "The quick brown fox jumps over the lazy dog.",
171
+ "Machine learning is a subset of artificial intelligence.",
172
+ "WikiText-103 is a large-scale language modeling benchmark.",
173
+ ]
174
+
175
+ print("\nEncoding/Decoding tests:")
176
+ print("-" * 70)
177
+
178
+ for i, text in enumerate(test_texts, 1):
179
+ # Encode
180
+ encoding = tokenizer.encode(text)
181
+ tokens = encoding.tokens
182
+ ids = encoding.ids
183
+
184
+ # Decode
185
+ decoded = tokenizer.decode(ids)
186
+
187
+ print(f"\nTest {i}:")
188
+ print(f" Original: {text}")
189
+ print(f" Tokens: {len(tokens)}")
190
+ print(f" IDs: {ids[:10]}..." if len(ids) > 10 else f" IDs: {ids}")
191
+ print(f" Decoded: {decoded}")
192
+
193
+ # Check round-trip
194
+ if decoded.strip() == text.strip():
195
+ print(" ✅ Round-trip successful")
196
+ else:
197
+ print(" ⚠️ Round-trip differs slightly (common with BPE)")
198
+
199
+ # Test batch encoding
200
+ print("\n\nBatch encoding test:")
201
+ print("-" * 70)
202
+ encodings = tokenizer.encode_batch(test_texts)
203
+ print(f" Batch size: {len(encodings)}")
204
+ print(f" Token counts: {[len(enc.ids) for enc in encodings]}")
205
+
206
+ print("\n" + "="*70)
207
+ print(" "*25 + "✅ Test Complete")
208
+ print("="*70 + "\n")
209
+
210
+
211
+ # Wrapper class for compatibility with HuggingFace-style interface
212
+ class TokenizerWrapper:
213
+ """Wrapper to make tokenizers.Tokenizer compatible with expected interface."""
214
+
215
+ def __init__(self, tokenizer: Tokenizer):
216
+ self.tokenizer = tokenizer
217
+ self._vocab_size = tokenizer.get_vocab_size()
218
+
219
+ # Get special token IDs - support multiple formats
220
+ # Try standard format first, then TinyStories custom format
221
+ self.pad_token_id = (
222
+ tokenizer.token_to_id("<pad>") or
223
+ tokenizer.token_to_id("<|padding|>") or
224
+ 0 # Fallback to 0 if not found
225
+ )
226
+ self.bos_token_id = (
227
+ tokenizer.token_to_id("<s>") or
228
+ tokenizer.token_to_id("<|startoftext|>")
229
+ )
230
+ self.eos_token_id = (
231
+ tokenizer.token_to_id("</s>") or
232
+ tokenizer.token_to_id("<|endoftext|>")
233
+ )
234
+ self.unk_token_id = tokenizer.token_to_id("<unk>")
235
+
236
+ def __call__(self, text, **kwargs):
237
+ """Encode text (callable interface)."""
238
+ if isinstance(text, str):
239
+ return self.tokenizer.encode(text).ids
240
+ elif isinstance(text, list):
241
+ return [self.tokenizer.encode(t).ids for t in text]
242
+
243
+ def encode(self, text, add_special_tokens=True):
244
+ """Encode text to token IDs."""
245
+ encoding = self.tokenizer.encode(text)
246
+ return encoding.ids
247
+
248
+ def decode(self, token_ids, skip_special_tokens=True):
249
+ """Decode token IDs to text."""
250
+ return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
251
+
252
+ def __len__(self):
253
+ """Return vocabulary size."""
254
+ return self._vocab_size
255
+
256
+ @property
257
+ def vocab_size(self):
258
+ """Vocabulary size property."""
259
+ return self._vocab_size
260
+
261
+
262
+ def create_tokenizer_wrapper(tokenizer_path: str) -> TokenizerWrapper:
263
+ """Create a wrapped tokenizer for easier use.
264
+
265
+ Args:
266
+ tokenizer_path: Path to tokenizer directory or file
267
+
268
+ Returns:
269
+ TokenizerWrapper instance
270
+ """
271
+ tokenizer = load_tokenizer(tokenizer_path, return_wrapper=False)
272
+ return TokenizerWrapper(tokenizer)
src/model/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Model components for WikiMini 95M."""
2
+
3
+ from .rmsnorm import RMSNorm, RMSNormOptimized
4
+ from .rope import RotaryPositionEmbeddings, RotaryPositionEmbeddingsComplex
5
+ from .swiglu import SwiGLU, SwiGLUParallel, GeGLU
6
+ from .attention import MultiHeadAttention
7
+ from .transformer_block import TransformerBlock, WikiMiniModel
8
+
9
+ __all__ = [
10
+ "RMSNorm",
11
+ "RMSNormOptimized",
12
+ "RotaryPositionEmbeddings",
13
+ "RotaryPositionEmbeddingsComplex",
14
+ "SwiGLU",
15
+ "SwiGLUParallel",
16
+ "GeGLU",
17
+ "MultiHeadAttention",
18
+ "TransformerBlock",
19
+ "WikiMiniModel",
20
+ ]
src/model/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (692 Bytes). View file
 
src/model/__pycache__/attention.cpython-313.pyc ADDED
Binary file (12.3 kB). View file
 
src/model/__pycache__/rmsnorm.cpython-313.pyc ADDED
Binary file (7.81 kB). View file
 
src/model/__pycache__/rope.cpython-313.pyc ADDED
Binary file (9.95 kB). View file
 
src/model/__pycache__/swiglu.cpython-313.pyc ADDED
Binary file (8.97 kB). View file
 
src/model/__pycache__/transformer_block.cpython-313.pyc ADDED
Binary file (18.7 kB). View file
 
src/model/attention.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Multi-Head Attention with RoPE integration and memory optimizations.
2
+
3
+ Critical implementation details:
4
+ 1. Apply RoPE only to Q and K, never to V
5
+ 2. Use SDPA for Flash Attention 2 support
6
+ 3. Pre-normalization architecture
7
+ 4. Memory-efficient implementation
8
+ """
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import math
14
+ from typing import Optional, Tuple
15
+ from .rope import RotaryPositionEmbeddings
16
+
17
+
18
+ class MultiHeadAttention(nn.Module):
19
+ """Multi-Head Attention with RoPE and Flash Attention support.
20
+
21
+ This implementation:
22
+ - Uses Rotary Position Embeddings (RoPE) on Q and K only
23
+ - Supports Flash Attention 2 via torch.nn.functional.scaled_dot_product_attention
24
+ - Uses no bias terms (modern approach)
25
+ - Includes proper causal masking
26
+ - Memory-efficient implementation
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ d_model: int = 768,
32
+ n_heads: int = 12,
33
+ dropout: float = 0.1,
34
+ max_seq_len: int = 2048,
35
+ rope_base: int = 10000,
36
+ rope_percentage: float = 0.5,
37
+ use_flash_attention: bool = True,
38
+ ):
39
+ super().__init__()
40
+
41
+ assert d_model % n_heads == 0, f"d_model ({d_model}) must be divisible by n_heads ({n_heads})"
42
+
43
+ self.d_model = d_model
44
+ self.n_heads = n_heads
45
+ self.head_dim = d_model // n_heads
46
+
47
+ # Windows Flash Attention: Test with PyTorch 2.10+ nightly
48
+ # Older versions had freezing issues, but newer versions may work
49
+ import sys
50
+ import logging
51
+ logger = logging.getLogger(__name__)
52
+
53
+ if sys.platform == 'win32' and use_flash_attention:
54
+ # Allow Flash Attention on Windows with PyTorch 2.10+
55
+ # If freezing occurs, set use_flash_attention: false in config
56
+ self.use_flash_attention = use_flash_attention
57
+ logger.info("[Windows] Attempting Flash Attention with PyTorch 2.10+ - if freezing occurs, disable in config")
58
+ elif sys.platform == 'win32':
59
+ self.use_flash_attention = False
60
+ logger.info("[Windows] Flash Attention disabled - using manual attention")
61
+ else:
62
+ self.use_flash_attention = use_flash_attention
63
+
64
+ self.dropout = dropout
65
+ self.scale = 1.0 / math.sqrt(self.head_dim)
66
+
67
+ # Q, K, V projections (no bias)
68
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
69
+ self.k_proj = nn.Linear(d_model, d_model, bias=False)
70
+ self.v_proj = nn.Linear(d_model, d_model, bias=False)
71
+ self.o_proj = nn.Linear(d_model, d_model, bias=False)
72
+
73
+ # RoPE for positional encoding
74
+ # Apply to only part of head dimensions (typically 50%)
75
+ rope_dim = int(self.head_dim * rope_percentage)
76
+ self.rope_dim = rope_dim
77
+ self.rope = RotaryPositionEmbeddings(
78
+ head_dim=rope_dim,
79
+ max_seq_len=max_seq_len,
80
+ base=rope_base
81
+ )
82
+
83
+ # Dropout
84
+ self.attn_dropout = nn.Dropout(dropout)
85
+ self.resid_dropout = nn.Dropout(dropout)
86
+
87
+ # Pre-allocate causal mask more efficiently
88
+ # We'll create it on-demand based on sequence length
89
+ self.register_buffer('cached_mask', None, persistent=False)
90
+ self.register_buffer('cached_mask_size', torch.tensor(0), persistent=False)
91
+
92
+ def _get_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
93
+ """Get or create causal mask for the given sequence length.
94
+
95
+ CRITICAL: Always returns mask on the specified device to prevent CPU OOM errors.
96
+ """
97
+ if self.cached_mask is None or self.cached_mask_size < seq_len:
98
+ # Create a new mask directly on the target device
99
+ mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1)
100
+ mask = mask.masked_fill(mask == 1, float('-inf'))
101
+ self.cached_mask = mask
102
+ self.cached_mask_size = torch.tensor(seq_len)
103
+
104
+ # CRITICAL: Ensure the returned mask is on the correct device
105
+ # This prevents CPU OOM when broadcasting during attn_scores + causal_mask
106
+ return self.cached_mask[:seq_len, :seq_len].to(device)
107
+
108
+ def _apply_rope(
109
+ self,
110
+ q: torch.Tensor,
111
+ k: torch.Tensor,
112
+ position_ids: Optional[torch.Tensor] = None,
113
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
114
+ """Apply RoPE to partial dimensions of Q and K.
115
+
116
+ Args:
117
+ q: Query tensor [batch, seq_len, n_heads, head_dim]
118
+ k: Key tensor [batch, seq_len, n_heads, head_dim]
119
+ position_ids: Optional custom position IDs
120
+
121
+ Returns:
122
+ Rotated Q and K tensors
123
+ """
124
+ # Split into RoPE and pass-through dimensions
125
+ if self.rope_dim > 0:
126
+ q_rope, q_pass = q[..., :self.rope_dim], q[..., self.rope_dim:]
127
+ k_rope, k_pass = k[..., :self.rope_dim], k[..., self.rope_dim:]
128
+
129
+ # Apply RoPE to the first part
130
+ q_rope, k_rope = self.rope(q_rope, k_rope, position_ids)
131
+
132
+ # Concatenate back
133
+ q = torch.cat([q_rope, q_pass], dim=-1)
134
+ k = torch.cat([k_rope, k_pass], dim=-1)
135
+
136
+ return q, k
137
+
138
+ def forward(
139
+ self,
140
+ x: torch.Tensor,
141
+ attention_mask: Optional[torch.Tensor] = None,
142
+ position_ids: Optional[torch.Tensor] = None,
143
+ use_cache: bool = False,
144
+ past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
145
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
146
+ """Forward pass of multi-head attention.
147
+
148
+ Args:
149
+ x: Input tensor [batch, seq_len, d_model]
150
+ attention_mask: Optional attention mask
151
+ position_ids: Optional position IDs for RoPE
152
+ use_cache: Whether to return KV cache for inference
153
+ past_kv: Past key-value cache for inference
154
+
155
+ Returns:
156
+ Output tensor and optional KV cache
157
+ """
158
+ batch_size, seq_len, _ = x.size()
159
+
160
+ # Project to Q, K, V
161
+ q = self.q_proj(x) # [batch, seq_len, d_model]
162
+ k = self.k_proj(x) # [batch, seq_len, d_model]
163
+ v = self.v_proj(x) # [batch, seq_len, d_model]
164
+
165
+ # Reshape for multi-head attention
166
+ # [batch, seq_len, d_model] -> [batch, seq_len, n_heads, head_dim]
167
+ q = q.view(batch_size, seq_len, self.n_heads, self.head_dim)
168
+ k = k.view(batch_size, seq_len, self.n_heads, self.head_dim)
169
+ v = v.view(batch_size, seq_len, self.n_heads, self.head_dim)
170
+
171
+ # Apply RoPE to Q and K only (not V!)
172
+ q, k = self._apply_rope(q, k, position_ids)
173
+
174
+ # Handle KV cache for inference
175
+ if use_cache and past_kv is not None:
176
+ past_k, past_v = past_kv
177
+ k = torch.cat([past_k, k], dim=1)
178
+ v = torch.cat([past_v, v], dim=1)
179
+
180
+ kv_cache = (k, v) if use_cache else None
181
+
182
+ # Transpose for attention computation
183
+ # [batch, seq_len, n_heads, head_dim] -> [batch, n_heads, seq_len, head_dim]
184
+ q = q.transpose(1, 2).contiguous()
185
+ k = k.transpose(1, 2).contiguous()
186
+ v = v.transpose(1, 2).contiguous()
187
+
188
+ # Use Flash Attention 2 via SDPA when available
189
+ # This is MUCH more memory efficient than manual attention
190
+ if self.use_flash_attention and hasattr(F, 'scaled_dot_product_attention'):
191
+ # Flash Attention 2 is automatically used when available
192
+ # It handles the causal mask internally when is_causal=True
193
+ # NOTE: Windows compatibility - skip context manager to avoid freezing
194
+ import sys
195
+ if sys.platform == 'win32':
196
+ # On Windows, use SDPA without explicit kernel selection
197
+ attn_output = F.scaled_dot_product_attention(
198
+ q, k, v,
199
+ attn_mask=attention_mask,
200
+ dropout_p=self.dropout if self.training else 0.0,
201
+ is_causal=True if attention_mask is None else False,
202
+ scale=self.scale,
203
+ )
204
+ else:
205
+ # On Linux, use explicit kernel selection for best performance
206
+ with torch.backends.cuda.sdp_kernel(
207
+ enable_flash=True, # Use Flash Attention when possible
208
+ enable_math=True, # Fallback to math implementation
209
+ enable_mem_efficient=True # Use memory-efficient attention
210
+ ):
211
+ attn_output = F.scaled_dot_product_attention(
212
+ q, k, v,
213
+ attn_mask=attention_mask,
214
+ dropout_p=self.dropout if self.training else 0.0,
215
+ is_causal=True if attention_mask is None else False,
216
+ scale=self.scale,
217
+ )
218
+ else:
219
+ # Manual attention computation (fallback)
220
+ # This is memory-intensive and should only be used for small sequences
221
+ attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
222
+
223
+ # Apply causal mask
224
+ if attention_mask is None:
225
+ causal_mask = self._get_causal_mask(seq_len, x.device)
226
+ # Expand mask for batch and heads
227
+ causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
228
+ attn_scores = attn_scores + causal_mask
229
+ else:
230
+ attn_scores = attn_scores + attention_mask
231
+
232
+ # Apply softmax
233
+ attn_weights = F.softmax(attn_scores, dim=-1, dtype=torch.float32).to(q.dtype)
234
+ attn_weights = self.attn_dropout(attn_weights)
235
+
236
+ # Compute output
237
+ attn_output = torch.matmul(attn_weights, v)
238
+
239
+ # Reshape back
240
+ # [batch, n_heads, seq_len, head_dim] -> [batch, seq_len, d_model]
241
+ attn_output = attn_output.transpose(1, 2).contiguous()
242
+ attn_output = attn_output.view(batch_size, seq_len, self.d_model)
243
+
244
+ # Output projection
245
+ output = self.o_proj(attn_output)
246
+ output = self.resid_dropout(output)
247
+
248
+ return output, kv_cache
249
+
250
+
251
+ # Test the attention implementation
252
+ def test_attention():
253
+ """Test multi-head attention with various configurations."""
254
+ print("Testing Multi-Head Attention...")
255
+
256
+ # Test configuration
257
+ batch_size = 2
258
+ seq_len = 128
259
+ d_model = 768
260
+ n_heads = 12
261
+
262
+ # Create attention module
263
+ attention = MultiHeadAttention(
264
+ d_model=d_model,
265
+ n_heads=n_heads,
266
+ dropout=0.1,
267
+ max_seq_len=2048,
268
+ rope_percentage=0.5,
269
+ use_flash_attention=True, # Enable Flash Attention
270
+ )
271
+
272
+ # Move to GPU if available
273
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
274
+ attention = attention.to(device)
275
+ attention.eval() # Set to eval mode for testing
276
+
277
+ # Create dummy input
278
+ x = torch.randn(batch_size, seq_len, d_model, device=device, dtype=torch.bfloat16)
279
+
280
+ # Forward pass
281
+ with torch.no_grad():
282
+ output, _ = attention(x)
283
+
284
+ # Check output shape
285
+ assert output.shape == (batch_size, seq_len, d_model), \
286
+ f"Expected shape {(batch_size, seq_len, d_model)}, got {output.shape}"
287
+
288
+ # Check for NaN
289
+ assert not torch.isnan(output).any(), "Output contains NaN values!"
290
+
291
+ print("✓ Multi-Head Attention test passed!")
292
+ print(f" Input shape: {x.shape}")
293
+ print(f" Output shape: {output.shape}")
294
+ print(f" Device: {device}")
295
+ print(f" Memory allocated: {torch.cuda.memory_allocated(device) / 1024**3:.2f} GB")
296
+
297
+ return True
298
+
299
+
300
+ if __name__ == "__main__":
301
+ test_attention()
src/model/rmsnorm.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Root Mean Square Layer Normalization (RMSNorm) implementation.
2
+
3
+ Critical implementation details:
4
+ 1. Use multiplication with rsqrt, NOT division
5
+ 2. No mean subtraction (unlike LayerNorm)
6
+ 3. Compute in FP32 for numerical stability even when using BF16/FP16
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from typing import Optional
12
+
13
+
14
+ class RMSNorm(nn.Module):
15
+ """Root Mean Square Layer Normalization.
16
+
17
+ RMSNorm is a simplification of LayerNorm that removes the mean subtraction
18
+ and only performs re-scaling via root mean square.
19
+
20
+ Based on the paper: 'Root Mean Square Layer Normalization'
21
+ https://arxiv.org/abs/1910.07467
22
+ """
23
+
24
+ def __init__(self, hidden_size: int, eps: float = 1e-6):
25
+ """
26
+ Args:
27
+ hidden_size: Size of the hidden dimension
28
+ eps: Small constant for numerical stability (1e-6 for BF16, 1e-5 for FP16)
29
+ """
30
+ super().__init__()
31
+ self.hidden_size = hidden_size
32
+ # CRITICAL FIX: Ensure eps is stored as float, not string
33
+ self.eps = float(eps) if isinstance(eps, str) else eps
34
+
35
+ # Learnable scale parameter (gamma)
36
+ self.weight = nn.Parameter(torch.ones(hidden_size))
37
+
38
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
39
+ """Apply RMSNorm to input tensor.
40
+
41
+ CRITICAL BUG TO AVOID:
42
+ The most common bug is using division with torch.rsqrt:
43
+ WRONG: x / torch.rsqrt(variance + eps) # This is x * sqrt(variance)
44
+ RIGHT: x * torch.rsqrt(variance + eps) # This is x / sqrt(variance)
45
+
46
+ Args:
47
+ x: Input tensor of shape [..., hidden_size]
48
+
49
+ Returns:
50
+ Normalized tensor of same shape as input
51
+ """
52
+ # Store original dtype (for mixed precision training)
53
+ input_dtype = x.dtype
54
+
55
+ # CRITICAL: Compute in float32 for numerical stability
56
+ x_float32 = x.float()
57
+
58
+ # Compute RMS (root mean square)
59
+ # RMS = sqrt(mean(x^2))
60
+ variance = x_float32.pow(2).mean(dim=-1, keepdim=True)
61
+
62
+ # CRITICAL: Use rsqrt (reciprocal square root) with multiplication
63
+ # rsqrt(x) = 1/sqrt(x), so x * rsqrt(variance) = x / sqrt(variance)
64
+ # PERFORMANCE FIX: PyTorch automatically broadcasts scalars, no need for tensor()
65
+ x_normalized = x_float32 * torch.rsqrt(variance + self.eps)
66
+
67
+ # Apply learned scale and cast back to original dtype
68
+ return self.weight * x_normalized.to(input_dtype)
69
+
70
+ def extra_repr(self) -> str:
71
+ return f'hidden_size={self.hidden_size}, eps={self.eps}'
72
+
73
+
74
+ class RMSNormOptimized(nn.Module):
75
+ """Optimized RMSNorm with optional fused operations.
76
+
77
+ This version includes optimizations for better performance:
78
+ 1. Option for in-place operations
79
+ 2. Support for sequence parallelism
80
+ 3. Optional residual connection fusion
81
+ """
82
+
83
+ def __init__(
84
+ self,
85
+ hidden_size: int,
86
+ eps: float = 1e-6,
87
+ elementwise_affine: bool = True,
88
+ memory_efficient: bool = False,
89
+ ):
90
+ super().__init__()
91
+ self.hidden_size = hidden_size
92
+ # CRITICAL FIX: Ensure eps is stored as float, not string
93
+ self.eps = float(eps) if isinstance(eps, str) else eps
94
+ self.elementwise_affine = elementwise_affine
95
+ self.memory_efficient = memory_efficient
96
+
97
+ if self.elementwise_affine:
98
+ self.weight = nn.Parameter(torch.ones(hidden_size))
99
+ else:
100
+ self.register_parameter('weight', None)
101
+
102
+ def forward(
103
+ self,
104
+ x: torch.Tensor,
105
+ residual: Optional[torch.Tensor] = None,
106
+ ) -> torch.Tensor:
107
+ """Apply RMSNorm with optional residual connection.
108
+
109
+ Args:
110
+ x: Input tensor
111
+ residual: Optional residual to add before normalization
112
+
113
+ Returns:
114
+ Normalized tensor (and residual if provided)
115
+ """
116
+ # Add residual if provided (pre-norm architecture)
117
+ if residual is not None:
118
+ x = x + residual
119
+ residual = x # Save for skip connection
120
+
121
+ # Original dtype for mixed precision
122
+ input_dtype = x.dtype
123
+
124
+ # Compute in FP32
125
+ if self.memory_efficient:
126
+ # In-place operations to save memory
127
+ x = x.float()
128
+ variance = x.pow_(2).mean(dim=-1, keepdim=True)
129
+ # PERFORMANCE FIX: Use scalar directly
130
+ x.mul_(torch.rsqrt(variance + self.eps))
131
+ else:
132
+ # Standard computation
133
+ x_float32 = x.float()
134
+ variance = x_float32.pow(2).mean(dim=-1, keepdim=True)
135
+ # PERFORMANCE FIX: Use scalar directly
136
+ x = x_float32 * torch.rsqrt(variance + self.eps)
137
+
138
+ # Apply weight and cast back
139
+ if self.elementwise_affine:
140
+ x = self.weight * x
141
+
142
+ x = x.to(input_dtype)
143
+
144
+ if residual is not None:
145
+ return x, residual
146
+ return x
147
+
148
+
149
+ def rmsnorm_func(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
150
+ """Functional version of RMSNorm for use in torch.compile or custom kernels.
151
+
152
+ This can be used with torch.compile for better optimization.
153
+ """
154
+ input_dtype = x.dtype
155
+ x = x.float()
156
+ variance = x.pow(2).mean(dim=-1, keepdim=True)
157
+ # Ensure eps is properly handled
158
+ eps_val = float(eps) if isinstance(eps, str) else eps
159
+ x = x * torch.rsqrt(variance + eps_val)
160
+ return (weight * x).to(input_dtype)
161
+
162
+
163
+ # Comparison with LayerNorm for reference
164
+ def compare_normalization():
165
+ """Compare RMSNorm with LayerNorm to understand the differences."""
166
+ import torch.nn as nn
167
+
168
+ batch_size, seq_len, hidden = 2, 10, 768
169
+ x = torch.randn(batch_size, seq_len, hidden)
170
+
171
+ # LayerNorm: normalizes by mean and variance
172
+ layer_norm = nn.LayerNorm(hidden)
173
+ ln_out = layer_norm(x)
174
+
175
+ # RMSNorm: normalizes by RMS only (no mean subtraction)
176
+ rms_norm = RMSNorm(hidden)
177
+ rms_out = rms_norm(x)
178
+
179
+ print(f"Input shape: {x.shape}")
180
+ print(f"LayerNorm output shape: {ln_out.shape}")
181
+ print(f"RMSNorm output shape: {rms_out.shape}")
182
+ print(f"Mean difference: {(ln_out - rms_out).abs().mean().item():.6f}")
183
+
184
+ # RMSNorm is 15-20% faster due to simpler computation
185
+ return ln_out, rms_out
src/model/rope.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Rotary Position Embeddings (RoPE) implementation.
2
+
3
+ Critical implementation details:
4
+ 1. Apply RoPE only to Q and K, never to V
5
+ 2. Use head_dim, not full model dimension
6
+ 3. Ensure proper dimension pairing for rotation
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import math
12
+ from typing import Optional, Tuple
13
+
14
+
15
+ class RotaryPositionEmbeddings(nn.Module):
16
+ """Rotary Position Embeddings (RoPE) for transformer models.
17
+
18
+ Based on the paper: 'RoFormer: Enhanced Transformer with Rotary Position Embedding'
19
+ https://arxiv.org/abs/2104.09864
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ head_dim: int,
25
+ max_seq_len: int = 2048,
26
+ base: int = 10000,
27
+ device: Optional[torch.device] = None,
28
+ ):
29
+ super().__init__()
30
+ self.head_dim = head_dim
31
+ self.max_seq_len = max_seq_len
32
+ self.base = base
33
+
34
+ # CRITICAL: head_dim must be even for proper pairing
35
+ assert head_dim % 2 == 0, f"head_dim must be even, got {head_dim}"
36
+
37
+ # Precompute frequencies
38
+ self._precompute_freqs(device)
39
+
40
+ def _precompute_freqs(self, device: Optional[torch.device] = None):
41
+ """Precompute the frequency tensor for RoPE."""
42
+ # Calculate theta frequencies
43
+ # theta_i = base^(-2i/d) for i in [0, 1, ..., d/2-1]
44
+ theta = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim))
45
+
46
+ # Create position indices
47
+ positions = torch.arange(self.max_seq_len).float()
48
+
49
+ # Compute outer product: [seq_len, head_dim/2]
50
+ freqs = torch.einsum('i,j->ij', positions, theta)
51
+
52
+ # Convert to cos and sin for rotation
53
+ freqs_cos = torch.cos(freqs) # [seq_len, head_dim/2]
54
+ freqs_sin = torch.sin(freqs) # [seq_len, head_dim/2]
55
+
56
+ # Duplicate for full dimension coverage
57
+ # [seq_len, head_dim/2] -> [seq_len, head_dim]
58
+ freqs_cos = torch.cat([freqs_cos, freqs_cos], dim=-1)
59
+ freqs_sin = torch.cat([freqs_sin, freqs_sin], dim=-1)
60
+
61
+ # Register as buffers (not trainable, moves with model to device)
62
+ self.register_buffer('freqs_cos', freqs_cos, persistent=False)
63
+ self.register_buffer('freqs_sin', freqs_sin, persistent=False)
64
+
65
+ def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
66
+ """Rotate half the hidden dims of the input.
67
+
68
+ CRITICAL: This is the most common bug - incorrect dimension pairing.
69
+ For input [1, 2, 3, 4], output should be [-3, -4, 1, 2].
70
+ """
71
+ x1 = x[..., :x.shape[-1] // 2]
72
+ x2 = x[..., x.shape[-1] // 2:]
73
+ return torch.cat([-x2, x1], dim=-1)
74
+
75
+ def apply_rotary_pos_emb(
76
+ self,
77
+ q: torch.Tensor,
78
+ k: torch.Tensor,
79
+ position_ids: Optional[torch.Tensor] = None,
80
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
81
+ """Apply rotary position embeddings to query and key tensors.
82
+
83
+ Args:
84
+ q: Query tensor of shape [batch, seq_len, num_heads, head_dim]
85
+ k: Key tensor of shape [batch, seq_len, num_heads, head_dim]
86
+ position_ids: Optional custom position IDs
87
+
88
+ Returns:
89
+ Tuple of rotated (q, k) tensors
90
+ """
91
+ seq_len = q.shape[1]
92
+
93
+ # Get the frequency tensors for current sequence length
94
+ if position_ids is not None:
95
+ freqs_cos = self.freqs_cos[position_ids]
96
+ freqs_sin = self.freqs_sin[position_ids]
97
+ else:
98
+ freqs_cos = self.freqs_cos[:seq_len]
99
+ freqs_sin = self.freqs_sin[:seq_len]
100
+
101
+ # Reshape for broadcasting
102
+ # [seq_len, head_dim] -> [1, seq_len, 1, head_dim]
103
+ freqs_cos = freqs_cos[None, :, None, :]
104
+ freqs_sin = freqs_sin[None, :, None, :]
105
+
106
+ # Apply rotation using the formula:
107
+ # x_rotated = x * cos + rotate_half(x) * sin
108
+ q_rotated = q * freqs_cos + self.rotate_half(q) * freqs_sin
109
+ k_rotated = k * freqs_cos + self.rotate_half(k) * freqs_sin
110
+
111
+ return q_rotated, k_rotated
112
+
113
+ def forward(
114
+ self,
115
+ q: torch.Tensor,
116
+ k: torch.Tensor,
117
+ position_ids: Optional[torch.Tensor] = None,
118
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
119
+ """Forward pass - apply RoPE to Q and K only.
120
+
121
+ CRITICAL: Never apply RoPE to V (value) tensor!
122
+ """
123
+ return self.apply_rotary_pos_emb(q, k, position_ids)
124
+
125
+
126
+ # Alternative implementation using complex numbers directly
127
+ class RotaryPositionEmbeddingsComplex(nn.Module):
128
+ """Alternative RoPE implementation using complex number operations.
129
+
130
+ This can be more efficient on some hardware but requires careful handling.
131
+ """
132
+
133
+ def __init__(
134
+ self,
135
+ head_dim: int,
136
+ max_seq_len: int = 2048,
137
+ base: int = 10000,
138
+ device: Optional[torch.device] = None,
139
+ ):
140
+ super().__init__()
141
+ self.head_dim = head_dim
142
+ self.max_seq_len = max_seq_len
143
+ self.base = base
144
+
145
+ assert head_dim % 2 == 0, f"head_dim must be even, got {head_dim}"
146
+
147
+ # Precompute complex exponentials
148
+ inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
149
+ t = torch.arange(max_seq_len, dtype=inv_freq.dtype)
150
+ freqs = torch.einsum('i,j->ij', t, inv_freq)
151
+
152
+ # Store as cos/sin values
153
+ emb = torch.cat([freqs, freqs], dim=-1)
154
+ self.register_buffer('cos_cached', emb.cos()[None, :, None, :])
155
+ self.register_buffer('sin_cached', emb.sin()[None, :, None, :])
156
+
157
+ def forward(
158
+ self,
159
+ q: torch.Tensor,
160
+ k: torch.Tensor,
161
+ seq_len: Optional[int] = None,
162
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
163
+ """Apply RoPE using cached cos/sin values."""
164
+ if seq_len is None:
165
+ seq_len = q.shape[1]
166
+
167
+ # Apply rotation
168
+ q_embed = (q * self.cos_cached[:, :seq_len]) + \
169
+ (self.rotate_half(q) * self.sin_cached[:, :seq_len])
170
+ k_embed = (k * self.cos_cached[:, :seq_len]) + \
171
+ (self.rotate_half(k) * self.sin_cached[:, :seq_len])
172
+
173
+ return q_embed, k_embed
174
+
175
+ def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
176
+ """Rotate half the hidden dims."""
177
+ x1, x2 = x.chunk(2, dim=-1)
178
+ return torch.cat([-x2, x1], dim=-1)
179
+
180
+
181
+ # Test function for RoPE
182
+ def test_rope():
183
+ """Test RoPE implementation."""
184
+ print("Testing RoPE implementation...")
185
+
186
+ batch_size = 2
187
+ seq_len = 128
188
+ n_heads = 12
189
+ head_dim = 64
190
+
191
+ # Create RoPE module
192
+ rope = RotaryPositionEmbeddings(head_dim=head_dim, max_seq_len=2048)
193
+
194
+ # Create dummy Q and K tensors
195
+ q = torch.randn(batch_size, seq_len, n_heads, head_dim)
196
+ k = torch.randn(batch_size, seq_len, n_heads, head_dim)
197
+
198
+ # Apply RoPE
199
+ q_rot, k_rot = rope(q, k)
200
+
201
+ # Check shapes
202
+ assert q_rot.shape == q.shape, f"Q shape mismatch: {q_rot.shape} != {q.shape}"
203
+ assert k_rot.shape == k.shape, f"K shape mismatch: {k_rot.shape} != {k.shape}"
204
+
205
+ # Check for NaN
206
+ assert not torch.isnan(q_rot).any(), "Q contains NaN after RoPE"
207
+ assert not torch.isnan(k_rot).any(), "K contains NaN after RoPE"
208
+
209
+ print("✓ RoPE test passed!")
210
+ print(f" Input shape: {q.shape}")
211
+ print(f" Output shape: {q_rot.shape}")
212
+
213
+ return True
214
+
215
+
216
+ if __name__ == "__main__":
217
+ test_rope()
src/model/swiglu.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SwiGLU (Swish-Gated Linear Unit) activation function implementation.
2
+
3
+ Critical implementation details:
4
+ 1. Requires THREE weight matrices (gate, value, down-projection)
5
+ 2. Hidden dimension should be adjusted to ~8/3 * d_model for parameter parity
6
+ 3. No bias terms in modern implementations
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from typing import Optional
13
+
14
+
15
+ class SwiGLU(nn.Module):
16
+ """Swish-Gated Linear Unit activation function.
17
+
18
+ SwiGLU combines the Swish activation (SiLU) with a gating mechanism
19
+ for improved gradient flow in deep networks.
20
+
21
+ Based on the paper: 'GLU Variants Improve Transformer'
22
+ https://arxiv.org/abs/2002.05202
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ input_dim: int,
28
+ hidden_dim: Optional[int] = None,
29
+ output_dim: Optional[int] = None,
30
+ multiple_of: int = 256,
31
+ bias: bool = False,
32
+ ):
33
+ """
34
+ Args:
35
+ input_dim: Input dimension (d_model)
36
+ hidden_dim: Hidden dimension for FFN. If None, uses 8/3 * input_dim
37
+ output_dim: Output dimension. If None, uses input_dim
38
+ multiple_of: Round hidden_dim to nearest multiple for hardware efficiency
39
+ bias: Whether to use bias terms (modern LLMs use False)
40
+ """
41
+ super().__init__()
42
+
43
+ self.input_dim = input_dim
44
+ self.output_dim = output_dim or input_dim
45
+
46
+ # CRITICAL: Adjust hidden dimension for parameter parity
47
+ # Standard FFN with ReLU/GELU uses 4 * d_model
48
+ # SwiGLU needs 3 matrices, so use (8/3) * d_model for same param count
49
+ if hidden_dim is None:
50
+ hidden_dim = int(8 * input_dim / 3)
51
+
52
+ # Round to nearest multiple for better hardware utilization
53
+ if multiple_of > 1:
54
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
55
+
56
+ self.hidden_dim = hidden_dim
57
+
58
+ # Three linear projections required for SwiGLU
59
+ self.w_gate = nn.Linear(input_dim, hidden_dim, bias=bias) # Gate projection
60
+ self.w_up = nn.Linear(input_dim, hidden_dim, bias=bias) # Value/up projection
61
+ self.w_down = nn.Linear(hidden_dim, self.output_dim, bias=bias) # Down projection
62
+
63
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
64
+ """Apply SwiGLU activation.
65
+
66
+ Formula: SwiGLU(x) = (Swish(xW_gate) ⊗ xW_up) W_down
67
+ where Swish(x) = x * sigmoid(x) = SiLU(x)
68
+
69
+ Args:
70
+ x: Input tensor of shape [..., input_dim]
71
+
72
+ Returns:
73
+ Output tensor of shape [..., output_dim]
74
+ """
75
+ # Gate path with Swish/SiLU activation
76
+ gate = F.silu(self.w_gate(x))
77
+
78
+ # Value path (no activation)
79
+ value = self.w_up(x)
80
+
81
+ # Element-wise multiplication (gating)
82
+ hidden = gate * value
83
+
84
+ # Down projection to output dimension
85
+ output = self.w_down(hidden)
86
+
87
+ return output
88
+
89
+ def extra_repr(self) -> str:
90
+ return (
91
+ f'input_dim={self.input_dim}, '
92
+ f'hidden_dim={self.hidden_dim}, '
93
+ f'output_dim={self.output_dim}'
94
+ )
95
+
96
+
97
+ class SwiGLUParallel(nn.Module):
98
+ """Parallel version of SwiGLU that combines gate and up projections.
99
+
100
+ This is more efficient as it reduces the number of separate matmuls.
101
+ Used in models like LLaMA and Mistral.
102
+ """
103
+
104
+ def __init__(
105
+ self,
106
+ input_dim: int,
107
+ hidden_dim: Optional[int] = None,
108
+ output_dim: Optional[int] = None,
109
+ multiple_of: int = 256,
110
+ bias: bool = False,
111
+ ):
112
+ super().__init__()
113
+
114
+ self.input_dim = input_dim
115
+ self.output_dim = output_dim or input_dim
116
+
117
+ if hidden_dim is None:
118
+ hidden_dim = int(8 * input_dim / 3)
119
+
120
+ if multiple_of > 1:
121
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
122
+
123
+ self.hidden_dim = hidden_dim
124
+
125
+ # Combined gate and up projection for efficiency
126
+ # Output shape: [batch, seq, 2 * hidden_dim]
127
+ self.w_gate_up = nn.Linear(input_dim, 2 * hidden_dim, bias=bias)
128
+ self.w_down = nn.Linear(hidden_dim, self.output_dim, bias=bias)
129
+
130
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
131
+ """Apply SwiGLU with parallel projections."""
132
+ # Single matmul for both gate and up projections
133
+ gate_up = self.w_gate_up(x)
134
+
135
+ # Split into gate and up components
136
+ gate, up = gate_up.chunk(2, dim=-1)
137
+
138
+ # Apply SwiGLU
139
+ hidden = F.silu(gate) * up
140
+ output = self.w_down(hidden)
141
+
142
+ return output
143
+
144
+
145
+ class GeGLU(nn.Module):
146
+ """GELU-Gated Linear Unit - alternative to SwiGLU.
147
+
148
+ Some models use GeGLU instead of SwiGLU. The difference is using
149
+ GELU instead of SiLU for the gating activation.
150
+ """
151
+
152
+ def __init__(
153
+ self,
154
+ input_dim: int,
155
+ hidden_dim: Optional[int] = None,
156
+ output_dim: Optional[int] = None,
157
+ bias: bool = False,
158
+ ):
159
+ super().__init__()
160
+
161
+ self.input_dim = input_dim
162
+ self.output_dim = output_dim or input_dim
163
+
164
+ if hidden_dim is None:
165
+ hidden_dim = int(8 * input_dim / 3)
166
+
167
+ self.hidden_dim = hidden_dim
168
+
169
+ self.w_gate = nn.Linear(input_dim, hidden_dim, bias=bias)
170
+ self.w_up = nn.Linear(input_dim, hidden_dim, bias=bias)
171
+ self.w_down = nn.Linear(hidden_dim, self.output_dim, bias=bias)
172
+
173
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
174
+ """Apply GeGLU activation."""
175
+ gate = F.gelu(self.w_gate(x))
176
+ value = self.w_up(x)
177
+ hidden = gate * value
178
+ output = self.w_down(hidden)
179
+ return output
180
+
181
+
182
+ def calculate_ffn_params(d_model: int, activation: str = "swiglu") -> dict:
183
+ """Calculate FFN parameters for different activation functions.
184
+
185
+ This helper ensures parameter parity across different activation types.
186
+ """
187
+ if activation == "relu" or activation == "gelu":
188
+ # Standard FFN: 2 matrices
189
+ hidden_dim = 4 * d_model
190
+ num_params = 2 * d_model * hidden_dim
191
+ elif activation in ["swiglu", "geglu"]:
192
+ # Gated FFN: 3 matrices, adjust hidden dimension
193
+ hidden_dim = int(8 * d_model / 3)
194
+ # Round to multiple of 256 for hardware efficiency
195
+ hidden_dim = 256 * ((hidden_dim + 255) // 256)
196
+ num_params = d_model * hidden_dim * 2 + hidden_dim * d_model
197
+ else:
198
+ raise ValueError(f"Unknown activation: {activation}")
199
+
200
+ return {
201
+ "activation": activation,
202
+ "d_model": d_model,
203
+ "hidden_dim": hidden_dim,
204
+ "num_params": num_params,
205
+ "params_millions": num_params / 1e6,
206
+ }
207
+
208
+
209
+ # Example usage and parameter comparison
210
+ if __name__ == "__main__":
211
+ d_model = 768
212
+
213
+ # Compare parameter counts
214
+ print("FFN Parameter Comparison:")
215
+ for act in ["relu", "gelu", "swiglu"]:
216
+ params = calculate_ffn_params(d_model, act)
217
+ print(f"{act.upper()}:")
218
+ print(f" Hidden dim: {params['hidden_dim']}")
219
+ print(f" Parameters: {params['params_millions']:.2f}M")
220
+
221
+ # Test SwiGLU
222
+ batch_size, seq_len = 2, 512
223
+ x = torch.randn(batch_size, seq_len, d_model)
224
+
225
+ swiglu = SwiGLU(d_model)
226
+ output = swiglu(x)
227
+
228
+ print(f"\nSwiGLU Test:")
229
+ print(f"Input shape: {x.shape}")
230
+ print(f"Output shape: {output.shape}")
231
+ print(f"SwiGLU parameters: {sum(p.numel() for p in swiglu.parameters()) / 1e6:.2f}M")
src/model/transformer_block.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Transformer block with pre-normalization architecture and memory optimizations.
2
+
3
+ Critical implementation details:
4
+ 1. Pre-normalization: RMSNorm BEFORE attention and FFN
5
+ 2. Residual connections after each sub-layer
6
+ 3. Modern component stack: RoPE + RMSNorm + SwiGLU
7
+ 4. Gradient checkpointing support for memory efficiency
8
+ """
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from typing import Optional, Tuple, Dict, Any
13
+ from torch.utils.checkpoint import checkpoint
14
+
15
+ from .rmsnorm import RMSNorm
16
+ from .attention import MultiHeadAttention
17
+ from .swiglu import SwiGLU
18
+
19
+
20
+ class TransformerBlock(nn.Module):
21
+ """Single transformer block with pre-normalization.
22
+
23
+ This follows the modern architecture used in LLaMA, Mistral, etc:
24
+ - Pre-normalization with RMSNorm
25
+ - Multi-head attention with RoPE
26
+ - SwiGLU activation in FFN
27
+ - Residual connections
28
+ - Gradient checkpointing support
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ d_model: int = 768,
34
+ n_heads: int = 12,
35
+ d_ffn: Optional[int] = None,
36
+ dropout: float = 0.1,
37
+ max_seq_len: int = 2048,
38
+ rope_base: int = 10000,
39
+ rope_percentage: float = 0.5,
40
+ rms_norm_eps: float = 1e-6,
41
+ use_flash_attention: bool = True,
42
+ use_gradient_checkpointing: bool = False,
43
+ ):
44
+ super().__init__()
45
+
46
+ self.d_model = d_model
47
+ self.n_heads = n_heads
48
+ self.use_gradient_checkpointing = use_gradient_checkpointing
49
+
50
+ # Pre-normalization layers
51
+ self.attn_norm = RMSNorm(d_model, eps=rms_norm_eps)
52
+ self.ffn_norm = RMSNorm(d_model, eps=rms_norm_eps)
53
+
54
+ # Multi-head attention with RoPE
55
+ self.attention = MultiHeadAttention(
56
+ d_model=d_model,
57
+ n_heads=n_heads,
58
+ dropout=dropout,
59
+ max_seq_len=max_seq_len,
60
+ rope_base=rope_base,
61
+ rope_percentage=rope_percentage,
62
+ use_flash_attention=use_flash_attention,
63
+ )
64
+
65
+ # SwiGLU FFN
66
+ # Default hidden dimension: 8/3 * d_model for parameter parity
67
+ if d_ffn is None:
68
+ d_ffn = int(8 * d_model / 3)
69
+ # Round to multiple of 256 for hardware efficiency
70
+ d_ffn = 256 * ((d_ffn + 255) // 256)
71
+
72
+ self.ffn = SwiGLU(
73
+ input_dim=d_model,
74
+ hidden_dim=d_ffn,
75
+ output_dim=d_model,
76
+ bias=False,
77
+ )
78
+
79
+ def _attention_block(
80
+ self,
81
+ x: torch.Tensor,
82
+ attention_mask: Optional[torch.Tensor] = None,
83
+ position_ids: Optional[torch.Tensor] = None,
84
+ use_cache: bool = False,
85
+ past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
86
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
87
+ """Attention sub-block with pre-norm."""
88
+ # Pre-normalization
89
+ x_norm = self.attn_norm(x)
90
+
91
+ # Multi-head attention
92
+ attn_output, kv_cache = self.attention(
93
+ x_norm,
94
+ attention_mask=attention_mask,
95
+ position_ids=position_ids,
96
+ use_cache=use_cache,
97
+ past_kv=past_kv,
98
+ )
99
+
100
+ # Residual connection
101
+ return attn_output, kv_cache
102
+
103
+ def _ffn_block(self, x: torch.Tensor) -> torch.Tensor:
104
+ """Feed-forward sub-block with pre-norm."""
105
+ # Pre-normalization
106
+ x_norm = self.ffn_norm(x)
107
+
108
+ # Feed-forward
109
+ ffn_output = self.ffn(x_norm)
110
+
111
+ return ffn_output
112
+
113
+ def forward(
114
+ self,
115
+ x: torch.Tensor,
116
+ attention_mask: Optional[torch.Tensor] = None,
117
+ position_ids: Optional[torch.Tensor] = None,
118
+ use_cache: bool = False,
119
+ past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
120
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
121
+ """Forward pass of transformer block.
122
+
123
+ Args:
124
+ x: Input tensor [batch, seq_len, d_model]
125
+ attention_mask: Optional attention mask
126
+ position_ids: Optional position IDs for RoPE
127
+ use_cache: Whether to return KV cache
128
+ past_kv: Past key-value cache
129
+
130
+ Returns:
131
+ Output tensor and optional KV cache
132
+ """
133
+ # Attention block with residual
134
+ if self.use_gradient_checkpointing and self.training:
135
+ # Use gradient checkpointing to save memory during training
136
+ def attention_fn(x_in):
137
+ attn_out, _ = self._attention_block(
138
+ x_in,
139
+ attention_mask=attention_mask,
140
+ position_ids=position_ids,
141
+ use_cache=False, # Can't use cache with checkpointing
142
+ past_kv=None,
143
+ )
144
+ return attn_out
145
+
146
+ attn_output = checkpoint(attention_fn, x, use_reentrant=False)
147
+ kv_cache = None
148
+ else:
149
+ attn_output, kv_cache = self._attention_block(
150
+ x,
151
+ attention_mask=attention_mask,
152
+ position_ids=position_ids,
153
+ use_cache=use_cache,
154
+ past_kv=past_kv,
155
+ )
156
+
157
+ # Add residual for attention
158
+ x = x + attn_output
159
+
160
+ # FFN block with residual
161
+ if self.use_gradient_checkpointing and self.training:
162
+ # Use gradient checkpointing for FFN as well
163
+ ffn_output = checkpoint(self._ffn_block, x, use_reentrant=False)
164
+ else:
165
+ ffn_output = self._ffn_block(x)
166
+
167
+ # Add residual for FFN
168
+ x = x + ffn_output
169
+
170
+ return x, kv_cache
171
+
172
+
173
+ class WikiMiniModel(nn.Module):
174
+ """Complete WikiMini 95M language model.
175
+
176
+ Architecture:
177
+ - Token embeddings with weight tying
178
+ - Stack of transformer blocks
179
+ - Final RMSNorm
180
+ - LM head (tied with embeddings)
181
+ """
182
+
183
+ def __init__(self, config: Dict[str, Any]):
184
+ super().__init__()
185
+
186
+ # Extract config values with defaults
187
+ self.vocab_size = config.get('vocab_size', 32000)
188
+ self.d_model = config.get('d_model', 768)
189
+ self.n_layers = config.get('n_layers', 12)
190
+ self.n_heads = config.get('n_heads', 12)
191
+ self.d_ffn = config.get('d_ffn', None)
192
+ self.max_seq_len = config.get('max_seq_len', 2048)
193
+ self.dropout = config.get('dropout', 0.1)
194
+ self.rope_percentage = config.get('rope_percentage', 0.5)
195
+ self.rope_base = config.get('rope_base', 10000)
196
+ self.rms_norm_eps = config.get('rms_norm_eps', 1e-6)
197
+ self.tie_embeddings = config.get('tie_embeddings', True)
198
+ self.use_flash_attention = config.get('use_flash_attention', True)
199
+ self.use_gradient_checkpointing = config.get('gradient_checkpointing', False)
200
+
201
+ # Token embeddings
202
+ self.token_embedding = nn.Embedding(self.vocab_size, self.d_model)
203
+
204
+ # Transformer blocks
205
+ self.blocks = nn.ModuleList([
206
+ TransformerBlock(
207
+ d_model=self.d_model,
208
+ n_heads=self.n_heads,
209
+ d_ffn=self.d_ffn,
210
+ dropout=self.dropout,
211
+ max_seq_len=self.max_seq_len,
212
+ rope_base=self.rope_base,
213
+ rope_percentage=self.rope_percentage,
214
+ rms_norm_eps=self.rms_norm_eps,
215
+ use_flash_attention=self.use_flash_attention,
216
+ use_gradient_checkpointing=self.use_gradient_checkpointing,
217
+ )
218
+ for _ in range(self.n_layers)
219
+ ])
220
+
221
+ # Final normalization
222
+ self.final_norm = RMSNorm(self.d_model, eps=self.rms_norm_eps)
223
+
224
+ # Language modeling head
225
+ self.lm_head = nn.Linear(self.d_model, self.vocab_size, bias=False)
226
+
227
+ # Weight tying
228
+ if self.tie_embeddings:
229
+ self.lm_head.weight = self.token_embedding.weight
230
+
231
+ # Initialize weights
232
+ self._init_weights()
233
+
234
+ def _init_weights(self):
235
+ """Initialize weights with scaled normal distribution."""
236
+ for module in self.modules():
237
+ if isinstance(module, nn.Linear):
238
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
239
+ if module.bias is not None:
240
+ torch.nn.init.zeros_(module.bias)
241
+ elif isinstance(module, nn.Embedding):
242
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
243
+
244
+ def enable_gradient_checkpointing(self):
245
+ """Enable gradient checkpointing for all transformer blocks."""
246
+ self.use_gradient_checkpointing = True
247
+ for block in self.blocks:
248
+ block.use_gradient_checkpointing = True
249
+
250
+ def disable_gradient_checkpointing(self):
251
+ """Disable gradient checkpointing for all transformer blocks."""
252
+ self.use_gradient_checkpointing = False
253
+ for block in self.blocks:
254
+ block.use_gradient_checkpointing = False
255
+
256
+ def count_parameters(self) -> dict:
257
+ """Count model parameters by component.
258
+
259
+ Returns:
260
+ Dictionary with parameter counts for each component
261
+ """
262
+ # Count by component type
263
+ embedding_params = sum(p.numel() for p in self.token_embedding.parameters())
264
+
265
+ attention_params = 0
266
+ ffn_params = 0
267
+ norm_params = 0
268
+
269
+ for block in self.blocks:
270
+ # Attention parameters
271
+ attention_params += sum(p.numel() for p in block.attention.parameters())
272
+ # FFN parameters
273
+ ffn_params += sum(p.numel() for p in block.ffn.parameters())
274
+ # Norm parameters (attention + ffn norms)
275
+ norm_params += sum(p.numel() for p in block.attn_norm.parameters())
276
+ norm_params += sum(p.numel() for p in block.ffn_norm.parameters())
277
+
278
+ # Final norm
279
+ norm_params += sum(p.numel() for p in self.final_norm.parameters())
280
+
281
+ # LM head (only if not tied)
282
+ if not self.tie_embeddings:
283
+ lm_head_params = sum(p.numel() for p in self.lm_head.parameters())
284
+ else:
285
+ lm_head_params = 0 # Shared with embeddings
286
+
287
+ total_params = sum(p.numel() for p in self.parameters())
288
+
289
+ return {
290
+ 'total': total_params,
291
+ 'total_millions': total_params / 1e6,
292
+ 'embedding': embedding_params,
293
+ 'attention': attention_params,
294
+ 'ffn': ffn_params,
295
+ 'norm': norm_params,
296
+ 'lm_head': lm_head_params,
297
+ }
298
+
299
+ def forward(
300
+ self,
301
+ input_ids: torch.Tensor,
302
+ attention_mask: Optional[torch.Tensor] = None,
303
+ position_ids: Optional[torch.Tensor] = None,
304
+ labels: Optional[torch.Tensor] = None,
305
+ use_cache: bool = False,
306
+ past_key_values: Optional[list] = None,
307
+ ) -> Dict[str, torch.Tensor]:
308
+ """Forward pass of the model.
309
+
310
+ Args:
311
+ input_ids: Token IDs [batch, seq_len]
312
+ attention_mask: Optional attention mask
313
+ position_ids: Optional position IDs
314
+ labels: Optional labels for language modeling loss
315
+ use_cache: Whether to return KV cache
316
+ past_key_values: Past KV cache for inference
317
+
318
+ Returns:
319
+ Dictionary with 'logits' and optionally 'loss' and 'past_key_values'
320
+ """
321
+ batch_size, seq_len = input_ids.shape
322
+
323
+ # Token embeddings
324
+ x = self.token_embedding(input_ids)
325
+
326
+ # Apply dropout to embeddings
327
+ x = nn.functional.dropout(x, p=self.dropout, training=self.training)
328
+
329
+ # Process through transformer blocks
330
+ past_key_values_out = [] if use_cache else None
331
+
332
+ for i, block in enumerate(self.blocks):
333
+ # Get past KV for this layer if available
334
+ past_kv = past_key_values[i] if past_key_values is not None else None
335
+
336
+ # Process through block
337
+ x, kv_cache = block(
338
+ x,
339
+ attention_mask=attention_mask,
340
+ position_ids=position_ids,
341
+ use_cache=use_cache,
342
+ past_kv=past_kv,
343
+ )
344
+
345
+ # Store KV cache if needed
346
+ if use_cache:
347
+ past_key_values_out.append(kv_cache)
348
+
349
+ # Final normalization
350
+ x = self.final_norm(x)
351
+
352
+ # Language modeling head
353
+ logits = self.lm_head(x)
354
+
355
+ # Prepare output
356
+ output = {'logits': logits}
357
+
358
+ # Calculate loss if labels provided
359
+ if labels is not None:
360
+ # Shift for next-token prediction
361
+ shift_logits = logits[..., :-1, :].contiguous()
362
+ shift_labels = labels[..., 1:].contiguous()
363
+
364
+ # Flatten for cross-entropy
365
+ shift_logits = shift_logits.view(-1, self.vocab_size)
366
+ shift_labels = shift_labels.view(-1)
367
+
368
+ # Calculate cross-entropy loss
369
+ loss = nn.functional.cross_entropy(
370
+ shift_logits,
371
+ shift_labels,
372
+ ignore_index=-100, # Standard ignore index
373
+ )
374
+
375
+ output['loss'] = loss
376
+
377
+ # Add KV cache to output if requested
378
+ if use_cache:
379
+ output['past_key_values'] = past_key_values_out
380
+
381
+ return output
382
+
383
+
384
+ def create_model(config: Dict[str, Any]) -> WikiMiniModel:
385
+ """Create a WikiMini model from configuration.
386
+
387
+ Args:
388
+ config: Model configuration dictionary
389
+
390
+ Returns:
391
+ WikiMiniModel instance
392
+ """
393
+ return WikiMiniModel(config)
394
+
395
+
396
+ # Test the complete model
397
+ if __name__ == "__main__":
398
+ # Test configuration for ~95M parameters
399
+ config = {
400
+ 'vocab_size': 32000,
401
+ 'd_model': 768,
402
+ 'n_layers': 12,
403
+ 'n_heads': 12,
404
+ 'd_ffn': 2048, # Adjusted for SwiGLU
405
+ 'max_seq_len': 2048,
406
+ 'dropout': 0.1,
407
+ 'rope_percentage': 0.5,
408
+ 'rope_base': 10000,
409
+ 'rms_norm_eps': 1e-6,
410
+ 'tie_embeddings': True,
411
+ 'use_flash_attention': True,
412
+ 'gradient_checkpointing': True, # Enable for memory efficiency
413
+ }
414
+
415
+ # Create model
416
+ model = WikiMiniModel(config)
417
+
418
+ # Count parameters
419
+ total_params = sum(p.numel() for p in model.parameters())
420
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
421
+
422
+ print(f"WikiMini Model:")
423
+ print(f" Total parameters: {total_params:,} ({total_params/1e6:.2f}M)")
424
+ print(f" Trainable parameters: {trainable_params:,} ({trainable_params/1e6:.2f}M)")
425
+ print(f" Layers: {model.n_layers}")
426
+ print(f" Hidden size: {model.d_model}")
427
+ print(f" Attention heads: {model.n_heads}")
428
+
429
+ # Test forward pass
430
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
431
+ model = model.to(device)
432
+ model.eval()
433
+
434
+ # Small test batch
435
+ batch_size = 2
436
+ seq_len = 128
437
+ input_ids = torch.randint(0, config['vocab_size'], (batch_size, seq_len), device=device)
438
+
439
+ # Enable gradient checkpointing
440
+ model.enable_gradient_checkpointing()
441
+
442
+ # Forward pass
443
+ with torch.no_grad():
444
+ outputs = model(input_ids=input_ids)
445
+
446
+ print(f"\nTest forward pass:")
447
+ print(f" Input shape: {input_ids.shape}")
448
+ print(f" Output logits shape: {outputs['logits'].shape}")
449
+ print(f" Device: {device}")
450
+
451
+ if torch.cuda.is_available():
452
+ print(f" Memory allocated: {torch.cuda.memory_allocated(device) / 1024**3:.2f} GB")
453
+
454
+ print("\n✓ Model test passed!")