Text Generation
Transformers
PyTorch
English
taonet_mini_t2
taonet
taotern
ssm
state-space-model
dplr
custom_code
experimental
Instructions to use TaoTern/TaoNet-mini-T2 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use TaoTern/TaoNet-mini-T2 with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="TaoTern/TaoNet-mini-T2", trust_remote_code=True)# Load model directly from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("TaoTern/TaoNet-mini-T2", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- vLLM
How to use TaoTern/TaoNet-mini-T2 with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "TaoTern/TaoNet-mini-T2" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "TaoTern/TaoNet-mini-T2", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/TaoTern/TaoNet-mini-T2
- SGLang
How to use TaoTern/TaoNet-mini-T2 with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "TaoTern/TaoNet-mini-T2" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "TaoTern/TaoNet-mini-T2", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "TaoTern/TaoNet-mini-T2" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "TaoTern/TaoNet-mini-T2", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use TaoTern/TaoNet-mini-T2 with Docker Model Runner:
docker model run hf.co/TaoTern/TaoNet-mini-T2
| """Pydantic configuration schemas for TaoTrain.""" | |
| from enum import Enum | |
| from typing import Optional, Literal | |
| from pathlib import Path | |
| import json | |
| from pydantic import BaseModel as PydanticBaseModel, Field, validator | |
| import yaml | |
| # ============================================================================ | |
| # Enums | |
| # ============================================================================ | |
| class DataTypeEnum(str, Enum): | |
| """Data types for training.""" | |
| FLOAT32 = "float32" | |
| FLOAT16 = "float16" | |
| BFLOAT16 = "bfloat16" | |
| class OptimizerEnum(str, Enum): | |
| """Supported optimizers.""" | |
| ADAM = "adam" | |
| ADAMW = "adamw" | |
| SGD = "sgd" | |
| HYBRID_MUON_ADAMW = "hybrid_muon_adamw" | |
| class ModelArchitectureEnum(str, Enum): | |
| """Built-in model architectures.""" | |
| TRANSFORMER = "transformer" | |
| TAONET = "taonet" | |
| TAONET_SSM = "taonet_ssm" | |
| TAONET_HYBRID = "taonet_hybrid" | |
| class SchedulerEnum(str, Enum): | |
| """Supported learning rate schedulers.""" | |
| LINEAR_WARMUP = "linearWarmup" | |
| COSINE_WARMUP = "cosineWarmup" | |
| CONSTANT = "constant" | |
| class RLMethodEnum(str, Enum): | |
| """Supported RL training methods.""" | |
| PPO = "ppo" | |
| DPO = "dpo" | |
| class TrainingModeEnum(str, Enum): | |
| """Training stages.""" | |
| PRETRAIN = "pretrain" | |
| SFT = "sft" | |
| RL = "rl" | |
| # ============================================================================ | |
| # Base Configs | |
| # ============================================================================ | |
| class BaseConfig(PydanticBaseModel): | |
| """Base Pydantic model with utility methods.""" | |
| class Config: | |
| """Pydantic config.""" | |
| arbitrary_types_allowed = True | |
| def to_dict(self) -> dict: | |
| """Convert to dictionary.""" | |
| data = self.model_dump(mode='json') # Enums -> strings | |
| return data | |
| def to_json_str(self) -> str: | |
| """Convert to JSON string.""" | |
| return json.dumps(self.to_dict(), indent=2) | |
| def save_yaml(self, path: str | Path) -> None: | |
| """Save config to YAML file.""" | |
| path = Path(path) | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| with open(path, 'w') as f: | |
| yaml.dump(self.to_dict(), f, default_flow_style=False, sort_keys=False) | |
| def save_json(self, path: str | Path) -> None: | |
| """Save config to JSON file.""" | |
| path = Path(path) | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| with open(path, 'w') as f: | |
| f.write(self.to_json_str()) | |
| def load_yaml(cls, path: str | Path) -> "BaseConfig": | |
| """Load config from YAML file.""" | |
| with open(path) as f: | |
| data = yaml.safe_load(f) | |
| return cls(**data) | |
| def load_json(cls, path: str | Path) -> "BaseConfig": | |
| """Load config from JSON file.""" | |
| with open(path) as f: | |
| data = json.load(f) | |
| return cls(**data) | |
| # ============================================================================ | |
| # Model Config | |
| # ============================================================================ | |
| class ModelConfig(BaseConfig): | |
| """Configuration for model architecture.""" | |
| architecture_type: ModelArchitectureEnum = Field( | |
| default=ModelArchitectureEnum.TRANSFORMER, | |
| description="Type of model architecture" | |
| ) | |
| # Transformer-specific | |
| vocab_size: int = Field(default=50257, description="Vocabulary size") | |
| hidden_dim: int = Field(default=768, description="Hidden dimension") | |
| num_layers: int = Field(default=12, description="Number of transformer blocks") | |
| num_heads: int = Field(default=12, description="Number of attention heads") | |
| head_dim: Optional[int] = Field( | |
| default=None, | |
| description="Head dimension (defaults to hidden_dim // num_heads)" | |
| ) | |
| intermediate_dim: Optional[int] = Field( | |
| default=None, | |
| description="FFN intermediate dimension (defaults to 4 * hidden_dim)" | |
| ) | |
| dropout: float = Field(default=0.1, description="Dropout rate") | |
| max_seq_length: int = Field(default=2048, description="Maximum sequence length") | |
| # TaoNet (DeepSeek MLA) specific | |
| d_latent_kv: Optional[int] = Field( | |
| default=None, | |
| description="KV compression dimension for MLA (defaults to 3/4 * hidden_dim). Only used for taonet architecture." | |
| ) | |
| d_rope: Optional[int] = Field( | |
| default=None, | |
| description="RoPE dimension per head (defaults to hidden_dim // num_heads). Only used for taonet architecture." | |
| ) | |
| gqa_groups: int = Field( | |
| default=1, | |
| description="Grouped Query Attention groups (1 = standard MLA, >1 = GQA). Only used for taonet architecture." | |
| ) | |
| hidden_dim_ff: Optional[int] = Field( | |
| default=None, | |
| description="Feed-forward intermediate dimension (defaults to 4 * hidden_dim)." | |
| ) | |
| use_factorized_embedding: bool = Field( | |
| default=False, | |
| description="Use low-rank factorized embedding instead of standard embedding (reduces params). Only for taonet." | |
| ) | |
| d_embed_rank: int = Field( | |
| default=96, | |
| description="Rank dimension for factorized embedding. Only used if use_factorized_embedding=True." | |
| ) | |
| # YaRN (Yet another RoPE eXtension) for context length extension | |
| rope_scale: float = Field( | |
| default=40.0, | |
| description="Base RoPE scale factor (default: 40.0). Controls position frequency base." | |
| ) | |
| yarn_enabled: bool = Field( | |
| default=False, | |
| description="Enable YaRN (Yet another RoPE eXtension) for context length interpolation." | |
| ) | |
| yarn_alpha: float = Field( | |
| default=1.0, | |
| description="YaRN interpolation smoothness (1.0=smooth, <1.0=aggressive, >1.0=conservative). Only used if yarn_enabled=True." | |
| ) | |
| # TaoNet-SSM specific: SSM mixer replacing MLA attention | |
| ssm_core: Literal["gamma_s4", "dplr"] = Field( | |
| default="gamma_s4", | |
| description="SSM core used by taonet_ssm. Use dplr for the ternary-aware DPLR SSM." | |
| ) | |
| ssm_hidden_dim: Optional[int] = Field( | |
| default=None, | |
| description="SSM hidden/state dimension for taonet_ssm. Defaults to d_latent_kv or hidden_dim." | |
| ) | |
| ssm_mixer_dim: Optional[int] = Field( | |
| default=None, | |
| description="Channel dimension processed by the SSM mixer. Defaults to hidden_dim; smaller values use an input/output projection bottleneck." | |
| ) | |
| ssm_num_lanes: int = Field( | |
| default=1, | |
| description="Number of independent SSM lanes inside each SSM mixer. Multiple lanes add SSM capacity with cheap elementwise combination." | |
| ) | |
| ssm_lane_combine: Literal["mean", "channel"] = Field( | |
| default="mean", | |
| description="How to combine multiple SSM lanes. Channel uses learned per-lane/per-channel elementwise weights." | |
| ) | |
| ssm_lane_mode: Literal["full", "split"] = Field( | |
| default="full", | |
| description="Whether each SSM lane processes the full mixer dimension or a disjoint split of the mixer channels." | |
| ) | |
| ssm_split_mix: Literal["none", "hadamard"] = Field( | |
| default="none", | |
| description="Optional ternary-friendly cross-lane mixer for split SSM lanes." | |
| ) | |
| ssm_rank: int = Field( | |
| default=1, | |
| description="Low-rank correction rank for ssm_core=dplr." | |
| ) | |
| ssm_max_low_rank_scale: float = Field( | |
| default=0.1, | |
| description="Maximum low-rank correction scale for ssm_core=dplr." | |
| ) | |
| ssm_finite_tail_correction: bool = Field( | |
| default=True, | |
| description="Enable exact finite-length tail correction for ssm_core=dplr. Disable for the faster approximate DPLR path." | |
| ) | |
| ssm_discretization: Literal["bilinear", "zoh", "euler"] = Field( | |
| default="bilinear", | |
| description="Discretization used by the Gamma SSM mixer." | |
| ) | |
| ssm_kernel_mode: Literal["auto", "recurrent", "conv", "conv_transfer"] = Field( | |
| default="auto", | |
| description="Gamma SSM execution path. Use auto/conv for full-sequence GPU training, conv_transfer to materialize frequency transfers, recurrent for step-wise tests." | |
| ) | |
| ssm_kernel_threshold: int = Field( | |
| default=64, | |
| description="Minimum sequence length for auto mode to use the convolutional Gamma SSM path." | |
| ) | |
| ssm_dt_min: float = Field(default=1e-3, description="Minimum learned SSM timestep.") | |
| ssm_dt_max: float = Field(default=1e-1, description="Maximum learned SSM timestep.") | |
| ssm_dt_init: float = Field(default=1e-2, description="Initial learned SSM timestep.") | |
| ssm_use_d: bool = Field(default=True, description="Enable direct skip term D in the Gamma SSM.") | |
| ssm_activation: Literal["gelu", "silu", "identity", "linear"] = Field( | |
| default="gelu", | |
| description="Activation applied to the Gamma SSM branch output." | |
| ) | |
| ssm_gate: bool = Field(default=True, description="Enable output gate on the Gamma SSM branch.") | |
| ssm_input_gate: bool = Field(default=True, description="Enable input gate before the Gamma SSM.") | |
| ssm_gate_type: Literal["dense", "channel"] = Field( | |
| default="dense", | |
| description="Gate implementation for enabled SSM input/output gates. Channel gates are elementwise and ternary-friendly." | |
| ) | |
| ssm_use_padding_mask: bool = Field( | |
| default=False, | |
| description="Apply dataset padding masks inside the SSM. Disabled by default so training can use the convolutional path." | |
| ) | |
| ssm_layer_scale_init: float = Field( | |
| default=0.1, | |
| description="Initial layer-scale multiplier for the Gamma SSM branch." | |
| ) | |
| ssm_branch_rms_norm: bool = Field( | |
| default=False, | |
| description="Normalize the SSM residual branch to unit RMS before layer-scale. Useful for stabilizing deep SSM/hybrid runs." | |
| ) | |
| ssm_branch_rms_eps: float = Field( | |
| default=1e-6, | |
| description="Numerical epsilon for optional SSM branch RMS normalization." | |
| ) | |
| ssm_branch_clip_value: Optional[float] = Field( | |
| default=None, | |
| description="Optional symmetric clamp applied to the SSM residual branch after layer-scale. None disables clamping." | |
| ) | |
| block_residual_rms_norm: bool = Field( | |
| default=False, | |
| description="Normalize the residual stream RMS after block residual additions. Intended for stabilizing deep SSM/hybrid experiments." | |
| ) | |
| block_residual_rms_target: float = Field( | |
| default=1.0, | |
| description="Target per-token RMS when block_residual_rms_norm is enabled." | |
| ) | |
| block_residual_rms_cap: Optional[float] = Field( | |
| default=None, | |
| description="Optional per-token RMS cap for the residual stream. Unlike block_residual_rms_norm, this only scales down tokens whose RMS exceeds the cap." | |
| ) | |
| block_residual_rms_eps: float = Field( | |
| default=1e-6, | |
| description="Numerical epsilon for optional block residual RMS normalization." | |
| ) | |
| ssm_local_shift: bool = Field( | |
| default=False, | |
| description="Add a cheap one-token causal shift/register branch to the taonet_ssm mixer." | |
| ) | |
| ssm_local_shift_init: float = Field( | |
| default=0.1, | |
| description="Initial scalar weight for the optional one-token local shift/register branch." | |
| ) | |
| ssm_local_shift_per_channel: bool = Field( | |
| default=False, | |
| description="Use one learned local-shift gain per model channel instead of one scalar." | |
| ) | |
| hybrid_pattern: Literal["attention_first", "ssm_first", "single_ssm_middle", "single_ssm_late"] = Field( | |
| default="attention_first", | |
| description="Layer pattern for taonet_hybrid when hybrid_ssm_layers is not set." | |
| ) | |
| hybrid_ssm_layers: Optional[str] = Field( | |
| default=None, | |
| description="Optional comma-separated 0-based layer indices that should use SSM blocks in taonet_hybrid." | |
| ) | |
| # Initializations | |
| init_std: float = Field(default=0.02, description="Weight initialization standard deviation") | |
| def validate_head_dim(cls, v, values): | |
| """Validate head dimension.""" | |
| if v is None and 'hidden_dim' in values: | |
| return values['hidden_dim'] // values.get('num_heads', 12) | |
| return v | |
| def validate_intermediate_dim(cls, v, values): | |
| """Validate intermediate dimension.""" | |
| if v is None and 'hidden_dim' in values: | |
| return 4 * values['hidden_dim'] | |
| return v | |
| # ============================================================================ | |
| # Dataset Config | |
| # ============================================================================ | |
| class DatasetConfig(BaseConfig): | |
| """Configuration for dataset loading.""" | |
| # Local vs HuggingFace dataset selection | |
| local: bool = Field(default=False, description="Use local JSONL dataset instead of HuggingFace") | |
| # HuggingFace dataset fields | |
| dataset_name: Optional[str] = Field(default=None, description="HuggingFace dataset name (e.g., 'wikitext', 'openwebtext')") | |
| split: str = Field(default="train", description="Dataset split to use") | |
| config: Optional[str] = Field(default=None, description="Dataset config if multi-config (e.g., 'wikitext-103')") | |
| # Local JSONL dataset fields | |
| jsonl_path: Optional[str] = Field(default=None, description="Path to local JSONL dataset file") | |
| text_field: str = Field(default="text", description="Name of text field in JSONL") | |
| # Text column name varies by dataset | |
| text_column: str = Field(default="text", description="Name of text column in dataset") | |
| # Preprocessing | |
| max_samples: Optional[int] = Field( | |
| default=None, | |
| description="Limit dataset to N samples (useful for debugging)" | |
| ) | |
| cache_dir: str = Field(default=".cache/datasets", description="HuggingFace cache directory") | |
| # For SFT/RL datasets with instruction-response format | |
| instruction_column: Optional[str] = Field(default=None, description="Instruction column for SFT") | |
| response_column: Optional[str] = Field(default=None, description="Response column for SFT") | |
| prompt_column: Optional[str] = Field(default=None, description="Prompt column for RL") | |
| # Instruction template | |
| instruction_template: Optional[str] = Field( | |
| default=None, | |
| description="Template for combining instruction and response. E.g., '{instruction}\\n{response}'" | |
| ) | |
| # Tokenizer configuration | |
| tokenizer_type: Optional[str] = Field( | |
| default=None, | |
| description="Tokenizer type: 'huggingface' or 'sentencepiece'. If None, defaults based on tokenizer_path." | |
| ) | |
| tokenizer_path: Optional[str] = Field( | |
| default=None, | |
| description="Path to saved tokenizer (for SentencePiece: .model file, for HuggingFace: model name or local path)" | |
| ) | |
| # Chunked loading for large JSONL files | |
| enable_streaming: bool = Field( | |
| default=True, | |
| description="Enable streaming/chunked loading for large JSONL files to reduce memory usage" | |
| ) | |
| chunk_size_gb: float = Field( | |
| default=5.0, | |
| description="Approximate chunk size in GB (ignored if samples_per_chunk is set)" | |
| ) | |
| samples_per_chunk: Optional[int] = Field( | |
| default=1000, | |
| description="Number of samples per chunk (takes precedence over chunk_size_gb). Default: 1000 samples" | |
| ) | |
| # Chunk caching | |
| enable_chunk_metadata_cache: bool = Field( | |
| default=True, | |
| description="Enable caching of chunk metadata (file scan results) to avoid re-scanning large JSONL files" | |
| ) | |
| enable_chunk_data_cache: bool = Field( | |
| default=False, | |
| description="Enable caching of actual chunk data as separate files for faster loading (uses more disk space)" | |
| ) | |
| chunk_cache_dir: str = Field( | |
| default=".cache/chunks", | |
| description="Directory to store chunk metadata and data cache files" | |
| ) | |
| # Tokenization parallelization | |
| tokenizer_threads: int = Field( | |
| default=1, | |
| description="Number of background threads for tokenization (1-32 recommended). Higher values speed up tokenization but increase memory usage." | |
| ) | |
| def validate_dataset_source(cls, v, values): | |
| """Validate that either local JSONL or HuggingFace dataset is specified.""" | |
| local = values.get('local', False) | |
| dataset_name = values.get('dataset_name') | |
| if local and not v: | |
| raise ValueError("jsonl_path must be provided when local=True") | |
| if not local and not dataset_name: | |
| raise ValueError("dataset_name must be provided when local=False (HuggingFace dataset)") | |
| return v | |
| def validate_tokenizer_threads(cls, v): | |
| """Validate tokenizer_threads is a positive integer.""" | |
| if v < 1: | |
| raise ValueError("tokenizer_threads must be at least 1") | |
| if v > 128: | |
| raise ValueError("tokenizer_threads should not exceed 128 (recommended: 1-32)") | |
| return v | |
| # ============================================================================ | |
| # Tokenizer Config | |
| # ============================================================================ | |
| class TokenizerConfig(BaseConfig): | |
| """Configuration for tokenizer training.""" | |
| # Dataset source | |
| jsonl_path: str = Field(description="Path to JSONL file containing training data") | |
| text_field: str = Field(default="text", description="Field name in JSONL for text data") | |
| # Training configuration | |
| vocab_size: int = Field(default=50000, description="Vocabulary size") | |
| model_type: str = Field(default="unigram", description="SentencePiece model type (unigram, bpe, char, word)") | |
| character_coverage: float = Field( | |
| default=0.9995, | |
| description="Character coverage for SentencePiece training" | |
| ) | |
| output_dir: str = Field(default="tokenizers", description="Directory to save trained tokenizer") | |
| tokenizer_prefix: Optional[str] = Field( | |
| default=None, | |
| description="Prefix for tokenizer output files (default: model_type)" | |
| ) | |
| # SentencePiece token IDs | |
| unk_id: int = Field(default=0, description="Unknown token ID") | |
| bos_id: int = Field(default=1, description="Beginning of sentence token ID") | |
| eos_id: int = Field(default=2, description="End of sentence token ID") | |
| pad_id: int = Field(default=3, description="Padding token ID") | |
| # Custom special tokens - add custom tokens like <PAD>, <EOS>, <BOS>, <UNK>, <think>, <user>, <assistant>, <image> | |
| special_tokens: Optional[dict[str, int]] = Field( | |
| default=None, | |
| description="Custom special tokens mapping: {token: id}. Example: {'<think>': 4, '<user>': 5, '<assistant>': 6, '<image>': 7}" | |
| ) | |
| # Data sampling | |
| max_samples: Optional[int] = Field( | |
| default=None, | |
| description="Limit training to first N samples from JSONL (useful for quick testing)" | |
| ) | |
| # Tokenizer metadata | |
| tokenizer_name: Optional[str] = Field( | |
| default=None, | |
| description="Optional name for the tokenizer" | |
| ) | |
| # ============================================================================ | |
| # Training Config | |
| # ============================================================================ | |
| class OptimizerConfig(BaseConfig): | |
| """Optimizer configuration.""" | |
| optimizer_type: OptimizerEnum = Field(default=OptimizerEnum.ADAMW, description="Optimizer type") | |
| learning_rate: float = Field(default=1e-4, description="Peak learning rate (for Muon 2D weights)") | |
| adamw_lr: Optional[float] = Field( | |
| default=None, | |
| description="Learning rate for AdamW (1D parameters). If None, defaults to learning_rate / 10. Used in hybrid_muon_adamw optimizer." | |
| ) | |
| weight_decay: float = Field(default=1e-2, description="Weight decay (L2 regularization)") | |
| betas: tuple[float, float] = Field(default=(0.9, 0.999), description="Adam betas") | |
| eps: float = Field(default=1e-8, description="Optimizer epsilon") | |
| def set_default_adamw_lr(cls, v, values): | |
| """Set default adamw_lr as 1/10 of learning_rate if not specified.""" | |
| if v is None and 'learning_rate' in values: | |
| return values['learning_rate'] / 10 | |
| return v | |
| class SchedulerConfig(BaseConfig): | |
| """Learning rate scheduler configuration.""" | |
| scheduler_type: SchedulerEnum = Field(default=SchedulerEnum.LINEAR_WARMUP, description="Scheduler type") | |
| warmup_steps: int = Field(default=0, description="Number of warmup steps (takes precedence over warmup_ratio)") | |
| warmup_ratio: float = Field(default=0.1, description="Warmup as fraction of total steps (used if warmup_steps=0)") | |
| # Cosine scheduler specific | |
| num_cycles: float = Field(default=0.5, description="Number of cycles for cosine schedule") | |
| last_epoch: int = Field(default=-1, description="Last epoch for scheduler") | |
| # TaoNet 3-phase scheduler (warmup -> steady -> cosine decay) | |
| steady_ratio: float = Field( | |
| default=0.0, | |
| description="Fraction of training steps at peak LR before cosine decay (0.0 = no steady phase). Only for cosineWarmup." | |
| ) | |
| min_lr_ratio: float = Field( | |
| default=0.0, | |
| description="Minimum LR as fraction of peak LR at end of training (0.0 = decay to 0). Only for cosineWarmup." | |
| ) | |
| def validate_warmup_ratio(cls, v): | |
| """Validate warmup ratio is between 0 and 1.""" | |
| if not 0 <= v <= 1: | |
| raise ValueError("warmup_ratio must be between 0 and 1") | |
| return v | |
| def validate_steady_ratio(cls, v): | |
| """Validate steady ratio is between 0 and 1.""" | |
| if not 0 <= v <= 1: | |
| raise ValueError("steady_ratio must be between 0 and 1") | |
| return v | |
| def validate_min_lr_ratio(cls, v): | |
| """Validate min_lr_ratio is between 0 and 1.""" | |
| if not 0 <= v <= 1: | |
| raise ValueError("min_lr_ratio must be between 0 and 1") | |
| return v | |
| def validate_warmup_steps(cls, v): | |
| """Validate warmup steps is non-negative.""" | |
| if v < 0: | |
| raise ValueError("warmup_steps must be non-negative") | |
| return v | |
| class TrainingConfig(BaseConfig): | |
| """Base training configuration shared across all modes.""" | |
| # Data and model | |
| model: ModelConfig = Field(default_factory=ModelConfig, description="Model configuration") | |
| dataset: DatasetConfig = Field(description="Dataset configuration") | |
| # Training hyperparameters | |
| batch_size: int = Field(default=32, description="Batch size per device") | |
| num_epochs: int = Field(default=3, description="Number of training epochs") | |
| max_steps: Optional[int] = Field( | |
| default=None, | |
| description="Maximum steps (overrides num_epochs if set)" | |
| ) | |
| gradient_accumulation_steps: int = Field( | |
| default=1, | |
| description="Gradient accumulation steps" | |
| ) | |
| max_grad_norm: float = Field(default=1.0, description="Gradient clipping max norm") | |
| # Optimizer | |
| optimizer: OptimizerConfig = Field( | |
| default_factory=OptimizerConfig, | |
| description="Optimizer configuration" | |
| ) | |
| # Scheduler | |
| scheduler: SchedulerConfig = Field( | |
| default_factory=SchedulerConfig, | |
| description="Learning rate scheduler configuration" | |
| ) | |
| # Data type and device | |
| dtype: DataTypeEnum = Field( | |
| default=DataTypeEnum.BFLOAT16, | |
| description="Training data type" | |
| ) | |
| device: str = Field(default="cuda", description="Device to train on (cuda, cpu)") | |
| seed: int = Field(default=42, description="Random seed") | |
| # Checkpointing | |
| checkpoint_dir: str = Field(default="checkpoints", description="Directory to save checkpoints") | |
| checkpoint_path: Optional[str] = Field( | |
| default=None, | |
| description="Path to load pretrained checkpoint (for SFT/RL). If provided, loads weights before training starts." | |
| ) | |
| save_every_steps: int = Field(default=500, description="Save checkpoint every N steps") | |
| keep_last_n_checkpoints: int = Field(default=3, description="Keep only last N checkpoints") | |
| save_best_model: bool = Field(default=True, description="Save best model based on validation loss") | |
| # Validation | |
| eval_every_steps: int = Field(default=500, description="Evaluate every N steps") | |
| eval_samples: int = Field(default=1000, description="Number of validation samples") | |
| # Logging | |
| log_every_steps: int = Field(default=10, description="Log metrics every N steps") | |
| aim_repo: str = Field(default=".aim", description="AimStack repository path") | |
| # Misc | |
| num_workers: int = Field(default=0, description="Number of DataLoader workers") | |
| pin_memory: bool = Field(default=True, description="Pin memory for DataLoader") | |
| use_compile: bool = Field(default=False, description="Use torch.compile (experimental)") | |
| # Mode | |
| mode: TrainingModeEnum = Field(default=TrainingModeEnum.PRETRAIN, description="Training mode") | |
| # ============================================================================ | |
| # Stage-Specific Configs | |
| # ============================================================================ | |
| class PretrainConfig(TrainingConfig): | |
| """Configuration for pretraining.""" | |
| mode: Literal[TrainingModeEnum.PRETRAIN] = TrainingModeEnum.PRETRAIN | |
| # Pretraining-specific | |
| sequence_length: int = Field(default=1024, description="Sequence length for pretraining") | |
| class SFTConfig(TrainingConfig): | |
| """Configuration for supervised fine-tuning.""" | |
| mode: Literal[TrainingModeEnum.SFT] = TrainingModeEnum.SFT | |
| # SFT-specific | |
| response_loss_only: bool = Field( | |
| default=True, | |
| description="Only compute loss on response/assistant tokens (not instruction/user tokens). Uses -100 label masking." | |
| ) | |
| # Multi-turn conversation role tokens | |
| user_token: str = Field( | |
| default="<user>", | |
| description="Special token representing user/instruction role in conversations" | |
| ) | |
| assistant_token: str = Field( | |
| default="<assistant>", | |
| description="Special token representing assistant/response role in conversations" | |
| ) | |
| class RLConfig(TrainingConfig): | |
| """Configuration for reinforcement learning training.""" | |
| mode: Literal[TrainingModeEnum.RL] = TrainingModeEnum.RL | |
| # RL-specific | |
| rl_method: RLMethodEnum = Field( | |
| default=RLMethodEnum.PPO, | |
| description="RL training method (PPO or DPO)" | |
| ) | |
| # Reward model | |
| reward_model_path: str = Field(description="Path to trained reward model checkpoint") | |
| # PPO-specific | |
| ppo_epochs: int = Field(default=4, description="PPO inner epochs") | |
| ppo_clip_ratio: float = Field(default=0.2, description="PPO clipping ratio") | |
| entropy_coeff: float = Field(default=0.01, description="Entropy bonus coefficient") | |
| value_loss_coeff: float = Field(default=1.0, description="Value function loss coefficient") | |
| # DPO-specific (Direct Preference Optimization) | |
| dpo_beta: float = Field(default=0.1, description="DPO inverse temperature (beta)") | |
| # Prompt distribution | |
| prompt_dataset: Optional[DatasetConfig] = Field( | |
| default=None, | |
| description="Separate dataset for prompts (if different from main dataset)" | |
| ) | |
| generation_max_length: int = Field( | |
| default=256, | |
| description="Maximum length for generated responses during RL" | |
| ) | |
| # ============================================================================ | |
| # Factory function | |
| # ============================================================================ | |
| def load_config(path: str | Path, mode: TrainingModeEnum | str) -> TrainingConfig: | |
| """Load config file and return appropriate config class.""" | |
| if isinstance(mode, str): | |
| mode = TrainingModeEnum(mode) | |
| config_map = { | |
| TrainingModeEnum.PRETRAIN: PretrainConfig, | |
| TrainingModeEnum.SFT: SFTConfig, | |
| TrainingModeEnum.RL: RLConfig, | |
| } | |
| config_class = config_map[mode] | |
| path = Path(path) | |
| if path.suffix == '.yaml' or path.suffix == '.yml': | |
| return config_class.load_yaml(path) | |
| elif path.suffix == '.json': | |
| return config_class.load_json(path) | |
| else: | |
| raise ValueError(f"Unsupported config file format: {path.suffix}") | |
| def load_tokenizer_config(path: str | Path) -> TokenizerConfig: | |
| """Load tokenizer config from YAML or JSON file.""" | |
| path = Path(path) | |
| if path.suffix == '.yaml' or path.suffix == '.yml': | |
| return TokenizerConfig.load_yaml(path) | |
| elif path.suffix == '.json': | |
| return TokenizerConfig.load_json(path) | |
| else: | |
| raise ValueError(f"Unsupported config file format: {path.suffix}") | |