deepdetection / src /training /config.py
akagtag's picture
Initial commit
4e75170
"""
src/training/config.py β€” TrainingConfig dataclass
Single source of truth for all training hyperparameters. Import this from
every training script instead of duplicating argparse defaults.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from pathlib import Path
from typing import List
# Generator label index mapping β€” must match GeneratorLabel enum in src/types.py
# and the classification head in every model file.
GENERATOR_CLASSES: List[str] = [
"real", # 0
"unknown_gan", # 1
"stable_diffusion", # 2
"midjourney", # 3
"dall_e", # 4
"flux", # 5
"firefly", # 6
"imagen", # 7
]
NUM_GENERATOR_CLASSES: int = len(GENERATOR_CLASSES) # 8 β€” never change this
@dataclass
class TrainingConfig:
"""Fingerprint engine training configuration (Phase 1)."""
# ── Paths ─────────────────────────────────────────────────────────────────
data_dir: Path = Path("data/processed/fingerprint")
output_dir: Path = Path("models/checkpoints/fingerprint")
log_dir: Path = Path("training/logs")
# ── Model ─────────────────────────────────────────────────────────────────
model_name: str = "vit_base_patch16_224" # timm slug
pretrained: bool = True
num_binary_classes: int = 2
num_generator_classes: int = NUM_GENERATOR_CLASSES # 8
# ── Training ──────────────────────────────────────────────────────────────
epochs: int = 30
batch_size: int = 64
learning_rate: float = 2e-5
weight_decay: float = 0.01
warmup_steps: int = 500
grad_accumulation: int = 1
amp: bool = True
generator_loss_weight: float = 0.3 # secondary objective weight
# ── Optimiser ─────────────────────────────────────────────────────────────
optimizer: str = "adamw"
scheduler: str = "cosine"
# ── Early stopping ────────────────────────────────────────────────────────
patience: int = 5 # stop if val AUC flat for N epochs
# ── Reproducibility ───────────────────────────────────────────────────────
seed: int = 42
# ── Kaggle ────────────────────────────────────────────────────────────────
is_kaggle: bool = False
def __post_init__(self) -> None:
self.data_dir = Path(self.data_dir)
self.output_dir = Path(self.output_dir)
self.log_dir = Path(self.log_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
self.log_dir.mkdir(parents=True, exist_ok=True)
def to_dict(self) -> dict:
return {
k: str(v) if isinstance(v, Path) else v
for k, v in self.__dict__.items()
}
@dataclass
class CoherenceConfig:
"""Coherence engine training configuration (Phase 2)."""
data_dir: Path = Path("data/processed/coherence")
output_dir: Path = Path("models/checkpoints/coherence")
log_dir: Path = Path("training/logs")
epochs: int = 25
batch_size: int = 16
learning_rate: float = 1e-4
weight_decay: float = 1e-4
contrastive_weight: float = 0.1
amp: bool = True
patience: int = 5
seed: int = 42
is_kaggle: bool = False
def __post_init__(self) -> None:
self.data_dir = Path(self.data_dir)
self.output_dir = Path(self.output_dir)
self.log_dir = Path(self.log_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
self.log_dir.mkdir(parents=True, exist_ok=True)
def to_dict(self) -> dict:
return {
k: str(v) if isinstance(v, Path) else v
for k, v in self.__dict__.items()
}
@dataclass
class SSTGNNConfig:
"""SSTGNN engine training configuration (Phase 3)."""
data_dir: Path = Path("data/processed/sstgnn")
output_dir: Path = Path("models/checkpoints/sstgnn")
log_dir: Path = Path("training/logs")
# Graph model
in_channels: int = 5 # node feature dim: x, y, z, frame_idx/T, lm_idx/L
hidden_dim: int = 64
heads: int = 4
num_gat_layers: int = 3
# Training β€” note: no AMP by default (NaN risk on GNNs)
epochs: int = 40
batch_size: int = 8
learning_rate: float = 5e-4
weight_decay: float = 5e-4
grad_clip: float = 1.0 # always clip for GNNs
amp: bool = False
patience: int = 8 # GNNs need longer to plateau
seed: int = 42
is_kaggle: bool = False
def __post_init__(self) -> None:
self.data_dir = Path(self.data_dir)
self.output_dir = Path(self.output_dir)
self.log_dir = Path(self.log_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
self.log_dir.mkdir(parents=True, exist_ok=True)
def to_dict(self) -> dict:
return {
k: str(v) if isinstance(v, Path) else v
for k, v in self.__dict__.items()
}