Text Generation
Transformers
PyTorch
English
taonet_mini_t2
taonet
taotern
ssm
state-space-model
dplr
custom_code
experimental
Instructions to use TaoTern/TaoNet-mini-T2 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use TaoTern/TaoNet-mini-T2 with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="TaoTern/TaoNet-mini-T2", trust_remote_code=True)# Load model directly from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("TaoTern/TaoNet-mini-T2", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- vLLM
How to use TaoTern/TaoNet-mini-T2 with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "TaoTern/TaoNet-mini-T2" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "TaoTern/TaoNet-mini-T2", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/TaoTern/TaoNet-mini-T2
- SGLang
How to use TaoTern/TaoNet-mini-T2 with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "TaoTern/TaoNet-mini-T2" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "TaoTern/TaoNet-mini-T2", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "TaoTern/TaoNet-mini-T2" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "TaoTern/TaoNet-mini-T2", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use TaoTern/TaoNet-mini-T2 with Docker Model Runner:
docker model run hf.co/TaoTern/TaoNet-mini-T2
Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +2 -0
- code/TaoTrain/src/taoTrain.egg-info/PKG-INFO +451 -0
- code/TaoTrain/src/taoTrain.egg-info/SOURCES.txt +65 -0
- code/TaoTrain/src/taoTrain.egg-info/requires.txt +20 -0
- code/TaoTrain/src/taoTrain.egg-info/top_level.txt +1 -0
- code/TaoTrain/src/taoTrain/benchmarks/__init__.py +5 -0
- code/TaoTrain/src/taoTrain/benchmarks/runner.py +221 -0
- code/TaoTrain/src/taoTrain/checkpointing/__init__.py +5 -0
- code/TaoTrain/src/taoTrain/checkpointing/checkpoint.py +194 -0
- code/TaoTrain/src/taoTrain/core/__init__.py +5 -0
- code/TaoTrain/src/taoTrain/core/base.py +271 -0
- code/TaoTrain/src/taoTrain/data/__init__.py +56 -0
- code/TaoTrain/src/taoTrain/data/async_loader.py +204 -0
- code/TaoTrain/src/taoTrain/data/chunk_manager.py +452 -0
- code/TaoTrain/src/taoTrain/data/factory.py +108 -0
- code/TaoTrain/src/taoTrain/data/hf_base.py +82 -0
- code/TaoTrain/src/taoTrain/data/hf_pretrain.py +78 -0
- code/TaoTrain/src/taoTrain/data/hf_rl.py +73 -0
- code/TaoTrain/src/taoTrain/data/hf_sft.py +81 -0
- code/TaoTrain/src/taoTrain/data/jsonl_base.py +220 -0
- code/TaoTrain/src/taoTrain/data/loaders.py +85 -0
- code/TaoTrain/src/taoTrain/data/pretrain_jsonl.py +65 -0
- code/TaoTrain/src/taoTrain/data/rl_jsonl.py +58 -0
- code/TaoTrain/src/taoTrain/data/sft_jsonl.py +156 -0
- code/TaoTrain/src/taoTrain/data/sft_utils.py +161 -0
- code/TaoTrain/src/taoTrain/data/tokenization_queue.py +410 -0
- code/TaoTrain/src/taoTrain/data/tokenizer.py +118 -0
- code/TaoTrain/src/taoTrain/inference/__init__.py +5 -0
- code/TaoTrain/src/taoTrain/inference/inferencer.py +301 -0
- code/TaoTrain/src/taoTrain/inference/tui.py +161 -0
- code/TaoTrain/src/taoTrain/logging/__init__.py +5 -0
- code/TaoTrain/src/taoTrain/logging/aim_logger.py +153 -0
- code/TaoTrain/src/taoTrain/models/__init__.py +5 -0
- code/TaoTrain/src/taoTrain/models/embeddings.py +51 -0
- code/TaoTrain/src/taoTrain/models/mla_components.py +370 -0
- code/TaoTrain/src/taoTrain/models/registry.py +73 -0
- code/TaoTrain/src/taoTrain/models/taonet.py +248 -0
- code/TaoTrain/src/taoTrain/models/taonet_ssm.py +654 -0
- code/TaoTrain/src/taoTrain/models/transformer.py +315 -0
- code/TaoTrain/src/taoTrain/optimizers/__init__.py +13 -0
- code/TaoTrain/src/taoTrain/optimizers/adam.py +64 -0
- code/TaoTrain/src/taoTrain/optimizers/adamw.py +64 -0
- code/TaoTrain/src/taoTrain/optimizers/hybrid_muon_adamw.py +243 -0
- code/TaoTrain/src/taoTrain/optimizers/registry.py +77 -0
- code/TaoTrain/src/taoTrain/optimizers/sgd.py +63 -0
- code/TaoTrain/src/taoTrain/schedulers/__init__.py +13 -0
- code/TaoTrain/src/taoTrain/schedulers/constant.py +44 -0
- code/TaoTrain/src/taoTrain/schedulers/cosine_warmup.py +77 -0
- code/TaoTrain/src/taoTrain/schedulers/linear_warmup.py +43 -0
- code/TaoTrain/src/taoTrain/schedulers/registry.py +78 -0
.gitattributes
CHANGED
|
@@ -3,3 +3,5 @@
|
|
| 3 |
*.vocab filter=lfs diff=lfs merge=lfs -text
|
| 4 |
*.csv filter=lfs diff=lfs merge=lfs -text
|
| 5 |
|
|
|
|
|
|
|
|
|
| 3 |
*.vocab filter=lfs diff=lfs merge=lfs -text
|
| 4 |
*.csv filter=lfs diff=lfs merge=lfs -text
|
| 5 |
|
| 6 |
+
code/Taotern_SSM/Gamma[[:space:]]Distributed[[:space:]]Ternary[[:space:]]HiPPO.pdf filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
code/Taotern_LLM_Experiments/docs/Taotern_Documentation_AI_Architecture.zip filter=lfs diff=lfs merge=lfs -text
|
code/TaoTrain/src/taoTrain.egg-info/PKG-INFO
ADDED
|
@@ -0,0 +1,451 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.4
|
| 2 |
+
Name: taoTrain
|
| 3 |
+
Version: 0.1.0
|
| 4 |
+
Summary: Clean, modular PyTorch LLM training framework with pluggable architectures, AimStack logging, and TUI inference
|
| 5 |
+
Author-email: Felix <felix@example.com>
|
| 6 |
+
License: MIT
|
| 7 |
+
Requires-Python: >=3.10
|
| 8 |
+
Description-Content-Type: text/markdown
|
| 9 |
+
Requires-Dist: torch>=2.0.0
|
| 10 |
+
Requires-Dist: transformers>=4.30.0
|
| 11 |
+
Requires-Dist: datasets>=2.10.0
|
| 12 |
+
Requires-Dist: pydantic>=2.0.0
|
| 13 |
+
Requires-Dist: pydantic-settings>=2.0.0
|
| 14 |
+
Requires-Dist: aim>=3.15.0
|
| 15 |
+
Requires-Dist: click>=8.1.0
|
| 16 |
+
Requires-Dist: rich>=13.0.0
|
| 17 |
+
Requires-Dist: textual>=0.30.0
|
| 18 |
+
Requires-Dist: numpy>=1.24.0
|
| 19 |
+
Requires-Dist: tqdm>=4.65.0
|
| 20 |
+
Requires-Dist: sentencepiece>=0.1.99
|
| 21 |
+
Provides-Extra: dev
|
| 22 |
+
Requires-Dist: pytest>=7.4.0; extra == "dev"
|
| 23 |
+
Requires-Dist: pytest-cov>=4.1.0; extra == "dev"
|
| 24 |
+
Requires-Dist: pytest-xdist>=3.3.0; extra == "dev"
|
| 25 |
+
Requires-Dist: black>=23.7.0; extra == "dev"
|
| 26 |
+
Requires-Dist: ruff>=0.0.280; extra == "dev"
|
| 27 |
+
Requires-Dist: typing-extensions>=4.7.0; extra == "dev"
|
| 28 |
+
|
| 29 |
+
# TaoTrain: Production-Grade LLM Training Framework
|
| 30 |
+
|
| 31 |
+
**TaoTrain** is a sophisticated PyTorch framework for training large language models at every scale—from experimental pretraining through supervised fine-tuning to reinforcement learning. Unlike fragmented training scripts or heavyweight frameworks, TaoTrain unifies the **entire training pipeline** in a clean, modular codebase that appeals to both ML engineers and software engineers.
|
| 32 |
+
|
| 33 |
+
## Current Taotern Work
|
| 34 |
+
|
| 35 |
+
TaoTrain now includes the Taotern comparison architectures used by the current SSM LLM work:
|
| 36 |
+
|
| 37 |
+
- `taonet`: the attention/MLA baseline.
|
| 38 |
+
- `taonet_ssm`: the TaoNet shell with the attention mixer replaced by the Gamma Space Model DPLR SSM.
|
| 39 |
+
- `taonet_hybrid`: an alternating attention/SSM TaoNet used for the current best 200M-class candidate.
|
| 40 |
+
|
| 41 |
+
The current selected deployment-oriented run is `hybrid_ssm_first_199m`, a `199,480,928` parameter model with 16 layers: SSM layers at `0,2,4,6,8,10,12,14` and attention layers at `1,3,5,7,9,11,13,15`. It uses the DPLR SSM core with split two-lane mixing, channel gates, per-channel local shift, and the faster convolution path for long-sequence training.
|
| 42 |
+
|
| 43 |
+
Remote run `taotern-200m-hybrid-chat-20260512` trains this model on TaoData for a 4B-token base stage and then runs SFT so the final artifact can be loaded as a chat model. The trainable fixes added for this run are:
|
| 44 |
+
|
| 45 |
+
- Async JSONL iteration keeps polling while tokenization workers are alive instead of ending early after a temporary empty queue.
|
| 46 |
+
- Cached JSONL scan metadata is reused safely while recomputing chunk ranges for the active `samples_per_chunk` and `max_samples` settings.
|
| 47 |
+
|
| 48 |
+
## Why TaoTrain?
|
| 49 |
+
|
| 50 |
+
- **Complete Unified Pipeline**: Pretraining → SFT → RL in a single, consistent framework. No context switching between different codebases or architectures.
|
| 51 |
+
- **Production-Grade Engineering**: Type-safe Pydantic configs, comprehensive checkpointing, AimStack integration, and proper gradient handling—not research code, but a framework you can deploy.
|
| 52 |
+
- **Extensibility Without Modification**: Register custom models, optimizers, schedulers, and datasets via decorators. Experiment freely without forking the framework.
|
| 53 |
+
- **Developer Experience First**: Interactive TUI for inference, intuitive YAML configurations, async data loading that eliminates I/O bottlenecks, and clear abstractions that make the codebase a pleasure to work with.
|
| 54 |
+
|
| 55 |
+
## Key Capabilities
|
| 56 |
+
|
| 57 |
+
| Capability | Details |
|
| 58 |
+
|---|---|
|
| 59 |
+
| **Multi-Stage Training** | Unified infrastructure for pretraining, SFT, and RL. Share model checkpoints, logging, and evaluation across stages. |
|
| 60 |
+
| **Advanced Optimization** | Hybrid Muon + AdamW optimizer: efficient 2D weight updates via SVD-based methods + adaptive learning for 1D parameters. |
|
| 61 |
+
| **Modern Architectures** | DeepSeek MLA with grouped query attention (GQA), YaRN context extension, and factorized embeddings—all configurable via YAML. |
|
| 62 |
+
| **Production Features** | BF16 mixed precision training, gradient accumulation, proper gradient clipping, checkpoint resumption, and validation loops. |
|
| 63 |
+
| **Async Data Pipeline** | Background tokenization with multi-threaded workers. Stream billion-token datasets from JSONL without loading into memory. |
|
| 64 |
+
| **Interactive Inference** | TUI chat interface with real-time generation speed metrics and multi-model comparison. |
|
| 65 |
+
| **Logging & Monitoring** | AimStack integration tracks loss, metrics, hyperparameters, and git hashes for reproducibility. Visualize training runs in your browser. |
|
| 66 |
+
|
| 67 |
+
## Getting Started
|
| 68 |
+
|
| 69 |
+
### Installation
|
| 70 |
+
|
| 71 |
+
```bash
|
| 72 |
+
git clone https://github.com/lobakkang/taoTrain.git
|
| 73 |
+
cd taoTrain
|
| 74 |
+
pip install -e .
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
### Training Examples
|
| 78 |
+
|
| 79 |
+
**Pretraining on a custom dataset:**
|
| 80 |
+
```bash
|
| 81 |
+
train pretrain --config configs/pretrain.yaml
|
| 82 |
+
```
|
| 83 |
+
Starts from scratch, learns representations from raw text via next-token prediction.
|
| 84 |
+
|
| 85 |
+
**Supervised Fine-tuning:**
|
| 86 |
+
```bash
|
| 87 |
+
train sft --config configs/sft.yaml
|
| 88 |
+
```
|
| 89 |
+
Fine-tune a pretrained model on instruction-response pairs for improved task performance.
|
| 90 |
+
|
| 91 |
+
**Reinforcement Learning (DPO):**
|
| 92 |
+
```bash
|
| 93 |
+
train rl --config configs/rl_dpo.yaml
|
| 94 |
+
```
|
| 95 |
+
Align models with human preferences using Direct Preference Optimization.
|
| 96 |
+
|
| 97 |
+
**Interactive Chat:**
|
| 98 |
+
```bash
|
| 99 |
+
tui-chat --model checkpoints/model.pt
|
| 100 |
+
```
|
| 101 |
+
Launch an interactive TUI to chat with your model and monitor generation metrics in real-time.
|
| 102 |
+
|
| 103 |
+
### Configuration
|
| 104 |
+
|
| 105 |
+
All training is configured via YAML with Pydantic validation. Configs are type-safe and automatically validated:
|
| 106 |
+
|
| 107 |
+
```yaml
|
| 108 |
+
# configs/sft.yaml
|
| 109 |
+
model:
|
| 110 |
+
architecture_type: "mla" # DeepSeek MLA with GQA
|
| 111 |
+
hidden_dim: 2048
|
| 112 |
+
num_layers: 24
|
| 113 |
+
num_heads: 32
|
| 114 |
+
d_latent_kv: 1536 # KV compression factor
|
| 115 |
+
|
| 116 |
+
training:
|
| 117 |
+
num_epochs: 3
|
| 118 |
+
batch_size: 32
|
| 119 |
+
learning_rate: 1e-4
|
| 120 |
+
warmup_ratio: 0.1
|
| 121 |
+
max_grad_norm: 1.0
|
| 122 |
+
|
| 123 |
+
optimizer:
|
| 124 |
+
optimizer_type: "muon_adamw" # Hybrid Muon + AdamW
|
| 125 |
+
muon_momentum: 0.95
|
| 126 |
+
|
| 127 |
+
data:
|
| 128 |
+
dataset_type: "sft_jsonl" # or "sft_hf" for HuggingFace
|
| 129 |
+
path: "data/sft_training.jsonl"
|
| 130 |
+
|
| 131 |
+
logging:
|
| 132 |
+
log_to_aim: true
|
| 133 |
+
aim_repo: "/tmp/aim_logs"
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
See `configs/` for complete examples.
|
| 137 |
+
|
| 138 |
+
## Project Architecture
|
| 139 |
+
|
| 140 |
+
```
|
| 141 |
+
src/taoTrain/
|
| 142 |
+
├── cli.py # Main CLI entry point
|
| 143 |
+
├── config.py # Pydantic configuration schemas
|
| 144 |
+
│
|
| 145 |
+
├── core/ # Base abstractions
|
| 146 |
+
│ └── base.py # BaseModel, BaseDataset, BaseTrainer
|
| 147 |
+
│
|
| 148 |
+
├── models/ # Pluggable architecture system
|
| 149 |
+
│ ├── registry.py # Architecture factory with @register_architecture
|
| 150 |
+
│ ├── taonet.py # SimpleLLM with DeepSeek MLA
|
| 151 |
+
│ ├── mla_components.py # KV compression, GQA, YaRN
|
| 152 |
+
│ ├── embeddings.py # Factorized embeddings
|
| 153 |
+
│ └── transformer.py # Standard Transformer reference
|
| 154 |
+
│
|
| 155 |
+
├── data/ # Advanced data pipeline
|
| 156 |
+
│ ├── factory.py # Dataset factory (HF + JSONL backends)
|
| 157 |
+
│ ├── async_loader.py # Async batch iteration (no I/O bottleneck)
|
| 158 |
+
│ ├── tokenization_queue.py # Background multi-threaded tokenization
|
| 159 |
+
│ ├── chunk_manager.py # Stream billion-token JSONL files
|
| 160 |
+
│ ├── hf_pretrain.py # HuggingFace pretraining datasets
|
| 161 |
+
│ ├── hf_sft.py # HuggingFace SFT datasets
|
| 162 |
+
│ ├── hf_rl.py # HuggingFace RL datasets
|
| 163 |
+
│ ├── pretrain_jsonl.py # JSONL pretraining
|
| 164 |
+
│ ├── sft_jsonl.py # JSONL SFT with instructions
|
| 165 |
+
│ └── rl_jsonl.py # JSONL RL with preferences
|
| 166 |
+
│
|
| 167 |
+
├── training/ # Unified training infrastructure
|
| 168 |
+
│ └── trainer.py # Trainer + PretrainTrainer, SFTTrainer, RLTrainer
|
| 169 |
+
│
|
| 170 |
+
├── optimizers/ # Pluggable optimizer system
|
| 171 |
+
│ ├── registry.py # Optimizer factory with @register_optimizer
|
| 172 |
+
│ ├── hybrid_muon_adamw.py # Composite: Muon (2D) + AdamW (1D)
|
| 173 |
+
│ ├── adamw.py # AdamW with weight decay
|
| 174 |
+
│ ├── adam.py # Standard Adam
|
| 175 |
+
│ └── sgd.py # SGD variants
|
| 176 |
+
│
|
| 177 |
+
├── schedulers/ # Learning rate schedules
|
| 178 |
+
│ ├── registry.py # LR scheduler factory
|
| 179 |
+
│ ├── cosine_warmup.py # 3-phase: linear warmup → plateau → cosine decay
|
| 180 |
+
│ ├── linear_warmup.py # Linear warmup + constant
|
| 181 |
+
│ └── constant.py # Constant learning rate
|
| 182 |
+
│
|
| 183 |
+
├── inference/ # Inference & interaction
|
| 184 |
+
│ ├── inferencer.py # Load & run inference from checkpoints
|
| 185 |
+
│ └── tui.py # Interactive chat with metrics display
|
| 186 |
+
│
|
| 187 |
+
├── checkpointing/ # State management
|
| 188 |
+
│ └── checkpoint.py # Save/load model + optimizer + config + metrics
|
| 189 |
+
│
|
| 190 |
+
├── logging/ # Experiment tracking
|
| 191 |
+
│ └── aim_logger.py # AimStack integration (loss, metrics, hyperparams)
|
| 192 |
+
│
|
| 193 |
+
├── benchmarks/ # Evaluation tools
|
| 194 |
+
│ └── runner.py # Perplexity, speed, and task-specific benchmarks
|
| 195 |
+
│
|
| 196 |
+
└── utils/
|
| 197 |
+
└── helpers.py # Utility functions
|
| 198 |
+
|
| 199 |
+
configs/ # Example YAML configurations
|
| 200 |
+
├── pretrain.yaml # Pretraining config
|
| 201 |
+
├── sft.yaml # SFT config
|
| 202 |
+
├── rl_dpo.yaml # RL/DPO config
|
| 203 |
+
└── tokenizer.yaml # Tokenizer config
|
| 204 |
+
|
| 205 |
+
tests/ # Unit & integration tests
|
| 206 |
+
└── test_dataset.py
|
| 207 |
+
```
|
| 208 |
+
|
| 209 |
+
## Extensible Architecture: The Registry Pattern
|
| 210 |
+
|
| 211 |
+
TaoTrain's power lies in its **pluggable design**. Add custom models, optimizers, schedulers, and datasets without modifying the framework.
|
| 212 |
+
|
| 213 |
+
### Custom Model Architecture
|
| 214 |
+
|
| 215 |
+
```python
|
| 216 |
+
from taoTrain.models import register_architecture, BaseModel
|
| 217 |
+
import torch.nn as nn
|
| 218 |
+
|
| 219 |
+
@register_architecture("custom_moe")
|
| 220 |
+
class MixtureOfExperts(BaseModel):
|
| 221 |
+
"""Your custom MoE architecture"""
|
| 222 |
+
def __init__(self, config):
|
| 223 |
+
super().__init__(config)
|
| 224 |
+
self.experts = nn.ModuleList([
|
| 225 |
+
nn.Linear(config.hidden_dim, config.hidden_dim)
|
| 226 |
+
for _ in range(config.num_experts)
|
| 227 |
+
])
|
| 228 |
+
self.router = nn.Linear(config.hidden_dim, config.num_experts)
|
| 229 |
+
|
| 230 |
+
def forward(self, input_ids, attention_mask=None):
|
| 231 |
+
# Your implementation
|
| 232 |
+
logits = self.compute_logits(input_ids)
|
| 233 |
+
loss = self.compute_loss(logits, labels) if labels is not None else None
|
| 234 |
+
return {"logits": logits, "loss": loss}
|
| 235 |
+
```
|
| 236 |
+
|
| 237 |
+
Then use it in your config:
|
| 238 |
+
|
| 239 |
+
```yaml
|
| 240 |
+
model:
|
| 241 |
+
architecture_type: "custom_moe"
|
| 242 |
+
hidden_dim: 2048
|
| 243 |
+
num_experts: 8
|
| 244 |
+
```
|
| 245 |
+
|
| 246 |
+
### Custom Optimizers & Schedulers
|
| 247 |
+
|
| 248 |
+
The same pattern works for optimizers and learning rate schedules:
|
| 249 |
+
|
| 250 |
+
```python
|
| 251 |
+
from taoTrain.optimizers import register_optimizer
|
| 252 |
+
from torch.optim import Optimizer
|
| 253 |
+
|
| 254 |
+
@register_optimizer("my_adaptive_optimizer")
|
| 255 |
+
class MyAdaptiveOptimizer(Optimizer):
|
| 256 |
+
def step(self, closure=None):
|
| 257 |
+
# Your optimization logic
|
| 258 |
+
pass
|
| 259 |
+
```
|
| 260 |
+
|
| 261 |
+
```python
|
| 262 |
+
from taoTrain.schedulers import register_scheduler
|
| 263 |
+
|
| 264 |
+
@register_scheduler("my_schedule")
|
| 265 |
+
def my_schedule(initial_lr, step, total_steps, **kwargs):
|
| 266 |
+
return initial_lr * (1.0 - step / total_steps) # Linear decay
|
| 267 |
+
```
|
| 268 |
+
|
| 269 |
+
**The key principle**: No framework code needs to change. You register once, it's available everywhere.
|
| 270 |
+
|
| 271 |
+
### Dataset Backend Flexibility
|
| 272 |
+
|
| 273 |
+
Define custom datasets (JSONL, HF, streaming, etc.) and let the factory route to them:
|
| 274 |
+
|
| 275 |
+
```python
|
| 276 |
+
from taoTrain.data import register_dataset
|
| 277 |
+
|
| 278 |
+
@register_dataset("pretrain", "my_backend")
|
| 279 |
+
class MyPretrainDataset(BaseDataset):
|
| 280 |
+
def __init__(self, config):
|
| 281 |
+
# Load from your custom backend
|
| 282 |
+
pass
|
| 283 |
+
|
| 284 |
+
def __getitem__(self, idx):
|
| 285 |
+
return {"input_ids": ..., "attention_mask": ...}
|
| 286 |
+
```
|
| 287 |
+
|
| 288 |
+
Use in config:
|
| 289 |
+
|
| 290 |
+
```yaml
|
| 291 |
+
data:
|
| 292 |
+
dataset_type: "pretrain"
|
| 293 |
+
backend_type: "my_backend" # Routes to MyPretrainDataset
|
| 294 |
+
```
|
| 295 |
+
|
| 296 |
+
## Why TaoTrain Framework?
|
| 297 |
+
|
| 298 |
+
### Async Data Loading: No I/O Bottleneck
|
| 299 |
+
|
| 300 |
+
Most training frameworks load and tokenize data on the main training thread, blocking compute. TaoTrain's **multi-threaded tokenization pipeline**:
|
| 301 |
+
|
| 302 |
+
- Tokenizes data in background workers while your GPU trains
|
| 303 |
+
- Supports streaming billion-token JSONL files without loading into memory
|
| 304 |
+
- Intelligent chunking (by file size or sample count)
|
| 305 |
+
- Metadata caching to avoid rescanning
|
| 306 |
+
|
| 307 |
+
**Result**: 10-100x faster data iteration on large datasets.
|
| 308 |
+
|
| 309 |
+
### Type-Safe Configuration
|
| 310 |
+
|
| 311 |
+
Forget YAML parsing errors or mysterious config bugs. TaoTrain uses **Pydantic dataclasses** for configuration:
|
| 312 |
+
|
| 313 |
+
- Automatic type validation: mistyped `learning_rate: "1e-4"` becomes an error, not silent failure
|
| 314 |
+
- Serialization: configs are part of checkpoints, ensuring reproducibility
|
| 315 |
+
- IDE support: autocomplete and type hints for all config fields
|
| 316 |
+
- Defaults: sensible defaults for all parameters
|
| 317 |
+
|
| 318 |
+
### Benchmarking & Metrics
|
| 319 |
+
|
| 320 |
+
Track what matters:
|
| 321 |
+
|
| 322 |
+
- **Perplexity**: Language modeling quality on held-out data
|
| 323 |
+
- **Generation Speed**: Tokens-per-second (useful for TUI or deployment)
|
| 324 |
+
- **Task-Specific Accuracy**: Evaluate on downstream tasks
|
| 325 |
+
- **Training Metrics**: Loss curves, gradient norms, effective batch size
|
| 326 |
+
|
| 327 |
+
All logged to AimStack with git hashes for reproducibility.
|
| 328 |
+
|
| 329 |
+
## Logging with AimStack
|
| 330 |
+
|
| 331 |
+
Automatically track and visualize experiments:
|
| 332 |
+
|
| 333 |
+
```bash
|
| 334 |
+
aim up --host 0.0.0.0
|
| 335 |
+
```
|
| 336 |
+
|
| 337 |
+
Then open `http://localhost:43800` to see:
|
| 338 |
+
|
| 339 |
+
- **Loss curves** per training step
|
| 340 |
+
- **Hyperparameters** (learning rate, batch size, model architecture)
|
| 341 |
+
- **Git hashes** for reproducibility
|
| 342 |
+
- **Custom metrics** (perplexity, validation accuracy, generation speed)
|
| 343 |
+
- **Compare runs**: Side-by-side experiment comparison
|
| 344 |
+
|
| 345 |
+
## Advanced Features
|
| 346 |
+
|
| 347 |
+
### Checkpointing with Resumption
|
| 348 |
+
|
| 349 |
+
TaoTrain saves complete training state:
|
| 350 |
+
|
| 351 |
+
```python
|
| 352 |
+
checkpoint = {
|
| 353 |
+
"step": 12500,
|
| 354 |
+
"model_state": model.state_dict(),
|
| 355 |
+
"optimizer_state": optimizer.state_dict(),
|
| 356 |
+
"config": config, # Full config as Pydantic object
|
| 357 |
+
"metrics": metrics_tracker.to_dict(),
|
| 358 |
+
}
|
| 359 |
+
```
|
| 360 |
+
|
| 361 |
+
Resume training from any checkpoint without loss of state. Keep last N checkpoints automatically.
|
| 362 |
+
|
| 363 |
+
### Mixed Precision Training (BF16)
|
| 364 |
+
|
| 365 |
+
```yaml
|
| 366 |
+
training:
|
| 367 |
+
use_bfloat16: true
|
| 368 |
+
gradient_accumulation_steps: 4
|
| 369 |
+
```
|
| 370 |
+
|
| 371 |
+
- BF16 via `torch.autocast` for ~2x speedup with minimal accuracy loss
|
| 372 |
+
- Proper gradient scaling and clipping
|
| 373 |
+
- Compatible with all optimizers and architectures
|
| 374 |
+
|
| 375 |
+
### 3-Phase Learning Rate Schedule
|
| 376 |
+
|
| 377 |
+
```yaml
|
| 378 |
+
scheduler:
|
| 379 |
+
scheduler_type: "cosine_warmup"
|
| 380 |
+
warmup_ratio: 0.1 # 10% of training steps
|
| 381 |
+
steady_ratio: 0.5 # 50% at steady rate
|
| 382 |
+
min_lr_ratio: 0.1 # Final LR = 0.1 × initial_lr
|
| 383 |
+
num_cycles: 1
|
| 384 |
+
```
|
| 385 |
+
|
| 386 |
+
This schedule:
|
| 387 |
+
1. **Linear warmup** (0 → 1) over 10% of steps
|
| 388 |
+
2. **Steady plateau** at full LR over 50% of steps
|
| 389 |
+
3. **Cosine decay** (1 → 0.1) over remaining 40% of steps
|
| 390 |
+
|
| 391 |
+
Better convergence than simple cosine or linear decay.
|
| 392 |
+
|
| 393 |
+
### Gradient Accumulation & Clipping
|
| 394 |
+
|
| 395 |
+
Simulate larger batch sizes with gradient accumulation:
|
| 396 |
+
|
| 397 |
+
```yaml
|
| 398 |
+
training:
|
| 399 |
+
batch_size: 32
|
| 400 |
+
gradient_accumulation_steps: 4 # Effective batch = 128
|
| 401 |
+
max_grad_norm: 1.0 # Gradient clipping
|
| 402 |
+
```
|
| 403 |
+
|
| 404 |
+
## Contributing
|
| 405 |
+
|
| 406 |
+
Contributions are welcome! TaoTrain is designed to make contributions easy:
|
| 407 |
+
|
| 408 |
+
1. **Add a model**: Implement `BaseModel` and `@register_architecture("name")`
|
| 409 |
+
2. **Add an optimizer**: Implement `torch.optim.Optimizer` and `@register_optimizer("name")`
|
| 410 |
+
3. **Add a dataset**: Implement `BaseDataset` and `@register_dataset(mode, backend_type)`
|
| 411 |
+
4. **Improve the core**: Submit PRs to `training/`, `data/`, `logging/`, etc.
|
| 412 |
+
|
| 413 |
+
Ensure new code includes:
|
| 414 |
+
- Type hints throughout
|
| 415 |
+
- Pydantic configs for new parameters
|
| 416 |
+
- Unit tests in `tests/`
|
| 417 |
+
- Documentation in docstrings and README
|
| 418 |
+
|
| 419 |
+
## Current Scope & Roadmap
|
| 420 |
+
|
| 421 |
+
### ✅ Currently Supported
|
| 422 |
+
|
| 423 |
+
- **Single GPU / single node** training
|
| 424 |
+
- **Pretraining, SFT, and RL training** stages
|
| 425 |
+
- **HuggingFace and JSONL** data backends
|
| 426 |
+
- **BF16 mixed precision** training
|
| 427 |
+
- **Checkpoint saving/loading** with resumption
|
| 428 |
+
- **Interactive inference** via TUI
|
| 429 |
+
- **Benchmarking** (perplexity, speed)
|
| 430 |
+
- **Pluggable architectures, optimizers, schedulers, datasets**
|
| 431 |
+
|
| 432 |
+
### 🚀 Roadmap (Future)
|
| 433 |
+
|
| 434 |
+
- **Distributed training** (DDP, FSDP) for multi-GPU/multi-node scaling
|
| 435 |
+
- **Quantization** support (INT8, QLoRA)
|
| 436 |
+
- **Advanced evaluation** (BLEU, ROUGE, custom tasks)
|
| 437 |
+
- **Streaming inference** with KV cache
|
| 438 |
+
- **Speculative decoding** for faster generation
|
| 439 |
+
- **Integration with popular model hubs** (Hugging Face Hub upload/download)
|
| 440 |
+
|
| 441 |
+
---
|
| 442 |
+
|
| 443 |
+
## Getting Help
|
| 444 |
+
|
| 445 |
+
- **Questions?** Open an issue on GitHub
|
| 446 |
+
- **Want to contribute?** See `CONTRIBUTING.md` (coming soon)
|
| 447 |
+
- **Found a bug?** Report it with a minimal reproduction script
|
| 448 |
+
|
| 449 |
+
## License
|
| 450 |
+
|
| 451 |
+
MIT
|
code/TaoTrain/src/taoTrain.egg-info/SOURCES.txt
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
README.md
|
| 2 |
+
pyproject.toml
|
| 3 |
+
src/taoTrain/__init__.py
|
| 4 |
+
src/taoTrain/cli.py
|
| 5 |
+
src/taoTrain/config.py
|
| 6 |
+
src/taoTrain.egg-info/PKG-INFO
|
| 7 |
+
src/taoTrain.egg-info/SOURCES.txt
|
| 8 |
+
src/taoTrain.egg-info/dependency_links.txt
|
| 9 |
+
src/taoTrain.egg-info/entry_points.txt
|
| 10 |
+
src/taoTrain.egg-info/requires.txt
|
| 11 |
+
src/taoTrain.egg-info/top_level.txt
|
| 12 |
+
src/taoTrain/benchmarks/__init__.py
|
| 13 |
+
src/taoTrain/benchmarks/runner.py
|
| 14 |
+
src/taoTrain/checkpointing/__init__.py
|
| 15 |
+
src/taoTrain/checkpointing/checkpoint.py
|
| 16 |
+
src/taoTrain/core/__init__.py
|
| 17 |
+
src/taoTrain/core/base.py
|
| 18 |
+
src/taoTrain/data/__init__.py
|
| 19 |
+
src/taoTrain/data/async_loader.py
|
| 20 |
+
src/taoTrain/data/chunk_manager.py
|
| 21 |
+
src/taoTrain/data/factory.py
|
| 22 |
+
src/taoTrain/data/hf_base.py
|
| 23 |
+
src/taoTrain/data/hf_pretrain.py
|
| 24 |
+
src/taoTrain/data/hf_rl.py
|
| 25 |
+
src/taoTrain/data/hf_sft.py
|
| 26 |
+
src/taoTrain/data/jsonl_base.py
|
| 27 |
+
src/taoTrain/data/loaders.py
|
| 28 |
+
src/taoTrain/data/pretrain_jsonl.py
|
| 29 |
+
src/taoTrain/data/rl_jsonl.py
|
| 30 |
+
src/taoTrain/data/sft_jsonl.py
|
| 31 |
+
src/taoTrain/data/sft_utils.py
|
| 32 |
+
src/taoTrain/data/tokenization_queue.py
|
| 33 |
+
src/taoTrain/data/tokenizer.py
|
| 34 |
+
src/taoTrain/inference/__init__.py
|
| 35 |
+
src/taoTrain/inference/inferencer.py
|
| 36 |
+
src/taoTrain/inference/tui.py
|
| 37 |
+
src/taoTrain/logging/__init__.py
|
| 38 |
+
src/taoTrain/logging/aim_logger.py
|
| 39 |
+
src/taoTrain/models/__init__.py
|
| 40 |
+
src/taoTrain/models/embeddings.py
|
| 41 |
+
src/taoTrain/models/mla_components.py
|
| 42 |
+
src/taoTrain/models/registry.py
|
| 43 |
+
src/taoTrain/models/taonet.py
|
| 44 |
+
src/taoTrain/models/taonet_ssm.py
|
| 45 |
+
src/taoTrain/models/transformer.py
|
| 46 |
+
src/taoTrain/optimizers/__init__.py
|
| 47 |
+
src/taoTrain/optimizers/adam.py
|
| 48 |
+
src/taoTrain/optimizers/adamw.py
|
| 49 |
+
src/taoTrain/optimizers/hybrid_muon_adamw.py
|
| 50 |
+
src/taoTrain/optimizers/registry.py
|
| 51 |
+
src/taoTrain/optimizers/sgd.py
|
| 52 |
+
src/taoTrain/schedulers/__init__.py
|
| 53 |
+
src/taoTrain/schedulers/constant.py
|
| 54 |
+
src/taoTrain/schedulers/cosine_warmup.py
|
| 55 |
+
src/taoTrain/schedulers/linear_warmup.py
|
| 56 |
+
src/taoTrain/schedulers/registry.py
|
| 57 |
+
src/taoTrain/tokenizers/__init__.py
|
| 58 |
+
src/taoTrain/tokenizers/trainer.py
|
| 59 |
+
src/taoTrain/training/__init__.py
|
| 60 |
+
src/taoTrain/training/trainer.py
|
| 61 |
+
src/taoTrain/utils/__init__.py
|
| 62 |
+
src/taoTrain/utils/helpers.py
|
| 63 |
+
tests/test_dataset.py
|
| 64 |
+
tests/test_sft_masking.py
|
| 65 |
+
tests/test_taonet_ssm.py
|
code/TaoTrain/src/taoTrain.egg-info/requires.txt
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0.0
|
| 2 |
+
transformers>=4.30.0
|
| 3 |
+
datasets>=2.10.0
|
| 4 |
+
pydantic>=2.0.0
|
| 5 |
+
pydantic-settings>=2.0.0
|
| 6 |
+
aim>=3.15.0
|
| 7 |
+
click>=8.1.0
|
| 8 |
+
rich>=13.0.0
|
| 9 |
+
textual>=0.30.0
|
| 10 |
+
numpy>=1.24.0
|
| 11 |
+
tqdm>=4.65.0
|
| 12 |
+
sentencepiece>=0.1.99
|
| 13 |
+
|
| 14 |
+
[dev]
|
| 15 |
+
pytest>=7.4.0
|
| 16 |
+
pytest-cov>=4.1.0
|
| 17 |
+
pytest-xdist>=3.3.0
|
| 18 |
+
black>=23.7.0
|
| 19 |
+
ruff>=0.0.280
|
| 20 |
+
typing-extensions>=4.7.0
|
code/TaoTrain/src/taoTrain.egg-info/top_level.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
taoTrain
|
code/TaoTrain/src/taoTrain/benchmarks/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Benchmarking suite."""
|
| 2 |
+
|
| 3 |
+
from .runner import BenchmarkRunner
|
| 4 |
+
|
| 5 |
+
__all__ = ["BenchmarkRunner"]
|
code/TaoTrain/src/taoTrain/benchmarks/runner.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Benchmarking suite for evaluating trained models."""
|
| 2 |
+
|
| 3 |
+
import time
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Optional, Dict
|
| 6 |
+
import torch
|
| 7 |
+
from torch.utils.data import DataLoader
|
| 8 |
+
|
| 9 |
+
from taoTrain.core import BaseModel
|
| 10 |
+
from taoTrain.config import TrainingConfig
|
| 11 |
+
from taoTrain.data.loaders import get_dataloader
|
| 12 |
+
from taoTrain.inference import Inferencer
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class BenchmarkRunner:
|
| 16 |
+
"""Run benchmarks on a trained model."""
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
model: BaseModel,
|
| 21 |
+
device: torch.device,
|
| 22 |
+
dtype: torch.dtype = torch.float32,
|
| 23 |
+
):
|
| 24 |
+
"""
|
| 25 |
+
Initialize benchmark runner.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
model: Trained model
|
| 29 |
+
device: Device for inference
|
| 30 |
+
dtype: Data type
|
| 31 |
+
"""
|
| 32 |
+
self.model = model.to(device)
|
| 33 |
+
self.model.eval()
|
| 34 |
+
self.device = device
|
| 35 |
+
self.dtype = dtype
|
| 36 |
+
|
| 37 |
+
@staticmethod
|
| 38 |
+
def load_from_checkpoint(
|
| 39 |
+
checkpoint_path: str | Path,
|
| 40 |
+
device: Optional[torch.device] = None,
|
| 41 |
+
) -> "BenchmarkRunner":
|
| 42 |
+
"""Load model from checkpoint."""
|
| 43 |
+
if device is None:
|
| 44 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 45 |
+
|
| 46 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 47 |
+
|
| 48 |
+
# Reconstruct model config
|
| 49 |
+
from taoTrain.config import ModelConfig
|
| 50 |
+
from taoTrain.models import get_model
|
| 51 |
+
|
| 52 |
+
model_config = ModelConfig(**checkpoint.get("config", {}).get("model", {}))
|
| 53 |
+
model = get_model(model_config, device=device)
|
| 54 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
| 55 |
+
|
| 56 |
+
return BenchmarkRunner(model, device)
|
| 57 |
+
|
| 58 |
+
def benchmark_perplexity(
|
| 59 |
+
self,
|
| 60 |
+
dataset: "DataLoader",
|
| 61 |
+
num_batches: Optional[int] = None,
|
| 62 |
+
) -> float:
|
| 63 |
+
"""
|
| 64 |
+
Compute perplexity on a dataset.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
dataset: DataLoader for evaluation
|
| 68 |
+
num_batches: Limit evaluation to N batches
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
Perplexity (exp of average loss)
|
| 72 |
+
"""
|
| 73 |
+
total_loss = 0.0
|
| 74 |
+
total_tokens = 0
|
| 75 |
+
|
| 76 |
+
with torch.no_grad():
|
| 77 |
+
for batch_idx, batch in enumerate(dataset):
|
| 78 |
+
if num_batches and batch_idx >= num_batches:
|
| 79 |
+
break
|
| 80 |
+
|
| 81 |
+
# Move to device
|
| 82 |
+
input_ids = batch["input_ids"].to(self.device)
|
| 83 |
+
attention_mask = batch.get("attention_mask")
|
| 84 |
+
if attention_mask is not None:
|
| 85 |
+
attention_mask = attention_mask.to(self.device)
|
| 86 |
+
labels = batch.get("labels")
|
| 87 |
+
if labels is not None:
|
| 88 |
+
labels = labels.to(self.device)
|
| 89 |
+
|
| 90 |
+
# Forward pass
|
| 91 |
+
with torch.autocast(
|
| 92 |
+
device_type="cuda" if self.device.type == "cuda" else "cpu",
|
| 93 |
+
dtype=torch.bfloat16 if self.dtype == torch.bfloat16 else torch.float32,
|
| 94 |
+
):
|
| 95 |
+
outputs = self.model(
|
| 96 |
+
input_ids=input_ids,
|
| 97 |
+
attention_mask=attention_mask,
|
| 98 |
+
labels=labels,
|
| 99 |
+
)
|
| 100 |
+
loss = outputs.get("loss")
|
| 101 |
+
|
| 102 |
+
if loss is not None:
|
| 103 |
+
total_loss += loss.item() * input_ids.shape[0]
|
| 104 |
+
total_tokens += input_ids.shape[0]
|
| 105 |
+
|
| 106 |
+
avg_loss = total_loss / total_tokens if total_tokens > 0 else float('inf')
|
| 107 |
+
perplexity = torch.exp(torch.tensor(avg_loss)).item()
|
| 108 |
+
|
| 109 |
+
return perplexity
|
| 110 |
+
|
| 111 |
+
def benchmark_throughput(
|
| 112 |
+
self,
|
| 113 |
+
batch_size: int = 32,
|
| 114 |
+
seq_length: int = 1024,
|
| 115 |
+
num_iters: int = 10,
|
| 116 |
+
) -> Dict[str, float]:
|
| 117 |
+
"""
|
| 118 |
+
Benchmark forward pass throughput.
|
| 119 |
+
|
| 120 |
+
Args:
|
| 121 |
+
batch_size: Batch size
|
| 122 |
+
seq_length: Sequence length
|
| 123 |
+
num_iters: Number of iterations
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
Dict with throughput metrics
|
| 127 |
+
"""
|
| 128 |
+
# Create dummy batch
|
| 129 |
+
dummy_input = torch.randint(
|
| 130 |
+
0, self.model.config.vocab_size,
|
| 131 |
+
(batch_size, seq_length)
|
| 132 |
+
).to(self.device)
|
| 133 |
+
|
| 134 |
+
# Warmup
|
| 135 |
+
with torch.no_grad():
|
| 136 |
+
for _ in range(2):
|
| 137 |
+
_ = self.model(dummy_input)
|
| 138 |
+
|
| 139 |
+
torch.cuda.synchronize() if torch.cuda.is_available() else None
|
| 140 |
+
|
| 141 |
+
# Benchmark forward pass
|
| 142 |
+
start = time.time()
|
| 143 |
+
|
| 144 |
+
with torch.no_grad():
|
| 145 |
+
for _ in range(num_iters):
|
| 146 |
+
_ = self.model(dummy_input)
|
| 147 |
+
|
| 148 |
+
torch.cuda.synchronize() if torch.cuda.is_available() else None
|
| 149 |
+
|
| 150 |
+
elapsed = time.time() - start
|
| 151 |
+
|
| 152 |
+
total_tokens = batch_size * seq_length * num_iters
|
| 153 |
+
tokens_per_sec = total_tokens / elapsed
|
| 154 |
+
|
| 155 |
+
return {
|
| 156 |
+
"throughput_tokens_per_sec": tokens_per_sec,
|
| 157 |
+
"throughput_samples_per_sec": (batch_size * num_iters) / elapsed,
|
| 158 |
+
"avg_time_per_iter_ms": (elapsed / num_iters) * 1000,
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
def benchmark_memory(self) -> Dict[str, float]:
|
| 162 |
+
"""
|
| 163 |
+
Benchmark peak GPU memory usage.
|
| 164 |
+
|
| 165 |
+
Returns:
|
| 166 |
+
Dict with memory stats
|
| 167 |
+
"""
|
| 168 |
+
if not torch.cuda.is_available():
|
| 169 |
+
return {"peak_memory_gb": 0.0}
|
| 170 |
+
|
| 171 |
+
torch.cuda.reset_peak_memory_stats()
|
| 172 |
+
torch.cuda.synchronize()
|
| 173 |
+
|
| 174 |
+
# Create dummy batch
|
| 175 |
+
dummy_input = torch.randint(
|
| 176 |
+
0, self.model.config.vocab_size,
|
| 177 |
+
(16, 1024)
|
| 178 |
+
).to(self.device)
|
| 179 |
+
|
| 180 |
+
with torch.no_grad():
|
| 181 |
+
_ = self.model(dummy_input)
|
| 182 |
+
|
| 183 |
+
torch.cuda.synchronize()
|
| 184 |
+
|
| 185 |
+
peak_memory = torch.cuda.max_memory_allocated() / (1024 ** 3) # GB
|
| 186 |
+
|
| 187 |
+
return {"peak_memory_gb": peak_memory}
|
| 188 |
+
|
| 189 |
+
def run_all_benchmarks(
|
| 190 |
+
self,
|
| 191 |
+
dataset: Optional["DataLoader"] = None,
|
| 192 |
+
batch_size: int = 32,
|
| 193 |
+
seq_length: int = 1024,
|
| 194 |
+
) -> Dict[str, float]:
|
| 195 |
+
"""
|
| 196 |
+
Run all benchmarks.
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
dataset: DataLoader for perplexity benchmark
|
| 200 |
+
batch_size: Batch size for throughput benchmark
|
| 201 |
+
seq_length: Sequence length for throughput benchmark
|
| 202 |
+
|
| 203 |
+
Returns:
|
| 204 |
+
Dict with all benchmark results
|
| 205 |
+
"""
|
| 206 |
+
results = {}
|
| 207 |
+
|
| 208 |
+
if dataset is not None:
|
| 209 |
+
print("Running perplexity benchmark...")
|
| 210 |
+
ppl = self.benchmark_perplexity(dataset, num_batches=10)
|
| 211 |
+
results["perplexity"] = ppl
|
| 212 |
+
|
| 213 |
+
print("Running throughput benchmark...")
|
| 214 |
+
throughput = self.benchmark_throughput(batch_size, seq_length)
|
| 215 |
+
results.update(throughput)
|
| 216 |
+
|
| 217 |
+
print("Running memory benchmark...")
|
| 218 |
+
memory = self.benchmark_memory()
|
| 219 |
+
results.update(memory)
|
| 220 |
+
|
| 221 |
+
return results
|
code/TaoTrain/src/taoTrain/checkpointing/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Checkpoint management."""
|
| 2 |
+
|
| 3 |
+
from .checkpoint import CheckpointManager
|
| 4 |
+
|
| 5 |
+
__all__ = ["CheckpointManager"]
|
code/TaoTrain/src/taoTrain/checkpointing/checkpoint.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Checkpoint management utilities.
|
| 2 |
+
|
| 3 |
+
Canonical Checkpoint Format (new):
|
| 4 |
+
{
|
| 5 |
+
'step': int, # Training step number
|
| 6 |
+
'model_state': Dict[str, Tensor], # Model state dict
|
| 7 |
+
'optimizer_state': Dict, # Optimizer state dict (optional)
|
| 8 |
+
'config': Dict, # TrainingConfig as dict
|
| 9 |
+
'metrics': Dict[str, float], # Training metrics
|
| 10 |
+
'global_step': int, # (deprecated, kept for compat) same as step
|
| 11 |
+
'current_epoch': int, # (optional) current epoch number
|
| 12 |
+
'best_loss': float, # (optional) best validation loss
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
Legacy Checkpoint Format (old, from BaseTrainer):
|
| 16 |
+
{
|
| 17 |
+
'global_step': int,
|
| 18 |
+
'current_epoch': int,
|
| 19 |
+
'best_loss': float,
|
| 20 |
+
'model_state_dict': Dict[str, Tensor], # ← Note: uses '_dict' suffix
|
| 21 |
+
'optimizer_state_dict': Dict,
|
| 22 |
+
'config': Dict,
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
The load() function auto-detects and migrates legacy format to canonical format.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
from pathlib import Path
|
| 29 |
+
from typing import Dict, Any, Optional
|
| 30 |
+
import torch
|
| 31 |
+
from taoTrain.config import TrainingConfig
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class CheckpointManager:
|
| 35 |
+
"""Manage model checkpoints with versioning."""
|
| 36 |
+
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
checkpoint_dir: str | Path,
|
| 40 |
+
keep_last_n: int = 3,
|
| 41 |
+
track_best: bool = True,
|
| 42 |
+
):
|
| 43 |
+
"""
|
| 44 |
+
Initialize checkpoint manager.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
checkpoint_dir: Directory to save checkpoints
|
| 48 |
+
keep_last_n: Number of recent checkpoints to keep
|
| 49 |
+
track_best: Whether to track best model
|
| 50 |
+
"""
|
| 51 |
+
self.checkpoint_dir = Path(checkpoint_dir)
|
| 52 |
+
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
| 53 |
+
|
| 54 |
+
self.keep_last_n = keep_last_n
|
| 55 |
+
self.track_best = track_best
|
| 56 |
+
|
| 57 |
+
self.best_metric = None
|
| 58 |
+
self.best_metric_name = None
|
| 59 |
+
self.saved_checkpoints = []
|
| 60 |
+
|
| 61 |
+
def save(
|
| 62 |
+
self,
|
| 63 |
+
step: int,
|
| 64 |
+
model_state: Dict[str, Any],
|
| 65 |
+
optimizer_state: Optional[Dict[str, Any]] = None,
|
| 66 |
+
config: Optional[TrainingConfig] = None,
|
| 67 |
+
metrics: Optional[Dict[str, float]] = None,
|
| 68 |
+
is_best: bool = False,
|
| 69 |
+
) -> Path:
|
| 70 |
+
"""
|
| 71 |
+
Save a checkpoint.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
step: Training step
|
| 75 |
+
model_state: Model state dict
|
| 76 |
+
optimizer_state: Optimizer state dict
|
| 77 |
+
config: Training config
|
| 78 |
+
metrics: Metrics dict
|
| 79 |
+
is_best: Whether this is the best model so far
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
Path to saved checkpoint
|
| 83 |
+
"""
|
| 84 |
+
checkpoint = {
|
| 85 |
+
"step": step,
|
| 86 |
+
"model_state": model_state,
|
| 87 |
+
"optimizer_state": optimizer_state,
|
| 88 |
+
"config": config.to_dict() if config else None,
|
| 89 |
+
"metrics": metrics or {},
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
filename = f"checkpoint_step_{step:06d}.pt"
|
| 93 |
+
if is_best:
|
| 94 |
+
filename = "best_model.pt"
|
| 95 |
+
|
| 96 |
+
path = self.checkpoint_dir / filename
|
| 97 |
+
torch.save(checkpoint, path)
|
| 98 |
+
|
| 99 |
+
# Track saved checkpoints
|
| 100 |
+
if not is_best:
|
| 101 |
+
self.saved_checkpoints.append((step, path))
|
| 102 |
+
|
| 103 |
+
# Clean up old checkpoints
|
| 104 |
+
if len(self.saved_checkpoints) > self.keep_last_n:
|
| 105 |
+
_, old_path = self.saved_checkpoints.pop(0)
|
| 106 |
+
if old_path.exists():
|
| 107 |
+
old_path.unlink()
|
| 108 |
+
|
| 109 |
+
return path
|
| 110 |
+
|
| 111 |
+
def load(
|
| 112 |
+
self,
|
| 113 |
+
checkpoint_path: str | Path,
|
| 114 |
+
device: Optional[torch.device] = None,
|
| 115 |
+
) -> Dict[str, Any]:
|
| 116 |
+
"""
|
| 117 |
+
Load a checkpoint with backward-compatible format handling.
|
| 118 |
+
|
| 119 |
+
Auto-detects checkpoint format (canonical or legacy) and normalizes
|
| 120 |
+
to canonical format in-memory. Legacy checkpoints are migrated without
|
| 121 |
+
modifying the file.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
checkpoint_path: Path to checkpoint
|
| 125 |
+
device: Device to load to
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
Checkpoint dict in canonical format with 'model_state' key
|
| 129 |
+
"""
|
| 130 |
+
if device is None:
|
| 131 |
+
device = torch.device("cpu")
|
| 132 |
+
|
| 133 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 134 |
+
|
| 135 |
+
# Auto-detect and migrate legacy format to canonical format
|
| 136 |
+
checkpoint = self._normalize_checkpoint_format(checkpoint)
|
| 137 |
+
|
| 138 |
+
return checkpoint
|
| 139 |
+
|
| 140 |
+
def _normalize_checkpoint_format(self, checkpoint: Dict[str, Any]) -> Dict[str, Any]:
|
| 141 |
+
"""
|
| 142 |
+
Normalize checkpoint to canonical format.
|
| 143 |
+
|
| 144 |
+
Detects if checkpoint is in legacy format (from BaseTrainer with 'model_state_dict')
|
| 145 |
+
and migrates it to canonical format (with 'model_state').
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
checkpoint: Raw checkpoint dict
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
Normalized checkpoint dict with canonical keys
|
| 152 |
+
"""
|
| 153 |
+
# Check if this is a legacy checkpoint (has 'model_state_dict' but not 'model_state')
|
| 154 |
+
if "model_state_dict" in checkpoint and "model_state" not in checkpoint:
|
| 155 |
+
# Migrate legacy format to canonical
|
| 156 |
+
migrated = {
|
| 157 |
+
"step": checkpoint.get("global_step", 0),
|
| 158 |
+
"model_state": checkpoint["model_state_dict"],
|
| 159 |
+
"optimizer_state": checkpoint.get("optimizer_state_dict"),
|
| 160 |
+
"config": checkpoint.get("config"),
|
| 161 |
+
"metrics": {},
|
| 162 |
+
# Keep legacy keys for backward compatibility in code that uses them
|
| 163 |
+
"global_step": checkpoint.get("global_step", 0),
|
| 164 |
+
"current_epoch": checkpoint.get("current_epoch", 0),
|
| 165 |
+
"best_loss": checkpoint.get("best_loss", float('inf')),
|
| 166 |
+
}
|
| 167 |
+
print(f"\n✓ [CheckpointManager] Detected legacy checkpoint format. Auto-migrated to canonical format.")
|
| 168 |
+
return migrated
|
| 169 |
+
|
| 170 |
+
# Already in canonical format or unknown format
|
| 171 |
+
if "model_state" not in checkpoint:
|
| 172 |
+
# If neither format detected, ensure model_state is accessible
|
| 173 |
+
# (might be a raw state_dict)
|
| 174 |
+
print(f"\n⚠ [CheckpointManager] Checkpoint format unclear. Assuming raw state_dict format.")
|
| 175 |
+
checkpoint["model_state"] = checkpoint
|
| 176 |
+
|
| 177 |
+
return checkpoint
|
| 178 |
+
|
| 179 |
+
def get_latest(self) -> Optional[Path]:
|
| 180 |
+
"""Get path to latest checkpoint."""
|
| 181 |
+
if not self.saved_checkpoints:
|
| 182 |
+
return None
|
| 183 |
+
return self.saved_checkpoints[-1][1]
|
| 184 |
+
|
| 185 |
+
def get_best(self) -> Optional[Path]:
|
| 186 |
+
"""Get path to best checkpoint."""
|
| 187 |
+
best_path = self.checkpoint_dir / "best_model.pt"
|
| 188 |
+
if best_path.exists():
|
| 189 |
+
return best_path
|
| 190 |
+
return None
|
| 191 |
+
|
| 192 |
+
def list_checkpoints(self) -> list[Path]:
|
| 193 |
+
"""List all saved checkpoints."""
|
| 194 |
+
return sorted(self.checkpoint_dir.glob("checkpoint_step_*.pt"))
|
code/TaoTrain/src/taoTrain/core/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Base classes for models, trainers, and datasets."""
|
| 2 |
+
|
| 3 |
+
from .base import BaseModel, BaseTrainer, BaseDataset, create_model, create_datasets
|
| 4 |
+
|
| 5 |
+
__all__ = ["BaseModel", "BaseTrainer", "BaseDataset", "create_model", "create_datasets"]
|
code/TaoTrain/src/taoTrain/core/base.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Base classes for models, trainers, and datasets."""
|
| 2 |
+
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Optional, Any, Iterator
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from torch.utils.data import Dataset as TorchDataset
|
| 9 |
+
from taoTrain.config import TrainingConfig, ModelConfig
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# ============================================================================
|
| 13 |
+
# Base Model
|
| 14 |
+
# ============================================================================
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class BaseModel(nn.Module, ABC):
|
| 18 |
+
"""Abstract base class for language models."""
|
| 19 |
+
|
| 20 |
+
def __init__(self, config: ModelConfig):
|
| 21 |
+
"""Initialize model with config."""
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.config = config
|
| 24 |
+
|
| 25 |
+
@abstractmethod
|
| 26 |
+
def forward(
|
| 27 |
+
self,
|
| 28 |
+
input_ids: torch.Tensor,
|
| 29 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 30 |
+
labels: Optional[torch.Tensor] = None,
|
| 31 |
+
) -> dict[str, torch.Tensor]:
|
| 32 |
+
"""
|
| 33 |
+
Forward pass.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
input_ids: Shape (batch_size, seq_length)
|
| 37 |
+
attention_mask: Shape (batch_size, seq_length), optional
|
| 38 |
+
labels: Shape (batch_size, seq_length), optional (for loss computation)
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
Dict with keys:
|
| 42 |
+
- 'logits': Shape (batch_size, seq_length, vocab_size)
|
| 43 |
+
- 'loss': Scalar (if labels provided)
|
| 44 |
+
"""
|
| 45 |
+
pass
|
| 46 |
+
|
| 47 |
+
def count_parameters(self) -> int:
|
| 48 |
+
"""Count total trainable parameters."""
|
| 49 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 50 |
+
|
| 51 |
+
def get_num_layers(self) -> int:
|
| 52 |
+
"""Get number of layers (for model architecture)."""
|
| 53 |
+
return self.config.num_layers
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# ============================================================================
|
| 57 |
+
# Base Dataset
|
| 58 |
+
# ============================================================================
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class BaseDataset(TorchDataset, ABC):
|
| 62 |
+
"""Abstract base class for datasets."""
|
| 63 |
+
|
| 64 |
+
def __init__(self, config: "TrainingConfig"):
|
| 65 |
+
"""Initialize dataset."""
|
| 66 |
+
self.config = config
|
| 67 |
+
self.data = None
|
| 68 |
+
|
| 69 |
+
@abstractmethod
|
| 70 |
+
def __len__(self) -> int:
|
| 71 |
+
"""Return dataset size."""
|
| 72 |
+
pass
|
| 73 |
+
|
| 74 |
+
@abstractmethod
|
| 75 |
+
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
| 76 |
+
"""
|
| 77 |
+
Get a single sample.
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
Dict with keys:
|
| 81 |
+
- 'input_ids': 1D tensor of token IDs
|
| 82 |
+
- 'attention_mask': 1D tensor of attention mask
|
| 83 |
+
- 'labels': 1D tensor of labels (optional)
|
| 84 |
+
"""
|
| 85 |
+
pass
|
| 86 |
+
|
| 87 |
+
def load_dataset(self) -> None:
|
| 88 |
+
"""Load dataset from HuggingFace or other source."""
|
| 89 |
+
pass
|
| 90 |
+
|
| 91 |
+
def preprocess(self) -> None:
|
| 92 |
+
"""Preprocess dataset (tokenization, etc)."""
|
| 93 |
+
pass
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
# ============================================================================
|
| 97 |
+
# Base Trainer
|
| 98 |
+
# ============================================================================
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class BaseTrainer(ABC):
|
| 102 |
+
"""Abstract base class for trainers."""
|
| 103 |
+
|
| 104 |
+
def __init__(
|
| 105 |
+
self,
|
| 106 |
+
model: BaseModel,
|
| 107 |
+
train_dataset: BaseDataset,
|
| 108 |
+
val_dataset: Optional[BaseDataset],
|
| 109 |
+
config: TrainingConfig,
|
| 110 |
+
device: torch.device,
|
| 111 |
+
):
|
| 112 |
+
"""Initialize trainer."""
|
| 113 |
+
self.model = model.to(device)
|
| 114 |
+
self.train_dataset = train_dataset
|
| 115 |
+
self.val_dataset = val_dataset
|
| 116 |
+
self.config = config
|
| 117 |
+
self.device = device
|
| 118 |
+
|
| 119 |
+
# Training state
|
| 120 |
+
self.global_step = 0
|
| 121 |
+
self.current_epoch = 0
|
| 122 |
+
self.best_loss = float('inf')
|
| 123 |
+
|
| 124 |
+
# Logging
|
| 125 |
+
self.logger = None
|
| 126 |
+
|
| 127 |
+
# Optimizer and scheduler (to be set up by subclass)
|
| 128 |
+
self.optimizer = None
|
| 129 |
+
self.scheduler = None
|
| 130 |
+
|
| 131 |
+
@abstractmethod
|
| 132 |
+
def training_step(self, batch: dict[str, torch.Tensor]) -> dict[str, float]:
|
| 133 |
+
"""
|
| 134 |
+
Single training step.
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
batch: Training batch with input_ids, attention_mask, labels, etc.
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
Dict with metrics (e.g., {'loss': 0.5, 'accuracy': 0.8})
|
| 141 |
+
"""
|
| 142 |
+
pass
|
| 143 |
+
|
| 144 |
+
@abstractmethod
|
| 145 |
+
def validation_step(self, batch: dict[str, torch.Tensor]) -> dict[str, float]:
|
| 146 |
+
"""
|
| 147 |
+
Single validation step.
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
batch: Validation batch
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
Dict with validation metrics
|
| 154 |
+
"""
|
| 155 |
+
pass
|
| 156 |
+
|
| 157 |
+
@abstractmethod
|
| 158 |
+
def train_epoch(self) -> dict[str, float]:
|
| 159 |
+
"""
|
| 160 |
+
Train for one epoch.
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
Dict with epoch-level metrics
|
| 164 |
+
"""
|
| 165 |
+
pass
|
| 166 |
+
|
| 167 |
+
@abstractmethod
|
| 168 |
+
def validate(self) -> dict[str, float]:
|
| 169 |
+
"""
|
| 170 |
+
Run validation on the entire validation set.
|
| 171 |
+
|
| 172 |
+
Returns:
|
| 173 |
+
Dict with validation metrics
|
| 174 |
+
"""
|
| 175 |
+
pass
|
| 176 |
+
|
| 177 |
+
def save_checkpoint(self, path: str | Path) -> None:
|
| 178 |
+
"""
|
| 179 |
+
Save checkpoint in canonical format.
|
| 180 |
+
|
| 181 |
+
Uses canonical checkpoint format:
|
| 182 |
+
{
|
| 183 |
+
'step': int,
|
| 184 |
+
'model_state': state_dict,
|
| 185 |
+
'optimizer_state': state_dict,
|
| 186 |
+
'config': dict,
|
| 187 |
+
'metrics': dict,
|
| 188 |
+
'global_step': int, # Legacy compat
|
| 189 |
+
'current_epoch': int, # Legacy compat
|
| 190 |
+
'best_loss': float, # Legacy compat
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
path: Path to save checkpoint
|
| 195 |
+
"""
|
| 196 |
+
path = Path(path)
|
| 197 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 198 |
+
|
| 199 |
+
# Save in canonical format
|
| 200 |
+
checkpoint = {
|
| 201 |
+
# Canonical format keys
|
| 202 |
+
'step': self.global_step,
|
| 203 |
+
'model_state': self.model.state_dict(),
|
| 204 |
+
'optimizer_state': self.optimizer.state_dict() if self.optimizer else None,
|
| 205 |
+
'config': self.config.to_dict(),
|
| 206 |
+
'metrics': {},
|
| 207 |
+
# Legacy format keys (for backward compatibility with code that reads them)
|
| 208 |
+
'global_step': self.global_step,
|
| 209 |
+
'current_epoch': self.current_epoch,
|
| 210 |
+
'best_loss': self.best_loss,
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
torch.save(checkpoint, path)
|
| 214 |
+
|
| 215 |
+
def load_checkpoint(self, path: str | Path) -> None:
|
| 216 |
+
"""
|
| 217 |
+
Load checkpoint (handles both canonical and legacy formats).
|
| 218 |
+
|
| 219 |
+
Args:
|
| 220 |
+
path: Path to checkpoint
|
| 221 |
+
"""
|
| 222 |
+
path = Path(path)
|
| 223 |
+
checkpoint = torch.load(path, map_location=self.device)
|
| 224 |
+
|
| 225 |
+
# Try canonical keys first, fall back to legacy keys
|
| 226 |
+
model_state_key = 'model_state' if 'model_state' in checkpoint else 'model_state_dict'
|
| 227 |
+
optimizer_state_key = 'optimizer_state' if 'optimizer_state' in checkpoint else 'optimizer_state_dict'
|
| 228 |
+
|
| 229 |
+
self.model.load_state_dict(checkpoint[model_state_key])
|
| 230 |
+
if self.optimizer and checkpoint.get(optimizer_state_key):
|
| 231 |
+
self.optimizer.load_state_dict(checkpoint[optimizer_state_key])
|
| 232 |
+
|
| 233 |
+
# Try canonical 'step' first, fall back to legacy 'global_step'
|
| 234 |
+
self.global_step = checkpoint.get('step', checkpoint.get('global_step', 0))
|
| 235 |
+
self.current_epoch = checkpoint.get('current_epoch', 0)
|
| 236 |
+
self.best_loss = checkpoint.get('best_loss', float('inf'))
|
| 237 |
+
|
| 238 |
+
def _get_lr(self) -> float:
|
| 239 |
+
"""Get current learning rate from optimizer."""
|
| 240 |
+
for param_group in self.optimizer.param_groups:
|
| 241 |
+
return param_group['lr']
|
| 242 |
+
return 0.0
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
# ============================================================================
|
| 246 |
+
# Utility functions
|
| 247 |
+
# ============================================================================
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def create_model(config: TrainingConfig, device: torch.device) -> BaseModel:
|
| 251 |
+
"""Create model from config (calls registry)."""
|
| 252 |
+
from taoTrain.models import get_model
|
| 253 |
+
return get_model(config.model, device=device)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def create_datasets(
|
| 257 |
+
config: TrainingConfig,
|
| 258 |
+
) -> tuple[BaseDataset, Optional[BaseDataset]]:
|
| 259 |
+
"""Create train and validation datasets using factory pattern."""
|
| 260 |
+
# Import here to avoid circular imports
|
| 261 |
+
from taoTrain.data import DatasetFactory
|
| 262 |
+
|
| 263 |
+
# Create train dataset
|
| 264 |
+
train_dataset = DatasetFactory.create_dataset(config, split="train")
|
| 265 |
+
|
| 266 |
+
# Create validation dataset (only for HuggingFace datasets with explicit validation split)
|
| 267 |
+
val_dataset = None
|
| 268 |
+
if not config.dataset.local and hasattr(config.dataset, "validation_split"):
|
| 269 |
+
val_dataset = DatasetFactory.create_dataset(config, split="validation")
|
| 270 |
+
|
| 271 |
+
return train_dataset, val_dataset
|
code/TaoTrain/src/taoTrain/data/__init__.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Dataset implementations and loaders."""
|
| 2 |
+
|
| 3 |
+
# HuggingFace-based datasets are optional for JSONL-only deployments.
|
| 4 |
+
try:
|
| 5 |
+
from .hf_base import BaseHFDataset
|
| 6 |
+
from .hf_pretrain import PretrainDataset
|
| 7 |
+
from .hf_sft import SFTDataset
|
| 8 |
+
from .hf_rl import RLDataset
|
| 9 |
+
except ImportError:
|
| 10 |
+
BaseHFDataset = None
|
| 11 |
+
PretrainDataset = None
|
| 12 |
+
SFTDataset = None
|
| 13 |
+
RLDataset = None
|
| 14 |
+
|
| 15 |
+
# JSONL-based datasets (async-only)
|
| 16 |
+
from .jsonl_base import BaseJSONLDataset
|
| 17 |
+
from .pretrain_jsonl import PretrainJSONLDataset
|
| 18 |
+
from .sft_jsonl import SFTJSONLDataset
|
| 19 |
+
from .rl_jsonl import RLJSONLDataset
|
| 20 |
+
|
| 21 |
+
# Utilities
|
| 22 |
+
from .tokenizer import SentencePieceTokenizerWrapper
|
| 23 |
+
from .sft_utils import (
|
| 24 |
+
parse_sft_record,
|
| 25 |
+
build_sft_sequence_tokens,
|
| 26 |
+
apply_response_masking,
|
| 27 |
+
build_response_only_next_token_labels,
|
| 28 |
+
)
|
| 29 |
+
from .loaders import get_dataloader
|
| 30 |
+
from .async_loader import AsyncBatchIterator
|
| 31 |
+
from .tokenization_queue import TokenizationQueue
|
| 32 |
+
from .factory import DatasetFactory
|
| 33 |
+
|
| 34 |
+
__all__ = [
|
| 35 |
+
# HuggingFace datasets
|
| 36 |
+
"BaseHFDataset",
|
| 37 |
+
"PretrainDataset",
|
| 38 |
+
"SFTDataset",
|
| 39 |
+
"RLDataset",
|
| 40 |
+
# JSONL datasets
|
| 41 |
+
"BaseJSONLDataset",
|
| 42 |
+
"PretrainJSONLDataset",
|
| 43 |
+
"SFTJSONLDataset",
|
| 44 |
+
"RLJSONLDataset",
|
| 45 |
+
# Utilities
|
| 46 |
+
"SentencePieceTokenizerWrapper",
|
| 47 |
+
"parse_sft_record",
|
| 48 |
+
"build_sft_sequence_tokens",
|
| 49 |
+
"apply_response_masking",
|
| 50 |
+
"build_response_only_next_token_labels",
|
| 51 |
+
# Data loading
|
| 52 |
+
"get_dataloader",
|
| 53 |
+
"AsyncBatchIterator",
|
| 54 |
+
"TokenizationQueue",
|
| 55 |
+
"DatasetFactory",
|
| 56 |
+
]
|
code/TaoTrain/src/taoTrain/data/async_loader.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Async batch iterator for training with background tokenization."""
|
| 2 |
+
|
| 3 |
+
from typing import Dict, List, Optional, Any, Iterator
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from taoTrain.data.tokenization_queue import TokenizationQueue
|
| 7 |
+
from taoTrain.data.sft_utils import build_response_only_next_token_labels
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class AsyncBatchIterator:
|
| 11 |
+
"""
|
| 12 |
+
Iterator that yields batches from a tokenization queue.
|
| 13 |
+
|
| 14 |
+
This allows batches to be consumed directly from the background tokenization
|
| 15 |
+
thread without waiting for all chunks to be tokenized upfront.
|
| 16 |
+
|
| 17 |
+
The iterator:
|
| 18 |
+
1. Pulls pre-tokenized chunks from the TokenizationQueue
|
| 19 |
+
2. Yields individual samples or batches
|
| 20 |
+
3. Handles movement to device (GPU/CPU) at batch level
|
| 21 |
+
4. Supports gradient accumulation
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
tokenization_queue: TokenizationQueue,
|
| 27 |
+
batch_size: int,
|
| 28 |
+
device: torch.device,
|
| 29 |
+
drop_last: bool = True,
|
| 30 |
+
gradient_accumulation_steps: int = 1,
|
| 31 |
+
):
|
| 32 |
+
"""
|
| 33 |
+
Initialize async batch iterator.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
tokenization_queue: TokenizationQueue instance
|
| 37 |
+
batch_size: Batch size for yielding batches
|
| 38 |
+
device: torch.device to move batches to
|
| 39 |
+
drop_last: If True, drop last incomplete batch
|
| 40 |
+
gradient_accumulation_steps: For logging purposes (not used here)
|
| 41 |
+
"""
|
| 42 |
+
self.queue = tokenization_queue
|
| 43 |
+
self.batch_size = batch_size
|
| 44 |
+
self.device = device
|
| 45 |
+
self.drop_last = drop_last
|
| 46 |
+
self.gradient_accumulation_steps = gradient_accumulation_steps
|
| 47 |
+
|
| 48 |
+
# State for iteration
|
| 49 |
+
self._current_chunk: Optional[Dict[str, List]] = None
|
| 50 |
+
self._current_idx = 0
|
| 51 |
+
self._samples_yielded = 0
|
| 52 |
+
self._finished = False
|
| 53 |
+
|
| 54 |
+
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
|
| 55 |
+
"""Return iterator (self)."""
|
| 56 |
+
# Reset state for new epoch
|
| 57 |
+
self._current_chunk = None
|
| 58 |
+
self._current_idx = 0
|
| 59 |
+
self._samples_yielded = 0
|
| 60 |
+
self._finished = False
|
| 61 |
+
|
| 62 |
+
# Reset tokenization queue for epochs 2+
|
| 63 |
+
if self.queue._next_chunk_idx > 0:
|
| 64 |
+
print(f"\n✓ Resetting TokenizationQueue for next epoch (cur_idx={self.queue._next_chunk_idx})")
|
| 65 |
+
self.queue.reset_for_next_epoch()
|
| 66 |
+
|
| 67 |
+
# Start tokenization threads once per iterator creation
|
| 68 |
+
if not self.queue._threads:
|
| 69 |
+
print("\n✓ Starting TokenizationQueue worker threads...")
|
| 70 |
+
self.queue.start()
|
| 71 |
+
else:
|
| 72 |
+
print(f"\n⚠ TokenizationQueue threads already running: {len(self.queue._threads)} active")
|
| 73 |
+
|
| 74 |
+
return self
|
| 75 |
+
|
| 76 |
+
def __next__(self) -> Dict[str, torch.Tensor]:
|
| 77 |
+
"""
|
| 78 |
+
Get next batch.
|
| 79 |
+
|
| 80 |
+
Yields:
|
| 81 |
+
Dict with 'input_ids', 'attention_mask', 'labels' (all as torch tensors on device)
|
| 82 |
+
|
| 83 |
+
Raises:
|
| 84 |
+
StopIteration: When no more batches available
|
| 85 |
+
"""
|
| 86 |
+
batch = self._get_next_batch()
|
| 87 |
+
|
| 88 |
+
if batch is None:
|
| 89 |
+
print("AsyncBatchIterator: No more batches available, stopping iteration.")
|
| 90 |
+
raise StopIteration
|
| 91 |
+
|
| 92 |
+
return batch
|
| 93 |
+
|
| 94 |
+
def _get_next_batch(self) -> Optional[Dict[str, torch.Tensor]]:
|
| 95 |
+
"""
|
| 96 |
+
Fetch and collate the next batch.
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
Dict with batch tensors, or None if iteration exhausted
|
| 100 |
+
"""
|
| 101 |
+
batch_input_ids = []
|
| 102 |
+
batch_attention_masks = []
|
| 103 |
+
batch_labels = []
|
| 104 |
+
|
| 105 |
+
while len(batch_input_ids) < self.batch_size:
|
| 106 |
+
# Try to get next sample from current chunk
|
| 107 |
+
if self._current_chunk is None or self._current_idx >= len(self._current_chunk["input_ids"]):
|
| 108 |
+
# Need new chunk
|
| 109 |
+
self._current_chunk = self.queue.get_next_chunk(timeout=30.0) # 30s polling timeout
|
| 110 |
+
|
| 111 |
+
if self._current_chunk is None:
|
| 112 |
+
if not self.queue.is_exhausted:
|
| 113 |
+
continue
|
| 114 |
+
# Queue exhausted
|
| 115 |
+
chunk_count = self.queue._next_chunk_idx if hasattr(self.queue, '_next_chunk_idx') else 'unknown'
|
| 116 |
+
print(f"AsyncBatchIterator: No more chunks (processed {chunk_count}/{len(self.queue._chunk_order)})")
|
| 117 |
+
print(f"AsyncBatchIterator: Samples yielded so far: {self._samples_yielded}")
|
| 118 |
+
self._finished = True
|
| 119 |
+
break
|
| 120 |
+
|
| 121 |
+
self._current_idx = 0
|
| 122 |
+
|
| 123 |
+
# Get sample from current chunk
|
| 124 |
+
input_ids = self._current_chunk["input_ids"][self._current_idx]
|
| 125 |
+
attention_mask = self._current_chunk["attention_mask"][self._current_idx]
|
| 126 |
+
|
| 127 |
+
# Generate labels based on SFT or pretrain mode
|
| 128 |
+
if "mask" in self._current_chunk:
|
| 129 |
+
# SFT mode: use mask to determine which tokens to train on
|
| 130 |
+
# mask=0 → label=-100 (ignore), mask=1 → label=input_id (train on)
|
| 131 |
+
mask = self._current_chunk["mask"][self._current_idx]
|
| 132 |
+
labels = build_response_only_next_token_labels(input_ids, mask)
|
| 133 |
+
else:
|
| 134 |
+
# Pretrain mode: shift labels by 1 for next-token prediction
|
| 135 |
+
# Position i predicts token at position i+1
|
| 136 |
+
labels = input_ids[1:] + [-100] # Append -100 as final position
|
| 137 |
+
|
| 138 |
+
# Mark padding tokens as -100 to ignore in loss computation
|
| 139 |
+
for i, mask_val in enumerate(attention_mask):
|
| 140 |
+
if mask_val == 0:
|
| 141 |
+
labels[i] = -100
|
| 142 |
+
|
| 143 |
+
batch_input_ids.append(input_ids)
|
| 144 |
+
batch_attention_masks.append(attention_mask)
|
| 145 |
+
batch_labels.append(labels)
|
| 146 |
+
|
| 147 |
+
self._current_idx += 1
|
| 148 |
+
self._samples_yielded += 1
|
| 149 |
+
|
| 150 |
+
# Return batch if we have any samples, respecting drop_last
|
| 151 |
+
if len(batch_input_ids) == 0:
|
| 152 |
+
print(f"AsyncBatchIterator: No samples collected for batch. Finished={self._finished}, returning None.")
|
| 153 |
+
return None
|
| 154 |
+
|
| 155 |
+
if len(batch_input_ids) < self.batch_size and self.drop_last:
|
| 156 |
+
incomplete_pct = (len(batch_input_ids) / self.batch_size) * 100
|
| 157 |
+
print(f"AsyncBatchIterator: Batch incomplete ({len(batch_input_ids)}/{self.batch_size} = {incomplete_pct:.1f}%) and drop_last=True, returning None.")
|
| 158 |
+
return None
|
| 159 |
+
|
| 160 |
+
return self._collate_batch(batch_input_ids, batch_attention_masks, batch_labels)
|
| 161 |
+
|
| 162 |
+
def _collate_batch(
|
| 163 |
+
self,
|
| 164 |
+
batch_input_ids: List[List[int]],
|
| 165 |
+
batch_attention_masks: List[List[int]],
|
| 166 |
+
batch_labels: List[List[int]],
|
| 167 |
+
) -> Dict[str, torch.Tensor]:
|
| 168 |
+
"""
|
| 169 |
+
Collate batch samples and move to device.
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
batch_input_ids: List of token ID lists
|
| 173 |
+
batch_attention_masks: List of attention mask lists
|
| 174 |
+
batch_labels: List of label lists
|
| 175 |
+
|
| 176 |
+
Returns:
|
| 177 |
+
Collated batch as torch tensors on device
|
| 178 |
+
"""
|
| 179 |
+
# Convert to tensors
|
| 180 |
+
input_ids_tensor = torch.tensor(batch_input_ids, dtype=torch.long, device=self.device)
|
| 181 |
+
attention_mask_tensor = torch.tensor(batch_attention_masks, dtype=torch.long, device=self.device)
|
| 182 |
+
labels_tensor = torch.tensor(batch_labels, dtype=torch.long, device=self.device)
|
| 183 |
+
|
| 184 |
+
return {
|
| 185 |
+
"input_ids": input_ids_tensor,
|
| 186 |
+
"attention_mask": attention_mask_tensor,
|
| 187 |
+
"labels": labels_tensor,
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
def __len__(self) -> int:
|
| 191 |
+
"""Return approximate number of batches."""
|
| 192 |
+
total_samples = len(self.queue)
|
| 193 |
+
if self.drop_last:
|
| 194 |
+
return total_samples // self.batch_size
|
| 195 |
+
else:
|
| 196 |
+
return (total_samples + self.batch_size - 1) // self.batch_size
|
| 197 |
+
|
| 198 |
+
def shutdown(self):
|
| 199 |
+
"""Shutdown the async iterator and background thread."""
|
| 200 |
+
self.queue.shutdown(wait=True)
|
| 201 |
+
|
| 202 |
+
def __del__(self):
|
| 203 |
+
"""Cleanup on deletion."""
|
| 204 |
+
self.shutdown()
|
code/TaoTrain/src/taoTrain/data/chunk_manager.py
ADDED
|
@@ -0,0 +1,452 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Chunk manager for streaming large JSONL datasets."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import json
|
| 5 |
+
import hashlib
|
| 6 |
+
from typing import Tuple, Optional, Dict, Any
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ChunkManager:
|
| 12 |
+
"""
|
| 13 |
+
Manages chunked reading of large JSONL files.
|
| 14 |
+
|
| 15 |
+
This class handles:
|
| 16 |
+
- File scanning to count total lines without loading all text
|
| 17 |
+
- Estimating chunk boundaries based on file size
|
| 18 |
+
- Tracking which line ranges belong to each chunk
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, jsonl_path: str, chunk_size_gb: float = 5.0,
|
| 22 |
+
samples_per_chunk: Optional[int] = None,
|
| 23 |
+
enable_metadata_cache: bool = True, chunk_cache_dir: str = ".cache/chunks",
|
| 24 |
+
max_samples: Optional[int] = None):
|
| 25 |
+
"""
|
| 26 |
+
Initialize ChunkManager.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
jsonl_path: Path to JSONL file
|
| 30 |
+
chunk_size_gb: Approximate chunk size in GB (ignored if samples_per_chunk is set)
|
| 31 |
+
samples_per_chunk: Number of samples per chunk (takes precedence over chunk_size_gb)
|
| 32 |
+
enable_metadata_cache: Enable caching of file scan metadata
|
| 33 |
+
chunk_cache_dir: Directory to store cache files
|
| 34 |
+
max_samples: Limit total samples to at most this many (if total_lines > max_samples)
|
| 35 |
+
|
| 36 |
+
Raises:
|
| 37 |
+
FileNotFoundError: If JSONL file doesn't exist
|
| 38 |
+
ValueError: If file is empty
|
| 39 |
+
"""
|
| 40 |
+
self.jsonl_path = Path(jsonl_path)
|
| 41 |
+
self.chunk_size_bytes = int(chunk_size_gb * 1024 ** 3) # Convert GB to bytes
|
| 42 |
+
self.max_samples = max_samples # Limit total samples if specified
|
| 43 |
+
print (f"Initializing ChunkManager for {self.jsonl_path} with target chunk size {chunk_size_gb} GB")
|
| 44 |
+
if samples_per_chunk is not None:
|
| 45 |
+
print(f" Overriding chunk size with {samples_per_chunk} samples per chunk")
|
| 46 |
+
if max_samples is not None:
|
| 47 |
+
print(f" Limiting dataset to {max_samples} samples")
|
| 48 |
+
self.samples_per_chunk = samples_per_chunk # If set, overrides GB-based chunking
|
| 49 |
+
self.enable_metadata_cache = enable_metadata_cache
|
| 50 |
+
self.chunk_cache_dir = Path(chunk_cache_dir)
|
| 51 |
+
|
| 52 |
+
if not self.jsonl_path.exists():
|
| 53 |
+
raise FileNotFoundError(f"JSONL file not found: {self.jsonl_path}")
|
| 54 |
+
|
| 55 |
+
self.file_size_bytes = os.path.getsize(self.jsonl_path)
|
| 56 |
+
self.file_mtime = os.path.getmtime(self.jsonl_path)
|
| 57 |
+
|
| 58 |
+
if self.file_size_bytes == 0:
|
| 59 |
+
raise ValueError("JSONL file is empty")
|
| 60 |
+
|
| 61 |
+
# Will be populated by _scan_file()
|
| 62 |
+
self.total_lines = 0
|
| 63 |
+
self.effective_lines = 0
|
| 64 |
+
self.line_sizes = [] # bytes per line
|
| 65 |
+
self.valid_line_offsets = [] # byte offset of each VALID JSON line (for seeking)
|
| 66 |
+
self.chunk_line_ranges = [] # [(start_line, end_line), ...]
|
| 67 |
+
|
| 68 |
+
# Try to load from cache first
|
| 69 |
+
cache_loaded = False
|
| 70 |
+
if self.enable_metadata_cache:
|
| 71 |
+
cache_loaded = self._load_metadata_cache()
|
| 72 |
+
|
| 73 |
+
# If cache not used, scan the file
|
| 74 |
+
if not cache_loaded:
|
| 75 |
+
self._scan_file()
|
| 76 |
+
self._compute_chunk_ranges()
|
| 77 |
+
|
| 78 |
+
# Save metadata cache for future runs
|
| 79 |
+
if self.enable_metadata_cache:
|
| 80 |
+
self._save_metadata_cache()
|
| 81 |
+
else:
|
| 82 |
+
# Cache stores file scan metadata. Recompute chunk ranges for the
|
| 83 |
+
# current training config so samples_per_chunk/max_samples changes
|
| 84 |
+
# are honored without rescanning the large JSONL file.
|
| 85 |
+
self._compute_chunk_ranges()
|
| 86 |
+
|
| 87 |
+
def _get_cache_path(self) -> Path:
|
| 88 |
+
"""Get the metadata cache file path for this JSONL file."""
|
| 89 |
+
# Create a hash of the file path to use as cache filename
|
| 90 |
+
file_hash = hashlib.md5(str(self.jsonl_path.absolute()).encode()).hexdigest()[:8]
|
| 91 |
+
cache_file = self.chunk_cache_dir / f"{file_hash}.metadata.json"
|
| 92 |
+
return cache_file
|
| 93 |
+
|
| 94 |
+
def _load_metadata_cache(self) -> bool:
|
| 95 |
+
"""
|
| 96 |
+
Load metadata from cache if it exists and is valid.
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
True if cache was loaded successfully, False otherwise
|
| 100 |
+
"""
|
| 101 |
+
cache_file = self._get_cache_path()
|
| 102 |
+
|
| 103 |
+
if not cache_file.exists():
|
| 104 |
+
return False
|
| 105 |
+
|
| 106 |
+
try:
|
| 107 |
+
with open(cache_file, 'r', encoding='utf-8') as f:
|
| 108 |
+
cache_data = json.load(f)
|
| 109 |
+
|
| 110 |
+
# Validate cache: check file hasn't changed
|
| 111 |
+
if (cache_data.get('file_size') != self.file_size_bytes or
|
| 112 |
+
cache_data.get('file_mtime') != self.file_mtime or
|
| 113 |
+
cache_data.get('jsonl_path') != str(self.jsonl_path.absolute())):
|
| 114 |
+
return False
|
| 115 |
+
|
| 116 |
+
# Load cached data
|
| 117 |
+
self.total_lines = cache_data.get('total_lines', 0)
|
| 118 |
+
self.line_sizes = cache_data.get('line_sizes', [])
|
| 119 |
+
self.valid_line_offsets = cache_data.get('valid_line_offsets', [])
|
| 120 |
+
# Convert loaded lists back to tuples for chunk_line_ranges
|
| 121 |
+
chunk_ranges = cache_data.get('chunk_line_ranges', [])
|
| 122 |
+
self.chunk_line_ranges = [tuple(r) for r in chunk_ranges]
|
| 123 |
+
self.chunk_size_bytes = cache_data.get('chunk_size_bytes', self.chunk_size_bytes)
|
| 124 |
+
|
| 125 |
+
print(f"✓ Loaded scan metadata from cache: {cache_file.name}")
|
| 126 |
+
print(f" Found {self.total_lines:,} valid JSON lines in {len(self.chunk_line_ranges)} chunks")
|
| 127 |
+
return True
|
| 128 |
+
|
| 129 |
+
except Exception as e:
|
| 130 |
+
# If cache loading fails, fall back to scanning
|
| 131 |
+
return False
|
| 132 |
+
|
| 133 |
+
def _save_metadata_cache(self) -> None:
|
| 134 |
+
"""Save metadata cache to file."""
|
| 135 |
+
cache_file = self._get_cache_path()
|
| 136 |
+
cache_file.parent.mkdir(parents=True, exist_ok=True)
|
| 137 |
+
|
| 138 |
+
cache_data = {
|
| 139 |
+
'jsonl_path': str(self.jsonl_path.absolute()),
|
| 140 |
+
'file_size': self.file_size_bytes,
|
| 141 |
+
'file_mtime': self.file_mtime,
|
| 142 |
+
'total_lines': self.total_lines,
|
| 143 |
+
'line_sizes': self.line_sizes,
|
| 144 |
+
'valid_line_offsets': self.valid_line_offsets,
|
| 145 |
+
'chunk_line_ranges': self.chunk_line_ranges,
|
| 146 |
+
'chunk_size_bytes': self.chunk_size_bytes,
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
try:
|
| 150 |
+
# Write atomically using a temp file + rename
|
| 151 |
+
temp_file = cache_file.with_suffix('.tmp')
|
| 152 |
+
with open(temp_file, 'w', encoding='utf-8') as f:
|
| 153 |
+
json.dump(cache_data, f, indent=2)
|
| 154 |
+
temp_file.replace(cache_file)
|
| 155 |
+
print(f" Saved scan metadata to cache: {cache_file.name}")
|
| 156 |
+
except Exception as e:
|
| 157 |
+
print(f" ⚠ Warning: failed to save cache: {e}")
|
| 158 |
+
|
| 159 |
+
def _get_chunk_cache_dir(self) -> Path:
|
| 160 |
+
"""Get the directory for storing cached chunk data for this JSONL file."""
|
| 161 |
+
file_hash = hashlib.md5(str(self.jsonl_path.absolute()).encode()).hexdigest()[:8]
|
| 162 |
+
chunk_dir = self.chunk_cache_dir / "chunks" / file_hash
|
| 163 |
+
return chunk_dir
|
| 164 |
+
|
| 165 |
+
def _get_chunk_cache_file(self, chunk_num: int) -> Path:
|
| 166 |
+
"""Get the cache file path for a specific chunk."""
|
| 167 |
+
chunk_dir = self._get_chunk_cache_dir()
|
| 168 |
+
return chunk_dir / f"chunk_{chunk_num:06d}.jsonl"
|
| 169 |
+
|
| 170 |
+
def _get_chunk_index_file(self) -> Path:
|
| 171 |
+
"""Get the index file that lists all cached chunks."""
|
| 172 |
+
chunk_dir = self._get_chunk_cache_dir()
|
| 173 |
+
return chunk_dir / "index.json"
|
| 174 |
+
|
| 175 |
+
def extract_and_cache_chunks(self) -> Dict[str, Any]:
|
| 176 |
+
"""
|
| 177 |
+
Extract chunks from the original JSONL file and save them as separate cached files.
|
| 178 |
+
|
| 179 |
+
This is optional and should be called manually if you want to pre-cache chunks
|
| 180 |
+
for faster repeated access. It can significantly speed up training but uses more disk space.
|
| 181 |
+
|
| 182 |
+
Returns:
|
| 183 |
+
Dictionary with cache information:
|
| 184 |
+
- 'cache_dir': path to cache directory
|
| 185 |
+
- 'num_chunks': number of chunks cached
|
| 186 |
+
- 'total_size_gb': total size of cached chunks
|
| 187 |
+
"""
|
| 188 |
+
chunk_dir = self._get_chunk_cache_dir()
|
| 189 |
+
chunk_dir.mkdir(parents=True, exist_ok=True)
|
| 190 |
+
|
| 191 |
+
print(f"💾 Extracting {len(self.chunk_line_ranges)} chunks to cache...")
|
| 192 |
+
total_size = 0
|
| 193 |
+
|
| 194 |
+
for chunk_num in range(len(self.chunk_line_ranges)):
|
| 195 |
+
cache_file = self._get_chunk_cache_file(chunk_num)
|
| 196 |
+
|
| 197 |
+
# Skip if already cached
|
| 198 |
+
if cache_file.exists():
|
| 199 |
+
total_size += os.path.getsize(cache_file)
|
| 200 |
+
continue
|
| 201 |
+
|
| 202 |
+
# Read chunk and save to cache file
|
| 203 |
+
chunk_examples = self.read_chunk(chunk_num, _from_cache=False)
|
| 204 |
+
|
| 205 |
+
with open(cache_file, 'w', encoding='utf-8') as f:
|
| 206 |
+
for obj in chunk_examples:
|
| 207 |
+
f.write(json.dumps(obj) + '\n')
|
| 208 |
+
|
| 209 |
+
total_size += os.path.getsize(cache_file)
|
| 210 |
+
if (chunk_num + 1) % max(1, len(self.chunk_line_ranges) // 10) == 0:
|
| 211 |
+
print(f" - Cached {chunk_num + 1}/{len(self.chunk_line_ranges)} chunks...")
|
| 212 |
+
|
| 213 |
+
# Write index file
|
| 214 |
+
index_data = {
|
| 215 |
+
'jsonl_path': str(self.jsonl_path.absolute()),
|
| 216 |
+
'num_chunks': len(self.chunk_line_ranges),
|
| 217 |
+
'chunk_ranges': self.chunk_line_ranges,
|
| 218 |
+
}
|
| 219 |
+
with open(self._get_chunk_index_file(), 'w', encoding='utf-8') as f:
|
| 220 |
+
json.dump(index_data, f, indent=2)
|
| 221 |
+
|
| 222 |
+
print(f"✓ Cached {len(self.chunk_line_ranges)} chunks ({total_size / (1024**3):.2f} GB)")
|
| 223 |
+
|
| 224 |
+
return {
|
| 225 |
+
'cache_dir': str(chunk_dir),
|
| 226 |
+
'num_chunks': len(self.chunk_line_ranges),
|
| 227 |
+
'total_size_gb': total_size / (1024**3),
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
def clear_chunk_cache(self, keep_metadata: bool = False) -> None:
|
| 231 |
+
"""
|
| 232 |
+
Clear cached chunk data.
|
| 233 |
+
|
| 234 |
+
Args:
|
| 235 |
+
keep_metadata: If True, only remove chunk files, keep the metadata cache
|
| 236 |
+
"""
|
| 237 |
+
chunk_dir = self._get_chunk_cache_dir()
|
| 238 |
+
|
| 239 |
+
if chunk_dir.exists():
|
| 240 |
+
import shutil
|
| 241 |
+
shutil.rmtree(chunk_dir)
|
| 242 |
+
print(f"✓ Cleared chunk cache: {chunk_dir}")
|
| 243 |
+
|
| 244 |
+
if not keep_metadata:
|
| 245 |
+
cache_file = self._get_cache_path()
|
| 246 |
+
if cache_file.exists():
|
| 247 |
+
cache_file.unlink()
|
| 248 |
+
print(f"✓ Cleared metadata cache: {cache_file}")
|
| 249 |
+
|
| 250 |
+
def _scan_file(self) -> None:
|
| 251 |
+
"""
|
| 252 |
+
Scan JSONL file to count lines and track offsets.
|
| 253 |
+
|
| 254 |
+
This reads the file once to:
|
| 255 |
+
- Count total valid JSON lines
|
| 256 |
+
- Record byte offset of each VALID line for seeking
|
| 257 |
+
- Estimate size per line
|
| 258 |
+
"""
|
| 259 |
+
print(f"📖 Scanning JSONL file: {self.jsonl_path}")
|
| 260 |
+
print(f" File size: {self.file_size_bytes / (1024**3):.2f} GB")
|
| 261 |
+
|
| 262 |
+
self.valid_line_offsets = []
|
| 263 |
+
current_offset = 0
|
| 264 |
+
valid_lines = 0
|
| 265 |
+
|
| 266 |
+
try:
|
| 267 |
+
with open(self.jsonl_path, 'r', encoding='utf-8') as f:
|
| 268 |
+
for line in tqdm(f, desc="Scanning JSONL", unit=" lines"):
|
| 269 |
+
# Skip empty lines - don't count toward line numbers
|
| 270 |
+
if not line.strip():
|
| 271 |
+
current_offset += len(line.encode('utf-8'))
|
| 272 |
+
continue
|
| 273 |
+
|
| 274 |
+
try:
|
| 275 |
+
json.loads(line)
|
| 276 |
+
# Valid JSON line - record its starting byte offset
|
| 277 |
+
self.valid_line_offsets.append(current_offset)
|
| 278 |
+
valid_lines += 1
|
| 279 |
+
|
| 280 |
+
line_bytes = len(line.encode('utf-8'))
|
| 281 |
+
self.line_sizes.append(line_bytes)
|
| 282 |
+
|
| 283 |
+
except json.JSONDecodeError:
|
| 284 |
+
# Skip invalid JSON lines - don't count toward line numbers
|
| 285 |
+
pass
|
| 286 |
+
|
| 287 |
+
current_offset += len(line.encode('utf-8'))
|
| 288 |
+
|
| 289 |
+
except Exception as e:
|
| 290 |
+
raise ValueError(f"Error scanning JSONL file: {e}")
|
| 291 |
+
|
| 292 |
+
self.total_lines = valid_lines
|
| 293 |
+
|
| 294 |
+
if self.total_lines == 0:
|
| 295 |
+
raise ValueError("No valid JSON lines found in JSONL file")
|
| 296 |
+
|
| 297 |
+
print(f"✓ Found {self.total_lines:,} valid JSON lines")
|
| 298 |
+
|
| 299 |
+
# Calculate average line size
|
| 300 |
+
avg_line_size = sum(self.line_sizes) / len(self.line_sizes) if self.line_sizes else 0
|
| 301 |
+
print(f" Average line size: {avg_line_size:.2f} bytes")
|
| 302 |
+
print(f" Chunk size target: {self.chunk_size_bytes / (1024**3):.2f} GB")
|
| 303 |
+
|
| 304 |
+
def _compute_chunk_ranges(self) -> None:
|
| 305 |
+
"""
|
| 306 |
+
Compute line ranges for each chunk based on target chunk size.
|
| 307 |
+
|
| 308 |
+
If samples_per_chunk is set, uses that. Otherwise, divides file
|
| 309 |
+
based on chunk_size_bytes. If max_samples is set, limits chunks to cover
|
| 310 |
+
at most max_samples lines.
|
| 311 |
+
"""
|
| 312 |
+
if self.total_lines == 0:
|
| 313 |
+
self.chunk_line_ranges = []
|
| 314 |
+
return
|
| 315 |
+
|
| 316 |
+
# Apply max_samples limit to effective line count
|
| 317 |
+
self.effective_lines = self.total_lines
|
| 318 |
+
if self.max_samples is not None:
|
| 319 |
+
self.effective_lines = min(self.total_lines, self.max_samples)
|
| 320 |
+
|
| 321 |
+
# Determine lines per chunk
|
| 322 |
+
if self.samples_per_chunk is not None:
|
| 323 |
+
# Use explicit sample count
|
| 324 |
+
lines_per_chunk = self.samples_per_chunk
|
| 325 |
+
else:
|
| 326 |
+
# Use GB-based calculation
|
| 327 |
+
avg_line_size = sum(self.line_sizes) / len(self.line_sizes) if self.line_sizes else 1
|
| 328 |
+
lines_per_chunk = max(1, int(self.chunk_size_bytes / avg_line_size))
|
| 329 |
+
|
| 330 |
+
chunk_ranges = []
|
| 331 |
+
start_line = 0
|
| 332 |
+
|
| 333 |
+
# Create chunks up to self.effective_lines (honors max_samples)
|
| 334 |
+
while start_line < self.effective_lines:
|
| 335 |
+
end_line = min(start_line + lines_per_chunk, self.effective_lines)
|
| 336 |
+
chunk_ranges.append((start_line, end_line))
|
| 337 |
+
start_line = end_line
|
| 338 |
+
|
| 339 |
+
self.chunk_line_ranges = chunk_ranges
|
| 340 |
+
self.num_chunks = len(chunk_ranges)
|
| 341 |
+
|
| 342 |
+
print(f" Divided into {self.num_chunks} chunks (covering {self.effective_lines:,} lines)")
|
| 343 |
+
|
| 344 |
+
def get_chunk_indices(self, chunk_num: int) -> Tuple[int, int]:
|
| 345 |
+
"""
|
| 346 |
+
Get (start_line, end_line) for a given chunk number.
|
| 347 |
+
|
| 348 |
+
Args:
|
| 349 |
+
chunk_num: Chunk number (0-indexed)
|
| 350 |
+
|
| 351 |
+
Returns:
|
| 352 |
+
Tuple of (start_line, end_line) where end_line is exclusive
|
| 353 |
+
|
| 354 |
+
Raises:
|
| 355 |
+
IndexError: If chunk_num is out of range
|
| 356 |
+
"""
|
| 357 |
+
if chunk_num < 0 or chunk_num >= len(self.chunk_line_ranges):
|
| 358 |
+
raise IndexError(f"Chunk {chunk_num} out of range [0, {len(self.chunk_line_ranges)-1}]")
|
| 359 |
+
|
| 360 |
+
return self.chunk_line_ranges[chunk_num]
|
| 361 |
+
|
| 362 |
+
def read_chunk(self, chunk_num: int, _from_cache: bool = True) -> list[dict]:
|
| 363 |
+
"""
|
| 364 |
+
Read a specific chunk and return parsed JSON objects.
|
| 365 |
+
|
| 366 |
+
If chunk cache is available, reads from cache. Otherwise reads from original JSONL
|
| 367 |
+
using file.seek() for O(1) lookup instead of O(n) scanning.
|
| 368 |
+
|
| 369 |
+
Args:
|
| 370 |
+
chunk_num: Chunk number (0-indexed)
|
| 371 |
+
_from_cache: Internal parameter to force reading from original (used during cache extraction)
|
| 372 |
+
|
| 373 |
+
Returns:
|
| 374 |
+
List of parsed JSON objects from that chunk
|
| 375 |
+
|
| 376 |
+
Raises:
|
| 377 |
+
IndexError: If chunk_num is out of range
|
| 378 |
+
ValueError: If JSON parsing fails
|
| 379 |
+
"""
|
| 380 |
+
# Try to read from cache first (if it exists)
|
| 381 |
+
if _from_cache:
|
| 382 |
+
cache_file = self._get_chunk_cache_file(chunk_num)
|
| 383 |
+
if cache_file.exists():
|
| 384 |
+
examples = []
|
| 385 |
+
try:
|
| 386 |
+
with open(cache_file, 'r', encoding='utf-8') as f:
|
| 387 |
+
for line in f:
|
| 388 |
+
if line.strip():
|
| 389 |
+
try:
|
| 390 |
+
obj = json.loads(line)
|
| 391 |
+
examples.append(obj)
|
| 392 |
+
except json.JSONDecodeError:
|
| 393 |
+
pass
|
| 394 |
+
return examples
|
| 395 |
+
except Exception as e:
|
| 396 |
+
print(f" ⚠ Warning: failed to read chunk from cache, falling back to original: {e}")
|
| 397 |
+
|
| 398 |
+
# Read from original JSONL file using seek optimization
|
| 399 |
+
start_line, end_line = self.get_chunk_indices(chunk_num)
|
| 400 |
+
|
| 401 |
+
examples = []
|
| 402 |
+
|
| 403 |
+
with open(self.jsonl_path, 'r', encoding='utf-8') as f:
|
| 404 |
+
# Seek to the byte offset of the start line
|
| 405 |
+
# This is O(1) instead of O(start_line) iteration
|
| 406 |
+
if start_line < len(self.valid_line_offsets):
|
| 407 |
+
f.seek(self.valid_line_offsets[start_line])
|
| 408 |
+
else:
|
| 409 |
+
# Fallback if valid_line_offsets not available (shouldn't happen)
|
| 410 |
+
f.seek(0)
|
| 411 |
+
|
| 412 |
+
current_line = start_line
|
| 413 |
+
|
| 414 |
+
# Read lines from start_line to end_line
|
| 415 |
+
for line in f:
|
| 416 |
+
# Skip empty lines
|
| 417 |
+
if not line.strip():
|
| 418 |
+
continue
|
| 419 |
+
|
| 420 |
+
# Stop when we've read enough lines
|
| 421 |
+
if current_line >= end_line:
|
| 422 |
+
break
|
| 423 |
+
|
| 424 |
+
try:
|
| 425 |
+
obj = json.loads(line)
|
| 426 |
+
examples.append(obj)
|
| 427 |
+
current_line += 1
|
| 428 |
+
except json.JSONDecodeError:
|
| 429 |
+
# Skip invalid JSON lines, but don't increment line counter
|
| 430 |
+
# This maintains alignment with line numbering from scan
|
| 431 |
+
pass
|
| 432 |
+
|
| 433 |
+
return examples
|
| 434 |
+
|
| 435 |
+
@property
|
| 436 |
+
def num_chunks(self) -> int:
|
| 437 |
+
"""Return number of chunks."""
|
| 438 |
+
return len(self.chunk_line_ranges)
|
| 439 |
+
|
| 440 |
+
@num_chunks.setter
|
| 441 |
+
def num_chunks(self, value: int) -> None:
|
| 442 |
+
"""Set number of chunks (internal use)."""
|
| 443 |
+
self._num_chunks = value
|
| 444 |
+
|
| 445 |
+
def __repr__(self) -> str:
|
| 446 |
+
"""String representation."""
|
| 447 |
+
return (
|
| 448 |
+
f"ChunkManager(file={self.jsonl_path.name}, "
|
| 449 |
+
f"size={self.file_size_bytes/(1024**3):.2f}GB, "
|
| 450 |
+
f"lines={self.effective_lines:,}, "
|
| 451 |
+
f"chunks={self.num_chunks})"
|
| 452 |
+
)
|
code/TaoTrain/src/taoTrain/data/factory.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Factory for creating datasets based on configuration."""
|
| 2 |
+
|
| 3 |
+
from taoTrain.config import TrainingConfig, TrainingModeEnum
|
| 4 |
+
from taoTrain.data.pretrain_jsonl import PretrainJSONLDataset
|
| 5 |
+
from taoTrain.data.sft_jsonl import SFTJSONLDataset
|
| 6 |
+
from taoTrain.data.rl_jsonl import RLJSONLDataset
|
| 7 |
+
|
| 8 |
+
try:
|
| 9 |
+
from taoTrain.data.hf_pretrain import PretrainDataset
|
| 10 |
+
from taoTrain.data.hf_sft import SFTDataset
|
| 11 |
+
from taoTrain.data.hf_rl import RLDataset
|
| 12 |
+
except ImportError:
|
| 13 |
+
PretrainDataset = None
|
| 14 |
+
SFTDataset = None
|
| 15 |
+
RLDataset = None
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class DatasetFactory:
|
| 19 |
+
"""Factory for creating datasets based on configuration."""
|
| 20 |
+
|
| 21 |
+
# Registry of dataset classes by mode and backend
|
| 22 |
+
DATASETS = {
|
| 23 |
+
(TrainingModeEnum.PRETRAIN, "jsonl"): PretrainJSONLDataset,
|
| 24 |
+
(TrainingModeEnum.SFT, "jsonl"): SFTJSONLDataset,
|
| 25 |
+
(TrainingModeEnum.RL, "jsonl"): RLJSONLDataset,
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
if PretrainDataset is not None:
|
| 29 |
+
DATASETS.update({
|
| 30 |
+
(TrainingModeEnum.PRETRAIN, "huggingface"): PretrainDataset,
|
| 31 |
+
(TrainingModeEnum.SFT, "huggingface"): SFTDataset,
|
| 32 |
+
(TrainingModeEnum.RL, "huggingface"): RLDataset,
|
| 33 |
+
})
|
| 34 |
+
|
| 35 |
+
@staticmethod
|
| 36 |
+
def create_dataset(
|
| 37 |
+
config: TrainingConfig,
|
| 38 |
+
split: str = "train",
|
| 39 |
+
):
|
| 40 |
+
"""
|
| 41 |
+
Create dataset instance based on configuration.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
config: Training configuration
|
| 45 |
+
split: Dataset split (train, validation, test) - primarily for HuggingFace datasets
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
Dataset instance matching the configured mode and backend
|
| 49 |
+
|
| 50 |
+
Raises:
|
| 51 |
+
ValueError: If configuration is invalid or unsupported mode/backend combination
|
| 52 |
+
"""
|
| 53 |
+
# Determine backend: JSONL or HuggingFace
|
| 54 |
+
if config.dataset.local:
|
| 55 |
+
backend = "jsonl"
|
| 56 |
+
else:
|
| 57 |
+
backend = "huggingface"
|
| 58 |
+
|
| 59 |
+
# Get mode
|
| 60 |
+
mode = config.mode
|
| 61 |
+
|
| 62 |
+
# Look up dataset class
|
| 63 |
+
key = (mode, backend)
|
| 64 |
+
if key not in DatasetFactory.DATASETS:
|
| 65 |
+
if backend == "huggingface":
|
| 66 |
+
raise ImportError(
|
| 67 |
+
"HuggingFace dataset support requires the optional 'datasets' dependency. "
|
| 68 |
+
"Install project dependencies before using dataset.local=false."
|
| 69 |
+
)
|
| 70 |
+
raise ValueError(
|
| 71 |
+
f"Unsupported dataset configuration: mode={mode.value}, backend={backend}. "
|
| 72 |
+
f"Supported: {list(DatasetFactory.DATASETS.keys())}"
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
dataset_class = DatasetFactory.DATASETS[key]
|
| 76 |
+
|
| 77 |
+
# Instantiate dataset
|
| 78 |
+
if backend == "jsonl":
|
| 79 |
+
# JSONL datasets don't use split parameter
|
| 80 |
+
return dataset_class(config)
|
| 81 |
+
else:
|
| 82 |
+
# HuggingFace datasets use split parameter
|
| 83 |
+
return dataset_class(config, split=split)
|
| 84 |
+
|
| 85 |
+
@staticmethod
|
| 86 |
+
def register_dataset(mode: TrainingModeEnum, backend: str, dataset_class):
|
| 87 |
+
"""
|
| 88 |
+
Register a custom dataset class.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
mode: Training mode (e.g., TrainingModeEnum.PRETRAIN)
|
| 92 |
+
backend: Backend name (e.g., "jsonl", "huggingface")
|
| 93 |
+
dataset_class: Dataset class to register
|
| 94 |
+
"""
|
| 95 |
+
DatasetFactory.DATASETS[(mode, backend)] = dataset_class
|
| 96 |
+
|
| 97 |
+
@staticmethod
|
| 98 |
+
def list_available_datasets():
|
| 99 |
+
"""List all available dataset configurations."""
|
| 100 |
+
configs = {}
|
| 101 |
+
for (mode, backend), dataset_class in DatasetFactory.DATASETS.items():
|
| 102 |
+
key = f"{mode.value}_{backend}"
|
| 103 |
+
configs[key] = {
|
| 104 |
+
"mode": mode.value,
|
| 105 |
+
"backend": backend,
|
| 106 |
+
"class": dataset_class.__name__,
|
| 107 |
+
}
|
| 108 |
+
return configs
|
code/TaoTrain/src/taoTrain/data/hf_base.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Base class for HuggingFace-based datasets."""
|
| 2 |
+
|
| 3 |
+
from typing import Optional, Dict
|
| 4 |
+
import torch
|
| 5 |
+
from torch.utils.data import Dataset
|
| 6 |
+
from datasets import load_dataset
|
| 7 |
+
from transformers import AutoTokenizer
|
| 8 |
+
from taoTrain.config import TrainingConfig
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class BaseHFDataset(Dataset):
|
| 12 |
+
"""Base class for HuggingFace-based datasets."""
|
| 13 |
+
|
| 14 |
+
def __init__(self, config: TrainingConfig, split: str = "train"):
|
| 15 |
+
"""
|
| 16 |
+
Initialize dataset.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
config: Training configuration
|
| 20 |
+
split: Dataset split (train, validation, test)
|
| 21 |
+
"""
|
| 22 |
+
self.config = config
|
| 23 |
+
self.split = split
|
| 24 |
+
self.data = None
|
| 25 |
+
self.tokenizer = None
|
| 26 |
+
|
| 27 |
+
# Load tokenizer
|
| 28 |
+
self._load_tokenizer()
|
| 29 |
+
|
| 30 |
+
# Load and preprocess dataset
|
| 31 |
+
self._load_dataset()
|
| 32 |
+
self._preprocess()
|
| 33 |
+
|
| 34 |
+
def _load_tokenizer(self):
|
| 35 |
+
"""Load tokenizer from HuggingFace."""
|
| 36 |
+
# Default to GPT-2 tokenizer if not specified
|
| 37 |
+
tokenizer_name = getattr(self.config, 'tokenizer_name', 'gpt2')
|
| 38 |
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
| 39 |
+
|
| 40 |
+
# Set pad token if not set
|
| 41 |
+
if self.tokenizer.pad_token is None:
|
| 42 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 43 |
+
|
| 44 |
+
def _load_dataset(self):
|
| 45 |
+
"""Load dataset from HuggingFace."""
|
| 46 |
+
dataset_config = self.config.dataset
|
| 47 |
+
|
| 48 |
+
try:
|
| 49 |
+
# Load dataset
|
| 50 |
+
if dataset_config.config:
|
| 51 |
+
self.data = load_dataset(
|
| 52 |
+
dataset_config.dataset_name,
|
| 53 |
+
dataset_config.config,
|
| 54 |
+
split=self.split,
|
| 55 |
+
cache_dir=dataset_config.cache_dir,
|
| 56 |
+
trust_remote_code=True,
|
| 57 |
+
)
|
| 58 |
+
else:
|
| 59 |
+
self.data = load_dataset(
|
| 60 |
+
dataset_config.dataset_name,
|
| 61 |
+
split=self.split,
|
| 62 |
+
cache_dir=dataset_config.cache_dir,
|
| 63 |
+
trust_remote_code=True,
|
| 64 |
+
)
|
| 65 |
+
except Exception as e:
|
| 66 |
+
raise ValueError(f"Failed to load dataset {dataset_config.dataset_name}: {e}")
|
| 67 |
+
|
| 68 |
+
# Limit samples if specified
|
| 69 |
+
if dataset_config.max_samples:
|
| 70 |
+
self.data = self.data.select(range(min(dataset_config.max_samples, len(self.data))))
|
| 71 |
+
|
| 72 |
+
def _preprocess(self):
|
| 73 |
+
"""Preprocess dataset (to be implemented by subclasses)."""
|
| 74 |
+
pass
|
| 75 |
+
|
| 76 |
+
def __len__(self) -> int:
|
| 77 |
+
"""Return dataset length."""
|
| 78 |
+
return len(self.data)
|
| 79 |
+
|
| 80 |
+
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
| 81 |
+
"""Get item (to be implemented by subclasses)."""
|
| 82 |
+
pass
|
code/TaoTrain/src/taoTrain/data/hf_pretrain.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pretrain dataset for HuggingFace datasets."""
|
| 2 |
+
|
| 3 |
+
from typing import Dict
|
| 4 |
+
import torch
|
| 5 |
+
from taoTrain.config import TrainingConfig
|
| 6 |
+
from taoTrain.data.hf_base import BaseHFDataset
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class PretrainDataset(BaseHFDataset):
|
| 10 |
+
"""Dataset for pretraining with raw text."""
|
| 11 |
+
|
| 12 |
+
def _preprocess(self):
|
| 13 |
+
"""Tokenize text data."""
|
| 14 |
+
dataset_config = self.config.dataset
|
| 15 |
+
text_column = dataset_config.text_column
|
| 16 |
+
|
| 17 |
+
def tokenize_function(examples):
|
| 18 |
+
# Concatenate all texts
|
| 19 |
+
concatenated_examples = {
|
| 20 |
+
k: sum(examples[k], []) for k in examples.keys()
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
total_length = len(concatenated_examples[text_column])
|
| 24 |
+
# We'll use max_seq_length for training
|
| 25 |
+
total_length = (total_length // self.config.model.max_seq_length) * self.config.model.max_seq_length
|
| 26 |
+
|
| 27 |
+
# Tokenize
|
| 28 |
+
tokenized = self.tokenizer(
|
| 29 |
+
concatenated_examples[text_column],
|
| 30 |
+
truncation=False, # We'll chunk below
|
| 31 |
+
return_special_tokens_mask=False,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
# Chunk tokenized text
|
| 35 |
+
result = {
|
| 36 |
+
"input_ids": [],
|
| 37 |
+
"attention_mask": [],
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
for i in range(0, total_length, self.config.model.max_seq_length):
|
| 41 |
+
result["input_ids"].append(
|
| 42 |
+
tokenized["input_ids"][i:i + self.config.model.max_seq_length]
|
| 43 |
+
)
|
| 44 |
+
result["attention_mask"].append(
|
| 45 |
+
tokenized["attention_mask"][i:i + self.config.model.max_seq_length]
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
return result
|
| 49 |
+
|
| 50 |
+
# Preprocess in batches
|
| 51 |
+
self.data = self.data.map(
|
| 52 |
+
tokenize_function,
|
| 53 |
+
batched=True,
|
| 54 |
+
batch_size=100,
|
| 55 |
+
remove_columns=self.data.column_names,
|
| 56 |
+
desc="Tokenizing...",
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
| 60 |
+
"""Get preprocessed sample."""
|
| 61 |
+
item = self.data[idx]
|
| 62 |
+
|
| 63 |
+
input_ids = torch.tensor(item["input_ids"], dtype=torch.long)
|
| 64 |
+
attention_mask = torch.tensor(item["attention_mask"], dtype=torch.long)
|
| 65 |
+
|
| 66 |
+
# For pretrain, labels = input_ids shifted by 1 (next token prediction)
|
| 67 |
+
# Position i predicts token at position i+1
|
| 68 |
+
labels = input_ids[1:].clone()
|
| 69 |
+
labels = torch.cat([labels, torch.tensor([-100])], dim=0)
|
| 70 |
+
|
| 71 |
+
# Mark padding tokens as -100 to ignore in loss computation
|
| 72 |
+
labels[attention_mask == 0] = -100
|
| 73 |
+
|
| 74 |
+
return {
|
| 75 |
+
"input_ids": input_ids,
|
| 76 |
+
"attention_mask": attention_mask,
|
| 77 |
+
"labels": labels,
|
| 78 |
+
}
|
code/TaoTrain/src/taoTrain/data/hf_rl.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""RL dataset for HuggingFace datasets."""
|
| 2 |
+
|
| 3 |
+
from typing import Dict
|
| 4 |
+
import torch
|
| 5 |
+
from taoTrain.config import TrainingConfig
|
| 6 |
+
from taoTrain.data.hf_base import BaseHFDataset
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class RLDataset(BaseHFDataset):
|
| 10 |
+
"""Dataset for RL training with prompts."""
|
| 11 |
+
|
| 12 |
+
def _preprocess(self):
|
| 13 |
+
"""Prepare prompts for RL."""
|
| 14 |
+
dataset_config = self.config.dataset
|
| 15 |
+
|
| 16 |
+
# For RL, we typically just need prompts (no responses)
|
| 17 |
+
# The responses will be generated by the model during training
|
| 18 |
+
|
| 19 |
+
if dataset_config.prompt_column:
|
| 20 |
+
# Use existing prompt column
|
| 21 |
+
def extract_prompt(example):
|
| 22 |
+
return {"prompt": example[dataset_config.prompt_column]}
|
| 23 |
+
|
| 24 |
+
self.data = self.data.map(
|
| 25 |
+
extract_prompt,
|
| 26 |
+
remove_columns=self.data.column_names,
|
| 27 |
+
desc="Extracting prompts...",
|
| 28 |
+
)
|
| 29 |
+
else:
|
| 30 |
+
# For general datasets, just use the text column as prompt
|
| 31 |
+
def identity(example):
|
| 32 |
+
return {"prompt": example.get(dataset_config.text_column, "")}
|
| 33 |
+
|
| 34 |
+
self.data = self.data.map(
|
| 35 |
+
identity,
|
| 36 |
+
remove_columns=self.data.column_names,
|
| 37 |
+
desc="Preparing prompts...",
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
# Tokenize prompts
|
| 41 |
+
def tokenize_function(examples):
|
| 42 |
+
tokenized = self.tokenizer(
|
| 43 |
+
examples["prompt"],
|
| 44 |
+
truncation=True,
|
| 45 |
+
max_length=self.config.model.max_seq_length,
|
| 46 |
+
padding="max_length",
|
| 47 |
+
return_attention_mask=True,
|
| 48 |
+
)
|
| 49 |
+
return tokenized
|
| 50 |
+
|
| 51 |
+
self.data = self.data.map(
|
| 52 |
+
tokenize_function,
|
| 53 |
+
batched=True,
|
| 54 |
+
batch_size=100,
|
| 55 |
+
remove_columns=self.data.column_names,
|
| 56 |
+
desc="Tokenizing prompts...",
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
| 60 |
+
"""Get preprocessed prompt."""
|
| 61 |
+
item = self.data[idx]
|
| 62 |
+
|
| 63 |
+
input_ids = torch.tensor(item["input_ids"], dtype=torch.long)
|
| 64 |
+
attention_mask = torch.tensor(item["attention_mask"], dtype=torch.long)
|
| 65 |
+
|
| 66 |
+
# For RL, we don't have labels yet
|
| 67 |
+
# They're generated during training
|
| 68 |
+
|
| 69 |
+
return {
|
| 70 |
+
"input_ids": input_ids,
|
| 71 |
+
"attention_mask": attention_mask,
|
| 72 |
+
# "labels" will be None or set by the trainer
|
| 73 |
+
}
|
code/TaoTrain/src/taoTrain/data/hf_sft.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SFT dataset for HuggingFace datasets."""
|
| 2 |
+
|
| 3 |
+
from typing import Dict
|
| 4 |
+
import torch
|
| 5 |
+
from taoTrain.config import TrainingConfig
|
| 6 |
+
from taoTrain.data.hf_base import BaseHFDataset
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class SFTDataset(BaseHFDataset):
|
| 10 |
+
"""Dataset for supervised fine-tuning with instruction-response pairs."""
|
| 11 |
+
|
| 12 |
+
def _preprocess(self):
|
| 13 |
+
"""Process instruction-response pairs."""
|
| 14 |
+
dataset_config = self.config.dataset
|
| 15 |
+
|
| 16 |
+
def format_example(example):
|
| 17 |
+
"""Format instruction and response."""
|
| 18 |
+
instruction = example.get(dataset_config.instruction_column, "")
|
| 19 |
+
response = example.get(dataset_config.response_column, "")
|
| 20 |
+
|
| 21 |
+
if dataset_config.instruction_template:
|
| 22 |
+
# Use custom template
|
| 23 |
+
text = dataset_config.instruction_template.format(
|
| 24 |
+
instruction=instruction,
|
| 25 |
+
response=response
|
| 26 |
+
)
|
| 27 |
+
else:
|
| 28 |
+
# Default template
|
| 29 |
+
text = f"{instruction}\n{response}"
|
| 30 |
+
|
| 31 |
+
return {"text": text}
|
| 32 |
+
|
| 33 |
+
# Format examples
|
| 34 |
+
self.data = self.data.map(
|
| 35 |
+
format_example,
|
| 36 |
+
remove_columns=[
|
| 37 |
+
col for col in self.data.column_names
|
| 38 |
+
if col not in ["text"]
|
| 39 |
+
] if "text" not in self.data.column_names else [],
|
| 40 |
+
desc="Formatting examples...",
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
# Tokenize
|
| 44 |
+
def tokenize_function(examples):
|
| 45 |
+
tokenized = self.tokenizer(
|
| 46 |
+
examples["text"],
|
| 47 |
+
truncation=True,
|
| 48 |
+
max_length=self.config.model.max_seq_length,
|
| 49 |
+
padding="max_length",
|
| 50 |
+
return_attention_mask=True,
|
| 51 |
+
)
|
| 52 |
+
return tokenized
|
| 53 |
+
|
| 54 |
+
self.data = self.data.map(
|
| 55 |
+
tokenize_function,
|
| 56 |
+
batched=True,
|
| 57 |
+
batch_size=100,
|
| 58 |
+
remove_columns=self.data.column_names,
|
| 59 |
+
desc="Tokenizing...",
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
| 63 |
+
"""Get preprocessed sample."""
|
| 64 |
+
item = self.data[idx]
|
| 65 |
+
|
| 66 |
+
input_ids = torch.tensor(item["input_ids"], dtype=torch.long)
|
| 67 |
+
attention_mask = torch.tensor(item["attention_mask"], dtype=torch.long)
|
| 68 |
+
|
| 69 |
+
# For SFT, labels = input_ids shifted by 1 (next token prediction)
|
| 70 |
+
# Position i predicts token at position i+1
|
| 71 |
+
labels = input_ids[1:].clone()
|
| 72 |
+
labels = torch.cat([labels, torch.tensor([-100])], dim=0)
|
| 73 |
+
|
| 74 |
+
# Mark padding tokens as -100 to ignore in loss computation
|
| 75 |
+
labels[attention_mask == 0] = -100
|
| 76 |
+
|
| 77 |
+
return {
|
| 78 |
+
"input_ids": input_ids,
|
| 79 |
+
"attention_mask": attention_mask,
|
| 80 |
+
"labels": labels,
|
| 81 |
+
}
|
code/TaoTrain/src/taoTrain/data/jsonl_base.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Base class for local JSONL-based datasets (async-only)."""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from typing import Optional, Dict, Any
|
| 5 |
+
import torch
|
| 6 |
+
from torch.utils.data import Dataset
|
| 7 |
+
from taoTrain.config import TrainingConfig
|
| 8 |
+
from taoTrain.data.chunk_manager import ChunkManager
|
| 9 |
+
from taoTrain.data.tokenizer import SentencePieceTokenizerWrapper
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class BaseJSONLDataset(Dataset):
|
| 13 |
+
"""
|
| 14 |
+
Base class for local JSONL-based datasets with async-only streaming.
|
| 15 |
+
|
| 16 |
+
Designed for use with AsyncBatchIterator and TokenizationQueue.
|
| 17 |
+
All data loading and preprocessing happens asynchronously in background threads.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(self, config: TrainingConfig, split: str = "train"):
|
| 21 |
+
"""
|
| 22 |
+
Initialize JSONL dataset with chunked loading.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
config: Training configuration
|
| 26 |
+
split: Dataset split (train, validation, test) - not used for JSONL but kept for compatibility
|
| 27 |
+
|
| 28 |
+
Note:
|
| 29 |
+
Requires AsyncBatchIterator and TokenizationQueue for data loading.
|
| 30 |
+
See taoTrain/data/async_loader.py for usage.
|
| 31 |
+
"""
|
| 32 |
+
self.config = config
|
| 33 |
+
self.split = split
|
| 34 |
+
self.tokenizer = None
|
| 35 |
+
|
| 36 |
+
# Initialize chunk manager for streaming
|
| 37 |
+
dataset_config = self.config.dataset
|
| 38 |
+
jsonl_path = dataset_config.jsonl_path
|
| 39 |
+
|
| 40 |
+
if not jsonl_path:
|
| 41 |
+
raise ValueError("jsonl_path must be provided for local JSONL datasets")
|
| 42 |
+
|
| 43 |
+
# Create chunk manager
|
| 44 |
+
enable_streaming = dataset_config.enable_streaming
|
| 45 |
+
chunk_size_gb = dataset_config.chunk_size_gb
|
| 46 |
+
samples_per_chunk = dataset_config.samples_per_chunk
|
| 47 |
+
enable_metadata_cache = dataset_config.enable_chunk_metadata_cache
|
| 48 |
+
chunk_cache_dir = dataset_config.chunk_cache_dir
|
| 49 |
+
max_samples = dataset_config.max_samples
|
| 50 |
+
|
| 51 |
+
if enable_streaming:
|
| 52 |
+
self.chunk_manager = ChunkManager(
|
| 53 |
+
jsonl_path,
|
| 54 |
+
chunk_size_gb=chunk_size_gb,
|
| 55 |
+
samples_per_chunk=samples_per_chunk,
|
| 56 |
+
enable_metadata_cache=enable_metadata_cache,
|
| 57 |
+
chunk_cache_dir=chunk_cache_dir,
|
| 58 |
+
max_samples=max_samples
|
| 59 |
+
)
|
| 60 |
+
print(f"✓ {self.chunk_manager}")
|
| 61 |
+
else:
|
| 62 |
+
self.chunk_manager = None
|
| 63 |
+
|
| 64 |
+
# Current chunk data
|
| 65 |
+
self._current_chunk_num = None
|
| 66 |
+
self._current_chunk_data = None # {"text": [...]} or preprocessed data
|
| 67 |
+
self._text_field = dataset_config.text_field
|
| 68 |
+
|
| 69 |
+
# Load tokenizer
|
| 70 |
+
print("✓ Loading tokenizer...")
|
| 71 |
+
self._load_tokenizer()
|
| 72 |
+
|
| 73 |
+
print("✓ Dataset initialization complete (async mode - chunks loaded on-demand).")
|
| 74 |
+
|
| 75 |
+
def _load_tokenizer(self):
|
| 76 |
+
"""Load tokenizer (from local SentencePiece or HuggingFace)."""
|
| 77 |
+
dataset_config = self.config.dataset
|
| 78 |
+
|
| 79 |
+
# Check if tokenizer_path is specified
|
| 80 |
+
if dataset_config.tokenizer_path:
|
| 81 |
+
tokenizer_type = dataset_config.tokenizer_type
|
| 82 |
+
|
| 83 |
+
# Auto-detect tokenizer type based on file extension
|
| 84 |
+
if tokenizer_type is None:
|
| 85 |
+
if dataset_config.tokenizer_path.endswith('.model'):
|
| 86 |
+
tokenizer_type = 'sentencepiece'
|
| 87 |
+
else:
|
| 88 |
+
tokenizer_type = 'huggingface'
|
| 89 |
+
|
| 90 |
+
if tokenizer_type == 'sentencepiece':
|
| 91 |
+
# Load SentencePiece tokenizer
|
| 92 |
+
try:
|
| 93 |
+
import sentencepiece as spm
|
| 94 |
+
sp = spm.SentencePieceProcessor()
|
| 95 |
+
sp.Load(dataset_config.tokenizer_path)
|
| 96 |
+
# Wrap SentencePiece in a compatible interface
|
| 97 |
+
self.tokenizer = SentencePieceTokenizerWrapper(sp)
|
| 98 |
+
except ImportError:
|
| 99 |
+
raise ImportError("SentencePiece not installed. Install with: pip install sentencepiece")
|
| 100 |
+
except Exception as e:
|
| 101 |
+
raise ValueError(f"Failed to load SentencePiece tokenizer from {dataset_config.tokenizer_path}: {e}")
|
| 102 |
+
else:
|
| 103 |
+
# Load HuggingFace tokenizer from path
|
| 104 |
+
try:
|
| 105 |
+
from transformers import AutoTokenizer
|
| 106 |
+
self.tokenizer = AutoTokenizer.from_pretrained(dataset_config.tokenizer_path)
|
| 107 |
+
except ImportError as e:
|
| 108 |
+
raise ImportError("HuggingFace tokenizers require the optional 'transformers' dependency") from e
|
| 109 |
+
except Exception as e:
|
| 110 |
+
raise ValueError(f"Failed to load HuggingFace tokenizer from {dataset_config.tokenizer_path}: {e}")
|
| 111 |
+
else:
|
| 112 |
+
# Default to GPT-2 tokenizer
|
| 113 |
+
try:
|
| 114 |
+
from transformers import AutoTokenizer
|
| 115 |
+
except ImportError as e:
|
| 116 |
+
raise ImportError("Default GPT-2 tokenizer requires the optional 'transformers' dependency") from e
|
| 117 |
+
tokenizer_name = getattr(self.config, 'tokenizer_name', 'gpt2')
|
| 118 |
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
| 119 |
+
|
| 120 |
+
# Set pad token if not set (for HuggingFace tokenizers)
|
| 121 |
+
if hasattr(self.tokenizer, 'pad_token') and self.tokenizer.pad_token is None:
|
| 122 |
+
if hasattr(self.tokenizer, 'eos_token'):
|
| 123 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 124 |
+
|
| 125 |
+
def _load_chunk(self, chunk_num: int):
|
| 126 |
+
"""
|
| 127 |
+
Load a specific chunk from JSONL file.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
chunk_num: Chunk number to load (0-indexed)
|
| 131 |
+
"""
|
| 132 |
+
if not self.chunk_manager:
|
| 133 |
+
return
|
| 134 |
+
|
| 135 |
+
if chunk_num == self._current_chunk_num and self._current_chunk_data is not None:
|
| 136 |
+
# Already loaded
|
| 137 |
+
return
|
| 138 |
+
|
| 139 |
+
# Read chunk
|
| 140 |
+
chunk_examples = self.chunk_manager.read_chunk(chunk_num)
|
| 141 |
+
|
| 142 |
+
# Convert to text data
|
| 143 |
+
texts = []
|
| 144 |
+
for obj in chunk_examples:
|
| 145 |
+
if self._text_field in obj:
|
| 146 |
+
texts.append(obj[self._text_field])
|
| 147 |
+
|
| 148 |
+
self._current_chunk_data = {"text": texts}
|
| 149 |
+
self._current_chunk_num = chunk_num
|
| 150 |
+
|
| 151 |
+
# Preprocess chunk (tokenization happens in background via AsyncBatchIterator)
|
| 152 |
+
self._preprocess_chunk()
|
| 153 |
+
|
| 154 |
+
def _get_chunk_for_idx(self, idx: int) -> int:
|
| 155 |
+
"""
|
| 156 |
+
Determine which chunk contains the given global index.
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
idx: Global index
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
Chunk number (0-indexed)
|
| 163 |
+
"""
|
| 164 |
+
if not self.chunk_manager:
|
| 165 |
+
return 0
|
| 166 |
+
|
| 167 |
+
current_line = 0
|
| 168 |
+
for chunk_num, (start_line, end_line) in enumerate(self.chunk_manager.chunk_line_ranges):
|
| 169 |
+
if idx < (end_line - start_line):
|
| 170 |
+
return chunk_num
|
| 171 |
+
idx -= (end_line - start_line)
|
| 172 |
+
|
| 173 |
+
# Shouldn't reach here
|
| 174 |
+
return 0
|
| 175 |
+
|
| 176 |
+
def _get_local_idx_in_chunk(self, global_idx: int) -> int:
|
| 177 |
+
"""
|
| 178 |
+
Convert global index to local index within the chunk.
|
| 179 |
+
|
| 180 |
+
Args:
|
| 181 |
+
global_idx: Global index
|
| 182 |
+
|
| 183 |
+
Returns:
|
| 184 |
+
Local index within the chunk
|
| 185 |
+
"""
|
| 186 |
+
if not self.chunk_manager:
|
| 187 |
+
return global_idx
|
| 188 |
+
|
| 189 |
+
current_line = 0
|
| 190 |
+
for chunk_num, (start_line, end_line) in enumerate(self.chunk_manager.chunk_line_ranges):
|
| 191 |
+
chunk_size = end_line - start_line
|
| 192 |
+
if global_idx < chunk_size:
|
| 193 |
+
return global_idx
|
| 194 |
+
global_idx -= chunk_size
|
| 195 |
+
|
| 196 |
+
return 0
|
| 197 |
+
|
| 198 |
+
def _preprocess(self):
|
| 199 |
+
"""Preprocess dataset (to be implemented by subclasses)."""
|
| 200 |
+
pass
|
| 201 |
+
|
| 202 |
+
def _preprocess_chunk(self):
|
| 203 |
+
"""
|
| 204 |
+
Preprocess current chunk (to be implemented by subclasses).
|
| 205 |
+
|
| 206 |
+
This is called after a chunk is loaded by AsyncBatchIterator.
|
| 207 |
+
"""
|
| 208 |
+
pass
|
| 209 |
+
|
| 210 |
+
def __len__(self) -> int:
|
| 211 |
+
"""Return dataset length."""
|
| 212 |
+
if self.chunk_manager:
|
| 213 |
+
return self.chunk_manager.effective_lines
|
| 214 |
+
elif self._current_chunk_data and "text" in self._current_chunk_data:
|
| 215 |
+
return len(self._current_chunk_data.get("text", []))
|
| 216 |
+
return 0
|
| 217 |
+
|
| 218 |
+
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
| 219 |
+
"""Get item (to be implemented by subclasses)."""
|
| 220 |
+
pass
|
code/TaoTrain/src/taoTrain/data/loaders.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""DataLoader utilities."""
|
| 2 |
+
|
| 3 |
+
from typing import Optional
|
| 4 |
+
import torch
|
| 5 |
+
from torch.utils.data import DataLoader, Dataset
|
| 6 |
+
from taoTrain.config import TrainingConfig
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def get_dataloader(
|
| 10 |
+
dataset: Dataset,
|
| 11 |
+
config: TrainingConfig,
|
| 12 |
+
shuffle: bool = True,
|
| 13 |
+
drop_last: bool = True,
|
| 14 |
+
) -> DataLoader:
|
| 15 |
+
"""
|
| 16 |
+
Create a DataLoader from a dataset.
|
| 17 |
+
|
| 18 |
+
**NOTE**: For JSONL-based datasets (PretrainJSONLDataset, SFTJSONLDataset, etc.),
|
| 19 |
+
this function is now deprecated in favor of AsyncBatchIterator for better performance.
|
| 20 |
+
AsyncBatchIterator enables tokenization to happen in parallel with training,
|
| 21 |
+
avoiding the startup bottleneck of tokenizing all data upfront.
|
| 22 |
+
|
| 23 |
+
See: taoTrain/data/async_loader.py for the new async loading approach.
|
| 24 |
+
The trainer automatically uses AsyncBatchIterator for JSONL datasets.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
dataset: PyTorch Dataset instance
|
| 28 |
+
config: Training configuration
|
| 29 |
+
shuffle: Whether to shuffle data
|
| 30 |
+
drop_last: Whether to drop last incomplete batch
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
DataLoader instance
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def collate_fn(batch):
|
| 37 |
+
"""Collate function for padding sequences."""
|
| 38 |
+
# Batch is a list of dicts
|
| 39 |
+
collated = {}
|
| 40 |
+
keys = batch[0].keys()
|
| 41 |
+
|
| 42 |
+
for key in keys:
|
| 43 |
+
items = [item[key] for item in batch]
|
| 44 |
+
|
| 45 |
+
# Stack tensors
|
| 46 |
+
if isinstance(items[0], torch.Tensor):
|
| 47 |
+
if key in ["input_ids", "labels"]:
|
| 48 |
+
# Pad sequences
|
| 49 |
+
max_len = max(item.shape[0] for item in items)
|
| 50 |
+
padded = []
|
| 51 |
+
for item in items:
|
| 52 |
+
if len(item.shape) == 1:
|
| 53 |
+
# 1D tensor - pad it
|
| 54 |
+
pad_len = max_len - item.shape[0]
|
| 55 |
+
if pad_len > 0:
|
| 56 |
+
item = torch.nn.functional.pad(item, (0, pad_len), value=-100 if key == "labels" else 0)
|
| 57 |
+
padded.append(item)
|
| 58 |
+
collated[key] = torch.stack(padded)
|
| 59 |
+
elif key == "attention_mask":
|
| 60 |
+
# Also pad attention mask
|
| 61 |
+
max_len = max(item.shape[0] for item in items)
|
| 62 |
+
padded = []
|
| 63 |
+
for item in items:
|
| 64 |
+
if len(item.shape) == 1:
|
| 65 |
+
pad_len = max_len - item.shape[0]
|
| 66 |
+
if pad_len > 0:
|
| 67 |
+
item = torch.nn.functional.pad(item, (0, pad_len), value=0)
|
| 68 |
+
padded.append(item)
|
| 69 |
+
collated[key] = torch.stack(padded)
|
| 70 |
+
else:
|
| 71 |
+
collated[key] = torch.stack(items)
|
| 72 |
+
else:
|
| 73 |
+
collated[key] = items
|
| 74 |
+
|
| 75 |
+
return collated
|
| 76 |
+
|
| 77 |
+
return DataLoader(
|
| 78 |
+
dataset,
|
| 79 |
+
batch_size=config.batch_size,
|
| 80 |
+
shuffle=shuffle,
|
| 81 |
+
drop_last=drop_last,
|
| 82 |
+
num_workers=config.num_workers,
|
| 83 |
+
pin_memory=config.pin_memory,
|
| 84 |
+
collate_fn=collate_fn,
|
| 85 |
+
)
|
code/TaoTrain/src/taoTrain/data/pretrain_jsonl.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pretrain JSONL dataset with async-only streaming."""
|
| 2 |
+
|
| 3 |
+
from typing import Dict
|
| 4 |
+
import torch
|
| 5 |
+
from taoTrain.config import TrainingConfig
|
| 6 |
+
from taoTrain.data.jsonl_base import BaseJSONLDataset
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class PretrainJSONLDataset(BaseJSONLDataset):
|
| 10 |
+
"""Dataset for pretraining with local JSONL files with chunked loading."""
|
| 11 |
+
|
| 12 |
+
def _preprocess_chunk(self):
|
| 13 |
+
"""Tokenize current chunk of text data."""
|
| 14 |
+
if not self._current_chunk_data or "text" not in self._current_chunk_data:
|
| 15 |
+
return
|
| 16 |
+
|
| 17 |
+
max_seq_length = self.config.model.max_seq_length
|
| 18 |
+
texts = self._current_chunk_data["text"]
|
| 19 |
+
|
| 20 |
+
# Tokenize all texts in this chunk
|
| 21 |
+
all_token_ids = []
|
| 22 |
+
all_attention_masks = []
|
| 23 |
+
|
| 24 |
+
for text in texts:
|
| 25 |
+
tokenized = self.tokenizer(
|
| 26 |
+
text,
|
| 27 |
+
truncation=True,
|
| 28 |
+
max_length=max_seq_length,
|
| 29 |
+
padding="max_length",
|
| 30 |
+
return_attention_mask=True,
|
| 31 |
+
)
|
| 32 |
+
all_token_ids.append(tokenized["input_ids"])
|
| 33 |
+
all_attention_masks.append(tokenized["attention_mask"])
|
| 34 |
+
|
| 35 |
+
self._current_chunk_data = {
|
| 36 |
+
"input_ids": all_token_ids,
|
| 37 |
+
"attention_mask": all_attention_masks,
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
| 41 |
+
"""Get preprocessed sample, loading chunk if needed."""
|
| 42 |
+
# Load appropriate chunk if using streaming
|
| 43 |
+
if self.chunk_manager:
|
| 44 |
+
chunk_num = self._get_chunk_for_idx(idx)
|
| 45 |
+
if chunk_num != self._current_chunk_num:
|
| 46 |
+
self._load_chunk(chunk_num)
|
| 47 |
+
local_idx = self._get_local_idx_in_chunk(idx)
|
| 48 |
+
else:
|
| 49 |
+
local_idx = idx
|
| 50 |
+
|
| 51 |
+
input_ids = torch.tensor(self._current_chunk_data["input_ids"][local_idx], dtype=torch.long)
|
| 52 |
+
attention_mask = torch.tensor(self._current_chunk_data["attention_mask"][local_idx], dtype=torch.long)
|
| 53 |
+
|
| 54 |
+
# For pretrain, labels = input_ids shifted
|
| 55 |
+
labels = input_ids[1:].clone()
|
| 56 |
+
labels = torch.cat([labels, torch.tensor([-100])], dim=0)
|
| 57 |
+
|
| 58 |
+
# Replace padding token labels with -100 to ignore in labels
|
| 59 |
+
labels[attention_mask == 0] = -100
|
| 60 |
+
|
| 61 |
+
return {
|
| 62 |
+
"input_ids": input_ids,
|
| 63 |
+
"attention_mask": attention_mask,
|
| 64 |
+
"labels": labels,
|
| 65 |
+
}
|
code/TaoTrain/src/taoTrain/data/rl_jsonl.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""RL JSONL dataset with async-only streaming."""
|
| 2 |
+
|
| 3 |
+
from typing import Dict
|
| 4 |
+
import torch
|
| 5 |
+
from taoTrain.config import TrainingConfig
|
| 6 |
+
from taoTrain.data.jsonl_base import BaseJSONLDataset
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class RLJSONLDataset(BaseJSONLDataset):
|
| 10 |
+
"""Dataset for RL training with local JSONL files with chunked loading."""
|
| 11 |
+
|
| 12 |
+
def _preprocess_chunk(self):
|
| 13 |
+
"""Prepare prompts for RL from current chunk."""
|
| 14 |
+
if not self._current_chunk_data or "text" not in self._current_chunk_data:
|
| 15 |
+
return
|
| 16 |
+
|
| 17 |
+
max_seq_length = self.config.model.max_seq_length
|
| 18 |
+
texts = self._current_chunk_data["text"]
|
| 19 |
+
|
| 20 |
+
# Tokenize all prompts in this chunk
|
| 21 |
+
all_token_ids = []
|
| 22 |
+
all_attention_masks = []
|
| 23 |
+
|
| 24 |
+
for text in texts:
|
| 25 |
+
tokenized = self.tokenizer(
|
| 26 |
+
text,
|
| 27 |
+
truncation=True,
|
| 28 |
+
max_length=max_seq_length,
|
| 29 |
+
padding="max_length",
|
| 30 |
+
return_attention_mask=True,
|
| 31 |
+
)
|
| 32 |
+
all_token_ids.append(tokenized["input_ids"])
|
| 33 |
+
all_attention_masks.append(tokenized["attention_mask"])
|
| 34 |
+
|
| 35 |
+
self._current_chunk_data = {
|
| 36 |
+
"input_ids": all_token_ids,
|
| 37 |
+
"attention_mask": all_attention_masks,
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
| 41 |
+
"""Get preprocessed prompt, loading chunk if needed."""
|
| 42 |
+
# Load appropriate chunk if using streaming
|
| 43 |
+
if self.chunk_manager:
|
| 44 |
+
chunk_num = self._get_chunk_for_idx(idx)
|
| 45 |
+
if chunk_num != self._current_chunk_num:
|
| 46 |
+
self._load_chunk(chunk_num)
|
| 47 |
+
local_idx = self._get_local_idx_in_chunk(idx)
|
| 48 |
+
else:
|
| 49 |
+
local_idx = idx
|
| 50 |
+
|
| 51 |
+
input_ids = torch.tensor(self._current_chunk_data["input_ids"][local_idx], dtype=torch.long)
|
| 52 |
+
attention_mask = torch.tensor(self._current_chunk_data["attention_mask"][local_idx], dtype=torch.long)
|
| 53 |
+
|
| 54 |
+
# For RL, no labels yet (generated during training)
|
| 55 |
+
return {
|
| 56 |
+
"input_ids": input_ids,
|
| 57 |
+
"attention_mask": attention_mask,
|
| 58 |
+
}
|
code/TaoTrain/src/taoTrain/data/sft_jsonl.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SFT JSONL dataset with async-only streaming and response-masking."""
|
| 2 |
+
|
| 3 |
+
from typing import Dict
|
| 4 |
+
import torch
|
| 5 |
+
from taoTrain.config import TrainingConfig
|
| 6 |
+
from taoTrain.data.jsonl_base import BaseJSONLDataset
|
| 7 |
+
from taoTrain.data.sft_utils import (
|
| 8 |
+
parse_sft_record,
|
| 9 |
+
build_sft_sequence_tokens,
|
| 10 |
+
build_response_only_next_token_labels,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class SFTJSONLDataset(BaseJSONLDataset):
|
| 15 |
+
"""
|
| 16 |
+
Dataset for supervised fine-tuning with local JSONL files with chunked loading.
|
| 17 |
+
|
| 18 |
+
Supports both single-turn and multi-turn SFT data:
|
| 19 |
+
- Single-turn: {"input": "...", "output": "..."}
|
| 20 |
+
- Multi-turn: {"turns": [{"user": "...", "assistant": "..."}, ...]}
|
| 21 |
+
|
| 22 |
+
With response-only loss masking: only trains on assistant/response tokens.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(self, *args, **kwargs):
|
| 26 |
+
"""Initialize dataset."""
|
| 27 |
+
super().__init__(*args, **kwargs)
|
| 28 |
+
# Store full records for parsing (not just text field)
|
| 29 |
+
self._current_chunk_records = None
|
| 30 |
+
|
| 31 |
+
# Get SFT-specific config
|
| 32 |
+
self.sft_config = self.config if hasattr(self.config, 'mode') else None
|
| 33 |
+
self.user_token = getattr(self.sft_config, 'user_token', '<user>') if self.sft_config else '<user>'
|
| 34 |
+
self.assistant_token = getattr(self.sft_config, 'assistant_token', '<assistant>') if self.sft_config else '<assistant>'
|
| 35 |
+
self.response_loss_only = getattr(self.sft_config, 'response_loss_only', True) if self.sft_config else True
|
| 36 |
+
|
| 37 |
+
def _load_chunk(self, chunk_num: int):
|
| 38 |
+
"""
|
| 39 |
+
Load a specific chunk from JSONL file, preserving full records for SFT parsing.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
chunk_num: Chunk number to load (0-indexed)
|
| 43 |
+
"""
|
| 44 |
+
if not self.chunk_manager:
|
| 45 |
+
return
|
| 46 |
+
|
| 47 |
+
if chunk_num == self._current_chunk_num and self._current_chunk_data is not None:
|
| 48 |
+
# Already loaded
|
| 49 |
+
return
|
| 50 |
+
|
| 51 |
+
# Read chunk - get full record objects
|
| 52 |
+
chunk_examples = self.chunk_manager.read_chunk(chunk_num)
|
| 53 |
+
|
| 54 |
+
# Store full records for SFT parsing (not just text field)
|
| 55 |
+
self._current_chunk_records = chunk_examples
|
| 56 |
+
|
| 57 |
+
# Initialize data structures
|
| 58 |
+
self._current_chunk_data = {
|
| 59 |
+
"input_ids": [],
|
| 60 |
+
"attention_mask": [],
|
| 61 |
+
"mask": [],
|
| 62 |
+
}
|
| 63 |
+
self._current_chunk_num = chunk_num
|
| 64 |
+
|
| 65 |
+
# Preprocess this chunk (tokenize and mask)
|
| 66 |
+
self._preprocess_chunk()
|
| 67 |
+
|
| 68 |
+
def _preprocess_chunk(self):
|
| 69 |
+
"""
|
| 70 |
+
Process SFT records from current chunk into tokenized sequences with masking.
|
| 71 |
+
|
| 72 |
+
Parses each record (single-turn or multi-turn) and generates:
|
| 73 |
+
- Token sequences with role markers
|
| 74 |
+
- Masking info (0=ignore, 1=train)
|
| 75 |
+
- Labels with -100 for ignored tokens
|
| 76 |
+
"""
|
| 77 |
+
if not self._current_chunk_records:
|
| 78 |
+
return
|
| 79 |
+
|
| 80 |
+
max_seq_length = self.config.model.max_seq_length
|
| 81 |
+
|
| 82 |
+
all_input_ids = []
|
| 83 |
+
all_attention_masks = []
|
| 84 |
+
all_masks = []
|
| 85 |
+
|
| 86 |
+
for record in self._current_chunk_records:
|
| 87 |
+
try:
|
| 88 |
+
# Parse record into (user, assistant) turns
|
| 89 |
+
turns, is_multi_turn = parse_sft_record(record, self.config)
|
| 90 |
+
|
| 91 |
+
if not turns:
|
| 92 |
+
# Fallback: try to use "text" field if present
|
| 93 |
+
if "text" in record:
|
| 94 |
+
turns = [(record["text"], "")]
|
| 95 |
+
else:
|
| 96 |
+
continue # Skip invalid records
|
| 97 |
+
|
| 98 |
+
# Build token sequence with role tokens and masking
|
| 99 |
+
input_ids, attention_mask, mask = build_sft_sequence_tokens(
|
| 100 |
+
turns=turns,
|
| 101 |
+
tokenizer=self.tokenizer,
|
| 102 |
+
user_token=self.user_token,
|
| 103 |
+
assistant_token=self.assistant_token,
|
| 104 |
+
max_seq_length=max_seq_length,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
all_input_ids.append(input_ids)
|
| 108 |
+
all_attention_masks.append(attention_mask)
|
| 109 |
+
all_masks.append(mask)
|
| 110 |
+
|
| 111 |
+
except Exception as e:
|
| 112 |
+
# Log and skip problematic records
|
| 113 |
+
print(f"Warning: Failed to process SFT record: {e}")
|
| 114 |
+
continue
|
| 115 |
+
|
| 116 |
+
# Update chunk data with tokenized sequences and masks
|
| 117 |
+
self._current_chunk_data = {
|
| 118 |
+
"input_ids": all_input_ids,
|
| 119 |
+
"attention_mask": all_attention_masks,
|
| 120 |
+
"mask": all_masks,
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
| 124 |
+
"""
|
| 125 |
+
Get preprocessed sample with response-only loss masking.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
idx: Sample index
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
Dict with input_ids, attention_mask, and labels (with -100 for ignored tokens)
|
| 132 |
+
"""
|
| 133 |
+
# Load appropriate chunk if using streaming
|
| 134 |
+
if self.chunk_manager:
|
| 135 |
+
chunk_num = self._get_chunk_for_idx(idx)
|
| 136 |
+
if chunk_num != self._current_chunk_num:
|
| 137 |
+
self._load_chunk(chunk_num)
|
| 138 |
+
local_idx = self._get_local_idx_in_chunk(idx)
|
| 139 |
+
else:
|
| 140 |
+
local_idx = idx
|
| 141 |
+
|
| 142 |
+
# Get tokenized data
|
| 143 |
+
input_ids = torch.tensor(self._current_chunk_data["input_ids"][local_idx], dtype=torch.long)
|
| 144 |
+
attention_mask = torch.tensor(self._current_chunk_data["attention_mask"][local_idx], dtype=torch.long)
|
| 145 |
+
mask = self._current_chunk_data["mask"][local_idx]
|
| 146 |
+
|
| 147 |
+
labels = torch.tensor(
|
| 148 |
+
build_response_only_next_token_labels(input_ids.tolist(), mask),
|
| 149 |
+
dtype=torch.long,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
return {
|
| 153 |
+
"input_ids": input_ids,
|
| 154 |
+
"attention_mask": attention_mask,
|
| 155 |
+
"labels": labels,
|
| 156 |
+
}
|
code/TaoTrain/src/taoTrain/data/sft_utils.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SFT utility functions for parsing and masking."""
|
| 2 |
+
|
| 3 |
+
from typing import Dict, Any, List, Tuple
|
| 4 |
+
from taoTrain.config import TrainingConfig
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def parse_sft_record(record: Dict[str, Any], config: TrainingConfig) -> Tuple[List[Tuple[str, str]], bool]:
|
| 8 |
+
"""
|
| 9 |
+
Parse JSONL record into list of (user, assistant) turns.
|
| 10 |
+
|
| 11 |
+
Supports two formats:
|
| 12 |
+
1. Single-turn: {"input": "...", "output": "..."}
|
| 13 |
+
2. Multi-turn: {"turns": [{"user": "...", "assistant": "..."}, ...]}
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
record: JSONL record (dict)
|
| 17 |
+
config: Training configuration
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
(turns_list, is_multi_turn) where:
|
| 21 |
+
- turns_list: List of (user_text, assistant_text) tuples
|
| 22 |
+
- is_multi_turn: Whether this is a multi-turn record
|
| 23 |
+
"""
|
| 24 |
+
# Check for multi-turn format
|
| 25 |
+
if "turns" in record:
|
| 26 |
+
turns = []
|
| 27 |
+
for turn in record["turns"]:
|
| 28 |
+
if isinstance(turn, dict) and "user" in turn and "assistant" in turn:
|
| 29 |
+
turns.append((turn["user"], turn["assistant"]))
|
| 30 |
+
if turns:
|
| 31 |
+
return turns, True
|
| 32 |
+
|
| 33 |
+
# Check for single-turn format with input/output fields
|
| 34 |
+
if "input" in record and "output" in record:
|
| 35 |
+
return [(record["input"], record["output"])], False
|
| 36 |
+
|
| 37 |
+
# Fallback: check for instruction/response fields (from config)
|
| 38 |
+
dataset_config = config.dataset
|
| 39 |
+
instruction_col = dataset_config.instruction_column or "instruction"
|
| 40 |
+
response_col = dataset_config.response_column or "response"
|
| 41 |
+
|
| 42 |
+
if instruction_col in record and response_col in record:
|
| 43 |
+
return [(record[instruction_col], record[response_col])], False
|
| 44 |
+
|
| 45 |
+
# Fallback: assume pre-formatted "text" field (old format)
|
| 46 |
+
if "text" in record:
|
| 47 |
+
return [(record["text"], "")], False
|
| 48 |
+
|
| 49 |
+
return [], False
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def build_sft_sequence_tokens(
|
| 53 |
+
turns: List[Tuple[str, str]],
|
| 54 |
+
tokenizer,
|
| 55 |
+
user_token: str = "<user>",
|
| 56 |
+
assistant_token: str = "<assistant>",
|
| 57 |
+
max_seq_length: int = 1024,
|
| 58 |
+
) -> Tuple[List[int], List[int], List[int]]:
|
| 59 |
+
"""
|
| 60 |
+
Build token sequence for SFT with role tokens and generate masking info.
|
| 61 |
+
|
| 62 |
+
Sequence format:
|
| 63 |
+
[user_token_id] user_tokens [assistant_token_id] assistant_tokens ... [eos_token_id]
|
| 64 |
+
|
| 65 |
+
Mask values:
|
| 66 |
+
- 0 (ignore): user input regions and role tokens → loss=-100
|
| 67 |
+
- 1 (train): assistant output regions → compute loss
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
turns: List of (user_text, assistant_text) tuples
|
| 71 |
+
tokenizer: Tokenizer instance
|
| 72 |
+
user_token: Role token for user (e.g., "<user>")
|
| 73 |
+
assistant_token: Role token for assistant (e.g., "<assistant>")
|
| 74 |
+
max_seq_length: Maximum sequence length
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
(input_ids, attention_mask, mask) where:
|
| 78 |
+
- input_ids: Token IDs for the full sequence
|
| 79 |
+
- attention_mask: Attention mask (1 for real tokens, 0 for padding)
|
| 80 |
+
- mask: Loss mask (0=ignore, 1=train loss)
|
| 81 |
+
"""
|
| 82 |
+
input_ids = []
|
| 83 |
+
mask = []
|
| 84 |
+
|
| 85 |
+
# Get token IDs for special tokens
|
| 86 |
+
user_token_ids = tokenizer(user_token, add_special_tokens=False)["input_ids"]
|
| 87 |
+
assistant_token_ids = tokenizer(assistant_token, add_special_tokens=False)["input_ids"]
|
| 88 |
+
|
| 89 |
+
# Process each turn
|
| 90 |
+
for user_text, assistant_text in turns:
|
| 91 |
+
# User role marker
|
| 92 |
+
input_ids.extend(user_token_ids)
|
| 93 |
+
mask.extend([0] * len(user_token_ids)) # Mask role token
|
| 94 |
+
|
| 95 |
+
# User message tokens
|
| 96 |
+
user_tokens = tokenizer(user_text, add_special_tokens=False)["input_ids"]
|
| 97 |
+
input_ids.extend(user_tokens)
|
| 98 |
+
mask.extend([0] * len(user_tokens)) # Mask user input
|
| 99 |
+
|
| 100 |
+
# Assistant role marker
|
| 101 |
+
input_ids.extend(assistant_token_ids)
|
| 102 |
+
mask.extend([0] * len(assistant_token_ids)) # Mask role token
|
| 103 |
+
|
| 104 |
+
# Assistant message tokens
|
| 105 |
+
assistant_tokens = tokenizer(assistant_text, add_special_tokens=False)["input_ids"]
|
| 106 |
+
input_ids.extend(assistant_tokens)
|
| 107 |
+
mask.extend([1] * len(assistant_tokens)) # Train on assistant output
|
| 108 |
+
|
| 109 |
+
# Add EOS token if exists
|
| 110 |
+
if hasattr(tokenizer, 'eos_token_id') and tokenizer.eos_token_id is not None:
|
| 111 |
+
input_ids.append(tokenizer.eos_token_id)
|
| 112 |
+
mask.append(0) # Mask EOS token
|
| 113 |
+
|
| 114 |
+
# Truncate if too long
|
| 115 |
+
if len(input_ids) > max_seq_length:
|
| 116 |
+
input_ids = input_ids[:max_seq_length]
|
| 117 |
+
mask = mask[:max_seq_length]
|
| 118 |
+
|
| 119 |
+
# Pad to max_seq_length
|
| 120 |
+
padding_len = max_seq_length - len(input_ids)
|
| 121 |
+
if padding_len > 0:
|
| 122 |
+
input_ids.extend([tokenizer.pad_token_id or 0] * padding_len)
|
| 123 |
+
mask.extend([0] * padding_len) # Mask padding tokens
|
| 124 |
+
|
| 125 |
+
# Create attention mask (1 for real tokens, 0 for padding)
|
| 126 |
+
attention_mask = [1 if i < len(input_ids) - padding_len else 0 for i in range(len(input_ids))]
|
| 127 |
+
|
| 128 |
+
return input_ids, attention_mask, mask
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def apply_response_masking(input_ids: List[int], mask: List[int]) -> List[int]:
|
| 132 |
+
"""
|
| 133 |
+
Apply response-only loss masking by converting mask values to label format.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
input_ids: Token IDs
|
| 137 |
+
mask: Mask array (0=ignore, 1=train)
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
labels: Where mask=0 tokens have label=-100 (ignore in loss), mask=1 tokens have label=input_id
|
| 141 |
+
"""
|
| 142 |
+
labels = input_ids.copy()
|
| 143 |
+
for i, m in enumerate(mask):
|
| 144 |
+
if m == 0:
|
| 145 |
+
labels[i] = -100 # CrossEntropyLoss will ignore this token
|
| 146 |
+
return labels
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def build_response_only_next_token_labels(input_ids: List[int], mask: List[int]) -> List[int]:
|
| 150 |
+
"""
|
| 151 |
+
Build next-token labels for SFT response-only training.
|
| 152 |
+
|
| 153 |
+
Position i predicts token i+1, so the loss mask must be applied to the target
|
| 154 |
+
token, not the current input token. This trains the first assistant token from
|
| 155 |
+
the assistant role marker and avoids training on masked EOS/padding targets.
|
| 156 |
+
"""
|
| 157 |
+
if len(input_ids) != len(mask):
|
| 158 |
+
raise ValueError(f"input_ids and mask must have the same length: {len(input_ids)} != {len(mask)}")
|
| 159 |
+
|
| 160 |
+
labels = apply_response_masking(input_ids, mask)
|
| 161 |
+
return labels[1:] + [-100]
|
code/TaoTrain/src/taoTrain/data/tokenization_queue.py
ADDED
|
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Background tokenization queue for streaming large JSONL datasets."""
|
| 2 |
+
|
| 3 |
+
import queue
|
| 4 |
+
import threading
|
| 5 |
+
import time
|
| 6 |
+
from typing import Dict, List, Optional, Any, Callable
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from taoTrain.data.chunk_manager import ChunkManager
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class TokenizationQueue:
|
| 13 |
+
"""
|
| 14 |
+
Background threads that continuously tokenize chunks and stores them in a queue.
|
| 15 |
+
|
| 16 |
+
This allows tokenization to happen in parallel with training, avoiding the bottleneck
|
| 17 |
+
of tokenizing all data upfront before training starts.
|
| 18 |
+
|
| 19 |
+
Supports multiple worker threads for faster throughput. Each thread greedily
|
| 20 |
+
grabs the next available chunk using an atomic counter.
|
| 21 |
+
|
| 22 |
+
Attributes:
|
| 23 |
+
total_items: Total number of samples across all chunks
|
| 24 |
+
queue_size: Maximum number of chunks to buffer in memory
|
| 25 |
+
num_threads: Number of worker threads for tokenization
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
chunk_manager: ChunkManager,
|
| 31 |
+
tokenizer: Any,
|
| 32 |
+
config: "TrainingConfig", # type: ignore
|
| 33 |
+
max_queue_size: int = 2,
|
| 34 |
+
shuffle_chunks: bool = True,
|
| 35 |
+
num_threads: int = 1,
|
| 36 |
+
):
|
| 37 |
+
"""
|
| 38 |
+
Initialize tokenization queue with multithreading support.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
chunk_manager: ChunkManager instance loaded with chunks
|
| 42 |
+
tokenizer: Tokenizer instance (HuggingFace or SentencePiece wrapper)
|
| 43 |
+
config: Training configuration with model and dataset settings
|
| 44 |
+
max_queue_size: Maximum chunks to buffer in queue (memory constraint)
|
| 45 |
+
shuffle_chunks: Whether to shuffle chunk order at initialization
|
| 46 |
+
num_threads: Number of worker threads for tokenization (default: 1)
|
| 47 |
+
|
| 48 |
+
Raises:
|
| 49 |
+
ValueError: If chunk_manager has no chunks or num_threads < 1
|
| 50 |
+
"""
|
| 51 |
+
if chunk_manager.num_chunks == 0:
|
| 52 |
+
raise ValueError("ChunkManager must have at least one chunk")
|
| 53 |
+
if num_threads < 1:
|
| 54 |
+
raise ValueError(f"num_threads must be >= 1, got {num_threads}")
|
| 55 |
+
|
| 56 |
+
self.chunk_manager = chunk_manager
|
| 57 |
+
self.tokenizer = tokenizer
|
| 58 |
+
self.config = config
|
| 59 |
+
self.max_queue_size = max_queue_size
|
| 60 |
+
self.shuffle_chunks = shuffle_chunks
|
| 61 |
+
self.num_threads = num_threads
|
| 62 |
+
|
| 63 |
+
# Detect SFT mode: check for response_loss_only flag
|
| 64 |
+
self.is_sft_mode = hasattr(config, 'response_loss_only') and config.response_loss_only
|
| 65 |
+
|
| 66 |
+
# Calculate total items across all chunks
|
| 67 |
+
self.total_items = chunk_manager.effective_lines
|
| 68 |
+
|
| 69 |
+
# Thread-safe queue for tokenized chunks
|
| 70 |
+
self._queue: queue.Queue[Dict[str, List]] = queue.Queue(maxsize=max_queue_size)
|
| 71 |
+
|
| 72 |
+
# Control signals
|
| 73 |
+
self._stop_event = threading.Event()
|
| 74 |
+
self._error_event = threading.Event()
|
| 75 |
+
self._error_messages: List[str] = []
|
| 76 |
+
self._threads: List[threading.Thread] = []
|
| 77 |
+
|
| 78 |
+
# Thread-safe chunk distribution
|
| 79 |
+
self._next_chunk_idx = 0
|
| 80 |
+
self._chunk_idx_lock = threading.Lock()
|
| 81 |
+
self._active_threads = 0
|
| 82 |
+
self._active_threads_lock = threading.Lock()
|
| 83 |
+
|
| 84 |
+
# Chunk ordering
|
| 85 |
+
self._chunk_order = list(range(chunk_manager.num_chunks))
|
| 86 |
+
print(f"TokenizationQueue initialized with {chunk_manager.num_chunks} chunks, total {chunk_manager.effective_lines} samples")
|
| 87 |
+
print(f"Using {num_threads} tokenization worker thread{'s' if num_threads != 1 else ''}")
|
| 88 |
+
print(f"Max queue size: {max_queue_size} chunks (memory constraint)")
|
| 89 |
+
if self.shuffle_chunks:
|
| 90 |
+
import random
|
| 91 |
+
random.shuffle(self._chunk_order)
|
| 92 |
+
|
| 93 |
+
def _get_next_chunk_idx(self) -> Optional[int]:
|
| 94 |
+
"""
|
| 95 |
+
Atomically get the next chunk index for processing.
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
Chunk index to process, or None if all chunks have been assigned
|
| 99 |
+
"""
|
| 100 |
+
with self._chunk_idx_lock:
|
| 101 |
+
if self._next_chunk_idx < len(self._chunk_order):
|
| 102 |
+
chunk_idx = self._chunk_order[self._next_chunk_idx]
|
| 103 |
+
self._next_chunk_idx += 1
|
| 104 |
+
return chunk_idx
|
| 105 |
+
return None
|
| 106 |
+
|
| 107 |
+
def start(self):
|
| 108 |
+
"""Start the tokenization background worker threads."""
|
| 109 |
+
if self._threads:
|
| 110 |
+
raise RuntimeError(f"Tokenization threads already started ({len(self._threads)} active)")
|
| 111 |
+
|
| 112 |
+
# Create and start N worker threads
|
| 113 |
+
for thread_id in range(self.num_threads):
|
| 114 |
+
thread = threading.Thread(target=self._worker, args=(thread_id,), daemon=False)
|
| 115 |
+
self._threads.append(thread)
|
| 116 |
+
thread.start()
|
| 117 |
+
|
| 118 |
+
def _worker(self, thread_id: int):
|
| 119 |
+
"""
|
| 120 |
+
Worker thread target: greedy chunk processing with thread-safe distribution.
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
thread_id: Identifier for this worker thread
|
| 124 |
+
"""
|
| 125 |
+
with self._active_threads_lock:
|
| 126 |
+
self._active_threads += 1
|
| 127 |
+
|
| 128 |
+
try:
|
| 129 |
+
while True:
|
| 130 |
+
# Check for stop signal
|
| 131 |
+
if self._stop_event.is_set():
|
| 132 |
+
break
|
| 133 |
+
|
| 134 |
+
# Get next chunk to process (atomic operation)
|
| 135 |
+
chunk_num = self._get_next_chunk_idx()
|
| 136 |
+
if chunk_num is None:
|
| 137 |
+
# All chunks assigned
|
| 138 |
+
break
|
| 139 |
+
|
| 140 |
+
# Load chunk
|
| 141 |
+
chunk_examples = self.chunk_manager.read_chunk(chunk_num)
|
| 142 |
+
|
| 143 |
+
# Tokenize chunk based on mode
|
| 144 |
+
if self.is_sft_mode:
|
| 145 |
+
tokenized_chunk = self._tokenize_batch_sft(chunk_examples)
|
| 146 |
+
else:
|
| 147 |
+
# Extract texts for pretrain
|
| 148 |
+
text_field = self.config.dataset.text_field
|
| 149 |
+
texts = [obj.get(text_field, "") for obj in chunk_examples]
|
| 150 |
+
tokenized_chunk = self._tokenize_batch(texts)
|
| 151 |
+
|
| 152 |
+
# Put in queue (blocks if queue is full)
|
| 153 |
+
self._queue.put(tokenized_chunk)
|
| 154 |
+
print(f"[Worker-{thread_id}] Processed chunk {chunk_num}, put {len(tokenized_chunk['input_ids'])} samples in queue")
|
| 155 |
+
except Exception as e:
|
| 156 |
+
error_msg = f"[Worker-{thread_id}] {str(e)}"
|
| 157 |
+
print(f"Worker-{thread_id} encountered an error: {error_msg}")
|
| 158 |
+
# Thread-safe append to error list
|
| 159 |
+
self._error_messages.append(error_msg)
|
| 160 |
+
self._error_event.set()
|
| 161 |
+
finally:
|
| 162 |
+
with self._active_threads_lock:
|
| 163 |
+
self._active_threads -= 1
|
| 164 |
+
remaining = self._active_threads
|
| 165 |
+
print(f"[Worker-{thread_id}] Finished processing. Active threads remaining: {remaining}")
|
| 166 |
+
def _tokenize_batch(self, texts: List[str]) -> Dict[str, List]:
|
| 167 |
+
"""
|
| 168 |
+
Tokenize a batch of texts, join with EOS, and split into fixed-size sequences.
|
| 169 |
+
|
| 170 |
+
This packs multiple documents into longer sequences separated by EOS tokens,
|
| 171 |
+
then splits the concatenated tokens into N fixed-size chunks of max_seq_length.
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
texts: List of text strings
|
| 175 |
+
|
| 176 |
+
Returns:
|
| 177 |
+
Dict with 'input_ids' and 'attention_mask' lists, where each element
|
| 178 |
+
is a fixed-size sequence of length max_seq_length
|
| 179 |
+
"""
|
| 180 |
+
max_seq_length = self.config.model.max_seq_length
|
| 181 |
+
|
| 182 |
+
# Get EOS token ID
|
| 183 |
+
eos_token_id = self.tokenizer.eos_token_id
|
| 184 |
+
unk_token_id = self.tokenizer.unk_token_id
|
| 185 |
+
if eos_token_id is None:
|
| 186 |
+
raise ValueError("Tokenizer does not have an EOS token defined")
|
| 187 |
+
if unk_token_id is None:
|
| 188 |
+
raise ValueError("Tokenizer does not have an UNK token defined")
|
| 189 |
+
|
| 190 |
+
# Tokenize all texts without truncation
|
| 191 |
+
all_token_ids = []
|
| 192 |
+
|
| 193 |
+
for i, text in enumerate(texts):
|
| 194 |
+
tokenized = self.tokenizer(
|
| 195 |
+
text,
|
| 196 |
+
truncation=False,
|
| 197 |
+
return_attention_mask=False,
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
# Remove UNK tokens from tokenized output (if any)
|
| 201 |
+
tokenized["input_ids"] = [tid for tid in tokenized["input_ids"] if tid != unk_token_id]
|
| 202 |
+
|
| 203 |
+
all_token_ids.extend(tokenized["input_ids"])
|
| 204 |
+
# Add EOS token between documents (except after the last one)
|
| 205 |
+
if i < len(texts) - 1:
|
| 206 |
+
all_token_ids.append(eos_token_id)
|
| 207 |
+
|
| 208 |
+
# Split into N fixed-size sequences
|
| 209 |
+
sequences_input_ids = []
|
| 210 |
+
sequences_attention_masks = []
|
| 211 |
+
|
| 212 |
+
for i in range(0, len(all_token_ids), max_seq_length):
|
| 213 |
+
seq = all_token_ids[i : i + max_seq_length]
|
| 214 |
+
|
| 215 |
+
# Pad sequence if it's shorter than max_seq_length
|
| 216 |
+
if len(seq) < max_seq_length:
|
| 217 |
+
# Create attention mask before padding
|
| 218 |
+
attention_mask = [1] * len(seq) + [0] * (max_seq_length - len(seq))
|
| 219 |
+
# Pad with 0 (assuming 0 is the pad token, or use tokenizer.pad_token_id)
|
| 220 |
+
pad_token_id = self.tokenizer.pad_token_id or 0
|
| 221 |
+
seq = seq + [pad_token_id] * (max_seq_length - len(seq))
|
| 222 |
+
else:
|
| 223 |
+
attention_mask = [1] * max_seq_length
|
| 224 |
+
|
| 225 |
+
sequences_input_ids.append(seq)
|
| 226 |
+
sequences_attention_masks.append(attention_mask)
|
| 227 |
+
|
| 228 |
+
return {
|
| 229 |
+
"input_ids": sequences_input_ids,
|
| 230 |
+
"attention_mask": sequences_attention_masks,
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
def _tokenize_batch_sft(self, records: List[Dict[str, Any]]) -> Dict[str, List]:
|
| 234 |
+
"""
|
| 235 |
+
Tokenize a batch of SFT records with role tokens and response masking.
|
| 236 |
+
|
| 237 |
+
Processes each record (single-turn or multi-turn) and generates sequences
|
| 238 |
+
with role markers and masking (0=ignore user, 1=train on assistant).
|
| 239 |
+
|
| 240 |
+
Args:
|
| 241 |
+
records: List of JSONL record dicts with various SFT formats
|
| 242 |
+
|
| 243 |
+
Returns:
|
| 244 |
+
Dict with 'input_ids', 'attention_mask', and 'mask' lists, where each
|
| 245 |
+
element is a fixed-size sequence of length max_seq_length with masking info
|
| 246 |
+
"""
|
| 247 |
+
# Import here to avoid circular imports
|
| 248 |
+
from taoTrain.data.sft_utils import parse_sft_record, build_sft_sequence_tokens
|
| 249 |
+
|
| 250 |
+
max_seq_length = self.config.model.max_seq_length
|
| 251 |
+
user_token = getattr(self.config, 'user_token', '<user>')
|
| 252 |
+
assistant_token = getattr(self.config, 'assistant_token', '<assistant>')
|
| 253 |
+
|
| 254 |
+
sequences_input_ids = []
|
| 255 |
+
sequences_attention_masks = []
|
| 256 |
+
sequences_masks = []
|
| 257 |
+
|
| 258 |
+
for record in records:
|
| 259 |
+
try:
|
| 260 |
+
# Parse SFT record (supports multiple formats)
|
| 261 |
+
turns, is_multi_turn = parse_sft_record(record, self.config)
|
| 262 |
+
|
| 263 |
+
if not turns:
|
| 264 |
+
# Skip records that couldn't be parsed
|
| 265 |
+
continue
|
| 266 |
+
|
| 267 |
+
# Build token sequence with role tokens and response masking
|
| 268 |
+
input_ids, attention_mask, mask = build_sft_sequence_tokens(
|
| 269 |
+
turns=turns,
|
| 270 |
+
tokenizer=self.tokenizer,
|
| 271 |
+
user_token=user_token,
|
| 272 |
+
assistant_token=assistant_token,
|
| 273 |
+
max_seq_length=max_seq_length,
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
sequences_input_ids.append(input_ids)
|
| 277 |
+
sequences_attention_masks.append(attention_mask)
|
| 278 |
+
sequences_masks.append(mask)
|
| 279 |
+
|
| 280 |
+
except Exception as e:
|
| 281 |
+
# Log error but continue processing
|
| 282 |
+
print(f"Warning: Failed to tokenize SFT record: {e}")
|
| 283 |
+
continue
|
| 284 |
+
|
| 285 |
+
return {
|
| 286 |
+
"input_ids": sequences_input_ids,
|
| 287 |
+
"attention_mask": sequences_attention_masks,
|
| 288 |
+
"mask": sequences_masks,
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
def get_next_chunk(self, timeout: Optional[float] = None) -> Optional[Dict[str, List]]:
|
| 292 |
+
"""
|
| 293 |
+
Get the next tokenized chunk from the queue.
|
| 294 |
+
|
| 295 |
+
This is a blocking call that waits for the next chunk to be tokenized.
|
| 296 |
+
Returns None if queue is closed or all chunks have been processed.
|
| 297 |
+
|
| 298 |
+
CRITICAL: Always attempts to drain the queue first before returning None.
|
| 299 |
+
This prevents abandoning buffered chunks when threads finish.
|
| 300 |
+
|
| 301 |
+
Args:
|
| 302 |
+
timeout: Timeout in seconds (None = wait indefinitely)
|
| 303 |
+
|
| 304 |
+
Returns:
|
| 305 |
+
Dict with tokenized chunk, or None if queue is exhausted
|
| 306 |
+
|
| 307 |
+
Raises:
|
| 308 |
+
RuntimeError: If an error occurred in any worker thread
|
| 309 |
+
"""
|
| 310 |
+
if self._error_event.is_set():
|
| 311 |
+
error_summary = "; ".join(self._error_messages) if self._error_messages else "Unknown error"
|
| 312 |
+
raise RuntimeError(f"Tokenization thread error: {error_summary}")
|
| 313 |
+
|
| 314 |
+
# PRIORITY: Try to get from queue first (may have buffered items)
|
| 315 |
+
try:
|
| 316 |
+
chunk = self._queue.get(timeout=timeout)
|
| 317 |
+
return chunk
|
| 318 |
+
except queue.Empty:
|
| 319 |
+
# Queue is empty - check if threads are still working
|
| 320 |
+
with self._active_threads_lock:
|
| 321 |
+
if self._active_threads == 0 and self._next_chunk_idx >= len(self._chunk_order):
|
| 322 |
+
# All chunks assigned AND no active threads = true exhaustion
|
| 323 |
+
return None
|
| 324 |
+
# Queue temporarily empty but threads still working - signal to wait
|
| 325 |
+
return None
|
| 326 |
+
|
| 327 |
+
@property
|
| 328 |
+
def is_exhausted(self) -> bool:
|
| 329 |
+
"""Return True only when all chunks are assigned and all workers are idle."""
|
| 330 |
+
with self._active_threads_lock:
|
| 331 |
+
return self._active_threads == 0 and self._next_chunk_idx >= len(self._chunk_order)
|
| 332 |
+
|
| 333 |
+
def shutdown(self, wait: bool = True):
|
| 334 |
+
"""
|
| 335 |
+
Shutdown the tokenization worker threads gracefully.
|
| 336 |
+
|
| 337 |
+
Args:
|
| 338 |
+
wait: If True, wait for all threads to finish; otherwise return immediately
|
| 339 |
+
"""
|
| 340 |
+
if not self._threads:
|
| 341 |
+
return
|
| 342 |
+
|
| 343 |
+
# Signal threads to stop
|
| 344 |
+
self._stop_event.set()
|
| 345 |
+
|
| 346 |
+
# Drain queue to unblock threads if they're waiting to put
|
| 347 |
+
try:
|
| 348 |
+
while True:
|
| 349 |
+
self._queue.get_nowait()
|
| 350 |
+
except queue.Empty:
|
| 351 |
+
pass
|
| 352 |
+
|
| 353 |
+
# Wait for all threads to finish
|
| 354 |
+
if wait:
|
| 355 |
+
for thread in self._threads:
|
| 356 |
+
thread.join(timeout=5.0)
|
| 357 |
+
if thread.is_alive():
|
| 358 |
+
print(f"⚠ Tokenization thread {thread.name} did not terminate cleanly")
|
| 359 |
+
|
| 360 |
+
# Clear thread list to allow fresh start in next epoch
|
| 361 |
+
self._threads.clear()
|
| 362 |
+
print("✓ TokenizationQueue shutdown complete, thread list cleared")
|
| 363 |
+
|
| 364 |
+
def reset_for_next_epoch(self):
|
| 365 |
+
"""
|
| 366 |
+
Reset queue state for the next epoch.
|
| 367 |
+
|
| 368 |
+
This allows the same TokenizationQueue to be reused across multiple epochs.
|
| 369 |
+
Resets the chunk index counter, reshuffles chunks (if enabled), and clears
|
| 370 |
+
any buffered items and error state.
|
| 371 |
+
|
| 372 |
+
Called by AsyncBatchIterator at the start of epoch 2+.
|
| 373 |
+
"""
|
| 374 |
+
# Reset iteration counter
|
| 375 |
+
self._next_chunk_idx = 0
|
| 376 |
+
|
| 377 |
+
# Reshuffle chunk order if enabled
|
| 378 |
+
if self.shuffle_chunks:
|
| 379 |
+
import random
|
| 380 |
+
random.shuffle(self._chunk_order)
|
| 381 |
+
print(f"✓ Reshuffled chunk order for next epoch: {self._chunk_order}")
|
| 382 |
+
|
| 383 |
+
# Drain any remaining items from queue
|
| 384 |
+
items_drained = 0
|
| 385 |
+
try:
|
| 386 |
+
while True:
|
| 387 |
+
self._queue.get_nowait()
|
| 388 |
+
items_drained += 1
|
| 389 |
+
except queue.Empty:
|
| 390 |
+
pass
|
| 391 |
+
|
| 392 |
+
if items_drained > 0:
|
| 393 |
+
print(f"⚠ Drained {items_drained} items from queue before epoch reset")
|
| 394 |
+
|
| 395 |
+
# Clear error state
|
| 396 |
+
self._error_event.clear()
|
| 397 |
+
self._error_messages.clear()
|
| 398 |
+
|
| 399 |
+
# Clear threads list so new threads will be started in next epoch
|
| 400 |
+
self._threads.clear()
|
| 401 |
+
|
| 402 |
+
print(f"✓ TokenizationQueue reset for next epoch. Ready to process {len(self._chunk_order)} chunks")
|
| 403 |
+
|
| 404 |
+
def __len__(self) -> int:
|
| 405 |
+
"""Return total number of samples."""
|
| 406 |
+
return self.total_items
|
| 407 |
+
|
| 408 |
+
def __del__(self):
|
| 409 |
+
"""Cleanup on deletion."""
|
| 410 |
+
self.shutdown(wait=False)
|
code/TaoTrain/src/taoTrain/data/tokenizer.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SentencePiece tokenizer wrapper for HuggingFace compatibility."""
|
| 2 |
+
|
| 3 |
+
from typing import Optional, List, Union
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class SentencePieceTokenizerWrapper:
|
| 7 |
+
"""Wrapper to make SentencePiece tokenizer compatible with HuggingFace interface."""
|
| 8 |
+
|
| 9 |
+
def __init__(self, sp_processor):
|
| 10 |
+
"""
|
| 11 |
+
Initialize wrapper.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
sp_processor: sentencepiece.SentencePieceProcessor instance
|
| 15 |
+
"""
|
| 16 |
+
self.sp = sp_processor
|
| 17 |
+
self.vocab_size = self.sp.vocab_size()
|
| 18 |
+
self.pad_token_id = self.sp.pad_id()
|
| 19 |
+
self.eos_token_id = self.sp.eos_id()
|
| 20 |
+
self.bos_token_id = self.sp.bos_id()
|
| 21 |
+
self.unk_token_id = self.sp.unk_id()
|
| 22 |
+
|
| 23 |
+
def __call__(self, text, **kwargs):
|
| 24 |
+
"""
|
| 25 |
+
Tokenize text.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
text: Input text or list of texts
|
| 29 |
+
**kwargs: Additional arguments (truncation, max_length, padding, return_attention_mask)
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
Dict with input_ids and attention_mask
|
| 33 |
+
"""
|
| 34 |
+
# Handle both single string and list of strings
|
| 35 |
+
is_single = isinstance(text, str)
|
| 36 |
+
texts = [text] if is_single else text
|
| 37 |
+
|
| 38 |
+
max_length = kwargs.get('max_length', None)
|
| 39 |
+
padding = kwargs.get('padding', None)
|
| 40 |
+
truncation = kwargs.get('truncation', False)
|
| 41 |
+
return_attention_mask = kwargs.get('return_attention_mask', True)
|
| 42 |
+
|
| 43 |
+
# Tokenize all texts
|
| 44 |
+
all_input_ids = []
|
| 45 |
+
for t in texts:
|
| 46 |
+
tokens = self.sp.encode(t, out_type=int)
|
| 47 |
+
|
| 48 |
+
# Truncate if needed
|
| 49 |
+
if truncation and max_length and len(tokens) > max_length:
|
| 50 |
+
tokens = tokens[:max_length]
|
| 51 |
+
|
| 52 |
+
all_input_ids.append(tokens)
|
| 53 |
+
|
| 54 |
+
# Padding
|
| 55 |
+
if padding or max_length:
|
| 56 |
+
target_length = max_length or max(len(ids) for ids in all_input_ids) if all_input_ids else 0
|
| 57 |
+
padded_input_ids = []
|
| 58 |
+
padded_attention_masks = []
|
| 59 |
+
|
| 60 |
+
for ids in all_input_ids:
|
| 61 |
+
pad_length = target_length - len(ids)
|
| 62 |
+
if pad_length > 0:
|
| 63 |
+
padded_ids = ids + [self.pad_token_id] * pad_length
|
| 64 |
+
else:
|
| 65 |
+
padded_ids = ids[:target_length]
|
| 66 |
+
|
| 67 |
+
padded_input_ids.append(padded_ids)
|
| 68 |
+
attention_mask = [1] * len(ids) + [0] * (target_length - len(ids))
|
| 69 |
+
padded_attention_masks.append(attention_mask)
|
| 70 |
+
|
| 71 |
+
result = {
|
| 72 |
+
"input_ids": padded_input_ids if not is_single else padded_input_ids[0],
|
| 73 |
+
}
|
| 74 |
+
if return_attention_mask:
|
| 75 |
+
result["attention_mask"] = padded_attention_masks if not is_single else padded_attention_masks[0]
|
| 76 |
+
else:
|
| 77 |
+
result = {
|
| 78 |
+
"input_ids": all_input_ids[0] if is_single else all_input_ids,
|
| 79 |
+
}
|
| 80 |
+
if return_attention_mask:
|
| 81 |
+
attention_masks = [[1] * len(ids) for ids in all_input_ids]
|
| 82 |
+
result["attention_mask"] = attention_masks[0] if is_single else attention_masks
|
| 83 |
+
|
| 84 |
+
return result
|
| 85 |
+
|
| 86 |
+
def encode(self, text, return_tensors=None, **kwargs):
|
| 87 |
+
"""Encode text to token IDs."""
|
| 88 |
+
result = self(text, **kwargs)
|
| 89 |
+
input_ids = result["input_ids"]
|
| 90 |
+
|
| 91 |
+
if return_tensors == "pt":
|
| 92 |
+
import torch
|
| 93 |
+
# Ensure input_ids is a 1D list of ints
|
| 94 |
+
if isinstance(input_ids[0], list):
|
| 95 |
+
input_ids = input_ids[0]
|
| 96 |
+
return torch.tensor(input_ids, dtype=torch.long).unsqueeze(0)
|
| 97 |
+
|
| 98 |
+
return input_ids
|
| 99 |
+
|
| 100 |
+
def encode_plus(self, text, **kwargs):
|
| 101 |
+
"""Encode text with additional information (HuggingFace compatibility)."""
|
| 102 |
+
return self(text, **kwargs)
|
| 103 |
+
|
| 104 |
+
def decode(self, token_ids, skip_special_tokens=False, **kwargs):
|
| 105 |
+
"""Decode token IDs to text."""
|
| 106 |
+
if hasattr(token_ids, 'tolist'): # Handle torch tensors
|
| 107 |
+
token_ids = token_ids.tolist()
|
| 108 |
+
|
| 109 |
+
# Handle various input formats
|
| 110 |
+
if isinstance(token_ids, (list, tuple)):
|
| 111 |
+
if len(token_ids) > 0 and isinstance(token_ids[0], (list, tuple)):
|
| 112 |
+
token_ids = token_ids[0]
|
| 113 |
+
|
| 114 |
+
# Ensure it's a list of ints
|
| 115 |
+
if not isinstance(token_ids, list):
|
| 116 |
+
token_ids = [int(t) for t in token_ids]
|
| 117 |
+
|
| 118 |
+
return self.sp.decode(token_ids)
|
code/TaoTrain/src/taoTrain/inference/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Inference engines."""
|
| 2 |
+
|
| 3 |
+
from .inferencer import Inferencer
|
| 4 |
+
|
| 5 |
+
__all__ = ["Inferencer"]
|
code/TaoTrain/src/taoTrain/inference/inferencer.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Inference engine for model generation."""
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Optional, Iterator, Any
|
| 5 |
+
import torch
|
| 6 |
+
from transformers import AutoTokenizer
|
| 7 |
+
from rich.console import Console
|
| 8 |
+
from rich.table import Table
|
| 9 |
+
|
| 10 |
+
from taoTrain.core import BaseModel
|
| 11 |
+
from taoTrain.config import ModelConfig
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Inferencer:
|
| 15 |
+
"""Inference engine for text generation."""
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
model: BaseModel,
|
| 20 |
+
tokenizer: Any,
|
| 21 |
+
device: Optional[torch.device] = None,
|
| 22 |
+
dtype: Optional[torch.dtype] = None,
|
| 23 |
+
):
|
| 24 |
+
"""
|
| 25 |
+
Initialize inferencer.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
model: Trained model
|
| 29 |
+
tokenizer: Tokenizer instance (HuggingFace or SentencePiece wrapped)
|
| 30 |
+
device: Device for inference
|
| 31 |
+
dtype: Data type for inference
|
| 32 |
+
"""
|
| 33 |
+
self.model = model
|
| 34 |
+
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 35 |
+
self.dtype = dtype or torch.float32
|
| 36 |
+
self.tokenizer = tokenizer
|
| 37 |
+
|
| 38 |
+
# Move model to device and set eval mode
|
| 39 |
+
self.model = self.model.to(self.device)
|
| 40 |
+
self.model.eval()
|
| 41 |
+
|
| 42 |
+
# Set pad token if needed (for HuggingFace tokenizers)
|
| 43 |
+
if hasattr(self.tokenizer, 'pad_token') and self.tokenizer.pad_token is None:
|
| 44 |
+
if hasattr(self.tokenizer, 'eos_token'):
|
| 45 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 46 |
+
|
| 47 |
+
@staticmethod
|
| 48 |
+
def _load_tokenizer(tokenizer_path: str | Path) -> Any:
|
| 49 |
+
"""
|
| 50 |
+
Load tokenizer from path (SentencePiece or HuggingFace).
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
tokenizer_path: Path to tokenizer file or HuggingFace model name
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
Tokenizer instance
|
| 57 |
+
|
| 58 |
+
Raises:
|
| 59 |
+
ValueError: If tokenizer cannot be loaded
|
| 60 |
+
"""
|
| 61 |
+
tokenizer_path = str(tokenizer_path)
|
| 62 |
+
|
| 63 |
+
# Auto-detect tokenizer type based on file extension
|
| 64 |
+
if tokenizer_path.endswith('.model'):
|
| 65 |
+
# Load SentencePiece tokenizer
|
| 66 |
+
try:
|
| 67 |
+
import sentencepiece as spm
|
| 68 |
+
sp = spm.SentencePieceProcessor()
|
| 69 |
+
sp.Load(tokenizer_path)
|
| 70 |
+
# Wrap SentencePiece in a compatible interface
|
| 71 |
+
from taoTrain.data import SentencePieceTokenizerWrapper
|
| 72 |
+
return SentencePieceTokenizerWrapper(sp)
|
| 73 |
+
except ImportError:
|
| 74 |
+
raise ImportError("SentencePiece not installed. Install with: pip install sentencepiece")
|
| 75 |
+
except Exception as e:
|
| 76 |
+
raise ValueError(f"Failed to load SentencePiece tokenizer from {tokenizer_path}: {e}")
|
| 77 |
+
else:
|
| 78 |
+
# Load HuggingFace tokenizer
|
| 79 |
+
try:
|
| 80 |
+
return AutoTokenizer.from_pretrained(tokenizer_path)
|
| 81 |
+
except Exception as e:
|
| 82 |
+
raise ValueError(f"Failed to load HuggingFace tokenizer from {tokenizer_path}: {e}")
|
| 83 |
+
|
| 84 |
+
@staticmethod
|
| 85 |
+
def _print_tokenizer_info(tokenizer: Any, tokenizer_path: str) -> None:
|
| 86 |
+
"""Print tokenizer information."""
|
| 87 |
+
console = Console()
|
| 88 |
+
table = Table(title="Tokenizer Information")
|
| 89 |
+
table.add_column("Property", style="cyan")
|
| 90 |
+
table.add_column("Value", style="green")
|
| 91 |
+
|
| 92 |
+
table.add_row("Type", "SentencePiece" if tokenizer_path.endswith('.model') else "HuggingFace")
|
| 93 |
+
table.add_row("Path", str(tokenizer_path))
|
| 94 |
+
|
| 95 |
+
if hasattr(tokenizer, 'vocab_size'):
|
| 96 |
+
table.add_row("Vocab Size", str(tokenizer.vocab_size))
|
| 97 |
+
|
| 98 |
+
console.print(table)
|
| 99 |
+
|
| 100 |
+
@staticmethod
|
| 101 |
+
def load_from_checkpoint(
|
| 102 |
+
checkpoint_path: str | Path,
|
| 103 |
+
tokenizer_path: Optional[str | Path] = None,
|
| 104 |
+
device: Optional[torch.device] = None,
|
| 105 |
+
) -> "Inferencer":
|
| 106 |
+
"""
|
| 107 |
+
Load model from checkpoint and create inferencer.
|
| 108 |
+
|
| 109 |
+
Handles both canonical and legacy checkpoint formats:
|
| 110 |
+
- Canonical: uses 'model_state' key
|
| 111 |
+
- Legacy: uses 'model_state_dict' key
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
checkpoint_path: Path to checkpoint file
|
| 115 |
+
tokenizer_path: Optional path to tokenizer (overrides checkpoint's tokenizer_path)
|
| 116 |
+
device: Device for inference
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
Inferencer instance
|
| 120 |
+
|
| 121 |
+
Raises:
|
| 122 |
+
ValueError: If no tokenizer path found in checkpoint or arguments
|
| 123 |
+
"""
|
| 124 |
+
if device is None:
|
| 125 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 126 |
+
|
| 127 |
+
# Load checkpoint using CheckpointManager for automatic format normalization
|
| 128 |
+
from taoTrain.checkpointing.checkpoint import CheckpointManager
|
| 129 |
+
checkpoint_manager = CheckpointManager(checkpoint_path.parent if isinstance(checkpoint_path, Path) else Path(checkpoint_path).parent)
|
| 130 |
+
checkpoint = checkpoint_manager.load(checkpoint_path, device=device)
|
| 131 |
+
|
| 132 |
+
config_dict = checkpoint.get("config", {})
|
| 133 |
+
|
| 134 |
+
# Extract tokenizer path from checkpoint config or use provided override
|
| 135 |
+
if tokenizer_path is None:
|
| 136 |
+
# Try to get tokenizer_path from checkpoint config
|
| 137 |
+
dataset_config = config_dict.get("dataset", {})
|
| 138 |
+
tokenizer_path = dataset_config.get("tokenizer_path")
|
| 139 |
+
|
| 140 |
+
if not tokenizer_path:
|
| 141 |
+
raise ValueError(
|
| 142 |
+
f"No tokenizer path found in checkpoint config at {checkpoint_path}. "
|
| 143 |
+
"Please provide --tokenizer argument with path to tokenizer file."
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
# Load tokenizer
|
| 147 |
+
console = Console()
|
| 148 |
+
console.print("\n[bold cyan]Loading tokenizer...[/bold cyan]")
|
| 149 |
+
tokenizer = Inferencer._load_tokenizer(tokenizer_path)
|
| 150 |
+
Inferencer._print_tokenizer_info(tokenizer, str(tokenizer_path))
|
| 151 |
+
|
| 152 |
+
# Reconstruct model config
|
| 153 |
+
from taoTrain.config import ModelConfig
|
| 154 |
+
model_config = ModelConfig(**config_dict.get("model", {}))
|
| 155 |
+
|
| 156 |
+
# Create and load model
|
| 157 |
+
# CheckpointManager.load() normalizes to 'model_state' key
|
| 158 |
+
from taoTrain.models import get_model
|
| 159 |
+
model = get_model(model_config, device=device)
|
| 160 |
+
model.load_state_dict(checkpoint["model_state"])
|
| 161 |
+
|
| 162 |
+
return Inferencer(model, tokenizer, device)
|
| 163 |
+
|
| 164 |
+
def generate(
|
| 165 |
+
self,
|
| 166 |
+
prompt: str,
|
| 167 |
+
max_length: int = 256,
|
| 168 |
+
temperature: float = 0.7,
|
| 169 |
+
top_p: float = 0.95,
|
| 170 |
+
top_k: Optional[int] = None,
|
| 171 |
+
repetition_penalty: float = 1.0,
|
| 172 |
+
do_sample: bool = True,
|
| 173 |
+
stream: bool = False,
|
| 174 |
+
) -> str | Iterator[str]:
|
| 175 |
+
"""
|
| 176 |
+
Generate text from a prompt.
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
prompt: Input prompt
|
| 180 |
+
max_length: Maximum generation length
|
| 181 |
+
temperature: Temperature for sampling
|
| 182 |
+
top_p: Nucleus sampling parameter
|
| 183 |
+
top_k: Top-k sampling parameter
|
| 184 |
+
repetition_penalty: Penalty for repeated tokens (1.0 = no penalty, >1.0 = penalize)
|
| 185 |
+
do_sample: Whether to sample or use greedy decoding
|
| 186 |
+
stream: Whether to stream tokens
|
| 187 |
+
|
| 188 |
+
Yields/Returns:
|
| 189 |
+
Generated text (or stream of tokens if stream=True)
|
| 190 |
+
"""
|
| 191 |
+
# Tokenize prompt
|
| 192 |
+
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
|
| 193 |
+
prompt_length = input_ids.shape[1]
|
| 194 |
+
|
| 195 |
+
# For streaming with full context decoding
|
| 196 |
+
generated_token_ids = [] # Accumulate all generated tokens
|
| 197 |
+
last_decoded_full = "" # Cache full decoded text from previous step
|
| 198 |
+
|
| 199 |
+
with torch.no_grad():
|
| 200 |
+
for step in range(max_length):
|
| 201 |
+
# Forward pass
|
| 202 |
+
outputs = self.model(
|
| 203 |
+
input_ids=input_ids,
|
| 204 |
+
attention_mask=None,
|
| 205 |
+
labels=None,
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
logits = outputs["logits"]
|
| 209 |
+
|
| 210 |
+
# Get logits for next token
|
| 211 |
+
next_logits = logits[:, -1, :] / temperature
|
| 212 |
+
|
| 213 |
+
# Apply repetition penalty to previously generated tokens
|
| 214 |
+
if repetition_penalty != 1.0:
|
| 215 |
+
generated_ids = input_ids[0, prompt_length:]
|
| 216 |
+
unique_ids = torch.unique(generated_ids)
|
| 217 |
+
for token_id in unique_ids:
|
| 218 |
+
next_logits[0, token_id] /= repetition_penalty
|
| 219 |
+
|
| 220 |
+
# Apply top-k and top-p sampling
|
| 221 |
+
if top_k is not None:
|
| 222 |
+
indices_to_remove = next_logits < torch.topk(next_logits, top_k)[0][..., -1, None]
|
| 223 |
+
next_logits[indices_to_remove] = float('-inf')
|
| 224 |
+
|
| 225 |
+
if top_p < 1.0:
|
| 226 |
+
sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
|
| 227 |
+
probs = torch.softmax(sorted_logits, dim=-1)
|
| 228 |
+
cumsum_probs = torch.cumsum(probs, dim=-1)
|
| 229 |
+
|
| 230 |
+
sorted_indices_to_remove = cumsum_probs > top_p
|
| 231 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 232 |
+
sorted_indices_to_remove[..., 0] = False
|
| 233 |
+
|
| 234 |
+
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
| 235 |
+
next_logits[:, indices_to_remove] = float('-inf')
|
| 236 |
+
|
| 237 |
+
# Sample or greedy
|
| 238 |
+
probs = torch.softmax(next_logits, dim=-1)
|
| 239 |
+
|
| 240 |
+
if do_sample:
|
| 241 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 242 |
+
else:
|
| 243 |
+
next_token = torch.argmax(next_logits, dim=-1, keepdim=True)
|
| 244 |
+
|
| 245 |
+
# Append to input
|
| 246 |
+
input_ids = torch.cat([input_ids, next_token], dim=-1)
|
| 247 |
+
|
| 248 |
+
# Stream if requested (with full context decoding to preserve spaces)
|
| 249 |
+
if stream:
|
| 250 |
+
# Accumulate the generated token ID
|
| 251 |
+
generated_token_ids.append(next_token.item())
|
| 252 |
+
# Decode entire accumulated sequence (tokenizer has full context)
|
| 253 |
+
full_decoded_text = self.tokenizer.decode(generated_token_ids)
|
| 254 |
+
# Extract only NEW text since last yield
|
| 255 |
+
new_text = full_decoded_text[len(last_decoded_full):]
|
| 256 |
+
if new_text:
|
| 257 |
+
yield new_text
|
| 258 |
+
last_decoded_full = full_decoded_text
|
| 259 |
+
|
| 260 |
+
# Stop on EOS
|
| 261 |
+
if next_token.item() == self.tokenizer.eos_token_id:
|
| 262 |
+
break
|
| 263 |
+
|
| 264 |
+
if not stream:
|
| 265 |
+
# Return full generated text
|
| 266 |
+
generated_ids = input_ids[0, prompt_length:]
|
| 267 |
+
return self.tokenizer.decode(generated_ids, skip_special_tokens=True)
|
| 268 |
+
|
| 269 |
+
def count_tokens_generated(
|
| 270 |
+
self,
|
| 271 |
+
prompt: str,
|
| 272 |
+
max_length: int = 256,
|
| 273 |
+
) -> torch.Tensor:
|
| 274 |
+
"""
|
| 275 |
+
Measure generation speed (tokens per second).
|
| 276 |
+
|
| 277 |
+
Args:
|
| 278 |
+
prompt: Input prompt
|
| 279 |
+
max_length: Maximum generation length
|
| 280 |
+
|
| 281 |
+
Returns:
|
| 282 |
+
Number of tokens generated
|
| 283 |
+
"""
|
| 284 |
+
import time
|
| 285 |
+
|
| 286 |
+
start = time.time()
|
| 287 |
+
|
| 288 |
+
# Generate (we'll just do one forward pass to measure)
|
| 289 |
+
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
|
| 290 |
+
|
| 291 |
+
with torch.no_grad():
|
| 292 |
+
outputs = self.model(
|
| 293 |
+
input_ids=input_ids,
|
| 294 |
+
attention_mask=None,
|
| 295 |
+
labels=None,
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
elapsed = time.time() - start
|
| 299 |
+
tokens_per_sec = (input_ids.shape[1] + 1) / elapsed
|
| 300 |
+
|
| 301 |
+
return tokens_per_sec
|
code/TaoTrain/src/taoTrain/inference/tui.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""TUI (Terminal User Interface) for interactive chat."""
|
| 2 |
+
|
| 3 |
+
import sys
|
| 4 |
+
import time
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Optional
|
| 7 |
+
import click
|
| 8 |
+
from rich.console import Console
|
| 9 |
+
from rich.markdown import Markdown
|
| 10 |
+
from rich.panel import Panel
|
| 11 |
+
from rich.text import Text
|
| 12 |
+
from rich.table import Table
|
| 13 |
+
from textual.app import ComposeResult, RenderableType
|
| 14 |
+
from textual.containers import Container, Horizontal, Vertical
|
| 15 |
+
from textual.widgets import TextArea, Static, Button
|
| 16 |
+
from textual.binding import Binding
|
| 17 |
+
|
| 18 |
+
from taoTrain.inference import Inferencer
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class TokensPerSecDisplay(Static):
|
| 22 |
+
"""Display tokens per second metric."""
|
| 23 |
+
|
| 24 |
+
DEFAULT_CSS = """
|
| 25 |
+
TokensPerSecDisplay {
|
| 26 |
+
width: 100%;
|
| 27 |
+
height: 1;
|
| 28 |
+
background: $panel;
|
| 29 |
+
border: solid $accent;
|
| 30 |
+
}
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(self, tps: float = 0.0):
|
| 34 |
+
"""Initialize."""
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.tps = tps
|
| 37 |
+
|
| 38 |
+
def render(self) -> RenderableType:
|
| 39 |
+
"""Render TPS display."""
|
| 40 |
+
text = f"Tokens/sec: {self.tps:.2f}"
|
| 41 |
+
return Text(text, style="bold cyan")
|
| 42 |
+
|
| 43 |
+
def update_tps(self, tps: float):
|
| 44 |
+
"""Update TPS value."""
|
| 45 |
+
self.tps = tps
|
| 46 |
+
self.update()
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class SimpleChat:
|
| 50 |
+
"""Simple CLI-based chat interface (fallback for testing)."""
|
| 51 |
+
|
| 52 |
+
def __init__(self, checkpoint_path: str | Path, tokenizer_path: Optional[str | Path] = None):
|
| 53 |
+
"""Initialize chat."""
|
| 54 |
+
self.checkpoint_path = Path(checkpoint_path)
|
| 55 |
+
self.tokenizer_path = tokenizer_path
|
| 56 |
+
|
| 57 |
+
print("\nLoading model...")
|
| 58 |
+
self.inferencer = Inferencer.load_from_checkpoint(
|
| 59 |
+
self.checkpoint_path,
|
| 60 |
+
tokenizer_path=self.tokenizer_path,
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# Print model info
|
| 64 |
+
console = Console()
|
| 65 |
+
info_table = Table(title="Model Information")
|
| 66 |
+
info_table.add_column("Property", style="cyan")
|
| 67 |
+
info_table.add_column("Value", style="green")
|
| 68 |
+
|
| 69 |
+
info_table.add_row("Checkpoint", str(self.checkpoint_path))
|
| 70 |
+
if self.tokenizer_path:
|
| 71 |
+
info_table.add_row("Tokenizer (override)", str(self.tokenizer_path))
|
| 72 |
+
|
| 73 |
+
console.print(info_table)
|
| 74 |
+
|
| 75 |
+
def run(self):
|
| 76 |
+
"""Run chat loop."""
|
| 77 |
+
console = Console()
|
| 78 |
+
|
| 79 |
+
console.print("\n[bold cyan]Chat Interface[/bold cyan]")
|
| 80 |
+
console.print("[dim]Type 'exit' or 'quit' to exit[/dim]\n")
|
| 81 |
+
|
| 82 |
+
while True:
|
| 83 |
+
try:
|
| 84 |
+
# Get user input
|
| 85 |
+
prompt = input("You: ").strip()
|
| 86 |
+
|
| 87 |
+
if prompt.lower() in ["exit", "quit"]:
|
| 88 |
+
console.print("\n[yellow]Goodbye![/yellow]")
|
| 89 |
+
break
|
| 90 |
+
|
| 91 |
+
if not prompt:
|
| 92 |
+
continue
|
| 93 |
+
|
| 94 |
+
# Generate response
|
| 95 |
+
console.print("\n[bold cyan]Assistant:[/bold cyan] ", end="")
|
| 96 |
+
|
| 97 |
+
start_time = time.time()
|
| 98 |
+
token_count = 0
|
| 99 |
+
|
| 100 |
+
# Stream generation
|
| 101 |
+
for token in self.inferencer.generate(
|
| 102 |
+
prompt,
|
| 103 |
+
max_length=256,
|
| 104 |
+
temperature=0.7,
|
| 105 |
+
top_p=0.95,
|
| 106 |
+
repetition_penalty=10,
|
| 107 |
+
stream=True,
|
| 108 |
+
):
|
| 109 |
+
console.print(token, end="", soft_wrap=True)
|
| 110 |
+
token_count += 1
|
| 111 |
+
|
| 112 |
+
elapsed = time.time() - start_time
|
| 113 |
+
tps = token_count / elapsed if elapsed > 0 else 0
|
| 114 |
+
|
| 115 |
+
console.print(f"\n\n[dim]({tps:.1f} tokens/sec, {token_count} tokens)[/dim]\n")
|
| 116 |
+
|
| 117 |
+
except KeyboardInterrupt:
|
| 118 |
+
console.print("\n\n[yellow]Chat interrupted.[/yellow]")
|
| 119 |
+
break
|
| 120 |
+
except Exception as e:
|
| 121 |
+
console.print(f"\n[red]Error: {e}[/red]\n")
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
@click.command()
|
| 125 |
+
@click.option(
|
| 126 |
+
"--model",
|
| 127 |
+
type=click.Path(exists=True),
|
| 128 |
+
required=True,
|
| 129 |
+
help="Path to model checkpoint (.pt file)",
|
| 130 |
+
)
|
| 131 |
+
@click.option(
|
| 132 |
+
"--tokenizer",
|
| 133 |
+
type=click.Path(exists=True),
|
| 134 |
+
required=False,
|
| 135 |
+
default=None,
|
| 136 |
+
help="Path to tokenizer file (.model or HuggingFace path). If not provided, uses tokenizer_path from checkpoint config.",
|
| 137 |
+
)
|
| 138 |
+
def main(model: str, tokenizer: Optional[str]):
|
| 139 |
+
"""
|
| 140 |
+
Interactive TUI chat with a trained model.
|
| 141 |
+
|
| 142 |
+
Example:
|
| 143 |
+
tui-chat --model checkpoints/best_model.pt
|
| 144 |
+
tui-chat --model checkpoints/best_model.pt --tokenizer path/to/tokenizer.model
|
| 145 |
+
"""
|
| 146 |
+
try:
|
| 147 |
+
chat = SimpleChat(model, tokenizer_path=tokenizer)
|
| 148 |
+
chat.run()
|
| 149 |
+
except FileNotFoundError:
|
| 150 |
+
click.echo(f"Error: Model file not found: {model}", err=True)
|
| 151 |
+
sys.exit(1)
|
| 152 |
+
except ValueError as e:
|
| 153 |
+
click.echo(f"Error: {e}", err=True)
|
| 154 |
+
sys.exit(1)
|
| 155 |
+
except Exception as e:
|
| 156 |
+
click.echo(f"Error: {e}", err=True)
|
| 157 |
+
sys.exit(1)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
if __name__ == "__main__":
|
| 161 |
+
main() # type: ignore
|
code/TaoTrain/src/taoTrain/logging/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Logging integrations."""
|
| 2 |
+
|
| 3 |
+
from .aim_logger import AimLogger
|
| 4 |
+
|
| 5 |
+
__all__ = ["AimLogger"]
|
code/TaoTrain/src/taoTrain/logging/aim_logger.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""AimStack logging integration."""
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Dict, Any, Optional
|
| 5 |
+
import subprocess
|
| 6 |
+
import json
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
from aim import Run
|
| 11 |
+
HAS_AIM = True
|
| 12 |
+
except ImportError:
|
| 13 |
+
HAS_AIM = False
|
| 14 |
+
|
| 15 |
+
from taoTrain.config import TrainingConfig
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class AimLogger:
|
| 19 |
+
"""AimStack logger for tracking training metrics and hyperparameters."""
|
| 20 |
+
|
| 21 |
+
def __init__(self, config: TrainingConfig):
|
| 22 |
+
"""
|
| 23 |
+
Initialize AimStack logger.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
config: Training configuration
|
| 27 |
+
"""
|
| 28 |
+
self.config = config
|
| 29 |
+
self.run: Optional[Run] = None
|
| 30 |
+
|
| 31 |
+
if HAS_AIM:
|
| 32 |
+
# Initialize AimStack run
|
| 33 |
+
repo_path = Path(config.aim_repo)
|
| 34 |
+
repo_path.mkdir(parents=True, exist_ok=True)
|
| 35 |
+
|
| 36 |
+
self.run = Run(repo=str(repo_path))
|
| 37 |
+
|
| 38 |
+
# Log hyperparameters
|
| 39 |
+
self._log_hyperparameters()
|
| 40 |
+
else:
|
| 41 |
+
print("Warning: AimStack not installed. Install with: pip install aim")
|
| 42 |
+
|
| 43 |
+
def _log_hyperparameters(self):
|
| 44 |
+
"""Log hyperparameters to AimStack."""
|
| 45 |
+
if self.run is None:
|
| 46 |
+
return
|
| 47 |
+
|
| 48 |
+
# Log model config
|
| 49 |
+
self.run["hparams/model"] = {
|
| 50 |
+
"architecture": self.config.model.architecture_type.value,
|
| 51 |
+
"vocab_size": self.config.model.vocab_size,
|
| 52 |
+
"hidden_dim": self.config.model.hidden_dim,
|
| 53 |
+
"num_layers": self.config.model.num_layers,
|
| 54 |
+
"num_heads": self.config.model.num_heads,
|
| 55 |
+
"dropout": self.config.model.dropout,
|
| 56 |
+
"max_seq_length": self.config.model.max_seq_length,
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
# Log training config
|
| 60 |
+
self.run["hparams/training"] = {
|
| 61 |
+
"batch_size": self.config.batch_size,
|
| 62 |
+
"num_epochs": self.config.num_epochs,
|
| 63 |
+
"learning_rate": self.config.optimizer.learning_rate,
|
| 64 |
+
"weight_decay": self.config.optimizer.weight_decay,
|
| 65 |
+
"gradient_accumulation_steps": self.config.gradient_accumulation_steps,
|
| 66 |
+
"max_grad_norm": self.config.max_grad_norm,
|
| 67 |
+
"dtype": self.config.dtype.value,
|
| 68 |
+
"seed": self.config.seed,
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
# Log optimizer and scheduler config
|
| 72 |
+
self.run["hparams/optimizer"] = {
|
| 73 |
+
"optimizer_type": self.config.optimizer.optimizer_type.value,
|
| 74 |
+
"learning_rate": self.config.optimizer.learning_rate,
|
| 75 |
+
"weight_decay": self.config.optimizer.weight_decay,
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
self.run["hparams/scheduler"] = {
|
| 79 |
+
"scheduler_type": self.config.scheduler.scheduler_type.value,
|
| 80 |
+
"warmup_steps": self.config.scheduler.warmup_steps,
|
| 81 |
+
"warmup_ratio": self.config.scheduler.warmup_ratio,
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
# Log dataset config
|
| 85 |
+
self.run["hparams/dataset"] = {
|
| 86 |
+
"dataset_name": self.config.dataset.dataset_name,
|
| 87 |
+
"split": self.config.dataset.split,
|
| 88 |
+
"max_samples": self.config.dataset.max_samples,
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
# Log mode
|
| 92 |
+
self.run["hparams/mode"] = self.config.mode.value
|
| 93 |
+
|
| 94 |
+
# Log git hash if available
|
| 95 |
+
try:
|
| 96 |
+
git_hash = subprocess.check_output(
|
| 97 |
+
["git", "rev-parse", "HEAD"],
|
| 98 |
+
stderr=subprocess.DEVNULL
|
| 99 |
+
).decode().strip()
|
| 100 |
+
self.run["hparams/git_hash"] = git_hash
|
| 101 |
+
except:
|
| 102 |
+
pass
|
| 103 |
+
|
| 104 |
+
# Log timestamp
|
| 105 |
+
self.run["hparams/timestamp"] = datetime.now().isoformat()
|
| 106 |
+
|
| 107 |
+
def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None):
|
| 108 |
+
"""
|
| 109 |
+
Log metrics to AimStack.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
metrics: Dict of metric names to values
|
| 113 |
+
step: Global step (optional, auto-increments if not provided)
|
| 114 |
+
"""
|
| 115 |
+
if self.run is None:
|
| 116 |
+
return
|
| 117 |
+
|
| 118 |
+
step = metrics.pop("step", step)
|
| 119 |
+
|
| 120 |
+
for metric_name, metric_value in metrics.items():
|
| 121 |
+
# Flatten nested dicts
|
| 122 |
+
if isinstance(metric_value, dict):
|
| 123 |
+
for nested_key, nested_val in metric_value.items():
|
| 124 |
+
self.run.track(
|
| 125 |
+
float(nested_val),
|
| 126 |
+
name=f"{metric_name}/{nested_key}",
|
| 127 |
+
step=step,
|
| 128 |
+
)
|
| 129 |
+
else:
|
| 130 |
+
try:
|
| 131 |
+
self.run.track(
|
| 132 |
+
float(metric_value),
|
| 133 |
+
name=metric_name,
|
| 134 |
+
step=step,
|
| 135 |
+
)
|
| 136 |
+
except (ValueError, TypeError):
|
| 137 |
+
# Skip non-numeric metrics
|
| 138 |
+
pass
|
| 139 |
+
|
| 140 |
+
def log_text(self, name: str, value: str, step: Optional[int] = None):
|
| 141 |
+
"""Log text content."""
|
| 142 |
+
if self.run is None:
|
| 143 |
+
return
|
| 144 |
+
|
| 145 |
+
# AimStack doesn't have direct text logging, use metadata
|
| 146 |
+
metadata = getattr(self.run, '_metadata', {})
|
| 147 |
+
if isinstance(metadata, dict):
|
| 148 |
+
metadata[name] = value
|
| 149 |
+
|
| 150 |
+
def finish(self):
|
| 151 |
+
"""Finish the run."""
|
| 152 |
+
if self.run:
|
| 153 |
+
self.run.close()
|
code/TaoTrain/src/taoTrain/models/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Model architectures and registry."""
|
| 2 |
+
|
| 3 |
+
from .registry import get_model, register_architecture
|
| 4 |
+
|
| 5 |
+
__all__ = ["get_model", "register_architecture"]
|
code/TaoTrain/src/taoTrain/models/embeddings.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Low-Rank Factorized Embedding.
|
| 3 |
+
|
| 4 |
+
Uses standard nn.Linear for projection (NOT ternary quantization).
|
| 5 |
+
Embeddings should use full precision for good token representations.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class FactorizedEmbedding(nn.Module):
|
| 13 |
+
"""
|
| 14 |
+
Low-Rank Factorized Embedding: vocab → d_embed_rank → d_model
|
| 15 |
+
|
| 16 |
+
Uses standard Linear layers (no quantization) for full precision.
|
| 17 |
+
Reduces embedding parameters from vocab_size × d_model to:
|
| 18 |
+
vocab_size × d_embed_rank + d_embed_rank × d_model
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, vocab_size, d_model, d_embed_rank=96):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.vocab_size = vocab_size
|
| 24 |
+
self.d_model = d_model
|
| 25 |
+
self.d_embed_rank = d_embed_rank
|
| 26 |
+
|
| 27 |
+
# Embedding table: vocab → compressed rank
|
| 28 |
+
self.embed = nn.Embedding(vocab_size, d_embed_rank)
|
| 29 |
+
|
| 30 |
+
# Projection: compressed → full (standard Linear)
|
| 31 |
+
self.proj = nn.Linear(d_embed_rank, d_model, bias=False)
|
| 32 |
+
|
| 33 |
+
# Initialize with small weights for stable training
|
| 34 |
+
nn.init.normal_(self.embed.weight, mean=0.0, std=0.02)
|
| 35 |
+
nn.init.normal_(self.proj.weight, mean=0.0, std=0.02)
|
| 36 |
+
|
| 37 |
+
def forward(self, input_ids):
|
| 38 |
+
"""
|
| 39 |
+
Args:
|
| 40 |
+
input_ids: [batch_size, seq_len] tensor of token IDs
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
embeddings: [batch_size, seq_len, d_model]
|
| 44 |
+
"""
|
| 45 |
+
x = self.embed(input_ids) # [B, S, d_embed_rank]
|
| 46 |
+
x = self.proj(x) # [B, S, d_model]
|
| 47 |
+
return x
|
| 48 |
+
|
| 49 |
+
def get_num_params(self):
|
| 50 |
+
"""Return total number of parameters."""
|
| 51 |
+
return self.vocab_size * self.d_embed_rank + self.d_embed_rank * self.d_model
|
code/TaoTrain/src/taoTrain/models/mla_components.py
ADDED
|
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
DeepSeek-style Multi-head Latent Attention (MLA) with RoPE.
|
| 3 |
+
|
| 4 |
+
Key innovations:
|
| 5 |
+
1. KV compression to latent space (reduce KV memory)
|
| 6 |
+
2. Q stays in full dimension for expressive query space
|
| 7 |
+
3. RoPE positional embeddings on Q and K
|
| 8 |
+
4. Grouped Query Attention (GQA) for efficiency
|
| 9 |
+
5. Learnable head combination weights
|
| 10 |
+
6. Numerical stability via pre-norm and scaling
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
import math
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _residual_rms_norm(x, enabled=False, target=1.0, eps=1e-6, cap=None):
|
| 20 |
+
if not enabled and cap is None:
|
| 21 |
+
return x
|
| 22 |
+
rms = x.float().square().mean(dim=-1, keepdim=True).add(eps).sqrt()
|
| 23 |
+
if enabled:
|
| 24 |
+
scale = target / rms
|
| 25 |
+
else:
|
| 26 |
+
cap_tensor = torch.tensor(float(cap), dtype=rms.dtype, device=rms.device)
|
| 27 |
+
scale = torch.minimum(torch.ones_like(rms), cap_tensor / rms)
|
| 28 |
+
return x * scale.to(dtype=x.dtype)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class RotaryEmbedding(nn.Module):
|
| 32 |
+
"""Rotary position embeddings used in RoPE with optional YaRN extension.
|
| 33 |
+
|
| 34 |
+
YaRN (Yet another RoPE eXtension) allows context length interpolation via
|
| 35 |
+
frequency scaling. When yarn_alpha != 1.0 or seq_len > max_seq_length,
|
| 36 |
+
frequencies are dynamically scaled to support longer sequences.
|
| 37 |
+
|
| 38 |
+
Parameters:
|
| 39 |
+
dim: Embedding dimension (must be even)
|
| 40 |
+
rope_scale: Base RoPE scale factor (default: 40)
|
| 41 |
+
max_seq_length: Original trained sequence length (default: 1024)
|
| 42 |
+
yarn_alpha: YaRN interpolation factor (default: 1.0, no interpolation)
|
| 43 |
+
- values < 1.0: aggressive interpolation (faster context expansion)
|
| 44 |
+
- values > 1.0: conservative interpolation (safer)
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
def __init__(self, dim, rope_scale=40.0, max_seq_length=1024, yarn_alpha=1.0):
|
| 48 |
+
super().__init__()
|
| 49 |
+
assert dim % 2 == 0, "Dimension must be even for rotary embeddings"
|
| 50 |
+
self.dim = dim
|
| 51 |
+
self.rope_scale = rope_scale
|
| 52 |
+
self.max_seq_length = max_seq_length
|
| 53 |
+
self.yarn_alpha = yarn_alpha
|
| 54 |
+
|
| 55 |
+
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
|
| 56 |
+
self.register_buffer("inv_freq", inv_freq)
|
| 57 |
+
|
| 58 |
+
def _apply_yarn_scaling(self, freqs, seq_len):
|
| 59 |
+
"""Apply YaRN frequency scaling for context extension.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
freqs: [seq_len, dim] frequency tensor
|
| 63 |
+
seq_len: Current sequence length
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
Scaled freqs if yarn is enabled and seq_len > max_seq_length, else original freqs
|
| 67 |
+
"""
|
| 68 |
+
# Only apply scaling if sequence exceeds training length or yarn_alpha != 1.0
|
| 69 |
+
if self.yarn_alpha == 1.0 and seq_len <= self.max_seq_length:
|
| 70 |
+
return freqs
|
| 71 |
+
|
| 72 |
+
# YaRN scaling factor: interpolate frequency reduction
|
| 73 |
+
# scale_factor = (seq_len / max_seq_length) ** (1 / yarn_alpha)
|
| 74 |
+
# Scales down frequencies to fit longer context while maintaining position distinctions
|
| 75 |
+
scale_factor = (seq_len / self.max_seq_length) ** (1.0 / self.yarn_alpha)
|
| 76 |
+
freqs = freqs / scale_factor
|
| 77 |
+
return freqs
|
| 78 |
+
|
| 79 |
+
def forward(self, seq_len, device):
|
| 80 |
+
"""Generate rotary embeddings for sequence with optional YaRN scaling.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
seq_len: Current sequence length
|
| 84 |
+
device: Device to create embeddings on
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
[seq_len, 2*dim] rotary embeddings (duplicated freqs)
|
| 88 |
+
"""
|
| 89 |
+
t = torch.arange(seq_len, device=device).type_as(self.inv_freq) / self.rope_scale
|
| 90 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq) # [seq_len, dim//2]
|
| 91 |
+
|
| 92 |
+
# Apply YaRN frequency scaling if enabled
|
| 93 |
+
freqs = self._apply_yarn_scaling(freqs, seq_len)
|
| 94 |
+
|
| 95 |
+
return torch.cat((freqs, freqs), dim=-1) # [seq_len, dim]
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def rotate_half(x):
|
| 99 |
+
"""Rotate half the hidden dims of the input."""
|
| 100 |
+
x1, x2 = x.chunk(2, dim=-1)
|
| 101 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def apply_rotary(x, cos, sin):
|
| 105 |
+
"""Apply rotary embeddings to input tensor.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
x: [B, n_heads, seq_len, head_dim] or similar
|
| 109 |
+
cos: [seq_len, head_dim] or [1, 1, seq_len, head_dim]
|
| 110 |
+
sin: [seq_len, head_dim] or [1, 1, seq_len, head_dim]
|
| 111 |
+
"""
|
| 112 |
+
# Ensure cos/sin have the right dimensions for broadcasting
|
| 113 |
+
if cos.dim() == 2:
|
| 114 |
+
cos = cos.unsqueeze(0).unsqueeze(0)
|
| 115 |
+
sin = sin.unsqueeze(0).unsqueeze(0)
|
| 116 |
+
|
| 117 |
+
# Handle case where cos/sin may be shorter than x
|
| 118 |
+
cos = cos[..., :x.shape[-1]]
|
| 119 |
+
sin = sin[..., :x.shape[-1]]
|
| 120 |
+
|
| 121 |
+
# Split x based on cos dimensions
|
| 122 |
+
x_rot = x[..., :cos.shape[-1]]
|
| 123 |
+
x_base = x[..., cos.shape[-1]:]
|
| 124 |
+
|
| 125 |
+
# Apply rotation
|
| 126 |
+
x_rot = (x_rot * cos) + (rotate_half(x_rot) * sin)
|
| 127 |
+
|
| 128 |
+
# Concatenate rotated and base parts
|
| 129 |
+
return torch.cat([x_rot, x_base], dim=-1) if x_base.shape[-1] > 0 else x_rot
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class DeepSeekMLA(nn.Module):
|
| 133 |
+
"""
|
| 134 |
+
DeepSeek-style Multi-head Latent Attention (MLA).
|
| 135 |
+
|
| 136 |
+
Architecture:
|
| 137 |
+
1. Project input to Query: [B, seq_len, d_model] -> [B, seq_len, d_model]
|
| 138 |
+
2. Compress to KV latent: [B, seq_len, d_model] -> [B, seq_len, d_latent_kv]
|
| 139 |
+
3. Split into heads for attention
|
| 140 |
+
4. Apply RoPE to Q and K
|
| 141 |
+
5. Compute attention scores: (Q @ K^T) / sqrt(d_head)
|
| 142 |
+
6. Apply softmax and combine with values
|
| 143 |
+
7. Concatenate heads and project back to d_model
|
| 144 |
+
|
| 145 |
+
Parameters:
|
| 146 |
+
d_model: Model dimension
|
| 147 |
+
d_latent_kv: Latent dimension for KV compression
|
| 148 |
+
n_heads: Number of attention heads
|
| 149 |
+
d_rope: Dimension for RoPE (usually == d_head_dim)
|
| 150 |
+
dropout: Dropout probability
|
| 151 |
+
gqa_groups: Grouped Query Attention groups (1 = standard MLA, >1 = GQA)
|
| 152 |
+
"""
|
| 153 |
+
|
| 154 |
+
def __init__(self, d_model, d_latent_kv, n_heads, d_rope, dropout=0.1, gqa_groups=1,
|
| 155 |
+
rope_scale=40.0, max_seq_length=1024, yarn_alpha=1.0):
|
| 156 |
+
super().__init__()
|
| 157 |
+
self.d_model = d_model
|
| 158 |
+
self.d_latent_kv = d_latent_kv
|
| 159 |
+
self.n_heads = n_heads
|
| 160 |
+
self.d_rope = d_rope
|
| 161 |
+
self.gqa_groups = gqa_groups
|
| 162 |
+
|
| 163 |
+
assert d_model % n_heads == 0, f"d_model ({d_model}) must be divisible by n_heads ({n_heads})"
|
| 164 |
+
assert d_latent_kv % n_heads == 0, f"d_latent_kv ({d_latent_kv}) must be divisible by n_heads ({n_heads})"
|
| 165 |
+
|
| 166 |
+
self.d_head_full = d_model // n_heads # Full head dimension for Q
|
| 167 |
+
self.d_head_latent = d_latent_kv // n_heads # Latent head dimension for K/V
|
| 168 |
+
|
| 169 |
+
# Scaling factor for attention scores
|
| 170 |
+
self.scale = 1.0 / math.sqrt(self.d_head_latent)
|
| 171 |
+
|
| 172 |
+
# Layer norm before attention for stability
|
| 173 |
+
self.norm = nn.LayerNorm(d_model)
|
| 174 |
+
|
| 175 |
+
# Q projection: d_model -> d_model (full dimension)
|
| 176 |
+
self.q_proj = nn.Linear(d_model, d_model, bias=False)
|
| 177 |
+
|
| 178 |
+
# K/V projections: d_model -> d_latent_kv (compressed)
|
| 179 |
+
self.k_proj = nn.Linear(d_model, d_latent_kv, bias=False)
|
| 180 |
+
self.v_proj = nn.Linear(d_model, d_latent_kv, bias=False)
|
| 181 |
+
|
| 182 |
+
# RoPE for position encoding with YaRN support
|
| 183 |
+
self.rotary = RotaryEmbedding(
|
| 184 |
+
d_rope,
|
| 185 |
+
rope_scale=rope_scale,
|
| 186 |
+
max_seq_length=max_seq_length,
|
| 187 |
+
yarn_alpha=yarn_alpha
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
# Output projection: d_latent_kv -> d_model
|
| 191 |
+
self.out_proj = nn.Linear(d_latent_kv, d_model, bias=False)
|
| 192 |
+
|
| 193 |
+
# Head combination weights (learnable scaling per head)
|
| 194 |
+
self.head_weights = nn.Parameter(torch.ones(n_heads))
|
| 195 |
+
|
| 196 |
+
# Dropout
|
| 197 |
+
self.attn_dropout = nn.Dropout(dropout)
|
| 198 |
+
self.proj_dropout = nn.Dropout(dropout)
|
| 199 |
+
|
| 200 |
+
def forward(self, x, attention_mask=None):
|
| 201 |
+
"""
|
| 202 |
+
Args:
|
| 203 |
+
x: [B, seq_len, d_model]
|
| 204 |
+
attention_mask: [B, seq_len] (1 = keep, 0 = mask) or
|
| 205 |
+
[B, 1, seq_len, seq_len] (causal mask)
|
| 206 |
+
|
| 207 |
+
Returns:
|
| 208 |
+
out: [B, seq_len, d_model]
|
| 209 |
+
"""
|
| 210 |
+
B, seq_len, _ = x.shape
|
| 211 |
+
device = x.device
|
| 212 |
+
|
| 213 |
+
# Pre-norm
|
| 214 |
+
x_norm = self.norm(x)
|
| 215 |
+
|
| 216 |
+
# Project to Q, K, V spaces
|
| 217 |
+
q = self.q_proj(x_norm) # [B, seq_len, d_model]
|
| 218 |
+
k = self.k_proj(x_norm) # [B, seq_len, d_latent_kv]
|
| 219 |
+
v = self.v_proj(x_norm) # [B, seq_len, d_latent_kv]
|
| 220 |
+
|
| 221 |
+
# ────────────────────────────────────────────────────────────────────────
|
| 222 |
+
# Reshape into multi-head format
|
| 223 |
+
# ────────────────────────────────────────────────────────────────────────
|
| 224 |
+
# Q: [B, seq_len, d_model] -> [B, seq_len, n_heads, d_head_full] -> [B, n_heads, seq_len, d_head_full]
|
| 225 |
+
q = q.view(B, seq_len, self.n_heads, self.d_head_full).transpose(1, 2)
|
| 226 |
+
|
| 227 |
+
# K: [B, seq_len, d_latent_kv] -> [B, seq_len, n_heads, d_head_latent] -> [B, n_heads, seq_len, d_head_latent]
|
| 228 |
+
k = k.view(B, seq_len, self.n_heads, self.d_head_latent).transpose(1, 2)
|
| 229 |
+
|
| 230 |
+
# V: [B, seq_len, d_latent_kv] -> [B, seq_len, n_heads, d_head_latent] -> [B, n_heads, seq_len, d_head_latent]
|
| 231 |
+
v = v.view(B, seq_len, self.n_heads, self.d_head_latent).transpose(1, 2)
|
| 232 |
+
|
| 233 |
+
# ────────────────────────────────────────────────────────────────────────
|
| 234 |
+
# Apply RoPE to Q and K
|
| 235 |
+
# ─────────────────────────────────��──────────────────────────────────────
|
| 236 |
+
if self.d_rope > 0:
|
| 237 |
+
# Generate RoPE embeddings: [seq_len, d_rope]
|
| 238 |
+
rotary_emb = self.rotary(seq_len, device) # [seq_len, d_rope]
|
| 239 |
+
cos = torch.cos(rotary_emb).unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, d_rope]
|
| 240 |
+
sin = torch.sin(rotary_emb).unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, d_rope]
|
| 241 |
+
|
| 242 |
+
# Apply RoPE to Q (only on first d_rope dimensions)
|
| 243 |
+
q_rope = apply_rotary(q[..., :self.d_rope], cos, sin) # [B, n_heads, seq_len, d_rope]
|
| 244 |
+
q = torch.cat([q_rope, q[..., self.d_rope:]], dim=-1) # Combine with remaining dims
|
| 245 |
+
|
| 246 |
+
# Apply RoPE to K (only on first d_rope dimensions)
|
| 247 |
+
k_rope = apply_rotary(k[..., :self.d_rope], cos, sin) # [B, n_heads, seq_len, d_rope]
|
| 248 |
+
k = torch.cat([k_rope, k[..., self.d_rope:]], dim=-1) # Combine with remaining dims
|
| 249 |
+
|
| 250 |
+
# ────────────────────────────────────────────────────────────────────────
|
| 251 |
+
# Compute attention using PyTorch 2.0+ fused scaled_dot_product_attention
|
| 252 |
+
# ────────────────────────────────────────────────────────────────────────
|
| 253 |
+
# Only use first d_head_latent dimensions of Q for attention
|
| 254 |
+
# K and V are already d_head_latent dimension
|
| 255 |
+
q_for_attn = q[..., :self.d_head_latent] # [B, n_heads, seq_len, d_head_latent]
|
| 256 |
+
|
| 257 |
+
# Convert attention mask to boolean format for scaled_dot_product_attention
|
| 258 |
+
# Input mask: 0 = mask (don't attend), 1 = keep (attend)
|
| 259 |
+
# Boolean mask: False = mask, True = attend
|
| 260 |
+
attn_mask_bool = None
|
| 261 |
+
if attention_mask is not None:
|
| 262 |
+
if attention_mask.dim() == 2:
|
| 263 |
+
# [B, seq_len] with {0, 1} -> [B, 1, 1, seq_len] with {False, True}
|
| 264 |
+
attn_mask_bool = attention_mask.bool().unsqueeze(1).unsqueeze(1)
|
| 265 |
+
else:
|
| 266 |
+
# Already 4D [B, 1, seq_len, seq_len], just convert to bool
|
| 267 |
+
attn_mask_bool = attention_mask.bool()
|
| 268 |
+
|
| 269 |
+
# Get dropout probability (0.0 when not training)
|
| 270 |
+
dropout_p = self.attn_dropout.p if self.training else 0.0
|
| 271 |
+
|
| 272 |
+
if hasattr(F, "scaled_dot_product_attention"):
|
| 273 |
+
# Apply fused attention operation when available.
|
| 274 |
+
out_heads = F.scaled_dot_product_attention(
|
| 275 |
+
q_for_attn, k, v,
|
| 276 |
+
attn_mask=attn_mask_bool,
|
| 277 |
+
dropout_p=dropout_p,
|
| 278 |
+
scale=None
|
| 279 |
+
) # [B, n_heads, seq_len, d_head_latent]
|
| 280 |
+
else:
|
| 281 |
+
scores = torch.matmul(q_for_attn, k.transpose(-2, -1)) * self.scale
|
| 282 |
+
if attn_mask_bool is not None:
|
| 283 |
+
scores = scores.masked_fill(~attn_mask_bool, torch.finfo(scores.dtype).min)
|
| 284 |
+
attn_weights = F.softmax(scores, dim=-1)
|
| 285 |
+
if dropout_p > 0.0:
|
| 286 |
+
attn_weights = F.dropout(attn_weights, p=dropout_p, training=True)
|
| 287 |
+
out_heads = torch.matmul(attn_weights, v)
|
| 288 |
+
|
| 289 |
+
# ────────────────────────────────────────────────────────────────────────
|
| 290 |
+
# Concatenate heads
|
| 291 |
+
# ────────────────────────────────────────────────────────────────────────
|
| 292 |
+
# [B, seq_len, n_heads, d_head_latent] -> [B, seq_len, d_latent_kv]
|
| 293 |
+
out_concat = out_heads.transpose(1, 2).reshape(B, seq_len, self.d_latent_kv)
|
| 294 |
+
|
| 295 |
+
# Project back to d_model
|
| 296 |
+
out = self.out_proj(out_concat) # [B, seq_len, d_model]
|
| 297 |
+
out = self.proj_dropout(out)
|
| 298 |
+
|
| 299 |
+
return out
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
class AttentionBlock(nn.Module):
|
| 303 |
+
"""
|
| 304 |
+
Attention block with pre-norm residual connection and feed-forward network.
|
| 305 |
+
|
| 306 |
+
Structure:
|
| 307 |
+
Input
|
| 308 |
+
├─> Norm ─┬─> MLA ──┬─> Residual Add
|
| 309 |
+
│ └────────┘
|
| 310 |
+
├────────────────────────────────────> Norm ─┬─> SwiGLU FFN ──┬─> Residual Add
|
| 311 |
+
│ └───────┘ │
|
| 312 |
+
└────────────────────────────────────────────────────────────> Output
|
| 313 |
+
"""
|
| 314 |
+
|
| 315 |
+
def __init__(self, d_model, d_latent_kv, n_heads, d_rope, d_ff, dropout=0.1, gqa_groups=1,
|
| 316 |
+
rope_scale=40.0, max_seq_length=1024, yarn_alpha=1.0,
|
| 317 |
+
residual_rms_norm=False, residual_rms_target=1.0, residual_rms_cap=None,
|
| 318 |
+
residual_rms_eps=1e-6):
|
| 319 |
+
super().__init__()
|
| 320 |
+
self.residual_rms_norm = residual_rms_norm
|
| 321 |
+
self.residual_rms_target = residual_rms_target
|
| 322 |
+
self.residual_rms_cap = residual_rms_cap
|
| 323 |
+
self.residual_rms_eps = residual_rms_eps
|
| 324 |
+
self.mla = DeepSeekMLA(d_model, d_latent_kv, n_heads, d_rope, dropout, gqa_groups,
|
| 325 |
+
rope_scale=rope_scale, max_seq_length=max_seq_length,
|
| 326 |
+
yarn_alpha=yarn_alpha)
|
| 327 |
+
|
| 328 |
+
# SwiGLU feed-forward network
|
| 329 |
+
self.ff_norm = nn.LayerNorm(d_model)
|
| 330 |
+
self.ff_gate = nn.Linear(d_model, d_ff, bias=False)
|
| 331 |
+
self.ff_value = nn.Linear(d_model, d_ff, bias=False)
|
| 332 |
+
self.ff_out = nn.Linear(d_ff, d_model, bias=False)
|
| 333 |
+
self.dropout = nn.Dropout(dropout)
|
| 334 |
+
|
| 335 |
+
def forward(self, x, attention_mask=None):
|
| 336 |
+
"""
|
| 337 |
+
Args:
|
| 338 |
+
x: [B, seq_len, d_model]
|
| 339 |
+
attention_mask: [B, seq_len] or [B, 1, seq_len, seq_len]
|
| 340 |
+
|
| 341 |
+
Returns:
|
| 342 |
+
out: [B, seq_len, d_model]
|
| 343 |
+
"""
|
| 344 |
+
# Attention with residual
|
| 345 |
+
attn_out = self.mla(x, attention_mask)
|
| 346 |
+
x = x + self.dropout(attn_out)
|
| 347 |
+
x = _residual_rms_norm(
|
| 348 |
+
x,
|
| 349 |
+
self.residual_rms_norm,
|
| 350 |
+
self.residual_rms_target,
|
| 351 |
+
self.residual_rms_eps,
|
| 352 |
+
self.residual_rms_cap,
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
# FFN with residual
|
| 356 |
+
ff_norm = self.ff_norm(x)
|
| 357 |
+
ff_gate = self.ff_gate(ff_norm)
|
| 358 |
+
ff_value = self.ff_value(ff_norm)
|
| 359 |
+
ff_out = ff_value * F.silu(ff_gate) # SwiGLU activation
|
| 360 |
+
ff_out = self.ff_out(ff_out)
|
| 361 |
+
x = x + self.dropout(ff_out)
|
| 362 |
+
x = _residual_rms_norm(
|
| 363 |
+
x,
|
| 364 |
+
self.residual_rms_norm,
|
| 365 |
+
self.residual_rms_target,
|
| 366 |
+
self.residual_rms_eps,
|
| 367 |
+
self.residual_rms_cap,
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
return x
|
code/TaoTrain/src/taoTrain/models/registry.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Model architecture registry and factory."""
|
| 2 |
+
|
| 3 |
+
from typing import Dict, Type, Optional
|
| 4 |
+
import torch
|
| 5 |
+
from taoTrain.core import BaseModel
|
| 6 |
+
from taoTrain.config import ModelConfig
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# Global registry for model architectures
|
| 10 |
+
_ARCHITECTURE_REGISTRY: Dict[str, Type[BaseModel]] = {}
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def register_architecture(name: str):
|
| 14 |
+
"""Decorator to register a custom model architecture."""
|
| 15 |
+
def decorator(cls: Type[BaseModel]):
|
| 16 |
+
if name in _ARCHITECTURE_REGISTRY:
|
| 17 |
+
raise ValueError(f"Architecture '{name}' is already registered")
|
| 18 |
+
_ARCHITECTURE_REGISTRY[name] = cls
|
| 19 |
+
return cls
|
| 20 |
+
return decorator
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_registered_architectures() -> Dict[str, Type[BaseModel]]:
|
| 24 |
+
"""Get all registered architectures."""
|
| 25 |
+
return _ARCHITECTURE_REGISTRY.copy()
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def get_model(
|
| 29 |
+
config: ModelConfig,
|
| 30 |
+
device: Optional[torch.device] = None,
|
| 31 |
+
) -> BaseModel:
|
| 32 |
+
"""
|
| 33 |
+
Create a model instance from config.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
config: ModelConfig instance
|
| 37 |
+
device: Device to create model on (defaults to CPU)
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
Model instance
|
| 41 |
+
"""
|
| 42 |
+
if device is None:
|
| 43 |
+
device = torch.device('cpu')
|
| 44 |
+
|
| 45 |
+
# Handle both enum and string values
|
| 46 |
+
arch_type = config.architecture_type
|
| 47 |
+
if isinstance(arch_type, str):
|
| 48 |
+
arch_name = arch_type
|
| 49 |
+
else:
|
| 50 |
+
arch_name = arch_type.value
|
| 51 |
+
|
| 52 |
+
if arch_name not in _ARCHITECTURE_REGISTRY:
|
| 53 |
+
raise ValueError(
|
| 54 |
+
f"Unknown architecture: {arch_name}. "
|
| 55 |
+
f"Available: {list(_ARCHITECTURE_REGISTRY.keys())}"
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
model_class = _ARCHITECTURE_REGISTRY[arch_name]
|
| 59 |
+
model = model_class(config).to(device)
|
| 60 |
+
|
| 61 |
+
return model
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def register_builtin_architectures():
|
| 65 |
+
"""Register all built-in architectures."""
|
| 66 |
+
# Import here to register (avoid circular imports)
|
| 67 |
+
from . import transformer # noqa: F401
|
| 68 |
+
from . import taonet # noqa: F401
|
| 69 |
+
from . import taonet_ssm # noqa: F401
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# Auto-register built-in architectures when module is imported
|
| 73 |
+
register_builtin_architectures()
|
code/TaoTrain/src/taoTrain/models/taonet.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SimpleLLM - Pure Attention-based Language Model with DeepSeek MLA + RoPE.
|
| 3 |
+
|
| 4 |
+
Architecture:
|
| 5 |
+
- Token Embedding → Attention Blocks → Output Head
|
| 6 |
+
- Attention Blocks: Multi-head Latent Attention with RoPE positional embeddings
|
| 7 |
+
- Feed-forward: SwiGLU gates
|
| 8 |
+
- No state-space models (SSM), pure transformer architecture
|
| 9 |
+
- Full BF16 precision (no quantization)
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import math
|
| 13 |
+
from typing import Optional
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
|
| 18 |
+
from taoTrain.core import BaseModel
|
| 19 |
+
from taoTrain.config import ModelConfig
|
| 20 |
+
from .registry import register_architecture
|
| 21 |
+
from .mla_components import AttentionBlock
|
| 22 |
+
from .embeddings import FactorizedEmbedding
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@register_architecture("taonet")
|
| 26 |
+
class SimpleLLM(BaseModel):
|
| 27 |
+
"""
|
| 28 |
+
Pure attention-based language model with DeepSeek MLA + RoPE.
|
| 29 |
+
|
| 30 |
+
Stateless architecture - no internal state management needed.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
config: ModelConfig with:
|
| 34 |
+
- vocab_size: Vocabulary size
|
| 35 |
+
- hidden_dim: Model dimension (d_model)
|
| 36 |
+
- hidden_dim_ff: Feed-forward dimension (default: 4 * hidden_dim)
|
| 37 |
+
- num_layers: Number of attention blocks (n_layers)
|
| 38 |
+
- num_heads: Number of attention heads (n_attn_heads)
|
| 39 |
+
- d_latent_kv: KV compression dimension (default: 3/4 * hidden_dim)
|
| 40 |
+
- d_rope: RoPE dimension per head (default: hidden_dim // num_heads)
|
| 41 |
+
- max_seq_length: Maximum sequence length
|
| 42 |
+
- dropout: Dropout rate
|
| 43 |
+
- gqa_groups: Grouped Query Attention groups (default: 1)
|
| 44 |
+
- use_factorized_embedding: Use low-rank embedding (default: False)
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
def __init__(self, config: ModelConfig):
|
| 48 |
+
super().__init__(config)
|
| 49 |
+
|
| 50 |
+
# Parse config - use defaults if not specified
|
| 51 |
+
self.vocab_size = config.vocab_size
|
| 52 |
+
self.d_model = config.hidden_dim
|
| 53 |
+
self.n_layers = config.num_layers
|
| 54 |
+
self.n_heads = config.num_heads
|
| 55 |
+
self.dropout = config.dropout
|
| 56 |
+
|
| 57 |
+
# Optional parameters with smart defaults
|
| 58 |
+
self.d_latent_kv = config.d_latent_kv if config.d_latent_kv is not None else int(self.d_model * 0.75)
|
| 59 |
+
self.d_rope = config.d_rope if config.d_rope is not None else (self.d_model // self.n_heads)
|
| 60 |
+
self.d_ff = config.hidden_dim_ff if config.hidden_dim_ff is not None else (self.d_model * 4)
|
| 61 |
+
self.gqa_groups = getattr(config, 'gqa_groups', 1)
|
| 62 |
+
self.use_factorized_embedding = getattr(config, 'use_factorized_embedding', False)
|
| 63 |
+
self.d_embed_rank = getattr(config, 'd_embed_rank', 96)
|
| 64 |
+
|
| 65 |
+
# YaRN parameters for context length extension
|
| 66 |
+
self.rope_scale = getattr(config, 'rope_scale', 40.0)
|
| 67 |
+
self.yarn_enabled = getattr(config, 'yarn_enabled', False)
|
| 68 |
+
self.yarn_alpha = getattr(config, 'yarn_alpha', 1.0)
|
| 69 |
+
self.max_seq_length = config.max_seq_length
|
| 70 |
+
|
| 71 |
+
# Validate dimensions
|
| 72 |
+
assert self.d_model % self.n_heads == 0, \
|
| 73 |
+
f"hidden_dim ({self.d_model}) must be divisible by num_heads ({self.n_heads})"
|
| 74 |
+
assert self.d_latent_kv % self.n_heads == 0, \
|
| 75 |
+
f"d_latent_kv ({self.d_latent_kv}) must be divisible by num_heads ({self.n_heads})"
|
| 76 |
+
|
| 77 |
+
# Token embedding
|
| 78 |
+
if self.use_factorized_embedding:
|
| 79 |
+
self.token_embedding = FactorizedEmbedding(
|
| 80 |
+
self.vocab_size,
|
| 81 |
+
self.d_model,
|
| 82 |
+
self.d_embed_rank
|
| 83 |
+
)
|
| 84 |
+
else:
|
| 85 |
+
self.token_embedding = nn.Embedding(self.vocab_size, self.d_model)
|
| 86 |
+
|
| 87 |
+
# Embedding dropout
|
| 88 |
+
self.embedding_dropout = nn.Dropout(self.dropout)
|
| 89 |
+
|
| 90 |
+
# Attention blocks with MLA + SwiGLU FFN
|
| 91 |
+
self.blocks = nn.ModuleList()
|
| 92 |
+
for _ in range(self.n_layers):
|
| 93 |
+
self.blocks.append(
|
| 94 |
+
AttentionBlock(
|
| 95 |
+
d_model=self.d_model,
|
| 96 |
+
d_latent_kv=self.d_latent_kv,
|
| 97 |
+
n_heads=self.n_heads,
|
| 98 |
+
d_rope=self.d_rope,
|
| 99 |
+
d_ff=int(self.d_ff),
|
| 100 |
+
dropout=self.dropout,
|
| 101 |
+
gqa_groups=self.gqa_groups,
|
| 102 |
+
rope_scale=self.rope_scale,
|
| 103 |
+
max_seq_length=self.max_seq_length,
|
| 104 |
+
yarn_alpha=self.yarn_alpha,
|
| 105 |
+
)
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# Final layer norm
|
| 109 |
+
self.final_norm = nn.LayerNorm(self.d_model)
|
| 110 |
+
|
| 111 |
+
# Output projection to vocabulary
|
| 112 |
+
self.output_head = nn.Linear(self.d_model, self.vocab_size, bias=False)
|
| 113 |
+
|
| 114 |
+
# Initialize weights
|
| 115 |
+
self.apply(self._init_weights)
|
| 116 |
+
|
| 117 |
+
# Cache for causal mask
|
| 118 |
+
self.register_buffer("causal_mask_cache", None, persistent=False)
|
| 119 |
+
|
| 120 |
+
self._print_architecture()
|
| 121 |
+
|
| 122 |
+
def _init_weights(self, module):
|
| 123 |
+
"""Initialize weights for stable training."""
|
| 124 |
+
if isinstance(module, nn.Linear):
|
| 125 |
+
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 126 |
+
if module.bias is not None:
|
| 127 |
+
nn.init.zeros_(module.bias)
|
| 128 |
+
elif isinstance(module, nn.Embedding):
|
| 129 |
+
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 130 |
+
|
| 131 |
+
def _print_architecture(self):
|
| 132 |
+
"""Print model architecture summary."""
|
| 133 |
+
total_params = sum(p.numel() for p in self.parameters())
|
| 134 |
+
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 135 |
+
|
| 136 |
+
print(f"\n{'='*70}")
|
| 137 |
+
print("MODEL ARCHITECTURE - TAОNET (DeepSeek MLA + RoPE)")
|
| 138 |
+
print(f"{'='*70}")
|
| 139 |
+
print(f"Embedding:")
|
| 140 |
+
if self.use_factorized_embedding:
|
| 141 |
+
embed_rank_params = self.vocab_size * self.d_embed_rank
|
| 142 |
+
embed_proj_params = self.d_embed_rank * self.d_model
|
| 143 |
+
print(f" Type: Factorized (rank={self.d_embed_rank})")
|
| 144 |
+
print(f" Rank layer: {embed_rank_params/1e6:>8.2f}M")
|
| 145 |
+
print(f" Projection: {embed_proj_params/1e6:>8.2f}M")
|
| 146 |
+
else:
|
| 147 |
+
embed_params = self.vocab_size * self.d_model
|
| 148 |
+
print(f" Type: Standard")
|
| 149 |
+
print(f" Params: {embed_params/1e6:>8.2f}M")
|
| 150 |
+
|
| 151 |
+
output_params = self.d_model * self.vocab_size
|
| 152 |
+
print(f"Output Head: {output_params/1e6:>8.2f}M")
|
| 153 |
+
print(f"Attention Blocks: {len(self.blocks):>10} layers × AttentionBlock")
|
| 154 |
+
print(f"{'─'*70}")
|
| 155 |
+
print(f"Total Parameters: {total_params/1e6:>8.2f}M (trainable: {trainable_params/1e6:.2f}M)")
|
| 156 |
+
print(f"{'─'*70}")
|
| 157 |
+
print(f"Configuration:")
|
| 158 |
+
print(f" Model dimension (d_model): {self.d_model}")
|
| 159 |
+
print(f" KV latent dimension (d_latent_kv): {self.d_latent_kv}")
|
| 160 |
+
print(f" Attention heads: {self.n_heads}")
|
| 161 |
+
print(f" Head dimension: {self.d_model // self.n_heads}")
|
| 162 |
+
print(f" RoPE dimension: {self.d_rope}")
|
| 163 |
+
print(f" Feed-forward dimension: {int(self.d_ff)}")
|
| 164 |
+
print(f" Number of layers: {self.n_layers}")
|
| 165 |
+
print(f" Max sequence length: {self.config.max_seq_length}")
|
| 166 |
+
print(f" Dropout: {self.dropout}")
|
| 167 |
+
print(f" GQA groups: {self.gqa_groups}")
|
| 168 |
+
print(f"{'='*70}\n")
|
| 169 |
+
|
| 170 |
+
def _get_causal_mask(self, seq_len, device):
|
| 171 |
+
"""Get or create causal mask for sequence."""
|
| 172 |
+
if self.causal_mask_cache is None or self.causal_mask_cache.size(-1) < seq_len:
|
| 173 |
+
# [seq_len, seq_len] lower triangular matrix (1 = attend, 0 = mask)
|
| 174 |
+
mask = torch.tril(torch.ones(seq_len, seq_len, device=device, dtype=torch.bool))
|
| 175 |
+
self.register_buffer("causal_mask_cache", mask, persistent=False)
|
| 176 |
+
return self.causal_mask_cache[:seq_len, :seq_len]
|
| 177 |
+
|
| 178 |
+
def forward(
|
| 179 |
+
self,
|
| 180 |
+
input_ids: torch.Tensor,
|
| 181 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 182 |
+
labels: Optional[torch.Tensor] = None,
|
| 183 |
+
) -> dict:
|
| 184 |
+
"""
|
| 185 |
+
Forward pass through the model.
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
input_ids: [batch_size, seq_len] tensor of token IDs
|
| 189 |
+
attention_mask: [batch_size, seq_len] tensor where 1 = valid, 0 = padding
|
| 190 |
+
labels: [batch_size, seq_len] target token IDs for loss computation
|
| 191 |
+
|
| 192 |
+
Returns:
|
| 193 |
+
Dictionary with:
|
| 194 |
+
- 'logits': [batch_size, seq_len, vocab_size] output logits
|
| 195 |
+
- 'loss': scalar loss (if labels provided, else None)
|
| 196 |
+
"""
|
| 197 |
+
batch_size, seq_len = input_ids.shape
|
| 198 |
+
device = input_ids.device
|
| 199 |
+
|
| 200 |
+
# Get causal mask: [seq_len, seq_len]
|
| 201 |
+
causal_mask = self._get_causal_mask(seq_len, device)
|
| 202 |
+
|
| 203 |
+
# Combine causal mask with attention mask if provided
|
| 204 |
+
if attention_mask is not None:
|
| 205 |
+
# attention_mask: [batch, seq_len] where 1 = valid, 0 = padding
|
| 206 |
+
# Expand to [batch, 1, 1, seq_len]
|
| 207 |
+
padding_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
|
| 208 |
+
# Combine with causal: [1, 1, seq_len, seq_len] * [batch, 1, 1, seq_len]
|
| 209 |
+
combined_mask = causal_mask.unsqueeze(0).unsqueeze(0) & padding_mask
|
| 210 |
+
# For MLA: convert to {0, 1} format
|
| 211 |
+
combined_mask = combined_mask.float()
|
| 212 |
+
else:
|
| 213 |
+
# Just causal mask
|
| 214 |
+
combined_mask = causal_mask.unsqueeze(0).unsqueeze(0).float()
|
| 215 |
+
|
| 216 |
+
# Embed tokens: [batch_size, seq_len] -> [batch_size, seq_len, d_model]
|
| 217 |
+
x = self.token_embedding(input_ids)
|
| 218 |
+
x = self.embedding_dropout(x)
|
| 219 |
+
|
| 220 |
+
# Pass through attention blocks
|
| 221 |
+
for block in self.blocks:
|
| 222 |
+
x = block(x, attention_mask=combined_mask)
|
| 223 |
+
|
| 224 |
+
# Final layer norm
|
| 225 |
+
x = self.final_norm(x)
|
| 226 |
+
|
| 227 |
+
# Output projection to vocabulary
|
| 228 |
+
logits = self.output_head(x) # [batch_size, seq_len, vocab_size]
|
| 229 |
+
|
| 230 |
+
# Compute loss if labels are provided
|
| 231 |
+
loss = None
|
| 232 |
+
if labels is not None:
|
| 233 |
+
# Flatten for loss computation
|
| 234 |
+
logits_flat = logits.view(-1, logits.size(-1)) # (batch * seq_len, vocab_size)
|
| 235 |
+
labels_flat = labels.view(-1)
|
| 236 |
+
|
| 237 |
+
# Only compute loss on valid targets (ignore -100 tokens for padding)
|
| 238 |
+
loss = F.cross_entropy(
|
| 239 |
+
logits_flat,
|
| 240 |
+
labels_flat,
|
| 241 |
+
reduction='mean',
|
| 242 |
+
ignore_index=-100
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
return {
|
| 246 |
+
'logits': logits,
|
| 247 |
+
'loss': loss,
|
| 248 |
+
}
|
code/TaoTrain/src/taoTrain/models/taonet_ssm.py
ADDED
|
@@ -0,0 +1,654 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""TaoNet variant that replaces MLA attention with an SSM mixer."""
|
| 2 |
+
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
from taoTrain.config import ModelConfig
|
| 10 |
+
from taoTrain.core import BaseModel
|
| 11 |
+
|
| 12 |
+
from .embeddings import FactorizedEmbedding
|
| 13 |
+
from .mla_components import AttentionBlock
|
| 14 |
+
from .registry import register_architecture
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _load_ssm_core(core: str):
|
| 18 |
+
try:
|
| 19 |
+
from gamma_space_model.modules.s4_ternary_dplr_ssm import S4TernaryDPLRSSM
|
| 20 |
+
from gamma_space_model.modules.ssm_gamma_s4 import SSMGammaS4
|
| 21 |
+
except ImportError as exc:
|
| 22 |
+
raise ImportError(
|
| 23 |
+
"taonet_ssm requires the Gamma Space Model package. Install the SSM repo "
|
| 24 |
+
"with `pip install -e /path/to/Taotern_SSM`, or put it on PYTHONPATH."
|
| 25 |
+
) from exc
|
| 26 |
+
if core == "gamma_s4":
|
| 27 |
+
return SSMGammaS4
|
| 28 |
+
if core == "dplr":
|
| 29 |
+
return S4TernaryDPLRSSM
|
| 30 |
+
raise ValueError(f"Unsupported ssm_core '{core}'.")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _padding_mask_from_attention_mask(attention_mask: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
|
| 34 |
+
if attention_mask is None:
|
| 35 |
+
return None
|
| 36 |
+
if attention_mask.dim() == 2:
|
| 37 |
+
return attention_mask
|
| 38 |
+
if attention_mask.dim() == 4:
|
| 39 |
+
return attention_mask.bool().any(dim=-2).squeeze(1).to(dtype=attention_mask.dtype)
|
| 40 |
+
raise ValueError(
|
| 41 |
+
"Expected attention_mask with shape [batch, seq_len] or "
|
| 42 |
+
f"[batch, 1, seq_len, seq_len], got {tuple(attention_mask.shape)}."
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _hybrid_ssm_layer_indices(config: ModelConfig, num_layers: int) -> set[int]:
|
| 47 |
+
if config.hybrid_ssm_layers:
|
| 48 |
+
indices = set()
|
| 49 |
+
for item in config.hybrid_ssm_layers.split(","):
|
| 50 |
+
item = item.strip()
|
| 51 |
+
if not item:
|
| 52 |
+
continue
|
| 53 |
+
index = int(item)
|
| 54 |
+
if index < 0 or index >= num_layers:
|
| 55 |
+
raise ValueError(
|
| 56 |
+
f"hybrid_ssm_layers index {index} is outside [0, {num_layers - 1}]."
|
| 57 |
+
)
|
| 58 |
+
indices.add(index)
|
| 59 |
+
if not indices:
|
| 60 |
+
raise ValueError("hybrid_ssm_layers was set but did not contain any valid layer indices.")
|
| 61 |
+
return indices
|
| 62 |
+
|
| 63 |
+
if config.hybrid_pattern == "attention_first":
|
| 64 |
+
return {idx for idx in range(num_layers) if idx % 2 == 1}
|
| 65 |
+
if config.hybrid_pattern == "ssm_first":
|
| 66 |
+
return {idx for idx in range(num_layers) if idx % 2 == 0}
|
| 67 |
+
if config.hybrid_pattern == "single_ssm_middle":
|
| 68 |
+
return {num_layers // 2}
|
| 69 |
+
if config.hybrid_pattern == "single_ssm_late":
|
| 70 |
+
return {num_layers - 1}
|
| 71 |
+
raise ValueError(f"Unsupported hybrid_pattern '{config.hybrid_pattern}'.")
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class ChannelGate(nn.Module):
|
| 75 |
+
"""Elementwise gate with one scale and bias per model channel."""
|
| 76 |
+
|
| 77 |
+
def __init__(self, d_model: int) -> None:
|
| 78 |
+
super().__init__()
|
| 79 |
+
self.weight = nn.Parameter(torch.zeros(d_model))
|
| 80 |
+
self.bias = nn.Parameter(torch.full((d_model,), 2.0))
|
| 81 |
+
|
| 82 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 83 |
+
return x * self.weight + self.bias
|
| 84 |
+
|
| 85 |
+
def reset_parameters(self) -> None:
|
| 86 |
+
nn.init.zeros_(self.weight)
|
| 87 |
+
nn.init.constant_(self.bias, 2.0)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def _build_gate(enabled: bool, gate_type: str, d_model: int) -> nn.Module | None:
|
| 91 |
+
if not enabled:
|
| 92 |
+
return None
|
| 93 |
+
if gate_type == "dense":
|
| 94 |
+
return nn.Linear(d_model, d_model)
|
| 95 |
+
if gate_type == "channel":
|
| 96 |
+
return ChannelGate(d_model)
|
| 97 |
+
raise ValueError(f"Unsupported ssm_gate_type '{gate_type}'.")
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def _residual_rms_norm(
|
| 101 |
+
x: torch.Tensor,
|
| 102 |
+
enabled: bool,
|
| 103 |
+
target: float,
|
| 104 |
+
eps: float,
|
| 105 |
+
cap: Optional[float] = None,
|
| 106 |
+
) -> torch.Tensor:
|
| 107 |
+
if not enabled and cap is None:
|
| 108 |
+
return x
|
| 109 |
+
rms = x.float().square().mean(dim=-1, keepdim=True).add(eps).sqrt()
|
| 110 |
+
if enabled:
|
| 111 |
+
scale = target / rms
|
| 112 |
+
else:
|
| 113 |
+
cap_tensor = torch.tensor(float(cap), dtype=rms.dtype, device=rms.device)
|
| 114 |
+
scale = torch.minimum(torch.ones_like(rms), cap_tensor / rms)
|
| 115 |
+
return x * scale.to(dtype=x.dtype)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class SSMMixer(nn.Module):
|
| 119 |
+
"""Causal sequence mixer with the same residual-branch contract as MLA."""
|
| 120 |
+
|
| 121 |
+
def __init__(self, config: ModelConfig) -> None:
|
| 122 |
+
super().__init__()
|
| 123 |
+
SSMCore = _load_ssm_core(config.ssm_core)
|
| 124 |
+
|
| 125 |
+
self.d_model = config.hidden_dim
|
| 126 |
+
self.ssm_core = config.ssm_core
|
| 127 |
+
d_latent_kv = config.d_latent_kv if config.d_latent_kv is not None else int(self.d_model * 0.75)
|
| 128 |
+
self.ssm_hidden_dim = config.ssm_hidden_dim if config.ssm_hidden_dim is not None else d_latent_kv
|
| 129 |
+
self.ssm_mixer_dim = config.ssm_mixer_dim if config.ssm_mixer_dim is not None else self.d_model
|
| 130 |
+
self.ssm_num_lanes = config.ssm_num_lanes
|
| 131 |
+
self.ssm_lane_combine = config.ssm_lane_combine
|
| 132 |
+
self.ssm_lane_mode = config.ssm_lane_mode
|
| 133 |
+
self.ssm_split_mix = config.ssm_split_mix
|
| 134 |
+
self.use_padding_mask = config.ssm_use_padding_mask
|
| 135 |
+
self.branch_rms_norm = config.ssm_branch_rms_norm
|
| 136 |
+
self.branch_rms_eps = config.ssm_branch_rms_eps
|
| 137 |
+
self.branch_clip_value = config.ssm_branch_clip_value
|
| 138 |
+
if self.ssm_num_lanes < 1:
|
| 139 |
+
raise ValueError("ssm_num_lanes must be at least 1.")
|
| 140 |
+
if self.ssm_lane_mode not in {"full", "split"}:
|
| 141 |
+
raise ValueError(f"Unsupported ssm_lane_mode '{self.ssm_lane_mode}'.")
|
| 142 |
+
if self.ssm_split_mix not in {"none", "hadamard"}:
|
| 143 |
+
raise ValueError(f"Unsupported ssm_split_mix '{self.ssm_split_mix}'.")
|
| 144 |
+
if self.ssm_split_mix != "none" and self.ssm_lane_mode != "split":
|
| 145 |
+
raise ValueError("ssm_split_mix is only supported when ssm_lane_mode='split'.")
|
| 146 |
+
if self.ssm_split_mix == "hadamard" and self.ssm_num_lanes != 2:
|
| 147 |
+
raise ValueError("ssm_split_mix='hadamard' currently requires exactly two SSM lanes.")
|
| 148 |
+
if self.ssm_lane_mode == "split" and self.ssm_mixer_dim % self.ssm_num_lanes != 0:
|
| 149 |
+
raise ValueError(
|
| 150 |
+
"ssm_mixer_dim must be divisible by ssm_num_lanes when ssm_lane_mode='split'."
|
| 151 |
+
)
|
| 152 |
+
self.ssm_lane_dim = (
|
| 153 |
+
self.ssm_mixer_dim // self.ssm_num_lanes
|
| 154 |
+
if self.ssm_lane_mode == "split"
|
| 155 |
+
else self.ssm_mixer_dim
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
self.norm = nn.LayerNorm(self.d_model)
|
| 159 |
+
self.gate_type = config.ssm_gate_type
|
| 160 |
+
self.input_gate = _build_gate(config.ssm_input_gate, self.gate_type, self.d_model)
|
| 161 |
+
self.input_proj = (
|
| 162 |
+
nn.Linear(self.d_model, self.ssm_mixer_dim, bias=False)
|
| 163 |
+
if self.ssm_mixer_dim != self.d_model
|
| 164 |
+
else nn.Identity()
|
| 165 |
+
)
|
| 166 |
+
common_kwargs = {
|
| 167 |
+
"state_dim": self.ssm_lane_dim,
|
| 168 |
+
"hidden_dim": self.ssm_hidden_dim,
|
| 169 |
+
"dt_min": config.ssm_dt_min,
|
| 170 |
+
"dt_max": config.ssm_dt_max,
|
| 171 |
+
"dt_init": config.ssm_dt_init,
|
| 172 |
+
"use_D": config.ssm_use_d,
|
| 173 |
+
"kernel_mode": config.ssm_kernel_mode,
|
| 174 |
+
"kernel_threshold": config.ssm_kernel_threshold,
|
| 175 |
+
}
|
| 176 |
+
self.ssm_lanes = nn.ModuleList(
|
| 177 |
+
[self._build_ssm_lane(SSMCore, common_kwargs, config) for _ in range(self.ssm_num_lanes)]
|
| 178 |
+
)
|
| 179 |
+
self.ssm = self.ssm_lanes[0]
|
| 180 |
+
self.lane_weights = None
|
| 181 |
+
if self.ssm_lane_combine not in {"mean", "channel"}:
|
| 182 |
+
raise ValueError(f"Unsupported ssm_lane_combine '{self.ssm_lane_combine}'.")
|
| 183 |
+
if (
|
| 184 |
+
self.ssm_lane_mode == "full"
|
| 185 |
+
and self.ssm_num_lanes > 1
|
| 186 |
+
and self.ssm_lane_combine == "channel"
|
| 187 |
+
):
|
| 188 |
+
self.lane_weights = nn.Parameter(
|
| 189 |
+
torch.full((self.ssm_num_lanes, self.ssm_mixer_dim), 1.0 / self.ssm_num_lanes)
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
if config.ssm_activation == "gelu":
|
| 193 |
+
self.activation = nn.GELU()
|
| 194 |
+
elif config.ssm_activation == "silu":
|
| 195 |
+
self.activation = nn.SiLU()
|
| 196 |
+
elif config.ssm_activation in {"identity", "linear"}:
|
| 197 |
+
self.activation = nn.Identity()
|
| 198 |
+
else:
|
| 199 |
+
raise ValueError(f"Unsupported ssm_activation '{config.ssm_activation}'.")
|
| 200 |
+
|
| 201 |
+
self.output_gate = _build_gate(config.ssm_gate, self.gate_type, self.d_model)
|
| 202 |
+
self.out_proj = nn.Linear(self.ssm_mixer_dim, self.d_model, bias=False)
|
| 203 |
+
self.layer_scale = nn.Parameter(torch.full((self.d_model,), config.ssm_layer_scale_init))
|
| 204 |
+
self.local_shift_scale = None
|
| 205 |
+
if config.ssm_local_shift:
|
| 206 |
+
if config.ssm_local_shift_per_channel:
|
| 207 |
+
self.local_shift_scale = nn.Parameter(
|
| 208 |
+
torch.full((self.d_model,), float(config.ssm_local_shift_init))
|
| 209 |
+
)
|
| 210 |
+
else:
|
| 211 |
+
self.local_shift_scale = nn.Parameter(torch.tensor(float(config.ssm_local_shift_init)))
|
| 212 |
+
self.proj_dropout = nn.Dropout(config.dropout)
|
| 213 |
+
|
| 214 |
+
self._reset_parameters()
|
| 215 |
+
|
| 216 |
+
def _normalize_branch(self, ssm_out: torch.Tensor) -> torch.Tensor:
|
| 217 |
+
if not self.branch_rms_norm:
|
| 218 |
+
return ssm_out
|
| 219 |
+
rms = ssm_out.float().square().mean(dim=-1, keepdim=True).add(self.branch_rms_eps).rsqrt()
|
| 220 |
+
return ssm_out * rms.to(dtype=ssm_out.dtype)
|
| 221 |
+
|
| 222 |
+
def _build_ssm_lane(self, SSMCore, common_kwargs: dict, config: ModelConfig) -> nn.Module:
|
| 223 |
+
if config.ssm_core == "gamma_s4":
|
| 224 |
+
return SSMCore(
|
| 225 |
+
**common_kwargs,
|
| 226 |
+
discretization=config.ssm_discretization,
|
| 227 |
+
)
|
| 228 |
+
return SSMCore(
|
| 229 |
+
**common_kwargs,
|
| 230 |
+
rank=config.ssm_rank,
|
| 231 |
+
max_low_rank_scale=config.ssm_max_low_rank_scale,
|
| 232 |
+
finite_tail_correction=config.ssm_finite_tail_correction,
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
def _reset_parameters(self) -> None:
|
| 236 |
+
if isinstance(self.input_gate, nn.Linear):
|
| 237 |
+
nn.init.zeros_(self.input_gate.weight)
|
| 238 |
+
nn.init.constant_(self.input_gate.bias, 2.0)
|
| 239 |
+
elif isinstance(self.input_gate, ChannelGate):
|
| 240 |
+
self.input_gate.reset_parameters()
|
| 241 |
+
if isinstance(self.output_gate, nn.Linear):
|
| 242 |
+
nn.init.zeros_(self.output_gate.weight)
|
| 243 |
+
nn.init.constant_(self.output_gate.bias, 2.0)
|
| 244 |
+
elif isinstance(self.output_gate, ChannelGate):
|
| 245 |
+
self.output_gate.reset_parameters()
|
| 246 |
+
if isinstance(self.input_proj, nn.Linear):
|
| 247 |
+
nn.init.xavier_uniform_(self.input_proj.weight)
|
| 248 |
+
nn.init.xavier_uniform_(self.out_proj.weight)
|
| 249 |
+
else:
|
| 250 |
+
nn.init.eye_(self.out_proj.weight)
|
| 251 |
+
|
| 252 |
+
def forward(
|
| 253 |
+
self,
|
| 254 |
+
x: torch.Tensor,
|
| 255 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 256 |
+
) -> torch.Tensor:
|
| 257 |
+
x_norm = self.norm(x)
|
| 258 |
+
ssm_in = x_norm
|
| 259 |
+
if self.input_gate is not None:
|
| 260 |
+
ssm_in = ssm_in * torch.sigmoid(self.input_gate(x_norm))
|
| 261 |
+
ssm_in = self.input_proj(ssm_in)
|
| 262 |
+
|
| 263 |
+
padding_mask = _padding_mask_from_attention_mask(attention_mask) if self.use_padding_mask else None
|
| 264 |
+
lane_outputs = []
|
| 265 |
+
if self.ssm_lane_mode == "split":
|
| 266 |
+
lane_inputs = torch.split(ssm_in, self.ssm_lane_dim, dim=-1)
|
| 267 |
+
else:
|
| 268 |
+
lane_inputs = [ssm_in] * self.ssm_num_lanes
|
| 269 |
+
for lane, lane_input in zip(self.ssm_lanes, lane_inputs):
|
| 270 |
+
lane_out, _ = lane(
|
| 271 |
+
lane_input,
|
| 272 |
+
mask=padding_mask,
|
| 273 |
+
return_state=False,
|
| 274 |
+
)
|
| 275 |
+
lane_outputs.append(lane_out)
|
| 276 |
+
if self.ssm_lane_mode == "split":
|
| 277 |
+
if self.ssm_split_mix == "hadamard":
|
| 278 |
+
left, right = lane_outputs
|
| 279 |
+
inv_sqrt_2 = 0.7071067811865476
|
| 280 |
+
ssm_out = torch.cat((left + right, left - right), dim=-1) * inv_sqrt_2
|
| 281 |
+
else:
|
| 282 |
+
ssm_out = torch.cat(lane_outputs, dim=-1)
|
| 283 |
+
elif len(lane_outputs) == 1:
|
| 284 |
+
ssm_out = lane_outputs[0]
|
| 285 |
+
elif self.lane_weights is not None:
|
| 286 |
+
weights = self.lane_weights.to(dtype=lane_outputs[0].dtype, device=lane_outputs[0].device)
|
| 287 |
+
ssm_out = torch.stack(lane_outputs, dim=2)
|
| 288 |
+
ssm_out = (ssm_out * weights.view(1, 1, self.ssm_num_lanes, self.ssm_mixer_dim)).sum(dim=2)
|
| 289 |
+
else:
|
| 290 |
+
ssm_out = torch.stack(lane_outputs, dim=0).mean(dim=0)
|
| 291 |
+
ssm_out = self.activation(ssm_out)
|
| 292 |
+
ssm_out = self.out_proj(ssm_out)
|
| 293 |
+
|
| 294 |
+
if self.output_gate is not None:
|
| 295 |
+
ssm_out = ssm_out * torch.sigmoid(self.output_gate(x_norm))
|
| 296 |
+
|
| 297 |
+
ssm_out = self._normalize_branch(ssm_out)
|
| 298 |
+
ssm_out = ssm_out * self.layer_scale
|
| 299 |
+
if self.local_shift_scale is not None:
|
| 300 |
+
shifted = torch.zeros_like(x_norm)
|
| 301 |
+
shifted[:, 1:] = x_norm[:, :-1]
|
| 302 |
+
ssm_out = ssm_out + shifted * self.local_shift_scale
|
| 303 |
+
if self.branch_clip_value is not None:
|
| 304 |
+
ssm_out = torch.clamp(ssm_out, -self.branch_clip_value, self.branch_clip_value)
|
| 305 |
+
return self.proj_dropout(ssm_out)
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
class SSMAttentionBlock(nn.Module):
|
| 309 |
+
"""TaoNet block with Gamma SSM sequence mixing and the original SwiGLU FFN."""
|
| 310 |
+
|
| 311 |
+
def __init__(self, config: ModelConfig) -> None:
|
| 312 |
+
super().__init__()
|
| 313 |
+
d_model = config.hidden_dim
|
| 314 |
+
d_ff = config.hidden_dim_ff if config.hidden_dim_ff is not None else d_model * 4
|
| 315 |
+
|
| 316 |
+
self.mixer = SSMMixer(config)
|
| 317 |
+
self.residual_rms_norm = config.block_residual_rms_norm
|
| 318 |
+
self.residual_rms_target = config.block_residual_rms_target
|
| 319 |
+
self.residual_rms_cap = config.block_residual_rms_cap
|
| 320 |
+
self.residual_rms_eps = config.block_residual_rms_eps
|
| 321 |
+
self.ff_norm = nn.LayerNorm(d_model)
|
| 322 |
+
self.ff_gate = nn.Linear(d_model, int(d_ff), bias=False)
|
| 323 |
+
self.ff_value = nn.Linear(d_model, int(d_ff), bias=False)
|
| 324 |
+
self.ff_out = nn.Linear(int(d_ff), d_model, bias=False)
|
| 325 |
+
self.dropout = nn.Dropout(config.dropout)
|
| 326 |
+
|
| 327 |
+
def forward(
|
| 328 |
+
self,
|
| 329 |
+
x: torch.Tensor,
|
| 330 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 331 |
+
) -> torch.Tensor:
|
| 332 |
+
x = x + self.dropout(self.mixer(x, attention_mask=attention_mask))
|
| 333 |
+
x = _residual_rms_norm(
|
| 334 |
+
x,
|
| 335 |
+
self.residual_rms_norm,
|
| 336 |
+
self.residual_rms_target,
|
| 337 |
+
self.residual_rms_eps,
|
| 338 |
+
self.residual_rms_cap,
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
ff_norm = self.ff_norm(x)
|
| 342 |
+
ff_gate = self.ff_gate(ff_norm)
|
| 343 |
+
ff_value = self.ff_value(ff_norm)
|
| 344 |
+
ff_out = ff_value * F.silu(ff_gate)
|
| 345 |
+
ff_out = self.ff_out(ff_out)
|
| 346 |
+
x = x + self.dropout(ff_out)
|
| 347 |
+
return _residual_rms_norm(
|
| 348 |
+
x,
|
| 349 |
+
self.residual_rms_norm,
|
| 350 |
+
self.residual_rms_target,
|
| 351 |
+
self.residual_rms_eps,
|
| 352 |
+
self.residual_rms_cap,
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
@register_architecture("taonet_ssm")
|
| 357 |
+
class TaoNetSSMLLM(BaseModel):
|
| 358 |
+
"""TaoNet language model with SSM blocks replacing MLA attention."""
|
| 359 |
+
|
| 360 |
+
def __init__(self, config: ModelConfig):
|
| 361 |
+
super().__init__(config)
|
| 362 |
+
|
| 363 |
+
self.vocab_size = config.vocab_size
|
| 364 |
+
self.d_model = config.hidden_dim
|
| 365 |
+
self.n_layers = config.num_layers
|
| 366 |
+
self.n_heads = config.num_heads
|
| 367 |
+
self.dropout = config.dropout
|
| 368 |
+
self.d_latent_kv = config.d_latent_kv if config.d_latent_kv is not None else int(self.d_model * 0.75)
|
| 369 |
+
self.d_ff = config.hidden_dim_ff if config.hidden_dim_ff is not None else self.d_model * 4
|
| 370 |
+
self.use_factorized_embedding = getattr(config, "use_factorized_embedding", False)
|
| 371 |
+
self.d_embed_rank = getattr(config, "d_embed_rank", 96)
|
| 372 |
+
self.max_seq_length = config.max_seq_length
|
| 373 |
+
|
| 374 |
+
if self.use_factorized_embedding:
|
| 375 |
+
self.token_embedding = FactorizedEmbedding(
|
| 376 |
+
self.vocab_size,
|
| 377 |
+
self.d_model,
|
| 378 |
+
self.d_embed_rank,
|
| 379 |
+
)
|
| 380 |
+
else:
|
| 381 |
+
self.token_embedding = nn.Embedding(self.vocab_size, self.d_model)
|
| 382 |
+
|
| 383 |
+
self.embedding_dropout = nn.Dropout(self.dropout)
|
| 384 |
+
self.blocks = nn.ModuleList([SSMAttentionBlock(config) for _ in range(self.n_layers)])
|
| 385 |
+
self.final_norm = nn.LayerNorm(self.d_model)
|
| 386 |
+
self.output_head = nn.Linear(self.d_model, self.vocab_size, bias=False)
|
| 387 |
+
|
| 388 |
+
self.apply(self._init_weights)
|
| 389 |
+
for block in self.blocks:
|
| 390 |
+
block.mixer._reset_parameters()
|
| 391 |
+
|
| 392 |
+
self._print_architecture(config)
|
| 393 |
+
|
| 394 |
+
def _init_weights(self, module):
|
| 395 |
+
if isinstance(module, nn.Linear):
|
| 396 |
+
nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std)
|
| 397 |
+
if module.bias is not None:
|
| 398 |
+
nn.init.zeros_(module.bias)
|
| 399 |
+
elif isinstance(module, nn.Embedding):
|
| 400 |
+
nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std)
|
| 401 |
+
|
| 402 |
+
def _print_architecture(self, config: ModelConfig):
|
| 403 |
+
total_params = sum(p.numel() for p in self.parameters())
|
| 404 |
+
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 405 |
+
ssm_hidden_dim = config.ssm_hidden_dim if config.ssm_hidden_dim is not None else self.d_latent_kv
|
| 406 |
+
|
| 407 |
+
print(f"\n{'=' * 70}")
|
| 408 |
+
print(f"MODEL ARCHITECTURE - TAONET-SSM ({config.ssm_core} + SwiGLU)")
|
| 409 |
+
print(f"{'=' * 70}")
|
| 410 |
+
print(f"Embedding vocab: {self.vocab_size}")
|
| 411 |
+
print(f"Output Head: {(self.d_model * self.vocab_size) / 1e6:>8.2f}M")
|
| 412 |
+
print(f"SSM Blocks: {len(self.blocks):>8} layers x SSMMixer")
|
| 413 |
+
print(f"{'-' * 70}")
|
| 414 |
+
print(f"Total Parameters: {total_params / 1e6:>8.2f}M (trainable: {trainable_params / 1e6:.2f}M)")
|
| 415 |
+
print(f"{'-' * 70}")
|
| 416 |
+
print("Configuration:")
|
| 417 |
+
print(f" Model dimension (d_model): {self.d_model}")
|
| 418 |
+
print(f" SSM core: {config.ssm_core}")
|
| 419 |
+
print(f" SSM hidden dimension: {ssm_hidden_dim}")
|
| 420 |
+
print(f" SSM mixer dimension: {config.ssm_mixer_dim or self.d_model}")
|
| 421 |
+
print(f" SSM lanes: {config.ssm_num_lanes}")
|
| 422 |
+
print(f" SSM lane mode: {config.ssm_lane_mode}")
|
| 423 |
+
print(f" SSM split mix: {config.ssm_split_mix}")
|
| 424 |
+
print(f" SSM lane combine: {config.ssm_lane_combine}")
|
| 425 |
+
if config.ssm_core == "dplr":
|
| 426 |
+
print(f" SSM DPLR rank: {config.ssm_rank}")
|
| 427 |
+
print(f" SSM discretization: {config.ssm_discretization}")
|
| 428 |
+
print(f" SSM kernel mode: {config.ssm_kernel_mode}")
|
| 429 |
+
print(f" SSM kernel threshold: {config.ssm_kernel_threshold}")
|
| 430 |
+
print(f" SSM padding mask enabled: {config.ssm_use_padding_mask}")
|
| 431 |
+
print(f" SSM gate type: {config.ssm_gate_type}")
|
| 432 |
+
print(f" SSM branch RMS norm: {config.ssm_branch_rms_norm}")
|
| 433 |
+
print(f" SSM branch clip value: {config.ssm_branch_clip_value}")
|
| 434 |
+
print(f" Block residual RMS norm: {config.block_residual_rms_norm}")
|
| 435 |
+
print(f" Block residual RMS cap: {config.block_residual_rms_cap}")
|
| 436 |
+
print(f" SSM local shift enabled: {config.ssm_local_shift}")
|
| 437 |
+
print(f" SSM local shift per channel: {config.ssm_local_shift_per_channel}")
|
| 438 |
+
print(f" Feed-forward dimension: {int(self.d_ff)}")
|
| 439 |
+
print(f" Number of layers: {self.n_layers}")
|
| 440 |
+
print(f" Max sequence length: {self.max_seq_length}")
|
| 441 |
+
print(f" Dropout: {self.dropout}")
|
| 442 |
+
print(f"{'=' * 70}\n")
|
| 443 |
+
|
| 444 |
+
def forward(
|
| 445 |
+
self,
|
| 446 |
+
input_ids: torch.Tensor,
|
| 447 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 448 |
+
labels: Optional[torch.Tensor] = None,
|
| 449 |
+
) -> dict:
|
| 450 |
+
x = self.token_embedding(input_ids)
|
| 451 |
+
x = self.embedding_dropout(x)
|
| 452 |
+
|
| 453 |
+
for block in self.blocks:
|
| 454 |
+
x = block(x, attention_mask=attention_mask)
|
| 455 |
+
|
| 456 |
+
x = self.final_norm(x)
|
| 457 |
+
logits = self.output_head(x)
|
| 458 |
+
|
| 459 |
+
loss = None
|
| 460 |
+
if labels is not None:
|
| 461 |
+
loss = F.cross_entropy(
|
| 462 |
+
logits.view(-1, logits.size(-1)),
|
| 463 |
+
labels.view(-1),
|
| 464 |
+
reduction="mean",
|
| 465 |
+
ignore_index=-100,
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
return {
|
| 469 |
+
"logits": logits,
|
| 470 |
+
"loss": loss,
|
| 471 |
+
}
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
@register_architecture("taonet_hybrid")
|
| 475 |
+
class TaoNetHybridLLM(BaseModel):
|
| 476 |
+
"""TaoNet language model with alternating MLA attention and SSM mixer blocks."""
|
| 477 |
+
|
| 478 |
+
def __init__(self, config: ModelConfig):
|
| 479 |
+
super().__init__(config)
|
| 480 |
+
|
| 481 |
+
self.vocab_size = config.vocab_size
|
| 482 |
+
self.d_model = config.hidden_dim
|
| 483 |
+
self.n_layers = config.num_layers
|
| 484 |
+
self.n_heads = config.num_heads
|
| 485 |
+
self.dropout = config.dropout
|
| 486 |
+
self.d_latent_kv = config.d_latent_kv if config.d_latent_kv is not None else int(self.d_model * 0.75)
|
| 487 |
+
self.d_rope = config.d_rope if config.d_rope is not None else self.d_model // self.n_heads
|
| 488 |
+
self.d_ff = config.hidden_dim_ff if config.hidden_dim_ff is not None else self.d_model * 4
|
| 489 |
+
self.gqa_groups = getattr(config, "gqa_groups", 1)
|
| 490 |
+
self.use_factorized_embedding = getattr(config, "use_factorized_embedding", False)
|
| 491 |
+
self.d_embed_rank = getattr(config, "d_embed_rank", 96)
|
| 492 |
+
self.rope_scale = getattr(config, "rope_scale", 40.0)
|
| 493 |
+
self.yarn_alpha = getattr(config, "yarn_alpha", 1.0)
|
| 494 |
+
self.max_seq_length = config.max_seq_length
|
| 495 |
+
|
| 496 |
+
assert self.d_model % self.n_heads == 0, (
|
| 497 |
+
f"hidden_dim ({self.d_model}) must be divisible by num_heads ({self.n_heads})"
|
| 498 |
+
)
|
| 499 |
+
assert self.d_latent_kv % self.n_heads == 0, (
|
| 500 |
+
f"d_latent_kv ({self.d_latent_kv}) must be divisible by num_heads ({self.n_heads})"
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
if self.use_factorized_embedding:
|
| 504 |
+
self.token_embedding = FactorizedEmbedding(
|
| 505 |
+
self.vocab_size,
|
| 506 |
+
self.d_model,
|
| 507 |
+
self.d_embed_rank,
|
| 508 |
+
)
|
| 509 |
+
else:
|
| 510 |
+
self.token_embedding = nn.Embedding(self.vocab_size, self.d_model)
|
| 511 |
+
|
| 512 |
+
self.embedding_dropout = nn.Dropout(self.dropout)
|
| 513 |
+
self.blocks = nn.ModuleList()
|
| 514 |
+
self.block_kinds: list[str] = []
|
| 515 |
+
self.ssm_layer_indices = _hybrid_ssm_layer_indices(config, self.n_layers)
|
| 516 |
+
for layer_idx in range(self.n_layers):
|
| 517 |
+
if layer_idx in self.ssm_layer_indices:
|
| 518 |
+
self.blocks.append(SSMAttentionBlock(config))
|
| 519 |
+
self.block_kinds.append("ssm")
|
| 520 |
+
else:
|
| 521 |
+
self.blocks.append(
|
| 522 |
+
AttentionBlock(
|
| 523 |
+
d_model=self.d_model,
|
| 524 |
+
d_latent_kv=self.d_latent_kv,
|
| 525 |
+
n_heads=self.n_heads,
|
| 526 |
+
d_rope=self.d_rope,
|
| 527 |
+
d_ff=int(self.d_ff),
|
| 528 |
+
dropout=self.dropout,
|
| 529 |
+
gqa_groups=self.gqa_groups,
|
| 530 |
+
rope_scale=self.rope_scale,
|
| 531 |
+
max_seq_length=self.max_seq_length,
|
| 532 |
+
yarn_alpha=self.yarn_alpha,
|
| 533 |
+
residual_rms_norm=config.block_residual_rms_norm,
|
| 534 |
+
residual_rms_target=config.block_residual_rms_target,
|
| 535 |
+
residual_rms_cap=config.block_residual_rms_cap,
|
| 536 |
+
residual_rms_eps=config.block_residual_rms_eps,
|
| 537 |
+
)
|
| 538 |
+
)
|
| 539 |
+
self.block_kinds.append("attention")
|
| 540 |
+
|
| 541 |
+
self.final_norm = nn.LayerNorm(self.d_model)
|
| 542 |
+
self.output_head = nn.Linear(self.d_model, self.vocab_size, bias=False)
|
| 543 |
+
|
| 544 |
+
self.apply(self._init_weights)
|
| 545 |
+
for block in self.blocks:
|
| 546 |
+
mixer = getattr(block, "mixer", None)
|
| 547 |
+
if mixer is not None:
|
| 548 |
+
mixer._reset_parameters()
|
| 549 |
+
|
| 550 |
+
self.register_buffer("causal_mask_cache", None, persistent=False)
|
| 551 |
+
self._print_architecture(config)
|
| 552 |
+
|
| 553 |
+
def _init_weights(self, module):
|
| 554 |
+
if isinstance(module, nn.Linear):
|
| 555 |
+
nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std)
|
| 556 |
+
if module.bias is not None:
|
| 557 |
+
nn.init.zeros_(module.bias)
|
| 558 |
+
elif isinstance(module, nn.Embedding):
|
| 559 |
+
nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std)
|
| 560 |
+
|
| 561 |
+
def _get_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
|
| 562 |
+
if self.causal_mask_cache is None or self.causal_mask_cache.size(-1) < seq_len:
|
| 563 |
+
mask = torch.tril(torch.ones(seq_len, seq_len, device=device, dtype=torch.bool))
|
| 564 |
+
self.register_buffer("causal_mask_cache", mask, persistent=False)
|
| 565 |
+
return self.causal_mask_cache[:seq_len, :seq_len]
|
| 566 |
+
|
| 567 |
+
def _get_combined_mask(
|
| 568 |
+
self,
|
| 569 |
+
attention_mask: Optional[torch.Tensor],
|
| 570 |
+
seq_len: int,
|
| 571 |
+
device: torch.device,
|
| 572 |
+
) -> torch.Tensor:
|
| 573 |
+
causal_mask = self._get_causal_mask(seq_len, device)
|
| 574 |
+
if attention_mask is None:
|
| 575 |
+
return causal_mask.unsqueeze(0).unsqueeze(0).float()
|
| 576 |
+
padding_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
|
| 577 |
+
return (causal_mask.unsqueeze(0).unsqueeze(0) & padding_mask).float()
|
| 578 |
+
|
| 579 |
+
def _print_architecture(self, config: ModelConfig) -> None:
|
| 580 |
+
total_params = sum(p.numel() for p in self.parameters())
|
| 581 |
+
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 582 |
+
attention_blocks = self.block_kinds.count("attention")
|
| 583 |
+
ssm_blocks = self.block_kinds.count("ssm")
|
| 584 |
+
ssm_hidden_dim = config.ssm_hidden_dim if config.ssm_hidden_dim is not None else self.d_latent_kv
|
| 585 |
+
|
| 586 |
+
print(f"\n{'=' * 70}")
|
| 587 |
+
print(f"MODEL ARCHITECTURE - TAONET-HYBRID (MLA + {config.ssm_core} SSM)")
|
| 588 |
+
print(f"{'=' * 70}")
|
| 589 |
+
print(f"Embedding vocab: {self.vocab_size}")
|
| 590 |
+
print(f"Output Head: {(self.d_model * self.vocab_size) / 1e6:>8.2f}M")
|
| 591 |
+
print(f"Attention Blocks: {attention_blocks:>8} layers")
|
| 592 |
+
print(f"SSM Blocks: {ssm_blocks:>8} layers")
|
| 593 |
+
print(f"{'-' * 70}")
|
| 594 |
+
print(f"Total Parameters: {total_params / 1e6:>8.2f}M (trainable: {trainable_params / 1e6:.2f}M)")
|
| 595 |
+
print(f"{'-' * 70}")
|
| 596 |
+
print("Configuration:")
|
| 597 |
+
print(f" Model dimension (d_model): {self.d_model}")
|
| 598 |
+
print(f" KV latent dimension (d_latent_kv): {self.d_latent_kv}")
|
| 599 |
+
print(f" Attention heads: {self.n_heads}")
|
| 600 |
+
print(f" SSM core: {config.ssm_core}")
|
| 601 |
+
print(f" SSM hidden dimension: {ssm_hidden_dim}")
|
| 602 |
+
print(f" SSM mixer dimension: {config.ssm_mixer_dim or self.d_model}")
|
| 603 |
+
print(f" SSM lanes: {config.ssm_num_lanes}")
|
| 604 |
+
print(f" SSM lane mode: {config.ssm_lane_mode}")
|
| 605 |
+
print(f" SSM split mix: {config.ssm_split_mix}")
|
| 606 |
+
print(f" SSM lane combine: {config.ssm_lane_combine}")
|
| 607 |
+
if config.ssm_core == "dplr":
|
| 608 |
+
print(f" SSM DPLR rank: {config.ssm_rank}")
|
| 609 |
+
print(f" SSM finite-tail correction: {config.ssm_finite_tail_correction}")
|
| 610 |
+
print(f" SSM branch RMS norm: {config.ssm_branch_rms_norm}")
|
| 611 |
+
print(f" SSM branch clip value: {config.ssm_branch_clip_value}")
|
| 612 |
+
print(f" Block residual RMS norm: {config.block_residual_rms_norm}")
|
| 613 |
+
print(f" Block residual RMS cap: {config.block_residual_rms_cap}")
|
| 614 |
+
print(f" SSM local shift enabled: {config.ssm_local_shift}")
|
| 615 |
+
print(f" SSM gate type: {config.ssm_gate_type}")
|
| 616 |
+
print(f" Hybrid pattern: {config.hybrid_pattern}")
|
| 617 |
+
print(f" Hybrid SSM layers: {','.join(str(i) for i in sorted(self.ssm_layer_indices))}")
|
| 618 |
+
print(f" Feed-forward dimension: {int(self.d_ff)}")
|
| 619 |
+
print(f" Number of layers: {self.n_layers}")
|
| 620 |
+
print(f" Max sequence length: {self.max_seq_length}")
|
| 621 |
+
print(f" Dropout: {self.dropout}")
|
| 622 |
+
print(f"{'=' * 70}\n")
|
| 623 |
+
|
| 624 |
+
def forward(
|
| 625 |
+
self,
|
| 626 |
+
input_ids: torch.Tensor,
|
| 627 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 628 |
+
labels: Optional[torch.Tensor] = None,
|
| 629 |
+
) -> dict:
|
| 630 |
+
_, seq_len = input_ids.shape
|
| 631 |
+
combined_mask = self._get_combined_mask(attention_mask, seq_len, input_ids.device)
|
| 632 |
+
|
| 633 |
+
x = self.token_embedding(input_ids)
|
| 634 |
+
x = self.embedding_dropout(x)
|
| 635 |
+
|
| 636 |
+
for block in self.blocks:
|
| 637 |
+
x = block(x, attention_mask=combined_mask)
|
| 638 |
+
|
| 639 |
+
x = self.final_norm(x)
|
| 640 |
+
logits = self.output_head(x)
|
| 641 |
+
|
| 642 |
+
loss = None
|
| 643 |
+
if labels is not None:
|
| 644 |
+
loss = F.cross_entropy(
|
| 645 |
+
logits.view(-1, logits.size(-1)),
|
| 646 |
+
labels.view(-1),
|
| 647 |
+
reduction="mean",
|
| 648 |
+
ignore_index=-100,
|
| 649 |
+
)
|
| 650 |
+
|
| 651 |
+
return {
|
| 652 |
+
"logits": logits,
|
| 653 |
+
"loss": loss,
|
| 654 |
+
}
|
code/TaoTrain/src/taoTrain/models/transformer.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Standard Transformer language model implementation."""
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from typing import Optional
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
from taoTrain.core import BaseModel
|
| 10 |
+
from taoTrain.config import ModelConfig
|
| 11 |
+
from .registry import register_architecture
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# ============================================================================
|
| 15 |
+
# Components
|
| 16 |
+
# ============================================================================
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class PositionalEmbedding(nn.Module):
|
| 20 |
+
"""Sinusoidal positional embeddings."""
|
| 21 |
+
|
| 22 |
+
def __init__(self, dim: int, max_seq_length: int = 2048):
|
| 23 |
+
"""Initialize positional embeddings."""
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.dim = dim
|
| 26 |
+
self.max_seq_length = max_seq_length
|
| 27 |
+
|
| 28 |
+
# Precompute positional embeddings
|
| 29 |
+
pe = torch.zeros(max_seq_length, dim)
|
| 30 |
+
pos = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
|
| 31 |
+
div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim))
|
| 32 |
+
|
| 33 |
+
pe[:, 0::2] = torch.sin(pos * div_term)
|
| 34 |
+
if dim % 2 == 1:
|
| 35 |
+
pe[:, 1::2] = torch.cos(pos * div_term[:-1])
|
| 36 |
+
else:
|
| 37 |
+
pe[:, 1::2] = torch.cos(pos * div_term)
|
| 38 |
+
|
| 39 |
+
self.register_buffer("pe", pe, persistent=False)
|
| 40 |
+
|
| 41 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 42 |
+
"""
|
| 43 |
+
Add positional embeddings to input.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
x: Input tensor (batch, seq_len, hidden_dim)
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
Input + positional embeddings
|
| 50 |
+
"""
|
| 51 |
+
seq_len = x.shape[1]
|
| 52 |
+
return x + self.pe[:seq_len]
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class Attention(nn.Module):
|
| 56 |
+
"""Multi-head self-attention using scaled dot-product attention."""
|
| 57 |
+
|
| 58 |
+
def __init__(self, config: ModelConfig):
|
| 59 |
+
"""Initialize attention."""
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.hidden_dim = config.hidden_dim
|
| 62 |
+
self.num_heads = config.num_heads
|
| 63 |
+
self.head_dim = config.head_dim
|
| 64 |
+
|
| 65 |
+
assert self.hidden_dim % self.num_heads == 0
|
| 66 |
+
|
| 67 |
+
# Linear projections
|
| 68 |
+
self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim)
|
| 69 |
+
self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim)
|
| 70 |
+
self.v_proj = nn.Linear(self.hidden_dim, self.hidden_dim)
|
| 71 |
+
self.out_proj = nn.Linear(self.hidden_dim, self.hidden_dim)
|
| 72 |
+
|
| 73 |
+
self.dropout_p = config.dropout
|
| 74 |
+
|
| 75 |
+
def forward(
|
| 76 |
+
self,
|
| 77 |
+
x: torch.Tensor,
|
| 78 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 79 |
+
) -> torch.Tensor:
|
| 80 |
+
"""
|
| 81 |
+
Forward pass using scaled_dot_product_attention.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
x: Shape (batch, seq_len, hidden_dim)
|
| 85 |
+
attention_mask: Shape (batch, seq_len)
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
Output: Shape (batch, seq_len, hidden_dim)
|
| 89 |
+
"""
|
| 90 |
+
batch_size, seq_len, _ = x.shape
|
| 91 |
+
|
| 92 |
+
# Project to Q, K, V
|
| 93 |
+
q = self.q_proj(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim)
|
| 94 |
+
k = self.k_proj(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim)
|
| 95 |
+
v = self.v_proj(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim)
|
| 96 |
+
|
| 97 |
+
# Transpose for attention: (batch, num_heads, seq_len, head_dim)
|
| 98 |
+
q = q.transpose(1, 2)
|
| 99 |
+
k = k.transpose(1, 2)
|
| 100 |
+
v = v.transpose(1, 2)
|
| 101 |
+
|
| 102 |
+
# NOTE: PyTorch's scaled_dot_product_attention does NOT support both
|
| 103 |
+
# explicit attn_mask AND is_causal=True together.
|
| 104 |
+
# When is_causal=True, PyTorch handles causal masking automatically.
|
| 105 |
+
# Padding positions are handled separately via loss computation (labels=-100).
|
| 106 |
+
# See: https://github.com/pytorch/pytorch/issues/96099
|
| 107 |
+
|
| 108 |
+
# Compute attention using scaled_dot_product_attention
|
| 109 |
+
# is_causal=True automatically applies causal masking
|
| 110 |
+
# We do NOT pass attn_mask when is_causal=True
|
| 111 |
+
out = F.scaled_dot_product_attention(
|
| 112 |
+
q, k, v,
|
| 113 |
+
attn_mask=None, # Must be None when is_causal=True
|
| 114 |
+
dropout_p=self.dropout_p if self.training else 0.0,
|
| 115 |
+
is_causal=True,
|
| 116 |
+
scale=None # Uses default scale of 1/sqrt(head_dim)
|
| 117 |
+
) # (batch, num_heads, seq_len, head_dim)
|
| 118 |
+
|
| 119 |
+
# Transpose back and reshape
|
| 120 |
+
out = out.transpose(1, 2).contiguous() # (batch, seq_len, num_heads, head_dim)
|
| 121 |
+
out = out.reshape(batch_size, seq_len, self.hidden_dim)
|
| 122 |
+
|
| 123 |
+
# Output projection
|
| 124 |
+
out = self.out_proj(out)
|
| 125 |
+
|
| 126 |
+
return out
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class SwiGLU(nn.Module):
|
| 130 |
+
"""Swish Gated Linear Unit activation."""
|
| 131 |
+
|
| 132 |
+
def __init__(self, in_dim: int, out_dim: int, dropout: float = 0.0):
|
| 133 |
+
"""
|
| 134 |
+
Initialize SwiGLU.
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
in_dim: Input dimension
|
| 138 |
+
out_dim: Intermediate/hidden dimension
|
| 139 |
+
dropout: Dropout rate
|
| 140 |
+
"""
|
| 141 |
+
super().__init__()
|
| 142 |
+
# Project to 2x the intermediate dimension (for value and gate)
|
| 143 |
+
self.fc1 = nn.Linear(in_dim, 2 * out_dim)
|
| 144 |
+
self.fc2 = nn.Linear(out_dim, in_dim) # Project back to input dimension
|
| 145 |
+
self.dropout = nn.Dropout(dropout)
|
| 146 |
+
|
| 147 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 148 |
+
"""
|
| 149 |
+
Forward pass with SwiGLU activation.
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
x: Input tensor
|
| 153 |
+
|
| 154 |
+
Returns:
|
| 155 |
+
Gated activation output (same dimension as input)
|
| 156 |
+
"""
|
| 157 |
+
# Project to 2x intermediate dimension
|
| 158 |
+
x = self.fc1(x)
|
| 159 |
+
|
| 160 |
+
# Split into value and gate
|
| 161 |
+
x, gate = x.chunk(2, dim=-1)
|
| 162 |
+
|
| 163 |
+
# SwiGLU: value * swish(gate) = value * gate * sigmoid(gate)
|
| 164 |
+
x = x * F.silu(gate) # SiLU is Swish: x * sigmoid(x)
|
| 165 |
+
|
| 166 |
+
x = self.dropout(x)
|
| 167 |
+
x = self.fc2(x) # Project back to input dimension
|
| 168 |
+
|
| 169 |
+
return x
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class FeedForward(nn.Module):
|
| 173 |
+
"""Feed-forward network with SwiGLU activation."""
|
| 174 |
+
|
| 175 |
+
def __init__(self, config: ModelConfig):
|
| 176 |
+
"""Initialize FFN with SwiGLU."""
|
| 177 |
+
super().__init__()
|
| 178 |
+
self.swiglu = SwiGLU(
|
| 179 |
+
in_dim=config.hidden_dim,
|
| 180 |
+
out_dim=config.intermediate_dim,
|
| 181 |
+
dropout=config.dropout
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 185 |
+
"""Forward pass with SwiGLU activation."""
|
| 186 |
+
return self.swiglu(x)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
class TransformerBlock(nn.Module):
|
| 190 |
+
"""Single transformer block with attention and FFN."""
|
| 191 |
+
|
| 192 |
+
def __init__(self, config: ModelConfig):
|
| 193 |
+
"""Initialize transformer block."""
|
| 194 |
+
super().__init__()
|
| 195 |
+
self.norm1 = nn.LayerNorm(config.hidden_dim)
|
| 196 |
+
self.attn = Attention(config)
|
| 197 |
+
self.norm2 = nn.LayerNorm(config.hidden_dim)
|
| 198 |
+
self.ffn = FeedForward(config)
|
| 199 |
+
|
| 200 |
+
def forward(
|
| 201 |
+
self,
|
| 202 |
+
x: torch.Tensor,
|
| 203 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 204 |
+
) -> torch.Tensor:
|
| 205 |
+
"""Forward pass with pre-norm residual connections."""
|
| 206 |
+
# Attention with residual
|
| 207 |
+
x = x + self.attn(self.norm1(x), attention_mask=attention_mask)
|
| 208 |
+
|
| 209 |
+
# FFN with residual
|
| 210 |
+
x = x + self.ffn(self.norm2(x))
|
| 211 |
+
|
| 212 |
+
return x
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
# ============================================================================
|
| 216 |
+
# Transformer LM
|
| 217 |
+
# ============================================================================
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
@register_architecture("transformer")
|
| 221 |
+
class TransformerLM(BaseModel):
|
| 222 |
+
"""Standard Transformer language model."""
|
| 223 |
+
|
| 224 |
+
def __init__(self, config: ModelConfig):
|
| 225 |
+
"""Initialize Transformer LM."""
|
| 226 |
+
super().__init__(config)
|
| 227 |
+
|
| 228 |
+
# Embeddings
|
| 229 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_dim)
|
| 230 |
+
self.pos_embed = PositionalEmbedding(config.hidden_dim, max_seq_length=config.max_seq_length)
|
| 231 |
+
self.dropout = nn.Dropout(config.dropout)
|
| 232 |
+
|
| 233 |
+
# Transformer blocks
|
| 234 |
+
self.blocks = nn.ModuleList([
|
| 235 |
+
TransformerBlock(config) for _ in range(config.num_layers)
|
| 236 |
+
])
|
| 237 |
+
|
| 238 |
+
# Final layer norm
|
| 239 |
+
self.final_norm = nn.LayerNorm(config.hidden_dim)
|
| 240 |
+
|
| 241 |
+
# Output projection (shared with input embeddings for efficiency)
|
| 242 |
+
self.lm_head = nn.Linear(config.hidden_dim, config.vocab_size, bias=False)
|
| 243 |
+
|
| 244 |
+
# Weight tying (optional)
|
| 245 |
+
self.lm_head.weight = self.embed_tokens.weight
|
| 246 |
+
|
| 247 |
+
# Initialize weights
|
| 248 |
+
self._init_weights()
|
| 249 |
+
|
| 250 |
+
def _init_weights(self):
|
| 251 |
+
"""Initialize model weights."""
|
| 252 |
+
for module in self.modules():
|
| 253 |
+
if isinstance(module, nn.Linear):
|
| 254 |
+
nn.init.normal_(module.weight, std=self.config.init_std)
|
| 255 |
+
if module.bias is not None:
|
| 256 |
+
nn.init.zeros_(module.bias)
|
| 257 |
+
elif isinstance(module, nn.Embedding):
|
| 258 |
+
nn.init.normal_(module.weight, std=self.config.init_std)
|
| 259 |
+
|
| 260 |
+
def forward(
|
| 261 |
+
self,
|
| 262 |
+
input_ids: torch.Tensor,
|
| 263 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 264 |
+
labels: Optional[torch.Tensor] = None,
|
| 265 |
+
) -> dict[str, torch.Tensor]:
|
| 266 |
+
"""
|
| 267 |
+
Forward pass.
|
| 268 |
+
|
| 269 |
+
Args:
|
| 270 |
+
input_ids: (batch_size, seq_len)
|
| 271 |
+
attention_mask: (batch_size, seq_len)
|
| 272 |
+
labels: (batch_size, seq_len) for loss computation
|
| 273 |
+
|
| 274 |
+
Returns:
|
| 275 |
+
Dict with 'logits' and optionally 'loss'
|
| 276 |
+
"""
|
| 277 |
+
batch_size, seq_len = input_ids.shape
|
| 278 |
+
|
| 279 |
+
# Embedding
|
| 280 |
+
x = self.embed_tokens(input_ids)
|
| 281 |
+
|
| 282 |
+
# Add positional embeddings
|
| 283 |
+
x = self.pos_embed(x)
|
| 284 |
+
|
| 285 |
+
x = self.dropout(x)
|
| 286 |
+
|
| 287 |
+
# Transformer blocks
|
| 288 |
+
for block in self.blocks:
|
| 289 |
+
x = block(x, attention_mask=attention_mask)
|
| 290 |
+
|
| 291 |
+
# Final normalization
|
| 292 |
+
x = self.final_norm(x)
|
| 293 |
+
|
| 294 |
+
# LM head
|
| 295 |
+
logits = self.lm_head(x) # (batch, seq_len, vocab_size)
|
| 296 |
+
|
| 297 |
+
# Loss computation
|
| 298 |
+
loss = None
|
| 299 |
+
if labels is not None:
|
| 300 |
+
# Flatten for loss computation
|
| 301 |
+
logits_flat = logits.view(-1, logits.size(-1)) # (batch * seq_len, vocab_size)
|
| 302 |
+
labels_flat = labels.view(-1)
|
| 303 |
+
|
| 304 |
+
# Only compute loss on valid targets (ignore -100 tokens)
|
| 305 |
+
loss = F.cross_entropy(
|
| 306 |
+
logits_flat,
|
| 307 |
+
labels_flat,
|
| 308 |
+
reduction='mean',
|
| 309 |
+
ignore_index=-100
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
return {
|
| 313 |
+
'logits': logits,
|
| 314 |
+
'loss': loss,
|
| 315 |
+
}
|
code/TaoTrain/src/taoTrain/optimizers/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Optimizer registry and factories."""
|
| 2 |
+
|
| 3 |
+
from .registry import (
|
| 4 |
+
register_optimizer,
|
| 5 |
+
get_optimizer,
|
| 6 |
+
get_registered_optimizers,
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
"register_optimizer",
|
| 11 |
+
"get_optimizer",
|
| 12 |
+
"get_registered_optimizers",
|
| 13 |
+
]
|
code/TaoTrain/src/taoTrain/optimizers/adam.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Adam optimizer factory."""
|
| 2 |
+
|
| 3 |
+
import torch.optim as optim
|
| 4 |
+
from taoTrain.core.base import BaseModel
|
| 5 |
+
from taoTrain.config import TrainingConfig
|
| 6 |
+
from .registry import register_optimizer
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def _separate_parameters(model: BaseModel) -> tuple[list, list]:
|
| 10 |
+
"""
|
| 11 |
+
Separate model parameters into decay and no-decay groups.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
model: Model instance
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
Tuple of (decay_params, no_decay_params)
|
| 18 |
+
"""
|
| 19 |
+
decay_params = []
|
| 20 |
+
no_decay_params = []
|
| 21 |
+
|
| 22 |
+
for name, param in model.named_parameters():
|
| 23 |
+
if not param.requires_grad:
|
| 24 |
+
continue
|
| 25 |
+
|
| 26 |
+
# Apply weight decay to all params except biases and layer norms
|
| 27 |
+
if 'bias' in name or 'norm' in name:
|
| 28 |
+
no_decay_params.append(param)
|
| 29 |
+
else:
|
| 30 |
+
decay_params.append(param)
|
| 31 |
+
|
| 32 |
+
return decay_params, no_decay_params
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@register_optimizer("adam")
|
| 36 |
+
def create_adam(model: BaseModel, config: TrainingConfig) -> optim.Adam:
|
| 37 |
+
"""
|
| 38 |
+
Create Adam optimizer with weight decay applied selectively.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
model: Model instance
|
| 42 |
+
config: TrainingConfig
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
Adam optimizer instance
|
| 46 |
+
"""
|
| 47 |
+
optimizer_config = config.optimizer
|
| 48 |
+
|
| 49 |
+
# Separate parameters for weight decay
|
| 50 |
+
decay_params, no_decay_params = _separate_parameters(model)
|
| 51 |
+
|
| 52 |
+
param_groups = [
|
| 53 |
+
{"params": decay_params, "weight_decay": optimizer_config.weight_decay},
|
| 54 |
+
{"params": no_decay_params, "weight_decay": 0.0},
|
| 55 |
+
]
|
| 56 |
+
|
| 57 |
+
optimizer = optim.Adam(
|
| 58 |
+
param_groups,
|
| 59 |
+
lr=optimizer_config.learning_rate,
|
| 60 |
+
betas=optimizer_config.betas,
|
| 61 |
+
eps=optimizer_config.eps,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
return optimizer
|
code/TaoTrain/src/taoTrain/optimizers/adamw.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""AdamW optimizer factory."""
|
| 2 |
+
|
| 3 |
+
import torch.optim as optim
|
| 4 |
+
from taoTrain.core.base import BaseModel
|
| 5 |
+
from taoTrain.config import TrainingConfig
|
| 6 |
+
from .registry import register_optimizer
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def _separate_parameters(model: BaseModel) -> tuple[list, list]:
|
| 10 |
+
"""
|
| 11 |
+
Separate model parameters into decay and no-decay groups.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
model: Model instance
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
Tuple of (decay_params, no_decay_params)
|
| 18 |
+
"""
|
| 19 |
+
decay_params = []
|
| 20 |
+
no_decay_params = []
|
| 21 |
+
|
| 22 |
+
for name, param in model.named_parameters():
|
| 23 |
+
if not param.requires_grad:
|
| 24 |
+
continue
|
| 25 |
+
|
| 26 |
+
# Apply weight decay to all params except biases and layer norms
|
| 27 |
+
if 'bias' in name or 'norm' in name:
|
| 28 |
+
no_decay_params.append(param)
|
| 29 |
+
else:
|
| 30 |
+
decay_params.append(param)
|
| 31 |
+
|
| 32 |
+
return decay_params, no_decay_params
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@register_optimizer("adamw")
|
| 36 |
+
def create_adamw(model: BaseModel, config: TrainingConfig) -> optim.AdamW:
|
| 37 |
+
"""
|
| 38 |
+
Create AdamW optimizer with weight decay applied selectively.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
model: Model instance
|
| 42 |
+
config: TrainingConfig
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
AdamW optimizer instance
|
| 46 |
+
"""
|
| 47 |
+
optimizer_config = config.optimizer
|
| 48 |
+
|
| 49 |
+
# Separate parameters for weight decay
|
| 50 |
+
decay_params, no_decay_params = _separate_parameters(model)
|
| 51 |
+
|
| 52 |
+
param_groups = [
|
| 53 |
+
{"params": decay_params, "weight_decay": optimizer_config.weight_decay},
|
| 54 |
+
{"params": no_decay_params, "weight_decay": 0.0},
|
| 55 |
+
]
|
| 56 |
+
|
| 57 |
+
optimizer = optim.AdamW(
|
| 58 |
+
param_groups,
|
| 59 |
+
lr=optimizer_config.learning_rate,
|
| 60 |
+
betas=optimizer_config.betas,
|
| 61 |
+
eps=optimizer_config.eps,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
return optimizer
|
code/TaoTrain/src/taoTrain/optimizers/hybrid_muon_adamw.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hybrid Muon + AdamW Optimizer for TaoNet models.
|
| 3 |
+
|
| 4 |
+
Combines:
|
| 5 |
+
- Muon: Specialized optimization for 2D weight matrices (linear layers)
|
| 6 |
+
Leverages orthogonal/SVD-based updates for better convergence on matrix weights
|
| 7 |
+
- AdamW: Adaptive moment estimation for 1D parameters (biases, norms, embeddings)
|
| 8 |
+
|
| 9 |
+
Key Design:
|
| 10 |
+
- 2D weight matrices use Muon optimizer with separate LRs for different layer types
|
| 11 |
+
- 1D parameters use AdamW with lower learning rate
|
| 12 |
+
- Inherits from torch.optim.Optimizer for LR scheduler compatibility
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
from typing import Dict, List, Any
|
| 18 |
+
|
| 19 |
+
from .registry import register_optimizer
|
| 20 |
+
from taoTrain.config import TrainingConfig
|
| 21 |
+
from taoTrain.core.base import BaseModel
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _get_param_dimensionality(param: torch.Tensor) -> str:
|
| 25 |
+
"""
|
| 26 |
+
Determine if a parameter is 2D (weight matrix) or 1D (bias/embedding/norm).
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
'weight_2d': Parameter has 2+ dimensions (for Muon)
|
| 30 |
+
'1d_other': Parameter is 1D (for AdamW)
|
| 31 |
+
"""
|
| 32 |
+
if param.dim() >= 2:
|
| 33 |
+
return 'weight_2d'
|
| 34 |
+
return '1d_other'
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class HybridMuonAdamW(torch.optim.Optimizer):
|
| 38 |
+
"""
|
| 39 |
+
Composite optimizer combining Muon (for 2D weights) and AdamW (for 1D params).
|
| 40 |
+
|
| 41 |
+
Why: Muon is specialized for 2D weight matrices in neural networks.
|
| 42 |
+
Biases, embeddings, and layer norms should use AdamW for adaptive convergence.
|
| 43 |
+
|
| 44 |
+
Inherits from torch.optim.Optimizer to be compatible with LR schedulers.
|
| 45 |
+
Manages two internal optimizers: Muon and AdamW.
|
| 46 |
+
|
| 47 |
+
Public interface compatible with standard PyTorch optimizers:
|
| 48 |
+
- step(): delegates to both internal optimizers
|
| 49 |
+
- zero_grad(set_to_none=True): delegates to both
|
| 50 |
+
- state_dict(): returns combined state
|
| 51 |
+
- load_state_dict(state): restores combined state
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def __init__(
|
| 55 |
+
self,
|
| 56 |
+
muon_params_groups: List[Dict[str, Any]],
|
| 57 |
+
adamw_params_group: Dict[str, Any],
|
| 58 |
+
muon_kwargs: Dict[str, Any],
|
| 59 |
+
adamw_kwargs: Dict[str, Any]
|
| 60 |
+
):
|
| 61 |
+
"""
|
| 62 |
+
Initialize HybridMuonAdamW optimizer.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
muon_params_groups: List of param groups for Muon optimizer
|
| 66 |
+
Each group should have 'params' and 'lr' keys
|
| 67 |
+
adamw_params_group: Dict param group for AdamW optimizer
|
| 68 |
+
Should have 'params' and 'lr' keys
|
| 69 |
+
muon_kwargs: Additional kwargs for torch.optim.Muon init
|
| 70 |
+
adamw_kwargs: Additional kwargs for torch.optim.AdamW init
|
| 71 |
+
"""
|
| 72 |
+
# Dummy params list for parent Optimizer class (required for registration)
|
| 73 |
+
# Real params are managed by internal optimizers
|
| 74 |
+
dummy_param = torch.nn.Parameter(torch.zeros(1))
|
| 75 |
+
super().__init__([dummy_param], {})
|
| 76 |
+
|
| 77 |
+
# Create internal optimizers with their parameter groups
|
| 78 |
+
try:
|
| 79 |
+
self.muon = torch.optim.Muon(muon_params_groups, **muon_kwargs)
|
| 80 |
+
except AttributeError:
|
| 81 |
+
raise RuntimeError(
|
| 82 |
+
"torch.optim.Muon not available. "
|
| 83 |
+
"Muon optimizer requires PyTorch 2.1+. "
|
| 84 |
+
"Please upgrade PyTorch: pip install --upgrade torch"
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
self.adamw = torch.optim.AdamW([adamw_params_group], **adamw_kwargs)
|
| 88 |
+
|
| 89 |
+
# Merge param_groups from both optimizers
|
| 90 |
+
# LR schedulers will update these merged groups
|
| 91 |
+
self.param_groups = self.muon.param_groups + self.adamw.param_groups
|
| 92 |
+
|
| 93 |
+
def step(self, closure=None):
|
| 94 |
+
"""Execute optimization step for both Muon and AdamW."""
|
| 95 |
+
if closure is not None:
|
| 96 |
+
loss = closure()
|
| 97 |
+
else:
|
| 98 |
+
loss = None
|
| 99 |
+
|
| 100 |
+
self.muon.step(closure)
|
| 101 |
+
self.adamw.step(closure)
|
| 102 |
+
|
| 103 |
+
return loss
|
| 104 |
+
|
| 105 |
+
def zero_grad(self, set_to_none: bool = False):
|
| 106 |
+
"""Zero gradients in both optimizers."""
|
| 107 |
+
self.muon.zero_grad(set_to_none=set_to_none)
|
| 108 |
+
self.adamw.zero_grad(set_to_none=set_to_none)
|
| 109 |
+
|
| 110 |
+
def state_dict(self) -> Dict[str, Any]:
|
| 111 |
+
"""Return combined state dict for both optimizers."""
|
| 112 |
+
return {
|
| 113 |
+
'muon': self.muon.state_dict(),
|
| 114 |
+
'adamw': self.adamw.state_dict(),
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
def load_state_dict(self, state_dict: Dict[str, Any]):
|
| 118 |
+
"""
|
| 119 |
+
Restore state from combined state dict.
|
| 120 |
+
|
| 121 |
+
Supports both new format (composite with Muon+AdamW) and legacy format
|
| 122 |
+
(AdamW-only checkpoints) for backward compatibility.
|
| 123 |
+
"""
|
| 124 |
+
if isinstance(state_dict, dict):
|
| 125 |
+
if 'muon' in state_dict and 'adamw' in state_dict:
|
| 126 |
+
# New format: composite optimizer with both Muon and AdamW
|
| 127 |
+
self.muon.load_state_dict(state_dict['muon'])
|
| 128 |
+
self.adamw.load_state_dict(state_dict['adamw'])
|
| 129 |
+
elif 'state' in state_dict or 'param_groups' in state_dict:
|
| 130 |
+
# Legacy format: old AdamW-only checkpoint
|
| 131 |
+
# Load into AdamW optimizer only, Muon starts fresh
|
| 132 |
+
try:
|
| 133 |
+
self.adamw.load_state_dict(state_dict)
|
| 134 |
+
print(" ⚠️ Loaded legacy AdamW-only checkpoint (Muon state initialized fresh)")
|
| 135 |
+
except Exception as e:
|
| 136 |
+
print(f" ⚠️ Failed to load optimizer state: {e}")
|
| 137 |
+
else:
|
| 138 |
+
print(f" ⚠️ Unknown checkpoint format")
|
| 139 |
+
else:
|
| 140 |
+
raise ValueError(f"Expected dict, got {type(state_dict)}")
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
@register_optimizer("hybrid_muon_adamw")
|
| 144 |
+
def create_hybrid_muon_adamw(model: BaseModel, training_config: TrainingConfig) -> HybridMuonAdamW:
|
| 145 |
+
"""
|
| 146 |
+
Factory function to create HybridMuonAdamW optimizer from model and config.
|
| 147 |
+
|
| 148 |
+
Parameter grouping strategy:
|
| 149 |
+
- Muon groups (2D weight matrices):
|
| 150 |
+
* Regular Linear 2D weights → learning_rate
|
| 151 |
+
* (BitLinear would use bitlinear_lr, but skipped in BF16 version)
|
| 152 |
+
- AdamW group (1D parameters):
|
| 153 |
+
* Biases, layer norms, embeddings → adamw_lr
|
| 154 |
+
|
| 155 |
+
Args:
|
| 156 |
+
model: PyTorch model to optimize
|
| 157 |
+
training_config: TrainingConfig with optimizer hyperparameters:
|
| 158 |
+
- learning_rate: LR for 2D Linear weights (Muon)
|
| 159 |
+
- adamw_lr: LR for 1D parameters (AdamW)
|
| 160 |
+
- weight_decay: L2 regularization
|
| 161 |
+
- betas: (beta1, beta2) for AdamW
|
| 162 |
+
- eps: epsilon for numerical stability
|
| 163 |
+
|
| 164 |
+
Returns:
|
| 165 |
+
HybridMuonAdamW optimizer instance
|
| 166 |
+
"""
|
| 167 |
+
|
| 168 |
+
# Separate parameters by dimensionality
|
| 169 |
+
linear_2d_weights = []
|
| 170 |
+
params_1d = []
|
| 171 |
+
|
| 172 |
+
# Classify all parameters
|
| 173 |
+
for module_name, module in model.named_modules():
|
| 174 |
+
for param_name, param in module.named_parameters(recurse=False):
|
| 175 |
+
if not param.requires_grad:
|
| 176 |
+
continue
|
| 177 |
+
|
| 178 |
+
param_dim = _get_param_dimensionality(param)
|
| 179 |
+
|
| 180 |
+
if param_dim == 'weight_2d' and isinstance(module, nn.Linear):
|
| 181 |
+
# 2D Linear weights → Muon
|
| 182 |
+
linear_2d_weights.append(param)
|
| 183 |
+
else:
|
| 184 |
+
# Everything else → AdamW (1D params + other 2D tensors)
|
| 185 |
+
params_1d.append(param)
|
| 186 |
+
|
| 187 |
+
# Verify we got all parameters
|
| 188 |
+
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 189 |
+
muon_params = sum(p.numel() for p in linear_2d_weights)
|
| 190 |
+
adamw_params = sum(p.numel() for p in params_1d)
|
| 191 |
+
assert total_params == muon_params + adamw_params, \
|
| 192 |
+
f"Parameter accounting error: {total_params} != {muon_params} + {adamw_params}"
|
| 193 |
+
|
| 194 |
+
# Prepare Muon parameter groups (one group with single LR for all Linear 2D weights)
|
| 195 |
+
muon_params_groups = [
|
| 196 |
+
{
|
| 197 |
+
'params': linear_2d_weights,
|
| 198 |
+
'lr': training_config.optimizer.learning_rate, # Use main learning_rate for Muon
|
| 199 |
+
}
|
| 200 |
+
]
|
| 201 |
+
|
| 202 |
+
# Prepare AdamW parameter group (1D parameters with lower LR)
|
| 203 |
+
adamw_params_group = {
|
| 204 |
+
'params': params_1d,
|
| 205 |
+
'lr': training_config.optimizer.adamw_lr, # Use adamw_lr for 1D params
|
| 206 |
+
'weight_decay': training_config.optimizer.weight_decay,
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
# Extract Muon kwargs (settings common to all Muon param groups)
|
| 210 |
+
muon_kwargs = {
|
| 211 |
+
'lr': training_config.optimizer.learning_rate, # Will be overridden by param_groups above
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
# Extract AdamW kwargs
|
| 215 |
+
adamw_kwargs = {
|
| 216 |
+
'betas': training_config.optimizer.betas,
|
| 217 |
+
'eps': training_config.optimizer.eps,
|
| 218 |
+
'weight_decay': training_config.optimizer.weight_decay,
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
# Print optimizer setup details
|
| 222 |
+
print(f"\n{'='*70}")
|
| 223 |
+
print("OPTIMIZER SETUP - HYBRID MUON + ADAMW")
|
| 224 |
+
print(f"{'='*70}")
|
| 225 |
+
print("\n[MUON - 2D Weight Matrices (Orthogonal Optimization)]")
|
| 226 |
+
print(f"Linear 2D weights: {muon_params/1e6:>8.2f}M")
|
| 227 |
+
print(f" Learning Rate: {training_config.optimizer.learning_rate}")
|
| 228 |
+
print(f"\n[ADAMW - 1D Parameters (Adaptive Moments)]")
|
| 229 |
+
print(f"Biases, embeddings, norms: {adamw_params/1e6:>8.2f}M")
|
| 230 |
+
print(f" Learning Rate: {training_config.optimizer.adamw_lr}")
|
| 231 |
+
print(f"{'─'*70}")
|
| 232 |
+
print(f"Total (Muon): {muon_params/1e6:>8.2f}M")
|
| 233 |
+
print(f"Total (AdamW): {adamw_params/1e6:>8.2f}M")
|
| 234 |
+
print(f"Total (All): {total_params/1e6:>8.2f}M")
|
| 235 |
+
print(f"{'─'*70}")
|
| 236 |
+
print(f"Hyperparameters:")
|
| 237 |
+
print(f" Weight Decay: {training_config.optimizer.weight_decay}")
|
| 238 |
+
print(f" Betas (AdamW): {training_config.optimizer.betas}")
|
| 239 |
+
print(f" Epsilon: {training_config.optimizer.eps}")
|
| 240 |
+
print(f"{'='*70}\n")
|
| 241 |
+
|
| 242 |
+
# Create and return optimizer
|
| 243 |
+
return HybridMuonAdamW(muon_params_groups, adamw_params_group, muon_kwargs, adamw_kwargs)
|
code/TaoTrain/src/taoTrain/optimizers/registry.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Optimizer registry and factory for instantiating optimizers."""
|
| 2 |
+
|
| 3 |
+
from typing import Dict, Type, Callable, Any
|
| 4 |
+
import torch.optim as optim
|
| 5 |
+
from taoTrain.core.base import BaseModel
|
| 6 |
+
from taoTrain.config import TrainingConfig, OptimizerEnum
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# Global registry for optimizers
|
| 10 |
+
_OPTIMIZER_REGISTRY: Dict[str, Callable] = {}
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def register_optimizer(name: str):
|
| 14 |
+
"""
|
| 15 |
+
Decorator to register a custom optimizer factory function.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
name: Name of the optimizer (e.g., 'adamw', 'adam', 'sgd')
|
| 19 |
+
"""
|
| 20 |
+
def decorator(fn: Callable) -> Callable:
|
| 21 |
+
if name in _OPTIMIZER_REGISTRY:
|
| 22 |
+
raise ValueError(f"Optimizer '{name}' is already registered")
|
| 23 |
+
_OPTIMIZER_REGISTRY[name] = fn
|
| 24 |
+
return fn
|
| 25 |
+
return decorator
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def get_registered_optimizers() -> Dict[str, Callable]:
|
| 29 |
+
"""Get all registered optimizer factory functions."""
|
| 30 |
+
return _OPTIMIZER_REGISTRY.copy()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_optimizer(
|
| 34 |
+
model: BaseModel,
|
| 35 |
+
config: TrainingConfig,
|
| 36 |
+
) -> optim.Optimizer:
|
| 37 |
+
"""
|
| 38 |
+
Create an optimizer instance from config.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
model: Model to optimize
|
| 42 |
+
config: TrainingConfig with optimizer configuration
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
Optimizer instance
|
| 46 |
+
|
| 47 |
+
Raises:
|
| 48 |
+
ValueError: If optimizer type is not registered
|
| 49 |
+
"""
|
| 50 |
+
# Handle both enum and string values
|
| 51 |
+
optimizer_type = config.optimizer.optimizer_type
|
| 52 |
+
if isinstance(optimizer_type, str):
|
| 53 |
+
optimizer_name = optimizer_type
|
| 54 |
+
else:
|
| 55 |
+
optimizer_name = optimizer_type.value
|
| 56 |
+
|
| 57 |
+
if optimizer_name not in _OPTIMIZER_REGISTRY:
|
| 58 |
+
raise ValueError(
|
| 59 |
+
f"Unknown optimizer: {optimizer_name}. "
|
| 60 |
+
f"Available: {list(_OPTIMIZER_REGISTRY.keys())}"
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
factory_fn = _OPTIMIZER_REGISTRY[optimizer_name]
|
| 64 |
+
return factory_fn(model, config)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def register_builtin_optimizers():
|
| 68 |
+
"""Register all built-in optimizers."""
|
| 69 |
+
# Import here to trigger decorator registration (avoid circular imports)
|
| 70 |
+
from . import adamw # noqa: F401
|
| 71 |
+
from . import adam # noqa: F401
|
| 72 |
+
from . import sgd # noqa: F401
|
| 73 |
+
from . import hybrid_muon_adamw # noqa: F401
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# Auto-register built-in optimizers when module is imported
|
| 77 |
+
register_builtin_optimizers()
|
code/TaoTrain/src/taoTrain/optimizers/sgd.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SGD optimizer factory."""
|
| 2 |
+
|
| 3 |
+
import torch.optim as optim
|
| 4 |
+
from taoTrain.core.base import BaseModel
|
| 5 |
+
from taoTrain.config import TrainingConfig
|
| 6 |
+
from .registry import register_optimizer
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def _separate_parameters(model: BaseModel) -> tuple[list, list]:
|
| 10 |
+
"""
|
| 11 |
+
Separate model parameters into decay and no-decay groups.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
model: Model instance
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
Tuple of (decay_params, no_decay_params)
|
| 18 |
+
"""
|
| 19 |
+
decay_params = []
|
| 20 |
+
no_decay_params = []
|
| 21 |
+
|
| 22 |
+
for name, param in model.named_parameters():
|
| 23 |
+
if not param.requires_grad:
|
| 24 |
+
continue
|
| 25 |
+
|
| 26 |
+
# Apply weight decay to all params except biases and layer norms
|
| 27 |
+
if 'bias' in name or 'norm' in name:
|
| 28 |
+
no_decay_params.append(param)
|
| 29 |
+
else:
|
| 30 |
+
decay_params.append(param)
|
| 31 |
+
|
| 32 |
+
return decay_params, no_decay_params
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@register_optimizer("sgd")
|
| 36 |
+
def create_sgd(model: BaseModel, config: TrainingConfig) -> optim.SGD:
|
| 37 |
+
"""
|
| 38 |
+
Create SGD optimizer with weight decay applied selectively.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
model: Model instance
|
| 42 |
+
config: TrainingConfig
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
SGD optimizer instance
|
| 46 |
+
"""
|
| 47 |
+
optimizer_config = config.optimizer
|
| 48 |
+
|
| 49 |
+
# Separate parameters for weight decay
|
| 50 |
+
decay_params, no_decay_params = _separate_parameters(model)
|
| 51 |
+
|
| 52 |
+
param_groups = [
|
| 53 |
+
{"params": decay_params, "weight_decay": optimizer_config.weight_decay},
|
| 54 |
+
{"params": no_decay_params, "weight_decay": 0.0},
|
| 55 |
+
]
|
| 56 |
+
|
| 57 |
+
optimizer = optim.SGD(
|
| 58 |
+
param_groups,
|
| 59 |
+
lr=optimizer_config.learning_rate,
|
| 60 |
+
momentum=optimizer_config.betas[0], # Use first beta as momentum
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
return optimizer
|
code/TaoTrain/src/taoTrain/schedulers/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Learning rate scheduler registry and factories."""
|
| 2 |
+
|
| 3 |
+
from .registry import (
|
| 4 |
+
register_scheduler,
|
| 5 |
+
get_scheduler,
|
| 6 |
+
get_registered_schedulers,
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
"register_scheduler",
|
| 11 |
+
"get_scheduler",
|
| 12 |
+
"get_registered_schedulers",
|
| 13 |
+
]
|
code/TaoTrain/src/taoTrain/schedulers/constant.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Constant learning rate scheduler with optional warmup."""
|
| 2 |
+
|
| 3 |
+
import torch.optim as optim
|
| 4 |
+
from torch.optim.lr_scheduler import LambdaLR
|
| 5 |
+
from taoTrain.config import TrainingConfig
|
| 6 |
+
from .registry import register_scheduler
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@register_scheduler("constant")
|
| 10 |
+
def create_constant(
|
| 11 |
+
optimizer: optim.Optimizer,
|
| 12 |
+
config: TrainingConfig,
|
| 13 |
+
num_training_steps: int,
|
| 14 |
+
) -> LambdaLR:
|
| 15 |
+
"""
|
| 16 |
+
Create a constant learning rate scheduler with optional linear warmup.
|
| 17 |
+
|
| 18 |
+
Linearly increases learning rate from 0 to peak over warmup steps,
|
| 19 |
+
then keeps it constant for the rest of training.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
optimizer: Optimizer instance
|
| 23 |
+
config: TrainingConfig with scheduler configuration
|
| 24 |
+
num_training_steps: Total number of training steps
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
LambdaLR scheduler instance
|
| 28 |
+
"""
|
| 29 |
+
scheduler_config = config.scheduler
|
| 30 |
+
|
| 31 |
+
# Determine warmup steps
|
| 32 |
+
if scheduler_config.warmup_steps > 0:
|
| 33 |
+
warmup_steps = scheduler_config.warmup_steps
|
| 34 |
+
else:
|
| 35 |
+
warmup_steps = int(num_training_steps * scheduler_config.warmup_ratio)
|
| 36 |
+
|
| 37 |
+
def lr_lambda(step):
|
| 38 |
+
"""Constant learning rate with optional warmup."""
|
| 39 |
+
if step < warmup_steps:
|
| 40 |
+
# Linear warmup
|
| 41 |
+
return float(step) / float(max(1, warmup_steps))
|
| 42 |
+
return 1.0
|
| 43 |
+
|
| 44 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch=scheduler_config.last_epoch)
|
code/TaoTrain/src/taoTrain/schedulers/cosine_warmup.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Cosine annealing with warmup learning rate scheduler."""
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
import torch.optim as optim
|
| 5 |
+
from torch.optim.lr_scheduler import LambdaLR
|
| 6 |
+
from taoTrain.config import TrainingConfig
|
| 7 |
+
from .registry import register_scheduler
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@register_scheduler("cosineWarmup")
|
| 11 |
+
def create_cosine_warmup(
|
| 12 |
+
optimizer: optim.Optimizer,
|
| 13 |
+
config: TrainingConfig,
|
| 14 |
+
num_training_steps: int,
|
| 15 |
+
) -> LambdaLR:
|
| 16 |
+
"""
|
| 17 |
+
Create a cosine annealing scheduler with optional linear warmup, steady phase, and decay.
|
| 18 |
+
|
| 19 |
+
Three-phase schedule:
|
| 20 |
+
1. Linear warmup: 0 → 1.0 (warmup_steps)
|
| 21 |
+
2. Steady phase: 1.0 (plateau at peak LR)
|
| 22 |
+
3. Cosine decay: 1.0 → min_lr_ratio
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
optimizer: Optimizer instance
|
| 26 |
+
config: TrainingConfig with scheduler configuration:
|
| 27 |
+
- warmup_steps: linear warmup duration (overrides warmup_ratio if > 0)
|
| 28 |
+
- warmup_ratio: warmup as fraction of total steps (default 0.1)
|
| 29 |
+
- steady_ratio: steady phase as fraction of total steps (default 0.0)
|
| 30 |
+
- min_lr_ratio: minimum LR at end as fraction of peak (default 0.0)
|
| 31 |
+
num_training_steps: Total number of training steps
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
LambdaLR scheduler instance
|
| 35 |
+
"""
|
| 36 |
+
scheduler_config = config.scheduler
|
| 37 |
+
|
| 38 |
+
# Determine warmup steps
|
| 39 |
+
if scheduler_config.warmup_steps > 0:
|
| 40 |
+
warmup_steps = scheduler_config.warmup_steps
|
| 41 |
+
else:
|
| 42 |
+
warmup_steps = int(num_training_steps * scheduler_config.warmup_ratio)
|
| 43 |
+
|
| 44 |
+
# Determine steady phase steps
|
| 45 |
+
steady_steps = int(num_training_steps * scheduler_config.steady_ratio)
|
| 46 |
+
|
| 47 |
+
# Remaining steps for cosine decay
|
| 48 |
+
decay_steps = num_training_steps - warmup_steps - steady_steps
|
| 49 |
+
|
| 50 |
+
min_lr_ratio = scheduler_config.min_lr_ratio
|
| 51 |
+
num_cycles = scheduler_config.num_cycles
|
| 52 |
+
|
| 53 |
+
print(f"✓ CosineWarmup scheduler: warmup={warmup_steps}, steady={steady_steps}, decay={decay_steps} (total={num_training_steps})")
|
| 54 |
+
print(f" min_lr_ratio={min_lr_ratio}, num_cycles={num_cycles}")
|
| 55 |
+
|
| 56 |
+
def lr_lambda(step):
|
| 57 |
+
"""Three-phase LR schedule: warmup → steady → cosine decay."""
|
| 58 |
+
if step < warmup_steps:
|
| 59 |
+
# Phase 1: Linear warmup from 0 to 1.0
|
| 60 |
+
return float(step) / float(max(1, warmup_steps))
|
| 61 |
+
|
| 62 |
+
elif step < warmup_steps + steady_steps:
|
| 63 |
+
# Phase 2: Steady at peak LR (1.0)
|
| 64 |
+
return 1.0
|
| 65 |
+
|
| 66 |
+
else:
|
| 67 |
+
# Phase 3: Cosine decay from 1.0 to min_lr_ratio
|
| 68 |
+
decay_step = step - warmup_steps - steady_steps
|
| 69 |
+
progress = float(decay_step) / float(max(1, decay_steps))
|
| 70 |
+
|
| 71 |
+
# Cosine annealing: 0.5 * (1 + cos(π * progress))
|
| 72 |
+
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
|
| 73 |
+
|
| 74 |
+
# Scale to reach min_lr_ratio at the end
|
| 75 |
+
return cosine_decay * (1.0 - min_lr_ratio) + min_lr_ratio
|
| 76 |
+
|
| 77 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch=scheduler_config.last_epoch)
|
code/TaoTrain/src/taoTrain/schedulers/linear_warmup.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Linear warmup learning rate scheduler."""
|
| 2 |
+
|
| 3 |
+
import torch.optim as optim
|
| 4 |
+
from torch.optim.lr_scheduler import LambdaLR
|
| 5 |
+
from taoTrain.config import TrainingConfig
|
| 6 |
+
from .registry import register_scheduler
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@register_scheduler("linearWarmup")
|
| 10 |
+
def create_linear_warmup(
|
| 11 |
+
optimizer: optim.Optimizer,
|
| 12 |
+
config: TrainingConfig,
|
| 13 |
+
num_training_steps: int,
|
| 14 |
+
) -> LambdaLR:
|
| 15 |
+
"""
|
| 16 |
+
Create a linear warmup scheduler.
|
| 17 |
+
|
| 18 |
+
Linearly increases learning rate from 0 to peak over warmup steps,
|
| 19 |
+
then keeps it constant.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
optimizer: Optimizer instance
|
| 23 |
+
config: TrainingConfig with scheduler configuration
|
| 24 |
+
num_training_steps: Total number of training steps
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
LambdaLR scheduler instance
|
| 28 |
+
"""
|
| 29 |
+
scheduler_config = config.scheduler
|
| 30 |
+
|
| 31 |
+
# Determine warmup steps
|
| 32 |
+
if scheduler_config.warmup_steps > 0:
|
| 33 |
+
warmup_steps = scheduler_config.warmup_steps
|
| 34 |
+
else:
|
| 35 |
+
warmup_steps = int(num_training_steps * scheduler_config.warmup_ratio)
|
| 36 |
+
|
| 37 |
+
def lr_lambda(step):
|
| 38 |
+
"""Linear warmup learning rate schedule."""
|
| 39 |
+
if step < warmup_steps:
|
| 40 |
+
return float(step) / float(max(1, warmup_steps))
|
| 41 |
+
return 1.0
|
| 42 |
+
|
| 43 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch=scheduler_config.last_epoch)
|
code/TaoTrain/src/taoTrain/schedulers/registry.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Scheduler registry and factory for instantiating learning rate schedulers."""
|
| 2 |
+
|
| 3 |
+
from typing import Dict, Callable, Optional
|
| 4 |
+
import torch.optim as optim
|
| 5 |
+
from torch.optim.lr_scheduler import LambdaLR
|
| 6 |
+
from taoTrain.config import TrainingConfig, SchedulerEnum
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# Global registry for schedulers
|
| 10 |
+
_SCHEDULER_REGISTRY: Dict[str, Callable] = {}
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def register_scheduler(name: str):
|
| 14 |
+
"""
|
| 15 |
+
Decorator to register a custom scheduler factory function.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
name: Name of the scheduler (e.g., 'linearWarmup', 'cosineWarmup', 'constant')
|
| 19 |
+
"""
|
| 20 |
+
def decorator(fn: Callable) -> Callable:
|
| 21 |
+
if name in _SCHEDULER_REGISTRY:
|
| 22 |
+
raise ValueError(f"Scheduler '{name}' is already registered")
|
| 23 |
+
_SCHEDULER_REGISTRY[name] = fn
|
| 24 |
+
return fn
|
| 25 |
+
return decorator
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def get_registered_schedulers() -> Dict[str, Callable]:
|
| 29 |
+
"""Get all registered scheduler factory functions."""
|
| 30 |
+
return _SCHEDULER_REGISTRY.copy()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_scheduler(
|
| 34 |
+
optimizer: optim.Optimizer,
|
| 35 |
+
config: TrainingConfig,
|
| 36 |
+
num_training_steps: int,
|
| 37 |
+
) -> LambdaLR:
|
| 38 |
+
"""
|
| 39 |
+
Create a learning rate scheduler instance from config.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
optimizer: Optimizer to schedule learning rate for
|
| 43 |
+
config: TrainingConfig with scheduler configuration
|
| 44 |
+
num_training_steps: Total number of training steps
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
Learning rate scheduler instance
|
| 48 |
+
|
| 49 |
+
Raises:
|
| 50 |
+
ValueError: If scheduler type is not registered
|
| 51 |
+
"""
|
| 52 |
+
# Handle both enum and string values
|
| 53 |
+
scheduler_type = config.scheduler.scheduler_type
|
| 54 |
+
if isinstance(scheduler_type, str):
|
| 55 |
+
scheduler_name = scheduler_type
|
| 56 |
+
else:
|
| 57 |
+
scheduler_name = scheduler_type.value
|
| 58 |
+
|
| 59 |
+
if scheduler_name not in _SCHEDULER_REGISTRY:
|
| 60 |
+
raise ValueError(
|
| 61 |
+
f"Unknown scheduler: {scheduler_name}. "
|
| 62 |
+
f"Available: {list(_SCHEDULER_REGISTRY.keys())}"
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
factory_fn = _SCHEDULER_REGISTRY[scheduler_name]
|
| 66 |
+
return factory_fn(optimizer, config, num_training_steps)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def register_builtin_schedulers():
|
| 70 |
+
"""Register all built-in schedulers."""
|
| 71 |
+
# Import here to trigger decorator registration (avoid circular imports)
|
| 72 |
+
from . import linear_warmup # noqa: F401
|
| 73 |
+
from . import cosine_warmup # noqa: F401
|
| 74 |
+
from . import constant # noqa: F401
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# Auto-register built-in schedulers when module is imported
|
| 78 |
+
register_builtin_schedulers()
|