""" Artist Style Embedding - Dataset & DataLoader """ import os import random import warnings from pathlib import Path from typing import Dict, List, Tuple from collections import defaultdict import torch from torch.utils.data import Dataset, DataLoader, Sampler from torchvision import transforms from PIL import Image import numpy as np from tqdm import tqdm # PIL 경고 억제 warnings.filterwarnings('ignore', category=UserWarning, module='PIL') class ArtistDataset(Dataset): """Multi-branch artist dataset""" def __init__( self, dataset_root: str, dataset_face_root: str, dataset_eyes_root: str, artist_to_idx: Dict[str, int], image_paths: Dict[str, List[str]], # 이 split의 full 이미지들 face_paths: Dict[str, List[str]], # 이 split의 face 이미지들 eye_paths: Dict[str, List[str]], # 이 split의 eye 이미지들 image_size: int = 224, is_training: bool = True, ): self.dataset_root = Path(dataset_root) self.dataset_face_root = Path(dataset_face_root) self.dataset_eyes_root = Path(dataset_eyes_root) self.artist_to_idx = artist_to_idx self.image_size = image_size self.is_training = is_training # Flat sample list self.samples = [] for artist, paths in image_paths.items(): for img_path in paths: self.samples.append((artist, os.path.basename(img_path))) self.transform = self._get_transforms() self.transform_eval = self._get_eval_transforms() # Face/Eye paths per artist (이미 split된 것) self._face_cache = {artist: [Path(p) for p in paths] for artist, paths in face_paths.items()} self._eye_cache = {artist: [Path(p) for p in paths] for artist, paths in eye_paths.items()} def _get_transforms(self): return transforms.Compose([ transforms.Resize((self.image_size + 32, self.image_size + 32)), transforms.RandomCrop(self.image_size), transforms.RandomHorizontalFlip(p=0.5), transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.02), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), transforms.RandomErasing(p=0.1, scale=(0.02, 0.1)), ]) def _get_eval_transforms(self): return transforms.Compose([ transforms.Resize((self.image_size, self.image_size)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def _load_image(self, path: Path): try: img = Image.open(path) # RGBA, Palette 등 모든 포맷을 RGB로 변환 if img.mode in ('RGBA', 'LA', 'P'): # 투명 배경을 흰색으로 background = Image.new('RGB', img.size, (255, 255, 255)) if img.mode == 'P': img = img.convert('RGBA') background.paste(img, mask=img.split()[-1] if img.mode == 'RGBA' else None) return background return img.convert('RGB') except Exception: return None def _get_placeholder(self) -> torch.Tensor: return torch.zeros(3, self.image_size, self.image_size) def __len__(self): return len(self.samples) def __getitem__(self, idx: int) -> Dict: artist, img_name = self.samples[idx] label = self.artist_to_idx[artist] transform = self.transform if self.is_training else self.transform_eval # Full image full_path = self.dataset_root / artist / img_name full_img = self._load_image(full_path) if full_img is None: return self.__getitem__((idx + 1) % len(self)) full_tensor = transform(full_img) # Face image face_paths = self._face_cache.get(artist, []) if face_paths: face_path = random.choice(face_paths) face_img = self._load_image(face_path) face_tensor = transform(face_img) if face_img else self._get_placeholder() has_face = face_img is not None else: face_tensor = self._get_placeholder() has_face = False # Eye image eye_paths = self._eye_cache.get(artist, []) if eye_paths: eye_path = random.choice(eye_paths) eye_img = self._load_image(eye_path) eye_tensor = transform(eye_img) if eye_img else self._get_placeholder() has_eye = eye_img is not None else: eye_tensor = self._get_placeholder() has_eye = False return { 'full': full_tensor, 'face': face_tensor, 'eye': eye_tensor, 'has_face': has_face, 'has_eye': has_eye, 'label': label, 'artist': artist, } class PKSampler(Sampler): """P classes, K samples per class sampler for metric learning""" def __init__(self, dataset: ArtistDataset, p: int = 32, k: int = 4): self.dataset = dataset self.p = p self.k = k self.class_to_indices = defaultdict(list) for idx, (artist, _) in enumerate(dataset.samples): label = dataset.artist_to_idx[artist] self.class_to_indices[label].append(idx) self.classes = list(self.class_to_indices.keys()) def __iter__(self): class_indices = { c: random.sample(indices, len(indices)) for c, indices in self.class_to_indices.items() } class_pointers = {c: 0 for c in self.classes} class_order = self.classes.copy() random.shuffle(class_order) batches = [] batch = [] classes_in_batch = set() for cls in class_order: if len(classes_in_batch) >= self.p: batches.append(batch) batch = [] classes_in_batch = set() indices = class_indices[cls] ptr = class_pointers[cls] samples = [] for _ in range(self.k): if ptr >= len(indices): ptr = 0 random.shuffle(indices) samples.append(indices[ptr]) ptr += 1 class_pointers[cls] = ptr batch.extend(samples) classes_in_batch.add(cls) if batch: batches.append(batch) random.shuffle(batches) for batch in batches: yield batch def __len__(self): return len(self.classes) // self.p def build_dataset_splits( dataset_root: str, dataset_face_root: str, dataset_eyes_root: str, min_images: int = 3, train_ratio: float = 0.8, val_ratio: float = 0.1, seed: int = 42, ) -> Tuple[Dict[str, int], Dict[str, Dict[str, List[str]]], Dict[str, Dict[str, List[str]]], Dict[str, Dict[str, List[str]]]]: """ Returns: artist_to_idx: 작가명 -> 인덱스 매핑 full_splits: {'train': {artist: [paths]}, 'val': {...}, 'test': {...}} face_splits: 동일 구조 eye_splits: 동일 구조 """ random.seed(seed) np.random.seed(seed) dataset_path = Path(dataset_root) face_path = Path(dataset_face_root) eyes_path = Path(dataset_eyes_root) artist_images = {} artist_faces = {} artist_eyes = {} print("Scanning dataset...") artists = [d for d in dataset_path.iterdir() if d.is_dir()] for artist_dir in tqdm(artists, desc="Loading artists"): artist_name = artist_dir.name # Full images images = list(artist_dir.glob("*.jpg")) + \ list(artist_dir.glob("*.png")) + \ list(artist_dir.glob("*.webp")) if len(images) >= min_images: artist_images[artist_name] = [str(p) for p in images] # Face images face_dir = face_path / artist_name if face_dir.exists(): faces = list(face_dir.glob("*.jpg")) + \ list(face_dir.glob("*.png")) + \ list(face_dir.glob("*.webp")) artist_faces[artist_name] = [str(p) for p in faces] else: artist_faces[artist_name] = [] # Eye images eye_dir = eyes_path / artist_name if eye_dir.exists(): eyes = list(eye_dir.glob("*.jpg")) + \ list(eye_dir.glob("*.png")) + \ list(eye_dir.glob("*.webp")) artist_eyes[artist_name] = [str(p) for p in eyes] else: artist_eyes[artist_name] = [] print(f"Found {len(artist_images)} artists with >= {min_images} images") artists_sorted = sorted(artist_images.keys()) artist_to_idx = {name: idx for idx, name in enumerate(artists_sorted)} full_splits = {'train': {}, 'val': {}, 'test': {}} face_splits = {'train': {}, 'val': {}, 'test': {}} eye_splits = {'train': {}, 'val': {}, 'test': {}} for artist in artist_images.keys(): # Full images 분할 images = artist_images[artist] random.shuffle(images) n = len(images) n_train = max(1, int(n * train_ratio)) n_val = max(1, int(n * val_ratio)) full_splits['train'][artist] = images[:n_train] full_splits['val'][artist] = images[n_train:n_train + n_val] full_splits['test'][artist] = images[n_train + n_val:] # Face images 분할 (동일 비율) faces = artist_faces[artist] if faces: random.shuffle(faces) n_f = len(faces) n_f_train = max(1, int(n_f * train_ratio)) if n_f > 0 else 0 n_f_val = max(1, int(n_f * val_ratio)) if n_f > 1 else 0 face_splits['train'][artist] = faces[:n_f_train] face_splits['val'][artist] = faces[n_f_train:n_f_train + n_f_val] face_splits['test'][artist] = faces[n_f_train + n_f_val:] else: face_splits['train'][artist] = [] face_splits['val'][artist] = [] face_splits['test'][artist] = [] # Eye images 분할 (동일 비율) eyes = artist_eyes[artist] if eyes: random.shuffle(eyes) n_e = len(eyes) n_e_train = max(1, int(n_e * train_ratio)) if n_e > 0 else 0 n_e_val = max(1, int(n_e * val_ratio)) if n_e > 1 else 0 eye_splits['train'][artist] = eyes[:n_e_train] eye_splits['val'][artist] = eyes[n_e_train:n_e_train + n_e_val] eye_splits['test'][artist] = eyes[n_e_train + n_e_val:] else: eye_splits['train'][artist] = [] eye_splits['val'][artist] = [] eye_splits['test'][artist] = [] # 통계 출력 for split_name in ['train', 'val', 'test']: total_full = sum(len(imgs) for imgs in full_splits[split_name].values()) total_face = sum(len(imgs) for imgs in face_splits[split_name].values()) total_eye = sum(len(imgs) for imgs in eye_splits[split_name].values()) print(f"{split_name}: {total_full} full, {total_face} face, {total_eye} eye images") return artist_to_idx, full_splits, face_splits, eye_splits for split_name, split_data in splits.items(): total = sum(len(imgs) for imgs in split_data.values()) print(f"{split_name}: {len(split_data)} artists, {total} images") return artist_to_idx, splits def collate_fn(batch): return { 'full': torch.stack([item['full'] for item in batch]), 'face': torch.stack([item['face'] for item in batch]), 'eye': torch.stack([item['eye'] for item in batch]), 'has_face': torch.tensor([item['has_face'] for item in batch]), 'has_eye': torch.tensor([item['has_eye'] for item in batch]), 'label': torch.tensor([item['label'] for item in batch]), 'artist': [item['artist'] for item in batch], } def create_dataloaders( config, artist_to_idx: Dict[str, int], full_splits: Dict[str, Dict[str, List[str]]], face_splits: Dict[str, Dict[str, List[str]]], eye_splits: Dict[str, Dict[str, List[str]]], ) -> Tuple[DataLoader, DataLoader, DataLoader]: train_dataset = ArtistDataset( dataset_root=config.data.dataset_root, dataset_face_root=config.data.dataset_face_root, dataset_eyes_root=config.data.dataset_eyes_root, artist_to_idx=artist_to_idx, image_paths=full_splits['train'], face_paths=face_splits['train'], eye_paths=eye_splits['train'], image_size=config.data.image_size, is_training=True, ) # batch_size에서 P와 K 계산 # batch_size = P * K, K는 samples_per_class로 고정 k = config.train.samples_per_class p = config.train.batch_size // k # batch_size=256이면 P=64 p = min(p, len(artist_to_idx)) # 클래스 수 초과 방지 print(f"PKSampler: P={p} classes × K={k} samples = {p*k} batch size") train_sampler = PKSampler( train_dataset, p=p, k=k, ) train_loader = DataLoader( train_dataset, batch_sampler=train_sampler, num_workers=config.data.num_workers, pin_memory=config.data.pin_memory, collate_fn=collate_fn, ) val_dataset = ArtistDataset( dataset_root=config.data.dataset_root, dataset_face_root=config.data.dataset_face_root, dataset_eyes_root=config.data.dataset_eyes_root, artist_to_idx=artist_to_idx, image_paths=full_splits['val'], face_paths=face_splits['val'], eye_paths=eye_splits['val'], image_size=config.data.image_size, is_training=False, ) val_loader = DataLoader( val_dataset, batch_size=config.train.batch_size, shuffle=False, num_workers=config.data.num_workers, pin_memory=config.data.pin_memory, collate_fn=collate_fn, ) test_dataset = ArtistDataset( dataset_root=config.data.dataset_root, dataset_face_root=config.data.dataset_face_root, dataset_eyes_root=config.data.dataset_eyes_root, artist_to_idx=artist_to_idx, image_paths=full_splits['test'], face_paths=face_splits['test'], eye_paths=eye_splits['test'], image_size=config.data.image_size, is_training=False, ) test_loader = DataLoader( test_dataset, batch_size=config.train.batch_size, shuffle=False, num_workers=config.data.num_workers, pin_memory=config.data.pin_memory, collate_fn=collate_fn, ) return train_loader, val_loader, test_loader