StarMist0012's picture
Add files using upload-large-folder tool
e2bfccc verified

TaoTrain: Production-Grade LLM Training Framework

TaoTrain is a sophisticated PyTorch framework for training large language models at every scaleβ€”from experimental pretraining through supervised fine-tuning to reinforcement learning. Unlike fragmented training scripts or heavyweight frameworks, TaoTrain unifies the entire training pipeline in a clean, modular codebase that appeals to both ML engineers and software engineers.

Current Taotern Work

TaoTrain now includes the Taotern comparison architectures used by the current SSM LLM work:

  • taonet: the attention/MLA baseline.
  • taonet_ssm: the TaoNet shell with the attention mixer replaced by the Gamma Space Model DPLR SSM.
  • taonet_hybrid: an alternating attention/SSM TaoNet used for the current best 200M-class candidate.

The current selected deployment-oriented run is hybrid_ssm_first_199m, a 199,480,928 parameter model with 16 layers: SSM layers at 0,2,4,6,8,10,12,14 and attention layers at 1,3,5,7,9,11,13,15. It uses the DPLR SSM core with split two-lane mixing, channel gates, per-channel local shift, and the faster convolution path for long-sequence training.

Remote run taotern-200m-hybrid-chat-20260512 trains this model on TaoData for a 4B-token base stage and then runs SFT so the final artifact can be loaded as a chat model. The trainable fixes added for this run are:

  • Async JSONL iteration keeps polling while tokenization workers are alive instead of ending early after a temporary empty queue.
  • Cached JSONL scan metadata is reused safely while recomputing chunk ranges for the active samples_per_chunk and max_samples settings.

Why TaoTrain?

  • Complete Unified Pipeline: Pretraining β†’ SFT β†’ RL in a single, consistent framework. No context switching between different codebases or architectures.
  • Production-Grade Engineering: Type-safe Pydantic configs, comprehensive checkpointing, AimStack integration, and proper gradient handlingβ€”not research code, but a framework you can deploy.
  • Extensibility Without Modification: Register custom models, optimizers, schedulers, and datasets via decorators. Experiment freely without forking the framework.
  • Developer Experience First: Interactive TUI for inference, intuitive YAML configurations, async data loading that eliminates I/O bottlenecks, and clear abstractions that make the codebase a pleasure to work with.

Key Capabilities

Capability Details
Multi-Stage Training Unified infrastructure for pretraining, SFT, and RL. Share model checkpoints, logging, and evaluation across stages.
Advanced Optimization Hybrid Muon + AdamW optimizer: efficient 2D weight updates via SVD-based methods + adaptive learning for 1D parameters.
Modern Architectures DeepSeek MLA with grouped query attention (GQA), YaRN context extension, and factorized embeddingsβ€”all configurable via YAML.
Production Features BF16 mixed precision training, gradient accumulation, proper gradient clipping, checkpoint resumption, and validation loops.
Async Data Pipeline Background tokenization with multi-threaded workers. Stream billion-token datasets from JSONL without loading into memory.
Interactive Inference TUI chat interface with real-time generation speed metrics and multi-model comparison.
Logging & Monitoring AimStack integration tracks loss, metrics, hyperparameters, and git hashes for reproducibility. Visualize training runs in your browser.

Getting Started

Installation

git clone https://github.com/lobakkang/taoTrain.git
cd taoTrain
pip install -e .

Training Examples

Pretraining on a custom dataset:

train pretrain --config configs/pretrain.yaml

Starts from scratch, learns representations from raw text via next-token prediction.

Supervised Fine-tuning:

train sft --config configs/sft.yaml

Fine-tune a pretrained model on instruction-response pairs for improved task performance.

Reinforcement Learning (DPO):

train rl --config configs/rl_dpo.yaml

Align models with human preferences using Direct Preference Optimization.

Interactive Chat:

tui-chat --model checkpoints/model.pt

Launch an interactive TUI to chat with your model and monitor generation metrics in real-time.

Configuration

All training is configured via YAML with Pydantic validation. Configs are type-safe and automatically validated:

# configs/sft.yaml
model:
  architecture_type: "mla"  # DeepSeek MLA with GQA
  hidden_dim: 2048
  num_layers: 24
  num_heads: 32
  d_latent_kv: 1536  # KV compression factor

training:
  num_epochs: 3
  batch_size: 32
  learning_rate: 1e-4
  warmup_ratio: 0.1
  max_grad_norm: 1.0

optimizer:
  optimizer_type: "muon_adamw"  # Hybrid Muon + AdamW
  muon_momentum: 0.95

data:
  dataset_type: "sft_jsonl"  # or "sft_hf" for HuggingFace
  path: "data/sft_training.jsonl"
  
logging:
  log_to_aim: true
  aim_repo: "/tmp/aim_logs"

See configs/ for complete examples.

Project Architecture

src/taoTrain/
β”œβ”€β”€ cli.py                      # Main CLI entry point
β”œβ”€β”€ config.py                   # Pydantic configuration schemas
β”‚
β”œβ”€β”€ core/                       # Base abstractions
β”‚   └── base.py                 # BaseModel, BaseDataset, BaseTrainer
β”‚
β”œβ”€β”€ models/                     # Pluggable architecture system
β”‚   β”œβ”€β”€ registry.py             # Architecture factory with @register_architecture
β”‚   β”œβ”€β”€ taonet.py               # SimpleLLM with DeepSeek MLA
β”‚   β”œβ”€β”€ mla_components.py       # KV compression, GQA, YaRN
β”‚   β”œβ”€β”€ embeddings.py           # Factorized embeddings
β”‚   └── transformer.py          # Standard Transformer reference
β”‚
β”œβ”€β”€ data/                       # Advanced data pipeline
β”‚   β”œβ”€β”€ factory.py              # Dataset factory (HF + JSONL backends)
β”‚   β”œβ”€β”€ async_loader.py         # Async batch iteration (no I/O bottleneck)
β”‚   β”œβ”€β”€ tokenization_queue.py   # Background multi-threaded tokenization
β”‚   β”œβ”€β”€ chunk_manager.py        # Stream billion-token JSONL files
β”‚   β”œβ”€β”€ hf_pretrain.py          # HuggingFace pretraining datasets
β”‚   β”œβ”€β”€ hf_sft.py               # HuggingFace SFT datasets
β”‚   β”œβ”€β”€ hf_rl.py                # HuggingFace RL datasets
β”‚   β”œβ”€β”€ pretrain_jsonl.py       # JSONL pretraining
β”‚   β”œβ”€β”€ sft_jsonl.py            # JSONL SFT with instructions
β”‚   └── rl_jsonl.py             # JSONL RL with preferences
β”‚
β”œβ”€β”€ training/                   # Unified training infrastructure
β”‚   └── trainer.py              # Trainer + PretrainTrainer, SFTTrainer, RLTrainer
β”‚
β”œβ”€β”€ optimizers/                 # Pluggable optimizer system
β”‚   β”œβ”€β”€ registry.py             # Optimizer factory with @register_optimizer
β”‚   β”œβ”€β”€ hybrid_muon_adamw.py    # Composite: Muon (2D) + AdamW (1D)
β”‚   β”œβ”€β”€ adamw.py                # AdamW with weight decay
β”‚   β”œβ”€β”€ adam.py                 # Standard Adam
β”‚   └── sgd.py                  # SGD variants
β”‚
β”œβ”€β”€ schedulers/                 # Learning rate schedules
β”‚   β”œβ”€β”€ registry.py             # LR scheduler factory
β”‚   β”œβ”€β”€ cosine_warmup.py        # 3-phase: linear warmup β†’ plateau β†’ cosine decay
β”‚   β”œβ”€β”€ linear_warmup.py        # Linear warmup + constant
β”‚   └── constant.py             # Constant learning rate
β”‚
β”œβ”€β”€ inference/                  # Inference & interaction
β”‚   β”œβ”€β”€ inferencer.py           # Load & run inference from checkpoints
β”‚   └── tui.py                  # Interactive chat with metrics display
β”‚
β”œβ”€β”€ checkpointing/              # State management
β”‚   └── checkpoint.py           # Save/load model + optimizer + config + metrics
β”‚
β”œβ”€β”€ logging/                    # Experiment tracking
β”‚   └── aim_logger.py           # AimStack integration (loss, metrics, hyperparams)
β”‚
β”œβ”€β”€ benchmarks/                 # Evaluation tools
β”‚   └── runner.py               # Perplexity, speed, and task-specific benchmarks
β”‚
└── utils/
    └── helpers.py              # Utility functions

configs/                        # Example YAML configurations
β”œβ”€β”€ pretrain.yaml               # Pretraining config
β”œβ”€β”€ sft.yaml                    # SFT config
β”œβ”€β”€ rl_dpo.yaml                 # RL/DPO config
└── tokenizer.yaml              # Tokenizer config

tests/                          # Unit & integration tests
└── test_dataset.py

Extensible Architecture: The Registry Pattern

TaoTrain's power lies in its pluggable design. Add custom models, optimizers, schedulers, and datasets without modifying the framework.

Custom Model Architecture

from taoTrain.models import register_architecture, BaseModel
import torch.nn as nn

@register_architecture("custom_moe")
class MixtureOfExperts(BaseModel):
    """Your custom MoE architecture"""
    def __init__(self, config):
        super().__init__(config)
        self.experts = nn.ModuleList([
            nn.Linear(config.hidden_dim, config.hidden_dim)
            for _ in range(config.num_experts)
        ])
        self.router = nn.Linear(config.hidden_dim, config.num_experts)
    
    def forward(self, input_ids, attention_mask=None):
        # Your implementation
        logits = self.compute_logits(input_ids)
        loss = self.compute_loss(logits, labels) if labels is not None else None
        return {"logits": logits, "loss": loss}

Then use it in your config:

model:
  architecture_type: "custom_moe"
  hidden_dim: 2048
  num_experts: 8

Custom Optimizers & Schedulers

The same pattern works for optimizers and learning rate schedules:

from taoTrain.optimizers import register_optimizer
from torch.optim import Optimizer

@register_optimizer("my_adaptive_optimizer")
class MyAdaptiveOptimizer(Optimizer):
    def step(self, closure=None):
        # Your optimization logic
        pass
from taoTrain.schedulers import register_scheduler

@register_scheduler("my_schedule")
def my_schedule(initial_lr, step, total_steps, **kwargs):
    return initial_lr * (1.0 - step / total_steps)  # Linear decay

The key principle: No framework code needs to change. You register once, it's available everywhere.

Dataset Backend Flexibility

Define custom datasets (JSONL, HF, streaming, etc.) and let the factory route to them:

from taoTrain.data import register_dataset

@register_dataset("pretrain", "my_backend")
class MyPretrainDataset(BaseDataset):
    def __init__(self, config):
        # Load from your custom backend
        pass
    
    def __getitem__(self, idx):
        return {"input_ids": ..., "attention_mask": ...}

Use in config:

data:
  dataset_type: "pretrain"
  backend_type: "my_backend"  # Routes to MyPretrainDataset

Why TaoTrain Framework?

Async Data Loading: No I/O Bottleneck

Most training frameworks load and tokenize data on the main training thread, blocking compute. TaoTrain's multi-threaded tokenization pipeline:

  • Tokenizes data in background workers while your GPU trains
  • Supports streaming billion-token JSONL files without loading into memory
  • Intelligent chunking (by file size or sample count)
  • Metadata caching to avoid rescanning

Result: 10-100x faster data iteration on large datasets.

Type-Safe Configuration

Forget YAML parsing errors or mysterious config bugs. TaoTrain uses Pydantic dataclasses for configuration:

  • Automatic type validation: mistyped learning_rate: "1e-4" becomes an error, not silent failure
  • Serialization: configs are part of checkpoints, ensuring reproducibility
  • IDE support: autocomplete and type hints for all config fields
  • Defaults: sensible defaults for all parameters

Benchmarking & Metrics

Track what matters:

  • Perplexity: Language modeling quality on held-out data
  • Generation Speed: Tokens-per-second (useful for TUI or deployment)
  • Task-Specific Accuracy: Evaluate on downstream tasks
  • Training Metrics: Loss curves, gradient norms, effective batch size

All logged to AimStack with git hashes for reproducibility.

Logging with AimStack

Automatically track and visualize experiments:

aim up --host 0.0.0.0

Then open http://localhost:43800 to see:

  • Loss curves per training step
  • Hyperparameters (learning rate, batch size, model architecture)
  • Git hashes for reproducibility
  • Custom metrics (perplexity, validation accuracy, generation speed)
  • Compare runs: Side-by-side experiment comparison

Advanced Features

Checkpointing with Resumption

TaoTrain saves complete training state:

checkpoint = {
    "step": 12500,
    "model_state": model.state_dict(),
    "optimizer_state": optimizer.state_dict(),
    "config": config,  # Full config as Pydantic object
    "metrics": metrics_tracker.to_dict(),
}

Resume training from any checkpoint without loss of state. Keep last N checkpoints automatically.

Mixed Precision Training (BF16)

training:
  use_bfloat16: true
  gradient_accumulation_steps: 4
  • BF16 via torch.autocast for ~2x speedup with minimal accuracy loss
  • Proper gradient scaling and clipping
  • Compatible with all optimizers and architectures

3-Phase Learning Rate Schedule

scheduler:
  scheduler_type: "cosine_warmup"
  warmup_ratio: 0.1          # 10% of training steps
  steady_ratio: 0.5          # 50% at steady rate
  min_lr_ratio: 0.1          # Final LR = 0.1 Γ— initial_lr
  num_cycles: 1

This schedule:

  1. Linear warmup (0 β†’ 1) over 10% of steps
  2. Steady plateau at full LR over 50% of steps
  3. Cosine decay (1 β†’ 0.1) over remaining 40% of steps

Better convergence than simple cosine or linear decay.

Gradient Accumulation & Clipping

Simulate larger batch sizes with gradient accumulation:

training:
  batch_size: 32
  gradient_accumulation_steps: 4  # Effective batch = 128
  max_grad_norm: 1.0               # Gradient clipping

Contributing

Contributions are welcome! TaoTrain is designed to make contributions easy:

  1. Add a model: Implement BaseModel and @register_architecture("name")
  2. Add an optimizer: Implement torch.optim.Optimizer and @register_optimizer("name")
  3. Add a dataset: Implement BaseDataset and @register_dataset(mode, backend_type)
  4. Improve the core: Submit PRs to training/, data/, logging/, etc.

Ensure new code includes:

  • Type hints throughout
  • Pydantic configs for new parameters
  • Unit tests in tests/
  • Documentation in docstrings and README

Current Scope & Roadmap

βœ… Currently Supported

  • Single GPU / single node training
  • Pretraining, SFT, and RL training stages
  • HuggingFace and JSONL data backends
  • BF16 mixed precision training
  • Checkpoint saving/loading with resumption
  • Interactive inference via TUI
  • Benchmarking (perplexity, speed)
  • Pluggable architectures, optimizers, schedulers, datasets

πŸš€ Roadmap (Future)

  • Distributed training (DDP, FSDP) for multi-GPU/multi-node scaling
  • Quantization support (INT8, QLoRA)
  • Advanced evaluation (BLEU, ROUGE, custom tasks)
  • Streaming inference with KV cache
  • Speculative decoding for faster generation
  • Integration with popular model hubs (Hugging Face Hub upload/download)

Getting Help

  • Questions? Open an issue on GitHub
  • Want to contribute? See CONTRIBUTING.md (coming soon)
  • Found a bug? Report it with a minimal reproduction script

License

MIT