nameissakthi commited on
Commit
c27df58
·
1 Parent(s): e3ef0ba

Remove pycache, add gitignore

Browse files
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.py[cod]
3
+ *$py.class
4
+ *.so
5
+ .DS_Store
README.md CHANGED
@@ -111,10 +111,6 @@ Training Time: ~40 minutes
111
 
112
  ```bash
113
  pip install torch tokenizers huggingface_hub
114
-
115
- # Clone model architecture code
116
- git clone https://github.com/nameissakthi/slm-qualcomm
117
- cd slm-qualcomm
118
  ```
119
 
120
  ### Download Model
@@ -132,6 +128,8 @@ tokenizer_path = hf_hub_download(repo_id="nameissakthi/PebbleLM-117M-Chat", file
132
  ```python
133
  import torch
134
  from tokenizers import Tokenizer
 
 
135
  from src.model.transformer import SLMForCausalLM
136
  from src.model.config import SLMConfig
137
 
 
111
 
112
  ```bash
113
  pip install torch tokenizers huggingface_hub
 
 
 
 
114
  ```
115
 
116
  ### Download Model
 
128
  ```python
129
  import torch
130
  from tokenizers import Tokenizer
131
+
132
+ # Model architecture is included in this repo
133
  from src.model.transformer import SLMForCausalLM
134
  from src.model.config import SLMConfig
135
 
src/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # SLM Qualcomm - Conversational Small Language Model
2
+ __version__ = "1.0.0"
src/data/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Data loading and tokenizer components
2
+
3
+ from .tokenizer import SLMTokenizer
4
+ from .dataset import (
5
+ ConversationalDataset,
6
+ StreamingTextDataset,
7
+ PackedDataset,
8
+ create_train_val_split,
9
+ load_jsonl,
10
+ save_jsonl,
11
+ )
12
+ from .dataloader import (
13
+ DataModule,
14
+ StreamingDataModule,
15
+ create_dataloader,
16
+ estimate_dataset_tokens,
17
+ )
18
+
19
+ __all__ = [
20
+ "SLMTokenizer",
21
+ "ConversationalDataset",
22
+ "StreamingTextDataset",
23
+ "PackedDataset",
24
+ "create_train_val_split",
25
+ "load_jsonl",
26
+ "save_jsonl",
27
+ "DataModule",
28
+ "StreamingDataModule",
29
+ "create_dataloader",
30
+ "estimate_dataset_tokens",
31
+ ]
src/data/dataloader.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DataLoader utilities for SLM training.
3
+
4
+ Provides efficient batching and data loading for training.
5
+ """
6
+
7
+ import os
8
+ from typing import Dict, Optional, List
9
+
10
+ import torch
11
+ from torch.utils.data import DataLoader, Dataset, DistributedSampler
12
+
13
+ from .dataset import ConversationalDataset, StreamingTextDataset, PackedDataset
14
+ from .tokenizer import SLMTokenizer
15
+
16
+
17
+ def create_dataloader(
18
+ dataset: Dataset,
19
+ batch_size: int,
20
+ shuffle: bool = True,
21
+ num_workers: int = 4,
22
+ pin_memory: bool = None, # Auto-detect based on device
23
+ drop_last: bool = True,
24
+ distributed: bool = False,
25
+ world_size: int = 1,
26
+ rank: int = 0,
27
+ ) -> DataLoader:
28
+ """Create a DataLoader with optimal settings.
29
+
30
+ Args:
31
+ dataset: The dataset to load from
32
+ batch_size: Batch size per device
33
+ shuffle: Whether to shuffle data
34
+ num_workers: Number of data loading workers
35
+ pin_memory: Pin memory for faster GPU transfer
36
+ drop_last: Drop last incomplete batch
37
+ distributed: Whether using distributed training
38
+ world_size: Number of distributed processes
39
+ rank: Current process rank
40
+
41
+ Returns:
42
+ Configured DataLoader
43
+ """
44
+ sampler = None
45
+ if distributed:
46
+ sampler = DistributedSampler(
47
+ dataset,
48
+ num_replicas=world_size,
49
+ rank=rank,
50
+ shuffle=shuffle,
51
+ )
52
+ shuffle = False # Sampler handles shuffling
53
+
54
+ # Auto-detect pin_memory: disable for MPS (not supported)
55
+ if pin_memory is None:
56
+ import torch
57
+ pin_memory = torch.cuda.is_available() # Only True for CUDA
58
+
59
+ return DataLoader(
60
+ dataset,
61
+ batch_size=batch_size,
62
+ shuffle=shuffle if sampler is None else False,
63
+ sampler=sampler,
64
+ num_workers=num_workers,
65
+ pin_memory=pin_memory,
66
+ drop_last=drop_last,
67
+ collate_fn=default_collate_fn,
68
+ )
69
+
70
+
71
+ def default_collate_fn(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
72
+ """Collate function for batching samples.
73
+
74
+ Args:
75
+ batch: List of sample dictionaries
76
+
77
+ Returns:
78
+ Batched dictionary with stacked tensors
79
+ """
80
+ return {
81
+ "input_ids": torch.stack([s["input_ids"] for s in batch]),
82
+ "attention_mask": torch.stack([s["attention_mask"] for s in batch]),
83
+ "labels": torch.stack([s["labels"] for s in batch]),
84
+ }
85
+
86
+
87
+ class DataModule:
88
+ """Data module for managing train/val dataloaders.
89
+
90
+ Provides a unified interface for data loading during training.
91
+ """
92
+
93
+ def __init__(
94
+ self,
95
+ data_dir: str,
96
+ tokenizer_path: str,
97
+ max_length: int = 1024,
98
+ batch_size: int = 32,
99
+ num_workers: int = 4,
100
+ val_batch_size: Optional[int] = None,
101
+ ):
102
+ """Initialize data module.
103
+
104
+ Args:
105
+ data_dir: Directory containing processed data
106
+ tokenizer_path: Path to tokenizer.json
107
+ max_length: Maximum sequence length
108
+ batch_size: Training batch size
109
+ num_workers: Number of data loading workers
110
+ val_batch_size: Validation batch size (defaults to batch_size)
111
+ """
112
+ self.data_dir = data_dir
113
+ self.max_length = max_length
114
+ self.batch_size = batch_size
115
+ self.val_batch_size = val_batch_size or batch_size
116
+ self.num_workers = num_workers
117
+
118
+ # Load tokenizer
119
+ self.tokenizer = SLMTokenizer.from_file(tokenizer_path)
120
+
121
+ # Datasets (created on first access)
122
+ self._train_dataset = None
123
+ self._val_dataset = None
124
+
125
+ @property
126
+ def train_dataset(self) -> Dataset:
127
+ """Get or create training dataset."""
128
+ if self._train_dataset is None:
129
+ self._train_dataset = ConversationalDataset(
130
+ data_path=self.data_dir,
131
+ tokenizer=self.tokenizer,
132
+ max_length=self.max_length,
133
+ split="train",
134
+ )
135
+ return self._train_dataset
136
+
137
+ @property
138
+ def val_dataset(self) -> Dataset:
139
+ """Get or create validation dataset."""
140
+ if self._val_dataset is None:
141
+ self._val_dataset = ConversationalDataset(
142
+ data_path=self.data_dir,
143
+ tokenizer=self.tokenizer,
144
+ max_length=self.max_length,
145
+ split="val",
146
+ )
147
+ return self._val_dataset
148
+
149
+ def train_dataloader(
150
+ self,
151
+ distributed: bool = False,
152
+ world_size: int = 1,
153
+ rank: int = 0,
154
+ ) -> DataLoader:
155
+ """Get training dataloader."""
156
+ return create_dataloader(
157
+ self.train_dataset,
158
+ batch_size=self.batch_size,
159
+ shuffle=True,
160
+ num_workers=self.num_workers,
161
+ drop_last=True,
162
+ distributed=distributed,
163
+ world_size=world_size,
164
+ rank=rank,
165
+ )
166
+
167
+ def val_dataloader(self) -> DataLoader:
168
+ """Get validation dataloader."""
169
+ return create_dataloader(
170
+ self.val_dataset,
171
+ batch_size=self.val_batch_size,
172
+ shuffle=False,
173
+ num_workers=self.num_workers,
174
+ drop_last=False,
175
+ )
176
+
177
+
178
+ class StreamingDataModule:
179
+ """Data module for streaming large datasets.
180
+
181
+ Memory-efficient loading for large text corpora.
182
+ """
183
+
184
+ def __init__(
185
+ self,
186
+ data_files: List[str],
187
+ tokenizer_path: str,
188
+ max_length: int = 1024,
189
+ batch_size: int = 32,
190
+ num_workers: int = 4,
191
+ ):
192
+ """Initialize streaming data module.
193
+
194
+ Args:
195
+ data_files: List of text file paths
196
+ tokenizer_path: Path to tokenizer.json
197
+ max_length: Maximum sequence length
198
+ batch_size: Batch size
199
+ num_workers: Number of data loading workers
200
+ """
201
+ self.data_files = data_files
202
+ self.max_length = max_length
203
+ self.batch_size = batch_size
204
+ self.num_workers = num_workers
205
+
206
+ # Load tokenizer
207
+ self.tokenizer = SLMTokenizer.from_file(tokenizer_path)
208
+
209
+ def train_dataloader(self) -> DataLoader:
210
+ """Get training dataloader for streaming data."""
211
+ dataset = StreamingTextDataset(
212
+ data_files=self.data_files,
213
+ tokenizer=self.tokenizer,
214
+ max_length=self.max_length,
215
+ shuffle=True,
216
+ )
217
+
218
+ return DataLoader(
219
+ dataset,
220
+ batch_size=self.batch_size,
221
+ num_workers=self.num_workers,
222
+ pin_memory=True,
223
+ collate_fn=default_collate_fn,
224
+ )
225
+
226
+
227
+ def estimate_dataset_tokens(data_dir: str, tokenizer_path: str) -> Dict[str, int]:
228
+ """Estimate total tokens in a dataset.
229
+
230
+ Args:
231
+ data_dir: Directory containing data files
232
+ tokenizer_path: Path to tokenizer
233
+
234
+ Returns:
235
+ Dictionary with token counts
236
+ """
237
+ import json
238
+ from pathlib import Path
239
+
240
+ tokenizer = SLMTokenizer.from_file(tokenizer_path)
241
+
242
+ total_tokens = 0
243
+ total_samples = 0
244
+
245
+ for file_path in Path(data_dir).glob("*.json*"):
246
+ with open(file_path, "r") as f:
247
+ if file_path.suffix == ".jsonl":
248
+ samples = [json.loads(line) for line in f if line.strip()]
249
+ else:
250
+ samples = json.load(f)
251
+ if not isinstance(samples, list):
252
+ samples = [samples]
253
+
254
+ for sample in samples:
255
+ if "user" in sample and "assistant" in sample:
256
+ tokens = tokenizer.encode_conversation(
257
+ sample["user"], sample["assistant"]
258
+ )
259
+ elif "text" in sample:
260
+ tokens = tokenizer.encode(sample["text"])
261
+ else:
262
+ continue
263
+
264
+ total_tokens += len(tokens)
265
+ total_samples += 1
266
+
267
+ return {
268
+ "total_tokens": total_tokens,
269
+ "total_samples": total_samples,
270
+ "avg_tokens_per_sample": total_tokens / max(total_samples, 1),
271
+ }
272
+
273
+
274
+ def get_dataloader_stats(dataloader: DataLoader) -> Dict[str, float]:
275
+ """Get statistics from a dataloader.
276
+
277
+ Args:
278
+ dataloader: The dataloader to analyze
279
+
280
+ Returns:
281
+ Dictionary with statistics
282
+ """
283
+ total_batches = 0
284
+ total_tokens = 0
285
+ total_non_pad_tokens = 0
286
+
287
+ for batch in dataloader:
288
+ total_batches += 1
289
+ total_tokens += batch["input_ids"].numel()
290
+ total_non_pad_tokens += batch["attention_mask"].sum().item()
291
+
292
+ # Only sample first 100 batches
293
+ if total_batches >= 100:
294
+ break
295
+
296
+ return {
297
+ "batches_sampled": total_batches,
298
+ "tokens_per_batch": total_tokens / max(total_batches, 1),
299
+ "non_pad_ratio": total_non_pad_tokens / max(total_tokens, 1),
300
+ }
src/data/dataset.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dataset classes for SLM training.
3
+
4
+ Handles loading, preprocessing, and tokenization of conversational data.
5
+ """
6
+
7
+ import os
8
+ import json
9
+ import random
10
+ from typing import List, Dict, Optional, Iterator, Tuple
11
+ from pathlib import Path
12
+
13
+ import torch
14
+ from torch.utils.data import Dataset, IterableDataset
15
+
16
+ from .tokenizer import SLMTokenizer
17
+
18
+
19
+ class ConversationalDataset(Dataset):
20
+ """Dataset for conversational/instruction-following data.
21
+
22
+ Loads pre-tokenized data from disk for efficient training.
23
+ Format: Each sample is a tokenized conversation with user/assistant turns.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ data_path: str,
29
+ tokenizer: SLMTokenizer,
30
+ max_length: int = 1024,
31
+ split: str = "train",
32
+ ):
33
+ """Initialize the dataset.
34
+
35
+ Args:
36
+ data_path: Path to the processed data directory
37
+ tokenizer: Tokenizer instance
38
+ max_length: Maximum sequence length
39
+ split: "train" or "val"
40
+ """
41
+ self.tokenizer = tokenizer
42
+ self.max_length = max_length
43
+ self.split = split
44
+
45
+ # Load data
46
+ self.samples = self._load_data(data_path)
47
+ print(f"Loaded {len(self.samples)} samples for {split} split")
48
+
49
+ def _load_data(self, data_path: str) -> List[Dict]:
50
+ """Load data from JSON or JSONL files."""
51
+ samples = []
52
+
53
+ # Check for split-specific JSONL file first (preferred for large datasets)
54
+ split_jsonl = os.path.join(data_path, f"{self.split}.jsonl")
55
+ if os.path.exists(split_jsonl):
56
+ with open(split_jsonl, "r", encoding="utf-8") as f:
57
+ for line in f:
58
+ line = line.strip()
59
+ if line:
60
+ samples.append(json.loads(line))
61
+ return samples
62
+
63
+ # Check for split-specific JSON file
64
+ split_file = os.path.join(data_path, f"{self.split}.json")
65
+ if os.path.exists(split_file):
66
+ with open(split_file, "r", encoding="utf-8") as f:
67
+ # Try JSONL format first (one JSON per line)
68
+ content = f.read()
69
+ f.seek(0)
70
+ try:
71
+ # Try loading as single JSON array
72
+ samples = json.loads(content)
73
+ if isinstance(samples, list):
74
+ return samples
75
+ except json.JSONDecodeError:
76
+ pass
77
+
78
+ # Load as JSONL (one JSON per line)
79
+ for line in f:
80
+ line = line.strip()
81
+ if line:
82
+ samples.append(json.loads(line))
83
+ return samples
84
+
85
+ # Check for combined file with splits
86
+ combined_file = os.path.join(data_path, "data.json")
87
+ if os.path.exists(combined_file):
88
+ with open(combined_file, "r") as f:
89
+ all_data = json.load(f)
90
+ if isinstance(all_data, dict) and self.split in all_data:
91
+ return all_data[self.split]
92
+ return all_data
93
+
94
+ # Load all .json and .jsonl files in directory
95
+ for ext in ["*.jsonl", "*.json"]:
96
+ for file in sorted(Path(data_path).glob(ext)):
97
+ with open(file, "r", encoding="utf-8") as f:
98
+ if file.suffix == ".jsonl":
99
+ for line in f:
100
+ line = line.strip()
101
+ if line:
102
+ samples.append(json.loads(line))
103
+ else:
104
+ data = json.load(f)
105
+ if isinstance(data, list):
106
+ samples.extend(data)
107
+ else:
108
+ samples.append(data)
109
+
110
+ return samples
111
+
112
+ def __len__(self) -> int:
113
+ return len(self.samples)
114
+
115
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
116
+ """Get a single sample.
117
+
118
+ Returns:
119
+ Dictionary with:
120
+ - input_ids: Token IDs for the full sequence
121
+ - attention_mask: 1 for real tokens, 0 for padding
122
+ - labels: Same as input_ids but with -100 for padding (for loss)
123
+ """
124
+ sample = self.samples[idx]
125
+
126
+ # Handle different data formats
127
+ if "input_ids" in sample:
128
+ # Pre-tokenized data
129
+ input_ids = sample["input_ids"]
130
+ elif "user" in sample and "assistant" in sample:
131
+ # Raw conversation format
132
+ input_ids = self.tokenizer.encode_conversation(
133
+ user_message=sample["user"],
134
+ assistant_message=sample["assistant"],
135
+ max_length=self.max_length,
136
+ )
137
+ elif "text" in sample:
138
+ # Raw text format
139
+ input_ids = self.tokenizer.encode(
140
+ sample["text"],
141
+ add_special_tokens=True,
142
+ max_length=self.max_length,
143
+ truncation=True,
144
+ )
145
+ elif "question" in sample and "answer" in sample:
146
+ # Q&A format
147
+ input_ids = self.tokenizer.encode_conversation(
148
+ user_message=sample["question"],
149
+ assistant_message=sample["answer"],
150
+ max_length=self.max_length,
151
+ )
152
+ else:
153
+ raise ValueError(f"Unknown sample format: {list(sample.keys())}")
154
+
155
+ # Pad or truncate
156
+ if len(input_ids) > self.max_length:
157
+ input_ids = input_ids[:self.max_length]
158
+ # Ensure EOS at the end
159
+ if input_ids[-1] != self.tokenizer.eos_token_id:
160
+ input_ids[-1] = self.tokenizer.eos_token_id
161
+
162
+ # Create attention mask (before padding)
163
+ attention_mask = [1] * len(input_ids)
164
+
165
+ # Pad if needed
166
+ padding_length = self.max_length - len(input_ids)
167
+ if padding_length > 0:
168
+ input_ids = input_ids + [self.tokenizer.pad_token_id] * padding_length
169
+ attention_mask = attention_mask + [0] * padding_length
170
+
171
+ # Labels for language modeling (shift happens in loss function)
172
+ # Use -100 for padding tokens so they're ignored in loss
173
+ labels = [
174
+ id if mask == 1 else -100
175
+ for id, mask in zip(input_ids, attention_mask)
176
+ ]
177
+
178
+ return {
179
+ "input_ids": torch.tensor(input_ids, dtype=torch.long),
180
+ "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
181
+ "labels": torch.tensor(labels, dtype=torch.long),
182
+ }
183
+
184
+
185
+ class StreamingTextDataset(IterableDataset):
186
+ """Streaming dataset for large text files.
187
+
188
+ Memory-efficient dataset that streams data from disk.
189
+ Useful for training on large text corpora.
190
+ """
191
+
192
+ def __init__(
193
+ self,
194
+ data_files: List[str],
195
+ tokenizer: SLMTokenizer,
196
+ max_length: int = 1024,
197
+ shuffle: bool = True,
198
+ seed: int = 42,
199
+ ):
200
+ """Initialize streaming dataset.
201
+
202
+ Args:
203
+ data_files: List of text file paths
204
+ tokenizer: Tokenizer instance
205
+ max_length: Maximum sequence length
206
+ shuffle: Whether to shuffle files and lines
207
+ seed: Random seed for shuffling
208
+ """
209
+ self.data_files = data_files
210
+ self.tokenizer = tokenizer
211
+ self.max_length = max_length
212
+ self.shuffle = shuffle
213
+ self.seed = seed
214
+
215
+ # Verify files exist
216
+ for f in data_files:
217
+ if not os.path.exists(f):
218
+ raise FileNotFoundError(f"Data file not found: {f}")
219
+
220
+ def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
221
+ """Iterate over all samples in all files."""
222
+ worker_info = torch.utils.data.get_worker_info()
223
+
224
+ # Handle multi-worker data loading
225
+ if worker_info is None:
226
+ files_to_process = self.data_files
227
+ else:
228
+ # Split files among workers
229
+ per_worker = len(self.data_files) // worker_info.num_workers
230
+ worker_id = worker_info.id
231
+ start = worker_id * per_worker
232
+ end = start + per_worker if worker_id < worker_info.num_workers - 1 else len(self.data_files)
233
+ files_to_process = self.data_files[start:end]
234
+
235
+ # Shuffle files if needed
236
+ if self.shuffle:
237
+ rng = random.Random(self.seed)
238
+ files_to_process = list(files_to_process)
239
+ rng.shuffle(files_to_process)
240
+
241
+ # Buffer for accumulating text
242
+ buffer = []
243
+ buffer_tokens = 0
244
+
245
+ for file_path in files_to_process:
246
+ with open(file_path, "r", encoding="utf-8") as f:
247
+ for line in f:
248
+ line = line.strip()
249
+ if not line:
250
+ continue
251
+
252
+ # Try to parse as JSON (for conversational data)
253
+ try:
254
+ data = json.loads(line)
255
+ if "user" in data and "assistant" in data:
256
+ tokens = self.tokenizer.encode_conversation(
257
+ data["user"], data["assistant"]
258
+ )
259
+ elif "text" in data:
260
+ tokens = self.tokenizer.encode(
261
+ data["text"], add_special_tokens=True
262
+ )
263
+ else:
264
+ tokens = self.tokenizer.encode(
265
+ line, add_special_tokens=True
266
+ )
267
+ except json.JSONDecodeError:
268
+ # Plain text line
269
+ tokens = self.tokenizer.encode(
270
+ line, add_special_tokens=True
271
+ )
272
+
273
+ buffer.extend(tokens)
274
+
275
+ # Yield chunks of max_length
276
+ while len(buffer) >= self.max_length:
277
+ chunk = buffer[:self.max_length]
278
+ buffer = buffer[self.max_length:]
279
+
280
+ yield self._create_sample(chunk)
281
+
282
+ # Handle remaining buffer (pad to max_length)
283
+ if len(buffer) > 0:
284
+ yield self._create_sample(buffer)
285
+
286
+ def _create_sample(self, tokens: List[int]) -> Dict[str, torch.Tensor]:
287
+ """Create a training sample from tokens."""
288
+ input_ids = tokens[:self.max_length]
289
+
290
+ # Pad if needed
291
+ attention_mask = [1] * len(input_ids)
292
+ padding_length = self.max_length - len(input_ids)
293
+ if padding_length > 0:
294
+ input_ids = input_ids + [self.tokenizer.pad_token_id] * padding_length
295
+ attention_mask = attention_mask + [0] * padding_length
296
+
297
+ labels = [
298
+ id if mask == 1 else -100
299
+ for id, mask in zip(input_ids, attention_mask)
300
+ ]
301
+
302
+ return {
303
+ "input_ids": torch.tensor(input_ids, dtype=torch.long),
304
+ "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
305
+ "labels": torch.tensor(labels, dtype=torch.long),
306
+ }
307
+
308
+
309
+ class PackedDataset(Dataset):
310
+ """Dataset that packs multiple short sequences into one.
311
+
312
+ Efficient for training when samples are shorter than max_length.
313
+ Concatenates samples with separator tokens to fill sequences.
314
+ """
315
+
316
+ def __init__(
317
+ self,
318
+ samples: List[Dict],
319
+ tokenizer: SLMTokenizer,
320
+ max_length: int = 1024,
321
+ ):
322
+ """Initialize packed dataset.
323
+
324
+ Args:
325
+ samples: List of samples with "user" and "assistant" keys
326
+ tokenizer: Tokenizer instance
327
+ max_length: Maximum sequence length
328
+ """
329
+ self.tokenizer = tokenizer
330
+ self.max_length = max_length
331
+
332
+ # Pack sequences
333
+ self.packed_samples = self._pack_sequences(samples)
334
+ print(f"Packed {len(samples)} samples into {len(self.packed_samples)} sequences")
335
+
336
+ def _pack_sequences(self, samples: List[Dict]) -> List[List[int]]:
337
+ """Pack short sequences together."""
338
+ packed = []
339
+ current_sequence = []
340
+
341
+ for sample in samples:
342
+ # Tokenize
343
+ if "user" in sample and "assistant" in sample:
344
+ tokens = self.tokenizer.encode_conversation(
345
+ sample["user"], sample["assistant"]
346
+ )
347
+ elif "text" in sample:
348
+ tokens = self.tokenizer.encode(sample["text"], add_special_tokens=True)
349
+ else:
350
+ continue
351
+
352
+ # Check if we can add to current sequence
353
+ if len(current_sequence) + len(tokens) <= self.max_length:
354
+ current_sequence.extend(tokens)
355
+ else:
356
+ # Save current and start new
357
+ if current_sequence:
358
+ packed.append(current_sequence)
359
+ current_sequence = tokens[:self.max_length]
360
+
361
+ # Don't forget the last sequence
362
+ if current_sequence:
363
+ packed.append(current_sequence)
364
+
365
+ return packed
366
+
367
+ def __len__(self) -> int:
368
+ return len(self.packed_samples)
369
+
370
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
371
+ """Get a packed sample."""
372
+ tokens = self.packed_samples[idx]
373
+
374
+ # Pad if needed
375
+ attention_mask = [1] * len(tokens)
376
+ padding_length = self.max_length - len(tokens)
377
+ if padding_length > 0:
378
+ tokens = tokens + [self.tokenizer.pad_token_id] * padding_length
379
+ attention_mask = attention_mask + [0] * padding_length
380
+
381
+ labels = [
382
+ id if mask == 1 else -100
383
+ for id, mask in zip(tokens, attention_mask)
384
+ ]
385
+
386
+ return {
387
+ "input_ids": torch.tensor(tokens, dtype=torch.long),
388
+ "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
389
+ "labels": torch.tensor(labels, dtype=torch.long),
390
+ }
391
+
392
+
393
+ def create_train_val_split(
394
+ samples: List[Dict],
395
+ val_ratio: float = 0.01,
396
+ seed: int = 42,
397
+ ) -> Tuple[List[Dict], List[Dict]]:
398
+ """Split samples into train and validation sets.
399
+
400
+ Args:
401
+ samples: List of all samples
402
+ val_ratio: Ratio for validation set
403
+ seed: Random seed
404
+
405
+ Returns:
406
+ Tuple of (train_samples, val_samples)
407
+ """
408
+ random.seed(seed)
409
+ shuffled = list(samples)
410
+ random.shuffle(shuffled)
411
+
412
+ val_size = int(len(shuffled) * val_ratio)
413
+ val_samples = shuffled[:val_size]
414
+ train_samples = shuffled[val_size:]
415
+
416
+ return train_samples, val_samples
417
+
418
+
419
+ def load_jsonl(file_path: str) -> List[Dict]:
420
+ """Load data from a JSONL file."""
421
+ samples = []
422
+ with open(file_path, "r", encoding="utf-8") as f:
423
+ for line in f:
424
+ line = line.strip()
425
+ if line:
426
+ samples.append(json.loads(line))
427
+ return samples
428
+
429
+
430
+ def save_jsonl(samples: List[Dict], file_path: str):
431
+ """Save data to a JSONL file."""
432
+ with open(file_path, "w", encoding="utf-8") as f:
433
+ for sample in samples:
434
+ f.write(json.dumps(sample) + "\n")
src/data/tokenizer.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom BPE Tokenizer for SLM v1.
3
+ 16,384 vocabulary size optimized for conversational use.
4
+ """
5
+
6
+ import os
7
+ import json
8
+ from typing import List, Optional, Union
9
+ from tokenizers import Tokenizer, models, trainers, pre_tokenizers, processors, decoders
10
+ from tokenizers.normalizers import NFKC, Lowercase, Sequence
11
+
12
+
13
+ class SLMTokenizer:
14
+ """Custom BPE tokenizer for the SLM model.
15
+
16
+ Features:
17
+ - 16,384 token vocabulary (memory efficient)
18
+ - Special tokens for conversation format
19
+ - Compatible with the model's embedding layer
20
+ """
21
+
22
+ # Special tokens
23
+ PAD_TOKEN = "<|pad|>"
24
+ BOS_TOKEN = "<|bos|>"
25
+ EOS_TOKEN = "<|eos|>"
26
+ UNK_TOKEN = "<|unk|>"
27
+ USER_TOKEN = "<|user|>"
28
+ ASSISTANT_TOKEN = "<|assistant|>"
29
+
30
+ SPECIAL_TOKENS = [PAD_TOKEN, BOS_TOKEN, EOS_TOKEN, UNK_TOKEN, USER_TOKEN, ASSISTANT_TOKEN]
31
+
32
+ def __init__(self, tokenizer: Optional[Tokenizer] = None):
33
+ """Initialize tokenizer.
34
+
35
+ Args:
36
+ tokenizer: Pre-trained HuggingFace tokenizer object
37
+ """
38
+ self.tokenizer = tokenizer
39
+ self._setup_special_token_ids()
40
+
41
+ def _setup_special_token_ids(self):
42
+ """Setup special token IDs for easy access."""
43
+ if self.tokenizer is not None:
44
+ self.pad_token_id = self.tokenizer.token_to_id(self.PAD_TOKEN)
45
+ self.bos_token_id = self.tokenizer.token_to_id(self.BOS_TOKEN)
46
+ self.eos_token_id = self.tokenizer.token_to_id(self.EOS_TOKEN)
47
+ self.unk_token_id = self.tokenizer.token_to_id(self.UNK_TOKEN)
48
+ self.user_token_id = self.tokenizer.token_to_id(self.USER_TOKEN)
49
+ self.assistant_token_id = self.tokenizer.token_to_id(self.ASSISTANT_TOKEN)
50
+
51
+ @classmethod
52
+ def train(
53
+ cls,
54
+ files: List[str],
55
+ vocab_size: int = 16384,
56
+ min_frequency: int = 2,
57
+ save_path: Optional[str] = None,
58
+ ) -> "SLMTokenizer":
59
+ """Train a new BPE tokenizer on the given files.
60
+
61
+ Args:
62
+ files: List of text file paths to train on
63
+ vocab_size: Size of vocabulary (default 16,384)
64
+ min_frequency: Minimum token frequency to include
65
+ save_path: Optional path to save the trained tokenizer
66
+
67
+ Returns:
68
+ Trained SLMTokenizer instance
69
+ """
70
+ print(f"Training BPE tokenizer with vocab_size={vocab_size}...")
71
+ print(f"Training files: {files}")
72
+
73
+ # Initialize a BPE tokenizer
74
+ tokenizer = Tokenizer(models.BPE(unk_token=cls.UNK_TOKEN))
75
+
76
+ # Set up normalizer (optional - keeps text mostly as-is)
77
+ # We use NFKC normalization to standardize unicode
78
+ tokenizer.normalizer = NFKC()
79
+
80
+ # Set up pre-tokenizer (splits on whitespace and punctuation)
81
+ tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
82
+
83
+ # Set up decoder
84
+ tokenizer.decoder = decoders.ByteLevel()
85
+
86
+ # Set up trainer
87
+ trainer = trainers.BpeTrainer(
88
+ vocab_size=vocab_size,
89
+ min_frequency=min_frequency,
90
+ special_tokens=cls.SPECIAL_TOKENS,
91
+ show_progress=True,
92
+ )
93
+
94
+ # Train the tokenizer
95
+ tokenizer.train(files, trainer)
96
+
97
+ # Set up post-processor for adding special tokens
98
+ tokenizer.post_processor = processors.TemplateProcessing(
99
+ single=f"{cls.BOS_TOKEN} $A {cls.EOS_TOKEN}",
100
+ pair=f"{cls.BOS_TOKEN} $A {cls.EOS_TOKEN} {cls.BOS_TOKEN} $B {cls.EOS_TOKEN}",
101
+ special_tokens=[
102
+ (cls.BOS_TOKEN, tokenizer.token_to_id(cls.BOS_TOKEN)),
103
+ (cls.EOS_TOKEN, tokenizer.token_to_id(cls.EOS_TOKEN)),
104
+ ],
105
+ )
106
+
107
+ print(f"Tokenizer trained! Vocabulary size: {tokenizer.get_vocab_size()}")
108
+
109
+ # Create instance
110
+ instance = cls(tokenizer)
111
+
112
+ # Save if path provided
113
+ if save_path:
114
+ instance.save(save_path)
115
+
116
+ return instance
117
+
118
+ @classmethod
119
+ def from_file(cls, path: str) -> "SLMTokenizer":
120
+ """Load a tokenizer from a saved file.
121
+
122
+ Args:
123
+ path: Path to the tokenizer.json file
124
+
125
+ Returns:
126
+ Loaded SLMTokenizer instance
127
+ """
128
+ tokenizer = Tokenizer.from_file(path)
129
+ return cls(tokenizer)
130
+
131
+ def save(self, path: str):
132
+ """Save the tokenizer to a file.
133
+
134
+ Args:
135
+ path: Path to save the tokenizer (directory or file)
136
+ """
137
+ if os.path.isdir(path):
138
+ save_path = os.path.join(path, "tokenizer.json")
139
+ else:
140
+ save_path = path
141
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
142
+
143
+ self.tokenizer.save(save_path)
144
+ print(f"Tokenizer saved to: {save_path}")
145
+
146
+ # Also save config
147
+ config_path = save_path.replace("tokenizer.json", "tokenizer_config.json")
148
+ config = {
149
+ "vocab_size": self.vocab_size,
150
+ "pad_token": self.PAD_TOKEN,
151
+ "bos_token": self.BOS_TOKEN,
152
+ "eos_token": self.EOS_TOKEN,
153
+ "unk_token": self.UNK_TOKEN,
154
+ "user_token": self.USER_TOKEN,
155
+ "assistant_token": self.ASSISTANT_TOKEN,
156
+ }
157
+ with open(config_path, "w") as f:
158
+ json.dump(config, f, indent=2)
159
+ print(f"Tokenizer config saved to: {config_path}")
160
+
161
+ def encode(
162
+ self,
163
+ text: str,
164
+ add_special_tokens: bool = True,
165
+ max_length: Optional[int] = None,
166
+ padding: bool = False,
167
+ truncation: bool = False,
168
+ ) -> List[int]:
169
+ """Encode text to token IDs.
170
+
171
+ Args:
172
+ text: Input text string
173
+ add_special_tokens: Whether to add BOS/EOS tokens
174
+ max_length: Maximum sequence length
175
+ padding: Whether to pad to max_length
176
+ truncation: Whether to truncate to max_length
177
+
178
+ Returns:
179
+ List of token IDs
180
+ """
181
+ # Encode
182
+ if add_special_tokens:
183
+ encoding = self.tokenizer.encode(text)
184
+ else:
185
+ encoding = self.tokenizer.encode(text, add_special_tokens=False)
186
+
187
+ ids = encoding.ids
188
+
189
+ # Truncation
190
+ if truncation and max_length and len(ids) > max_length:
191
+ ids = ids[:max_length]
192
+ # Ensure EOS at end if we had special tokens
193
+ if add_special_tokens and ids[-1] != self.eos_token_id:
194
+ ids[-1] = self.eos_token_id
195
+
196
+ # Padding
197
+ if padding and max_length and len(ids) < max_length:
198
+ ids = ids + [self.pad_token_id] * (max_length - len(ids))
199
+
200
+ return ids
201
+
202
+ def decode(self, ids: List[int], skip_special_tokens: bool = True) -> str:
203
+ """Decode token IDs to text.
204
+
205
+ Args:
206
+ ids: List of token IDs
207
+ skip_special_tokens: Whether to remove special tokens
208
+
209
+ Returns:
210
+ Decoded text string
211
+ """
212
+ return self.tokenizer.decode(ids, skip_special_tokens=skip_special_tokens)
213
+
214
+ def encode_conversation(
215
+ self,
216
+ user_message: str,
217
+ assistant_message: Optional[str] = None,
218
+ max_length: Optional[int] = None,
219
+ ) -> List[int]:
220
+ """Encode a conversation turn.
221
+
222
+ Format: <|bos|><|user|>message<|assistant|>response<|eos|>
223
+
224
+ Args:
225
+ user_message: The user's message
226
+ assistant_message: Optional assistant response
227
+ max_length: Maximum sequence length
228
+
229
+ Returns:
230
+ List of token IDs
231
+ """
232
+ # Build conversation string
233
+ if assistant_message:
234
+ text = f"{self.USER_TOKEN}{user_message}{self.ASSISTANT_TOKEN}{assistant_message}"
235
+ else:
236
+ # For inference - no response yet
237
+ text = f"{self.USER_TOKEN}{user_message}{self.ASSISTANT_TOKEN}"
238
+
239
+ return self.encode(text, add_special_tokens=True, max_length=max_length, truncation=True)
240
+
241
+ @property
242
+ def vocab_size(self) -> int:
243
+ """Get vocabulary size."""
244
+ return self.tokenizer.get_vocab_size()
245
+
246
+ def get_vocab(self) -> dict:
247
+ """Get the vocabulary as a dictionary."""
248
+ return self.tokenizer.get_vocab()
249
+
250
+ def __len__(self) -> int:
251
+ """Return vocabulary size."""
252
+ return self.vocab_size
253
+
254
+ def __call__(
255
+ self,
256
+ text: Union[str, List[str]],
257
+ max_length: Optional[int] = None,
258
+ padding: bool = False,
259
+ truncation: bool = False,
260
+ return_tensors: Optional[str] = None,
261
+ ) -> dict:
262
+ """Tokenize text (HuggingFace-style interface).
263
+
264
+ Args:
265
+ text: Input text or list of texts
266
+ max_length: Maximum sequence length
267
+ padding: Whether to pad sequences
268
+ truncation: Whether to truncate sequences
269
+ return_tensors: If "pt", return PyTorch tensors
270
+
271
+ Returns:
272
+ Dictionary with input_ids and attention_mask
273
+ """
274
+ if isinstance(text, str):
275
+ text = [text]
276
+
277
+ all_ids = []
278
+ for t in text:
279
+ ids = self.encode(
280
+ t,
281
+ max_length=max_length,
282
+ padding=padding,
283
+ truncation=truncation,
284
+ )
285
+ all_ids.append(ids)
286
+
287
+ # Create attention mask (1 for real tokens, 0 for padding)
288
+ attention_mask = [[1 if id != self.pad_token_id else 0 for id in ids] for ids in all_ids]
289
+
290
+ result = {
291
+ "input_ids": all_ids,
292
+ "attention_mask": attention_mask,
293
+ }
294
+
295
+ if return_tensors == "pt":
296
+ import torch
297
+ result["input_ids"] = torch.tensor(all_ids)
298
+ result["attention_mask"] = torch.tensor(attention_mask)
299
+
300
+ return result
src/export/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # ONNX export components
src/inference/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Inference and generation components
src/model/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SLM Model Components."""
2
+
3
+ from .config import SLMConfig
4
+ from .transformer import SLMForCausalLM, SLMModel, SLMOutput
5
+ from .kv_cache import KVCache
6
+ from .normalization import RMSNorm
7
+ from .rope import RotaryEmbedding
8
+ from .attention import MultiHeadAttention, create_causal_mask
9
+ from .ffn import FeedForward
10
+ from .decoder import DecoderBlock
11
+
12
+ __all__ = [
13
+ "SLMConfig",
14
+ "SLMForCausalLM",
15
+ "SLMModel",
16
+ "SLMOutput",
17
+ "KVCache",
18
+ "RMSNorm",
19
+ "RotaryEmbedding",
20
+ "MultiHeadAttention",
21
+ "create_causal_mask",
22
+ "FeedForward",
23
+ "DecoderBlock",
24
+ ]
src/model/attention.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Multi-Head Attention with explicit KV cache for SLM.
3
+ Qualcomm-safe: No FlashAttention, no fused ops, clean ONNX export.
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from typing import Optional, Tuple
10
+
11
+ from .config import SLMConfig
12
+ from .rope import RotaryEmbedding
13
+ from .kv_cache import KVCache
14
+
15
+
16
+ class MultiHeadAttention(nn.Module):
17
+ """Multi-Head Self-Attention with RoPE and explicit KV cache.
18
+
19
+ Design choices for Qualcomm compatibility:
20
+ - Standard attention (no FlashAttention)
21
+ - No grouped/multi-query attention (simpler, v1.1 will add GQA)
22
+ - Explicit KV cache management
23
+ - Clean tensor operations for ONNX export
24
+ """
25
+
26
+ def __init__(self, config: SLMConfig, layer_idx: int):
27
+ """Initialize attention layer.
28
+
29
+ Args:
30
+ config: Model configuration
31
+ layer_idx: Index of this layer (for KV cache)
32
+ """
33
+ super().__init__()
34
+ self.config = config
35
+ self.layer_idx = layer_idx
36
+
37
+ self.hidden_size = config.hidden_size
38
+ self.num_heads = config.num_heads
39
+ self.head_dim = config.head_dim
40
+ self.dropout = config.attention_dropout
41
+
42
+ # Q, K, V projections
43
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
44
+ self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
45
+ self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
46
+
47
+ # Output projection
48
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
49
+
50
+ # Rotary embeddings
51
+ self.rotary_emb = RotaryEmbedding(
52
+ dim=self.head_dim,
53
+ max_position_embeddings=config.max_position_embeddings,
54
+ base=config.rope_theta,
55
+ )
56
+
57
+ def forward(
58
+ self,
59
+ hidden_states: torch.Tensor,
60
+ position_ids: torch.Tensor,
61
+ attention_mask: Optional[torch.Tensor] = None,
62
+ kv_cache: Optional[KVCache] = None,
63
+ use_cache: bool = False,
64
+ ) -> Tuple[torch.Tensor, Optional[KVCache]]:
65
+ """Forward pass for attention.
66
+
67
+ Args:
68
+ hidden_states: Input tensor [batch, seq_len, hidden_size]
69
+ position_ids: Position indices [batch, seq_len]
70
+ attention_mask: Causal mask [batch, 1, seq_len, kv_seq_len]
71
+ kv_cache: Optional KV cache for inference
72
+ use_cache: Whether to use/update KV cache
73
+
74
+ Returns:
75
+ Tuple of (output, kv_cache)
76
+ """
77
+ batch_size, seq_len, _ = hidden_states.shape
78
+
79
+ # Project to Q, K, V
80
+ query = self.q_proj(hidden_states)
81
+ key = self.k_proj(hidden_states)
82
+ value = self.v_proj(hidden_states)
83
+
84
+ # Reshape: [batch, seq, hidden] -> [batch, seq, heads, head_dim]
85
+ query = query.view(batch_size, seq_len, self.num_heads, self.head_dim)
86
+ key = key.view(batch_size, seq_len, self.num_heads, self.head_dim)
87
+ value = value.view(batch_size, seq_len, self.num_heads, self.head_dim)
88
+
89
+ # Transpose for attention: [batch, heads, seq, head_dim]
90
+ query = query.transpose(1, 2)
91
+ key = key.transpose(1, 2)
92
+ value = value.transpose(1, 2)
93
+
94
+ # Apply rotary embeddings to Q and K
95
+ query, key = self.rotary_emb(query, key, position_ids)
96
+
97
+ # Handle KV cache
98
+ if use_cache and kv_cache is not None:
99
+ # Get the position to write to cache
100
+ cache_position = position_ids[0, 0].item()
101
+
102
+ # Update cache and get full K, V
103
+ key, value = kv_cache.update(
104
+ layer_idx=self.layer_idx,
105
+ key=key,
106
+ value=value,
107
+ position=cache_position,
108
+ )
109
+
110
+ # Compute attention scores
111
+ # [batch, heads, seq, head_dim] @ [batch, heads, head_dim, kv_seq]
112
+ # -> [batch, heads, seq, kv_seq]
113
+ scale = 1.0 / (self.head_dim ** 0.5)
114
+ attn_weights = torch.matmul(query, key.transpose(-2, -1)) * scale
115
+
116
+ # Apply causal mask
117
+ if attention_mask is not None:
118
+ attn_weights = attn_weights + attention_mask
119
+
120
+ # Softmax and dropout
121
+ attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
122
+
123
+ if self.training and self.dropout > 0:
124
+ attn_weights = F.dropout(attn_weights, p=self.dropout)
125
+
126
+ # Apply attention to values
127
+ # [batch, heads, seq, kv_seq] @ [batch, heads, kv_seq, head_dim]
128
+ # -> [batch, heads, seq, head_dim]
129
+ attn_output = torch.matmul(attn_weights, value)
130
+
131
+ # Reshape back: [batch, heads, seq, head_dim] -> [batch, seq, hidden]
132
+ attn_output = attn_output.transpose(1, 2).contiguous()
133
+ attn_output = attn_output.view(batch_size, seq_len, self.hidden_size)
134
+
135
+ # Output projection
136
+ output = self.o_proj(attn_output)
137
+
138
+ return output, kv_cache
139
+
140
+
141
+ def create_causal_mask(
142
+ seq_len: int,
143
+ kv_seq_len: int,
144
+ dtype: torch.dtype,
145
+ device: torch.device,
146
+ ) -> torch.Tensor:
147
+ """Create a causal attention mask.
148
+
149
+ Args:
150
+ seq_len: Query sequence length
151
+ kv_seq_len: Key/value sequence length
152
+ dtype: Data type for mask
153
+ device: Device for mask
154
+
155
+ Returns:
156
+ Causal mask tensor [1, 1, seq_len, kv_seq_len]
157
+ """
158
+ # Create lower triangular mask
159
+ mask = torch.full((seq_len, kv_seq_len), float("-inf"), dtype=dtype, device=device)
160
+
161
+ # For decode (seq_len=1), we can attend to all previous tokens
162
+ if seq_len == 1:
163
+ mask = torch.zeros((seq_len, kv_seq_len), dtype=dtype, device=device)
164
+ else:
165
+ # For prefill, create standard causal mask
166
+ # Position i can attend to positions 0..i
167
+ for i in range(seq_len):
168
+ # Offset for KV cache
169
+ offset = kv_seq_len - seq_len
170
+ mask[i, : offset + i + 1] = 0.0
171
+
172
+ return mask.unsqueeze(0).unsqueeze(0)
src/model/config.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model configuration for SLM v1.
3
+ Defines all hyperparameters based on architecture specification.
4
+ """
5
+
6
+ from dataclasses import dataclass
7
+ from typing import Optional
8
+ import yaml
9
+
10
+
11
+ @dataclass
12
+ class SLMConfig:
13
+ """Configuration class for the SLM model.
14
+
15
+ Architecture: 120M parameter decoder-only transformer
16
+ - 8 layers, 1024 hidden size, 16 attention heads
17
+ - RMSNorm (pre-norm), GELU FFN, RoPE positions
18
+ - Explicit KV cache for efficient inference
19
+ """
20
+
21
+ # Model architecture
22
+ vocab_size: int = 16384
23
+ hidden_size: int = 1024
24
+ num_layers: int = 8
25
+ num_heads: int = 16
26
+ head_dim: int = 64
27
+ intermediate_size: int = 4096 # 4 * hidden_size
28
+
29
+ # Position encoding
30
+ max_position_embeddings: int = 1024
31
+ rope_theta: float = 10000.0
32
+
33
+ # Normalization
34
+ rms_norm_eps: float = 1e-6
35
+
36
+ # Embeddings
37
+ tie_word_embeddings: bool = True
38
+
39
+ # Dropout (disabled for inference, optional for training)
40
+ dropout: float = 0.0
41
+ attention_dropout: float = 0.0
42
+
43
+ # Precision
44
+ torch_dtype: str = "float16"
45
+
46
+ def __post_init__(self):
47
+ """Validate configuration after initialization."""
48
+ assert self.hidden_size % self.num_heads == 0, \
49
+ f"hidden_size ({self.hidden_size}) must be divisible by num_heads ({self.num_heads})"
50
+ assert self.head_dim == self.hidden_size // self.num_heads, \
51
+ f"head_dim ({self.head_dim}) must equal hidden_size // num_heads ({self.hidden_size // self.num_heads})"
52
+
53
+ @classmethod
54
+ def from_yaml(cls, path: str) -> "SLMConfig":
55
+ """Load configuration from YAML file."""
56
+ with open(path, "r") as f:
57
+ config_dict = yaml.safe_load(f)
58
+
59
+ model_config = config_dict.get("model", {})
60
+ return cls(**model_config)
61
+
62
+ def to_dict(self) -> dict:
63
+ """Convert configuration to dictionary."""
64
+ return {
65
+ "vocab_size": self.vocab_size,
66
+ "hidden_size": self.hidden_size,
67
+ "num_layers": self.num_layers,
68
+ "num_heads": self.num_heads,
69
+ "head_dim": self.head_dim,
70
+ "intermediate_size": self.intermediate_size,
71
+ "max_position_embeddings": self.max_position_embeddings,
72
+ "rope_theta": self.rope_theta,
73
+ "rms_norm_eps": self.rms_norm_eps,
74
+ "tie_word_embeddings": self.tie_word_embeddings,
75
+ "dropout": self.dropout,
76
+ "attention_dropout": self.attention_dropout,
77
+ "torch_dtype": self.torch_dtype,
78
+ }
79
+
80
+ @property
81
+ def num_parameters(self) -> int:
82
+ """Estimate total number of parameters."""
83
+ # Embedding: vocab_size * hidden_size
84
+ embedding_params = self.vocab_size * self.hidden_size
85
+
86
+ # Per layer:
87
+ # - Attention: 4 * hidden_size^2 (Q, K, V, O projections)
88
+ # - FFN: 2 * hidden_size * intermediate_size
89
+ # - Norms: 2 * hidden_size
90
+ attention_params = 4 * self.hidden_size * self.hidden_size
91
+ ffn_params = 2 * self.hidden_size * self.intermediate_size
92
+ norm_params = 2 * self.hidden_size
93
+
94
+ layer_params = attention_params + ffn_params + norm_params
95
+ total_layer_params = self.num_layers * layer_params
96
+
97
+ # Output head (tied with embedding if enabled)
98
+ output_params = 0 if self.tie_word_embeddings else self.vocab_size * self.hidden_size
99
+
100
+ # Final norm
101
+ final_norm_params = self.hidden_size
102
+
103
+ return embedding_params + total_layer_params + output_params + final_norm_params
104
+
105
+ def __repr__(self) -> str:
106
+ params_m = self.num_parameters / 1e6
107
+ return (
108
+ f"SLMConfig(\n"
109
+ f" vocab_size={self.vocab_size},\n"
110
+ f" hidden_size={self.hidden_size},\n"
111
+ f" num_layers={self.num_layers},\n"
112
+ f" num_heads={self.num_heads},\n"
113
+ f" max_position_embeddings={self.max_position_embeddings},\n"
114
+ f" estimated_params={params_m:.1f}M\n"
115
+ f")"
116
+ )
src/model/decoder.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Decoder Block for SLM.
3
+ Pre-norm architecture with residual connections.
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from typing import Optional, Tuple
9
+
10
+ from .config import SLMConfig
11
+ from .normalization import RMSNorm
12
+ from .attention import MultiHeadAttention
13
+ from .ffn import FeedForward
14
+ from .kv_cache import KVCache
15
+
16
+
17
+ class DecoderBlock(nn.Module):
18
+ """Single decoder block with pre-norm architecture.
19
+
20
+ Structure (Pre-Norm):
21
+ ```
22
+ x
23
+ ├─ RMSNorm
24
+ ├─ Multi-Head Attention
25
+ ├─ Residual Add
26
+ ├─ RMSNorm
27
+ ├─ Feed-Forward Network
28
+ └─ Residual Add
29
+ ```
30
+
31
+ Why pre-norm:
32
+ - More stable gradients in FP16 training
33
+ - Better quantization behavior
34
+ - Easier ONNX export (no layer-crossing dependencies)
35
+ """
36
+
37
+ def __init__(self, config: SLMConfig, layer_idx: int):
38
+ """Initialize decoder block.
39
+
40
+ Args:
41
+ config: Model configuration
42
+ layer_idx: Index of this layer
43
+ """
44
+ super().__init__()
45
+ self.config = config
46
+ self.layer_idx = layer_idx
47
+
48
+ # Pre-attention norm
49
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
50
+
51
+ # Self-attention
52
+ self.self_attn = MultiHeadAttention(config, layer_idx)
53
+
54
+ # Pre-FFN norm
55
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
56
+
57
+ # Feed-forward network
58
+ self.mlp = FeedForward(config)
59
+
60
+ def forward(
61
+ self,
62
+ hidden_states: torch.Tensor,
63
+ position_ids: torch.Tensor,
64
+ attention_mask: Optional[torch.Tensor] = None,
65
+ kv_cache: Optional[KVCache] = None,
66
+ use_cache: bool = False,
67
+ ) -> Tuple[torch.Tensor, Optional[KVCache]]:
68
+ """Forward pass through decoder block.
69
+
70
+ Args:
71
+ hidden_states: Input tensor [batch, seq, hidden_size]
72
+ position_ids: Position indices [batch, seq]
73
+ attention_mask: Causal attention mask
74
+ kv_cache: Optional KV cache
75
+ use_cache: Whether to use/update cache
76
+
77
+ Returns:
78
+ Tuple of (output, kv_cache)
79
+ """
80
+ # Store residual
81
+ residual = hidden_states
82
+
83
+ # Pre-norm -> Attention
84
+ hidden_states = self.input_layernorm(hidden_states)
85
+ hidden_states, kv_cache = self.self_attn(
86
+ hidden_states=hidden_states,
87
+ position_ids=position_ids,
88
+ attention_mask=attention_mask,
89
+ kv_cache=kv_cache,
90
+ use_cache=use_cache,
91
+ )
92
+
93
+ # Residual connection
94
+ hidden_states = residual + hidden_states
95
+
96
+ # Store residual
97
+ residual = hidden_states
98
+
99
+ # Pre-norm -> FFN
100
+ hidden_states = self.post_attention_layernorm(hidden_states)
101
+ hidden_states = self.mlp(hidden_states)
102
+
103
+ # Residual connection
104
+ hidden_states = residual + hidden_states
105
+
106
+ return hidden_states, kv_cache
src/model/ffn.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Feed-Forward Network for SLM.
3
+ Uses GELU activation (not SwiGLU) for better INT8 quantization.
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from .config import SLMConfig
11
+
12
+
13
+ class FeedForward(nn.Module):
14
+ """Feed-Forward Network with GELU activation.
15
+
16
+ Architecture: Linear -> GELU -> Linear
17
+ - Input: [batch, seq, hidden_size=768]
18
+ - Hidden: [batch, seq, intermediate_size=3072]
19
+ - Output: [batch, seq, hidden_size=768]
20
+
21
+ Why GELU over SwiGLU:
22
+ - Fewer operations (2 matmuls vs 3)
23
+ - Better INT8 quantization behavior
24
+ - Full QNN support without decomposition
25
+ - SwiGLU benefits mainly appear at >1B parameters
26
+ """
27
+
28
+ def __init__(self, config: SLMConfig):
29
+ """Initialize FFN.
30
+
31
+ Args:
32
+ config: Model configuration
33
+ """
34
+ super().__init__()
35
+
36
+ self.hidden_size = config.hidden_size
37
+ self.intermediate_size = config.intermediate_size
38
+
39
+ # Up projection: hidden -> intermediate
40
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
41
+
42
+ # Down projection: intermediate -> hidden
43
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
44
+
45
+ self.dropout = config.dropout
46
+
47
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
48
+ """Forward pass through FFN.
49
+
50
+ Args:
51
+ x: Input tensor [batch, seq, hidden_size]
52
+
53
+ Returns:
54
+ Output tensor [batch, seq, hidden_size]
55
+ """
56
+ # Up project and apply GELU
57
+ hidden = self.up_proj(x)
58
+ hidden = F.gelu(hidden, approximate="tanh")
59
+
60
+ # Down project
61
+ output = self.down_proj(hidden)
62
+
63
+ # Apply dropout during training
64
+ if self.training and self.dropout > 0:
65
+ output = F.dropout(output, p=self.dropout)
66
+
67
+ return output
src/model/kv_cache.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Explicit KV Cache management for efficient inference.
3
+ This is critical for Qualcomm deployment and agent control loops.
4
+ """
5
+
6
+ import torch
7
+ from typing import Optional, Tuple
8
+ from dataclasses import dataclass
9
+
10
+
11
+ @dataclass
12
+ class KVCache:
13
+ """Key-Value cache for transformer inference.
14
+
15
+ Layout: [num_layers, batch_size, num_heads, max_seq_len, head_dim]
16
+
17
+ This explicit cache enables:
18
+ - Efficient autoregressive decoding
19
+ - Cache offloading for memory management
20
+ - Sliding window attention (future)
21
+ - Agent control loops with cache manipulation
22
+ """
23
+
24
+ key_cache: torch.Tensor # [num_layers, batch, heads, max_len, head_dim]
25
+ value_cache: torch.Tensor # [num_layers, batch, heads, max_len, head_dim]
26
+ seq_len: int # Current sequence length in cache
27
+
28
+ @classmethod
29
+ def create(
30
+ cls,
31
+ num_layers: int,
32
+ batch_size: int,
33
+ num_heads: int,
34
+ max_seq_len: int,
35
+ head_dim: int,
36
+ dtype: torch.dtype = torch.float16,
37
+ device: torch.device = None,
38
+ ) -> "KVCache":
39
+ """Create an empty KV cache.
40
+
41
+ Args:
42
+ num_layers: Number of transformer layers
43
+ batch_size: Batch size
44
+ num_heads: Number of attention heads
45
+ max_seq_len: Maximum sequence length
46
+ head_dim: Dimension per attention head
47
+ dtype: Data type for cache tensors
48
+ device: Device to create cache on
49
+
50
+ Returns:
51
+ Initialized KVCache with zero tensors
52
+ """
53
+ shape = (num_layers, batch_size, num_heads, max_seq_len, head_dim)
54
+
55
+ key_cache = torch.zeros(shape, dtype=dtype, device=device)
56
+ value_cache = torch.zeros(shape, dtype=dtype, device=device)
57
+
58
+ return cls(key_cache=key_cache, value_cache=value_cache, seq_len=0)
59
+
60
+ def update(
61
+ self,
62
+ layer_idx: int,
63
+ key: torch.Tensor,
64
+ value: torch.Tensor,
65
+ position: int,
66
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
67
+ """Update cache for a specific layer and return full K, V.
68
+
69
+ Args:
70
+ layer_idx: Index of the transformer layer
71
+ key: New key tensor [batch, heads, seq_len, head_dim]
72
+ value: New value tensor [batch, heads, seq_len, head_dim]
73
+ position: Starting position for the new tokens
74
+
75
+ Returns:
76
+ Tuple of (full_key, full_value) including cached values
77
+ """
78
+ seq_len = key.shape[2]
79
+ end_pos = position + seq_len
80
+
81
+ # Store new keys and values
82
+ self.key_cache[layer_idx, :, :, position:end_pos, :] = key
83
+ self.value_cache[layer_idx, :, :, position:end_pos, :] = value
84
+
85
+ # Update sequence length
86
+ self.seq_len = max(self.seq_len, end_pos)
87
+
88
+ # Return full K, V up to current position
89
+ return (
90
+ self.key_cache[layer_idx, :, :, :end_pos, :],
91
+ self.value_cache[layer_idx, :, :, :end_pos, :],
92
+ )
93
+
94
+ def get(
95
+ self,
96
+ layer_idx: int,
97
+ end_pos: Optional[int] = None,
98
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
99
+ """Get cached K, V for a specific layer.
100
+
101
+ Args:
102
+ layer_idx: Index of the transformer layer
103
+ end_pos: End position (defaults to current seq_len)
104
+
105
+ Returns:
106
+ Tuple of (key, value) tensors
107
+ """
108
+ if end_pos is None:
109
+ end_pos = self.seq_len
110
+
111
+ return (
112
+ self.key_cache[layer_idx, :, :, :end_pos, :],
113
+ self.value_cache[layer_idx, :, :, :end_pos, :],
114
+ )
115
+
116
+ def reset(self):
117
+ """Reset the cache to empty state."""
118
+ self.key_cache.zero_()
119
+ self.value_cache.zero_()
120
+ self.seq_len = 0
121
+
122
+ @property
123
+ def memory_usage_mb(self) -> float:
124
+ """Calculate memory usage in megabytes."""
125
+ total_bytes = self.key_cache.numel() * self.key_cache.element_size()
126
+ total_bytes += self.value_cache.numel() * self.value_cache.element_size()
127
+ return total_bytes / (1024 * 1024)
src/model/normalization.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RMSNorm implementation for SLM.
3
+ Pre-norm architecture for stable FP16 training and better quantization.
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+
10
+ class RMSNorm(nn.Module):
11
+ """Root Mean Square Layer Normalization.
12
+
13
+ RMSNorm is computationally simpler than LayerNorm as it doesn't
14
+ compute mean statistics. This makes it:
15
+ - Faster to compute
16
+ - More stable in FP16
17
+ - Better for quantization
18
+
19
+ Reference: https://arxiv.org/abs/1910.07467
20
+ """
21
+
22
+ def __init__(self, hidden_size: int, eps: float = 1e-6):
23
+ """Initialize RMSNorm.
24
+
25
+ Args:
26
+ hidden_size: The size of the hidden dimension
27
+ eps: Small constant for numerical stability
28
+ """
29
+ super().__init__()
30
+ self.weight = nn.Parameter(torch.ones(hidden_size))
31
+ self.eps = eps
32
+
33
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
34
+ """Apply RMS normalization.
35
+
36
+ Args:
37
+ x: Input tensor of shape [..., hidden_size]
38
+
39
+ Returns:
40
+ Normalized tensor of same shape
41
+ """
42
+ # Compute RMS: sqrt(mean(x^2))
43
+ # Use float32 for numerical stability, then cast back
44
+ input_dtype = x.dtype
45
+ x = x.float()
46
+
47
+ variance = x.pow(2).mean(dim=-1, keepdim=True)
48
+ x = x * torch.rsqrt(variance + self.eps)
49
+
50
+ return (self.weight * x).to(input_dtype)
src/model/rope.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Rotary Position Embedding (RoPE) implementation.
3
+ Applied to Q and K only, with fixed base (no dynamic scaling).
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from typing import Tuple
9
+
10
+
11
+ class RotaryEmbedding(nn.Module):
12
+ """Rotary Position Embedding (RoPE).
13
+
14
+ RoPE encodes position information by rotating the query and key vectors.
15
+ Key properties:
16
+ - Parameter-free (no learnable embeddings)
17
+ - Naturally encodes relative positions
18
+ - Extrapolates well to longer sequences
19
+
20
+ Reference: https://arxiv.org/abs/2104.09864
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ dim: int,
26
+ max_position_embeddings: int = 1024,
27
+ base: float = 10000.0,
28
+ ):
29
+ """Initialize RoPE.
30
+
31
+ Args:
32
+ dim: Dimension of the rotary embedding (usually head_dim)
33
+ max_position_embeddings: Maximum sequence length
34
+ base: Base for the frequency computation
35
+ """
36
+ super().__init__()
37
+ self.dim = dim
38
+ self.max_position_embeddings = max_position_embeddings
39
+ self.base = base
40
+
41
+ # Precompute inverse frequencies
42
+ inv_freq = 1.0 / (
43
+ self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)
44
+ )
45
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
46
+
47
+ # Precompute cos and sin for all positions
48
+ self._set_cos_sin_cache(max_position_embeddings)
49
+
50
+ def _set_cos_sin_cache(self, seq_len: int):
51
+ """Precompute cos and sin values for positions."""
52
+ self.max_seq_len_cached = seq_len
53
+ t = torch.arange(seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
54
+
55
+ # Outer product: [seq_len] x [dim/2] -> [seq_len, dim/2]
56
+ freqs = torch.outer(t, self.inv_freq)
57
+
58
+ # Concatenate to get [seq_len, dim]
59
+ emb = torch.cat((freqs, freqs), dim=-1)
60
+
61
+ self.register_buffer("cos_cached", emb.cos(), persistent=False)
62
+ self.register_buffer("sin_cached", emb.sin(), persistent=False)
63
+
64
+ def forward(
65
+ self,
66
+ q: torch.Tensor,
67
+ k: torch.Tensor,
68
+ position_ids: torch.Tensor,
69
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
70
+ """Apply rotary embeddings to query and key tensors.
71
+
72
+ Args:
73
+ q: Query tensor of shape [batch, num_heads, seq_len, head_dim]
74
+ k: Key tensor of shape [batch, num_heads, seq_len, head_dim]
75
+ position_ids: Position indices of shape [batch, seq_len]
76
+
77
+ Returns:
78
+ Tuple of (rotated_q, rotated_k) with same shapes as inputs
79
+ """
80
+ seq_len = position_ids.max() + 1
81
+
82
+ # Extend cache if needed
83
+ if seq_len > self.max_seq_len_cached:
84
+ self._set_cos_sin_cache(seq_len)
85
+
86
+ # Get cos and sin for the positions
87
+ # Shape: [batch, seq_len, dim]
88
+ cos = self.cos_cached[position_ids]
89
+ sin = self.sin_cached[position_ids]
90
+
91
+ # Add head dimension: [batch, 1, seq_len, dim]
92
+ cos = cos.unsqueeze(1)
93
+ sin = sin.unsqueeze(1)
94
+
95
+ # Apply rotation
96
+ q_embed = (q * cos) + (self._rotate_half(q) * sin)
97
+ k_embed = (k * cos) + (self._rotate_half(k) * sin)
98
+
99
+ return q_embed, k_embed
100
+
101
+ @staticmethod
102
+ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
103
+ """Rotate half the hidden dims of the input.
104
+
105
+ Splits the input into two halves and rotates:
106
+ [x1, x2, x3, x4] -> [-x3, -x4, x1, x2]
107
+ """
108
+ x1 = x[..., : x.shape[-1] // 2]
109
+ x2 = x[..., x.shape[-1] // 2 :]
110
+ return torch.cat((-x2, x1), dim=-1)
src/model/transformer.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Full Transformer model for SLM.
3
+ Implements the mandatory prefill/decode API for Qualcomm deployment.
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from typing import Optional, Tuple, Union
9
+ from dataclasses import dataclass
10
+
11
+ from .config import SLMConfig
12
+ from .normalization import RMSNorm
13
+ from .decoder import DecoderBlock
14
+ from .attention import create_causal_mask
15
+ from .kv_cache import KVCache
16
+
17
+
18
+ @dataclass
19
+ class SLMOutput:
20
+ """Output from SLM forward pass."""
21
+
22
+ logits: torch.Tensor # [batch, seq, vocab_size]
23
+ kv_cache: Optional[KVCache] = None
24
+ hidden_states: Optional[torch.Tensor] = None
25
+
26
+
27
+ class SLMModel(nn.Module):
28
+ """Core transformer model (without LM head).
29
+
30
+ This is the decoder stack:
31
+ - Token embedding
32
+ - N decoder blocks
33
+ - Final RMSNorm
34
+ """
35
+
36
+ def __init__(self, config: SLMConfig):
37
+ """Initialize transformer model.
38
+
39
+ Args:
40
+ config: Model configuration
41
+ """
42
+ super().__init__()
43
+ self.config = config
44
+
45
+ # Token embeddings
46
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
47
+
48
+ # Decoder layers
49
+ self.layers = nn.ModuleList([
50
+ DecoderBlock(config, layer_idx=i)
51
+ for i in range(config.num_layers)
52
+ ])
53
+
54
+ # Final normalization
55
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
56
+
57
+ def forward(
58
+ self,
59
+ input_ids: torch.Tensor,
60
+ position_ids: Optional[torch.Tensor] = None,
61
+ attention_mask: Optional[torch.Tensor] = None,
62
+ kv_cache: Optional[KVCache] = None,
63
+ use_cache: bool = False,
64
+ ) -> Tuple[torch.Tensor, Optional[KVCache]]:
65
+ """Forward pass through transformer.
66
+
67
+ Args:
68
+ input_ids: Token IDs [batch, seq]
69
+ position_ids: Position indices [batch, seq]
70
+ attention_mask: Causal mask
71
+ kv_cache: Optional KV cache
72
+ use_cache: Whether to use/update cache
73
+
74
+ Returns:
75
+ Tuple of (hidden_states, kv_cache)
76
+ """
77
+ batch_size, seq_len = input_ids.shape
78
+
79
+ # Create position IDs if not provided
80
+ if position_ids is None:
81
+ if kv_cache is not None and kv_cache.seq_len > 0:
82
+ # For decode: position is the current cache length
83
+ position_ids = torch.arange(
84
+ kv_cache.seq_len, kv_cache.seq_len + seq_len,
85
+ device=input_ids.device
86
+ ).unsqueeze(0).expand(batch_size, -1)
87
+ else:
88
+ # For prefill: positions are 0..seq_len-1
89
+ position_ids = torch.arange(
90
+ seq_len, device=input_ids.device
91
+ ).unsqueeze(0).expand(batch_size, -1)
92
+
93
+ # Create attention mask if not provided
94
+ if attention_mask is None:
95
+ kv_seq_len = seq_len
96
+ if kv_cache is not None and kv_cache.seq_len > 0:
97
+ kv_seq_len = kv_cache.seq_len + seq_len
98
+
99
+ attention_mask = create_causal_mask(
100
+ seq_len=seq_len,
101
+ kv_seq_len=kv_seq_len,
102
+ dtype=self.embed_tokens.weight.dtype,
103
+ device=input_ids.device,
104
+ )
105
+
106
+ # Token embeddings
107
+ hidden_states = self.embed_tokens(input_ids)
108
+
109
+ # Pass through decoder layers
110
+ for layer in self.layers:
111
+ hidden_states, kv_cache = layer(
112
+ hidden_states=hidden_states,
113
+ position_ids=position_ids,
114
+ attention_mask=attention_mask,
115
+ kv_cache=kv_cache,
116
+ use_cache=use_cache,
117
+ )
118
+
119
+ # Final normalization
120
+ hidden_states = self.norm(hidden_states)
121
+
122
+ return hidden_states, kv_cache
123
+
124
+
125
+ class SLMForCausalLM(nn.Module):
126
+ """SLM with language modeling head.
127
+
128
+ This is the full model with:
129
+ - Transformer backbone
130
+ - LM head (tied with embeddings)
131
+ - Prefill/Decode API for Qualcomm deployment
132
+ """
133
+
134
+ def __init__(self, config: SLMConfig):
135
+ """Initialize causal LM.
136
+
137
+ Args:
138
+ config: Model configuration
139
+ """
140
+ super().__init__()
141
+ self.config = config
142
+
143
+ # Transformer backbone
144
+ self.model = SLMModel(config)
145
+
146
+ # LM head
147
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
148
+
149
+ # Tie weights if configured
150
+ if config.tie_word_embeddings:
151
+ self.lm_head.weight = self.model.embed_tokens.weight
152
+
153
+ # Initialize weights
154
+ self.apply(self._init_weights)
155
+
156
+ def _init_weights(self, module: nn.Module):
157
+ """Initialize model weights."""
158
+ std = 0.02
159
+ if isinstance(module, nn.Linear):
160
+ module.weight.data.normal_(mean=0.0, std=std)
161
+ if module.bias is not None:
162
+ module.bias.data.zero_()
163
+ elif isinstance(module, nn.Embedding):
164
+ module.weight.data.normal_(mean=0.0, std=std)
165
+
166
+ def forward(
167
+ self,
168
+ input_ids: torch.Tensor,
169
+ position_ids: Optional[torch.Tensor] = None,
170
+ attention_mask: Optional[torch.Tensor] = None,
171
+ kv_cache: Optional[KVCache] = None,
172
+ use_cache: bool = False,
173
+ labels: Optional[torch.Tensor] = None,
174
+ ) -> SLMOutput:
175
+ """Forward pass for causal LM.
176
+
177
+ Args:
178
+ input_ids: Token IDs [batch, seq]
179
+ position_ids: Position indices [batch, seq]
180
+ attention_mask: Causal mask
181
+ kv_cache: Optional KV cache
182
+ use_cache: Whether to use/update cache
183
+ labels: Optional labels for loss computation
184
+
185
+ Returns:
186
+ SLMOutput with logits and optional loss
187
+ """
188
+ # Get hidden states from transformer
189
+ hidden_states, kv_cache = self.model(
190
+ input_ids=input_ids,
191
+ position_ids=position_ids,
192
+ attention_mask=attention_mask,
193
+ kv_cache=kv_cache,
194
+ use_cache=use_cache,
195
+ )
196
+
197
+ # Compute logits
198
+ logits = self.lm_head(hidden_states)
199
+
200
+ return SLMOutput(
201
+ logits=logits,
202
+ kv_cache=kv_cache,
203
+ hidden_states=hidden_states,
204
+ )
205
+
206
+ # =========================================================================
207
+ # MANDATORY KV CACHE API (from architecture.txt)
208
+ # =========================================================================
209
+
210
+ def prefill(
211
+ self,
212
+ input_ids: torch.Tensor,
213
+ kv_cache: Optional[KVCache] = None,
214
+ ) -> Tuple[torch.Tensor, KVCache]:
215
+ """Prefill: Process full prompt and populate KV cache.
216
+
217
+ This is Graph 1 for Qualcomm deployment.
218
+
219
+ Args:
220
+ input_ids: Token IDs [batch, seq]
221
+ kv_cache: Empty or existing KV cache
222
+
223
+ Returns:
224
+ Tuple of (logits [batch, seq, vocab], populated_kv_cache)
225
+ """
226
+ batch_size = input_ids.shape[0]
227
+
228
+ # Create empty cache if not provided
229
+ if kv_cache is None:
230
+ kv_cache = KVCache.create(
231
+ num_layers=self.config.num_layers,
232
+ batch_size=batch_size,
233
+ num_heads=self.config.num_heads,
234
+ max_seq_len=self.config.max_position_embeddings,
235
+ head_dim=self.config.head_dim,
236
+ dtype=self.model.embed_tokens.weight.dtype,
237
+ device=input_ids.device,
238
+ )
239
+
240
+ # Forward pass with cache
241
+ output = self.forward(
242
+ input_ids=input_ids,
243
+ kv_cache=kv_cache,
244
+ use_cache=True,
245
+ )
246
+
247
+ return output.logits, output.kv_cache
248
+
249
+ def decode(
250
+ self,
251
+ input_id: torch.Tensor,
252
+ kv_cache: KVCache,
253
+ position: Optional[int] = None,
254
+ ) -> Tuple[torch.Tensor, KVCache]:
255
+ """Decode: Generate single token using KV cache.
256
+
257
+ This is Graph 2 for Qualcomm deployment.
258
+
259
+ Args:
260
+ input_id: Single token ID [batch, 1]
261
+ kv_cache: Populated KV cache from prefill or previous decode
262
+ position: Position index (defaults to cache.seq_len)
263
+
264
+ Returns:
265
+ Tuple of (logits [batch, 1, vocab], updated_kv_cache)
266
+ """
267
+ batch_size = input_id.shape[0]
268
+
269
+ # Get position from cache if not provided
270
+ if position is None:
271
+ position = kv_cache.seq_len
272
+
273
+ # Create position IDs
274
+ position_ids = torch.tensor(
275
+ [[position]], device=input_id.device
276
+ ).expand(batch_size, -1)
277
+
278
+ # Forward pass with cache
279
+ output = self.forward(
280
+ input_ids=input_id,
281
+ position_ids=position_ids,
282
+ kv_cache=kv_cache,
283
+ use_cache=True,
284
+ )
285
+
286
+ return output.logits, output.kv_cache
287
+
288
+ def create_empty_cache(
289
+ self,
290
+ batch_size: int = 1,
291
+ device: torch.device = None,
292
+ ) -> KVCache:
293
+ """Create an empty KV cache for inference.
294
+
295
+ Args:
296
+ batch_size: Batch size
297
+ device: Device for cache tensors
298
+
299
+ Returns:
300
+ Empty KVCache ready for prefill
301
+ """
302
+ if device is None:
303
+ device = self.model.embed_tokens.weight.device
304
+
305
+ return KVCache.create(
306
+ num_layers=self.config.num_layers,
307
+ batch_size=batch_size,
308
+ num_heads=self.config.num_heads,
309
+ max_seq_len=self.config.max_position_embeddings,
310
+ head_dim=self.config.head_dim,
311
+ dtype=self.model.embed_tokens.weight.dtype,
312
+ device=device,
313
+ )
314
+
315
+ @property
316
+ def num_parameters(self) -> int:
317
+ """Count total trainable parameters."""
318
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
319
+
320
+ @property
321
+ def device(self) -> torch.device:
322
+ """Get model device."""
323
+ return self.model.embed_tokens.weight.device
src/training/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Training components
2
+
3
+ from .loss import LanguageModelingLoss, compute_perplexity, compute_accuracy
4
+ from .optimizer import create_optimizer, create_scheduler, clip_grad_norm
5
+ from .trainer import Trainer, TrainingConfig
6
+
7
+ __all__ = [
8
+ "LanguageModelingLoss",
9
+ "compute_perplexity",
10
+ "compute_accuracy",
11
+ "create_optimizer",
12
+ "create_scheduler",
13
+ "clip_grad_norm",
14
+ "Trainer",
15
+ "TrainingConfig",
16
+ ]
src/training/loss.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Loss functions for SLM training.
3
+
4
+ Cross-entropy loss with optional label smoothing.
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from typing import Optional
11
+
12
+
13
+ class LanguageModelingLoss(nn.Module):
14
+ """Cross-entropy loss for language modeling.
15
+
16
+ Handles:
17
+ - Automatic shifting of labels
18
+ - Ignoring padding tokens (-100)
19
+ - Optional label smoothing
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ vocab_size: int,
25
+ label_smoothing: float = 0.0,
26
+ ignore_index: int = -100,
27
+ ):
28
+ """Initialize loss function.
29
+
30
+ Args:
31
+ vocab_size: Size of vocabulary
32
+ label_smoothing: Label smoothing factor (0.0 = no smoothing)
33
+ ignore_index: Index to ignore in loss calculation (padding)
34
+ """
35
+ super().__init__()
36
+ self.vocab_size = vocab_size
37
+ self.label_smoothing = label_smoothing
38
+ self.ignore_index = ignore_index
39
+
40
+ self.ce_loss = nn.CrossEntropyLoss(
41
+ ignore_index=ignore_index,
42
+ label_smoothing=label_smoothing,
43
+ )
44
+
45
+ def forward(
46
+ self,
47
+ logits: torch.Tensor,
48
+ labels: torch.Tensor,
49
+ shift_labels: bool = True,
50
+ ) -> torch.Tensor:
51
+ """Compute loss.
52
+
53
+ Args:
54
+ logits: Model output logits [batch_size, seq_len, vocab_size]
55
+ labels: Target token IDs [batch_size, seq_len]
56
+ shift_labels: Whether to shift labels (for autoregressive LM)
57
+
58
+ Returns:
59
+ Scalar loss tensor
60
+ """
61
+ if shift_labels:
62
+ # Shift so we predict next token
63
+ # logits: predict tokens 1..n
64
+ # labels: actual tokens 1..n
65
+ shift_logits = logits[..., :-1, :].contiguous()
66
+ shift_labels = labels[..., 1:].contiguous()
67
+ else:
68
+ shift_logits = logits
69
+ shift_labels = labels
70
+
71
+ # Flatten for cross-entropy
72
+ # [batch * seq_len, vocab_size]
73
+ flat_logits = shift_logits.view(-1, self.vocab_size)
74
+ # [batch * seq_len]
75
+ flat_labels = shift_labels.view(-1)
76
+
77
+ loss = self.ce_loss(flat_logits, flat_labels)
78
+
79
+ return loss
80
+
81
+
82
+ def compute_perplexity(loss: torch.Tensor) -> torch.Tensor:
83
+ """Compute perplexity from cross-entropy loss.
84
+
85
+ Args:
86
+ loss: Cross-entropy loss value
87
+
88
+ Returns:
89
+ Perplexity (exp of loss)
90
+ """
91
+ return torch.exp(loss)
92
+
93
+
94
+ def compute_accuracy(
95
+ logits: torch.Tensor,
96
+ labels: torch.Tensor,
97
+ ignore_index: int = -100,
98
+ ) -> torch.Tensor:
99
+ """Compute token prediction accuracy.
100
+
101
+ Args:
102
+ logits: Model output logits [batch_size, seq_len, vocab_size]
103
+ labels: Target token IDs [batch_size, seq_len]
104
+ ignore_index: Index to ignore in accuracy calculation
105
+
106
+ Returns:
107
+ Accuracy as a scalar tensor
108
+ """
109
+ # Shift for autoregressive prediction
110
+ shift_logits = logits[..., :-1, :].contiguous()
111
+ shift_labels = labels[..., 1:].contiguous()
112
+
113
+ # Get predictions
114
+ predictions = shift_logits.argmax(dim=-1)
115
+
116
+ # Mask for valid positions
117
+ mask = shift_labels != ignore_index
118
+
119
+ # Compute accuracy on valid positions
120
+ correct = (predictions == shift_labels) & mask
121
+ accuracy = correct.sum().float() / mask.sum().float()
122
+
123
+ return accuracy
src/training/optimizer.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Optimizer and learning rate scheduler for SLM training.
3
+
4
+ Uses AdamW with cosine annealing and warmup.
5
+ """
6
+
7
+ import math
8
+ from typing import Optional, Tuple, List
9
+
10
+ import torch
11
+ from torch.optim import AdamW
12
+ from torch.optim.lr_scheduler import LambdaLR
13
+
14
+
15
+ def create_optimizer(
16
+ model: torch.nn.Module,
17
+ learning_rate: float = 3e-4,
18
+ weight_decay: float = 0.1,
19
+ betas: Tuple[float, float] = (0.9, 0.95),
20
+ eps: float = 1e-8,
21
+ ) -> AdamW:
22
+ """Create AdamW optimizer with weight decay.
23
+
24
+ Applies weight decay only to 2D parameters (weights, not biases/norms).
25
+
26
+ Args:
27
+ model: The model to optimize
28
+ learning_rate: Base learning rate
29
+ weight_decay: Weight decay coefficient
30
+ betas: Adam beta parameters
31
+ eps: Adam epsilon for numerical stability
32
+
33
+ Returns:
34
+ Configured AdamW optimizer
35
+ """
36
+ # Separate parameters into decay and no-decay groups
37
+ decay_params = []
38
+ no_decay_params = []
39
+
40
+ for name, param in model.named_parameters():
41
+ if not param.requires_grad:
42
+ continue
43
+
44
+ # No weight decay for:
45
+ # - 1D parameters (biases, layer norms)
46
+ # - Embedding layers
47
+ if param.dim() == 1 or "embedding" in name.lower():
48
+ no_decay_params.append(param)
49
+ else:
50
+ decay_params.append(param)
51
+
52
+ param_groups = [
53
+ {"params": decay_params, "weight_decay": weight_decay},
54
+ {"params": no_decay_params, "weight_decay": 0.0},
55
+ ]
56
+
57
+ optimizer = AdamW(
58
+ param_groups,
59
+ lr=learning_rate,
60
+ betas=betas,
61
+ eps=eps,
62
+ )
63
+
64
+ return optimizer
65
+
66
+
67
+ def create_scheduler(
68
+ optimizer: torch.optim.Optimizer,
69
+ num_training_steps: int,
70
+ warmup_ratio: float = 0.1,
71
+ min_lr_ratio: float = 0.1,
72
+ scheduler_type: str = "cosine",
73
+ ) -> LambdaLR:
74
+ """Create learning rate scheduler.
75
+
76
+ Args:
77
+ optimizer: The optimizer to schedule
78
+ num_training_steps: Total number of training steps
79
+ warmup_ratio: Ratio of warmup steps (e.g., 0.1 = 10%)
80
+ min_lr_ratio: Minimum LR as ratio of max (e.g., 0.1 = 10% of peak LR)
81
+ scheduler_type: Type of scheduler ("cosine", "linear", "constant")
82
+
83
+ Returns:
84
+ LambdaLR scheduler
85
+ """
86
+ num_warmup_steps = int(num_training_steps * warmup_ratio)
87
+
88
+ if scheduler_type == "cosine":
89
+ def lr_lambda(current_step: int) -> float:
90
+ # Warmup phase
91
+ if current_step < num_warmup_steps:
92
+ return float(current_step) / float(max(1, num_warmup_steps))
93
+
94
+ # Cosine annealing phase
95
+ progress = float(current_step - num_warmup_steps) / float(
96
+ max(1, num_training_steps - num_warmup_steps)
97
+ )
98
+ cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
99
+
100
+ # Scale between min_lr_ratio and 1.0
101
+ return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_decay
102
+
103
+ elif scheduler_type == "linear":
104
+ def lr_lambda(current_step: int) -> float:
105
+ if current_step < num_warmup_steps:
106
+ return float(current_step) / float(max(1, num_warmup_steps))
107
+
108
+ progress = float(current_step - num_warmup_steps) / float(
109
+ max(1, num_training_steps - num_warmup_steps)
110
+ )
111
+ return max(min_lr_ratio, 1.0 - progress * (1.0 - min_lr_ratio))
112
+
113
+ elif scheduler_type == "constant":
114
+ def lr_lambda(current_step: int) -> float:
115
+ if current_step < num_warmup_steps:
116
+ return float(current_step) / float(max(1, num_warmup_steps))
117
+ return 1.0
118
+
119
+ else:
120
+ raise ValueError(f"Unknown scheduler type: {scheduler_type}")
121
+
122
+ return LambdaLR(optimizer, lr_lambda)
123
+
124
+
125
+ def get_parameter_count(model: torch.nn.Module) -> dict:
126
+ """Get detailed parameter count for a model.
127
+
128
+ Args:
129
+ model: The model to analyze
130
+
131
+ Returns:
132
+ Dictionary with parameter counts
133
+ """
134
+ total_params = 0
135
+ trainable_params = 0
136
+ embedding_params = 0
137
+
138
+ for name, param in model.named_parameters():
139
+ num_params = param.numel()
140
+ total_params += num_params
141
+
142
+ if param.requires_grad:
143
+ trainable_params += num_params
144
+
145
+ if "embedding" in name.lower():
146
+ embedding_params += num_params
147
+
148
+ return {
149
+ "total": total_params,
150
+ "trainable": trainable_params,
151
+ "embedding": embedding_params,
152
+ "non_embedding": total_params - embedding_params,
153
+ }
154
+
155
+
156
+ def get_optimizer_state(optimizer: torch.optim.Optimizer) -> dict:
157
+ """Get optimizer state statistics.
158
+
159
+ Args:
160
+ optimizer: The optimizer to analyze
161
+
162
+ Returns:
163
+ Dictionary with optimizer state info
164
+ """
165
+ num_params = sum(
166
+ sum(p.numel() for p in group["params"])
167
+ for group in optimizer.param_groups
168
+ )
169
+
170
+ current_lrs = [group["lr"] for group in optimizer.param_groups]
171
+
172
+ return {
173
+ "num_param_groups": len(optimizer.param_groups),
174
+ "total_params": num_params,
175
+ "learning_rates": current_lrs,
176
+ }
177
+
178
+
179
+ def clip_grad_norm(
180
+ model: torch.nn.Module,
181
+ max_norm: float = 1.0,
182
+ ) -> float:
183
+ """Clip gradient norm and return the norm value.
184
+
185
+ Args:
186
+ model: The model with gradients
187
+ max_norm: Maximum gradient norm
188
+
189
+ Returns:
190
+ The gradient norm before clipping
191
+ """
192
+ parameters = [p for p in model.parameters() if p.grad is not None]
193
+ if len(parameters) == 0:
194
+ return 0.0
195
+
196
+ total_norm = torch.nn.utils.clip_grad_norm_(parameters, max_norm)
197
+ return total_norm.item()
src/training/trainer.py ADDED
@@ -0,0 +1,511 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training loop for SLM.
3
+
4
+ Handles the complete training process including:
5
+ - Mixed precision training
6
+ - Gradient accumulation
7
+ - Checkpointing
8
+ - Logging
9
+ """
10
+
11
+ import os
12
+ import time
13
+ import json
14
+ from dataclasses import dataclass, asdict
15
+ from typing import Optional, Dict, Any, Callable
16
+ from pathlib import Path
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from torch.utils.data import DataLoader
21
+ from torch.cuda.amp import autocast, GradScaler
22
+ from tqdm import tqdm
23
+
24
+ from .loss import LanguageModelingLoss, compute_perplexity, compute_accuracy
25
+ from .optimizer import create_optimizer, create_scheduler, clip_grad_norm
26
+
27
+
28
+ @dataclass
29
+ class TrainingConfig:
30
+ """Configuration for training."""
31
+
32
+ # Optimization
33
+ learning_rate: float = 3e-4
34
+ weight_decay: float = 0.1
35
+ warmup_ratio: float = 0.1
36
+ min_lr_ratio: float = 0.1
37
+ max_grad_norm: float = 1.0
38
+ label_smoothing: float = 0.0
39
+
40
+ # Training
41
+ num_epochs: int = 5
42
+ gradient_accumulation_steps: int = 4
43
+ fp16: bool = True
44
+
45
+ # Checkpointing
46
+ checkpoint_dir: str = "checkpoints"
47
+ save_steps: int = 1000
48
+ save_total_limit: int = 3
49
+
50
+ # Evaluation
51
+ eval_steps: int = 500
52
+ logging_steps: int = 10
53
+
54
+ # Early stopping
55
+ early_stopping_patience: int = 5 # Stop after N evals without improvement
56
+ early_stopping_threshold: float = 0.01 # Minimum improvement to reset patience
57
+
58
+ # Device
59
+ device: str = "auto"
60
+
61
+ # Compile model (torch.compile)
62
+ compile_model: bool = False
63
+
64
+ def to_dict(self) -> Dict[str, Any]:
65
+ return asdict(self)
66
+
67
+
68
+ class Trainer:
69
+ """Training loop for SLM model."""
70
+
71
+ def __init__(
72
+ self,
73
+ model: nn.Module,
74
+ config: TrainingConfig,
75
+ train_dataloader: DataLoader,
76
+ val_dataloader: Optional[DataLoader] = None,
77
+ wandb_project: Optional[str] = None,
78
+ ):
79
+ """Initialize trainer.
80
+
81
+ Args:
82
+ model: The model to train
83
+ config: Training configuration
84
+ train_dataloader: Training data loader
85
+ val_dataloader: Optional validation data loader
86
+ wandb_project: Optional W&B project name for logging
87
+ """
88
+ self.config = config
89
+ self.train_dataloader = train_dataloader
90
+ self.val_dataloader = val_dataloader
91
+
92
+ # Setup device
93
+ if config.device == "auto":
94
+ if torch.cuda.is_available():
95
+ self.device = torch.device("cuda")
96
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
97
+ self.device = torch.device("mps")
98
+ else:
99
+ self.device = torch.device("cpu")
100
+ else:
101
+ self.device = torch.device(config.device)
102
+
103
+ print(f"Training on device: {self.device}")
104
+
105
+ # Move model to device
106
+ self.model = model.to(self.device)
107
+
108
+ # Get vocab size from model
109
+ if hasattr(model, "config"):
110
+ self.vocab_size = model.config.vocab_size
111
+ else:
112
+ self.vocab_size = model.embed_tokens.num_embeddings
113
+
114
+ # Setup loss function
115
+ self.loss_fn = LanguageModelingLoss(
116
+ vocab_size=self.vocab_size,
117
+ label_smoothing=config.label_smoothing,
118
+ )
119
+
120
+ # Calculate total steps
121
+ self.steps_per_epoch = len(train_dataloader)
122
+ self.total_steps = self.steps_per_epoch * config.num_epochs
123
+ self.total_steps = self.total_steps // config.gradient_accumulation_steps
124
+
125
+ # Setup optimizer and scheduler
126
+ self.optimizer = create_optimizer(
127
+ model,
128
+ learning_rate=config.learning_rate,
129
+ weight_decay=config.weight_decay,
130
+ )
131
+
132
+ self.scheduler = create_scheduler(
133
+ self.optimizer,
134
+ num_training_steps=self.total_steps,
135
+ warmup_ratio=config.warmup_ratio,
136
+ min_lr_ratio=config.min_lr_ratio,
137
+ )
138
+
139
+ # Setup mixed precision
140
+ self.use_amp = config.fp16 and self.device.type == "cuda"
141
+ self.scaler = GradScaler() if self.use_amp else None
142
+
143
+ # Tracking
144
+ self.global_step = 0
145
+ self.epoch = 0
146
+ self.best_val_loss = float("inf")
147
+
148
+ # Early stopping tracking
149
+ self.early_stopping_counter = 0
150
+ self.should_stop = False
151
+
152
+ # Checkpoint directory
153
+ os.makedirs(config.checkpoint_dir, exist_ok=True)
154
+
155
+ # W&B logging
156
+ self.wandb = None
157
+ if wandb_project:
158
+ try:
159
+ import wandb
160
+ wandb.init(project=wandb_project, config=config.to_dict())
161
+ self.wandb = wandb
162
+ except ImportError:
163
+ print("wandb not installed, skipping logging")
164
+
165
+ def train(self) -> Dict[str, Any]:
166
+ """Run the full training loop.
167
+
168
+ Returns:
169
+ Dictionary with training results
170
+ """
171
+ print(f"\n{'='*60}")
172
+ print("STARTING TRAINING")
173
+ print(f"{'='*60}")
174
+ print(f"Total epochs: {self.config.num_epochs}")
175
+ print(f"Steps per epoch: {self.steps_per_epoch}")
176
+ print(f"Total optimization steps: {self.total_steps}")
177
+ print(f"Gradient accumulation: {self.config.gradient_accumulation_steps}")
178
+ print(f"Mixed precision: {self.use_amp}")
179
+ if self.config.early_stopping_patience > 0:
180
+ print(f"Early stopping: patience={self.config.early_stopping_patience}")
181
+ print(f"{'='*60}\n")
182
+
183
+ training_start = time.time()
184
+
185
+ # FIX: Start from loaded epoch (for resume), not always from 0
186
+ start_epoch = self.epoch
187
+ if start_epoch > 0:
188
+ print(f"Resuming from epoch {start_epoch + 1}")
189
+
190
+ for epoch in range(start_epoch, self.config.num_epochs):
191
+ self.epoch = epoch
192
+ epoch_loss = self._train_epoch()
193
+
194
+ print(f"\nEpoch {epoch + 1}/{self.config.num_epochs} - Loss: {epoch_loss:.4f}")
195
+
196
+ # Validation
197
+ if self.val_dataloader is not None:
198
+ val_metrics = self.evaluate()
199
+ print(f"Validation - Loss: {val_metrics['loss']:.4f}, PPL: {val_metrics['perplexity']:.2f}")
200
+
201
+ # Early stopping check
202
+ if val_metrics["loss"] < self.best_val_loss - self.config.early_stopping_threshold:
203
+ self.best_val_loss = val_metrics["loss"]
204
+ self.early_stopping_counter = 0
205
+ self.save_checkpoint("best")
206
+ print(f" New best model saved!")
207
+ else:
208
+ self.early_stopping_counter += 1
209
+ print(f" No improvement. Early stopping: {self.early_stopping_counter}/{self.config.early_stopping_patience}")
210
+
211
+ if self.config.early_stopping_patience > 0 and self.early_stopping_counter >= self.config.early_stopping_patience:
212
+ print(f"\nEarly stopping triggered after {self.early_stopping_counter} evaluations without improvement.")
213
+ self.should_stop = True
214
+
215
+ # Save epoch checkpoint
216
+ self.save_checkpoint(f"epoch_{epoch + 1}")
217
+
218
+ # Check early stopping
219
+ if self.should_stop:
220
+ print("Stopping training early.")
221
+ break
222
+
223
+ training_time = time.time() - training_start
224
+ print(f"\n{'='*60}")
225
+ print(f"TRAINING COMPLETE")
226
+ print(f"Total time: {training_time / 3600:.2f} hours")
227
+ print(f"Best validation loss: {self.best_val_loss:.4f}")
228
+ if self.should_stop:
229
+ print(f"Stopped early at epoch {self.epoch + 1}")
230
+ print(f"{'='*60}")
231
+
232
+ return {
233
+ "total_steps": self.global_step,
234
+ "training_time": training_time,
235
+ "best_val_loss": self.best_val_loss,
236
+ }
237
+
238
+ def _train_epoch(self) -> float:
239
+ """Train for one epoch.
240
+
241
+ Returns:
242
+ Average training loss for the epoch
243
+ """
244
+ self.model.train()
245
+ total_loss = 0.0
246
+ num_batches = 0
247
+ accumulated_loss = 0.0
248
+ num_accumulated_batches = 0 # FIX: Track actual number of batches for correct averaging
249
+
250
+ # Create progress bar
251
+ pbar = tqdm(
252
+ enumerate(self.train_dataloader),
253
+ total=len(self.train_dataloader),
254
+ desc=f"Epoch {self.epoch + 1}",
255
+ ncols=100,
256
+ )
257
+
258
+ for step, batch in pbar:
259
+ # Move batch to device
260
+ input_ids = batch["input_ids"].to(self.device)
261
+ labels = batch["labels"].to(self.device)
262
+ # Note: attention_mask from dataloader is padding mask (1/0)
263
+ # The model creates its own causal mask internally
264
+ # We handle padding via -100 labels in the loss function
265
+
266
+ # Forward pass with optional mixed precision
267
+ with autocast(enabled=self.use_amp):
268
+ outputs = self.model(input_ids)
269
+ # Handle different output types (tensor, tuple, or dataclass)
270
+ if isinstance(outputs, torch.Tensor):
271
+ logits = outputs
272
+ elif hasattr(outputs, 'logits'):
273
+ logits = outputs.logits
274
+ else:
275
+ logits = outputs[0]
276
+ loss = self.loss_fn(logits, labels)
277
+ loss = loss / self.config.gradient_accumulation_steps
278
+
279
+ # Backward pass
280
+ if self.use_amp:
281
+ self.scaler.scale(loss).backward()
282
+ else:
283
+ loss.backward()
284
+
285
+ # FIX: Track unscaled loss correctly
286
+ unscaled_loss = loss.item() * self.config.gradient_accumulation_steps
287
+ accumulated_loss += unscaled_loss
288
+ num_accumulated_batches += 1
289
+ total_loss += unscaled_loss
290
+ num_batches += 1
291
+
292
+ # Gradient accumulation
293
+ if (step + 1) % self.config.gradient_accumulation_steps == 0:
294
+ # Gradient clipping
295
+ if self.use_amp:
296
+ self.scaler.unscale_(self.optimizer)
297
+
298
+ grad_norm = clip_grad_norm(self.model, self.config.max_grad_norm)
299
+
300
+ # Optimizer step
301
+ if self.use_amp:
302
+ self.scaler.step(self.optimizer)
303
+ self.scaler.update()
304
+ else:
305
+ self.optimizer.step()
306
+
307
+ self.scheduler.step()
308
+ self.optimizer.zero_grad()
309
+
310
+ self.global_step += 1
311
+
312
+ # Logging
313
+ if self.global_step % self.config.logging_steps == 0:
314
+ # FIX: Divide by actual number of accumulated batches, not gradient_accumulation_steps
315
+ avg_loss = accumulated_loss / max(num_accumulated_batches, 1)
316
+ lr = self.scheduler.get_last_lr()[0]
317
+
318
+ # Update progress bar
319
+ pbar.set_postfix({
320
+ 'loss': f'{avg_loss:.4f}',
321
+ 'lr': f'{lr:.2e}',
322
+ 'step': f'{self.global_step}/{self.total_steps}'
323
+ })
324
+
325
+ tqdm.write(
326
+ f"Step {self.global_step}/{self.total_steps} | "
327
+ f"Loss: {avg_loss:.4f} | "
328
+ f"LR: {lr:.2e} | "
329
+ f"Grad: {grad_norm:.2f}"
330
+ )
331
+
332
+ if self.wandb:
333
+ self.wandb.log({
334
+ "train/loss": avg_loss,
335
+ "train/learning_rate": lr,
336
+ "train/grad_norm": grad_norm,
337
+ "train/epoch": self.epoch,
338
+ }, step=self.global_step)
339
+
340
+ # Reset accumulators
341
+ accumulated_loss = 0.0
342
+ num_accumulated_batches = 0
343
+
344
+ # Evaluation
345
+ if self.config.eval_steps > 0 and self.global_step % self.config.eval_steps == 0:
346
+ if self.val_dataloader is not None:
347
+ val_metrics = self.evaluate()
348
+ print(f" Eval - Loss: {val_metrics['loss']:.4f}, PPL: {val_metrics['perplexity']:.2f}")
349
+
350
+ if self.wandb:
351
+ self.wandb.log({
352
+ "eval/loss": val_metrics["loss"],
353
+ "eval/perplexity": val_metrics["perplexity"],
354
+ }, step=self.global_step)
355
+
356
+ # Early stopping check during training
357
+ if val_metrics["loss"] < self.best_val_loss - self.config.early_stopping_threshold:
358
+ self.best_val_loss = val_metrics["loss"]
359
+ self.early_stopping_counter = 0
360
+ self.save_checkpoint("best")
361
+ print(f" New best model! Loss: {self.best_val_loss:.4f}")
362
+ else:
363
+ self.early_stopping_counter += 1
364
+ if self.config.early_stopping_patience > 0:
365
+ print(f" No improvement ({self.early_stopping_counter}/{self.config.early_stopping_patience})")
366
+ if self.early_stopping_counter >= self.config.early_stopping_patience:
367
+ print(f"\n Early stopping triggered!")
368
+ self.should_stop = True
369
+ break # Exit the training loop
370
+
371
+ # Checkpointing
372
+ if self.config.save_steps > 0 and self.global_step % self.config.save_steps == 0:
373
+ self.save_checkpoint(f"step_{self.global_step}")
374
+
375
+ # Check if early stopping was triggered
376
+ if self.should_stop:
377
+ break
378
+
379
+ return total_loss / max(num_batches, 1)
380
+
381
+ @torch.no_grad()
382
+ def evaluate(self) -> Dict[str, float]:
383
+ """Evaluate the model on validation data.
384
+
385
+ Returns:
386
+ Dictionary with evaluation metrics
387
+ """
388
+ self.model.eval()
389
+ total_loss = 0.0
390
+ total_accuracy = 0.0
391
+ num_batches = 0
392
+
393
+ for batch in self.val_dataloader:
394
+ input_ids = batch["input_ids"].to(self.device)
395
+ labels = batch["labels"].to(self.device)
396
+
397
+ with autocast(enabled=self.use_amp):
398
+ outputs = self.model(input_ids)
399
+ # Handle different output types (tensor, tuple, or dataclass)
400
+ if isinstance(outputs, torch.Tensor):
401
+ logits = outputs
402
+ elif hasattr(outputs, 'logits'):
403
+ logits = outputs.logits
404
+ else:
405
+ logits = outputs[0]
406
+ loss = self.loss_fn(logits, labels)
407
+
408
+ total_loss += loss.item()
409
+ total_accuracy += compute_accuracy(logits, labels).item()
410
+ num_batches += 1
411
+
412
+ self.model.train()
413
+
414
+ avg_loss = total_loss / max(num_batches, 1)
415
+ avg_accuracy = total_accuracy / max(num_batches, 1)
416
+
417
+ return {
418
+ "loss": avg_loss,
419
+ "perplexity": compute_perplexity(torch.tensor(avg_loss)).item(),
420
+ "accuracy": avg_accuracy,
421
+ }
422
+
423
+ def save_checkpoint(self, name: str):
424
+ """Save a checkpoint.
425
+
426
+ Args:
427
+ name: Checkpoint name (e.g., "best", "epoch_1", "step_1000")
428
+ """
429
+ checkpoint_path = os.path.join(self.config.checkpoint_dir, name)
430
+ os.makedirs(checkpoint_path, exist_ok=True)
431
+
432
+ # Save model
433
+ model_path = os.path.join(checkpoint_path, "model.pt")
434
+ torch.save(self.model.state_dict(), model_path)
435
+
436
+ # Save optimizer and scheduler
437
+ optimizer_path = os.path.join(checkpoint_path, "optimizer.pt")
438
+ torch.save({
439
+ "optimizer": self.optimizer.state_dict(),
440
+ "scheduler": self.scheduler.state_dict(),
441
+ "global_step": self.global_step,
442
+ "epoch": self.epoch,
443
+ "best_val_loss": self.best_val_loss,
444
+ "early_stopping_counter": self.early_stopping_counter,
445
+ }, optimizer_path)
446
+
447
+ # Save config
448
+ config_path = os.path.join(checkpoint_path, "config.json")
449
+ with open(config_path, "w") as f:
450
+ json.dump(self.config.to_dict(), f, indent=2)
451
+
452
+ print(f"Saved checkpoint: {checkpoint_path}")
453
+
454
+ # Cleanup old checkpoints
455
+ self._cleanup_checkpoints()
456
+
457
+ def load_checkpoint(self, checkpoint_path: str):
458
+ """Load a checkpoint.
459
+
460
+ Args:
461
+ checkpoint_path: Path to checkpoint directory
462
+ """
463
+ # Load model
464
+ model_path = os.path.join(checkpoint_path, "model.pt")
465
+ state_dict = torch.load(model_path, map_location=self.device)
466
+
467
+ # FIX: Handle torch.compile prefix (_orig_mod.) if present
468
+ if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
469
+ print(" Detected compiled model checkpoint, removing _orig_mod. prefix...")
470
+ state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
471
+
472
+ self.model.load_state_dict(state_dict)
473
+
474
+ # Load optimizer and scheduler
475
+ optimizer_path = os.path.join(checkpoint_path, "optimizer.pt")
476
+ if os.path.exists(optimizer_path):
477
+ state = torch.load(optimizer_path, map_location=self.device)
478
+ self.optimizer.load_state_dict(state["optimizer"])
479
+ self.scheduler.load_state_dict(state["scheduler"])
480
+ self.global_step = state["global_step"]
481
+ self.epoch = state["epoch"]
482
+ self.best_val_loss = state.get("best_val_loss", float("inf"))
483
+ self.early_stopping_counter = state.get("early_stopping_counter", 0)
484
+
485
+ # FIX: Increment epoch to start from next epoch (we saved after completing this epoch)
486
+ # Only if checkpoint was saved at end of epoch (epoch_* checkpoints)
487
+ if "epoch_" in checkpoint_path:
488
+ self.epoch += 1
489
+ print(f" Checkpoint was end-of-epoch, will start from epoch {self.epoch + 1}")
490
+
491
+ print(f"Loaded checkpoint: {checkpoint_path}")
492
+ print(f" Resuming from step {self.global_step}, epoch {self.epoch}")
493
+ print(f" Best val loss so far: {self.best_val_loss:.4f}")
494
+
495
+ def _cleanup_checkpoints(self):
496
+ """Remove old checkpoints to save disk space."""
497
+ if self.config.save_total_limit <= 0:
498
+ return
499
+
500
+ checkpoint_dir = Path(self.config.checkpoint_dir)
501
+ checkpoints = sorted(
502
+ [d for d in checkpoint_dir.iterdir() if d.is_dir() and d.name.startswith("step_")],
503
+ key=lambda x: int(x.name.split("_")[1]),
504
+ )
505
+
506
+ # Keep only the most recent checkpoints (plus "best" and "epoch_*")
507
+ while len(checkpoints) > self.config.save_total_limit:
508
+ old_checkpoint = checkpoints.pop(0)
509
+ print(f"Removing old checkpoint: {old_checkpoint}")
510
+ import shutil
511
+ shutil.rmtree(old_checkpoint)