""" Artist Style Embedding - Training Script EVA02-Large, 3-branch, Maximum Performance """ import argparse import random import numpy as np import torch from config import get_config from dataset import build_dataset_splits, create_dataloaders from model import create_model from losses import create_loss from trainer import Trainer def set_seed(seed: int): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def parse_args(): parser = argparse.ArgumentParser(description='Train Artist Style Embedding') # Data paths parser.add_argument('--dataset_root', type=str, default='./dataset') parser.add_argument('--dataset_face_root', type=str, default='./dataset_face') parser.add_argument('--dataset_eyes_root', type=str, default='./dataset_eyes') # Training parser.add_argument('--epochs', type=int, default=50) parser.add_argument('--batch_size', type=int, default=768) parser.add_argument('--lr', type=float, default=2e-4) # Other parser.add_argument('--save_dir', type=str, default='./checkpoints') parser.add_argument('--resume', type=str, default=None) parser.add_argument('--wandb_project', type=str, default='artist-style-embedding') parser.add_argument('--seed', type=int, default=42) return parser.parse_args() def main(): args = parse_args() set_seed(args.seed) config = get_config() # Update from args config.data.dataset_root = args.dataset_root config.data.dataset_face_root = args.dataset_face_root config.data.dataset_eyes_root = args.dataset_eyes_root config.train.epochs = args.epochs config.train.batch_size = args.batch_size config.train.learning_rate = args.lr config.train.save_dir = args.save_dir config.train.wandb_project = args.wandb_project config.train.seed = args.seed # Print config print("=" * 60) print("Artist Style Embedding - Maximum Performance") print("=" * 60) print(f"Backbone: EVA02-Large (3x separate)") print(f"Branches: Full + Face + Eye") print(f"Embedding dim: {config.model.embedding_dim}") print(f"Fusion: Gated") print(f"Epochs: {config.train.epochs}") print(f"Batch size: {config.train.batch_size}") print(f"Learning rate: {config.train.learning_rate}") print("=" * 60) # Build dataset print("\nBuilding dataset splits...") artist_to_idx, full_splits, face_splits, eye_splits = build_dataset_splits( config.data.dataset_root, config.data.dataset_face_root, config.data.dataset_eyes_root, min_images=config.data.min_images_per_artist, train_ratio=config.data.train_ratio, val_ratio=config.data.val_ratio, seed=config.train.seed, ) num_classes = len(artist_to_idx) print(f"Number of artists: {num_classes}") # Create dataloaders print("\nCreating dataloaders...") train_loader, val_loader, test_loader = create_dataloaders( config, artist_to_idx, full_splits, face_splits, eye_splits ) # Create model print("\nCreating model (3x EVA02-Large)...") model = create_model(config, num_classes) total_params = sum(p.numel() for p in model.parameters()) print(f"Total parameters: {total_params:,}") # Create loss loss_fn = create_loss(config, num_classes) # Create trainer trainer = Trainer( model=model, loss_fn=loss_fn, train_loader=train_loader, val_loader=val_loader, config=config, artist_to_idx=artist_to_idx, ) if args.resume: trainer.load_checkpoint(args.resume) # Train trainer.train() print("\nTraining complete!") if __name__ == '__main__': main()