iljung1106
Initial commit
546ff88
"""
Artist Style Embedding - Training Script
EVA02-Large, 3-branch, Maximum Performance
"""
import argparse
import random
import numpy as np
import torch
from config import get_config
from dataset import build_dataset_splits, create_dataloaders
from model import create_model
from losses import create_loss
from trainer import Trainer
def set_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def parse_args():
parser = argparse.ArgumentParser(description='Train Artist Style Embedding')
# Data paths
parser.add_argument('--dataset_root', type=str, default='./dataset')
parser.add_argument('--dataset_face_root', type=str, default='./dataset_face')
parser.add_argument('--dataset_eyes_root', type=str, default='./dataset_eyes')
# Training
parser.add_argument('--epochs', type=int, default=50)
parser.add_argument('--batch_size', type=int, default=768)
parser.add_argument('--lr', type=float, default=2e-4)
# Other
parser.add_argument('--save_dir', type=str, default='./checkpoints')
parser.add_argument('--resume', type=str, default=None)
parser.add_argument('--wandb_project', type=str, default='artist-style-embedding')
parser.add_argument('--seed', type=int, default=42)
return parser.parse_args()
def main():
args = parse_args()
set_seed(args.seed)
config = get_config()
# Update from args
config.data.dataset_root = args.dataset_root
config.data.dataset_face_root = args.dataset_face_root
config.data.dataset_eyes_root = args.dataset_eyes_root
config.train.epochs = args.epochs
config.train.batch_size = args.batch_size
config.train.learning_rate = args.lr
config.train.save_dir = args.save_dir
config.train.wandb_project = args.wandb_project
config.train.seed = args.seed
# Print config
print("=" * 60)
print("Artist Style Embedding - Maximum Performance")
print("=" * 60)
print(f"Backbone: EVA02-Large (3x separate)")
print(f"Branches: Full + Face + Eye")
print(f"Embedding dim: {config.model.embedding_dim}")
print(f"Fusion: Gated")
print(f"Epochs: {config.train.epochs}")
print(f"Batch size: {config.train.batch_size}")
print(f"Learning rate: {config.train.learning_rate}")
print("=" * 60)
# Build dataset
print("\nBuilding dataset splits...")
artist_to_idx, full_splits, face_splits, eye_splits = build_dataset_splits(
config.data.dataset_root,
config.data.dataset_face_root,
config.data.dataset_eyes_root,
min_images=config.data.min_images_per_artist,
train_ratio=config.data.train_ratio,
val_ratio=config.data.val_ratio,
seed=config.train.seed,
)
num_classes = len(artist_to_idx)
print(f"Number of artists: {num_classes}")
# Create dataloaders
print("\nCreating dataloaders...")
train_loader, val_loader, test_loader = create_dataloaders(
config, artist_to_idx, full_splits, face_splits, eye_splits
)
# Create model
print("\nCreating model (3x EVA02-Large)...")
model = create_model(config, num_classes)
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}")
# Create loss
loss_fn = create_loss(config, num_classes)
# Create trainer
trainer = Trainer(
model=model,
loss_fn=loss_fn,
train_loader=train_loader,
val_loader=val_loader,
config=config,
artist_to_idx=artist_to_idx,
)
if args.resume:
trainer.load_checkpoint(args.resume)
# Train
trainer.train()
print("\nTraining complete!")
if __name__ == '__main__':
main()