StarMist0012's picture
Add files using upload-large-folder tool
e2bfccc verified
# TaoTrain: Production-Grade LLM Training Framework
**TaoTrain** is a sophisticated PyTorch framework for training large language models at every scale—from experimental pretraining through supervised fine-tuning to reinforcement learning. Unlike fragmented training scripts or heavyweight frameworks, TaoTrain unifies the **entire training pipeline** in a clean, modular codebase that appeals to both ML engineers and software engineers.
## Current Taotern Work
TaoTrain now includes the Taotern comparison architectures used by the current SSM LLM work:
- `taonet`: the attention/MLA baseline.
- `taonet_ssm`: the TaoNet shell with the attention mixer replaced by the Gamma Space Model DPLR SSM.
- `taonet_hybrid`: an alternating attention/SSM TaoNet used for the current best 200M-class candidate.
The current selected deployment-oriented run is `hybrid_ssm_first_199m`, a `199,480,928` parameter model with 16 layers: SSM layers at `0,2,4,6,8,10,12,14` and attention layers at `1,3,5,7,9,11,13,15`. It uses the DPLR SSM core with split two-lane mixing, channel gates, per-channel local shift, and the faster convolution path for long-sequence training.
Remote run `taotern-200m-hybrid-chat-20260512` trains this model on TaoData for a 4B-token base stage and then runs SFT so the final artifact can be loaded as a chat model. The trainable fixes added for this run are:
- Async JSONL iteration keeps polling while tokenization workers are alive instead of ending early after a temporary empty queue.
- Cached JSONL scan metadata is reused safely while recomputing chunk ranges for the active `samples_per_chunk` and `max_samples` settings.
## Why TaoTrain?
- **Complete Unified Pipeline**: Pretraining → SFT → RL in a single, consistent framework. No context switching between different codebases or architectures.
- **Production-Grade Engineering**: Type-safe Pydantic configs, comprehensive checkpointing, AimStack integration, and proper gradient handling—not research code, but a framework you can deploy.
- **Extensibility Without Modification**: Register custom models, optimizers, schedulers, and datasets via decorators. Experiment freely without forking the framework.
- **Developer Experience First**: Interactive TUI for inference, intuitive YAML configurations, async data loading that eliminates I/O bottlenecks, and clear abstractions that make the codebase a pleasure to work with.
## Key Capabilities
| Capability | Details |
|---|---|
| **Multi-Stage Training** | Unified infrastructure for pretraining, SFT, and RL. Share model checkpoints, logging, and evaluation across stages. |
| **Advanced Optimization** | Hybrid Muon + AdamW optimizer: efficient 2D weight updates via SVD-based methods + adaptive learning for 1D parameters. |
| **Modern Architectures** | DeepSeek MLA with grouped query attention (GQA), YaRN context extension, and factorized embeddings—all configurable via YAML. |
| **Production Features** | BF16 mixed precision training, gradient accumulation, proper gradient clipping, checkpoint resumption, and validation loops. |
| **Async Data Pipeline** | Background tokenization with multi-threaded workers. Stream billion-token datasets from JSONL without loading into memory. |
| **Interactive Inference** | TUI chat interface with real-time generation speed metrics and multi-model comparison. |
| **Logging & Monitoring** | AimStack integration tracks loss, metrics, hyperparameters, and git hashes for reproducibility. Visualize training runs in your browser. |
## Getting Started
### Installation
```bash
git clone https://github.com/lobakkang/taoTrain.git
cd taoTrain
pip install -e .
```
### Training Examples
**Pretraining on a custom dataset:**
```bash
train pretrain --config configs/pretrain.yaml
```
Starts from scratch, learns representations from raw text via next-token prediction.
**Supervised Fine-tuning:**
```bash
train sft --config configs/sft.yaml
```
Fine-tune a pretrained model on instruction-response pairs for improved task performance.
**Reinforcement Learning (DPO):**
```bash
train rl --config configs/rl_dpo.yaml
```
Align models with human preferences using Direct Preference Optimization.
**Interactive Chat:**
```bash
tui-chat --model checkpoints/model.pt
```
Launch an interactive TUI to chat with your model and monitor generation metrics in real-time.
### Configuration
All training is configured via YAML with Pydantic validation. Configs are type-safe and automatically validated:
```yaml
# configs/sft.yaml
model:
architecture_type: "mla" # DeepSeek MLA with GQA
hidden_dim: 2048
num_layers: 24
num_heads: 32
d_latent_kv: 1536 # KV compression factor
training:
num_epochs: 3
batch_size: 32
learning_rate: 1e-4
warmup_ratio: 0.1
max_grad_norm: 1.0
optimizer:
optimizer_type: "muon_adamw" # Hybrid Muon + AdamW
muon_momentum: 0.95
data:
dataset_type: "sft_jsonl" # or "sft_hf" for HuggingFace
path: "data/sft_training.jsonl"
logging:
log_to_aim: true
aim_repo: "/tmp/aim_logs"
```
See `configs/` for complete examples.
## Project Architecture
```
src/taoTrain/
├── cli.py # Main CLI entry point
├── config.py # Pydantic configuration schemas
├── core/ # Base abstractions
│ └── base.py # BaseModel, BaseDataset, BaseTrainer
├── models/ # Pluggable architecture system
│ ├── registry.py # Architecture factory with @register_architecture
│ ├── taonet.py # SimpleLLM with DeepSeek MLA
│ ├── mla_components.py # KV compression, GQA, YaRN
│ ├── embeddings.py # Factorized embeddings
│ └── transformer.py # Standard Transformer reference
├── data/ # Advanced data pipeline
│ ├── factory.py # Dataset factory (HF + JSONL backends)
│ ├── async_loader.py # Async batch iteration (no I/O bottleneck)
│ ├── tokenization_queue.py # Background multi-threaded tokenization
│ ├── chunk_manager.py # Stream billion-token JSONL files
│ ├── hf_pretrain.py # HuggingFace pretraining datasets
│ ├── hf_sft.py # HuggingFace SFT datasets
│ ├── hf_rl.py # HuggingFace RL datasets
│ ├── pretrain_jsonl.py # JSONL pretraining
│ ├── sft_jsonl.py # JSONL SFT with instructions
│ └── rl_jsonl.py # JSONL RL with preferences
├── training/ # Unified training infrastructure
│ └── trainer.py # Trainer + PretrainTrainer, SFTTrainer, RLTrainer
├── optimizers/ # Pluggable optimizer system
│ ├── registry.py # Optimizer factory with @register_optimizer
│ ├── hybrid_muon_adamw.py # Composite: Muon (2D) + AdamW (1D)
│ ├── adamw.py # AdamW with weight decay
│ ├── adam.py # Standard Adam
│ └── sgd.py # SGD variants
├── schedulers/ # Learning rate schedules
│ ├── registry.py # LR scheduler factory
│ ├── cosine_warmup.py # 3-phase: linear warmup → plateau → cosine decay
│ ├── linear_warmup.py # Linear warmup + constant
│ └── constant.py # Constant learning rate
├── inference/ # Inference & interaction
│ ├── inferencer.py # Load & run inference from checkpoints
│ └── tui.py # Interactive chat with metrics display
├── checkpointing/ # State management
│ └── checkpoint.py # Save/load model + optimizer + config + metrics
├── logging/ # Experiment tracking
│ └── aim_logger.py # AimStack integration (loss, metrics, hyperparams)
├── benchmarks/ # Evaluation tools
│ └── runner.py # Perplexity, speed, and task-specific benchmarks
└── utils/
└── helpers.py # Utility functions
configs/ # Example YAML configurations
├── pretrain.yaml # Pretraining config
├── sft.yaml # SFT config
├── rl_dpo.yaml # RL/DPO config
└── tokenizer.yaml # Tokenizer config
tests/ # Unit & integration tests
└── test_dataset.py
```
## Extensible Architecture: The Registry Pattern
TaoTrain's power lies in its **pluggable design**. Add custom models, optimizers, schedulers, and datasets without modifying the framework.
### Custom Model Architecture
```python
from taoTrain.models import register_architecture, BaseModel
import torch.nn as nn
@register_architecture("custom_moe")
class MixtureOfExperts(BaseModel):
"""Your custom MoE architecture"""
def __init__(self, config):
super().__init__(config)
self.experts = nn.ModuleList([
nn.Linear(config.hidden_dim, config.hidden_dim)
for _ in range(config.num_experts)
])
self.router = nn.Linear(config.hidden_dim, config.num_experts)
def forward(self, input_ids, attention_mask=None):
# Your implementation
logits = self.compute_logits(input_ids)
loss = self.compute_loss(logits, labels) if labels is not None else None
return {"logits": logits, "loss": loss}
```
Then use it in your config:
```yaml
model:
architecture_type: "custom_moe"
hidden_dim: 2048
num_experts: 8
```
### Custom Optimizers & Schedulers
The same pattern works for optimizers and learning rate schedules:
```python
from taoTrain.optimizers import register_optimizer
from torch.optim import Optimizer
@register_optimizer("my_adaptive_optimizer")
class MyAdaptiveOptimizer(Optimizer):
def step(self, closure=None):
# Your optimization logic
pass
```
```python
from taoTrain.schedulers import register_scheduler
@register_scheduler("my_schedule")
def my_schedule(initial_lr, step, total_steps, **kwargs):
return initial_lr * (1.0 - step / total_steps) # Linear decay
```
**The key principle**: No framework code needs to change. You register once, it's available everywhere.
### Dataset Backend Flexibility
Define custom datasets (JSONL, HF, streaming, etc.) and let the factory route to them:
```python
from taoTrain.data import register_dataset
@register_dataset("pretrain", "my_backend")
class MyPretrainDataset(BaseDataset):
def __init__(self, config):
# Load from your custom backend
pass
def __getitem__(self, idx):
return {"input_ids": ..., "attention_mask": ...}
```
Use in config:
```yaml
data:
dataset_type: "pretrain"
backend_type: "my_backend" # Routes to MyPretrainDataset
```
## Why TaoTrain Framework?
### Async Data Loading: No I/O Bottleneck
Most training frameworks load and tokenize data on the main training thread, blocking compute. TaoTrain's **multi-threaded tokenization pipeline**:
- Tokenizes data in background workers while your GPU trains
- Supports streaming billion-token JSONL files without loading into memory
- Intelligent chunking (by file size or sample count)
- Metadata caching to avoid rescanning
**Result**: 10-100x faster data iteration on large datasets.
### Type-Safe Configuration
Forget YAML parsing errors or mysterious config bugs. TaoTrain uses **Pydantic dataclasses** for configuration:
- Automatic type validation: mistyped `learning_rate: "1e-4"` becomes an error, not silent failure
- Serialization: configs are part of checkpoints, ensuring reproducibility
- IDE support: autocomplete and type hints for all config fields
- Defaults: sensible defaults for all parameters
### Benchmarking & Metrics
Track what matters:
- **Perplexity**: Language modeling quality on held-out data
- **Generation Speed**: Tokens-per-second (useful for TUI or deployment)
- **Task-Specific Accuracy**: Evaluate on downstream tasks
- **Training Metrics**: Loss curves, gradient norms, effective batch size
All logged to AimStack with git hashes for reproducibility.
## Logging with AimStack
Automatically track and visualize experiments:
```bash
aim up --host 0.0.0.0
```
Then open `http://localhost:43800` to see:
- **Loss curves** per training step
- **Hyperparameters** (learning rate, batch size, model architecture)
- **Git hashes** for reproducibility
- **Custom metrics** (perplexity, validation accuracy, generation speed)
- **Compare runs**: Side-by-side experiment comparison
## Advanced Features
### Checkpointing with Resumption
TaoTrain saves complete training state:
```python
checkpoint = {
"step": 12500,
"model_state": model.state_dict(),
"optimizer_state": optimizer.state_dict(),
"config": config, # Full config as Pydantic object
"metrics": metrics_tracker.to_dict(),
}
```
Resume training from any checkpoint without loss of state. Keep last N checkpoints automatically.
### Mixed Precision Training (BF16)
```yaml
training:
use_bfloat16: true
gradient_accumulation_steps: 4
```
- BF16 via `torch.autocast` for ~2x speedup with minimal accuracy loss
- Proper gradient scaling and clipping
- Compatible with all optimizers and architectures
### 3-Phase Learning Rate Schedule
```yaml
scheduler:
scheduler_type: "cosine_warmup"
warmup_ratio: 0.1 # 10% of training steps
steady_ratio: 0.5 # 50% at steady rate
min_lr_ratio: 0.1 # Final LR = 0.1 × initial_lr
num_cycles: 1
```
This schedule:
1. **Linear warmup** (0 → 1) over 10% of steps
2. **Steady plateau** at full LR over 50% of steps
3. **Cosine decay** (1 → 0.1) over remaining 40% of steps
Better convergence than simple cosine or linear decay.
### Gradient Accumulation & Clipping
Simulate larger batch sizes with gradient accumulation:
```yaml
training:
batch_size: 32
gradient_accumulation_steps: 4 # Effective batch = 128
max_grad_norm: 1.0 # Gradient clipping
```
## Contributing
Contributions are welcome! TaoTrain is designed to make contributions easy:
1. **Add a model**: Implement `BaseModel` and `@register_architecture("name")`
2. **Add an optimizer**: Implement `torch.optim.Optimizer` and `@register_optimizer("name")`
3. **Add a dataset**: Implement `BaseDataset` and `@register_dataset(mode, backend_type)`
4. **Improve the core**: Submit PRs to `training/`, `data/`, `logging/`, etc.
Ensure new code includes:
- Type hints throughout
- Pydantic configs for new parameters
- Unit tests in `tests/`
- Documentation in docstrings and README
## Current Scope & Roadmap
### ✅ Currently Supported
- **Single GPU / single node** training
- **Pretraining, SFT, and RL training** stages
- **HuggingFace and JSONL** data backends
- **BF16 mixed precision** training
- **Checkpoint saving/loading** with resumption
- **Interactive inference** via TUI
- **Benchmarking** (perplexity, speed)
- **Pluggable architectures, optimizers, schedulers, datasets**
### 🚀 Roadmap (Future)
- **Distributed training** (DDP, FSDP) for multi-GPU/multi-node scaling
- **Quantization** support (INT8, QLoRA)
- **Advanced evaluation** (BLEU, ROUGE, custom tasks)
- **Streaming inference** with KV cache
- **Speculative decoding** for faster generation
- **Integration with popular model hubs** (Hugging Face Hub upload/download)
---
## Getting Help
- **Questions?** Open an issue on GitHub
- **Want to contribute?** See `CONTRIBUTING.md` (coming soon)
- **Found a bug?** Report it with a minimal reproduction script
## License
MIT