File size: 5,388 Bytes
dbc69f3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 | from __future__ import annotations
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import yaml
from .utils import resolve_repo_path
@dataclass
class ModelConfig:
name: str
trust_remote_code: bool = False
load_in_4bit: bool = False
use_lora_adapters: bool = False
lora_r: int = 16
lora_alpha: int = 16
lora_dropout: float = 0.0
@dataclass
class TrainerConfig:
output_dir: str
run_name: str = "grpo-run"
max_steps: int = -1
num_train_epochs: float = 1.0
per_device_train_batch_size: int = 1
gradient_accumulation_steps: int = 8
learning_rate: float = 1.0e-6
logging_steps: int = 1
save_steps: int = 25
save_total_limit: int = 5
bf16: bool = True
seed: int = 42
report_to: str = "wandb"
optim: str = "adamw_torch"
gradient_checkpointing: bool = True
max_grad_norm: float = 1.0
shuffle_dataset: bool = False
lr_scheduler_type: str = "cosine"
lr_scheduler_kwargs: dict[str, Any] = field(default_factory=dict)
warmup_steps: int = 20
sanity_log_examples: int = 8
sanity_log_max_chars: int = 300
permanent_checkpoint_steps: int = 300
permanent_checkpoint_dir: str = "checkpoints/permanent"
@dataclass
class DataConfig:
provider: str = "gsm8k_math_curriculum"
split: str = "train"
max_samples: int | None = None
@dataclass
class GenerationConfig:
max_prompt_length: int = 512
max_completion_length: int = 256
num_generations: int = 4
temperature: float = 0.9
top_p: float = 0.95
@dataclass
class ObjectiveConfig:
name: str = "token_grpo"
kwargs: dict[str, Any] = field(default_factory=dict)
class_path: str | None = None
@dataclass
class RewardsConfig:
kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
@dataclass
class AuthConfig:
hf_api_key: str | None = None
wandb_api_key: str | None = None
hf_api_key_env: str = "HF_TOKEN"
wandb_api_key_env: str = "WANDB_API_KEY"
@dataclass
class StorageConfig:
cache_dir: str = "cache"
@dataclass
class ThinkingKLConfig:
"""Scale KL penalty on completion tokens that overlap the *inner* redacted thinking body."""
inner_kl_weight: float = 1.0
@dataclass
class ExperimentConfig:
model: ModelConfig
trainer: TrainerConfig
data: DataConfig = field(default_factory=DataConfig)
generation: GenerationConfig = field(default_factory=GenerationConfig)
objective: ObjectiveConfig = field(default_factory=ObjectiveConfig)
rewards: RewardsConfig = field(default_factory=RewardsConfig)
auth: AuthConfig = field(default_factory=AuthConfig)
storage: StorageConfig = field(default_factory=StorageConfig)
thinking_kl: ThinkingKLConfig = field(default_factory=ThinkingKLConfig)
grpo: dict[str, Any] = field(default_factory=dict)
def load_config(path: str | Path) -> ExperimentConfig:
resolved = resolve_repo_path(path)
with resolved.open("r", encoding="utf-8") as handle:
raw = yaml.safe_load(handle)
trainer_raw = raw["trainer"]
# Backward-compatible alias: allow "optimizer" in YAML.
if "optim" not in trainer_raw and "optimizer" in trainer_raw:
trainer_raw = {**trainer_raw, "optim": trainer_raw["optimizer"]}
if "optimizer" in trainer_raw:
trainer_raw = {k: v for k, v in trainer_raw.items() if k != "optimizer"}
# Accept common shorthand names.
optim_aliases = {
"adamw": "adamw_torch",
"adamw_fused": "adamw_torch_fused",
}
scheduler_aliases = {
"cosine_decay": "cosine",
}
if "optim" in trainer_raw:
trainer_raw = {
**trainer_raw,
"optim": optim_aliases.get(trainer_raw["optim"], trainer_raw["optim"]),
}
if "lr_scheduler_type" in trainer_raw:
trainer_raw = {
**trainer_raw,
"lr_scheduler_type": scheduler_aliases.get(
trainer_raw["lr_scheduler_type"], trainer_raw["lr_scheduler_type"]
),
}
# Normalize numeric fields that may come from YAML as strings.
float_fields = {
"learning_rate",
"max_grad_norm",
"num_train_epochs",
}
int_fields = {
"max_steps",
"per_device_train_batch_size",
"gradient_accumulation_steps",
"logging_steps",
"save_steps",
"save_total_limit",
"seed",
"warmup_steps",
"sanity_log_examples",
"sanity_log_max_chars",
"permanent_checkpoint_steps",
}
for key in float_fields:
if key in trainer_raw:
trainer_raw = {**trainer_raw, key: float(trainer_raw[key])}
for key in int_fields:
if key in trainer_raw:
trainer_raw = {**trainer_raw, key: int(trainer_raw[key])}
return ExperimentConfig(
model=ModelConfig(**raw["model"]),
trainer=TrainerConfig(**trainer_raw),
data=DataConfig(**raw.get("data", {})),
generation=GenerationConfig(**raw.get("generation", {})),
objective=ObjectiveConfig(**raw.get("objective", {})),
rewards=RewardsConfig(**raw.get("rewards", {})),
auth=AuthConfig(**raw.get("auth", {})),
storage=StorageConfig(**raw.get("storage", {})),
thinking_kl=ThinkingKLConfig(**raw.get("thinking_kl", {})),
grpo=dict(raw.get("grpo", {})),
)
|