sv-task / src /schemas /config.py
lamossta's picture
config dataclass for models
4ce549d
import json
from dataclasses import dataclass, fields
from pathlib import Path
from transformers import TrainingArguments
VALID_MODES = {"marker", "qa_m", "qa_b"}
@dataclass
class TrainingConfig:
mode: str = "marker"
data_path: str = "data/data_augmented_256.jsonl"
output_dir: str = "models/"
model_name: str = "distilbert-base-uncased"
max_len: int = 256
num_train_epochs: int = 5
per_device_train_batch_size: int = 32
per_device_eval_batch_size: int = 64
gradient_accumulation_steps: int = 1
learning_rate: float = 2e-5
warmup_ratio: float = 0.1
weight_decay: float = 0.01
val_split: float = 0.1
test_split: float = 0.1
early_stopping_patience: int = 3
fp16: bool = True
seed: int = 42
logging_steps: int = 50
save_total_limit: int = 2
loss_fn: str = "cross_entropy"
focal_gamma: float = 2.0
def __post_init__(self):
if self.mode not in VALID_MODES:
raise ValueError(f"mode must be one of {VALID_MODES}, got '{self.mode}'")
if self.loss_fn not in ("cross_entropy", "focal"):
raise ValueError(f"loss_fn must be 'cross_entropy' or 'focal', got '{self.loss_fn}'")
@classmethod
def from_json(cls, path: str | Path) -> "TrainingConfig":
with open(path, "r", encoding="utf-8") as f:
raw = json.load(f)
known = {f.name for f in fields(cls)}
unknown = set(raw.keys()) - known
if unknown:
print(f"Warning: unknown config keys ignored: {unknown}")
filtered = {k: v for k, v in raw.items() if k in known}
return cls(**filtered)
def to_training_arguments(self) -> TrainingArguments:
return TrainingArguments(
output_dir=self.output_dir,
num_train_epochs=self.num_train_epochs,
per_device_train_batch_size=self.per_device_train_batch_size,
per_device_eval_batch_size=self.per_device_eval_batch_size,
gradient_accumulation_steps=self.gradient_accumulation_steps,
learning_rate=self.learning_rate,
warmup_ratio=self.warmup_ratio,
weight_decay=self.weight_decay,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="macro_f1",
greater_is_better=True,
save_total_limit=self.save_total_limit,
logging_dir=f"{self.output_dir}/logs",
logging_steps=self.logging_steps,
report_to="none",
fp16=self.fp16,
seed=self.seed,
data_seed=self.seed,
)