Three-View-Style-Embedder-Combined / extract_embeddings.py
iljung1106
Initial commit
546ff88
"""
Three-View-Style-Embedder - Fast Embedding Extraction
DataLoader ๊ธฐ๋ฐ˜ ๊ณ ์† ๋ฐฐ์น˜ ์ฒ˜๋ฆฌ
"""
import argparse
import random
import itertools
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import json
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
from tqdm import tqdm
from collections import defaultdict
from config import get_config
from model import ArtistStyleModel
class ArtistCombinationDataset(Dataset):
"""๋ชจ๋“  ์ž‘๊ฐ€์˜ ์กฐํ•ฉ์„ ๋ฏธ๋ฆฌ ์ƒ์„ฑํ•˜๋Š” ๋ฐ์ดํ„ฐ์…‹"""
def __init__(
self,
dataset_root: str,
dataset_face_root: str,
dataset_eyes_root: str,
max_combinations: int = 30,
):
self.dataset_root = Path(dataset_root)
self.dataset_face_root = Path(dataset_face_root)
self.dataset_eyes_root = Path(dataset_eyes_root)
self.max_combinations = max_combinations
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# ๋ชจ๋“  ์ƒ˜ํ”Œ ๋ฏธ๋ฆฌ ์ƒ์„ฑ: (artist_name, full_path, face_path, eye_path)
self.samples = []
self.artist_to_indices = defaultdict(list) # artist -> [sample indices]
self._build_samples()
def _get_image_paths(self, folder: Path) -> List[Path]:
if not folder.exists():
return []
return list(folder.glob("*.jpg")) + list(folder.glob("*.png")) + list(folder.glob("*.webp"))
def _build_samples(self):
"""๋ชจ๋“  ์ž‘๊ฐ€์˜ ๋ชจ๋“  ์กฐํ•ฉ ๋ฏธ๋ฆฌ ์ƒ์„ฑ"""
print("Building sample combinations...")
artist_dirs = [d for d in self.dataset_root.iterdir() if d.is_dir()]
for artist_dir in tqdm(artist_dirs, desc="Preparing artists"):
artist_name = artist_dir.name
full_paths = self._get_image_paths(artist_dir)
if not full_paths:
continue
face_paths = self._get_image_paths(self.dataset_face_root / artist_name)
eye_paths = self._get_image_paths(self.dataset_eyes_root / artist_name)
# ์กฐํ•ฉ ์ƒ์„ฑ
face_options = face_paths if face_paths else [None]
eye_options = eye_paths if eye_paths else [None]
all_combinations = list(itertools.product(full_paths, face_options, eye_options))
random.shuffle(all_combinations)
selected = all_combinations[:self.max_combinations]
# ์ƒ˜ํ”Œ ์ถ”๊ฐ€
for full_path, face_path, eye_path in selected:
idx = len(self.samples)
self.samples.append((artist_name, full_path, face_path, eye_path))
self.artist_to_indices[artist_name].append(idx)
print(f"Total samples: {len(self.samples)} from {len(self.artist_to_indices)} artists")
def _load_image(self, path: Optional[Path]) -> Optional[torch.Tensor]:
if path is None:
return None
try:
img = Image.open(path)
if img.mode in ('RGBA', 'LA', 'P'):
background = Image.new('RGB', img.size, (255, 255, 255))
if img.mode == 'P':
img = img.convert('RGBA')
if img.mode in ('RGBA', 'LA'):
background.paste(img, mask=img.split()[-1])
img = background
else:
img = img.convert('RGB')
else:
img = img.convert('RGB')
return self.transform(img)
except:
return None
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
artist_name, full_path, face_path, eye_path = self.samples[idx]
full_tensor = self._load_image(full_path)
if full_tensor is None:
full_tensor = torch.zeros(3, 224, 224)
valid = False
else:
valid = True
face_tensor = self._load_image(face_path)
has_face = face_tensor is not None
if not has_face:
face_tensor = torch.zeros(3, 224, 224)
eye_tensor = self._load_image(eye_path)
has_eye = eye_tensor is not None
if not has_eye:
eye_tensor = torch.zeros(3, 224, 224)
return {
'full': full_tensor,
'face': face_tensor,
'eye': eye_tensor,
'has_face': has_face,
'has_eye': has_eye,
'artist': artist_name,
'valid': valid,
'idx': idx,
}
def collate_fn(batch):
return {
'full': torch.stack([x['full'] for x in batch]),
'face': torch.stack([x['face'] for x in batch]),
'eye': torch.stack([x['eye'] for x in batch]),
'has_face': torch.tensor([x['has_face'] for x in batch]),
'has_eye': torch.tensor([x['has_eye'] for x in batch]),
'artist': [x['artist'] for x in batch],
'valid': torch.tensor([x['valid'] for x in batch]),
'idx': torch.tensor([x['idx'] for x in batch]),
}
@torch.no_grad()
def extract_all_embeddings(
checkpoint_path: str,
dataset_root: str,
dataset_face_root: str,
dataset_eyes_root: str,
output_path: str,
max_combinations: int = 30,
batch_size: int = 64,
num_workers: int = 8,
device: str = 'cuda',
):
"""๊ณ ์† ์ž„๋ฒ ๋”ฉ ์ถ”์ถœ"""
requested_device = device
if requested_device.startswith('cuda') and not torch.cuda.is_available():
print(
"[WARN] --device=cuda requested but torch.cuda.is_available() is False. "
"Falling back to CPU. (Install a CUDA-enabled PyTorch build to use GPU.)"
)
requested_device = 'cpu'
device = torch.device(requested_device)
# ๋ชจ๋ธ ๋กœ๋“œ
print("Loading model...")
# Always load checkpoint on CPU to avoid duplicating large tensors on GPU.
checkpoint = torch.load(checkpoint_path, map_location='cpu')
config = get_config()
model = ArtistStyleModel(
num_classes=len(checkpoint['artist_to_idx']),
embedding_dim=config.model.embedding_dim,
hidden_dim=config.model.hidden_dim,
)
model.load_state_dict(checkpoint['model_state_dict'])
# Reduce VRAM: keep weights in FP16 on CUDA.
if device.type == 'cuda':
model = model.to(dtype=torch.float16)
model = model.to(device)
model.eval()
# ๋ฐ์ดํ„ฐ์…‹ ์ƒ์„ฑ
dataset = ArtistCombinationDataset(
dataset_root=dataset_root,
dataset_face_root=dataset_face_root,
dataset_eyes_root=dataset_eyes_root,
max_combinations=max_combinations,
)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
collate_fn=collate_fn,
pin_memory=True,
)
# ์ž„๋ฒ ๋”ฉ ์ €์žฅ์šฉ
all_embeddings = torch.zeros(len(dataset), config.model.embedding_dim)
all_valid = torch.zeros(len(dataset), dtype=torch.bool)
# ๋ฐฐ์น˜ ์ถ”๋ก  (AMP ์‚ฌ์šฉ)
print("Extracting embeddings with AMP...")
for batch in tqdm(dataloader, desc="Processing"):
full = batch['full'].to(device)
face = batch['face'].to(device)
eye = batch['eye'].to(device)
has_face = batch['has_face'].to(device)
has_eye = batch['has_eye'].to(device)
indices = batch['idx']
valid = batch['valid']
with torch.cuda.amp.autocast(enabled=(device.type == 'cuda')):
embeddings = model.get_embeddings(full, face, eye, has_face, has_eye)
all_embeddings[indices] = embeddings.float().cpu()
all_valid[indices] = valid
# ์ž‘๊ฐ€๋ณ„ ํ‰๊ท  ๊ณ„์‚ฐ
print("Computing artist averages...")
artist_embeddings = {}
failed_artists = []
for artist_name, indices in tqdm(dataset.artist_to_indices.items(), desc="Averaging"):
indices = torch.tensor(indices)
valid_mask = all_valid[indices]
if valid_mask.sum() == 0:
failed_artists.append(artist_name)
continue
valid_embeddings = all_embeddings[indices][valid_mask]
mean_emb = valid_embeddings.mean(dim=0)
mean_emb = F.normalize(mean_emb, p=2, dim=0)
artist_embeddings[artist_name] = mean_emb.numpy()
print(f"Success: {len(artist_embeddings)}, Failed: {len(failed_artists)}")
# ์ €์žฅ
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
artist_names = list(artist_embeddings.keys())
embeddings_array = np.stack([artist_embeddings[name] for name in artist_names])
np.savez_compressed(
output_path,
artist_names=np.array(artist_names),
embeddings=embeddings_array,
)
print(f"Saved: {output_path}")
# ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ
meta_path = output_path.with_suffix('.json')
with open(meta_path, 'w') as f:
json.dump({
'num_artists': len(artist_embeddings),
'embedding_dim': config.model.embedding_dim,
'max_combinations': max_combinations,
'failed_artists': failed_artists,
}, f, indent=2)
print(f"Saved: {meta_path}")
def load_embeddings(npz_path: str) -> Tuple[List[str], np.ndarray]:
"""์ €์žฅ๋œ ์ž„๋ฒ ๋”ฉ ๋กœ๋“œ"""
data = np.load(npz_path)
return data['artist_names'].tolist(), data['embeddings']
def find_similar_artists(
query_embedding: np.ndarray,
artist_names: List[str],
embeddings: np.ndarray,
top_k: int = 10,
) -> List[Tuple[str, float]]:
"""์œ ์‚ฌ ์ž‘๊ฐ€ ๊ฒ€์ƒ‰"""
query_norm = query_embedding / np.linalg.norm(query_embedding)
embeddings_norm = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
similarities = embeddings_norm @ query_norm
top_indices = np.argsort(similarities)[::-1][:top_k]
return [(artist_names[i], float(similarities[i])) for i in top_indices]
def main():
parser = argparse.ArgumentParser(description='Three-View-Style-Embedder - Extract Embeddings')
parser.add_argument('--checkpoint', type=str, required=True)
parser.add_argument('--dataset_root', type=str, default='./dataset')
parser.add_argument('--dataset_face_root', type=str, default='./dataset_face')
parser.add_argument('--dataset_eyes_root', type=str, default='./dataset_eyes')
parser.add_argument('--output', type=str, default='./embeddings/artist_embeddings.npz')
parser.add_argument('--max_combinations', type=int, default=30)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--num_workers', type=int, default=8)
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--seed', type=int, default=42)
args = parser.parse_args()
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
extract_all_embeddings(
checkpoint_path=args.checkpoint,
dataset_root=args.dataset_root,
dataset_face_root=args.dataset_face_root,
dataset_eyes_root=args.dataset_eyes_root,
output_path=args.output,
max_combinations=args.max_combinations,
batch_size=args.batch_size,
num_workers=args.num_workers,
device=args.device,
)
if __name__ == '__main__':
main()