diff --git a/.gitattributes b/.gitattributes index d609215bb9564c734f68e8f68f7fd5b7e1ed0a2a..f779ec7105ba94b2c2fbddace2d0af684c156002 100644 --- a/.gitattributes +++ b/.gitattributes @@ -3,3 +3,5 @@ *.vocab filter=lfs diff=lfs merge=lfs -text *.csv filter=lfs diff=lfs merge=lfs -text +code/Taotern_SSM/Gamma[[:space:]]Distributed[[:space:]]Ternary[[:space:]]HiPPO.pdf filter=lfs diff=lfs merge=lfs -text +code/Taotern_LLM_Experiments/docs/Taotern_Documentation_AI_Architecture.zip filter=lfs diff=lfs merge=lfs -text diff --git a/code/TaoTrain/src/taoTrain.egg-info/PKG-INFO b/code/TaoTrain/src/taoTrain.egg-info/PKG-INFO new file mode 100644 index 0000000000000000000000000000000000000000..7059427e080e512355b3da3e366a23ce5fb0a901 --- /dev/null +++ b/code/TaoTrain/src/taoTrain.egg-info/PKG-INFO @@ -0,0 +1,451 @@ +Metadata-Version: 2.4 +Name: taoTrain +Version: 0.1.0 +Summary: Clean, modular PyTorch LLM training framework with pluggable architectures, AimStack logging, and TUI inference +Author-email: Felix +License: MIT +Requires-Python: >=3.10 +Description-Content-Type: text/markdown +Requires-Dist: torch>=2.0.0 +Requires-Dist: transformers>=4.30.0 +Requires-Dist: datasets>=2.10.0 +Requires-Dist: pydantic>=2.0.0 +Requires-Dist: pydantic-settings>=2.0.0 +Requires-Dist: aim>=3.15.0 +Requires-Dist: click>=8.1.0 +Requires-Dist: rich>=13.0.0 +Requires-Dist: textual>=0.30.0 +Requires-Dist: numpy>=1.24.0 +Requires-Dist: tqdm>=4.65.0 +Requires-Dist: sentencepiece>=0.1.99 +Provides-Extra: dev +Requires-Dist: pytest>=7.4.0; extra == "dev" +Requires-Dist: pytest-cov>=4.1.0; extra == "dev" +Requires-Dist: pytest-xdist>=3.3.0; extra == "dev" +Requires-Dist: black>=23.7.0; extra == "dev" +Requires-Dist: ruff>=0.0.280; extra == "dev" +Requires-Dist: typing-extensions>=4.7.0; extra == "dev" + +# TaoTrain: Production-Grade LLM Training Framework + +**TaoTrain** is a sophisticated PyTorch framework for training large language models at every scale—from experimental pretraining through supervised fine-tuning to reinforcement learning. Unlike fragmented training scripts or heavyweight frameworks, TaoTrain unifies the **entire training pipeline** in a clean, modular codebase that appeals to both ML engineers and software engineers. + +## Current Taotern Work + +TaoTrain now includes the Taotern comparison architectures used by the current SSM LLM work: + +- `taonet`: the attention/MLA baseline. +- `taonet_ssm`: the TaoNet shell with the attention mixer replaced by the Gamma Space Model DPLR SSM. +- `taonet_hybrid`: an alternating attention/SSM TaoNet used for the current best 200M-class candidate. + +The current selected deployment-oriented run is `hybrid_ssm_first_199m`, a `199,480,928` parameter model with 16 layers: SSM layers at `0,2,4,6,8,10,12,14` and attention layers at `1,3,5,7,9,11,13,15`. It uses the DPLR SSM core with split two-lane mixing, channel gates, per-channel local shift, and the faster convolution path for long-sequence training. + +Remote run `taotern-200m-hybrid-chat-20260512` trains this model on TaoData for a 4B-token base stage and then runs SFT so the final artifact can be loaded as a chat model. The trainable fixes added for this run are: + +- Async JSONL iteration keeps polling while tokenization workers are alive instead of ending early after a temporary empty queue. +- Cached JSONL scan metadata is reused safely while recomputing chunk ranges for the active `samples_per_chunk` and `max_samples` settings. + +## Why TaoTrain? + +- **Complete Unified Pipeline**: Pretraining → SFT → RL in a single, consistent framework. No context switching between different codebases or architectures. +- **Production-Grade Engineering**: Type-safe Pydantic configs, comprehensive checkpointing, AimStack integration, and proper gradient handling—not research code, but a framework you can deploy. +- **Extensibility Without Modification**: Register custom models, optimizers, schedulers, and datasets via decorators. Experiment freely without forking the framework. +- **Developer Experience First**: Interactive TUI for inference, intuitive YAML configurations, async data loading that eliminates I/O bottlenecks, and clear abstractions that make the codebase a pleasure to work with. + +## Key Capabilities + +| Capability | Details | +|---|---| +| **Multi-Stage Training** | Unified infrastructure for pretraining, SFT, and RL. Share model checkpoints, logging, and evaluation across stages. | +| **Advanced Optimization** | Hybrid Muon + AdamW optimizer: efficient 2D weight updates via SVD-based methods + adaptive learning for 1D parameters. | +| **Modern Architectures** | DeepSeek MLA with grouped query attention (GQA), YaRN context extension, and factorized embeddings—all configurable via YAML. | +| **Production Features** | BF16 mixed precision training, gradient accumulation, proper gradient clipping, checkpoint resumption, and validation loops. | +| **Async Data Pipeline** | Background tokenization with multi-threaded workers. Stream billion-token datasets from JSONL without loading into memory. | +| **Interactive Inference** | TUI chat interface with real-time generation speed metrics and multi-model comparison. | +| **Logging & Monitoring** | AimStack integration tracks loss, metrics, hyperparameters, and git hashes for reproducibility. Visualize training runs in your browser. | + +## Getting Started + +### Installation + +```bash +git clone https://github.com/lobakkang/taoTrain.git +cd taoTrain +pip install -e . +``` + +### Training Examples + +**Pretraining on a custom dataset:** +```bash +train pretrain --config configs/pretrain.yaml +``` +Starts from scratch, learns representations from raw text via next-token prediction. + +**Supervised Fine-tuning:** +```bash +train sft --config configs/sft.yaml +``` +Fine-tune a pretrained model on instruction-response pairs for improved task performance. + +**Reinforcement Learning (DPO):** +```bash +train rl --config configs/rl_dpo.yaml +``` +Align models with human preferences using Direct Preference Optimization. + +**Interactive Chat:** +```bash +tui-chat --model checkpoints/model.pt +``` +Launch an interactive TUI to chat with your model and monitor generation metrics in real-time. + +### Configuration + +All training is configured via YAML with Pydantic validation. Configs are type-safe and automatically validated: + +```yaml +# configs/sft.yaml +model: + architecture_type: "mla" # DeepSeek MLA with GQA + hidden_dim: 2048 + num_layers: 24 + num_heads: 32 + d_latent_kv: 1536 # KV compression factor + +training: + num_epochs: 3 + batch_size: 32 + learning_rate: 1e-4 + warmup_ratio: 0.1 + max_grad_norm: 1.0 + +optimizer: + optimizer_type: "muon_adamw" # Hybrid Muon + AdamW + muon_momentum: 0.95 + +data: + dataset_type: "sft_jsonl" # or "sft_hf" for HuggingFace + path: "data/sft_training.jsonl" + +logging: + log_to_aim: true + aim_repo: "/tmp/aim_logs" +``` + +See `configs/` for complete examples. + +## Project Architecture + +``` +src/taoTrain/ +├── cli.py # Main CLI entry point +├── config.py # Pydantic configuration schemas +│ +├── core/ # Base abstractions +│ └── base.py # BaseModel, BaseDataset, BaseTrainer +│ +├── models/ # Pluggable architecture system +│ ├── registry.py # Architecture factory with @register_architecture +│ ├── taonet.py # SimpleLLM with DeepSeek MLA +│ ├── mla_components.py # KV compression, GQA, YaRN +│ ├── embeddings.py # Factorized embeddings +│ └── transformer.py # Standard Transformer reference +│ +├── data/ # Advanced data pipeline +│ ├── factory.py # Dataset factory (HF + JSONL backends) +│ ├── async_loader.py # Async batch iteration (no I/O bottleneck) +│ ├── tokenization_queue.py # Background multi-threaded tokenization +│ ├── chunk_manager.py # Stream billion-token JSONL files +│ ├── hf_pretrain.py # HuggingFace pretraining datasets +│ ├── hf_sft.py # HuggingFace SFT datasets +│ ├── hf_rl.py # HuggingFace RL datasets +│ ├── pretrain_jsonl.py # JSONL pretraining +│ ├── sft_jsonl.py # JSONL SFT with instructions +│ └── rl_jsonl.py # JSONL RL with preferences +│ +├── training/ # Unified training infrastructure +│ └── trainer.py # Trainer + PretrainTrainer, SFTTrainer, RLTrainer +│ +├── optimizers/ # Pluggable optimizer system +│ ├── registry.py # Optimizer factory with @register_optimizer +│ ├── hybrid_muon_adamw.py # Composite: Muon (2D) + AdamW (1D) +│ ├── adamw.py # AdamW with weight decay +│ ├── adam.py # Standard Adam +│ └── sgd.py # SGD variants +│ +├── schedulers/ # Learning rate schedules +│ ├── registry.py # LR scheduler factory +│ ├── cosine_warmup.py # 3-phase: linear warmup → plateau → cosine decay +│ ├── linear_warmup.py # Linear warmup + constant +│ └── constant.py # Constant learning rate +│ +├── inference/ # Inference & interaction +│ ├── inferencer.py # Load & run inference from checkpoints +│ └── tui.py # Interactive chat with metrics display +│ +├── checkpointing/ # State management +│ └── checkpoint.py # Save/load model + optimizer + config + metrics +│ +├── logging/ # Experiment tracking +│ └── aim_logger.py # AimStack integration (loss, metrics, hyperparams) +│ +├── benchmarks/ # Evaluation tools +│ └── runner.py # Perplexity, speed, and task-specific benchmarks +│ +└── utils/ + └── helpers.py # Utility functions + +configs/ # Example YAML configurations +├── pretrain.yaml # Pretraining config +├── sft.yaml # SFT config +├── rl_dpo.yaml # RL/DPO config +└── tokenizer.yaml # Tokenizer config + +tests/ # Unit & integration tests +└── test_dataset.py +``` + +## Extensible Architecture: The Registry Pattern + +TaoTrain's power lies in its **pluggable design**. Add custom models, optimizers, schedulers, and datasets without modifying the framework. + +### Custom Model Architecture + +```python +from taoTrain.models import register_architecture, BaseModel +import torch.nn as nn + +@register_architecture("custom_moe") +class MixtureOfExperts(BaseModel): + """Your custom MoE architecture""" + def __init__(self, config): + super().__init__(config) + self.experts = nn.ModuleList([ + nn.Linear(config.hidden_dim, config.hidden_dim) + for _ in range(config.num_experts) + ]) + self.router = nn.Linear(config.hidden_dim, config.num_experts) + + def forward(self, input_ids, attention_mask=None): + # Your implementation + logits = self.compute_logits(input_ids) + loss = self.compute_loss(logits, labels) if labels is not None else None + return {"logits": logits, "loss": loss} +``` + +Then use it in your config: + +```yaml +model: + architecture_type: "custom_moe" + hidden_dim: 2048 + num_experts: 8 +``` + +### Custom Optimizers & Schedulers + +The same pattern works for optimizers and learning rate schedules: + +```python +from taoTrain.optimizers import register_optimizer +from torch.optim import Optimizer + +@register_optimizer("my_adaptive_optimizer") +class MyAdaptiveOptimizer(Optimizer): + def step(self, closure=None): + # Your optimization logic + pass +``` + +```python +from taoTrain.schedulers import register_scheduler + +@register_scheduler("my_schedule") +def my_schedule(initial_lr, step, total_steps, **kwargs): + return initial_lr * (1.0 - step / total_steps) # Linear decay +``` + +**The key principle**: No framework code needs to change. You register once, it's available everywhere. + +### Dataset Backend Flexibility + +Define custom datasets (JSONL, HF, streaming, etc.) and let the factory route to them: + +```python +from taoTrain.data import register_dataset + +@register_dataset("pretrain", "my_backend") +class MyPretrainDataset(BaseDataset): + def __init__(self, config): + # Load from your custom backend + pass + + def __getitem__(self, idx): + return {"input_ids": ..., "attention_mask": ...} +``` + +Use in config: + +```yaml +data: + dataset_type: "pretrain" + backend_type: "my_backend" # Routes to MyPretrainDataset +``` + +## Why TaoTrain Framework? + +### Async Data Loading: No I/O Bottleneck + +Most training frameworks load and tokenize data on the main training thread, blocking compute. TaoTrain's **multi-threaded tokenization pipeline**: + +- Tokenizes data in background workers while your GPU trains +- Supports streaming billion-token JSONL files without loading into memory +- Intelligent chunking (by file size or sample count) +- Metadata caching to avoid rescanning + +**Result**: 10-100x faster data iteration on large datasets. + +### Type-Safe Configuration + +Forget YAML parsing errors or mysterious config bugs. TaoTrain uses **Pydantic dataclasses** for configuration: + +- Automatic type validation: mistyped `learning_rate: "1e-4"` becomes an error, not silent failure +- Serialization: configs are part of checkpoints, ensuring reproducibility +- IDE support: autocomplete and type hints for all config fields +- Defaults: sensible defaults for all parameters + +### Benchmarking & Metrics + +Track what matters: + +- **Perplexity**: Language modeling quality on held-out data +- **Generation Speed**: Tokens-per-second (useful for TUI or deployment) +- **Task-Specific Accuracy**: Evaluate on downstream tasks +- **Training Metrics**: Loss curves, gradient norms, effective batch size + +All logged to AimStack with git hashes for reproducibility. + +## Logging with AimStack + +Automatically track and visualize experiments: + +```bash +aim up --host 0.0.0.0 +``` + +Then open `http://localhost:43800` to see: + +- **Loss curves** per training step +- **Hyperparameters** (learning rate, batch size, model architecture) +- **Git hashes** for reproducibility +- **Custom metrics** (perplexity, validation accuracy, generation speed) +- **Compare runs**: Side-by-side experiment comparison + +## Advanced Features + +### Checkpointing with Resumption + +TaoTrain saves complete training state: + +```python +checkpoint = { + "step": 12500, + "model_state": model.state_dict(), + "optimizer_state": optimizer.state_dict(), + "config": config, # Full config as Pydantic object + "metrics": metrics_tracker.to_dict(), +} +``` + +Resume training from any checkpoint without loss of state. Keep last N checkpoints automatically. + +### Mixed Precision Training (BF16) + +```yaml +training: + use_bfloat16: true + gradient_accumulation_steps: 4 +``` + +- BF16 via `torch.autocast` for ~2x speedup with minimal accuracy loss +- Proper gradient scaling and clipping +- Compatible with all optimizers and architectures + +### 3-Phase Learning Rate Schedule + +```yaml +scheduler: + scheduler_type: "cosine_warmup" + warmup_ratio: 0.1 # 10% of training steps + steady_ratio: 0.5 # 50% at steady rate + min_lr_ratio: 0.1 # Final LR = 0.1 × initial_lr + num_cycles: 1 +``` + +This schedule: +1. **Linear warmup** (0 → 1) over 10% of steps +2. **Steady plateau** at full LR over 50% of steps +3. **Cosine decay** (1 → 0.1) over remaining 40% of steps + +Better convergence than simple cosine or linear decay. + +### Gradient Accumulation & Clipping + +Simulate larger batch sizes with gradient accumulation: + +```yaml +training: + batch_size: 32 + gradient_accumulation_steps: 4 # Effective batch = 128 + max_grad_norm: 1.0 # Gradient clipping +``` + +## Contributing + +Contributions are welcome! TaoTrain is designed to make contributions easy: + +1. **Add a model**: Implement `BaseModel` and `@register_architecture("name")` +2. **Add an optimizer**: Implement `torch.optim.Optimizer` and `@register_optimizer("name")` +3. **Add a dataset**: Implement `BaseDataset` and `@register_dataset(mode, backend_type)` +4. **Improve the core**: Submit PRs to `training/`, `data/`, `logging/`, etc. + +Ensure new code includes: +- Type hints throughout +- Pydantic configs for new parameters +- Unit tests in `tests/` +- Documentation in docstrings and README + +## Current Scope & Roadmap + +### ✅ Currently Supported + +- **Single GPU / single node** training +- **Pretraining, SFT, and RL training** stages +- **HuggingFace and JSONL** data backends +- **BF16 mixed precision** training +- **Checkpoint saving/loading** with resumption +- **Interactive inference** via TUI +- **Benchmarking** (perplexity, speed) +- **Pluggable architectures, optimizers, schedulers, datasets** + +### 🚀 Roadmap (Future) + +- **Distributed training** (DDP, FSDP) for multi-GPU/multi-node scaling +- **Quantization** support (INT8, QLoRA) +- **Advanced evaluation** (BLEU, ROUGE, custom tasks) +- **Streaming inference** with KV cache +- **Speculative decoding** for faster generation +- **Integration with popular model hubs** (Hugging Face Hub upload/download) + +--- + +## Getting Help + +- **Questions?** Open an issue on GitHub +- **Want to contribute?** See `CONTRIBUTING.md` (coming soon) +- **Found a bug?** Report it with a minimal reproduction script + +## License + +MIT diff --git a/code/TaoTrain/src/taoTrain.egg-info/SOURCES.txt b/code/TaoTrain/src/taoTrain.egg-info/SOURCES.txt new file mode 100644 index 0000000000000000000000000000000000000000..081fe9244821316c828be0281ef65768959e545d --- /dev/null +++ b/code/TaoTrain/src/taoTrain.egg-info/SOURCES.txt @@ -0,0 +1,65 @@ +README.md +pyproject.toml +src/taoTrain/__init__.py +src/taoTrain/cli.py +src/taoTrain/config.py +src/taoTrain.egg-info/PKG-INFO +src/taoTrain.egg-info/SOURCES.txt +src/taoTrain.egg-info/dependency_links.txt +src/taoTrain.egg-info/entry_points.txt +src/taoTrain.egg-info/requires.txt +src/taoTrain.egg-info/top_level.txt +src/taoTrain/benchmarks/__init__.py +src/taoTrain/benchmarks/runner.py +src/taoTrain/checkpointing/__init__.py +src/taoTrain/checkpointing/checkpoint.py +src/taoTrain/core/__init__.py +src/taoTrain/core/base.py +src/taoTrain/data/__init__.py +src/taoTrain/data/async_loader.py +src/taoTrain/data/chunk_manager.py +src/taoTrain/data/factory.py +src/taoTrain/data/hf_base.py +src/taoTrain/data/hf_pretrain.py +src/taoTrain/data/hf_rl.py +src/taoTrain/data/hf_sft.py +src/taoTrain/data/jsonl_base.py +src/taoTrain/data/loaders.py +src/taoTrain/data/pretrain_jsonl.py +src/taoTrain/data/rl_jsonl.py +src/taoTrain/data/sft_jsonl.py +src/taoTrain/data/sft_utils.py +src/taoTrain/data/tokenization_queue.py +src/taoTrain/data/tokenizer.py +src/taoTrain/inference/__init__.py +src/taoTrain/inference/inferencer.py +src/taoTrain/inference/tui.py +src/taoTrain/logging/__init__.py +src/taoTrain/logging/aim_logger.py +src/taoTrain/models/__init__.py +src/taoTrain/models/embeddings.py +src/taoTrain/models/mla_components.py +src/taoTrain/models/registry.py +src/taoTrain/models/taonet.py +src/taoTrain/models/taonet_ssm.py +src/taoTrain/models/transformer.py +src/taoTrain/optimizers/__init__.py +src/taoTrain/optimizers/adam.py +src/taoTrain/optimizers/adamw.py +src/taoTrain/optimizers/hybrid_muon_adamw.py +src/taoTrain/optimizers/registry.py +src/taoTrain/optimizers/sgd.py +src/taoTrain/schedulers/__init__.py +src/taoTrain/schedulers/constant.py +src/taoTrain/schedulers/cosine_warmup.py +src/taoTrain/schedulers/linear_warmup.py +src/taoTrain/schedulers/registry.py +src/taoTrain/tokenizers/__init__.py +src/taoTrain/tokenizers/trainer.py +src/taoTrain/training/__init__.py +src/taoTrain/training/trainer.py +src/taoTrain/utils/__init__.py +src/taoTrain/utils/helpers.py +tests/test_dataset.py +tests/test_sft_masking.py +tests/test_taonet_ssm.py \ No newline at end of file diff --git a/code/TaoTrain/src/taoTrain.egg-info/requires.txt b/code/TaoTrain/src/taoTrain.egg-info/requires.txt new file mode 100644 index 0000000000000000000000000000000000000000..26d1e115948ce8cfe5ecca552d4db5f2ac4b0fbd --- /dev/null +++ b/code/TaoTrain/src/taoTrain.egg-info/requires.txt @@ -0,0 +1,20 @@ +torch>=2.0.0 +transformers>=4.30.0 +datasets>=2.10.0 +pydantic>=2.0.0 +pydantic-settings>=2.0.0 +aim>=3.15.0 +click>=8.1.0 +rich>=13.0.0 +textual>=0.30.0 +numpy>=1.24.0 +tqdm>=4.65.0 +sentencepiece>=0.1.99 + +[dev] +pytest>=7.4.0 +pytest-cov>=4.1.0 +pytest-xdist>=3.3.0 +black>=23.7.0 +ruff>=0.0.280 +typing-extensions>=4.7.0 diff --git a/code/TaoTrain/src/taoTrain.egg-info/top_level.txt b/code/TaoTrain/src/taoTrain.egg-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..f2bbd2247de54234a6cc4feac6e7bbd3b80489ed --- /dev/null +++ b/code/TaoTrain/src/taoTrain.egg-info/top_level.txt @@ -0,0 +1 @@ +taoTrain diff --git a/code/TaoTrain/src/taoTrain/benchmarks/__init__.py b/code/TaoTrain/src/taoTrain/benchmarks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0dfbf8f50bf3fa8a04fbe15f5230a3fbe00ebd08 --- /dev/null +++ b/code/TaoTrain/src/taoTrain/benchmarks/__init__.py @@ -0,0 +1,5 @@ +"""Benchmarking suite.""" + +from .runner import BenchmarkRunner + +__all__ = ["BenchmarkRunner"] diff --git a/code/TaoTrain/src/taoTrain/benchmarks/runner.py b/code/TaoTrain/src/taoTrain/benchmarks/runner.py new file mode 100644 index 0000000000000000000000000000000000000000..84963289182ec88bb9cbceb791d6417f94302bf5 --- /dev/null +++ b/code/TaoTrain/src/taoTrain/benchmarks/runner.py @@ -0,0 +1,221 @@ +"""Benchmarking suite for evaluating trained models.""" + +import time +from pathlib import Path +from typing import Optional, Dict +import torch +from torch.utils.data import DataLoader + +from taoTrain.core import BaseModel +from taoTrain.config import TrainingConfig +from taoTrain.data.loaders import get_dataloader +from taoTrain.inference import Inferencer + + +class BenchmarkRunner: + """Run benchmarks on a trained model.""" + + def __init__( + self, + model: BaseModel, + device: torch.device, + dtype: torch.dtype = torch.float32, + ): + """ + Initialize benchmark runner. + + Args: + model: Trained model + device: Device for inference + dtype: Data type + """ + self.model = model.to(device) + self.model.eval() + self.device = device + self.dtype = dtype + + @staticmethod + def load_from_checkpoint( + checkpoint_path: str | Path, + device: Optional[torch.device] = None, + ) -> "BenchmarkRunner": + """Load model from checkpoint.""" + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + checkpoint = torch.load(checkpoint_path, map_location=device) + + # Reconstruct model config + from taoTrain.config import ModelConfig + from taoTrain.models import get_model + + model_config = ModelConfig(**checkpoint.get("config", {}).get("model", {})) + model = get_model(model_config, device=device) + model.load_state_dict(checkpoint["model_state_dict"]) + + return BenchmarkRunner(model, device) + + def benchmark_perplexity( + self, + dataset: "DataLoader", + num_batches: Optional[int] = None, + ) -> float: + """ + Compute perplexity on a dataset. + + Args: + dataset: DataLoader for evaluation + num_batches: Limit evaluation to N batches + + Returns: + Perplexity (exp of average loss) + """ + total_loss = 0.0 + total_tokens = 0 + + with torch.no_grad(): + for batch_idx, batch in enumerate(dataset): + if num_batches and batch_idx >= num_batches: + break + + # Move to device + input_ids = batch["input_ids"].to(self.device) + attention_mask = batch.get("attention_mask") + if attention_mask is not None: + attention_mask = attention_mask.to(self.device) + labels = batch.get("labels") + if labels is not None: + labels = labels.to(self.device) + + # Forward pass + with torch.autocast( + device_type="cuda" if self.device.type == "cuda" else "cpu", + dtype=torch.bfloat16 if self.dtype == torch.bfloat16 else torch.float32, + ): + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + labels=labels, + ) + loss = outputs.get("loss") + + if loss is not None: + total_loss += loss.item() * input_ids.shape[0] + total_tokens += input_ids.shape[0] + + avg_loss = total_loss / total_tokens if total_tokens > 0 else float('inf') + perplexity = torch.exp(torch.tensor(avg_loss)).item() + + return perplexity + + def benchmark_throughput( + self, + batch_size: int = 32, + seq_length: int = 1024, + num_iters: int = 10, + ) -> Dict[str, float]: + """ + Benchmark forward pass throughput. + + Args: + batch_size: Batch size + seq_length: Sequence length + num_iters: Number of iterations + + Returns: + Dict with throughput metrics + """ + # Create dummy batch + dummy_input = torch.randint( + 0, self.model.config.vocab_size, + (batch_size, seq_length) + ).to(self.device) + + # Warmup + with torch.no_grad(): + for _ in range(2): + _ = self.model(dummy_input) + + torch.cuda.synchronize() if torch.cuda.is_available() else None + + # Benchmark forward pass + start = time.time() + + with torch.no_grad(): + for _ in range(num_iters): + _ = self.model(dummy_input) + + torch.cuda.synchronize() if torch.cuda.is_available() else None + + elapsed = time.time() - start + + total_tokens = batch_size * seq_length * num_iters + tokens_per_sec = total_tokens / elapsed + + return { + "throughput_tokens_per_sec": tokens_per_sec, + "throughput_samples_per_sec": (batch_size * num_iters) / elapsed, + "avg_time_per_iter_ms": (elapsed / num_iters) * 1000, + } + + def benchmark_memory(self) -> Dict[str, float]: + """ + Benchmark peak GPU memory usage. + + Returns: + Dict with memory stats + """ + if not torch.cuda.is_available(): + return {"peak_memory_gb": 0.0} + + torch.cuda.reset_peak_memory_stats() + torch.cuda.synchronize() + + # Create dummy batch + dummy_input = torch.randint( + 0, self.model.config.vocab_size, + (16, 1024) + ).to(self.device) + + with torch.no_grad(): + _ = self.model(dummy_input) + + torch.cuda.synchronize() + + peak_memory = torch.cuda.max_memory_allocated() / (1024 ** 3) # GB + + return {"peak_memory_gb": peak_memory} + + def run_all_benchmarks( + self, + dataset: Optional["DataLoader"] = None, + batch_size: int = 32, + seq_length: int = 1024, + ) -> Dict[str, float]: + """ + Run all benchmarks. + + Args: + dataset: DataLoader for perplexity benchmark + batch_size: Batch size for throughput benchmark + seq_length: Sequence length for throughput benchmark + + Returns: + Dict with all benchmark results + """ + results = {} + + if dataset is not None: + print("Running perplexity benchmark...") + ppl = self.benchmark_perplexity(dataset, num_batches=10) + results["perplexity"] = ppl + + print("Running throughput benchmark...") + throughput = self.benchmark_throughput(batch_size, seq_length) + results.update(throughput) + + print("Running memory benchmark...") + memory = self.benchmark_memory() + results.update(memory) + + return results diff --git a/code/TaoTrain/src/taoTrain/checkpointing/__init__.py b/code/TaoTrain/src/taoTrain/checkpointing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1d4ad85dfdb6c491e3b927ad666fc1f5bc5df30e --- /dev/null +++ b/code/TaoTrain/src/taoTrain/checkpointing/__init__.py @@ -0,0 +1,5 @@ +"""Checkpoint management.""" + +from .checkpoint import CheckpointManager + +__all__ = ["CheckpointManager"] diff --git a/code/TaoTrain/src/taoTrain/checkpointing/checkpoint.py b/code/TaoTrain/src/taoTrain/checkpointing/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..c11721447c5c8e4e8bf17dfb3e3c5f5bcc85712a --- /dev/null +++ b/code/TaoTrain/src/taoTrain/checkpointing/checkpoint.py @@ -0,0 +1,194 @@ +"""Checkpoint management utilities. + +Canonical Checkpoint Format (new): + { + 'step': int, # Training step number + 'model_state': Dict[str, Tensor], # Model state dict + 'optimizer_state': Dict, # Optimizer state dict (optional) + 'config': Dict, # TrainingConfig as dict + 'metrics': Dict[str, float], # Training metrics + 'global_step': int, # (deprecated, kept for compat) same as step + 'current_epoch': int, # (optional) current epoch number + 'best_loss': float, # (optional) best validation loss + } + +Legacy Checkpoint Format (old, from BaseTrainer): + { + 'global_step': int, + 'current_epoch': int, + 'best_loss': float, + 'model_state_dict': Dict[str, Tensor], # ← Note: uses '_dict' suffix + 'optimizer_state_dict': Dict, + 'config': Dict, + } + +The load() function auto-detects and migrates legacy format to canonical format. +""" + +from pathlib import Path +from typing import Dict, Any, Optional +import torch +from taoTrain.config import TrainingConfig + + +class CheckpointManager: + """Manage model checkpoints with versioning.""" + + def __init__( + self, + checkpoint_dir: str | Path, + keep_last_n: int = 3, + track_best: bool = True, + ): + """ + Initialize checkpoint manager. + + Args: + checkpoint_dir: Directory to save checkpoints + keep_last_n: Number of recent checkpoints to keep + track_best: Whether to track best model + """ + self.checkpoint_dir = Path(checkpoint_dir) + self.checkpoint_dir.mkdir(parents=True, exist_ok=True) + + self.keep_last_n = keep_last_n + self.track_best = track_best + + self.best_metric = None + self.best_metric_name = None + self.saved_checkpoints = [] + + def save( + self, + step: int, + model_state: Dict[str, Any], + optimizer_state: Optional[Dict[str, Any]] = None, + config: Optional[TrainingConfig] = None, + metrics: Optional[Dict[str, float]] = None, + is_best: bool = False, + ) -> Path: + """ + Save a checkpoint. + + Args: + step: Training step + model_state: Model state dict + optimizer_state: Optimizer state dict + config: Training config + metrics: Metrics dict + is_best: Whether this is the best model so far + + Returns: + Path to saved checkpoint + """ + checkpoint = { + "step": step, + "model_state": model_state, + "optimizer_state": optimizer_state, + "config": config.to_dict() if config else None, + "metrics": metrics or {}, + } + + filename = f"checkpoint_step_{step:06d}.pt" + if is_best: + filename = "best_model.pt" + + path = self.checkpoint_dir / filename + torch.save(checkpoint, path) + + # Track saved checkpoints + if not is_best: + self.saved_checkpoints.append((step, path)) + + # Clean up old checkpoints + if len(self.saved_checkpoints) > self.keep_last_n: + _, old_path = self.saved_checkpoints.pop(0) + if old_path.exists(): + old_path.unlink() + + return path + + def load( + self, + checkpoint_path: str | Path, + device: Optional[torch.device] = None, + ) -> Dict[str, Any]: + """ + Load a checkpoint with backward-compatible format handling. + + Auto-detects checkpoint format (canonical or legacy) and normalizes + to canonical format in-memory. Legacy checkpoints are migrated without + modifying the file. + + Args: + checkpoint_path: Path to checkpoint + device: Device to load to + + Returns: + Checkpoint dict in canonical format with 'model_state' key + """ + if device is None: + device = torch.device("cpu") + + checkpoint = torch.load(checkpoint_path, map_location=device) + + # Auto-detect and migrate legacy format to canonical format + checkpoint = self._normalize_checkpoint_format(checkpoint) + + return checkpoint + + def _normalize_checkpoint_format(self, checkpoint: Dict[str, Any]) -> Dict[str, Any]: + """ + Normalize checkpoint to canonical format. + + Detects if checkpoint is in legacy format (from BaseTrainer with 'model_state_dict') + and migrates it to canonical format (with 'model_state'). + + Args: + checkpoint: Raw checkpoint dict + + Returns: + Normalized checkpoint dict with canonical keys + """ + # Check if this is a legacy checkpoint (has 'model_state_dict' but not 'model_state') + if "model_state_dict" in checkpoint and "model_state" not in checkpoint: + # Migrate legacy format to canonical + migrated = { + "step": checkpoint.get("global_step", 0), + "model_state": checkpoint["model_state_dict"], + "optimizer_state": checkpoint.get("optimizer_state_dict"), + "config": checkpoint.get("config"), + "metrics": {}, + # Keep legacy keys for backward compatibility in code that uses them + "global_step": checkpoint.get("global_step", 0), + "current_epoch": checkpoint.get("current_epoch", 0), + "best_loss": checkpoint.get("best_loss", float('inf')), + } + print(f"\n✓ [CheckpointManager] Detected legacy checkpoint format. Auto-migrated to canonical format.") + return migrated + + # Already in canonical format or unknown format + if "model_state" not in checkpoint: + # If neither format detected, ensure model_state is accessible + # (might be a raw state_dict) + print(f"\n⚠ [CheckpointManager] Checkpoint format unclear. Assuming raw state_dict format.") + checkpoint["model_state"] = checkpoint + + return checkpoint + + def get_latest(self) -> Optional[Path]: + """Get path to latest checkpoint.""" + if not self.saved_checkpoints: + return None + return self.saved_checkpoints[-1][1] + + def get_best(self) -> Optional[Path]: + """Get path to best checkpoint.""" + best_path = self.checkpoint_dir / "best_model.pt" + if best_path.exists(): + return best_path + return None + + def list_checkpoints(self) -> list[Path]: + """List all saved checkpoints.""" + return sorted(self.checkpoint_dir.glob("checkpoint_step_*.pt")) diff --git a/code/TaoTrain/src/taoTrain/core/__init__.py b/code/TaoTrain/src/taoTrain/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5c03a85860b357b94280d1addb302fc60bbae433 --- /dev/null +++ b/code/TaoTrain/src/taoTrain/core/__init__.py @@ -0,0 +1,5 @@ +"""Base classes for models, trainers, and datasets.""" + +from .base import BaseModel, BaseTrainer, BaseDataset, create_model, create_datasets + +__all__ = ["BaseModel", "BaseTrainer", "BaseDataset", "create_model", "create_datasets"] diff --git a/code/TaoTrain/src/taoTrain/core/base.py b/code/TaoTrain/src/taoTrain/core/base.py new file mode 100644 index 0000000000000000000000000000000000000000..d7692bb84b6f15d6396f9621a93e1da0088cca01 --- /dev/null +++ b/code/TaoTrain/src/taoTrain/core/base.py @@ -0,0 +1,271 @@ +"""Base classes for models, trainers, and datasets.""" + +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Optional, Any, Iterator +import torch +import torch.nn as nn +from torch.utils.data import Dataset as TorchDataset +from taoTrain.config import TrainingConfig, ModelConfig + + +# ============================================================================ +# Base Model +# ============================================================================ + + +class BaseModel(nn.Module, ABC): + """Abstract base class for language models.""" + + def __init__(self, config: ModelConfig): + """Initialize model with config.""" + super().__init__() + self.config = config + + @abstractmethod + def forward( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + ) -> dict[str, torch.Tensor]: + """ + Forward pass. + + Args: + input_ids: Shape (batch_size, seq_length) + attention_mask: Shape (batch_size, seq_length), optional + labels: Shape (batch_size, seq_length), optional (for loss computation) + + Returns: + Dict with keys: + - 'logits': Shape (batch_size, seq_length, vocab_size) + - 'loss': Scalar (if labels provided) + """ + pass + + def count_parameters(self) -> int: + """Count total trainable parameters.""" + return sum(p.numel() for p in self.parameters() if p.requires_grad) + + def get_num_layers(self) -> int: + """Get number of layers (for model architecture).""" + return self.config.num_layers + + +# ============================================================================ +# Base Dataset +# ============================================================================ + + +class BaseDataset(TorchDataset, ABC): + """Abstract base class for datasets.""" + + def __init__(self, config: "TrainingConfig"): + """Initialize dataset.""" + self.config = config + self.data = None + + @abstractmethod + def __len__(self) -> int: + """Return dataset size.""" + pass + + @abstractmethod + def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: + """ + Get a single sample. + + Returns: + Dict with keys: + - 'input_ids': 1D tensor of token IDs + - 'attention_mask': 1D tensor of attention mask + - 'labels': 1D tensor of labels (optional) + """ + pass + + def load_dataset(self) -> None: + """Load dataset from HuggingFace or other source.""" + pass + + def preprocess(self) -> None: + """Preprocess dataset (tokenization, etc).""" + pass + + +# ============================================================================ +# Base Trainer +# ============================================================================ + + +class BaseTrainer(ABC): + """Abstract base class for trainers.""" + + def __init__( + self, + model: BaseModel, + train_dataset: BaseDataset, + val_dataset: Optional[BaseDataset], + config: TrainingConfig, + device: torch.device, + ): + """Initialize trainer.""" + self.model = model.to(device) + self.train_dataset = train_dataset + self.val_dataset = val_dataset + self.config = config + self.device = device + + # Training state + self.global_step = 0 + self.current_epoch = 0 + self.best_loss = float('inf') + + # Logging + self.logger = None + + # Optimizer and scheduler (to be set up by subclass) + self.optimizer = None + self.scheduler = None + + @abstractmethod + def training_step(self, batch: dict[str, torch.Tensor]) -> dict[str, float]: + """ + Single training step. + + Args: + batch: Training batch with input_ids, attention_mask, labels, etc. + + Returns: + Dict with metrics (e.g., {'loss': 0.5, 'accuracy': 0.8}) + """ + pass + + @abstractmethod + def validation_step(self, batch: dict[str, torch.Tensor]) -> dict[str, float]: + """ + Single validation step. + + Args: + batch: Validation batch + + Returns: + Dict with validation metrics + """ + pass + + @abstractmethod + def train_epoch(self) -> dict[str, float]: + """ + Train for one epoch. + + Returns: + Dict with epoch-level metrics + """ + pass + + @abstractmethod + def validate(self) -> dict[str, float]: + """ + Run validation on the entire validation set. + + Returns: + Dict with validation metrics + """ + pass + + def save_checkpoint(self, path: str | Path) -> None: + """ + Save checkpoint in canonical format. + + Uses canonical checkpoint format: + { + 'step': int, + 'model_state': state_dict, + 'optimizer_state': state_dict, + 'config': dict, + 'metrics': dict, + 'global_step': int, # Legacy compat + 'current_epoch': int, # Legacy compat + 'best_loss': float, # Legacy compat + } + + Args: + path: Path to save checkpoint + """ + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + + # Save in canonical format + checkpoint = { + # Canonical format keys + 'step': self.global_step, + 'model_state': self.model.state_dict(), + 'optimizer_state': self.optimizer.state_dict() if self.optimizer else None, + 'config': self.config.to_dict(), + 'metrics': {}, + # Legacy format keys (for backward compatibility with code that reads them) + 'global_step': self.global_step, + 'current_epoch': self.current_epoch, + 'best_loss': self.best_loss, + } + + torch.save(checkpoint, path) + + def load_checkpoint(self, path: str | Path) -> None: + """ + Load checkpoint (handles both canonical and legacy formats). + + Args: + path: Path to checkpoint + """ + path = Path(path) + checkpoint = torch.load(path, map_location=self.device) + + # Try canonical keys first, fall back to legacy keys + model_state_key = 'model_state' if 'model_state' in checkpoint else 'model_state_dict' + optimizer_state_key = 'optimizer_state' if 'optimizer_state' in checkpoint else 'optimizer_state_dict' + + self.model.load_state_dict(checkpoint[model_state_key]) + if self.optimizer and checkpoint.get(optimizer_state_key): + self.optimizer.load_state_dict(checkpoint[optimizer_state_key]) + + # Try canonical 'step' first, fall back to legacy 'global_step' + self.global_step = checkpoint.get('step', checkpoint.get('global_step', 0)) + self.current_epoch = checkpoint.get('current_epoch', 0) + self.best_loss = checkpoint.get('best_loss', float('inf')) + + def _get_lr(self) -> float: + """Get current learning rate from optimizer.""" + for param_group in self.optimizer.param_groups: + return param_group['lr'] + return 0.0 + + +# ============================================================================ +# Utility functions +# ============================================================================ + + +def create_model(config: TrainingConfig, device: torch.device) -> BaseModel: + """Create model from config (calls registry).""" + from taoTrain.models import get_model + return get_model(config.model, device=device) + + +def create_datasets( + config: TrainingConfig, +) -> tuple[BaseDataset, Optional[BaseDataset]]: + """Create train and validation datasets using factory pattern.""" + # Import here to avoid circular imports + from taoTrain.data import DatasetFactory + + # Create train dataset + train_dataset = DatasetFactory.create_dataset(config, split="train") + + # Create validation dataset (only for HuggingFace datasets with explicit validation split) + val_dataset = None + if not config.dataset.local and hasattr(config.dataset, "validation_split"): + val_dataset = DatasetFactory.create_dataset(config, split="validation") + + return train_dataset, val_dataset diff --git a/code/TaoTrain/src/taoTrain/data/__init__.py b/code/TaoTrain/src/taoTrain/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2d16f57da63db43248eabb3281d8ae8883ca81c8 --- /dev/null +++ b/code/TaoTrain/src/taoTrain/data/__init__.py @@ -0,0 +1,56 @@ +"""Dataset implementations and loaders.""" + +# HuggingFace-based datasets are optional for JSONL-only deployments. +try: + from .hf_base import BaseHFDataset + from .hf_pretrain import PretrainDataset + from .hf_sft import SFTDataset + from .hf_rl import RLDataset +except ImportError: + BaseHFDataset = None + PretrainDataset = None + SFTDataset = None + RLDataset = None + +# JSONL-based datasets (async-only) +from .jsonl_base import BaseJSONLDataset +from .pretrain_jsonl import PretrainJSONLDataset +from .sft_jsonl import SFTJSONLDataset +from .rl_jsonl import RLJSONLDataset + +# Utilities +from .tokenizer import SentencePieceTokenizerWrapper +from .sft_utils import ( + parse_sft_record, + build_sft_sequence_tokens, + apply_response_masking, + build_response_only_next_token_labels, +) +from .loaders import get_dataloader +from .async_loader import AsyncBatchIterator +from .tokenization_queue import TokenizationQueue +from .factory import DatasetFactory + +__all__ = [ + # HuggingFace datasets + "BaseHFDataset", + "PretrainDataset", + "SFTDataset", + "RLDataset", + # JSONL datasets + "BaseJSONLDataset", + "PretrainJSONLDataset", + "SFTJSONLDataset", + "RLJSONLDataset", + # Utilities + "SentencePieceTokenizerWrapper", + "parse_sft_record", + "build_sft_sequence_tokens", + "apply_response_masking", + "build_response_only_next_token_labels", + # Data loading + "get_dataloader", + "AsyncBatchIterator", + "TokenizationQueue", + "DatasetFactory", +] diff --git a/code/TaoTrain/src/taoTrain/data/async_loader.py b/code/TaoTrain/src/taoTrain/data/async_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..9b42c60db963b96f8e31fafbe084e79a6d6d7ddc --- /dev/null +++ b/code/TaoTrain/src/taoTrain/data/async_loader.py @@ -0,0 +1,204 @@ +"""Async batch iterator for training with background tokenization.""" + +from typing import Dict, List, Optional, Any, Iterator +import torch + +from taoTrain.data.tokenization_queue import TokenizationQueue +from taoTrain.data.sft_utils import build_response_only_next_token_labels + + +class AsyncBatchIterator: + """ + Iterator that yields batches from a tokenization queue. + + This allows batches to be consumed directly from the background tokenization + thread without waiting for all chunks to be tokenized upfront. + + The iterator: + 1. Pulls pre-tokenized chunks from the TokenizationQueue + 2. Yields individual samples or batches + 3. Handles movement to device (GPU/CPU) at batch level + 4. Supports gradient accumulation + """ + + def __init__( + self, + tokenization_queue: TokenizationQueue, + batch_size: int, + device: torch.device, + drop_last: bool = True, + gradient_accumulation_steps: int = 1, + ): + """ + Initialize async batch iterator. + + Args: + tokenization_queue: TokenizationQueue instance + batch_size: Batch size for yielding batches + device: torch.device to move batches to + drop_last: If True, drop last incomplete batch + gradient_accumulation_steps: For logging purposes (not used here) + """ + self.queue = tokenization_queue + self.batch_size = batch_size + self.device = device + self.drop_last = drop_last + self.gradient_accumulation_steps = gradient_accumulation_steps + + # State for iteration + self._current_chunk: Optional[Dict[str, List]] = None + self._current_idx = 0 + self._samples_yielded = 0 + self._finished = False + + def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: + """Return iterator (self).""" + # Reset state for new epoch + self._current_chunk = None + self._current_idx = 0 + self._samples_yielded = 0 + self._finished = False + + # Reset tokenization queue for epochs 2+ + if self.queue._next_chunk_idx > 0: + print(f"\n✓ Resetting TokenizationQueue for next epoch (cur_idx={self.queue._next_chunk_idx})") + self.queue.reset_for_next_epoch() + + # Start tokenization threads once per iterator creation + if not self.queue._threads: + print("\n✓ Starting TokenizationQueue worker threads...") + self.queue.start() + else: + print(f"\n⚠ TokenizationQueue threads already running: {len(self.queue._threads)} active") + + return self + + def __next__(self) -> Dict[str, torch.Tensor]: + """ + Get next batch. + + Yields: + Dict with 'input_ids', 'attention_mask', 'labels' (all as torch tensors on device) + + Raises: + StopIteration: When no more batches available + """ + batch = self._get_next_batch() + + if batch is None: + print("AsyncBatchIterator: No more batches available, stopping iteration.") + raise StopIteration + + return batch + + def _get_next_batch(self) -> Optional[Dict[str, torch.Tensor]]: + """ + Fetch and collate the next batch. + + Returns: + Dict with batch tensors, or None if iteration exhausted + """ + batch_input_ids = [] + batch_attention_masks = [] + batch_labels = [] + + while len(batch_input_ids) < self.batch_size: + # Try to get next sample from current chunk + if self._current_chunk is None or self._current_idx >= len(self._current_chunk["input_ids"]): + # Need new chunk + self._current_chunk = self.queue.get_next_chunk(timeout=30.0) # 30s polling timeout + + if self._current_chunk is None: + if not self.queue.is_exhausted: + continue + # Queue exhausted + chunk_count = self.queue._next_chunk_idx if hasattr(self.queue, '_next_chunk_idx') else 'unknown' + print(f"AsyncBatchIterator: No more chunks (processed {chunk_count}/{len(self.queue._chunk_order)})") + print(f"AsyncBatchIterator: Samples yielded so far: {self._samples_yielded}") + self._finished = True + break + + self._current_idx = 0 + + # Get sample from current chunk + input_ids = self._current_chunk["input_ids"][self._current_idx] + attention_mask = self._current_chunk["attention_mask"][self._current_idx] + + # Generate labels based on SFT or pretrain mode + if "mask" in self._current_chunk: + # SFT mode: use mask to determine which tokens to train on + # mask=0 → label=-100 (ignore), mask=1 → label=input_id (train on) + mask = self._current_chunk["mask"][self._current_idx] + labels = build_response_only_next_token_labels(input_ids, mask) + else: + # Pretrain mode: shift labels by 1 for next-token prediction + # Position i predicts token at position i+1 + labels = input_ids[1:] + [-100] # Append -100 as final position + + # Mark padding tokens as -100 to ignore in loss computation + for i, mask_val in enumerate(attention_mask): + if mask_val == 0: + labels[i] = -100 + + batch_input_ids.append(input_ids) + batch_attention_masks.append(attention_mask) + batch_labels.append(labels) + + self._current_idx += 1 + self._samples_yielded += 1 + + # Return batch if we have any samples, respecting drop_last + if len(batch_input_ids) == 0: + print(f"AsyncBatchIterator: No samples collected for batch. Finished={self._finished}, returning None.") + return None + + if len(batch_input_ids) < self.batch_size and self.drop_last: + incomplete_pct = (len(batch_input_ids) / self.batch_size) * 100 + print(f"AsyncBatchIterator: Batch incomplete ({len(batch_input_ids)}/{self.batch_size} = {incomplete_pct:.1f}%) and drop_last=True, returning None.") + return None + + return self._collate_batch(batch_input_ids, batch_attention_masks, batch_labels) + + def _collate_batch( + self, + batch_input_ids: List[List[int]], + batch_attention_masks: List[List[int]], + batch_labels: List[List[int]], + ) -> Dict[str, torch.Tensor]: + """ + Collate batch samples and move to device. + + Args: + batch_input_ids: List of token ID lists + batch_attention_masks: List of attention mask lists + batch_labels: List of label lists + + Returns: + Collated batch as torch tensors on device + """ + # Convert to tensors + input_ids_tensor = torch.tensor(batch_input_ids, dtype=torch.long, device=self.device) + attention_mask_tensor = torch.tensor(batch_attention_masks, dtype=torch.long, device=self.device) + labels_tensor = torch.tensor(batch_labels, dtype=torch.long, device=self.device) + + return { + "input_ids": input_ids_tensor, + "attention_mask": attention_mask_tensor, + "labels": labels_tensor, + } + + def __len__(self) -> int: + """Return approximate number of batches.""" + total_samples = len(self.queue) + if self.drop_last: + return total_samples // self.batch_size + else: + return (total_samples + self.batch_size - 1) // self.batch_size + + def shutdown(self): + """Shutdown the async iterator and background thread.""" + self.queue.shutdown(wait=True) + + def __del__(self): + """Cleanup on deletion.""" + self.shutdown() diff --git a/code/TaoTrain/src/taoTrain/data/chunk_manager.py b/code/TaoTrain/src/taoTrain/data/chunk_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..88351a224a7d307dda139e63c08cff2d5ec09308 --- /dev/null +++ b/code/TaoTrain/src/taoTrain/data/chunk_manager.py @@ -0,0 +1,452 @@ +"""Chunk manager for streaming large JSONL datasets.""" + +import os +import json +import hashlib +from typing import Tuple, Optional, Dict, Any +from pathlib import Path +from tqdm import tqdm + + +class ChunkManager: + """ + Manages chunked reading of large JSONL files. + + This class handles: + - File scanning to count total lines without loading all text + - Estimating chunk boundaries based on file size + - Tracking which line ranges belong to each chunk + """ + + def __init__(self, jsonl_path: str, chunk_size_gb: float = 5.0, + samples_per_chunk: Optional[int] = None, + enable_metadata_cache: bool = True, chunk_cache_dir: str = ".cache/chunks", + max_samples: Optional[int] = None): + """ + Initialize ChunkManager. + + Args: + jsonl_path: Path to JSONL file + chunk_size_gb: Approximate chunk size in GB (ignored if samples_per_chunk is set) + samples_per_chunk: Number of samples per chunk (takes precedence over chunk_size_gb) + enable_metadata_cache: Enable caching of file scan metadata + chunk_cache_dir: Directory to store cache files + max_samples: Limit total samples to at most this many (if total_lines > max_samples) + + Raises: + FileNotFoundError: If JSONL file doesn't exist + ValueError: If file is empty + """ + self.jsonl_path = Path(jsonl_path) + self.chunk_size_bytes = int(chunk_size_gb * 1024 ** 3) # Convert GB to bytes + self.max_samples = max_samples # Limit total samples if specified + print (f"Initializing ChunkManager for {self.jsonl_path} with target chunk size {chunk_size_gb} GB") + if samples_per_chunk is not None: + print(f" Overriding chunk size with {samples_per_chunk} samples per chunk") + if max_samples is not None: + print(f" Limiting dataset to {max_samples} samples") + self.samples_per_chunk = samples_per_chunk # If set, overrides GB-based chunking + self.enable_metadata_cache = enable_metadata_cache + self.chunk_cache_dir = Path(chunk_cache_dir) + + if not self.jsonl_path.exists(): + raise FileNotFoundError(f"JSONL file not found: {self.jsonl_path}") + + self.file_size_bytes = os.path.getsize(self.jsonl_path) + self.file_mtime = os.path.getmtime(self.jsonl_path) + + if self.file_size_bytes == 0: + raise ValueError("JSONL file is empty") + + # Will be populated by _scan_file() + self.total_lines = 0 + self.effective_lines = 0 + self.line_sizes = [] # bytes per line + self.valid_line_offsets = [] # byte offset of each VALID JSON line (for seeking) + self.chunk_line_ranges = [] # [(start_line, end_line), ...] + + # Try to load from cache first + cache_loaded = False + if self.enable_metadata_cache: + cache_loaded = self._load_metadata_cache() + + # If cache not used, scan the file + if not cache_loaded: + self._scan_file() + self._compute_chunk_ranges() + + # Save metadata cache for future runs + if self.enable_metadata_cache: + self._save_metadata_cache() + else: + # Cache stores file scan metadata. Recompute chunk ranges for the + # current training config so samples_per_chunk/max_samples changes + # are honored without rescanning the large JSONL file. + self._compute_chunk_ranges() + + def _get_cache_path(self) -> Path: + """Get the metadata cache file path for this JSONL file.""" + # Create a hash of the file path to use as cache filename + file_hash = hashlib.md5(str(self.jsonl_path.absolute()).encode()).hexdigest()[:8] + cache_file = self.chunk_cache_dir / f"{file_hash}.metadata.json" + return cache_file + + def _load_metadata_cache(self) -> bool: + """ + Load metadata from cache if it exists and is valid. + + Returns: + True if cache was loaded successfully, False otherwise + """ + cache_file = self._get_cache_path() + + if not cache_file.exists(): + return False + + try: + with open(cache_file, 'r', encoding='utf-8') as f: + cache_data = json.load(f) + + # Validate cache: check file hasn't changed + if (cache_data.get('file_size') != self.file_size_bytes or + cache_data.get('file_mtime') != self.file_mtime or + cache_data.get('jsonl_path') != str(self.jsonl_path.absolute())): + return False + + # Load cached data + self.total_lines = cache_data.get('total_lines', 0) + self.line_sizes = cache_data.get('line_sizes', []) + self.valid_line_offsets = cache_data.get('valid_line_offsets', []) + # Convert loaded lists back to tuples for chunk_line_ranges + chunk_ranges = cache_data.get('chunk_line_ranges', []) + self.chunk_line_ranges = [tuple(r) for r in chunk_ranges] + self.chunk_size_bytes = cache_data.get('chunk_size_bytes', self.chunk_size_bytes) + + print(f"✓ Loaded scan metadata from cache: {cache_file.name}") + print(f" Found {self.total_lines:,} valid JSON lines in {len(self.chunk_line_ranges)} chunks") + return True + + except Exception as e: + # If cache loading fails, fall back to scanning + return False + + def _save_metadata_cache(self) -> None: + """Save metadata cache to file.""" + cache_file = self._get_cache_path() + cache_file.parent.mkdir(parents=True, exist_ok=True) + + cache_data = { + 'jsonl_path': str(self.jsonl_path.absolute()), + 'file_size': self.file_size_bytes, + 'file_mtime': self.file_mtime, + 'total_lines': self.total_lines, + 'line_sizes': self.line_sizes, + 'valid_line_offsets': self.valid_line_offsets, + 'chunk_line_ranges': self.chunk_line_ranges, + 'chunk_size_bytes': self.chunk_size_bytes, + } + + try: + # Write atomically using a temp file + rename + temp_file = cache_file.with_suffix('.tmp') + with open(temp_file, 'w', encoding='utf-8') as f: + json.dump(cache_data, f, indent=2) + temp_file.replace(cache_file) + print(f" Saved scan metadata to cache: {cache_file.name}") + except Exception as e: + print(f" ⚠ Warning: failed to save cache: {e}") + + def _get_chunk_cache_dir(self) -> Path: + """Get the directory for storing cached chunk data for this JSONL file.""" + file_hash = hashlib.md5(str(self.jsonl_path.absolute()).encode()).hexdigest()[:8] + chunk_dir = self.chunk_cache_dir / "chunks" / file_hash + return chunk_dir + + def _get_chunk_cache_file(self, chunk_num: int) -> Path: + """Get the cache file path for a specific chunk.""" + chunk_dir = self._get_chunk_cache_dir() + return chunk_dir / f"chunk_{chunk_num:06d}.jsonl" + + def _get_chunk_index_file(self) -> Path: + """Get the index file that lists all cached chunks.""" + chunk_dir = self._get_chunk_cache_dir() + return chunk_dir / "index.json" + + def extract_and_cache_chunks(self) -> Dict[str, Any]: + """ + Extract chunks from the original JSONL file and save them as separate cached files. + + This is optional and should be called manually if you want to pre-cache chunks + for faster repeated access. It can significantly speed up training but uses more disk space. + + Returns: + Dictionary with cache information: + - 'cache_dir': path to cache directory + - 'num_chunks': number of chunks cached + - 'total_size_gb': total size of cached chunks + """ + chunk_dir = self._get_chunk_cache_dir() + chunk_dir.mkdir(parents=True, exist_ok=True) + + print(f"💾 Extracting {len(self.chunk_line_ranges)} chunks to cache...") + total_size = 0 + + for chunk_num in range(len(self.chunk_line_ranges)): + cache_file = self._get_chunk_cache_file(chunk_num) + + # Skip if already cached + if cache_file.exists(): + total_size += os.path.getsize(cache_file) + continue + + # Read chunk and save to cache file + chunk_examples = self.read_chunk(chunk_num, _from_cache=False) + + with open(cache_file, 'w', encoding='utf-8') as f: + for obj in chunk_examples: + f.write(json.dumps(obj) + '\n') + + total_size += os.path.getsize(cache_file) + if (chunk_num + 1) % max(1, len(self.chunk_line_ranges) // 10) == 0: + print(f" - Cached {chunk_num + 1}/{len(self.chunk_line_ranges)} chunks...") + + # Write index file + index_data = { + 'jsonl_path': str(self.jsonl_path.absolute()), + 'num_chunks': len(self.chunk_line_ranges), + 'chunk_ranges': self.chunk_line_ranges, + } + with open(self._get_chunk_index_file(), 'w', encoding='utf-8') as f: + json.dump(index_data, f, indent=2) + + print(f"✓ Cached {len(self.chunk_line_ranges)} chunks ({total_size / (1024**3):.2f} GB)") + + return { + 'cache_dir': str(chunk_dir), + 'num_chunks': len(self.chunk_line_ranges), + 'total_size_gb': total_size / (1024**3), + } + + def clear_chunk_cache(self, keep_metadata: bool = False) -> None: + """ + Clear cached chunk data. + + Args: + keep_metadata: If True, only remove chunk files, keep the metadata cache + """ + chunk_dir = self._get_chunk_cache_dir() + + if chunk_dir.exists(): + import shutil + shutil.rmtree(chunk_dir) + print(f"✓ Cleared chunk cache: {chunk_dir}") + + if not keep_metadata: + cache_file = self._get_cache_path() + if cache_file.exists(): + cache_file.unlink() + print(f"✓ Cleared metadata cache: {cache_file}") + + def _scan_file(self) -> None: + """ + Scan JSONL file to count lines and track offsets. + + This reads the file once to: + - Count total valid JSON lines + - Record byte offset of each VALID line for seeking + - Estimate size per line + """ + print(f"📖 Scanning JSONL file: {self.jsonl_path}") + print(f" File size: {self.file_size_bytes / (1024**3):.2f} GB") + + self.valid_line_offsets = [] + current_offset = 0 + valid_lines = 0 + + try: + with open(self.jsonl_path, 'r', encoding='utf-8') as f: + for line in tqdm(f, desc="Scanning JSONL", unit=" lines"): + # Skip empty lines - don't count toward line numbers + if not line.strip(): + current_offset += len(line.encode('utf-8')) + continue + + try: + json.loads(line) + # Valid JSON line - record its starting byte offset + self.valid_line_offsets.append(current_offset) + valid_lines += 1 + + line_bytes = len(line.encode('utf-8')) + self.line_sizes.append(line_bytes) + + except json.JSONDecodeError: + # Skip invalid JSON lines - don't count toward line numbers + pass + + current_offset += len(line.encode('utf-8')) + + except Exception as e: + raise ValueError(f"Error scanning JSONL file: {e}") + + self.total_lines = valid_lines + + if self.total_lines == 0: + raise ValueError("No valid JSON lines found in JSONL file") + + print(f"✓ Found {self.total_lines:,} valid JSON lines") + + # Calculate average line size + avg_line_size = sum(self.line_sizes) / len(self.line_sizes) if self.line_sizes else 0 + print(f" Average line size: {avg_line_size:.2f} bytes") + print(f" Chunk size target: {self.chunk_size_bytes / (1024**3):.2f} GB") + + def _compute_chunk_ranges(self) -> None: + """ + Compute line ranges for each chunk based on target chunk size. + + If samples_per_chunk is set, uses that. Otherwise, divides file + based on chunk_size_bytes. If max_samples is set, limits chunks to cover + at most max_samples lines. + """ + if self.total_lines == 0: + self.chunk_line_ranges = [] + return + + # Apply max_samples limit to effective line count + self.effective_lines = self.total_lines + if self.max_samples is not None: + self.effective_lines = min(self.total_lines, self.max_samples) + + # Determine lines per chunk + if self.samples_per_chunk is not None: + # Use explicit sample count + lines_per_chunk = self.samples_per_chunk + else: + # Use GB-based calculation + avg_line_size = sum(self.line_sizes) / len(self.line_sizes) if self.line_sizes else 1 + lines_per_chunk = max(1, int(self.chunk_size_bytes / avg_line_size)) + + chunk_ranges = [] + start_line = 0 + + # Create chunks up to self.effective_lines (honors max_samples) + while start_line < self.effective_lines: + end_line = min(start_line + lines_per_chunk, self.effective_lines) + chunk_ranges.append((start_line, end_line)) + start_line = end_line + + self.chunk_line_ranges = chunk_ranges + self.num_chunks = len(chunk_ranges) + + print(f" Divided into {self.num_chunks} chunks (covering {self.effective_lines:,} lines)") + + def get_chunk_indices(self, chunk_num: int) -> Tuple[int, int]: + """ + Get (start_line, end_line) for a given chunk number. + + Args: + chunk_num: Chunk number (0-indexed) + + Returns: + Tuple of (start_line, end_line) where end_line is exclusive + + Raises: + IndexError: If chunk_num is out of range + """ + if chunk_num < 0 or chunk_num >= len(self.chunk_line_ranges): + raise IndexError(f"Chunk {chunk_num} out of range [0, {len(self.chunk_line_ranges)-1}]") + + return self.chunk_line_ranges[chunk_num] + + def read_chunk(self, chunk_num: int, _from_cache: bool = True) -> list[dict]: + """ + Read a specific chunk and return parsed JSON objects. + + If chunk cache is available, reads from cache. Otherwise reads from original JSONL + using file.seek() for O(1) lookup instead of O(n) scanning. + + Args: + chunk_num: Chunk number (0-indexed) + _from_cache: Internal parameter to force reading from original (used during cache extraction) + + Returns: + List of parsed JSON objects from that chunk + + Raises: + IndexError: If chunk_num is out of range + ValueError: If JSON parsing fails + """ + # Try to read from cache first (if it exists) + if _from_cache: + cache_file = self._get_chunk_cache_file(chunk_num) + if cache_file.exists(): + examples = [] + try: + with open(cache_file, 'r', encoding='utf-8') as f: + for line in f: + if line.strip(): + try: + obj = json.loads(line) + examples.append(obj) + except json.JSONDecodeError: + pass + return examples + except Exception as e: + print(f" ⚠ Warning: failed to read chunk from cache, falling back to original: {e}") + + # Read from original JSONL file using seek optimization + start_line, end_line = self.get_chunk_indices(chunk_num) + + examples = [] + + with open(self.jsonl_path, 'r', encoding='utf-8') as f: + # Seek to the byte offset of the start line + # This is O(1) instead of O(start_line) iteration + if start_line < len(self.valid_line_offsets): + f.seek(self.valid_line_offsets[start_line]) + else: + # Fallback if valid_line_offsets not available (shouldn't happen) + f.seek(0) + + current_line = start_line + + # Read lines from start_line to end_line + for line in f: + # Skip empty lines + if not line.strip(): + continue + + # Stop when we've read enough lines + if current_line >= end_line: + break + + try: + obj = json.loads(line) + examples.append(obj) + current_line += 1 + except json.JSONDecodeError: + # Skip invalid JSON lines, but don't increment line counter + # This maintains alignment with line numbering from scan + pass + + return examples + + @property + def num_chunks(self) -> int: + """Return number of chunks.""" + return len(self.chunk_line_ranges) + + @num_chunks.setter + def num_chunks(self, value: int) -> None: + """Set number of chunks (internal use).""" + self._num_chunks = value + + def __repr__(self) -> str: + """String representation.""" + return ( + f"ChunkManager(file={self.jsonl_path.name}, " + f"size={self.file_size_bytes/(1024**3):.2f}GB, " + f"lines={self.effective_lines:,}, " + f"chunks={self.num_chunks})" + ) diff --git a/code/TaoTrain/src/taoTrain/data/factory.py b/code/TaoTrain/src/taoTrain/data/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..1efec5ca283cbb7f6789ce0d2bfb74f959a1759f --- /dev/null +++ b/code/TaoTrain/src/taoTrain/data/factory.py @@ -0,0 +1,108 @@ +"""Factory for creating datasets based on configuration.""" + +from taoTrain.config import TrainingConfig, TrainingModeEnum +from taoTrain.data.pretrain_jsonl import PretrainJSONLDataset +from taoTrain.data.sft_jsonl import SFTJSONLDataset +from taoTrain.data.rl_jsonl import RLJSONLDataset + +try: + from taoTrain.data.hf_pretrain import PretrainDataset + from taoTrain.data.hf_sft import SFTDataset + from taoTrain.data.hf_rl import RLDataset +except ImportError: + PretrainDataset = None + SFTDataset = None + RLDataset = None + + +class DatasetFactory: + """Factory for creating datasets based on configuration.""" + + # Registry of dataset classes by mode and backend + DATASETS = { + (TrainingModeEnum.PRETRAIN, "jsonl"): PretrainJSONLDataset, + (TrainingModeEnum.SFT, "jsonl"): SFTJSONLDataset, + (TrainingModeEnum.RL, "jsonl"): RLJSONLDataset, + } + + if PretrainDataset is not None: + DATASETS.update({ + (TrainingModeEnum.PRETRAIN, "huggingface"): PretrainDataset, + (TrainingModeEnum.SFT, "huggingface"): SFTDataset, + (TrainingModeEnum.RL, "huggingface"): RLDataset, + }) + + @staticmethod + def create_dataset( + config: TrainingConfig, + split: str = "train", + ): + """ + Create dataset instance based on configuration. + + Args: + config: Training configuration + split: Dataset split (train, validation, test) - primarily for HuggingFace datasets + + Returns: + Dataset instance matching the configured mode and backend + + Raises: + ValueError: If configuration is invalid or unsupported mode/backend combination + """ + # Determine backend: JSONL or HuggingFace + if config.dataset.local: + backend = "jsonl" + else: + backend = "huggingface" + + # Get mode + mode = config.mode + + # Look up dataset class + key = (mode, backend) + if key not in DatasetFactory.DATASETS: + if backend == "huggingface": + raise ImportError( + "HuggingFace dataset support requires the optional 'datasets' dependency. " + "Install project dependencies before using dataset.local=false." + ) + raise ValueError( + f"Unsupported dataset configuration: mode={mode.value}, backend={backend}. " + f"Supported: {list(DatasetFactory.DATASETS.keys())}" + ) + + dataset_class = DatasetFactory.DATASETS[key] + + # Instantiate dataset + if backend == "jsonl": + # JSONL datasets don't use split parameter + return dataset_class(config) + else: + # HuggingFace datasets use split parameter + return dataset_class(config, split=split) + + @staticmethod + def register_dataset(mode: TrainingModeEnum, backend: str, dataset_class): + """ + Register a custom dataset class. + + Args: + mode: Training mode (e.g., TrainingModeEnum.PRETRAIN) + backend: Backend name (e.g., "jsonl", "huggingface") + dataset_class: Dataset class to register + """ + DatasetFactory.DATASETS[(mode, backend)] = dataset_class + + @staticmethod + def list_available_datasets(): + """List all available dataset configurations.""" + configs = {} + for (mode, backend), dataset_class in DatasetFactory.DATASETS.items(): + key = f"{mode.value}_{backend}" + configs[key] = { + "mode": mode.value, + "backend": backend, + "class": dataset_class.__name__, + } + return configs diff --git a/code/TaoTrain/src/taoTrain/data/hf_base.py b/code/TaoTrain/src/taoTrain/data/hf_base.py new file mode 100644 index 0000000000000000000000000000000000000000..a98fc1c521f16383270b4aa712f56a98f7d76b45 --- /dev/null +++ b/code/TaoTrain/src/taoTrain/data/hf_base.py @@ -0,0 +1,82 @@ +"""Base class for HuggingFace-based datasets.""" + +from typing import Optional, Dict +import torch +from torch.utils.data import Dataset +from datasets import load_dataset +from transformers import AutoTokenizer +from taoTrain.config import TrainingConfig + + +class BaseHFDataset(Dataset): + """Base class for HuggingFace-based datasets.""" + + def __init__(self, config: TrainingConfig, split: str = "train"): + """ + Initialize dataset. + + Args: + config: Training configuration + split: Dataset split (train, validation, test) + """ + self.config = config + self.split = split + self.data = None + self.tokenizer = None + + # Load tokenizer + self._load_tokenizer() + + # Load and preprocess dataset + self._load_dataset() + self._preprocess() + + def _load_tokenizer(self): + """Load tokenizer from HuggingFace.""" + # Default to GPT-2 tokenizer if not specified + tokenizer_name = getattr(self.config, 'tokenizer_name', 'gpt2') + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + + # Set pad token if not set + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + def _load_dataset(self): + """Load dataset from HuggingFace.""" + dataset_config = self.config.dataset + + try: + # Load dataset + if dataset_config.config: + self.data = load_dataset( + dataset_config.dataset_name, + dataset_config.config, + split=self.split, + cache_dir=dataset_config.cache_dir, + trust_remote_code=True, + ) + else: + self.data = load_dataset( + dataset_config.dataset_name, + split=self.split, + cache_dir=dataset_config.cache_dir, + trust_remote_code=True, + ) + except Exception as e: + raise ValueError(f"Failed to load dataset {dataset_config.dataset_name}: {e}") + + # Limit samples if specified + if dataset_config.max_samples: + self.data = self.data.select(range(min(dataset_config.max_samples, len(self.data)))) + + def _preprocess(self): + """Preprocess dataset (to be implemented by subclasses).""" + pass + + def __len__(self) -> int: + """Return dataset length.""" + return len(self.data) + + def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: + """Get item (to be implemented by subclasses).""" + pass diff --git a/code/TaoTrain/src/taoTrain/data/hf_pretrain.py b/code/TaoTrain/src/taoTrain/data/hf_pretrain.py new file mode 100644 index 0000000000000000000000000000000000000000..20ef6b59dd1d02f30e43b8fccd01b598421ec9a9 --- /dev/null +++ b/code/TaoTrain/src/taoTrain/data/hf_pretrain.py @@ -0,0 +1,78 @@ +"""Pretrain dataset for HuggingFace datasets.""" + +from typing import Dict +import torch +from taoTrain.config import TrainingConfig +from taoTrain.data.hf_base import BaseHFDataset + + +class PretrainDataset(BaseHFDataset): + """Dataset for pretraining with raw text.""" + + def _preprocess(self): + """Tokenize text data.""" + dataset_config = self.config.dataset + text_column = dataset_config.text_column + + def tokenize_function(examples): + # Concatenate all texts + concatenated_examples = { + k: sum(examples[k], []) for k in examples.keys() + } + + total_length = len(concatenated_examples[text_column]) + # We'll use max_seq_length for training + total_length = (total_length // self.config.model.max_seq_length) * self.config.model.max_seq_length + + # Tokenize + tokenized = self.tokenizer( + concatenated_examples[text_column], + truncation=False, # We'll chunk below + return_special_tokens_mask=False, + ) + + # Chunk tokenized text + result = { + "input_ids": [], + "attention_mask": [], + } + + for i in range(0, total_length, self.config.model.max_seq_length): + result["input_ids"].append( + tokenized["input_ids"][i:i + self.config.model.max_seq_length] + ) + result["attention_mask"].append( + tokenized["attention_mask"][i:i + self.config.model.max_seq_length] + ) + + return result + + # Preprocess in batches + self.data = self.data.map( + tokenize_function, + batched=True, + batch_size=100, + remove_columns=self.data.column_names, + desc="Tokenizing...", + ) + + def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: + """Get preprocessed sample.""" + item = self.data[idx] + + input_ids = torch.tensor(item["input_ids"], dtype=torch.long) + attention_mask = torch.tensor(item["attention_mask"], dtype=torch.long) + + # For pretrain, labels = input_ids shifted by 1 (next token prediction) + # Position i predicts token at position i+1 + labels = input_ids[1:].clone() + labels = torch.cat([labels, torch.tensor([-100])], dim=0) + + # Mark padding tokens as -100 to ignore in loss computation + labels[attention_mask == 0] = -100 + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": labels, + } diff --git a/code/TaoTrain/src/taoTrain/data/hf_rl.py b/code/TaoTrain/src/taoTrain/data/hf_rl.py new file mode 100644 index 0000000000000000000000000000000000000000..6aa4617d0bb892f9dcea73b3143c7b07dfc4f88d --- /dev/null +++ b/code/TaoTrain/src/taoTrain/data/hf_rl.py @@ -0,0 +1,73 @@ +"""RL dataset for HuggingFace datasets.""" + +from typing import Dict +import torch +from taoTrain.config import TrainingConfig +from taoTrain.data.hf_base import BaseHFDataset + + +class RLDataset(BaseHFDataset): + """Dataset for RL training with prompts.""" + + def _preprocess(self): + """Prepare prompts for RL.""" + dataset_config = self.config.dataset + + # For RL, we typically just need prompts (no responses) + # The responses will be generated by the model during training + + if dataset_config.prompt_column: + # Use existing prompt column + def extract_prompt(example): + return {"prompt": example[dataset_config.prompt_column]} + + self.data = self.data.map( + extract_prompt, + remove_columns=self.data.column_names, + desc="Extracting prompts...", + ) + else: + # For general datasets, just use the text column as prompt + def identity(example): + return {"prompt": example.get(dataset_config.text_column, "")} + + self.data = self.data.map( + identity, + remove_columns=self.data.column_names, + desc="Preparing prompts...", + ) + + # Tokenize prompts + def tokenize_function(examples): + tokenized = self.tokenizer( + examples["prompt"], + truncation=True, + max_length=self.config.model.max_seq_length, + padding="max_length", + return_attention_mask=True, + ) + return tokenized + + self.data = self.data.map( + tokenize_function, + batched=True, + batch_size=100, + remove_columns=self.data.column_names, + desc="Tokenizing prompts...", + ) + + def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: + """Get preprocessed prompt.""" + item = self.data[idx] + + input_ids = torch.tensor(item["input_ids"], dtype=torch.long) + attention_mask = torch.tensor(item["attention_mask"], dtype=torch.long) + + # For RL, we don't have labels yet + # They're generated during training + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + # "labels" will be None or set by the trainer + } diff --git a/code/TaoTrain/src/taoTrain/data/hf_sft.py b/code/TaoTrain/src/taoTrain/data/hf_sft.py new file mode 100644 index 0000000000000000000000000000000000000000..0e95480d1fa16a7b76611bac36877a420d753254 --- /dev/null +++ b/code/TaoTrain/src/taoTrain/data/hf_sft.py @@ -0,0 +1,81 @@ +"""SFT dataset for HuggingFace datasets.""" + +from typing import Dict +import torch +from taoTrain.config import TrainingConfig +from taoTrain.data.hf_base import BaseHFDataset + + +class SFTDataset(BaseHFDataset): + """Dataset for supervised fine-tuning with instruction-response pairs.""" + + def _preprocess(self): + """Process instruction-response pairs.""" + dataset_config = self.config.dataset + + def format_example(example): + """Format instruction and response.""" + instruction = example.get(dataset_config.instruction_column, "") + response = example.get(dataset_config.response_column, "") + + if dataset_config.instruction_template: + # Use custom template + text = dataset_config.instruction_template.format( + instruction=instruction, + response=response + ) + else: + # Default template + text = f"{instruction}\n{response}" + + return {"text": text} + + # Format examples + self.data = self.data.map( + format_example, + remove_columns=[ + col for col in self.data.column_names + if col not in ["text"] + ] if "text" not in self.data.column_names else [], + desc="Formatting examples...", + ) + + # Tokenize + def tokenize_function(examples): + tokenized = self.tokenizer( + examples["text"], + truncation=True, + max_length=self.config.model.max_seq_length, + padding="max_length", + return_attention_mask=True, + ) + return tokenized + + self.data = self.data.map( + tokenize_function, + batched=True, + batch_size=100, + remove_columns=self.data.column_names, + desc="Tokenizing...", + ) + + def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: + """Get preprocessed sample.""" + item = self.data[idx] + + input_ids = torch.tensor(item["input_ids"], dtype=torch.long) + attention_mask = torch.tensor(item["attention_mask"], dtype=torch.long) + + # For SFT, labels = input_ids shifted by 1 (next token prediction) + # Position i predicts token at position i+1 + labels = input_ids[1:].clone() + labels = torch.cat([labels, torch.tensor([-100])], dim=0) + + # Mark padding tokens as -100 to ignore in loss computation + labels[attention_mask == 0] = -100 + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": labels, + } diff --git a/code/TaoTrain/src/taoTrain/data/jsonl_base.py b/code/TaoTrain/src/taoTrain/data/jsonl_base.py new file mode 100644 index 0000000000000000000000000000000000000000..2c246bbcd11164b4b1da63073da8d0855ad7a70a --- /dev/null +++ b/code/TaoTrain/src/taoTrain/data/jsonl_base.py @@ -0,0 +1,220 @@ +"""Base class for local JSONL-based datasets (async-only).""" + +import json +from typing import Optional, Dict, Any +import torch +from torch.utils.data import Dataset +from taoTrain.config import TrainingConfig +from taoTrain.data.chunk_manager import ChunkManager +from taoTrain.data.tokenizer import SentencePieceTokenizerWrapper + + +class BaseJSONLDataset(Dataset): + """ + Base class for local JSONL-based datasets with async-only streaming. + + Designed for use with AsyncBatchIterator and TokenizationQueue. + All data loading and preprocessing happens asynchronously in background threads. + """ + + def __init__(self, config: TrainingConfig, split: str = "train"): + """ + Initialize JSONL dataset with chunked loading. + + Args: + config: Training configuration + split: Dataset split (train, validation, test) - not used for JSONL but kept for compatibility + + Note: + Requires AsyncBatchIterator and TokenizationQueue for data loading. + See taoTrain/data/async_loader.py for usage. + """ + self.config = config + self.split = split + self.tokenizer = None + + # Initialize chunk manager for streaming + dataset_config = self.config.dataset + jsonl_path = dataset_config.jsonl_path + + if not jsonl_path: + raise ValueError("jsonl_path must be provided for local JSONL datasets") + + # Create chunk manager + enable_streaming = dataset_config.enable_streaming + chunk_size_gb = dataset_config.chunk_size_gb + samples_per_chunk = dataset_config.samples_per_chunk + enable_metadata_cache = dataset_config.enable_chunk_metadata_cache + chunk_cache_dir = dataset_config.chunk_cache_dir + max_samples = dataset_config.max_samples + + if enable_streaming: + self.chunk_manager = ChunkManager( + jsonl_path, + chunk_size_gb=chunk_size_gb, + samples_per_chunk=samples_per_chunk, + enable_metadata_cache=enable_metadata_cache, + chunk_cache_dir=chunk_cache_dir, + max_samples=max_samples + ) + print(f"✓ {self.chunk_manager}") + else: + self.chunk_manager = None + + # Current chunk data + self._current_chunk_num = None + self._current_chunk_data = None # {"text": [...]} or preprocessed data + self._text_field = dataset_config.text_field + + # Load tokenizer + print("✓ Loading tokenizer...") + self._load_tokenizer() + + print("✓ Dataset initialization complete (async mode - chunks loaded on-demand).") + + def _load_tokenizer(self): + """Load tokenizer (from local SentencePiece or HuggingFace).""" + dataset_config = self.config.dataset + + # Check if tokenizer_path is specified + if dataset_config.tokenizer_path: + tokenizer_type = dataset_config.tokenizer_type + + # Auto-detect tokenizer type based on file extension + if tokenizer_type is None: + if dataset_config.tokenizer_path.endswith('.model'): + tokenizer_type = 'sentencepiece' + else: + tokenizer_type = 'huggingface' + + if tokenizer_type == 'sentencepiece': + # Load SentencePiece tokenizer + try: + import sentencepiece as spm + sp = spm.SentencePieceProcessor() + sp.Load(dataset_config.tokenizer_path) + # Wrap SentencePiece in a compatible interface + self.tokenizer = SentencePieceTokenizerWrapper(sp) + except ImportError: + raise ImportError("SentencePiece not installed. Install with: pip install sentencepiece") + except Exception as e: + raise ValueError(f"Failed to load SentencePiece tokenizer from {dataset_config.tokenizer_path}: {e}") + else: + # Load HuggingFace tokenizer from path + try: + from transformers import AutoTokenizer + self.tokenizer = AutoTokenizer.from_pretrained(dataset_config.tokenizer_path) + except ImportError as e: + raise ImportError("HuggingFace tokenizers require the optional 'transformers' dependency") from e + except Exception as e: + raise ValueError(f"Failed to load HuggingFace tokenizer from {dataset_config.tokenizer_path}: {e}") + else: + # Default to GPT-2 tokenizer + try: + from transformers import AutoTokenizer + except ImportError as e: + raise ImportError("Default GPT-2 tokenizer requires the optional 'transformers' dependency") from e + tokenizer_name = getattr(self.config, 'tokenizer_name', 'gpt2') + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + + # Set pad token if not set (for HuggingFace tokenizers) + if hasattr(self.tokenizer, 'pad_token') and self.tokenizer.pad_token is None: + if hasattr(self.tokenizer, 'eos_token'): + self.tokenizer.pad_token = self.tokenizer.eos_token + + def _load_chunk(self, chunk_num: int): + """ + Load a specific chunk from JSONL file. + + Args: + chunk_num: Chunk number to load (0-indexed) + """ + if not self.chunk_manager: + return + + if chunk_num == self._current_chunk_num and self._current_chunk_data is not None: + # Already loaded + return + + # Read chunk + chunk_examples = self.chunk_manager.read_chunk(chunk_num) + + # Convert to text data + texts = [] + for obj in chunk_examples: + if self._text_field in obj: + texts.append(obj[self._text_field]) + + self._current_chunk_data = {"text": texts} + self._current_chunk_num = chunk_num + + # Preprocess chunk (tokenization happens in background via AsyncBatchIterator) + self._preprocess_chunk() + + def _get_chunk_for_idx(self, idx: int) -> int: + """ + Determine which chunk contains the given global index. + + Args: + idx: Global index + + Returns: + Chunk number (0-indexed) + """ + if not self.chunk_manager: + return 0 + + current_line = 0 + for chunk_num, (start_line, end_line) in enumerate(self.chunk_manager.chunk_line_ranges): + if idx < (end_line - start_line): + return chunk_num + idx -= (end_line - start_line) + + # Shouldn't reach here + return 0 + + def _get_local_idx_in_chunk(self, global_idx: int) -> int: + """ + Convert global index to local index within the chunk. + + Args: + global_idx: Global index + + Returns: + Local index within the chunk + """ + if not self.chunk_manager: + return global_idx + + current_line = 0 + for chunk_num, (start_line, end_line) in enumerate(self.chunk_manager.chunk_line_ranges): + chunk_size = end_line - start_line + if global_idx < chunk_size: + return global_idx + global_idx -= chunk_size + + return 0 + + def _preprocess(self): + """Preprocess dataset (to be implemented by subclasses).""" + pass + + def _preprocess_chunk(self): + """ + Preprocess current chunk (to be implemented by subclasses). + + This is called after a chunk is loaded by AsyncBatchIterator. + """ + pass + + def __len__(self) -> int: + """Return dataset length.""" + if self.chunk_manager: + return self.chunk_manager.effective_lines + elif self._current_chunk_data and "text" in self._current_chunk_data: + return len(self._current_chunk_data.get("text", [])) + return 0 + + def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: + """Get item (to be implemented by subclasses).""" + pass diff --git a/code/TaoTrain/src/taoTrain/data/loaders.py b/code/TaoTrain/src/taoTrain/data/loaders.py new file mode 100644 index 0000000000000000000000000000000000000000..71fa000d9ebdbfa5769ad456a865adf78c85f6e7 --- /dev/null +++ b/code/TaoTrain/src/taoTrain/data/loaders.py @@ -0,0 +1,85 @@ +"""DataLoader utilities.""" + +from typing import Optional +import torch +from torch.utils.data import DataLoader, Dataset +from taoTrain.config import TrainingConfig + + +def get_dataloader( + dataset: Dataset, + config: TrainingConfig, + shuffle: bool = True, + drop_last: bool = True, +) -> DataLoader: + """ + Create a DataLoader from a dataset. + + **NOTE**: For JSONL-based datasets (PretrainJSONLDataset, SFTJSONLDataset, etc.), + this function is now deprecated in favor of AsyncBatchIterator for better performance. + AsyncBatchIterator enables tokenization to happen in parallel with training, + avoiding the startup bottleneck of tokenizing all data upfront. + + See: taoTrain/data/async_loader.py for the new async loading approach. + The trainer automatically uses AsyncBatchIterator for JSONL datasets. + + Args: + dataset: PyTorch Dataset instance + config: Training configuration + shuffle: Whether to shuffle data + drop_last: Whether to drop last incomplete batch + + Returns: + DataLoader instance + """ + + def collate_fn(batch): + """Collate function for padding sequences.""" + # Batch is a list of dicts + collated = {} + keys = batch[0].keys() + + for key in keys: + items = [item[key] for item in batch] + + # Stack tensors + if isinstance(items[0], torch.Tensor): + if key in ["input_ids", "labels"]: + # Pad sequences + max_len = max(item.shape[0] for item in items) + padded = [] + for item in items: + if len(item.shape) == 1: + # 1D tensor - pad it + pad_len = max_len - item.shape[0] + if pad_len > 0: + item = torch.nn.functional.pad(item, (0, pad_len), value=-100 if key == "labels" else 0) + padded.append(item) + collated[key] = torch.stack(padded) + elif key == "attention_mask": + # Also pad attention mask + max_len = max(item.shape[0] for item in items) + padded = [] + for item in items: + if len(item.shape) == 1: + pad_len = max_len - item.shape[0] + if pad_len > 0: + item = torch.nn.functional.pad(item, (0, pad_len), value=0) + padded.append(item) + collated[key] = torch.stack(padded) + else: + collated[key] = torch.stack(items) + else: + collated[key] = items + + return collated + + return DataLoader( + dataset, + batch_size=config.batch_size, + shuffle=shuffle, + drop_last=drop_last, + num_workers=config.num_workers, + pin_memory=config.pin_memory, + collate_fn=collate_fn, + ) diff --git a/code/TaoTrain/src/taoTrain/data/pretrain_jsonl.py b/code/TaoTrain/src/taoTrain/data/pretrain_jsonl.py new file mode 100644 index 0000000000000000000000000000000000000000..e6029012d4bb8c1a6e9fd95982e5575fd26a923b --- /dev/null +++ b/code/TaoTrain/src/taoTrain/data/pretrain_jsonl.py @@ -0,0 +1,65 @@ +"""Pretrain JSONL dataset with async-only streaming.""" + +from typing import Dict +import torch +from taoTrain.config import TrainingConfig +from taoTrain.data.jsonl_base import BaseJSONLDataset + + +class PretrainJSONLDataset(BaseJSONLDataset): + """Dataset for pretraining with local JSONL files with chunked loading.""" + + def _preprocess_chunk(self): + """Tokenize current chunk of text data.""" + if not self._current_chunk_data or "text" not in self._current_chunk_data: + return + + max_seq_length = self.config.model.max_seq_length + texts = self._current_chunk_data["text"] + + # Tokenize all texts in this chunk + all_token_ids = [] + all_attention_masks = [] + + for text in texts: + tokenized = self.tokenizer( + text, + truncation=True, + max_length=max_seq_length, + padding="max_length", + return_attention_mask=True, + ) + all_token_ids.append(tokenized["input_ids"]) + all_attention_masks.append(tokenized["attention_mask"]) + + self._current_chunk_data = { + "input_ids": all_token_ids, + "attention_mask": all_attention_masks, + } + + def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: + """Get preprocessed sample, loading chunk if needed.""" + # Load appropriate chunk if using streaming + if self.chunk_manager: + chunk_num = self._get_chunk_for_idx(idx) + if chunk_num != self._current_chunk_num: + self._load_chunk(chunk_num) + local_idx = self._get_local_idx_in_chunk(idx) + else: + local_idx = idx + + input_ids = torch.tensor(self._current_chunk_data["input_ids"][local_idx], dtype=torch.long) + attention_mask = torch.tensor(self._current_chunk_data["attention_mask"][local_idx], dtype=torch.long) + + # For pretrain, labels = input_ids shifted + labels = input_ids[1:].clone() + labels = torch.cat([labels, torch.tensor([-100])], dim=0) + + # Replace padding token labels with -100 to ignore in labels + labels[attention_mask == 0] = -100 + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": labels, + } diff --git a/code/TaoTrain/src/taoTrain/data/rl_jsonl.py b/code/TaoTrain/src/taoTrain/data/rl_jsonl.py new file mode 100644 index 0000000000000000000000000000000000000000..a6bbaf1b632f98232710211d7dfd484cfadbfe09 --- /dev/null +++ b/code/TaoTrain/src/taoTrain/data/rl_jsonl.py @@ -0,0 +1,58 @@ +"""RL JSONL dataset with async-only streaming.""" + +from typing import Dict +import torch +from taoTrain.config import TrainingConfig +from taoTrain.data.jsonl_base import BaseJSONLDataset + + +class RLJSONLDataset(BaseJSONLDataset): + """Dataset for RL training with local JSONL files with chunked loading.""" + + def _preprocess_chunk(self): + """Prepare prompts for RL from current chunk.""" + if not self._current_chunk_data or "text" not in self._current_chunk_data: + return + + max_seq_length = self.config.model.max_seq_length + texts = self._current_chunk_data["text"] + + # Tokenize all prompts in this chunk + all_token_ids = [] + all_attention_masks = [] + + for text in texts: + tokenized = self.tokenizer( + text, + truncation=True, + max_length=max_seq_length, + padding="max_length", + return_attention_mask=True, + ) + all_token_ids.append(tokenized["input_ids"]) + all_attention_masks.append(tokenized["attention_mask"]) + + self._current_chunk_data = { + "input_ids": all_token_ids, + "attention_mask": all_attention_masks, + } + + def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: + """Get preprocessed prompt, loading chunk if needed.""" + # Load appropriate chunk if using streaming + if self.chunk_manager: + chunk_num = self._get_chunk_for_idx(idx) + if chunk_num != self._current_chunk_num: + self._load_chunk(chunk_num) + local_idx = self._get_local_idx_in_chunk(idx) + else: + local_idx = idx + + input_ids = torch.tensor(self._current_chunk_data["input_ids"][local_idx], dtype=torch.long) + attention_mask = torch.tensor(self._current_chunk_data["attention_mask"][local_idx], dtype=torch.long) + + # For RL, no labels yet (generated during training) + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + } diff --git a/code/TaoTrain/src/taoTrain/data/sft_jsonl.py b/code/TaoTrain/src/taoTrain/data/sft_jsonl.py new file mode 100644 index 0000000000000000000000000000000000000000..0678a5f5f9fa0e3a0c820070830f49d3210e799a --- /dev/null +++ b/code/TaoTrain/src/taoTrain/data/sft_jsonl.py @@ -0,0 +1,156 @@ +"""SFT JSONL dataset with async-only streaming and response-masking.""" + +from typing import Dict +import torch +from taoTrain.config import TrainingConfig +from taoTrain.data.jsonl_base import BaseJSONLDataset +from taoTrain.data.sft_utils import ( + parse_sft_record, + build_sft_sequence_tokens, + build_response_only_next_token_labels, +) + + +class SFTJSONLDataset(BaseJSONLDataset): + """ + Dataset for supervised fine-tuning with local JSONL files with chunked loading. + + Supports both single-turn and multi-turn SFT data: + - Single-turn: {"input": "...", "output": "..."} + - Multi-turn: {"turns": [{"user": "...", "assistant": "..."}, ...]} + + With response-only loss masking: only trains on assistant/response tokens. + """ + + def __init__(self, *args, **kwargs): + """Initialize dataset.""" + super().__init__(*args, **kwargs) + # Store full records for parsing (not just text field) + self._current_chunk_records = None + + # Get SFT-specific config + self.sft_config = self.config if hasattr(self.config, 'mode') else None + self.user_token = getattr(self.sft_config, 'user_token', '') if self.sft_config else '' + self.assistant_token = getattr(self.sft_config, 'assistant_token', '') if self.sft_config else '' + self.response_loss_only = getattr(self.sft_config, 'response_loss_only', True) if self.sft_config else True + + def _load_chunk(self, chunk_num: int): + """ + Load a specific chunk from JSONL file, preserving full records for SFT parsing. + + Args: + chunk_num: Chunk number to load (0-indexed) + """ + if not self.chunk_manager: + return + + if chunk_num == self._current_chunk_num and self._current_chunk_data is not None: + # Already loaded + return + + # Read chunk - get full record objects + chunk_examples = self.chunk_manager.read_chunk(chunk_num) + + # Store full records for SFT parsing (not just text field) + self._current_chunk_records = chunk_examples + + # Initialize data structures + self._current_chunk_data = { + "input_ids": [], + "attention_mask": [], + "mask": [], + } + self._current_chunk_num = chunk_num + + # Preprocess this chunk (tokenize and mask) + self._preprocess_chunk() + + def _preprocess_chunk(self): + """ + Process SFT records from current chunk into tokenized sequences with masking. + + Parses each record (single-turn or multi-turn) and generates: + - Token sequences with role markers + - Masking info (0=ignore, 1=train) + - Labels with -100 for ignored tokens + """ + if not self._current_chunk_records: + return + + max_seq_length = self.config.model.max_seq_length + + all_input_ids = [] + all_attention_masks = [] + all_masks = [] + + for record in self._current_chunk_records: + try: + # Parse record into (user, assistant) turns + turns, is_multi_turn = parse_sft_record(record, self.config) + + if not turns: + # Fallback: try to use "text" field if present + if "text" in record: + turns = [(record["text"], "")] + else: + continue # Skip invalid records + + # Build token sequence with role tokens and masking + input_ids, attention_mask, mask = build_sft_sequence_tokens( + turns=turns, + tokenizer=self.tokenizer, + user_token=self.user_token, + assistant_token=self.assistant_token, + max_seq_length=max_seq_length, + ) + + all_input_ids.append(input_ids) + all_attention_masks.append(attention_mask) + all_masks.append(mask) + + except Exception as e: + # Log and skip problematic records + print(f"Warning: Failed to process SFT record: {e}") + continue + + # Update chunk data with tokenized sequences and masks + self._current_chunk_data = { + "input_ids": all_input_ids, + "attention_mask": all_attention_masks, + "mask": all_masks, + } + + def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: + """ + Get preprocessed sample with response-only loss masking. + + Args: + idx: Sample index + + Returns: + Dict with input_ids, attention_mask, and labels (with -100 for ignored tokens) + """ + # Load appropriate chunk if using streaming + if self.chunk_manager: + chunk_num = self._get_chunk_for_idx(idx) + if chunk_num != self._current_chunk_num: + self._load_chunk(chunk_num) + local_idx = self._get_local_idx_in_chunk(idx) + else: + local_idx = idx + + # Get tokenized data + input_ids = torch.tensor(self._current_chunk_data["input_ids"][local_idx], dtype=torch.long) + attention_mask = torch.tensor(self._current_chunk_data["attention_mask"][local_idx], dtype=torch.long) + mask = self._current_chunk_data["mask"][local_idx] + + labels = torch.tensor( + build_response_only_next_token_labels(input_ids.tolist(), mask), + dtype=torch.long, + ) + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": labels, + } diff --git a/code/TaoTrain/src/taoTrain/data/sft_utils.py b/code/TaoTrain/src/taoTrain/data/sft_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..76507132a7ad9bfd618fe1e8673c7ef7772812ca --- /dev/null +++ b/code/TaoTrain/src/taoTrain/data/sft_utils.py @@ -0,0 +1,161 @@ +"""SFT utility functions for parsing and masking.""" + +from typing import Dict, Any, List, Tuple +from taoTrain.config import TrainingConfig + + +def parse_sft_record(record: Dict[str, Any], config: TrainingConfig) -> Tuple[List[Tuple[str, str]], bool]: + """ + Parse JSONL record into list of (user, assistant) turns. + + Supports two formats: + 1. Single-turn: {"input": "...", "output": "..."} + 2. Multi-turn: {"turns": [{"user": "...", "assistant": "..."}, ...]} + + Args: + record: JSONL record (dict) + config: Training configuration + + Returns: + (turns_list, is_multi_turn) where: + - turns_list: List of (user_text, assistant_text) tuples + - is_multi_turn: Whether this is a multi-turn record + """ + # Check for multi-turn format + if "turns" in record: + turns = [] + for turn in record["turns"]: + if isinstance(turn, dict) and "user" in turn and "assistant" in turn: + turns.append((turn["user"], turn["assistant"])) + if turns: + return turns, True + + # Check for single-turn format with input/output fields + if "input" in record and "output" in record: + return [(record["input"], record["output"])], False + + # Fallback: check for instruction/response fields (from config) + dataset_config = config.dataset + instruction_col = dataset_config.instruction_column or "instruction" + response_col = dataset_config.response_column or "response" + + if instruction_col in record and response_col in record: + return [(record[instruction_col], record[response_col])], False + + # Fallback: assume pre-formatted "text" field (old format) + if "text" in record: + return [(record["text"], "")], False + + return [], False + + +def build_sft_sequence_tokens( + turns: List[Tuple[str, str]], + tokenizer, + user_token: str = "", + assistant_token: str = "", + max_seq_length: int = 1024, +) -> Tuple[List[int], List[int], List[int]]: + """ + Build token sequence for SFT with role tokens and generate masking info. + + Sequence format: + [user_token_id] user_tokens [assistant_token_id] assistant_tokens ... [eos_token_id] + + Mask values: + - 0 (ignore): user input regions and role tokens → loss=-100 + - 1 (train): assistant output regions → compute loss + + Args: + turns: List of (user_text, assistant_text) tuples + tokenizer: Tokenizer instance + user_token: Role token for user (e.g., "") + assistant_token: Role token for assistant (e.g., "") + max_seq_length: Maximum sequence length + + Returns: + (input_ids, attention_mask, mask) where: + - input_ids: Token IDs for the full sequence + - attention_mask: Attention mask (1 for real tokens, 0 for padding) + - mask: Loss mask (0=ignore, 1=train loss) + """ + input_ids = [] + mask = [] + + # Get token IDs for special tokens + user_token_ids = tokenizer(user_token, add_special_tokens=False)["input_ids"] + assistant_token_ids = tokenizer(assistant_token, add_special_tokens=False)["input_ids"] + + # Process each turn + for user_text, assistant_text in turns: + # User role marker + input_ids.extend(user_token_ids) + mask.extend([0] * len(user_token_ids)) # Mask role token + + # User message tokens + user_tokens = tokenizer(user_text, add_special_tokens=False)["input_ids"] + input_ids.extend(user_tokens) + mask.extend([0] * len(user_tokens)) # Mask user input + + # Assistant role marker + input_ids.extend(assistant_token_ids) + mask.extend([0] * len(assistant_token_ids)) # Mask role token + + # Assistant message tokens + assistant_tokens = tokenizer(assistant_text, add_special_tokens=False)["input_ids"] + input_ids.extend(assistant_tokens) + mask.extend([1] * len(assistant_tokens)) # Train on assistant output + + # Add EOS token if exists + if hasattr(tokenizer, 'eos_token_id') and tokenizer.eos_token_id is not None: + input_ids.append(tokenizer.eos_token_id) + mask.append(0) # Mask EOS token + + # Truncate if too long + if len(input_ids) > max_seq_length: + input_ids = input_ids[:max_seq_length] + mask = mask[:max_seq_length] + + # Pad to max_seq_length + padding_len = max_seq_length - len(input_ids) + if padding_len > 0: + input_ids.extend([tokenizer.pad_token_id or 0] * padding_len) + mask.extend([0] * padding_len) # Mask padding tokens + + # Create attention mask (1 for real tokens, 0 for padding) + attention_mask = [1 if i < len(input_ids) - padding_len else 0 for i in range(len(input_ids))] + + return input_ids, attention_mask, mask + + +def apply_response_masking(input_ids: List[int], mask: List[int]) -> List[int]: + """ + Apply response-only loss masking by converting mask values to label format. + + Args: + input_ids: Token IDs + mask: Mask array (0=ignore, 1=train) + + Returns: + labels: Where mask=0 tokens have label=-100 (ignore in loss), mask=1 tokens have label=input_id + """ + labels = input_ids.copy() + for i, m in enumerate(mask): + if m == 0: + labels[i] = -100 # CrossEntropyLoss will ignore this token + return labels + + +def build_response_only_next_token_labels(input_ids: List[int], mask: List[int]) -> List[int]: + """ + Build next-token labels for SFT response-only training. + + Position i predicts token i+1, so the loss mask must be applied to the target + token, not the current input token. This trains the first assistant token from + the assistant role marker and avoids training on masked EOS/padding targets. + """ + if len(input_ids) != len(mask): + raise ValueError(f"input_ids and mask must have the same length: {len(input_ids)} != {len(mask)}") + + labels = apply_response_masking(input_ids, mask) + return labels[1:] + [-100] diff --git a/code/TaoTrain/src/taoTrain/data/tokenization_queue.py b/code/TaoTrain/src/taoTrain/data/tokenization_queue.py new file mode 100644 index 0000000000000000000000000000000000000000..6631cd1ee57b4e44cdd43f75b960dc709733503a --- /dev/null +++ b/code/TaoTrain/src/taoTrain/data/tokenization_queue.py @@ -0,0 +1,410 @@ +"""Background tokenization queue for streaming large JSONL datasets.""" + +import queue +import threading +import time +from typing import Dict, List, Optional, Any, Callable +import torch + +from taoTrain.data.chunk_manager import ChunkManager + + +class TokenizationQueue: + """ + Background threads that continuously tokenize chunks and stores them in a queue. + + This allows tokenization to happen in parallel with training, avoiding the bottleneck + of tokenizing all data upfront before training starts. + + Supports multiple worker threads for faster throughput. Each thread greedily + grabs the next available chunk using an atomic counter. + + Attributes: + total_items: Total number of samples across all chunks + queue_size: Maximum number of chunks to buffer in memory + num_threads: Number of worker threads for tokenization + """ + + def __init__( + self, + chunk_manager: ChunkManager, + tokenizer: Any, + config: "TrainingConfig", # type: ignore + max_queue_size: int = 2, + shuffle_chunks: bool = True, + num_threads: int = 1, + ): + """ + Initialize tokenization queue with multithreading support. + + Args: + chunk_manager: ChunkManager instance loaded with chunks + tokenizer: Tokenizer instance (HuggingFace or SentencePiece wrapper) + config: Training configuration with model and dataset settings + max_queue_size: Maximum chunks to buffer in queue (memory constraint) + shuffle_chunks: Whether to shuffle chunk order at initialization + num_threads: Number of worker threads for tokenization (default: 1) + + Raises: + ValueError: If chunk_manager has no chunks or num_threads < 1 + """ + if chunk_manager.num_chunks == 0: + raise ValueError("ChunkManager must have at least one chunk") + if num_threads < 1: + raise ValueError(f"num_threads must be >= 1, got {num_threads}") + + self.chunk_manager = chunk_manager + self.tokenizer = tokenizer + self.config = config + self.max_queue_size = max_queue_size + self.shuffle_chunks = shuffle_chunks + self.num_threads = num_threads + + # Detect SFT mode: check for response_loss_only flag + self.is_sft_mode = hasattr(config, 'response_loss_only') and config.response_loss_only + + # Calculate total items across all chunks + self.total_items = chunk_manager.effective_lines + + # Thread-safe queue for tokenized chunks + self._queue: queue.Queue[Dict[str, List]] = queue.Queue(maxsize=max_queue_size) + + # Control signals + self._stop_event = threading.Event() + self._error_event = threading.Event() + self._error_messages: List[str] = [] + self._threads: List[threading.Thread] = [] + + # Thread-safe chunk distribution + self._next_chunk_idx = 0 + self._chunk_idx_lock = threading.Lock() + self._active_threads = 0 + self._active_threads_lock = threading.Lock() + + # Chunk ordering + self._chunk_order = list(range(chunk_manager.num_chunks)) + print(f"TokenizationQueue initialized with {chunk_manager.num_chunks} chunks, total {chunk_manager.effective_lines} samples") + print(f"Using {num_threads} tokenization worker thread{'s' if num_threads != 1 else ''}") + print(f"Max queue size: {max_queue_size} chunks (memory constraint)") + if self.shuffle_chunks: + import random + random.shuffle(self._chunk_order) + + def _get_next_chunk_idx(self) -> Optional[int]: + """ + Atomically get the next chunk index for processing. + + Returns: + Chunk index to process, or None if all chunks have been assigned + """ + with self._chunk_idx_lock: + if self._next_chunk_idx < len(self._chunk_order): + chunk_idx = self._chunk_order[self._next_chunk_idx] + self._next_chunk_idx += 1 + return chunk_idx + return None + + def start(self): + """Start the tokenization background worker threads.""" + if self._threads: + raise RuntimeError(f"Tokenization threads already started ({len(self._threads)} active)") + + # Create and start N worker threads + for thread_id in range(self.num_threads): + thread = threading.Thread(target=self._worker, args=(thread_id,), daemon=False) + self._threads.append(thread) + thread.start() + + def _worker(self, thread_id: int): + """ + Worker thread target: greedy chunk processing with thread-safe distribution. + + Args: + thread_id: Identifier for this worker thread + """ + with self._active_threads_lock: + self._active_threads += 1 + + try: + while True: + # Check for stop signal + if self._stop_event.is_set(): + break + + # Get next chunk to process (atomic operation) + chunk_num = self._get_next_chunk_idx() + if chunk_num is None: + # All chunks assigned + break + + # Load chunk + chunk_examples = self.chunk_manager.read_chunk(chunk_num) + + # Tokenize chunk based on mode + if self.is_sft_mode: + tokenized_chunk = self._tokenize_batch_sft(chunk_examples) + else: + # Extract texts for pretrain + text_field = self.config.dataset.text_field + texts = [obj.get(text_field, "") for obj in chunk_examples] + tokenized_chunk = self._tokenize_batch(texts) + + # Put in queue (blocks if queue is full) + self._queue.put(tokenized_chunk) + print(f"[Worker-{thread_id}] Processed chunk {chunk_num}, put {len(tokenized_chunk['input_ids'])} samples in queue") + except Exception as e: + error_msg = f"[Worker-{thread_id}] {str(e)}" + print(f"Worker-{thread_id} encountered an error: {error_msg}") + # Thread-safe append to error list + self._error_messages.append(error_msg) + self._error_event.set() + finally: + with self._active_threads_lock: + self._active_threads -= 1 + remaining = self._active_threads + print(f"[Worker-{thread_id}] Finished processing. Active threads remaining: {remaining}") + def _tokenize_batch(self, texts: List[str]) -> Dict[str, List]: + """ + Tokenize a batch of texts, join with EOS, and split into fixed-size sequences. + + This packs multiple documents into longer sequences separated by EOS tokens, + then splits the concatenated tokens into N fixed-size chunks of max_seq_length. + + Args: + texts: List of text strings + + Returns: + Dict with 'input_ids' and 'attention_mask' lists, where each element + is a fixed-size sequence of length max_seq_length + """ + max_seq_length = self.config.model.max_seq_length + + # Get EOS token ID + eos_token_id = self.tokenizer.eos_token_id + unk_token_id = self.tokenizer.unk_token_id + if eos_token_id is None: + raise ValueError("Tokenizer does not have an EOS token defined") + if unk_token_id is None: + raise ValueError("Tokenizer does not have an UNK token defined") + + # Tokenize all texts without truncation + all_token_ids = [] + + for i, text in enumerate(texts): + tokenized = self.tokenizer( + text, + truncation=False, + return_attention_mask=False, + ) + + # Remove UNK tokens from tokenized output (if any) + tokenized["input_ids"] = [tid for tid in tokenized["input_ids"] if tid != unk_token_id] + + all_token_ids.extend(tokenized["input_ids"]) + # Add EOS token between documents (except after the last one) + if i < len(texts) - 1: + all_token_ids.append(eos_token_id) + + # Split into N fixed-size sequences + sequences_input_ids = [] + sequences_attention_masks = [] + + for i in range(0, len(all_token_ids), max_seq_length): + seq = all_token_ids[i : i + max_seq_length] + + # Pad sequence if it's shorter than max_seq_length + if len(seq) < max_seq_length: + # Create attention mask before padding + attention_mask = [1] * len(seq) + [0] * (max_seq_length - len(seq)) + # Pad with 0 (assuming 0 is the pad token, or use tokenizer.pad_token_id) + pad_token_id = self.tokenizer.pad_token_id or 0 + seq = seq + [pad_token_id] * (max_seq_length - len(seq)) + else: + attention_mask = [1] * max_seq_length + + sequences_input_ids.append(seq) + sequences_attention_masks.append(attention_mask) + + return { + "input_ids": sequences_input_ids, + "attention_mask": sequences_attention_masks, + } + + def _tokenize_batch_sft(self, records: List[Dict[str, Any]]) -> Dict[str, List]: + """ + Tokenize a batch of SFT records with role tokens and response masking. + + Processes each record (single-turn or multi-turn) and generates sequences + with role markers and masking (0=ignore user, 1=train on assistant). + + Args: + records: List of JSONL record dicts with various SFT formats + + Returns: + Dict with 'input_ids', 'attention_mask', and 'mask' lists, where each + element is a fixed-size sequence of length max_seq_length with masking info + """ + # Import here to avoid circular imports + from taoTrain.data.sft_utils import parse_sft_record, build_sft_sequence_tokens + + max_seq_length = self.config.model.max_seq_length + user_token = getattr(self.config, 'user_token', '') + assistant_token = getattr(self.config, 'assistant_token', '') + + sequences_input_ids = [] + sequences_attention_masks = [] + sequences_masks = [] + + for record in records: + try: + # Parse SFT record (supports multiple formats) + turns, is_multi_turn = parse_sft_record(record, self.config) + + if not turns: + # Skip records that couldn't be parsed + continue + + # Build token sequence with role tokens and response masking + input_ids, attention_mask, mask = build_sft_sequence_tokens( + turns=turns, + tokenizer=self.tokenizer, + user_token=user_token, + assistant_token=assistant_token, + max_seq_length=max_seq_length, + ) + + sequences_input_ids.append(input_ids) + sequences_attention_masks.append(attention_mask) + sequences_masks.append(mask) + + except Exception as e: + # Log error but continue processing + print(f"Warning: Failed to tokenize SFT record: {e}") + continue + + return { + "input_ids": sequences_input_ids, + "attention_mask": sequences_attention_masks, + "mask": sequences_masks, + } + + def get_next_chunk(self, timeout: Optional[float] = None) -> Optional[Dict[str, List]]: + """ + Get the next tokenized chunk from the queue. + + This is a blocking call that waits for the next chunk to be tokenized. + Returns None if queue is closed or all chunks have been processed. + + CRITICAL: Always attempts to drain the queue first before returning None. + This prevents abandoning buffered chunks when threads finish. + + Args: + timeout: Timeout in seconds (None = wait indefinitely) + + Returns: + Dict with tokenized chunk, or None if queue is exhausted + + Raises: + RuntimeError: If an error occurred in any worker thread + """ + if self._error_event.is_set(): + error_summary = "; ".join(self._error_messages) if self._error_messages else "Unknown error" + raise RuntimeError(f"Tokenization thread error: {error_summary}") + + # PRIORITY: Try to get from queue first (may have buffered items) + try: + chunk = self._queue.get(timeout=timeout) + return chunk + except queue.Empty: + # Queue is empty - check if threads are still working + with self._active_threads_lock: + if self._active_threads == 0 and self._next_chunk_idx >= len(self._chunk_order): + # All chunks assigned AND no active threads = true exhaustion + return None + # Queue temporarily empty but threads still working - signal to wait + return None + + @property + def is_exhausted(self) -> bool: + """Return True only when all chunks are assigned and all workers are idle.""" + with self._active_threads_lock: + return self._active_threads == 0 and self._next_chunk_idx >= len(self._chunk_order) + + def shutdown(self, wait: bool = True): + """ + Shutdown the tokenization worker threads gracefully. + + Args: + wait: If True, wait for all threads to finish; otherwise return immediately + """ + if not self._threads: + return + + # Signal threads to stop + self._stop_event.set() + + # Drain queue to unblock threads if they're waiting to put + try: + while True: + self._queue.get_nowait() + except queue.Empty: + pass + + # Wait for all threads to finish + if wait: + for thread in self._threads: + thread.join(timeout=5.0) + if thread.is_alive(): + print(f"⚠ Tokenization thread {thread.name} did not terminate cleanly") + + # Clear thread list to allow fresh start in next epoch + self._threads.clear() + print("✓ TokenizationQueue shutdown complete, thread list cleared") + + def reset_for_next_epoch(self): + """ + Reset queue state for the next epoch. + + This allows the same TokenizationQueue to be reused across multiple epochs. + Resets the chunk index counter, reshuffles chunks (if enabled), and clears + any buffered items and error state. + + Called by AsyncBatchIterator at the start of epoch 2+. + """ + # Reset iteration counter + self._next_chunk_idx = 0 + + # Reshuffle chunk order if enabled + if self.shuffle_chunks: + import random + random.shuffle(self._chunk_order) + print(f"✓ Reshuffled chunk order for next epoch: {self._chunk_order}") + + # Drain any remaining items from queue + items_drained = 0 + try: + while True: + self._queue.get_nowait() + items_drained += 1 + except queue.Empty: + pass + + if items_drained > 0: + print(f"⚠ Drained {items_drained} items from queue before epoch reset") + + # Clear error state + self._error_event.clear() + self._error_messages.clear() + + # Clear threads list so new threads will be started in next epoch + self._threads.clear() + + print(f"✓ TokenizationQueue reset for next epoch. Ready to process {len(self._chunk_order)} chunks") + + def __len__(self) -> int: + """Return total number of samples.""" + return self.total_items + + def __del__(self): + """Cleanup on deletion.""" + self.shutdown(wait=False) diff --git a/code/TaoTrain/src/taoTrain/data/tokenizer.py b/code/TaoTrain/src/taoTrain/data/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..ed621b18a7fed445d0f07a065a1b438f7945e93f --- /dev/null +++ b/code/TaoTrain/src/taoTrain/data/tokenizer.py @@ -0,0 +1,118 @@ +"""SentencePiece tokenizer wrapper for HuggingFace compatibility.""" + +from typing import Optional, List, Union + + +class SentencePieceTokenizerWrapper: + """Wrapper to make SentencePiece tokenizer compatible with HuggingFace interface.""" + + def __init__(self, sp_processor): + """ + Initialize wrapper. + + Args: + sp_processor: sentencepiece.SentencePieceProcessor instance + """ + self.sp = sp_processor + self.vocab_size = self.sp.vocab_size() + self.pad_token_id = self.sp.pad_id() + self.eos_token_id = self.sp.eos_id() + self.bos_token_id = self.sp.bos_id() + self.unk_token_id = self.sp.unk_id() + + def __call__(self, text, **kwargs): + """ + Tokenize text. + + Args: + text: Input text or list of texts + **kwargs: Additional arguments (truncation, max_length, padding, return_attention_mask) + + Returns: + Dict with input_ids and attention_mask + """ + # Handle both single string and list of strings + is_single = isinstance(text, str) + texts = [text] if is_single else text + + max_length = kwargs.get('max_length', None) + padding = kwargs.get('padding', None) + truncation = kwargs.get('truncation', False) + return_attention_mask = kwargs.get('return_attention_mask', True) + + # Tokenize all texts + all_input_ids = [] + for t in texts: + tokens = self.sp.encode(t, out_type=int) + + # Truncate if needed + if truncation and max_length and len(tokens) > max_length: + tokens = tokens[:max_length] + + all_input_ids.append(tokens) + + # Padding + if padding or max_length: + target_length = max_length or max(len(ids) for ids in all_input_ids) if all_input_ids else 0 + padded_input_ids = [] + padded_attention_masks = [] + + for ids in all_input_ids: + pad_length = target_length - len(ids) + if pad_length > 0: + padded_ids = ids + [self.pad_token_id] * pad_length + else: + padded_ids = ids[:target_length] + + padded_input_ids.append(padded_ids) + attention_mask = [1] * len(ids) + [0] * (target_length - len(ids)) + padded_attention_masks.append(attention_mask) + + result = { + "input_ids": padded_input_ids if not is_single else padded_input_ids[0], + } + if return_attention_mask: + result["attention_mask"] = padded_attention_masks if not is_single else padded_attention_masks[0] + else: + result = { + "input_ids": all_input_ids[0] if is_single else all_input_ids, + } + if return_attention_mask: + attention_masks = [[1] * len(ids) for ids in all_input_ids] + result["attention_mask"] = attention_masks[0] if is_single else attention_masks + + return result + + def encode(self, text, return_tensors=None, **kwargs): + """Encode text to token IDs.""" + result = self(text, **kwargs) + input_ids = result["input_ids"] + + if return_tensors == "pt": + import torch + # Ensure input_ids is a 1D list of ints + if isinstance(input_ids[0], list): + input_ids = input_ids[0] + return torch.tensor(input_ids, dtype=torch.long).unsqueeze(0) + + return input_ids + + def encode_plus(self, text, **kwargs): + """Encode text with additional information (HuggingFace compatibility).""" + return self(text, **kwargs) + + def decode(self, token_ids, skip_special_tokens=False, **kwargs): + """Decode token IDs to text.""" + if hasattr(token_ids, 'tolist'): # Handle torch tensors + token_ids = token_ids.tolist() + + # Handle various input formats + if isinstance(token_ids, (list, tuple)): + if len(token_ids) > 0 and isinstance(token_ids[0], (list, tuple)): + token_ids = token_ids[0] + + # Ensure it's a list of ints + if not isinstance(token_ids, list): + token_ids = [int(t) for t in token_ids] + + return self.sp.decode(token_ids) diff --git a/code/TaoTrain/src/taoTrain/inference/__init__.py b/code/TaoTrain/src/taoTrain/inference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b6ac9f0da6faa20e018f4c8c1aa47c3602a7d380 --- /dev/null +++ b/code/TaoTrain/src/taoTrain/inference/__init__.py @@ -0,0 +1,5 @@ +"""Inference engines.""" + +from .inferencer import Inferencer + +__all__ = ["Inferencer"] diff --git a/code/TaoTrain/src/taoTrain/inference/inferencer.py b/code/TaoTrain/src/taoTrain/inference/inferencer.py new file mode 100644 index 0000000000000000000000000000000000000000..0d7512ab3d0ba3ec39f2825f6a43a82ceebd8466 --- /dev/null +++ b/code/TaoTrain/src/taoTrain/inference/inferencer.py @@ -0,0 +1,301 @@ +"""Inference engine for model generation.""" + +from pathlib import Path +from typing import Optional, Iterator, Any +import torch +from transformers import AutoTokenizer +from rich.console import Console +from rich.table import Table + +from taoTrain.core import BaseModel +from taoTrain.config import ModelConfig + + +class Inferencer: + """Inference engine for text generation.""" + + def __init__( + self, + model: BaseModel, + tokenizer: Any, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + """ + Initialize inferencer. + + Args: + model: Trained model + tokenizer: Tokenizer instance (HuggingFace or SentencePiece wrapped) + device: Device for inference + dtype: Data type for inference + """ + self.model = model + self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.dtype = dtype or torch.float32 + self.tokenizer = tokenizer + + # Move model to device and set eval mode + self.model = self.model.to(self.device) + self.model.eval() + + # Set pad token if needed (for HuggingFace tokenizers) + if hasattr(self.tokenizer, 'pad_token') and self.tokenizer.pad_token is None: + if hasattr(self.tokenizer, 'eos_token'): + self.tokenizer.pad_token = self.tokenizer.eos_token + + @staticmethod + def _load_tokenizer(tokenizer_path: str | Path) -> Any: + """ + Load tokenizer from path (SentencePiece or HuggingFace). + + Args: + tokenizer_path: Path to tokenizer file or HuggingFace model name + + Returns: + Tokenizer instance + + Raises: + ValueError: If tokenizer cannot be loaded + """ + tokenizer_path = str(tokenizer_path) + + # Auto-detect tokenizer type based on file extension + if tokenizer_path.endswith('.model'): + # Load SentencePiece tokenizer + try: + import sentencepiece as spm + sp = spm.SentencePieceProcessor() + sp.Load(tokenizer_path) + # Wrap SentencePiece in a compatible interface + from taoTrain.data import SentencePieceTokenizerWrapper + return SentencePieceTokenizerWrapper(sp) + except ImportError: + raise ImportError("SentencePiece not installed. Install with: pip install sentencepiece") + except Exception as e: + raise ValueError(f"Failed to load SentencePiece tokenizer from {tokenizer_path}: {e}") + else: + # Load HuggingFace tokenizer + try: + return AutoTokenizer.from_pretrained(tokenizer_path) + except Exception as e: + raise ValueError(f"Failed to load HuggingFace tokenizer from {tokenizer_path}: {e}") + + @staticmethod + def _print_tokenizer_info(tokenizer: Any, tokenizer_path: str) -> None: + """Print tokenizer information.""" + console = Console() + table = Table(title="Tokenizer Information") + table.add_column("Property", style="cyan") + table.add_column("Value", style="green") + + table.add_row("Type", "SentencePiece" if tokenizer_path.endswith('.model') else "HuggingFace") + table.add_row("Path", str(tokenizer_path)) + + if hasattr(tokenizer, 'vocab_size'): + table.add_row("Vocab Size", str(tokenizer.vocab_size)) + + console.print(table) + + @staticmethod + def load_from_checkpoint( + checkpoint_path: str | Path, + tokenizer_path: Optional[str | Path] = None, + device: Optional[torch.device] = None, + ) -> "Inferencer": + """ + Load model from checkpoint and create inferencer. + + Handles both canonical and legacy checkpoint formats: + - Canonical: uses 'model_state' key + - Legacy: uses 'model_state_dict' key + + Args: + checkpoint_path: Path to checkpoint file + tokenizer_path: Optional path to tokenizer (overrides checkpoint's tokenizer_path) + device: Device for inference + + Returns: + Inferencer instance + + Raises: + ValueError: If no tokenizer path found in checkpoint or arguments + """ + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Load checkpoint using CheckpointManager for automatic format normalization + from taoTrain.checkpointing.checkpoint import CheckpointManager + checkpoint_manager = CheckpointManager(checkpoint_path.parent if isinstance(checkpoint_path, Path) else Path(checkpoint_path).parent) + checkpoint = checkpoint_manager.load(checkpoint_path, device=device) + + config_dict = checkpoint.get("config", {}) + + # Extract tokenizer path from checkpoint config or use provided override + if tokenizer_path is None: + # Try to get tokenizer_path from checkpoint config + dataset_config = config_dict.get("dataset", {}) + tokenizer_path = dataset_config.get("tokenizer_path") + + if not tokenizer_path: + raise ValueError( + f"No tokenizer path found in checkpoint config at {checkpoint_path}. " + "Please provide --tokenizer argument with path to tokenizer file." + ) + + # Load tokenizer + console = Console() + console.print("\n[bold cyan]Loading tokenizer...[/bold cyan]") + tokenizer = Inferencer._load_tokenizer(tokenizer_path) + Inferencer._print_tokenizer_info(tokenizer, str(tokenizer_path)) + + # Reconstruct model config + from taoTrain.config import ModelConfig + model_config = ModelConfig(**config_dict.get("model", {})) + + # Create and load model + # CheckpointManager.load() normalizes to 'model_state' key + from taoTrain.models import get_model + model = get_model(model_config, device=device) + model.load_state_dict(checkpoint["model_state"]) + + return Inferencer(model, tokenizer, device) + + def generate( + self, + prompt: str, + max_length: int = 256, + temperature: float = 0.7, + top_p: float = 0.95, + top_k: Optional[int] = None, + repetition_penalty: float = 1.0, + do_sample: bool = True, + stream: bool = False, + ) -> str | Iterator[str]: + """ + Generate text from a prompt. + + Args: + prompt: Input prompt + max_length: Maximum generation length + temperature: Temperature for sampling + top_p: Nucleus sampling parameter + top_k: Top-k sampling parameter + repetition_penalty: Penalty for repeated tokens (1.0 = no penalty, >1.0 = penalize) + do_sample: Whether to sample or use greedy decoding + stream: Whether to stream tokens + + Yields/Returns: + Generated text (or stream of tokens if stream=True) + """ + # Tokenize prompt + input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device) + prompt_length = input_ids.shape[1] + + # For streaming with full context decoding + generated_token_ids = [] # Accumulate all generated tokens + last_decoded_full = "" # Cache full decoded text from previous step + + with torch.no_grad(): + for step in range(max_length): + # Forward pass + outputs = self.model( + input_ids=input_ids, + attention_mask=None, + labels=None, + ) + + logits = outputs["logits"] + + # Get logits for next token + next_logits = logits[:, -1, :] / temperature + + # Apply repetition penalty to previously generated tokens + if repetition_penalty != 1.0: + generated_ids = input_ids[0, prompt_length:] + unique_ids = torch.unique(generated_ids) + for token_id in unique_ids: + next_logits[0, token_id] /= repetition_penalty + + # Apply top-k and top-p sampling + if top_k is not None: + indices_to_remove = next_logits < torch.topk(next_logits, top_k)[0][..., -1, None] + next_logits[indices_to_remove] = float('-inf') + + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(next_logits, descending=True) + probs = torch.softmax(sorted_logits, dim=-1) + cumsum_probs = torch.cumsum(probs, dim=-1) + + sorted_indices_to_remove = cumsum_probs > top_p + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = False + + indices_to_remove = sorted_indices[sorted_indices_to_remove] + next_logits[:, indices_to_remove] = float('-inf') + + # Sample or greedy + probs = torch.softmax(next_logits, dim=-1) + + if do_sample: + next_token = torch.multinomial(probs, num_samples=1) + else: + next_token = torch.argmax(next_logits, dim=-1, keepdim=True) + + # Append to input + input_ids = torch.cat([input_ids, next_token], dim=-1) + + # Stream if requested (with full context decoding to preserve spaces) + if stream: + # Accumulate the generated token ID + generated_token_ids.append(next_token.item()) + # Decode entire accumulated sequence (tokenizer has full context) + full_decoded_text = self.tokenizer.decode(generated_token_ids) + # Extract only NEW text since last yield + new_text = full_decoded_text[len(last_decoded_full):] + if new_text: + yield new_text + last_decoded_full = full_decoded_text + + # Stop on EOS + if next_token.item() == self.tokenizer.eos_token_id: + break + + if not stream: + # Return full generated text + generated_ids = input_ids[0, prompt_length:] + return self.tokenizer.decode(generated_ids, skip_special_tokens=True) + + def count_tokens_generated( + self, + prompt: str, + max_length: int = 256, + ) -> torch.Tensor: + """ + Measure generation speed (tokens per second). + + Args: + prompt: Input prompt + max_length: Maximum generation length + + Returns: + Number of tokens generated + """ + import time + + start = time.time() + + # Generate (we'll just do one forward pass to measure) + input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device) + + with torch.no_grad(): + outputs = self.model( + input_ids=input_ids, + attention_mask=None, + labels=None, + ) + + elapsed = time.time() - start + tokens_per_sec = (input_ids.shape[1] + 1) / elapsed + + return tokens_per_sec diff --git a/code/TaoTrain/src/taoTrain/inference/tui.py b/code/TaoTrain/src/taoTrain/inference/tui.py new file mode 100644 index 0000000000000000000000000000000000000000..627a535fe4a89b7cc1c03537af4f74cb4b06ef9b --- /dev/null +++ b/code/TaoTrain/src/taoTrain/inference/tui.py @@ -0,0 +1,161 @@ +"""TUI (Terminal User Interface) for interactive chat.""" + +import sys +import time +from pathlib import Path +from typing import Optional +import click +from rich.console import Console +from rich.markdown import Markdown +from rich.panel import Panel +from rich.text import Text +from rich.table import Table +from textual.app import ComposeResult, RenderableType +from textual.containers import Container, Horizontal, Vertical +from textual.widgets import TextArea, Static, Button +from textual.binding import Binding + +from taoTrain.inference import Inferencer + + +class TokensPerSecDisplay(Static): + """Display tokens per second metric.""" + + DEFAULT_CSS = """ + TokensPerSecDisplay { + width: 100%; + height: 1; + background: $panel; + border: solid $accent; + } + """ + + def __init__(self, tps: float = 0.0): + """Initialize.""" + super().__init__() + self.tps = tps + + def render(self) -> RenderableType: + """Render TPS display.""" + text = f"Tokens/sec: {self.tps:.2f}" + return Text(text, style="bold cyan") + + def update_tps(self, tps: float): + """Update TPS value.""" + self.tps = tps + self.update() + + +class SimpleChat: + """Simple CLI-based chat interface (fallback for testing).""" + + def __init__(self, checkpoint_path: str | Path, tokenizer_path: Optional[str | Path] = None): + """Initialize chat.""" + self.checkpoint_path = Path(checkpoint_path) + self.tokenizer_path = tokenizer_path + + print("\nLoading model...") + self.inferencer = Inferencer.load_from_checkpoint( + self.checkpoint_path, + tokenizer_path=self.tokenizer_path, + ) + + # Print model info + console = Console() + info_table = Table(title="Model Information") + info_table.add_column("Property", style="cyan") + info_table.add_column("Value", style="green") + + info_table.add_row("Checkpoint", str(self.checkpoint_path)) + if self.tokenizer_path: + info_table.add_row("Tokenizer (override)", str(self.tokenizer_path)) + + console.print(info_table) + + def run(self): + """Run chat loop.""" + console = Console() + + console.print("\n[bold cyan]Chat Interface[/bold cyan]") + console.print("[dim]Type 'exit' or 'quit' to exit[/dim]\n") + + while True: + try: + # Get user input + prompt = input("You: ").strip() + + if prompt.lower() in ["exit", "quit"]: + console.print("\n[yellow]Goodbye![/yellow]") + break + + if not prompt: + continue + + # Generate response + console.print("\n[bold cyan]Assistant:[/bold cyan] ", end="") + + start_time = time.time() + token_count = 0 + + # Stream generation + for token in self.inferencer.generate( + prompt, + max_length=256, + temperature=0.7, + top_p=0.95, + repetition_penalty=10, + stream=True, + ): + console.print(token, end="", soft_wrap=True) + token_count += 1 + + elapsed = time.time() - start_time + tps = token_count / elapsed if elapsed > 0 else 0 + + console.print(f"\n\n[dim]({tps:.1f} tokens/sec, {token_count} tokens)[/dim]\n") + + except KeyboardInterrupt: + console.print("\n\n[yellow]Chat interrupted.[/yellow]") + break + except Exception as e: + console.print(f"\n[red]Error: {e}[/red]\n") + + +@click.command() +@click.option( + "--model", + type=click.Path(exists=True), + required=True, + help="Path to model checkpoint (.pt file)", +) +@click.option( + "--tokenizer", + type=click.Path(exists=True), + required=False, + default=None, + help="Path to tokenizer file (.model or HuggingFace path). If not provided, uses tokenizer_path from checkpoint config.", +) +def main(model: str, tokenizer: Optional[str]): + """ + Interactive TUI chat with a trained model. + + Example: + tui-chat --model checkpoints/best_model.pt + tui-chat --model checkpoints/best_model.pt --tokenizer path/to/tokenizer.model + """ + try: + chat = SimpleChat(model, tokenizer_path=tokenizer) + chat.run() + except FileNotFoundError: + click.echo(f"Error: Model file not found: {model}", err=True) + sys.exit(1) + except ValueError as e: + click.echo(f"Error: {e}", err=True) + sys.exit(1) + except Exception as e: + click.echo(f"Error: {e}", err=True) + sys.exit(1) + + +if __name__ == "__main__": + main() # type: ignore diff --git a/code/TaoTrain/src/taoTrain/logging/__init__.py b/code/TaoTrain/src/taoTrain/logging/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ecd8ac46dedef63c27b074e972ee1ce9425b3717 --- /dev/null +++ b/code/TaoTrain/src/taoTrain/logging/__init__.py @@ -0,0 +1,5 @@ +"""Logging integrations.""" + +from .aim_logger import AimLogger + +__all__ = ["AimLogger"] diff --git a/code/TaoTrain/src/taoTrain/logging/aim_logger.py b/code/TaoTrain/src/taoTrain/logging/aim_logger.py new file mode 100644 index 0000000000000000000000000000000000000000..449c0b247c770fa612e3b9bc13d180a0c1820a46 --- /dev/null +++ b/code/TaoTrain/src/taoTrain/logging/aim_logger.py @@ -0,0 +1,153 @@ +"""AimStack logging integration.""" + +from pathlib import Path +from typing import Dict, Any, Optional +import subprocess +import json +from datetime import datetime + +try: + from aim import Run + HAS_AIM = True +except ImportError: + HAS_AIM = False + +from taoTrain.config import TrainingConfig + + +class AimLogger: + """AimStack logger for tracking training metrics and hyperparameters.""" + + def __init__(self, config: TrainingConfig): + """ + Initialize AimStack logger. + + Args: + config: Training configuration + """ + self.config = config + self.run: Optional[Run] = None + + if HAS_AIM: + # Initialize AimStack run + repo_path = Path(config.aim_repo) + repo_path.mkdir(parents=True, exist_ok=True) + + self.run = Run(repo=str(repo_path)) + + # Log hyperparameters + self._log_hyperparameters() + else: + print("Warning: AimStack not installed. Install with: pip install aim") + + def _log_hyperparameters(self): + """Log hyperparameters to AimStack.""" + if self.run is None: + return + + # Log model config + self.run["hparams/model"] = { + "architecture": self.config.model.architecture_type.value, + "vocab_size": self.config.model.vocab_size, + "hidden_dim": self.config.model.hidden_dim, + "num_layers": self.config.model.num_layers, + "num_heads": self.config.model.num_heads, + "dropout": self.config.model.dropout, + "max_seq_length": self.config.model.max_seq_length, + } + + # Log training config + self.run["hparams/training"] = { + "batch_size": self.config.batch_size, + "num_epochs": self.config.num_epochs, + "learning_rate": self.config.optimizer.learning_rate, + "weight_decay": self.config.optimizer.weight_decay, + "gradient_accumulation_steps": self.config.gradient_accumulation_steps, + "max_grad_norm": self.config.max_grad_norm, + "dtype": self.config.dtype.value, + "seed": self.config.seed, + } + + # Log optimizer and scheduler config + self.run["hparams/optimizer"] = { + "optimizer_type": self.config.optimizer.optimizer_type.value, + "learning_rate": self.config.optimizer.learning_rate, + "weight_decay": self.config.optimizer.weight_decay, + } + + self.run["hparams/scheduler"] = { + "scheduler_type": self.config.scheduler.scheduler_type.value, + "warmup_steps": self.config.scheduler.warmup_steps, + "warmup_ratio": self.config.scheduler.warmup_ratio, + } + + # Log dataset config + self.run["hparams/dataset"] = { + "dataset_name": self.config.dataset.dataset_name, + "split": self.config.dataset.split, + "max_samples": self.config.dataset.max_samples, + } + + # Log mode + self.run["hparams/mode"] = self.config.mode.value + + # Log git hash if available + try: + git_hash = subprocess.check_output( + ["git", "rev-parse", "HEAD"], + stderr=subprocess.DEVNULL + ).decode().strip() + self.run["hparams/git_hash"] = git_hash + except: + pass + + # Log timestamp + self.run["hparams/timestamp"] = datetime.now().isoformat() + + def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None): + """ + Log metrics to AimStack. + + Args: + metrics: Dict of metric names to values + step: Global step (optional, auto-increments if not provided) + """ + if self.run is None: + return + + step = metrics.pop("step", step) + + for metric_name, metric_value in metrics.items(): + # Flatten nested dicts + if isinstance(metric_value, dict): + for nested_key, nested_val in metric_value.items(): + self.run.track( + float(nested_val), + name=f"{metric_name}/{nested_key}", + step=step, + ) + else: + try: + self.run.track( + float(metric_value), + name=metric_name, + step=step, + ) + except (ValueError, TypeError): + # Skip non-numeric metrics + pass + + def log_text(self, name: str, value: str, step: Optional[int] = None): + """Log text content.""" + if self.run is None: + return + + # AimStack doesn't have direct text logging, use metadata + metadata = getattr(self.run, '_metadata', {}) + if isinstance(metadata, dict): + metadata[name] = value + + def finish(self): + """Finish the run.""" + if self.run: + self.run.close() diff --git a/code/TaoTrain/src/taoTrain/models/__init__.py b/code/TaoTrain/src/taoTrain/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5ba818b0be87729d44103df4588fa9095b84b860 --- /dev/null +++ b/code/TaoTrain/src/taoTrain/models/__init__.py @@ -0,0 +1,5 @@ +"""Model architectures and registry.""" + +from .registry import get_model, register_architecture + +__all__ = ["get_model", "register_architecture"] diff --git a/code/TaoTrain/src/taoTrain/models/embeddings.py b/code/TaoTrain/src/taoTrain/models/embeddings.py new file mode 100644 index 0000000000000000000000000000000000000000..d12998f38a28bf39e04c8932e59dc44bc65e3f1c --- /dev/null +++ b/code/TaoTrain/src/taoTrain/models/embeddings.py @@ -0,0 +1,51 @@ +""" +Low-Rank Factorized Embedding. + +Uses standard nn.Linear for projection (NOT ternary quantization). +Embeddings should use full precision for good token representations. +""" + +import torch +import torch.nn as nn + + +class FactorizedEmbedding(nn.Module): + """ + Low-Rank Factorized Embedding: vocab → d_embed_rank → d_model + + Uses standard Linear layers (no quantization) for full precision. + Reduces embedding parameters from vocab_size × d_model to: + vocab_size × d_embed_rank + d_embed_rank × d_model + """ + + def __init__(self, vocab_size, d_model, d_embed_rank=96): + super().__init__() + self.vocab_size = vocab_size + self.d_model = d_model + self.d_embed_rank = d_embed_rank + + # Embedding table: vocab → compressed rank + self.embed = nn.Embedding(vocab_size, d_embed_rank) + + # Projection: compressed → full (standard Linear) + self.proj = nn.Linear(d_embed_rank, d_model, bias=False) + + # Initialize with small weights for stable training + nn.init.normal_(self.embed.weight, mean=0.0, std=0.02) + nn.init.normal_(self.proj.weight, mean=0.0, std=0.02) + + def forward(self, input_ids): + """ + Args: + input_ids: [batch_size, seq_len] tensor of token IDs + + Returns: + embeddings: [batch_size, seq_len, d_model] + """ + x = self.embed(input_ids) # [B, S, d_embed_rank] + x = self.proj(x) # [B, S, d_model] + return x + + def get_num_params(self): + """Return total number of parameters.""" + return self.vocab_size * self.d_embed_rank + self.d_embed_rank * self.d_model diff --git a/code/TaoTrain/src/taoTrain/models/mla_components.py b/code/TaoTrain/src/taoTrain/models/mla_components.py new file mode 100644 index 0000000000000000000000000000000000000000..2157d77d169519e7c63190c33792f083740c586b --- /dev/null +++ b/code/TaoTrain/src/taoTrain/models/mla_components.py @@ -0,0 +1,370 @@ +""" +DeepSeek-style Multi-head Latent Attention (MLA) with RoPE. + +Key innovations: +1. KV compression to latent space (reduce KV memory) +2. Q stays in full dimension for expressive query space +3. RoPE positional embeddings on Q and K +4. Grouped Query Attention (GQA) for efficiency +5. Learnable head combination weights +6. Numerical stability via pre-norm and scaling +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + + +def _residual_rms_norm(x, enabled=False, target=1.0, eps=1e-6, cap=None): + if not enabled and cap is None: + return x + rms = x.float().square().mean(dim=-1, keepdim=True).add(eps).sqrt() + if enabled: + scale = target / rms + else: + cap_tensor = torch.tensor(float(cap), dtype=rms.dtype, device=rms.device) + scale = torch.minimum(torch.ones_like(rms), cap_tensor / rms) + return x * scale.to(dtype=x.dtype) + + +class RotaryEmbedding(nn.Module): + """Rotary position embeddings used in RoPE with optional YaRN extension. + + YaRN (Yet another RoPE eXtension) allows context length interpolation via + frequency scaling. When yarn_alpha != 1.0 or seq_len > max_seq_length, + frequencies are dynamically scaled to support longer sequences. + + Parameters: + dim: Embedding dimension (must be even) + rope_scale: Base RoPE scale factor (default: 40) + max_seq_length: Original trained sequence length (default: 1024) + yarn_alpha: YaRN interpolation factor (default: 1.0, no interpolation) + - values < 1.0: aggressive interpolation (faster context expansion) + - values > 1.0: conservative interpolation (safer) + """ + + def __init__(self, dim, rope_scale=40.0, max_seq_length=1024, yarn_alpha=1.0): + super().__init__() + assert dim % 2 == 0, "Dimension must be even for rotary embeddings" + self.dim = dim + self.rope_scale = rope_scale + self.max_seq_length = max_seq_length + self.yarn_alpha = yarn_alpha + + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + + def _apply_yarn_scaling(self, freqs, seq_len): + """Apply YaRN frequency scaling for context extension. + + Args: + freqs: [seq_len, dim] frequency tensor + seq_len: Current sequence length + + Returns: + Scaled freqs if yarn is enabled and seq_len > max_seq_length, else original freqs + """ + # Only apply scaling if sequence exceeds training length or yarn_alpha != 1.0 + if self.yarn_alpha == 1.0 and seq_len <= self.max_seq_length: + return freqs + + # YaRN scaling factor: interpolate frequency reduction + # scale_factor = (seq_len / max_seq_length) ** (1 / yarn_alpha) + # Scales down frequencies to fit longer context while maintaining position distinctions + scale_factor = (seq_len / self.max_seq_length) ** (1.0 / self.yarn_alpha) + freqs = freqs / scale_factor + return freqs + + def forward(self, seq_len, device): + """Generate rotary embeddings for sequence with optional YaRN scaling. + + Args: + seq_len: Current sequence length + device: Device to create embeddings on + + Returns: + [seq_len, 2*dim] rotary embeddings (duplicated freqs) + """ + t = torch.arange(seq_len, device=device).type_as(self.inv_freq) / self.rope_scale + freqs = torch.einsum("i,j->ij", t, self.inv_freq) # [seq_len, dim//2] + + # Apply YaRN frequency scaling if enabled + freqs = self._apply_yarn_scaling(freqs, seq_len) + + return torch.cat((freqs, freqs), dim=-1) # [seq_len, dim] + + +def rotate_half(x): + """Rotate half the hidden dims of the input.""" + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary(x, cos, sin): + """Apply rotary embeddings to input tensor. + + Args: + x: [B, n_heads, seq_len, head_dim] or similar + cos: [seq_len, head_dim] or [1, 1, seq_len, head_dim] + sin: [seq_len, head_dim] or [1, 1, seq_len, head_dim] + """ + # Ensure cos/sin have the right dimensions for broadcasting + if cos.dim() == 2: + cos = cos.unsqueeze(0).unsqueeze(0) + sin = sin.unsqueeze(0).unsqueeze(0) + + # Handle case where cos/sin may be shorter than x + cos = cos[..., :x.shape[-1]] + sin = sin[..., :x.shape[-1]] + + # Split x based on cos dimensions + x_rot = x[..., :cos.shape[-1]] + x_base = x[..., cos.shape[-1]:] + + # Apply rotation + x_rot = (x_rot * cos) + (rotate_half(x_rot) * sin) + + # Concatenate rotated and base parts + return torch.cat([x_rot, x_base], dim=-1) if x_base.shape[-1] > 0 else x_rot + + +class DeepSeekMLA(nn.Module): + """ + DeepSeek-style Multi-head Latent Attention (MLA). + + Architecture: + 1. Project input to Query: [B, seq_len, d_model] -> [B, seq_len, d_model] + 2. Compress to KV latent: [B, seq_len, d_model] -> [B, seq_len, d_latent_kv] + 3. Split into heads for attention + 4. Apply RoPE to Q and K + 5. Compute attention scores: (Q @ K^T) / sqrt(d_head) + 6. Apply softmax and combine with values + 7. Concatenate heads and project back to d_model + + Parameters: + d_model: Model dimension + d_latent_kv: Latent dimension for KV compression + n_heads: Number of attention heads + d_rope: Dimension for RoPE (usually == d_head_dim) + dropout: Dropout probability + gqa_groups: Grouped Query Attention groups (1 = standard MLA, >1 = GQA) + """ + + def __init__(self, d_model, d_latent_kv, n_heads, d_rope, dropout=0.1, gqa_groups=1, + rope_scale=40.0, max_seq_length=1024, yarn_alpha=1.0): + super().__init__() + self.d_model = d_model + self.d_latent_kv = d_latent_kv + self.n_heads = n_heads + self.d_rope = d_rope + self.gqa_groups = gqa_groups + + assert d_model % n_heads == 0, f"d_model ({d_model}) must be divisible by n_heads ({n_heads})" + assert d_latent_kv % n_heads == 0, f"d_latent_kv ({d_latent_kv}) must be divisible by n_heads ({n_heads})" + + self.d_head_full = d_model // n_heads # Full head dimension for Q + self.d_head_latent = d_latent_kv // n_heads # Latent head dimension for K/V + + # Scaling factor for attention scores + self.scale = 1.0 / math.sqrt(self.d_head_latent) + + # Layer norm before attention for stability + self.norm = nn.LayerNorm(d_model) + + # Q projection: d_model -> d_model (full dimension) + self.q_proj = nn.Linear(d_model, d_model, bias=False) + + # K/V projections: d_model -> d_latent_kv (compressed) + self.k_proj = nn.Linear(d_model, d_latent_kv, bias=False) + self.v_proj = nn.Linear(d_model, d_latent_kv, bias=False) + + # RoPE for position encoding with YaRN support + self.rotary = RotaryEmbedding( + d_rope, + rope_scale=rope_scale, + max_seq_length=max_seq_length, + yarn_alpha=yarn_alpha + ) + + # Output projection: d_latent_kv -> d_model + self.out_proj = nn.Linear(d_latent_kv, d_model, bias=False) + + # Head combination weights (learnable scaling per head) + self.head_weights = nn.Parameter(torch.ones(n_heads)) + + # Dropout + self.attn_dropout = nn.Dropout(dropout) + self.proj_dropout = nn.Dropout(dropout) + + def forward(self, x, attention_mask=None): + """ + Args: + x: [B, seq_len, d_model] + attention_mask: [B, seq_len] (1 = keep, 0 = mask) or + [B, 1, seq_len, seq_len] (causal mask) + + Returns: + out: [B, seq_len, d_model] + """ + B, seq_len, _ = x.shape + device = x.device + + # Pre-norm + x_norm = self.norm(x) + + # Project to Q, K, V spaces + q = self.q_proj(x_norm) # [B, seq_len, d_model] + k = self.k_proj(x_norm) # [B, seq_len, d_latent_kv] + v = self.v_proj(x_norm) # [B, seq_len, d_latent_kv] + + # ──────────────────────────────────────────────────────────────────────── + # Reshape into multi-head format + # ──────────────────────────────────────────────────────────────────────── + # Q: [B, seq_len, d_model] -> [B, seq_len, n_heads, d_head_full] -> [B, n_heads, seq_len, d_head_full] + q = q.view(B, seq_len, self.n_heads, self.d_head_full).transpose(1, 2) + + # K: [B, seq_len, d_latent_kv] -> [B, seq_len, n_heads, d_head_latent] -> [B, n_heads, seq_len, d_head_latent] + k = k.view(B, seq_len, self.n_heads, self.d_head_latent).transpose(1, 2) + + # V: [B, seq_len, d_latent_kv] -> [B, seq_len, n_heads, d_head_latent] -> [B, n_heads, seq_len, d_head_latent] + v = v.view(B, seq_len, self.n_heads, self.d_head_latent).transpose(1, 2) + + # ──────────────────────────────────────────────────────────────────────── + # Apply RoPE to Q and K + # ──────────────────────────────────────────────────────────────────────── + if self.d_rope > 0: + # Generate RoPE embeddings: [seq_len, d_rope] + rotary_emb = self.rotary(seq_len, device) # [seq_len, d_rope] + cos = torch.cos(rotary_emb).unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, d_rope] + sin = torch.sin(rotary_emb).unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, d_rope] + + # Apply RoPE to Q (only on first d_rope dimensions) + q_rope = apply_rotary(q[..., :self.d_rope], cos, sin) # [B, n_heads, seq_len, d_rope] + q = torch.cat([q_rope, q[..., self.d_rope:]], dim=-1) # Combine with remaining dims + + # Apply RoPE to K (only on first d_rope dimensions) + k_rope = apply_rotary(k[..., :self.d_rope], cos, sin) # [B, n_heads, seq_len, d_rope] + k = torch.cat([k_rope, k[..., self.d_rope:]], dim=-1) # Combine with remaining dims + + # ──────────────────────────────────────────────────────────────────────── + # Compute attention using PyTorch 2.0+ fused scaled_dot_product_attention + # ──────────────────────────────────────────────────────────────────────── + # Only use first d_head_latent dimensions of Q for attention + # K and V are already d_head_latent dimension + q_for_attn = q[..., :self.d_head_latent] # [B, n_heads, seq_len, d_head_latent] + + # Convert attention mask to boolean format for scaled_dot_product_attention + # Input mask: 0 = mask (don't attend), 1 = keep (attend) + # Boolean mask: False = mask, True = attend + attn_mask_bool = None + if attention_mask is not None: + if attention_mask.dim() == 2: + # [B, seq_len] with {0, 1} -> [B, 1, 1, seq_len] with {False, True} + attn_mask_bool = attention_mask.bool().unsqueeze(1).unsqueeze(1) + else: + # Already 4D [B, 1, seq_len, seq_len], just convert to bool + attn_mask_bool = attention_mask.bool() + + # Get dropout probability (0.0 when not training) + dropout_p = self.attn_dropout.p if self.training else 0.0 + + if hasattr(F, "scaled_dot_product_attention"): + # Apply fused attention operation when available. + out_heads = F.scaled_dot_product_attention( + q_for_attn, k, v, + attn_mask=attn_mask_bool, + dropout_p=dropout_p, + scale=None + ) # [B, n_heads, seq_len, d_head_latent] + else: + scores = torch.matmul(q_for_attn, k.transpose(-2, -1)) * self.scale + if attn_mask_bool is not None: + scores = scores.masked_fill(~attn_mask_bool, torch.finfo(scores.dtype).min) + attn_weights = F.softmax(scores, dim=-1) + if dropout_p > 0.0: + attn_weights = F.dropout(attn_weights, p=dropout_p, training=True) + out_heads = torch.matmul(attn_weights, v) + + # ──────────────────────────────────────────────────────────────────────── + # Concatenate heads + # ──────────────────────────────────────────────────────────────────────── + # [B, seq_len, n_heads, d_head_latent] -> [B, seq_len, d_latent_kv] + out_concat = out_heads.transpose(1, 2).reshape(B, seq_len, self.d_latent_kv) + + # Project back to d_model + out = self.out_proj(out_concat) # [B, seq_len, d_model] + out = self.proj_dropout(out) + + return out + + +class AttentionBlock(nn.Module): + """ + Attention block with pre-norm residual connection and feed-forward network. + + Structure: + Input + ├─> Norm ─┬─> MLA ──┬─> Residual Add + │ └────────┘ + ├────────────────────────────────────> Norm ─┬─> SwiGLU FFN ──┬─> Residual Add + │ └───────┘ │ + └────────────────────────────────────────────────────────────> Output + """ + + def __init__(self, d_model, d_latent_kv, n_heads, d_rope, d_ff, dropout=0.1, gqa_groups=1, + rope_scale=40.0, max_seq_length=1024, yarn_alpha=1.0, + residual_rms_norm=False, residual_rms_target=1.0, residual_rms_cap=None, + residual_rms_eps=1e-6): + super().__init__() + self.residual_rms_norm = residual_rms_norm + self.residual_rms_target = residual_rms_target + self.residual_rms_cap = residual_rms_cap + self.residual_rms_eps = residual_rms_eps + self.mla = DeepSeekMLA(d_model, d_latent_kv, n_heads, d_rope, dropout, gqa_groups, + rope_scale=rope_scale, max_seq_length=max_seq_length, + yarn_alpha=yarn_alpha) + + # SwiGLU feed-forward network + self.ff_norm = nn.LayerNorm(d_model) + self.ff_gate = nn.Linear(d_model, d_ff, bias=False) + self.ff_value = nn.Linear(d_model, d_ff, bias=False) + self.ff_out = nn.Linear(d_ff, d_model, bias=False) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, attention_mask=None): + """ + Args: + x: [B, seq_len, d_model] + attention_mask: [B, seq_len] or [B, 1, seq_len, seq_len] + + Returns: + out: [B, seq_len, d_model] + """ + # Attention with residual + attn_out = self.mla(x, attention_mask) + x = x + self.dropout(attn_out) + x = _residual_rms_norm( + x, + self.residual_rms_norm, + self.residual_rms_target, + self.residual_rms_eps, + self.residual_rms_cap, + ) + + # FFN with residual + ff_norm = self.ff_norm(x) + ff_gate = self.ff_gate(ff_norm) + ff_value = self.ff_value(ff_norm) + ff_out = ff_value * F.silu(ff_gate) # SwiGLU activation + ff_out = self.ff_out(ff_out) + x = x + self.dropout(ff_out) + x = _residual_rms_norm( + x, + self.residual_rms_norm, + self.residual_rms_target, + self.residual_rms_eps, + self.residual_rms_cap, + ) + + return x diff --git a/code/TaoTrain/src/taoTrain/models/registry.py b/code/TaoTrain/src/taoTrain/models/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..d5ffc8fcb7de3db731725d50d98c13504b05feb1 --- /dev/null +++ b/code/TaoTrain/src/taoTrain/models/registry.py @@ -0,0 +1,73 @@ +"""Model architecture registry and factory.""" + +from typing import Dict, Type, Optional +import torch +from taoTrain.core import BaseModel +from taoTrain.config import ModelConfig + + +# Global registry for model architectures +_ARCHITECTURE_REGISTRY: Dict[str, Type[BaseModel]] = {} + + +def register_architecture(name: str): + """Decorator to register a custom model architecture.""" + def decorator(cls: Type[BaseModel]): + if name in _ARCHITECTURE_REGISTRY: + raise ValueError(f"Architecture '{name}' is already registered") + _ARCHITECTURE_REGISTRY[name] = cls + return cls + return decorator + + +def get_registered_architectures() -> Dict[str, Type[BaseModel]]: + """Get all registered architectures.""" + return _ARCHITECTURE_REGISTRY.copy() + + +def get_model( + config: ModelConfig, + device: Optional[torch.device] = None, +) -> BaseModel: + """ + Create a model instance from config. + + Args: + config: ModelConfig instance + device: Device to create model on (defaults to CPU) + + Returns: + Model instance + """ + if device is None: + device = torch.device('cpu') + + # Handle both enum and string values + arch_type = config.architecture_type + if isinstance(arch_type, str): + arch_name = arch_type + else: + arch_name = arch_type.value + + if arch_name not in _ARCHITECTURE_REGISTRY: + raise ValueError( + f"Unknown architecture: {arch_name}. " + f"Available: {list(_ARCHITECTURE_REGISTRY.keys())}" + ) + + model_class = _ARCHITECTURE_REGISTRY[arch_name] + model = model_class(config).to(device) + + return model + + +def register_builtin_architectures(): + """Register all built-in architectures.""" + # Import here to register (avoid circular imports) + from . import transformer # noqa: F401 + from . import taonet # noqa: F401 + from . import taonet_ssm # noqa: F401 + + +# Auto-register built-in architectures when module is imported +register_builtin_architectures() diff --git a/code/TaoTrain/src/taoTrain/models/taonet.py b/code/TaoTrain/src/taoTrain/models/taonet.py new file mode 100644 index 0000000000000000000000000000000000000000..61c6eee39e9b6fba527d9ac446e49af548af0ff9 --- /dev/null +++ b/code/TaoTrain/src/taoTrain/models/taonet.py @@ -0,0 +1,248 @@ +""" +SimpleLLM - Pure Attention-based Language Model with DeepSeek MLA + RoPE. + +Architecture: +- Token Embedding → Attention Blocks → Output Head +- Attention Blocks: Multi-head Latent Attention with RoPE positional embeddings +- Feed-forward: SwiGLU gates +- No state-space models (SSM), pure transformer architecture +- Full BF16 precision (no quantization) +""" + +import math +from typing import Optional +import torch +import torch.nn as nn +import torch.nn.functional as F + +from taoTrain.core import BaseModel +from taoTrain.config import ModelConfig +from .registry import register_architecture +from .mla_components import AttentionBlock +from .embeddings import FactorizedEmbedding + + +@register_architecture("taonet") +class SimpleLLM(BaseModel): + """ + Pure attention-based language model with DeepSeek MLA + RoPE. + + Stateless architecture - no internal state management needed. + + Args: + config: ModelConfig with: + - vocab_size: Vocabulary size + - hidden_dim: Model dimension (d_model) + - hidden_dim_ff: Feed-forward dimension (default: 4 * hidden_dim) + - num_layers: Number of attention blocks (n_layers) + - num_heads: Number of attention heads (n_attn_heads) + - d_latent_kv: KV compression dimension (default: 3/4 * hidden_dim) + - d_rope: RoPE dimension per head (default: hidden_dim // num_heads) + - max_seq_length: Maximum sequence length + - dropout: Dropout rate + - gqa_groups: Grouped Query Attention groups (default: 1) + - use_factorized_embedding: Use low-rank embedding (default: False) + """ + + def __init__(self, config: ModelConfig): + super().__init__(config) + + # Parse config - use defaults if not specified + self.vocab_size = config.vocab_size + self.d_model = config.hidden_dim + self.n_layers = config.num_layers + self.n_heads = config.num_heads + self.dropout = config.dropout + + # Optional parameters with smart defaults + self.d_latent_kv = config.d_latent_kv if config.d_latent_kv is not None else int(self.d_model * 0.75) + self.d_rope = config.d_rope if config.d_rope is not None else (self.d_model // self.n_heads) + self.d_ff = config.hidden_dim_ff if config.hidden_dim_ff is not None else (self.d_model * 4) + self.gqa_groups = getattr(config, 'gqa_groups', 1) + self.use_factorized_embedding = getattr(config, 'use_factorized_embedding', False) + self.d_embed_rank = getattr(config, 'd_embed_rank', 96) + + # YaRN parameters for context length extension + self.rope_scale = getattr(config, 'rope_scale', 40.0) + self.yarn_enabled = getattr(config, 'yarn_enabled', False) + self.yarn_alpha = getattr(config, 'yarn_alpha', 1.0) + self.max_seq_length = config.max_seq_length + + # Validate dimensions + assert self.d_model % self.n_heads == 0, \ + f"hidden_dim ({self.d_model}) must be divisible by num_heads ({self.n_heads})" + assert self.d_latent_kv % self.n_heads == 0, \ + f"d_latent_kv ({self.d_latent_kv}) must be divisible by num_heads ({self.n_heads})" + + # Token embedding + if self.use_factorized_embedding: + self.token_embedding = FactorizedEmbedding( + self.vocab_size, + self.d_model, + self.d_embed_rank + ) + else: + self.token_embedding = nn.Embedding(self.vocab_size, self.d_model) + + # Embedding dropout + self.embedding_dropout = nn.Dropout(self.dropout) + + # Attention blocks with MLA + SwiGLU FFN + self.blocks = nn.ModuleList() + for _ in range(self.n_layers): + self.blocks.append( + AttentionBlock( + d_model=self.d_model, + d_latent_kv=self.d_latent_kv, + n_heads=self.n_heads, + d_rope=self.d_rope, + d_ff=int(self.d_ff), + dropout=self.dropout, + gqa_groups=self.gqa_groups, + rope_scale=self.rope_scale, + max_seq_length=self.max_seq_length, + yarn_alpha=self.yarn_alpha, + ) + ) + + # Final layer norm + self.final_norm = nn.LayerNorm(self.d_model) + + # Output projection to vocabulary + self.output_head = nn.Linear(self.d_model, self.vocab_size, bias=False) + + # Initialize weights + self.apply(self._init_weights) + + # Cache for causal mask + self.register_buffer("causal_mask_cache", None, persistent=False) + + self._print_architecture() + + def _init_weights(self, module): + """Initialize weights for stable training.""" + if isinstance(module, nn.Linear): + nn.init.normal_(module.weight, mean=0.0, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=0.02) + + def _print_architecture(self): + """Print model architecture summary.""" + total_params = sum(p.numel() for p in self.parameters()) + trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) + + print(f"\n{'='*70}") + print("MODEL ARCHITECTURE - TAОNET (DeepSeek MLA + RoPE)") + print(f"{'='*70}") + print(f"Embedding:") + if self.use_factorized_embedding: + embed_rank_params = self.vocab_size * self.d_embed_rank + embed_proj_params = self.d_embed_rank * self.d_model + print(f" Type: Factorized (rank={self.d_embed_rank})") + print(f" Rank layer: {embed_rank_params/1e6:>8.2f}M") + print(f" Projection: {embed_proj_params/1e6:>8.2f}M") + else: + embed_params = self.vocab_size * self.d_model + print(f" Type: Standard") + print(f" Params: {embed_params/1e6:>8.2f}M") + + output_params = self.d_model * self.vocab_size + print(f"Output Head: {output_params/1e6:>8.2f}M") + print(f"Attention Blocks: {len(self.blocks):>10} layers × AttentionBlock") + print(f"{'─'*70}") + print(f"Total Parameters: {total_params/1e6:>8.2f}M (trainable: {trainable_params/1e6:.2f}M)") + print(f"{'─'*70}") + print(f"Configuration:") + print(f" Model dimension (d_model): {self.d_model}") + print(f" KV latent dimension (d_latent_kv): {self.d_latent_kv}") + print(f" Attention heads: {self.n_heads}") + print(f" Head dimension: {self.d_model // self.n_heads}") + print(f" RoPE dimension: {self.d_rope}") + print(f" Feed-forward dimension: {int(self.d_ff)}") + print(f" Number of layers: {self.n_layers}") + print(f" Max sequence length: {self.config.max_seq_length}") + print(f" Dropout: {self.dropout}") + print(f" GQA groups: {self.gqa_groups}") + print(f"{'='*70}\n") + + def _get_causal_mask(self, seq_len, device): + """Get or create causal mask for sequence.""" + if self.causal_mask_cache is None or self.causal_mask_cache.size(-1) < seq_len: + # [seq_len, seq_len] lower triangular matrix (1 = attend, 0 = mask) + mask = torch.tril(torch.ones(seq_len, seq_len, device=device, dtype=torch.bool)) + self.register_buffer("causal_mask_cache", mask, persistent=False) + return self.causal_mask_cache[:seq_len, :seq_len] + + def forward( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + ) -> dict: + """ + Forward pass through the model. + + Args: + input_ids: [batch_size, seq_len] tensor of token IDs + attention_mask: [batch_size, seq_len] tensor where 1 = valid, 0 = padding + labels: [batch_size, seq_len] target token IDs for loss computation + + Returns: + Dictionary with: + - 'logits': [batch_size, seq_len, vocab_size] output logits + - 'loss': scalar loss (if labels provided, else None) + """ + batch_size, seq_len = input_ids.shape + device = input_ids.device + + # Get causal mask: [seq_len, seq_len] + causal_mask = self._get_causal_mask(seq_len, device) + + # Combine causal mask with attention mask if provided + if attention_mask is not None: + # attention_mask: [batch, seq_len] where 1 = valid, 0 = padding + # Expand to [batch, 1, 1, seq_len] + padding_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool() + # Combine with causal: [1, 1, seq_len, seq_len] * [batch, 1, 1, seq_len] + combined_mask = causal_mask.unsqueeze(0).unsqueeze(0) & padding_mask + # For MLA: convert to {0, 1} format + combined_mask = combined_mask.float() + else: + # Just causal mask + combined_mask = causal_mask.unsqueeze(0).unsqueeze(0).float() + + # Embed tokens: [batch_size, seq_len] -> [batch_size, seq_len, d_model] + x = self.token_embedding(input_ids) + x = self.embedding_dropout(x) + + # Pass through attention blocks + for block in self.blocks: + x = block(x, attention_mask=combined_mask) + + # Final layer norm + x = self.final_norm(x) + + # Output projection to vocabulary + logits = self.output_head(x) # [batch_size, seq_len, vocab_size] + + # Compute loss if labels are provided + loss = None + if labels is not None: + # Flatten for loss computation + logits_flat = logits.view(-1, logits.size(-1)) # (batch * seq_len, vocab_size) + labels_flat = labels.view(-1) + + # Only compute loss on valid targets (ignore -100 tokens for padding) + loss = F.cross_entropy( + logits_flat, + labels_flat, + reduction='mean', + ignore_index=-100 + ) + + return { + 'logits': logits, + 'loss': loss, + } diff --git a/code/TaoTrain/src/taoTrain/models/taonet_ssm.py b/code/TaoTrain/src/taoTrain/models/taonet_ssm.py new file mode 100644 index 0000000000000000000000000000000000000000..35dcbd560152eaab6b666dcbd44ce75383d332ef --- /dev/null +++ b/code/TaoTrain/src/taoTrain/models/taonet_ssm.py @@ -0,0 +1,654 @@ +"""TaoNet variant that replaces MLA attention with an SSM mixer.""" + +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from taoTrain.config import ModelConfig +from taoTrain.core import BaseModel + +from .embeddings import FactorizedEmbedding +from .mla_components import AttentionBlock +from .registry import register_architecture + + +def _load_ssm_core(core: str): + try: + from gamma_space_model.modules.s4_ternary_dplr_ssm import S4TernaryDPLRSSM + from gamma_space_model.modules.ssm_gamma_s4 import SSMGammaS4 + except ImportError as exc: + raise ImportError( + "taonet_ssm requires the Gamma Space Model package. Install the SSM repo " + "with `pip install -e /path/to/Taotern_SSM`, or put it on PYTHONPATH." + ) from exc + if core == "gamma_s4": + return SSMGammaS4 + if core == "dplr": + return S4TernaryDPLRSSM + raise ValueError(f"Unsupported ssm_core '{core}'.") + + +def _padding_mask_from_attention_mask(attention_mask: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + if attention_mask is None: + return None + if attention_mask.dim() == 2: + return attention_mask + if attention_mask.dim() == 4: + return attention_mask.bool().any(dim=-2).squeeze(1).to(dtype=attention_mask.dtype) + raise ValueError( + "Expected attention_mask with shape [batch, seq_len] or " + f"[batch, 1, seq_len, seq_len], got {tuple(attention_mask.shape)}." + ) + + +def _hybrid_ssm_layer_indices(config: ModelConfig, num_layers: int) -> set[int]: + if config.hybrid_ssm_layers: + indices = set() + for item in config.hybrid_ssm_layers.split(","): + item = item.strip() + if not item: + continue + index = int(item) + if index < 0 or index >= num_layers: + raise ValueError( + f"hybrid_ssm_layers index {index} is outside [0, {num_layers - 1}]." + ) + indices.add(index) + if not indices: + raise ValueError("hybrid_ssm_layers was set but did not contain any valid layer indices.") + return indices + + if config.hybrid_pattern == "attention_first": + return {idx for idx in range(num_layers) if idx % 2 == 1} + if config.hybrid_pattern == "ssm_first": + return {idx for idx in range(num_layers) if idx % 2 == 0} + if config.hybrid_pattern == "single_ssm_middle": + return {num_layers // 2} + if config.hybrid_pattern == "single_ssm_late": + return {num_layers - 1} + raise ValueError(f"Unsupported hybrid_pattern '{config.hybrid_pattern}'.") + + +class ChannelGate(nn.Module): + """Elementwise gate with one scale and bias per model channel.""" + + def __init__(self, d_model: int) -> None: + super().__init__() + self.weight = nn.Parameter(torch.zeros(d_model)) + self.bias = nn.Parameter(torch.full((d_model,), 2.0)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x * self.weight + self.bias + + def reset_parameters(self) -> None: + nn.init.zeros_(self.weight) + nn.init.constant_(self.bias, 2.0) + + +def _build_gate(enabled: bool, gate_type: str, d_model: int) -> nn.Module | None: + if not enabled: + return None + if gate_type == "dense": + return nn.Linear(d_model, d_model) + if gate_type == "channel": + return ChannelGate(d_model) + raise ValueError(f"Unsupported ssm_gate_type '{gate_type}'.") + + +def _residual_rms_norm( + x: torch.Tensor, + enabled: bool, + target: float, + eps: float, + cap: Optional[float] = None, +) -> torch.Tensor: + if not enabled and cap is None: + return x + rms = x.float().square().mean(dim=-1, keepdim=True).add(eps).sqrt() + if enabled: + scale = target / rms + else: + cap_tensor = torch.tensor(float(cap), dtype=rms.dtype, device=rms.device) + scale = torch.minimum(torch.ones_like(rms), cap_tensor / rms) + return x * scale.to(dtype=x.dtype) + + +class SSMMixer(nn.Module): + """Causal sequence mixer with the same residual-branch contract as MLA.""" + + def __init__(self, config: ModelConfig) -> None: + super().__init__() + SSMCore = _load_ssm_core(config.ssm_core) + + self.d_model = config.hidden_dim + self.ssm_core = config.ssm_core + d_latent_kv = config.d_latent_kv if config.d_latent_kv is not None else int(self.d_model * 0.75) + self.ssm_hidden_dim = config.ssm_hidden_dim if config.ssm_hidden_dim is not None else d_latent_kv + self.ssm_mixer_dim = config.ssm_mixer_dim if config.ssm_mixer_dim is not None else self.d_model + self.ssm_num_lanes = config.ssm_num_lanes + self.ssm_lane_combine = config.ssm_lane_combine + self.ssm_lane_mode = config.ssm_lane_mode + self.ssm_split_mix = config.ssm_split_mix + self.use_padding_mask = config.ssm_use_padding_mask + self.branch_rms_norm = config.ssm_branch_rms_norm + self.branch_rms_eps = config.ssm_branch_rms_eps + self.branch_clip_value = config.ssm_branch_clip_value + if self.ssm_num_lanes < 1: + raise ValueError("ssm_num_lanes must be at least 1.") + if self.ssm_lane_mode not in {"full", "split"}: + raise ValueError(f"Unsupported ssm_lane_mode '{self.ssm_lane_mode}'.") + if self.ssm_split_mix not in {"none", "hadamard"}: + raise ValueError(f"Unsupported ssm_split_mix '{self.ssm_split_mix}'.") + if self.ssm_split_mix != "none" and self.ssm_lane_mode != "split": + raise ValueError("ssm_split_mix is only supported when ssm_lane_mode='split'.") + if self.ssm_split_mix == "hadamard" and self.ssm_num_lanes != 2: + raise ValueError("ssm_split_mix='hadamard' currently requires exactly two SSM lanes.") + if self.ssm_lane_mode == "split" and self.ssm_mixer_dim % self.ssm_num_lanes != 0: + raise ValueError( + "ssm_mixer_dim must be divisible by ssm_num_lanes when ssm_lane_mode='split'." + ) + self.ssm_lane_dim = ( + self.ssm_mixer_dim // self.ssm_num_lanes + if self.ssm_lane_mode == "split" + else self.ssm_mixer_dim + ) + + self.norm = nn.LayerNorm(self.d_model) + self.gate_type = config.ssm_gate_type + self.input_gate = _build_gate(config.ssm_input_gate, self.gate_type, self.d_model) + self.input_proj = ( + nn.Linear(self.d_model, self.ssm_mixer_dim, bias=False) + if self.ssm_mixer_dim != self.d_model + else nn.Identity() + ) + common_kwargs = { + "state_dim": self.ssm_lane_dim, + "hidden_dim": self.ssm_hidden_dim, + "dt_min": config.ssm_dt_min, + "dt_max": config.ssm_dt_max, + "dt_init": config.ssm_dt_init, + "use_D": config.ssm_use_d, + "kernel_mode": config.ssm_kernel_mode, + "kernel_threshold": config.ssm_kernel_threshold, + } + self.ssm_lanes = nn.ModuleList( + [self._build_ssm_lane(SSMCore, common_kwargs, config) for _ in range(self.ssm_num_lanes)] + ) + self.ssm = self.ssm_lanes[0] + self.lane_weights = None + if self.ssm_lane_combine not in {"mean", "channel"}: + raise ValueError(f"Unsupported ssm_lane_combine '{self.ssm_lane_combine}'.") + if ( + self.ssm_lane_mode == "full" + and self.ssm_num_lanes > 1 + and self.ssm_lane_combine == "channel" + ): + self.lane_weights = nn.Parameter( + torch.full((self.ssm_num_lanes, self.ssm_mixer_dim), 1.0 / self.ssm_num_lanes) + ) + + if config.ssm_activation == "gelu": + self.activation = nn.GELU() + elif config.ssm_activation == "silu": + self.activation = nn.SiLU() + elif config.ssm_activation in {"identity", "linear"}: + self.activation = nn.Identity() + else: + raise ValueError(f"Unsupported ssm_activation '{config.ssm_activation}'.") + + self.output_gate = _build_gate(config.ssm_gate, self.gate_type, self.d_model) + self.out_proj = nn.Linear(self.ssm_mixer_dim, self.d_model, bias=False) + self.layer_scale = nn.Parameter(torch.full((self.d_model,), config.ssm_layer_scale_init)) + self.local_shift_scale = None + if config.ssm_local_shift: + if config.ssm_local_shift_per_channel: + self.local_shift_scale = nn.Parameter( + torch.full((self.d_model,), float(config.ssm_local_shift_init)) + ) + else: + self.local_shift_scale = nn.Parameter(torch.tensor(float(config.ssm_local_shift_init))) + self.proj_dropout = nn.Dropout(config.dropout) + + self._reset_parameters() + + def _normalize_branch(self, ssm_out: torch.Tensor) -> torch.Tensor: + if not self.branch_rms_norm: + return ssm_out + rms = ssm_out.float().square().mean(dim=-1, keepdim=True).add(self.branch_rms_eps).rsqrt() + return ssm_out * rms.to(dtype=ssm_out.dtype) + + def _build_ssm_lane(self, SSMCore, common_kwargs: dict, config: ModelConfig) -> nn.Module: + if config.ssm_core == "gamma_s4": + return SSMCore( + **common_kwargs, + discretization=config.ssm_discretization, + ) + return SSMCore( + **common_kwargs, + rank=config.ssm_rank, + max_low_rank_scale=config.ssm_max_low_rank_scale, + finite_tail_correction=config.ssm_finite_tail_correction, + ) + + def _reset_parameters(self) -> None: + if isinstance(self.input_gate, nn.Linear): + nn.init.zeros_(self.input_gate.weight) + nn.init.constant_(self.input_gate.bias, 2.0) + elif isinstance(self.input_gate, ChannelGate): + self.input_gate.reset_parameters() + if isinstance(self.output_gate, nn.Linear): + nn.init.zeros_(self.output_gate.weight) + nn.init.constant_(self.output_gate.bias, 2.0) + elif isinstance(self.output_gate, ChannelGate): + self.output_gate.reset_parameters() + if isinstance(self.input_proj, nn.Linear): + nn.init.xavier_uniform_(self.input_proj.weight) + nn.init.xavier_uniform_(self.out_proj.weight) + else: + nn.init.eye_(self.out_proj.weight) + + def forward( + self, + x: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + x_norm = self.norm(x) + ssm_in = x_norm + if self.input_gate is not None: + ssm_in = ssm_in * torch.sigmoid(self.input_gate(x_norm)) + ssm_in = self.input_proj(ssm_in) + + padding_mask = _padding_mask_from_attention_mask(attention_mask) if self.use_padding_mask else None + lane_outputs = [] + if self.ssm_lane_mode == "split": + lane_inputs = torch.split(ssm_in, self.ssm_lane_dim, dim=-1) + else: + lane_inputs = [ssm_in] * self.ssm_num_lanes + for lane, lane_input in zip(self.ssm_lanes, lane_inputs): + lane_out, _ = lane( + lane_input, + mask=padding_mask, + return_state=False, + ) + lane_outputs.append(lane_out) + if self.ssm_lane_mode == "split": + if self.ssm_split_mix == "hadamard": + left, right = lane_outputs + inv_sqrt_2 = 0.7071067811865476 + ssm_out = torch.cat((left + right, left - right), dim=-1) * inv_sqrt_2 + else: + ssm_out = torch.cat(lane_outputs, dim=-1) + elif len(lane_outputs) == 1: + ssm_out = lane_outputs[0] + elif self.lane_weights is not None: + weights = self.lane_weights.to(dtype=lane_outputs[0].dtype, device=lane_outputs[0].device) + ssm_out = torch.stack(lane_outputs, dim=2) + ssm_out = (ssm_out * weights.view(1, 1, self.ssm_num_lanes, self.ssm_mixer_dim)).sum(dim=2) + else: + ssm_out = torch.stack(lane_outputs, dim=0).mean(dim=0) + ssm_out = self.activation(ssm_out) + ssm_out = self.out_proj(ssm_out) + + if self.output_gate is not None: + ssm_out = ssm_out * torch.sigmoid(self.output_gate(x_norm)) + + ssm_out = self._normalize_branch(ssm_out) + ssm_out = ssm_out * self.layer_scale + if self.local_shift_scale is not None: + shifted = torch.zeros_like(x_norm) + shifted[:, 1:] = x_norm[:, :-1] + ssm_out = ssm_out + shifted * self.local_shift_scale + if self.branch_clip_value is not None: + ssm_out = torch.clamp(ssm_out, -self.branch_clip_value, self.branch_clip_value) + return self.proj_dropout(ssm_out) + + +class SSMAttentionBlock(nn.Module): + """TaoNet block with Gamma SSM sequence mixing and the original SwiGLU FFN.""" + + def __init__(self, config: ModelConfig) -> None: + super().__init__() + d_model = config.hidden_dim + d_ff = config.hidden_dim_ff if config.hidden_dim_ff is not None else d_model * 4 + + self.mixer = SSMMixer(config) + self.residual_rms_norm = config.block_residual_rms_norm + self.residual_rms_target = config.block_residual_rms_target + self.residual_rms_cap = config.block_residual_rms_cap + self.residual_rms_eps = config.block_residual_rms_eps + self.ff_norm = nn.LayerNorm(d_model) + self.ff_gate = nn.Linear(d_model, int(d_ff), bias=False) + self.ff_value = nn.Linear(d_model, int(d_ff), bias=False) + self.ff_out = nn.Linear(int(d_ff), d_model, bias=False) + self.dropout = nn.Dropout(config.dropout) + + def forward( + self, + x: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + x = x + self.dropout(self.mixer(x, attention_mask=attention_mask)) + x = _residual_rms_norm( + x, + self.residual_rms_norm, + self.residual_rms_target, + self.residual_rms_eps, + self.residual_rms_cap, + ) + + ff_norm = self.ff_norm(x) + ff_gate = self.ff_gate(ff_norm) + ff_value = self.ff_value(ff_norm) + ff_out = ff_value * F.silu(ff_gate) + ff_out = self.ff_out(ff_out) + x = x + self.dropout(ff_out) + return _residual_rms_norm( + x, + self.residual_rms_norm, + self.residual_rms_target, + self.residual_rms_eps, + self.residual_rms_cap, + ) + + +@register_architecture("taonet_ssm") +class TaoNetSSMLLM(BaseModel): + """TaoNet language model with SSM blocks replacing MLA attention.""" + + def __init__(self, config: ModelConfig): + super().__init__(config) + + self.vocab_size = config.vocab_size + self.d_model = config.hidden_dim + self.n_layers = config.num_layers + self.n_heads = config.num_heads + self.dropout = config.dropout + self.d_latent_kv = config.d_latent_kv if config.d_latent_kv is not None else int(self.d_model * 0.75) + self.d_ff = config.hidden_dim_ff if config.hidden_dim_ff is not None else self.d_model * 4 + self.use_factorized_embedding = getattr(config, "use_factorized_embedding", False) + self.d_embed_rank = getattr(config, "d_embed_rank", 96) + self.max_seq_length = config.max_seq_length + + if self.use_factorized_embedding: + self.token_embedding = FactorizedEmbedding( + self.vocab_size, + self.d_model, + self.d_embed_rank, + ) + else: + self.token_embedding = nn.Embedding(self.vocab_size, self.d_model) + + self.embedding_dropout = nn.Dropout(self.dropout) + self.blocks = nn.ModuleList([SSMAttentionBlock(config) for _ in range(self.n_layers)]) + self.final_norm = nn.LayerNorm(self.d_model) + self.output_head = nn.Linear(self.d_model, self.vocab_size, bias=False) + + self.apply(self._init_weights) + for block in self.blocks: + block.mixer._reset_parameters() + + self._print_architecture(config) + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std) + + def _print_architecture(self, config: ModelConfig): + total_params = sum(p.numel() for p in self.parameters()) + trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) + ssm_hidden_dim = config.ssm_hidden_dim if config.ssm_hidden_dim is not None else self.d_latent_kv + + print(f"\n{'=' * 70}") + print(f"MODEL ARCHITECTURE - TAONET-SSM ({config.ssm_core} + SwiGLU)") + print(f"{'=' * 70}") + print(f"Embedding vocab: {self.vocab_size}") + print(f"Output Head: {(self.d_model * self.vocab_size) / 1e6:>8.2f}M") + print(f"SSM Blocks: {len(self.blocks):>8} layers x SSMMixer") + print(f"{'-' * 70}") + print(f"Total Parameters: {total_params / 1e6:>8.2f}M (trainable: {trainable_params / 1e6:.2f}M)") + print(f"{'-' * 70}") + print("Configuration:") + print(f" Model dimension (d_model): {self.d_model}") + print(f" SSM core: {config.ssm_core}") + print(f" SSM hidden dimension: {ssm_hidden_dim}") + print(f" SSM mixer dimension: {config.ssm_mixer_dim or self.d_model}") + print(f" SSM lanes: {config.ssm_num_lanes}") + print(f" SSM lane mode: {config.ssm_lane_mode}") + print(f" SSM split mix: {config.ssm_split_mix}") + print(f" SSM lane combine: {config.ssm_lane_combine}") + if config.ssm_core == "dplr": + print(f" SSM DPLR rank: {config.ssm_rank}") + print(f" SSM discretization: {config.ssm_discretization}") + print(f" SSM kernel mode: {config.ssm_kernel_mode}") + print(f" SSM kernel threshold: {config.ssm_kernel_threshold}") + print(f" SSM padding mask enabled: {config.ssm_use_padding_mask}") + print(f" SSM gate type: {config.ssm_gate_type}") + print(f" SSM branch RMS norm: {config.ssm_branch_rms_norm}") + print(f" SSM branch clip value: {config.ssm_branch_clip_value}") + print(f" Block residual RMS norm: {config.block_residual_rms_norm}") + print(f" Block residual RMS cap: {config.block_residual_rms_cap}") + print(f" SSM local shift enabled: {config.ssm_local_shift}") + print(f" SSM local shift per channel: {config.ssm_local_shift_per_channel}") + print(f" Feed-forward dimension: {int(self.d_ff)}") + print(f" Number of layers: {self.n_layers}") + print(f" Max sequence length: {self.max_seq_length}") + print(f" Dropout: {self.dropout}") + print(f"{'=' * 70}\n") + + def forward( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + ) -> dict: + x = self.token_embedding(input_ids) + x = self.embedding_dropout(x) + + for block in self.blocks: + x = block(x, attention_mask=attention_mask) + + x = self.final_norm(x) + logits = self.output_head(x) + + loss = None + if labels is not None: + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), + labels.view(-1), + reduction="mean", + ignore_index=-100, + ) + + return { + "logits": logits, + "loss": loss, + } + + +@register_architecture("taonet_hybrid") +class TaoNetHybridLLM(BaseModel): + """TaoNet language model with alternating MLA attention and SSM mixer blocks.""" + + def __init__(self, config: ModelConfig): + super().__init__(config) + + self.vocab_size = config.vocab_size + self.d_model = config.hidden_dim + self.n_layers = config.num_layers + self.n_heads = config.num_heads + self.dropout = config.dropout + self.d_latent_kv = config.d_latent_kv if config.d_latent_kv is not None else int(self.d_model * 0.75) + self.d_rope = config.d_rope if config.d_rope is not None else self.d_model // self.n_heads + self.d_ff = config.hidden_dim_ff if config.hidden_dim_ff is not None else self.d_model * 4 + self.gqa_groups = getattr(config, "gqa_groups", 1) + self.use_factorized_embedding = getattr(config, "use_factorized_embedding", False) + self.d_embed_rank = getattr(config, "d_embed_rank", 96) + self.rope_scale = getattr(config, "rope_scale", 40.0) + self.yarn_alpha = getattr(config, "yarn_alpha", 1.0) + self.max_seq_length = config.max_seq_length + + assert self.d_model % self.n_heads == 0, ( + f"hidden_dim ({self.d_model}) must be divisible by num_heads ({self.n_heads})" + ) + assert self.d_latent_kv % self.n_heads == 0, ( + f"d_latent_kv ({self.d_latent_kv}) must be divisible by num_heads ({self.n_heads})" + ) + + if self.use_factorized_embedding: + self.token_embedding = FactorizedEmbedding( + self.vocab_size, + self.d_model, + self.d_embed_rank, + ) + else: + self.token_embedding = nn.Embedding(self.vocab_size, self.d_model) + + self.embedding_dropout = nn.Dropout(self.dropout) + self.blocks = nn.ModuleList() + self.block_kinds: list[str] = [] + self.ssm_layer_indices = _hybrid_ssm_layer_indices(config, self.n_layers) + for layer_idx in range(self.n_layers): + if layer_idx in self.ssm_layer_indices: + self.blocks.append(SSMAttentionBlock(config)) + self.block_kinds.append("ssm") + else: + self.blocks.append( + AttentionBlock( + d_model=self.d_model, + d_latent_kv=self.d_latent_kv, + n_heads=self.n_heads, + d_rope=self.d_rope, + d_ff=int(self.d_ff), + dropout=self.dropout, + gqa_groups=self.gqa_groups, + rope_scale=self.rope_scale, + max_seq_length=self.max_seq_length, + yarn_alpha=self.yarn_alpha, + residual_rms_norm=config.block_residual_rms_norm, + residual_rms_target=config.block_residual_rms_target, + residual_rms_cap=config.block_residual_rms_cap, + residual_rms_eps=config.block_residual_rms_eps, + ) + ) + self.block_kinds.append("attention") + + self.final_norm = nn.LayerNorm(self.d_model) + self.output_head = nn.Linear(self.d_model, self.vocab_size, bias=False) + + self.apply(self._init_weights) + for block in self.blocks: + mixer = getattr(block, "mixer", None) + if mixer is not None: + mixer._reset_parameters() + + self.register_buffer("causal_mask_cache", None, persistent=False) + self._print_architecture(config) + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std) + + def _get_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor: + if self.causal_mask_cache is None or self.causal_mask_cache.size(-1) < seq_len: + mask = torch.tril(torch.ones(seq_len, seq_len, device=device, dtype=torch.bool)) + self.register_buffer("causal_mask_cache", mask, persistent=False) + return self.causal_mask_cache[:seq_len, :seq_len] + + def _get_combined_mask( + self, + attention_mask: Optional[torch.Tensor], + seq_len: int, + device: torch.device, + ) -> torch.Tensor: + causal_mask = self._get_causal_mask(seq_len, device) + if attention_mask is None: + return causal_mask.unsqueeze(0).unsqueeze(0).float() + padding_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool() + return (causal_mask.unsqueeze(0).unsqueeze(0) & padding_mask).float() + + def _print_architecture(self, config: ModelConfig) -> None: + total_params = sum(p.numel() for p in self.parameters()) + trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) + attention_blocks = self.block_kinds.count("attention") + ssm_blocks = self.block_kinds.count("ssm") + ssm_hidden_dim = config.ssm_hidden_dim if config.ssm_hidden_dim is not None else self.d_latent_kv + + print(f"\n{'=' * 70}") + print(f"MODEL ARCHITECTURE - TAONET-HYBRID (MLA + {config.ssm_core} SSM)") + print(f"{'=' * 70}") + print(f"Embedding vocab: {self.vocab_size}") + print(f"Output Head: {(self.d_model * self.vocab_size) / 1e6:>8.2f}M") + print(f"Attention Blocks: {attention_blocks:>8} layers") + print(f"SSM Blocks: {ssm_blocks:>8} layers") + print(f"{'-' * 70}") + print(f"Total Parameters: {total_params / 1e6:>8.2f}M (trainable: {trainable_params / 1e6:.2f}M)") + print(f"{'-' * 70}") + print("Configuration:") + print(f" Model dimension (d_model): {self.d_model}") + print(f" KV latent dimension (d_latent_kv): {self.d_latent_kv}") + print(f" Attention heads: {self.n_heads}") + print(f" SSM core: {config.ssm_core}") + print(f" SSM hidden dimension: {ssm_hidden_dim}") + print(f" SSM mixer dimension: {config.ssm_mixer_dim or self.d_model}") + print(f" SSM lanes: {config.ssm_num_lanes}") + print(f" SSM lane mode: {config.ssm_lane_mode}") + print(f" SSM split mix: {config.ssm_split_mix}") + print(f" SSM lane combine: {config.ssm_lane_combine}") + if config.ssm_core == "dplr": + print(f" SSM DPLR rank: {config.ssm_rank}") + print(f" SSM finite-tail correction: {config.ssm_finite_tail_correction}") + print(f" SSM branch RMS norm: {config.ssm_branch_rms_norm}") + print(f" SSM branch clip value: {config.ssm_branch_clip_value}") + print(f" Block residual RMS norm: {config.block_residual_rms_norm}") + print(f" Block residual RMS cap: {config.block_residual_rms_cap}") + print(f" SSM local shift enabled: {config.ssm_local_shift}") + print(f" SSM gate type: {config.ssm_gate_type}") + print(f" Hybrid pattern: {config.hybrid_pattern}") + print(f" Hybrid SSM layers: {','.join(str(i) for i in sorted(self.ssm_layer_indices))}") + print(f" Feed-forward dimension: {int(self.d_ff)}") + print(f" Number of layers: {self.n_layers}") + print(f" Max sequence length: {self.max_seq_length}") + print(f" Dropout: {self.dropout}") + print(f"{'=' * 70}\n") + + def forward( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + ) -> dict: + _, seq_len = input_ids.shape + combined_mask = self._get_combined_mask(attention_mask, seq_len, input_ids.device) + + x = self.token_embedding(input_ids) + x = self.embedding_dropout(x) + + for block in self.blocks: + x = block(x, attention_mask=combined_mask) + + x = self.final_norm(x) + logits = self.output_head(x) + + loss = None + if labels is not None: + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), + labels.view(-1), + reduction="mean", + ignore_index=-100, + ) + + return { + "logits": logits, + "loss": loss, + } diff --git a/code/TaoTrain/src/taoTrain/models/transformer.py b/code/TaoTrain/src/taoTrain/models/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..cd4e38199e59300e6546c966d3db83ad208bd0cd --- /dev/null +++ b/code/TaoTrain/src/taoTrain/models/transformer.py @@ -0,0 +1,315 @@ +"""Standard Transformer language model implementation.""" + +import math +from typing import Optional +import torch +import torch.nn as nn +import torch.nn.functional as F + +from taoTrain.core import BaseModel +from taoTrain.config import ModelConfig +from .registry import register_architecture + + +# ============================================================================ +# Components +# ============================================================================ + + +class PositionalEmbedding(nn.Module): + """Sinusoidal positional embeddings.""" + + def __init__(self, dim: int, max_seq_length: int = 2048): + """Initialize positional embeddings.""" + super().__init__() + self.dim = dim + self.max_seq_length = max_seq_length + + # Precompute positional embeddings + pe = torch.zeros(max_seq_length, dim) + pos = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim)) + + pe[:, 0::2] = torch.sin(pos * div_term) + if dim % 2 == 1: + pe[:, 1::2] = torch.cos(pos * div_term[:-1]) + else: + pe[:, 1::2] = torch.cos(pos * div_term) + + self.register_buffer("pe", pe, persistent=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Add positional embeddings to input. + + Args: + x: Input tensor (batch, seq_len, hidden_dim) + + Returns: + Input + positional embeddings + """ + seq_len = x.shape[1] + return x + self.pe[:seq_len] + + +class Attention(nn.Module): + """Multi-head self-attention using scaled dot-product attention.""" + + def __init__(self, config: ModelConfig): + """Initialize attention.""" + super().__init__() + self.hidden_dim = config.hidden_dim + self.num_heads = config.num_heads + self.head_dim = config.head_dim + + assert self.hidden_dim % self.num_heads == 0 + + # Linear projections + self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim) + self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim) + self.v_proj = nn.Linear(self.hidden_dim, self.hidden_dim) + self.out_proj = nn.Linear(self.hidden_dim, self.hidden_dim) + + self.dropout_p = config.dropout + + def forward( + self, + x: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Forward pass using scaled_dot_product_attention. + + Args: + x: Shape (batch, seq_len, hidden_dim) + attention_mask: Shape (batch, seq_len) + + Returns: + Output: Shape (batch, seq_len, hidden_dim) + """ + batch_size, seq_len, _ = x.shape + + # Project to Q, K, V + q = self.q_proj(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim) + k = self.k_proj(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim) + v = self.v_proj(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim) + + # Transpose for attention: (batch, num_heads, seq_len, head_dim) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # NOTE: PyTorch's scaled_dot_product_attention does NOT support both + # explicit attn_mask AND is_causal=True together. + # When is_causal=True, PyTorch handles causal masking automatically. + # Padding positions are handled separately via loss computation (labels=-100). + # See: https://github.com/pytorch/pytorch/issues/96099 + + # Compute attention using scaled_dot_product_attention + # is_causal=True automatically applies causal masking + # We do NOT pass attn_mask when is_causal=True + out = F.scaled_dot_product_attention( + q, k, v, + attn_mask=None, # Must be None when is_causal=True + dropout_p=self.dropout_p if self.training else 0.0, + is_causal=True, + scale=None # Uses default scale of 1/sqrt(head_dim) + ) # (batch, num_heads, seq_len, head_dim) + + # Transpose back and reshape + out = out.transpose(1, 2).contiguous() # (batch, seq_len, num_heads, head_dim) + out = out.reshape(batch_size, seq_len, self.hidden_dim) + + # Output projection + out = self.out_proj(out) + + return out + + +class SwiGLU(nn.Module): + """Swish Gated Linear Unit activation.""" + + def __init__(self, in_dim: int, out_dim: int, dropout: float = 0.0): + """ + Initialize SwiGLU. + + Args: + in_dim: Input dimension + out_dim: Intermediate/hidden dimension + dropout: Dropout rate + """ + super().__init__() + # Project to 2x the intermediate dimension (for value and gate) + self.fc1 = nn.Linear(in_dim, 2 * out_dim) + self.fc2 = nn.Linear(out_dim, in_dim) # Project back to input dimension + self.dropout = nn.Dropout(dropout) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass with SwiGLU activation. + + Args: + x: Input tensor + + Returns: + Gated activation output (same dimension as input) + """ + # Project to 2x intermediate dimension + x = self.fc1(x) + + # Split into value and gate + x, gate = x.chunk(2, dim=-1) + + # SwiGLU: value * swish(gate) = value * gate * sigmoid(gate) + x = x * F.silu(gate) # SiLU is Swish: x * sigmoid(x) + + x = self.dropout(x) + x = self.fc2(x) # Project back to input dimension + + return x + + +class FeedForward(nn.Module): + """Feed-forward network with SwiGLU activation.""" + + def __init__(self, config: ModelConfig): + """Initialize FFN with SwiGLU.""" + super().__init__() + self.swiglu = SwiGLU( + in_dim=config.hidden_dim, + out_dim=config.intermediate_dim, + dropout=config.dropout + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass with SwiGLU activation.""" + return self.swiglu(x) + + +class TransformerBlock(nn.Module): + """Single transformer block with attention and FFN.""" + + def __init__(self, config: ModelConfig): + """Initialize transformer block.""" + super().__init__() + self.norm1 = nn.LayerNorm(config.hidden_dim) + self.attn = Attention(config) + self.norm2 = nn.LayerNorm(config.hidden_dim) + self.ffn = FeedForward(config) + + def forward( + self, + x: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with pre-norm residual connections.""" + # Attention with residual + x = x + self.attn(self.norm1(x), attention_mask=attention_mask) + + # FFN with residual + x = x + self.ffn(self.norm2(x)) + + return x + + +# ============================================================================ +# Transformer LM +# ============================================================================ + + +@register_architecture("transformer") +class TransformerLM(BaseModel): + """Standard Transformer language model.""" + + def __init__(self, config: ModelConfig): + """Initialize Transformer LM.""" + super().__init__(config) + + # Embeddings + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_dim) + self.pos_embed = PositionalEmbedding(config.hidden_dim, max_seq_length=config.max_seq_length) + self.dropout = nn.Dropout(config.dropout) + + # Transformer blocks + self.blocks = nn.ModuleList([ + TransformerBlock(config) for _ in range(config.num_layers) + ]) + + # Final layer norm + self.final_norm = nn.LayerNorm(config.hidden_dim) + + # Output projection (shared with input embeddings for efficiency) + self.lm_head = nn.Linear(config.hidden_dim, config.vocab_size, bias=False) + + # Weight tying (optional) + self.lm_head.weight = self.embed_tokens.weight + + # Initialize weights + self._init_weights() + + def _init_weights(self): + """Initialize model weights.""" + for module in self.modules(): + if isinstance(module, nn.Linear): + nn.init.normal_(module.weight, std=self.config.init_std) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, std=self.config.init_std) + + def forward( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + ) -> dict[str, torch.Tensor]: + """ + Forward pass. + + Args: + input_ids: (batch_size, seq_len) + attention_mask: (batch_size, seq_len) + labels: (batch_size, seq_len) for loss computation + + Returns: + Dict with 'logits' and optionally 'loss' + """ + batch_size, seq_len = input_ids.shape + + # Embedding + x = self.embed_tokens(input_ids) + + # Add positional embeddings + x = self.pos_embed(x) + + x = self.dropout(x) + + # Transformer blocks + for block in self.blocks: + x = block(x, attention_mask=attention_mask) + + # Final normalization + x = self.final_norm(x) + + # LM head + logits = self.lm_head(x) # (batch, seq_len, vocab_size) + + # Loss computation + loss = None + if labels is not None: + # Flatten for loss computation + logits_flat = logits.view(-1, logits.size(-1)) # (batch * seq_len, vocab_size) + labels_flat = labels.view(-1) + + # Only compute loss on valid targets (ignore -100 tokens) + loss = F.cross_entropy( + logits_flat, + labels_flat, + reduction='mean', + ignore_index=-100 + ) + + return { + 'logits': logits, + 'loss': loss, + } diff --git a/code/TaoTrain/src/taoTrain/optimizers/__init__.py b/code/TaoTrain/src/taoTrain/optimizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4d4084621a2127df5249dc51d0e1365d01b33ae9 --- /dev/null +++ b/code/TaoTrain/src/taoTrain/optimizers/__init__.py @@ -0,0 +1,13 @@ +"""Optimizer registry and factories.""" + +from .registry import ( + register_optimizer, + get_optimizer, + get_registered_optimizers, +) + +__all__ = [ + "register_optimizer", + "get_optimizer", + "get_registered_optimizers", +] diff --git a/code/TaoTrain/src/taoTrain/optimizers/adam.py b/code/TaoTrain/src/taoTrain/optimizers/adam.py new file mode 100644 index 0000000000000000000000000000000000000000..9514e6b028548cb49bce574f9cf26f1a8ca1c5da --- /dev/null +++ b/code/TaoTrain/src/taoTrain/optimizers/adam.py @@ -0,0 +1,64 @@ +"""Adam optimizer factory.""" + +import torch.optim as optim +from taoTrain.core.base import BaseModel +from taoTrain.config import TrainingConfig +from .registry import register_optimizer + + +def _separate_parameters(model: BaseModel) -> tuple[list, list]: + """ + Separate model parameters into decay and no-decay groups. + + Args: + model: Model instance + + Returns: + Tuple of (decay_params, no_decay_params) + """ + decay_params = [] + no_decay_params = [] + + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + + # Apply weight decay to all params except biases and layer norms + if 'bias' in name or 'norm' in name: + no_decay_params.append(param) + else: + decay_params.append(param) + + return decay_params, no_decay_params + + +@register_optimizer("adam") +def create_adam(model: BaseModel, config: TrainingConfig) -> optim.Adam: + """ + Create Adam optimizer with weight decay applied selectively. + + Args: + model: Model instance + config: TrainingConfig + + Returns: + Adam optimizer instance + """ + optimizer_config = config.optimizer + + # Separate parameters for weight decay + decay_params, no_decay_params = _separate_parameters(model) + + param_groups = [ + {"params": decay_params, "weight_decay": optimizer_config.weight_decay}, + {"params": no_decay_params, "weight_decay": 0.0}, + ] + + optimizer = optim.Adam( + param_groups, + lr=optimizer_config.learning_rate, + betas=optimizer_config.betas, + eps=optimizer_config.eps, + ) + + return optimizer diff --git a/code/TaoTrain/src/taoTrain/optimizers/adamw.py b/code/TaoTrain/src/taoTrain/optimizers/adamw.py new file mode 100644 index 0000000000000000000000000000000000000000..fa2bfb6723031183e5588543685e37f925d1c392 --- /dev/null +++ b/code/TaoTrain/src/taoTrain/optimizers/adamw.py @@ -0,0 +1,64 @@ +"""AdamW optimizer factory.""" + +import torch.optim as optim +from taoTrain.core.base import BaseModel +from taoTrain.config import TrainingConfig +from .registry import register_optimizer + + +def _separate_parameters(model: BaseModel) -> tuple[list, list]: + """ + Separate model parameters into decay and no-decay groups. + + Args: + model: Model instance + + Returns: + Tuple of (decay_params, no_decay_params) + """ + decay_params = [] + no_decay_params = [] + + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + + # Apply weight decay to all params except biases and layer norms + if 'bias' in name or 'norm' in name: + no_decay_params.append(param) + else: + decay_params.append(param) + + return decay_params, no_decay_params + + +@register_optimizer("adamw") +def create_adamw(model: BaseModel, config: TrainingConfig) -> optim.AdamW: + """ + Create AdamW optimizer with weight decay applied selectively. + + Args: + model: Model instance + config: TrainingConfig + + Returns: + AdamW optimizer instance + """ + optimizer_config = config.optimizer + + # Separate parameters for weight decay + decay_params, no_decay_params = _separate_parameters(model) + + param_groups = [ + {"params": decay_params, "weight_decay": optimizer_config.weight_decay}, + {"params": no_decay_params, "weight_decay": 0.0}, + ] + + optimizer = optim.AdamW( + param_groups, + lr=optimizer_config.learning_rate, + betas=optimizer_config.betas, + eps=optimizer_config.eps, + ) + + return optimizer diff --git a/code/TaoTrain/src/taoTrain/optimizers/hybrid_muon_adamw.py b/code/TaoTrain/src/taoTrain/optimizers/hybrid_muon_adamw.py new file mode 100644 index 0000000000000000000000000000000000000000..2165ce6f7ac9d9bb70d49e0ee4ea3713e04a7d30 --- /dev/null +++ b/code/TaoTrain/src/taoTrain/optimizers/hybrid_muon_adamw.py @@ -0,0 +1,243 @@ +""" +Hybrid Muon + AdamW Optimizer for TaoNet models. + +Combines: +- Muon: Specialized optimization for 2D weight matrices (linear layers) + Leverages orthogonal/SVD-based updates for better convergence on matrix weights +- AdamW: Adaptive moment estimation for 1D parameters (biases, norms, embeddings) + +Key Design: +- 2D weight matrices use Muon optimizer with separate LRs for different layer types +- 1D parameters use AdamW with lower learning rate +- Inherits from torch.optim.Optimizer for LR scheduler compatibility +""" + +import torch +import torch.nn as nn +from typing import Dict, List, Any + +from .registry import register_optimizer +from taoTrain.config import TrainingConfig +from taoTrain.core.base import BaseModel + + +def _get_param_dimensionality(param: torch.Tensor) -> str: + """ + Determine if a parameter is 2D (weight matrix) or 1D (bias/embedding/norm). + + Returns: + 'weight_2d': Parameter has 2+ dimensions (for Muon) + '1d_other': Parameter is 1D (for AdamW) + """ + if param.dim() >= 2: + return 'weight_2d' + return '1d_other' + + +class HybridMuonAdamW(torch.optim.Optimizer): + """ + Composite optimizer combining Muon (for 2D weights) and AdamW (for 1D params). + + Why: Muon is specialized for 2D weight matrices in neural networks. + Biases, embeddings, and layer norms should use AdamW for adaptive convergence. + + Inherits from torch.optim.Optimizer to be compatible with LR schedulers. + Manages two internal optimizers: Muon and AdamW. + + Public interface compatible with standard PyTorch optimizers: + - step(): delegates to both internal optimizers + - zero_grad(set_to_none=True): delegates to both + - state_dict(): returns combined state + - load_state_dict(state): restores combined state + """ + + def __init__( + self, + muon_params_groups: List[Dict[str, Any]], + adamw_params_group: Dict[str, Any], + muon_kwargs: Dict[str, Any], + adamw_kwargs: Dict[str, Any] + ): + """ + Initialize HybridMuonAdamW optimizer. + + Args: + muon_params_groups: List of param groups for Muon optimizer + Each group should have 'params' and 'lr' keys + adamw_params_group: Dict param group for AdamW optimizer + Should have 'params' and 'lr' keys + muon_kwargs: Additional kwargs for torch.optim.Muon init + adamw_kwargs: Additional kwargs for torch.optim.AdamW init + """ + # Dummy params list for parent Optimizer class (required for registration) + # Real params are managed by internal optimizers + dummy_param = torch.nn.Parameter(torch.zeros(1)) + super().__init__([dummy_param], {}) + + # Create internal optimizers with their parameter groups + try: + self.muon = torch.optim.Muon(muon_params_groups, **muon_kwargs) + except AttributeError: + raise RuntimeError( + "torch.optim.Muon not available. " + "Muon optimizer requires PyTorch 2.1+. " + "Please upgrade PyTorch: pip install --upgrade torch" + ) + + self.adamw = torch.optim.AdamW([adamw_params_group], **adamw_kwargs) + + # Merge param_groups from both optimizers + # LR schedulers will update these merged groups + self.param_groups = self.muon.param_groups + self.adamw.param_groups + + def step(self, closure=None): + """Execute optimization step for both Muon and AdamW.""" + if closure is not None: + loss = closure() + else: + loss = None + + self.muon.step(closure) + self.adamw.step(closure) + + return loss + + def zero_grad(self, set_to_none: bool = False): + """Zero gradients in both optimizers.""" + self.muon.zero_grad(set_to_none=set_to_none) + self.adamw.zero_grad(set_to_none=set_to_none) + + def state_dict(self) -> Dict[str, Any]: + """Return combined state dict for both optimizers.""" + return { + 'muon': self.muon.state_dict(), + 'adamw': self.adamw.state_dict(), + } + + def load_state_dict(self, state_dict: Dict[str, Any]): + """ + Restore state from combined state dict. + + Supports both new format (composite with Muon+AdamW) and legacy format + (AdamW-only checkpoints) for backward compatibility. + """ + if isinstance(state_dict, dict): + if 'muon' in state_dict and 'adamw' in state_dict: + # New format: composite optimizer with both Muon and AdamW + self.muon.load_state_dict(state_dict['muon']) + self.adamw.load_state_dict(state_dict['adamw']) + elif 'state' in state_dict or 'param_groups' in state_dict: + # Legacy format: old AdamW-only checkpoint + # Load into AdamW optimizer only, Muon starts fresh + try: + self.adamw.load_state_dict(state_dict) + print(" ⚠️ Loaded legacy AdamW-only checkpoint (Muon state initialized fresh)") + except Exception as e: + print(f" ⚠️ Failed to load optimizer state: {e}") + else: + print(f" ⚠️ Unknown checkpoint format") + else: + raise ValueError(f"Expected dict, got {type(state_dict)}") + + +@register_optimizer("hybrid_muon_adamw") +def create_hybrid_muon_adamw(model: BaseModel, training_config: TrainingConfig) -> HybridMuonAdamW: + """ + Factory function to create HybridMuonAdamW optimizer from model and config. + + Parameter grouping strategy: + - Muon groups (2D weight matrices): + * Regular Linear 2D weights → learning_rate + * (BitLinear would use bitlinear_lr, but skipped in BF16 version) + - AdamW group (1D parameters): + * Biases, layer norms, embeddings → adamw_lr + + Args: + model: PyTorch model to optimize + training_config: TrainingConfig with optimizer hyperparameters: + - learning_rate: LR for 2D Linear weights (Muon) + - adamw_lr: LR for 1D parameters (AdamW) + - weight_decay: L2 regularization + - betas: (beta1, beta2) for AdamW + - eps: epsilon for numerical stability + + Returns: + HybridMuonAdamW optimizer instance + """ + + # Separate parameters by dimensionality + linear_2d_weights = [] + params_1d = [] + + # Classify all parameters + for module_name, module in model.named_modules(): + for param_name, param in module.named_parameters(recurse=False): + if not param.requires_grad: + continue + + param_dim = _get_param_dimensionality(param) + + if param_dim == 'weight_2d' and isinstance(module, nn.Linear): + # 2D Linear weights → Muon + linear_2d_weights.append(param) + else: + # Everything else → AdamW (1D params + other 2D tensors) + params_1d.append(param) + + # Verify we got all parameters + total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + muon_params = sum(p.numel() for p in linear_2d_weights) + adamw_params = sum(p.numel() for p in params_1d) + assert total_params == muon_params + adamw_params, \ + f"Parameter accounting error: {total_params} != {muon_params} + {adamw_params}" + + # Prepare Muon parameter groups (one group with single LR for all Linear 2D weights) + muon_params_groups = [ + { + 'params': linear_2d_weights, + 'lr': training_config.optimizer.learning_rate, # Use main learning_rate for Muon + } + ] + + # Prepare AdamW parameter group (1D parameters with lower LR) + adamw_params_group = { + 'params': params_1d, + 'lr': training_config.optimizer.adamw_lr, # Use adamw_lr for 1D params + 'weight_decay': training_config.optimizer.weight_decay, + } + + # Extract Muon kwargs (settings common to all Muon param groups) + muon_kwargs = { + 'lr': training_config.optimizer.learning_rate, # Will be overridden by param_groups above + } + + # Extract AdamW kwargs + adamw_kwargs = { + 'betas': training_config.optimizer.betas, + 'eps': training_config.optimizer.eps, + 'weight_decay': training_config.optimizer.weight_decay, + } + + # Print optimizer setup details + print(f"\n{'='*70}") + print("OPTIMIZER SETUP - HYBRID MUON + ADAMW") + print(f"{'='*70}") + print("\n[MUON - 2D Weight Matrices (Orthogonal Optimization)]") + print(f"Linear 2D weights: {muon_params/1e6:>8.2f}M") + print(f" Learning Rate: {training_config.optimizer.learning_rate}") + print(f"\n[ADAMW - 1D Parameters (Adaptive Moments)]") + print(f"Biases, embeddings, norms: {adamw_params/1e6:>8.2f}M") + print(f" Learning Rate: {training_config.optimizer.adamw_lr}") + print(f"{'─'*70}") + print(f"Total (Muon): {muon_params/1e6:>8.2f}M") + print(f"Total (AdamW): {adamw_params/1e6:>8.2f}M") + print(f"Total (All): {total_params/1e6:>8.2f}M") + print(f"{'─'*70}") + print(f"Hyperparameters:") + print(f" Weight Decay: {training_config.optimizer.weight_decay}") + print(f" Betas (AdamW): {training_config.optimizer.betas}") + print(f" Epsilon: {training_config.optimizer.eps}") + print(f"{'='*70}\n") + + # Create and return optimizer + return HybridMuonAdamW(muon_params_groups, adamw_params_group, muon_kwargs, adamw_kwargs) diff --git a/code/TaoTrain/src/taoTrain/optimizers/registry.py b/code/TaoTrain/src/taoTrain/optimizers/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..2964c6d4163759d99a3ec46cad12be4eb73bd42c --- /dev/null +++ b/code/TaoTrain/src/taoTrain/optimizers/registry.py @@ -0,0 +1,77 @@ +"""Optimizer registry and factory for instantiating optimizers.""" + +from typing import Dict, Type, Callable, Any +import torch.optim as optim +from taoTrain.core.base import BaseModel +from taoTrain.config import TrainingConfig, OptimizerEnum + + +# Global registry for optimizers +_OPTIMIZER_REGISTRY: Dict[str, Callable] = {} + + +def register_optimizer(name: str): + """ + Decorator to register a custom optimizer factory function. + + Args: + name: Name of the optimizer (e.g., 'adamw', 'adam', 'sgd') + """ + def decorator(fn: Callable) -> Callable: + if name in _OPTIMIZER_REGISTRY: + raise ValueError(f"Optimizer '{name}' is already registered") + _OPTIMIZER_REGISTRY[name] = fn + return fn + return decorator + + +def get_registered_optimizers() -> Dict[str, Callable]: + """Get all registered optimizer factory functions.""" + return _OPTIMIZER_REGISTRY.copy() + + +def get_optimizer( + model: BaseModel, + config: TrainingConfig, +) -> optim.Optimizer: + """ + Create an optimizer instance from config. + + Args: + model: Model to optimize + config: TrainingConfig with optimizer configuration + + Returns: + Optimizer instance + + Raises: + ValueError: If optimizer type is not registered + """ + # Handle both enum and string values + optimizer_type = config.optimizer.optimizer_type + if isinstance(optimizer_type, str): + optimizer_name = optimizer_type + else: + optimizer_name = optimizer_type.value + + if optimizer_name not in _OPTIMIZER_REGISTRY: + raise ValueError( + f"Unknown optimizer: {optimizer_name}. " + f"Available: {list(_OPTIMIZER_REGISTRY.keys())}" + ) + + factory_fn = _OPTIMIZER_REGISTRY[optimizer_name] + return factory_fn(model, config) + + +def register_builtin_optimizers(): + """Register all built-in optimizers.""" + # Import here to trigger decorator registration (avoid circular imports) + from . import adamw # noqa: F401 + from . import adam # noqa: F401 + from . import sgd # noqa: F401 + from . import hybrid_muon_adamw # noqa: F401 + + +# Auto-register built-in optimizers when module is imported +register_builtin_optimizers() diff --git a/code/TaoTrain/src/taoTrain/optimizers/sgd.py b/code/TaoTrain/src/taoTrain/optimizers/sgd.py new file mode 100644 index 0000000000000000000000000000000000000000..2985f5fde9de9d9316af04714eeee7f3cc72be64 --- /dev/null +++ b/code/TaoTrain/src/taoTrain/optimizers/sgd.py @@ -0,0 +1,63 @@ +"""SGD optimizer factory.""" + +import torch.optim as optim +from taoTrain.core.base import BaseModel +from taoTrain.config import TrainingConfig +from .registry import register_optimizer + + +def _separate_parameters(model: BaseModel) -> tuple[list, list]: + """ + Separate model parameters into decay and no-decay groups. + + Args: + model: Model instance + + Returns: + Tuple of (decay_params, no_decay_params) + """ + decay_params = [] + no_decay_params = [] + + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + + # Apply weight decay to all params except biases and layer norms + if 'bias' in name or 'norm' in name: + no_decay_params.append(param) + else: + decay_params.append(param) + + return decay_params, no_decay_params + + +@register_optimizer("sgd") +def create_sgd(model: BaseModel, config: TrainingConfig) -> optim.SGD: + """ + Create SGD optimizer with weight decay applied selectively. + + Args: + model: Model instance + config: TrainingConfig + + Returns: + SGD optimizer instance + """ + optimizer_config = config.optimizer + + # Separate parameters for weight decay + decay_params, no_decay_params = _separate_parameters(model) + + param_groups = [ + {"params": decay_params, "weight_decay": optimizer_config.weight_decay}, + {"params": no_decay_params, "weight_decay": 0.0}, + ] + + optimizer = optim.SGD( + param_groups, + lr=optimizer_config.learning_rate, + momentum=optimizer_config.betas[0], # Use first beta as momentum + ) + + return optimizer diff --git a/code/TaoTrain/src/taoTrain/schedulers/__init__.py b/code/TaoTrain/src/taoTrain/schedulers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bb822276efdbbcadaed7a721d67243464abbf687 --- /dev/null +++ b/code/TaoTrain/src/taoTrain/schedulers/__init__.py @@ -0,0 +1,13 @@ +"""Learning rate scheduler registry and factories.""" + +from .registry import ( + register_scheduler, + get_scheduler, + get_registered_schedulers, +) + +__all__ = [ + "register_scheduler", + "get_scheduler", + "get_registered_schedulers", +] diff --git a/code/TaoTrain/src/taoTrain/schedulers/constant.py b/code/TaoTrain/src/taoTrain/schedulers/constant.py new file mode 100644 index 0000000000000000000000000000000000000000..39e15af13444e0a61bc03bae1c7bad9b7e2b2d92 --- /dev/null +++ b/code/TaoTrain/src/taoTrain/schedulers/constant.py @@ -0,0 +1,44 @@ +"""Constant learning rate scheduler with optional warmup.""" + +import torch.optim as optim +from torch.optim.lr_scheduler import LambdaLR +from taoTrain.config import TrainingConfig +from .registry import register_scheduler + + +@register_scheduler("constant") +def create_constant( + optimizer: optim.Optimizer, + config: TrainingConfig, + num_training_steps: int, +) -> LambdaLR: + """ + Create a constant learning rate scheduler with optional linear warmup. + + Linearly increases learning rate from 0 to peak over warmup steps, + then keeps it constant for the rest of training. + + Args: + optimizer: Optimizer instance + config: TrainingConfig with scheduler configuration + num_training_steps: Total number of training steps + + Returns: + LambdaLR scheduler instance + """ + scheduler_config = config.scheduler + + # Determine warmup steps + if scheduler_config.warmup_steps > 0: + warmup_steps = scheduler_config.warmup_steps + else: + warmup_steps = int(num_training_steps * scheduler_config.warmup_ratio) + + def lr_lambda(step): + """Constant learning rate with optional warmup.""" + if step < warmup_steps: + # Linear warmup + return float(step) / float(max(1, warmup_steps)) + return 1.0 + + return LambdaLR(optimizer, lr_lambda, last_epoch=scheduler_config.last_epoch) diff --git a/code/TaoTrain/src/taoTrain/schedulers/cosine_warmup.py b/code/TaoTrain/src/taoTrain/schedulers/cosine_warmup.py new file mode 100644 index 0000000000000000000000000000000000000000..b70c611b989637fbb1e93054fc3ef9f6e1dc0254 --- /dev/null +++ b/code/TaoTrain/src/taoTrain/schedulers/cosine_warmup.py @@ -0,0 +1,77 @@ +"""Cosine annealing with warmup learning rate scheduler.""" + +import math +import torch.optim as optim +from torch.optim.lr_scheduler import LambdaLR +from taoTrain.config import TrainingConfig +from .registry import register_scheduler + + +@register_scheduler("cosineWarmup") +def create_cosine_warmup( + optimizer: optim.Optimizer, + config: TrainingConfig, + num_training_steps: int, +) -> LambdaLR: + """ + Create a cosine annealing scheduler with optional linear warmup, steady phase, and decay. + + Three-phase schedule: + 1. Linear warmup: 0 → 1.0 (warmup_steps) + 2. Steady phase: 1.0 (plateau at peak LR) + 3. Cosine decay: 1.0 → min_lr_ratio + + Args: + optimizer: Optimizer instance + config: TrainingConfig with scheduler configuration: + - warmup_steps: linear warmup duration (overrides warmup_ratio if > 0) + - warmup_ratio: warmup as fraction of total steps (default 0.1) + - steady_ratio: steady phase as fraction of total steps (default 0.0) + - min_lr_ratio: minimum LR at end as fraction of peak (default 0.0) + num_training_steps: Total number of training steps + + Returns: + LambdaLR scheduler instance + """ + scheduler_config = config.scheduler + + # Determine warmup steps + if scheduler_config.warmup_steps > 0: + warmup_steps = scheduler_config.warmup_steps + else: + warmup_steps = int(num_training_steps * scheduler_config.warmup_ratio) + + # Determine steady phase steps + steady_steps = int(num_training_steps * scheduler_config.steady_ratio) + + # Remaining steps for cosine decay + decay_steps = num_training_steps - warmup_steps - steady_steps + + min_lr_ratio = scheduler_config.min_lr_ratio + num_cycles = scheduler_config.num_cycles + + print(f"✓ CosineWarmup scheduler: warmup={warmup_steps}, steady={steady_steps}, decay={decay_steps} (total={num_training_steps})") + print(f" min_lr_ratio={min_lr_ratio}, num_cycles={num_cycles}") + + def lr_lambda(step): + """Three-phase LR schedule: warmup → steady → cosine decay.""" + if step < warmup_steps: + # Phase 1: Linear warmup from 0 to 1.0 + return float(step) / float(max(1, warmup_steps)) + + elif step < warmup_steps + steady_steps: + # Phase 2: Steady at peak LR (1.0) + return 1.0 + + else: + # Phase 3: Cosine decay from 1.0 to min_lr_ratio + decay_step = step - warmup_steps - steady_steps + progress = float(decay_step) / float(max(1, decay_steps)) + + # Cosine annealing: 0.5 * (1 + cos(π * progress)) + cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress)) + + # Scale to reach min_lr_ratio at the end + return cosine_decay * (1.0 - min_lr_ratio) + min_lr_ratio + + return LambdaLR(optimizer, lr_lambda, last_epoch=scheduler_config.last_epoch) diff --git a/code/TaoTrain/src/taoTrain/schedulers/linear_warmup.py b/code/TaoTrain/src/taoTrain/schedulers/linear_warmup.py new file mode 100644 index 0000000000000000000000000000000000000000..30ac6f51ca94c59dcf4cd66c01e973e6f0a793cc --- /dev/null +++ b/code/TaoTrain/src/taoTrain/schedulers/linear_warmup.py @@ -0,0 +1,43 @@ +"""Linear warmup learning rate scheduler.""" + +import torch.optim as optim +from torch.optim.lr_scheduler import LambdaLR +from taoTrain.config import TrainingConfig +from .registry import register_scheduler + + +@register_scheduler("linearWarmup") +def create_linear_warmup( + optimizer: optim.Optimizer, + config: TrainingConfig, + num_training_steps: int, +) -> LambdaLR: + """ + Create a linear warmup scheduler. + + Linearly increases learning rate from 0 to peak over warmup steps, + then keeps it constant. + + Args: + optimizer: Optimizer instance + config: TrainingConfig with scheduler configuration + num_training_steps: Total number of training steps + + Returns: + LambdaLR scheduler instance + """ + scheduler_config = config.scheduler + + # Determine warmup steps + if scheduler_config.warmup_steps > 0: + warmup_steps = scheduler_config.warmup_steps + else: + warmup_steps = int(num_training_steps * scheduler_config.warmup_ratio) + + def lr_lambda(step): + """Linear warmup learning rate schedule.""" + if step < warmup_steps: + return float(step) / float(max(1, warmup_steps)) + return 1.0 + + return LambdaLR(optimizer, lr_lambda, last_epoch=scheduler_config.last_epoch) diff --git a/code/TaoTrain/src/taoTrain/schedulers/registry.py b/code/TaoTrain/src/taoTrain/schedulers/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..519ea047f0d2be35c34fc218d1881942515089ad --- /dev/null +++ b/code/TaoTrain/src/taoTrain/schedulers/registry.py @@ -0,0 +1,78 @@ +"""Scheduler registry and factory for instantiating learning rate schedulers.""" + +from typing import Dict, Callable, Optional +import torch.optim as optim +from torch.optim.lr_scheduler import LambdaLR +from taoTrain.config import TrainingConfig, SchedulerEnum + + +# Global registry for schedulers +_SCHEDULER_REGISTRY: Dict[str, Callable] = {} + + +def register_scheduler(name: str): + """ + Decorator to register a custom scheduler factory function. + + Args: + name: Name of the scheduler (e.g., 'linearWarmup', 'cosineWarmup', 'constant') + """ + def decorator(fn: Callable) -> Callable: + if name in _SCHEDULER_REGISTRY: + raise ValueError(f"Scheduler '{name}' is already registered") + _SCHEDULER_REGISTRY[name] = fn + return fn + return decorator + + +def get_registered_schedulers() -> Dict[str, Callable]: + """Get all registered scheduler factory functions.""" + return _SCHEDULER_REGISTRY.copy() + + +def get_scheduler( + optimizer: optim.Optimizer, + config: TrainingConfig, + num_training_steps: int, +) -> LambdaLR: + """ + Create a learning rate scheduler instance from config. + + Args: + optimizer: Optimizer to schedule learning rate for + config: TrainingConfig with scheduler configuration + num_training_steps: Total number of training steps + + Returns: + Learning rate scheduler instance + + Raises: + ValueError: If scheduler type is not registered + """ + # Handle both enum and string values + scheduler_type = config.scheduler.scheduler_type + if isinstance(scheduler_type, str): + scheduler_name = scheduler_type + else: + scheduler_name = scheduler_type.value + + if scheduler_name not in _SCHEDULER_REGISTRY: + raise ValueError( + f"Unknown scheduler: {scheduler_name}. " + f"Available: {list(_SCHEDULER_REGISTRY.keys())}" + ) + + factory_fn = _SCHEDULER_REGISTRY[scheduler_name] + return factory_fn(optimizer, config, num_training_steps) + + +def register_builtin_schedulers(): + """Register all built-in schedulers.""" + # Import here to trigger decorator registration (avoid circular imports) + from . import linear_warmup # noqa: F401 + from . import cosine_warmup # noqa: F401 + from . import constant # noqa: F401 + + +# Auto-register built-in schedulers when module is imported +register_builtin_schedulers() diff --git a/code/TaoTrain/src/taoTrain/tokenizers/__init__.py b/code/TaoTrain/src/taoTrain/tokenizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..676788a50607b8f6f96950fe9cca81d1b5d63b3f --- /dev/null +++ b/code/TaoTrain/src/taoTrain/tokenizers/__init__.py @@ -0,0 +1,5 @@ +"""Tokenizer utilities and training.""" + +from .trainer import TokenizerTrainer + +__all__ = ["TokenizerTrainer"] diff --git a/code/TaoTrain/src/taoTrain/tokenizers/trainer.py b/code/TaoTrain/src/taoTrain/tokenizers/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..657572bbfbdd61784e260803d431a8648154d38c --- /dev/null +++ b/code/TaoTrain/src/taoTrain/tokenizers/trainer.py @@ -0,0 +1,249 @@ +"""Tokenizer training utilities.""" + +import os +import json +import tempfile +from pathlib import Path +from typing import Optional, Dict, Any +from taoTrain.data.chunk_manager import ChunkManager + + +class TokenizerTrainer: + """Train SentencePiece tokenizers from JSONL data.""" + + @staticmethod + def train_from_config(config: "TokenizerConfig") -> dict: # type: ignore + """ + Train a tokenizer from a TokenizerConfig object. + + Args: + config: TokenizerConfig instance + + Returns: + Dict with paths to generated tokenizer files + """ + # Build special tokens string for SentencePiece if provided + user_defined_symbols = None + if config.special_tokens: + # Sort by ID and format as comma-separated tokens + sorted_tokens = sorted(config.special_tokens.items(), key=lambda x: x[1]) + user_defined_symbols = ','.join([token for token, _ in sorted_tokens]) + + return TokenizerTrainer.train_sentencepiece( + jsonl_path=config.jsonl_path, + output_dir=config.output_dir, + vocab_size=config.vocab_size, + model_type=config.model_type, + character_coverage=config.character_coverage, + unk_id=config.unk_id, + bos_id=config.bos_id, + eos_id=config.eos_id, + pad_id=config.pad_id, + tokenizer_prefix=config.tokenizer_prefix, + text_field=config.text_field, + max_samples=config.max_samples, + user_defined_symbols=user_defined_symbols, + ) + + @staticmethod + def train_sentencepiece( + jsonl_path: str, + output_dir: str = "tokenizers", + vocab_size: int = 50000, + model_type: str = "unigram", + character_coverage: float = 0.9995, + unk_id: int = 0, + bos_id: int = 1, + eos_id: int = 2, + pad_id: int = 3, + tokenizer_prefix: Optional[str] = None, + text_field: str = "text", + max_samples: Optional[int] = None, + user_defined_symbols: Optional[str] = None, + ) -> dict: + """ + Train a SentencePiece tokenizer from JSONL data. + + Args: + jsonl_path: Path to JSONL file containing text data + output_dir: Directory to save tokenizer files + vocab_size: Vocabulary size for the tokenizer + model_type: Model type (unigram, bpe, char, word) + character_coverage: Character coverage for SentencePiece + unk_id: Unknown token ID + bos_id: Beginning of sentence token ID + eos_id: End of sentence token ID + pad_id: Padding token ID + tokenizer_prefix: Prefix for tokenizer model files (default: model_type) + text_field: Field name in JSONL for text data (default: "text") + max_samples: Limit training to first N samples (optional) + user_defined_symbols: Custom special tokens as comma-separated string (optional) + + Returns: + Dict with paths to generated tokenizer files + + Raises: + ImportError: If SentencePiece is not installed + FileNotFoundError: If JSONL file doesn't exist + ValueError: If JSONL file is invalid or empty + """ + try: + import sentencepiece as spm + except ImportError: + raise ImportError( + "SentencePiece not installed. Install with: pip install sentencepiece" + ) + + # Validate paths + jsonl_path = Path(jsonl_path) + if not jsonl_path.exists(): + raise FileNotFoundError(f"JSONL file not found: {jsonl_path}") + + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # Set tokenizer prefix + if tokenizer_prefix is None: + tokenizer_prefix = model_type + + # Extract text data from JSONL using ChunkManager for large files + if max_samples: + print(f"📖 Reading JSONL file (up to {max_samples:,} samples): {jsonl_path}") + else: + print(f"📖 Reading JSONL file: {jsonl_path}") + + # Use ChunkManager for efficient streaming (with metadata caching enabled) + chunk_manager = ChunkManager( + jsonl_path, + chunk_size_gb=5.0, + enable_metadata_cache=True, + chunk_cache_dir=".cache/chunks" + ) + + # Write text to temporary file for SentencePiece (use streaming) + with tempfile.NamedTemporaryFile( + mode='w', + suffix='.txt', + delete=False, + encoding='utf-8' + ) as tmp: + text_count = 0 + + # Process chunks one at a time + for chunk_num in range(chunk_manager.num_chunks): + print(f" - Processing chunk {chunk_num + 1}/{chunk_manager.num_chunks}...") + chunk_examples = chunk_manager.read_chunk(chunk_num) + + for obj in chunk_examples: + # Check if we've reached max_samples limit + if max_samples and text_count >= max_samples: + break + + # Extract text from specified field or try common field names + text = None + if text_field in obj and isinstance(obj[text_field], str): + text = obj[text_field] + else: + # Fallback to common field names + for field in ['text', 'content', 'data', 'body']: + if field in obj and isinstance(obj[field], str): + text = obj[field] + break + + if text: + # Clean text: remove newlines, extra spaces + clean_text = ' '.join(text.split()) + tmp.write(clean_text + '\n') + text_count += 1 + + # Break outer loop if max_samples reached + if max_samples and text_count >= max_samples: + print(f"Reached max_samples limit of {max_samples:,}. Stopping data processing.") + break + + tmp_path = tmp.name + + if text_count == 0: + os.remove(tmp_path) + raise ValueError("No valid text data found in JSONL file") + + sample_info = f"{text_count:,} samples" if max_samples else f"{text_count:,} lines" + print(f"✓ Processed {sample_info} with text data from {chunk_manager.num_chunks} chunks") + + try: + # Train SentencePiece model + print(f"🔧 Training SentencePiece {model_type} tokenizer...") + print(f" - Vocabulary size: {vocab_size}") + print(f" - Character coverage: {character_coverage}") + if user_defined_symbols: + print(f" - Special tokens: {user_defined_symbols}") + + model_path = output_dir / tokenizer_prefix + + # Prepare training arguments + train_kwargs = { + 'input': tmp_path, + 'model_prefix': str(model_path), + 'vocab_size': vocab_size, + 'model_type': model_type, + 'character_coverage': character_coverage, + 'unk_id': unk_id, + 'bos_id': bos_id, + 'eos_id': eos_id, + 'pad_id': pad_id, + # Additional options + 'normalization_rule_name': 'identity', + 'split_digits': True + } + + # Add user-defined symbols if provided + if user_defined_symbols: + train_kwargs['user_defined_symbols'] = user_defined_symbols + + spm.SentencePieceTrainer.train(**train_kwargs) + + model_file = model_path.with_suffix('.model') + vocab_file = model_path.with_suffix('.vocab') + + if model_file.exists() and vocab_file.exists(): + print(f"✅ Tokenizer trained successfully!") + print(f" - Model: {model_file}") + print(f" - Vocab: {vocab_file}") + + return { + "model_file": str(model_file), + "vocab_file": str(vocab_file), + "output_dir": str(output_dir), + "vocab_size": vocab_size, + "model_type": model_type, + } + else: + raise RuntimeError("SentencePiece training didn't produce output files") + + finally: + # Clean up temporary file + if os.path.exists(tmp_path): + os.remove(tmp_path) + + @staticmethod + def validate_tokenizer(model_path: str) -> bool: + """ + Validate that a SentencePiece tokenizer file is valid. + + Args: + model_path: Path to .model file + + Returns: + True if valid, False otherwise + """ + try: + import sentencepiece as spm + sp = spm.SentencePieceProcessor() + sp.Load(model_path) + # Try a simple encode/decode + test_text = "Hello world" + tokens = sp.encode(test_text, out_type=int) + decoded = sp.decode(tokens) + return len(tokens) > 0 + except Exception: + return False diff --git a/code/Taotern_LLM_Experiments/docs/Taotern_Documentation_AI_Architecture.zip b/code/Taotern_LLM_Experiments/docs/Taotern_Documentation_AI_Architecture.zip new file mode 100644 index 0000000000000000000000000000000000000000..6be913adc7ba86afa1e7f944fc5a441e6c61948d --- /dev/null +++ b/code/Taotern_LLM_Experiments/docs/Taotern_Documentation_AI_Architecture.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c8aa7c1a1f7bdaae0f9bd62dc9a8e47f8623dae9db70dc16431127b0e1d81443 +size 59482 diff --git a/code/Taotern_SSM/Gamma Distributed Ternary HiPPO.pdf b/code/Taotern_SSM/Gamma Distributed Ternary HiPPO.pdf new file mode 100644 index 0000000000000000000000000000000000000000..34295ee95659a44da78228d8f0565a329a3bca25 --- /dev/null +++ b/code/Taotern_SSM/Gamma Distributed Ternary HiPPO.pdf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:914c54870bc267c4a0182f18056b5fade98e942fe83a7bbbca3dafb7c507bd2d +size 117289 diff --git a/tokenizer.model b/tokenizer.model new file mode 100644 index 0000000000000000000000000000000000000000..ebbc9f06d6aeae5eb2c9e5210d7ee9f4eee72c9e --- /dev/null +++ b/tokenizer.model @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8a501d7d4bee4e7db1833ce27e613099a6042feeb947c469901a03e6f7ecae08 +size 144164 diff --git a/tokenizer/tokenizer.model b/tokenizer/tokenizer.model new file mode 100644 index 0000000000000000000000000000000000000000..ebbc9f06d6aeae5eb2c9e5210d7ee9f4eee72c9e --- /dev/null +++ b/tokenizer/tokenizer.model @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8a501d7d4bee4e7db1833ce27e613099a6042feeb947c469901a03e6f7ecae08 +size 144164