| import json |
| from dataclasses import dataclass, fields |
| from pathlib import Path |
| from transformers import TrainingArguments |
|
|
| VALID_MODES = {"marker", "qa_m", "qa_b"} |
|
|
|
|
| @dataclass |
| class TrainingConfig: |
| mode: str = "marker" |
| data_path: str = "data/data_augmented_256.jsonl" |
| output_dir: str = "models/" |
| model_name: str = "distilbert-base-uncased" |
| max_len: int = 256 |
| num_train_epochs: int = 5 |
| per_device_train_batch_size: int = 32 |
| per_device_eval_batch_size: int = 64 |
| gradient_accumulation_steps: int = 1 |
| learning_rate: float = 2e-5 |
| warmup_ratio: float = 0.1 |
| weight_decay: float = 0.01 |
| val_split: float = 0.1 |
| test_split: float = 0.1 |
| early_stopping_patience: int = 3 |
| fp16: bool = True |
| seed: int = 42 |
| logging_steps: int = 50 |
| save_total_limit: int = 2 |
| loss_fn: str = "cross_entropy" |
| focal_gamma: float = 2.0 |
|
|
| def __post_init__(self): |
| if self.mode not in VALID_MODES: |
| raise ValueError(f"mode must be one of {VALID_MODES}, got '{self.mode}'") |
| if self.loss_fn not in ("cross_entropy", "focal"): |
| raise ValueError(f"loss_fn must be 'cross_entropy' or 'focal', got '{self.loss_fn}'") |
|
|
| @classmethod |
| def from_json(cls, path: str | Path) -> "TrainingConfig": |
| with open(path, "r", encoding="utf-8") as f: |
| raw = json.load(f) |
| known = {f.name for f in fields(cls)} |
| unknown = set(raw.keys()) - known |
| if unknown: |
| print(f"Warning: unknown config keys ignored: {unknown}") |
| filtered = {k: v for k, v in raw.items() if k in known} |
| return cls(**filtered) |
|
|
| def to_training_arguments(self) -> TrainingArguments: |
| return TrainingArguments( |
| output_dir=self.output_dir, |
| num_train_epochs=self.num_train_epochs, |
| per_device_train_batch_size=self.per_device_train_batch_size, |
| per_device_eval_batch_size=self.per_device_eval_batch_size, |
| gradient_accumulation_steps=self.gradient_accumulation_steps, |
| learning_rate=self.learning_rate, |
| warmup_ratio=self.warmup_ratio, |
| weight_decay=self.weight_decay, |
| eval_strategy="epoch", |
| save_strategy="epoch", |
| load_best_model_at_end=True, |
| metric_for_best_model="macro_f1", |
| greater_is_better=True, |
| save_total_limit=self.save_total_limit, |
| logging_dir=f"{self.output_dir}/logs", |
| logging_steps=self.logging_steps, |
| report_to="none", |
| fp16=self.fp16, |
| seed=self.seed, |
| data_seed=self.seed, |
| ) |
|
|