File size: 3,921 Bytes
546ff88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""
Artist Style Embedding - Training Script
EVA02-Large, 3-branch, Maximum Performance
"""
import argparse
import random
import numpy as np
import torch

from config import get_config
from dataset import build_dataset_splits, create_dataloaders
from model import create_model
from losses import create_loss
from trainer import Trainer


def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def parse_args():
    parser = argparse.ArgumentParser(description='Train Artist Style Embedding')
    
    # Data paths
    parser.add_argument('--dataset_root', type=str, default='./dataset')
    parser.add_argument('--dataset_face_root', type=str, default='./dataset_face')
    parser.add_argument('--dataset_eyes_root', type=str, default='./dataset_eyes')
    
    # Training
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--batch_size', type=int, default=768)
    parser.add_argument('--lr', type=float, default=2e-4)
    
    # Other
    parser.add_argument('--save_dir', type=str, default='./checkpoints')
    parser.add_argument('--resume', type=str, default=None)
    parser.add_argument('--wandb_project', type=str, default='artist-style-embedding')
    parser.add_argument('--seed', type=int, default=42)
    
    return parser.parse_args()


def main():
    args = parse_args()
    set_seed(args.seed)
    
    config = get_config()
    
    # Update from args
    config.data.dataset_root = args.dataset_root
    config.data.dataset_face_root = args.dataset_face_root
    config.data.dataset_eyes_root = args.dataset_eyes_root
    config.train.epochs = args.epochs
    config.train.batch_size = args.batch_size
    config.train.learning_rate = args.lr
    config.train.save_dir = args.save_dir
    config.train.wandb_project = args.wandb_project
    config.train.seed = args.seed
    
    # Print config
    print("=" * 60)
    print("Artist Style Embedding - Maximum Performance")
    print("=" * 60)
    print(f"Backbone: EVA02-Large (3x separate)")
    print(f"Branches: Full + Face + Eye")
    print(f"Embedding dim: {config.model.embedding_dim}")
    print(f"Fusion: Gated")
    print(f"Epochs: {config.train.epochs}")
    print(f"Batch size: {config.train.batch_size}")
    print(f"Learning rate: {config.train.learning_rate}")
    print("=" * 60)
    
    # Build dataset
    print("\nBuilding dataset splits...")
    artist_to_idx, full_splits, face_splits, eye_splits = build_dataset_splits(
        config.data.dataset_root,
        config.data.dataset_face_root,
        config.data.dataset_eyes_root,
        min_images=config.data.min_images_per_artist,
        train_ratio=config.data.train_ratio,
        val_ratio=config.data.val_ratio,
        seed=config.train.seed,
    )
    
    num_classes = len(artist_to_idx)
    print(f"Number of artists: {num_classes}")
    
    # Create dataloaders
    print("\nCreating dataloaders...")
    train_loader, val_loader, test_loader = create_dataloaders(
        config, artist_to_idx, full_splits, face_splits, eye_splits
    )
    
    # Create model
    print("\nCreating model (3x EVA02-Large)...")
    model = create_model(config, num_classes)
    
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total parameters: {total_params:,}")
    
    # Create loss
    loss_fn = create_loss(config, num_classes)
    
    # Create trainer
    trainer = Trainer(
        model=model,
        loss_fn=loss_fn,
        train_loader=train_loader,
        val_loader=val_loader,
        config=config,
        artist_to_idx=artist_to_idx,
    )
    
    if args.resume:
        trainer.load_checkpoint(args.resume)
    
    # Train
    trainer.train()
    
    print("\nTraining complete!")


if __name__ == '__main__':
    main()