Leacb4 commited on
Commit
43824ca
·
verified ·
1 Parent(s): fc5b142

Upload evaluation/color_evaluation.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. evaluation/color_evaluation.py +919 -0
evaluation/color_evaluation.py ADDED
@@ -0,0 +1,919 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
4
+
5
+ import torch
6
+ import pandas as pd
7
+ import numpy as np
8
+ import matplotlib.pyplot as plt
9
+ import seaborn as sns
10
+ import difflib
11
+ from sklearn.metrics.pairwise import cosine_similarity
12
+ from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
13
+ from collections import defaultdict
14
+ from tqdm import tqdm
15
+ from torch.utils.data import Dataset, DataLoader
16
+ from torchvision import transforms
17
+ from PIL import Image
18
+ from io import BytesIO
19
+ import warnings
20
+ warnings.filterwarnings('ignore')
21
+ from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers
22
+
23
+ from config import (
24
+ color_model_path,
25
+ color_emb_dim,
26
+ local_dataset_path,
27
+ column_local_image_path,
28
+ tokeniser_path,
29
+ )
30
+ from color_model import ColorCLIP, Tokenizer
31
+
32
+
33
+ class KaggleDataset(Dataset):
34
+ """Dataset class for KAGL Marqo dataset"""
35
+ def __init__(self, dataframe, image_size=224):
36
+ self.dataframe = dataframe
37
+ self.image_size = image_size
38
+
39
+ # Transforms for validation (no augmentation)
40
+ self.transform = transforms.Compose([
41
+ transforms.Resize((224, 224)),
42
+ transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), # AUGMENTATION
43
+ transforms.ToTensor(),
44
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
45
+ ])
46
+
47
+ def __len__(self):
48
+ return len(self.dataframe)
49
+
50
+ def __getitem__(self, idx):
51
+ row = self.dataframe.iloc[idx]
52
+
53
+ # Handle image - it should be in row['image_url'] and contain the image data as bytes
54
+ image_data = row['image_url']
55
+
56
+ # Check if image_data has 'bytes' key or is already PIL Image
57
+ if isinstance(image_data, dict) and 'bytes' in image_data:
58
+ image = Image.open(BytesIO(image_data['bytes'])).convert("RGB")
59
+ elif hasattr(image_data, 'convert'): # Already a PIL Image
60
+ image = image_data.convert("RGB")
61
+ else:
62
+ # Assume it's raw bytes
63
+ image = Image.open(BytesIO(image_data)).convert("RGB")
64
+
65
+ # Apply validation transform
66
+ image = self.transform(image)
67
+
68
+ # Get text and labels
69
+ description = row['text']
70
+ color = row['color']
71
+
72
+ return image, description, color
73
+
74
+
75
+ def load_kaggle_marqo_dataset(max_samples=5000):
76
+ """Load and prepare Kaggle KAGL dataset with memory optimization"""
77
+ from datasets import load_dataset
78
+ print("📊 Loading Kaggle KAGL dataset...")
79
+
80
+ # Load the dataset
81
+ dataset = load_dataset("Marqo/KAGL")
82
+ df = dataset["data"].to_pandas()
83
+ print(f"✅ Dataset Kaggle loaded")
84
+ print(f" Before filtering: {len(df)} samples")
85
+ print(f" Available columns: {list(df.columns)}")
86
+
87
+ # Ensure we have text and image data
88
+ df = df.dropna(subset=['text', 'image'])
89
+ print(f" After removing missing text/image: {len(df)} samples")
90
+
91
+ df_test = df.copy()
92
+
93
+ # Limit to max_samples with RANDOM SAMPLING to get diverse colors
94
+ if len(df_test) > max_samples:
95
+ df_test = df_test.sample(n=max_samples, random_state=42)
96
+ print(f"📊 Randomly sampled {max_samples} samples from Kaggle dataset")
97
+
98
+ # Create formatted dataset with proper column names
99
+ kaggle_formatted = pd.DataFrame({
100
+ 'image_url': df_test['image'], # This contains image data as bytes
101
+ 'text': df_test['text'],
102
+ 'color': df_test['baseColour'].str.lower().str.replace("grey", "gray") # Use actual colors
103
+ })
104
+
105
+ # Filter out rows with None/NaN colors
106
+ before_color_filter = len(kaggle_formatted)
107
+ kaggle_formatted = kaggle_formatted.dropna(subset=['color'])
108
+ if len(kaggle_formatted) < before_color_filter:
109
+ print(f" After removing missing colors: {len(kaggle_formatted)} samples (removed {before_color_filter - len(kaggle_formatted)} samples)")
110
+
111
+ # Filter for colors that were used during training (11 colors)
112
+ valid_colors = ['beige', 'black', 'blue', 'brown', 'green', 'orange', 'pink', 'purple', 'red', 'white', 'yellow']
113
+ before_valid_filter = len(kaggle_formatted)
114
+ kaggle_formatted = kaggle_formatted[kaggle_formatted['color'].isin(valid_colors)]
115
+ print(f" After filtering for valid colors: {len(kaggle_formatted)} samples (removed {before_valid_filter - len(kaggle_formatted)} samples)")
116
+ print(f" Valid colors found: {sorted(kaggle_formatted['color'].unique())}")
117
+
118
+ print(f" Final dataset size: {len(kaggle_formatted)} samples")
119
+
120
+ # Show color distribution in final dataset
121
+ print(f"🎨 Color distribution in Kaggle dataset:")
122
+ color_counts = kaggle_formatted['color'].value_counts()
123
+ for color in color_counts.index:
124
+ print(f" {color}: {color_counts[color]} samples")
125
+
126
+ return KaggleDataset(kaggle_formatted)
127
+
128
+
129
+ class LocalDataset(Dataset):
130
+ """Dataset class for local validation dataset"""
131
+ def __init__(self, dataframe, image_size=224):
132
+ self.dataframe = dataframe
133
+ self.image_size = image_size
134
+
135
+ # Transforms for validation (no augmentation)
136
+ self.transform = transforms.Compose([
137
+ transforms.Resize((224, 224)),
138
+ transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), # AUGMENTATION
139
+ transforms.ToTensor(),
140
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
141
+ ])
142
+
143
+ def __len__(self):
144
+ return len(self.dataframe)
145
+
146
+ def __getitem__(self, idx):
147
+ row = self.dataframe.iloc[idx]
148
+
149
+ # Load image from local path
150
+ image_path = row[column_local_image_path]
151
+ try:
152
+ image = Image.open(image_path).convert("RGB")
153
+ except Exception as e:
154
+ print(f"Error loading image at index {idx} from {image_path}: {e}")
155
+ # Create a dummy image if loading fails
156
+ image = Image.new('RGB', (224, 224), color='gray')
157
+
158
+ # Apply validation transform
159
+ image = self.transform(image)
160
+
161
+ # Get text and labels
162
+ description = row['text']
163
+ color = row['color']
164
+
165
+ return image, description, color
166
+
167
+
168
+ def load_local_validation_dataset(max_samples=5000):
169
+ """Load and prepare local validation dataset"""
170
+ print("📊 Loading local validation dataset...")
171
+
172
+ df = pd.read_csv(local_dataset_path)
173
+ print(f"✅ Dataset loaded: {len(df)} samples")
174
+
175
+ # Filter out rows with NaN values in image path
176
+ df_clean = df.dropna(subset=[column_local_image_path])
177
+ print(f"📊 After filtering NaN image paths: {len(df_clean)} samples")
178
+
179
+ # Filter for colors that were used during training (11 colors)
180
+ valid_colors = ['beige', 'black', 'blue', 'brown', 'green', 'orange', 'pink', 'purple', 'red', 'white', 'yellow']
181
+ if 'color' in df_clean.columns:
182
+ before_valid_filter = len(df_clean)
183
+ df_clean = df_clean[df_clean['color'].isin(valid_colors)]
184
+ print(f"📊 After filtering for valid colors: {len(df_clean)} samples (removed {before_valid_filter - len(df_clean)} samples)")
185
+ print(f"🎨 Valid colors found: {sorted(df_clean['color'].unique())}")
186
+
187
+ # Limit to max_samples with RANDOM SAMPLING to get diverse colors
188
+ if len(df_clean) > max_samples:
189
+ df_clean = df_clean.sample(n=max_samples, random_state=42)
190
+ print(f"📊 Randomly sampled {max_samples} samples")
191
+
192
+ print(f"📊 Using {len(df_clean)} samples for evaluation")
193
+
194
+ # Show color distribution after sampling
195
+ if 'color' in df_clean.columns:
196
+ print(f"🎨 Color distribution in sampled data:")
197
+ color_counts = df_clean['color'].value_counts()
198
+ print(f" Total unique colors: {len(color_counts)}")
199
+ for color in color_counts.index[:15]: # Show top 15
200
+ print(f" {color}: {color_counts[color]} samples")
201
+
202
+ return LocalDataset(df_clean)
203
+
204
+
205
+ def collate_fn_filter_none(batch):
206
+ """Collate function that filters out None values from batch with debug print"""
207
+ # Filter out None values
208
+ original_len = len(batch)
209
+ batch = [item for item in batch if item is not None]
210
+
211
+ if original_len > len(batch):
212
+ print(f"⚠️ Filtered out {original_len - len(batch)} None values from batch (original: {original_len}, filtered: {len(batch)})")
213
+
214
+ if len(batch) == 0:
215
+ # Return empty batch with correct structure
216
+ print("⚠️ Empty batch after filtering None values")
217
+ return torch.tensor([]), [], []
218
+
219
+ images, texts, colors = zip(*batch)
220
+ images = torch.stack(images, dim=0)
221
+ return images, list(texts), list(colors)
222
+
223
+
224
+ class ColorEvaluator:
225
+ """Evaluate color 16 embeddings"""
226
+
227
+ def __init__(self, device='mps', directory="color_model_analysis"):
228
+ self.device = torch.device(device)
229
+ self.directory = directory
230
+ self.color_emb_dim = color_emb_dim
231
+ os.makedirs(self.directory, exist_ok=True)
232
+
233
+ # Load baseline Fashion CLIP model
234
+ print("📦 Loading baseline Fashion CLIP model...")
235
+ patrick_model_name = "patrickjohncyh/fashion-clip"
236
+ self.baseline_processor = CLIPProcessor.from_pretrained(patrick_model_name)
237
+ self.baseline_model = CLIPModel_transformers.from_pretrained(patrick_model_name).to(self.device)
238
+ self.baseline_model.eval()
239
+ print("✅ Baseline Fashion CLIP model loaded successfully")
240
+
241
+ # Load specialized color model (16D)
242
+ self.color_model = None
243
+ self.color_tokenizer = None
244
+ self._load_color_model()
245
+
246
+ def _load_color_model(self):
247
+ """Load the specialized 16D color model and tokenizer."""
248
+ if self.color_model is not None and self.color_tokenizer is not None:
249
+ return
250
+
251
+ if not os.path.exists(color_model_path):
252
+ raise FileNotFoundError(f"Color model file {color_model_path} not found")
253
+ if not os.path.exists(tokeniser_path):
254
+ raise FileNotFoundError(f"Tokenizer vocab file {tokeniser_path} not found")
255
+
256
+ print("🎨 Loading specialized color model (16D)...")
257
+
258
+ # Load checkpoint first to get the actual vocab size
259
+ state_dict = torch.load(color_model_path, map_location=self.device)
260
+
261
+ # Get vocab size from the embedding weight shape in checkpoint
262
+ vocab_size = state_dict['text_encoder.embedding.weight'].shape[0]
263
+ print(f" Detected vocab size from checkpoint: {vocab_size}")
264
+
265
+ # Load tokenizer vocab
266
+ with open(tokeniser_path, "r") as f:
267
+ vocab = json.load(f)
268
+
269
+ self.color_tokenizer = Tokenizer()
270
+ self.color_tokenizer.load_vocab(vocab)
271
+
272
+ # Create model with the vocab size from checkpoint (not from tokenizer)
273
+ self.color_model = ColorCLIP(vocab_size=vocab_size, embedding_dim=self.color_emb_dim)
274
+
275
+ # Load state dict
276
+ self.color_model.load_state_dict(state_dict)
277
+ self.color_model.to(self.device)
278
+ self.color_model.eval()
279
+ print("✅ Color model loaded successfully")
280
+
281
+ def _tokenize_color_texts(self, texts):
282
+ """Tokenize texts with the color tokenizer and return padded tensors."""
283
+ token_lists = [self.color_tokenizer(t) for t in texts]
284
+ max_len = max((len(toks) for toks in token_lists), default=0)
285
+ max_len = max_len if max_len > 0 else 1
286
+
287
+ input_ids = torch.zeros(len(texts), max_len, dtype=torch.long, device=self.device)
288
+ lengths = torch.zeros(len(texts), dtype=torch.long, device=self.device)
289
+
290
+ for i, toks in enumerate(token_lists):
291
+ if len(toks) > 0:
292
+ input_ids[i, :len(toks)] = torch.tensor(toks, dtype=torch.long, device=self.device)
293
+ lengths[i] = len(toks)
294
+ else:
295
+ lengths[i] = 1 # avoid zero-length
296
+
297
+ return input_ids, lengths
298
+
299
+ def extract_color_embeddings(self, dataloader, embedding_type='text', max_samples=10000):
300
+ """Extract 16D color embeddings from specialized color model."""
301
+ self._load_color_model()
302
+ all_embeddings = []
303
+ all_colors = []
304
+
305
+ sample_count = 0
306
+ with torch.no_grad():
307
+ for batch in tqdm(dataloader, desc=f"Extracting {embedding_type} color embeddings"):
308
+ if sample_count >= max_samples:
309
+ break
310
+
311
+ images, texts, colors = batch
312
+ images = images.to(self.device)
313
+ images = images.expand(-1, 3, -1, -1)
314
+
315
+ if embedding_type == 'text':
316
+ input_ids, lengths = self._tokenize_color_texts(texts)
317
+ embeddings = self.color_model.text_encoder(input_ids, lengths)
318
+ elif embedding_type == 'image':
319
+ embeddings = self.color_model.image_encoder(images)
320
+ else:
321
+ input_ids, lengths = self._tokenize_color_texts(texts)
322
+ embeddings = self.color_model.text_encoder(input_ids, lengths)
323
+
324
+ all_embeddings.append(embeddings.cpu().numpy())
325
+ normalized_colors = [str(c).lower().strip().replace("grey", "gray") for c in colors]
326
+ all_colors.extend(normalized_colors)
327
+
328
+ sample_count += len(images)
329
+
330
+ del images, embeddings
331
+ if embedding_type != 'image':
332
+ del input_ids, lengths
333
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
334
+
335
+ return np.vstack(all_embeddings), all_colors
336
+
337
+ def extract_baseline_embeddings_batch(self, dataloader, embedding_type='text', max_samples=10000):
338
+ """Extract embeddings from baseline Fashion CLIP model"""
339
+ all_embeddings = []
340
+ all_colors = []
341
+
342
+ sample_count = 0
343
+
344
+ with torch.no_grad():
345
+ for batch in tqdm(dataloader, desc=f"Extracting baseline {embedding_type} embeddings"):
346
+ if sample_count >= max_samples:
347
+ break
348
+
349
+ images, texts, colors = batch
350
+ images = images.to(self.device)
351
+ images = images.expand(-1, 3, -1, -1) # Ensure 3 channels
352
+
353
+ # Process text inputs with baseline processor
354
+ text_inputs = self.baseline_processor(text=texts, padding=True, return_tensors="pt")
355
+ text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
356
+
357
+ # Forward pass through baseline model
358
+ outputs = self.baseline_model(**text_inputs, pixel_values=images)
359
+
360
+ # Extract embeddings based on type
361
+ if embedding_type == 'text':
362
+ embeddings = outputs.text_embeds
363
+ elif embedding_type == 'image':
364
+ embeddings = outputs.image_embeds
365
+ else:
366
+ embeddings = outputs.text_embeds
367
+
368
+ all_embeddings.append(embeddings.cpu().numpy())
369
+ all_colors.extend(colors)
370
+
371
+ sample_count += len(images)
372
+
373
+ # Clear GPU memory
374
+ del images, text_inputs, outputs, embeddings
375
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
376
+
377
+ return np.vstack(all_embeddings), all_colors
378
+
379
+ def compute_similarity_metrics(self, embeddings, labels):
380
+ """Compute intra-class and inter-class similarities - optimized version"""
381
+ max_samples = min(5000, len(embeddings))
382
+ if len(embeddings) > max_samples:
383
+ indices = np.random.choice(len(embeddings), max_samples, replace=False)
384
+ embeddings = embeddings[indices]
385
+ labels = [labels[i] for i in indices]
386
+
387
+ similarities = cosine_similarity(embeddings)
388
+
389
+ # Create label groups using numpy for faster indexing
390
+ label_array = np.array(labels)
391
+ unique_labels = np.unique(label_array)
392
+ label_groups = {label: np.where(label_array == label)[0] for label in unique_labels}
393
+
394
+ # Compute intra-class similarities using vectorized operations
395
+ intra_class_similarities = []
396
+ for label, indices in label_groups.items():
397
+ if len(indices) > 1:
398
+ # Extract submatrix for this class
399
+ class_similarities = similarities[np.ix_(indices, indices)]
400
+ # Get upper triangle (excluding diagonal)
401
+ triu_indices = np.triu_indices_from(class_similarities, k=1)
402
+ intra_class_similarities.extend(class_similarities[triu_indices].tolist())
403
+
404
+ # Compute inter-class similarities using vectorized operations
405
+ inter_class_similarities = []
406
+ labels_list = list(label_groups.keys())
407
+ for i in range(len(labels_list)):
408
+ for j in range(i + 1, len(labels_list)):
409
+ label1_indices = label_groups[labels_list[i]]
410
+ label2_indices = label_groups[labels_list[j]]
411
+ # Extract submatrix between two classes
412
+ inter_sims = similarities[np.ix_(label1_indices, label2_indices)]
413
+ inter_class_similarities.extend(inter_sims.flatten().tolist())
414
+
415
+ nn_accuracy = self.compute_embedding_accuracy(embeddings, labels, similarities)
416
+ centroid_accuracy = self.compute_centroid_accuracy(embeddings, labels)
417
+
418
+ return {
419
+ 'intra_class_similarities': intra_class_similarities,
420
+ 'inter_class_similarities': inter_class_similarities,
421
+ 'intra_class_mean': float(np.mean(intra_class_similarities)) if intra_class_similarities else 0.0,
422
+ 'inter_class_mean': float(np.mean(inter_class_similarities)) if inter_class_similarities else 0.0,
423
+ 'separation_score': float(np.mean(intra_class_similarities) - np.mean(inter_class_similarities)) if intra_class_similarities and inter_class_similarities else 0.0,
424
+ 'accuracy': nn_accuracy,
425
+ 'centroid_accuracy': centroid_accuracy,
426
+ }
427
+
428
+ def compute_embedding_accuracy(self, embeddings, labels, similarities):
429
+ """Compute classification accuracy using nearest neighbor"""
430
+ correct_predictions = 0
431
+ total_predictions = len(labels)
432
+ for i in range(len(embeddings)):
433
+ true_label = labels[i]
434
+ similarities_row = similarities[i].copy()
435
+ similarities_row[i] = -1
436
+ nearest_neighbor_idx = int(np.argmax(similarities_row))
437
+ predicted_label = labels[nearest_neighbor_idx]
438
+ if predicted_label == true_label:
439
+ correct_predictions += 1
440
+ return correct_predictions / total_predictions if total_predictions > 0 else 0.0
441
+
442
+ def compute_centroid_accuracy(self, embeddings, labels):
443
+ """Compute classification accuracy using centroids - optimized vectorized version"""
444
+ unique_labels = list(set(labels))
445
+
446
+ # Compute centroids efficiently
447
+ centroids = {}
448
+ for label in unique_labels:
449
+ label_mask = np.array(labels) == label
450
+ centroids[label] = np.mean(embeddings[label_mask], axis=0)
451
+
452
+ # Stack centroids for vectorized similarity computation
453
+ centroid_matrix = np.vstack([centroids[label] for label in unique_labels])
454
+
455
+ # Compute all similarities at once
456
+ similarities = cosine_similarity(embeddings, centroid_matrix)
457
+
458
+ # Get predicted labels
459
+ predicted_indices = np.argmax(similarities, axis=1)
460
+ predicted_labels = [unique_labels[idx] for idx in predicted_indices]
461
+
462
+ # Compute accuracy
463
+ correct_predictions = sum(pred == true for pred, true in zip(predicted_labels, labels))
464
+ return correct_predictions / len(labels) if len(labels) > 0 else 0.0
465
+
466
+ def predict_labels_from_embeddings(self, embeddings, labels):
467
+ """Predict labels from embeddings using centroid-based classification - optimized vectorized version"""
468
+ # Filter out None labels when computing centroids
469
+ unique_labels = [l for l in set(labels) if l is not None]
470
+ if len(unique_labels) == 0:
471
+ # If no valid labels, return None for all predictions
472
+ return [None] * len(embeddings)
473
+
474
+ # Compute centroids efficiently
475
+ centroids = {}
476
+ for label in unique_labels:
477
+ label_mask = np.array(labels) == label
478
+ if np.any(label_mask):
479
+ centroids[label] = np.mean(embeddings[label_mask], axis=0)
480
+
481
+ # Stack centroids for vectorized similarity computation
482
+ centroid_labels = list(centroids.keys())
483
+ centroid_matrix = np.vstack([centroids[label] for label in centroid_labels])
484
+
485
+ # Compute all similarities at once
486
+ similarities = cosine_similarity(embeddings, centroid_matrix)
487
+
488
+ # Get predicted labels
489
+ predicted_indices = np.argmax(similarities, axis=1)
490
+ predictions = [centroid_labels[idx] for idx in predicted_indices]
491
+
492
+ return predictions
493
+
494
+ def create_confusion_matrix(self, true_labels, predicted_labels, title="Confusion Matrix", label_type="Label"):
495
+ """Create and plot confusion matrix"""
496
+ unique_labels = sorted(list(set(true_labels + predicted_labels)))
497
+ cm = confusion_matrix(true_labels, predicted_labels, labels=unique_labels)
498
+ accuracy = accuracy_score(true_labels, predicted_labels)
499
+ plt.figure(figsize=(12, 10))
500
+ sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=unique_labels, yticklabels=unique_labels)
501
+ plt.title(f'{title}\nAccuracy: {accuracy:.3f} ({accuracy*100:.1f}%)')
502
+ plt.ylabel(f'True {label_type}')
503
+ plt.xlabel(f'Predicted {label_type}')
504
+ plt.xticks(rotation=45)
505
+ plt.yticks(rotation=0)
506
+ plt.tight_layout()
507
+ return plt.gcf(), accuracy, cm
508
+
509
+ def evaluate_classification_performance(self, embeddings, labels, embedding_type="Embeddings", label_type="Label"):
510
+ """
511
+ Evaluate classification performance and create confusion matrix.
512
+
513
+ Args:
514
+ embeddings: Embeddings
515
+ labels: True labels
516
+ embedding_type: Type of embeddings for display
517
+ label_type: Type of labels (Color)
518
+ full_embeddings: Optional full 512-dim embeddings for ensemble (if None, uses only embeddings)
519
+ ensemble_weight: Weight for embeddings in ensemble (0.0 = only full, 1.0 = only embeddings)
520
+ """
521
+
522
+ predictions = self.predict_labels_from_embeddings(embeddings, labels)
523
+ title_suffix = ""
524
+
525
+ # Filter out None values from labels and predictions
526
+ valid_indices = [i for i, (label, pred) in enumerate(zip(labels, predictions))
527
+ if label is not None and pred is not None]
528
+
529
+ if len(valid_indices) == 0:
530
+ print(f"⚠️ Warning: No valid labels/predictions found (all are None)")
531
+ return {
532
+ 'accuracy': 0.0,
533
+ 'predictions': predictions,
534
+ 'confusion_matrix': None,
535
+ 'classification_report': None,
536
+ 'figure': None,
537
+ }
538
+
539
+ filtered_labels = [labels[i] for i in valid_indices]
540
+ filtered_predictions = [predictions[i] for i in valid_indices]
541
+
542
+ accuracy = accuracy_score(filtered_labels, filtered_predictions)
543
+ fig, acc, cm = self.create_confusion_matrix(
544
+ filtered_labels, filtered_predictions,
545
+ f"{embedding_type} - {label_type} Classification{title_suffix}",
546
+ label_type
547
+ )
548
+ unique_labels = sorted(list(set(filtered_labels)))
549
+ report = classification_report(filtered_labels, filtered_predictions, labels=unique_labels, target_names=unique_labels, output_dict=True)
550
+ return {
551
+ 'accuracy': accuracy,
552
+ 'predictions': predictions,
553
+ 'confusion_matrix': cm,
554
+ 'classification_report': report,
555
+ 'figure': fig,
556
+ }
557
+
558
+
559
+ def evaluate_kaggle_marqo(self, max_samples):
560
+ """Evaluate both color embeddings on KAGL Marqo dataset"""
561
+ print(f"\n{'='*60}")
562
+ print("Evaluating KAGL Marqo Dataset with Color embeddings")
563
+ print(f"Max samples: {max_samples}")
564
+ print(f"{'='*60}")
565
+
566
+ kaggle_dataset = load_kaggle_marqo_dataset(max_samples)
567
+ if kaggle_dataset is None:
568
+ print("❌ Failed to load KAGL dataset")
569
+ return None
570
+
571
+ dataloader = DataLoader(kaggle_dataset, batch_size=8, shuffle=False, num_workers=0, collate_fn=collate_fn_filter_none)
572
+
573
+ results = {}
574
+
575
+ # ========== EXTRACT BASELINE EMBEDDINGS ==========
576
+ print("\n📦 Extracting baseline embeddings...")
577
+ text_full_embeddings, text_colors_full = self.extract_color_embeddings(dataloader, embedding_type='text', max_samples=max_samples)
578
+ image_full_embeddings, image_colors_full = self.extract_color_embeddings(dataloader, embedding_type='image', max_samples=max_samples)
579
+ text_color_metrics = self.compute_similarity_metrics(text_full_embeddings, text_colors_full)
580
+ text_color_class = self.evaluate_classification_performance(
581
+ text_full_embeddings, text_colors_full,
582
+ "Text Color Embeddings (Baseline)", "Color",
583
+ )
584
+ text_color_metrics.update(text_color_class)
585
+ results['text_color'] = text_color_metrics
586
+ image_color_metrics = self.compute_similarity_metrics(image_full_embeddings, image_colors_full)
587
+ image_color_class = self.evaluate_classification_performance(
588
+ image_full_embeddings, image_colors_full,
589
+ "Image Color Embeddings (Baseline)", "Color",
590
+ )
591
+ image_color_metrics.update(image_color_class)
592
+ results['image_color'] = image_color_metrics
593
+ del text_full_embeddings, image_full_embeddings
594
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
595
+
596
+ # ========== SAVE VISUALIZATIONS ==========
597
+ os.makedirs(self.directory, exist_ok=True)
598
+ for key in ['text_color', 'image_color']:
599
+ results[key]['figure'].savefig(
600
+ f"{self.directory}/kaggle_{key.replace('_', '_')}_confusion_matrix.png",
601
+ dpi=300,
602
+ bbox_inches='tight',
603
+ )
604
+ plt.close(results[key]['figure'])
605
+
606
+ return results
607
+
608
+ def evaluate_local_validation(self, max_samples):
609
+ """Evaluate both color embeddings on local validation dataset"""
610
+ print(f"\n{'='*60}")
611
+ print("Evaluating Local Validation Dataset")
612
+ print(" Color embeddings")
613
+ print(f"Max samples: {max_samples}")
614
+ print(f"{'='*60}")
615
+
616
+ local_dataset = load_local_validation_dataset(max_samples)
617
+ dataloader = DataLoader(local_dataset, batch_size=8, shuffle=False, num_workers=0)
618
+
619
+ results = {}
620
+
621
+ # ========== COLOR EVALUATION ==========
622
+ print("\n🎨 COLOR EVALUATION ")
623
+ print("=" * 50)
624
+
625
+ # Text color embeddings
626
+ print("\n📝 Extracting text color embeddings...")
627
+ text_color_embeddings, text_colors = self.extract_color_embeddings(dataloader, 'text', max_samples)
628
+ print(f" Text color embeddings shape: {text_color_embeddings.shape}")
629
+ text_color_metrics = self.compute_similarity_metrics(text_color_embeddings, text_colors)
630
+ text_color_class = self.evaluate_classification_performance(
631
+ text_color_embeddings, text_colors, "Text Color Embeddings (Baseline)", "Color"
632
+ )
633
+ text_color_metrics.update(text_color_class)
634
+ results['text_color'] = text_color_metrics
635
+
636
+ del text_color_embeddings
637
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
638
+
639
+ # Image color embeddings
640
+ print("\n🖼️ Extracting image color embeddings...")
641
+ image_color_embeddings, image_colors = self.extract_color_embeddings(dataloader, 'image', max_samples)
642
+ print(f" Image color embeddings shape: {image_color_embeddings.shape}")
643
+ image_color_metrics = self.compute_similarity_metrics(image_color_embeddings, image_colors)
644
+ image_color_class = self.evaluate_classification_performance(
645
+ image_color_embeddings, image_colors, "Image Color Embeddings (Baseline)", "Color"
646
+ )
647
+ image_color_metrics.update(image_color_class)
648
+ results['image_color'] = image_color_metrics
649
+
650
+ del image_color_embeddings
651
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
652
+ # ========== SAVE VISUALIZATIONS ==========
653
+ os.makedirs(self.directory, exist_ok=True)
654
+ for key in ['text_color', 'image_color']:
655
+ results[key]['figure'].savefig(
656
+ f"{self.directory}/local_{key.replace('_', '_')}_confusion_matrix.png",
657
+ dpi=300,
658
+ bbox_inches='tight',
659
+ )
660
+ plt.close(results[key]['figure'])
661
+
662
+ return results
663
+
664
+
665
+ def evaluate_baseline_kaggle_marqo(self, max_samples=5000):
666
+ """Evaluate baseline Fashion CLIP model on KAGL Marqo dataset"""
667
+ print(f"\n{'='*60}")
668
+ print("Evaluating Baseline Fashion CLIP on KAGL Marqo Dataset")
669
+ print(f"Max samples: {max_samples}")
670
+ print(f"{'='*60}")
671
+
672
+ # Load KAGL Marqo dataset
673
+ kaggle_dataset = load_kaggle_marqo_dataset(max_samples)
674
+ if kaggle_dataset is None:
675
+ print("❌ Failed to load KAGL dataset")
676
+ return None
677
+
678
+ # Create dataloader
679
+ dataloader = DataLoader(kaggle_dataset, batch_size=8, shuffle=False, num_workers=0, collate_fn=collate_fn_filter_none)
680
+
681
+ results = {}
682
+
683
+ # Evaluate text embeddings
684
+ print("\n📝 Extracting baseline text embeddings from KAGL Marqo...")
685
+ text_embeddings, text_colors = self.extract_baseline_embeddings_batch(dataloader, 'text', max_samples)
686
+ print(f" Baseline text embeddings shape: {text_embeddings.shape} (using all {text_embeddings.shape[1]} dimensions)")
687
+ text_color_metrics = self.compute_similarity_metrics(text_embeddings, text_colors)
688
+
689
+ text_color_classification = self.evaluate_classification_performance(
690
+ text_embeddings, text_colors, "Baseline KAGL Marqo Text Embeddings - Color", "Color"
691
+ )
692
+ text_color_metrics.update(text_color_classification)
693
+ results['text'] = {
694
+ 'color': text_color_metrics
695
+ }
696
+
697
+ # Clear memory
698
+ del text_embeddings
699
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
700
+
701
+ # Evaluate image embeddings
702
+ print("\n🖼️ Extracting baseline image embeddings from KAGL Marqo...")
703
+ image_embeddings, image_colors = self.extract_baseline_embeddings_batch(dataloader, 'image', max_samples)
704
+ print(f" Baseline image embeddings shape: {image_embeddings.shape} (using all {image_embeddings.shape[1]} dimensions)")
705
+ image_color_metrics = self.compute_similarity_metrics(image_embeddings, image_colors)
706
+
707
+ image_color_classification = self.evaluate_classification_performance(
708
+ image_embeddings, image_colors, "Baseline KAGL Marqo Image Embeddings - Color", "Color"
709
+ )
710
+ image_color_metrics.update(image_color_classification)
711
+ results['image'] = {
712
+ 'color': image_color_metrics
713
+ }
714
+
715
+ # Clear memory
716
+ del image_embeddings
717
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
718
+
719
+ # ========== SAVE VISUALIZATIONS ==========
720
+ os.makedirs(self.directory, exist_ok=True)
721
+ for key in ['text', 'image']:
722
+ for subkey in ['color']:
723
+ figure = results[key][subkey]['figure']
724
+ figure.savefig(
725
+ f"{self.directory}/kaggle_baseline_{key}_{subkey}_confusion_matrix.png",
726
+ dpi=300,
727
+ bbox_inches='tight',
728
+ )
729
+ plt.close(figure)
730
+
731
+ return results
732
+
733
+ def evaluate_baseline_local_validation(self, max_samples=5000):
734
+ """Evaluate baseline Fashion CLIP model on local validation dataset"""
735
+ print(f"\n{'='*60}")
736
+ print("Evaluating Baseline Fashion CLIP on Local Validation Dataset")
737
+ print(f"Max samples: {max_samples}")
738
+ print(f"{'='*60}")
739
+
740
+ # Load local validation dataset
741
+ local_dataset = load_local_validation_dataset(max_samples)
742
+ if local_dataset is None:
743
+ print("❌ Failed to load local validation dataset")
744
+ return None
745
+
746
+ # Create dataloader
747
+ dataloader = DataLoader(local_dataset, batch_size=8, shuffle=False, num_workers=0)
748
+
749
+ results = {}
750
+
751
+ # Evaluate text embeddings
752
+ print("\n📝 Extracting baseline text embeddings from Local Validation...")
753
+ text_embeddings, text_colors = self.extract_baseline_embeddings_batch(dataloader, 'text', max_samples)
754
+ print(f" Baseline text embeddings shape: {text_embeddings.shape} (using all {text_embeddings.shape[1]} dimensions)")
755
+ text_color_metrics = self.compute_similarity_metrics(text_embeddings, text_colors)
756
+
757
+ text_color_classification = self.evaluate_classification_performance(
758
+ text_embeddings, text_colors, "Baseline Local Validation Text Embeddings - Color", "Color"
759
+ )
760
+ text_color_metrics.update(text_color_classification)
761
+ results['text'] = {
762
+ 'color': text_color_metrics
763
+ }
764
+
765
+ # Clear memory
766
+ del text_embeddings
767
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
768
+
769
+ # Evaluate image embeddings
770
+ print("\n🖼️ Extracting baseline image embeddings from Local Validation...")
771
+ image_embeddings, image_colors = self.extract_baseline_embeddings_batch(dataloader, 'image', max_samples)
772
+ print(f" Baseline image embeddings shape: {image_embeddings.shape} (using all {image_embeddings.shape[1]} dimensions)")
773
+ image_color_metrics = self.compute_similarity_metrics(image_embeddings, image_colors)
774
+
775
+ image_color_classification = self.evaluate_classification_performance(
776
+ image_embeddings, image_colors, "Baseline Local Validation Image Embeddings - Color", "Color"
777
+ )
778
+ image_color_metrics.update(image_color_classification)
779
+ results['image'] = {
780
+ 'color': image_color_metrics
781
+ }
782
+
783
+ # Clear memory
784
+ del image_embeddings
785
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
786
+
787
+ # ========== SAVE VISUALIZATIONS ==========
788
+ os.makedirs(self.directory, exist_ok=True)
789
+ for key in ['text', 'image']:
790
+ for subkey in ['color']:
791
+ figure = results[key][subkey]['figure']
792
+ figure.savefig(
793
+ f"{self.directory}/local_baseline_{key}_{subkey}_confusion_matrix.png",
794
+ dpi=300,
795
+ bbox_inches='tight',
796
+ )
797
+ plt.close(figure)
798
+
799
+ return results
800
+
801
+ def analyze_baseline_vs_trained_performance(self, results_trained, results_baseline, dataset_name):
802
+ """
803
+ Analyse et explique pourquoi la baseline peut performer mieux que le modèle entraîné
804
+
805
+ Raisons possibles:
806
+ 1. Capacité dimensionnelle: Baseline utilise toutes les dimensions (512), modèle entraîné utilise seulement des sous-espaces (17 ou 64 dims)
807
+ 2. Distribution shift: Dataset de validation différent de celui d'entraînement
808
+ 3. Overfitting: Modèle trop spécialisé sur le dataset d'entraînement
809
+ 4. Généralisation: Baseline pré-entraînée sur un dataset plus large et diversifié
810
+ 5. Perte d'information: Spécialisation excessive peut causer perte d'information générale
811
+ """
812
+ print(f"\n{'='*60}")
813
+ print(f"📊 ANALYSE: Baseline vs Modèle Entraîné - {dataset_name}")
814
+ print(f"{'='*60}")
815
+
816
+ # Comparer les métriques pour chaque type d'embedding
817
+ comparisons = []
818
+
819
+ # Text Color
820
+ trained_color_text_acc = results_trained.get('text_color', {}).get('accuracy', 0)
821
+ baseline_color_text_acc = results_baseline.get('text', {}).get('color', {}).get('accuracy', 0)
822
+ if trained_color_text_acc > 0 and baseline_color_text_acc > 0:
823
+ diff = baseline_color_text_acc - trained_color_text_acc
824
+ comparisons.append({
825
+ 'type': 'Text Color',
826
+ 'trained': trained_color_text_acc,
827
+ 'baseline': baseline_color_text_acc,
828
+ 'diff': diff,
829
+ 'trained_dims': '0-15 (16 dims)',
830
+ 'baseline_dims': 'All dimensions (512 dims)'
831
+ })
832
+
833
+ # Image Color
834
+ trained_color_img_acc = results_trained.get('image_color', {}).get('accuracy', 0)
835
+ baseline_color_img_acc = results_baseline.get('image', {}).get('color', {}).get('accuracy', 0)
836
+ if trained_color_img_acc > 0 and baseline_color_img_acc > 0:
837
+ diff = baseline_color_img_acc - trained_color_img_acc
838
+ comparisons.append({
839
+ 'type': 'Image Color',
840
+ 'trained': trained_color_img_acc,
841
+ 'baseline': baseline_color_img_acc,
842
+ 'diff': diff,
843
+ 'trained_dims': '0-15 (16 dims)',
844
+ 'baseline_dims': 'All dimensions (512 dims)'
845
+ })
846
+
847
+ return comparisons
848
+
849
+
850
+
851
+ if __name__ == "__main__":
852
+ device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
853
+ print(f"Using device: {device}")
854
+
855
+ directory = 'color_model_analysis'
856
+ max_samples = 10000
857
+
858
+ evaluator = ColorEvaluator(device=device, directory=directory)
859
+
860
+ # Evaluate KAGL Marqo
861
+ print("\n" + "="*60)
862
+ print("🚀 Starting evaluation of KAGL Marqo with Color embeddings")
863
+ print("="*60)
864
+ results_kaggle = evaluator.evaluate_kaggle_marqo(max_samples=max_samples)
865
+
866
+ print(f"\n{'='*60}")
867
+ print("KAGL MARQO EVALUATION SUMMARY")
868
+ print(f"{'='*60}")
869
+
870
+ print("\n🎨 COLOR CLASSIFICATION RESULTS:")
871
+ print(f" Text - NN Acc: {results_kaggle['text_color']['accuracy']*100:.1f}% | Centroid Acc: {results_kaggle['text_color']['centroid_accuracy']*100:.1f}% | Separation: {results_kaggle['text_color']['separation_score']:.4f}")
872
+ print(f" Image - NN Acc: {results_kaggle['image_color']['accuracy']*100:.1f}% | Centroid Acc: {results_kaggle['image_color']['centroid_accuracy']*100:.1f}% | Separation: {results_kaggle['image_color']['separation_score']:.4f}")
873
+
874
+ # Evaluate Baseline Fashion CLIP on KAGL Marqo
875
+ print("\n" + "="*60)
876
+ print("🚀 Starting evaluation of Baseline Fashion CLIP on KAGL Marqo")
877
+ print("="*60)
878
+ results_baseline_kaggle = evaluator.evaluate_baseline_kaggle_marqo(max_samples=max_samples)
879
+
880
+ print(f"\n{'='*60}")
881
+ print("BASELINE KAGL MARQO EVALUATION SUMMARY")
882
+ print(f"{'='*60}")
883
+
884
+ print("\n🎨 COLOR CLASSIFICATION RESULTS (Baseline):")
885
+ print(f" Text - NN Acc: {results_baseline_kaggle['text']['color']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_kaggle['text']['color']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_kaggle['text']['color']['separation_score']:.4f}")
886
+ print(f" Image - NN Acc: {results_baseline_kaggle['image']['color']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_kaggle['image']['color']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_kaggle['image']['color']['separation_score']:.4f}")
887
+
888
+ # Evaluate Local Validation Dataset
889
+ print("\n" + "="*60)
890
+ print("🚀 Starting evaluation of Local Validation Dataset with Color embeddings")
891
+ print("="*60)
892
+ results_local = evaluator.evaluate_local_validation(max_samples=max_samples)
893
+
894
+ if results_local is not None:
895
+ print(f"\n{'='*60}")
896
+ print("LOCAL VALIDATION DATASET EVALUATION SUMMARY")
897
+ print(f"{'='*60}")
898
+
899
+ print("\n🎨 COLOR CLASSIFICATION RESULTS:")
900
+ print(f" Text - NN Acc: {results_local['text_color']['accuracy']*100:.1f}% | Centroid Acc: {results_local['text_color']['centroid_accuracy']*100:.1f}% | Separation: {results_local['text_color']['separation_score']:.4f}")
901
+ print(f" Image - NN Acc: {results_local['image_color']['accuracy']*100:.1f}% | Centroid Acc: {results_local['image_color']['centroid_accuracy']*100:.1f}% | Separation: {results_local['image_color']['separation_score']:.4f}")
902
+
903
+ # Evaluate Baseline Fashion CLIP on Local Validation
904
+ print("\n" + "="*60)
905
+ print("🚀 Starting evaluation of Baseline Fashion CLIP on Local Validation")
906
+ print("="*60)
907
+ results_baseline_local = evaluator.evaluate_baseline_local_validation(max_samples=max_samples)
908
+
909
+ if results_baseline_local is not None:
910
+ print(f"\n{'='*60}")
911
+ print("BASELINE LOCAL VALIDATION EVALUATION SUMMARY")
912
+ print(f"{'='*60}")
913
+
914
+ print("\n🎨 COLOR CLASSIFICATION RESULTS (Baseline):")
915
+ print(f" Text - NN Acc: {results_baseline_local['text']['color']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_local['text']['color']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_local['text']['color']['separation_score']:.4f}")
916
+ print(f" Image - NN Acc: {results_baseline_local['image']['color']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_local['image']['color']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_local['image']['color']['separation_score']:.4f}")
917
+
918
+
919
+ print(f"\n✅ Evaluation completed! Check '{directory}/' for visualization files.")