File size: 11,716 Bytes
546ff88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
"""
Three-View-Style-Embedder - Fast Embedding Extraction
DataLoader ๊ธฐ๋ฐ˜ ๊ณ ์† ๋ฐฐ์น˜ ์ฒ˜๋ฆฌ
"""
import argparse
import random
import itertools
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import json

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
from tqdm import tqdm
from collections import defaultdict

from config import get_config
from model import ArtistStyleModel


class ArtistCombinationDataset(Dataset):
    """๋ชจ๋“  ์ž‘๊ฐ€์˜ ์กฐํ•ฉ์„ ๋ฏธ๋ฆฌ ์ƒ์„ฑํ•˜๋Š” ๋ฐ์ดํ„ฐ์…‹"""
    
    def __init__(
        self,
        dataset_root: str,
        dataset_face_root: str,
        dataset_eyes_root: str,
        max_combinations: int = 30,
    ):
        self.dataset_root = Path(dataset_root)
        self.dataset_face_root = Path(dataset_face_root)
        self.dataset_eyes_root = Path(dataset_eyes_root)
        self.max_combinations = max_combinations
        
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        
        # ๋ชจ๋“  ์ƒ˜ํ”Œ ๋ฏธ๋ฆฌ ์ƒ์„ฑ: (artist_name, full_path, face_path, eye_path)
        self.samples = []
        self.artist_to_indices = defaultdict(list)  # artist -> [sample indices]
        
        self._build_samples()
    
    def _get_image_paths(self, folder: Path) -> List[Path]:
        if not folder.exists():
            return []
        return list(folder.glob("*.jpg")) + list(folder.glob("*.png")) + list(folder.glob("*.webp"))
    
    def _build_samples(self):
        """๋ชจ๋“  ์ž‘๊ฐ€์˜ ๋ชจ๋“  ์กฐํ•ฉ ๋ฏธ๋ฆฌ ์ƒ์„ฑ"""
        print("Building sample combinations...")
        
        artist_dirs = [d for d in self.dataset_root.iterdir() if d.is_dir()]
        
        for artist_dir in tqdm(artist_dirs, desc="Preparing artists"):
            artist_name = artist_dir.name
            
            full_paths = self._get_image_paths(artist_dir)
            if not full_paths:
                continue
            
            face_paths = self._get_image_paths(self.dataset_face_root / artist_name)
            eye_paths = self._get_image_paths(self.dataset_eyes_root / artist_name)
            
            # ์กฐํ•ฉ ์ƒ์„ฑ
            face_options = face_paths if face_paths else [None]
            eye_options = eye_paths if eye_paths else [None]
            
            all_combinations = list(itertools.product(full_paths, face_options, eye_options))
            random.shuffle(all_combinations)
            selected = all_combinations[:self.max_combinations]
            
            # ์ƒ˜ํ”Œ ์ถ”๊ฐ€
            for full_path, face_path, eye_path in selected:
                idx = len(self.samples)
                self.samples.append((artist_name, full_path, face_path, eye_path))
                self.artist_to_indices[artist_name].append(idx)
        
        print(f"Total samples: {len(self.samples)} from {len(self.artist_to_indices)} artists")
    
    def _load_image(self, path: Optional[Path]) -> Optional[torch.Tensor]:
        if path is None:
            return None
        try:
            img = Image.open(path)
            if img.mode in ('RGBA', 'LA', 'P'):
                background = Image.new('RGB', img.size, (255, 255, 255))
                if img.mode == 'P':
                    img = img.convert('RGBA')
                if img.mode in ('RGBA', 'LA'):
                    background.paste(img, mask=img.split()[-1])
                    img = background
                else:
                    img = img.convert('RGB')
            else:
                img = img.convert('RGB')
            return self.transform(img)
        except:
            return None
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        artist_name, full_path, face_path, eye_path = self.samples[idx]
        
        full_tensor = self._load_image(full_path)
        if full_tensor is None:
            full_tensor = torch.zeros(3, 224, 224)
            valid = False
        else:
            valid = True
        
        face_tensor = self._load_image(face_path)
        has_face = face_tensor is not None
        if not has_face:
            face_tensor = torch.zeros(3, 224, 224)
        
        eye_tensor = self._load_image(eye_path)
        has_eye = eye_tensor is not None
        if not has_eye:
            eye_tensor = torch.zeros(3, 224, 224)
        
        return {
            'full': full_tensor,
            'face': face_tensor,
            'eye': eye_tensor,
            'has_face': has_face,
            'has_eye': has_eye,
            'artist': artist_name,
            'valid': valid,
            'idx': idx,
        }


def collate_fn(batch):
    return {
        'full': torch.stack([x['full'] for x in batch]),
        'face': torch.stack([x['face'] for x in batch]),
        'eye': torch.stack([x['eye'] for x in batch]),
        'has_face': torch.tensor([x['has_face'] for x in batch]),
        'has_eye': torch.tensor([x['has_eye'] for x in batch]),
        'artist': [x['artist'] for x in batch],
        'valid': torch.tensor([x['valid'] for x in batch]),
        'idx': torch.tensor([x['idx'] for x in batch]),
    }


@torch.no_grad()
def extract_all_embeddings(
    checkpoint_path: str,
    dataset_root: str,
    dataset_face_root: str,
    dataset_eyes_root: str,
    output_path: str,
    max_combinations: int = 30,
    batch_size: int = 64,
    num_workers: int = 8,
    device: str = 'cuda',
):
    """๊ณ ์† ์ž„๋ฒ ๋”ฉ ์ถ”์ถœ"""

    requested_device = device
    if requested_device.startswith('cuda') and not torch.cuda.is_available():
        print(
            "[WARN] --device=cuda requested but torch.cuda.is_available() is False. "
            "Falling back to CPU. (Install a CUDA-enabled PyTorch build to use GPU.)"
        )
        requested_device = 'cpu'
    device = torch.device(requested_device)
    
    # ๋ชจ๋ธ ๋กœ๋“œ
    print("Loading model...")
    # Always load checkpoint on CPU to avoid duplicating large tensors on GPU.
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    config = get_config()
    
    model = ArtistStyleModel(
        num_classes=len(checkpoint['artist_to_idx']),
        embedding_dim=config.model.embedding_dim,
        hidden_dim=config.model.hidden_dim,
    )
    model.load_state_dict(checkpoint['model_state_dict'])

    # Reduce VRAM: keep weights in FP16 on CUDA.
    if device.type == 'cuda':
        model = model.to(dtype=torch.float16)
    model = model.to(device)
    model.eval()
    
    # ๋ฐ์ดํ„ฐ์…‹ ์ƒ์„ฑ
    dataset = ArtistCombinationDataset(
        dataset_root=dataset_root,
        dataset_face_root=dataset_face_root,
        dataset_eyes_root=dataset_eyes_root,
        max_combinations=max_combinations,
    )
    
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        collate_fn=collate_fn,
        pin_memory=True,
    )
    
    # ์ž„๋ฒ ๋”ฉ ์ €์žฅ์šฉ
    all_embeddings = torch.zeros(len(dataset), config.model.embedding_dim)
    all_valid = torch.zeros(len(dataset), dtype=torch.bool)
    
    # ๋ฐฐ์น˜ ์ถ”๋ก  (AMP ์‚ฌ์šฉ)
    print("Extracting embeddings with AMP...")
    for batch in tqdm(dataloader, desc="Processing"):
        full = batch['full'].to(device)
        face = batch['face'].to(device)
        eye = batch['eye'].to(device)
        has_face = batch['has_face'].to(device)
        has_eye = batch['has_eye'].to(device)
        indices = batch['idx']
        valid = batch['valid']
        
        with torch.cuda.amp.autocast(enabled=(device.type == 'cuda')):
            embeddings = model.get_embeddings(full, face, eye, has_face, has_eye)
        
        all_embeddings[indices] = embeddings.float().cpu()
        all_valid[indices] = valid
    
    # ์ž‘๊ฐ€๋ณ„ ํ‰๊ท  ๊ณ„์‚ฐ
    print("Computing artist averages...")
    artist_embeddings = {}
    failed_artists = []
    
    for artist_name, indices in tqdm(dataset.artist_to_indices.items(), desc="Averaging"):
        indices = torch.tensor(indices)
        valid_mask = all_valid[indices]
        
        if valid_mask.sum() == 0:
            failed_artists.append(artist_name)
            continue
        
        valid_embeddings = all_embeddings[indices][valid_mask]
        mean_emb = valid_embeddings.mean(dim=0)
        mean_emb = F.normalize(mean_emb, p=2, dim=0)
        artist_embeddings[artist_name] = mean_emb.numpy()
    
    print(f"Success: {len(artist_embeddings)}, Failed: {len(failed_artists)}")
    
    # ์ €์žฅ
    output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    
    artist_names = list(artist_embeddings.keys())
    embeddings_array = np.stack([artist_embeddings[name] for name in artist_names])
    
    np.savez_compressed(
        output_path,
        artist_names=np.array(artist_names),
        embeddings=embeddings_array,
    )
    print(f"Saved: {output_path}")
    
    # ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ
    meta_path = output_path.with_suffix('.json')
    with open(meta_path, 'w') as f:
        json.dump({
            'num_artists': len(artist_embeddings),
            'embedding_dim': config.model.embedding_dim,
            'max_combinations': max_combinations,
            'failed_artists': failed_artists,
        }, f, indent=2)
    print(f"Saved: {meta_path}")


def load_embeddings(npz_path: str) -> Tuple[List[str], np.ndarray]:
    """์ €์žฅ๋œ ์ž„๋ฒ ๋”ฉ ๋กœ๋“œ"""
    data = np.load(npz_path)
    return data['artist_names'].tolist(), data['embeddings']


def find_similar_artists(
    query_embedding: np.ndarray,
    artist_names: List[str],
    embeddings: np.ndarray,
    top_k: int = 10,
) -> List[Tuple[str, float]]:
    """์œ ์‚ฌ ์ž‘๊ฐ€ ๊ฒ€์ƒ‰"""
    query_norm = query_embedding / np.linalg.norm(query_embedding)
    embeddings_norm = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
    similarities = embeddings_norm @ query_norm
    
    top_indices = np.argsort(similarities)[::-1][:top_k]
    return [(artist_names[i], float(similarities[i])) for i in top_indices]


def main():
    parser = argparse.ArgumentParser(description='Three-View-Style-Embedder - Extract Embeddings')
    parser.add_argument('--checkpoint', type=str, required=True)
    parser.add_argument('--dataset_root', type=str, default='./dataset')
    parser.add_argument('--dataset_face_root', type=str, default='./dataset_face')
    parser.add_argument('--dataset_eyes_root', type=str, default='./dataset_eyes')
    parser.add_argument('--output', type=str, default='./embeddings/artist_embeddings.npz')
    parser.add_argument('--max_combinations', type=int, default=30)
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--num_workers', type=int, default=8)
    parser.add_argument('--device', type=str, default='cuda')
    parser.add_argument('--seed', type=int, default=42)
    
    args = parser.parse_args()
    
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    
    extract_all_embeddings(
        checkpoint_path=args.checkpoint,
        dataset_root=args.dataset_root,
        dataset_face_root=args.dataset_face_root,
        dataset_eyes_root=args.dataset_eyes_root,
        output_path=args.output,
        max_combinations=args.max_combinations,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        device=args.device,
    )


if __name__ == '__main__':
    main()