File size: 3,104 Bytes
546ff88 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
"""
Artist Style Embedding - Configuration
Maximum Performance Settings for RTX 5090
"""
from dataclasses import dataclass, field
from typing import Optional
import torch
@dataclass
class DataConfig:
# 데이터셋 경로
dataset_root: str = "./dataset"
dataset_face_root: str = "./dataset_face"
dataset_eyes_root: str = "./dataset_eyes"
# 이미지 설정
image_size: int = 224
min_images_per_artist: int = 3
# 데이터 분할
train_ratio: float = 0.8
val_ratio: float = 0.1
test_ratio: float = 0.1
# 데이터 로딩
num_workers: int = 12
pin_memory: bool = True
@dataclass
class ModelConfig:
# Backbone - EVA02-Large (최고 성능)
backbone: str = "eva02_large_patch14_clip_224"
backbone_pretrained: bool = True
freeze_backbone_epochs: int = 0 # 처음부터 unfreeze
# 임베딩 설정
embedding_dim: int = 512
hidden_dim: int = 1024
# Multi-branch 설정 - 모든 브랜치 활성화, 별도 백본
use_face_branch: bool = True
use_eye_branch: bool = True
share_backbone_weights: bool = False # 별도 백본으로 최고 성능
# Fusion 설정
fusion_type: str = "gated"
num_attention_heads: int = 8
# Dropout
dropout: float = 0.2 # 약간 높임
@dataclass
class LossConfig:
# ArcFace settings
arcface_scale: float = 64.0
arcface_margin: float = 0.5
arcface_weight: float = 0.2
# Multi-Similarity Loss weight
ms_loss_weight: float = 3.0
# Center Loss weight
center_loss_weight: float = 0.01
@dataclass
class TrainConfig:
# 학습 설정
epochs: int = 100
batch_size: int = 128
# Optimizer - 더 높은 learning rate
learning_rate: float = 5e-4 # 1e-4 → 5e-4
backbone_lr_multiplier: float = 0.2 # 0.1 → 0.2 (backbone도 더 학습)
weight_decay: float = 0.01 # 0.05 → 0.01 (regularization 줄임)
# Scheduler
warmup_epochs: int = 3 # 5 → 3
min_lr: float = 1e-6
# Gradient
max_grad_norm: float = 1.0
# Mixed precision
use_amp: bool = True
# 체크포인트
save_dir: str = "./checkpoints"
save_every_n_epochs: int = 5
# 로깅
log_every_n_steps: int = 50
wandb_project: Optional[str] = "artist-style-embedding"
wandb_run_name: Optional[str] = None
# Sampling
samples_per_class: int = 4
# Early stopping
patience: int = 20 # 더 오래 기다림
# Device
device: str = "cuda" if torch.cuda.is_available() else "cpu"
# Random seed
seed: int = 42
@dataclass
class Config:
data: DataConfig = field(default_factory=DataConfig)
model: ModelConfig = field(default_factory=ModelConfig)
loss: LossConfig = field(default_factory=LossConfig)
train: TrainConfig = field(default_factory=TrainConfig)
def __post_init__(self):
if self.train.wandb_run_name is None:
self.train.wandb_run_name = f"eva02_large_emb{self.model.embedding_dim}"
def get_config():
return Config()
|