File size: 3,921 Bytes
546ff88 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
"""
Artist Style Embedding - Training Script
EVA02-Large, 3-branch, Maximum Performance
"""
import argparse
import random
import numpy as np
import torch
from config import get_config
from dataset import build_dataset_splits, create_dataloaders
from model import create_model
from losses import create_loss
from trainer import Trainer
def set_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def parse_args():
parser = argparse.ArgumentParser(description='Train Artist Style Embedding')
# Data paths
parser.add_argument('--dataset_root', type=str, default='./dataset')
parser.add_argument('--dataset_face_root', type=str, default='./dataset_face')
parser.add_argument('--dataset_eyes_root', type=str, default='./dataset_eyes')
# Training
parser.add_argument('--epochs', type=int, default=50)
parser.add_argument('--batch_size', type=int, default=768)
parser.add_argument('--lr', type=float, default=2e-4)
# Other
parser.add_argument('--save_dir', type=str, default='./checkpoints')
parser.add_argument('--resume', type=str, default=None)
parser.add_argument('--wandb_project', type=str, default='artist-style-embedding')
parser.add_argument('--seed', type=int, default=42)
return parser.parse_args()
def main():
args = parse_args()
set_seed(args.seed)
config = get_config()
# Update from args
config.data.dataset_root = args.dataset_root
config.data.dataset_face_root = args.dataset_face_root
config.data.dataset_eyes_root = args.dataset_eyes_root
config.train.epochs = args.epochs
config.train.batch_size = args.batch_size
config.train.learning_rate = args.lr
config.train.save_dir = args.save_dir
config.train.wandb_project = args.wandb_project
config.train.seed = args.seed
# Print config
print("=" * 60)
print("Artist Style Embedding - Maximum Performance")
print("=" * 60)
print(f"Backbone: EVA02-Large (3x separate)")
print(f"Branches: Full + Face + Eye")
print(f"Embedding dim: {config.model.embedding_dim}")
print(f"Fusion: Gated")
print(f"Epochs: {config.train.epochs}")
print(f"Batch size: {config.train.batch_size}")
print(f"Learning rate: {config.train.learning_rate}")
print("=" * 60)
# Build dataset
print("\nBuilding dataset splits...")
artist_to_idx, full_splits, face_splits, eye_splits = build_dataset_splits(
config.data.dataset_root,
config.data.dataset_face_root,
config.data.dataset_eyes_root,
min_images=config.data.min_images_per_artist,
train_ratio=config.data.train_ratio,
val_ratio=config.data.val_ratio,
seed=config.train.seed,
)
num_classes = len(artist_to_idx)
print(f"Number of artists: {num_classes}")
# Create dataloaders
print("\nCreating dataloaders...")
train_loader, val_loader, test_loader = create_dataloaders(
config, artist_to_idx, full_splits, face_splits, eye_splits
)
# Create model
print("\nCreating model (3x EVA02-Large)...")
model = create_model(config, num_classes)
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}")
# Create loss
loss_fn = create_loss(config, num_classes)
# Create trainer
trainer = Trainer(
model=model,
loss_fn=loss_fn,
train_loader=train_loader,
val_loader=val_loader,
config=config,
artist_to_idx=artist_to_idx,
)
if args.resume:
trainer.load_checkpoint(args.resume)
# Train
trainer.train()
print("\nTraining complete!")
if __name__ == '__main__':
main()
|