|
|
""" |
|
|
Artist Style Embedding - Dataset & DataLoader |
|
|
""" |
|
|
import os |
|
|
import random |
|
|
import warnings |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Tuple |
|
|
from collections import defaultdict |
|
|
|
|
|
import torch |
|
|
from torch.utils.data import Dataset, DataLoader, Sampler |
|
|
from torchvision import transforms |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
warnings.filterwarnings('ignore', category=UserWarning, module='PIL') |
|
|
|
|
|
|
|
|
class ArtistDataset(Dataset): |
|
|
"""Multi-branch artist dataset""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dataset_root: str, |
|
|
dataset_face_root: str, |
|
|
dataset_eyes_root: str, |
|
|
artist_to_idx: Dict[str, int], |
|
|
image_paths: Dict[str, List[str]], |
|
|
face_paths: Dict[str, List[str]], |
|
|
eye_paths: Dict[str, List[str]], |
|
|
image_size: int = 224, |
|
|
is_training: bool = True, |
|
|
): |
|
|
self.dataset_root = Path(dataset_root) |
|
|
self.dataset_face_root = Path(dataset_face_root) |
|
|
self.dataset_eyes_root = Path(dataset_eyes_root) |
|
|
self.artist_to_idx = artist_to_idx |
|
|
self.image_size = image_size |
|
|
self.is_training = is_training |
|
|
|
|
|
|
|
|
self.samples = [] |
|
|
for artist, paths in image_paths.items(): |
|
|
for img_path in paths: |
|
|
self.samples.append((artist, os.path.basename(img_path))) |
|
|
|
|
|
self.transform = self._get_transforms() |
|
|
self.transform_eval = self._get_eval_transforms() |
|
|
|
|
|
|
|
|
self._face_cache = {artist: [Path(p) for p in paths] for artist, paths in face_paths.items()} |
|
|
self._eye_cache = {artist: [Path(p) for p in paths] for artist, paths in eye_paths.items()} |
|
|
|
|
|
def _get_transforms(self): |
|
|
return transforms.Compose([ |
|
|
transforms.Resize((self.image_size + 32, self.image_size + 32)), |
|
|
transforms.RandomCrop(self.image_size), |
|
|
transforms.RandomHorizontalFlip(p=0.5), |
|
|
transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.02), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
|
transforms.RandomErasing(p=0.1, scale=(0.02, 0.1)), |
|
|
]) |
|
|
|
|
|
def _get_eval_transforms(self): |
|
|
return transforms.Compose([ |
|
|
transforms.Resize((self.image_size, self.image_size)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
|
]) |
|
|
|
|
|
def _load_image(self, path: Path): |
|
|
try: |
|
|
img = Image.open(path) |
|
|
|
|
|
if img.mode in ('RGBA', 'LA', 'P'): |
|
|
|
|
|
background = Image.new('RGB', img.size, (255, 255, 255)) |
|
|
if img.mode == 'P': |
|
|
img = img.convert('RGBA') |
|
|
background.paste(img, mask=img.split()[-1] if img.mode == 'RGBA' else None) |
|
|
return background |
|
|
return img.convert('RGB') |
|
|
except Exception: |
|
|
return None |
|
|
|
|
|
def _get_placeholder(self) -> torch.Tensor: |
|
|
return torch.zeros(3, self.image_size, self.image_size) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.samples) |
|
|
|
|
|
def __getitem__(self, idx: int) -> Dict: |
|
|
artist, img_name = self.samples[idx] |
|
|
label = self.artist_to_idx[artist] |
|
|
transform = self.transform if self.is_training else self.transform_eval |
|
|
|
|
|
|
|
|
full_path = self.dataset_root / artist / img_name |
|
|
full_img = self._load_image(full_path) |
|
|
if full_img is None: |
|
|
return self.__getitem__((idx + 1) % len(self)) |
|
|
full_tensor = transform(full_img) |
|
|
|
|
|
|
|
|
face_paths = self._face_cache.get(artist, []) |
|
|
if face_paths: |
|
|
face_path = random.choice(face_paths) |
|
|
face_img = self._load_image(face_path) |
|
|
face_tensor = transform(face_img) if face_img else self._get_placeholder() |
|
|
has_face = face_img is not None |
|
|
else: |
|
|
face_tensor = self._get_placeholder() |
|
|
has_face = False |
|
|
|
|
|
|
|
|
eye_paths = self._eye_cache.get(artist, []) |
|
|
if eye_paths: |
|
|
eye_path = random.choice(eye_paths) |
|
|
eye_img = self._load_image(eye_path) |
|
|
eye_tensor = transform(eye_img) if eye_img else self._get_placeholder() |
|
|
has_eye = eye_img is not None |
|
|
else: |
|
|
eye_tensor = self._get_placeholder() |
|
|
has_eye = False |
|
|
|
|
|
return { |
|
|
'full': full_tensor, |
|
|
'face': face_tensor, |
|
|
'eye': eye_tensor, |
|
|
'has_face': has_face, |
|
|
'has_eye': has_eye, |
|
|
'label': label, |
|
|
'artist': artist, |
|
|
} |
|
|
|
|
|
|
|
|
class PKSampler(Sampler): |
|
|
"""P classes, K samples per class sampler for metric learning""" |
|
|
|
|
|
def __init__(self, dataset: ArtistDataset, p: int = 32, k: int = 4): |
|
|
self.dataset = dataset |
|
|
self.p = p |
|
|
self.k = k |
|
|
|
|
|
self.class_to_indices = defaultdict(list) |
|
|
for idx, (artist, _) in enumerate(dataset.samples): |
|
|
label = dataset.artist_to_idx[artist] |
|
|
self.class_to_indices[label].append(idx) |
|
|
|
|
|
self.classes = list(self.class_to_indices.keys()) |
|
|
|
|
|
def __iter__(self): |
|
|
class_indices = { |
|
|
c: random.sample(indices, len(indices)) |
|
|
for c, indices in self.class_to_indices.items() |
|
|
} |
|
|
class_pointers = {c: 0 for c in self.classes} |
|
|
|
|
|
class_order = self.classes.copy() |
|
|
random.shuffle(class_order) |
|
|
|
|
|
batches = [] |
|
|
batch = [] |
|
|
classes_in_batch = set() |
|
|
|
|
|
for cls in class_order: |
|
|
if len(classes_in_batch) >= self.p: |
|
|
batches.append(batch) |
|
|
batch = [] |
|
|
classes_in_batch = set() |
|
|
|
|
|
indices = class_indices[cls] |
|
|
ptr = class_pointers[cls] |
|
|
|
|
|
samples = [] |
|
|
for _ in range(self.k): |
|
|
if ptr >= len(indices): |
|
|
ptr = 0 |
|
|
random.shuffle(indices) |
|
|
samples.append(indices[ptr]) |
|
|
ptr += 1 |
|
|
|
|
|
class_pointers[cls] = ptr |
|
|
batch.extend(samples) |
|
|
classes_in_batch.add(cls) |
|
|
|
|
|
if batch: |
|
|
batches.append(batch) |
|
|
|
|
|
random.shuffle(batches) |
|
|
for batch in batches: |
|
|
yield batch |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.classes) // self.p |
|
|
|
|
|
|
|
|
def build_dataset_splits( |
|
|
dataset_root: str, |
|
|
dataset_face_root: str, |
|
|
dataset_eyes_root: str, |
|
|
min_images: int = 3, |
|
|
train_ratio: float = 0.8, |
|
|
val_ratio: float = 0.1, |
|
|
seed: int = 42, |
|
|
) -> Tuple[Dict[str, int], Dict[str, Dict[str, List[str]]], Dict[str, Dict[str, List[str]]], Dict[str, Dict[str, List[str]]]]: |
|
|
""" |
|
|
Returns: |
|
|
artist_to_idx: 작가명 -> 인덱스 매핑 |
|
|
full_splits: {'train': {artist: [paths]}, 'val': {...}, 'test': {...}} |
|
|
face_splits: 동일 구조 |
|
|
eye_splits: 동일 구조 |
|
|
""" |
|
|
random.seed(seed) |
|
|
np.random.seed(seed) |
|
|
|
|
|
dataset_path = Path(dataset_root) |
|
|
face_path = Path(dataset_face_root) |
|
|
eyes_path = Path(dataset_eyes_root) |
|
|
|
|
|
artist_images = {} |
|
|
artist_faces = {} |
|
|
artist_eyes = {} |
|
|
|
|
|
print("Scanning dataset...") |
|
|
artists = [d for d in dataset_path.iterdir() if d.is_dir()] |
|
|
|
|
|
for artist_dir in tqdm(artists, desc="Loading artists"): |
|
|
artist_name = artist_dir.name |
|
|
|
|
|
|
|
|
images = list(artist_dir.glob("*.jpg")) + \ |
|
|
list(artist_dir.glob("*.png")) + \ |
|
|
list(artist_dir.glob("*.webp")) |
|
|
|
|
|
if len(images) >= min_images: |
|
|
artist_images[artist_name] = [str(p) for p in images] |
|
|
|
|
|
|
|
|
face_dir = face_path / artist_name |
|
|
if face_dir.exists(): |
|
|
faces = list(face_dir.glob("*.jpg")) + \ |
|
|
list(face_dir.glob("*.png")) + \ |
|
|
list(face_dir.glob("*.webp")) |
|
|
artist_faces[artist_name] = [str(p) for p in faces] |
|
|
else: |
|
|
artist_faces[artist_name] = [] |
|
|
|
|
|
|
|
|
eye_dir = eyes_path / artist_name |
|
|
if eye_dir.exists(): |
|
|
eyes = list(eye_dir.glob("*.jpg")) + \ |
|
|
list(eye_dir.glob("*.png")) + \ |
|
|
list(eye_dir.glob("*.webp")) |
|
|
artist_eyes[artist_name] = [str(p) for p in eyes] |
|
|
else: |
|
|
artist_eyes[artist_name] = [] |
|
|
|
|
|
print(f"Found {len(artist_images)} artists with >= {min_images} images") |
|
|
|
|
|
artists_sorted = sorted(artist_images.keys()) |
|
|
artist_to_idx = {name: idx for idx, name in enumerate(artists_sorted)} |
|
|
|
|
|
full_splits = {'train': {}, 'val': {}, 'test': {}} |
|
|
face_splits = {'train': {}, 'val': {}, 'test': {}} |
|
|
eye_splits = {'train': {}, 'val': {}, 'test': {}} |
|
|
|
|
|
for artist in artist_images.keys(): |
|
|
|
|
|
images = artist_images[artist] |
|
|
random.shuffle(images) |
|
|
n = len(images) |
|
|
n_train = max(1, int(n * train_ratio)) |
|
|
n_val = max(1, int(n * val_ratio)) |
|
|
|
|
|
full_splits['train'][artist] = images[:n_train] |
|
|
full_splits['val'][artist] = images[n_train:n_train + n_val] |
|
|
full_splits['test'][artist] = images[n_train + n_val:] |
|
|
|
|
|
|
|
|
faces = artist_faces[artist] |
|
|
if faces: |
|
|
random.shuffle(faces) |
|
|
n_f = len(faces) |
|
|
n_f_train = max(1, int(n_f * train_ratio)) if n_f > 0 else 0 |
|
|
n_f_val = max(1, int(n_f * val_ratio)) if n_f > 1 else 0 |
|
|
|
|
|
face_splits['train'][artist] = faces[:n_f_train] |
|
|
face_splits['val'][artist] = faces[n_f_train:n_f_train + n_f_val] |
|
|
face_splits['test'][artist] = faces[n_f_train + n_f_val:] |
|
|
else: |
|
|
face_splits['train'][artist] = [] |
|
|
face_splits['val'][artist] = [] |
|
|
face_splits['test'][artist] = [] |
|
|
|
|
|
|
|
|
eyes = artist_eyes[artist] |
|
|
if eyes: |
|
|
random.shuffle(eyes) |
|
|
n_e = len(eyes) |
|
|
n_e_train = max(1, int(n_e * train_ratio)) if n_e > 0 else 0 |
|
|
n_e_val = max(1, int(n_e * val_ratio)) if n_e > 1 else 0 |
|
|
|
|
|
eye_splits['train'][artist] = eyes[:n_e_train] |
|
|
eye_splits['val'][artist] = eyes[n_e_train:n_e_train + n_e_val] |
|
|
eye_splits['test'][artist] = eyes[n_e_train + n_e_val:] |
|
|
else: |
|
|
eye_splits['train'][artist] = [] |
|
|
eye_splits['val'][artist] = [] |
|
|
eye_splits['test'][artist] = [] |
|
|
|
|
|
|
|
|
for split_name in ['train', 'val', 'test']: |
|
|
total_full = sum(len(imgs) for imgs in full_splits[split_name].values()) |
|
|
total_face = sum(len(imgs) for imgs in face_splits[split_name].values()) |
|
|
total_eye = sum(len(imgs) for imgs in eye_splits[split_name].values()) |
|
|
print(f"{split_name}: {total_full} full, {total_face} face, {total_eye} eye images") |
|
|
|
|
|
return artist_to_idx, full_splits, face_splits, eye_splits |
|
|
|
|
|
for split_name, split_data in splits.items(): |
|
|
total = sum(len(imgs) for imgs in split_data.values()) |
|
|
print(f"{split_name}: {len(split_data)} artists, {total} images") |
|
|
|
|
|
return artist_to_idx, splits |
|
|
|
|
|
|
|
|
def collate_fn(batch): |
|
|
return { |
|
|
'full': torch.stack([item['full'] for item in batch]), |
|
|
'face': torch.stack([item['face'] for item in batch]), |
|
|
'eye': torch.stack([item['eye'] for item in batch]), |
|
|
'has_face': torch.tensor([item['has_face'] for item in batch]), |
|
|
'has_eye': torch.tensor([item['has_eye'] for item in batch]), |
|
|
'label': torch.tensor([item['label'] for item in batch]), |
|
|
'artist': [item['artist'] for item in batch], |
|
|
} |
|
|
|
|
|
|
|
|
def create_dataloaders( |
|
|
config, |
|
|
artist_to_idx: Dict[str, int], |
|
|
full_splits: Dict[str, Dict[str, List[str]]], |
|
|
face_splits: Dict[str, Dict[str, List[str]]], |
|
|
eye_splits: Dict[str, Dict[str, List[str]]], |
|
|
) -> Tuple[DataLoader, DataLoader, DataLoader]: |
|
|
|
|
|
train_dataset = ArtistDataset( |
|
|
dataset_root=config.data.dataset_root, |
|
|
dataset_face_root=config.data.dataset_face_root, |
|
|
dataset_eyes_root=config.data.dataset_eyes_root, |
|
|
artist_to_idx=artist_to_idx, |
|
|
image_paths=full_splits['train'], |
|
|
face_paths=face_splits['train'], |
|
|
eye_paths=eye_splits['train'], |
|
|
image_size=config.data.image_size, |
|
|
is_training=True, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
k = config.train.samples_per_class |
|
|
p = config.train.batch_size // k |
|
|
p = min(p, len(artist_to_idx)) |
|
|
|
|
|
print(f"PKSampler: P={p} classes × K={k} samples = {p*k} batch size") |
|
|
|
|
|
train_sampler = PKSampler( |
|
|
train_dataset, |
|
|
p=p, |
|
|
k=k, |
|
|
) |
|
|
|
|
|
train_loader = DataLoader( |
|
|
train_dataset, |
|
|
batch_sampler=train_sampler, |
|
|
num_workers=config.data.num_workers, |
|
|
pin_memory=config.data.pin_memory, |
|
|
collate_fn=collate_fn, |
|
|
) |
|
|
|
|
|
val_dataset = ArtistDataset( |
|
|
dataset_root=config.data.dataset_root, |
|
|
dataset_face_root=config.data.dataset_face_root, |
|
|
dataset_eyes_root=config.data.dataset_eyes_root, |
|
|
artist_to_idx=artist_to_idx, |
|
|
image_paths=full_splits['val'], |
|
|
face_paths=face_splits['val'], |
|
|
eye_paths=eye_splits['val'], |
|
|
image_size=config.data.image_size, |
|
|
is_training=False, |
|
|
) |
|
|
|
|
|
val_loader = DataLoader( |
|
|
val_dataset, |
|
|
batch_size=config.train.batch_size, |
|
|
shuffle=False, |
|
|
num_workers=config.data.num_workers, |
|
|
pin_memory=config.data.pin_memory, |
|
|
collate_fn=collate_fn, |
|
|
) |
|
|
|
|
|
test_dataset = ArtistDataset( |
|
|
dataset_root=config.data.dataset_root, |
|
|
dataset_face_root=config.data.dataset_face_root, |
|
|
dataset_eyes_root=config.data.dataset_eyes_root, |
|
|
artist_to_idx=artist_to_idx, |
|
|
image_paths=full_splits['test'], |
|
|
face_paths=face_splits['test'], |
|
|
eye_paths=eye_splits['test'], |
|
|
image_size=config.data.image_size, |
|
|
is_training=False, |
|
|
) |
|
|
|
|
|
test_loader = DataLoader( |
|
|
test_dataset, |
|
|
batch_size=config.train.batch_size, |
|
|
shuffle=False, |
|
|
num_workers=config.data.num_workers, |
|
|
pin_memory=config.data.pin_memory, |
|
|
collate_fn=collate_fn, |
|
|
) |
|
|
|
|
|
return train_loader, val_loader, test_loader |
|
|
|