Spaces:
Running
Running
| """ | |
| src/training/config.py β TrainingConfig dataclass | |
| Single source of truth for all training hyperparameters. Import this from | |
| every training script instead of duplicating argparse defaults. | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import List | |
| # Generator label index mapping β must match GeneratorLabel enum in src/types.py | |
| # and the classification head in every model file. | |
| GENERATOR_CLASSES: List[str] = [ | |
| "real", # 0 | |
| "unknown_gan", # 1 | |
| "stable_diffusion", # 2 | |
| "midjourney", # 3 | |
| "dall_e", # 4 | |
| "flux", # 5 | |
| "firefly", # 6 | |
| "imagen", # 7 | |
| ] | |
| NUM_GENERATOR_CLASSES: int = len(GENERATOR_CLASSES) # 8 β never change this | |
| class TrainingConfig: | |
| """Fingerprint engine training configuration (Phase 1).""" | |
| # ββ Paths βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| data_dir: Path = Path("data/processed/fingerprint") | |
| output_dir: Path = Path("models/checkpoints/fingerprint") | |
| log_dir: Path = Path("training/logs") | |
| # ββ Model βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| model_name: str = "vit_base_patch16_224" # timm slug | |
| pretrained: bool = True | |
| num_binary_classes: int = 2 | |
| num_generator_classes: int = NUM_GENERATOR_CLASSES # 8 | |
| # ββ Training ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| epochs: int = 30 | |
| batch_size: int = 64 | |
| learning_rate: float = 2e-5 | |
| weight_decay: float = 0.01 | |
| warmup_steps: int = 500 | |
| grad_accumulation: int = 1 | |
| amp: bool = True | |
| generator_loss_weight: float = 0.3 # secondary objective weight | |
| # ββ Optimiser βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| optimizer: str = "adamw" | |
| scheduler: str = "cosine" | |
| # ββ Early stopping ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| patience: int = 5 # stop if val AUC flat for N epochs | |
| # ββ Reproducibility βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| seed: int = 42 | |
| # ββ Kaggle ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| is_kaggle: bool = False | |
| def __post_init__(self) -> None: | |
| self.data_dir = Path(self.data_dir) | |
| self.output_dir = Path(self.output_dir) | |
| self.log_dir = Path(self.log_dir) | |
| self.output_dir.mkdir(parents=True, exist_ok=True) | |
| self.log_dir.mkdir(parents=True, exist_ok=True) | |
| def to_dict(self) -> dict: | |
| return { | |
| k: str(v) if isinstance(v, Path) else v | |
| for k, v in self.__dict__.items() | |
| } | |
| class CoherenceConfig: | |
| """Coherence engine training configuration (Phase 2).""" | |
| data_dir: Path = Path("data/processed/coherence") | |
| output_dir: Path = Path("models/checkpoints/coherence") | |
| log_dir: Path = Path("training/logs") | |
| epochs: int = 25 | |
| batch_size: int = 16 | |
| learning_rate: float = 1e-4 | |
| weight_decay: float = 1e-4 | |
| contrastive_weight: float = 0.1 | |
| amp: bool = True | |
| patience: int = 5 | |
| seed: int = 42 | |
| is_kaggle: bool = False | |
| def __post_init__(self) -> None: | |
| self.data_dir = Path(self.data_dir) | |
| self.output_dir = Path(self.output_dir) | |
| self.log_dir = Path(self.log_dir) | |
| self.output_dir.mkdir(parents=True, exist_ok=True) | |
| self.log_dir.mkdir(parents=True, exist_ok=True) | |
| def to_dict(self) -> dict: | |
| return { | |
| k: str(v) if isinstance(v, Path) else v | |
| for k, v in self.__dict__.items() | |
| } | |
| class SSTGNNConfig: | |
| """SSTGNN engine training configuration (Phase 3).""" | |
| data_dir: Path = Path("data/processed/sstgnn") | |
| output_dir: Path = Path("models/checkpoints/sstgnn") | |
| log_dir: Path = Path("training/logs") | |
| # Graph model | |
| in_channels: int = 5 # node feature dim: x, y, z, frame_idx/T, lm_idx/L | |
| hidden_dim: int = 64 | |
| heads: int = 4 | |
| num_gat_layers: int = 3 | |
| # Training β note: no AMP by default (NaN risk on GNNs) | |
| epochs: int = 40 | |
| batch_size: int = 8 | |
| learning_rate: float = 5e-4 | |
| weight_decay: float = 5e-4 | |
| grad_clip: float = 1.0 # always clip for GNNs | |
| amp: bool = False | |
| patience: int = 8 # GNNs need longer to plateau | |
| seed: int = 42 | |
| is_kaggle: bool = False | |
| def __post_init__(self) -> None: | |
| self.data_dir = Path(self.data_dir) | |
| self.output_dir = Path(self.output_dir) | |
| self.log_dir = Path(self.log_dir) | |
| self.output_dir.mkdir(parents=True, exist_ok=True) | |
| self.log_dir.mkdir(parents=True, exist_ok=True) | |
| def to_dict(self) -> dict: | |
| return { | |
| k: str(v) if isinstance(v, Path) else v | |
| for k, v in self.__dict__.items() | |
| } | |