from __future__ import annotations from dataclasses import dataclass, field from pathlib import Path from typing import Any import yaml from .utils import resolve_repo_path @dataclass class ModelConfig: name: str trust_remote_code: bool = False load_in_4bit: bool = False use_lora_adapters: bool = False lora_r: int = 16 lora_alpha: int = 16 lora_dropout: float = 0.0 @dataclass class TrainerConfig: output_dir: str run_name: str = "grpo-run" max_steps: int = -1 num_train_epochs: float = 1.0 per_device_train_batch_size: int = 1 gradient_accumulation_steps: int = 8 learning_rate: float = 1.0e-6 logging_steps: int = 1 save_steps: int = 25 save_total_limit: int = 5 bf16: bool = True seed: int = 42 report_to: str = "wandb" optim: str = "adamw_torch" gradient_checkpointing: bool = True max_grad_norm: float = 1.0 shuffle_dataset: bool = False lr_scheduler_type: str = "cosine" lr_scheduler_kwargs: dict[str, Any] = field(default_factory=dict) warmup_steps: int = 20 sanity_log_examples: int = 8 sanity_log_max_chars: int = 300 permanent_checkpoint_steps: int = 300 permanent_checkpoint_dir: str = "checkpoints/permanent" @dataclass class DataConfig: provider: str = "gsm8k_math_curriculum" split: str = "train" max_samples: int | None = None @dataclass class GenerationConfig: max_prompt_length: int = 512 max_completion_length: int = 256 num_generations: int = 4 temperature: float = 0.9 top_p: float = 0.95 @dataclass class ObjectiveConfig: name: str = "token_grpo" kwargs: dict[str, Any] = field(default_factory=dict) class_path: str | None = None @dataclass class RewardsConfig: kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) @dataclass class AuthConfig: hf_api_key: str | None = None wandb_api_key: str | None = None hf_api_key_env: str = "HF_TOKEN" wandb_api_key_env: str = "WANDB_API_KEY" @dataclass class StorageConfig: cache_dir: str = "cache" @dataclass class ThinkingKLConfig: """Scale KL penalty on completion tokens that overlap the *inner* redacted thinking body.""" inner_kl_weight: float = 1.0 @dataclass class ExperimentConfig: model: ModelConfig trainer: TrainerConfig data: DataConfig = field(default_factory=DataConfig) generation: GenerationConfig = field(default_factory=GenerationConfig) objective: ObjectiveConfig = field(default_factory=ObjectiveConfig) rewards: RewardsConfig = field(default_factory=RewardsConfig) auth: AuthConfig = field(default_factory=AuthConfig) storage: StorageConfig = field(default_factory=StorageConfig) thinking_kl: ThinkingKLConfig = field(default_factory=ThinkingKLConfig) grpo: dict[str, Any] = field(default_factory=dict) def load_config(path: str | Path) -> ExperimentConfig: resolved = resolve_repo_path(path) with resolved.open("r", encoding="utf-8") as handle: raw = yaml.safe_load(handle) trainer_raw = raw["trainer"] # Backward-compatible alias: allow "optimizer" in YAML. if "optim" not in trainer_raw and "optimizer" in trainer_raw: trainer_raw = {**trainer_raw, "optim": trainer_raw["optimizer"]} if "optimizer" in trainer_raw: trainer_raw = {k: v for k, v in trainer_raw.items() if k != "optimizer"} # Accept common shorthand names. optim_aliases = { "adamw": "adamw_torch", "adamw_fused": "adamw_torch_fused", } scheduler_aliases = { "cosine_decay": "cosine", } if "optim" in trainer_raw: trainer_raw = { **trainer_raw, "optim": optim_aliases.get(trainer_raw["optim"], trainer_raw["optim"]), } if "lr_scheduler_type" in trainer_raw: trainer_raw = { **trainer_raw, "lr_scheduler_type": scheduler_aliases.get( trainer_raw["lr_scheduler_type"], trainer_raw["lr_scheduler_type"] ), } # Normalize numeric fields that may come from YAML as strings. float_fields = { "learning_rate", "max_grad_norm", "num_train_epochs", } int_fields = { "max_steps", "per_device_train_batch_size", "gradient_accumulation_steps", "logging_steps", "save_steps", "save_total_limit", "seed", "warmup_steps", "sanity_log_examples", "sanity_log_max_chars", "permanent_checkpoint_steps", } for key in float_fields: if key in trainer_raw: trainer_raw = {**trainer_raw, key: float(trainer_raw[key])} for key in int_fields: if key in trainer_raw: trainer_raw = {**trainer_raw, key: int(trainer_raw[key])} return ExperimentConfig( model=ModelConfig(**raw["model"]), trainer=TrainerConfig(**trainer_raw), data=DataConfig(**raw.get("data", {})), generation=GenerationConfig(**raw.get("generation", {})), objective=ObjectiveConfig(**raw.get("objective", {})), rewards=RewardsConfig(**raw.get("rewards", {})), auth=AuthConfig(**raw.get("auth", {})), storage=StorageConfig(**raw.get("storage", {})), thinking_kl=ThinkingKLConfig(**raw.get("thinking_kl", {})), grpo=dict(raw.get("grpo", {})), )