|
|
""" |
|
|
Artist Style Embedding - Configuration |
|
|
Maximum Performance Settings for RTX 5090 |
|
|
""" |
|
|
from dataclasses import dataclass, field |
|
|
from typing import Optional |
|
|
import torch |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DataConfig: |
|
|
|
|
|
dataset_root: str = "./dataset" |
|
|
dataset_face_root: str = "./dataset_face" |
|
|
dataset_eyes_root: str = "./dataset_eyes" |
|
|
|
|
|
|
|
|
image_size: int = 224 |
|
|
min_images_per_artist: int = 3 |
|
|
|
|
|
|
|
|
train_ratio: float = 0.8 |
|
|
val_ratio: float = 0.1 |
|
|
test_ratio: float = 0.1 |
|
|
|
|
|
|
|
|
num_workers: int = 12 |
|
|
pin_memory: bool = True |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ModelConfig: |
|
|
|
|
|
backbone: str = "eva02_large_patch14_clip_224" |
|
|
backbone_pretrained: bool = True |
|
|
freeze_backbone_epochs: int = 0 |
|
|
|
|
|
|
|
|
embedding_dim: int = 512 |
|
|
hidden_dim: int = 1024 |
|
|
|
|
|
|
|
|
use_face_branch: bool = True |
|
|
use_eye_branch: bool = True |
|
|
share_backbone_weights: bool = False |
|
|
|
|
|
|
|
|
fusion_type: str = "gated" |
|
|
num_attention_heads: int = 8 |
|
|
|
|
|
|
|
|
dropout: float = 0.2 |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class LossConfig: |
|
|
|
|
|
arcface_scale: float = 64.0 |
|
|
arcface_margin: float = 0.5 |
|
|
arcface_weight: float = 0.2 |
|
|
|
|
|
|
|
|
ms_loss_weight: float = 3.0 |
|
|
|
|
|
|
|
|
center_loss_weight: float = 0.01 |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TrainConfig: |
|
|
|
|
|
epochs: int = 100 |
|
|
batch_size: int = 128 |
|
|
|
|
|
|
|
|
learning_rate: float = 5e-4 |
|
|
backbone_lr_multiplier: float = 0.2 |
|
|
weight_decay: float = 0.01 |
|
|
|
|
|
|
|
|
warmup_epochs: int = 3 |
|
|
min_lr: float = 1e-6 |
|
|
|
|
|
|
|
|
max_grad_norm: float = 1.0 |
|
|
|
|
|
|
|
|
use_amp: bool = True |
|
|
|
|
|
|
|
|
save_dir: str = "./checkpoints" |
|
|
save_every_n_epochs: int = 5 |
|
|
|
|
|
|
|
|
log_every_n_steps: int = 50 |
|
|
wandb_project: Optional[str] = "artist-style-embedding" |
|
|
wandb_run_name: Optional[str] = None |
|
|
|
|
|
|
|
|
samples_per_class: int = 4 |
|
|
|
|
|
|
|
|
patience: int = 20 |
|
|
|
|
|
|
|
|
device: str = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
seed: int = 42 |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Config: |
|
|
data: DataConfig = field(default_factory=DataConfig) |
|
|
model: ModelConfig = field(default_factory=ModelConfig) |
|
|
loss: LossConfig = field(default_factory=LossConfig) |
|
|
train: TrainConfig = field(default_factory=TrainConfig) |
|
|
|
|
|
def __post_init__(self): |
|
|
if self.train.wandb_run_name is None: |
|
|
self.train.wandb_run_name = f"eva02_large_emb{self.model.embedding_dim}" |
|
|
|
|
|
|
|
|
def get_config(): |
|
|
return Config() |
|
|
|