File size: 5,388 Bytes
dbc69f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
from __future__ import annotations

from dataclasses import dataclass, field
from pathlib import Path
from typing import Any

import yaml

from .utils import resolve_repo_path


@dataclass
class ModelConfig:
    name: str
    trust_remote_code: bool = False
    load_in_4bit: bool = False
    use_lora_adapters: bool = False
    lora_r: int = 16
    lora_alpha: int = 16
    lora_dropout: float = 0.0


@dataclass
class TrainerConfig:
    output_dir: str
    run_name: str = "grpo-run"
    max_steps: int = -1
    num_train_epochs: float = 1.0
    per_device_train_batch_size: int = 1
    gradient_accumulation_steps: int = 8
    learning_rate: float = 1.0e-6
    logging_steps: int = 1
    save_steps: int = 25
    save_total_limit: int = 5
    bf16: bool = True
    seed: int = 42
    report_to: str = "wandb"
    optim: str = "adamw_torch"
    gradient_checkpointing: bool = True
    max_grad_norm: float = 1.0
    shuffle_dataset: bool = False
    lr_scheduler_type: str = "cosine"
    lr_scheduler_kwargs: dict[str, Any] = field(default_factory=dict)
    warmup_steps: int = 20
    sanity_log_examples: int = 8
    sanity_log_max_chars: int = 300
    permanent_checkpoint_steps: int = 300
    permanent_checkpoint_dir: str = "checkpoints/permanent"



@dataclass
class DataConfig:
    provider: str = "gsm8k_math_curriculum"
    split: str = "train"
    max_samples: int | None = None


@dataclass
class GenerationConfig:
    max_prompt_length: int = 512
    max_completion_length: int = 256
    num_generations: int = 4
    temperature: float = 0.9
    top_p: float = 0.95


@dataclass
class ObjectiveConfig:
    name: str = "token_grpo"
    kwargs: dict[str, Any] = field(default_factory=dict)
    class_path: str | None = None


@dataclass
class RewardsConfig:
    kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)


@dataclass
class AuthConfig:
    hf_api_key: str | None = None
    wandb_api_key: str | None = None
    hf_api_key_env: str = "HF_TOKEN"
    wandb_api_key_env: str = "WANDB_API_KEY"


@dataclass
class StorageConfig:
    cache_dir: str = "cache"


@dataclass
class ThinkingKLConfig:
    """Scale KL penalty on completion tokens that overlap the *inner* redacted thinking body."""

    inner_kl_weight: float = 1.0


@dataclass
class ExperimentConfig:
    model: ModelConfig
    trainer: TrainerConfig
    data: DataConfig = field(default_factory=DataConfig)
    generation: GenerationConfig = field(default_factory=GenerationConfig)
    objective: ObjectiveConfig = field(default_factory=ObjectiveConfig)
    rewards: RewardsConfig = field(default_factory=RewardsConfig)
    auth: AuthConfig = field(default_factory=AuthConfig)
    storage: StorageConfig = field(default_factory=StorageConfig)
    thinking_kl: ThinkingKLConfig = field(default_factory=ThinkingKLConfig)
    grpo: dict[str, Any] = field(default_factory=dict)


def load_config(path: str | Path) -> ExperimentConfig:
    resolved = resolve_repo_path(path)
    with resolved.open("r", encoding="utf-8") as handle:
        raw = yaml.safe_load(handle)

    trainer_raw = raw["trainer"]
    # Backward-compatible alias: allow "optimizer" in YAML.
    if "optim" not in trainer_raw and "optimizer" in trainer_raw:
        trainer_raw = {**trainer_raw, "optim": trainer_raw["optimizer"]}
    if "optimizer" in trainer_raw:
        trainer_raw = {k: v for k, v in trainer_raw.items() if k != "optimizer"}
    # Accept common shorthand names.
    optim_aliases = {
        "adamw": "adamw_torch",
        "adamw_fused": "adamw_torch_fused",
    }
    scheduler_aliases = {
        "cosine_decay": "cosine",
    }
    if "optim" in trainer_raw:
        trainer_raw = {
            **trainer_raw,
            "optim": optim_aliases.get(trainer_raw["optim"], trainer_raw["optim"]),
        }
    if "lr_scheduler_type" in trainer_raw:
        trainer_raw = {
            **trainer_raw,
            "lr_scheduler_type": scheduler_aliases.get(
                trainer_raw["lr_scheduler_type"], trainer_raw["lr_scheduler_type"]
            ),
        }

    # Normalize numeric fields that may come from YAML as strings.
    float_fields = {
        "learning_rate",
        "max_grad_norm",
        "num_train_epochs",
    }
    int_fields = {
        "max_steps",
        "per_device_train_batch_size",
        "gradient_accumulation_steps",
        "logging_steps",
        "save_steps",
        "save_total_limit",
        "seed",
        "warmup_steps",
        "sanity_log_examples",
        "sanity_log_max_chars",
        "permanent_checkpoint_steps",
    }
    for key in float_fields:
        if key in trainer_raw:
            trainer_raw = {**trainer_raw, key: float(trainer_raw[key])}
    for key in int_fields:
        if key in trainer_raw:
            trainer_raw = {**trainer_raw, key: int(trainer_raw[key])}

    return ExperimentConfig(
        model=ModelConfig(**raw["model"]),
        trainer=TrainerConfig(**trainer_raw),
        data=DataConfig(**raw.get("data", {})),
        generation=GenerationConfig(**raw.get("generation", {})),
        objective=ObjectiveConfig(**raw.get("objective", {})),
        rewards=RewardsConfig(**raw.get("rewards", {})),
        auth=AuthConfig(**raw.get("auth", {})),
        storage=StorageConfig(**raw.get("storage", {})),
        thinking_kl=ThinkingKLConfig(**raw.get("thinking_kl", {})),
        grpo=dict(raw.get("grpo", {})),
    )