# typed dataclass wrappers over the yaml configs. consumed by every entry # point so hyperparameters never sneak into code. from dataclasses import dataclass, field from pathlib import Path from typing import Optional import yaml # ---------- training ---------- @dataclass class LoraConfig: r: int alpha: int dropout: float bias: str target_modules: list @dataclass class TrainConfig: base_model: str language: str data_dir: str max_seq_length: int train_rows: Optional[int] val_rows: Optional[int] lora: LoraConfig learning_rate: float weight_decay: float warmup_ratio: float max_grad_norm: float adam_beta1: float adam_beta2: float adam_epsilon: float num_epochs: int lr_scheduler_type: str train_batch_size: int eval_batch_size: int gradient_accumulation_steps: int bf16: bool fp16: bool tf32: bool seed: int metric_for_best_model: str greater_is_better: bool save_total_limit: int logging_steps: int eval_steps: int save_steps: int dataloader_num_workers: int def load_train_config(path) -> TrainConfig: raw = yaml.safe_load(Path(path).read_text()) lora = LoraConfig(**raw.pop("lora")) return TrainConfig(lora=lora, **raw) # ---------- data ---------- @dataclass class DataSplits: train: float val: float test: float @dataclass class DataConfig: seed_path: str splits: DataSplits random_seed: int def load_data_config(path) -> DataConfig: raw = yaml.safe_load(Path(path).read_text()) splits = DataSplits(**raw.pop("splits")) return DataConfig(splits=splits, **raw) # ---------- optional inject (v1.1) ---------- @dataclass class InjectSampling: ops_per_example_min: int ops_per_example_max: int @dataclass class InjectConfig: sampling: InjectSampling ops: dict = field(default_factory=dict) def load_inject_config(path) -> InjectConfig: raw = yaml.safe_load(Path(path).read_text()) sampling = InjectSampling(**raw.pop("sampling")) return InjectConfig(sampling=sampling, **raw)