|
|
""" |
|
|
Artist Style Embedding - Training Script |
|
|
EVA02-Large, 3-branch, Maximum Performance |
|
|
""" |
|
|
import argparse |
|
|
import random |
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
from config import get_config |
|
|
from dataset import build_dataset_splits, create_dataloaders |
|
|
from model import create_model |
|
|
from losses import create_loss |
|
|
from trainer import Trainer |
|
|
|
|
|
|
|
|
def set_seed(seed: int): |
|
|
random.seed(seed) |
|
|
np.random.seed(seed) |
|
|
torch.manual_seed(seed) |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
torch.backends.cudnn.deterministic = True |
|
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
parser = argparse.ArgumentParser(description='Train Artist Style Embedding') |
|
|
|
|
|
|
|
|
parser.add_argument('--dataset_root', type=str, default='./dataset') |
|
|
parser.add_argument('--dataset_face_root', type=str, default='./dataset_face') |
|
|
parser.add_argument('--dataset_eyes_root', type=str, default='./dataset_eyes') |
|
|
|
|
|
|
|
|
parser.add_argument('--epochs', type=int, default=50) |
|
|
parser.add_argument('--batch_size', type=int, default=768) |
|
|
parser.add_argument('--lr', type=float, default=2e-4) |
|
|
|
|
|
|
|
|
parser.add_argument('--save_dir', type=str, default='./checkpoints') |
|
|
parser.add_argument('--resume', type=str, default=None) |
|
|
parser.add_argument('--wandb_project', type=str, default='artist-style-embedding') |
|
|
parser.add_argument('--seed', type=int, default=42) |
|
|
|
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def main(): |
|
|
args = parse_args() |
|
|
set_seed(args.seed) |
|
|
|
|
|
config = get_config() |
|
|
|
|
|
|
|
|
config.data.dataset_root = args.dataset_root |
|
|
config.data.dataset_face_root = args.dataset_face_root |
|
|
config.data.dataset_eyes_root = args.dataset_eyes_root |
|
|
config.train.epochs = args.epochs |
|
|
config.train.batch_size = args.batch_size |
|
|
config.train.learning_rate = args.lr |
|
|
config.train.save_dir = args.save_dir |
|
|
config.train.wandb_project = args.wandb_project |
|
|
config.train.seed = args.seed |
|
|
|
|
|
|
|
|
print("=" * 60) |
|
|
print("Artist Style Embedding - Maximum Performance") |
|
|
print("=" * 60) |
|
|
print(f"Backbone: EVA02-Large (3x separate)") |
|
|
print(f"Branches: Full + Face + Eye") |
|
|
print(f"Embedding dim: {config.model.embedding_dim}") |
|
|
print(f"Fusion: Gated") |
|
|
print(f"Epochs: {config.train.epochs}") |
|
|
print(f"Batch size: {config.train.batch_size}") |
|
|
print(f"Learning rate: {config.train.learning_rate}") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
print("\nBuilding dataset splits...") |
|
|
artist_to_idx, full_splits, face_splits, eye_splits = build_dataset_splits( |
|
|
config.data.dataset_root, |
|
|
config.data.dataset_face_root, |
|
|
config.data.dataset_eyes_root, |
|
|
min_images=config.data.min_images_per_artist, |
|
|
train_ratio=config.data.train_ratio, |
|
|
val_ratio=config.data.val_ratio, |
|
|
seed=config.train.seed, |
|
|
) |
|
|
|
|
|
num_classes = len(artist_to_idx) |
|
|
print(f"Number of artists: {num_classes}") |
|
|
|
|
|
|
|
|
print("\nCreating dataloaders...") |
|
|
train_loader, val_loader, test_loader = create_dataloaders( |
|
|
config, artist_to_idx, full_splits, face_splits, eye_splits |
|
|
) |
|
|
|
|
|
|
|
|
print("\nCreating model (3x EVA02-Large)...") |
|
|
model = create_model(config, num_classes) |
|
|
|
|
|
total_params = sum(p.numel() for p in model.parameters()) |
|
|
print(f"Total parameters: {total_params:,}") |
|
|
|
|
|
|
|
|
loss_fn = create_loss(config, num_classes) |
|
|
|
|
|
|
|
|
trainer = Trainer( |
|
|
model=model, |
|
|
loss_fn=loss_fn, |
|
|
train_loader=train_loader, |
|
|
val_loader=val_loader, |
|
|
config=config, |
|
|
artist_to_idx=artist_to_idx, |
|
|
) |
|
|
|
|
|
if args.resume: |
|
|
trainer.load_checkpoint(args.resume) |
|
|
|
|
|
|
|
|
trainer.train() |
|
|
|
|
|
print("\nTraining complete!") |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|