| from __future__ import annotations |
|
|
| from dataclasses import dataclass, field |
| from pathlib import Path |
| from typing import Any |
|
|
| import yaml |
|
|
| from .utils import resolve_repo_path |
|
|
|
|
| @dataclass |
| class ModelConfig: |
| name: str |
| trust_remote_code: bool = False |
| load_in_4bit: bool = False |
| use_lora_adapters: bool = False |
| lora_r: int = 16 |
| lora_alpha: int = 16 |
| lora_dropout: float = 0.0 |
|
|
|
|
| @dataclass |
| class TrainerConfig: |
| output_dir: str |
| run_name: str = "grpo-run" |
| max_steps: int = -1 |
| num_train_epochs: float = 1.0 |
| per_device_train_batch_size: int = 1 |
| gradient_accumulation_steps: int = 8 |
| learning_rate: float = 1.0e-6 |
| logging_steps: int = 1 |
| save_steps: int = 25 |
| save_total_limit: int = 5 |
| bf16: bool = True |
| seed: int = 42 |
| report_to: str = "wandb" |
| optim: str = "adamw_torch" |
| gradient_checkpointing: bool = True |
| max_grad_norm: float = 1.0 |
| shuffle_dataset: bool = False |
| lr_scheduler_type: str = "cosine" |
| lr_scheduler_kwargs: dict[str, Any] = field(default_factory=dict) |
| warmup_steps: int = 20 |
| sanity_log_examples: int = 8 |
| sanity_log_max_chars: int = 300 |
| permanent_checkpoint_steps: int = 300 |
| permanent_checkpoint_dir: str = "checkpoints/permanent" |
|
|
|
|
|
|
| @dataclass |
| class DataConfig: |
| provider: str = "gsm8k_math_curriculum" |
| split: str = "train" |
| max_samples: int | None = None |
|
|
|
|
| @dataclass |
| class GenerationConfig: |
| max_prompt_length: int = 512 |
| max_completion_length: int = 256 |
| num_generations: int = 4 |
| temperature: float = 0.9 |
| top_p: float = 0.95 |
|
|
|
|
| @dataclass |
| class ObjectiveConfig: |
| name: str = "token_grpo" |
| kwargs: dict[str, Any] = field(default_factory=dict) |
| class_path: str | None = None |
|
|
|
|
| @dataclass |
| class RewardsConfig: |
| kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) |
|
|
|
|
| @dataclass |
| class AuthConfig: |
| hf_api_key: str | None = None |
| wandb_api_key: str | None = None |
| hf_api_key_env: str = "HF_TOKEN" |
| wandb_api_key_env: str = "WANDB_API_KEY" |
|
|
|
|
| @dataclass |
| class StorageConfig: |
| cache_dir: str = "cache" |
|
|
|
|
| @dataclass |
| class ThinkingKLConfig: |
| """Scale KL penalty on completion tokens that overlap the *inner* redacted thinking body.""" |
|
|
| inner_kl_weight: float = 1.0 |
|
|
|
|
| @dataclass |
| class ExperimentConfig: |
| model: ModelConfig |
| trainer: TrainerConfig |
| data: DataConfig = field(default_factory=DataConfig) |
| generation: GenerationConfig = field(default_factory=GenerationConfig) |
| objective: ObjectiveConfig = field(default_factory=ObjectiveConfig) |
| rewards: RewardsConfig = field(default_factory=RewardsConfig) |
| auth: AuthConfig = field(default_factory=AuthConfig) |
| storage: StorageConfig = field(default_factory=StorageConfig) |
| thinking_kl: ThinkingKLConfig = field(default_factory=ThinkingKLConfig) |
| grpo: dict[str, Any] = field(default_factory=dict) |
|
|
|
|
| def load_config(path: str | Path) -> ExperimentConfig: |
| resolved = resolve_repo_path(path) |
| with resolved.open("r", encoding="utf-8") as handle: |
| raw = yaml.safe_load(handle) |
|
|
| trainer_raw = raw["trainer"] |
| |
| if "optim" not in trainer_raw and "optimizer" in trainer_raw: |
| trainer_raw = {**trainer_raw, "optim": trainer_raw["optimizer"]} |
| if "optimizer" in trainer_raw: |
| trainer_raw = {k: v for k, v in trainer_raw.items() if k != "optimizer"} |
| |
| optim_aliases = { |
| "adamw": "adamw_torch", |
| "adamw_fused": "adamw_torch_fused", |
| } |
| scheduler_aliases = { |
| "cosine_decay": "cosine", |
| } |
| if "optim" in trainer_raw: |
| trainer_raw = { |
| **trainer_raw, |
| "optim": optim_aliases.get(trainer_raw["optim"], trainer_raw["optim"]), |
| } |
| if "lr_scheduler_type" in trainer_raw: |
| trainer_raw = { |
| **trainer_raw, |
| "lr_scheduler_type": scheduler_aliases.get( |
| trainer_raw["lr_scheduler_type"], trainer_raw["lr_scheduler_type"] |
| ), |
| } |
|
|
| |
| float_fields = { |
| "learning_rate", |
| "max_grad_norm", |
| "num_train_epochs", |
| } |
| int_fields = { |
| "max_steps", |
| "per_device_train_batch_size", |
| "gradient_accumulation_steps", |
| "logging_steps", |
| "save_steps", |
| "save_total_limit", |
| "seed", |
| "warmup_steps", |
| "sanity_log_examples", |
| "sanity_log_max_chars", |
| "permanent_checkpoint_steps", |
| } |
| for key in float_fields: |
| if key in trainer_raw: |
| trainer_raw = {**trainer_raw, key: float(trainer_raw[key])} |
| for key in int_fields: |
| if key in trainer_raw: |
| trainer_raw = {**trainer_raw, key: int(trainer_raw[key])} |
|
|
| return ExperimentConfig( |
| model=ModelConfig(**raw["model"]), |
| trainer=TrainerConfig(**trainer_raw), |
| data=DataConfig(**raw.get("data", {})), |
| generation=GenerationConfig(**raw.get("generation", {})), |
| objective=ObjectiveConfig(**raw.get("objective", {})), |
| rewards=RewardsConfig(**raw.get("rewards", {})), |
| auth=AuthConfig(**raw.get("auth", {})), |
| storage=StorageConfig(**raw.get("storage", {})), |
| thinking_kl=ThinkingKLConfig(**raw.get("thinking_kl", {})), |
| grpo=dict(raw.get("grpo", {})), |
| ) |
|
|