neuralese_temp / src /hackable /config.py
psidharth567's picture
Export neuralese codebase (cache and .env excluded).
dbc69f3
from __future__ import annotations
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import yaml
from .utils import resolve_repo_path
@dataclass
class ModelConfig:
name: str
trust_remote_code: bool = False
load_in_4bit: bool = False
use_lora_adapters: bool = False
lora_r: int = 16
lora_alpha: int = 16
lora_dropout: float = 0.0
@dataclass
class TrainerConfig:
output_dir: str
run_name: str = "grpo-run"
max_steps: int = -1
num_train_epochs: float = 1.0
per_device_train_batch_size: int = 1
gradient_accumulation_steps: int = 8
learning_rate: float = 1.0e-6
logging_steps: int = 1
save_steps: int = 25
save_total_limit: int = 5
bf16: bool = True
seed: int = 42
report_to: str = "wandb"
optim: str = "adamw_torch"
gradient_checkpointing: bool = True
max_grad_norm: float = 1.0
shuffle_dataset: bool = False
lr_scheduler_type: str = "cosine"
lr_scheduler_kwargs: dict[str, Any] = field(default_factory=dict)
warmup_steps: int = 20
sanity_log_examples: int = 8
sanity_log_max_chars: int = 300
permanent_checkpoint_steps: int = 300
permanent_checkpoint_dir: str = "checkpoints/permanent"
@dataclass
class DataConfig:
provider: str = "gsm8k_math_curriculum"
split: str = "train"
max_samples: int | None = None
@dataclass
class GenerationConfig:
max_prompt_length: int = 512
max_completion_length: int = 256
num_generations: int = 4
temperature: float = 0.9
top_p: float = 0.95
@dataclass
class ObjectiveConfig:
name: str = "token_grpo"
kwargs: dict[str, Any] = field(default_factory=dict)
class_path: str | None = None
@dataclass
class RewardsConfig:
kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
@dataclass
class AuthConfig:
hf_api_key: str | None = None
wandb_api_key: str | None = None
hf_api_key_env: str = "HF_TOKEN"
wandb_api_key_env: str = "WANDB_API_KEY"
@dataclass
class StorageConfig:
cache_dir: str = "cache"
@dataclass
class ThinkingKLConfig:
"""Scale KL penalty on completion tokens that overlap the *inner* redacted thinking body."""
inner_kl_weight: float = 1.0
@dataclass
class ExperimentConfig:
model: ModelConfig
trainer: TrainerConfig
data: DataConfig = field(default_factory=DataConfig)
generation: GenerationConfig = field(default_factory=GenerationConfig)
objective: ObjectiveConfig = field(default_factory=ObjectiveConfig)
rewards: RewardsConfig = field(default_factory=RewardsConfig)
auth: AuthConfig = field(default_factory=AuthConfig)
storage: StorageConfig = field(default_factory=StorageConfig)
thinking_kl: ThinkingKLConfig = field(default_factory=ThinkingKLConfig)
grpo: dict[str, Any] = field(default_factory=dict)
def load_config(path: str | Path) -> ExperimentConfig:
resolved = resolve_repo_path(path)
with resolved.open("r", encoding="utf-8") as handle:
raw = yaml.safe_load(handle)
trainer_raw = raw["trainer"]
# Backward-compatible alias: allow "optimizer" in YAML.
if "optim" not in trainer_raw and "optimizer" in trainer_raw:
trainer_raw = {**trainer_raw, "optim": trainer_raw["optimizer"]}
if "optimizer" in trainer_raw:
trainer_raw = {k: v for k, v in trainer_raw.items() if k != "optimizer"}
# Accept common shorthand names.
optim_aliases = {
"adamw": "adamw_torch",
"adamw_fused": "adamw_torch_fused",
}
scheduler_aliases = {
"cosine_decay": "cosine",
}
if "optim" in trainer_raw:
trainer_raw = {
**trainer_raw,
"optim": optim_aliases.get(trainer_raw["optim"], trainer_raw["optim"]),
}
if "lr_scheduler_type" in trainer_raw:
trainer_raw = {
**trainer_raw,
"lr_scheduler_type": scheduler_aliases.get(
trainer_raw["lr_scheduler_type"], trainer_raw["lr_scheduler_type"]
),
}
# Normalize numeric fields that may come from YAML as strings.
float_fields = {
"learning_rate",
"max_grad_norm",
"num_train_epochs",
}
int_fields = {
"max_steps",
"per_device_train_batch_size",
"gradient_accumulation_steps",
"logging_steps",
"save_steps",
"save_total_limit",
"seed",
"warmup_steps",
"sanity_log_examples",
"sanity_log_max_chars",
"permanent_checkpoint_steps",
}
for key in float_fields:
if key in trainer_raw:
trainer_raw = {**trainer_raw, key: float(trainer_raw[key])}
for key in int_fields:
if key in trainer_raw:
trainer_raw = {**trainer_raw, key: int(trainer_raw[key])}
return ExperimentConfig(
model=ModelConfig(**raw["model"]),
trainer=TrainerConfig(**trainer_raw),
data=DataConfig(**raw.get("data", {})),
generation=GenerationConfig(**raw.get("generation", {})),
objective=ObjectiveConfig(**raw.get("objective", {})),
rewards=RewardsConfig(**raw.get("rewards", {})),
auth=AuthConfig(**raw.get("auth", {})),
storage=StorageConfig(**raw.get("storage", {})),
thinking_kl=ThinkingKLConfig(**raw.get("thinking_kl", {})),
grpo=dict(raw.get("grpo", {})),
)