|
|
""" |
|
|
Artist Style Embedding - Evaluation and Inference |
|
|
""" |
|
|
import argparse |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Tuple |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import DataLoader |
|
|
from torchvision import transforms |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
from tqdm import tqdm |
|
|
from sklearn.manifold import TSNE |
|
|
|
|
|
try: |
|
|
import matplotlib.pyplot as plt |
|
|
PLOT_AVAILABLE = True |
|
|
except ImportError: |
|
|
PLOT_AVAILABLE = False |
|
|
|
|
|
from config import get_config |
|
|
from model import ArtistStyleModel |
|
|
from dataset import ArtistDataset, build_dataset_splits, collate_fn |
|
|
|
|
|
|
|
|
class ArtistEmbeddingInference: |
|
|
"""Inference class for artist style embedding""" |
|
|
|
|
|
def __init__(self, checkpoint_path: str, device: str = 'cuda'): |
|
|
requested_device = device |
|
|
if requested_device.startswith('cuda') and not torch.cuda.is_available(): |
|
|
print( |
|
|
"[WARN] --device=cuda requested but torch.cuda.is_available() is False. " |
|
|
"Falling back to CPU. (Install a CUDA-enabled PyTorch build to use GPU.)" |
|
|
) |
|
|
requested_device = 'cpu' |
|
|
self.device = torch.device(requested_device) |
|
|
|
|
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location='cpu') |
|
|
self.artist_to_idx = checkpoint['artist_to_idx'] |
|
|
self.idx_to_artist = {v: k for k, v in self.artist_to_idx.items()} |
|
|
|
|
|
config = get_config() |
|
|
self.model = ArtistStyleModel( |
|
|
num_classes=len(self.artist_to_idx), |
|
|
embedding_dim=config.model.embedding_dim, |
|
|
hidden_dim=config.model.hidden_dim, |
|
|
) |
|
|
self.model.load_state_dict(checkpoint['model_state_dict']) |
|
|
|
|
|
|
|
|
if self.device.type == 'cuda': |
|
|
self.model = self.model.to(dtype=torch.float16) |
|
|
self.model = self.model.to(self.device) |
|
|
self.model.eval() |
|
|
|
|
|
self.transform = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
|
]) |
|
|
|
|
|
def get_embedding(self, image: Image.Image) -> torch.Tensor: |
|
|
tensor = self.transform(image).unsqueeze(0).to(self.device) |
|
|
placeholder = torch.zeros(1, 3, 224, 224).to(self.device) |
|
|
has_false = torch.tensor([False]).to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
embedding = self.model.get_embeddings(tensor, placeholder, placeholder, has_false, has_false) |
|
|
return embedding.squeeze(0) |
|
|
|
|
|
def predict_artist(self, image: Image.Image, top_k: int = 5) -> List[Tuple[str, float]]: |
|
|
tensor = self.transform(image).unsqueeze(0).to(self.device) |
|
|
placeholder = torch.zeros(1, 3, 224, 224).to(self.device) |
|
|
has_false = torch.tensor([False]).to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
output = self.model(tensor, placeholder, placeholder, has_false, has_false) |
|
|
probs = F.softmax(output['cosine'].squeeze(0), dim=0) |
|
|
top_probs, top_indices = probs.topk(top_k) |
|
|
|
|
|
return [(self.idx_to_artist[idx.item()], prob.item()) for prob, idx in zip(top_probs, top_indices)] |
|
|
|
|
|
|
|
|
def evaluate_model(checkpoint_path: str, dataset_root: str, dataset_face_root: str, dataset_eyes_root: str, device: str = 'cuda'): |
|
|
inference = ArtistEmbeddingInference(checkpoint_path, device) |
|
|
config = get_config() |
|
|
|
|
|
artist_to_idx, full_splits, face_splits, eye_splits = build_dataset_splits( |
|
|
dataset_root, dataset_face_root, dataset_eyes_root, |
|
|
min_images=config.data.min_images_per_artist |
|
|
) |
|
|
|
|
|
test_dataset = ArtistDataset( |
|
|
dataset_root, dataset_face_root, dataset_eyes_root, |
|
|
artist_to_idx, full_splits['test'], face_splits['test'], eye_splits['test'], |
|
|
is_training=False |
|
|
) |
|
|
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4, collate_fn=collate_fn) |
|
|
|
|
|
total_correct = 0 |
|
|
total_correct_top5 = 0 |
|
|
total_samples = 0 |
|
|
|
|
|
for batch in tqdm(test_loader, desc="Evaluating"): |
|
|
full = batch['full'].to(inference.device) |
|
|
face = batch['face'].to(inference.device) |
|
|
eye = batch['eye'].to(inference.device) |
|
|
has_face = batch['has_face'].to(inference.device) |
|
|
has_eye = batch['has_eye'].to(inference.device) |
|
|
labels = batch['label'].to(inference.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
output = inference.model(full, face, eye, has_face, has_eye) |
|
|
|
|
|
|
|
|
preds = output['cosine'].argmax(dim=1) |
|
|
total_correct += (preds == labels).sum().item() |
|
|
|
|
|
|
|
|
_, top5_preds = output['cosine'].topk(5, dim=1) |
|
|
top5_correct = top5_preds.eq(labels.view(-1, 1).expand_as(top5_preds)) |
|
|
total_correct_top5 += top5_correct.any(dim=1).sum().item() |
|
|
|
|
|
total_samples += labels.size(0) |
|
|
|
|
|
accuracy = total_correct / total_samples if total_samples > 0 else 0 |
|
|
accuracy_top5 = total_correct_top5 / total_samples if total_samples > 0 else 0 |
|
|
|
|
|
print("\nEvaluation Results:") |
|
|
print("-" * 40) |
|
|
print(f"Top-1 Accuracy: {accuracy:.4f} ({total_correct}/{total_samples})") |
|
|
print(f"Top-5 Accuracy: {accuracy_top5:.4f} ({total_correct_top5}/{total_samples})") |
|
|
|
|
|
|
|
|
def visualize_embeddings(checkpoint_path: str, dataset_root: str, dataset_face_root: str, dataset_eyes_root: str, output_path: str = 'tsne.png', max_artists: int = 50, device: str = 'cuda'): |
|
|
if not PLOT_AVAILABLE: |
|
|
print("matplotlib not available") |
|
|
return |
|
|
|
|
|
inference = ArtistEmbeddingInference(checkpoint_path, device) |
|
|
config = get_config() |
|
|
|
|
|
artist_to_idx, full_splits, face_splits, eye_splits = build_dataset_splits( |
|
|
dataset_root, dataset_face_root, dataset_eyes_root, |
|
|
min_images=config.data.min_images_per_artist |
|
|
) |
|
|
|
|
|
selected = list(artist_to_idx.keys())[:max_artists] |
|
|
filtered_full = {a: p[:10] for a, p in full_splits['test'].items() if a in selected} |
|
|
filtered_face = {a: face_splits['test'].get(a, []) for a in selected} |
|
|
filtered_eye = {a: eye_splits['test'].get(a, []) for a in selected} |
|
|
filtered_idx = {a: i for i, a in enumerate(selected)} |
|
|
|
|
|
dataset = ArtistDataset( |
|
|
dataset_root, dataset_face_root, dataset_eyes_root, |
|
|
filtered_idx, filtered_full, filtered_face, filtered_eye, |
|
|
is_training=False |
|
|
) |
|
|
loader = DataLoader(dataset, batch_size=32, shuffle=False, collate_fn=collate_fn) |
|
|
|
|
|
all_embeddings, all_labels = [], [] |
|
|
|
|
|
for batch in tqdm(loader, desc="Extracting"): |
|
|
full = batch['full'].to(inference.device) |
|
|
face = batch['face'].to(inference.device) |
|
|
eye = batch['eye'].to(inference.device) |
|
|
has_face = batch['has_face'].to(inference.device) |
|
|
has_eye = batch['has_eye'].to(inference.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
embeddings = inference.model.get_embeddings(full, face, eye, has_face, has_eye) |
|
|
|
|
|
all_embeddings.append(embeddings.cpu()) |
|
|
all_labels.extend(batch['label'].tolist()) |
|
|
|
|
|
embeddings = torch.cat(all_embeddings, dim=0).numpy() |
|
|
|
|
|
print("Running t-SNE...") |
|
|
tsne = TSNE(n_components=2, random_state=42, perplexity=30) |
|
|
embeddings_2d = tsne.fit_transform(embeddings) |
|
|
|
|
|
plt.figure(figsize=(14, 10)) |
|
|
colors = plt.cm.tab20(np.linspace(0, 1, max_artists)) |
|
|
|
|
|
for label in set(all_labels): |
|
|
mask = np.array(all_labels) == label |
|
|
plt.scatter(embeddings_2d[mask, 0], embeddings_2d[mask, 1], c=[colors[label]], alpha=0.7, s=50) |
|
|
|
|
|
plt.title('Artist Style Embeddings (t-SNE)') |
|
|
plt.tight_layout() |
|
|
plt.savefig(output_path, dpi=150) |
|
|
plt.close() |
|
|
print(f"Saved to {output_path}") |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument('--checkpoint', type=str, required=True) |
|
|
parser.add_argument('--dataset_root', type=str, default='./dataset') |
|
|
parser.add_argument('--dataset_face_root', type=str, default='./dataset_face') |
|
|
parser.add_argument('--dataset_eyes_root', type=str, default='./dataset_eyes') |
|
|
parser.add_argument('--mode', type=str, default='evaluate', choices=['evaluate', 'visualize', 'predict']) |
|
|
parser.add_argument('--image', type=str, default=None) |
|
|
parser.add_argument('--output', type=str, default='tsne.png') |
|
|
parser.add_argument('--device', type=str, default='cuda') |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.mode == 'evaluate': |
|
|
evaluate_model(args.checkpoint, args.dataset_root, args.dataset_face_root, args.dataset_eyes_root, args.device) |
|
|
elif args.mode == 'visualize': |
|
|
visualize_embeddings(args.checkpoint, args.dataset_root, args.dataset_face_root, args.dataset_eyes_root, args.output, device=args.device) |
|
|
elif args.mode == 'predict': |
|
|
if not args.image: |
|
|
print("--image required for predict mode") |
|
|
return |
|
|
inference = ArtistEmbeddingInference(args.checkpoint, args.device) |
|
|
image = Image.open(args.image).convert('RGB') |
|
|
predictions = inference.predict_artist(image, top_k=10) |
|
|
print("\nTop 10 Predictions:") |
|
|
for i, (artist, prob) in enumerate(predictions, 1): |
|
|
print(f"{i}. {artist}: {prob:.4f}") |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|