Spaces:
Sleeping
Sleeping
File size: 6,011 Bytes
4e75170 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | """
src/training/config.py β TrainingConfig dataclass
Single source of truth for all training hyperparameters. Import this from
every training script instead of duplicating argparse defaults.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from pathlib import Path
from typing import List
# Generator label index mapping β must match GeneratorLabel enum in src/types.py
# and the classification head in every model file.
GENERATOR_CLASSES: List[str] = [
"real", # 0
"unknown_gan", # 1
"stable_diffusion", # 2
"midjourney", # 3
"dall_e", # 4
"flux", # 5
"firefly", # 6
"imagen", # 7
]
NUM_GENERATOR_CLASSES: int = len(GENERATOR_CLASSES) # 8 β never change this
@dataclass
class TrainingConfig:
"""Fingerprint engine training configuration (Phase 1)."""
# ββ Paths βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
data_dir: Path = Path("data/processed/fingerprint")
output_dir: Path = Path("models/checkpoints/fingerprint")
log_dir: Path = Path("training/logs")
# ββ Model βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
model_name: str = "vit_base_patch16_224" # timm slug
pretrained: bool = True
num_binary_classes: int = 2
num_generator_classes: int = NUM_GENERATOR_CLASSES # 8
# ββ Training ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
epochs: int = 30
batch_size: int = 64
learning_rate: float = 2e-5
weight_decay: float = 0.01
warmup_steps: int = 500
grad_accumulation: int = 1
amp: bool = True
generator_loss_weight: float = 0.3 # secondary objective weight
# ββ Optimiser βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
optimizer: str = "adamw"
scheduler: str = "cosine"
# ββ Early stopping ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
patience: int = 5 # stop if val AUC flat for N epochs
# ββ Reproducibility βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
seed: int = 42
# ββ Kaggle ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
is_kaggle: bool = False
def __post_init__(self) -> None:
self.data_dir = Path(self.data_dir)
self.output_dir = Path(self.output_dir)
self.log_dir = Path(self.log_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
self.log_dir.mkdir(parents=True, exist_ok=True)
def to_dict(self) -> dict:
return {
k: str(v) if isinstance(v, Path) else v
for k, v in self.__dict__.items()
}
@dataclass
class CoherenceConfig:
"""Coherence engine training configuration (Phase 2)."""
data_dir: Path = Path("data/processed/coherence")
output_dir: Path = Path("models/checkpoints/coherence")
log_dir: Path = Path("training/logs")
epochs: int = 25
batch_size: int = 16
learning_rate: float = 1e-4
weight_decay: float = 1e-4
contrastive_weight: float = 0.1
amp: bool = True
patience: int = 5
seed: int = 42
is_kaggle: bool = False
def __post_init__(self) -> None:
self.data_dir = Path(self.data_dir)
self.output_dir = Path(self.output_dir)
self.log_dir = Path(self.log_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
self.log_dir.mkdir(parents=True, exist_ok=True)
def to_dict(self) -> dict:
return {
k: str(v) if isinstance(v, Path) else v
for k, v in self.__dict__.items()
}
@dataclass
class SSTGNNConfig:
"""SSTGNN engine training configuration (Phase 3)."""
data_dir: Path = Path("data/processed/sstgnn")
output_dir: Path = Path("models/checkpoints/sstgnn")
log_dir: Path = Path("training/logs")
# Graph model
in_channels: int = 5 # node feature dim: x, y, z, frame_idx/T, lm_idx/L
hidden_dim: int = 64
heads: int = 4
num_gat_layers: int = 3
# Training β note: no AMP by default (NaN risk on GNNs)
epochs: int = 40
batch_size: int = 8
learning_rate: float = 5e-4
weight_decay: float = 5e-4
grad_clip: float = 1.0 # always clip for GNNs
amp: bool = False
patience: int = 8 # GNNs need longer to plateau
seed: int = 42
is_kaggle: bool = False
def __post_init__(self) -> None:
self.data_dir = Path(self.data_dir)
self.output_dir = Path(self.output_dir)
self.log_dir = Path(self.log_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
self.log_dir.mkdir(parents=True, exist_ok=True)
def to_dict(self) -> dict:
return {
k: str(v) if isinstance(v, Path) else v
for k, v in self.__dict__.items()
}
|