File size: 3,104 Bytes
546ff88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""
Artist Style Embedding - Configuration
Maximum Performance Settings for RTX 5090
"""
from dataclasses import dataclass, field
from typing import Optional
import torch


@dataclass
class DataConfig:
    # 데이터셋 경로
    dataset_root: str = "./dataset"
    dataset_face_root: str = "./dataset_face"
    dataset_eyes_root: str = "./dataset_eyes"
    
    # 이미지 설정
    image_size: int = 224
    min_images_per_artist: int = 3
    
    # 데이터 분할
    train_ratio: float = 0.8
    val_ratio: float = 0.1
    test_ratio: float = 0.1
    
    # 데이터 로딩
    num_workers: int = 12
    pin_memory: bool = True


@dataclass
class ModelConfig:
    # Backbone - EVA02-Large (최고 성능)
    backbone: str = "eva02_large_patch14_clip_224"
    backbone_pretrained: bool = True
    freeze_backbone_epochs: int = 0  # 처음부터 unfreeze
    
    # 임베딩 설정
    embedding_dim: int = 512
    hidden_dim: int = 1024
    
    # Multi-branch 설정 - 모든 브랜치 활성화, 별도 백본
    use_face_branch: bool = True
    use_eye_branch: bool = True
    share_backbone_weights: bool = False  # 별도 백본으로 최고 성능
    
    # Fusion 설정
    fusion_type: str = "gated"
    num_attention_heads: int = 8
    
    # Dropout
    dropout: float = 0.2  # 약간 높임


@dataclass
class LossConfig:
    # ArcFace settings
    arcface_scale: float = 64.0
    arcface_margin: float = 0.5
    arcface_weight: float = 0.2
    
    # Multi-Similarity Loss weight
    ms_loss_weight: float = 3.0
    
    # Center Loss weight
    center_loss_weight: float = 0.01


@dataclass
class TrainConfig:
    # 학습 설정
    epochs: int = 100
    batch_size: int = 128
    
    # Optimizer - 더 높은 learning rate
    learning_rate: float = 5e-4  # 1e-4 → 5e-4
    backbone_lr_multiplier: float = 0.2  # 0.1 → 0.2 (backbone도 더 학습)
    weight_decay: float = 0.01  # 0.05 → 0.01 (regularization 줄임)
    
    # Scheduler
    warmup_epochs: int = 3  # 5 → 3
    min_lr: float = 1e-6
    
    # Gradient
    max_grad_norm: float = 1.0
    
    # Mixed precision
    use_amp: bool = True
    
    # 체크포인트
    save_dir: str = "./checkpoints"
    save_every_n_epochs: int = 5
    
    # 로깅
    log_every_n_steps: int = 50
    wandb_project: Optional[str] = "artist-style-embedding"
    wandb_run_name: Optional[str] = None
    
    # Sampling
    samples_per_class: int = 4
    
    # Early stopping
    patience: int = 20  # 더 오래 기다림
    
    # Device
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Random seed
    seed: int = 42


@dataclass
class Config:
    data: DataConfig = field(default_factory=DataConfig)
    model: ModelConfig = field(default_factory=ModelConfig)
    loss: LossConfig = field(default_factory=LossConfig)
    train: TrainConfig = field(default_factory=TrainConfig)
    
    def __post_init__(self):
        if self.train.wandb_run_name is None:
            self.train.wandb_run_name = f"eva02_large_emb{self.model.embedding_dim}"


def get_config():
    return Config()