|
|
""" |
|
|
Three-View-Style-Embedder - Fast Embedding Extraction |
|
|
DataLoader ๊ธฐ๋ฐ ๊ณ ์ ๋ฐฐ์น ์ฒ๋ฆฌ |
|
|
""" |
|
|
import argparse |
|
|
import random |
|
|
import itertools |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Optional, Tuple |
|
|
import json |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from torchvision import transforms |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
from tqdm import tqdm |
|
|
from collections import defaultdict |
|
|
|
|
|
from config import get_config |
|
|
from model import ArtistStyleModel |
|
|
|
|
|
|
|
|
class ArtistCombinationDataset(Dataset): |
|
|
"""๋ชจ๋ ์๊ฐ์ ์กฐํฉ์ ๋ฏธ๋ฆฌ ์์ฑํ๋ ๋ฐ์ดํฐ์
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dataset_root: str, |
|
|
dataset_face_root: str, |
|
|
dataset_eyes_root: str, |
|
|
max_combinations: int = 30, |
|
|
): |
|
|
self.dataset_root = Path(dataset_root) |
|
|
self.dataset_face_root = Path(dataset_face_root) |
|
|
self.dataset_eyes_root = Path(dataset_eyes_root) |
|
|
self.max_combinations = max_combinations |
|
|
|
|
|
self.transform = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
|
]) |
|
|
|
|
|
|
|
|
self.samples = [] |
|
|
self.artist_to_indices = defaultdict(list) |
|
|
|
|
|
self._build_samples() |
|
|
|
|
|
def _get_image_paths(self, folder: Path) -> List[Path]: |
|
|
if not folder.exists(): |
|
|
return [] |
|
|
return list(folder.glob("*.jpg")) + list(folder.glob("*.png")) + list(folder.glob("*.webp")) |
|
|
|
|
|
def _build_samples(self): |
|
|
"""๋ชจ๋ ์๊ฐ์ ๋ชจ๋ ์กฐํฉ ๋ฏธ๋ฆฌ ์์ฑ""" |
|
|
print("Building sample combinations...") |
|
|
|
|
|
artist_dirs = [d for d in self.dataset_root.iterdir() if d.is_dir()] |
|
|
|
|
|
for artist_dir in tqdm(artist_dirs, desc="Preparing artists"): |
|
|
artist_name = artist_dir.name |
|
|
|
|
|
full_paths = self._get_image_paths(artist_dir) |
|
|
if not full_paths: |
|
|
continue |
|
|
|
|
|
face_paths = self._get_image_paths(self.dataset_face_root / artist_name) |
|
|
eye_paths = self._get_image_paths(self.dataset_eyes_root / artist_name) |
|
|
|
|
|
|
|
|
face_options = face_paths if face_paths else [None] |
|
|
eye_options = eye_paths if eye_paths else [None] |
|
|
|
|
|
all_combinations = list(itertools.product(full_paths, face_options, eye_options)) |
|
|
random.shuffle(all_combinations) |
|
|
selected = all_combinations[:self.max_combinations] |
|
|
|
|
|
|
|
|
for full_path, face_path, eye_path in selected: |
|
|
idx = len(self.samples) |
|
|
self.samples.append((artist_name, full_path, face_path, eye_path)) |
|
|
self.artist_to_indices[artist_name].append(idx) |
|
|
|
|
|
print(f"Total samples: {len(self.samples)} from {len(self.artist_to_indices)} artists") |
|
|
|
|
|
def _load_image(self, path: Optional[Path]) -> Optional[torch.Tensor]: |
|
|
if path is None: |
|
|
return None |
|
|
try: |
|
|
img = Image.open(path) |
|
|
if img.mode in ('RGBA', 'LA', 'P'): |
|
|
background = Image.new('RGB', img.size, (255, 255, 255)) |
|
|
if img.mode == 'P': |
|
|
img = img.convert('RGBA') |
|
|
if img.mode in ('RGBA', 'LA'): |
|
|
background.paste(img, mask=img.split()[-1]) |
|
|
img = background |
|
|
else: |
|
|
img = img.convert('RGB') |
|
|
else: |
|
|
img = img.convert('RGB') |
|
|
return self.transform(img) |
|
|
except: |
|
|
return None |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.samples) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
artist_name, full_path, face_path, eye_path = self.samples[idx] |
|
|
|
|
|
full_tensor = self._load_image(full_path) |
|
|
if full_tensor is None: |
|
|
full_tensor = torch.zeros(3, 224, 224) |
|
|
valid = False |
|
|
else: |
|
|
valid = True |
|
|
|
|
|
face_tensor = self._load_image(face_path) |
|
|
has_face = face_tensor is not None |
|
|
if not has_face: |
|
|
face_tensor = torch.zeros(3, 224, 224) |
|
|
|
|
|
eye_tensor = self._load_image(eye_path) |
|
|
has_eye = eye_tensor is not None |
|
|
if not has_eye: |
|
|
eye_tensor = torch.zeros(3, 224, 224) |
|
|
|
|
|
return { |
|
|
'full': full_tensor, |
|
|
'face': face_tensor, |
|
|
'eye': eye_tensor, |
|
|
'has_face': has_face, |
|
|
'has_eye': has_eye, |
|
|
'artist': artist_name, |
|
|
'valid': valid, |
|
|
'idx': idx, |
|
|
} |
|
|
|
|
|
|
|
|
def collate_fn(batch): |
|
|
return { |
|
|
'full': torch.stack([x['full'] for x in batch]), |
|
|
'face': torch.stack([x['face'] for x in batch]), |
|
|
'eye': torch.stack([x['eye'] for x in batch]), |
|
|
'has_face': torch.tensor([x['has_face'] for x in batch]), |
|
|
'has_eye': torch.tensor([x['has_eye'] for x in batch]), |
|
|
'artist': [x['artist'] for x in batch], |
|
|
'valid': torch.tensor([x['valid'] for x in batch]), |
|
|
'idx': torch.tensor([x['idx'] for x in batch]), |
|
|
} |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def extract_all_embeddings( |
|
|
checkpoint_path: str, |
|
|
dataset_root: str, |
|
|
dataset_face_root: str, |
|
|
dataset_eyes_root: str, |
|
|
output_path: str, |
|
|
max_combinations: int = 30, |
|
|
batch_size: int = 64, |
|
|
num_workers: int = 8, |
|
|
device: str = 'cuda', |
|
|
): |
|
|
"""๊ณ ์ ์๋ฒ ๋ฉ ์ถ์ถ""" |
|
|
|
|
|
requested_device = device |
|
|
if requested_device.startswith('cuda') and not torch.cuda.is_available(): |
|
|
print( |
|
|
"[WARN] --device=cuda requested but torch.cuda.is_available() is False. " |
|
|
"Falling back to CPU. (Install a CUDA-enabled PyTorch build to use GPU.)" |
|
|
) |
|
|
requested_device = 'cpu' |
|
|
device = torch.device(requested_device) |
|
|
|
|
|
|
|
|
print("Loading model...") |
|
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location='cpu') |
|
|
config = get_config() |
|
|
|
|
|
model = ArtistStyleModel( |
|
|
num_classes=len(checkpoint['artist_to_idx']), |
|
|
embedding_dim=config.model.embedding_dim, |
|
|
hidden_dim=config.model.hidden_dim, |
|
|
) |
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
|
|
|
|
|
|
if device.type == 'cuda': |
|
|
model = model.to(dtype=torch.float16) |
|
|
model = model.to(device) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
dataset = ArtistCombinationDataset( |
|
|
dataset_root=dataset_root, |
|
|
dataset_face_root=dataset_face_root, |
|
|
dataset_eyes_root=dataset_eyes_root, |
|
|
max_combinations=max_combinations, |
|
|
) |
|
|
|
|
|
dataloader = DataLoader( |
|
|
dataset, |
|
|
batch_size=batch_size, |
|
|
shuffle=False, |
|
|
num_workers=num_workers, |
|
|
collate_fn=collate_fn, |
|
|
pin_memory=True, |
|
|
) |
|
|
|
|
|
|
|
|
all_embeddings = torch.zeros(len(dataset), config.model.embedding_dim) |
|
|
all_valid = torch.zeros(len(dataset), dtype=torch.bool) |
|
|
|
|
|
|
|
|
print("Extracting embeddings with AMP...") |
|
|
for batch in tqdm(dataloader, desc="Processing"): |
|
|
full = batch['full'].to(device) |
|
|
face = batch['face'].to(device) |
|
|
eye = batch['eye'].to(device) |
|
|
has_face = batch['has_face'].to(device) |
|
|
has_eye = batch['has_eye'].to(device) |
|
|
indices = batch['idx'] |
|
|
valid = batch['valid'] |
|
|
|
|
|
with torch.cuda.amp.autocast(enabled=(device.type == 'cuda')): |
|
|
embeddings = model.get_embeddings(full, face, eye, has_face, has_eye) |
|
|
|
|
|
all_embeddings[indices] = embeddings.float().cpu() |
|
|
all_valid[indices] = valid |
|
|
|
|
|
|
|
|
print("Computing artist averages...") |
|
|
artist_embeddings = {} |
|
|
failed_artists = [] |
|
|
|
|
|
for artist_name, indices in tqdm(dataset.artist_to_indices.items(), desc="Averaging"): |
|
|
indices = torch.tensor(indices) |
|
|
valid_mask = all_valid[indices] |
|
|
|
|
|
if valid_mask.sum() == 0: |
|
|
failed_artists.append(artist_name) |
|
|
continue |
|
|
|
|
|
valid_embeddings = all_embeddings[indices][valid_mask] |
|
|
mean_emb = valid_embeddings.mean(dim=0) |
|
|
mean_emb = F.normalize(mean_emb, p=2, dim=0) |
|
|
artist_embeddings[artist_name] = mean_emb.numpy() |
|
|
|
|
|
print(f"Success: {len(artist_embeddings)}, Failed: {len(failed_artists)}") |
|
|
|
|
|
|
|
|
output_path = Path(output_path) |
|
|
output_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
artist_names = list(artist_embeddings.keys()) |
|
|
embeddings_array = np.stack([artist_embeddings[name] for name in artist_names]) |
|
|
|
|
|
np.savez_compressed( |
|
|
output_path, |
|
|
artist_names=np.array(artist_names), |
|
|
embeddings=embeddings_array, |
|
|
) |
|
|
print(f"Saved: {output_path}") |
|
|
|
|
|
|
|
|
meta_path = output_path.with_suffix('.json') |
|
|
with open(meta_path, 'w') as f: |
|
|
json.dump({ |
|
|
'num_artists': len(artist_embeddings), |
|
|
'embedding_dim': config.model.embedding_dim, |
|
|
'max_combinations': max_combinations, |
|
|
'failed_artists': failed_artists, |
|
|
}, f, indent=2) |
|
|
print(f"Saved: {meta_path}") |
|
|
|
|
|
|
|
|
def load_embeddings(npz_path: str) -> Tuple[List[str], np.ndarray]: |
|
|
"""์ ์ฅ๋ ์๋ฒ ๋ฉ ๋ก๋""" |
|
|
data = np.load(npz_path) |
|
|
return data['artist_names'].tolist(), data['embeddings'] |
|
|
|
|
|
|
|
|
def find_similar_artists( |
|
|
query_embedding: np.ndarray, |
|
|
artist_names: List[str], |
|
|
embeddings: np.ndarray, |
|
|
top_k: int = 10, |
|
|
) -> List[Tuple[str, float]]: |
|
|
"""์ ์ฌ ์๊ฐ ๊ฒ์""" |
|
|
query_norm = query_embedding / np.linalg.norm(query_embedding) |
|
|
embeddings_norm = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True) |
|
|
similarities = embeddings_norm @ query_norm |
|
|
|
|
|
top_indices = np.argsort(similarities)[::-1][:top_k] |
|
|
return [(artist_names[i], float(similarities[i])) for i in top_indices] |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description='Three-View-Style-Embedder - Extract Embeddings') |
|
|
parser.add_argument('--checkpoint', type=str, required=True) |
|
|
parser.add_argument('--dataset_root', type=str, default='./dataset') |
|
|
parser.add_argument('--dataset_face_root', type=str, default='./dataset_face') |
|
|
parser.add_argument('--dataset_eyes_root', type=str, default='./dataset_eyes') |
|
|
parser.add_argument('--output', type=str, default='./embeddings/artist_embeddings.npz') |
|
|
parser.add_argument('--max_combinations', type=int, default=30) |
|
|
parser.add_argument('--batch_size', type=int, default=64) |
|
|
parser.add_argument('--num_workers', type=int, default=8) |
|
|
parser.add_argument('--device', type=str, default='cuda') |
|
|
parser.add_argument('--seed', type=int, default=42) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
random.seed(args.seed) |
|
|
np.random.seed(args.seed) |
|
|
torch.manual_seed(args.seed) |
|
|
|
|
|
extract_all_embeddings( |
|
|
checkpoint_path=args.checkpoint, |
|
|
dataset_root=args.dataset_root, |
|
|
dataset_face_root=args.dataset_face_root, |
|
|
dataset_eyes_root=args.dataset_eyes_root, |
|
|
output_path=args.output, |
|
|
max_combinations=args.max_combinations, |
|
|
batch_size=args.batch_size, |
|
|
num_workers=args.num_workers, |
|
|
device=args.device, |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|