Leacb4 commited on
Commit
61f1c2b
Β·
verified Β·
1 Parent(s): 11dbd66

Upload evaluation/evaluate_color_embeddings.py with huggingface_hub

Browse files
evaluation/evaluate_color_embeddings.py ADDED
@@ -0,0 +1,1124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Comprehensive evaluation of color embeddings with Fashion-CLIP comparison.
3
+ This file evaluates the quality of color embeddings generated by the ColorCLIP model
4
+ by calculating intra-class and inter-class similarity metrics, classification accuracies,
5
+ and generating confusion matrices. It also compares results with Fashion-CLIP as a baseline
6
+ to measure relative performance.
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import pandas as pd
12
+ import numpy as np
13
+ import matplotlib.pyplot as plt
14
+ import seaborn as sns
15
+ from sklearn.metrics.pairwise import cosine_similarity
16
+ from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
17
+ from collections import defaultdict
18
+ import os
19
+ import json
20
+ from tqdm import tqdm
21
+ from torch.utils.data import Dataset, DataLoader
22
+ from torchvision import transforms
23
+ import requests
24
+ from io import BytesIO
25
+ from PIL import Image
26
+ import warnings
27
+ warnings.filterwarnings('ignore')
28
+ from color_model import ColorCLIP, Tokenizer, ImageEncoder, TextEncoder, collate_batch
29
+ from torch.utils.data import DataLoader
30
+ from transformers import CLIPProcessor, CLIPModel as TransformersCLIPModel
31
+ import config
32
+
33
+ class ColorDataset(Dataset):
34
+ """
35
+ Dataset class for color embedding evaluation.
36
+
37
+ Handles loading images from various sources (local paths, URLs, bytes) and
38
+ applying appropriate transformations for evaluation.
39
+ """
40
+ def __init__(self, dataframe):
41
+ """
42
+ Initialize the color dataset.
43
+
44
+ Args:
45
+ dataframe: DataFrame containing image paths/URLs, text, and color labels
46
+ """
47
+ self.dataframe = dataframe
48
+ self.transform = transforms.Compose([
49
+ transforms.Resize((224, 224)),
50
+ transforms.ToTensor(),
51
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
52
+ ])
53
+
54
+ def __len__(self):
55
+ return len(self.dataframe)
56
+
57
+ def __getitem__(self, idx):
58
+ row = self.dataframe.iloc[idx]
59
+
60
+ # Handle image - it should be in row[config.column_url_image] and contain the image data
61
+ image_data = row[config.column_url_image]
62
+
63
+ try:
64
+ # Check if image_data has 'bytes' key or is already PIL Image
65
+ if isinstance(image_data, dict) and 'bytes' in image_data:
66
+ image = Image.open(BytesIO(image_data['bytes'])).convert("RGB")
67
+ elif hasattr(image_data, 'convert'): # Already a PIL Image
68
+ image = image_data.convert("RGB")
69
+ elif isinstance(image_data, str):
70
+ # It's a file path (local or URL)
71
+ if image_data.startswith('http'):
72
+ # It's a URL - download the image
73
+ response = requests.get(image_data, timeout=10)
74
+ response.raise_for_status()
75
+ image = Image.open(BytesIO(response.content)).convert("RGB")
76
+ else:
77
+ # It's a local file path
78
+ image = Image.open(image_data).convert("RGB")
79
+ else:
80
+ # Assume it's bytes data
81
+ image = Image.open(BytesIO(image_data)).convert("RGB")
82
+
83
+ # Apply transform
84
+ image = self.transform(image)
85
+
86
+ except Exception as e:
87
+ print(f"⚠️ Failed to load image {idx}: {e}")
88
+ # Return a placeholder image
89
+ image = torch.zeros(3, 224, 224)
90
+
91
+ # Get text and color
92
+ description = row[config.text_column]
93
+ color = row[config.color_column]
94
+
95
+ return image, description, color
96
+
97
+ class EmbeddingEvaluator:
98
+ """
99
+ Evaluator for color embeddings generated by the ColorCLIP model.
100
+
101
+ This class provides methods to evaluate the quality of color embeddings by computing
102
+ similarity metrics, classification accuracies, and generating visualizations.
103
+ """
104
+
105
+ def __init__(self, model_path, embed_dim):
106
+ """
107
+ Initialize the embedding evaluator.
108
+
109
+ Args:
110
+ model_path: Path to the trained ColorCLIP model checkpoint
111
+ embed_dim: Embedding dimension for the model
112
+ """
113
+ self.device = config.device
114
+
115
+ # Initialize tokenizer first to get vocab size
116
+ self.tokenizer = Tokenizer()
117
+ vocab_size = None
118
+
119
+ # Load vocabulary if available to determine vocab_size
120
+ if os.path.exists(config.tokeniser_path):
121
+ with open(config.tokeniser_path, 'r') as f:
122
+ vocab_dict = json.load(f)
123
+ # Manually load vocabulary
124
+ self.tokenizer.word2idx = defaultdict(lambda: 0, {k: int(v) for k, v in vocab_dict.items()})
125
+ self.tokenizer.idx2word = {int(v): k for k, v in vocab_dict.items() if int(v) > 0}
126
+ self.tokenizer.counter = max(self.tokenizer.word2idx.values(), default=0) + 1
127
+ vocab_size = self.tokenizer.counter
128
+ print(f"Tokenizer vocabulary loaded from {config.tokeniser_path}")
129
+ else:
130
+ print(f"Warning: {config.tokeniser_path} not found. Using default tokenizer.")
131
+
132
+ # Load checkpoint to get vocab_size and state_dict
133
+ checkpoint = None
134
+ if os.path.exists(model_path):
135
+ checkpoint = torch.load(model_path, map_location=self.device)
136
+
137
+ # Try to get vocab_size from model checkpoint if not already determined
138
+ if vocab_size is None:
139
+ # Try to get vocab_size from metadata
140
+ if isinstance(checkpoint, dict) and 'vocab_size' in checkpoint:
141
+ vocab_size = checkpoint['vocab_size']
142
+ # Otherwise, try to infer from model state dict
143
+ elif isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
144
+ state_dict = checkpoint['model_state_dict']
145
+ if 'text_encoder.embedding.weight' in state_dict:
146
+ vocab_size = state_dict['text_encoder.embedding.weight'].shape[0]
147
+ elif isinstance(checkpoint, dict) and 'text_encoder.embedding.weight' in checkpoint:
148
+ vocab_size = checkpoint['text_encoder.embedding.weight'].shape[0]
149
+
150
+ # Fallback to default if still not determined
151
+ if vocab_size is None:
152
+ vocab_size = 39 # Default fallback
153
+ print(f"Warning: Could not determine vocab_size, using default: {vocab_size}")
154
+
155
+ # Initialize model with determined vocab_size
156
+ self.model = ColorCLIP(vocab_size=vocab_size, embedding_dim=embed_dim).to(self.device)
157
+
158
+ # Load trained model state dict
159
+ if checkpoint is not None:
160
+ state_dict = checkpoint.get('model_state_dict', checkpoint)
161
+ self.model.load_state_dict(state_dict)
162
+ print(f"Model loaded from {model_path}")
163
+ else:
164
+ print(f"Warning: Model file {model_path} not found. Using untrained model.")
165
+
166
+ self.model.eval()
167
+
168
+ def extract_embeddings(self, dataloader, embedding_type='text'):
169
+ """
170
+ Extract embeddings from the model for a given dataloader.
171
+
172
+ Args:
173
+ dataloader: DataLoader containing images, texts, and colors
174
+ embedding_type: Type of embeddings to extract ('text', 'image', or 'color')
175
+
176
+ Returns:
177
+ Tuple of (embeddings array, labels list, texts list)
178
+ """
179
+ all_embeddings = []
180
+ all_labels = []
181
+ all_texts = []
182
+
183
+ with torch.no_grad():
184
+ for images, texts, colors in tqdm(dataloader, desc=f"Extracting {embedding_type} embeddings"):
185
+ if embedding_type == 'text':
186
+ # Tokenize texts using the tokenizer
187
+ tokenized_texts = [self.tokenizer(text) for text in texts]
188
+ # Convert to tensors and pad sequences
189
+ text_tensors = [torch.tensor(t, dtype=torch.long) for t in tokenized_texts]
190
+ text_tokens = nn.utils.rnn.pad_sequence(text_tensors, batch_first=True, padding_value=0).to(self.device)
191
+ lengths = torch.tensor([len(t) for t in tokenized_texts], dtype=torch.long).to(self.device)
192
+ embeddings = self.model.text_encoder(text_tokens, lengths)
193
+ labels = colors
194
+ elif embedding_type == 'image':
195
+ images = images.to(self.device)
196
+ embeddings = self.model.image_encoder(images)
197
+ labels = colors
198
+ elif embedding_type == 'color':
199
+ # Tokenize color names using the tokenizer
200
+ tokenized_colors = [self.tokenizer(color) for color in colors]
201
+ # Convert to tensors and pad sequences
202
+ color_tensors = [torch.tensor(t, dtype=torch.long) for t in tokenized_colors]
203
+ color_tokens = nn.utils.rnn.pad_sequence(color_tensors, batch_first=True, padding_value=0).to(self.device)
204
+ lengths = torch.tensor([len(t) for t in tokenized_colors], dtype=torch.long).to(self.device)
205
+ embeddings = self.model.text_encoder(color_tokens, lengths)
206
+ labels = colors
207
+
208
+ all_embeddings.append(embeddings.cpu().numpy())
209
+ all_labels.extend(labels)
210
+ all_texts.extend(texts)
211
+
212
+ return np.vstack(all_embeddings), all_labels, all_texts
213
+
214
+ def compute_similarity_metrics(self, embeddings, labels):
215
+ """Compute intra-class and inter-class similarities"""
216
+ similarities = cosine_similarity(embeddings)
217
+
218
+ # Group embeddings by color
219
+ color_groups = defaultdict(list)
220
+ for i, color in enumerate(labels):
221
+ color_groups[color].append(i)
222
+
223
+ # Calculate intra-class similarities (same color)
224
+ intra_class_similarities = []
225
+ for color, indices in color_groups.items():
226
+ if len(indices) > 1:
227
+ for i in range(len(indices)):
228
+ for j in range(i+1, len(indices)):
229
+ sim = similarities[indices[i], indices[j]]
230
+ intra_class_similarities.append(sim)
231
+
232
+ # Calculate inter-class similarities (different colors)
233
+ inter_class_similarities = []
234
+ colors = list(color_groups.keys())
235
+ for i in range(len(colors)):
236
+ for j in range(i+1, len(colors)):
237
+ color1_indices = color_groups[colors[i]]
238
+ color2_indices = color_groups[colors[j]]
239
+
240
+ for idx1 in color1_indices:
241
+ for idx2 in color2_indices:
242
+ sim = similarities[idx1, idx2]
243
+ inter_class_similarities.append(sim)
244
+
245
+ # Calculate classification accuracy using nearest neighbor in embedding space
246
+ nn_accuracy = self.compute_embedding_accuracy(embeddings, labels, similarities)
247
+
248
+ # Calculate classification accuracy using centroids
249
+ centroid_accuracy = self.compute_centroid_accuracy(embeddings, labels)
250
+
251
+ return {
252
+ 'intra_class_similarities': intra_class_similarities,
253
+ 'inter_class_similarities': inter_class_similarities,
254
+ 'intra_class_mean': np.mean(intra_class_similarities) if intra_class_similarities else 0,
255
+ 'inter_class_mean': np.mean(inter_class_similarities) if inter_class_similarities else 0,
256
+ 'separation_score': np.mean(intra_class_similarities) - np.mean(inter_class_similarities) if intra_class_similarities and inter_class_similarities else 0,
257
+ 'accuracy': nn_accuracy,
258
+ 'centroid_accuracy': centroid_accuracy
259
+ }
260
+
261
+ def compute_embedding_accuracy(self, embeddings, labels, similarities):
262
+ """Compute classification accuracy using nearest neighbor in embedding space"""
263
+ correct_predictions = 0
264
+ total_predictions = len(labels)
265
+
266
+ for i in range(len(embeddings)):
267
+ true_label = labels[i]
268
+
269
+ # Find the most similar embedding (excluding itself)
270
+ similarities_row = similarities[i].copy()
271
+ similarities_row[i] = -1 # Exclude self-similarity
272
+ nearest_neighbor_idx = np.argmax(similarities_row)
273
+ predicted_label = labels[nearest_neighbor_idx]
274
+
275
+ if predicted_label == true_label:
276
+ correct_predictions += 1
277
+
278
+ return correct_predictions / total_predictions if total_predictions > 0 else 0
279
+
280
+ def compute_centroid_accuracy(self, embeddings, labels):
281
+ """Compute classification accuracy using color centroids"""
282
+ # Create centroids for each color
283
+ unique_colors = list(set(labels))
284
+ centroids = {}
285
+
286
+ for color in unique_colors:
287
+ color_indices = [i for i, label in enumerate(labels) if label == color]
288
+ color_embeddings = embeddings[color_indices]
289
+ centroids[color] = np.mean(color_embeddings, axis=0)
290
+
291
+ # Classify each embedding to nearest centroid
292
+ correct_predictions = 0
293
+ total_predictions = len(labels)
294
+
295
+ for i, embedding in enumerate(embeddings):
296
+ true_label = labels[i]
297
+
298
+ # Find closest centroid
299
+ best_similarity = -1
300
+ predicted_label = None
301
+
302
+ for color, centroid in centroids.items():
303
+ similarity = cosine_similarity([embedding], [centroid])[0][0]
304
+ if similarity > best_similarity:
305
+ best_similarity = similarity
306
+ predicted_label = color
307
+
308
+ if predicted_label == true_label:
309
+ correct_predictions += 1
310
+
311
+ return correct_predictions / total_predictions if total_predictions > 0 else 0
312
+
313
+ def predict_colors_from_embeddings(self, embeddings, labels):
314
+ """Predict colors from embeddings using centroid-based classification"""
315
+ # Create color centroids from training data
316
+ unique_colors = list(set(labels))
317
+ centroids = {}
318
+
319
+ for color in unique_colors:
320
+ color_indices = [i for i, label in enumerate(labels) if label == color]
321
+ color_embeddings = embeddings[color_indices]
322
+ centroids[color] = np.mean(color_embeddings, axis=0)
323
+
324
+ # Predict colors for all embeddings
325
+ predictions = []
326
+
327
+ for i, embedding in enumerate(embeddings):
328
+ # Find closest centroid
329
+ best_similarity = -1
330
+ predicted_color = None
331
+
332
+ for color, centroid in centroids.items():
333
+ similarity = cosine_similarity([embedding], [centroid])[0][0]
334
+ if similarity > best_similarity:
335
+ best_similarity = similarity
336
+ predicted_color = color
337
+
338
+ predictions.append(predicted_color)
339
+
340
+ return predictions
341
+
342
+ def create_confusion_matrix(self, true_labels, predicted_labels, title="Confusion Matrix"):
343
+ """Create and plot confusion matrix"""
344
+ # Get unique labels
345
+ unique_labels = sorted(list(set(true_labels + predicted_labels)))
346
+
347
+ # Create confusion matrix
348
+ cm = confusion_matrix(true_labels, predicted_labels, labels=unique_labels)
349
+
350
+ # Calculate accuracy
351
+ accuracy = accuracy_score(true_labels, predicted_labels)
352
+
353
+ # Plot confusion matrix
354
+ plt.figure(figsize=(12, 10))
355
+ sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
356
+ xticklabels=unique_labels, yticklabels=unique_labels)
357
+ plt.title(f'{title}\nAccuracy: {accuracy:.3f} ({accuracy*100:.1f}%)')
358
+ plt.ylabel('True Color')
359
+ plt.xlabel('Predicted Color')
360
+ plt.xticks(rotation=45)
361
+ plt.yticks(rotation=0)
362
+ plt.tight_layout()
363
+
364
+ return plt.gcf(), accuracy, cm
365
+
366
+ def evaluate_classification_performance(self, embeddings, labels, embedding_type="Embeddings"):
367
+ """Evaluate classification performance and create confusion matrix"""
368
+ # Predict colors
369
+ predictions = self.predict_colors_from_embeddings(embeddings, labels)
370
+
371
+ # Calculate accuracy
372
+ accuracy = accuracy_score(labels, predictions)
373
+
374
+ # Create confusion matrix
375
+ fig, acc, cm = self.create_confusion_matrix(labels, predictions,
376
+ f"{embedding_type} - Color Classification")
377
+
378
+ # Generate classification report
379
+ unique_labels = sorted(list(set(labels)))
380
+ report = classification_report(labels, predictions, labels=unique_labels,
381
+ target_names=unique_labels, output_dict=True)
382
+
383
+ return {
384
+ 'accuracy': accuracy,
385
+ 'predictions': predictions,
386
+ 'confusion_matrix': cm,
387
+ 'classification_report': report,
388
+ 'figure': fig
389
+ }
390
+
391
+ def evaluate_dataset(self, dataframe, dataset_name="Dataset"):
392
+ """
393
+ Evaluate embeddings on a given dataset.
394
+
395
+ This method extracts embeddings for text, image, and color, computes similarity metrics,
396
+ evaluates classification performance, and saves confusion matrices.
397
+
398
+ Args:
399
+ dataframe: DataFrame containing the dataset
400
+ dataset_name: Name of the dataset for display purposes
401
+
402
+ Returns:
403
+ Dictionary containing evaluation results for text, image, and color embeddings
404
+ """
405
+ print(f"\n{'='*60}")
406
+ print(f"Evaluating {dataset_name}")
407
+ print(f"{'='*60}")
408
+
409
+ # Create dataset and dataloader - use KaglDataset for kagl data
410
+ if "kagl" in dataset_name.lower():
411
+ dataset = KaglDataset(dataframe)
412
+ else:
413
+ dataset = ColorDataset(dataframe)
414
+ # Optimize batch size and workers for faster processing
415
+ dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=4, pin_memory=True)
416
+
417
+ results = {}
418
+
419
+ # Evaluate text embeddings
420
+ text_embeddings, text_labels, texts = self.extract_embeddings(dataloader, 'text')
421
+ text_metrics = self.compute_similarity_metrics(text_embeddings, text_labels)
422
+ text_classification = self.evaluate_classification_performance(text_embeddings, text_labels, "Text Embeddings")
423
+ text_metrics.update(text_classification)
424
+ results['text'] = text_metrics
425
+
426
+ # Evaluate image embeddings
427
+ image_embeddings, image_labels, _ = self.extract_embeddings(dataloader, 'image')
428
+ image_metrics = self.compute_similarity_metrics(image_embeddings, image_labels)
429
+ image_classification = self.evaluate_classification_performance(image_embeddings, image_labels, "Image Embeddings")
430
+ image_metrics.update(image_classification)
431
+ results['image'] = image_metrics
432
+
433
+ # Evaluate color embeddings
434
+ color_embeddings, color_labels, _ = self.extract_embeddings(dataloader, 'color')
435
+ color_metrics = self.compute_similarity_metrics(color_embeddings, color_labels)
436
+ color_classification = self.evaluate_classification_performance(color_embeddings, color_labels, "Color Embeddings")
437
+ color_metrics.update(color_classification)
438
+ results['color'] = color_metrics
439
+
440
+ # Print results
441
+ print(f"\n{dataset_name} Results:")
442
+ print("-" * 40)
443
+ for emb_type, metrics in results.items():
444
+ print(f"{emb_type.capitalize()} Embeddings:")
445
+ print(f" Intra-class similarity (same color): {metrics['intra_class_mean']:.4f}")
446
+ print(f" Inter-class similarity (diff colors): {metrics['inter_class_mean']:.4f}")
447
+ print(f" Separation score: {metrics['separation_score']:.4f}")
448
+ print(f" Nearest Neighbor Accuracy: {metrics['accuracy']:.4f} ({metrics['accuracy']*100:.1f}%)")
449
+ print(f" Centroid Accuracy: {metrics['centroid_accuracy']:.4f} ({metrics['centroid_accuracy']*100:.1f}%)")
450
+
451
+ # Classification report summary
452
+ report = metrics['classification_report']
453
+ print(f" πŸ“Š Classification Performance:")
454
+ print(f" β€’ Macro Avg F1-Score: {report['macro avg']['f1-score']:.4f}")
455
+ print(f" β€’ Weighted Avg F1-Score: {report['weighted avg']['f1-score']:.4f}")
456
+ print(f" β€’ Support: {report['macro avg']['support']:.0f} samples")
457
+ print()
458
+
459
+ # Create visualizations
460
+ os.makedirs('embedding_evaluation', exist_ok=True)
461
+
462
+ # Confusion matrices
463
+ results['text']['figure'].savefig(f'embedding_evaluation/{dataset_name.lower()}_text_confusion_matrix.png', dpi=300, bbox_inches='tight')
464
+ plt.close(results['text']['figure'])
465
+
466
+ results['image']['figure'].savefig(f'embedding_evaluation/{dataset_name.lower()}_image_confusion_matrix.png', dpi=300, bbox_inches='tight')
467
+ plt.close(results['image']['figure'])
468
+
469
+ results['color']['figure'].savefig(f'embedding_evaluation/{dataset_name.lower()}_color_confusion_matrix.png', dpi=300, bbox_inches='tight')
470
+ plt.close(results['color']['figure'])
471
+
472
+ return results
473
+
474
+ class FashionCLIPDataset(Dataset):
475
+ """
476
+ Special dataset for Fashion-CLIP that doesn't normalize images.
477
+
478
+ This dataset is used when evaluating with Fashion-CLIP baseline model,
479
+ which requires different image preprocessing (no normalization).
480
+ """
481
+ def __init__(self, dataframe):
482
+ """
483
+ Initialize the Fashion-CLIP dataset.
484
+
485
+ Args:
486
+ dataframe: DataFrame containing image paths/URLs, text, and color labels
487
+ """
488
+ self.dataframe = dataframe
489
+ # Only resize and convert to tensor, no normalization
490
+ self.transform = transforms.Compose([
491
+ transforms.Resize((224, 224)),
492
+ transforms.ToTensor()
493
+ ])
494
+
495
+ def __len__(self):
496
+ return len(self.dataframe)
497
+
498
+ def __getitem__(self, idx):
499
+ row = self.dataframe.iloc[idx]
500
+
501
+ # Handle image - it should be in row[config.column_url_image] and contain the image data
502
+ image_data = row[config.column_url_image]
503
+
504
+ try:
505
+ # Check if image_data has 'bytes' key or is already PIL Image
506
+ if isinstance(image_data, dict) and 'bytes' in image_data:
507
+ image = Image.open(BytesIO(image_data['bytes'])).convert("RGB")
508
+ elif hasattr(image_data, 'convert'): # Already a PIL Image
509
+ image = image_data.convert("RGB")
510
+ elif isinstance(image_data, str):
511
+ # It's a file path (local or URL)
512
+ if image_data.startswith('http'):
513
+ # It's a URL - download the image
514
+ import requests
515
+ response = requests.get(image_data, timeout=10)
516
+ response.raise_for_status()
517
+ image = Image.open(BytesIO(response.content)).convert("RGB")
518
+ else:
519
+ # It's a local file path
520
+ image = Image.open(image_data).convert("RGB")
521
+ else:
522
+ # Assume it's bytes data
523
+ image = Image.open(BytesIO(image_data)).convert("RGB")
524
+
525
+ # Apply minimal transform (no normalization)
526
+ image = self.transform(image)
527
+
528
+ except Exception as e:
529
+ print(f"⚠️ Failed to load image {idx}: {e}")
530
+ # Return a placeholder image instead of undefined variable
531
+ image = torch.zeros(3, 224, 224)
532
+
533
+ # Get text and color
534
+ description = row[config.text_column]
535
+ color = row[config.color_column]
536
+
537
+ return image, description, color
538
+
539
+ class FashionCLIPEvaluator:
540
+ """
541
+ Evaluator for Fashion-CLIP baseline model.
542
+
543
+ This class provides methods to evaluate embeddings from the Fashion-CLIP model
544
+ and compare them with the custom ColorCLIP model.
545
+ """
546
+
547
+ def __init__(self):
548
+ """
549
+ Initialize the Fashion-CLIP evaluator.
550
+
551
+ Loads the Fashion-CLIP model from Hugging Face and prepares it for evaluation.
552
+ """
553
+ # Load Fashion-CLIP model
554
+ patrick_model_name = "patrickjohncyh/fashion-clip"
555
+ print(f"πŸ”„ Loading Fashion-CLIP model: {patrick_model_name}")
556
+ self.processor = CLIPProcessor.from_pretrained(patrick_model_name)
557
+ self.device = config.device
558
+ self.model = TransformersCLIPModel.from_pretrained(patrick_model_name).to(self.device)
559
+ self.model.eval()
560
+ print(f"βœ… Fashion-CLIP model loaded successfully")
561
+
562
+ def extract_embeddings(self, dataloader, embedding_type='text'):
563
+ """
564
+ Extract embeddings from the Fashion-CLIP model.
565
+
566
+ Args:
567
+ dataloader: DataLoader containing images, texts, and colors
568
+ embedding_type: Type of embeddings to extract ('text', 'image', or 'color')
569
+
570
+ Returns:
571
+ Tuple of (embeddings array, labels list, texts list)
572
+ """
573
+ all_embeddings = []
574
+ all_labels = []
575
+ all_texts = []
576
+
577
+ with torch.no_grad():
578
+ for images, texts, colors in tqdm(dataloader, desc=f"Extracting {embedding_type} embeddings (Fashion-CLIP)"):
579
+ if embedding_type == 'text':
580
+ # Process text through Fashion-CLIP
581
+ inputs = self.processor(text=texts, return_tensors="pt", padding=True, truncation=True, max_length=77)
582
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
583
+ text_features = self.model.get_text_features(**inputs)
584
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
585
+ embeddings = text_features.cpu().numpy()
586
+ labels = colors
587
+ elif embedding_type == 'image':
588
+ # Convert tensors back to PIL images for CLIP processor
589
+ pil_images = []
590
+ for i in range(images.shape[0]):
591
+ # Convert tensor back to PIL Image
592
+ img_tensor = images[i]
593
+ # Denormalize if needed (images should be in [0,1] range)
594
+ if img_tensor.min() < 0 or img_tensor.max() > 1:
595
+ # If normalized, denormalize
596
+ img_tensor = (img_tensor + 1) / 2 # Assuming [-1,1] to [0,1]
597
+ img_tensor = torch.clamp(img_tensor, 0, 1)
598
+ img_pil = transforms.ToPILImage()(img_tensor)
599
+ pil_images.append(img_pil)
600
+
601
+ # Process images through Fashion-CLIP
602
+ inputs = self.processor(images=pil_images, return_tensors="pt", padding=True)
603
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
604
+ image_features = self.model.get_image_features(**inputs)
605
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
606
+ embeddings = image_features.cpu().numpy()
607
+ labels = colors
608
+ elif embedding_type == 'color':
609
+ # Process color names as text through Fashion-CLIP
610
+ inputs = self.processor(text=colors, return_tensors="pt", padding=True, truncation=True, max_length=77)
611
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
612
+ text_features = self.model.get_text_features(**inputs)
613
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
614
+ embeddings = text_features.cpu().numpy()
615
+ labels = colors
616
+
617
+ all_embeddings.append(embeddings)
618
+ all_labels.extend(labels)
619
+ all_texts.extend(texts)
620
+
621
+ return np.vstack(all_embeddings), all_labels, all_texts
622
+
623
+ def compute_similarity_metrics(self, embeddings, labels):
624
+ """Compute intra-class and inter-class similarities"""
625
+ similarities = cosine_similarity(embeddings)
626
+
627
+ # Group embeddings by color
628
+ color_groups = defaultdict(list)
629
+ for i, color in enumerate(labels):
630
+ color_groups[color].append(i)
631
+
632
+ # Calculate intra-class similarities (same color)
633
+ intra_class_similarities = []
634
+ for color, indices in color_groups.items():
635
+ if len(indices) > 1:
636
+ for i in range(len(indices)):
637
+ for j in range(i+1, len(indices)):
638
+ sim = similarities[indices[i], indices[j]]
639
+ intra_class_similarities.append(sim)
640
+
641
+ # Calculate inter-class similarities (different colors)
642
+ inter_class_similarities = []
643
+ colors = list(color_groups.keys())
644
+ for i in range(len(colors)):
645
+ for j in range(i+1, len(colors)):
646
+ color1_indices = color_groups[colors[i]]
647
+ color2_indices = color_groups[colors[j]]
648
+
649
+ for idx1 in color1_indices:
650
+ for idx2 in color2_indices:
651
+ sim = similarities[idx1, idx2]
652
+ inter_class_similarities.append(sim)
653
+
654
+ # Calculate classification accuracy using nearest neighbor in embedding space
655
+ nn_accuracy = self.compute_embedding_accuracy(embeddings, labels, similarities)
656
+
657
+ # Calculate classification accuracy using centroids
658
+ centroid_accuracy = self.compute_centroid_accuracy(embeddings, labels)
659
+
660
+ return {
661
+ 'intra_class_similarities': intra_class_similarities,
662
+ 'inter_class_similarities': inter_class_similarities,
663
+ 'intra_class_mean': np.mean(intra_class_similarities) if intra_class_similarities else 0,
664
+ 'inter_class_mean': np.mean(inter_class_similarities) if inter_class_similarities else 0,
665
+ 'separation_score': np.mean(intra_class_similarities) - np.mean(inter_class_similarities) if intra_class_similarities and inter_class_similarities else 0,
666
+ 'accuracy': nn_accuracy,
667
+ 'centroid_accuracy': centroid_accuracy
668
+ }
669
+
670
+ def compute_embedding_accuracy(self, embeddings, labels, similarities):
671
+ """Compute classification accuracy using nearest neighbor in embedding space"""
672
+ correct_predictions = 0
673
+ total_predictions = len(labels)
674
+
675
+ for i in range(len(embeddings)):
676
+ true_label = labels[i]
677
+
678
+ # Find the most similar embedding (excluding itself)
679
+ similarities_row = similarities[i].copy()
680
+ similarities_row[i] = -1 # Exclude self-similarity
681
+ nearest_neighbor_idx = np.argmax(similarities_row)
682
+ predicted_label = labels[nearest_neighbor_idx]
683
+
684
+ if predicted_label == true_label:
685
+ correct_predictions += 1
686
+
687
+ return correct_predictions / total_predictions if total_predictions > 0 else 0
688
+
689
+ def compute_centroid_accuracy(self, embeddings, labels):
690
+ """Compute classification accuracy using color centroids"""
691
+ # Create centroids for each color
692
+ unique_colors = list(set(labels))
693
+ centroids = {}
694
+
695
+ for color in unique_colors:
696
+ color_indices = [i for i, label in enumerate(labels) if label == color]
697
+ color_embeddings = embeddings[color_indices]
698
+ centroids[color] = np.mean(color_embeddings, axis=0)
699
+
700
+ # Classify each embedding to nearest centroid
701
+ correct_predictions = 0
702
+ total_predictions = len(labels)
703
+
704
+ for i, embedding in enumerate(embeddings):
705
+ true_label = labels[i]
706
+
707
+ # Find closest centroid
708
+ best_similarity = -1
709
+ predicted_label = None
710
+
711
+ for color, centroid in centroids.items():
712
+ similarity = cosine_similarity([embedding], [centroid])[0][0]
713
+ if similarity > best_similarity:
714
+ best_similarity = similarity
715
+ predicted_label = color
716
+
717
+ if predicted_label == true_label:
718
+ correct_predictions += 1
719
+
720
+ return correct_predictions / total_predictions if total_predictions > 0 else 0
721
+
722
+ def predict_colors_from_embeddings(self, embeddings, labels):
723
+ """Predict colors from embeddings using centroid-based classification"""
724
+ # Create color centroids from training data
725
+ unique_colors = list(set(labels))
726
+ centroids = {}
727
+
728
+ for color in unique_colors:
729
+ color_indices = [i for i, label in enumerate(labels) if label == color]
730
+ color_embeddings = embeddings[color_indices]
731
+ centroids[color] = np.mean(color_embeddings, axis=0)
732
+
733
+ # Predict colors for all embeddings
734
+ predictions = []
735
+
736
+ for i, embedding in enumerate(embeddings):
737
+ # Find closest centroid
738
+ best_similarity = -1
739
+ predicted_color = None
740
+
741
+ for color, centroid in centroids.items():
742
+ similarity = cosine_similarity([embedding], [centroid])[0][0]
743
+ if similarity > best_similarity:
744
+ best_similarity = similarity
745
+ predicted_color = color
746
+
747
+ predictions.append(predicted_color)
748
+
749
+ return predictions
750
+
751
+ def create_confusion_matrix(self, true_labels, predicted_labels, title="Confusion Matrix"):
752
+ """Create and plot confusion matrix"""
753
+ # Get unique labels
754
+ unique_labels = sorted(list(set(true_labels + predicted_labels)))
755
+
756
+ # Create confusion matrix
757
+ cm = confusion_matrix(true_labels, predicted_labels, labels=unique_labels)
758
+
759
+ # Calculate accuracy
760
+ accuracy = accuracy_score(true_labels, predicted_labels)
761
+
762
+ # Plot confusion matrix
763
+ plt.figure(figsize=(12, 10))
764
+ sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
765
+ xticklabels=unique_labels, yticklabels=unique_labels)
766
+ plt.title(f'{title}\nAccuracy: {accuracy:.3f} ({accuracy*100:.1f}%)')
767
+ plt.ylabel('True Color')
768
+ plt.xlabel('Predicted Color')
769
+ plt.xticks(rotation=45)
770
+ plt.yticks(rotation=0)
771
+ plt.tight_layout()
772
+
773
+ return plt.gcf(), accuracy, cm
774
+
775
+ def evaluate_classification_performance(self, embeddings, labels, embedding_type="Embeddings"):
776
+ """Evaluate classification performance and create confusion matrix"""
777
+ # Predict colors
778
+ predictions = self.predict_colors_from_embeddings(embeddings, labels)
779
+
780
+ # Calculate accuracy
781
+ accuracy = accuracy_score(labels, predictions)
782
+
783
+ # Create confusion matrix
784
+ fig, acc, cm = self.create_confusion_matrix(labels, predictions,
785
+ f"{embedding_type} - Color Classification (Fashion-CLIP)")
786
+
787
+ # Generate classification report
788
+ unique_labels = sorted(list(set(labels)))
789
+ report = classification_report(labels, predictions, labels=unique_labels,
790
+ target_names=unique_labels, output_dict=True)
791
+
792
+ return {
793
+ 'accuracy': accuracy,
794
+ 'predictions': predictions,
795
+ 'confusion_matrix': cm,
796
+ 'classification_report': report,
797
+ 'figure': fig
798
+ }
799
+
800
+ def evaluate_dataset(self, dataframe, dataset_name="Dataset"):
801
+ """
802
+ Evaluate Fashion-CLIP embeddings on a given dataset.
803
+
804
+ This method extracts embeddings for text, image, and color, computes similarity metrics,
805
+ evaluates classification performance, and saves confusion matrices.
806
+
807
+ Args:
808
+ dataframe: DataFrame containing the dataset
809
+ dataset_name: Name of the dataset for display purposes
810
+
811
+ Returns:
812
+ Dictionary containing evaluation results for text, image, and color embeddings
813
+ """
814
+ print(f"\n{'='*60}")
815
+ print(f"Evaluating {dataset_name} with Fashion-CLIP")
816
+ print(f"{'='*60}")
817
+
818
+ # Create dataset and dataloader - use FashionCLIPDataset for Fashion-CLIP
819
+ if "kagl" in dataset_name.lower():
820
+ dataset = KaglDataset(dataframe)
821
+ else:
822
+ dataset = FashionCLIPDataset(dataframe) # Use special dataset for Fashion-CLIP
823
+ # Optimize batch size for Fashion-CLIP
824
+ dataloader = DataLoader(dataset, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)
825
+
826
+ results = {}
827
+
828
+ # Evaluate text embeddings
829
+ text_embeddings, text_labels, texts = self.extract_embeddings(dataloader, 'text')
830
+ text_metrics = self.compute_similarity_metrics(text_embeddings, text_labels)
831
+ text_classification = self.evaluate_classification_performance(text_embeddings, text_labels, "Text Embeddings")
832
+ text_metrics.update(text_classification)
833
+ results['text'] = text_metrics
834
+
835
+ # Evaluate image embeddings
836
+ image_embeddings, image_labels, _ = self.extract_embeddings(dataloader, 'image')
837
+ image_metrics = self.compute_similarity_metrics(image_embeddings, image_labels)
838
+ image_classification = self.evaluate_classification_performance(image_embeddings, image_labels, "Image Embeddings")
839
+ image_metrics.update(image_classification)
840
+ results['image'] = image_metrics
841
+
842
+ # Evaluate color embeddings
843
+ color_embeddings, color_labels, _ = self.extract_embeddings(dataloader, 'color')
844
+ color_metrics = self.compute_similarity_metrics(color_embeddings, color_labels)
845
+ color_classification = self.evaluate_classification_performance(color_embeddings, color_labels, "Color Embeddings")
846
+ color_metrics.update(color_classification)
847
+ results['color'] = color_metrics
848
+
849
+ # Print results
850
+ print(f"\n{dataset_name} Results (Fashion-CLIP):")
851
+ print("-" * 40)
852
+ for emb_type, metrics in results.items():
853
+ print(f"{emb_type.capitalize()} Embeddings:")
854
+ print(f" Intra-class similarity (same color): {metrics['intra_class_mean']:.4f}")
855
+ print(f" Inter-class similarity (diff colors): {metrics['inter_class_mean']:.4f}")
856
+ print(f" Separation score: {metrics['separation_score']:.4f}")
857
+ print(f" Nearest Neighbor Accuracy: {metrics['accuracy']:.4f} ({metrics['accuracy']*100:.1f}%)")
858
+ print(f" Centroid Accuracy: {metrics['centroid_accuracy']:.4f} ({metrics['centroid_accuracy']*100:.1f}%)")
859
+
860
+ # Classification report summary
861
+ report = metrics['classification_report']
862
+ print(f" πŸ“Š Classification Performance:")
863
+ print(f" β€’ Macro Avg F1-Score: {report['macro avg']['f1-score']:.4f}")
864
+ print(f" β€’ Weighted Avg F1-Score: {report['weighted avg']['f1-score']:.4f}")
865
+ print(f" β€’ Support: {report['macro avg']['support']:.0f} samples")
866
+ print()
867
+
868
+ # Create visualizations
869
+ os.makedirs('embedding_evaluation', exist_ok=True)
870
+
871
+ # Confusion matrices
872
+ results['text']['figure'].savefig(f'embedding_evaluation/{dataset_name.lower()}_text_confusion_matrix_fashion_clip.png', dpi=300, bbox_inches='tight')
873
+ plt.close(results['text']['figure'])
874
+
875
+ results['image']['figure'].savefig(f'embedding_evaluation/{dataset_name.lower()}_image_confusion_matrix_fashion_clip.png', dpi=300, bbox_inches='tight')
876
+ plt.close(results['image']['figure'])
877
+
878
+ results['color']['figure'].savefig(f'embedding_evaluation/{dataset_name.lower()}_color_confusion_matrix_fashion_clip.png', dpi=300, bbox_inches='tight')
879
+ plt.close(results['color']['figure'])
880
+
881
+ return results
882
+
883
+ class KaglDataset(Dataset):
884
+ """
885
+ Dataset class for KAGL Marqo dataset evaluation.
886
+
887
+ Handles loading images from the KAGL dataset format (with 'bytes' in image_url).
888
+ """
889
+ def __init__(self, dataframe):
890
+ """
891
+ Initialize the KAGL dataset.
892
+
893
+ Args:
894
+ dataframe: DataFrame containing image_url (with bytes), text, and color labels
895
+ """
896
+ self.dataframe = dataframe
897
+ self.transform = transforms.Compose([
898
+ transforms.Resize((224, 224)),
899
+ transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
900
+ transforms.ToTensor(),
901
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
902
+ ])
903
+
904
+ def __len__(self):
905
+ return len(self.dataframe)
906
+
907
+ def __getitem__(self, idx):
908
+ row = self.dataframe.iloc[idx]
909
+
910
+ # Handle image - it should be in row['image_url'] and contain the image data
911
+ image_data = row["image_url"]
912
+
913
+ # Check if image_data has 'bytes' key or is already PIL Image
914
+ if isinstance(image_data, dict) and 'bytes' in image_data:
915
+ image = Image.open(BytesIO(image_data['bytes'])).convert("RGB")
916
+ elif hasattr(image_data, 'convert'): # Already a PIL Image
917
+ image = image_data.convert("RGB")
918
+ else:
919
+ image = Image.open(BytesIO(image_data)).convert("RGB")
920
+
921
+ image = self.transform(image)
922
+
923
+ # Get text and color from kagl
924
+ description = row['text']
925
+ color = row['color']
926
+
927
+ return image, description, color
928
+
929
+ def load_kagl_marqo_dataset():
930
+ """
931
+ Load and prepare KAGL Marqo dataset from Hugging Face.
932
+
933
+ This function loads the Marqo/KAGL dataset, filters for valid colors,
934
+ and formats it for evaluation.
935
+
936
+ Returns:
937
+ DataFrame with columns: image_url, text, color
938
+ """
939
+ from datasets import load_dataset
940
+ print("Loading kagl KAGL dataset...")
941
+
942
+ # Load the dataset
943
+ dataset = load_dataset("Marqo/KAGL")
944
+ df = dataset["data"].to_pandas()
945
+ print(f"βœ… Dataset kagl loaded")
946
+
947
+ # Prepare data - Replace baseColour
948
+ df['baseColour'] = df['baseColour'].str.lower().str.replace("grey", "gray")
949
+ df_test = df[df['baseColour'].notna()].copy()
950
+
951
+ print(f"πŸ“Š Before filtering: {len(df_test)} samples")
952
+
953
+ # Filter for common colors
954
+ valid_colors = ['red', 'blue', 'green', 'yellow', 'purple', 'pink', 'orange',
955
+ 'brown', 'black', 'white', 'gray', 'navy', 'maroon', 'beige']
956
+ df_test = df_test[df_test['baseColour'].isin(valid_colors)]
957
+
958
+ print(f"πŸ“Š After filtering invalid colors: {len(df_test)} samples")
959
+ print(f"🎨 Valid colors found: {sorted(df_test['baseColour'].unique())}")
960
+
961
+ if len(df_test) == 0:
962
+ print("❌ No samples left after color filtering. Using mock dataset.")
963
+
964
+ # Map to our expected column names
965
+ kagl_formatted = pd.DataFrame({
966
+ 'image_url': df_test['image_url'],
967
+ 'text': df_test['text'],
968
+ 'color': df_test['baseColour'].str.lower().str.replace("grey", "gray")
969
+ })
970
+
971
+ # Additional validation - remove rows with missing data
972
+ print(f"πŸ“Š Before final validation: {len(kagl_formatted)} samples")
973
+ kagl_formatted = kagl_formatted.dropna(subset=[config.column_url_image, config.text_column, config.color_column])
974
+ print(f"πŸ“Š After removing missing data: {len(kagl_formatted)} samples")
975
+
976
+ # Check for empty strings
977
+ kagl_formatted = kagl_formatted[
978
+ (kagl_formatted['text'].str.strip() != '') &
979
+ (kagl_formatted['color'].str.strip() != '')
980
+ ]
981
+ print(f"πŸ“Š After removing empty strings: {len(kagl_formatted)} samples")
982
+
983
+ print(f"πŸ“Š Final dataset size: {len(kagl_formatted)} samples")
984
+
985
+ return kagl_formatted
986
+
987
+ def create_comparison_table(val_results, kagl_results, val_results_fashion_clip, kagl_results_fashion_clip):
988
+ """
989
+ Create a structured comparison table between custom model and Fashion-CLIP baseline.
990
+
991
+ Args:
992
+ val_results: Evaluation results for custom model on validation dataset
993
+ kagl_results: Evaluation results for custom model on KAGL dataset
994
+ val_results_fashion_clip: Evaluation results for Fashion-CLIP on validation dataset
995
+ kagl_results_fashion_clip: Evaluation results for Fashion-CLIP on KAGL dataset
996
+
997
+ Returns:
998
+ DataFrame containing the comparison table
999
+ """
1000
+
1001
+ # Create DataFrame for comparison
1002
+ data = []
1003
+
1004
+ # Define embedding types and their display names
1005
+ embedding_types = [
1006
+ ('text', 'Text Embeddings'),
1007
+ ('image', 'Image Embeddings'),
1008
+ ('color', 'Color Embeddings')
1009
+ ]
1010
+
1011
+ # Define datasets
1012
+ datasets = [
1013
+ ('Validation Dataset', val_results, val_results_fashion_clip),
1014
+ ('kagl Marqo Dataset', kagl_results, kagl_results_fashion_clip)
1015
+ ]
1016
+
1017
+ for dataset_name, custom_results, baseline_results in datasets:
1018
+ for emb_type, emb_display in embedding_types:
1019
+ # Your custom model results
1020
+ custom_metrics = custom_results[emb_type]
1021
+ # Baseline model results
1022
+ baseline_metrics = baseline_results[emb_type]
1023
+
1024
+ data.append({
1025
+ 'Dataset': dataset_name,
1026
+ 'Embedding Type': emb_display,
1027
+ 'Model': 'Your Model',
1028
+ 'Separation Score': f"{custom_metrics['separation_score']:.4f}",
1029
+ 'NN Accuracy (%)': f"{custom_metrics['accuracy']*100:.1f}%",
1030
+ 'Centroid Accuracy (%)': f"{custom_metrics['centroid_accuracy']*100:.1f}%",
1031
+ 'Intra-class Similarity': f"{custom_metrics['intra_class_mean']:.4f}",
1032
+ 'Inter-class Similarity': f"{custom_metrics['inter_class_mean']:.4f}",
1033
+ 'Macro F1-Score': f"{custom_metrics['classification_report']['macro avg']['f1-score']:.4f}",
1034
+ 'Weighted F1-Score': f"{custom_metrics['classification_report']['weighted avg']['f1-score']:.4f}"
1035
+ })
1036
+
1037
+ data.append({
1038
+ 'Dataset': dataset_name,
1039
+ 'Embedding Type': emb_display,
1040
+ 'Model': 'Fashion-CLIP (Baseline)',
1041
+ 'Separation Score': f"{baseline_metrics['separation_score']:.4f}",
1042
+ 'NN Accuracy (%)': f"{baseline_metrics['accuracy']*100:.1f}%",
1043
+ 'Centroid Accuracy (%)': f"{baseline_metrics['centroid_accuracy']*100:.1f}%",
1044
+ 'Intra-class Similarity': f"{baseline_metrics['intra_class_mean']:.4f}",
1045
+ 'Inter-class Similarity': f"{baseline_metrics['inter_class_mean']:.4f}",
1046
+ 'Macro F1-Score': f"{baseline_metrics['classification_report']['macro avg']['f1-score']:.4f}",
1047
+ 'Weighted F1-Score': f"{baseline_metrics['classification_report']['weighted avg']['f1-score']:.4f}"
1048
+ })
1049
+
1050
+ # Create DataFrame
1051
+ df_comparison = pd.DataFrame(data)
1052
+
1053
+ # Save to CSV
1054
+ df_comparison.to_csv('embedding_evaluation/model_comparison_table.csv', index=False)
1055
+
1056
+ # Print formatted table
1057
+ print(f"\n{'='*120}")
1058
+ print("πŸ“Š COMPREHENSIVE MODEL COMPARISON TABLE")
1059
+ print(f"{'='*120}")
1060
+
1061
+ # Print table by dataset
1062
+ for dataset_name in df_comparison['Dataset'].unique():
1063
+ print(f"\nπŸ” {dataset_name.upper()}")
1064
+ print("-" * 120)
1065
+
1066
+ dataset_df = df_comparison[df_comparison['Dataset'] == dataset_name]
1067
+
1068
+ for emb_type in dataset_df['Embedding Type'].unique():
1069
+ print(f"\nπŸ“ˆ {emb_type}:")
1070
+ emb_df = dataset_df[dataset_df['Embedding Type'] == emb_type]
1071
+
1072
+ # Print header
1073
+ print(f"{'Model':<20} {'Separation':<12} {'NN Acc':<10} {'Centroid Acc':<13} {'Intra-class':<12} {'Inter-class':<12} {'Macro F1':<10} {'Weighted F1':<12}")
1074
+ print("-" * 120)
1075
+
1076
+ # Print data
1077
+ for _, row in emb_df.iterrows():
1078
+ print(f"{row['Model']:<20} {row['Separation Score']:<12} {row['NN Accuracy (%)']:<10} {row['Centroid Accuracy (%)']:<13} {row['Intra-class Similarity']:<12} {row['Inter-class Similarity']:<12} {row['Macro F1-Score']:<10} {row['Weighted F1-Score']:<12}")
1079
+
1080
+ return df_comparison
1081
+
1082
+ if __name__ == "__main__":
1083
+
1084
+ # Initialize evaluator for your custom model
1085
+ evaluator = EmbeddingEvaluator(model_path=config.color_model_path, embed_dim=config.color_emb_dim)
1086
+
1087
+ # Initialize Fashion-CLIP evaluator
1088
+ fashion_clip_evaluator = FashionCLIPEvaluator()
1089
+
1090
+ # Load datasets
1091
+ print("Loading datasets...")
1092
+
1093
+ # Load validation dataset
1094
+ df_val = pd.read_csv(config.local_dataset_path)
1095
+
1096
+ # Filter for better quality data
1097
+ print(f"πŸ“Š Original dataset size: {len(df_val)}")
1098
+ samples_to_evaluate = 10000
1099
+
1100
+
1101
+ # Load kagl Marqo dataset
1102
+ kagl_df = load_kagl_marqo_dataset()
1103
+
1104
+ # Evaluate your custom model on validation dataset
1105
+ val_results = evaluator.evaluate_dataset(df_val, "Validation Dataset")
1106
+
1107
+ # Evaluate your custom model on kagl Marqo dataset (reduced sample for speed)
1108
+ kagl_results = evaluator.evaluate_dataset(kagl_df.sample(min(samples_to_evaluate, len(kagl_df)), random_state=42), "kagl Marqo Dataset")
1109
+
1110
+ # Evaluate Fashion-CLIP on validation dataset
1111
+ val_results_fashion_clip = fashion_clip_evaluator.evaluate_dataset(df_val, "Validation Dataset")
1112
+
1113
+ # Create comprehensive comparison table
1114
+ comparison_df = create_comparison_table(
1115
+ val_results, kagl_results,
1116
+ val_results_fashion_clip
1117
+ )
1118
+
1119
+ print(f"\n{'='*120}")
1120
+ print("βœ… Evaluation complete!")
1121
+ print("πŸ“ Confusion matrices saved in 'embedding_evaluation/' folder")
1122
+ print("πŸ“ Comparison table saved as 'model_comparison_table.csv'")
1123
+ print("πŸ“ Fashion-CLIP results are saved with '_fashion_clip' suffix.")
1124
+ print(f"{'='*120}")