""" Artist Style Embedding - Configuration Maximum Performance Settings for RTX 5090 """ from dataclasses import dataclass, field from typing import Optional import torch @dataclass class DataConfig: # 데이터셋 경로 dataset_root: str = "./dataset" dataset_face_root: str = "./dataset_face" dataset_eyes_root: str = "./dataset_eyes" # 이미지 설정 image_size: int = 224 min_images_per_artist: int = 3 # 데이터 분할 train_ratio: float = 0.8 val_ratio: float = 0.1 test_ratio: float = 0.1 # 데이터 로딩 num_workers: int = 12 pin_memory: bool = True @dataclass class ModelConfig: # Backbone - EVA02-Large (최고 성능) backbone: str = "eva02_large_patch14_clip_224" backbone_pretrained: bool = True freeze_backbone_epochs: int = 0 # 처음부터 unfreeze # 임베딩 설정 embedding_dim: int = 512 hidden_dim: int = 1024 # Multi-branch 설정 - 모든 브랜치 활성화, 별도 백본 use_face_branch: bool = True use_eye_branch: bool = True share_backbone_weights: bool = False # 별도 백본으로 최고 성능 # Fusion 설정 fusion_type: str = "gated" num_attention_heads: int = 8 # Dropout dropout: float = 0.2 # 약간 높임 @dataclass class LossConfig: # ArcFace settings arcface_scale: float = 64.0 arcface_margin: float = 0.5 arcface_weight: float = 0.2 # Multi-Similarity Loss weight ms_loss_weight: float = 3.0 # Center Loss weight center_loss_weight: float = 0.01 @dataclass class TrainConfig: # 학습 설정 epochs: int = 100 batch_size: int = 128 # Optimizer - 더 높은 learning rate learning_rate: float = 5e-4 # 1e-4 → 5e-4 backbone_lr_multiplier: float = 0.2 # 0.1 → 0.2 (backbone도 더 학습) weight_decay: float = 0.01 # 0.05 → 0.01 (regularization 줄임) # Scheduler warmup_epochs: int = 3 # 5 → 3 min_lr: float = 1e-6 # Gradient max_grad_norm: float = 1.0 # Mixed precision use_amp: bool = True # 체크포인트 save_dir: str = "./checkpoints" save_every_n_epochs: int = 5 # 로깅 log_every_n_steps: int = 50 wandb_project: Optional[str] = "artist-style-embedding" wandb_run_name: Optional[str] = None # Sampling samples_per_class: int = 4 # Early stopping patience: int = 20 # 더 오래 기다림 # Device device: str = "cuda" if torch.cuda.is_available() else "cpu" # Random seed seed: int = 42 @dataclass class Config: data: DataConfig = field(default_factory=DataConfig) model: ModelConfig = field(default_factory=ModelConfig) loss: LossConfig = field(default_factory=LossConfig) train: TrainConfig = field(default_factory=TrainConfig) def __post_init__(self): if self.train.wandb_run_name is None: self.train.wandb_run_name = f"eva02_large_emb{self.model.embedding_dim}" def get_config(): return Config()