import json from dataclasses import dataclass, fields from pathlib import Path from transformers import TrainingArguments VALID_MODES = {"marker", "qa_m", "qa_b"} @dataclass class TrainingConfig: mode: str = "marker" data_path: str = "data/data_augmented_256.jsonl" output_dir: str = "models/" model_name: str = "distilbert-base-uncased" max_len: int = 256 num_train_epochs: int = 5 per_device_train_batch_size: int = 32 per_device_eval_batch_size: int = 64 gradient_accumulation_steps: int = 1 learning_rate: float = 2e-5 warmup_ratio: float = 0.1 weight_decay: float = 0.01 val_split: float = 0.1 test_split: float = 0.1 early_stopping_patience: int = 3 fp16: bool = True seed: int = 42 logging_steps: int = 50 save_total_limit: int = 2 loss_fn: str = "cross_entropy" focal_gamma: float = 2.0 def __post_init__(self): if self.mode not in VALID_MODES: raise ValueError(f"mode must be one of {VALID_MODES}, got '{self.mode}'") if self.loss_fn not in ("cross_entropy", "focal"): raise ValueError(f"loss_fn must be 'cross_entropy' or 'focal', got '{self.loss_fn}'") @classmethod def from_json(cls, path: str | Path) -> "TrainingConfig": with open(path, "r", encoding="utf-8") as f: raw = json.load(f) known = {f.name for f in fields(cls)} unknown = set(raw.keys()) - known if unknown: print(f"Warning: unknown config keys ignored: {unknown}") filtered = {k: v for k, v in raw.items() if k in known} return cls(**filtered) def to_training_arguments(self) -> TrainingArguments: return TrainingArguments( output_dir=self.output_dir, num_train_epochs=self.num_train_epochs, per_device_train_batch_size=self.per_device_train_batch_size, per_device_eval_batch_size=self.per_device_eval_batch_size, gradient_accumulation_steps=self.gradient_accumulation_steps, learning_rate=self.learning_rate, warmup_ratio=self.warmup_ratio, weight_decay=self.weight_decay, eval_strategy="epoch", save_strategy="epoch", load_best_model_at_end=True, metric_for_best_model="macro_f1", greater_is_better=True, save_total_limit=self.save_total_limit, logging_dir=f"{self.output_dir}/logs", logging_steps=self.logging_steps, report_to="none", fp16=self.fp16, seed=self.seed, data_seed=self.seed, )