"""Pydantic configuration schemas for TaoTrain.""" from enum import Enum from typing import Optional, Literal from pathlib import Path import json from pydantic import BaseModel as PydanticBaseModel, Field, validator import yaml # ============================================================================ # Enums # ============================================================================ class DataTypeEnum(str, Enum): """Data types for training.""" FLOAT32 = "float32" FLOAT16 = "float16" BFLOAT16 = "bfloat16" class OptimizerEnum(str, Enum): """Supported optimizers.""" ADAM = "adam" ADAMW = "adamw" SGD = "sgd" HYBRID_MUON_ADAMW = "hybrid_muon_adamw" class ModelArchitectureEnum(str, Enum): """Built-in model architectures.""" TRANSFORMER = "transformer" TAONET = "taonet" TAONET_SSM = "taonet_ssm" TAONET_HYBRID = "taonet_hybrid" class SchedulerEnum(str, Enum): """Supported learning rate schedulers.""" LINEAR_WARMUP = "linearWarmup" COSINE_WARMUP = "cosineWarmup" CONSTANT = "constant" class RLMethodEnum(str, Enum): """Supported RL training methods.""" PPO = "ppo" DPO = "dpo" class TrainingModeEnum(str, Enum): """Training stages.""" PRETRAIN = "pretrain" SFT = "sft" RL = "rl" # ============================================================================ # Base Configs # ============================================================================ class BaseConfig(PydanticBaseModel): """Base Pydantic model with utility methods.""" class Config: """Pydantic config.""" arbitrary_types_allowed = True def to_dict(self) -> dict: """Convert to dictionary.""" data = self.model_dump(mode='json') # Enums -> strings return data def to_json_str(self) -> str: """Convert to JSON string.""" return json.dumps(self.to_dict(), indent=2) def save_yaml(self, path: str | Path) -> None: """Save config to YAML file.""" path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) with open(path, 'w') as f: yaml.dump(self.to_dict(), f, default_flow_style=False, sort_keys=False) def save_json(self, path: str | Path) -> None: """Save config to JSON file.""" path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) with open(path, 'w') as f: f.write(self.to_json_str()) @classmethod def load_yaml(cls, path: str | Path) -> "BaseConfig": """Load config from YAML file.""" with open(path) as f: data = yaml.safe_load(f) return cls(**data) @classmethod def load_json(cls, path: str | Path) -> "BaseConfig": """Load config from JSON file.""" with open(path) as f: data = json.load(f) return cls(**data) # ============================================================================ # Model Config # ============================================================================ class ModelConfig(BaseConfig): """Configuration for model architecture.""" architecture_type: ModelArchitectureEnum = Field( default=ModelArchitectureEnum.TRANSFORMER, description="Type of model architecture" ) # Transformer-specific vocab_size: int = Field(default=50257, description="Vocabulary size") hidden_dim: int = Field(default=768, description="Hidden dimension") num_layers: int = Field(default=12, description="Number of transformer blocks") num_heads: int = Field(default=12, description="Number of attention heads") head_dim: Optional[int] = Field( default=None, description="Head dimension (defaults to hidden_dim // num_heads)" ) intermediate_dim: Optional[int] = Field( default=None, description="FFN intermediate dimension (defaults to 4 * hidden_dim)" ) dropout: float = Field(default=0.1, description="Dropout rate") max_seq_length: int = Field(default=2048, description="Maximum sequence length") # TaoNet (DeepSeek MLA) specific d_latent_kv: Optional[int] = Field( default=None, description="KV compression dimension for MLA (defaults to 3/4 * hidden_dim). Only used for taonet architecture." ) d_rope: Optional[int] = Field( default=None, description="RoPE dimension per head (defaults to hidden_dim // num_heads). Only used for taonet architecture." ) gqa_groups: int = Field( default=1, description="Grouped Query Attention groups (1 = standard MLA, >1 = GQA). Only used for taonet architecture." ) hidden_dim_ff: Optional[int] = Field( default=None, description="Feed-forward intermediate dimension (defaults to 4 * hidden_dim)." ) use_factorized_embedding: bool = Field( default=False, description="Use low-rank factorized embedding instead of standard embedding (reduces params). Only for taonet." ) d_embed_rank: int = Field( default=96, description="Rank dimension for factorized embedding. Only used if use_factorized_embedding=True." ) # YaRN (Yet another RoPE eXtension) for context length extension rope_scale: float = Field( default=40.0, description="Base RoPE scale factor (default: 40.0). Controls position frequency base." ) yarn_enabled: bool = Field( default=False, description="Enable YaRN (Yet another RoPE eXtension) for context length interpolation." ) yarn_alpha: float = Field( default=1.0, description="YaRN interpolation smoothness (1.0=smooth, <1.0=aggressive, >1.0=conservative). Only used if yarn_enabled=True." ) # TaoNet-SSM specific: SSM mixer replacing MLA attention ssm_core: Literal["gamma_s4", "dplr"] = Field( default="gamma_s4", description="SSM core used by taonet_ssm. Use dplr for the ternary-aware DPLR SSM." ) ssm_hidden_dim: Optional[int] = Field( default=None, description="SSM hidden/state dimension for taonet_ssm. Defaults to d_latent_kv or hidden_dim." ) ssm_mixer_dim: Optional[int] = Field( default=None, description="Channel dimension processed by the SSM mixer. Defaults to hidden_dim; smaller values use an input/output projection bottleneck." ) ssm_num_lanes: int = Field( default=1, description="Number of independent SSM lanes inside each SSM mixer. Multiple lanes add SSM capacity with cheap elementwise combination." ) ssm_lane_combine: Literal["mean", "channel"] = Field( default="mean", description="How to combine multiple SSM lanes. Channel uses learned per-lane/per-channel elementwise weights." ) ssm_lane_mode: Literal["full", "split"] = Field( default="full", description="Whether each SSM lane processes the full mixer dimension or a disjoint split of the mixer channels." ) ssm_split_mix: Literal["none", "hadamard"] = Field( default="none", description="Optional ternary-friendly cross-lane mixer for split SSM lanes." ) ssm_rank: int = Field( default=1, description="Low-rank correction rank for ssm_core=dplr." ) ssm_max_low_rank_scale: float = Field( default=0.1, description="Maximum low-rank correction scale for ssm_core=dplr." ) ssm_finite_tail_correction: bool = Field( default=True, description="Enable exact finite-length tail correction for ssm_core=dplr. Disable for the faster approximate DPLR path." ) ssm_discretization: Literal["bilinear", "zoh", "euler"] = Field( default="bilinear", description="Discretization used by the Gamma SSM mixer." ) ssm_kernel_mode: Literal["auto", "recurrent", "conv", "conv_transfer"] = Field( default="auto", description="Gamma SSM execution path. Use auto/conv for full-sequence GPU training, conv_transfer to materialize frequency transfers, recurrent for step-wise tests." ) ssm_kernel_threshold: int = Field( default=64, description="Minimum sequence length for auto mode to use the convolutional Gamma SSM path." ) ssm_dt_min: float = Field(default=1e-3, description="Minimum learned SSM timestep.") ssm_dt_max: float = Field(default=1e-1, description="Maximum learned SSM timestep.") ssm_dt_init: float = Field(default=1e-2, description="Initial learned SSM timestep.") ssm_use_d: bool = Field(default=True, description="Enable direct skip term D in the Gamma SSM.") ssm_activation: Literal["gelu", "silu", "identity", "linear"] = Field( default="gelu", description="Activation applied to the Gamma SSM branch output." ) ssm_gate: bool = Field(default=True, description="Enable output gate on the Gamma SSM branch.") ssm_input_gate: bool = Field(default=True, description="Enable input gate before the Gamma SSM.") ssm_gate_type: Literal["dense", "channel"] = Field( default="dense", description="Gate implementation for enabled SSM input/output gates. Channel gates are elementwise and ternary-friendly." ) ssm_use_padding_mask: bool = Field( default=False, description="Apply dataset padding masks inside the SSM. Disabled by default so training can use the convolutional path." ) ssm_layer_scale_init: float = Field( default=0.1, description="Initial layer-scale multiplier for the Gamma SSM branch." ) ssm_branch_rms_norm: bool = Field( default=False, description="Normalize the SSM residual branch to unit RMS before layer-scale. Useful for stabilizing deep SSM/hybrid runs." ) ssm_branch_rms_eps: float = Field( default=1e-6, description="Numerical epsilon for optional SSM branch RMS normalization." ) ssm_branch_clip_value: Optional[float] = Field( default=None, description="Optional symmetric clamp applied to the SSM residual branch after layer-scale. None disables clamping." ) block_residual_rms_norm: bool = Field( default=False, description="Normalize the residual stream RMS after block residual additions. Intended for stabilizing deep SSM/hybrid experiments." ) block_residual_rms_target: float = Field( default=1.0, description="Target per-token RMS when block_residual_rms_norm is enabled." ) block_residual_rms_cap: Optional[float] = Field( default=None, description="Optional per-token RMS cap for the residual stream. Unlike block_residual_rms_norm, this only scales down tokens whose RMS exceeds the cap." ) block_residual_rms_eps: float = Field( default=1e-6, description="Numerical epsilon for optional block residual RMS normalization." ) ssm_local_shift: bool = Field( default=False, description="Add a cheap one-token causal shift/register branch to the taonet_ssm mixer." ) ssm_local_shift_init: float = Field( default=0.1, description="Initial scalar weight for the optional one-token local shift/register branch." ) ssm_local_shift_per_channel: bool = Field( default=False, description="Use one learned local-shift gain per model channel instead of one scalar." ) hybrid_pattern: Literal["attention_first", "ssm_first", "single_ssm_middle", "single_ssm_late"] = Field( default="attention_first", description="Layer pattern for taonet_hybrid when hybrid_ssm_layers is not set." ) hybrid_ssm_layers: Optional[str] = Field( default=None, description="Optional comma-separated 0-based layer indices that should use SSM blocks in taonet_hybrid." ) # Initializations init_std: float = Field(default=0.02, description="Weight initialization standard deviation") @validator("head_dim", always=True) def validate_head_dim(cls, v, values): """Validate head dimension.""" if v is None and 'hidden_dim' in values: return values['hidden_dim'] // values.get('num_heads', 12) return v @validator("intermediate_dim", always=True) def validate_intermediate_dim(cls, v, values): """Validate intermediate dimension.""" if v is None and 'hidden_dim' in values: return 4 * values['hidden_dim'] return v # ============================================================================ # Dataset Config # ============================================================================ class DatasetConfig(BaseConfig): """Configuration for dataset loading.""" # Local vs HuggingFace dataset selection local: bool = Field(default=False, description="Use local JSONL dataset instead of HuggingFace") # HuggingFace dataset fields dataset_name: Optional[str] = Field(default=None, description="HuggingFace dataset name (e.g., 'wikitext', 'openwebtext')") split: str = Field(default="train", description="Dataset split to use") config: Optional[str] = Field(default=None, description="Dataset config if multi-config (e.g., 'wikitext-103')") # Local JSONL dataset fields jsonl_path: Optional[str] = Field(default=None, description="Path to local JSONL dataset file") text_field: str = Field(default="text", description="Name of text field in JSONL") # Text column name varies by dataset text_column: str = Field(default="text", description="Name of text column in dataset") # Preprocessing max_samples: Optional[int] = Field( default=None, description="Limit dataset to N samples (useful for debugging)" ) cache_dir: str = Field(default=".cache/datasets", description="HuggingFace cache directory") # For SFT/RL datasets with instruction-response format instruction_column: Optional[str] = Field(default=None, description="Instruction column for SFT") response_column: Optional[str] = Field(default=None, description="Response column for SFT") prompt_column: Optional[str] = Field(default=None, description="Prompt column for RL") # Instruction template instruction_template: Optional[str] = Field( default=None, description="Template for combining instruction and response. E.g., '{instruction}\\n{response}'" ) # Tokenizer configuration tokenizer_type: Optional[str] = Field( default=None, description="Tokenizer type: 'huggingface' or 'sentencepiece'. If None, defaults based on tokenizer_path." ) tokenizer_path: Optional[str] = Field( default=None, description="Path to saved tokenizer (for SentencePiece: .model file, for HuggingFace: model name or local path)" ) # Chunked loading for large JSONL files enable_streaming: bool = Field( default=True, description="Enable streaming/chunked loading for large JSONL files to reduce memory usage" ) chunk_size_gb: float = Field( default=5.0, description="Approximate chunk size in GB (ignored if samples_per_chunk is set)" ) samples_per_chunk: Optional[int] = Field( default=1000, description="Number of samples per chunk (takes precedence over chunk_size_gb). Default: 1000 samples" ) # Chunk caching enable_chunk_metadata_cache: bool = Field( default=True, description="Enable caching of chunk metadata (file scan results) to avoid re-scanning large JSONL files" ) enable_chunk_data_cache: bool = Field( default=False, description="Enable caching of actual chunk data as separate files for faster loading (uses more disk space)" ) chunk_cache_dir: str = Field( default=".cache/chunks", description="Directory to store chunk metadata and data cache files" ) # Tokenization parallelization tokenizer_threads: int = Field( default=1, description="Number of background threads for tokenization (1-32 recommended). Higher values speed up tokenization but increase memory usage." ) @validator('jsonl_path', always=True) def validate_dataset_source(cls, v, values): """Validate that either local JSONL or HuggingFace dataset is specified.""" local = values.get('local', False) dataset_name = values.get('dataset_name') if local and not v: raise ValueError("jsonl_path must be provided when local=True") if not local and not dataset_name: raise ValueError("dataset_name must be provided when local=False (HuggingFace dataset)") return v @validator('tokenizer_threads') def validate_tokenizer_threads(cls, v): """Validate tokenizer_threads is a positive integer.""" if v < 1: raise ValueError("tokenizer_threads must be at least 1") if v > 128: raise ValueError("tokenizer_threads should not exceed 128 (recommended: 1-32)") return v # ============================================================================ # Tokenizer Config # ============================================================================ class TokenizerConfig(BaseConfig): """Configuration for tokenizer training.""" # Dataset source jsonl_path: str = Field(description="Path to JSONL file containing training data") text_field: str = Field(default="text", description="Field name in JSONL for text data") # Training configuration vocab_size: int = Field(default=50000, description="Vocabulary size") model_type: str = Field(default="unigram", description="SentencePiece model type (unigram, bpe, char, word)") character_coverage: float = Field( default=0.9995, description="Character coverage for SentencePiece training" ) output_dir: str = Field(default="tokenizers", description="Directory to save trained tokenizer") tokenizer_prefix: Optional[str] = Field( default=None, description="Prefix for tokenizer output files (default: model_type)" ) # SentencePiece token IDs unk_id: int = Field(default=0, description="Unknown token ID") bos_id: int = Field(default=1, description="Beginning of sentence token ID") eos_id: int = Field(default=2, description="End of sentence token ID") pad_id: int = Field(default=3, description="Padding token ID") # Custom special tokens - add custom tokens like , , , , , , , special_tokens: Optional[dict[str, int]] = Field( default=None, description="Custom special tokens mapping: {token: id}. Example: {'': 4, '': 5, '': 6, '': 7}" ) # Data sampling max_samples: Optional[int] = Field( default=None, description="Limit training to first N samples from JSONL (useful for quick testing)" ) # Tokenizer metadata tokenizer_name: Optional[str] = Field( default=None, description="Optional name for the tokenizer" ) # ============================================================================ # Training Config # ============================================================================ class OptimizerConfig(BaseConfig): """Optimizer configuration.""" optimizer_type: OptimizerEnum = Field(default=OptimizerEnum.ADAMW, description="Optimizer type") learning_rate: float = Field(default=1e-4, description="Peak learning rate (for Muon 2D weights)") adamw_lr: Optional[float] = Field( default=None, description="Learning rate for AdamW (1D parameters). If None, defaults to learning_rate / 10. Used in hybrid_muon_adamw optimizer." ) weight_decay: float = Field(default=1e-2, description="Weight decay (L2 regularization)") betas: tuple[float, float] = Field(default=(0.9, 0.999), description="Adam betas") eps: float = Field(default=1e-8, description="Optimizer epsilon") @validator('adamw_lr', always=True) def set_default_adamw_lr(cls, v, values): """Set default adamw_lr as 1/10 of learning_rate if not specified.""" if v is None and 'learning_rate' in values: return values['learning_rate'] / 10 return v class SchedulerConfig(BaseConfig): """Learning rate scheduler configuration.""" scheduler_type: SchedulerEnum = Field(default=SchedulerEnum.LINEAR_WARMUP, description="Scheduler type") warmup_steps: int = Field(default=0, description="Number of warmup steps (takes precedence over warmup_ratio)") warmup_ratio: float = Field(default=0.1, description="Warmup as fraction of total steps (used if warmup_steps=0)") # Cosine scheduler specific num_cycles: float = Field(default=0.5, description="Number of cycles for cosine schedule") last_epoch: int = Field(default=-1, description="Last epoch for scheduler") # TaoNet 3-phase scheduler (warmup -> steady -> cosine decay) steady_ratio: float = Field( default=0.0, description="Fraction of training steps at peak LR before cosine decay (0.0 = no steady phase). Only for cosineWarmup." ) min_lr_ratio: float = Field( default=0.0, description="Minimum LR as fraction of peak LR at end of training (0.0 = decay to 0). Only for cosineWarmup." ) @validator('warmup_ratio') def validate_warmup_ratio(cls, v): """Validate warmup ratio is between 0 and 1.""" if not 0 <= v <= 1: raise ValueError("warmup_ratio must be between 0 and 1") return v @validator('steady_ratio') def validate_steady_ratio(cls, v): """Validate steady ratio is between 0 and 1.""" if not 0 <= v <= 1: raise ValueError("steady_ratio must be between 0 and 1") return v @validator('min_lr_ratio') def validate_min_lr_ratio(cls, v): """Validate min_lr_ratio is between 0 and 1.""" if not 0 <= v <= 1: raise ValueError("min_lr_ratio must be between 0 and 1") return v @validator('warmup_steps') def validate_warmup_steps(cls, v): """Validate warmup steps is non-negative.""" if v < 0: raise ValueError("warmup_steps must be non-negative") return v class TrainingConfig(BaseConfig): """Base training configuration shared across all modes.""" # Data and model model: ModelConfig = Field(default_factory=ModelConfig, description="Model configuration") dataset: DatasetConfig = Field(description="Dataset configuration") # Training hyperparameters batch_size: int = Field(default=32, description="Batch size per device") num_epochs: int = Field(default=3, description="Number of training epochs") max_steps: Optional[int] = Field( default=None, description="Maximum steps (overrides num_epochs if set)" ) gradient_accumulation_steps: int = Field( default=1, description="Gradient accumulation steps" ) max_grad_norm: float = Field(default=1.0, description="Gradient clipping max norm") # Optimizer optimizer: OptimizerConfig = Field( default_factory=OptimizerConfig, description="Optimizer configuration" ) # Scheduler scheduler: SchedulerConfig = Field( default_factory=SchedulerConfig, description="Learning rate scheduler configuration" ) # Data type and device dtype: DataTypeEnum = Field( default=DataTypeEnum.BFLOAT16, description="Training data type" ) device: str = Field(default="cuda", description="Device to train on (cuda, cpu)") seed: int = Field(default=42, description="Random seed") # Checkpointing checkpoint_dir: str = Field(default="checkpoints", description="Directory to save checkpoints") checkpoint_path: Optional[str] = Field( default=None, description="Path to load pretrained checkpoint (for SFT/RL). If provided, loads weights before training starts." ) save_every_steps: int = Field(default=500, description="Save checkpoint every N steps") keep_last_n_checkpoints: int = Field(default=3, description="Keep only last N checkpoints") save_best_model: bool = Field(default=True, description="Save best model based on validation loss") # Validation eval_every_steps: int = Field(default=500, description="Evaluate every N steps") eval_samples: int = Field(default=1000, description="Number of validation samples") # Logging log_every_steps: int = Field(default=10, description="Log metrics every N steps") aim_repo: str = Field(default=".aim", description="AimStack repository path") # Misc num_workers: int = Field(default=0, description="Number of DataLoader workers") pin_memory: bool = Field(default=True, description="Pin memory for DataLoader") use_compile: bool = Field(default=False, description="Use torch.compile (experimental)") # Mode mode: TrainingModeEnum = Field(default=TrainingModeEnum.PRETRAIN, description="Training mode") # ============================================================================ # Stage-Specific Configs # ============================================================================ class PretrainConfig(TrainingConfig): """Configuration for pretraining.""" mode: Literal[TrainingModeEnum.PRETRAIN] = TrainingModeEnum.PRETRAIN # Pretraining-specific sequence_length: int = Field(default=1024, description="Sequence length for pretraining") class SFTConfig(TrainingConfig): """Configuration for supervised fine-tuning.""" mode: Literal[TrainingModeEnum.SFT] = TrainingModeEnum.SFT # SFT-specific response_loss_only: bool = Field( default=True, description="Only compute loss on response/assistant tokens (not instruction/user tokens). Uses -100 label masking." ) # Multi-turn conversation role tokens user_token: str = Field( default="", description="Special token representing user/instruction role in conversations" ) assistant_token: str = Field( default="", description="Special token representing assistant/response role in conversations" ) class RLConfig(TrainingConfig): """Configuration for reinforcement learning training.""" mode: Literal[TrainingModeEnum.RL] = TrainingModeEnum.RL # RL-specific rl_method: RLMethodEnum = Field( default=RLMethodEnum.PPO, description="RL training method (PPO or DPO)" ) # Reward model reward_model_path: str = Field(description="Path to trained reward model checkpoint") # PPO-specific ppo_epochs: int = Field(default=4, description="PPO inner epochs") ppo_clip_ratio: float = Field(default=0.2, description="PPO clipping ratio") entropy_coeff: float = Field(default=0.01, description="Entropy bonus coefficient") value_loss_coeff: float = Field(default=1.0, description="Value function loss coefficient") # DPO-specific (Direct Preference Optimization) dpo_beta: float = Field(default=0.1, description="DPO inverse temperature (beta)") # Prompt distribution prompt_dataset: Optional[DatasetConfig] = Field( default=None, description="Separate dataset for prompts (if different from main dataset)" ) generation_max_length: int = Field( default=256, description="Maximum length for generated responses during RL" ) # ============================================================================ # Factory function # ============================================================================ def load_config(path: str | Path, mode: TrainingModeEnum | str) -> TrainingConfig: """Load config file and return appropriate config class.""" if isinstance(mode, str): mode = TrainingModeEnum(mode) config_map = { TrainingModeEnum.PRETRAIN: PretrainConfig, TrainingModeEnum.SFT: SFTConfig, TrainingModeEnum.RL: RLConfig, } config_class = config_map[mode] path = Path(path) if path.suffix == '.yaml' or path.suffix == '.yml': return config_class.load_yaml(path) elif path.suffix == '.json': return config_class.load_json(path) else: raise ValueError(f"Unsupported config file format: {path.suffix}") def load_tokenizer_config(path: str | Path) -> TokenizerConfig: """Load tokenizer config from YAML or JSON file.""" path = Path(path) if path.suffix == '.yaml' or path.suffix == '.yml': return TokenizerConfig.load_yaml(path) elif path.suffix == '.json': return TokenizerConfig.load_json(path) else: raise ValueError(f"Unsupported config file format: {path.suffix}")