StarMist0012's picture
Add files using upload-large-folder tool
e2bfccc verified
"""Pydantic configuration schemas for TaoTrain."""
from enum import Enum
from typing import Optional, Literal
from pathlib import Path
import json
from pydantic import BaseModel as PydanticBaseModel, Field, validator
import yaml
# ============================================================================
# Enums
# ============================================================================
class DataTypeEnum(str, Enum):
"""Data types for training."""
FLOAT32 = "float32"
FLOAT16 = "float16"
BFLOAT16 = "bfloat16"
class OptimizerEnum(str, Enum):
"""Supported optimizers."""
ADAM = "adam"
ADAMW = "adamw"
SGD = "sgd"
HYBRID_MUON_ADAMW = "hybrid_muon_adamw"
class ModelArchitectureEnum(str, Enum):
"""Built-in model architectures."""
TRANSFORMER = "transformer"
TAONET = "taonet"
TAONET_SSM = "taonet_ssm"
TAONET_HYBRID = "taonet_hybrid"
class SchedulerEnum(str, Enum):
"""Supported learning rate schedulers."""
LINEAR_WARMUP = "linearWarmup"
COSINE_WARMUP = "cosineWarmup"
CONSTANT = "constant"
class RLMethodEnum(str, Enum):
"""Supported RL training methods."""
PPO = "ppo"
DPO = "dpo"
class TrainingModeEnum(str, Enum):
"""Training stages."""
PRETRAIN = "pretrain"
SFT = "sft"
RL = "rl"
# ============================================================================
# Base Configs
# ============================================================================
class BaseConfig(PydanticBaseModel):
"""Base Pydantic model with utility methods."""
class Config:
"""Pydantic config."""
arbitrary_types_allowed = True
def to_dict(self) -> dict:
"""Convert to dictionary."""
data = self.model_dump(mode='json') # Enums -> strings
return data
def to_json_str(self) -> str:
"""Convert to JSON string."""
return json.dumps(self.to_dict(), indent=2)
def save_yaml(self, path: str | Path) -> None:
"""Save config to YAML file."""
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, 'w') as f:
yaml.dump(self.to_dict(), f, default_flow_style=False, sort_keys=False)
def save_json(self, path: str | Path) -> None:
"""Save config to JSON file."""
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, 'w') as f:
f.write(self.to_json_str())
@classmethod
def load_yaml(cls, path: str | Path) -> "BaseConfig":
"""Load config from YAML file."""
with open(path) as f:
data = yaml.safe_load(f)
return cls(**data)
@classmethod
def load_json(cls, path: str | Path) -> "BaseConfig":
"""Load config from JSON file."""
with open(path) as f:
data = json.load(f)
return cls(**data)
# ============================================================================
# Model Config
# ============================================================================
class ModelConfig(BaseConfig):
"""Configuration for model architecture."""
architecture_type: ModelArchitectureEnum = Field(
default=ModelArchitectureEnum.TRANSFORMER,
description="Type of model architecture"
)
# Transformer-specific
vocab_size: int = Field(default=50257, description="Vocabulary size")
hidden_dim: int = Field(default=768, description="Hidden dimension")
num_layers: int = Field(default=12, description="Number of transformer blocks")
num_heads: int = Field(default=12, description="Number of attention heads")
head_dim: Optional[int] = Field(
default=None,
description="Head dimension (defaults to hidden_dim // num_heads)"
)
intermediate_dim: Optional[int] = Field(
default=None,
description="FFN intermediate dimension (defaults to 4 * hidden_dim)"
)
dropout: float = Field(default=0.1, description="Dropout rate")
max_seq_length: int = Field(default=2048, description="Maximum sequence length")
# TaoNet (DeepSeek MLA) specific
d_latent_kv: Optional[int] = Field(
default=None,
description="KV compression dimension for MLA (defaults to 3/4 * hidden_dim). Only used for taonet architecture."
)
d_rope: Optional[int] = Field(
default=None,
description="RoPE dimension per head (defaults to hidden_dim // num_heads). Only used for taonet architecture."
)
gqa_groups: int = Field(
default=1,
description="Grouped Query Attention groups (1 = standard MLA, >1 = GQA). Only used for taonet architecture."
)
hidden_dim_ff: Optional[int] = Field(
default=None,
description="Feed-forward intermediate dimension (defaults to 4 * hidden_dim)."
)
use_factorized_embedding: bool = Field(
default=False,
description="Use low-rank factorized embedding instead of standard embedding (reduces params). Only for taonet."
)
d_embed_rank: int = Field(
default=96,
description="Rank dimension for factorized embedding. Only used if use_factorized_embedding=True."
)
# YaRN (Yet another RoPE eXtension) for context length extension
rope_scale: float = Field(
default=40.0,
description="Base RoPE scale factor (default: 40.0). Controls position frequency base."
)
yarn_enabled: bool = Field(
default=False,
description="Enable YaRN (Yet another RoPE eXtension) for context length interpolation."
)
yarn_alpha: float = Field(
default=1.0,
description="YaRN interpolation smoothness (1.0=smooth, <1.0=aggressive, >1.0=conservative). Only used if yarn_enabled=True."
)
# TaoNet-SSM specific: SSM mixer replacing MLA attention
ssm_core: Literal["gamma_s4", "dplr"] = Field(
default="gamma_s4",
description="SSM core used by taonet_ssm. Use dplr for the ternary-aware DPLR SSM."
)
ssm_hidden_dim: Optional[int] = Field(
default=None,
description="SSM hidden/state dimension for taonet_ssm. Defaults to d_latent_kv or hidden_dim."
)
ssm_mixer_dim: Optional[int] = Field(
default=None,
description="Channel dimension processed by the SSM mixer. Defaults to hidden_dim; smaller values use an input/output projection bottleneck."
)
ssm_num_lanes: int = Field(
default=1,
description="Number of independent SSM lanes inside each SSM mixer. Multiple lanes add SSM capacity with cheap elementwise combination."
)
ssm_lane_combine: Literal["mean", "channel"] = Field(
default="mean",
description="How to combine multiple SSM lanes. Channel uses learned per-lane/per-channel elementwise weights."
)
ssm_lane_mode: Literal["full", "split"] = Field(
default="full",
description="Whether each SSM lane processes the full mixer dimension or a disjoint split of the mixer channels."
)
ssm_split_mix: Literal["none", "hadamard"] = Field(
default="none",
description="Optional ternary-friendly cross-lane mixer for split SSM lanes."
)
ssm_rank: int = Field(
default=1,
description="Low-rank correction rank for ssm_core=dplr."
)
ssm_max_low_rank_scale: float = Field(
default=0.1,
description="Maximum low-rank correction scale for ssm_core=dplr."
)
ssm_finite_tail_correction: bool = Field(
default=True,
description="Enable exact finite-length tail correction for ssm_core=dplr. Disable for the faster approximate DPLR path."
)
ssm_discretization: Literal["bilinear", "zoh", "euler"] = Field(
default="bilinear",
description="Discretization used by the Gamma SSM mixer."
)
ssm_kernel_mode: Literal["auto", "recurrent", "conv", "conv_transfer"] = Field(
default="auto",
description="Gamma SSM execution path. Use auto/conv for full-sequence GPU training, conv_transfer to materialize frequency transfers, recurrent for step-wise tests."
)
ssm_kernel_threshold: int = Field(
default=64,
description="Minimum sequence length for auto mode to use the convolutional Gamma SSM path."
)
ssm_dt_min: float = Field(default=1e-3, description="Minimum learned SSM timestep.")
ssm_dt_max: float = Field(default=1e-1, description="Maximum learned SSM timestep.")
ssm_dt_init: float = Field(default=1e-2, description="Initial learned SSM timestep.")
ssm_use_d: bool = Field(default=True, description="Enable direct skip term D in the Gamma SSM.")
ssm_activation: Literal["gelu", "silu", "identity", "linear"] = Field(
default="gelu",
description="Activation applied to the Gamma SSM branch output."
)
ssm_gate: bool = Field(default=True, description="Enable output gate on the Gamma SSM branch.")
ssm_input_gate: bool = Field(default=True, description="Enable input gate before the Gamma SSM.")
ssm_gate_type: Literal["dense", "channel"] = Field(
default="dense",
description="Gate implementation for enabled SSM input/output gates. Channel gates are elementwise and ternary-friendly."
)
ssm_use_padding_mask: bool = Field(
default=False,
description="Apply dataset padding masks inside the SSM. Disabled by default so training can use the convolutional path."
)
ssm_layer_scale_init: float = Field(
default=0.1,
description="Initial layer-scale multiplier for the Gamma SSM branch."
)
ssm_branch_rms_norm: bool = Field(
default=False,
description="Normalize the SSM residual branch to unit RMS before layer-scale. Useful for stabilizing deep SSM/hybrid runs."
)
ssm_branch_rms_eps: float = Field(
default=1e-6,
description="Numerical epsilon for optional SSM branch RMS normalization."
)
ssm_branch_clip_value: Optional[float] = Field(
default=None,
description="Optional symmetric clamp applied to the SSM residual branch after layer-scale. None disables clamping."
)
block_residual_rms_norm: bool = Field(
default=False,
description="Normalize the residual stream RMS after block residual additions. Intended for stabilizing deep SSM/hybrid experiments."
)
block_residual_rms_target: float = Field(
default=1.0,
description="Target per-token RMS when block_residual_rms_norm is enabled."
)
block_residual_rms_cap: Optional[float] = Field(
default=None,
description="Optional per-token RMS cap for the residual stream. Unlike block_residual_rms_norm, this only scales down tokens whose RMS exceeds the cap."
)
block_residual_rms_eps: float = Field(
default=1e-6,
description="Numerical epsilon for optional block residual RMS normalization."
)
ssm_local_shift: bool = Field(
default=False,
description="Add a cheap one-token causal shift/register branch to the taonet_ssm mixer."
)
ssm_local_shift_init: float = Field(
default=0.1,
description="Initial scalar weight for the optional one-token local shift/register branch."
)
ssm_local_shift_per_channel: bool = Field(
default=False,
description="Use one learned local-shift gain per model channel instead of one scalar."
)
hybrid_pattern: Literal["attention_first", "ssm_first", "single_ssm_middle", "single_ssm_late"] = Field(
default="attention_first",
description="Layer pattern for taonet_hybrid when hybrid_ssm_layers is not set."
)
hybrid_ssm_layers: Optional[str] = Field(
default=None,
description="Optional comma-separated 0-based layer indices that should use SSM blocks in taonet_hybrid."
)
# Initializations
init_std: float = Field(default=0.02, description="Weight initialization standard deviation")
@validator("head_dim", always=True)
def validate_head_dim(cls, v, values):
"""Validate head dimension."""
if v is None and 'hidden_dim' in values:
return values['hidden_dim'] // values.get('num_heads', 12)
return v
@validator("intermediate_dim", always=True)
def validate_intermediate_dim(cls, v, values):
"""Validate intermediate dimension."""
if v is None and 'hidden_dim' in values:
return 4 * values['hidden_dim']
return v
# ============================================================================
# Dataset Config
# ============================================================================
class DatasetConfig(BaseConfig):
"""Configuration for dataset loading."""
# Local vs HuggingFace dataset selection
local: bool = Field(default=False, description="Use local JSONL dataset instead of HuggingFace")
# HuggingFace dataset fields
dataset_name: Optional[str] = Field(default=None, description="HuggingFace dataset name (e.g., 'wikitext', 'openwebtext')")
split: str = Field(default="train", description="Dataset split to use")
config: Optional[str] = Field(default=None, description="Dataset config if multi-config (e.g., 'wikitext-103')")
# Local JSONL dataset fields
jsonl_path: Optional[str] = Field(default=None, description="Path to local JSONL dataset file")
text_field: str = Field(default="text", description="Name of text field in JSONL")
# Text column name varies by dataset
text_column: str = Field(default="text", description="Name of text column in dataset")
# Preprocessing
max_samples: Optional[int] = Field(
default=None,
description="Limit dataset to N samples (useful for debugging)"
)
cache_dir: str = Field(default=".cache/datasets", description="HuggingFace cache directory")
# For SFT/RL datasets with instruction-response format
instruction_column: Optional[str] = Field(default=None, description="Instruction column for SFT")
response_column: Optional[str] = Field(default=None, description="Response column for SFT")
prompt_column: Optional[str] = Field(default=None, description="Prompt column for RL")
# Instruction template
instruction_template: Optional[str] = Field(
default=None,
description="Template for combining instruction and response. E.g., '{instruction}\\n{response}'"
)
# Tokenizer configuration
tokenizer_type: Optional[str] = Field(
default=None,
description="Tokenizer type: 'huggingface' or 'sentencepiece'. If None, defaults based on tokenizer_path."
)
tokenizer_path: Optional[str] = Field(
default=None,
description="Path to saved tokenizer (for SentencePiece: .model file, for HuggingFace: model name or local path)"
)
# Chunked loading for large JSONL files
enable_streaming: bool = Field(
default=True,
description="Enable streaming/chunked loading for large JSONL files to reduce memory usage"
)
chunk_size_gb: float = Field(
default=5.0,
description="Approximate chunk size in GB (ignored if samples_per_chunk is set)"
)
samples_per_chunk: Optional[int] = Field(
default=1000,
description="Number of samples per chunk (takes precedence over chunk_size_gb). Default: 1000 samples"
)
# Chunk caching
enable_chunk_metadata_cache: bool = Field(
default=True,
description="Enable caching of chunk metadata (file scan results) to avoid re-scanning large JSONL files"
)
enable_chunk_data_cache: bool = Field(
default=False,
description="Enable caching of actual chunk data as separate files for faster loading (uses more disk space)"
)
chunk_cache_dir: str = Field(
default=".cache/chunks",
description="Directory to store chunk metadata and data cache files"
)
# Tokenization parallelization
tokenizer_threads: int = Field(
default=1,
description="Number of background threads for tokenization (1-32 recommended). Higher values speed up tokenization but increase memory usage."
)
@validator('jsonl_path', always=True)
def validate_dataset_source(cls, v, values):
"""Validate that either local JSONL or HuggingFace dataset is specified."""
local = values.get('local', False)
dataset_name = values.get('dataset_name')
if local and not v:
raise ValueError("jsonl_path must be provided when local=True")
if not local and not dataset_name:
raise ValueError("dataset_name must be provided when local=False (HuggingFace dataset)")
return v
@validator('tokenizer_threads')
def validate_tokenizer_threads(cls, v):
"""Validate tokenizer_threads is a positive integer."""
if v < 1:
raise ValueError("tokenizer_threads must be at least 1")
if v > 128:
raise ValueError("tokenizer_threads should not exceed 128 (recommended: 1-32)")
return v
# ============================================================================
# Tokenizer Config
# ============================================================================
class TokenizerConfig(BaseConfig):
"""Configuration for tokenizer training."""
# Dataset source
jsonl_path: str = Field(description="Path to JSONL file containing training data")
text_field: str = Field(default="text", description="Field name in JSONL for text data")
# Training configuration
vocab_size: int = Field(default=50000, description="Vocabulary size")
model_type: str = Field(default="unigram", description="SentencePiece model type (unigram, bpe, char, word)")
character_coverage: float = Field(
default=0.9995,
description="Character coverage for SentencePiece training"
)
output_dir: str = Field(default="tokenizers", description="Directory to save trained tokenizer")
tokenizer_prefix: Optional[str] = Field(
default=None,
description="Prefix for tokenizer output files (default: model_type)"
)
# SentencePiece token IDs
unk_id: int = Field(default=0, description="Unknown token ID")
bos_id: int = Field(default=1, description="Beginning of sentence token ID")
eos_id: int = Field(default=2, description="End of sentence token ID")
pad_id: int = Field(default=3, description="Padding token ID")
# Custom special tokens - add custom tokens like <PAD>, <EOS>, <BOS>, <UNK>, <think>, <user>, <assistant>, <image>
special_tokens: Optional[dict[str, int]] = Field(
default=None,
description="Custom special tokens mapping: {token: id}. Example: {'<think>': 4, '<user>': 5, '<assistant>': 6, '<image>': 7}"
)
# Data sampling
max_samples: Optional[int] = Field(
default=None,
description="Limit training to first N samples from JSONL (useful for quick testing)"
)
# Tokenizer metadata
tokenizer_name: Optional[str] = Field(
default=None,
description="Optional name for the tokenizer"
)
# ============================================================================
# Training Config
# ============================================================================
class OptimizerConfig(BaseConfig):
"""Optimizer configuration."""
optimizer_type: OptimizerEnum = Field(default=OptimizerEnum.ADAMW, description="Optimizer type")
learning_rate: float = Field(default=1e-4, description="Peak learning rate (for Muon 2D weights)")
adamw_lr: Optional[float] = Field(
default=None,
description="Learning rate for AdamW (1D parameters). If None, defaults to learning_rate / 10. Used in hybrid_muon_adamw optimizer."
)
weight_decay: float = Field(default=1e-2, description="Weight decay (L2 regularization)")
betas: tuple[float, float] = Field(default=(0.9, 0.999), description="Adam betas")
eps: float = Field(default=1e-8, description="Optimizer epsilon")
@validator('adamw_lr', always=True)
def set_default_adamw_lr(cls, v, values):
"""Set default adamw_lr as 1/10 of learning_rate if not specified."""
if v is None and 'learning_rate' in values:
return values['learning_rate'] / 10
return v
class SchedulerConfig(BaseConfig):
"""Learning rate scheduler configuration."""
scheduler_type: SchedulerEnum = Field(default=SchedulerEnum.LINEAR_WARMUP, description="Scheduler type")
warmup_steps: int = Field(default=0, description="Number of warmup steps (takes precedence over warmup_ratio)")
warmup_ratio: float = Field(default=0.1, description="Warmup as fraction of total steps (used if warmup_steps=0)")
# Cosine scheduler specific
num_cycles: float = Field(default=0.5, description="Number of cycles for cosine schedule")
last_epoch: int = Field(default=-1, description="Last epoch for scheduler")
# TaoNet 3-phase scheduler (warmup -> steady -> cosine decay)
steady_ratio: float = Field(
default=0.0,
description="Fraction of training steps at peak LR before cosine decay (0.0 = no steady phase). Only for cosineWarmup."
)
min_lr_ratio: float = Field(
default=0.0,
description="Minimum LR as fraction of peak LR at end of training (0.0 = decay to 0). Only for cosineWarmup."
)
@validator('warmup_ratio')
def validate_warmup_ratio(cls, v):
"""Validate warmup ratio is between 0 and 1."""
if not 0 <= v <= 1:
raise ValueError("warmup_ratio must be between 0 and 1")
return v
@validator('steady_ratio')
def validate_steady_ratio(cls, v):
"""Validate steady ratio is between 0 and 1."""
if not 0 <= v <= 1:
raise ValueError("steady_ratio must be between 0 and 1")
return v
@validator('min_lr_ratio')
def validate_min_lr_ratio(cls, v):
"""Validate min_lr_ratio is between 0 and 1."""
if not 0 <= v <= 1:
raise ValueError("min_lr_ratio must be between 0 and 1")
return v
@validator('warmup_steps')
def validate_warmup_steps(cls, v):
"""Validate warmup steps is non-negative."""
if v < 0:
raise ValueError("warmup_steps must be non-negative")
return v
class TrainingConfig(BaseConfig):
"""Base training configuration shared across all modes."""
# Data and model
model: ModelConfig = Field(default_factory=ModelConfig, description="Model configuration")
dataset: DatasetConfig = Field(description="Dataset configuration")
# Training hyperparameters
batch_size: int = Field(default=32, description="Batch size per device")
num_epochs: int = Field(default=3, description="Number of training epochs")
max_steps: Optional[int] = Field(
default=None,
description="Maximum steps (overrides num_epochs if set)"
)
gradient_accumulation_steps: int = Field(
default=1,
description="Gradient accumulation steps"
)
max_grad_norm: float = Field(default=1.0, description="Gradient clipping max norm")
# Optimizer
optimizer: OptimizerConfig = Field(
default_factory=OptimizerConfig,
description="Optimizer configuration"
)
# Scheduler
scheduler: SchedulerConfig = Field(
default_factory=SchedulerConfig,
description="Learning rate scheduler configuration"
)
# Data type and device
dtype: DataTypeEnum = Field(
default=DataTypeEnum.BFLOAT16,
description="Training data type"
)
device: str = Field(default="cuda", description="Device to train on (cuda, cpu)")
seed: int = Field(default=42, description="Random seed")
# Checkpointing
checkpoint_dir: str = Field(default="checkpoints", description="Directory to save checkpoints")
checkpoint_path: Optional[str] = Field(
default=None,
description="Path to load pretrained checkpoint (for SFT/RL). If provided, loads weights before training starts."
)
save_every_steps: int = Field(default=500, description="Save checkpoint every N steps")
keep_last_n_checkpoints: int = Field(default=3, description="Keep only last N checkpoints")
save_best_model: bool = Field(default=True, description="Save best model based on validation loss")
# Validation
eval_every_steps: int = Field(default=500, description="Evaluate every N steps")
eval_samples: int = Field(default=1000, description="Number of validation samples")
# Logging
log_every_steps: int = Field(default=10, description="Log metrics every N steps")
aim_repo: str = Field(default=".aim", description="AimStack repository path")
# Misc
num_workers: int = Field(default=0, description="Number of DataLoader workers")
pin_memory: bool = Field(default=True, description="Pin memory for DataLoader")
use_compile: bool = Field(default=False, description="Use torch.compile (experimental)")
# Mode
mode: TrainingModeEnum = Field(default=TrainingModeEnum.PRETRAIN, description="Training mode")
# ============================================================================
# Stage-Specific Configs
# ============================================================================
class PretrainConfig(TrainingConfig):
"""Configuration for pretraining."""
mode: Literal[TrainingModeEnum.PRETRAIN] = TrainingModeEnum.PRETRAIN
# Pretraining-specific
sequence_length: int = Field(default=1024, description="Sequence length for pretraining")
class SFTConfig(TrainingConfig):
"""Configuration for supervised fine-tuning."""
mode: Literal[TrainingModeEnum.SFT] = TrainingModeEnum.SFT
# SFT-specific
response_loss_only: bool = Field(
default=True,
description="Only compute loss on response/assistant tokens (not instruction/user tokens). Uses -100 label masking."
)
# Multi-turn conversation role tokens
user_token: str = Field(
default="<user>",
description="Special token representing user/instruction role in conversations"
)
assistant_token: str = Field(
default="<assistant>",
description="Special token representing assistant/response role in conversations"
)
class RLConfig(TrainingConfig):
"""Configuration for reinforcement learning training."""
mode: Literal[TrainingModeEnum.RL] = TrainingModeEnum.RL
# RL-specific
rl_method: RLMethodEnum = Field(
default=RLMethodEnum.PPO,
description="RL training method (PPO or DPO)"
)
# Reward model
reward_model_path: str = Field(description="Path to trained reward model checkpoint")
# PPO-specific
ppo_epochs: int = Field(default=4, description="PPO inner epochs")
ppo_clip_ratio: float = Field(default=0.2, description="PPO clipping ratio")
entropy_coeff: float = Field(default=0.01, description="Entropy bonus coefficient")
value_loss_coeff: float = Field(default=1.0, description="Value function loss coefficient")
# DPO-specific (Direct Preference Optimization)
dpo_beta: float = Field(default=0.1, description="DPO inverse temperature (beta)")
# Prompt distribution
prompt_dataset: Optional[DatasetConfig] = Field(
default=None,
description="Separate dataset for prompts (if different from main dataset)"
)
generation_max_length: int = Field(
default=256,
description="Maximum length for generated responses during RL"
)
# ============================================================================
# Factory function
# ============================================================================
def load_config(path: str | Path, mode: TrainingModeEnum | str) -> TrainingConfig:
"""Load config file and return appropriate config class."""
if isinstance(mode, str):
mode = TrainingModeEnum(mode)
config_map = {
TrainingModeEnum.PRETRAIN: PretrainConfig,
TrainingModeEnum.SFT: SFTConfig,
TrainingModeEnum.RL: RLConfig,
}
config_class = config_map[mode]
path = Path(path)
if path.suffix == '.yaml' or path.suffix == '.yml':
return config_class.load_yaml(path)
elif path.suffix == '.json':
return config_class.load_json(path)
else:
raise ValueError(f"Unsupported config file format: {path.suffix}")
def load_tokenizer_config(path: str | Path) -> TokenizerConfig:
"""Load tokenizer config from YAML or JSON file."""
path = Path(path)
if path.suffix == '.yaml' or path.suffix == '.yml':
return TokenizerConfig.load_yaml(path)
elif path.suffix == '.json':
return TokenizerConfig.load_json(path)
else:
raise ValueError(f"Unsupported config file format: {path.suffix}")