StarMist0012 commited on
Commit
3270dae
·
verified ·
1 Parent(s): e2bfccc

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. code/TaoTrain/src/taoTrain.egg-info/PKG-INFO +451 -0
  3. code/TaoTrain/src/taoTrain.egg-info/SOURCES.txt +65 -0
  4. code/TaoTrain/src/taoTrain.egg-info/requires.txt +20 -0
  5. code/TaoTrain/src/taoTrain.egg-info/top_level.txt +1 -0
  6. code/TaoTrain/src/taoTrain/benchmarks/__init__.py +5 -0
  7. code/TaoTrain/src/taoTrain/benchmarks/runner.py +221 -0
  8. code/TaoTrain/src/taoTrain/checkpointing/__init__.py +5 -0
  9. code/TaoTrain/src/taoTrain/checkpointing/checkpoint.py +194 -0
  10. code/TaoTrain/src/taoTrain/core/__init__.py +5 -0
  11. code/TaoTrain/src/taoTrain/core/base.py +271 -0
  12. code/TaoTrain/src/taoTrain/data/__init__.py +56 -0
  13. code/TaoTrain/src/taoTrain/data/async_loader.py +204 -0
  14. code/TaoTrain/src/taoTrain/data/chunk_manager.py +452 -0
  15. code/TaoTrain/src/taoTrain/data/factory.py +108 -0
  16. code/TaoTrain/src/taoTrain/data/hf_base.py +82 -0
  17. code/TaoTrain/src/taoTrain/data/hf_pretrain.py +78 -0
  18. code/TaoTrain/src/taoTrain/data/hf_rl.py +73 -0
  19. code/TaoTrain/src/taoTrain/data/hf_sft.py +81 -0
  20. code/TaoTrain/src/taoTrain/data/jsonl_base.py +220 -0
  21. code/TaoTrain/src/taoTrain/data/loaders.py +85 -0
  22. code/TaoTrain/src/taoTrain/data/pretrain_jsonl.py +65 -0
  23. code/TaoTrain/src/taoTrain/data/rl_jsonl.py +58 -0
  24. code/TaoTrain/src/taoTrain/data/sft_jsonl.py +156 -0
  25. code/TaoTrain/src/taoTrain/data/sft_utils.py +161 -0
  26. code/TaoTrain/src/taoTrain/data/tokenization_queue.py +410 -0
  27. code/TaoTrain/src/taoTrain/data/tokenizer.py +118 -0
  28. code/TaoTrain/src/taoTrain/inference/__init__.py +5 -0
  29. code/TaoTrain/src/taoTrain/inference/inferencer.py +301 -0
  30. code/TaoTrain/src/taoTrain/inference/tui.py +161 -0
  31. code/TaoTrain/src/taoTrain/logging/__init__.py +5 -0
  32. code/TaoTrain/src/taoTrain/logging/aim_logger.py +153 -0
  33. code/TaoTrain/src/taoTrain/models/__init__.py +5 -0
  34. code/TaoTrain/src/taoTrain/models/embeddings.py +51 -0
  35. code/TaoTrain/src/taoTrain/models/mla_components.py +370 -0
  36. code/TaoTrain/src/taoTrain/models/registry.py +73 -0
  37. code/TaoTrain/src/taoTrain/models/taonet.py +248 -0
  38. code/TaoTrain/src/taoTrain/models/taonet_ssm.py +654 -0
  39. code/TaoTrain/src/taoTrain/models/transformer.py +315 -0
  40. code/TaoTrain/src/taoTrain/optimizers/__init__.py +13 -0
  41. code/TaoTrain/src/taoTrain/optimizers/adam.py +64 -0
  42. code/TaoTrain/src/taoTrain/optimizers/adamw.py +64 -0
  43. code/TaoTrain/src/taoTrain/optimizers/hybrid_muon_adamw.py +243 -0
  44. code/TaoTrain/src/taoTrain/optimizers/registry.py +77 -0
  45. code/TaoTrain/src/taoTrain/optimizers/sgd.py +63 -0
  46. code/TaoTrain/src/taoTrain/schedulers/__init__.py +13 -0
  47. code/TaoTrain/src/taoTrain/schedulers/constant.py +44 -0
  48. code/TaoTrain/src/taoTrain/schedulers/cosine_warmup.py +77 -0
  49. code/TaoTrain/src/taoTrain/schedulers/linear_warmup.py +43 -0
  50. code/TaoTrain/src/taoTrain/schedulers/registry.py +78 -0
.gitattributes CHANGED
@@ -3,3 +3,5 @@
3
  *.vocab filter=lfs diff=lfs merge=lfs -text
4
  *.csv filter=lfs diff=lfs merge=lfs -text
5
 
 
 
 
3
  *.vocab filter=lfs diff=lfs merge=lfs -text
4
  *.csv filter=lfs diff=lfs merge=lfs -text
5
 
6
+ code/Taotern_SSM/Gamma[[:space:]]Distributed[[:space:]]Ternary[[:space:]]HiPPO.pdf filter=lfs diff=lfs merge=lfs -text
7
+ code/Taotern_LLM_Experiments/docs/Taotern_Documentation_AI_Architecture.zip filter=lfs diff=lfs merge=lfs -text
code/TaoTrain/src/taoTrain.egg-info/PKG-INFO ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: taoTrain
3
+ Version: 0.1.0
4
+ Summary: Clean, modular PyTorch LLM training framework with pluggable architectures, AimStack logging, and TUI inference
5
+ Author-email: Felix <felix@example.com>
6
+ License: MIT
7
+ Requires-Python: >=3.10
8
+ Description-Content-Type: text/markdown
9
+ Requires-Dist: torch>=2.0.0
10
+ Requires-Dist: transformers>=4.30.0
11
+ Requires-Dist: datasets>=2.10.0
12
+ Requires-Dist: pydantic>=2.0.0
13
+ Requires-Dist: pydantic-settings>=2.0.0
14
+ Requires-Dist: aim>=3.15.0
15
+ Requires-Dist: click>=8.1.0
16
+ Requires-Dist: rich>=13.0.0
17
+ Requires-Dist: textual>=0.30.0
18
+ Requires-Dist: numpy>=1.24.0
19
+ Requires-Dist: tqdm>=4.65.0
20
+ Requires-Dist: sentencepiece>=0.1.99
21
+ Provides-Extra: dev
22
+ Requires-Dist: pytest>=7.4.0; extra == "dev"
23
+ Requires-Dist: pytest-cov>=4.1.0; extra == "dev"
24
+ Requires-Dist: pytest-xdist>=3.3.0; extra == "dev"
25
+ Requires-Dist: black>=23.7.0; extra == "dev"
26
+ Requires-Dist: ruff>=0.0.280; extra == "dev"
27
+ Requires-Dist: typing-extensions>=4.7.0; extra == "dev"
28
+
29
+ # TaoTrain: Production-Grade LLM Training Framework
30
+
31
+ **TaoTrain** is a sophisticated PyTorch framework for training large language models at every scale—from experimental pretraining through supervised fine-tuning to reinforcement learning. Unlike fragmented training scripts or heavyweight frameworks, TaoTrain unifies the **entire training pipeline** in a clean, modular codebase that appeals to both ML engineers and software engineers.
32
+
33
+ ## Current Taotern Work
34
+
35
+ TaoTrain now includes the Taotern comparison architectures used by the current SSM LLM work:
36
+
37
+ - `taonet`: the attention/MLA baseline.
38
+ - `taonet_ssm`: the TaoNet shell with the attention mixer replaced by the Gamma Space Model DPLR SSM.
39
+ - `taonet_hybrid`: an alternating attention/SSM TaoNet used for the current best 200M-class candidate.
40
+
41
+ The current selected deployment-oriented run is `hybrid_ssm_first_199m`, a `199,480,928` parameter model with 16 layers: SSM layers at `0,2,4,6,8,10,12,14` and attention layers at `1,3,5,7,9,11,13,15`. It uses the DPLR SSM core with split two-lane mixing, channel gates, per-channel local shift, and the faster convolution path for long-sequence training.
42
+
43
+ Remote run `taotern-200m-hybrid-chat-20260512` trains this model on TaoData for a 4B-token base stage and then runs SFT so the final artifact can be loaded as a chat model. The trainable fixes added for this run are:
44
+
45
+ - Async JSONL iteration keeps polling while tokenization workers are alive instead of ending early after a temporary empty queue.
46
+ - Cached JSONL scan metadata is reused safely while recomputing chunk ranges for the active `samples_per_chunk` and `max_samples` settings.
47
+
48
+ ## Why TaoTrain?
49
+
50
+ - **Complete Unified Pipeline**: Pretraining → SFT → RL in a single, consistent framework. No context switching between different codebases or architectures.
51
+ - **Production-Grade Engineering**: Type-safe Pydantic configs, comprehensive checkpointing, AimStack integration, and proper gradient handling—not research code, but a framework you can deploy.
52
+ - **Extensibility Without Modification**: Register custom models, optimizers, schedulers, and datasets via decorators. Experiment freely without forking the framework.
53
+ - **Developer Experience First**: Interactive TUI for inference, intuitive YAML configurations, async data loading that eliminates I/O bottlenecks, and clear abstractions that make the codebase a pleasure to work with.
54
+
55
+ ## Key Capabilities
56
+
57
+ | Capability | Details |
58
+ |---|---|
59
+ | **Multi-Stage Training** | Unified infrastructure for pretraining, SFT, and RL. Share model checkpoints, logging, and evaluation across stages. |
60
+ | **Advanced Optimization** | Hybrid Muon + AdamW optimizer: efficient 2D weight updates via SVD-based methods + adaptive learning for 1D parameters. |
61
+ | **Modern Architectures** | DeepSeek MLA with grouped query attention (GQA), YaRN context extension, and factorized embeddings—all configurable via YAML. |
62
+ | **Production Features** | BF16 mixed precision training, gradient accumulation, proper gradient clipping, checkpoint resumption, and validation loops. |
63
+ | **Async Data Pipeline** | Background tokenization with multi-threaded workers. Stream billion-token datasets from JSONL without loading into memory. |
64
+ | **Interactive Inference** | TUI chat interface with real-time generation speed metrics and multi-model comparison. |
65
+ | **Logging & Monitoring** | AimStack integration tracks loss, metrics, hyperparameters, and git hashes for reproducibility. Visualize training runs in your browser. |
66
+
67
+ ## Getting Started
68
+
69
+ ### Installation
70
+
71
+ ```bash
72
+ git clone https://github.com/lobakkang/taoTrain.git
73
+ cd taoTrain
74
+ pip install -e .
75
+ ```
76
+
77
+ ### Training Examples
78
+
79
+ **Pretraining on a custom dataset:**
80
+ ```bash
81
+ train pretrain --config configs/pretrain.yaml
82
+ ```
83
+ Starts from scratch, learns representations from raw text via next-token prediction.
84
+
85
+ **Supervised Fine-tuning:**
86
+ ```bash
87
+ train sft --config configs/sft.yaml
88
+ ```
89
+ Fine-tune a pretrained model on instruction-response pairs for improved task performance.
90
+
91
+ **Reinforcement Learning (DPO):**
92
+ ```bash
93
+ train rl --config configs/rl_dpo.yaml
94
+ ```
95
+ Align models with human preferences using Direct Preference Optimization.
96
+
97
+ **Interactive Chat:**
98
+ ```bash
99
+ tui-chat --model checkpoints/model.pt
100
+ ```
101
+ Launch an interactive TUI to chat with your model and monitor generation metrics in real-time.
102
+
103
+ ### Configuration
104
+
105
+ All training is configured via YAML with Pydantic validation. Configs are type-safe and automatically validated:
106
+
107
+ ```yaml
108
+ # configs/sft.yaml
109
+ model:
110
+ architecture_type: "mla" # DeepSeek MLA with GQA
111
+ hidden_dim: 2048
112
+ num_layers: 24
113
+ num_heads: 32
114
+ d_latent_kv: 1536 # KV compression factor
115
+
116
+ training:
117
+ num_epochs: 3
118
+ batch_size: 32
119
+ learning_rate: 1e-4
120
+ warmup_ratio: 0.1
121
+ max_grad_norm: 1.0
122
+
123
+ optimizer:
124
+ optimizer_type: "muon_adamw" # Hybrid Muon + AdamW
125
+ muon_momentum: 0.95
126
+
127
+ data:
128
+ dataset_type: "sft_jsonl" # or "sft_hf" for HuggingFace
129
+ path: "data/sft_training.jsonl"
130
+
131
+ logging:
132
+ log_to_aim: true
133
+ aim_repo: "/tmp/aim_logs"
134
+ ```
135
+
136
+ See `configs/` for complete examples.
137
+
138
+ ## Project Architecture
139
+
140
+ ```
141
+ src/taoTrain/
142
+ ├── cli.py # Main CLI entry point
143
+ ├── config.py # Pydantic configuration schemas
144
+
145
+ ├── core/ # Base abstractions
146
+ │ └── base.py # BaseModel, BaseDataset, BaseTrainer
147
+
148
+ ├── models/ # Pluggable architecture system
149
+ │ ├── registry.py # Architecture factory with @register_architecture
150
+ │ ├── taonet.py # SimpleLLM with DeepSeek MLA
151
+ │ ├── mla_components.py # KV compression, GQA, YaRN
152
+ │ ├── embeddings.py # Factorized embeddings
153
+ │ └── transformer.py # Standard Transformer reference
154
+
155
+ ├── data/ # Advanced data pipeline
156
+ │ ├── factory.py # Dataset factory (HF + JSONL backends)
157
+ │ ├── async_loader.py # Async batch iteration (no I/O bottleneck)
158
+ │ ├── tokenization_queue.py # Background multi-threaded tokenization
159
+ │ ├── chunk_manager.py # Stream billion-token JSONL files
160
+ │ ├── hf_pretrain.py # HuggingFace pretraining datasets
161
+ │ ├── hf_sft.py # HuggingFace SFT datasets
162
+ │ ├── hf_rl.py # HuggingFace RL datasets
163
+ │ ├── pretrain_jsonl.py # JSONL pretraining
164
+ │ ├── sft_jsonl.py # JSONL SFT with instructions
165
+ │ └── rl_jsonl.py # JSONL RL with preferences
166
+
167
+ ├── training/ # Unified training infrastructure
168
+ │ └── trainer.py # Trainer + PretrainTrainer, SFTTrainer, RLTrainer
169
+
170
+ ├── optimizers/ # Pluggable optimizer system
171
+ │ ├── registry.py # Optimizer factory with @register_optimizer
172
+ │ ├── hybrid_muon_adamw.py # Composite: Muon (2D) + AdamW (1D)
173
+ │ ├── adamw.py # AdamW with weight decay
174
+ │ ├── adam.py # Standard Adam
175
+ │ └── sgd.py # SGD variants
176
+
177
+ ├── schedulers/ # Learning rate schedules
178
+ │ ├── registry.py # LR scheduler factory
179
+ │ ├── cosine_warmup.py # 3-phase: linear warmup → plateau → cosine decay
180
+ │ ├── linear_warmup.py # Linear warmup + constant
181
+ │ └── constant.py # Constant learning rate
182
+
183
+ ├── inference/ # Inference & interaction
184
+ │ ├── inferencer.py # Load & run inference from checkpoints
185
+ │ └── tui.py # Interactive chat with metrics display
186
+
187
+ ├── checkpointing/ # State management
188
+ │ └── checkpoint.py # Save/load model + optimizer + config + metrics
189
+
190
+ ├── logging/ # Experiment tracking
191
+ │ └── aim_logger.py # AimStack integration (loss, metrics, hyperparams)
192
+
193
+ ├── benchmarks/ # Evaluation tools
194
+ │ └── runner.py # Perplexity, speed, and task-specific benchmarks
195
+
196
+ └── utils/
197
+ └── helpers.py # Utility functions
198
+
199
+ configs/ # Example YAML configurations
200
+ ├── pretrain.yaml # Pretraining config
201
+ ├── sft.yaml # SFT config
202
+ ├── rl_dpo.yaml # RL/DPO config
203
+ └── tokenizer.yaml # Tokenizer config
204
+
205
+ tests/ # Unit & integration tests
206
+ └── test_dataset.py
207
+ ```
208
+
209
+ ## Extensible Architecture: The Registry Pattern
210
+
211
+ TaoTrain's power lies in its **pluggable design**. Add custom models, optimizers, schedulers, and datasets without modifying the framework.
212
+
213
+ ### Custom Model Architecture
214
+
215
+ ```python
216
+ from taoTrain.models import register_architecture, BaseModel
217
+ import torch.nn as nn
218
+
219
+ @register_architecture("custom_moe")
220
+ class MixtureOfExperts(BaseModel):
221
+ """Your custom MoE architecture"""
222
+ def __init__(self, config):
223
+ super().__init__(config)
224
+ self.experts = nn.ModuleList([
225
+ nn.Linear(config.hidden_dim, config.hidden_dim)
226
+ for _ in range(config.num_experts)
227
+ ])
228
+ self.router = nn.Linear(config.hidden_dim, config.num_experts)
229
+
230
+ def forward(self, input_ids, attention_mask=None):
231
+ # Your implementation
232
+ logits = self.compute_logits(input_ids)
233
+ loss = self.compute_loss(logits, labels) if labels is not None else None
234
+ return {"logits": logits, "loss": loss}
235
+ ```
236
+
237
+ Then use it in your config:
238
+
239
+ ```yaml
240
+ model:
241
+ architecture_type: "custom_moe"
242
+ hidden_dim: 2048
243
+ num_experts: 8
244
+ ```
245
+
246
+ ### Custom Optimizers & Schedulers
247
+
248
+ The same pattern works for optimizers and learning rate schedules:
249
+
250
+ ```python
251
+ from taoTrain.optimizers import register_optimizer
252
+ from torch.optim import Optimizer
253
+
254
+ @register_optimizer("my_adaptive_optimizer")
255
+ class MyAdaptiveOptimizer(Optimizer):
256
+ def step(self, closure=None):
257
+ # Your optimization logic
258
+ pass
259
+ ```
260
+
261
+ ```python
262
+ from taoTrain.schedulers import register_scheduler
263
+
264
+ @register_scheduler("my_schedule")
265
+ def my_schedule(initial_lr, step, total_steps, **kwargs):
266
+ return initial_lr * (1.0 - step / total_steps) # Linear decay
267
+ ```
268
+
269
+ **The key principle**: No framework code needs to change. You register once, it's available everywhere.
270
+
271
+ ### Dataset Backend Flexibility
272
+
273
+ Define custom datasets (JSONL, HF, streaming, etc.) and let the factory route to them:
274
+
275
+ ```python
276
+ from taoTrain.data import register_dataset
277
+
278
+ @register_dataset("pretrain", "my_backend")
279
+ class MyPretrainDataset(BaseDataset):
280
+ def __init__(self, config):
281
+ # Load from your custom backend
282
+ pass
283
+
284
+ def __getitem__(self, idx):
285
+ return {"input_ids": ..., "attention_mask": ...}
286
+ ```
287
+
288
+ Use in config:
289
+
290
+ ```yaml
291
+ data:
292
+ dataset_type: "pretrain"
293
+ backend_type: "my_backend" # Routes to MyPretrainDataset
294
+ ```
295
+
296
+ ## Why TaoTrain Framework?
297
+
298
+ ### Async Data Loading: No I/O Bottleneck
299
+
300
+ Most training frameworks load and tokenize data on the main training thread, blocking compute. TaoTrain's **multi-threaded tokenization pipeline**:
301
+
302
+ - Tokenizes data in background workers while your GPU trains
303
+ - Supports streaming billion-token JSONL files without loading into memory
304
+ - Intelligent chunking (by file size or sample count)
305
+ - Metadata caching to avoid rescanning
306
+
307
+ **Result**: 10-100x faster data iteration on large datasets.
308
+
309
+ ### Type-Safe Configuration
310
+
311
+ Forget YAML parsing errors or mysterious config bugs. TaoTrain uses **Pydantic dataclasses** for configuration:
312
+
313
+ - Automatic type validation: mistyped `learning_rate: "1e-4"` becomes an error, not silent failure
314
+ - Serialization: configs are part of checkpoints, ensuring reproducibility
315
+ - IDE support: autocomplete and type hints for all config fields
316
+ - Defaults: sensible defaults for all parameters
317
+
318
+ ### Benchmarking & Metrics
319
+
320
+ Track what matters:
321
+
322
+ - **Perplexity**: Language modeling quality on held-out data
323
+ - **Generation Speed**: Tokens-per-second (useful for TUI or deployment)
324
+ - **Task-Specific Accuracy**: Evaluate on downstream tasks
325
+ - **Training Metrics**: Loss curves, gradient norms, effective batch size
326
+
327
+ All logged to AimStack with git hashes for reproducibility.
328
+
329
+ ## Logging with AimStack
330
+
331
+ Automatically track and visualize experiments:
332
+
333
+ ```bash
334
+ aim up --host 0.0.0.0
335
+ ```
336
+
337
+ Then open `http://localhost:43800` to see:
338
+
339
+ - **Loss curves** per training step
340
+ - **Hyperparameters** (learning rate, batch size, model architecture)
341
+ - **Git hashes** for reproducibility
342
+ - **Custom metrics** (perplexity, validation accuracy, generation speed)
343
+ - **Compare runs**: Side-by-side experiment comparison
344
+
345
+ ## Advanced Features
346
+
347
+ ### Checkpointing with Resumption
348
+
349
+ TaoTrain saves complete training state:
350
+
351
+ ```python
352
+ checkpoint = {
353
+ "step": 12500,
354
+ "model_state": model.state_dict(),
355
+ "optimizer_state": optimizer.state_dict(),
356
+ "config": config, # Full config as Pydantic object
357
+ "metrics": metrics_tracker.to_dict(),
358
+ }
359
+ ```
360
+
361
+ Resume training from any checkpoint without loss of state. Keep last N checkpoints automatically.
362
+
363
+ ### Mixed Precision Training (BF16)
364
+
365
+ ```yaml
366
+ training:
367
+ use_bfloat16: true
368
+ gradient_accumulation_steps: 4
369
+ ```
370
+
371
+ - BF16 via `torch.autocast` for ~2x speedup with minimal accuracy loss
372
+ - Proper gradient scaling and clipping
373
+ - Compatible with all optimizers and architectures
374
+
375
+ ### 3-Phase Learning Rate Schedule
376
+
377
+ ```yaml
378
+ scheduler:
379
+ scheduler_type: "cosine_warmup"
380
+ warmup_ratio: 0.1 # 10% of training steps
381
+ steady_ratio: 0.5 # 50% at steady rate
382
+ min_lr_ratio: 0.1 # Final LR = 0.1 × initial_lr
383
+ num_cycles: 1
384
+ ```
385
+
386
+ This schedule:
387
+ 1. **Linear warmup** (0 → 1) over 10% of steps
388
+ 2. **Steady plateau** at full LR over 50% of steps
389
+ 3. **Cosine decay** (1 → 0.1) over remaining 40% of steps
390
+
391
+ Better convergence than simple cosine or linear decay.
392
+
393
+ ### Gradient Accumulation & Clipping
394
+
395
+ Simulate larger batch sizes with gradient accumulation:
396
+
397
+ ```yaml
398
+ training:
399
+ batch_size: 32
400
+ gradient_accumulation_steps: 4 # Effective batch = 128
401
+ max_grad_norm: 1.0 # Gradient clipping
402
+ ```
403
+
404
+ ## Contributing
405
+
406
+ Contributions are welcome! TaoTrain is designed to make contributions easy:
407
+
408
+ 1. **Add a model**: Implement `BaseModel` and `@register_architecture("name")`
409
+ 2. **Add an optimizer**: Implement `torch.optim.Optimizer` and `@register_optimizer("name")`
410
+ 3. **Add a dataset**: Implement `BaseDataset` and `@register_dataset(mode, backend_type)`
411
+ 4. **Improve the core**: Submit PRs to `training/`, `data/`, `logging/`, etc.
412
+
413
+ Ensure new code includes:
414
+ - Type hints throughout
415
+ - Pydantic configs for new parameters
416
+ - Unit tests in `tests/`
417
+ - Documentation in docstrings and README
418
+
419
+ ## Current Scope & Roadmap
420
+
421
+ ### ✅ Currently Supported
422
+
423
+ - **Single GPU / single node** training
424
+ - **Pretraining, SFT, and RL training** stages
425
+ - **HuggingFace and JSONL** data backends
426
+ - **BF16 mixed precision** training
427
+ - **Checkpoint saving/loading** with resumption
428
+ - **Interactive inference** via TUI
429
+ - **Benchmarking** (perplexity, speed)
430
+ - **Pluggable architectures, optimizers, schedulers, datasets**
431
+
432
+ ### 🚀 Roadmap (Future)
433
+
434
+ - **Distributed training** (DDP, FSDP) for multi-GPU/multi-node scaling
435
+ - **Quantization** support (INT8, QLoRA)
436
+ - **Advanced evaluation** (BLEU, ROUGE, custom tasks)
437
+ - **Streaming inference** with KV cache
438
+ - **Speculative decoding** for faster generation
439
+ - **Integration with popular model hubs** (Hugging Face Hub upload/download)
440
+
441
+ ---
442
+
443
+ ## Getting Help
444
+
445
+ - **Questions?** Open an issue on GitHub
446
+ - **Want to contribute?** See `CONTRIBUTING.md` (coming soon)
447
+ - **Found a bug?** Report it with a minimal reproduction script
448
+
449
+ ## License
450
+
451
+ MIT
code/TaoTrain/src/taoTrain.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ README.md
2
+ pyproject.toml
3
+ src/taoTrain/__init__.py
4
+ src/taoTrain/cli.py
5
+ src/taoTrain/config.py
6
+ src/taoTrain.egg-info/PKG-INFO
7
+ src/taoTrain.egg-info/SOURCES.txt
8
+ src/taoTrain.egg-info/dependency_links.txt
9
+ src/taoTrain.egg-info/entry_points.txt
10
+ src/taoTrain.egg-info/requires.txt
11
+ src/taoTrain.egg-info/top_level.txt
12
+ src/taoTrain/benchmarks/__init__.py
13
+ src/taoTrain/benchmarks/runner.py
14
+ src/taoTrain/checkpointing/__init__.py
15
+ src/taoTrain/checkpointing/checkpoint.py
16
+ src/taoTrain/core/__init__.py
17
+ src/taoTrain/core/base.py
18
+ src/taoTrain/data/__init__.py
19
+ src/taoTrain/data/async_loader.py
20
+ src/taoTrain/data/chunk_manager.py
21
+ src/taoTrain/data/factory.py
22
+ src/taoTrain/data/hf_base.py
23
+ src/taoTrain/data/hf_pretrain.py
24
+ src/taoTrain/data/hf_rl.py
25
+ src/taoTrain/data/hf_sft.py
26
+ src/taoTrain/data/jsonl_base.py
27
+ src/taoTrain/data/loaders.py
28
+ src/taoTrain/data/pretrain_jsonl.py
29
+ src/taoTrain/data/rl_jsonl.py
30
+ src/taoTrain/data/sft_jsonl.py
31
+ src/taoTrain/data/sft_utils.py
32
+ src/taoTrain/data/tokenization_queue.py
33
+ src/taoTrain/data/tokenizer.py
34
+ src/taoTrain/inference/__init__.py
35
+ src/taoTrain/inference/inferencer.py
36
+ src/taoTrain/inference/tui.py
37
+ src/taoTrain/logging/__init__.py
38
+ src/taoTrain/logging/aim_logger.py
39
+ src/taoTrain/models/__init__.py
40
+ src/taoTrain/models/embeddings.py
41
+ src/taoTrain/models/mla_components.py
42
+ src/taoTrain/models/registry.py
43
+ src/taoTrain/models/taonet.py
44
+ src/taoTrain/models/taonet_ssm.py
45
+ src/taoTrain/models/transformer.py
46
+ src/taoTrain/optimizers/__init__.py
47
+ src/taoTrain/optimizers/adam.py
48
+ src/taoTrain/optimizers/adamw.py
49
+ src/taoTrain/optimizers/hybrid_muon_adamw.py
50
+ src/taoTrain/optimizers/registry.py
51
+ src/taoTrain/optimizers/sgd.py
52
+ src/taoTrain/schedulers/__init__.py
53
+ src/taoTrain/schedulers/constant.py
54
+ src/taoTrain/schedulers/cosine_warmup.py
55
+ src/taoTrain/schedulers/linear_warmup.py
56
+ src/taoTrain/schedulers/registry.py
57
+ src/taoTrain/tokenizers/__init__.py
58
+ src/taoTrain/tokenizers/trainer.py
59
+ src/taoTrain/training/__init__.py
60
+ src/taoTrain/training/trainer.py
61
+ src/taoTrain/utils/__init__.py
62
+ src/taoTrain/utils/helpers.py
63
+ tests/test_dataset.py
64
+ tests/test_sft_masking.py
65
+ tests/test_taonet_ssm.py
code/TaoTrain/src/taoTrain.egg-info/requires.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ transformers>=4.30.0
3
+ datasets>=2.10.0
4
+ pydantic>=2.0.0
5
+ pydantic-settings>=2.0.0
6
+ aim>=3.15.0
7
+ click>=8.1.0
8
+ rich>=13.0.0
9
+ textual>=0.30.0
10
+ numpy>=1.24.0
11
+ tqdm>=4.65.0
12
+ sentencepiece>=0.1.99
13
+
14
+ [dev]
15
+ pytest>=7.4.0
16
+ pytest-cov>=4.1.0
17
+ pytest-xdist>=3.3.0
18
+ black>=23.7.0
19
+ ruff>=0.0.280
20
+ typing-extensions>=4.7.0
code/TaoTrain/src/taoTrain.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ taoTrain
code/TaoTrain/src/taoTrain/benchmarks/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Benchmarking suite."""
2
+
3
+ from .runner import BenchmarkRunner
4
+
5
+ __all__ = ["BenchmarkRunner"]
code/TaoTrain/src/taoTrain/benchmarks/runner.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Benchmarking suite for evaluating trained models."""
2
+
3
+ import time
4
+ from pathlib import Path
5
+ from typing import Optional, Dict
6
+ import torch
7
+ from torch.utils.data import DataLoader
8
+
9
+ from taoTrain.core import BaseModel
10
+ from taoTrain.config import TrainingConfig
11
+ from taoTrain.data.loaders import get_dataloader
12
+ from taoTrain.inference import Inferencer
13
+
14
+
15
+ class BenchmarkRunner:
16
+ """Run benchmarks on a trained model."""
17
+
18
+ def __init__(
19
+ self,
20
+ model: BaseModel,
21
+ device: torch.device,
22
+ dtype: torch.dtype = torch.float32,
23
+ ):
24
+ """
25
+ Initialize benchmark runner.
26
+
27
+ Args:
28
+ model: Trained model
29
+ device: Device for inference
30
+ dtype: Data type
31
+ """
32
+ self.model = model.to(device)
33
+ self.model.eval()
34
+ self.device = device
35
+ self.dtype = dtype
36
+
37
+ @staticmethod
38
+ def load_from_checkpoint(
39
+ checkpoint_path: str | Path,
40
+ device: Optional[torch.device] = None,
41
+ ) -> "BenchmarkRunner":
42
+ """Load model from checkpoint."""
43
+ if device is None:
44
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45
+
46
+ checkpoint = torch.load(checkpoint_path, map_location=device)
47
+
48
+ # Reconstruct model config
49
+ from taoTrain.config import ModelConfig
50
+ from taoTrain.models import get_model
51
+
52
+ model_config = ModelConfig(**checkpoint.get("config", {}).get("model", {}))
53
+ model = get_model(model_config, device=device)
54
+ model.load_state_dict(checkpoint["model_state_dict"])
55
+
56
+ return BenchmarkRunner(model, device)
57
+
58
+ def benchmark_perplexity(
59
+ self,
60
+ dataset: "DataLoader",
61
+ num_batches: Optional[int] = None,
62
+ ) -> float:
63
+ """
64
+ Compute perplexity on a dataset.
65
+
66
+ Args:
67
+ dataset: DataLoader for evaluation
68
+ num_batches: Limit evaluation to N batches
69
+
70
+ Returns:
71
+ Perplexity (exp of average loss)
72
+ """
73
+ total_loss = 0.0
74
+ total_tokens = 0
75
+
76
+ with torch.no_grad():
77
+ for batch_idx, batch in enumerate(dataset):
78
+ if num_batches and batch_idx >= num_batches:
79
+ break
80
+
81
+ # Move to device
82
+ input_ids = batch["input_ids"].to(self.device)
83
+ attention_mask = batch.get("attention_mask")
84
+ if attention_mask is not None:
85
+ attention_mask = attention_mask.to(self.device)
86
+ labels = batch.get("labels")
87
+ if labels is not None:
88
+ labels = labels.to(self.device)
89
+
90
+ # Forward pass
91
+ with torch.autocast(
92
+ device_type="cuda" if self.device.type == "cuda" else "cpu",
93
+ dtype=torch.bfloat16 if self.dtype == torch.bfloat16 else torch.float32,
94
+ ):
95
+ outputs = self.model(
96
+ input_ids=input_ids,
97
+ attention_mask=attention_mask,
98
+ labels=labels,
99
+ )
100
+ loss = outputs.get("loss")
101
+
102
+ if loss is not None:
103
+ total_loss += loss.item() * input_ids.shape[0]
104
+ total_tokens += input_ids.shape[0]
105
+
106
+ avg_loss = total_loss / total_tokens if total_tokens > 0 else float('inf')
107
+ perplexity = torch.exp(torch.tensor(avg_loss)).item()
108
+
109
+ return perplexity
110
+
111
+ def benchmark_throughput(
112
+ self,
113
+ batch_size: int = 32,
114
+ seq_length: int = 1024,
115
+ num_iters: int = 10,
116
+ ) -> Dict[str, float]:
117
+ """
118
+ Benchmark forward pass throughput.
119
+
120
+ Args:
121
+ batch_size: Batch size
122
+ seq_length: Sequence length
123
+ num_iters: Number of iterations
124
+
125
+ Returns:
126
+ Dict with throughput metrics
127
+ """
128
+ # Create dummy batch
129
+ dummy_input = torch.randint(
130
+ 0, self.model.config.vocab_size,
131
+ (batch_size, seq_length)
132
+ ).to(self.device)
133
+
134
+ # Warmup
135
+ with torch.no_grad():
136
+ for _ in range(2):
137
+ _ = self.model(dummy_input)
138
+
139
+ torch.cuda.synchronize() if torch.cuda.is_available() else None
140
+
141
+ # Benchmark forward pass
142
+ start = time.time()
143
+
144
+ with torch.no_grad():
145
+ for _ in range(num_iters):
146
+ _ = self.model(dummy_input)
147
+
148
+ torch.cuda.synchronize() if torch.cuda.is_available() else None
149
+
150
+ elapsed = time.time() - start
151
+
152
+ total_tokens = batch_size * seq_length * num_iters
153
+ tokens_per_sec = total_tokens / elapsed
154
+
155
+ return {
156
+ "throughput_tokens_per_sec": tokens_per_sec,
157
+ "throughput_samples_per_sec": (batch_size * num_iters) / elapsed,
158
+ "avg_time_per_iter_ms": (elapsed / num_iters) * 1000,
159
+ }
160
+
161
+ def benchmark_memory(self) -> Dict[str, float]:
162
+ """
163
+ Benchmark peak GPU memory usage.
164
+
165
+ Returns:
166
+ Dict with memory stats
167
+ """
168
+ if not torch.cuda.is_available():
169
+ return {"peak_memory_gb": 0.0}
170
+
171
+ torch.cuda.reset_peak_memory_stats()
172
+ torch.cuda.synchronize()
173
+
174
+ # Create dummy batch
175
+ dummy_input = torch.randint(
176
+ 0, self.model.config.vocab_size,
177
+ (16, 1024)
178
+ ).to(self.device)
179
+
180
+ with torch.no_grad():
181
+ _ = self.model(dummy_input)
182
+
183
+ torch.cuda.synchronize()
184
+
185
+ peak_memory = torch.cuda.max_memory_allocated() / (1024 ** 3) # GB
186
+
187
+ return {"peak_memory_gb": peak_memory}
188
+
189
+ def run_all_benchmarks(
190
+ self,
191
+ dataset: Optional["DataLoader"] = None,
192
+ batch_size: int = 32,
193
+ seq_length: int = 1024,
194
+ ) -> Dict[str, float]:
195
+ """
196
+ Run all benchmarks.
197
+
198
+ Args:
199
+ dataset: DataLoader for perplexity benchmark
200
+ batch_size: Batch size for throughput benchmark
201
+ seq_length: Sequence length for throughput benchmark
202
+
203
+ Returns:
204
+ Dict with all benchmark results
205
+ """
206
+ results = {}
207
+
208
+ if dataset is not None:
209
+ print("Running perplexity benchmark...")
210
+ ppl = self.benchmark_perplexity(dataset, num_batches=10)
211
+ results["perplexity"] = ppl
212
+
213
+ print("Running throughput benchmark...")
214
+ throughput = self.benchmark_throughput(batch_size, seq_length)
215
+ results.update(throughput)
216
+
217
+ print("Running memory benchmark...")
218
+ memory = self.benchmark_memory()
219
+ results.update(memory)
220
+
221
+ return results
code/TaoTrain/src/taoTrain/checkpointing/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Checkpoint management."""
2
+
3
+ from .checkpoint import CheckpointManager
4
+
5
+ __all__ = ["CheckpointManager"]
code/TaoTrain/src/taoTrain/checkpointing/checkpoint.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Checkpoint management utilities.
2
+
3
+ Canonical Checkpoint Format (new):
4
+ {
5
+ 'step': int, # Training step number
6
+ 'model_state': Dict[str, Tensor], # Model state dict
7
+ 'optimizer_state': Dict, # Optimizer state dict (optional)
8
+ 'config': Dict, # TrainingConfig as dict
9
+ 'metrics': Dict[str, float], # Training metrics
10
+ 'global_step': int, # (deprecated, kept for compat) same as step
11
+ 'current_epoch': int, # (optional) current epoch number
12
+ 'best_loss': float, # (optional) best validation loss
13
+ }
14
+
15
+ Legacy Checkpoint Format (old, from BaseTrainer):
16
+ {
17
+ 'global_step': int,
18
+ 'current_epoch': int,
19
+ 'best_loss': float,
20
+ 'model_state_dict': Dict[str, Tensor], # ← Note: uses '_dict' suffix
21
+ 'optimizer_state_dict': Dict,
22
+ 'config': Dict,
23
+ }
24
+
25
+ The load() function auto-detects and migrates legacy format to canonical format.
26
+ """
27
+
28
+ from pathlib import Path
29
+ from typing import Dict, Any, Optional
30
+ import torch
31
+ from taoTrain.config import TrainingConfig
32
+
33
+
34
+ class CheckpointManager:
35
+ """Manage model checkpoints with versioning."""
36
+
37
+ def __init__(
38
+ self,
39
+ checkpoint_dir: str | Path,
40
+ keep_last_n: int = 3,
41
+ track_best: bool = True,
42
+ ):
43
+ """
44
+ Initialize checkpoint manager.
45
+
46
+ Args:
47
+ checkpoint_dir: Directory to save checkpoints
48
+ keep_last_n: Number of recent checkpoints to keep
49
+ track_best: Whether to track best model
50
+ """
51
+ self.checkpoint_dir = Path(checkpoint_dir)
52
+ self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
53
+
54
+ self.keep_last_n = keep_last_n
55
+ self.track_best = track_best
56
+
57
+ self.best_metric = None
58
+ self.best_metric_name = None
59
+ self.saved_checkpoints = []
60
+
61
+ def save(
62
+ self,
63
+ step: int,
64
+ model_state: Dict[str, Any],
65
+ optimizer_state: Optional[Dict[str, Any]] = None,
66
+ config: Optional[TrainingConfig] = None,
67
+ metrics: Optional[Dict[str, float]] = None,
68
+ is_best: bool = False,
69
+ ) -> Path:
70
+ """
71
+ Save a checkpoint.
72
+
73
+ Args:
74
+ step: Training step
75
+ model_state: Model state dict
76
+ optimizer_state: Optimizer state dict
77
+ config: Training config
78
+ metrics: Metrics dict
79
+ is_best: Whether this is the best model so far
80
+
81
+ Returns:
82
+ Path to saved checkpoint
83
+ """
84
+ checkpoint = {
85
+ "step": step,
86
+ "model_state": model_state,
87
+ "optimizer_state": optimizer_state,
88
+ "config": config.to_dict() if config else None,
89
+ "metrics": metrics or {},
90
+ }
91
+
92
+ filename = f"checkpoint_step_{step:06d}.pt"
93
+ if is_best:
94
+ filename = "best_model.pt"
95
+
96
+ path = self.checkpoint_dir / filename
97
+ torch.save(checkpoint, path)
98
+
99
+ # Track saved checkpoints
100
+ if not is_best:
101
+ self.saved_checkpoints.append((step, path))
102
+
103
+ # Clean up old checkpoints
104
+ if len(self.saved_checkpoints) > self.keep_last_n:
105
+ _, old_path = self.saved_checkpoints.pop(0)
106
+ if old_path.exists():
107
+ old_path.unlink()
108
+
109
+ return path
110
+
111
+ def load(
112
+ self,
113
+ checkpoint_path: str | Path,
114
+ device: Optional[torch.device] = None,
115
+ ) -> Dict[str, Any]:
116
+ """
117
+ Load a checkpoint with backward-compatible format handling.
118
+
119
+ Auto-detects checkpoint format (canonical or legacy) and normalizes
120
+ to canonical format in-memory. Legacy checkpoints are migrated without
121
+ modifying the file.
122
+
123
+ Args:
124
+ checkpoint_path: Path to checkpoint
125
+ device: Device to load to
126
+
127
+ Returns:
128
+ Checkpoint dict in canonical format with 'model_state' key
129
+ """
130
+ if device is None:
131
+ device = torch.device("cpu")
132
+
133
+ checkpoint = torch.load(checkpoint_path, map_location=device)
134
+
135
+ # Auto-detect and migrate legacy format to canonical format
136
+ checkpoint = self._normalize_checkpoint_format(checkpoint)
137
+
138
+ return checkpoint
139
+
140
+ def _normalize_checkpoint_format(self, checkpoint: Dict[str, Any]) -> Dict[str, Any]:
141
+ """
142
+ Normalize checkpoint to canonical format.
143
+
144
+ Detects if checkpoint is in legacy format (from BaseTrainer with 'model_state_dict')
145
+ and migrates it to canonical format (with 'model_state').
146
+
147
+ Args:
148
+ checkpoint: Raw checkpoint dict
149
+
150
+ Returns:
151
+ Normalized checkpoint dict with canonical keys
152
+ """
153
+ # Check if this is a legacy checkpoint (has 'model_state_dict' but not 'model_state')
154
+ if "model_state_dict" in checkpoint and "model_state" not in checkpoint:
155
+ # Migrate legacy format to canonical
156
+ migrated = {
157
+ "step": checkpoint.get("global_step", 0),
158
+ "model_state": checkpoint["model_state_dict"],
159
+ "optimizer_state": checkpoint.get("optimizer_state_dict"),
160
+ "config": checkpoint.get("config"),
161
+ "metrics": {},
162
+ # Keep legacy keys for backward compatibility in code that uses them
163
+ "global_step": checkpoint.get("global_step", 0),
164
+ "current_epoch": checkpoint.get("current_epoch", 0),
165
+ "best_loss": checkpoint.get("best_loss", float('inf')),
166
+ }
167
+ print(f"\n✓ [CheckpointManager] Detected legacy checkpoint format. Auto-migrated to canonical format.")
168
+ return migrated
169
+
170
+ # Already in canonical format or unknown format
171
+ if "model_state" not in checkpoint:
172
+ # If neither format detected, ensure model_state is accessible
173
+ # (might be a raw state_dict)
174
+ print(f"\n⚠ [CheckpointManager] Checkpoint format unclear. Assuming raw state_dict format.")
175
+ checkpoint["model_state"] = checkpoint
176
+
177
+ return checkpoint
178
+
179
+ def get_latest(self) -> Optional[Path]:
180
+ """Get path to latest checkpoint."""
181
+ if not self.saved_checkpoints:
182
+ return None
183
+ return self.saved_checkpoints[-1][1]
184
+
185
+ def get_best(self) -> Optional[Path]:
186
+ """Get path to best checkpoint."""
187
+ best_path = self.checkpoint_dir / "best_model.pt"
188
+ if best_path.exists():
189
+ return best_path
190
+ return None
191
+
192
+ def list_checkpoints(self) -> list[Path]:
193
+ """List all saved checkpoints."""
194
+ return sorted(self.checkpoint_dir.glob("checkpoint_step_*.pt"))
code/TaoTrain/src/taoTrain/core/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Base classes for models, trainers, and datasets."""
2
+
3
+ from .base import BaseModel, BaseTrainer, BaseDataset, create_model, create_datasets
4
+
5
+ __all__ = ["BaseModel", "BaseTrainer", "BaseDataset", "create_model", "create_datasets"]
code/TaoTrain/src/taoTrain/core/base.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Base classes for models, trainers, and datasets."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from pathlib import Path
5
+ from typing import Optional, Any, Iterator
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.utils.data import Dataset as TorchDataset
9
+ from taoTrain.config import TrainingConfig, ModelConfig
10
+
11
+
12
+ # ============================================================================
13
+ # Base Model
14
+ # ============================================================================
15
+
16
+
17
+ class BaseModel(nn.Module, ABC):
18
+ """Abstract base class for language models."""
19
+
20
+ def __init__(self, config: ModelConfig):
21
+ """Initialize model with config."""
22
+ super().__init__()
23
+ self.config = config
24
+
25
+ @abstractmethod
26
+ def forward(
27
+ self,
28
+ input_ids: torch.Tensor,
29
+ attention_mask: Optional[torch.Tensor] = None,
30
+ labels: Optional[torch.Tensor] = None,
31
+ ) -> dict[str, torch.Tensor]:
32
+ """
33
+ Forward pass.
34
+
35
+ Args:
36
+ input_ids: Shape (batch_size, seq_length)
37
+ attention_mask: Shape (batch_size, seq_length), optional
38
+ labels: Shape (batch_size, seq_length), optional (for loss computation)
39
+
40
+ Returns:
41
+ Dict with keys:
42
+ - 'logits': Shape (batch_size, seq_length, vocab_size)
43
+ - 'loss': Scalar (if labels provided)
44
+ """
45
+ pass
46
+
47
+ def count_parameters(self) -> int:
48
+ """Count total trainable parameters."""
49
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
50
+
51
+ def get_num_layers(self) -> int:
52
+ """Get number of layers (for model architecture)."""
53
+ return self.config.num_layers
54
+
55
+
56
+ # ============================================================================
57
+ # Base Dataset
58
+ # ============================================================================
59
+
60
+
61
+ class BaseDataset(TorchDataset, ABC):
62
+ """Abstract base class for datasets."""
63
+
64
+ def __init__(self, config: "TrainingConfig"):
65
+ """Initialize dataset."""
66
+ self.config = config
67
+ self.data = None
68
+
69
+ @abstractmethod
70
+ def __len__(self) -> int:
71
+ """Return dataset size."""
72
+ pass
73
+
74
+ @abstractmethod
75
+ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
76
+ """
77
+ Get a single sample.
78
+
79
+ Returns:
80
+ Dict with keys:
81
+ - 'input_ids': 1D tensor of token IDs
82
+ - 'attention_mask': 1D tensor of attention mask
83
+ - 'labels': 1D tensor of labels (optional)
84
+ """
85
+ pass
86
+
87
+ def load_dataset(self) -> None:
88
+ """Load dataset from HuggingFace or other source."""
89
+ pass
90
+
91
+ def preprocess(self) -> None:
92
+ """Preprocess dataset (tokenization, etc)."""
93
+ pass
94
+
95
+
96
+ # ============================================================================
97
+ # Base Trainer
98
+ # ============================================================================
99
+
100
+
101
+ class BaseTrainer(ABC):
102
+ """Abstract base class for trainers."""
103
+
104
+ def __init__(
105
+ self,
106
+ model: BaseModel,
107
+ train_dataset: BaseDataset,
108
+ val_dataset: Optional[BaseDataset],
109
+ config: TrainingConfig,
110
+ device: torch.device,
111
+ ):
112
+ """Initialize trainer."""
113
+ self.model = model.to(device)
114
+ self.train_dataset = train_dataset
115
+ self.val_dataset = val_dataset
116
+ self.config = config
117
+ self.device = device
118
+
119
+ # Training state
120
+ self.global_step = 0
121
+ self.current_epoch = 0
122
+ self.best_loss = float('inf')
123
+
124
+ # Logging
125
+ self.logger = None
126
+
127
+ # Optimizer and scheduler (to be set up by subclass)
128
+ self.optimizer = None
129
+ self.scheduler = None
130
+
131
+ @abstractmethod
132
+ def training_step(self, batch: dict[str, torch.Tensor]) -> dict[str, float]:
133
+ """
134
+ Single training step.
135
+
136
+ Args:
137
+ batch: Training batch with input_ids, attention_mask, labels, etc.
138
+
139
+ Returns:
140
+ Dict with metrics (e.g., {'loss': 0.5, 'accuracy': 0.8})
141
+ """
142
+ pass
143
+
144
+ @abstractmethod
145
+ def validation_step(self, batch: dict[str, torch.Tensor]) -> dict[str, float]:
146
+ """
147
+ Single validation step.
148
+
149
+ Args:
150
+ batch: Validation batch
151
+
152
+ Returns:
153
+ Dict with validation metrics
154
+ """
155
+ pass
156
+
157
+ @abstractmethod
158
+ def train_epoch(self) -> dict[str, float]:
159
+ """
160
+ Train for one epoch.
161
+
162
+ Returns:
163
+ Dict with epoch-level metrics
164
+ """
165
+ pass
166
+
167
+ @abstractmethod
168
+ def validate(self) -> dict[str, float]:
169
+ """
170
+ Run validation on the entire validation set.
171
+
172
+ Returns:
173
+ Dict with validation metrics
174
+ """
175
+ pass
176
+
177
+ def save_checkpoint(self, path: str | Path) -> None:
178
+ """
179
+ Save checkpoint in canonical format.
180
+
181
+ Uses canonical checkpoint format:
182
+ {
183
+ 'step': int,
184
+ 'model_state': state_dict,
185
+ 'optimizer_state': state_dict,
186
+ 'config': dict,
187
+ 'metrics': dict,
188
+ 'global_step': int, # Legacy compat
189
+ 'current_epoch': int, # Legacy compat
190
+ 'best_loss': float, # Legacy compat
191
+ }
192
+
193
+ Args:
194
+ path: Path to save checkpoint
195
+ """
196
+ path = Path(path)
197
+ path.parent.mkdir(parents=True, exist_ok=True)
198
+
199
+ # Save in canonical format
200
+ checkpoint = {
201
+ # Canonical format keys
202
+ 'step': self.global_step,
203
+ 'model_state': self.model.state_dict(),
204
+ 'optimizer_state': self.optimizer.state_dict() if self.optimizer else None,
205
+ 'config': self.config.to_dict(),
206
+ 'metrics': {},
207
+ # Legacy format keys (for backward compatibility with code that reads them)
208
+ 'global_step': self.global_step,
209
+ 'current_epoch': self.current_epoch,
210
+ 'best_loss': self.best_loss,
211
+ }
212
+
213
+ torch.save(checkpoint, path)
214
+
215
+ def load_checkpoint(self, path: str | Path) -> None:
216
+ """
217
+ Load checkpoint (handles both canonical and legacy formats).
218
+
219
+ Args:
220
+ path: Path to checkpoint
221
+ """
222
+ path = Path(path)
223
+ checkpoint = torch.load(path, map_location=self.device)
224
+
225
+ # Try canonical keys first, fall back to legacy keys
226
+ model_state_key = 'model_state' if 'model_state' in checkpoint else 'model_state_dict'
227
+ optimizer_state_key = 'optimizer_state' if 'optimizer_state' in checkpoint else 'optimizer_state_dict'
228
+
229
+ self.model.load_state_dict(checkpoint[model_state_key])
230
+ if self.optimizer and checkpoint.get(optimizer_state_key):
231
+ self.optimizer.load_state_dict(checkpoint[optimizer_state_key])
232
+
233
+ # Try canonical 'step' first, fall back to legacy 'global_step'
234
+ self.global_step = checkpoint.get('step', checkpoint.get('global_step', 0))
235
+ self.current_epoch = checkpoint.get('current_epoch', 0)
236
+ self.best_loss = checkpoint.get('best_loss', float('inf'))
237
+
238
+ def _get_lr(self) -> float:
239
+ """Get current learning rate from optimizer."""
240
+ for param_group in self.optimizer.param_groups:
241
+ return param_group['lr']
242
+ return 0.0
243
+
244
+
245
+ # ============================================================================
246
+ # Utility functions
247
+ # ============================================================================
248
+
249
+
250
+ def create_model(config: TrainingConfig, device: torch.device) -> BaseModel:
251
+ """Create model from config (calls registry)."""
252
+ from taoTrain.models import get_model
253
+ return get_model(config.model, device=device)
254
+
255
+
256
+ def create_datasets(
257
+ config: TrainingConfig,
258
+ ) -> tuple[BaseDataset, Optional[BaseDataset]]:
259
+ """Create train and validation datasets using factory pattern."""
260
+ # Import here to avoid circular imports
261
+ from taoTrain.data import DatasetFactory
262
+
263
+ # Create train dataset
264
+ train_dataset = DatasetFactory.create_dataset(config, split="train")
265
+
266
+ # Create validation dataset (only for HuggingFace datasets with explicit validation split)
267
+ val_dataset = None
268
+ if not config.dataset.local and hasattr(config.dataset, "validation_split"):
269
+ val_dataset = DatasetFactory.create_dataset(config, split="validation")
270
+
271
+ return train_dataset, val_dataset
code/TaoTrain/src/taoTrain/data/__init__.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dataset implementations and loaders."""
2
+
3
+ # HuggingFace-based datasets are optional for JSONL-only deployments.
4
+ try:
5
+ from .hf_base import BaseHFDataset
6
+ from .hf_pretrain import PretrainDataset
7
+ from .hf_sft import SFTDataset
8
+ from .hf_rl import RLDataset
9
+ except ImportError:
10
+ BaseHFDataset = None
11
+ PretrainDataset = None
12
+ SFTDataset = None
13
+ RLDataset = None
14
+
15
+ # JSONL-based datasets (async-only)
16
+ from .jsonl_base import BaseJSONLDataset
17
+ from .pretrain_jsonl import PretrainJSONLDataset
18
+ from .sft_jsonl import SFTJSONLDataset
19
+ from .rl_jsonl import RLJSONLDataset
20
+
21
+ # Utilities
22
+ from .tokenizer import SentencePieceTokenizerWrapper
23
+ from .sft_utils import (
24
+ parse_sft_record,
25
+ build_sft_sequence_tokens,
26
+ apply_response_masking,
27
+ build_response_only_next_token_labels,
28
+ )
29
+ from .loaders import get_dataloader
30
+ from .async_loader import AsyncBatchIterator
31
+ from .tokenization_queue import TokenizationQueue
32
+ from .factory import DatasetFactory
33
+
34
+ __all__ = [
35
+ # HuggingFace datasets
36
+ "BaseHFDataset",
37
+ "PretrainDataset",
38
+ "SFTDataset",
39
+ "RLDataset",
40
+ # JSONL datasets
41
+ "BaseJSONLDataset",
42
+ "PretrainJSONLDataset",
43
+ "SFTJSONLDataset",
44
+ "RLJSONLDataset",
45
+ # Utilities
46
+ "SentencePieceTokenizerWrapper",
47
+ "parse_sft_record",
48
+ "build_sft_sequence_tokens",
49
+ "apply_response_masking",
50
+ "build_response_only_next_token_labels",
51
+ # Data loading
52
+ "get_dataloader",
53
+ "AsyncBatchIterator",
54
+ "TokenizationQueue",
55
+ "DatasetFactory",
56
+ ]
code/TaoTrain/src/taoTrain/data/async_loader.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Async batch iterator for training with background tokenization."""
2
+
3
+ from typing import Dict, List, Optional, Any, Iterator
4
+ import torch
5
+
6
+ from taoTrain.data.tokenization_queue import TokenizationQueue
7
+ from taoTrain.data.sft_utils import build_response_only_next_token_labels
8
+
9
+
10
+ class AsyncBatchIterator:
11
+ """
12
+ Iterator that yields batches from a tokenization queue.
13
+
14
+ This allows batches to be consumed directly from the background tokenization
15
+ thread without waiting for all chunks to be tokenized upfront.
16
+
17
+ The iterator:
18
+ 1. Pulls pre-tokenized chunks from the TokenizationQueue
19
+ 2. Yields individual samples or batches
20
+ 3. Handles movement to device (GPU/CPU) at batch level
21
+ 4. Supports gradient accumulation
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ tokenization_queue: TokenizationQueue,
27
+ batch_size: int,
28
+ device: torch.device,
29
+ drop_last: bool = True,
30
+ gradient_accumulation_steps: int = 1,
31
+ ):
32
+ """
33
+ Initialize async batch iterator.
34
+
35
+ Args:
36
+ tokenization_queue: TokenizationQueue instance
37
+ batch_size: Batch size for yielding batches
38
+ device: torch.device to move batches to
39
+ drop_last: If True, drop last incomplete batch
40
+ gradient_accumulation_steps: For logging purposes (not used here)
41
+ """
42
+ self.queue = tokenization_queue
43
+ self.batch_size = batch_size
44
+ self.device = device
45
+ self.drop_last = drop_last
46
+ self.gradient_accumulation_steps = gradient_accumulation_steps
47
+
48
+ # State for iteration
49
+ self._current_chunk: Optional[Dict[str, List]] = None
50
+ self._current_idx = 0
51
+ self._samples_yielded = 0
52
+ self._finished = False
53
+
54
+ def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
55
+ """Return iterator (self)."""
56
+ # Reset state for new epoch
57
+ self._current_chunk = None
58
+ self._current_idx = 0
59
+ self._samples_yielded = 0
60
+ self._finished = False
61
+
62
+ # Reset tokenization queue for epochs 2+
63
+ if self.queue._next_chunk_idx > 0:
64
+ print(f"\n✓ Resetting TokenizationQueue for next epoch (cur_idx={self.queue._next_chunk_idx})")
65
+ self.queue.reset_for_next_epoch()
66
+
67
+ # Start tokenization threads once per iterator creation
68
+ if not self.queue._threads:
69
+ print("\n✓ Starting TokenizationQueue worker threads...")
70
+ self.queue.start()
71
+ else:
72
+ print(f"\n⚠ TokenizationQueue threads already running: {len(self.queue._threads)} active")
73
+
74
+ return self
75
+
76
+ def __next__(self) -> Dict[str, torch.Tensor]:
77
+ """
78
+ Get next batch.
79
+
80
+ Yields:
81
+ Dict with 'input_ids', 'attention_mask', 'labels' (all as torch tensors on device)
82
+
83
+ Raises:
84
+ StopIteration: When no more batches available
85
+ """
86
+ batch = self._get_next_batch()
87
+
88
+ if batch is None:
89
+ print("AsyncBatchIterator: No more batches available, stopping iteration.")
90
+ raise StopIteration
91
+
92
+ return batch
93
+
94
+ def _get_next_batch(self) -> Optional[Dict[str, torch.Tensor]]:
95
+ """
96
+ Fetch and collate the next batch.
97
+
98
+ Returns:
99
+ Dict with batch tensors, or None if iteration exhausted
100
+ """
101
+ batch_input_ids = []
102
+ batch_attention_masks = []
103
+ batch_labels = []
104
+
105
+ while len(batch_input_ids) < self.batch_size:
106
+ # Try to get next sample from current chunk
107
+ if self._current_chunk is None or self._current_idx >= len(self._current_chunk["input_ids"]):
108
+ # Need new chunk
109
+ self._current_chunk = self.queue.get_next_chunk(timeout=30.0) # 30s polling timeout
110
+
111
+ if self._current_chunk is None:
112
+ if not self.queue.is_exhausted:
113
+ continue
114
+ # Queue exhausted
115
+ chunk_count = self.queue._next_chunk_idx if hasattr(self.queue, '_next_chunk_idx') else 'unknown'
116
+ print(f"AsyncBatchIterator: No more chunks (processed {chunk_count}/{len(self.queue._chunk_order)})")
117
+ print(f"AsyncBatchIterator: Samples yielded so far: {self._samples_yielded}")
118
+ self._finished = True
119
+ break
120
+
121
+ self._current_idx = 0
122
+
123
+ # Get sample from current chunk
124
+ input_ids = self._current_chunk["input_ids"][self._current_idx]
125
+ attention_mask = self._current_chunk["attention_mask"][self._current_idx]
126
+
127
+ # Generate labels based on SFT or pretrain mode
128
+ if "mask" in self._current_chunk:
129
+ # SFT mode: use mask to determine which tokens to train on
130
+ # mask=0 → label=-100 (ignore), mask=1 → label=input_id (train on)
131
+ mask = self._current_chunk["mask"][self._current_idx]
132
+ labels = build_response_only_next_token_labels(input_ids, mask)
133
+ else:
134
+ # Pretrain mode: shift labels by 1 for next-token prediction
135
+ # Position i predicts token at position i+1
136
+ labels = input_ids[1:] + [-100] # Append -100 as final position
137
+
138
+ # Mark padding tokens as -100 to ignore in loss computation
139
+ for i, mask_val in enumerate(attention_mask):
140
+ if mask_val == 0:
141
+ labels[i] = -100
142
+
143
+ batch_input_ids.append(input_ids)
144
+ batch_attention_masks.append(attention_mask)
145
+ batch_labels.append(labels)
146
+
147
+ self._current_idx += 1
148
+ self._samples_yielded += 1
149
+
150
+ # Return batch if we have any samples, respecting drop_last
151
+ if len(batch_input_ids) == 0:
152
+ print(f"AsyncBatchIterator: No samples collected for batch. Finished={self._finished}, returning None.")
153
+ return None
154
+
155
+ if len(batch_input_ids) < self.batch_size and self.drop_last:
156
+ incomplete_pct = (len(batch_input_ids) / self.batch_size) * 100
157
+ print(f"AsyncBatchIterator: Batch incomplete ({len(batch_input_ids)}/{self.batch_size} = {incomplete_pct:.1f}%) and drop_last=True, returning None.")
158
+ return None
159
+
160
+ return self._collate_batch(batch_input_ids, batch_attention_masks, batch_labels)
161
+
162
+ def _collate_batch(
163
+ self,
164
+ batch_input_ids: List[List[int]],
165
+ batch_attention_masks: List[List[int]],
166
+ batch_labels: List[List[int]],
167
+ ) -> Dict[str, torch.Tensor]:
168
+ """
169
+ Collate batch samples and move to device.
170
+
171
+ Args:
172
+ batch_input_ids: List of token ID lists
173
+ batch_attention_masks: List of attention mask lists
174
+ batch_labels: List of label lists
175
+
176
+ Returns:
177
+ Collated batch as torch tensors on device
178
+ """
179
+ # Convert to tensors
180
+ input_ids_tensor = torch.tensor(batch_input_ids, dtype=torch.long, device=self.device)
181
+ attention_mask_tensor = torch.tensor(batch_attention_masks, dtype=torch.long, device=self.device)
182
+ labels_tensor = torch.tensor(batch_labels, dtype=torch.long, device=self.device)
183
+
184
+ return {
185
+ "input_ids": input_ids_tensor,
186
+ "attention_mask": attention_mask_tensor,
187
+ "labels": labels_tensor,
188
+ }
189
+
190
+ def __len__(self) -> int:
191
+ """Return approximate number of batches."""
192
+ total_samples = len(self.queue)
193
+ if self.drop_last:
194
+ return total_samples // self.batch_size
195
+ else:
196
+ return (total_samples + self.batch_size - 1) // self.batch_size
197
+
198
+ def shutdown(self):
199
+ """Shutdown the async iterator and background thread."""
200
+ self.queue.shutdown(wait=True)
201
+
202
+ def __del__(self):
203
+ """Cleanup on deletion."""
204
+ self.shutdown()
code/TaoTrain/src/taoTrain/data/chunk_manager.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Chunk manager for streaming large JSONL datasets."""
2
+
3
+ import os
4
+ import json
5
+ import hashlib
6
+ from typing import Tuple, Optional, Dict, Any
7
+ from pathlib import Path
8
+ from tqdm import tqdm
9
+
10
+
11
+ class ChunkManager:
12
+ """
13
+ Manages chunked reading of large JSONL files.
14
+
15
+ This class handles:
16
+ - File scanning to count total lines without loading all text
17
+ - Estimating chunk boundaries based on file size
18
+ - Tracking which line ranges belong to each chunk
19
+ """
20
+
21
+ def __init__(self, jsonl_path: str, chunk_size_gb: float = 5.0,
22
+ samples_per_chunk: Optional[int] = None,
23
+ enable_metadata_cache: bool = True, chunk_cache_dir: str = ".cache/chunks",
24
+ max_samples: Optional[int] = None):
25
+ """
26
+ Initialize ChunkManager.
27
+
28
+ Args:
29
+ jsonl_path: Path to JSONL file
30
+ chunk_size_gb: Approximate chunk size in GB (ignored if samples_per_chunk is set)
31
+ samples_per_chunk: Number of samples per chunk (takes precedence over chunk_size_gb)
32
+ enable_metadata_cache: Enable caching of file scan metadata
33
+ chunk_cache_dir: Directory to store cache files
34
+ max_samples: Limit total samples to at most this many (if total_lines > max_samples)
35
+
36
+ Raises:
37
+ FileNotFoundError: If JSONL file doesn't exist
38
+ ValueError: If file is empty
39
+ """
40
+ self.jsonl_path = Path(jsonl_path)
41
+ self.chunk_size_bytes = int(chunk_size_gb * 1024 ** 3) # Convert GB to bytes
42
+ self.max_samples = max_samples # Limit total samples if specified
43
+ print (f"Initializing ChunkManager for {self.jsonl_path} with target chunk size {chunk_size_gb} GB")
44
+ if samples_per_chunk is not None:
45
+ print(f" Overriding chunk size with {samples_per_chunk} samples per chunk")
46
+ if max_samples is not None:
47
+ print(f" Limiting dataset to {max_samples} samples")
48
+ self.samples_per_chunk = samples_per_chunk # If set, overrides GB-based chunking
49
+ self.enable_metadata_cache = enable_metadata_cache
50
+ self.chunk_cache_dir = Path(chunk_cache_dir)
51
+
52
+ if not self.jsonl_path.exists():
53
+ raise FileNotFoundError(f"JSONL file not found: {self.jsonl_path}")
54
+
55
+ self.file_size_bytes = os.path.getsize(self.jsonl_path)
56
+ self.file_mtime = os.path.getmtime(self.jsonl_path)
57
+
58
+ if self.file_size_bytes == 0:
59
+ raise ValueError("JSONL file is empty")
60
+
61
+ # Will be populated by _scan_file()
62
+ self.total_lines = 0
63
+ self.effective_lines = 0
64
+ self.line_sizes = [] # bytes per line
65
+ self.valid_line_offsets = [] # byte offset of each VALID JSON line (for seeking)
66
+ self.chunk_line_ranges = [] # [(start_line, end_line), ...]
67
+
68
+ # Try to load from cache first
69
+ cache_loaded = False
70
+ if self.enable_metadata_cache:
71
+ cache_loaded = self._load_metadata_cache()
72
+
73
+ # If cache not used, scan the file
74
+ if not cache_loaded:
75
+ self._scan_file()
76
+ self._compute_chunk_ranges()
77
+
78
+ # Save metadata cache for future runs
79
+ if self.enable_metadata_cache:
80
+ self._save_metadata_cache()
81
+ else:
82
+ # Cache stores file scan metadata. Recompute chunk ranges for the
83
+ # current training config so samples_per_chunk/max_samples changes
84
+ # are honored without rescanning the large JSONL file.
85
+ self._compute_chunk_ranges()
86
+
87
+ def _get_cache_path(self) -> Path:
88
+ """Get the metadata cache file path for this JSONL file."""
89
+ # Create a hash of the file path to use as cache filename
90
+ file_hash = hashlib.md5(str(self.jsonl_path.absolute()).encode()).hexdigest()[:8]
91
+ cache_file = self.chunk_cache_dir / f"{file_hash}.metadata.json"
92
+ return cache_file
93
+
94
+ def _load_metadata_cache(self) -> bool:
95
+ """
96
+ Load metadata from cache if it exists and is valid.
97
+
98
+ Returns:
99
+ True if cache was loaded successfully, False otherwise
100
+ """
101
+ cache_file = self._get_cache_path()
102
+
103
+ if not cache_file.exists():
104
+ return False
105
+
106
+ try:
107
+ with open(cache_file, 'r', encoding='utf-8') as f:
108
+ cache_data = json.load(f)
109
+
110
+ # Validate cache: check file hasn't changed
111
+ if (cache_data.get('file_size') != self.file_size_bytes or
112
+ cache_data.get('file_mtime') != self.file_mtime or
113
+ cache_data.get('jsonl_path') != str(self.jsonl_path.absolute())):
114
+ return False
115
+
116
+ # Load cached data
117
+ self.total_lines = cache_data.get('total_lines', 0)
118
+ self.line_sizes = cache_data.get('line_sizes', [])
119
+ self.valid_line_offsets = cache_data.get('valid_line_offsets', [])
120
+ # Convert loaded lists back to tuples for chunk_line_ranges
121
+ chunk_ranges = cache_data.get('chunk_line_ranges', [])
122
+ self.chunk_line_ranges = [tuple(r) for r in chunk_ranges]
123
+ self.chunk_size_bytes = cache_data.get('chunk_size_bytes', self.chunk_size_bytes)
124
+
125
+ print(f"✓ Loaded scan metadata from cache: {cache_file.name}")
126
+ print(f" Found {self.total_lines:,} valid JSON lines in {len(self.chunk_line_ranges)} chunks")
127
+ return True
128
+
129
+ except Exception as e:
130
+ # If cache loading fails, fall back to scanning
131
+ return False
132
+
133
+ def _save_metadata_cache(self) -> None:
134
+ """Save metadata cache to file."""
135
+ cache_file = self._get_cache_path()
136
+ cache_file.parent.mkdir(parents=True, exist_ok=True)
137
+
138
+ cache_data = {
139
+ 'jsonl_path': str(self.jsonl_path.absolute()),
140
+ 'file_size': self.file_size_bytes,
141
+ 'file_mtime': self.file_mtime,
142
+ 'total_lines': self.total_lines,
143
+ 'line_sizes': self.line_sizes,
144
+ 'valid_line_offsets': self.valid_line_offsets,
145
+ 'chunk_line_ranges': self.chunk_line_ranges,
146
+ 'chunk_size_bytes': self.chunk_size_bytes,
147
+ }
148
+
149
+ try:
150
+ # Write atomically using a temp file + rename
151
+ temp_file = cache_file.with_suffix('.tmp')
152
+ with open(temp_file, 'w', encoding='utf-8') as f:
153
+ json.dump(cache_data, f, indent=2)
154
+ temp_file.replace(cache_file)
155
+ print(f" Saved scan metadata to cache: {cache_file.name}")
156
+ except Exception as e:
157
+ print(f" ⚠ Warning: failed to save cache: {e}")
158
+
159
+ def _get_chunk_cache_dir(self) -> Path:
160
+ """Get the directory for storing cached chunk data for this JSONL file."""
161
+ file_hash = hashlib.md5(str(self.jsonl_path.absolute()).encode()).hexdigest()[:8]
162
+ chunk_dir = self.chunk_cache_dir / "chunks" / file_hash
163
+ return chunk_dir
164
+
165
+ def _get_chunk_cache_file(self, chunk_num: int) -> Path:
166
+ """Get the cache file path for a specific chunk."""
167
+ chunk_dir = self._get_chunk_cache_dir()
168
+ return chunk_dir / f"chunk_{chunk_num:06d}.jsonl"
169
+
170
+ def _get_chunk_index_file(self) -> Path:
171
+ """Get the index file that lists all cached chunks."""
172
+ chunk_dir = self._get_chunk_cache_dir()
173
+ return chunk_dir / "index.json"
174
+
175
+ def extract_and_cache_chunks(self) -> Dict[str, Any]:
176
+ """
177
+ Extract chunks from the original JSONL file and save them as separate cached files.
178
+
179
+ This is optional and should be called manually if you want to pre-cache chunks
180
+ for faster repeated access. It can significantly speed up training but uses more disk space.
181
+
182
+ Returns:
183
+ Dictionary with cache information:
184
+ - 'cache_dir': path to cache directory
185
+ - 'num_chunks': number of chunks cached
186
+ - 'total_size_gb': total size of cached chunks
187
+ """
188
+ chunk_dir = self._get_chunk_cache_dir()
189
+ chunk_dir.mkdir(parents=True, exist_ok=True)
190
+
191
+ print(f"💾 Extracting {len(self.chunk_line_ranges)} chunks to cache...")
192
+ total_size = 0
193
+
194
+ for chunk_num in range(len(self.chunk_line_ranges)):
195
+ cache_file = self._get_chunk_cache_file(chunk_num)
196
+
197
+ # Skip if already cached
198
+ if cache_file.exists():
199
+ total_size += os.path.getsize(cache_file)
200
+ continue
201
+
202
+ # Read chunk and save to cache file
203
+ chunk_examples = self.read_chunk(chunk_num, _from_cache=False)
204
+
205
+ with open(cache_file, 'w', encoding='utf-8') as f:
206
+ for obj in chunk_examples:
207
+ f.write(json.dumps(obj) + '\n')
208
+
209
+ total_size += os.path.getsize(cache_file)
210
+ if (chunk_num + 1) % max(1, len(self.chunk_line_ranges) // 10) == 0:
211
+ print(f" - Cached {chunk_num + 1}/{len(self.chunk_line_ranges)} chunks...")
212
+
213
+ # Write index file
214
+ index_data = {
215
+ 'jsonl_path': str(self.jsonl_path.absolute()),
216
+ 'num_chunks': len(self.chunk_line_ranges),
217
+ 'chunk_ranges': self.chunk_line_ranges,
218
+ }
219
+ with open(self._get_chunk_index_file(), 'w', encoding='utf-8') as f:
220
+ json.dump(index_data, f, indent=2)
221
+
222
+ print(f"✓ Cached {len(self.chunk_line_ranges)} chunks ({total_size / (1024**3):.2f} GB)")
223
+
224
+ return {
225
+ 'cache_dir': str(chunk_dir),
226
+ 'num_chunks': len(self.chunk_line_ranges),
227
+ 'total_size_gb': total_size / (1024**3),
228
+ }
229
+
230
+ def clear_chunk_cache(self, keep_metadata: bool = False) -> None:
231
+ """
232
+ Clear cached chunk data.
233
+
234
+ Args:
235
+ keep_metadata: If True, only remove chunk files, keep the metadata cache
236
+ """
237
+ chunk_dir = self._get_chunk_cache_dir()
238
+
239
+ if chunk_dir.exists():
240
+ import shutil
241
+ shutil.rmtree(chunk_dir)
242
+ print(f"✓ Cleared chunk cache: {chunk_dir}")
243
+
244
+ if not keep_metadata:
245
+ cache_file = self._get_cache_path()
246
+ if cache_file.exists():
247
+ cache_file.unlink()
248
+ print(f"✓ Cleared metadata cache: {cache_file}")
249
+
250
+ def _scan_file(self) -> None:
251
+ """
252
+ Scan JSONL file to count lines and track offsets.
253
+
254
+ This reads the file once to:
255
+ - Count total valid JSON lines
256
+ - Record byte offset of each VALID line for seeking
257
+ - Estimate size per line
258
+ """
259
+ print(f"📖 Scanning JSONL file: {self.jsonl_path}")
260
+ print(f" File size: {self.file_size_bytes / (1024**3):.2f} GB")
261
+
262
+ self.valid_line_offsets = []
263
+ current_offset = 0
264
+ valid_lines = 0
265
+
266
+ try:
267
+ with open(self.jsonl_path, 'r', encoding='utf-8') as f:
268
+ for line in tqdm(f, desc="Scanning JSONL", unit=" lines"):
269
+ # Skip empty lines - don't count toward line numbers
270
+ if not line.strip():
271
+ current_offset += len(line.encode('utf-8'))
272
+ continue
273
+
274
+ try:
275
+ json.loads(line)
276
+ # Valid JSON line - record its starting byte offset
277
+ self.valid_line_offsets.append(current_offset)
278
+ valid_lines += 1
279
+
280
+ line_bytes = len(line.encode('utf-8'))
281
+ self.line_sizes.append(line_bytes)
282
+
283
+ except json.JSONDecodeError:
284
+ # Skip invalid JSON lines - don't count toward line numbers
285
+ pass
286
+
287
+ current_offset += len(line.encode('utf-8'))
288
+
289
+ except Exception as e:
290
+ raise ValueError(f"Error scanning JSONL file: {e}")
291
+
292
+ self.total_lines = valid_lines
293
+
294
+ if self.total_lines == 0:
295
+ raise ValueError("No valid JSON lines found in JSONL file")
296
+
297
+ print(f"✓ Found {self.total_lines:,} valid JSON lines")
298
+
299
+ # Calculate average line size
300
+ avg_line_size = sum(self.line_sizes) / len(self.line_sizes) if self.line_sizes else 0
301
+ print(f" Average line size: {avg_line_size:.2f} bytes")
302
+ print(f" Chunk size target: {self.chunk_size_bytes / (1024**3):.2f} GB")
303
+
304
+ def _compute_chunk_ranges(self) -> None:
305
+ """
306
+ Compute line ranges for each chunk based on target chunk size.
307
+
308
+ If samples_per_chunk is set, uses that. Otherwise, divides file
309
+ based on chunk_size_bytes. If max_samples is set, limits chunks to cover
310
+ at most max_samples lines.
311
+ """
312
+ if self.total_lines == 0:
313
+ self.chunk_line_ranges = []
314
+ return
315
+
316
+ # Apply max_samples limit to effective line count
317
+ self.effective_lines = self.total_lines
318
+ if self.max_samples is not None:
319
+ self.effective_lines = min(self.total_lines, self.max_samples)
320
+
321
+ # Determine lines per chunk
322
+ if self.samples_per_chunk is not None:
323
+ # Use explicit sample count
324
+ lines_per_chunk = self.samples_per_chunk
325
+ else:
326
+ # Use GB-based calculation
327
+ avg_line_size = sum(self.line_sizes) / len(self.line_sizes) if self.line_sizes else 1
328
+ lines_per_chunk = max(1, int(self.chunk_size_bytes / avg_line_size))
329
+
330
+ chunk_ranges = []
331
+ start_line = 0
332
+
333
+ # Create chunks up to self.effective_lines (honors max_samples)
334
+ while start_line < self.effective_lines:
335
+ end_line = min(start_line + lines_per_chunk, self.effective_lines)
336
+ chunk_ranges.append((start_line, end_line))
337
+ start_line = end_line
338
+
339
+ self.chunk_line_ranges = chunk_ranges
340
+ self.num_chunks = len(chunk_ranges)
341
+
342
+ print(f" Divided into {self.num_chunks} chunks (covering {self.effective_lines:,} lines)")
343
+
344
+ def get_chunk_indices(self, chunk_num: int) -> Tuple[int, int]:
345
+ """
346
+ Get (start_line, end_line) for a given chunk number.
347
+
348
+ Args:
349
+ chunk_num: Chunk number (0-indexed)
350
+
351
+ Returns:
352
+ Tuple of (start_line, end_line) where end_line is exclusive
353
+
354
+ Raises:
355
+ IndexError: If chunk_num is out of range
356
+ """
357
+ if chunk_num < 0 or chunk_num >= len(self.chunk_line_ranges):
358
+ raise IndexError(f"Chunk {chunk_num} out of range [0, {len(self.chunk_line_ranges)-1}]")
359
+
360
+ return self.chunk_line_ranges[chunk_num]
361
+
362
+ def read_chunk(self, chunk_num: int, _from_cache: bool = True) -> list[dict]:
363
+ """
364
+ Read a specific chunk and return parsed JSON objects.
365
+
366
+ If chunk cache is available, reads from cache. Otherwise reads from original JSONL
367
+ using file.seek() for O(1) lookup instead of O(n) scanning.
368
+
369
+ Args:
370
+ chunk_num: Chunk number (0-indexed)
371
+ _from_cache: Internal parameter to force reading from original (used during cache extraction)
372
+
373
+ Returns:
374
+ List of parsed JSON objects from that chunk
375
+
376
+ Raises:
377
+ IndexError: If chunk_num is out of range
378
+ ValueError: If JSON parsing fails
379
+ """
380
+ # Try to read from cache first (if it exists)
381
+ if _from_cache:
382
+ cache_file = self._get_chunk_cache_file(chunk_num)
383
+ if cache_file.exists():
384
+ examples = []
385
+ try:
386
+ with open(cache_file, 'r', encoding='utf-8') as f:
387
+ for line in f:
388
+ if line.strip():
389
+ try:
390
+ obj = json.loads(line)
391
+ examples.append(obj)
392
+ except json.JSONDecodeError:
393
+ pass
394
+ return examples
395
+ except Exception as e:
396
+ print(f" ⚠ Warning: failed to read chunk from cache, falling back to original: {e}")
397
+
398
+ # Read from original JSONL file using seek optimization
399
+ start_line, end_line = self.get_chunk_indices(chunk_num)
400
+
401
+ examples = []
402
+
403
+ with open(self.jsonl_path, 'r', encoding='utf-8') as f:
404
+ # Seek to the byte offset of the start line
405
+ # This is O(1) instead of O(start_line) iteration
406
+ if start_line < len(self.valid_line_offsets):
407
+ f.seek(self.valid_line_offsets[start_line])
408
+ else:
409
+ # Fallback if valid_line_offsets not available (shouldn't happen)
410
+ f.seek(0)
411
+
412
+ current_line = start_line
413
+
414
+ # Read lines from start_line to end_line
415
+ for line in f:
416
+ # Skip empty lines
417
+ if not line.strip():
418
+ continue
419
+
420
+ # Stop when we've read enough lines
421
+ if current_line >= end_line:
422
+ break
423
+
424
+ try:
425
+ obj = json.loads(line)
426
+ examples.append(obj)
427
+ current_line += 1
428
+ except json.JSONDecodeError:
429
+ # Skip invalid JSON lines, but don't increment line counter
430
+ # This maintains alignment with line numbering from scan
431
+ pass
432
+
433
+ return examples
434
+
435
+ @property
436
+ def num_chunks(self) -> int:
437
+ """Return number of chunks."""
438
+ return len(self.chunk_line_ranges)
439
+
440
+ @num_chunks.setter
441
+ def num_chunks(self, value: int) -> None:
442
+ """Set number of chunks (internal use)."""
443
+ self._num_chunks = value
444
+
445
+ def __repr__(self) -> str:
446
+ """String representation."""
447
+ return (
448
+ f"ChunkManager(file={self.jsonl_path.name}, "
449
+ f"size={self.file_size_bytes/(1024**3):.2f}GB, "
450
+ f"lines={self.effective_lines:,}, "
451
+ f"chunks={self.num_chunks})"
452
+ )
code/TaoTrain/src/taoTrain/data/factory.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Factory for creating datasets based on configuration."""
2
+
3
+ from taoTrain.config import TrainingConfig, TrainingModeEnum
4
+ from taoTrain.data.pretrain_jsonl import PretrainJSONLDataset
5
+ from taoTrain.data.sft_jsonl import SFTJSONLDataset
6
+ from taoTrain.data.rl_jsonl import RLJSONLDataset
7
+
8
+ try:
9
+ from taoTrain.data.hf_pretrain import PretrainDataset
10
+ from taoTrain.data.hf_sft import SFTDataset
11
+ from taoTrain.data.hf_rl import RLDataset
12
+ except ImportError:
13
+ PretrainDataset = None
14
+ SFTDataset = None
15
+ RLDataset = None
16
+
17
+
18
+ class DatasetFactory:
19
+ """Factory for creating datasets based on configuration."""
20
+
21
+ # Registry of dataset classes by mode and backend
22
+ DATASETS = {
23
+ (TrainingModeEnum.PRETRAIN, "jsonl"): PretrainJSONLDataset,
24
+ (TrainingModeEnum.SFT, "jsonl"): SFTJSONLDataset,
25
+ (TrainingModeEnum.RL, "jsonl"): RLJSONLDataset,
26
+ }
27
+
28
+ if PretrainDataset is not None:
29
+ DATASETS.update({
30
+ (TrainingModeEnum.PRETRAIN, "huggingface"): PretrainDataset,
31
+ (TrainingModeEnum.SFT, "huggingface"): SFTDataset,
32
+ (TrainingModeEnum.RL, "huggingface"): RLDataset,
33
+ })
34
+
35
+ @staticmethod
36
+ def create_dataset(
37
+ config: TrainingConfig,
38
+ split: str = "train",
39
+ ):
40
+ """
41
+ Create dataset instance based on configuration.
42
+
43
+ Args:
44
+ config: Training configuration
45
+ split: Dataset split (train, validation, test) - primarily for HuggingFace datasets
46
+
47
+ Returns:
48
+ Dataset instance matching the configured mode and backend
49
+
50
+ Raises:
51
+ ValueError: If configuration is invalid or unsupported mode/backend combination
52
+ """
53
+ # Determine backend: JSONL or HuggingFace
54
+ if config.dataset.local:
55
+ backend = "jsonl"
56
+ else:
57
+ backend = "huggingface"
58
+
59
+ # Get mode
60
+ mode = config.mode
61
+
62
+ # Look up dataset class
63
+ key = (mode, backend)
64
+ if key not in DatasetFactory.DATASETS:
65
+ if backend == "huggingface":
66
+ raise ImportError(
67
+ "HuggingFace dataset support requires the optional 'datasets' dependency. "
68
+ "Install project dependencies before using dataset.local=false."
69
+ )
70
+ raise ValueError(
71
+ f"Unsupported dataset configuration: mode={mode.value}, backend={backend}. "
72
+ f"Supported: {list(DatasetFactory.DATASETS.keys())}"
73
+ )
74
+
75
+ dataset_class = DatasetFactory.DATASETS[key]
76
+
77
+ # Instantiate dataset
78
+ if backend == "jsonl":
79
+ # JSONL datasets don't use split parameter
80
+ return dataset_class(config)
81
+ else:
82
+ # HuggingFace datasets use split parameter
83
+ return dataset_class(config, split=split)
84
+
85
+ @staticmethod
86
+ def register_dataset(mode: TrainingModeEnum, backend: str, dataset_class):
87
+ """
88
+ Register a custom dataset class.
89
+
90
+ Args:
91
+ mode: Training mode (e.g., TrainingModeEnum.PRETRAIN)
92
+ backend: Backend name (e.g., "jsonl", "huggingface")
93
+ dataset_class: Dataset class to register
94
+ """
95
+ DatasetFactory.DATASETS[(mode, backend)] = dataset_class
96
+
97
+ @staticmethod
98
+ def list_available_datasets():
99
+ """List all available dataset configurations."""
100
+ configs = {}
101
+ for (mode, backend), dataset_class in DatasetFactory.DATASETS.items():
102
+ key = f"{mode.value}_{backend}"
103
+ configs[key] = {
104
+ "mode": mode.value,
105
+ "backend": backend,
106
+ "class": dataset_class.__name__,
107
+ }
108
+ return configs
code/TaoTrain/src/taoTrain/data/hf_base.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Base class for HuggingFace-based datasets."""
2
+
3
+ from typing import Optional, Dict
4
+ import torch
5
+ from torch.utils.data import Dataset
6
+ from datasets import load_dataset
7
+ from transformers import AutoTokenizer
8
+ from taoTrain.config import TrainingConfig
9
+
10
+
11
+ class BaseHFDataset(Dataset):
12
+ """Base class for HuggingFace-based datasets."""
13
+
14
+ def __init__(self, config: TrainingConfig, split: str = "train"):
15
+ """
16
+ Initialize dataset.
17
+
18
+ Args:
19
+ config: Training configuration
20
+ split: Dataset split (train, validation, test)
21
+ """
22
+ self.config = config
23
+ self.split = split
24
+ self.data = None
25
+ self.tokenizer = None
26
+
27
+ # Load tokenizer
28
+ self._load_tokenizer()
29
+
30
+ # Load and preprocess dataset
31
+ self._load_dataset()
32
+ self._preprocess()
33
+
34
+ def _load_tokenizer(self):
35
+ """Load tokenizer from HuggingFace."""
36
+ # Default to GPT-2 tokenizer if not specified
37
+ tokenizer_name = getattr(self.config, 'tokenizer_name', 'gpt2')
38
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
39
+
40
+ # Set pad token if not set
41
+ if self.tokenizer.pad_token is None:
42
+ self.tokenizer.pad_token = self.tokenizer.eos_token
43
+
44
+ def _load_dataset(self):
45
+ """Load dataset from HuggingFace."""
46
+ dataset_config = self.config.dataset
47
+
48
+ try:
49
+ # Load dataset
50
+ if dataset_config.config:
51
+ self.data = load_dataset(
52
+ dataset_config.dataset_name,
53
+ dataset_config.config,
54
+ split=self.split,
55
+ cache_dir=dataset_config.cache_dir,
56
+ trust_remote_code=True,
57
+ )
58
+ else:
59
+ self.data = load_dataset(
60
+ dataset_config.dataset_name,
61
+ split=self.split,
62
+ cache_dir=dataset_config.cache_dir,
63
+ trust_remote_code=True,
64
+ )
65
+ except Exception as e:
66
+ raise ValueError(f"Failed to load dataset {dataset_config.dataset_name}: {e}")
67
+
68
+ # Limit samples if specified
69
+ if dataset_config.max_samples:
70
+ self.data = self.data.select(range(min(dataset_config.max_samples, len(self.data))))
71
+
72
+ def _preprocess(self):
73
+ """Preprocess dataset (to be implemented by subclasses)."""
74
+ pass
75
+
76
+ def __len__(self) -> int:
77
+ """Return dataset length."""
78
+ return len(self.data)
79
+
80
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
81
+ """Get item (to be implemented by subclasses)."""
82
+ pass
code/TaoTrain/src/taoTrain/data/hf_pretrain.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pretrain dataset for HuggingFace datasets."""
2
+
3
+ from typing import Dict
4
+ import torch
5
+ from taoTrain.config import TrainingConfig
6
+ from taoTrain.data.hf_base import BaseHFDataset
7
+
8
+
9
+ class PretrainDataset(BaseHFDataset):
10
+ """Dataset for pretraining with raw text."""
11
+
12
+ def _preprocess(self):
13
+ """Tokenize text data."""
14
+ dataset_config = self.config.dataset
15
+ text_column = dataset_config.text_column
16
+
17
+ def tokenize_function(examples):
18
+ # Concatenate all texts
19
+ concatenated_examples = {
20
+ k: sum(examples[k], []) for k in examples.keys()
21
+ }
22
+
23
+ total_length = len(concatenated_examples[text_column])
24
+ # We'll use max_seq_length for training
25
+ total_length = (total_length // self.config.model.max_seq_length) * self.config.model.max_seq_length
26
+
27
+ # Tokenize
28
+ tokenized = self.tokenizer(
29
+ concatenated_examples[text_column],
30
+ truncation=False, # We'll chunk below
31
+ return_special_tokens_mask=False,
32
+ )
33
+
34
+ # Chunk tokenized text
35
+ result = {
36
+ "input_ids": [],
37
+ "attention_mask": [],
38
+ }
39
+
40
+ for i in range(0, total_length, self.config.model.max_seq_length):
41
+ result["input_ids"].append(
42
+ tokenized["input_ids"][i:i + self.config.model.max_seq_length]
43
+ )
44
+ result["attention_mask"].append(
45
+ tokenized["attention_mask"][i:i + self.config.model.max_seq_length]
46
+ )
47
+
48
+ return result
49
+
50
+ # Preprocess in batches
51
+ self.data = self.data.map(
52
+ tokenize_function,
53
+ batched=True,
54
+ batch_size=100,
55
+ remove_columns=self.data.column_names,
56
+ desc="Tokenizing...",
57
+ )
58
+
59
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
60
+ """Get preprocessed sample."""
61
+ item = self.data[idx]
62
+
63
+ input_ids = torch.tensor(item["input_ids"], dtype=torch.long)
64
+ attention_mask = torch.tensor(item["attention_mask"], dtype=torch.long)
65
+
66
+ # For pretrain, labels = input_ids shifted by 1 (next token prediction)
67
+ # Position i predicts token at position i+1
68
+ labels = input_ids[1:].clone()
69
+ labels = torch.cat([labels, torch.tensor([-100])], dim=0)
70
+
71
+ # Mark padding tokens as -100 to ignore in loss computation
72
+ labels[attention_mask == 0] = -100
73
+
74
+ return {
75
+ "input_ids": input_ids,
76
+ "attention_mask": attention_mask,
77
+ "labels": labels,
78
+ }
code/TaoTrain/src/taoTrain/data/hf_rl.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """RL dataset for HuggingFace datasets."""
2
+
3
+ from typing import Dict
4
+ import torch
5
+ from taoTrain.config import TrainingConfig
6
+ from taoTrain.data.hf_base import BaseHFDataset
7
+
8
+
9
+ class RLDataset(BaseHFDataset):
10
+ """Dataset for RL training with prompts."""
11
+
12
+ def _preprocess(self):
13
+ """Prepare prompts for RL."""
14
+ dataset_config = self.config.dataset
15
+
16
+ # For RL, we typically just need prompts (no responses)
17
+ # The responses will be generated by the model during training
18
+
19
+ if dataset_config.prompt_column:
20
+ # Use existing prompt column
21
+ def extract_prompt(example):
22
+ return {"prompt": example[dataset_config.prompt_column]}
23
+
24
+ self.data = self.data.map(
25
+ extract_prompt,
26
+ remove_columns=self.data.column_names,
27
+ desc="Extracting prompts...",
28
+ )
29
+ else:
30
+ # For general datasets, just use the text column as prompt
31
+ def identity(example):
32
+ return {"prompt": example.get(dataset_config.text_column, "")}
33
+
34
+ self.data = self.data.map(
35
+ identity,
36
+ remove_columns=self.data.column_names,
37
+ desc="Preparing prompts...",
38
+ )
39
+
40
+ # Tokenize prompts
41
+ def tokenize_function(examples):
42
+ tokenized = self.tokenizer(
43
+ examples["prompt"],
44
+ truncation=True,
45
+ max_length=self.config.model.max_seq_length,
46
+ padding="max_length",
47
+ return_attention_mask=True,
48
+ )
49
+ return tokenized
50
+
51
+ self.data = self.data.map(
52
+ tokenize_function,
53
+ batched=True,
54
+ batch_size=100,
55
+ remove_columns=self.data.column_names,
56
+ desc="Tokenizing prompts...",
57
+ )
58
+
59
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
60
+ """Get preprocessed prompt."""
61
+ item = self.data[idx]
62
+
63
+ input_ids = torch.tensor(item["input_ids"], dtype=torch.long)
64
+ attention_mask = torch.tensor(item["attention_mask"], dtype=torch.long)
65
+
66
+ # For RL, we don't have labels yet
67
+ # They're generated during training
68
+
69
+ return {
70
+ "input_ids": input_ids,
71
+ "attention_mask": attention_mask,
72
+ # "labels" will be None or set by the trainer
73
+ }
code/TaoTrain/src/taoTrain/data/hf_sft.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SFT dataset for HuggingFace datasets."""
2
+
3
+ from typing import Dict
4
+ import torch
5
+ from taoTrain.config import TrainingConfig
6
+ from taoTrain.data.hf_base import BaseHFDataset
7
+
8
+
9
+ class SFTDataset(BaseHFDataset):
10
+ """Dataset for supervised fine-tuning with instruction-response pairs."""
11
+
12
+ def _preprocess(self):
13
+ """Process instruction-response pairs."""
14
+ dataset_config = self.config.dataset
15
+
16
+ def format_example(example):
17
+ """Format instruction and response."""
18
+ instruction = example.get(dataset_config.instruction_column, "")
19
+ response = example.get(dataset_config.response_column, "")
20
+
21
+ if dataset_config.instruction_template:
22
+ # Use custom template
23
+ text = dataset_config.instruction_template.format(
24
+ instruction=instruction,
25
+ response=response
26
+ )
27
+ else:
28
+ # Default template
29
+ text = f"{instruction}\n{response}"
30
+
31
+ return {"text": text}
32
+
33
+ # Format examples
34
+ self.data = self.data.map(
35
+ format_example,
36
+ remove_columns=[
37
+ col for col in self.data.column_names
38
+ if col not in ["text"]
39
+ ] if "text" not in self.data.column_names else [],
40
+ desc="Formatting examples...",
41
+ )
42
+
43
+ # Tokenize
44
+ def tokenize_function(examples):
45
+ tokenized = self.tokenizer(
46
+ examples["text"],
47
+ truncation=True,
48
+ max_length=self.config.model.max_seq_length,
49
+ padding="max_length",
50
+ return_attention_mask=True,
51
+ )
52
+ return tokenized
53
+
54
+ self.data = self.data.map(
55
+ tokenize_function,
56
+ batched=True,
57
+ batch_size=100,
58
+ remove_columns=self.data.column_names,
59
+ desc="Tokenizing...",
60
+ )
61
+
62
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
63
+ """Get preprocessed sample."""
64
+ item = self.data[idx]
65
+
66
+ input_ids = torch.tensor(item["input_ids"], dtype=torch.long)
67
+ attention_mask = torch.tensor(item["attention_mask"], dtype=torch.long)
68
+
69
+ # For SFT, labels = input_ids shifted by 1 (next token prediction)
70
+ # Position i predicts token at position i+1
71
+ labels = input_ids[1:].clone()
72
+ labels = torch.cat([labels, torch.tensor([-100])], dim=0)
73
+
74
+ # Mark padding tokens as -100 to ignore in loss computation
75
+ labels[attention_mask == 0] = -100
76
+
77
+ return {
78
+ "input_ids": input_ids,
79
+ "attention_mask": attention_mask,
80
+ "labels": labels,
81
+ }
code/TaoTrain/src/taoTrain/data/jsonl_base.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Base class for local JSONL-based datasets (async-only)."""
2
+
3
+ import json
4
+ from typing import Optional, Dict, Any
5
+ import torch
6
+ from torch.utils.data import Dataset
7
+ from taoTrain.config import TrainingConfig
8
+ from taoTrain.data.chunk_manager import ChunkManager
9
+ from taoTrain.data.tokenizer import SentencePieceTokenizerWrapper
10
+
11
+
12
+ class BaseJSONLDataset(Dataset):
13
+ """
14
+ Base class for local JSONL-based datasets with async-only streaming.
15
+
16
+ Designed for use with AsyncBatchIterator and TokenizationQueue.
17
+ All data loading and preprocessing happens asynchronously in background threads.
18
+ """
19
+
20
+ def __init__(self, config: TrainingConfig, split: str = "train"):
21
+ """
22
+ Initialize JSONL dataset with chunked loading.
23
+
24
+ Args:
25
+ config: Training configuration
26
+ split: Dataset split (train, validation, test) - not used for JSONL but kept for compatibility
27
+
28
+ Note:
29
+ Requires AsyncBatchIterator and TokenizationQueue for data loading.
30
+ See taoTrain/data/async_loader.py for usage.
31
+ """
32
+ self.config = config
33
+ self.split = split
34
+ self.tokenizer = None
35
+
36
+ # Initialize chunk manager for streaming
37
+ dataset_config = self.config.dataset
38
+ jsonl_path = dataset_config.jsonl_path
39
+
40
+ if not jsonl_path:
41
+ raise ValueError("jsonl_path must be provided for local JSONL datasets")
42
+
43
+ # Create chunk manager
44
+ enable_streaming = dataset_config.enable_streaming
45
+ chunk_size_gb = dataset_config.chunk_size_gb
46
+ samples_per_chunk = dataset_config.samples_per_chunk
47
+ enable_metadata_cache = dataset_config.enable_chunk_metadata_cache
48
+ chunk_cache_dir = dataset_config.chunk_cache_dir
49
+ max_samples = dataset_config.max_samples
50
+
51
+ if enable_streaming:
52
+ self.chunk_manager = ChunkManager(
53
+ jsonl_path,
54
+ chunk_size_gb=chunk_size_gb,
55
+ samples_per_chunk=samples_per_chunk,
56
+ enable_metadata_cache=enable_metadata_cache,
57
+ chunk_cache_dir=chunk_cache_dir,
58
+ max_samples=max_samples
59
+ )
60
+ print(f"✓ {self.chunk_manager}")
61
+ else:
62
+ self.chunk_manager = None
63
+
64
+ # Current chunk data
65
+ self._current_chunk_num = None
66
+ self._current_chunk_data = None # {"text": [...]} or preprocessed data
67
+ self._text_field = dataset_config.text_field
68
+
69
+ # Load tokenizer
70
+ print("✓ Loading tokenizer...")
71
+ self._load_tokenizer()
72
+
73
+ print("✓ Dataset initialization complete (async mode - chunks loaded on-demand).")
74
+
75
+ def _load_tokenizer(self):
76
+ """Load tokenizer (from local SentencePiece or HuggingFace)."""
77
+ dataset_config = self.config.dataset
78
+
79
+ # Check if tokenizer_path is specified
80
+ if dataset_config.tokenizer_path:
81
+ tokenizer_type = dataset_config.tokenizer_type
82
+
83
+ # Auto-detect tokenizer type based on file extension
84
+ if tokenizer_type is None:
85
+ if dataset_config.tokenizer_path.endswith('.model'):
86
+ tokenizer_type = 'sentencepiece'
87
+ else:
88
+ tokenizer_type = 'huggingface'
89
+
90
+ if tokenizer_type == 'sentencepiece':
91
+ # Load SentencePiece tokenizer
92
+ try:
93
+ import sentencepiece as spm
94
+ sp = spm.SentencePieceProcessor()
95
+ sp.Load(dataset_config.tokenizer_path)
96
+ # Wrap SentencePiece in a compatible interface
97
+ self.tokenizer = SentencePieceTokenizerWrapper(sp)
98
+ except ImportError:
99
+ raise ImportError("SentencePiece not installed. Install with: pip install sentencepiece")
100
+ except Exception as e:
101
+ raise ValueError(f"Failed to load SentencePiece tokenizer from {dataset_config.tokenizer_path}: {e}")
102
+ else:
103
+ # Load HuggingFace tokenizer from path
104
+ try:
105
+ from transformers import AutoTokenizer
106
+ self.tokenizer = AutoTokenizer.from_pretrained(dataset_config.tokenizer_path)
107
+ except ImportError as e:
108
+ raise ImportError("HuggingFace tokenizers require the optional 'transformers' dependency") from e
109
+ except Exception as e:
110
+ raise ValueError(f"Failed to load HuggingFace tokenizer from {dataset_config.tokenizer_path}: {e}")
111
+ else:
112
+ # Default to GPT-2 tokenizer
113
+ try:
114
+ from transformers import AutoTokenizer
115
+ except ImportError as e:
116
+ raise ImportError("Default GPT-2 tokenizer requires the optional 'transformers' dependency") from e
117
+ tokenizer_name = getattr(self.config, 'tokenizer_name', 'gpt2')
118
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
119
+
120
+ # Set pad token if not set (for HuggingFace tokenizers)
121
+ if hasattr(self.tokenizer, 'pad_token') and self.tokenizer.pad_token is None:
122
+ if hasattr(self.tokenizer, 'eos_token'):
123
+ self.tokenizer.pad_token = self.tokenizer.eos_token
124
+
125
+ def _load_chunk(self, chunk_num: int):
126
+ """
127
+ Load a specific chunk from JSONL file.
128
+
129
+ Args:
130
+ chunk_num: Chunk number to load (0-indexed)
131
+ """
132
+ if not self.chunk_manager:
133
+ return
134
+
135
+ if chunk_num == self._current_chunk_num and self._current_chunk_data is not None:
136
+ # Already loaded
137
+ return
138
+
139
+ # Read chunk
140
+ chunk_examples = self.chunk_manager.read_chunk(chunk_num)
141
+
142
+ # Convert to text data
143
+ texts = []
144
+ for obj in chunk_examples:
145
+ if self._text_field in obj:
146
+ texts.append(obj[self._text_field])
147
+
148
+ self._current_chunk_data = {"text": texts}
149
+ self._current_chunk_num = chunk_num
150
+
151
+ # Preprocess chunk (tokenization happens in background via AsyncBatchIterator)
152
+ self._preprocess_chunk()
153
+
154
+ def _get_chunk_for_idx(self, idx: int) -> int:
155
+ """
156
+ Determine which chunk contains the given global index.
157
+
158
+ Args:
159
+ idx: Global index
160
+
161
+ Returns:
162
+ Chunk number (0-indexed)
163
+ """
164
+ if not self.chunk_manager:
165
+ return 0
166
+
167
+ current_line = 0
168
+ for chunk_num, (start_line, end_line) in enumerate(self.chunk_manager.chunk_line_ranges):
169
+ if idx < (end_line - start_line):
170
+ return chunk_num
171
+ idx -= (end_line - start_line)
172
+
173
+ # Shouldn't reach here
174
+ return 0
175
+
176
+ def _get_local_idx_in_chunk(self, global_idx: int) -> int:
177
+ """
178
+ Convert global index to local index within the chunk.
179
+
180
+ Args:
181
+ global_idx: Global index
182
+
183
+ Returns:
184
+ Local index within the chunk
185
+ """
186
+ if not self.chunk_manager:
187
+ return global_idx
188
+
189
+ current_line = 0
190
+ for chunk_num, (start_line, end_line) in enumerate(self.chunk_manager.chunk_line_ranges):
191
+ chunk_size = end_line - start_line
192
+ if global_idx < chunk_size:
193
+ return global_idx
194
+ global_idx -= chunk_size
195
+
196
+ return 0
197
+
198
+ def _preprocess(self):
199
+ """Preprocess dataset (to be implemented by subclasses)."""
200
+ pass
201
+
202
+ def _preprocess_chunk(self):
203
+ """
204
+ Preprocess current chunk (to be implemented by subclasses).
205
+
206
+ This is called after a chunk is loaded by AsyncBatchIterator.
207
+ """
208
+ pass
209
+
210
+ def __len__(self) -> int:
211
+ """Return dataset length."""
212
+ if self.chunk_manager:
213
+ return self.chunk_manager.effective_lines
214
+ elif self._current_chunk_data and "text" in self._current_chunk_data:
215
+ return len(self._current_chunk_data.get("text", []))
216
+ return 0
217
+
218
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
219
+ """Get item (to be implemented by subclasses)."""
220
+ pass
code/TaoTrain/src/taoTrain/data/loaders.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DataLoader utilities."""
2
+
3
+ from typing import Optional
4
+ import torch
5
+ from torch.utils.data import DataLoader, Dataset
6
+ from taoTrain.config import TrainingConfig
7
+
8
+
9
+ def get_dataloader(
10
+ dataset: Dataset,
11
+ config: TrainingConfig,
12
+ shuffle: bool = True,
13
+ drop_last: bool = True,
14
+ ) -> DataLoader:
15
+ """
16
+ Create a DataLoader from a dataset.
17
+
18
+ **NOTE**: For JSONL-based datasets (PretrainJSONLDataset, SFTJSONLDataset, etc.),
19
+ this function is now deprecated in favor of AsyncBatchIterator for better performance.
20
+ AsyncBatchIterator enables tokenization to happen in parallel with training,
21
+ avoiding the startup bottleneck of tokenizing all data upfront.
22
+
23
+ See: taoTrain/data/async_loader.py for the new async loading approach.
24
+ The trainer automatically uses AsyncBatchIterator for JSONL datasets.
25
+
26
+ Args:
27
+ dataset: PyTorch Dataset instance
28
+ config: Training configuration
29
+ shuffle: Whether to shuffle data
30
+ drop_last: Whether to drop last incomplete batch
31
+
32
+ Returns:
33
+ DataLoader instance
34
+ """
35
+
36
+ def collate_fn(batch):
37
+ """Collate function for padding sequences."""
38
+ # Batch is a list of dicts
39
+ collated = {}
40
+ keys = batch[0].keys()
41
+
42
+ for key in keys:
43
+ items = [item[key] for item in batch]
44
+
45
+ # Stack tensors
46
+ if isinstance(items[0], torch.Tensor):
47
+ if key in ["input_ids", "labels"]:
48
+ # Pad sequences
49
+ max_len = max(item.shape[0] for item in items)
50
+ padded = []
51
+ for item in items:
52
+ if len(item.shape) == 1:
53
+ # 1D tensor - pad it
54
+ pad_len = max_len - item.shape[0]
55
+ if pad_len > 0:
56
+ item = torch.nn.functional.pad(item, (0, pad_len), value=-100 if key == "labels" else 0)
57
+ padded.append(item)
58
+ collated[key] = torch.stack(padded)
59
+ elif key == "attention_mask":
60
+ # Also pad attention mask
61
+ max_len = max(item.shape[0] for item in items)
62
+ padded = []
63
+ for item in items:
64
+ if len(item.shape) == 1:
65
+ pad_len = max_len - item.shape[0]
66
+ if pad_len > 0:
67
+ item = torch.nn.functional.pad(item, (0, pad_len), value=0)
68
+ padded.append(item)
69
+ collated[key] = torch.stack(padded)
70
+ else:
71
+ collated[key] = torch.stack(items)
72
+ else:
73
+ collated[key] = items
74
+
75
+ return collated
76
+
77
+ return DataLoader(
78
+ dataset,
79
+ batch_size=config.batch_size,
80
+ shuffle=shuffle,
81
+ drop_last=drop_last,
82
+ num_workers=config.num_workers,
83
+ pin_memory=config.pin_memory,
84
+ collate_fn=collate_fn,
85
+ )
code/TaoTrain/src/taoTrain/data/pretrain_jsonl.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pretrain JSONL dataset with async-only streaming."""
2
+
3
+ from typing import Dict
4
+ import torch
5
+ from taoTrain.config import TrainingConfig
6
+ from taoTrain.data.jsonl_base import BaseJSONLDataset
7
+
8
+
9
+ class PretrainJSONLDataset(BaseJSONLDataset):
10
+ """Dataset for pretraining with local JSONL files with chunked loading."""
11
+
12
+ def _preprocess_chunk(self):
13
+ """Tokenize current chunk of text data."""
14
+ if not self._current_chunk_data or "text" not in self._current_chunk_data:
15
+ return
16
+
17
+ max_seq_length = self.config.model.max_seq_length
18
+ texts = self._current_chunk_data["text"]
19
+
20
+ # Tokenize all texts in this chunk
21
+ all_token_ids = []
22
+ all_attention_masks = []
23
+
24
+ for text in texts:
25
+ tokenized = self.tokenizer(
26
+ text,
27
+ truncation=True,
28
+ max_length=max_seq_length,
29
+ padding="max_length",
30
+ return_attention_mask=True,
31
+ )
32
+ all_token_ids.append(tokenized["input_ids"])
33
+ all_attention_masks.append(tokenized["attention_mask"])
34
+
35
+ self._current_chunk_data = {
36
+ "input_ids": all_token_ids,
37
+ "attention_mask": all_attention_masks,
38
+ }
39
+
40
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
41
+ """Get preprocessed sample, loading chunk if needed."""
42
+ # Load appropriate chunk if using streaming
43
+ if self.chunk_manager:
44
+ chunk_num = self._get_chunk_for_idx(idx)
45
+ if chunk_num != self._current_chunk_num:
46
+ self._load_chunk(chunk_num)
47
+ local_idx = self._get_local_idx_in_chunk(idx)
48
+ else:
49
+ local_idx = idx
50
+
51
+ input_ids = torch.tensor(self._current_chunk_data["input_ids"][local_idx], dtype=torch.long)
52
+ attention_mask = torch.tensor(self._current_chunk_data["attention_mask"][local_idx], dtype=torch.long)
53
+
54
+ # For pretrain, labels = input_ids shifted
55
+ labels = input_ids[1:].clone()
56
+ labels = torch.cat([labels, torch.tensor([-100])], dim=0)
57
+
58
+ # Replace padding token labels with -100 to ignore in labels
59
+ labels[attention_mask == 0] = -100
60
+
61
+ return {
62
+ "input_ids": input_ids,
63
+ "attention_mask": attention_mask,
64
+ "labels": labels,
65
+ }
code/TaoTrain/src/taoTrain/data/rl_jsonl.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """RL JSONL dataset with async-only streaming."""
2
+
3
+ from typing import Dict
4
+ import torch
5
+ from taoTrain.config import TrainingConfig
6
+ from taoTrain.data.jsonl_base import BaseJSONLDataset
7
+
8
+
9
+ class RLJSONLDataset(BaseJSONLDataset):
10
+ """Dataset for RL training with local JSONL files with chunked loading."""
11
+
12
+ def _preprocess_chunk(self):
13
+ """Prepare prompts for RL from current chunk."""
14
+ if not self._current_chunk_data or "text" not in self._current_chunk_data:
15
+ return
16
+
17
+ max_seq_length = self.config.model.max_seq_length
18
+ texts = self._current_chunk_data["text"]
19
+
20
+ # Tokenize all prompts in this chunk
21
+ all_token_ids = []
22
+ all_attention_masks = []
23
+
24
+ for text in texts:
25
+ tokenized = self.tokenizer(
26
+ text,
27
+ truncation=True,
28
+ max_length=max_seq_length,
29
+ padding="max_length",
30
+ return_attention_mask=True,
31
+ )
32
+ all_token_ids.append(tokenized["input_ids"])
33
+ all_attention_masks.append(tokenized["attention_mask"])
34
+
35
+ self._current_chunk_data = {
36
+ "input_ids": all_token_ids,
37
+ "attention_mask": all_attention_masks,
38
+ }
39
+
40
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
41
+ """Get preprocessed prompt, loading chunk if needed."""
42
+ # Load appropriate chunk if using streaming
43
+ if self.chunk_manager:
44
+ chunk_num = self._get_chunk_for_idx(idx)
45
+ if chunk_num != self._current_chunk_num:
46
+ self._load_chunk(chunk_num)
47
+ local_idx = self._get_local_idx_in_chunk(idx)
48
+ else:
49
+ local_idx = idx
50
+
51
+ input_ids = torch.tensor(self._current_chunk_data["input_ids"][local_idx], dtype=torch.long)
52
+ attention_mask = torch.tensor(self._current_chunk_data["attention_mask"][local_idx], dtype=torch.long)
53
+
54
+ # For RL, no labels yet (generated during training)
55
+ return {
56
+ "input_ids": input_ids,
57
+ "attention_mask": attention_mask,
58
+ }
code/TaoTrain/src/taoTrain/data/sft_jsonl.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SFT JSONL dataset with async-only streaming and response-masking."""
2
+
3
+ from typing import Dict
4
+ import torch
5
+ from taoTrain.config import TrainingConfig
6
+ from taoTrain.data.jsonl_base import BaseJSONLDataset
7
+ from taoTrain.data.sft_utils import (
8
+ parse_sft_record,
9
+ build_sft_sequence_tokens,
10
+ build_response_only_next_token_labels,
11
+ )
12
+
13
+
14
+ class SFTJSONLDataset(BaseJSONLDataset):
15
+ """
16
+ Dataset for supervised fine-tuning with local JSONL files with chunked loading.
17
+
18
+ Supports both single-turn and multi-turn SFT data:
19
+ - Single-turn: {"input": "...", "output": "..."}
20
+ - Multi-turn: {"turns": [{"user": "...", "assistant": "..."}, ...]}
21
+
22
+ With response-only loss masking: only trains on assistant/response tokens.
23
+ """
24
+
25
+ def __init__(self, *args, **kwargs):
26
+ """Initialize dataset."""
27
+ super().__init__(*args, **kwargs)
28
+ # Store full records for parsing (not just text field)
29
+ self._current_chunk_records = None
30
+
31
+ # Get SFT-specific config
32
+ self.sft_config = self.config if hasattr(self.config, 'mode') else None
33
+ self.user_token = getattr(self.sft_config, 'user_token', '<user>') if self.sft_config else '<user>'
34
+ self.assistant_token = getattr(self.sft_config, 'assistant_token', '<assistant>') if self.sft_config else '<assistant>'
35
+ self.response_loss_only = getattr(self.sft_config, 'response_loss_only', True) if self.sft_config else True
36
+
37
+ def _load_chunk(self, chunk_num: int):
38
+ """
39
+ Load a specific chunk from JSONL file, preserving full records for SFT parsing.
40
+
41
+ Args:
42
+ chunk_num: Chunk number to load (0-indexed)
43
+ """
44
+ if not self.chunk_manager:
45
+ return
46
+
47
+ if chunk_num == self._current_chunk_num and self._current_chunk_data is not None:
48
+ # Already loaded
49
+ return
50
+
51
+ # Read chunk - get full record objects
52
+ chunk_examples = self.chunk_manager.read_chunk(chunk_num)
53
+
54
+ # Store full records for SFT parsing (not just text field)
55
+ self._current_chunk_records = chunk_examples
56
+
57
+ # Initialize data structures
58
+ self._current_chunk_data = {
59
+ "input_ids": [],
60
+ "attention_mask": [],
61
+ "mask": [],
62
+ }
63
+ self._current_chunk_num = chunk_num
64
+
65
+ # Preprocess this chunk (tokenize and mask)
66
+ self._preprocess_chunk()
67
+
68
+ def _preprocess_chunk(self):
69
+ """
70
+ Process SFT records from current chunk into tokenized sequences with masking.
71
+
72
+ Parses each record (single-turn or multi-turn) and generates:
73
+ - Token sequences with role markers
74
+ - Masking info (0=ignore, 1=train)
75
+ - Labels with -100 for ignored tokens
76
+ """
77
+ if not self._current_chunk_records:
78
+ return
79
+
80
+ max_seq_length = self.config.model.max_seq_length
81
+
82
+ all_input_ids = []
83
+ all_attention_masks = []
84
+ all_masks = []
85
+
86
+ for record in self._current_chunk_records:
87
+ try:
88
+ # Parse record into (user, assistant) turns
89
+ turns, is_multi_turn = parse_sft_record(record, self.config)
90
+
91
+ if not turns:
92
+ # Fallback: try to use "text" field if present
93
+ if "text" in record:
94
+ turns = [(record["text"], "")]
95
+ else:
96
+ continue # Skip invalid records
97
+
98
+ # Build token sequence with role tokens and masking
99
+ input_ids, attention_mask, mask = build_sft_sequence_tokens(
100
+ turns=turns,
101
+ tokenizer=self.tokenizer,
102
+ user_token=self.user_token,
103
+ assistant_token=self.assistant_token,
104
+ max_seq_length=max_seq_length,
105
+ )
106
+
107
+ all_input_ids.append(input_ids)
108
+ all_attention_masks.append(attention_mask)
109
+ all_masks.append(mask)
110
+
111
+ except Exception as e:
112
+ # Log and skip problematic records
113
+ print(f"Warning: Failed to process SFT record: {e}")
114
+ continue
115
+
116
+ # Update chunk data with tokenized sequences and masks
117
+ self._current_chunk_data = {
118
+ "input_ids": all_input_ids,
119
+ "attention_mask": all_attention_masks,
120
+ "mask": all_masks,
121
+ }
122
+
123
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
124
+ """
125
+ Get preprocessed sample with response-only loss masking.
126
+
127
+ Args:
128
+ idx: Sample index
129
+
130
+ Returns:
131
+ Dict with input_ids, attention_mask, and labels (with -100 for ignored tokens)
132
+ """
133
+ # Load appropriate chunk if using streaming
134
+ if self.chunk_manager:
135
+ chunk_num = self._get_chunk_for_idx(idx)
136
+ if chunk_num != self._current_chunk_num:
137
+ self._load_chunk(chunk_num)
138
+ local_idx = self._get_local_idx_in_chunk(idx)
139
+ else:
140
+ local_idx = idx
141
+
142
+ # Get tokenized data
143
+ input_ids = torch.tensor(self._current_chunk_data["input_ids"][local_idx], dtype=torch.long)
144
+ attention_mask = torch.tensor(self._current_chunk_data["attention_mask"][local_idx], dtype=torch.long)
145
+ mask = self._current_chunk_data["mask"][local_idx]
146
+
147
+ labels = torch.tensor(
148
+ build_response_only_next_token_labels(input_ids.tolist(), mask),
149
+ dtype=torch.long,
150
+ )
151
+
152
+ return {
153
+ "input_ids": input_ids,
154
+ "attention_mask": attention_mask,
155
+ "labels": labels,
156
+ }
code/TaoTrain/src/taoTrain/data/sft_utils.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SFT utility functions for parsing and masking."""
2
+
3
+ from typing import Dict, Any, List, Tuple
4
+ from taoTrain.config import TrainingConfig
5
+
6
+
7
+ def parse_sft_record(record: Dict[str, Any], config: TrainingConfig) -> Tuple[List[Tuple[str, str]], bool]:
8
+ """
9
+ Parse JSONL record into list of (user, assistant) turns.
10
+
11
+ Supports two formats:
12
+ 1. Single-turn: {"input": "...", "output": "..."}
13
+ 2. Multi-turn: {"turns": [{"user": "...", "assistant": "..."}, ...]}
14
+
15
+ Args:
16
+ record: JSONL record (dict)
17
+ config: Training configuration
18
+
19
+ Returns:
20
+ (turns_list, is_multi_turn) where:
21
+ - turns_list: List of (user_text, assistant_text) tuples
22
+ - is_multi_turn: Whether this is a multi-turn record
23
+ """
24
+ # Check for multi-turn format
25
+ if "turns" in record:
26
+ turns = []
27
+ for turn in record["turns"]:
28
+ if isinstance(turn, dict) and "user" in turn and "assistant" in turn:
29
+ turns.append((turn["user"], turn["assistant"]))
30
+ if turns:
31
+ return turns, True
32
+
33
+ # Check for single-turn format with input/output fields
34
+ if "input" in record and "output" in record:
35
+ return [(record["input"], record["output"])], False
36
+
37
+ # Fallback: check for instruction/response fields (from config)
38
+ dataset_config = config.dataset
39
+ instruction_col = dataset_config.instruction_column or "instruction"
40
+ response_col = dataset_config.response_column or "response"
41
+
42
+ if instruction_col in record and response_col in record:
43
+ return [(record[instruction_col], record[response_col])], False
44
+
45
+ # Fallback: assume pre-formatted "text" field (old format)
46
+ if "text" in record:
47
+ return [(record["text"], "")], False
48
+
49
+ return [], False
50
+
51
+
52
+ def build_sft_sequence_tokens(
53
+ turns: List[Tuple[str, str]],
54
+ tokenizer,
55
+ user_token: str = "<user>",
56
+ assistant_token: str = "<assistant>",
57
+ max_seq_length: int = 1024,
58
+ ) -> Tuple[List[int], List[int], List[int]]:
59
+ """
60
+ Build token sequence for SFT with role tokens and generate masking info.
61
+
62
+ Sequence format:
63
+ [user_token_id] user_tokens [assistant_token_id] assistant_tokens ... [eos_token_id]
64
+
65
+ Mask values:
66
+ - 0 (ignore): user input regions and role tokens → loss=-100
67
+ - 1 (train): assistant output regions → compute loss
68
+
69
+ Args:
70
+ turns: List of (user_text, assistant_text) tuples
71
+ tokenizer: Tokenizer instance
72
+ user_token: Role token for user (e.g., "<user>")
73
+ assistant_token: Role token for assistant (e.g., "<assistant>")
74
+ max_seq_length: Maximum sequence length
75
+
76
+ Returns:
77
+ (input_ids, attention_mask, mask) where:
78
+ - input_ids: Token IDs for the full sequence
79
+ - attention_mask: Attention mask (1 for real tokens, 0 for padding)
80
+ - mask: Loss mask (0=ignore, 1=train loss)
81
+ """
82
+ input_ids = []
83
+ mask = []
84
+
85
+ # Get token IDs for special tokens
86
+ user_token_ids = tokenizer(user_token, add_special_tokens=False)["input_ids"]
87
+ assistant_token_ids = tokenizer(assistant_token, add_special_tokens=False)["input_ids"]
88
+
89
+ # Process each turn
90
+ for user_text, assistant_text in turns:
91
+ # User role marker
92
+ input_ids.extend(user_token_ids)
93
+ mask.extend([0] * len(user_token_ids)) # Mask role token
94
+
95
+ # User message tokens
96
+ user_tokens = tokenizer(user_text, add_special_tokens=False)["input_ids"]
97
+ input_ids.extend(user_tokens)
98
+ mask.extend([0] * len(user_tokens)) # Mask user input
99
+
100
+ # Assistant role marker
101
+ input_ids.extend(assistant_token_ids)
102
+ mask.extend([0] * len(assistant_token_ids)) # Mask role token
103
+
104
+ # Assistant message tokens
105
+ assistant_tokens = tokenizer(assistant_text, add_special_tokens=False)["input_ids"]
106
+ input_ids.extend(assistant_tokens)
107
+ mask.extend([1] * len(assistant_tokens)) # Train on assistant output
108
+
109
+ # Add EOS token if exists
110
+ if hasattr(tokenizer, 'eos_token_id') and tokenizer.eos_token_id is not None:
111
+ input_ids.append(tokenizer.eos_token_id)
112
+ mask.append(0) # Mask EOS token
113
+
114
+ # Truncate if too long
115
+ if len(input_ids) > max_seq_length:
116
+ input_ids = input_ids[:max_seq_length]
117
+ mask = mask[:max_seq_length]
118
+
119
+ # Pad to max_seq_length
120
+ padding_len = max_seq_length - len(input_ids)
121
+ if padding_len > 0:
122
+ input_ids.extend([tokenizer.pad_token_id or 0] * padding_len)
123
+ mask.extend([0] * padding_len) # Mask padding tokens
124
+
125
+ # Create attention mask (1 for real tokens, 0 for padding)
126
+ attention_mask = [1 if i < len(input_ids) - padding_len else 0 for i in range(len(input_ids))]
127
+
128
+ return input_ids, attention_mask, mask
129
+
130
+
131
+ def apply_response_masking(input_ids: List[int], mask: List[int]) -> List[int]:
132
+ """
133
+ Apply response-only loss masking by converting mask values to label format.
134
+
135
+ Args:
136
+ input_ids: Token IDs
137
+ mask: Mask array (0=ignore, 1=train)
138
+
139
+ Returns:
140
+ labels: Where mask=0 tokens have label=-100 (ignore in loss), mask=1 tokens have label=input_id
141
+ """
142
+ labels = input_ids.copy()
143
+ for i, m in enumerate(mask):
144
+ if m == 0:
145
+ labels[i] = -100 # CrossEntropyLoss will ignore this token
146
+ return labels
147
+
148
+
149
+ def build_response_only_next_token_labels(input_ids: List[int], mask: List[int]) -> List[int]:
150
+ """
151
+ Build next-token labels for SFT response-only training.
152
+
153
+ Position i predicts token i+1, so the loss mask must be applied to the target
154
+ token, not the current input token. This trains the first assistant token from
155
+ the assistant role marker and avoids training on masked EOS/padding targets.
156
+ """
157
+ if len(input_ids) != len(mask):
158
+ raise ValueError(f"input_ids and mask must have the same length: {len(input_ids)} != {len(mask)}")
159
+
160
+ labels = apply_response_masking(input_ids, mask)
161
+ return labels[1:] + [-100]
code/TaoTrain/src/taoTrain/data/tokenization_queue.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Background tokenization queue for streaming large JSONL datasets."""
2
+
3
+ import queue
4
+ import threading
5
+ import time
6
+ from typing import Dict, List, Optional, Any, Callable
7
+ import torch
8
+
9
+ from taoTrain.data.chunk_manager import ChunkManager
10
+
11
+
12
+ class TokenizationQueue:
13
+ """
14
+ Background threads that continuously tokenize chunks and stores them in a queue.
15
+
16
+ This allows tokenization to happen in parallel with training, avoiding the bottleneck
17
+ of tokenizing all data upfront before training starts.
18
+
19
+ Supports multiple worker threads for faster throughput. Each thread greedily
20
+ grabs the next available chunk using an atomic counter.
21
+
22
+ Attributes:
23
+ total_items: Total number of samples across all chunks
24
+ queue_size: Maximum number of chunks to buffer in memory
25
+ num_threads: Number of worker threads for tokenization
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ chunk_manager: ChunkManager,
31
+ tokenizer: Any,
32
+ config: "TrainingConfig", # type: ignore
33
+ max_queue_size: int = 2,
34
+ shuffle_chunks: bool = True,
35
+ num_threads: int = 1,
36
+ ):
37
+ """
38
+ Initialize tokenization queue with multithreading support.
39
+
40
+ Args:
41
+ chunk_manager: ChunkManager instance loaded with chunks
42
+ tokenizer: Tokenizer instance (HuggingFace or SentencePiece wrapper)
43
+ config: Training configuration with model and dataset settings
44
+ max_queue_size: Maximum chunks to buffer in queue (memory constraint)
45
+ shuffle_chunks: Whether to shuffle chunk order at initialization
46
+ num_threads: Number of worker threads for tokenization (default: 1)
47
+
48
+ Raises:
49
+ ValueError: If chunk_manager has no chunks or num_threads < 1
50
+ """
51
+ if chunk_manager.num_chunks == 0:
52
+ raise ValueError("ChunkManager must have at least one chunk")
53
+ if num_threads < 1:
54
+ raise ValueError(f"num_threads must be >= 1, got {num_threads}")
55
+
56
+ self.chunk_manager = chunk_manager
57
+ self.tokenizer = tokenizer
58
+ self.config = config
59
+ self.max_queue_size = max_queue_size
60
+ self.shuffle_chunks = shuffle_chunks
61
+ self.num_threads = num_threads
62
+
63
+ # Detect SFT mode: check for response_loss_only flag
64
+ self.is_sft_mode = hasattr(config, 'response_loss_only') and config.response_loss_only
65
+
66
+ # Calculate total items across all chunks
67
+ self.total_items = chunk_manager.effective_lines
68
+
69
+ # Thread-safe queue for tokenized chunks
70
+ self._queue: queue.Queue[Dict[str, List]] = queue.Queue(maxsize=max_queue_size)
71
+
72
+ # Control signals
73
+ self._stop_event = threading.Event()
74
+ self._error_event = threading.Event()
75
+ self._error_messages: List[str] = []
76
+ self._threads: List[threading.Thread] = []
77
+
78
+ # Thread-safe chunk distribution
79
+ self._next_chunk_idx = 0
80
+ self._chunk_idx_lock = threading.Lock()
81
+ self._active_threads = 0
82
+ self._active_threads_lock = threading.Lock()
83
+
84
+ # Chunk ordering
85
+ self._chunk_order = list(range(chunk_manager.num_chunks))
86
+ print(f"TokenizationQueue initialized with {chunk_manager.num_chunks} chunks, total {chunk_manager.effective_lines} samples")
87
+ print(f"Using {num_threads} tokenization worker thread{'s' if num_threads != 1 else ''}")
88
+ print(f"Max queue size: {max_queue_size} chunks (memory constraint)")
89
+ if self.shuffle_chunks:
90
+ import random
91
+ random.shuffle(self._chunk_order)
92
+
93
+ def _get_next_chunk_idx(self) -> Optional[int]:
94
+ """
95
+ Atomically get the next chunk index for processing.
96
+
97
+ Returns:
98
+ Chunk index to process, or None if all chunks have been assigned
99
+ """
100
+ with self._chunk_idx_lock:
101
+ if self._next_chunk_idx < len(self._chunk_order):
102
+ chunk_idx = self._chunk_order[self._next_chunk_idx]
103
+ self._next_chunk_idx += 1
104
+ return chunk_idx
105
+ return None
106
+
107
+ def start(self):
108
+ """Start the tokenization background worker threads."""
109
+ if self._threads:
110
+ raise RuntimeError(f"Tokenization threads already started ({len(self._threads)} active)")
111
+
112
+ # Create and start N worker threads
113
+ for thread_id in range(self.num_threads):
114
+ thread = threading.Thread(target=self._worker, args=(thread_id,), daemon=False)
115
+ self._threads.append(thread)
116
+ thread.start()
117
+
118
+ def _worker(self, thread_id: int):
119
+ """
120
+ Worker thread target: greedy chunk processing with thread-safe distribution.
121
+
122
+ Args:
123
+ thread_id: Identifier for this worker thread
124
+ """
125
+ with self._active_threads_lock:
126
+ self._active_threads += 1
127
+
128
+ try:
129
+ while True:
130
+ # Check for stop signal
131
+ if self._stop_event.is_set():
132
+ break
133
+
134
+ # Get next chunk to process (atomic operation)
135
+ chunk_num = self._get_next_chunk_idx()
136
+ if chunk_num is None:
137
+ # All chunks assigned
138
+ break
139
+
140
+ # Load chunk
141
+ chunk_examples = self.chunk_manager.read_chunk(chunk_num)
142
+
143
+ # Tokenize chunk based on mode
144
+ if self.is_sft_mode:
145
+ tokenized_chunk = self._tokenize_batch_sft(chunk_examples)
146
+ else:
147
+ # Extract texts for pretrain
148
+ text_field = self.config.dataset.text_field
149
+ texts = [obj.get(text_field, "") for obj in chunk_examples]
150
+ tokenized_chunk = self._tokenize_batch(texts)
151
+
152
+ # Put in queue (blocks if queue is full)
153
+ self._queue.put(tokenized_chunk)
154
+ print(f"[Worker-{thread_id}] Processed chunk {chunk_num}, put {len(tokenized_chunk['input_ids'])} samples in queue")
155
+ except Exception as e:
156
+ error_msg = f"[Worker-{thread_id}] {str(e)}"
157
+ print(f"Worker-{thread_id} encountered an error: {error_msg}")
158
+ # Thread-safe append to error list
159
+ self._error_messages.append(error_msg)
160
+ self._error_event.set()
161
+ finally:
162
+ with self._active_threads_lock:
163
+ self._active_threads -= 1
164
+ remaining = self._active_threads
165
+ print(f"[Worker-{thread_id}] Finished processing. Active threads remaining: {remaining}")
166
+ def _tokenize_batch(self, texts: List[str]) -> Dict[str, List]:
167
+ """
168
+ Tokenize a batch of texts, join with EOS, and split into fixed-size sequences.
169
+
170
+ This packs multiple documents into longer sequences separated by EOS tokens,
171
+ then splits the concatenated tokens into N fixed-size chunks of max_seq_length.
172
+
173
+ Args:
174
+ texts: List of text strings
175
+
176
+ Returns:
177
+ Dict with 'input_ids' and 'attention_mask' lists, where each element
178
+ is a fixed-size sequence of length max_seq_length
179
+ """
180
+ max_seq_length = self.config.model.max_seq_length
181
+
182
+ # Get EOS token ID
183
+ eos_token_id = self.tokenizer.eos_token_id
184
+ unk_token_id = self.tokenizer.unk_token_id
185
+ if eos_token_id is None:
186
+ raise ValueError("Tokenizer does not have an EOS token defined")
187
+ if unk_token_id is None:
188
+ raise ValueError("Tokenizer does not have an UNK token defined")
189
+
190
+ # Tokenize all texts without truncation
191
+ all_token_ids = []
192
+
193
+ for i, text in enumerate(texts):
194
+ tokenized = self.tokenizer(
195
+ text,
196
+ truncation=False,
197
+ return_attention_mask=False,
198
+ )
199
+
200
+ # Remove UNK tokens from tokenized output (if any)
201
+ tokenized["input_ids"] = [tid for tid in tokenized["input_ids"] if tid != unk_token_id]
202
+
203
+ all_token_ids.extend(tokenized["input_ids"])
204
+ # Add EOS token between documents (except after the last one)
205
+ if i < len(texts) - 1:
206
+ all_token_ids.append(eos_token_id)
207
+
208
+ # Split into N fixed-size sequences
209
+ sequences_input_ids = []
210
+ sequences_attention_masks = []
211
+
212
+ for i in range(0, len(all_token_ids), max_seq_length):
213
+ seq = all_token_ids[i : i + max_seq_length]
214
+
215
+ # Pad sequence if it's shorter than max_seq_length
216
+ if len(seq) < max_seq_length:
217
+ # Create attention mask before padding
218
+ attention_mask = [1] * len(seq) + [0] * (max_seq_length - len(seq))
219
+ # Pad with 0 (assuming 0 is the pad token, or use tokenizer.pad_token_id)
220
+ pad_token_id = self.tokenizer.pad_token_id or 0
221
+ seq = seq + [pad_token_id] * (max_seq_length - len(seq))
222
+ else:
223
+ attention_mask = [1] * max_seq_length
224
+
225
+ sequences_input_ids.append(seq)
226
+ sequences_attention_masks.append(attention_mask)
227
+
228
+ return {
229
+ "input_ids": sequences_input_ids,
230
+ "attention_mask": sequences_attention_masks,
231
+ }
232
+
233
+ def _tokenize_batch_sft(self, records: List[Dict[str, Any]]) -> Dict[str, List]:
234
+ """
235
+ Tokenize a batch of SFT records with role tokens and response masking.
236
+
237
+ Processes each record (single-turn or multi-turn) and generates sequences
238
+ with role markers and masking (0=ignore user, 1=train on assistant).
239
+
240
+ Args:
241
+ records: List of JSONL record dicts with various SFT formats
242
+
243
+ Returns:
244
+ Dict with 'input_ids', 'attention_mask', and 'mask' lists, where each
245
+ element is a fixed-size sequence of length max_seq_length with masking info
246
+ """
247
+ # Import here to avoid circular imports
248
+ from taoTrain.data.sft_utils import parse_sft_record, build_sft_sequence_tokens
249
+
250
+ max_seq_length = self.config.model.max_seq_length
251
+ user_token = getattr(self.config, 'user_token', '<user>')
252
+ assistant_token = getattr(self.config, 'assistant_token', '<assistant>')
253
+
254
+ sequences_input_ids = []
255
+ sequences_attention_masks = []
256
+ sequences_masks = []
257
+
258
+ for record in records:
259
+ try:
260
+ # Parse SFT record (supports multiple formats)
261
+ turns, is_multi_turn = parse_sft_record(record, self.config)
262
+
263
+ if not turns:
264
+ # Skip records that couldn't be parsed
265
+ continue
266
+
267
+ # Build token sequence with role tokens and response masking
268
+ input_ids, attention_mask, mask = build_sft_sequence_tokens(
269
+ turns=turns,
270
+ tokenizer=self.tokenizer,
271
+ user_token=user_token,
272
+ assistant_token=assistant_token,
273
+ max_seq_length=max_seq_length,
274
+ )
275
+
276
+ sequences_input_ids.append(input_ids)
277
+ sequences_attention_masks.append(attention_mask)
278
+ sequences_masks.append(mask)
279
+
280
+ except Exception as e:
281
+ # Log error but continue processing
282
+ print(f"Warning: Failed to tokenize SFT record: {e}")
283
+ continue
284
+
285
+ return {
286
+ "input_ids": sequences_input_ids,
287
+ "attention_mask": sequences_attention_masks,
288
+ "mask": sequences_masks,
289
+ }
290
+
291
+ def get_next_chunk(self, timeout: Optional[float] = None) -> Optional[Dict[str, List]]:
292
+ """
293
+ Get the next tokenized chunk from the queue.
294
+
295
+ This is a blocking call that waits for the next chunk to be tokenized.
296
+ Returns None if queue is closed or all chunks have been processed.
297
+
298
+ CRITICAL: Always attempts to drain the queue first before returning None.
299
+ This prevents abandoning buffered chunks when threads finish.
300
+
301
+ Args:
302
+ timeout: Timeout in seconds (None = wait indefinitely)
303
+
304
+ Returns:
305
+ Dict with tokenized chunk, or None if queue is exhausted
306
+
307
+ Raises:
308
+ RuntimeError: If an error occurred in any worker thread
309
+ """
310
+ if self._error_event.is_set():
311
+ error_summary = "; ".join(self._error_messages) if self._error_messages else "Unknown error"
312
+ raise RuntimeError(f"Tokenization thread error: {error_summary}")
313
+
314
+ # PRIORITY: Try to get from queue first (may have buffered items)
315
+ try:
316
+ chunk = self._queue.get(timeout=timeout)
317
+ return chunk
318
+ except queue.Empty:
319
+ # Queue is empty - check if threads are still working
320
+ with self._active_threads_lock:
321
+ if self._active_threads == 0 and self._next_chunk_idx >= len(self._chunk_order):
322
+ # All chunks assigned AND no active threads = true exhaustion
323
+ return None
324
+ # Queue temporarily empty but threads still working - signal to wait
325
+ return None
326
+
327
+ @property
328
+ def is_exhausted(self) -> bool:
329
+ """Return True only when all chunks are assigned and all workers are idle."""
330
+ with self._active_threads_lock:
331
+ return self._active_threads == 0 and self._next_chunk_idx >= len(self._chunk_order)
332
+
333
+ def shutdown(self, wait: bool = True):
334
+ """
335
+ Shutdown the tokenization worker threads gracefully.
336
+
337
+ Args:
338
+ wait: If True, wait for all threads to finish; otherwise return immediately
339
+ """
340
+ if not self._threads:
341
+ return
342
+
343
+ # Signal threads to stop
344
+ self._stop_event.set()
345
+
346
+ # Drain queue to unblock threads if they're waiting to put
347
+ try:
348
+ while True:
349
+ self._queue.get_nowait()
350
+ except queue.Empty:
351
+ pass
352
+
353
+ # Wait for all threads to finish
354
+ if wait:
355
+ for thread in self._threads:
356
+ thread.join(timeout=5.0)
357
+ if thread.is_alive():
358
+ print(f"⚠ Tokenization thread {thread.name} did not terminate cleanly")
359
+
360
+ # Clear thread list to allow fresh start in next epoch
361
+ self._threads.clear()
362
+ print("✓ TokenizationQueue shutdown complete, thread list cleared")
363
+
364
+ def reset_for_next_epoch(self):
365
+ """
366
+ Reset queue state for the next epoch.
367
+
368
+ This allows the same TokenizationQueue to be reused across multiple epochs.
369
+ Resets the chunk index counter, reshuffles chunks (if enabled), and clears
370
+ any buffered items and error state.
371
+
372
+ Called by AsyncBatchIterator at the start of epoch 2+.
373
+ """
374
+ # Reset iteration counter
375
+ self._next_chunk_idx = 0
376
+
377
+ # Reshuffle chunk order if enabled
378
+ if self.shuffle_chunks:
379
+ import random
380
+ random.shuffle(self._chunk_order)
381
+ print(f"✓ Reshuffled chunk order for next epoch: {self._chunk_order}")
382
+
383
+ # Drain any remaining items from queue
384
+ items_drained = 0
385
+ try:
386
+ while True:
387
+ self._queue.get_nowait()
388
+ items_drained += 1
389
+ except queue.Empty:
390
+ pass
391
+
392
+ if items_drained > 0:
393
+ print(f"⚠ Drained {items_drained} items from queue before epoch reset")
394
+
395
+ # Clear error state
396
+ self._error_event.clear()
397
+ self._error_messages.clear()
398
+
399
+ # Clear threads list so new threads will be started in next epoch
400
+ self._threads.clear()
401
+
402
+ print(f"✓ TokenizationQueue reset for next epoch. Ready to process {len(self._chunk_order)} chunks")
403
+
404
+ def __len__(self) -> int:
405
+ """Return total number of samples."""
406
+ return self.total_items
407
+
408
+ def __del__(self):
409
+ """Cleanup on deletion."""
410
+ self.shutdown(wait=False)
code/TaoTrain/src/taoTrain/data/tokenizer.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SentencePiece tokenizer wrapper for HuggingFace compatibility."""
2
+
3
+ from typing import Optional, List, Union
4
+
5
+
6
+ class SentencePieceTokenizerWrapper:
7
+ """Wrapper to make SentencePiece tokenizer compatible with HuggingFace interface."""
8
+
9
+ def __init__(self, sp_processor):
10
+ """
11
+ Initialize wrapper.
12
+
13
+ Args:
14
+ sp_processor: sentencepiece.SentencePieceProcessor instance
15
+ """
16
+ self.sp = sp_processor
17
+ self.vocab_size = self.sp.vocab_size()
18
+ self.pad_token_id = self.sp.pad_id()
19
+ self.eos_token_id = self.sp.eos_id()
20
+ self.bos_token_id = self.sp.bos_id()
21
+ self.unk_token_id = self.sp.unk_id()
22
+
23
+ def __call__(self, text, **kwargs):
24
+ """
25
+ Tokenize text.
26
+
27
+ Args:
28
+ text: Input text or list of texts
29
+ **kwargs: Additional arguments (truncation, max_length, padding, return_attention_mask)
30
+
31
+ Returns:
32
+ Dict with input_ids and attention_mask
33
+ """
34
+ # Handle both single string and list of strings
35
+ is_single = isinstance(text, str)
36
+ texts = [text] if is_single else text
37
+
38
+ max_length = kwargs.get('max_length', None)
39
+ padding = kwargs.get('padding', None)
40
+ truncation = kwargs.get('truncation', False)
41
+ return_attention_mask = kwargs.get('return_attention_mask', True)
42
+
43
+ # Tokenize all texts
44
+ all_input_ids = []
45
+ for t in texts:
46
+ tokens = self.sp.encode(t, out_type=int)
47
+
48
+ # Truncate if needed
49
+ if truncation and max_length and len(tokens) > max_length:
50
+ tokens = tokens[:max_length]
51
+
52
+ all_input_ids.append(tokens)
53
+
54
+ # Padding
55
+ if padding or max_length:
56
+ target_length = max_length or max(len(ids) for ids in all_input_ids) if all_input_ids else 0
57
+ padded_input_ids = []
58
+ padded_attention_masks = []
59
+
60
+ for ids in all_input_ids:
61
+ pad_length = target_length - len(ids)
62
+ if pad_length > 0:
63
+ padded_ids = ids + [self.pad_token_id] * pad_length
64
+ else:
65
+ padded_ids = ids[:target_length]
66
+
67
+ padded_input_ids.append(padded_ids)
68
+ attention_mask = [1] * len(ids) + [0] * (target_length - len(ids))
69
+ padded_attention_masks.append(attention_mask)
70
+
71
+ result = {
72
+ "input_ids": padded_input_ids if not is_single else padded_input_ids[0],
73
+ }
74
+ if return_attention_mask:
75
+ result["attention_mask"] = padded_attention_masks if not is_single else padded_attention_masks[0]
76
+ else:
77
+ result = {
78
+ "input_ids": all_input_ids[0] if is_single else all_input_ids,
79
+ }
80
+ if return_attention_mask:
81
+ attention_masks = [[1] * len(ids) for ids in all_input_ids]
82
+ result["attention_mask"] = attention_masks[0] if is_single else attention_masks
83
+
84
+ return result
85
+
86
+ def encode(self, text, return_tensors=None, **kwargs):
87
+ """Encode text to token IDs."""
88
+ result = self(text, **kwargs)
89
+ input_ids = result["input_ids"]
90
+
91
+ if return_tensors == "pt":
92
+ import torch
93
+ # Ensure input_ids is a 1D list of ints
94
+ if isinstance(input_ids[0], list):
95
+ input_ids = input_ids[0]
96
+ return torch.tensor(input_ids, dtype=torch.long).unsqueeze(0)
97
+
98
+ return input_ids
99
+
100
+ def encode_plus(self, text, **kwargs):
101
+ """Encode text with additional information (HuggingFace compatibility)."""
102
+ return self(text, **kwargs)
103
+
104
+ def decode(self, token_ids, skip_special_tokens=False, **kwargs):
105
+ """Decode token IDs to text."""
106
+ if hasattr(token_ids, 'tolist'): # Handle torch tensors
107
+ token_ids = token_ids.tolist()
108
+
109
+ # Handle various input formats
110
+ if isinstance(token_ids, (list, tuple)):
111
+ if len(token_ids) > 0 and isinstance(token_ids[0], (list, tuple)):
112
+ token_ids = token_ids[0]
113
+
114
+ # Ensure it's a list of ints
115
+ if not isinstance(token_ids, list):
116
+ token_ids = [int(t) for t in token_ids]
117
+
118
+ return self.sp.decode(token_ids)
code/TaoTrain/src/taoTrain/inference/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Inference engines."""
2
+
3
+ from .inferencer import Inferencer
4
+
5
+ __all__ = ["Inferencer"]
code/TaoTrain/src/taoTrain/inference/inferencer.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Inference engine for model generation."""
2
+
3
+ from pathlib import Path
4
+ from typing import Optional, Iterator, Any
5
+ import torch
6
+ from transformers import AutoTokenizer
7
+ from rich.console import Console
8
+ from rich.table import Table
9
+
10
+ from taoTrain.core import BaseModel
11
+ from taoTrain.config import ModelConfig
12
+
13
+
14
+ class Inferencer:
15
+ """Inference engine for text generation."""
16
+
17
+ def __init__(
18
+ self,
19
+ model: BaseModel,
20
+ tokenizer: Any,
21
+ device: Optional[torch.device] = None,
22
+ dtype: Optional[torch.dtype] = None,
23
+ ):
24
+ """
25
+ Initialize inferencer.
26
+
27
+ Args:
28
+ model: Trained model
29
+ tokenizer: Tokenizer instance (HuggingFace or SentencePiece wrapped)
30
+ device: Device for inference
31
+ dtype: Data type for inference
32
+ """
33
+ self.model = model
34
+ self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
+ self.dtype = dtype or torch.float32
36
+ self.tokenizer = tokenizer
37
+
38
+ # Move model to device and set eval mode
39
+ self.model = self.model.to(self.device)
40
+ self.model.eval()
41
+
42
+ # Set pad token if needed (for HuggingFace tokenizers)
43
+ if hasattr(self.tokenizer, 'pad_token') and self.tokenizer.pad_token is None:
44
+ if hasattr(self.tokenizer, 'eos_token'):
45
+ self.tokenizer.pad_token = self.tokenizer.eos_token
46
+
47
+ @staticmethod
48
+ def _load_tokenizer(tokenizer_path: str | Path) -> Any:
49
+ """
50
+ Load tokenizer from path (SentencePiece or HuggingFace).
51
+
52
+ Args:
53
+ tokenizer_path: Path to tokenizer file or HuggingFace model name
54
+
55
+ Returns:
56
+ Tokenizer instance
57
+
58
+ Raises:
59
+ ValueError: If tokenizer cannot be loaded
60
+ """
61
+ tokenizer_path = str(tokenizer_path)
62
+
63
+ # Auto-detect tokenizer type based on file extension
64
+ if tokenizer_path.endswith('.model'):
65
+ # Load SentencePiece tokenizer
66
+ try:
67
+ import sentencepiece as spm
68
+ sp = spm.SentencePieceProcessor()
69
+ sp.Load(tokenizer_path)
70
+ # Wrap SentencePiece in a compatible interface
71
+ from taoTrain.data import SentencePieceTokenizerWrapper
72
+ return SentencePieceTokenizerWrapper(sp)
73
+ except ImportError:
74
+ raise ImportError("SentencePiece not installed. Install with: pip install sentencepiece")
75
+ except Exception as e:
76
+ raise ValueError(f"Failed to load SentencePiece tokenizer from {tokenizer_path}: {e}")
77
+ else:
78
+ # Load HuggingFace tokenizer
79
+ try:
80
+ return AutoTokenizer.from_pretrained(tokenizer_path)
81
+ except Exception as e:
82
+ raise ValueError(f"Failed to load HuggingFace tokenizer from {tokenizer_path}: {e}")
83
+
84
+ @staticmethod
85
+ def _print_tokenizer_info(tokenizer: Any, tokenizer_path: str) -> None:
86
+ """Print tokenizer information."""
87
+ console = Console()
88
+ table = Table(title="Tokenizer Information")
89
+ table.add_column("Property", style="cyan")
90
+ table.add_column("Value", style="green")
91
+
92
+ table.add_row("Type", "SentencePiece" if tokenizer_path.endswith('.model') else "HuggingFace")
93
+ table.add_row("Path", str(tokenizer_path))
94
+
95
+ if hasattr(tokenizer, 'vocab_size'):
96
+ table.add_row("Vocab Size", str(tokenizer.vocab_size))
97
+
98
+ console.print(table)
99
+
100
+ @staticmethod
101
+ def load_from_checkpoint(
102
+ checkpoint_path: str | Path,
103
+ tokenizer_path: Optional[str | Path] = None,
104
+ device: Optional[torch.device] = None,
105
+ ) -> "Inferencer":
106
+ """
107
+ Load model from checkpoint and create inferencer.
108
+
109
+ Handles both canonical and legacy checkpoint formats:
110
+ - Canonical: uses 'model_state' key
111
+ - Legacy: uses 'model_state_dict' key
112
+
113
+ Args:
114
+ checkpoint_path: Path to checkpoint file
115
+ tokenizer_path: Optional path to tokenizer (overrides checkpoint's tokenizer_path)
116
+ device: Device for inference
117
+
118
+ Returns:
119
+ Inferencer instance
120
+
121
+ Raises:
122
+ ValueError: If no tokenizer path found in checkpoint or arguments
123
+ """
124
+ if device is None:
125
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
126
+
127
+ # Load checkpoint using CheckpointManager for automatic format normalization
128
+ from taoTrain.checkpointing.checkpoint import CheckpointManager
129
+ checkpoint_manager = CheckpointManager(checkpoint_path.parent if isinstance(checkpoint_path, Path) else Path(checkpoint_path).parent)
130
+ checkpoint = checkpoint_manager.load(checkpoint_path, device=device)
131
+
132
+ config_dict = checkpoint.get("config", {})
133
+
134
+ # Extract tokenizer path from checkpoint config or use provided override
135
+ if tokenizer_path is None:
136
+ # Try to get tokenizer_path from checkpoint config
137
+ dataset_config = config_dict.get("dataset", {})
138
+ tokenizer_path = dataset_config.get("tokenizer_path")
139
+
140
+ if not tokenizer_path:
141
+ raise ValueError(
142
+ f"No tokenizer path found in checkpoint config at {checkpoint_path}. "
143
+ "Please provide --tokenizer argument with path to tokenizer file."
144
+ )
145
+
146
+ # Load tokenizer
147
+ console = Console()
148
+ console.print("\n[bold cyan]Loading tokenizer...[/bold cyan]")
149
+ tokenizer = Inferencer._load_tokenizer(tokenizer_path)
150
+ Inferencer._print_tokenizer_info(tokenizer, str(tokenizer_path))
151
+
152
+ # Reconstruct model config
153
+ from taoTrain.config import ModelConfig
154
+ model_config = ModelConfig(**config_dict.get("model", {}))
155
+
156
+ # Create and load model
157
+ # CheckpointManager.load() normalizes to 'model_state' key
158
+ from taoTrain.models import get_model
159
+ model = get_model(model_config, device=device)
160
+ model.load_state_dict(checkpoint["model_state"])
161
+
162
+ return Inferencer(model, tokenizer, device)
163
+
164
+ def generate(
165
+ self,
166
+ prompt: str,
167
+ max_length: int = 256,
168
+ temperature: float = 0.7,
169
+ top_p: float = 0.95,
170
+ top_k: Optional[int] = None,
171
+ repetition_penalty: float = 1.0,
172
+ do_sample: bool = True,
173
+ stream: bool = False,
174
+ ) -> str | Iterator[str]:
175
+ """
176
+ Generate text from a prompt.
177
+
178
+ Args:
179
+ prompt: Input prompt
180
+ max_length: Maximum generation length
181
+ temperature: Temperature for sampling
182
+ top_p: Nucleus sampling parameter
183
+ top_k: Top-k sampling parameter
184
+ repetition_penalty: Penalty for repeated tokens (1.0 = no penalty, >1.0 = penalize)
185
+ do_sample: Whether to sample or use greedy decoding
186
+ stream: Whether to stream tokens
187
+
188
+ Yields/Returns:
189
+ Generated text (or stream of tokens if stream=True)
190
+ """
191
+ # Tokenize prompt
192
+ input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
193
+ prompt_length = input_ids.shape[1]
194
+
195
+ # For streaming with full context decoding
196
+ generated_token_ids = [] # Accumulate all generated tokens
197
+ last_decoded_full = "" # Cache full decoded text from previous step
198
+
199
+ with torch.no_grad():
200
+ for step in range(max_length):
201
+ # Forward pass
202
+ outputs = self.model(
203
+ input_ids=input_ids,
204
+ attention_mask=None,
205
+ labels=None,
206
+ )
207
+
208
+ logits = outputs["logits"]
209
+
210
+ # Get logits for next token
211
+ next_logits = logits[:, -1, :] / temperature
212
+
213
+ # Apply repetition penalty to previously generated tokens
214
+ if repetition_penalty != 1.0:
215
+ generated_ids = input_ids[0, prompt_length:]
216
+ unique_ids = torch.unique(generated_ids)
217
+ for token_id in unique_ids:
218
+ next_logits[0, token_id] /= repetition_penalty
219
+
220
+ # Apply top-k and top-p sampling
221
+ if top_k is not None:
222
+ indices_to_remove = next_logits < torch.topk(next_logits, top_k)[0][..., -1, None]
223
+ next_logits[indices_to_remove] = float('-inf')
224
+
225
+ if top_p < 1.0:
226
+ sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
227
+ probs = torch.softmax(sorted_logits, dim=-1)
228
+ cumsum_probs = torch.cumsum(probs, dim=-1)
229
+
230
+ sorted_indices_to_remove = cumsum_probs > top_p
231
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
232
+ sorted_indices_to_remove[..., 0] = False
233
+
234
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
235
+ next_logits[:, indices_to_remove] = float('-inf')
236
+
237
+ # Sample or greedy
238
+ probs = torch.softmax(next_logits, dim=-1)
239
+
240
+ if do_sample:
241
+ next_token = torch.multinomial(probs, num_samples=1)
242
+ else:
243
+ next_token = torch.argmax(next_logits, dim=-1, keepdim=True)
244
+
245
+ # Append to input
246
+ input_ids = torch.cat([input_ids, next_token], dim=-1)
247
+
248
+ # Stream if requested (with full context decoding to preserve spaces)
249
+ if stream:
250
+ # Accumulate the generated token ID
251
+ generated_token_ids.append(next_token.item())
252
+ # Decode entire accumulated sequence (tokenizer has full context)
253
+ full_decoded_text = self.tokenizer.decode(generated_token_ids)
254
+ # Extract only NEW text since last yield
255
+ new_text = full_decoded_text[len(last_decoded_full):]
256
+ if new_text:
257
+ yield new_text
258
+ last_decoded_full = full_decoded_text
259
+
260
+ # Stop on EOS
261
+ if next_token.item() == self.tokenizer.eos_token_id:
262
+ break
263
+
264
+ if not stream:
265
+ # Return full generated text
266
+ generated_ids = input_ids[0, prompt_length:]
267
+ return self.tokenizer.decode(generated_ids, skip_special_tokens=True)
268
+
269
+ def count_tokens_generated(
270
+ self,
271
+ prompt: str,
272
+ max_length: int = 256,
273
+ ) -> torch.Tensor:
274
+ """
275
+ Measure generation speed (tokens per second).
276
+
277
+ Args:
278
+ prompt: Input prompt
279
+ max_length: Maximum generation length
280
+
281
+ Returns:
282
+ Number of tokens generated
283
+ """
284
+ import time
285
+
286
+ start = time.time()
287
+
288
+ # Generate (we'll just do one forward pass to measure)
289
+ input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
290
+
291
+ with torch.no_grad():
292
+ outputs = self.model(
293
+ input_ids=input_ids,
294
+ attention_mask=None,
295
+ labels=None,
296
+ )
297
+
298
+ elapsed = time.time() - start
299
+ tokens_per_sec = (input_ids.shape[1] + 1) / elapsed
300
+
301
+ return tokens_per_sec
code/TaoTrain/src/taoTrain/inference/tui.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """TUI (Terminal User Interface) for interactive chat."""
2
+
3
+ import sys
4
+ import time
5
+ from pathlib import Path
6
+ from typing import Optional
7
+ import click
8
+ from rich.console import Console
9
+ from rich.markdown import Markdown
10
+ from rich.panel import Panel
11
+ from rich.text import Text
12
+ from rich.table import Table
13
+ from textual.app import ComposeResult, RenderableType
14
+ from textual.containers import Container, Horizontal, Vertical
15
+ from textual.widgets import TextArea, Static, Button
16
+ from textual.binding import Binding
17
+
18
+ from taoTrain.inference import Inferencer
19
+
20
+
21
+ class TokensPerSecDisplay(Static):
22
+ """Display tokens per second metric."""
23
+
24
+ DEFAULT_CSS = """
25
+ TokensPerSecDisplay {
26
+ width: 100%;
27
+ height: 1;
28
+ background: $panel;
29
+ border: solid $accent;
30
+ }
31
+ """
32
+
33
+ def __init__(self, tps: float = 0.0):
34
+ """Initialize."""
35
+ super().__init__()
36
+ self.tps = tps
37
+
38
+ def render(self) -> RenderableType:
39
+ """Render TPS display."""
40
+ text = f"Tokens/sec: {self.tps:.2f}"
41
+ return Text(text, style="bold cyan")
42
+
43
+ def update_tps(self, tps: float):
44
+ """Update TPS value."""
45
+ self.tps = tps
46
+ self.update()
47
+
48
+
49
+ class SimpleChat:
50
+ """Simple CLI-based chat interface (fallback for testing)."""
51
+
52
+ def __init__(self, checkpoint_path: str | Path, tokenizer_path: Optional[str | Path] = None):
53
+ """Initialize chat."""
54
+ self.checkpoint_path = Path(checkpoint_path)
55
+ self.tokenizer_path = tokenizer_path
56
+
57
+ print("\nLoading model...")
58
+ self.inferencer = Inferencer.load_from_checkpoint(
59
+ self.checkpoint_path,
60
+ tokenizer_path=self.tokenizer_path,
61
+ )
62
+
63
+ # Print model info
64
+ console = Console()
65
+ info_table = Table(title="Model Information")
66
+ info_table.add_column("Property", style="cyan")
67
+ info_table.add_column("Value", style="green")
68
+
69
+ info_table.add_row("Checkpoint", str(self.checkpoint_path))
70
+ if self.tokenizer_path:
71
+ info_table.add_row("Tokenizer (override)", str(self.tokenizer_path))
72
+
73
+ console.print(info_table)
74
+
75
+ def run(self):
76
+ """Run chat loop."""
77
+ console = Console()
78
+
79
+ console.print("\n[bold cyan]Chat Interface[/bold cyan]")
80
+ console.print("[dim]Type 'exit' or 'quit' to exit[/dim]\n")
81
+
82
+ while True:
83
+ try:
84
+ # Get user input
85
+ prompt = input("You: ").strip()
86
+
87
+ if prompt.lower() in ["exit", "quit"]:
88
+ console.print("\n[yellow]Goodbye![/yellow]")
89
+ break
90
+
91
+ if not prompt:
92
+ continue
93
+
94
+ # Generate response
95
+ console.print("\n[bold cyan]Assistant:[/bold cyan] ", end="")
96
+
97
+ start_time = time.time()
98
+ token_count = 0
99
+
100
+ # Stream generation
101
+ for token in self.inferencer.generate(
102
+ prompt,
103
+ max_length=256,
104
+ temperature=0.7,
105
+ top_p=0.95,
106
+ repetition_penalty=10,
107
+ stream=True,
108
+ ):
109
+ console.print(token, end="", soft_wrap=True)
110
+ token_count += 1
111
+
112
+ elapsed = time.time() - start_time
113
+ tps = token_count / elapsed if elapsed > 0 else 0
114
+
115
+ console.print(f"\n\n[dim]({tps:.1f} tokens/sec, {token_count} tokens)[/dim]\n")
116
+
117
+ except KeyboardInterrupt:
118
+ console.print("\n\n[yellow]Chat interrupted.[/yellow]")
119
+ break
120
+ except Exception as e:
121
+ console.print(f"\n[red]Error: {e}[/red]\n")
122
+
123
+
124
+ @click.command()
125
+ @click.option(
126
+ "--model",
127
+ type=click.Path(exists=True),
128
+ required=True,
129
+ help="Path to model checkpoint (.pt file)",
130
+ )
131
+ @click.option(
132
+ "--tokenizer",
133
+ type=click.Path(exists=True),
134
+ required=False,
135
+ default=None,
136
+ help="Path to tokenizer file (.model or HuggingFace path). If not provided, uses tokenizer_path from checkpoint config.",
137
+ )
138
+ def main(model: str, tokenizer: Optional[str]):
139
+ """
140
+ Interactive TUI chat with a trained model.
141
+
142
+ Example:
143
+ tui-chat --model checkpoints/best_model.pt
144
+ tui-chat --model checkpoints/best_model.pt --tokenizer path/to/tokenizer.model
145
+ """
146
+ try:
147
+ chat = SimpleChat(model, tokenizer_path=tokenizer)
148
+ chat.run()
149
+ except FileNotFoundError:
150
+ click.echo(f"Error: Model file not found: {model}", err=True)
151
+ sys.exit(1)
152
+ except ValueError as e:
153
+ click.echo(f"Error: {e}", err=True)
154
+ sys.exit(1)
155
+ except Exception as e:
156
+ click.echo(f"Error: {e}", err=True)
157
+ sys.exit(1)
158
+
159
+
160
+ if __name__ == "__main__":
161
+ main() # type: ignore
code/TaoTrain/src/taoTrain/logging/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Logging integrations."""
2
+
3
+ from .aim_logger import AimLogger
4
+
5
+ __all__ = ["AimLogger"]
code/TaoTrain/src/taoTrain/logging/aim_logger.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """AimStack logging integration."""
2
+
3
+ from pathlib import Path
4
+ from typing import Dict, Any, Optional
5
+ import subprocess
6
+ import json
7
+ from datetime import datetime
8
+
9
+ try:
10
+ from aim import Run
11
+ HAS_AIM = True
12
+ except ImportError:
13
+ HAS_AIM = False
14
+
15
+ from taoTrain.config import TrainingConfig
16
+
17
+
18
+ class AimLogger:
19
+ """AimStack logger for tracking training metrics and hyperparameters."""
20
+
21
+ def __init__(self, config: TrainingConfig):
22
+ """
23
+ Initialize AimStack logger.
24
+
25
+ Args:
26
+ config: Training configuration
27
+ """
28
+ self.config = config
29
+ self.run: Optional[Run] = None
30
+
31
+ if HAS_AIM:
32
+ # Initialize AimStack run
33
+ repo_path = Path(config.aim_repo)
34
+ repo_path.mkdir(parents=True, exist_ok=True)
35
+
36
+ self.run = Run(repo=str(repo_path))
37
+
38
+ # Log hyperparameters
39
+ self._log_hyperparameters()
40
+ else:
41
+ print("Warning: AimStack not installed. Install with: pip install aim")
42
+
43
+ def _log_hyperparameters(self):
44
+ """Log hyperparameters to AimStack."""
45
+ if self.run is None:
46
+ return
47
+
48
+ # Log model config
49
+ self.run["hparams/model"] = {
50
+ "architecture": self.config.model.architecture_type.value,
51
+ "vocab_size": self.config.model.vocab_size,
52
+ "hidden_dim": self.config.model.hidden_dim,
53
+ "num_layers": self.config.model.num_layers,
54
+ "num_heads": self.config.model.num_heads,
55
+ "dropout": self.config.model.dropout,
56
+ "max_seq_length": self.config.model.max_seq_length,
57
+ }
58
+
59
+ # Log training config
60
+ self.run["hparams/training"] = {
61
+ "batch_size": self.config.batch_size,
62
+ "num_epochs": self.config.num_epochs,
63
+ "learning_rate": self.config.optimizer.learning_rate,
64
+ "weight_decay": self.config.optimizer.weight_decay,
65
+ "gradient_accumulation_steps": self.config.gradient_accumulation_steps,
66
+ "max_grad_norm": self.config.max_grad_norm,
67
+ "dtype": self.config.dtype.value,
68
+ "seed": self.config.seed,
69
+ }
70
+
71
+ # Log optimizer and scheduler config
72
+ self.run["hparams/optimizer"] = {
73
+ "optimizer_type": self.config.optimizer.optimizer_type.value,
74
+ "learning_rate": self.config.optimizer.learning_rate,
75
+ "weight_decay": self.config.optimizer.weight_decay,
76
+ }
77
+
78
+ self.run["hparams/scheduler"] = {
79
+ "scheduler_type": self.config.scheduler.scheduler_type.value,
80
+ "warmup_steps": self.config.scheduler.warmup_steps,
81
+ "warmup_ratio": self.config.scheduler.warmup_ratio,
82
+ }
83
+
84
+ # Log dataset config
85
+ self.run["hparams/dataset"] = {
86
+ "dataset_name": self.config.dataset.dataset_name,
87
+ "split": self.config.dataset.split,
88
+ "max_samples": self.config.dataset.max_samples,
89
+ }
90
+
91
+ # Log mode
92
+ self.run["hparams/mode"] = self.config.mode.value
93
+
94
+ # Log git hash if available
95
+ try:
96
+ git_hash = subprocess.check_output(
97
+ ["git", "rev-parse", "HEAD"],
98
+ stderr=subprocess.DEVNULL
99
+ ).decode().strip()
100
+ self.run["hparams/git_hash"] = git_hash
101
+ except:
102
+ pass
103
+
104
+ # Log timestamp
105
+ self.run["hparams/timestamp"] = datetime.now().isoformat()
106
+
107
+ def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None):
108
+ """
109
+ Log metrics to AimStack.
110
+
111
+ Args:
112
+ metrics: Dict of metric names to values
113
+ step: Global step (optional, auto-increments if not provided)
114
+ """
115
+ if self.run is None:
116
+ return
117
+
118
+ step = metrics.pop("step", step)
119
+
120
+ for metric_name, metric_value in metrics.items():
121
+ # Flatten nested dicts
122
+ if isinstance(metric_value, dict):
123
+ for nested_key, nested_val in metric_value.items():
124
+ self.run.track(
125
+ float(nested_val),
126
+ name=f"{metric_name}/{nested_key}",
127
+ step=step,
128
+ )
129
+ else:
130
+ try:
131
+ self.run.track(
132
+ float(metric_value),
133
+ name=metric_name,
134
+ step=step,
135
+ )
136
+ except (ValueError, TypeError):
137
+ # Skip non-numeric metrics
138
+ pass
139
+
140
+ def log_text(self, name: str, value: str, step: Optional[int] = None):
141
+ """Log text content."""
142
+ if self.run is None:
143
+ return
144
+
145
+ # AimStack doesn't have direct text logging, use metadata
146
+ metadata = getattr(self.run, '_metadata', {})
147
+ if isinstance(metadata, dict):
148
+ metadata[name] = value
149
+
150
+ def finish(self):
151
+ """Finish the run."""
152
+ if self.run:
153
+ self.run.close()
code/TaoTrain/src/taoTrain/models/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Model architectures and registry."""
2
+
3
+ from .registry import get_model, register_architecture
4
+
5
+ __all__ = ["get_model", "register_architecture"]
code/TaoTrain/src/taoTrain/models/embeddings.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Low-Rank Factorized Embedding.
3
+
4
+ Uses standard nn.Linear for projection (NOT ternary quantization).
5
+ Embeddings should use full precision for good token representations.
6
+ """
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+
12
+ class FactorizedEmbedding(nn.Module):
13
+ """
14
+ Low-Rank Factorized Embedding: vocab → d_embed_rank → d_model
15
+
16
+ Uses standard Linear layers (no quantization) for full precision.
17
+ Reduces embedding parameters from vocab_size × d_model to:
18
+ vocab_size × d_embed_rank + d_embed_rank × d_model
19
+ """
20
+
21
+ def __init__(self, vocab_size, d_model, d_embed_rank=96):
22
+ super().__init__()
23
+ self.vocab_size = vocab_size
24
+ self.d_model = d_model
25
+ self.d_embed_rank = d_embed_rank
26
+
27
+ # Embedding table: vocab → compressed rank
28
+ self.embed = nn.Embedding(vocab_size, d_embed_rank)
29
+
30
+ # Projection: compressed → full (standard Linear)
31
+ self.proj = nn.Linear(d_embed_rank, d_model, bias=False)
32
+
33
+ # Initialize with small weights for stable training
34
+ nn.init.normal_(self.embed.weight, mean=0.0, std=0.02)
35
+ nn.init.normal_(self.proj.weight, mean=0.0, std=0.02)
36
+
37
+ def forward(self, input_ids):
38
+ """
39
+ Args:
40
+ input_ids: [batch_size, seq_len] tensor of token IDs
41
+
42
+ Returns:
43
+ embeddings: [batch_size, seq_len, d_model]
44
+ """
45
+ x = self.embed(input_ids) # [B, S, d_embed_rank]
46
+ x = self.proj(x) # [B, S, d_model]
47
+ return x
48
+
49
+ def get_num_params(self):
50
+ """Return total number of parameters."""
51
+ return self.vocab_size * self.d_embed_rank + self.d_embed_rank * self.d_model
code/TaoTrain/src/taoTrain/models/mla_components.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DeepSeek-style Multi-head Latent Attention (MLA) with RoPE.
3
+
4
+ Key innovations:
5
+ 1. KV compression to latent space (reduce KV memory)
6
+ 2. Q stays in full dimension for expressive query space
7
+ 3. RoPE positional embeddings on Q and K
8
+ 4. Grouped Query Attention (GQA) for efficiency
9
+ 5. Learnable head combination weights
10
+ 6. Numerical stability via pre-norm and scaling
11
+ """
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ import math
17
+
18
+
19
+ def _residual_rms_norm(x, enabled=False, target=1.0, eps=1e-6, cap=None):
20
+ if not enabled and cap is None:
21
+ return x
22
+ rms = x.float().square().mean(dim=-1, keepdim=True).add(eps).sqrt()
23
+ if enabled:
24
+ scale = target / rms
25
+ else:
26
+ cap_tensor = torch.tensor(float(cap), dtype=rms.dtype, device=rms.device)
27
+ scale = torch.minimum(torch.ones_like(rms), cap_tensor / rms)
28
+ return x * scale.to(dtype=x.dtype)
29
+
30
+
31
+ class RotaryEmbedding(nn.Module):
32
+ """Rotary position embeddings used in RoPE with optional YaRN extension.
33
+
34
+ YaRN (Yet another RoPE eXtension) allows context length interpolation via
35
+ frequency scaling. When yarn_alpha != 1.0 or seq_len > max_seq_length,
36
+ frequencies are dynamically scaled to support longer sequences.
37
+
38
+ Parameters:
39
+ dim: Embedding dimension (must be even)
40
+ rope_scale: Base RoPE scale factor (default: 40)
41
+ max_seq_length: Original trained sequence length (default: 1024)
42
+ yarn_alpha: YaRN interpolation factor (default: 1.0, no interpolation)
43
+ - values < 1.0: aggressive interpolation (faster context expansion)
44
+ - values > 1.0: conservative interpolation (safer)
45
+ """
46
+
47
+ def __init__(self, dim, rope_scale=40.0, max_seq_length=1024, yarn_alpha=1.0):
48
+ super().__init__()
49
+ assert dim % 2 == 0, "Dimension must be even for rotary embeddings"
50
+ self.dim = dim
51
+ self.rope_scale = rope_scale
52
+ self.max_seq_length = max_seq_length
53
+ self.yarn_alpha = yarn_alpha
54
+
55
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
56
+ self.register_buffer("inv_freq", inv_freq)
57
+
58
+ def _apply_yarn_scaling(self, freqs, seq_len):
59
+ """Apply YaRN frequency scaling for context extension.
60
+
61
+ Args:
62
+ freqs: [seq_len, dim] frequency tensor
63
+ seq_len: Current sequence length
64
+
65
+ Returns:
66
+ Scaled freqs if yarn is enabled and seq_len > max_seq_length, else original freqs
67
+ """
68
+ # Only apply scaling if sequence exceeds training length or yarn_alpha != 1.0
69
+ if self.yarn_alpha == 1.0 and seq_len <= self.max_seq_length:
70
+ return freqs
71
+
72
+ # YaRN scaling factor: interpolate frequency reduction
73
+ # scale_factor = (seq_len / max_seq_length) ** (1 / yarn_alpha)
74
+ # Scales down frequencies to fit longer context while maintaining position distinctions
75
+ scale_factor = (seq_len / self.max_seq_length) ** (1.0 / self.yarn_alpha)
76
+ freqs = freqs / scale_factor
77
+ return freqs
78
+
79
+ def forward(self, seq_len, device):
80
+ """Generate rotary embeddings for sequence with optional YaRN scaling.
81
+
82
+ Args:
83
+ seq_len: Current sequence length
84
+ device: Device to create embeddings on
85
+
86
+ Returns:
87
+ [seq_len, 2*dim] rotary embeddings (duplicated freqs)
88
+ """
89
+ t = torch.arange(seq_len, device=device).type_as(self.inv_freq) / self.rope_scale
90
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq) # [seq_len, dim//2]
91
+
92
+ # Apply YaRN frequency scaling if enabled
93
+ freqs = self._apply_yarn_scaling(freqs, seq_len)
94
+
95
+ return torch.cat((freqs, freqs), dim=-1) # [seq_len, dim]
96
+
97
+
98
+ def rotate_half(x):
99
+ """Rotate half the hidden dims of the input."""
100
+ x1, x2 = x.chunk(2, dim=-1)
101
+ return torch.cat((-x2, x1), dim=-1)
102
+
103
+
104
+ def apply_rotary(x, cos, sin):
105
+ """Apply rotary embeddings to input tensor.
106
+
107
+ Args:
108
+ x: [B, n_heads, seq_len, head_dim] or similar
109
+ cos: [seq_len, head_dim] or [1, 1, seq_len, head_dim]
110
+ sin: [seq_len, head_dim] or [1, 1, seq_len, head_dim]
111
+ """
112
+ # Ensure cos/sin have the right dimensions for broadcasting
113
+ if cos.dim() == 2:
114
+ cos = cos.unsqueeze(0).unsqueeze(0)
115
+ sin = sin.unsqueeze(0).unsqueeze(0)
116
+
117
+ # Handle case where cos/sin may be shorter than x
118
+ cos = cos[..., :x.shape[-1]]
119
+ sin = sin[..., :x.shape[-1]]
120
+
121
+ # Split x based on cos dimensions
122
+ x_rot = x[..., :cos.shape[-1]]
123
+ x_base = x[..., cos.shape[-1]:]
124
+
125
+ # Apply rotation
126
+ x_rot = (x_rot * cos) + (rotate_half(x_rot) * sin)
127
+
128
+ # Concatenate rotated and base parts
129
+ return torch.cat([x_rot, x_base], dim=-1) if x_base.shape[-1] > 0 else x_rot
130
+
131
+
132
+ class DeepSeekMLA(nn.Module):
133
+ """
134
+ DeepSeek-style Multi-head Latent Attention (MLA).
135
+
136
+ Architecture:
137
+ 1. Project input to Query: [B, seq_len, d_model] -> [B, seq_len, d_model]
138
+ 2. Compress to KV latent: [B, seq_len, d_model] -> [B, seq_len, d_latent_kv]
139
+ 3. Split into heads for attention
140
+ 4. Apply RoPE to Q and K
141
+ 5. Compute attention scores: (Q @ K^T) / sqrt(d_head)
142
+ 6. Apply softmax and combine with values
143
+ 7. Concatenate heads and project back to d_model
144
+
145
+ Parameters:
146
+ d_model: Model dimension
147
+ d_latent_kv: Latent dimension for KV compression
148
+ n_heads: Number of attention heads
149
+ d_rope: Dimension for RoPE (usually == d_head_dim)
150
+ dropout: Dropout probability
151
+ gqa_groups: Grouped Query Attention groups (1 = standard MLA, >1 = GQA)
152
+ """
153
+
154
+ def __init__(self, d_model, d_latent_kv, n_heads, d_rope, dropout=0.1, gqa_groups=1,
155
+ rope_scale=40.0, max_seq_length=1024, yarn_alpha=1.0):
156
+ super().__init__()
157
+ self.d_model = d_model
158
+ self.d_latent_kv = d_latent_kv
159
+ self.n_heads = n_heads
160
+ self.d_rope = d_rope
161
+ self.gqa_groups = gqa_groups
162
+
163
+ assert d_model % n_heads == 0, f"d_model ({d_model}) must be divisible by n_heads ({n_heads})"
164
+ assert d_latent_kv % n_heads == 0, f"d_latent_kv ({d_latent_kv}) must be divisible by n_heads ({n_heads})"
165
+
166
+ self.d_head_full = d_model // n_heads # Full head dimension for Q
167
+ self.d_head_latent = d_latent_kv // n_heads # Latent head dimension for K/V
168
+
169
+ # Scaling factor for attention scores
170
+ self.scale = 1.0 / math.sqrt(self.d_head_latent)
171
+
172
+ # Layer norm before attention for stability
173
+ self.norm = nn.LayerNorm(d_model)
174
+
175
+ # Q projection: d_model -> d_model (full dimension)
176
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
177
+
178
+ # K/V projections: d_model -> d_latent_kv (compressed)
179
+ self.k_proj = nn.Linear(d_model, d_latent_kv, bias=False)
180
+ self.v_proj = nn.Linear(d_model, d_latent_kv, bias=False)
181
+
182
+ # RoPE for position encoding with YaRN support
183
+ self.rotary = RotaryEmbedding(
184
+ d_rope,
185
+ rope_scale=rope_scale,
186
+ max_seq_length=max_seq_length,
187
+ yarn_alpha=yarn_alpha
188
+ )
189
+
190
+ # Output projection: d_latent_kv -> d_model
191
+ self.out_proj = nn.Linear(d_latent_kv, d_model, bias=False)
192
+
193
+ # Head combination weights (learnable scaling per head)
194
+ self.head_weights = nn.Parameter(torch.ones(n_heads))
195
+
196
+ # Dropout
197
+ self.attn_dropout = nn.Dropout(dropout)
198
+ self.proj_dropout = nn.Dropout(dropout)
199
+
200
+ def forward(self, x, attention_mask=None):
201
+ """
202
+ Args:
203
+ x: [B, seq_len, d_model]
204
+ attention_mask: [B, seq_len] (1 = keep, 0 = mask) or
205
+ [B, 1, seq_len, seq_len] (causal mask)
206
+
207
+ Returns:
208
+ out: [B, seq_len, d_model]
209
+ """
210
+ B, seq_len, _ = x.shape
211
+ device = x.device
212
+
213
+ # Pre-norm
214
+ x_norm = self.norm(x)
215
+
216
+ # Project to Q, K, V spaces
217
+ q = self.q_proj(x_norm) # [B, seq_len, d_model]
218
+ k = self.k_proj(x_norm) # [B, seq_len, d_latent_kv]
219
+ v = self.v_proj(x_norm) # [B, seq_len, d_latent_kv]
220
+
221
+ # ────────────────────────────────────────────────────────────────────────
222
+ # Reshape into multi-head format
223
+ # ────────────────────────────────────────────────────────────────────────
224
+ # Q: [B, seq_len, d_model] -> [B, seq_len, n_heads, d_head_full] -> [B, n_heads, seq_len, d_head_full]
225
+ q = q.view(B, seq_len, self.n_heads, self.d_head_full).transpose(1, 2)
226
+
227
+ # K: [B, seq_len, d_latent_kv] -> [B, seq_len, n_heads, d_head_latent] -> [B, n_heads, seq_len, d_head_latent]
228
+ k = k.view(B, seq_len, self.n_heads, self.d_head_latent).transpose(1, 2)
229
+
230
+ # V: [B, seq_len, d_latent_kv] -> [B, seq_len, n_heads, d_head_latent] -> [B, n_heads, seq_len, d_head_latent]
231
+ v = v.view(B, seq_len, self.n_heads, self.d_head_latent).transpose(1, 2)
232
+
233
+ # ────────────────────────────────────────────────────────────────────────
234
+ # Apply RoPE to Q and K
235
+ # ─────────────────────────────────��──────────────────────────────────────
236
+ if self.d_rope > 0:
237
+ # Generate RoPE embeddings: [seq_len, d_rope]
238
+ rotary_emb = self.rotary(seq_len, device) # [seq_len, d_rope]
239
+ cos = torch.cos(rotary_emb).unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, d_rope]
240
+ sin = torch.sin(rotary_emb).unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, d_rope]
241
+
242
+ # Apply RoPE to Q (only on first d_rope dimensions)
243
+ q_rope = apply_rotary(q[..., :self.d_rope], cos, sin) # [B, n_heads, seq_len, d_rope]
244
+ q = torch.cat([q_rope, q[..., self.d_rope:]], dim=-1) # Combine with remaining dims
245
+
246
+ # Apply RoPE to K (only on first d_rope dimensions)
247
+ k_rope = apply_rotary(k[..., :self.d_rope], cos, sin) # [B, n_heads, seq_len, d_rope]
248
+ k = torch.cat([k_rope, k[..., self.d_rope:]], dim=-1) # Combine with remaining dims
249
+
250
+ # ────────────────────────────────────────────────────────────────────────
251
+ # Compute attention using PyTorch 2.0+ fused scaled_dot_product_attention
252
+ # ────────────────────────────────────────────────────────────────────────
253
+ # Only use first d_head_latent dimensions of Q for attention
254
+ # K and V are already d_head_latent dimension
255
+ q_for_attn = q[..., :self.d_head_latent] # [B, n_heads, seq_len, d_head_latent]
256
+
257
+ # Convert attention mask to boolean format for scaled_dot_product_attention
258
+ # Input mask: 0 = mask (don't attend), 1 = keep (attend)
259
+ # Boolean mask: False = mask, True = attend
260
+ attn_mask_bool = None
261
+ if attention_mask is not None:
262
+ if attention_mask.dim() == 2:
263
+ # [B, seq_len] with {0, 1} -> [B, 1, 1, seq_len] with {False, True}
264
+ attn_mask_bool = attention_mask.bool().unsqueeze(1).unsqueeze(1)
265
+ else:
266
+ # Already 4D [B, 1, seq_len, seq_len], just convert to bool
267
+ attn_mask_bool = attention_mask.bool()
268
+
269
+ # Get dropout probability (0.0 when not training)
270
+ dropout_p = self.attn_dropout.p if self.training else 0.0
271
+
272
+ if hasattr(F, "scaled_dot_product_attention"):
273
+ # Apply fused attention operation when available.
274
+ out_heads = F.scaled_dot_product_attention(
275
+ q_for_attn, k, v,
276
+ attn_mask=attn_mask_bool,
277
+ dropout_p=dropout_p,
278
+ scale=None
279
+ ) # [B, n_heads, seq_len, d_head_latent]
280
+ else:
281
+ scores = torch.matmul(q_for_attn, k.transpose(-2, -1)) * self.scale
282
+ if attn_mask_bool is not None:
283
+ scores = scores.masked_fill(~attn_mask_bool, torch.finfo(scores.dtype).min)
284
+ attn_weights = F.softmax(scores, dim=-1)
285
+ if dropout_p > 0.0:
286
+ attn_weights = F.dropout(attn_weights, p=dropout_p, training=True)
287
+ out_heads = torch.matmul(attn_weights, v)
288
+
289
+ # ────────────────────────────────────────────────────────────────────────
290
+ # Concatenate heads
291
+ # ────────────────────────────────────────────────────────────────────────
292
+ # [B, seq_len, n_heads, d_head_latent] -> [B, seq_len, d_latent_kv]
293
+ out_concat = out_heads.transpose(1, 2).reshape(B, seq_len, self.d_latent_kv)
294
+
295
+ # Project back to d_model
296
+ out = self.out_proj(out_concat) # [B, seq_len, d_model]
297
+ out = self.proj_dropout(out)
298
+
299
+ return out
300
+
301
+
302
+ class AttentionBlock(nn.Module):
303
+ """
304
+ Attention block with pre-norm residual connection and feed-forward network.
305
+
306
+ Structure:
307
+ Input
308
+ ├─> Norm ─┬─> MLA ──┬─> Residual Add
309
+ │ └────────┘
310
+ ├────────────────────────────────────> Norm ─┬─> SwiGLU FFN ──┬─> Residual Add
311
+ │ └───────┘ │
312
+ └────────────────────────────────────────────────────────────> Output
313
+ """
314
+
315
+ def __init__(self, d_model, d_latent_kv, n_heads, d_rope, d_ff, dropout=0.1, gqa_groups=1,
316
+ rope_scale=40.0, max_seq_length=1024, yarn_alpha=1.0,
317
+ residual_rms_norm=False, residual_rms_target=1.0, residual_rms_cap=None,
318
+ residual_rms_eps=1e-6):
319
+ super().__init__()
320
+ self.residual_rms_norm = residual_rms_norm
321
+ self.residual_rms_target = residual_rms_target
322
+ self.residual_rms_cap = residual_rms_cap
323
+ self.residual_rms_eps = residual_rms_eps
324
+ self.mla = DeepSeekMLA(d_model, d_latent_kv, n_heads, d_rope, dropout, gqa_groups,
325
+ rope_scale=rope_scale, max_seq_length=max_seq_length,
326
+ yarn_alpha=yarn_alpha)
327
+
328
+ # SwiGLU feed-forward network
329
+ self.ff_norm = nn.LayerNorm(d_model)
330
+ self.ff_gate = nn.Linear(d_model, d_ff, bias=False)
331
+ self.ff_value = nn.Linear(d_model, d_ff, bias=False)
332
+ self.ff_out = nn.Linear(d_ff, d_model, bias=False)
333
+ self.dropout = nn.Dropout(dropout)
334
+
335
+ def forward(self, x, attention_mask=None):
336
+ """
337
+ Args:
338
+ x: [B, seq_len, d_model]
339
+ attention_mask: [B, seq_len] or [B, 1, seq_len, seq_len]
340
+
341
+ Returns:
342
+ out: [B, seq_len, d_model]
343
+ """
344
+ # Attention with residual
345
+ attn_out = self.mla(x, attention_mask)
346
+ x = x + self.dropout(attn_out)
347
+ x = _residual_rms_norm(
348
+ x,
349
+ self.residual_rms_norm,
350
+ self.residual_rms_target,
351
+ self.residual_rms_eps,
352
+ self.residual_rms_cap,
353
+ )
354
+
355
+ # FFN with residual
356
+ ff_norm = self.ff_norm(x)
357
+ ff_gate = self.ff_gate(ff_norm)
358
+ ff_value = self.ff_value(ff_norm)
359
+ ff_out = ff_value * F.silu(ff_gate) # SwiGLU activation
360
+ ff_out = self.ff_out(ff_out)
361
+ x = x + self.dropout(ff_out)
362
+ x = _residual_rms_norm(
363
+ x,
364
+ self.residual_rms_norm,
365
+ self.residual_rms_target,
366
+ self.residual_rms_eps,
367
+ self.residual_rms_cap,
368
+ )
369
+
370
+ return x
code/TaoTrain/src/taoTrain/models/registry.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Model architecture registry and factory."""
2
+
3
+ from typing import Dict, Type, Optional
4
+ import torch
5
+ from taoTrain.core import BaseModel
6
+ from taoTrain.config import ModelConfig
7
+
8
+
9
+ # Global registry for model architectures
10
+ _ARCHITECTURE_REGISTRY: Dict[str, Type[BaseModel]] = {}
11
+
12
+
13
+ def register_architecture(name: str):
14
+ """Decorator to register a custom model architecture."""
15
+ def decorator(cls: Type[BaseModel]):
16
+ if name in _ARCHITECTURE_REGISTRY:
17
+ raise ValueError(f"Architecture '{name}' is already registered")
18
+ _ARCHITECTURE_REGISTRY[name] = cls
19
+ return cls
20
+ return decorator
21
+
22
+
23
+ def get_registered_architectures() -> Dict[str, Type[BaseModel]]:
24
+ """Get all registered architectures."""
25
+ return _ARCHITECTURE_REGISTRY.copy()
26
+
27
+
28
+ def get_model(
29
+ config: ModelConfig,
30
+ device: Optional[torch.device] = None,
31
+ ) -> BaseModel:
32
+ """
33
+ Create a model instance from config.
34
+
35
+ Args:
36
+ config: ModelConfig instance
37
+ device: Device to create model on (defaults to CPU)
38
+
39
+ Returns:
40
+ Model instance
41
+ """
42
+ if device is None:
43
+ device = torch.device('cpu')
44
+
45
+ # Handle both enum and string values
46
+ arch_type = config.architecture_type
47
+ if isinstance(arch_type, str):
48
+ arch_name = arch_type
49
+ else:
50
+ arch_name = arch_type.value
51
+
52
+ if arch_name not in _ARCHITECTURE_REGISTRY:
53
+ raise ValueError(
54
+ f"Unknown architecture: {arch_name}. "
55
+ f"Available: {list(_ARCHITECTURE_REGISTRY.keys())}"
56
+ )
57
+
58
+ model_class = _ARCHITECTURE_REGISTRY[arch_name]
59
+ model = model_class(config).to(device)
60
+
61
+ return model
62
+
63
+
64
+ def register_builtin_architectures():
65
+ """Register all built-in architectures."""
66
+ # Import here to register (avoid circular imports)
67
+ from . import transformer # noqa: F401
68
+ from . import taonet # noqa: F401
69
+ from . import taonet_ssm # noqa: F401
70
+
71
+
72
+ # Auto-register built-in architectures when module is imported
73
+ register_builtin_architectures()
code/TaoTrain/src/taoTrain/models/taonet.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SimpleLLM - Pure Attention-based Language Model with DeepSeek MLA + RoPE.
3
+
4
+ Architecture:
5
+ - Token Embedding → Attention Blocks → Output Head
6
+ - Attention Blocks: Multi-head Latent Attention with RoPE positional embeddings
7
+ - Feed-forward: SwiGLU gates
8
+ - No state-space models (SSM), pure transformer architecture
9
+ - Full BF16 precision (no quantization)
10
+ """
11
+
12
+ import math
13
+ from typing import Optional
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+
18
+ from taoTrain.core import BaseModel
19
+ from taoTrain.config import ModelConfig
20
+ from .registry import register_architecture
21
+ from .mla_components import AttentionBlock
22
+ from .embeddings import FactorizedEmbedding
23
+
24
+
25
+ @register_architecture("taonet")
26
+ class SimpleLLM(BaseModel):
27
+ """
28
+ Pure attention-based language model with DeepSeek MLA + RoPE.
29
+
30
+ Stateless architecture - no internal state management needed.
31
+
32
+ Args:
33
+ config: ModelConfig with:
34
+ - vocab_size: Vocabulary size
35
+ - hidden_dim: Model dimension (d_model)
36
+ - hidden_dim_ff: Feed-forward dimension (default: 4 * hidden_dim)
37
+ - num_layers: Number of attention blocks (n_layers)
38
+ - num_heads: Number of attention heads (n_attn_heads)
39
+ - d_latent_kv: KV compression dimension (default: 3/4 * hidden_dim)
40
+ - d_rope: RoPE dimension per head (default: hidden_dim // num_heads)
41
+ - max_seq_length: Maximum sequence length
42
+ - dropout: Dropout rate
43
+ - gqa_groups: Grouped Query Attention groups (default: 1)
44
+ - use_factorized_embedding: Use low-rank embedding (default: False)
45
+ """
46
+
47
+ def __init__(self, config: ModelConfig):
48
+ super().__init__(config)
49
+
50
+ # Parse config - use defaults if not specified
51
+ self.vocab_size = config.vocab_size
52
+ self.d_model = config.hidden_dim
53
+ self.n_layers = config.num_layers
54
+ self.n_heads = config.num_heads
55
+ self.dropout = config.dropout
56
+
57
+ # Optional parameters with smart defaults
58
+ self.d_latent_kv = config.d_latent_kv if config.d_latent_kv is not None else int(self.d_model * 0.75)
59
+ self.d_rope = config.d_rope if config.d_rope is not None else (self.d_model // self.n_heads)
60
+ self.d_ff = config.hidden_dim_ff if config.hidden_dim_ff is not None else (self.d_model * 4)
61
+ self.gqa_groups = getattr(config, 'gqa_groups', 1)
62
+ self.use_factorized_embedding = getattr(config, 'use_factorized_embedding', False)
63
+ self.d_embed_rank = getattr(config, 'd_embed_rank', 96)
64
+
65
+ # YaRN parameters for context length extension
66
+ self.rope_scale = getattr(config, 'rope_scale', 40.0)
67
+ self.yarn_enabled = getattr(config, 'yarn_enabled', False)
68
+ self.yarn_alpha = getattr(config, 'yarn_alpha', 1.0)
69
+ self.max_seq_length = config.max_seq_length
70
+
71
+ # Validate dimensions
72
+ assert self.d_model % self.n_heads == 0, \
73
+ f"hidden_dim ({self.d_model}) must be divisible by num_heads ({self.n_heads})"
74
+ assert self.d_latent_kv % self.n_heads == 0, \
75
+ f"d_latent_kv ({self.d_latent_kv}) must be divisible by num_heads ({self.n_heads})"
76
+
77
+ # Token embedding
78
+ if self.use_factorized_embedding:
79
+ self.token_embedding = FactorizedEmbedding(
80
+ self.vocab_size,
81
+ self.d_model,
82
+ self.d_embed_rank
83
+ )
84
+ else:
85
+ self.token_embedding = nn.Embedding(self.vocab_size, self.d_model)
86
+
87
+ # Embedding dropout
88
+ self.embedding_dropout = nn.Dropout(self.dropout)
89
+
90
+ # Attention blocks with MLA + SwiGLU FFN
91
+ self.blocks = nn.ModuleList()
92
+ for _ in range(self.n_layers):
93
+ self.blocks.append(
94
+ AttentionBlock(
95
+ d_model=self.d_model,
96
+ d_latent_kv=self.d_latent_kv,
97
+ n_heads=self.n_heads,
98
+ d_rope=self.d_rope,
99
+ d_ff=int(self.d_ff),
100
+ dropout=self.dropout,
101
+ gqa_groups=self.gqa_groups,
102
+ rope_scale=self.rope_scale,
103
+ max_seq_length=self.max_seq_length,
104
+ yarn_alpha=self.yarn_alpha,
105
+ )
106
+ )
107
+
108
+ # Final layer norm
109
+ self.final_norm = nn.LayerNorm(self.d_model)
110
+
111
+ # Output projection to vocabulary
112
+ self.output_head = nn.Linear(self.d_model, self.vocab_size, bias=False)
113
+
114
+ # Initialize weights
115
+ self.apply(self._init_weights)
116
+
117
+ # Cache for causal mask
118
+ self.register_buffer("causal_mask_cache", None, persistent=False)
119
+
120
+ self._print_architecture()
121
+
122
+ def _init_weights(self, module):
123
+ """Initialize weights for stable training."""
124
+ if isinstance(module, nn.Linear):
125
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
126
+ if module.bias is not None:
127
+ nn.init.zeros_(module.bias)
128
+ elif isinstance(module, nn.Embedding):
129
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
130
+
131
+ def _print_architecture(self):
132
+ """Print model architecture summary."""
133
+ total_params = sum(p.numel() for p in self.parameters())
134
+ trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
135
+
136
+ print(f"\n{'='*70}")
137
+ print("MODEL ARCHITECTURE - TAОNET (DeepSeek MLA + RoPE)")
138
+ print(f"{'='*70}")
139
+ print(f"Embedding:")
140
+ if self.use_factorized_embedding:
141
+ embed_rank_params = self.vocab_size * self.d_embed_rank
142
+ embed_proj_params = self.d_embed_rank * self.d_model
143
+ print(f" Type: Factorized (rank={self.d_embed_rank})")
144
+ print(f" Rank layer: {embed_rank_params/1e6:>8.2f}M")
145
+ print(f" Projection: {embed_proj_params/1e6:>8.2f}M")
146
+ else:
147
+ embed_params = self.vocab_size * self.d_model
148
+ print(f" Type: Standard")
149
+ print(f" Params: {embed_params/1e6:>8.2f}M")
150
+
151
+ output_params = self.d_model * self.vocab_size
152
+ print(f"Output Head: {output_params/1e6:>8.2f}M")
153
+ print(f"Attention Blocks: {len(self.blocks):>10} layers × AttentionBlock")
154
+ print(f"{'─'*70}")
155
+ print(f"Total Parameters: {total_params/1e6:>8.2f}M (trainable: {trainable_params/1e6:.2f}M)")
156
+ print(f"{'─'*70}")
157
+ print(f"Configuration:")
158
+ print(f" Model dimension (d_model): {self.d_model}")
159
+ print(f" KV latent dimension (d_latent_kv): {self.d_latent_kv}")
160
+ print(f" Attention heads: {self.n_heads}")
161
+ print(f" Head dimension: {self.d_model // self.n_heads}")
162
+ print(f" RoPE dimension: {self.d_rope}")
163
+ print(f" Feed-forward dimension: {int(self.d_ff)}")
164
+ print(f" Number of layers: {self.n_layers}")
165
+ print(f" Max sequence length: {self.config.max_seq_length}")
166
+ print(f" Dropout: {self.dropout}")
167
+ print(f" GQA groups: {self.gqa_groups}")
168
+ print(f"{'='*70}\n")
169
+
170
+ def _get_causal_mask(self, seq_len, device):
171
+ """Get or create causal mask for sequence."""
172
+ if self.causal_mask_cache is None or self.causal_mask_cache.size(-1) < seq_len:
173
+ # [seq_len, seq_len] lower triangular matrix (1 = attend, 0 = mask)
174
+ mask = torch.tril(torch.ones(seq_len, seq_len, device=device, dtype=torch.bool))
175
+ self.register_buffer("causal_mask_cache", mask, persistent=False)
176
+ return self.causal_mask_cache[:seq_len, :seq_len]
177
+
178
+ def forward(
179
+ self,
180
+ input_ids: torch.Tensor,
181
+ attention_mask: Optional[torch.Tensor] = None,
182
+ labels: Optional[torch.Tensor] = None,
183
+ ) -> dict:
184
+ """
185
+ Forward pass through the model.
186
+
187
+ Args:
188
+ input_ids: [batch_size, seq_len] tensor of token IDs
189
+ attention_mask: [batch_size, seq_len] tensor where 1 = valid, 0 = padding
190
+ labels: [batch_size, seq_len] target token IDs for loss computation
191
+
192
+ Returns:
193
+ Dictionary with:
194
+ - 'logits': [batch_size, seq_len, vocab_size] output logits
195
+ - 'loss': scalar loss (if labels provided, else None)
196
+ """
197
+ batch_size, seq_len = input_ids.shape
198
+ device = input_ids.device
199
+
200
+ # Get causal mask: [seq_len, seq_len]
201
+ causal_mask = self._get_causal_mask(seq_len, device)
202
+
203
+ # Combine causal mask with attention mask if provided
204
+ if attention_mask is not None:
205
+ # attention_mask: [batch, seq_len] where 1 = valid, 0 = padding
206
+ # Expand to [batch, 1, 1, seq_len]
207
+ padding_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
208
+ # Combine with causal: [1, 1, seq_len, seq_len] * [batch, 1, 1, seq_len]
209
+ combined_mask = causal_mask.unsqueeze(0).unsqueeze(0) & padding_mask
210
+ # For MLA: convert to {0, 1} format
211
+ combined_mask = combined_mask.float()
212
+ else:
213
+ # Just causal mask
214
+ combined_mask = causal_mask.unsqueeze(0).unsqueeze(0).float()
215
+
216
+ # Embed tokens: [batch_size, seq_len] -> [batch_size, seq_len, d_model]
217
+ x = self.token_embedding(input_ids)
218
+ x = self.embedding_dropout(x)
219
+
220
+ # Pass through attention blocks
221
+ for block in self.blocks:
222
+ x = block(x, attention_mask=combined_mask)
223
+
224
+ # Final layer norm
225
+ x = self.final_norm(x)
226
+
227
+ # Output projection to vocabulary
228
+ logits = self.output_head(x) # [batch_size, seq_len, vocab_size]
229
+
230
+ # Compute loss if labels are provided
231
+ loss = None
232
+ if labels is not None:
233
+ # Flatten for loss computation
234
+ logits_flat = logits.view(-1, logits.size(-1)) # (batch * seq_len, vocab_size)
235
+ labels_flat = labels.view(-1)
236
+
237
+ # Only compute loss on valid targets (ignore -100 tokens for padding)
238
+ loss = F.cross_entropy(
239
+ logits_flat,
240
+ labels_flat,
241
+ reduction='mean',
242
+ ignore_index=-100
243
+ )
244
+
245
+ return {
246
+ 'logits': logits,
247
+ 'loss': loss,
248
+ }
code/TaoTrain/src/taoTrain/models/taonet_ssm.py ADDED
@@ -0,0 +1,654 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """TaoNet variant that replaces MLA attention with an SSM mixer."""
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from taoTrain.config import ModelConfig
10
+ from taoTrain.core import BaseModel
11
+
12
+ from .embeddings import FactorizedEmbedding
13
+ from .mla_components import AttentionBlock
14
+ from .registry import register_architecture
15
+
16
+
17
+ def _load_ssm_core(core: str):
18
+ try:
19
+ from gamma_space_model.modules.s4_ternary_dplr_ssm import S4TernaryDPLRSSM
20
+ from gamma_space_model.modules.ssm_gamma_s4 import SSMGammaS4
21
+ except ImportError as exc:
22
+ raise ImportError(
23
+ "taonet_ssm requires the Gamma Space Model package. Install the SSM repo "
24
+ "with `pip install -e /path/to/Taotern_SSM`, or put it on PYTHONPATH."
25
+ ) from exc
26
+ if core == "gamma_s4":
27
+ return SSMGammaS4
28
+ if core == "dplr":
29
+ return S4TernaryDPLRSSM
30
+ raise ValueError(f"Unsupported ssm_core '{core}'.")
31
+
32
+
33
+ def _padding_mask_from_attention_mask(attention_mask: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
34
+ if attention_mask is None:
35
+ return None
36
+ if attention_mask.dim() == 2:
37
+ return attention_mask
38
+ if attention_mask.dim() == 4:
39
+ return attention_mask.bool().any(dim=-2).squeeze(1).to(dtype=attention_mask.dtype)
40
+ raise ValueError(
41
+ "Expected attention_mask with shape [batch, seq_len] or "
42
+ f"[batch, 1, seq_len, seq_len], got {tuple(attention_mask.shape)}."
43
+ )
44
+
45
+
46
+ def _hybrid_ssm_layer_indices(config: ModelConfig, num_layers: int) -> set[int]:
47
+ if config.hybrid_ssm_layers:
48
+ indices = set()
49
+ for item in config.hybrid_ssm_layers.split(","):
50
+ item = item.strip()
51
+ if not item:
52
+ continue
53
+ index = int(item)
54
+ if index < 0 or index >= num_layers:
55
+ raise ValueError(
56
+ f"hybrid_ssm_layers index {index} is outside [0, {num_layers - 1}]."
57
+ )
58
+ indices.add(index)
59
+ if not indices:
60
+ raise ValueError("hybrid_ssm_layers was set but did not contain any valid layer indices.")
61
+ return indices
62
+
63
+ if config.hybrid_pattern == "attention_first":
64
+ return {idx for idx in range(num_layers) if idx % 2 == 1}
65
+ if config.hybrid_pattern == "ssm_first":
66
+ return {idx for idx in range(num_layers) if idx % 2 == 0}
67
+ if config.hybrid_pattern == "single_ssm_middle":
68
+ return {num_layers // 2}
69
+ if config.hybrid_pattern == "single_ssm_late":
70
+ return {num_layers - 1}
71
+ raise ValueError(f"Unsupported hybrid_pattern '{config.hybrid_pattern}'.")
72
+
73
+
74
+ class ChannelGate(nn.Module):
75
+ """Elementwise gate with one scale and bias per model channel."""
76
+
77
+ def __init__(self, d_model: int) -> None:
78
+ super().__init__()
79
+ self.weight = nn.Parameter(torch.zeros(d_model))
80
+ self.bias = nn.Parameter(torch.full((d_model,), 2.0))
81
+
82
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
83
+ return x * self.weight + self.bias
84
+
85
+ def reset_parameters(self) -> None:
86
+ nn.init.zeros_(self.weight)
87
+ nn.init.constant_(self.bias, 2.0)
88
+
89
+
90
+ def _build_gate(enabled: bool, gate_type: str, d_model: int) -> nn.Module | None:
91
+ if not enabled:
92
+ return None
93
+ if gate_type == "dense":
94
+ return nn.Linear(d_model, d_model)
95
+ if gate_type == "channel":
96
+ return ChannelGate(d_model)
97
+ raise ValueError(f"Unsupported ssm_gate_type '{gate_type}'.")
98
+
99
+
100
+ def _residual_rms_norm(
101
+ x: torch.Tensor,
102
+ enabled: bool,
103
+ target: float,
104
+ eps: float,
105
+ cap: Optional[float] = None,
106
+ ) -> torch.Tensor:
107
+ if not enabled and cap is None:
108
+ return x
109
+ rms = x.float().square().mean(dim=-1, keepdim=True).add(eps).sqrt()
110
+ if enabled:
111
+ scale = target / rms
112
+ else:
113
+ cap_tensor = torch.tensor(float(cap), dtype=rms.dtype, device=rms.device)
114
+ scale = torch.minimum(torch.ones_like(rms), cap_tensor / rms)
115
+ return x * scale.to(dtype=x.dtype)
116
+
117
+
118
+ class SSMMixer(nn.Module):
119
+ """Causal sequence mixer with the same residual-branch contract as MLA."""
120
+
121
+ def __init__(self, config: ModelConfig) -> None:
122
+ super().__init__()
123
+ SSMCore = _load_ssm_core(config.ssm_core)
124
+
125
+ self.d_model = config.hidden_dim
126
+ self.ssm_core = config.ssm_core
127
+ d_latent_kv = config.d_latent_kv if config.d_latent_kv is not None else int(self.d_model * 0.75)
128
+ self.ssm_hidden_dim = config.ssm_hidden_dim if config.ssm_hidden_dim is not None else d_latent_kv
129
+ self.ssm_mixer_dim = config.ssm_mixer_dim if config.ssm_mixer_dim is not None else self.d_model
130
+ self.ssm_num_lanes = config.ssm_num_lanes
131
+ self.ssm_lane_combine = config.ssm_lane_combine
132
+ self.ssm_lane_mode = config.ssm_lane_mode
133
+ self.ssm_split_mix = config.ssm_split_mix
134
+ self.use_padding_mask = config.ssm_use_padding_mask
135
+ self.branch_rms_norm = config.ssm_branch_rms_norm
136
+ self.branch_rms_eps = config.ssm_branch_rms_eps
137
+ self.branch_clip_value = config.ssm_branch_clip_value
138
+ if self.ssm_num_lanes < 1:
139
+ raise ValueError("ssm_num_lanes must be at least 1.")
140
+ if self.ssm_lane_mode not in {"full", "split"}:
141
+ raise ValueError(f"Unsupported ssm_lane_mode '{self.ssm_lane_mode}'.")
142
+ if self.ssm_split_mix not in {"none", "hadamard"}:
143
+ raise ValueError(f"Unsupported ssm_split_mix '{self.ssm_split_mix}'.")
144
+ if self.ssm_split_mix != "none" and self.ssm_lane_mode != "split":
145
+ raise ValueError("ssm_split_mix is only supported when ssm_lane_mode='split'.")
146
+ if self.ssm_split_mix == "hadamard" and self.ssm_num_lanes != 2:
147
+ raise ValueError("ssm_split_mix='hadamard' currently requires exactly two SSM lanes.")
148
+ if self.ssm_lane_mode == "split" and self.ssm_mixer_dim % self.ssm_num_lanes != 0:
149
+ raise ValueError(
150
+ "ssm_mixer_dim must be divisible by ssm_num_lanes when ssm_lane_mode='split'."
151
+ )
152
+ self.ssm_lane_dim = (
153
+ self.ssm_mixer_dim // self.ssm_num_lanes
154
+ if self.ssm_lane_mode == "split"
155
+ else self.ssm_mixer_dim
156
+ )
157
+
158
+ self.norm = nn.LayerNorm(self.d_model)
159
+ self.gate_type = config.ssm_gate_type
160
+ self.input_gate = _build_gate(config.ssm_input_gate, self.gate_type, self.d_model)
161
+ self.input_proj = (
162
+ nn.Linear(self.d_model, self.ssm_mixer_dim, bias=False)
163
+ if self.ssm_mixer_dim != self.d_model
164
+ else nn.Identity()
165
+ )
166
+ common_kwargs = {
167
+ "state_dim": self.ssm_lane_dim,
168
+ "hidden_dim": self.ssm_hidden_dim,
169
+ "dt_min": config.ssm_dt_min,
170
+ "dt_max": config.ssm_dt_max,
171
+ "dt_init": config.ssm_dt_init,
172
+ "use_D": config.ssm_use_d,
173
+ "kernel_mode": config.ssm_kernel_mode,
174
+ "kernel_threshold": config.ssm_kernel_threshold,
175
+ }
176
+ self.ssm_lanes = nn.ModuleList(
177
+ [self._build_ssm_lane(SSMCore, common_kwargs, config) for _ in range(self.ssm_num_lanes)]
178
+ )
179
+ self.ssm = self.ssm_lanes[0]
180
+ self.lane_weights = None
181
+ if self.ssm_lane_combine not in {"mean", "channel"}:
182
+ raise ValueError(f"Unsupported ssm_lane_combine '{self.ssm_lane_combine}'.")
183
+ if (
184
+ self.ssm_lane_mode == "full"
185
+ and self.ssm_num_lanes > 1
186
+ and self.ssm_lane_combine == "channel"
187
+ ):
188
+ self.lane_weights = nn.Parameter(
189
+ torch.full((self.ssm_num_lanes, self.ssm_mixer_dim), 1.0 / self.ssm_num_lanes)
190
+ )
191
+
192
+ if config.ssm_activation == "gelu":
193
+ self.activation = nn.GELU()
194
+ elif config.ssm_activation == "silu":
195
+ self.activation = nn.SiLU()
196
+ elif config.ssm_activation in {"identity", "linear"}:
197
+ self.activation = nn.Identity()
198
+ else:
199
+ raise ValueError(f"Unsupported ssm_activation '{config.ssm_activation}'.")
200
+
201
+ self.output_gate = _build_gate(config.ssm_gate, self.gate_type, self.d_model)
202
+ self.out_proj = nn.Linear(self.ssm_mixer_dim, self.d_model, bias=False)
203
+ self.layer_scale = nn.Parameter(torch.full((self.d_model,), config.ssm_layer_scale_init))
204
+ self.local_shift_scale = None
205
+ if config.ssm_local_shift:
206
+ if config.ssm_local_shift_per_channel:
207
+ self.local_shift_scale = nn.Parameter(
208
+ torch.full((self.d_model,), float(config.ssm_local_shift_init))
209
+ )
210
+ else:
211
+ self.local_shift_scale = nn.Parameter(torch.tensor(float(config.ssm_local_shift_init)))
212
+ self.proj_dropout = nn.Dropout(config.dropout)
213
+
214
+ self._reset_parameters()
215
+
216
+ def _normalize_branch(self, ssm_out: torch.Tensor) -> torch.Tensor:
217
+ if not self.branch_rms_norm:
218
+ return ssm_out
219
+ rms = ssm_out.float().square().mean(dim=-1, keepdim=True).add(self.branch_rms_eps).rsqrt()
220
+ return ssm_out * rms.to(dtype=ssm_out.dtype)
221
+
222
+ def _build_ssm_lane(self, SSMCore, common_kwargs: dict, config: ModelConfig) -> nn.Module:
223
+ if config.ssm_core == "gamma_s4":
224
+ return SSMCore(
225
+ **common_kwargs,
226
+ discretization=config.ssm_discretization,
227
+ )
228
+ return SSMCore(
229
+ **common_kwargs,
230
+ rank=config.ssm_rank,
231
+ max_low_rank_scale=config.ssm_max_low_rank_scale,
232
+ finite_tail_correction=config.ssm_finite_tail_correction,
233
+ )
234
+
235
+ def _reset_parameters(self) -> None:
236
+ if isinstance(self.input_gate, nn.Linear):
237
+ nn.init.zeros_(self.input_gate.weight)
238
+ nn.init.constant_(self.input_gate.bias, 2.0)
239
+ elif isinstance(self.input_gate, ChannelGate):
240
+ self.input_gate.reset_parameters()
241
+ if isinstance(self.output_gate, nn.Linear):
242
+ nn.init.zeros_(self.output_gate.weight)
243
+ nn.init.constant_(self.output_gate.bias, 2.0)
244
+ elif isinstance(self.output_gate, ChannelGate):
245
+ self.output_gate.reset_parameters()
246
+ if isinstance(self.input_proj, nn.Linear):
247
+ nn.init.xavier_uniform_(self.input_proj.weight)
248
+ nn.init.xavier_uniform_(self.out_proj.weight)
249
+ else:
250
+ nn.init.eye_(self.out_proj.weight)
251
+
252
+ def forward(
253
+ self,
254
+ x: torch.Tensor,
255
+ attention_mask: Optional[torch.Tensor] = None,
256
+ ) -> torch.Tensor:
257
+ x_norm = self.norm(x)
258
+ ssm_in = x_norm
259
+ if self.input_gate is not None:
260
+ ssm_in = ssm_in * torch.sigmoid(self.input_gate(x_norm))
261
+ ssm_in = self.input_proj(ssm_in)
262
+
263
+ padding_mask = _padding_mask_from_attention_mask(attention_mask) if self.use_padding_mask else None
264
+ lane_outputs = []
265
+ if self.ssm_lane_mode == "split":
266
+ lane_inputs = torch.split(ssm_in, self.ssm_lane_dim, dim=-1)
267
+ else:
268
+ lane_inputs = [ssm_in] * self.ssm_num_lanes
269
+ for lane, lane_input in zip(self.ssm_lanes, lane_inputs):
270
+ lane_out, _ = lane(
271
+ lane_input,
272
+ mask=padding_mask,
273
+ return_state=False,
274
+ )
275
+ lane_outputs.append(lane_out)
276
+ if self.ssm_lane_mode == "split":
277
+ if self.ssm_split_mix == "hadamard":
278
+ left, right = lane_outputs
279
+ inv_sqrt_2 = 0.7071067811865476
280
+ ssm_out = torch.cat((left + right, left - right), dim=-1) * inv_sqrt_2
281
+ else:
282
+ ssm_out = torch.cat(lane_outputs, dim=-1)
283
+ elif len(lane_outputs) == 1:
284
+ ssm_out = lane_outputs[0]
285
+ elif self.lane_weights is not None:
286
+ weights = self.lane_weights.to(dtype=lane_outputs[0].dtype, device=lane_outputs[0].device)
287
+ ssm_out = torch.stack(lane_outputs, dim=2)
288
+ ssm_out = (ssm_out * weights.view(1, 1, self.ssm_num_lanes, self.ssm_mixer_dim)).sum(dim=2)
289
+ else:
290
+ ssm_out = torch.stack(lane_outputs, dim=0).mean(dim=0)
291
+ ssm_out = self.activation(ssm_out)
292
+ ssm_out = self.out_proj(ssm_out)
293
+
294
+ if self.output_gate is not None:
295
+ ssm_out = ssm_out * torch.sigmoid(self.output_gate(x_norm))
296
+
297
+ ssm_out = self._normalize_branch(ssm_out)
298
+ ssm_out = ssm_out * self.layer_scale
299
+ if self.local_shift_scale is not None:
300
+ shifted = torch.zeros_like(x_norm)
301
+ shifted[:, 1:] = x_norm[:, :-1]
302
+ ssm_out = ssm_out + shifted * self.local_shift_scale
303
+ if self.branch_clip_value is not None:
304
+ ssm_out = torch.clamp(ssm_out, -self.branch_clip_value, self.branch_clip_value)
305
+ return self.proj_dropout(ssm_out)
306
+
307
+
308
+ class SSMAttentionBlock(nn.Module):
309
+ """TaoNet block with Gamma SSM sequence mixing and the original SwiGLU FFN."""
310
+
311
+ def __init__(self, config: ModelConfig) -> None:
312
+ super().__init__()
313
+ d_model = config.hidden_dim
314
+ d_ff = config.hidden_dim_ff if config.hidden_dim_ff is not None else d_model * 4
315
+
316
+ self.mixer = SSMMixer(config)
317
+ self.residual_rms_norm = config.block_residual_rms_norm
318
+ self.residual_rms_target = config.block_residual_rms_target
319
+ self.residual_rms_cap = config.block_residual_rms_cap
320
+ self.residual_rms_eps = config.block_residual_rms_eps
321
+ self.ff_norm = nn.LayerNorm(d_model)
322
+ self.ff_gate = nn.Linear(d_model, int(d_ff), bias=False)
323
+ self.ff_value = nn.Linear(d_model, int(d_ff), bias=False)
324
+ self.ff_out = nn.Linear(int(d_ff), d_model, bias=False)
325
+ self.dropout = nn.Dropout(config.dropout)
326
+
327
+ def forward(
328
+ self,
329
+ x: torch.Tensor,
330
+ attention_mask: Optional[torch.Tensor] = None,
331
+ ) -> torch.Tensor:
332
+ x = x + self.dropout(self.mixer(x, attention_mask=attention_mask))
333
+ x = _residual_rms_norm(
334
+ x,
335
+ self.residual_rms_norm,
336
+ self.residual_rms_target,
337
+ self.residual_rms_eps,
338
+ self.residual_rms_cap,
339
+ )
340
+
341
+ ff_norm = self.ff_norm(x)
342
+ ff_gate = self.ff_gate(ff_norm)
343
+ ff_value = self.ff_value(ff_norm)
344
+ ff_out = ff_value * F.silu(ff_gate)
345
+ ff_out = self.ff_out(ff_out)
346
+ x = x + self.dropout(ff_out)
347
+ return _residual_rms_norm(
348
+ x,
349
+ self.residual_rms_norm,
350
+ self.residual_rms_target,
351
+ self.residual_rms_eps,
352
+ self.residual_rms_cap,
353
+ )
354
+
355
+
356
+ @register_architecture("taonet_ssm")
357
+ class TaoNetSSMLLM(BaseModel):
358
+ """TaoNet language model with SSM blocks replacing MLA attention."""
359
+
360
+ def __init__(self, config: ModelConfig):
361
+ super().__init__(config)
362
+
363
+ self.vocab_size = config.vocab_size
364
+ self.d_model = config.hidden_dim
365
+ self.n_layers = config.num_layers
366
+ self.n_heads = config.num_heads
367
+ self.dropout = config.dropout
368
+ self.d_latent_kv = config.d_latent_kv if config.d_latent_kv is not None else int(self.d_model * 0.75)
369
+ self.d_ff = config.hidden_dim_ff if config.hidden_dim_ff is not None else self.d_model * 4
370
+ self.use_factorized_embedding = getattr(config, "use_factorized_embedding", False)
371
+ self.d_embed_rank = getattr(config, "d_embed_rank", 96)
372
+ self.max_seq_length = config.max_seq_length
373
+
374
+ if self.use_factorized_embedding:
375
+ self.token_embedding = FactorizedEmbedding(
376
+ self.vocab_size,
377
+ self.d_model,
378
+ self.d_embed_rank,
379
+ )
380
+ else:
381
+ self.token_embedding = nn.Embedding(self.vocab_size, self.d_model)
382
+
383
+ self.embedding_dropout = nn.Dropout(self.dropout)
384
+ self.blocks = nn.ModuleList([SSMAttentionBlock(config) for _ in range(self.n_layers)])
385
+ self.final_norm = nn.LayerNorm(self.d_model)
386
+ self.output_head = nn.Linear(self.d_model, self.vocab_size, bias=False)
387
+
388
+ self.apply(self._init_weights)
389
+ for block in self.blocks:
390
+ block.mixer._reset_parameters()
391
+
392
+ self._print_architecture(config)
393
+
394
+ def _init_weights(self, module):
395
+ if isinstance(module, nn.Linear):
396
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std)
397
+ if module.bias is not None:
398
+ nn.init.zeros_(module.bias)
399
+ elif isinstance(module, nn.Embedding):
400
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std)
401
+
402
+ def _print_architecture(self, config: ModelConfig):
403
+ total_params = sum(p.numel() for p in self.parameters())
404
+ trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
405
+ ssm_hidden_dim = config.ssm_hidden_dim if config.ssm_hidden_dim is not None else self.d_latent_kv
406
+
407
+ print(f"\n{'=' * 70}")
408
+ print(f"MODEL ARCHITECTURE - TAONET-SSM ({config.ssm_core} + SwiGLU)")
409
+ print(f"{'=' * 70}")
410
+ print(f"Embedding vocab: {self.vocab_size}")
411
+ print(f"Output Head: {(self.d_model * self.vocab_size) / 1e6:>8.2f}M")
412
+ print(f"SSM Blocks: {len(self.blocks):>8} layers x SSMMixer")
413
+ print(f"{'-' * 70}")
414
+ print(f"Total Parameters: {total_params / 1e6:>8.2f}M (trainable: {trainable_params / 1e6:.2f}M)")
415
+ print(f"{'-' * 70}")
416
+ print("Configuration:")
417
+ print(f" Model dimension (d_model): {self.d_model}")
418
+ print(f" SSM core: {config.ssm_core}")
419
+ print(f" SSM hidden dimension: {ssm_hidden_dim}")
420
+ print(f" SSM mixer dimension: {config.ssm_mixer_dim or self.d_model}")
421
+ print(f" SSM lanes: {config.ssm_num_lanes}")
422
+ print(f" SSM lane mode: {config.ssm_lane_mode}")
423
+ print(f" SSM split mix: {config.ssm_split_mix}")
424
+ print(f" SSM lane combine: {config.ssm_lane_combine}")
425
+ if config.ssm_core == "dplr":
426
+ print(f" SSM DPLR rank: {config.ssm_rank}")
427
+ print(f" SSM discretization: {config.ssm_discretization}")
428
+ print(f" SSM kernel mode: {config.ssm_kernel_mode}")
429
+ print(f" SSM kernel threshold: {config.ssm_kernel_threshold}")
430
+ print(f" SSM padding mask enabled: {config.ssm_use_padding_mask}")
431
+ print(f" SSM gate type: {config.ssm_gate_type}")
432
+ print(f" SSM branch RMS norm: {config.ssm_branch_rms_norm}")
433
+ print(f" SSM branch clip value: {config.ssm_branch_clip_value}")
434
+ print(f" Block residual RMS norm: {config.block_residual_rms_norm}")
435
+ print(f" Block residual RMS cap: {config.block_residual_rms_cap}")
436
+ print(f" SSM local shift enabled: {config.ssm_local_shift}")
437
+ print(f" SSM local shift per channel: {config.ssm_local_shift_per_channel}")
438
+ print(f" Feed-forward dimension: {int(self.d_ff)}")
439
+ print(f" Number of layers: {self.n_layers}")
440
+ print(f" Max sequence length: {self.max_seq_length}")
441
+ print(f" Dropout: {self.dropout}")
442
+ print(f"{'=' * 70}\n")
443
+
444
+ def forward(
445
+ self,
446
+ input_ids: torch.Tensor,
447
+ attention_mask: Optional[torch.Tensor] = None,
448
+ labels: Optional[torch.Tensor] = None,
449
+ ) -> dict:
450
+ x = self.token_embedding(input_ids)
451
+ x = self.embedding_dropout(x)
452
+
453
+ for block in self.blocks:
454
+ x = block(x, attention_mask=attention_mask)
455
+
456
+ x = self.final_norm(x)
457
+ logits = self.output_head(x)
458
+
459
+ loss = None
460
+ if labels is not None:
461
+ loss = F.cross_entropy(
462
+ logits.view(-1, logits.size(-1)),
463
+ labels.view(-1),
464
+ reduction="mean",
465
+ ignore_index=-100,
466
+ )
467
+
468
+ return {
469
+ "logits": logits,
470
+ "loss": loss,
471
+ }
472
+
473
+
474
+ @register_architecture("taonet_hybrid")
475
+ class TaoNetHybridLLM(BaseModel):
476
+ """TaoNet language model with alternating MLA attention and SSM mixer blocks."""
477
+
478
+ def __init__(self, config: ModelConfig):
479
+ super().__init__(config)
480
+
481
+ self.vocab_size = config.vocab_size
482
+ self.d_model = config.hidden_dim
483
+ self.n_layers = config.num_layers
484
+ self.n_heads = config.num_heads
485
+ self.dropout = config.dropout
486
+ self.d_latent_kv = config.d_latent_kv if config.d_latent_kv is not None else int(self.d_model * 0.75)
487
+ self.d_rope = config.d_rope if config.d_rope is not None else self.d_model // self.n_heads
488
+ self.d_ff = config.hidden_dim_ff if config.hidden_dim_ff is not None else self.d_model * 4
489
+ self.gqa_groups = getattr(config, "gqa_groups", 1)
490
+ self.use_factorized_embedding = getattr(config, "use_factorized_embedding", False)
491
+ self.d_embed_rank = getattr(config, "d_embed_rank", 96)
492
+ self.rope_scale = getattr(config, "rope_scale", 40.0)
493
+ self.yarn_alpha = getattr(config, "yarn_alpha", 1.0)
494
+ self.max_seq_length = config.max_seq_length
495
+
496
+ assert self.d_model % self.n_heads == 0, (
497
+ f"hidden_dim ({self.d_model}) must be divisible by num_heads ({self.n_heads})"
498
+ )
499
+ assert self.d_latent_kv % self.n_heads == 0, (
500
+ f"d_latent_kv ({self.d_latent_kv}) must be divisible by num_heads ({self.n_heads})"
501
+ )
502
+
503
+ if self.use_factorized_embedding:
504
+ self.token_embedding = FactorizedEmbedding(
505
+ self.vocab_size,
506
+ self.d_model,
507
+ self.d_embed_rank,
508
+ )
509
+ else:
510
+ self.token_embedding = nn.Embedding(self.vocab_size, self.d_model)
511
+
512
+ self.embedding_dropout = nn.Dropout(self.dropout)
513
+ self.blocks = nn.ModuleList()
514
+ self.block_kinds: list[str] = []
515
+ self.ssm_layer_indices = _hybrid_ssm_layer_indices(config, self.n_layers)
516
+ for layer_idx in range(self.n_layers):
517
+ if layer_idx in self.ssm_layer_indices:
518
+ self.blocks.append(SSMAttentionBlock(config))
519
+ self.block_kinds.append("ssm")
520
+ else:
521
+ self.blocks.append(
522
+ AttentionBlock(
523
+ d_model=self.d_model,
524
+ d_latent_kv=self.d_latent_kv,
525
+ n_heads=self.n_heads,
526
+ d_rope=self.d_rope,
527
+ d_ff=int(self.d_ff),
528
+ dropout=self.dropout,
529
+ gqa_groups=self.gqa_groups,
530
+ rope_scale=self.rope_scale,
531
+ max_seq_length=self.max_seq_length,
532
+ yarn_alpha=self.yarn_alpha,
533
+ residual_rms_norm=config.block_residual_rms_norm,
534
+ residual_rms_target=config.block_residual_rms_target,
535
+ residual_rms_cap=config.block_residual_rms_cap,
536
+ residual_rms_eps=config.block_residual_rms_eps,
537
+ )
538
+ )
539
+ self.block_kinds.append("attention")
540
+
541
+ self.final_norm = nn.LayerNorm(self.d_model)
542
+ self.output_head = nn.Linear(self.d_model, self.vocab_size, bias=False)
543
+
544
+ self.apply(self._init_weights)
545
+ for block in self.blocks:
546
+ mixer = getattr(block, "mixer", None)
547
+ if mixer is not None:
548
+ mixer._reset_parameters()
549
+
550
+ self.register_buffer("causal_mask_cache", None, persistent=False)
551
+ self._print_architecture(config)
552
+
553
+ def _init_weights(self, module):
554
+ if isinstance(module, nn.Linear):
555
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std)
556
+ if module.bias is not None:
557
+ nn.init.zeros_(module.bias)
558
+ elif isinstance(module, nn.Embedding):
559
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std)
560
+
561
+ def _get_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
562
+ if self.causal_mask_cache is None or self.causal_mask_cache.size(-1) < seq_len:
563
+ mask = torch.tril(torch.ones(seq_len, seq_len, device=device, dtype=torch.bool))
564
+ self.register_buffer("causal_mask_cache", mask, persistent=False)
565
+ return self.causal_mask_cache[:seq_len, :seq_len]
566
+
567
+ def _get_combined_mask(
568
+ self,
569
+ attention_mask: Optional[torch.Tensor],
570
+ seq_len: int,
571
+ device: torch.device,
572
+ ) -> torch.Tensor:
573
+ causal_mask = self._get_causal_mask(seq_len, device)
574
+ if attention_mask is None:
575
+ return causal_mask.unsqueeze(0).unsqueeze(0).float()
576
+ padding_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
577
+ return (causal_mask.unsqueeze(0).unsqueeze(0) & padding_mask).float()
578
+
579
+ def _print_architecture(self, config: ModelConfig) -> None:
580
+ total_params = sum(p.numel() for p in self.parameters())
581
+ trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
582
+ attention_blocks = self.block_kinds.count("attention")
583
+ ssm_blocks = self.block_kinds.count("ssm")
584
+ ssm_hidden_dim = config.ssm_hidden_dim if config.ssm_hidden_dim is not None else self.d_latent_kv
585
+
586
+ print(f"\n{'=' * 70}")
587
+ print(f"MODEL ARCHITECTURE - TAONET-HYBRID (MLA + {config.ssm_core} SSM)")
588
+ print(f"{'=' * 70}")
589
+ print(f"Embedding vocab: {self.vocab_size}")
590
+ print(f"Output Head: {(self.d_model * self.vocab_size) / 1e6:>8.2f}M")
591
+ print(f"Attention Blocks: {attention_blocks:>8} layers")
592
+ print(f"SSM Blocks: {ssm_blocks:>8} layers")
593
+ print(f"{'-' * 70}")
594
+ print(f"Total Parameters: {total_params / 1e6:>8.2f}M (trainable: {trainable_params / 1e6:.2f}M)")
595
+ print(f"{'-' * 70}")
596
+ print("Configuration:")
597
+ print(f" Model dimension (d_model): {self.d_model}")
598
+ print(f" KV latent dimension (d_latent_kv): {self.d_latent_kv}")
599
+ print(f" Attention heads: {self.n_heads}")
600
+ print(f" SSM core: {config.ssm_core}")
601
+ print(f" SSM hidden dimension: {ssm_hidden_dim}")
602
+ print(f" SSM mixer dimension: {config.ssm_mixer_dim or self.d_model}")
603
+ print(f" SSM lanes: {config.ssm_num_lanes}")
604
+ print(f" SSM lane mode: {config.ssm_lane_mode}")
605
+ print(f" SSM split mix: {config.ssm_split_mix}")
606
+ print(f" SSM lane combine: {config.ssm_lane_combine}")
607
+ if config.ssm_core == "dplr":
608
+ print(f" SSM DPLR rank: {config.ssm_rank}")
609
+ print(f" SSM finite-tail correction: {config.ssm_finite_tail_correction}")
610
+ print(f" SSM branch RMS norm: {config.ssm_branch_rms_norm}")
611
+ print(f" SSM branch clip value: {config.ssm_branch_clip_value}")
612
+ print(f" Block residual RMS norm: {config.block_residual_rms_norm}")
613
+ print(f" Block residual RMS cap: {config.block_residual_rms_cap}")
614
+ print(f" SSM local shift enabled: {config.ssm_local_shift}")
615
+ print(f" SSM gate type: {config.ssm_gate_type}")
616
+ print(f" Hybrid pattern: {config.hybrid_pattern}")
617
+ print(f" Hybrid SSM layers: {','.join(str(i) for i in sorted(self.ssm_layer_indices))}")
618
+ print(f" Feed-forward dimension: {int(self.d_ff)}")
619
+ print(f" Number of layers: {self.n_layers}")
620
+ print(f" Max sequence length: {self.max_seq_length}")
621
+ print(f" Dropout: {self.dropout}")
622
+ print(f"{'=' * 70}\n")
623
+
624
+ def forward(
625
+ self,
626
+ input_ids: torch.Tensor,
627
+ attention_mask: Optional[torch.Tensor] = None,
628
+ labels: Optional[torch.Tensor] = None,
629
+ ) -> dict:
630
+ _, seq_len = input_ids.shape
631
+ combined_mask = self._get_combined_mask(attention_mask, seq_len, input_ids.device)
632
+
633
+ x = self.token_embedding(input_ids)
634
+ x = self.embedding_dropout(x)
635
+
636
+ for block in self.blocks:
637
+ x = block(x, attention_mask=combined_mask)
638
+
639
+ x = self.final_norm(x)
640
+ logits = self.output_head(x)
641
+
642
+ loss = None
643
+ if labels is not None:
644
+ loss = F.cross_entropy(
645
+ logits.view(-1, logits.size(-1)),
646
+ labels.view(-1),
647
+ reduction="mean",
648
+ ignore_index=-100,
649
+ )
650
+
651
+ return {
652
+ "logits": logits,
653
+ "loss": loss,
654
+ }
code/TaoTrain/src/taoTrain/models/transformer.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Standard Transformer language model implementation."""
2
+
3
+ import math
4
+ from typing import Optional
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from taoTrain.core import BaseModel
10
+ from taoTrain.config import ModelConfig
11
+ from .registry import register_architecture
12
+
13
+
14
+ # ============================================================================
15
+ # Components
16
+ # ============================================================================
17
+
18
+
19
+ class PositionalEmbedding(nn.Module):
20
+ """Sinusoidal positional embeddings."""
21
+
22
+ def __init__(self, dim: int, max_seq_length: int = 2048):
23
+ """Initialize positional embeddings."""
24
+ super().__init__()
25
+ self.dim = dim
26
+ self.max_seq_length = max_seq_length
27
+
28
+ # Precompute positional embeddings
29
+ pe = torch.zeros(max_seq_length, dim)
30
+ pos = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
31
+ div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim))
32
+
33
+ pe[:, 0::2] = torch.sin(pos * div_term)
34
+ if dim % 2 == 1:
35
+ pe[:, 1::2] = torch.cos(pos * div_term[:-1])
36
+ else:
37
+ pe[:, 1::2] = torch.cos(pos * div_term)
38
+
39
+ self.register_buffer("pe", pe, persistent=False)
40
+
41
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
42
+ """
43
+ Add positional embeddings to input.
44
+
45
+ Args:
46
+ x: Input tensor (batch, seq_len, hidden_dim)
47
+
48
+ Returns:
49
+ Input + positional embeddings
50
+ """
51
+ seq_len = x.shape[1]
52
+ return x + self.pe[:seq_len]
53
+
54
+
55
+ class Attention(nn.Module):
56
+ """Multi-head self-attention using scaled dot-product attention."""
57
+
58
+ def __init__(self, config: ModelConfig):
59
+ """Initialize attention."""
60
+ super().__init__()
61
+ self.hidden_dim = config.hidden_dim
62
+ self.num_heads = config.num_heads
63
+ self.head_dim = config.head_dim
64
+
65
+ assert self.hidden_dim % self.num_heads == 0
66
+
67
+ # Linear projections
68
+ self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim)
69
+ self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim)
70
+ self.v_proj = nn.Linear(self.hidden_dim, self.hidden_dim)
71
+ self.out_proj = nn.Linear(self.hidden_dim, self.hidden_dim)
72
+
73
+ self.dropout_p = config.dropout
74
+
75
+ def forward(
76
+ self,
77
+ x: torch.Tensor,
78
+ attention_mask: Optional[torch.Tensor] = None,
79
+ ) -> torch.Tensor:
80
+ """
81
+ Forward pass using scaled_dot_product_attention.
82
+
83
+ Args:
84
+ x: Shape (batch, seq_len, hidden_dim)
85
+ attention_mask: Shape (batch, seq_len)
86
+
87
+ Returns:
88
+ Output: Shape (batch, seq_len, hidden_dim)
89
+ """
90
+ batch_size, seq_len, _ = x.shape
91
+
92
+ # Project to Q, K, V
93
+ q = self.q_proj(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim)
94
+ k = self.k_proj(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim)
95
+ v = self.v_proj(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim)
96
+
97
+ # Transpose for attention: (batch, num_heads, seq_len, head_dim)
98
+ q = q.transpose(1, 2)
99
+ k = k.transpose(1, 2)
100
+ v = v.transpose(1, 2)
101
+
102
+ # NOTE: PyTorch's scaled_dot_product_attention does NOT support both
103
+ # explicit attn_mask AND is_causal=True together.
104
+ # When is_causal=True, PyTorch handles causal masking automatically.
105
+ # Padding positions are handled separately via loss computation (labels=-100).
106
+ # See: https://github.com/pytorch/pytorch/issues/96099
107
+
108
+ # Compute attention using scaled_dot_product_attention
109
+ # is_causal=True automatically applies causal masking
110
+ # We do NOT pass attn_mask when is_causal=True
111
+ out = F.scaled_dot_product_attention(
112
+ q, k, v,
113
+ attn_mask=None, # Must be None when is_causal=True
114
+ dropout_p=self.dropout_p if self.training else 0.0,
115
+ is_causal=True,
116
+ scale=None # Uses default scale of 1/sqrt(head_dim)
117
+ ) # (batch, num_heads, seq_len, head_dim)
118
+
119
+ # Transpose back and reshape
120
+ out = out.transpose(1, 2).contiguous() # (batch, seq_len, num_heads, head_dim)
121
+ out = out.reshape(batch_size, seq_len, self.hidden_dim)
122
+
123
+ # Output projection
124
+ out = self.out_proj(out)
125
+
126
+ return out
127
+
128
+
129
+ class SwiGLU(nn.Module):
130
+ """Swish Gated Linear Unit activation."""
131
+
132
+ def __init__(self, in_dim: int, out_dim: int, dropout: float = 0.0):
133
+ """
134
+ Initialize SwiGLU.
135
+
136
+ Args:
137
+ in_dim: Input dimension
138
+ out_dim: Intermediate/hidden dimension
139
+ dropout: Dropout rate
140
+ """
141
+ super().__init__()
142
+ # Project to 2x the intermediate dimension (for value and gate)
143
+ self.fc1 = nn.Linear(in_dim, 2 * out_dim)
144
+ self.fc2 = nn.Linear(out_dim, in_dim) # Project back to input dimension
145
+ self.dropout = nn.Dropout(dropout)
146
+
147
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
148
+ """
149
+ Forward pass with SwiGLU activation.
150
+
151
+ Args:
152
+ x: Input tensor
153
+
154
+ Returns:
155
+ Gated activation output (same dimension as input)
156
+ """
157
+ # Project to 2x intermediate dimension
158
+ x = self.fc1(x)
159
+
160
+ # Split into value and gate
161
+ x, gate = x.chunk(2, dim=-1)
162
+
163
+ # SwiGLU: value * swish(gate) = value * gate * sigmoid(gate)
164
+ x = x * F.silu(gate) # SiLU is Swish: x * sigmoid(x)
165
+
166
+ x = self.dropout(x)
167
+ x = self.fc2(x) # Project back to input dimension
168
+
169
+ return x
170
+
171
+
172
+ class FeedForward(nn.Module):
173
+ """Feed-forward network with SwiGLU activation."""
174
+
175
+ def __init__(self, config: ModelConfig):
176
+ """Initialize FFN with SwiGLU."""
177
+ super().__init__()
178
+ self.swiglu = SwiGLU(
179
+ in_dim=config.hidden_dim,
180
+ out_dim=config.intermediate_dim,
181
+ dropout=config.dropout
182
+ )
183
+
184
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
185
+ """Forward pass with SwiGLU activation."""
186
+ return self.swiglu(x)
187
+
188
+
189
+ class TransformerBlock(nn.Module):
190
+ """Single transformer block with attention and FFN."""
191
+
192
+ def __init__(self, config: ModelConfig):
193
+ """Initialize transformer block."""
194
+ super().__init__()
195
+ self.norm1 = nn.LayerNorm(config.hidden_dim)
196
+ self.attn = Attention(config)
197
+ self.norm2 = nn.LayerNorm(config.hidden_dim)
198
+ self.ffn = FeedForward(config)
199
+
200
+ def forward(
201
+ self,
202
+ x: torch.Tensor,
203
+ attention_mask: Optional[torch.Tensor] = None,
204
+ ) -> torch.Tensor:
205
+ """Forward pass with pre-norm residual connections."""
206
+ # Attention with residual
207
+ x = x + self.attn(self.norm1(x), attention_mask=attention_mask)
208
+
209
+ # FFN with residual
210
+ x = x + self.ffn(self.norm2(x))
211
+
212
+ return x
213
+
214
+
215
+ # ============================================================================
216
+ # Transformer LM
217
+ # ============================================================================
218
+
219
+
220
+ @register_architecture("transformer")
221
+ class TransformerLM(BaseModel):
222
+ """Standard Transformer language model."""
223
+
224
+ def __init__(self, config: ModelConfig):
225
+ """Initialize Transformer LM."""
226
+ super().__init__(config)
227
+
228
+ # Embeddings
229
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_dim)
230
+ self.pos_embed = PositionalEmbedding(config.hidden_dim, max_seq_length=config.max_seq_length)
231
+ self.dropout = nn.Dropout(config.dropout)
232
+
233
+ # Transformer blocks
234
+ self.blocks = nn.ModuleList([
235
+ TransformerBlock(config) for _ in range(config.num_layers)
236
+ ])
237
+
238
+ # Final layer norm
239
+ self.final_norm = nn.LayerNorm(config.hidden_dim)
240
+
241
+ # Output projection (shared with input embeddings for efficiency)
242
+ self.lm_head = nn.Linear(config.hidden_dim, config.vocab_size, bias=False)
243
+
244
+ # Weight tying (optional)
245
+ self.lm_head.weight = self.embed_tokens.weight
246
+
247
+ # Initialize weights
248
+ self._init_weights()
249
+
250
+ def _init_weights(self):
251
+ """Initialize model weights."""
252
+ for module in self.modules():
253
+ if isinstance(module, nn.Linear):
254
+ nn.init.normal_(module.weight, std=self.config.init_std)
255
+ if module.bias is not None:
256
+ nn.init.zeros_(module.bias)
257
+ elif isinstance(module, nn.Embedding):
258
+ nn.init.normal_(module.weight, std=self.config.init_std)
259
+
260
+ def forward(
261
+ self,
262
+ input_ids: torch.Tensor,
263
+ attention_mask: Optional[torch.Tensor] = None,
264
+ labels: Optional[torch.Tensor] = None,
265
+ ) -> dict[str, torch.Tensor]:
266
+ """
267
+ Forward pass.
268
+
269
+ Args:
270
+ input_ids: (batch_size, seq_len)
271
+ attention_mask: (batch_size, seq_len)
272
+ labels: (batch_size, seq_len) for loss computation
273
+
274
+ Returns:
275
+ Dict with 'logits' and optionally 'loss'
276
+ """
277
+ batch_size, seq_len = input_ids.shape
278
+
279
+ # Embedding
280
+ x = self.embed_tokens(input_ids)
281
+
282
+ # Add positional embeddings
283
+ x = self.pos_embed(x)
284
+
285
+ x = self.dropout(x)
286
+
287
+ # Transformer blocks
288
+ for block in self.blocks:
289
+ x = block(x, attention_mask=attention_mask)
290
+
291
+ # Final normalization
292
+ x = self.final_norm(x)
293
+
294
+ # LM head
295
+ logits = self.lm_head(x) # (batch, seq_len, vocab_size)
296
+
297
+ # Loss computation
298
+ loss = None
299
+ if labels is not None:
300
+ # Flatten for loss computation
301
+ logits_flat = logits.view(-1, logits.size(-1)) # (batch * seq_len, vocab_size)
302
+ labels_flat = labels.view(-1)
303
+
304
+ # Only compute loss on valid targets (ignore -100 tokens)
305
+ loss = F.cross_entropy(
306
+ logits_flat,
307
+ labels_flat,
308
+ reduction='mean',
309
+ ignore_index=-100
310
+ )
311
+
312
+ return {
313
+ 'logits': logits,
314
+ 'loss': loss,
315
+ }
code/TaoTrain/src/taoTrain/optimizers/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Optimizer registry and factories."""
2
+
3
+ from .registry import (
4
+ register_optimizer,
5
+ get_optimizer,
6
+ get_registered_optimizers,
7
+ )
8
+
9
+ __all__ = [
10
+ "register_optimizer",
11
+ "get_optimizer",
12
+ "get_registered_optimizers",
13
+ ]
code/TaoTrain/src/taoTrain/optimizers/adam.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Adam optimizer factory."""
2
+
3
+ import torch.optim as optim
4
+ from taoTrain.core.base import BaseModel
5
+ from taoTrain.config import TrainingConfig
6
+ from .registry import register_optimizer
7
+
8
+
9
+ def _separate_parameters(model: BaseModel) -> tuple[list, list]:
10
+ """
11
+ Separate model parameters into decay and no-decay groups.
12
+
13
+ Args:
14
+ model: Model instance
15
+
16
+ Returns:
17
+ Tuple of (decay_params, no_decay_params)
18
+ """
19
+ decay_params = []
20
+ no_decay_params = []
21
+
22
+ for name, param in model.named_parameters():
23
+ if not param.requires_grad:
24
+ continue
25
+
26
+ # Apply weight decay to all params except biases and layer norms
27
+ if 'bias' in name or 'norm' in name:
28
+ no_decay_params.append(param)
29
+ else:
30
+ decay_params.append(param)
31
+
32
+ return decay_params, no_decay_params
33
+
34
+
35
+ @register_optimizer("adam")
36
+ def create_adam(model: BaseModel, config: TrainingConfig) -> optim.Adam:
37
+ """
38
+ Create Adam optimizer with weight decay applied selectively.
39
+
40
+ Args:
41
+ model: Model instance
42
+ config: TrainingConfig
43
+
44
+ Returns:
45
+ Adam optimizer instance
46
+ """
47
+ optimizer_config = config.optimizer
48
+
49
+ # Separate parameters for weight decay
50
+ decay_params, no_decay_params = _separate_parameters(model)
51
+
52
+ param_groups = [
53
+ {"params": decay_params, "weight_decay": optimizer_config.weight_decay},
54
+ {"params": no_decay_params, "weight_decay": 0.0},
55
+ ]
56
+
57
+ optimizer = optim.Adam(
58
+ param_groups,
59
+ lr=optimizer_config.learning_rate,
60
+ betas=optimizer_config.betas,
61
+ eps=optimizer_config.eps,
62
+ )
63
+
64
+ return optimizer
code/TaoTrain/src/taoTrain/optimizers/adamw.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """AdamW optimizer factory."""
2
+
3
+ import torch.optim as optim
4
+ from taoTrain.core.base import BaseModel
5
+ from taoTrain.config import TrainingConfig
6
+ from .registry import register_optimizer
7
+
8
+
9
+ def _separate_parameters(model: BaseModel) -> tuple[list, list]:
10
+ """
11
+ Separate model parameters into decay and no-decay groups.
12
+
13
+ Args:
14
+ model: Model instance
15
+
16
+ Returns:
17
+ Tuple of (decay_params, no_decay_params)
18
+ """
19
+ decay_params = []
20
+ no_decay_params = []
21
+
22
+ for name, param in model.named_parameters():
23
+ if not param.requires_grad:
24
+ continue
25
+
26
+ # Apply weight decay to all params except biases and layer norms
27
+ if 'bias' in name or 'norm' in name:
28
+ no_decay_params.append(param)
29
+ else:
30
+ decay_params.append(param)
31
+
32
+ return decay_params, no_decay_params
33
+
34
+
35
+ @register_optimizer("adamw")
36
+ def create_adamw(model: BaseModel, config: TrainingConfig) -> optim.AdamW:
37
+ """
38
+ Create AdamW optimizer with weight decay applied selectively.
39
+
40
+ Args:
41
+ model: Model instance
42
+ config: TrainingConfig
43
+
44
+ Returns:
45
+ AdamW optimizer instance
46
+ """
47
+ optimizer_config = config.optimizer
48
+
49
+ # Separate parameters for weight decay
50
+ decay_params, no_decay_params = _separate_parameters(model)
51
+
52
+ param_groups = [
53
+ {"params": decay_params, "weight_decay": optimizer_config.weight_decay},
54
+ {"params": no_decay_params, "weight_decay": 0.0},
55
+ ]
56
+
57
+ optimizer = optim.AdamW(
58
+ param_groups,
59
+ lr=optimizer_config.learning_rate,
60
+ betas=optimizer_config.betas,
61
+ eps=optimizer_config.eps,
62
+ )
63
+
64
+ return optimizer
code/TaoTrain/src/taoTrain/optimizers/hybrid_muon_adamw.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hybrid Muon + AdamW Optimizer for TaoNet models.
3
+
4
+ Combines:
5
+ - Muon: Specialized optimization for 2D weight matrices (linear layers)
6
+ Leverages orthogonal/SVD-based updates for better convergence on matrix weights
7
+ - AdamW: Adaptive moment estimation for 1D parameters (biases, norms, embeddings)
8
+
9
+ Key Design:
10
+ - 2D weight matrices use Muon optimizer with separate LRs for different layer types
11
+ - 1D parameters use AdamW with lower learning rate
12
+ - Inherits from torch.optim.Optimizer for LR scheduler compatibility
13
+ """
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ from typing import Dict, List, Any
18
+
19
+ from .registry import register_optimizer
20
+ from taoTrain.config import TrainingConfig
21
+ from taoTrain.core.base import BaseModel
22
+
23
+
24
+ def _get_param_dimensionality(param: torch.Tensor) -> str:
25
+ """
26
+ Determine if a parameter is 2D (weight matrix) or 1D (bias/embedding/norm).
27
+
28
+ Returns:
29
+ 'weight_2d': Parameter has 2+ dimensions (for Muon)
30
+ '1d_other': Parameter is 1D (for AdamW)
31
+ """
32
+ if param.dim() >= 2:
33
+ return 'weight_2d'
34
+ return '1d_other'
35
+
36
+
37
+ class HybridMuonAdamW(torch.optim.Optimizer):
38
+ """
39
+ Composite optimizer combining Muon (for 2D weights) and AdamW (for 1D params).
40
+
41
+ Why: Muon is specialized for 2D weight matrices in neural networks.
42
+ Biases, embeddings, and layer norms should use AdamW for adaptive convergence.
43
+
44
+ Inherits from torch.optim.Optimizer to be compatible with LR schedulers.
45
+ Manages two internal optimizers: Muon and AdamW.
46
+
47
+ Public interface compatible with standard PyTorch optimizers:
48
+ - step(): delegates to both internal optimizers
49
+ - zero_grad(set_to_none=True): delegates to both
50
+ - state_dict(): returns combined state
51
+ - load_state_dict(state): restores combined state
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ muon_params_groups: List[Dict[str, Any]],
57
+ adamw_params_group: Dict[str, Any],
58
+ muon_kwargs: Dict[str, Any],
59
+ adamw_kwargs: Dict[str, Any]
60
+ ):
61
+ """
62
+ Initialize HybridMuonAdamW optimizer.
63
+
64
+ Args:
65
+ muon_params_groups: List of param groups for Muon optimizer
66
+ Each group should have 'params' and 'lr' keys
67
+ adamw_params_group: Dict param group for AdamW optimizer
68
+ Should have 'params' and 'lr' keys
69
+ muon_kwargs: Additional kwargs for torch.optim.Muon init
70
+ adamw_kwargs: Additional kwargs for torch.optim.AdamW init
71
+ """
72
+ # Dummy params list for parent Optimizer class (required for registration)
73
+ # Real params are managed by internal optimizers
74
+ dummy_param = torch.nn.Parameter(torch.zeros(1))
75
+ super().__init__([dummy_param], {})
76
+
77
+ # Create internal optimizers with their parameter groups
78
+ try:
79
+ self.muon = torch.optim.Muon(muon_params_groups, **muon_kwargs)
80
+ except AttributeError:
81
+ raise RuntimeError(
82
+ "torch.optim.Muon not available. "
83
+ "Muon optimizer requires PyTorch 2.1+. "
84
+ "Please upgrade PyTorch: pip install --upgrade torch"
85
+ )
86
+
87
+ self.adamw = torch.optim.AdamW([adamw_params_group], **adamw_kwargs)
88
+
89
+ # Merge param_groups from both optimizers
90
+ # LR schedulers will update these merged groups
91
+ self.param_groups = self.muon.param_groups + self.adamw.param_groups
92
+
93
+ def step(self, closure=None):
94
+ """Execute optimization step for both Muon and AdamW."""
95
+ if closure is not None:
96
+ loss = closure()
97
+ else:
98
+ loss = None
99
+
100
+ self.muon.step(closure)
101
+ self.adamw.step(closure)
102
+
103
+ return loss
104
+
105
+ def zero_grad(self, set_to_none: bool = False):
106
+ """Zero gradients in both optimizers."""
107
+ self.muon.zero_grad(set_to_none=set_to_none)
108
+ self.adamw.zero_grad(set_to_none=set_to_none)
109
+
110
+ def state_dict(self) -> Dict[str, Any]:
111
+ """Return combined state dict for both optimizers."""
112
+ return {
113
+ 'muon': self.muon.state_dict(),
114
+ 'adamw': self.adamw.state_dict(),
115
+ }
116
+
117
+ def load_state_dict(self, state_dict: Dict[str, Any]):
118
+ """
119
+ Restore state from combined state dict.
120
+
121
+ Supports both new format (composite with Muon+AdamW) and legacy format
122
+ (AdamW-only checkpoints) for backward compatibility.
123
+ """
124
+ if isinstance(state_dict, dict):
125
+ if 'muon' in state_dict and 'adamw' in state_dict:
126
+ # New format: composite optimizer with both Muon and AdamW
127
+ self.muon.load_state_dict(state_dict['muon'])
128
+ self.adamw.load_state_dict(state_dict['adamw'])
129
+ elif 'state' in state_dict or 'param_groups' in state_dict:
130
+ # Legacy format: old AdamW-only checkpoint
131
+ # Load into AdamW optimizer only, Muon starts fresh
132
+ try:
133
+ self.adamw.load_state_dict(state_dict)
134
+ print(" ⚠️ Loaded legacy AdamW-only checkpoint (Muon state initialized fresh)")
135
+ except Exception as e:
136
+ print(f" ⚠️ Failed to load optimizer state: {e}")
137
+ else:
138
+ print(f" ⚠️ Unknown checkpoint format")
139
+ else:
140
+ raise ValueError(f"Expected dict, got {type(state_dict)}")
141
+
142
+
143
+ @register_optimizer("hybrid_muon_adamw")
144
+ def create_hybrid_muon_adamw(model: BaseModel, training_config: TrainingConfig) -> HybridMuonAdamW:
145
+ """
146
+ Factory function to create HybridMuonAdamW optimizer from model and config.
147
+
148
+ Parameter grouping strategy:
149
+ - Muon groups (2D weight matrices):
150
+ * Regular Linear 2D weights → learning_rate
151
+ * (BitLinear would use bitlinear_lr, but skipped in BF16 version)
152
+ - AdamW group (1D parameters):
153
+ * Biases, layer norms, embeddings → adamw_lr
154
+
155
+ Args:
156
+ model: PyTorch model to optimize
157
+ training_config: TrainingConfig with optimizer hyperparameters:
158
+ - learning_rate: LR for 2D Linear weights (Muon)
159
+ - adamw_lr: LR for 1D parameters (AdamW)
160
+ - weight_decay: L2 regularization
161
+ - betas: (beta1, beta2) for AdamW
162
+ - eps: epsilon for numerical stability
163
+
164
+ Returns:
165
+ HybridMuonAdamW optimizer instance
166
+ """
167
+
168
+ # Separate parameters by dimensionality
169
+ linear_2d_weights = []
170
+ params_1d = []
171
+
172
+ # Classify all parameters
173
+ for module_name, module in model.named_modules():
174
+ for param_name, param in module.named_parameters(recurse=False):
175
+ if not param.requires_grad:
176
+ continue
177
+
178
+ param_dim = _get_param_dimensionality(param)
179
+
180
+ if param_dim == 'weight_2d' and isinstance(module, nn.Linear):
181
+ # 2D Linear weights → Muon
182
+ linear_2d_weights.append(param)
183
+ else:
184
+ # Everything else → AdamW (1D params + other 2D tensors)
185
+ params_1d.append(param)
186
+
187
+ # Verify we got all parameters
188
+ total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
189
+ muon_params = sum(p.numel() for p in linear_2d_weights)
190
+ adamw_params = sum(p.numel() for p in params_1d)
191
+ assert total_params == muon_params + adamw_params, \
192
+ f"Parameter accounting error: {total_params} != {muon_params} + {adamw_params}"
193
+
194
+ # Prepare Muon parameter groups (one group with single LR for all Linear 2D weights)
195
+ muon_params_groups = [
196
+ {
197
+ 'params': linear_2d_weights,
198
+ 'lr': training_config.optimizer.learning_rate, # Use main learning_rate for Muon
199
+ }
200
+ ]
201
+
202
+ # Prepare AdamW parameter group (1D parameters with lower LR)
203
+ adamw_params_group = {
204
+ 'params': params_1d,
205
+ 'lr': training_config.optimizer.adamw_lr, # Use adamw_lr for 1D params
206
+ 'weight_decay': training_config.optimizer.weight_decay,
207
+ }
208
+
209
+ # Extract Muon kwargs (settings common to all Muon param groups)
210
+ muon_kwargs = {
211
+ 'lr': training_config.optimizer.learning_rate, # Will be overridden by param_groups above
212
+ }
213
+
214
+ # Extract AdamW kwargs
215
+ adamw_kwargs = {
216
+ 'betas': training_config.optimizer.betas,
217
+ 'eps': training_config.optimizer.eps,
218
+ 'weight_decay': training_config.optimizer.weight_decay,
219
+ }
220
+
221
+ # Print optimizer setup details
222
+ print(f"\n{'='*70}")
223
+ print("OPTIMIZER SETUP - HYBRID MUON + ADAMW")
224
+ print(f"{'='*70}")
225
+ print("\n[MUON - 2D Weight Matrices (Orthogonal Optimization)]")
226
+ print(f"Linear 2D weights: {muon_params/1e6:>8.2f}M")
227
+ print(f" Learning Rate: {training_config.optimizer.learning_rate}")
228
+ print(f"\n[ADAMW - 1D Parameters (Adaptive Moments)]")
229
+ print(f"Biases, embeddings, norms: {adamw_params/1e6:>8.2f}M")
230
+ print(f" Learning Rate: {training_config.optimizer.adamw_lr}")
231
+ print(f"{'─'*70}")
232
+ print(f"Total (Muon): {muon_params/1e6:>8.2f}M")
233
+ print(f"Total (AdamW): {adamw_params/1e6:>8.2f}M")
234
+ print(f"Total (All): {total_params/1e6:>8.2f}M")
235
+ print(f"{'─'*70}")
236
+ print(f"Hyperparameters:")
237
+ print(f" Weight Decay: {training_config.optimizer.weight_decay}")
238
+ print(f" Betas (AdamW): {training_config.optimizer.betas}")
239
+ print(f" Epsilon: {training_config.optimizer.eps}")
240
+ print(f"{'='*70}\n")
241
+
242
+ # Create and return optimizer
243
+ return HybridMuonAdamW(muon_params_groups, adamw_params_group, muon_kwargs, adamw_kwargs)
code/TaoTrain/src/taoTrain/optimizers/registry.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Optimizer registry and factory for instantiating optimizers."""
2
+
3
+ from typing import Dict, Type, Callable, Any
4
+ import torch.optim as optim
5
+ from taoTrain.core.base import BaseModel
6
+ from taoTrain.config import TrainingConfig, OptimizerEnum
7
+
8
+
9
+ # Global registry for optimizers
10
+ _OPTIMIZER_REGISTRY: Dict[str, Callable] = {}
11
+
12
+
13
+ def register_optimizer(name: str):
14
+ """
15
+ Decorator to register a custom optimizer factory function.
16
+
17
+ Args:
18
+ name: Name of the optimizer (e.g., 'adamw', 'adam', 'sgd')
19
+ """
20
+ def decorator(fn: Callable) -> Callable:
21
+ if name in _OPTIMIZER_REGISTRY:
22
+ raise ValueError(f"Optimizer '{name}' is already registered")
23
+ _OPTIMIZER_REGISTRY[name] = fn
24
+ return fn
25
+ return decorator
26
+
27
+
28
+ def get_registered_optimizers() -> Dict[str, Callable]:
29
+ """Get all registered optimizer factory functions."""
30
+ return _OPTIMIZER_REGISTRY.copy()
31
+
32
+
33
+ def get_optimizer(
34
+ model: BaseModel,
35
+ config: TrainingConfig,
36
+ ) -> optim.Optimizer:
37
+ """
38
+ Create an optimizer instance from config.
39
+
40
+ Args:
41
+ model: Model to optimize
42
+ config: TrainingConfig with optimizer configuration
43
+
44
+ Returns:
45
+ Optimizer instance
46
+
47
+ Raises:
48
+ ValueError: If optimizer type is not registered
49
+ """
50
+ # Handle both enum and string values
51
+ optimizer_type = config.optimizer.optimizer_type
52
+ if isinstance(optimizer_type, str):
53
+ optimizer_name = optimizer_type
54
+ else:
55
+ optimizer_name = optimizer_type.value
56
+
57
+ if optimizer_name not in _OPTIMIZER_REGISTRY:
58
+ raise ValueError(
59
+ f"Unknown optimizer: {optimizer_name}. "
60
+ f"Available: {list(_OPTIMIZER_REGISTRY.keys())}"
61
+ )
62
+
63
+ factory_fn = _OPTIMIZER_REGISTRY[optimizer_name]
64
+ return factory_fn(model, config)
65
+
66
+
67
+ def register_builtin_optimizers():
68
+ """Register all built-in optimizers."""
69
+ # Import here to trigger decorator registration (avoid circular imports)
70
+ from . import adamw # noqa: F401
71
+ from . import adam # noqa: F401
72
+ from . import sgd # noqa: F401
73
+ from . import hybrid_muon_adamw # noqa: F401
74
+
75
+
76
+ # Auto-register built-in optimizers when module is imported
77
+ register_builtin_optimizers()
code/TaoTrain/src/taoTrain/optimizers/sgd.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SGD optimizer factory."""
2
+
3
+ import torch.optim as optim
4
+ from taoTrain.core.base import BaseModel
5
+ from taoTrain.config import TrainingConfig
6
+ from .registry import register_optimizer
7
+
8
+
9
+ def _separate_parameters(model: BaseModel) -> tuple[list, list]:
10
+ """
11
+ Separate model parameters into decay and no-decay groups.
12
+
13
+ Args:
14
+ model: Model instance
15
+
16
+ Returns:
17
+ Tuple of (decay_params, no_decay_params)
18
+ """
19
+ decay_params = []
20
+ no_decay_params = []
21
+
22
+ for name, param in model.named_parameters():
23
+ if not param.requires_grad:
24
+ continue
25
+
26
+ # Apply weight decay to all params except biases and layer norms
27
+ if 'bias' in name or 'norm' in name:
28
+ no_decay_params.append(param)
29
+ else:
30
+ decay_params.append(param)
31
+
32
+ return decay_params, no_decay_params
33
+
34
+
35
+ @register_optimizer("sgd")
36
+ def create_sgd(model: BaseModel, config: TrainingConfig) -> optim.SGD:
37
+ """
38
+ Create SGD optimizer with weight decay applied selectively.
39
+
40
+ Args:
41
+ model: Model instance
42
+ config: TrainingConfig
43
+
44
+ Returns:
45
+ SGD optimizer instance
46
+ """
47
+ optimizer_config = config.optimizer
48
+
49
+ # Separate parameters for weight decay
50
+ decay_params, no_decay_params = _separate_parameters(model)
51
+
52
+ param_groups = [
53
+ {"params": decay_params, "weight_decay": optimizer_config.weight_decay},
54
+ {"params": no_decay_params, "weight_decay": 0.0},
55
+ ]
56
+
57
+ optimizer = optim.SGD(
58
+ param_groups,
59
+ lr=optimizer_config.learning_rate,
60
+ momentum=optimizer_config.betas[0], # Use first beta as momentum
61
+ )
62
+
63
+ return optimizer
code/TaoTrain/src/taoTrain/schedulers/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Learning rate scheduler registry and factories."""
2
+
3
+ from .registry import (
4
+ register_scheduler,
5
+ get_scheduler,
6
+ get_registered_schedulers,
7
+ )
8
+
9
+ __all__ = [
10
+ "register_scheduler",
11
+ "get_scheduler",
12
+ "get_registered_schedulers",
13
+ ]
code/TaoTrain/src/taoTrain/schedulers/constant.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Constant learning rate scheduler with optional warmup."""
2
+
3
+ import torch.optim as optim
4
+ from torch.optim.lr_scheduler import LambdaLR
5
+ from taoTrain.config import TrainingConfig
6
+ from .registry import register_scheduler
7
+
8
+
9
+ @register_scheduler("constant")
10
+ def create_constant(
11
+ optimizer: optim.Optimizer,
12
+ config: TrainingConfig,
13
+ num_training_steps: int,
14
+ ) -> LambdaLR:
15
+ """
16
+ Create a constant learning rate scheduler with optional linear warmup.
17
+
18
+ Linearly increases learning rate from 0 to peak over warmup steps,
19
+ then keeps it constant for the rest of training.
20
+
21
+ Args:
22
+ optimizer: Optimizer instance
23
+ config: TrainingConfig with scheduler configuration
24
+ num_training_steps: Total number of training steps
25
+
26
+ Returns:
27
+ LambdaLR scheduler instance
28
+ """
29
+ scheduler_config = config.scheduler
30
+
31
+ # Determine warmup steps
32
+ if scheduler_config.warmup_steps > 0:
33
+ warmup_steps = scheduler_config.warmup_steps
34
+ else:
35
+ warmup_steps = int(num_training_steps * scheduler_config.warmup_ratio)
36
+
37
+ def lr_lambda(step):
38
+ """Constant learning rate with optional warmup."""
39
+ if step < warmup_steps:
40
+ # Linear warmup
41
+ return float(step) / float(max(1, warmup_steps))
42
+ return 1.0
43
+
44
+ return LambdaLR(optimizer, lr_lambda, last_epoch=scheduler_config.last_epoch)
code/TaoTrain/src/taoTrain/schedulers/cosine_warmup.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Cosine annealing with warmup learning rate scheduler."""
2
+
3
+ import math
4
+ import torch.optim as optim
5
+ from torch.optim.lr_scheduler import LambdaLR
6
+ from taoTrain.config import TrainingConfig
7
+ from .registry import register_scheduler
8
+
9
+
10
+ @register_scheduler("cosineWarmup")
11
+ def create_cosine_warmup(
12
+ optimizer: optim.Optimizer,
13
+ config: TrainingConfig,
14
+ num_training_steps: int,
15
+ ) -> LambdaLR:
16
+ """
17
+ Create a cosine annealing scheduler with optional linear warmup, steady phase, and decay.
18
+
19
+ Three-phase schedule:
20
+ 1. Linear warmup: 0 → 1.0 (warmup_steps)
21
+ 2. Steady phase: 1.0 (plateau at peak LR)
22
+ 3. Cosine decay: 1.0 → min_lr_ratio
23
+
24
+ Args:
25
+ optimizer: Optimizer instance
26
+ config: TrainingConfig with scheduler configuration:
27
+ - warmup_steps: linear warmup duration (overrides warmup_ratio if > 0)
28
+ - warmup_ratio: warmup as fraction of total steps (default 0.1)
29
+ - steady_ratio: steady phase as fraction of total steps (default 0.0)
30
+ - min_lr_ratio: minimum LR at end as fraction of peak (default 0.0)
31
+ num_training_steps: Total number of training steps
32
+
33
+ Returns:
34
+ LambdaLR scheduler instance
35
+ """
36
+ scheduler_config = config.scheduler
37
+
38
+ # Determine warmup steps
39
+ if scheduler_config.warmup_steps > 0:
40
+ warmup_steps = scheduler_config.warmup_steps
41
+ else:
42
+ warmup_steps = int(num_training_steps * scheduler_config.warmup_ratio)
43
+
44
+ # Determine steady phase steps
45
+ steady_steps = int(num_training_steps * scheduler_config.steady_ratio)
46
+
47
+ # Remaining steps for cosine decay
48
+ decay_steps = num_training_steps - warmup_steps - steady_steps
49
+
50
+ min_lr_ratio = scheduler_config.min_lr_ratio
51
+ num_cycles = scheduler_config.num_cycles
52
+
53
+ print(f"✓ CosineWarmup scheduler: warmup={warmup_steps}, steady={steady_steps}, decay={decay_steps} (total={num_training_steps})")
54
+ print(f" min_lr_ratio={min_lr_ratio}, num_cycles={num_cycles}")
55
+
56
+ def lr_lambda(step):
57
+ """Three-phase LR schedule: warmup → steady → cosine decay."""
58
+ if step < warmup_steps:
59
+ # Phase 1: Linear warmup from 0 to 1.0
60
+ return float(step) / float(max(1, warmup_steps))
61
+
62
+ elif step < warmup_steps + steady_steps:
63
+ # Phase 2: Steady at peak LR (1.0)
64
+ return 1.0
65
+
66
+ else:
67
+ # Phase 3: Cosine decay from 1.0 to min_lr_ratio
68
+ decay_step = step - warmup_steps - steady_steps
69
+ progress = float(decay_step) / float(max(1, decay_steps))
70
+
71
+ # Cosine annealing: 0.5 * (1 + cos(π * progress))
72
+ cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
73
+
74
+ # Scale to reach min_lr_ratio at the end
75
+ return cosine_decay * (1.0 - min_lr_ratio) + min_lr_ratio
76
+
77
+ return LambdaLR(optimizer, lr_lambda, last_epoch=scheduler_config.last_epoch)
code/TaoTrain/src/taoTrain/schedulers/linear_warmup.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Linear warmup learning rate scheduler."""
2
+
3
+ import torch.optim as optim
4
+ from torch.optim.lr_scheduler import LambdaLR
5
+ from taoTrain.config import TrainingConfig
6
+ from .registry import register_scheduler
7
+
8
+
9
+ @register_scheduler("linearWarmup")
10
+ def create_linear_warmup(
11
+ optimizer: optim.Optimizer,
12
+ config: TrainingConfig,
13
+ num_training_steps: int,
14
+ ) -> LambdaLR:
15
+ """
16
+ Create a linear warmup scheduler.
17
+
18
+ Linearly increases learning rate from 0 to peak over warmup steps,
19
+ then keeps it constant.
20
+
21
+ Args:
22
+ optimizer: Optimizer instance
23
+ config: TrainingConfig with scheduler configuration
24
+ num_training_steps: Total number of training steps
25
+
26
+ Returns:
27
+ LambdaLR scheduler instance
28
+ """
29
+ scheduler_config = config.scheduler
30
+
31
+ # Determine warmup steps
32
+ if scheduler_config.warmup_steps > 0:
33
+ warmup_steps = scheduler_config.warmup_steps
34
+ else:
35
+ warmup_steps = int(num_training_steps * scheduler_config.warmup_ratio)
36
+
37
+ def lr_lambda(step):
38
+ """Linear warmup learning rate schedule."""
39
+ if step < warmup_steps:
40
+ return float(step) / float(max(1, warmup_steps))
41
+ return 1.0
42
+
43
+ return LambdaLR(optimizer, lr_lambda, last_epoch=scheduler_config.last_epoch)
code/TaoTrain/src/taoTrain/schedulers/registry.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Scheduler registry and factory for instantiating learning rate schedulers."""
2
+
3
+ from typing import Dict, Callable, Optional
4
+ import torch.optim as optim
5
+ from torch.optim.lr_scheduler import LambdaLR
6
+ from taoTrain.config import TrainingConfig, SchedulerEnum
7
+
8
+
9
+ # Global registry for schedulers
10
+ _SCHEDULER_REGISTRY: Dict[str, Callable] = {}
11
+
12
+
13
+ def register_scheduler(name: str):
14
+ """
15
+ Decorator to register a custom scheduler factory function.
16
+
17
+ Args:
18
+ name: Name of the scheduler (e.g., 'linearWarmup', 'cosineWarmup', 'constant')
19
+ """
20
+ def decorator(fn: Callable) -> Callable:
21
+ if name in _SCHEDULER_REGISTRY:
22
+ raise ValueError(f"Scheduler '{name}' is already registered")
23
+ _SCHEDULER_REGISTRY[name] = fn
24
+ return fn
25
+ return decorator
26
+
27
+
28
+ def get_registered_schedulers() -> Dict[str, Callable]:
29
+ """Get all registered scheduler factory functions."""
30
+ return _SCHEDULER_REGISTRY.copy()
31
+
32
+
33
+ def get_scheduler(
34
+ optimizer: optim.Optimizer,
35
+ config: TrainingConfig,
36
+ num_training_steps: int,
37
+ ) -> LambdaLR:
38
+ """
39
+ Create a learning rate scheduler instance from config.
40
+
41
+ Args:
42
+ optimizer: Optimizer to schedule learning rate for
43
+ config: TrainingConfig with scheduler configuration
44
+ num_training_steps: Total number of training steps
45
+
46
+ Returns:
47
+ Learning rate scheduler instance
48
+
49
+ Raises:
50
+ ValueError: If scheduler type is not registered
51
+ """
52
+ # Handle both enum and string values
53
+ scheduler_type = config.scheduler.scheduler_type
54
+ if isinstance(scheduler_type, str):
55
+ scheduler_name = scheduler_type
56
+ else:
57
+ scheduler_name = scheduler_type.value
58
+
59
+ if scheduler_name not in _SCHEDULER_REGISTRY:
60
+ raise ValueError(
61
+ f"Unknown scheduler: {scheduler_name}. "
62
+ f"Available: {list(_SCHEDULER_REGISTRY.keys())}"
63
+ )
64
+
65
+ factory_fn = _SCHEDULER_REGISTRY[scheduler_name]
66
+ return factory_fn(optimizer, config, num_training_steps)
67
+
68
+
69
+ def register_builtin_schedulers():
70
+ """Register all built-in schedulers."""
71
+ # Import here to trigger decorator registration (avoid circular imports)
72
+ from . import linear_warmup # noqa: F401
73
+ from . import cosine_warmup # noqa: F401
74
+ from . import constant # noqa: F401
75
+
76
+
77
+ # Auto-register built-in schedulers when module is imported
78
+ register_builtin_schedulers()