mumble-cleanup / src /cleanup /config.py
adikuma's picture
initial upload: cleanup code and 688-pair seed dataset
fd0b01f verified
Raw
History Blame Contribute Delete
2.12 kB
# typed dataclass wrappers over the yaml configs. consumed by every entry
# point so hyperparameters never sneak into code.
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
import yaml
# ---------- training ----------
@dataclass
class LoraConfig:
r: int
alpha: int
dropout: float
bias: str
target_modules: list
@dataclass
class TrainConfig:
base_model: str
language: str
data_dir: str
max_seq_length: int
train_rows: Optional[int]
val_rows: Optional[int]
lora: LoraConfig
learning_rate: float
weight_decay: float
warmup_ratio: float
max_grad_norm: float
adam_beta1: float
adam_beta2: float
adam_epsilon: float
num_epochs: int
lr_scheduler_type: str
train_batch_size: int
eval_batch_size: int
gradient_accumulation_steps: int
bf16: bool
fp16: bool
tf32: bool
seed: int
metric_for_best_model: str
greater_is_better: bool
save_total_limit: int
logging_steps: int
eval_steps: int
save_steps: int
dataloader_num_workers: int
def load_train_config(path) -> TrainConfig:
raw = yaml.safe_load(Path(path).read_text())
lora = LoraConfig(**raw.pop("lora"))
return TrainConfig(lora=lora, **raw)
# ---------- data ----------
@dataclass
class DataSplits:
train: float
val: float
test: float
@dataclass
class DataConfig:
seed_path: str
splits: DataSplits
random_seed: int
def load_data_config(path) -> DataConfig:
raw = yaml.safe_load(Path(path).read_text())
splits = DataSplits(**raw.pop("splits"))
return DataConfig(splits=splits, **raw)
# ---------- optional inject (v1.1) ----------
@dataclass
class InjectSampling:
ops_per_example_min: int
ops_per_example_max: int
@dataclass
class InjectConfig:
sampling: InjectSampling
ops: dict = field(default_factory=dict)
def load_inject_config(path) -> InjectConfig:
raw = yaml.safe_load(Path(path).read_text())
sampling = InjectSampling(**raw.pop("sampling"))
return InjectConfig(sampling=sampling, **raw)