""" src/training/config.py — TrainingConfig dataclass Single source of truth for all training hyperparameters. Import this from every training script instead of duplicating argparse defaults. """ from __future__ import annotations from dataclasses import dataclass, field from pathlib import Path from typing import List # Generator label index mapping — must match GeneratorLabel enum in src/types.py # and the classification head in every model file. GENERATOR_CLASSES: List[str] = [ "real", # 0 "unknown_gan", # 1 "stable_diffusion", # 2 "midjourney", # 3 "dall_e", # 4 "flux", # 5 "firefly", # 6 "imagen", # 7 ] NUM_GENERATOR_CLASSES: int = len(GENERATOR_CLASSES) # 8 — never change this @dataclass class TrainingConfig: """Fingerprint engine training configuration (Phase 1).""" # ── Paths ───────────────────────────────────────────────────────────────── data_dir: Path = Path("data/processed/fingerprint") output_dir: Path = Path("models/checkpoints/fingerprint") log_dir: Path = Path("training/logs") # ── Model ───────────────────────────────────────────────────────────────── model_name: str = "vit_base_patch16_224" # timm slug pretrained: bool = True num_binary_classes: int = 2 num_generator_classes: int = NUM_GENERATOR_CLASSES # 8 # ── Training ────────────────────────────────────────────────────────────── epochs: int = 30 batch_size: int = 64 learning_rate: float = 2e-5 weight_decay: float = 0.01 warmup_steps: int = 500 grad_accumulation: int = 1 amp: bool = True generator_loss_weight: float = 0.3 # secondary objective weight # ── Optimiser ───────────────────────────────────────────────────────────── optimizer: str = "adamw" scheduler: str = "cosine" # ── Early stopping ──────────────────────────────────────────────────────── patience: int = 5 # stop if val AUC flat for N epochs # ── Reproducibility ─────────────────────────────────────────────────────── seed: int = 42 # ── Kaggle ──────────────────────────────────────────────────────────────── is_kaggle: bool = False def __post_init__(self) -> None: self.data_dir = Path(self.data_dir) self.output_dir = Path(self.output_dir) self.log_dir = Path(self.log_dir) self.output_dir.mkdir(parents=True, exist_ok=True) self.log_dir.mkdir(parents=True, exist_ok=True) def to_dict(self) -> dict: return { k: str(v) if isinstance(v, Path) else v for k, v in self.__dict__.items() } @dataclass class CoherenceConfig: """Coherence engine training configuration (Phase 2).""" data_dir: Path = Path("data/processed/coherence") output_dir: Path = Path("models/checkpoints/coherence") log_dir: Path = Path("training/logs") epochs: int = 25 batch_size: int = 16 learning_rate: float = 1e-4 weight_decay: float = 1e-4 contrastive_weight: float = 0.1 amp: bool = True patience: int = 5 seed: int = 42 is_kaggle: bool = False def __post_init__(self) -> None: self.data_dir = Path(self.data_dir) self.output_dir = Path(self.output_dir) self.log_dir = Path(self.log_dir) self.output_dir.mkdir(parents=True, exist_ok=True) self.log_dir.mkdir(parents=True, exist_ok=True) def to_dict(self) -> dict: return { k: str(v) if isinstance(v, Path) else v for k, v in self.__dict__.items() } @dataclass class SSTGNNConfig: """SSTGNN engine training configuration (Phase 3).""" data_dir: Path = Path("data/processed/sstgnn") output_dir: Path = Path("models/checkpoints/sstgnn") log_dir: Path = Path("training/logs") # Graph model in_channels: int = 5 # node feature dim: x, y, z, frame_idx/T, lm_idx/L hidden_dim: int = 64 heads: int = 4 num_gat_layers: int = 3 # Training — note: no AMP by default (NaN risk on GNNs) epochs: int = 40 batch_size: int = 8 learning_rate: float = 5e-4 weight_decay: float = 5e-4 grad_clip: float = 1.0 # always clip for GNNs amp: bool = False patience: int = 8 # GNNs need longer to plateau seed: int = 42 is_kaggle: bool = False def __post_init__(self) -> None: self.data_dir = Path(self.data_dir) self.output_dir = Path(self.output_dir) self.log_dir = Path(self.log_dir) self.output_dir.mkdir(parents=True, exist_ok=True) self.log_dir.mkdir(parents=True, exist_ok=True) def to_dict(self) -> dict: return { k: str(v) if isinstance(v, Path) else v for k, v in self.__dict__.items() }