Leacb4 commited on
Commit
fb6be9e
·
verified ·
1 Parent(s): ed400b3

Delete Evaluation

Browse files
Evaluation/0_shot_classification.py DELETED
@@ -1,512 +0,0 @@
1
- """
2
- Zero-shot classification evaluation on a new dataset.
3
- This file evaluates the main model's performance on unseen data by performing
4
- zero-shot classification. It compares three methods: color-to-color classification,
5
- text-to-text, and image-to-text. It generates confusion matrices and classification reports
6
- for each method to analyze the model's generalization capability.
7
- """
8
-
9
- import os
10
- # Set environment variable to disable tokenizers parallelism warnings
11
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
12
-
13
- import torch
14
- import torch.nn.functional as F
15
- import numpy as np
16
- import pandas as pd
17
- from torch.utils.data import Dataset
18
- import matplotlib.pyplot as plt
19
- from PIL import Image
20
- from torchvision import transforms
21
- from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers
22
- import warnings
23
- import config
24
- from tqdm import tqdm
25
- from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
26
- import seaborn as sns
27
- from color_model import CLIPModel as ColorModel
28
- from hierarchy_model import Model, HierarchyExtractor
29
-
30
- # Suppress warnings
31
- warnings.filterwarnings("ignore", category=FutureWarning)
32
- warnings.filterwarnings("ignore", category=UserWarning)
33
-
34
- def load_trained_model(model_path, device):
35
- """
36
- Load the trained CLIP model from checkpoint
37
- """
38
- print(f"Loading trained model from: {model_path}")
39
-
40
- # Load checkpoint
41
- checkpoint = torch.load(model_path, map_location=device)
42
-
43
- # Create the base CLIP model
44
- model = CLIPModel_transformers.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
45
-
46
- # Load the trained weights
47
- model.load_state_dict(checkpoint['model_state_dict'])
48
- model = model.to(device)
49
- model.eval()
50
-
51
- print(f"✅ Model loaded successfully!")
52
- print(f"📊 Training epoch: {checkpoint['epoch']}")
53
- print(f"📉 Best validation loss: {checkpoint['best_val_loss']:.4f}")
54
-
55
- return model, checkpoint
56
-
57
- def load_feature_models(device):
58
- """Load feature models (color and hierarchy)"""
59
-
60
- # Load color model (embed_dim=16)
61
- color_checkpoint = torch.load(config.color_model_path, map_location=device, weights_only=True)
62
- color_model = ColorModel(embed_dim=config.color_emb_dim).to(device)
63
- color_model.load_state_dict(color_checkpoint)
64
- color_model.eval()
65
- color_model.name = 'color'
66
-
67
- # Load hierarchy model (embed_dim=64)
68
- hierarchy_checkpoint = torch.load(config.hierarchy_model_path, map_location=device)
69
- hierarchy_classes = hierarchy_checkpoint.get('hierarchy_classes', [])
70
- hierarchy_model = Model(
71
- num_hierarchy_classes=len(hierarchy_classes),
72
- embed_dim=config.hierarchy_emb_dim
73
- ).to(device)
74
- hierarchy_model.load_state_dict(hierarchy_checkpoint['model_state'])
75
-
76
- # Set up hierarchy extractor
77
- hierarchy_extractor = HierarchyExtractor(hierarchy_classes, verbose=False)
78
- hierarchy_model.set_hierarchy_extractor(hierarchy_extractor)
79
- hierarchy_model.eval()
80
- hierarchy_model.name = 'hierarchy'
81
-
82
- feature_models = {model.name: model for model in [color_model, hierarchy_model]}
83
- return feature_models
84
-
85
- def get_image_embedding(model, image, device):
86
- """Get image embedding from the trained model"""
87
- model.eval()
88
- with torch.no_grad():
89
- # Ensure image has 3 channels
90
- if image.dim() == 3 and image.size(0) == 1:
91
- image = image.expand(3, -1, -1)
92
- elif image.dim() == 4 and image.size(1) == 1:
93
- image = image.expand(-1, 3, -1, -1)
94
-
95
- # Add batch dimension if missing
96
- if image.dim() == 3:
97
- image = image.unsqueeze(0) # Add batch dimension: (C, H, W) -> (1, C, H, W)
98
-
99
- image = image.to(device)
100
-
101
- # Use vision model directly to get image embeddings
102
- vision_outputs = model.vision_model(pixel_values=image)
103
- image_features = model.visual_projection(vision_outputs.pooler_output)
104
-
105
- return F.normalize(image_features, dim=-1)
106
-
107
- def get_text_embedding(model, text, processor, device):
108
- """Get text embedding from the trained model"""
109
- model.eval()
110
- with torch.no_grad():
111
- text_inputs = processor(text=text, padding=True, return_tensors="pt")
112
- text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
113
-
114
- # Use text model directly to get text embeddings
115
- text_outputs = model.text_model(**text_inputs)
116
- text_features = model.text_projection(text_outputs.pooler_output)
117
-
118
- return F.normalize(text_features, dim=-1)
119
-
120
- def evaluate_custom_csv_accuracy(model, dataset, processor, method='similarity'):
121
- """
122
- Evaluate the accuracy of the model on your custom CSV using text-to-text similarity
123
-
124
- Args:
125
- model: The trained CLIP model
126
- dataset: CustomCSVDataset
127
- processor: CLIPProcessor
128
- method: 'similarity' or 'classification'
129
- """
130
- print(f"\n📊 === Evaluation of the accuracy on custom CSV (TEXT-TO-TEXT method) ===")
131
-
132
- model.eval()
133
-
134
- # Get all unique colors for classification
135
- all_colors = set()
136
- for i in range(len(dataset)):
137
- _, _, color = dataset[i]
138
- all_colors.add(color)
139
-
140
- color_list = sorted(list(all_colors))
141
- print(f"🎨 Colors found: {color_list}")
142
-
143
- true_labels = []
144
- predicted_labels = []
145
-
146
- # Pre-calculate the embeddings of the color descriptions
147
- print("🔄 Pre-calculating the embeddings of the colors...")
148
- color_embeddings = {}
149
- for color in color_list:
150
- color_emb = get_text_embedding(model, color, processor)
151
- color_embeddings[color] = color_emb
152
-
153
- print("🔄 Evaluation in progress...")
154
- correct_predictions = 0
155
-
156
- for idx in tqdm(range(len(dataset)), desc="Evaluation"):
157
- image, text, true_color = dataset[idx]
158
-
159
- # Get text embedding instead of image embedding
160
- text_emb = get_text_embedding(model, text, processor)
161
-
162
- # Calculate the similarity with each possible color
163
- best_similarity = -1
164
- predicted_color = color_list[0]
165
-
166
- for color, color_emb in color_embeddings.items():
167
- similarity = F.cosine_similarity(text_emb, color_emb, dim=1).item()
168
- if similarity > best_similarity:
169
- best_similarity = similarity
170
- predicted_color = color
171
-
172
- true_labels.append(true_color)
173
- predicted_labels.append(predicted_color)
174
-
175
- if true_color == predicted_color:
176
- correct_predictions += 1
177
-
178
- # Calculate the accuracy
179
- accuracy = accuracy_score(true_labels, predicted_labels)
180
-
181
- print(f"\n✅ Results of evaluation:")
182
- print(f"🎯 Global accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
183
- print(f"📊 Correct predictions: {correct_predictions}/{len(true_labels)}")
184
-
185
- return true_labels, predicted_labels, accuracy
186
-
187
- def evaluate_custom_csv_accuracy_image(model, dataset, processor, method='similarity'):
188
- """
189
- Evaluate the accuracy of the model on your custom CSV using image-to-text similarity
190
-
191
- Args:
192
- model: The trained CLIP model
193
- dataset: CustomCSVDataset with images loaded
194
- processor: CLIPProcessor
195
- method: 'similarity' or 'classification'
196
- """
197
- print(f"\n📊 === Evaluation of the accuracy on custom CSV (IMAGE-TO-TEXT method) ===")
198
-
199
- model.eval()
200
-
201
- # Get all unique colors for classification
202
- all_colors = set()
203
- for i in range(len(dataset)):
204
- _, _, color = dataset[i]
205
- all_colors.add(color)
206
-
207
- color_list = sorted(list(all_colors))
208
- print(f"🎨 Colors found: {color_list}")
209
-
210
- true_labels = []
211
- predicted_labels = []
212
-
213
- # Pre-calculate the embeddings of the color descriptions
214
- print("🔄 Pre-calculating the embeddings of the colors...")
215
- color_embeddings = {}
216
- for color in color_list:
217
- color_emb = get_text_embedding(model, color, processor)
218
- color_embeddings[color] = color_emb
219
-
220
- print("🔄 Evaluation in progress...")
221
- correct_predictions = 0
222
-
223
- for idx in tqdm(range(len(dataset)), desc="Evaluation"):
224
- image, text, true_color = dataset[idx]
225
-
226
- # Get image embedding (this is the key difference from text-to-text)
227
- image_emb = get_image_embedding(model, image, processor)
228
-
229
- # Calculate the similarity with each possible color
230
- best_similarity = -1
231
- predicted_color = color_list[0]
232
-
233
- for color, color_emb in color_embeddings.items():
234
- similarity = F.cosine_similarity(image_emb, color_emb, dim=1).item()
235
- if similarity > best_similarity:
236
- best_similarity = similarity
237
- predicted_color = color
238
-
239
- true_labels.append(true_color)
240
- predicted_labels.append(predicted_color)
241
-
242
- if true_color == predicted_color:
243
- correct_predictions += 1
244
-
245
- # Calculate the accuracy
246
- accuracy = accuracy_score(true_labels, predicted_labels)
247
-
248
- print(f"\n✅ Results of evaluation:")
249
- print(f"🎯 Global accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
250
- print(f"📊 Correct predictions: {correct_predictions}/{len(true_labels)}")
251
-
252
- return true_labels, predicted_labels, accuracy
253
-
254
- def evaluate_custom_csv_accuracy_color_only(model, dataset, processor):
255
- """
256
- Evaluate the accuracy by encoding ONLY the color (not the full text)
257
- This tests if the embedding space is consistent for colors
258
-
259
- Args:
260
- model: The trained CLIP model
261
- dataset: CustomCSVDataset
262
- processor: CLIPProcessor
263
- """
264
- print(f"\n📊 === Evaluation of the accuracy on custom CSV (COLOR-TO-COLOR method) ===")
265
- print("🔬 This test encodes ONLY the color name, not the full text")
266
-
267
- model.eval()
268
-
269
- # Get all unique colors for classification
270
- all_colors = set()
271
- for i in range(len(dataset)):
272
- _, _, color = dataset[i]
273
- all_colors.add(color)
274
-
275
- color_list = sorted(list(all_colors))
276
- print(f"🎨 Colors found: {color_list}")
277
-
278
- true_labels = []
279
- predicted_labels = []
280
-
281
- # Pre-calculate the embeddings of the color descriptions
282
- print("🔄 Pre-calculating the embeddings of the colors...")
283
- color_embeddings = {}
284
- for color in color_list:
285
- color_emb = get_text_embedding(model, color, processor)
286
- color_embeddings[color] = color_emb
287
-
288
- print("🔄 Evaluation in progress...")
289
- correct_predictions = 0
290
-
291
- for idx in tqdm(range(len(dataset)), desc="Evaluation"):
292
- image, text, true_color = dataset[idx]
293
-
294
- # KEY DIFFERENCE: Get embedding of the TRUE COLOR only (not the full text)
295
- true_color_emb = get_text_embedding(model, true_color, processor)
296
-
297
- # Calculate the similarity with each possible color
298
- best_similarity = -1
299
- predicted_color = color_list[0]
300
-
301
- for color, color_emb in color_embeddings.items():
302
- similarity = F.cosine_similarity(true_color_emb, color_emb, dim=1).item()
303
- if similarity > best_similarity:
304
- best_similarity = similarity
305
- predicted_color = color
306
-
307
- true_labels.append(true_color)
308
- predicted_labels.append(predicted_color)
309
-
310
- if true_color == predicted_color:
311
- correct_predictions += 1
312
-
313
- # Calculate the accuracy
314
- accuracy = accuracy_score(true_labels, predicted_labels)
315
-
316
- print(f"\n✅ Results of evaluation:")
317
- print(f"🎯 Global accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
318
- print(f"📊 Correct predictions: {correct_predictions}/{len(true_labels)}")
319
-
320
- return true_labels, predicted_labels, accuracy
321
-
322
- def search_custom_csv_by_text(model, dataset, query, processor, top_k=5):
323
- """Search in your CSV by text query"""
324
- print(f"\n🔍 Search in custom CSV: '{query}'")
325
-
326
- # Get the embedding of the query
327
- query_emb = get_text_embedding(model, query, processor)
328
-
329
- similarities = []
330
-
331
- print("🔄 Calculating similarities...")
332
- for idx in tqdm(range(len(dataset)), desc="Processing"):
333
- image, text, color, _, image_path = dataset[idx]
334
-
335
- # Get the embedding of the image
336
- image_emb = get_image_embedding(model, image, processor)
337
-
338
- # Calculer la similarité
339
- similarity = F.cosine_similarity(query_emb, image_emb, dim=1).item()
340
-
341
- similarities.append((idx, similarity, text, color, color, image_path))
342
-
343
- # Trier par similarité
344
- similarities.sort(key=lambda x: x[1], reverse=True)
345
-
346
- return similarities[:top_k]
347
-
348
- def plot_confusion_matrix(true_labels, predicted_labels, save_path=None, title_suffix="text"):
349
- """
350
- Display and save the confusion matrix
351
- """
352
- print("\n📈 === Generation of the confusion matrix ===")
353
-
354
- # Calculate the confusion matrix
355
- cm = confusion_matrix(true_labels, predicted_labels)
356
-
357
- # Get unique labels in sorted order
358
- unique_labels = sorted(set(true_labels + predicted_labels))
359
-
360
- # Calculate accuracy
361
- accuracy = accuracy_score(true_labels, predicted_labels)
362
-
363
- # Calculate the percentages and round to integers
364
- cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100
365
- cm_percent = np.around(cm_percent).astype(int)
366
-
367
- # Create the figure
368
- plt.figure(figsize=(12, 10))
369
-
370
- # Confusion matrix with percentages and labels (no decimal points)
371
- sns.heatmap(cm_percent,
372
- annot=True,
373
- fmt='d',
374
- cmap='Blues',
375
- cbar_kws={'label': 'Percentage (%)'},
376
- xticklabels=unique_labels,
377
- yticklabels=unique_labels)
378
-
379
- plt.title(f"Confusion Matrix for {title_suffix} - new data - accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)", fontsize=16)
380
- plt.xlabel('Predictions', fontsize=12)
381
- plt.ylabel('True colors', fontsize=12)
382
- plt.xticks(rotation=45, ha='right')
383
- plt.yticks(rotation=0)
384
- plt.tight_layout()
385
-
386
- if save_path:
387
- plt.savefig(save_path, dpi=300, bbox_inches='tight')
388
- print(f"💾 Confusion matrix saved: {save_path}")
389
-
390
- plt.show()
391
-
392
- return cm
393
-
394
- class CustomCSVDataset(Dataset):
395
- def __init__(self, dataframe, image_size=224, load_images=True):
396
- self.dataframe = dataframe
397
- self.image_size = image_size
398
- self.load_images = load_images
399
-
400
- # Define image transformations
401
- self.transform = transforms.Compose([
402
- transforms.Resize((image_size, image_size)),
403
- transforms.ToTensor(),
404
- transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
405
- std=[0.26862954, 0.26130258, 0.27577711])
406
- ])
407
-
408
- def __len__(self):
409
- return len(self.dataframe)
410
-
411
- def __getitem__(self, idx):
412
- row = self.dataframe.iloc[idx]
413
- text = row[config.text_column]
414
- colors = row[config.color_column]
415
-
416
- if self.load_images and config.column_local_image_path in row:
417
- # Load the actual image
418
- try:
419
- image = Image.open(row[config.column_local_image_path]).convert('RGB')
420
- image = self.transform(image)
421
- except Exception as e:
422
- print(f"Warning: Could not load image {row.get(config.column_local_image_path, 'unknown')}: {e}")
423
- image = torch.zeros(3, self.image_size, self.image_size)
424
- else:
425
- # Return dummy image if not loading images
426
- image = torch.zeros(3, self.image_size, self.image_size)
427
-
428
- return image, text, colors
429
-
430
- if __name__ == "__main__":
431
- """Main function with evaluation"""
432
- print("🚀 === Test and Evaluation of the model on new dataset ===")
433
-
434
- # Load model
435
- print("🔧 Loading the model...")
436
- model, checkpoint = load_trained_model(config.main_model_path, config.device)
437
-
438
- # Create processor
439
- processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
440
-
441
- # Load new dataset
442
- print("📊 Loading the new dataset...")
443
- df = pd.read_csv(config.local_dataset_path) # replace local_dataset_path with a new df
444
-
445
- print("\n" + "="*80)
446
- print("🎨 COLOR-TO-COLOR CLASSIFICATION (Control Test)")
447
- print("="*80)
448
-
449
- # Create dataset without loading images
450
- dataset_color = CustomCSVDataset(df, load_images=False)
451
-
452
- # 0. Evaluation encoding ONLY the color (control test)
453
- true_labels_color, predicted_labels_color, accuracy_color = evaluate_custom_csv_accuracy_color_only(
454
- model, dataset_color, processor
455
- )
456
-
457
- # Confusion matrix for color-only
458
- confusion_matrix_color = plot_confusion_matrix(
459
- true_labels_color, predicted_labels_color,
460
- save_path="confusion_matrix_color_only.png",
461
- title_suffix="color-only"
462
- )
463
-
464
- print("\n" + "="*80)
465
- print("📝 TEXT-TO-TEXT CLASSIFICATION")
466
- print("="*80)
467
-
468
- # Create dataset without loading images for text-to-text
469
- dataset_text = CustomCSVDataset(df, load_images=False)
470
-
471
- # 1. Evaluation of the accuracy (text-to-text)
472
- true_labels_text, predicted_labels_text, accuracy_text = evaluate_custom_csv_accuracy(
473
- model, dataset_text, processor, method='similarity'
474
- )
475
-
476
- # 2. Confusion matrix for text
477
- confusion_matrix_text = plot_confusion_matrix(
478
- true_labels_text, predicted_labels_text,
479
- save_path="confusion_matrix_text.png",
480
- title_suffix="text"
481
- )
482
-
483
- print("\n" + "="*80)
484
- print("🖼️ IMAGE-TO-TEXT CLASSIFICATION")
485
- print("="*80)
486
-
487
- # Create dataset with images loaded for image-to-text
488
- dataset_image = CustomCSVDataset(df, load_images=True)
489
-
490
- # 3. Evaluation of the accuracy (image-to-text)
491
- true_labels_image, predicted_labels_image, accuracy_image = evaluate_custom_csv_accuracy_image(
492
- model, dataset_image, processor, method='similarity'
493
- )
494
-
495
- # 4. Confusion matrix for images
496
- confusion_matrix_image = plot_confusion_matrix(
497
- true_labels_image, predicted_labels_image,
498
- save_path="confusion_matrix_image.png",
499
- title_suffix="image"
500
- )
501
-
502
- # 5. Summary comparison
503
- print("\n" + "="*80)
504
- print("📊 SUMMARY")
505
- print("="*80)
506
- print(f"🎨 Color-to-Color Accuracy (Control): {accuracy_color:.4f} ({accuracy_color*100:.2f}%)")
507
- print(f"📝 Text-to-Text Accuracy: {accuracy_text:.4f} ({accuracy_text*100:.2f}%)")
508
- print(f"🖼️ Image-to-Text Accuracy: {accuracy_image:.4f} ({accuracy_image*100:.2f}%)")
509
- print(f"\n📊 Analysis:")
510
- print(f" • Loss from full text vs color-only: {abs(accuracy_color - accuracy_text):.4f} ({abs(accuracy_color - accuracy_text)*100:.2f}%)")
511
- print(f" • Difference text vs image: {abs(accuracy_text - accuracy_image):.4f} ({abs(accuracy_text - accuracy_image)*100:.2f}%)")
512
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Evaluation/basic_test_generalized.py DELETED
@@ -1,425 +0,0 @@
1
- """
2
- Generalized evaluation of the main model with sub-module comparison.
3
- This file evaluates the main model's performance by comparing specialized parts
4
- (color and hierarchy) with corresponding specialized models. It calculates similarity
5
- matrices, linear projections between embedding spaces, and generates detailed statistics
6
- on alignment between different representations.
7
- """
8
-
9
- import os
10
- import json
11
- import argparse
12
- import config
13
- import torch
14
- import torch.nn.functional as F
15
- import pandas as pd
16
- from PIL import Image
17
- from torchvision import transforms
18
- from transformers import CLIPProcessor, CLIPModel as CLIPModelTransformers
19
- from tqdm.auto import tqdm
20
-
21
- # Local imports
22
- from color_model import ColorCLIP as ColorModel, ColorDataset, Tokenizer
23
- from config import color_model_path, color_emb_dim, device, hierarchy_model_path, hierarchy_emb_dim
24
- from hierarchy_model import Model as HierarchyModel, HierarchyExtractor
25
-
26
-
27
- def load_color_model(color_model_path, color_emb_dim, device):
28
- # Load color model
29
- color_checkpoint = torch.load(color_model_path, map_location=device, weights_only=True)
30
- color_model = ColorModel(vocab_size=39, embedding_dim=color_emb_dim).to(device)
31
- color_model.load_state_dict(color_checkpoint)
32
-
33
- # Load and set the tokenizer
34
- tokenizer = Tokenizer()
35
- with open(config.tokeniser_path, 'r') as f:
36
- vocab_dict = json.load(f)
37
- color_model.tokenizer = tokenizer
38
-
39
- color_model.eval()
40
- return color_model
41
-
42
-
43
- def get_emb_color_model(color_model, image_path_to_encode, text_to_encode):
44
- # Load and preprocess image
45
- image = Image.open(image_path_to_encode).convert('RGB')
46
-
47
- transform = transforms.Compose([
48
- transforms.Resize((224, 224)),
49
- transforms.ToTensor(),
50
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
51
- ])
52
-
53
- processed_image = transform(image)
54
-
55
- # Get embeddings
56
- processed_image_batch = processed_image.unsqueeze(0).to(device) # Shape: [1, 3, 224, 224]
57
- with torch.no_grad():
58
- image_emb = color_model.image_encoder(processed_image_batch)
59
-
60
- # Text embedding via tokenizer + text_encoder
61
- token_ids = torch.tensor([color_model.tokenizer(text_to_encode)], dtype=torch.long, device=device)
62
- lengths = torch.tensor([token_ids.size(1) if token_ids.dim() > 1 else token_ids.size(0)], dtype=torch.long, device=device)
63
- with torch.no_grad():
64
- txt_emb = color_model.text_encoder(token_ids, lengths)
65
-
66
- return image_emb, txt_emb
67
-
68
- def load_main_model(main_model_path, device):
69
- checkpoint = torch.load(main_model_path, map_location=device)
70
- main_model = CLIPModel_transformers.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
71
- state = checkpoint['model_state_dict'] if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint else checkpoint
72
- try:
73
- main_model.load_state_dict(state, strict=False)
74
- except Exception:
75
- # Fallback: filter matching keys
76
- model_state = main_model.state_dict()
77
- filtered = {k: v for k, v in state.items() if k in model_state and model_state[k].shape == v.shape}
78
- main_model.load_state_dict(filtered, strict=False)
79
- main_model.to(device)
80
- main_model.eval()
81
- processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
82
- return main_model, processor
83
-
84
-
85
- def load_hierarchy_model(hierarchy_model_path, device):
86
- checkpoint = torch.load(hierarchy_model_path, map_location=device)
87
- hierarchy_classes = checkpoint.get('hierarchy_classes', [])
88
- model = HierarchyModel(num_hierarchy_classes=len(hierarchy_classes), embed_dim=config.hierarchy_emb_dim).to(device)
89
- model.load_state_dict(checkpoint['model_state'])
90
- extractor = HierarchyExtractor(hierarchy_classes, verbose=False)
91
- model.set_hierarchy_extractor(extractor)
92
- model.eval()
93
- return model
94
-
95
-
96
- def get_emb_hierarchy_model(hierarchy_model, image_path_to_encode, text_to_encode):
97
- image = Image.open(image_path_to_encode).convert('RGB')
98
- transform = transforms.Compose([
99
- transforms.Resize((224, 224)),
100
- transforms.ToTensor(),
101
- ])
102
- image_tensor = transform(image).unsqueeze(0).to(device)
103
-
104
- with torch.no_grad():
105
- img_emb = hierarchy_model.get_image_embeddings(image_tensor)
106
- txt_emb = hierarchy_model.get_text_embeddings(text_to_encode)
107
-
108
- return img_emb, txt_emb
109
-
110
- def get_emb_main_model(main_model, processor, image_path_to_encode, text_to_encode):
111
- image = Image.open(image_path_to_encode).convert('RGB')
112
- transform = transforms.Compose([
113
- transforms.Resize((224, 224)),
114
- transforms.ToTensor(),
115
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
116
- ])
117
- image = transform(image)
118
- image = image.unsqueeze(0).to(device)
119
- # Prepare text inputs via processor
120
- text_inputs = processor(text=[text_to_encode], return_tensors="pt", padding=True)
121
- text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
122
- outputs = main_model(**text_inputs, pixel_values=image)
123
- text_emb = outputs.text_embeds
124
- image_emb = outputs.image_embeds
125
-
126
- return text_emb, image_emb
127
-
128
-
129
- if __name__ == '__main__':
130
- parser = argparse.ArgumentParser(description='Evaluate main model parts vs small models and build similarity matrices')
131
- parser.add_argument('--main-checkpoint', type=str, default='models/laion_explicable_model.pth')
132
- parser.add_argument('--color-checkpoint', type=str, default='models/color_model.pt')
133
- parser.add_argument('--csv', type=str, default='data/data_with_local_paths.csv')
134
- parser.add_argument('--color-emb-dim', type=int, default=16)
135
- parser.add_argument('--num-samples', type=int, default=200)
136
- parser.add_argument('--seed', type=int, default=42)
137
- parser.add_argument('--primary-metric', type=str, default='sim_color_txt_img',
138
- choices=['sim_txt_color_part', 'sim_img_color_part', 'sim_color_txt_img', 'sim_small_txt_img',
139
- 'sim_txt_hierarchy_part', 'sim_img_hierarchy_part'])
140
- parser.add_argument('--top-k', type=int, default=30)
141
- parser.add_argument('--heatmap', action='store_true')
142
- parser.add_argument('--l2-grid', type=str, default='1e-5,1e-4,1e-3,1e-2,1e-1')
143
- args = parser.parse_args()
144
-
145
- main_checkpoint = args.main_checkpoint
146
- color_checkpoint = args.color_checkpoint
147
- csv = args.csv
148
- color_emb_dim = args.color_emb_dim
149
- num_samples = args.num_samples
150
- seed = args.seed
151
- primary_metric = args.primary_metric
152
- top_k = args.top_k
153
- l2_grid = [float(x) for x in args.l2_grid.split(',') if x]
154
- device = torch.device("mps")
155
-
156
- df = pd.read_csv(csv)
157
-
158
- # Normalize colors (reduce aliasing and sparsity)
159
- def normalize_color(c):
160
- if pd.isna(c):
161
- return c
162
- s = str(c).strip().lower()
163
- aliases = {
164
- 'grey': 'gray',
165
- 'navy blue': 'navy',
166
- 'light blue': 'blue',
167
- 'dark blue': 'blue',
168
- 'light grey': 'gray',
169
- 'dark grey': 'gray',
170
- 'light gray': 'gray',
171
- 'dark gray': 'gray',
172
- }
173
- return aliases.get(s, s)
174
-
175
- if config.color_column in df.columns:
176
- df[config.color_column] = df[config.color_column].apply(normalize_color)
177
-
178
- color_model = load_color_model(color_checkpoint, color_emb_dim, device)
179
- main_model, processor = load_main_model(main_checkpoint, device)
180
- hierarchy_model = load_hierarchy_model(hierarchy_model_path, device)
181
-
182
- # Results container
183
- results = []
184
-
185
- # Accumulators for projection (A: main part, B: small model)
186
- color_txt_As, color_txt_Bs = [], []
187
- color_img_As, color_img_Bs = [], []
188
- hier_txt_As, hier_txt_Bs = [], []
189
- hier_img_As, hier_img_Bs = [], []
190
-
191
- # Ensure determinism for sampling
192
- pd.options.mode.copy_on_write = True
193
- rng = pd.Series(range(len(df)), dtype=int)
194
- _ = rng # silence lint
195
- torch.manual_seed(seed)
196
-
197
- unique_hiers = sorted(df[config.hierarchy_column].dropna().unique())
198
- unique_colors = sorted(df[config.color_column].dropna().unique())
199
-
200
- # Progress bar across all (hierarchy, color) pairs
201
- total_pairs = len(unique_hiers) * len(unique_colors)
202
- pair_pbar = tqdm(total=total_pairs, desc="Evaluating pairs", leave=False)
203
- for hierarchy in unique_hiers:
204
- for color in unique_colors:
205
- group = df[(df[config.hierarchy_column] == hierarchy) & (df[config.color_column] == color)]
206
-
207
- # Sample up to num_samples per (hierarchy, color)
208
- k = min(num_samples, len(group))
209
- group_iter = group.sample(n=k, random_state=seed) if len(group) > k else group.iloc[:k]
210
-
211
- # Progress bar for samples within the pair
212
- inner_pbar = tqdm(total=len(group_iter), desc=f"{hierarchy}/{color}", leave=False)
213
- for row_idx, (_, example) in enumerate(group_iter.iterrows()):
214
- try:
215
- image_emb, txt_emb = get_emb_color_model(color_model, example['local_image_path'], example['text'])
216
- image_emb_hier, txt_emb_hier = get_emb_hierarchy_model(hierarchy_model, example['local_image_path'], example['text'])
217
- text_emb_main_model, image_emb_main_model = get_emb_main_model(
218
- main_model, processor, example['local_image_path'], example['text']
219
- )
220
-
221
- color_part_txt = text_emb_main_model[:, :color_emb_dim]
222
- color_part_img = image_emb_main_model[:, :color_emb_dim]
223
- hier_part_txt = text_emb_main_model[:, color_emb_dim:color_emb_dim + hierarchy_emb_dim]
224
- hier_part_img = image_emb_main_model[:, color_emb_dim:color_emb_dim + hierarchy_emb_dim]
225
-
226
- # L2-normalize parts and small-model embeddings for stable cosine
227
- color_part_txt = F.normalize(color_part_txt, dim=1)
228
- color_part_img = F.normalize(color_part_img, dim=1)
229
- hier_part_txt = F.normalize(hier_part_txt, dim=1)
230
- hier_part_img = F.normalize(hier_part_img, dim=1)
231
- txt_emb = F.normalize(txt_emb, dim=1)
232
- image_emb = F.normalize(image_emb, dim=1)
233
- txt_emb_hier = F.normalize(txt_emb_hier, dim=1)
234
- image_emb_hier = F.normalize(image_emb_hier, dim=1)
235
-
236
- sim_txt_color_part = F.cosine_similarity(txt_emb, color_part_txt).item()
237
- sim_img_color_part = F.cosine_similarity(image_emb, color_part_img).item()
238
- sim_color_txt_img = F.cosine_similarity(color_part_txt, color_part_img).item()
239
- sim_small_txt_img = F.cosine_similarity(txt_emb, image_emb).item()
240
-
241
- sim_txt_hierarchy_part = F.cosine_similarity(txt_emb_hier, hier_part_txt).item()
242
- sim_img_hierarchy_part = F.cosine_similarity(image_emb_hier, hier_part_img).item()
243
-
244
- # Accumulate for projection fitting later
245
- color_txt_As.append(color_part_txt.squeeze(0).detach().cpu())
246
- color_txt_Bs.append(txt_emb.squeeze(0).detach().cpu())
247
- color_img_As.append(color_part_img.squeeze(0).detach().cpu())
248
- color_img_Bs.append(image_emb.squeeze(0).detach().cpu())
249
-
250
- hier_txt_As.append(hier_part_txt.squeeze(0).detach().cpu())
251
- hier_txt_Bs.append(txt_emb_hier.squeeze(0).detach().cpu())
252
- hier_img_As.append(hier_part_img.squeeze(0).detach().cpu())
253
- hier_img_Bs.append(image_emb_hier.squeeze(0).detach().cpu())
254
-
255
- results.append({
256
- 'hierarchy' "hierarchy",
257
- 'color': color,
258
- 'row_index': int(row_idx),
259
- 'sim_txt_color_part': float(sim_txt_color_part),
260
- 'sim_img_color_part': float(sim_img_color_part),
261
- 'sim_color_txt_img': float(sim_color_txt_img),
262
- 'sim_small_txt_img': float(sim_small_txt_img),
263
- 'sim_txt_hierarchy_part': float(sim_txt_hierarchy_part),
264
- 'sim_img_hierarchy_part': float(sim_img_hierarchy_part),
265
- })
266
- except Exception as e:
267
- print(f"Skipping example due to error: {e}")
268
- finally:
269
- inner_pbar.update(1)
270
- inner_pbar.close()
271
- pair_pbar.update(1)
272
- pair_pbar.close()
273
-
274
- results_df = pd.DataFrame(results)
275
-
276
- # Save raw results
277
- os.makedirs('evaluation_outputs', exist_ok=True)
278
- raw_path = os.path.join('evaluation_outputs', 'similarities_raw.csv')
279
- results_df.to_csv(raw_path, index=False)
280
- print(f"Saved raw similarities to {raw_path}")
281
-
282
- # Intelligent averages
283
- metrics = ['sim_txt_color_part', 'sim_img_color_part', 'sim_color_txt_img', 'sim_small_txt_img',
284
- 'sim_txt_hierarchy_part', 'sim_img_hierarchy_part']
285
-
286
- # Overall means
287
- overall_means = results_df[metrics].mean().to_frame(name='mean').T
288
- overall_means.insert(0, 'level', 'overall')
289
-
290
- # By hierarchy
291
- by_hierarchy = results_df.groupby(config.hierarchy_column)[metrics].mean().reset_index()
292
- by_hierarchy.insert(0, 'level', config.hierarchy_column)
293
-
294
- # By color
295
- by_color = results_df.groupby(config.color_column)[metrics].mean().reset_index()
296
- by_color.insert(0, 'level', config.color_column)
297
-
298
- # By hierarchy+color
299
- by_pair = results_df.groupby([config.hierarchy_column, config.color_column])[metrics].mean().reset_index()
300
- by_pair.insert(0, 'level', 'hierarchy_color')
301
-
302
- summary_df = pd.concat([overall_means, by_hierarchy, by_color, by_pair], ignore_index=True)
303
- summary_path = os.path.join('evaluation_outputs', 'similarities_summary.csv')
304
- summary_df.to_csv(summary_path, index=False)
305
- print(f"Saved summary statistics to {summary_path}")
306
-
307
- # =====================
308
- # Similarity matrices for best hierarchy-color combinations
309
- # =====================
310
- try:
311
- by_pair_core = results_df.groupby([config.hierarchy_column, config.color_column])[metrics].mean().reset_index()
312
- top_pairs = by_pair_core.nlargest(top_k, primary_metric)
313
- matrix = top_pairs.pivot(index=config.hierarchy_column, columns=config.color_column, values=primary_metric)
314
- os.makedirs('evaluation_outputs', exist_ok=True)
315
- matrix_csv_path = os.path.join('evaluation_outputs', f'similarity_matrix_{primary_metric}_top{top_k}.csv')
316
- matrix.to_csv(matrix_csv_path)
317
- print(f"Saved similarity matrix to {matrix_csv_path}")
318
-
319
- if args.heatmap:
320
- try:
321
- import seaborn as sns
322
- import matplotlib.pyplot as plt
323
- plt.figure(figsize=(max(6, 0.5 * len(matrix.columns)), max(4, 0.5 * len(matrix.index))))
324
- sns.heatmap(matrix, annot=False, cmap='viridis')
325
- plt.title(f'Similarity matrix (top {top_k}) - {primary_metric}')
326
- heatmap_path = os.path.join('evaluation_outputs', f'similarity_matrix_{primary_metric}_top{top_k}.png')
327
- plt.tight_layout()
328
- plt.savefig(heatmap_path, dpi=200)
329
- plt.close()
330
- print(f"Saved similarity heatmap to {heatmap_path}")
331
- except Exception as e:
332
- print(f"Skipping heatmap generation: {e}")
333
- except Exception as e:
334
- print(f"Skipping matrix generation: {e}")
335
-
336
- # =====================
337
- # Learn projections A->B and report projected cosine means
338
- # =====================
339
- def fit_ridge_projection(A, B, l2_reg=1e-3):
340
- # A: [N, D_in], B: [N, D_out]
341
- A = torch.stack(A) # [N, D_in]
342
- B = torch.stack(B) # [N, D_out]
343
- # Closed-form ridge: W = (A^T A + λI)^-1 A^T B
344
- AtA = A.T @ A
345
- D_in = AtA.shape[0]
346
- AtA_reg = AtA + l2_reg * torch.eye(D_in)
347
- W = torch.linalg.solve(AtA_reg, A.T @ B)
348
- return W # [D_in, D_out]
349
-
350
- def fit_ridge_with_cv(A, B, l2_values):
351
- # Simple holdout CV: 80/20 split
352
- if len(A) < 10:
353
- # Not enough data for split; fallback to middle lambda
354
- best_l2 = l2_values[min(len(l2_values) // 2, len(l2_values)-1)]
355
- W = fit_ridge_projection(A, B, best_l2)
356
- return W, best_l2, None
357
-
358
- N = len(A)
359
- idx = torch.randperm(N)
360
- split = int(0.8 * N)
361
- train_idx = idx[:split]
362
- val_idx = idx[split:]
363
-
364
- A_tensor = torch.stack(A)
365
- B_tensor = torch.stack(B)
366
-
367
- A_train, B_train = A_tensor[train_idx], B_tensor[train_idx]
368
- A_val, B_val = A_tensor[val_idx], B_tensor[val_idx]
369
-
370
- def to_list(t):
371
- return [row for row in t]
372
-
373
- best_l2 = None
374
- best_score = -1.0
375
- for l2 in l2_values:
376
- W = fit_ridge_projection(to_list(A_train), to_list(B_train), l2)
377
- score = mean_projected_cosine(to_list(A_val), to_list(B_val), W)
378
- if score > best_score:
379
- best_score = score
380
- best_l2 = l2
381
-
382
- # Refit on all with best_l2
383
- W_best = fit_ridge_projection(A, B, best_l2)
384
- return W_best, best_l2, best_score
385
-
386
- def mean_projected_cosine(A, B, W):
387
- A = torch.stack(A)
388
- B = torch.stack(B)
389
- A_proj = A @ W
390
- A_proj = F.normalize(A_proj, dim=1)
391
- B = F.normalize(B, dim=1)
392
- return torch.mean(torch.sum(A_proj * B, dim=1)).item()
393
-
394
- projection_report = {}
395
-
396
- if len(color_txt_As) >= 8:
397
- W_ct, best_l2_ct, cv_ct = fit_ridge_with_cv(color_txt_As, color_txt_Bs, l2_grid)
398
- projection_report['proj_sim_txt_color_part_mean'] = mean_projected_cosine(color_txt_As, color_txt_Bs, W_ct)
399
- projection_report['proj_txt_color_part_best_l2'] = best_l2_ct
400
- if cv_ct is not None:
401
- projection_report['proj_txt_color_part_cv_val'] = cv_ct
402
- if len(color_img_As) >= 8:
403
- W_ci, best_l2_ci, cv_ci = fit_ridge_with_cv(color_img_As, color_img_Bs, l2_grid)
404
- projection_report['proj_sim_img_color_part_mean'] = mean_projected_cosine(color_img_As, color_img_Bs, W_ci)
405
- projection_report['proj_img_color_part_best_l2'] = best_l2_ci
406
- if cv_ci is not None:
407
- projection_report['proj_img_color_part_cv_val'] = cv_ci
408
- if len(hier_txt_As) >= 8:
409
- W_ht, best_l2_ht, cv_ht = fit_ridge_with_cv(hier_txt_As, hier_txt_Bs, l2_grid)
410
- projection_report['proj_sim_txt_hierarchy_part_mean'] = mean_projected_cosine(hier_txt_As, hier_txt_Bs, W_ht)
411
- projection_report['proj_txt_hierarchy_part_best_l2'] = best_l2_ht
412
- if cv_ht is not None:
413
- projection_report['proj_txt_hierarchy_part_cv_val'] = cv_ht
414
- if len(hier_img_As) >= 8:
415
- W_hi, best_l2_hi, cv_hi = fit_ridge_with_cv(hier_img_As, hier_img_Bs, l2_grid)
416
- projection_report['proj_sim_img_hierarchy_part_mean'] = mean_projected_cosine(hier_img_As, hier_img_Bs, W_hi)
417
- projection_report['proj_img_hierarchy_part_best_l2'] = best_l2_hi
418
- if cv_hi is not None:
419
- projection_report['proj_img_hierarchy_part_cv_val'] = cv_hi
420
-
421
- proj_summary_path = os.path.join('evaluation_outputs', 'projection_summary.json')
422
- with open(proj_summary_path, 'w') as f:
423
- json.dump(projection_report, f, indent=2)
424
- print(f"Saved projection summary to {proj_summary_path}")
425
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Evaluation/evaluate_color_embeddings.py DELETED
@@ -1,1124 +0,0 @@
1
- """
2
- Comprehensive evaluation of color embeddings with Fashion-CLIP comparison.
3
- This file evaluates the quality of color embeddings generated by the ColorCLIP model
4
- by calculating intra-class and inter-class similarity metrics, classification accuracies,
5
- and generating confusion matrices. It also compares results with Fashion-CLIP as a baseline
6
- to measure relative performance.
7
- """
8
-
9
- import torch
10
- import torch.nn as nn
11
- import pandas as pd
12
- import numpy as np
13
- import matplotlib.pyplot as plt
14
- import seaborn as sns
15
- from sklearn.metrics.pairwise import cosine_similarity
16
- from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
17
- from collections import defaultdict
18
- import os
19
- import json
20
- from tqdm import tqdm
21
- from torch.utils.data import Dataset, DataLoader
22
- from torchvision import transforms
23
- import requests
24
- from io import BytesIO
25
- from PIL import Image
26
- import warnings
27
- warnings.filterwarnings('ignore')
28
- from color_model import ColorCLIP, Tokenizer, ImageEncoder, TextEncoder, collate_batch
29
- from torch.utils.data import DataLoader
30
- from transformers import CLIPProcessor, CLIPModel as TransformersCLIPModel
31
- import config
32
-
33
- class ColorDataset(Dataset):
34
- """
35
- Dataset class for color embedding evaluation.
36
-
37
- Handles loading images from various sources (local paths, URLs, bytes) and
38
- applying appropriate transformations for evaluation.
39
- """
40
- def __init__(self, dataframe):
41
- """
42
- Initialize the color dataset.
43
-
44
- Args:
45
- dataframe: DataFrame containing image paths/URLs, text, and color labels
46
- """
47
- self.dataframe = dataframe
48
- self.transform = transforms.Compose([
49
- transforms.Resize((224, 224)),
50
- transforms.ToTensor(),
51
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
52
- ])
53
-
54
- def __len__(self):
55
- return len(self.dataframe)
56
-
57
- def __getitem__(self, idx):
58
- row = self.dataframe.iloc[idx]
59
-
60
- # Handle image - it should be in row[config.column_url_image] and contain the image data
61
- image_data = row[config.column_url_image]
62
-
63
- try:
64
- # Check if image_data has 'bytes' key or is already PIL Image
65
- if isinstance(image_data, dict) and 'bytes' in image_data:
66
- image = Image.open(BytesIO(image_data['bytes'])).convert("RGB")
67
- elif hasattr(image_data, 'convert'): # Already a PIL Image
68
- image = image_data.convert("RGB")
69
- elif isinstance(image_data, str):
70
- # It's a file path (local or URL)
71
- if image_data.startswith('http'):
72
- # It's a URL - download the image
73
- response = requests.get(image_data, timeout=10)
74
- response.raise_for_status()
75
- image = Image.open(BytesIO(response.content)).convert("RGB")
76
- else:
77
- # It's a local file path
78
- image = Image.open(image_data).convert("RGB")
79
- else:
80
- # Assume it's bytes data
81
- image = Image.open(BytesIO(image_data)).convert("RGB")
82
-
83
- # Apply transform
84
- image = self.transform(image)
85
-
86
- except Exception as e:
87
- print(f"⚠️ Failed to load image {idx}: {e}")
88
- # Return a placeholder image
89
- image = torch.zeros(3, 224, 224)
90
-
91
- # Get text and color
92
- description = row[config.text_column]
93
- color = row[config.color_column]
94
-
95
- return image, description, color
96
-
97
- class EmbeddingEvaluator:
98
- """
99
- Evaluator for color embeddings generated by the ColorCLIP model.
100
-
101
- This class provides methods to evaluate the quality of color embeddings by computing
102
- similarity metrics, classification accuracies, and generating visualizations.
103
- """
104
-
105
- def __init__(self, model_path, embed_dim):
106
- """
107
- Initialize the embedding evaluator.
108
-
109
- Args:
110
- model_path: Path to the trained ColorCLIP model checkpoint
111
- embed_dim: Embedding dimension for the model
112
- """
113
- self.device = config.device
114
-
115
- # Initialize tokenizer first to get vocab size
116
- self.tokenizer = Tokenizer()
117
- vocab_size = None
118
-
119
- # Load vocabulary if available to determine vocab_size
120
- if os.path.exists(config.tokeniser_path):
121
- with open(config.tokeniser_path, 'r') as f:
122
- vocab_dict = json.load(f)
123
- # Manually load vocabulary
124
- self.tokenizer.word2idx = defaultdict(lambda: 0, {k: int(v) for k, v in vocab_dict.items()})
125
- self.tokenizer.idx2word = {int(v): k for k, v in vocab_dict.items() if int(v) > 0}
126
- self.tokenizer.counter = max(self.tokenizer.word2idx.values(), default=0) + 1
127
- vocab_size = self.tokenizer.counter
128
- print(f"Tokenizer vocabulary loaded from {config.tokeniser_path}")
129
- else:
130
- print(f"Warning: {config.tokeniser_path} not found. Using default tokenizer.")
131
-
132
- # Load checkpoint to get vocab_size and state_dict
133
- checkpoint = None
134
- if os.path.exists(model_path):
135
- checkpoint = torch.load(model_path, map_location=self.device)
136
-
137
- # Try to get vocab_size from model checkpoint if not already determined
138
- if vocab_size is None:
139
- # Try to get vocab_size from metadata
140
- if isinstance(checkpoint, dict) and 'vocab_size' in checkpoint:
141
- vocab_size = checkpoint['vocab_size']
142
- # Otherwise, try to infer from model state dict
143
- elif isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
144
- state_dict = checkpoint['model_state_dict']
145
- if 'text_encoder.embedding.weight' in state_dict:
146
- vocab_size = state_dict['text_encoder.embedding.weight'].shape[0]
147
- elif isinstance(checkpoint, dict) and 'text_encoder.embedding.weight' in checkpoint:
148
- vocab_size = checkpoint['text_encoder.embedding.weight'].shape[0]
149
-
150
- # Fallback to default if still not determined
151
- if vocab_size is None:
152
- vocab_size = 39 # Default fallback
153
- print(f"Warning: Could not determine vocab_size, using default: {vocab_size}")
154
-
155
- # Initialize model with determined vocab_size
156
- self.model = ColorCLIP(vocab_size=vocab_size, embedding_dim=embed_dim).to(self.device)
157
-
158
- # Load trained model state dict
159
- if checkpoint is not None:
160
- state_dict = checkpoint.get('model_state_dict', checkpoint)
161
- self.model.load_state_dict(state_dict)
162
- print(f"Model loaded from {model_path}")
163
- else:
164
- print(f"Warning: Model file {model_path} not found. Using untrained model.")
165
-
166
- self.model.eval()
167
-
168
- def extract_embeddings(self, dataloader, embedding_type='text'):
169
- """
170
- Extract embeddings from the model for a given dataloader.
171
-
172
- Args:
173
- dataloader: DataLoader containing images, texts, and colors
174
- embedding_type: Type of embeddings to extract ('text', 'image', or 'color')
175
-
176
- Returns:
177
- Tuple of (embeddings array, labels list, texts list)
178
- """
179
- all_embeddings = []
180
- all_labels = []
181
- all_texts = []
182
-
183
- with torch.no_grad():
184
- for images, texts, colors in tqdm(dataloader, desc=f"Extracting {embedding_type} embeddings"):
185
- if embedding_type == 'text':
186
- # Tokenize texts using the tokenizer
187
- tokenized_texts = [self.tokenizer(text) for text in texts]
188
- # Convert to tensors and pad sequences
189
- text_tensors = [torch.tensor(t, dtype=torch.long) for t in tokenized_texts]
190
- text_tokens = nn.utils.rnn.pad_sequence(text_tensors, batch_first=True, padding_value=0).to(self.device)
191
- lengths = torch.tensor([len(t) for t in tokenized_texts], dtype=torch.long).to(self.device)
192
- embeddings = self.model.text_encoder(text_tokens, lengths)
193
- labels = colors
194
- elif embedding_type == 'image':
195
- images = images.to(self.device)
196
- embeddings = self.model.image_encoder(images)
197
- labels = colors
198
- elif embedding_type == 'color':
199
- # Tokenize color names using the tokenizer
200
- tokenized_colors = [self.tokenizer(color) for color in colors]
201
- # Convert to tensors and pad sequences
202
- color_tensors = [torch.tensor(t, dtype=torch.long) for t in tokenized_colors]
203
- color_tokens = nn.utils.rnn.pad_sequence(color_tensors, batch_first=True, padding_value=0).to(self.device)
204
- lengths = torch.tensor([len(t) for t in tokenized_colors], dtype=torch.long).to(self.device)
205
- embeddings = self.model.text_encoder(color_tokens, lengths)
206
- labels = colors
207
-
208
- all_embeddings.append(embeddings.cpu().numpy())
209
- all_labels.extend(labels)
210
- all_texts.extend(texts)
211
-
212
- return np.vstack(all_embeddings), all_labels, all_texts
213
-
214
- def compute_similarity_metrics(self, embeddings, labels):
215
- """Compute intra-class and inter-class similarities"""
216
- similarities = cosine_similarity(embeddings)
217
-
218
- # Group embeddings by color
219
- color_groups = defaultdict(list)
220
- for i, color in enumerate(labels):
221
- color_groups[color].append(i)
222
-
223
- # Calculate intra-class similarities (same color)
224
- intra_class_similarities = []
225
- for color, indices in color_groups.items():
226
- if len(indices) > 1:
227
- for i in range(len(indices)):
228
- for j in range(i+1, len(indices)):
229
- sim = similarities[indices[i], indices[j]]
230
- intra_class_similarities.append(sim)
231
-
232
- # Calculate inter-class similarities (different colors)
233
- inter_class_similarities = []
234
- colors = list(color_groups.keys())
235
- for i in range(len(colors)):
236
- for j in range(i+1, len(colors)):
237
- color1_indices = color_groups[colors[i]]
238
- color2_indices = color_groups[colors[j]]
239
-
240
- for idx1 in color1_indices:
241
- for idx2 in color2_indices:
242
- sim = similarities[idx1, idx2]
243
- inter_class_similarities.append(sim)
244
-
245
- # Calculate classification accuracy using nearest neighbor in embedding space
246
- nn_accuracy = self.compute_embedding_accuracy(embeddings, labels, similarities)
247
-
248
- # Calculate classification accuracy using centroids
249
- centroid_accuracy = self.compute_centroid_accuracy(embeddings, labels)
250
-
251
- return {
252
- 'intra_class_similarities': intra_class_similarities,
253
- 'inter_class_similarities': inter_class_similarities,
254
- 'intra_class_mean': np.mean(intra_class_similarities) if intra_class_similarities else 0,
255
- 'inter_class_mean': np.mean(inter_class_similarities) if inter_class_similarities else 0,
256
- 'separation_score': np.mean(intra_class_similarities) - np.mean(inter_class_similarities) if intra_class_similarities and inter_class_similarities else 0,
257
- 'accuracy': nn_accuracy,
258
- 'centroid_accuracy': centroid_accuracy
259
- }
260
-
261
- def compute_embedding_accuracy(self, embeddings, labels, similarities):
262
- """Compute classification accuracy using nearest neighbor in embedding space"""
263
- correct_predictions = 0
264
- total_predictions = len(labels)
265
-
266
- for i in range(len(embeddings)):
267
- true_label = labels[i]
268
-
269
- # Find the most similar embedding (excluding itself)
270
- similarities_row = similarities[i].copy()
271
- similarities_row[i] = -1 # Exclude self-similarity
272
- nearest_neighbor_idx = np.argmax(similarities_row)
273
- predicted_label = labels[nearest_neighbor_idx]
274
-
275
- if predicted_label == true_label:
276
- correct_predictions += 1
277
-
278
- return correct_predictions / total_predictions if total_predictions > 0 else 0
279
-
280
- def compute_centroid_accuracy(self, embeddings, labels):
281
- """Compute classification accuracy using color centroids"""
282
- # Create centroids for each color
283
- unique_colors = list(set(labels))
284
- centroids = {}
285
-
286
- for color in unique_colors:
287
- color_indices = [i for i, label in enumerate(labels) if label == color]
288
- color_embeddings = embeddings[color_indices]
289
- centroids[color] = np.mean(color_embeddings, axis=0)
290
-
291
- # Classify each embedding to nearest centroid
292
- correct_predictions = 0
293
- total_predictions = len(labels)
294
-
295
- for i, embedding in enumerate(embeddings):
296
- true_label = labels[i]
297
-
298
- # Find closest centroid
299
- best_similarity = -1
300
- predicted_label = None
301
-
302
- for color, centroid in centroids.items():
303
- similarity = cosine_similarity([embedding], [centroid])[0][0]
304
- if similarity > best_similarity:
305
- best_similarity = similarity
306
- predicted_label = color
307
-
308
- if predicted_label == true_label:
309
- correct_predictions += 1
310
-
311
- return correct_predictions / total_predictions if total_predictions > 0 else 0
312
-
313
- def predict_colors_from_embeddings(self, embeddings, labels):
314
- """Predict colors from embeddings using centroid-based classification"""
315
- # Create color centroids from training data
316
- unique_colors = list(set(labels))
317
- centroids = {}
318
-
319
- for color in unique_colors:
320
- color_indices = [i for i, label in enumerate(labels) if label == color]
321
- color_embeddings = embeddings[color_indices]
322
- centroids[color] = np.mean(color_embeddings, axis=0)
323
-
324
- # Predict colors for all embeddings
325
- predictions = []
326
-
327
- for i, embedding in enumerate(embeddings):
328
- # Find closest centroid
329
- best_similarity = -1
330
- predicted_color = None
331
-
332
- for color, centroid in centroids.items():
333
- similarity = cosine_similarity([embedding], [centroid])[0][0]
334
- if similarity > best_similarity:
335
- best_similarity = similarity
336
- predicted_color = color
337
-
338
- predictions.append(predicted_color)
339
-
340
- return predictions
341
-
342
- def create_confusion_matrix(self, true_labels, predicted_labels, title="Confusion Matrix"):
343
- """Create and plot confusion matrix"""
344
- # Get unique labels
345
- unique_labels = sorted(list(set(true_labels + predicted_labels)))
346
-
347
- # Create confusion matrix
348
- cm = confusion_matrix(true_labels, predicted_labels, labels=unique_labels)
349
-
350
- # Calculate accuracy
351
- accuracy = accuracy_score(true_labels, predicted_labels)
352
-
353
- # Plot confusion matrix
354
- plt.figure(figsize=(12, 10))
355
- sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
356
- xticklabels=unique_labels, yticklabels=unique_labels)
357
- plt.title(f'{title}\nAccuracy: {accuracy:.3f} ({accuracy*100:.1f}%)')
358
- plt.ylabel('True Color')
359
- plt.xlabel('Predicted Color')
360
- plt.xticks(rotation=45)
361
- plt.yticks(rotation=0)
362
- plt.tight_layout()
363
-
364
- return plt.gcf(), accuracy, cm
365
-
366
- def evaluate_classification_performance(self, embeddings, labels, embedding_type="Embeddings"):
367
- """Evaluate classification performance and create confusion matrix"""
368
- # Predict colors
369
- predictions = self.predict_colors_from_embeddings(embeddings, labels)
370
-
371
- # Calculate accuracy
372
- accuracy = accuracy_score(labels, predictions)
373
-
374
- # Create confusion matrix
375
- fig, acc, cm = self.create_confusion_matrix(labels, predictions,
376
- f"{embedding_type} - Color Classification")
377
-
378
- # Generate classification report
379
- unique_labels = sorted(list(set(labels)))
380
- report = classification_report(labels, predictions, labels=unique_labels,
381
- target_names=unique_labels, output_dict=True)
382
-
383
- return {
384
- 'accuracy': accuracy,
385
- 'predictions': predictions,
386
- 'confusion_matrix': cm,
387
- 'classification_report': report,
388
- 'figure': fig
389
- }
390
-
391
- def evaluate_dataset(self, dataframe, dataset_name="Dataset"):
392
- """
393
- Evaluate embeddings on a given dataset.
394
-
395
- This method extracts embeddings for text, image, and color, computes similarity metrics,
396
- evaluates classification performance, and saves confusion matrices.
397
-
398
- Args:
399
- dataframe: DataFrame containing the dataset
400
- dataset_name: Name of the dataset for display purposes
401
-
402
- Returns:
403
- Dictionary containing evaluation results for text, image, and color embeddings
404
- """
405
- print(f"\n{'='*60}")
406
- print(f"Evaluating {dataset_name}")
407
- print(f"{'='*60}")
408
-
409
- # Create dataset and dataloader - use KaglDataset for kagl data
410
- if "kagl" in dataset_name.lower():
411
- dataset = KaglDataset(dataframe)
412
- else:
413
- dataset = ColorDataset(dataframe)
414
- # Optimize batch size and workers for faster processing
415
- dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=4, pin_memory=True)
416
-
417
- results = {}
418
-
419
- # Evaluate text embeddings
420
- text_embeddings, text_labels, texts = self.extract_embeddings(dataloader, 'text')
421
- text_metrics = self.compute_similarity_metrics(text_embeddings, text_labels)
422
- text_classification = self.evaluate_classification_performance(text_embeddings, text_labels, "Text Embeddings")
423
- text_metrics.update(text_classification)
424
- results['text'] = text_metrics
425
-
426
- # Evaluate image embeddings
427
- image_embeddings, image_labels, _ = self.extract_embeddings(dataloader, 'image')
428
- image_metrics = self.compute_similarity_metrics(image_embeddings, image_labels)
429
- image_classification = self.evaluate_classification_performance(image_embeddings, image_labels, "Image Embeddings")
430
- image_metrics.update(image_classification)
431
- results['image'] = image_metrics
432
-
433
- # Evaluate color embeddings
434
- color_embeddings, color_labels, _ = self.extract_embeddings(dataloader, 'color')
435
- color_metrics = self.compute_similarity_metrics(color_embeddings, color_labels)
436
- color_classification = self.evaluate_classification_performance(color_embeddings, color_labels, "Color Embeddings")
437
- color_metrics.update(color_classification)
438
- results['color'] = color_metrics
439
-
440
- # Print results
441
- print(f"\n{dataset_name} Results:")
442
- print("-" * 40)
443
- for emb_type, metrics in results.items():
444
- print(f"{emb_type.capitalize()} Embeddings:")
445
- print(f" Intra-class similarity (same color): {metrics['intra_class_mean']:.4f}")
446
- print(f" Inter-class similarity (diff colors): {metrics['inter_class_mean']:.4f}")
447
- print(f" Separation score: {metrics['separation_score']:.4f}")
448
- print(f" Nearest Neighbor Accuracy: {metrics['accuracy']:.4f} ({metrics['accuracy']*100:.1f}%)")
449
- print(f" Centroid Accuracy: {metrics['centroid_accuracy']:.4f} ({metrics['centroid_accuracy']*100:.1f}%)")
450
-
451
- # Classification report summary
452
- report = metrics['classification_report']
453
- print(f" 📊 Classification Performance:")
454
- print(f" • Macro Avg F1-Score: {report['macro avg']['f1-score']:.4f}")
455
- print(f" • Weighted Avg F1-Score: {report['weighted avg']['f1-score']:.4f}")
456
- print(f" • Support: {report['macro avg']['support']:.0f} samples")
457
- print()
458
-
459
- # Create visualizations
460
- os.makedirs('embedding_evaluation', exist_ok=True)
461
-
462
- # Confusion matrices
463
- results['text']['figure'].savefig(f'embedding_evaluation/{dataset_name.lower()}_text_confusion_matrix.png', dpi=300, bbox_inches='tight')
464
- plt.close(results['text']['figure'])
465
-
466
- results['image']['figure'].savefig(f'embedding_evaluation/{dataset_name.lower()}_image_confusion_matrix.png', dpi=300, bbox_inches='tight')
467
- plt.close(results['image']['figure'])
468
-
469
- results['color']['figure'].savefig(f'embedding_evaluation/{dataset_name.lower()}_color_confusion_matrix.png', dpi=300, bbox_inches='tight')
470
- plt.close(results['color']['figure'])
471
-
472
- return results
473
-
474
- class FashionCLIPDataset(Dataset):
475
- """
476
- Special dataset for Fashion-CLIP that doesn't normalize images.
477
-
478
- This dataset is used when evaluating with Fashion-CLIP baseline model,
479
- which requires different image preprocessing (no normalization).
480
- """
481
- def __init__(self, dataframe):
482
- """
483
- Initialize the Fashion-CLIP dataset.
484
-
485
- Args:
486
- dataframe: DataFrame containing image paths/URLs, text, and color labels
487
- """
488
- self.dataframe = dataframe
489
- # Only resize and convert to tensor, no normalization
490
- self.transform = transforms.Compose([
491
- transforms.Resize((224, 224)),
492
- transforms.ToTensor()
493
- ])
494
-
495
- def __len__(self):
496
- return len(self.dataframe)
497
-
498
- def __getitem__(self, idx):
499
- row = self.dataframe.iloc[idx]
500
-
501
- # Handle image - it should be in row[config.column_url_image] and contain the image data
502
- image_data = row[config.column_url_image]
503
-
504
- try:
505
- # Check if image_data has 'bytes' key or is already PIL Image
506
- if isinstance(image_data, dict) and 'bytes' in image_data:
507
- image = Image.open(BytesIO(image_data['bytes'])).convert("RGB")
508
- elif hasattr(image_data, 'convert'): # Already a PIL Image
509
- image = image_data.convert("RGB")
510
- elif isinstance(image_data, str):
511
- # It's a file path (local or URL)
512
- if image_data.startswith('http'):
513
- # It's a URL - download the image
514
- import requests
515
- response = requests.get(image_data, timeout=10)
516
- response.raise_for_status()
517
- image = Image.open(BytesIO(response.content)).convert("RGB")
518
- else:
519
- # It's a local file path
520
- image = Image.open(image_data).convert("RGB")
521
- else:
522
- # Assume it's bytes data
523
- image = Image.open(BytesIO(image_data)).convert("RGB")
524
-
525
- # Apply minimal transform (no normalization)
526
- image = self.transform(image)
527
-
528
- except Exception as e:
529
- print(f"⚠️ Failed to load image {idx}: {e}")
530
- # Return a placeholder image instead of undefined variable
531
- image = torch.zeros(3, 224, 224)
532
-
533
- # Get text and color
534
- description = row[config.text_column]
535
- color = row[config.color_column]
536
-
537
- return image, description, color
538
-
539
- class FashionCLIPEvaluator:
540
- """
541
- Evaluator for Fashion-CLIP baseline model.
542
-
543
- This class provides methods to evaluate embeddings from the Fashion-CLIP model
544
- and compare them with the custom ColorCLIP model.
545
- """
546
-
547
- def __init__(self):
548
- """
549
- Initialize the Fashion-CLIP evaluator.
550
-
551
- Loads the Fashion-CLIP model from Hugging Face and prepares it for evaluation.
552
- """
553
- # Load Fashion-CLIP model
554
- patrick_model_name = "patrickjohncyh/fashion-clip"
555
- print(f"🔄 Loading Fashion-CLIP model: {patrick_model_name}")
556
- self.processor = CLIPProcessor.from_pretrained(patrick_model_name)
557
- self.device = config.device
558
- self.model = TransformersCLIPModel.from_pretrained(patrick_model_name).to(self.device)
559
- self.model.eval()
560
- print(f"✅ Fashion-CLIP model loaded successfully")
561
-
562
- def extract_embeddings(self, dataloader, embedding_type='text'):
563
- """
564
- Extract embeddings from the Fashion-CLIP model.
565
-
566
- Args:
567
- dataloader: DataLoader containing images, texts, and colors
568
- embedding_type: Type of embeddings to extract ('text', 'image', or 'color')
569
-
570
- Returns:
571
- Tuple of (embeddings array, labels list, texts list)
572
- """
573
- all_embeddings = []
574
- all_labels = []
575
- all_texts = []
576
-
577
- with torch.no_grad():
578
- for images, texts, colors in tqdm(dataloader, desc=f"Extracting {embedding_type} embeddings (Fashion-CLIP)"):
579
- if embedding_type == 'text':
580
- # Process text through Fashion-CLIP
581
- inputs = self.processor(text=texts, return_tensors="pt", padding=True, truncation=True, max_length=77)
582
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
583
- text_features = self.model.get_text_features(**inputs)
584
- text_features = text_features / text_features.norm(dim=-1, keepdim=True)
585
- embeddings = text_features.cpu().numpy()
586
- labels = colors
587
- elif embedding_type == 'image':
588
- # Convert tensors back to PIL images for CLIP processor
589
- pil_images = []
590
- for i in range(images.shape[0]):
591
- # Convert tensor back to PIL Image
592
- img_tensor = images[i]
593
- # Denormalize if needed (images should be in [0,1] range)
594
- if img_tensor.min() < 0 or img_tensor.max() > 1:
595
- # If normalized, denormalize
596
- img_tensor = (img_tensor + 1) / 2 # Assuming [-1,1] to [0,1]
597
- img_tensor = torch.clamp(img_tensor, 0, 1)
598
- img_pil = transforms.ToPILImage()(img_tensor)
599
- pil_images.append(img_pil)
600
-
601
- # Process images through Fashion-CLIP
602
- inputs = self.processor(images=pil_images, return_tensors="pt", padding=True)
603
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
604
- image_features = self.model.get_image_features(**inputs)
605
- image_features = image_features / image_features.norm(dim=-1, keepdim=True)
606
- embeddings = image_features.cpu().numpy()
607
- labels = colors
608
- elif embedding_type == 'color':
609
- # Process color names as text through Fashion-CLIP
610
- inputs = self.processor(text=colors, return_tensors="pt", padding=True, truncation=True, max_length=77)
611
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
612
- text_features = self.model.get_text_features(**inputs)
613
- text_features = text_features / text_features.norm(dim=-1, keepdim=True)
614
- embeddings = text_features.cpu().numpy()
615
- labels = colors
616
-
617
- all_embeddings.append(embeddings)
618
- all_labels.extend(labels)
619
- all_texts.extend(texts)
620
-
621
- return np.vstack(all_embeddings), all_labels, all_texts
622
-
623
- def compute_similarity_metrics(self, embeddings, labels):
624
- """Compute intra-class and inter-class similarities"""
625
- similarities = cosine_similarity(embeddings)
626
-
627
- # Group embeddings by color
628
- color_groups = defaultdict(list)
629
- for i, color in enumerate(labels):
630
- color_groups[color].append(i)
631
-
632
- # Calculate intra-class similarities (same color)
633
- intra_class_similarities = []
634
- for color, indices in color_groups.items():
635
- if len(indices) > 1:
636
- for i in range(len(indices)):
637
- for j in range(i+1, len(indices)):
638
- sim = similarities[indices[i], indices[j]]
639
- intra_class_similarities.append(sim)
640
-
641
- # Calculate inter-class similarities (different colors)
642
- inter_class_similarities = []
643
- colors = list(color_groups.keys())
644
- for i in range(len(colors)):
645
- for j in range(i+1, len(colors)):
646
- color1_indices = color_groups[colors[i]]
647
- color2_indices = color_groups[colors[j]]
648
-
649
- for idx1 in color1_indices:
650
- for idx2 in color2_indices:
651
- sim = similarities[idx1, idx2]
652
- inter_class_similarities.append(sim)
653
-
654
- # Calculate classification accuracy using nearest neighbor in embedding space
655
- nn_accuracy = self.compute_embedding_accuracy(embeddings, labels, similarities)
656
-
657
- # Calculate classification accuracy using centroids
658
- centroid_accuracy = self.compute_centroid_accuracy(embeddings, labels)
659
-
660
- return {
661
- 'intra_class_similarities': intra_class_similarities,
662
- 'inter_class_similarities': inter_class_similarities,
663
- 'intra_class_mean': np.mean(intra_class_similarities) if intra_class_similarities else 0,
664
- 'inter_class_mean': np.mean(inter_class_similarities) if inter_class_similarities else 0,
665
- 'separation_score': np.mean(intra_class_similarities) - np.mean(inter_class_similarities) if intra_class_similarities and inter_class_similarities else 0,
666
- 'accuracy': nn_accuracy,
667
- 'centroid_accuracy': centroid_accuracy
668
- }
669
-
670
- def compute_embedding_accuracy(self, embeddings, labels, similarities):
671
- """Compute classification accuracy using nearest neighbor in embedding space"""
672
- correct_predictions = 0
673
- total_predictions = len(labels)
674
-
675
- for i in range(len(embeddings)):
676
- true_label = labels[i]
677
-
678
- # Find the most similar embedding (excluding itself)
679
- similarities_row = similarities[i].copy()
680
- similarities_row[i] = -1 # Exclude self-similarity
681
- nearest_neighbor_idx = np.argmax(similarities_row)
682
- predicted_label = labels[nearest_neighbor_idx]
683
-
684
- if predicted_label == true_label:
685
- correct_predictions += 1
686
-
687
- return correct_predictions / total_predictions if total_predictions > 0 else 0
688
-
689
- def compute_centroid_accuracy(self, embeddings, labels):
690
- """Compute classification accuracy using color centroids"""
691
- # Create centroids for each color
692
- unique_colors = list(set(labels))
693
- centroids = {}
694
-
695
- for color in unique_colors:
696
- color_indices = [i for i, label in enumerate(labels) if label == color]
697
- color_embeddings = embeddings[color_indices]
698
- centroids[color] = np.mean(color_embeddings, axis=0)
699
-
700
- # Classify each embedding to nearest centroid
701
- correct_predictions = 0
702
- total_predictions = len(labels)
703
-
704
- for i, embedding in enumerate(embeddings):
705
- true_label = labels[i]
706
-
707
- # Find closest centroid
708
- best_similarity = -1
709
- predicted_label = None
710
-
711
- for color, centroid in centroids.items():
712
- similarity = cosine_similarity([embedding], [centroid])[0][0]
713
- if similarity > best_similarity:
714
- best_similarity = similarity
715
- predicted_label = color
716
-
717
- if predicted_label == true_label:
718
- correct_predictions += 1
719
-
720
- return correct_predictions / total_predictions if total_predictions > 0 else 0
721
-
722
- def predict_colors_from_embeddings(self, embeddings, labels):
723
- """Predict colors from embeddings using centroid-based classification"""
724
- # Create color centroids from training data
725
- unique_colors = list(set(labels))
726
- centroids = {}
727
-
728
- for color in unique_colors:
729
- color_indices = [i for i, label in enumerate(labels) if label == color]
730
- color_embeddings = embeddings[color_indices]
731
- centroids[color] = np.mean(color_embeddings, axis=0)
732
-
733
- # Predict colors for all embeddings
734
- predictions = []
735
-
736
- for i, embedding in enumerate(embeddings):
737
- # Find closest centroid
738
- best_similarity = -1
739
- predicted_color = None
740
-
741
- for color, centroid in centroids.items():
742
- similarity = cosine_similarity([embedding], [centroid])[0][0]
743
- if similarity > best_similarity:
744
- best_similarity = similarity
745
- predicted_color = color
746
-
747
- predictions.append(predicted_color)
748
-
749
- return predictions
750
-
751
- def create_confusion_matrix(self, true_labels, predicted_labels, title="Confusion Matrix"):
752
- """Create and plot confusion matrix"""
753
- # Get unique labels
754
- unique_labels = sorted(list(set(true_labels + predicted_labels)))
755
-
756
- # Create confusion matrix
757
- cm = confusion_matrix(true_labels, predicted_labels, labels=unique_labels)
758
-
759
- # Calculate accuracy
760
- accuracy = accuracy_score(true_labels, predicted_labels)
761
-
762
- # Plot confusion matrix
763
- plt.figure(figsize=(12, 10))
764
- sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
765
- xticklabels=unique_labels, yticklabels=unique_labels)
766
- plt.title(f'{title}\nAccuracy: {accuracy:.3f} ({accuracy*100:.1f}%)')
767
- plt.ylabel('True Color')
768
- plt.xlabel('Predicted Color')
769
- plt.xticks(rotation=45)
770
- plt.yticks(rotation=0)
771
- plt.tight_layout()
772
-
773
- return plt.gcf(), accuracy, cm
774
-
775
- def evaluate_classification_performance(self, embeddings, labels, embedding_type="Embeddings"):
776
- """Evaluate classification performance and create confusion matrix"""
777
- # Predict colors
778
- predictions = self.predict_colors_from_embeddings(embeddings, labels)
779
-
780
- # Calculate accuracy
781
- accuracy = accuracy_score(labels, predictions)
782
-
783
- # Create confusion matrix
784
- fig, acc, cm = self.create_confusion_matrix(labels, predictions,
785
- f"{embedding_type} - Color Classification (Fashion-CLIP)")
786
-
787
- # Generate classification report
788
- unique_labels = sorted(list(set(labels)))
789
- report = classification_report(labels, predictions, labels=unique_labels,
790
- target_names=unique_labels, output_dict=True)
791
-
792
- return {
793
- 'accuracy': accuracy,
794
- 'predictions': predictions,
795
- 'confusion_matrix': cm,
796
- 'classification_report': report,
797
- 'figure': fig
798
- }
799
-
800
- def evaluate_dataset(self, dataframe, dataset_name="Dataset"):
801
- """
802
- Evaluate Fashion-CLIP embeddings on a given dataset.
803
-
804
- This method extracts embeddings for text, image, and color, computes similarity metrics,
805
- evaluates classification performance, and saves confusion matrices.
806
-
807
- Args:
808
- dataframe: DataFrame containing the dataset
809
- dataset_name: Name of the dataset for display purposes
810
-
811
- Returns:
812
- Dictionary containing evaluation results for text, image, and color embeddings
813
- """
814
- print(f"\n{'='*60}")
815
- print(f"Evaluating {dataset_name} with Fashion-CLIP")
816
- print(f"{'='*60}")
817
-
818
- # Create dataset and dataloader - use FashionCLIPDataset for Fashion-CLIP
819
- if "kagl" in dataset_name.lower():
820
- dataset = KaglDataset(dataframe)
821
- else:
822
- dataset = FashionCLIPDataset(dataframe) # Use special dataset for Fashion-CLIP
823
- # Optimize batch size for Fashion-CLIP
824
- dataloader = DataLoader(dataset, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)
825
-
826
- results = {}
827
-
828
- # Evaluate text embeddings
829
- text_embeddings, text_labels, texts = self.extract_embeddings(dataloader, 'text')
830
- text_metrics = self.compute_similarity_metrics(text_embeddings, text_labels)
831
- text_classification = self.evaluate_classification_performance(text_embeddings, text_labels, "Text Embeddings")
832
- text_metrics.update(text_classification)
833
- results['text'] = text_metrics
834
-
835
- # Evaluate image embeddings
836
- image_embeddings, image_labels, _ = self.extract_embeddings(dataloader, 'image')
837
- image_metrics = self.compute_similarity_metrics(image_embeddings, image_labels)
838
- image_classification = self.evaluate_classification_performance(image_embeddings, image_labels, "Image Embeddings")
839
- image_metrics.update(image_classification)
840
- results['image'] = image_metrics
841
-
842
- # Evaluate color embeddings
843
- color_embeddings, color_labels, _ = self.extract_embeddings(dataloader, 'color')
844
- color_metrics = self.compute_similarity_metrics(color_embeddings, color_labels)
845
- color_classification = self.evaluate_classification_performance(color_embeddings, color_labels, "Color Embeddings")
846
- color_metrics.update(color_classification)
847
- results['color'] = color_metrics
848
-
849
- # Print results
850
- print(f"\n{dataset_name} Results (Fashion-CLIP):")
851
- print("-" * 40)
852
- for emb_type, metrics in results.items():
853
- print(f"{emb_type.capitalize()} Embeddings:")
854
- print(f" Intra-class similarity (same color): {metrics['intra_class_mean']:.4f}")
855
- print(f" Inter-class similarity (diff colors): {metrics['inter_class_mean']:.4f}")
856
- print(f" Separation score: {metrics['separation_score']:.4f}")
857
- print(f" Nearest Neighbor Accuracy: {metrics['accuracy']:.4f} ({metrics['accuracy']*100:.1f}%)")
858
- print(f" Centroid Accuracy: {metrics['centroid_accuracy']:.4f} ({metrics['centroid_accuracy']*100:.1f}%)")
859
-
860
- # Classification report summary
861
- report = metrics['classification_report']
862
- print(f" 📊 Classification Performance:")
863
- print(f" • Macro Avg F1-Score: {report['macro avg']['f1-score']:.4f}")
864
- print(f" • Weighted Avg F1-Score: {report['weighted avg']['f1-score']:.4f}")
865
- print(f" • Support: {report['macro avg']['support']:.0f} samples")
866
- print()
867
-
868
- # Create visualizations
869
- os.makedirs('embedding_evaluation', exist_ok=True)
870
-
871
- # Confusion matrices
872
- results['text']['figure'].savefig(f'embedding_evaluation/{dataset_name.lower()}_text_confusion_matrix_fashion_clip.png', dpi=300, bbox_inches='tight')
873
- plt.close(results['text']['figure'])
874
-
875
- results['image']['figure'].savefig(f'embedding_evaluation/{dataset_name.lower()}_image_confusion_matrix_fashion_clip.png', dpi=300, bbox_inches='tight')
876
- plt.close(results['image']['figure'])
877
-
878
- results['color']['figure'].savefig(f'embedding_evaluation/{dataset_name.lower()}_color_confusion_matrix_fashion_clip.png', dpi=300, bbox_inches='tight')
879
- plt.close(results['color']['figure'])
880
-
881
- return results
882
-
883
- class KaglDataset(Dataset):
884
- """
885
- Dataset class for KAGL Marqo dataset evaluation.
886
-
887
- Handles loading images from the KAGL dataset format (with 'bytes' in image_url).
888
- """
889
- def __init__(self, dataframe):
890
- """
891
- Initialize the KAGL dataset.
892
-
893
- Args:
894
- dataframe: DataFrame containing image_url (with bytes), text, and color labels
895
- """
896
- self.dataframe = dataframe
897
- self.transform = transforms.Compose([
898
- transforms.Resize((224, 224)),
899
- transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
900
- transforms.ToTensor(),
901
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
902
- ])
903
-
904
- def __len__(self):
905
- return len(self.dataframe)
906
-
907
- def __getitem__(self, idx):
908
- row = self.dataframe.iloc[idx]
909
-
910
- # Handle image - it should be in row['image_url'] and contain the image data
911
- image_data = row["image_url"]
912
-
913
- # Check if image_data has 'bytes' key or is already PIL Image
914
- if isinstance(image_data, dict) and 'bytes' in image_data:
915
- image = Image.open(BytesIO(image_data['bytes'])).convert("RGB")
916
- elif hasattr(image_data, 'convert'): # Already a PIL Image
917
- image = image_data.convert("RGB")
918
- else:
919
- image = Image.open(BytesIO(image_data)).convert("RGB")
920
-
921
- image = self.transform(image)
922
-
923
- # Get text and color from kagl
924
- description = row['text']
925
- color = row['color']
926
-
927
- return image, description, color
928
-
929
- def load_kagl_marqo_dataset():
930
- """
931
- Load and prepare KAGL Marqo dataset from Hugging Face.
932
-
933
- This function loads the Marqo/KAGL dataset, filters for valid colors,
934
- and formats it for evaluation.
935
-
936
- Returns:
937
- DataFrame with columns: image_url, text, color
938
- """
939
- from datasets import load_dataset
940
- print("Loading kagl KAGL dataset...")
941
-
942
- # Load the dataset
943
- dataset = load_dataset("Marqo/KAGL")
944
- df = dataset["data"].to_pandas()
945
- print(f"✅ Dataset kagl loaded")
946
-
947
- # Prepare data - Replace baseColour
948
- df['baseColour'] = df['baseColour'].str.lower().str.replace("grey", "gray")
949
- df_test = df[df['baseColour'].notna()].copy()
950
-
951
- print(f"📊 Before filtering: {len(df_test)} samples")
952
-
953
- # Filter for common colors
954
- valid_colors = ['red', 'blue', 'green', 'yellow', 'purple', 'pink', 'orange',
955
- 'brown', 'black', 'white', 'gray', 'navy', 'maroon', 'beige']
956
- df_test = df_test[df_test['baseColour'].isin(valid_colors)]
957
-
958
- print(f"📊 After filtering invalid colors: {len(df_test)} samples")
959
- print(f"🎨 Valid colors found: {sorted(df_test['baseColour'].unique())}")
960
-
961
- if len(df_test) == 0:
962
- print("❌ No samples left after color filtering. Using mock dataset.")
963
-
964
- # Map to our expected column names
965
- kagl_formatted = pd.DataFrame({
966
- 'image_url': df_test['image_url'],
967
- 'text': df_test['text'],
968
- 'color': df_test['baseColour'].str.lower().str.replace("grey", "gray")
969
- })
970
-
971
- # Additional validation - remove rows with missing data
972
- print(f"📊 Before final validation: {len(kagl_formatted)} samples")
973
- kagl_formatted = kagl_formatted.dropna(subset=[config.column_url_image, config.text_column, config.color_column])
974
- print(f"📊 After removing missing data: {len(kagl_formatted)} samples")
975
-
976
- # Check for empty strings
977
- kagl_formatted = kagl_formatted[
978
- (kagl_formatted['text'].str.strip() != '') &
979
- (kagl_formatted['color'].str.strip() != '')
980
- ]
981
- print(f"📊 After removing empty strings: {len(kagl_formatted)} samples")
982
-
983
- print(f"📊 Final dataset size: {len(kagl_formatted)} samples")
984
-
985
- return kagl_formatted
986
-
987
- def create_comparison_table(val_results, kagl_results, val_results_fashion_clip, kagl_results_fashion_clip):
988
- """
989
- Create a structured comparison table between custom model and Fashion-CLIP baseline.
990
-
991
- Args:
992
- val_results: Evaluation results for custom model on validation dataset
993
- kagl_results: Evaluation results for custom model on KAGL dataset
994
- val_results_fashion_clip: Evaluation results for Fashion-CLIP on validation dataset
995
- kagl_results_fashion_clip: Evaluation results for Fashion-CLIP on KAGL dataset
996
-
997
- Returns:
998
- DataFrame containing the comparison table
999
- """
1000
-
1001
- # Create DataFrame for comparison
1002
- data = []
1003
-
1004
- # Define embedding types and their display names
1005
- embedding_types = [
1006
- ('text', 'Text Embeddings'),
1007
- ('image', 'Image Embeddings'),
1008
- ('color', 'Color Embeddings')
1009
- ]
1010
-
1011
- # Define datasets
1012
- datasets = [
1013
- ('Validation Dataset', val_results, val_results_fashion_clip),
1014
- ('kagl Marqo Dataset', kagl_results, kagl_results_fashion_clip)
1015
- ]
1016
-
1017
- for dataset_name, custom_results, baseline_results in datasets:
1018
- for emb_type, emb_display in embedding_types:
1019
- # Your custom model results
1020
- custom_metrics = custom_results[emb_type]
1021
- # Baseline model results
1022
- baseline_metrics = baseline_results[emb_type]
1023
-
1024
- data.append({
1025
- 'Dataset': dataset_name,
1026
- 'Embedding Type': emb_display,
1027
- 'Model': 'Your Model',
1028
- 'Separation Score': f"{custom_metrics['separation_score']:.4f}",
1029
- 'NN Accuracy (%)': f"{custom_metrics['accuracy']*100:.1f}%",
1030
- 'Centroid Accuracy (%)': f"{custom_metrics['centroid_accuracy']*100:.1f}%",
1031
- 'Intra-class Similarity': f"{custom_metrics['intra_class_mean']:.4f}",
1032
- 'Inter-class Similarity': f"{custom_metrics['inter_class_mean']:.4f}",
1033
- 'Macro F1-Score': f"{custom_metrics['classification_report']['macro avg']['f1-score']:.4f}",
1034
- 'Weighted F1-Score': f"{custom_metrics['classification_report']['weighted avg']['f1-score']:.4f}"
1035
- })
1036
-
1037
- data.append({
1038
- 'Dataset': dataset_name,
1039
- 'Embedding Type': emb_display,
1040
- 'Model': 'Fashion-CLIP (Baseline)',
1041
- 'Separation Score': f"{baseline_metrics['separation_score']:.4f}",
1042
- 'NN Accuracy (%)': f"{baseline_metrics['accuracy']*100:.1f}%",
1043
- 'Centroid Accuracy (%)': f"{baseline_metrics['centroid_accuracy']*100:.1f}%",
1044
- 'Intra-class Similarity': f"{baseline_metrics['intra_class_mean']:.4f}",
1045
- 'Inter-class Similarity': f"{baseline_metrics['inter_class_mean']:.4f}",
1046
- 'Macro F1-Score': f"{baseline_metrics['classification_report']['macro avg']['f1-score']:.4f}",
1047
- 'Weighted F1-Score': f"{baseline_metrics['classification_report']['weighted avg']['f1-score']:.4f}"
1048
- })
1049
-
1050
- # Create DataFrame
1051
- df_comparison = pd.DataFrame(data)
1052
-
1053
- # Save to CSV
1054
- df_comparison.to_csv('embedding_evaluation/model_comparison_table.csv', index=False)
1055
-
1056
- # Print formatted table
1057
- print(f"\n{'='*120}")
1058
- print("📊 COMPREHENSIVE MODEL COMPARISON TABLE")
1059
- print(f"{'='*120}")
1060
-
1061
- # Print table by dataset
1062
- for dataset_name in df_comparison['Dataset'].unique():
1063
- print(f"\n🔍 {dataset_name.upper()}")
1064
- print("-" * 120)
1065
-
1066
- dataset_df = df_comparison[df_comparison['Dataset'] == dataset_name]
1067
-
1068
- for emb_type in dataset_df['Embedding Type'].unique():
1069
- print(f"\n📈 {emb_type}:")
1070
- emb_df = dataset_df[dataset_df['Embedding Type'] == emb_type]
1071
-
1072
- # Print header
1073
- print(f"{'Model':<20} {'Separation':<12} {'NN Acc':<10} {'Centroid Acc':<13} {'Intra-class':<12} {'Inter-class':<12} {'Macro F1':<10} {'Weighted F1':<12}")
1074
- print("-" * 120)
1075
-
1076
- # Print data
1077
- for _, row in emb_df.iterrows():
1078
- print(f"{row['Model']:<20} {row['Separation Score']:<12} {row['NN Accuracy (%)']:<10} {row['Centroid Accuracy (%)']:<13} {row['Intra-class Similarity']:<12} {row['Inter-class Similarity']:<12} {row['Macro F1-Score']:<10} {row['Weighted F1-Score']:<12}")
1079
-
1080
- return df_comparison
1081
-
1082
- if __name__ == "__main__":
1083
-
1084
- # Initialize evaluator for your custom model
1085
- evaluator = EmbeddingEvaluator(model_path=config.color_model_path, embed_dim=config.color_emb_dim)
1086
-
1087
- # Initialize Fashion-CLIP evaluator
1088
- fashion_clip_evaluator = FashionCLIPEvaluator()
1089
-
1090
- # Load datasets
1091
- print("Loading datasets...")
1092
-
1093
- # Load validation dataset
1094
- df_val = pd.read_csv(config.local_dataset_path)
1095
-
1096
- # Filter for better quality data
1097
- print(f"📊 Original dataset size: {len(df_val)}")
1098
- samples_to_evaluate = 10000
1099
-
1100
-
1101
- # Load kagl Marqo dataset
1102
- kagl_df = load_kagl_marqo_dataset()
1103
-
1104
- # Evaluate your custom model on validation dataset
1105
- val_results = evaluator.evaluate_dataset(df_val, "Validation Dataset")
1106
-
1107
- # Evaluate your custom model on kagl Marqo dataset (reduced sample for speed)
1108
- kagl_results = evaluator.evaluate_dataset(kagl_df.sample(min(samples_to_evaluate, len(kagl_df)), random_state=42), "kagl Marqo Dataset")
1109
-
1110
- # Evaluate Fashion-CLIP on validation dataset
1111
- val_results_fashion_clip = fashion_clip_evaluator.evaluate_dataset(df_val, "Validation Dataset")
1112
-
1113
- # Create comprehensive comparison table
1114
- comparison_df = create_comparison_table(
1115
- val_results, kagl_results,
1116
- val_results_fashion_clip
1117
- )
1118
-
1119
- print(f"\n{'='*120}")
1120
- print("✅ Evaluation complete!")
1121
- print("📁 Confusion matrices saved in 'embedding_evaluation/' folder")
1122
- print("📁 Comparison table saved as 'model_comparison_table.csv'")
1123
- print("📁 Fashion-CLIP results are saved with '_fashion_clip' suffix.")
1124
- print(f"{'='*120}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Evaluation/fashion_search.py DELETED
@@ -1,365 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Fashion search system using multi-modal embeddings.
4
- This file implements a fashion search engine that allows searching for clothing items
5
- using text queries. It uses embeddings from the main model to calculate cosine similarities
6
- and return the most relevant items. The system pre-computes embeddings for all items
7
- in the dataset for fast search.
8
- """
9
-
10
- import torch
11
- import numpy as np
12
- import pandas as pd
13
- from PIL import Image
14
- import matplotlib.pyplot as plt
15
- from sklearn.metrics.pairwise import cosine_similarity
16
- from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers
17
- import warnings
18
- import os
19
- from typing import List, Tuple, Union, Optional
20
- import argparse
21
-
22
- # Import custom models
23
- from color_model import CLIPModel as ColorModel
24
- from hierarchy_model import Model as HierarchyModel, HierarchyExtractor
25
- from main_model import CustomDataset
26
- import config
27
-
28
- warnings.filterwarnings("ignore")
29
-
30
- class FashionSearchEngine:
31
- """
32
- Fashion search engine using multi-modal embeddings with category emphasis
33
- """
34
-
35
- def __init__(self, top_k: int = 10, max_items: int = 10000):
36
- """
37
- Initialize the fashion search engine
38
- Args:
39
- top_k: Number of top results to return
40
- max_items: Maximum number of items to process (for faster initialization)
41
- hierarchy_weight: Weight for hierarchy/category dimensions (default: 2.0)
42
- color_weight: Weight for color dimensions (default: 1.0)
43
- """
44
- self.device = config.device
45
- self.top_k = top_k
46
- self.max_items = max_items
47
- self.color_dim = config.color_emb_dim
48
- self.hierarchy_dim = config.hierarchy_emb_dim
49
-
50
- # Load models
51
- self._load_models()
52
-
53
- # Load dataset
54
- self._load_dataset()
55
-
56
- # Pre-compute embeddings for all items
57
- self._precompute_embeddings()
58
-
59
- print("✅ Fashion Search Engine ready!")
60
-
61
- def _load_models(self):
62
- """Load all required models"""
63
- print("📦 Loading models...")
64
-
65
- # Load color model
66
- color_checkpoint = torch.load(config.color_model_path, map_location=self.device, weights_only=True)
67
- self.color_model = ColorModel(embed_dim=self.color_dim).to(self.device)
68
- self.color_model.load_state_dict(color_checkpoint)
69
- self.color_model.eval()
70
-
71
- # Load hierarchy model
72
- hierarchy_checkpoint = torch.load(config.hierarchy_model_path, map_location=self.device)
73
- self.hierarchy_classes = hierarchy_checkpoint.get('hierarchy_classes', [])
74
- self.hierarchy_model = HierarchyModel(
75
- num_hierarchy_classes=len(self.hierarchy_classes),
76
- embed_dim=self.hierarchy_dim
77
- ).to(self.device)
78
- self.hierarchy_model.load_state_dict(hierarchy_checkpoint['model_state'])
79
-
80
- # Set hierarchy extractor
81
- hierarchy_extractor = HierarchyExtractor(self.hierarchy_classes, verbose=False)
82
- self.hierarchy_model.set_hierarchy_extractor(hierarchy_extractor)
83
- self.hierarchy_model.eval()
84
-
85
- # Load main CLIP model - Use the trained model directly
86
- self.main_model = CLIPModel_transformers.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
87
-
88
- # Load the trained weights
89
- checkpoint = torch.load(config.main_model_path, map_location=self.device)
90
- if 'model_state_dict' in checkpoint:
91
- self.main_model.load_state_dict(checkpoint['model_state_dict'])
92
- else:
93
- # Fallback: try to load as state dict directly
94
- self.main_model.load_state_dict(checkpoint)
95
- print("✅ Loaded model weights directly")
96
-
97
- self.main_model.to(self.device)
98
- self.main_model.eval()
99
-
100
- # Load CLIP processor
101
- self.clip_processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
102
-
103
- print(f"✅ Models loaded - Colors: {self.color_dim}D, Hierarchy: {self.hierarchy_dim}D")
104
-
105
- def _load_dataset(self):
106
- """Load the fashion dataset"""
107
- print("📊 Loading dataset...")
108
-
109
- # Load dataset
110
- self.df = pd.read_csv(config.local_dataset_path)
111
- self.df_clean = self.df.dropna(subset=[config.column_local_image_path])
112
-
113
- # Create dataset object
114
- self.dataset = CustomDataset(self.df_clean)
115
- self.dataset.set_training_mode(False) # No augmentation for search
116
-
117
- print(f"✅ {len(self.df_clean)} items loaded for search")
118
-
119
- def _precompute_embeddings(self):
120
- """Pre-compute embeddings for all items in the dataset"""
121
- print("🔄 Pre-computing embeddings...")
122
-
123
- # OPTIMIZATION: Sample a subset for faster initialization
124
- print(f"⚠️ Dataset too large ({len(self.dataset)} items). Using stratified sampling of 10 items per color-category combination.")
125
-
126
- # Stratified sampling by color-category combinations
127
- sampled_df = self.df_clean.groupby([config.color_column, config.hierarchy_column]).sample(n=20, replace=False)
128
-
129
- # Get the original indices of sampled items
130
- sampled_indices = sampled_df.index.tolist()
131
-
132
- all_embeddings = []
133
- all_texts = []
134
- all_colors = []
135
- all_hierarchies = []
136
- all_images = []
137
- all_urls = []
138
-
139
- # Process in batches for efficiency
140
- batch_size = 32
141
-
142
- # Add progress bar
143
- from tqdm import tqdm
144
- total_batches = (len(sampled_indices) + batch_size - 1) // batch_size
145
-
146
- for i in tqdm(range(0, len(sampled_indices), batch_size),
147
- desc="Computing embeddings",
148
- total=total_batches):
149
- batch_end = min(i + batch_size, len(sampled_indices))
150
- batch_items = []
151
-
152
- for j in range(i, batch_end):
153
- try:
154
- # Use the original dataset with the sampled index
155
- original_idx = sampled_indices[j]
156
- image, text, color, hierarchy = self.dataset[original_idx]
157
- batch_items.append((image, text, color, hierarchy))
158
- all_texts.append(text)
159
- all_colors.append(color)
160
- all_hierarchies.append(hierarchy)
161
- all_images.append(self.df_clean.iloc[original_idx][config.column_local_image_path])
162
- all_urls.append(self.df_clean.iloc[original_idx][config.column_url_image])
163
- except Exception as e:
164
- print(f"⚠️ Skipping item {j}: {e}")
165
- continue
166
-
167
- if not batch_items:
168
- continue
169
-
170
- # Process batch
171
- images = torch.stack([item[0] for item in batch_items]).to(self.device)
172
- texts = [item[1] for item in batch_items]
173
-
174
- with torch.no_grad():
175
- # Get embeddings from main model (text embeddings only)
176
- text_inputs = self.clip_processor(text=texts, padding=True, return_tensors="pt")
177
- text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
178
-
179
- # Create dummy images for the model
180
- dummy_images = torch.zeros(len(texts), 3, 224, 224).to(self.device)
181
-
182
- outputs = self.main_model(**text_inputs, pixel_values=dummy_images)
183
- embeddings = outputs.text_embeds.cpu().numpy()
184
-
185
- all_embeddings.extend(embeddings)
186
-
187
- self.all_embeddings = np.array(all_embeddings)
188
- self.all_texts = all_texts
189
- self.all_colors = all_colors
190
- self.all_hierarchies = all_hierarchies
191
- self.all_images = all_images
192
- self.all_urls = all_urls
193
-
194
- print(f"✅ Pre-computed embeddings for {len(self.all_embeddings)} items")
195
-
196
- def search_by_text(self, query_text: str, filter_category: str = None) -> List[dict]:
197
- """
198
- Search for clothing items using text query
199
-
200
- Args:
201
- query_text: Text description to search for
202
-
203
- Returns:
204
- List of dictionaries containing search results
205
- """
206
- print(f"🔍 Searching for: '{query_text}'")
207
-
208
- # Get query embedding
209
- with torch.no_grad():
210
- text_inputs = self.clip_processor(text=[query_text], padding=True, return_tensors="pt")
211
- text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
212
-
213
- # Create a dummy image tensor to satisfy the model's requirements
214
- dummy_image = torch.zeros(1, 3, 224, 224).to(self.device)
215
-
216
- outputs = self.main_model(**text_inputs, pixel_values=dummy_image)
217
- query_embedding = outputs.text_embeds.cpu().numpy()
218
-
219
- # Calculate similarities
220
- similarities = cosine_similarity(query_embedding, self.all_embeddings)[0]
221
-
222
- # Get top-k results
223
- top_indices = np.argsort(similarities)[::-1][:self.top_k * 2] # Prendre plus de résultats
224
-
225
- results = []
226
- for idx in top_indices:
227
- if similarities[idx] > -0.5:
228
- # Filter by category if specified
229
- if filter_category and filter_category.lower() not in self.all_hierarchies[idx].lower():
230
- continue
231
-
232
- results.append({
233
- 'rank': len(results) + 1,
234
- 'image_path': self.all_images[idx],
235
- 'text': self.all_texts[idx],
236
- 'color': self.all_colors[idx],
237
- 'hierarchy': self.all_hierarchies[idx],
238
- 'similarity': float(similarities[idx]),
239
- 'index': int(idx),
240
- 'url': self.all_urls[idx]
241
- })
242
-
243
- if len(results) >= self.top_k:
244
- break
245
-
246
- print(f"✅ Found {len(results)} results")
247
- return results
248
-
249
-
250
- def display_results(self, results: List[dict], query_info: str = ""):
251
- """
252
- Display search results with images and information
253
-
254
- Args:
255
- results: List of search result dictionaries
256
- query_info: Information about the query
257
- """
258
- if not results:
259
- print("❌ No results found")
260
- return
261
-
262
- print(f"\n🎯 Search Results for: {query_info}")
263
- print("=" * 80)
264
-
265
- # Calculate grid layout
266
- n_results = len(results)
267
- cols = min(5, n_results)
268
- rows = (n_results + cols - 1) // cols
269
-
270
- fig, axes = plt.subplots(rows, cols, figsize=(4*cols, 4*rows))
271
- if rows == 1:
272
- axes = axes.reshape(1, -1)
273
- elif cols == 1:
274
- axes = axes.reshape(-1, 1)
275
-
276
- for i, result in enumerate(results):
277
- row = i // cols
278
- col = i % cols
279
- ax = axes[row, col]
280
-
281
- try:
282
- # Load and display image
283
- image = Image.open(result['image_path'])
284
- ax.imshow(image)
285
- ax.axis('off')
286
-
287
- # Add title with similarity score
288
- title = f"#{result['rank']} (Similarity: {result['similarity']:.3f})\n{result['color']} {result['hierarchy']}"
289
- ax.set_title(title, fontsize=10, wrap=True)
290
-
291
- except Exception as e:
292
- ax.text(0.5, 0.5, f"Error loading image\n{result['image_path']}",
293
- ha='center', va='center', transform=ax.transAxes)
294
- ax.axis('off')
295
-
296
- # Hide empty subplots
297
- for i in range(n_results, rows * cols):
298
- row = i // cols
299
- col = i % cols
300
- axes[row, col].axis('off')
301
-
302
- plt.tight_layout()
303
- plt.show()
304
-
305
- # Print detailed results
306
- print("\n📋 Detailed Results:")
307
- for result in results:
308
- print(f"#{result['rank']:2d} | Similarity: {result['similarity']:.3f} | "
309
- f"Color: {result['color']:12s} | Category: {result['hierarchy']:15s} | "
310
- f"Text: {result['text'][:50]}...")
311
- print(f" 🔗 URL: {result['url']}")
312
- print()
313
-
314
-
315
- def main():
316
- """Main function for command-line usage"""
317
- parser = argparse.ArgumentParser(description="Fashion Search Engine with Category Emphasis")
318
- parser.add_argument("--query", "-q", type=str, help="Search query")
319
- parser.add_argument("--top-k", "-k", type=int, default=10, help="Number of results (default: 10)")
320
- parser.add_argument("--fast", "-f", action="store_true", help="Fast mode (less items)")
321
- parser.add_argument("--interactive", "-i", action="store_true", help="Interactive mode")
322
-
323
- args = parser.parse_args()
324
-
325
- print("🎯 Fashion Search Engine with Category Emphasis")
326
-
327
- search_engine = FashionSearchEngine(
328
- top_k=args.top_k,
329
- )
330
- print("✅ Ready!")
331
-
332
- # Single query mode
333
- if args.query:
334
- print(f"🔍 Search: '{args.query}'...")
335
- results = search_engine.search_by_text(args.query)
336
- search_engine.display_results(results, args.query)
337
-
338
-
339
- # Interactive mode
340
- print("Enter your query (e.g. 'red dress') or 'quit' to exit")
341
-
342
- while True:
343
- try:
344
- user_input = input("\n🔍 Query: ").strip()
345
- if not user_input or user_input.lower() in ['quit', 'exit', 'q']:
346
- print("👋 Goodbye!")
347
- break
348
-
349
- if user_input.startswith('verify '):
350
- if 'yellow accessories' in user_input:
351
- search_engine.display_yellow_accessories()
352
- continue
353
-
354
- print(f"🔍 Search: '{user_input}'...")
355
- results = search_engine.search_by_text(user_input)
356
- search_engine.display_results(results, user_input)
357
-
358
- except KeyboardInterrupt:
359
- print("\n👋 Goodbye!")
360
- break
361
- except Exception as e:
362
- print(f"❌ Error: {e}")
363
-
364
- if __name__ == "__main__":
365
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Evaluation/hierarchy_evaluation.py DELETED
@@ -1,589 +0,0 @@
1
- """
2
- Hierarchy embedding evaluation for clothing category classification.
3
- This file evaluates the quality of hierarchy embeddings generated by the hierarchy model
4
- by calculating intra-class and inter-class similarity metrics, nearest neighbor and centroid-based
5
- classification accuracies, and generating confusion matrices. It can be used on different datasets
6
- (local validation, Kagl Marqo) to measure model generalization.
7
- """
8
-
9
- import torch
10
- import pandas as pd
11
- import numpy as np
12
- import matplotlib.pyplot as plt
13
- import seaborn as sns
14
- from sklearn.metrics.pairwise import cosine_similarity
15
- from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
16
- from collections import defaultdict
17
- import os
18
- from tqdm import tqdm
19
- from torch.utils.data import Dataset, DataLoader
20
- from torchvision import transforms
21
- from sklearn.model_selection import train_test_split
22
- from io import BytesIO
23
- from PIL import Image
24
- import config
25
- import warnings
26
- warnings.filterwarnings('ignore')
27
- from hierarchy_model import Model, HierarchyExtractor, HierarchyDataset, collate_fn
28
-
29
-
30
- class EmbeddingEvaluator:
31
- """
32
- Evaluator for hierarchy embeddings generated by the hierarchy model.
33
-
34
- This class provides methods to evaluate the quality of hierarchy embeddings by computing
35
- similarity metrics, classification accuracies, and generating visualizations.
36
- """
37
-
38
- def __init__(self, model_path, directory):
39
- """
40
- Initialize the embedding evaluator.
41
-
42
- Args:
43
- model_path: Path to the trained hierarchy model checkpoint
44
- directory: Directory to save evaluation results and visualizations
45
- """
46
- self.device = config.device
47
- self.directory = directory
48
-
49
- # 1. Load the dataset
50
- CSV = config.local_dataset_path
51
- print(f"📁 Using dataset with local images: {CSV}")
52
- df = pd.read_csv(CSV)
53
-
54
- print(f"📁 Loaded {len(df)} samples")
55
-
56
- # 2. Get unique hierarchy classes from the dataset
57
- hierarchy_classes = sorted(df[config.hierarchy_column].unique().tolist())
58
- print(f"📋 Found {len(hierarchy_classes)} hierarchy classes")
59
-
60
- _, self.val_df = train_test_split(df, test_size=0.2, random_state=42, stratify=df[config.hierarchy_column])
61
-
62
- # 3. Load the model
63
- if os.path.exists(model_path):
64
- checkpoint = torch.load(model_path, map_location=self.device)
65
-
66
- # Use model_config to avoid shadowing the imported config module
67
- model_config = checkpoint.get('config', {})
68
- saved_hierarchy_classes = checkpoint['hierarchy_classes']
69
-
70
- # Use the saved hierarchy classes
71
- self.hierarchy_classes = saved_hierarchy_classes
72
-
73
- # Create the hierarchy extractor
74
- self.vocab = HierarchyExtractor(saved_hierarchy_classes)
75
-
76
- # Create the model with the saved configuration
77
- self.model = Model(
78
- num_hierarchy_classes=len(saved_hierarchy_classes),
79
- embed_dim=model_config['embed_dim'],
80
- dropout=model_config['dropout']
81
- ).to(self.device)
82
-
83
- self.model.load_state_dict(checkpoint['model_state'])
84
-
85
- print(f"✅ Model loaded with:")
86
- print(f"📋 Hierarchy classes: {len(saved_hierarchy_classes)}")
87
- print(f"🎯 Embed dim: {model_config['embed_dim']}")
88
- print(f"💧 Dropout: {model_config['dropout']}")
89
- print(f"📅 Epoch: {checkpoint.get('epoch', 'unknown')}")
90
-
91
- else:
92
- raise FileNotFoundError(f"Model file {model_path} not found")
93
-
94
- self.model.eval()
95
-
96
- def create_dataloader(self, dataframe, batch_size=16):
97
- """
98
- Create a DataLoader for the hierarchy dataset.
99
-
100
- Args:
101
- dataframe: DataFrame containing the dataset
102
- batch_size: Batch size for the DataLoader
103
-
104
- Returns:
105
- DataLoader instance
106
- """
107
- dataset = HierarchyDataset(dataframe, image_size=224)
108
-
109
- dataloader = DataLoader(
110
- dataset,
111
- batch_size=batch_size,
112
- shuffle=False,
113
- collate_fn=lambda batch: collate_fn(batch, self.vocab),
114
- num_workers=0
115
- )
116
-
117
- return dataloader
118
-
119
- def extract_embeddings(self, dataloader, embedding_type='text'):
120
- """
121
- Extract embeddings from the model for a given dataloader.
122
-
123
- Args:
124
- dataloader: DataLoader containing images, texts, and hierarchy labels
125
- embedding_type: Type of embeddings to extract ('text' or 'image')
126
-
127
- Returns:
128
- Tuple of (embeddings array, labels list, texts list)
129
- """
130
- all_embeddings = []
131
- all_labels = []
132
- all_texts = []
133
-
134
- with torch.no_grad():
135
- for batch in tqdm(dataloader, desc=f"Extracting {embedding_type} embeddings"):
136
- images = batch['image'].to(self.device)
137
- hierarchy_indices = batch['hierarchy_indices'].to(self.device)
138
- hierarchy_labels = batch['hierarchy']
139
-
140
- # Forward pass
141
- out = self.model(image=images, hierarchy_indices=hierarchy_indices)
142
- embeddings = out['z_txt'] if embedding_type == 'text' else out['z_img'] if embedding_type == 'image' else out['z_txt']
143
-
144
- all_embeddings.append(embeddings.cpu().numpy())
145
- all_labels.extend(hierarchy_labels)
146
- all_texts.extend(hierarchy_labels)
147
-
148
- return np.vstack(all_embeddings), all_labels, all_texts
149
-
150
- def compute_similarity_metrics(self, embeddings, labels):
151
- """
152
- Compute intra-class and inter-class similarity metrics.
153
-
154
- Args:
155
- embeddings: Array of embeddings [N, embed_dim]
156
- labels: List of labels for each embedding
157
-
158
- Returns:
159
- Dictionary containing similarity metrics, accuracies, and separation scores
160
- """
161
- similarities = cosine_similarity(embeddings)
162
-
163
- # Group embeddings by hierarchy
164
- hierarchy_groups = defaultdict(list)
165
- for i, hierarchy in enumerate(labels):
166
- hierarchy_groups[hierarchy].append(i)
167
-
168
- # Calculate intra-class similarities (same hierarchy)
169
- intra_class_similarities = []
170
- for hierarchy, indices in hierarchy_groups.items():
171
- if len(indices) > 1:
172
- for i in range(len(indices)):
173
- for j in range(i+1, len(indices)):
174
- sim = similarities[indices[i], indices[j]]
175
- intra_class_similarities.append(sim)
176
-
177
-
178
- # Calculate inter-class similarities (different hierarchies)
179
- inter_class_similarities = []
180
- hierarchies = list(hierarchy_groups.keys())
181
- for i in range(len(hierarchies)):
182
- for j in range(i+1, len(hierarchies)):
183
- hierarchy1_indices = hierarchy_groups[hierarchies[i]]
184
- hierarchy2_indices = hierarchy_groups[hierarchies[j]]
185
-
186
- for idx1 in hierarchy1_indices:
187
- for idx2 in hierarchy2_indices:
188
- sim = similarities[idx1, idx2]
189
- inter_class_similarities.append(sim)
190
-
191
- # Calculate classification accuracy using nearest neighbor in embedding space
192
- nn_accuracy = self.compute_embedding_accuracy(embeddings, labels, similarities)
193
-
194
- # Calculate classification accuracy using centroids
195
- centroid_accuracy = self.compute_centroid_accuracy(embeddings, labels)
196
-
197
- return {
198
- 'intra_class_similarities': intra_class_similarities,
199
- 'inter_class_similarities': inter_class_similarities,
200
- 'intra_class_mean': np.mean(intra_class_similarities) if intra_class_similarities else 0,
201
- 'inter_class_mean': np.mean(inter_class_similarities) if inter_class_similarities else 0,
202
- 'separation_score': np.mean(intra_class_similarities) - np.mean(inter_class_similarities) if intra_class_similarities and inter_class_similarities else 0,
203
- 'accuracy': nn_accuracy,
204
- 'centroid_accuracy': centroid_accuracy
205
- }
206
-
207
- def compute_embedding_accuracy(self, embeddings, labels, similarities):
208
- """
209
- Compute classification accuracy using nearest neighbor in embedding space.
210
-
211
- Args:
212
- embeddings: Array of embeddings [N, embed_dim]
213
- labels: List of true labels
214
- similarities: Pre-computed similarity matrix [N, N]
215
-
216
- Returns:
217
- Accuracy score (float between 0 and 1)
218
- """
219
- correct_predictions = 0
220
- total_predictions = len(labels)
221
-
222
- for i in range(len(embeddings)):
223
- true_label = labels[i]
224
-
225
- # Find the most similar embedding (excluding itself)
226
- similarities_row = similarities[i].copy()
227
- similarities_row[i] = -1 # Exclude self-similarity
228
- nearest_neighbor_idx = np.argmax(similarities_row)
229
- predicted_label = labels[nearest_neighbor_idx]
230
-
231
- if predicted_label == true_label:
232
- correct_predictions += 1
233
-
234
- return correct_predictions / total_predictions if total_predictions > 0 else 0
235
-
236
- def compute_centroid_accuracy(self, embeddings, labels):
237
- """
238
- Compute classification accuracy using hierarchy centroids.
239
-
240
- Each hierarchy class is represented by its centroid (mean embedding), and each
241
- embedding is classified to the nearest centroid.
242
-
243
- Args:
244
- embeddings: Array of embeddings [N, embed_dim]
245
- labels: List of true labels
246
-
247
- Returns:
248
- Accuracy score (float between 0 and 1)
249
- """
250
- # Create centroids for each hierarchy
251
- unique_hierarchies = list(set(labels))
252
- centroids = {}
253
-
254
- for hierarchy in unique_hierarchies:
255
- hierarchy_indices = [i for i, label in enumerate(labels) if label == hierarchy]
256
- hierarchy_embeddings = embeddings[hierarchy_indices]
257
- centroids[hierarchy] = np.mean(hierarchy_embeddings, axis=0)
258
-
259
- # Classify each embedding to nearest centroid
260
- correct_predictions = 0
261
- total_predictions = len(labels)
262
-
263
- for i, embedding in enumerate(embeddings):
264
- true_label = labels[i]
265
-
266
- # Find closest centroid
267
- best_similarity = -1
268
- predicted_label = None
269
-
270
- for hierarchy, centroid in centroids.items():
271
- similarity = cosine_similarity([embedding], [centroid])[0][0]
272
- if similarity > best_similarity:
273
- best_similarity = similarity
274
- predicted_label = hierarchy
275
-
276
- if predicted_label == true_label:
277
- correct_predictions += 1
278
-
279
- return correct_predictions / total_predictions if total_predictions > 0 else 0
280
-
281
- def predict_hierarchy_from_embeddings(self, embeddings, labels):
282
- """
283
- Predict hierarchy from embeddings using centroid-based classification.
284
-
285
- Args:
286
- embeddings: Array of embeddings [N, embed_dim]
287
- labels: List of labels used to compute centroids
288
-
289
- Returns:
290
- List of predicted hierarchy labels
291
- """
292
- # Create hierarchy centroids from training data
293
- unique_hierarchies = list(set(labels))
294
- centroids = {}
295
-
296
- for hierarchy in unique_hierarchies:
297
- hierarchy_indices = [i for i, label in enumerate(labels) if label == hierarchy]
298
- hierarchy_embeddings = embeddings[hierarchy_indices]
299
- centroids[hierarchy] = np.mean(hierarchy_embeddings, axis=0)
300
-
301
- # Predict hierarchy for all embeddings
302
- predictions = []
303
-
304
- for i, embedding in enumerate(embeddings):
305
- # Find closest centroid
306
- best_similarity = -1
307
- predicted_hierarchy = None
308
-
309
- for hierarchy, centroid in centroids.items():
310
- similarity = cosine_similarity([embedding], [centroid])[0][0]
311
- if similarity > best_similarity:
312
- best_similarity = similarity
313
- predicted_hierarchy = hierarchy
314
-
315
- predictions.append(predicted_hierarchy)
316
-
317
- return predictions
318
-
319
- def create_confusion_matrix(self, true_labels, predicted_labels, title="Confusion Matrix"):
320
- """
321
- Create and plot a confusion matrix.
322
-
323
- Args:
324
- true_labels: List of true labels
325
- predicted_labels: List of predicted labels
326
- title: Title for the confusion matrix plot
327
-
328
- Returns:
329
- Tuple of (figure, accuracy, confusion_matrix)
330
- """
331
- # Get unique labels
332
- unique_labels = sorted(list(set(true_labels + predicted_labels)))
333
-
334
- # Create confusion matrix
335
- cm = confusion_matrix(true_labels, predicted_labels, labels=unique_labels)
336
-
337
- # Calculate accuracy
338
- accuracy = accuracy_score(true_labels, predicted_labels)
339
-
340
- # Plot confusion matrix
341
- plt.figure(figsize=(12, 10))
342
- sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
343
- xticklabels=unique_labels, yticklabels=unique_labels)
344
- plt.title(f'{title}\nAccuracy: {accuracy:.3f} ({accuracy*100:.1f}%)')
345
- plt.ylabel('True Hierarchy')
346
- plt.xlabel('Predicted Hierarchy')
347
- plt.xticks(rotation=45)
348
- plt.yticks(rotation=0)
349
- plt.tight_layout()
350
-
351
- return plt.gcf(), accuracy, cm
352
-
353
- def evaluate_classification_performance(self, embeddings, labels, embedding_type="Embeddings"):
354
- """
355
- Evaluate classification performance and create confusion matrix.
356
-
357
- Args:
358
- embeddings: Array of embeddings [N, embed_dim]
359
- labels: List of true labels
360
- embedding_type: Type of embeddings for display purposes
361
-
362
- Returns:
363
- Dictionary containing accuracy, predictions, confusion matrix, and classification report
364
- """
365
- # Predict hierarchy
366
- predictions = self.predict_hierarchy_from_embeddings(embeddings, labels)
367
-
368
- # Calculate accuracy
369
- accuracy = accuracy_score(labels, predictions)
370
-
371
- # Create confusion matrix
372
- fig, acc, cm = self.create_confusion_matrix(labels, predictions,
373
- f"{embedding_type} - Hierarchy Classification")
374
-
375
- # Generate classification report
376
- unique_labels = sorted(list(set(labels)))
377
- report = classification_report(labels, predictions, labels=unique_labels,
378
- target_names=unique_labels, output_dict=True)
379
-
380
- return {
381
- 'accuracy': accuracy,
382
- 'predictions': predictions,
383
- 'confusion_matrix': cm,
384
- 'classification_report': report,
385
- 'figure': fig
386
- }
387
-
388
- def evaluate_dataset(self, dataframe, dataset_name="Dataset"):
389
- """
390
- Evaluate embeddings on a given dataset.
391
-
392
- This method extracts embeddings for text and image, computes similarity metrics,
393
- evaluates classification performance, and saves confusion matrices.
394
-
395
- Args:
396
- dataframe: DataFrame containing the dataset
397
- dataset_name: Name of the dataset for display purposes
398
-
399
- Returns:
400
- Dictionary containing evaluation results for text and image embeddings
401
- """
402
- print(f"\n{'='*60}")
403
- print(f"Evaluating {dataset_name}")
404
- print(f"{'='*60}")
405
-
406
- # Create dataloader exactly as during training
407
- dataloader = self.create_dataloader(dataframe, batch_size=16)
408
-
409
- results = {}
410
-
411
- # Evaluate text embeddings
412
- text_embeddings, text_labels, texts = self.extract_embeddings(dataloader, 'text')
413
- text_metrics = self.compute_similarity_metrics(text_embeddings, text_labels)
414
- text_classification = self.evaluate_classification_performance(text_embeddings, text_labels, "Text Embeddings")
415
- text_metrics.update(text_classification)
416
- results['text'] = text_metrics
417
-
418
- # Evaluate image embeddings
419
- image_embeddings, image_labels, _ = self.extract_embeddings(dataloader, 'image')
420
- image_metrics = self.compute_similarity_metrics(image_embeddings, image_labels)
421
- image_classification = self.evaluate_classification_performance(image_embeddings, image_labels, "Image Embeddings")
422
- image_metrics.update(image_classification)
423
- results['image'] = image_metrics
424
-
425
- # Evaluate hierarchy embeddings
426
- hierarchy_embeddings, hierarchy_labels, _ = self.extract_embeddings(dataloader, 'category2')
427
- hierarchy_metrics = self.compute_similarity_metrics(hierarchy_embeddings, hierarchy_labels)
428
- hierarchy_classification = self.evaluate_classification_performance(hierarchy_embeddings, hierarchy_labels, "hierarchy Embeddings")
429
- hierarchy_metrics.update(hierarchy_classification)
430
- results['hierarchy'] = hierarchy_metrics
431
-
432
- # Print results
433
- print(f"\n{dataset_name} Results:")
434
- print("-" * 40)
435
- for emb_type, metrics in results.items():
436
- print(f"{emb_type.capitalize()} Embeddings:")
437
- print(f" Intra-class similarity (same hierarchy): {metrics['intra_class_mean']:.4f}")
438
- print(f" Inter-class similarity (diff hierarchy): {metrics['inter_class_mean']:.4f}")
439
- print(f" Separation score: {metrics['separation_score']:.4f}")
440
- print(f" Nearest Neighbor Accuracy: {metrics['accuracy']:.4f} ({metrics['accuracy']*100:.1f}%)")
441
- print(f" Centroid Accuracy: {metrics['centroid_accuracy']:.4f} ({metrics['centroid_accuracy']*100:.1f}%)")
442
-
443
- # Classification report summary
444
- report = metrics['classification_report']
445
- print(f" 📊 Classification Performance:")
446
- print(f" • Macro Avg F1-Score: {report['macro avg']['f1-score']:.4f}")
447
- print(f" • Weighted Avg F1-Score: {report['weighted avg']['f1-score']:.4f}")
448
- print(f" • Support: {report['macro avg']['support']:.0f} samples")
449
- print()
450
-
451
- # Create visualizations
452
- os.makedirs(f'{self.directory}', exist_ok=True)
453
-
454
- # Confusion matrices
455
- results['text']['figure'].savefig(f'{self.directory}/{dataset_name.lower()}_text_confusion_matrix.png', dpi=300, bbox_inches='tight')
456
- plt.close(results['text']['figure'])
457
-
458
- results['image']['figure'].savefig(f'{self.directory}/{dataset_name.lower()}_image_confusion_matrix.png', dpi=300, bbox_inches='tight')
459
- plt.close(results['image']['figure'])
460
-
461
- results['hierarchy']['figure'].savefig(f'{self.directory}/{dataset_name.lower()}_hierarchy_confusion_matrix.png', dpi=300, bbox_inches='tight')
462
- plt.close(results['hierarchy']['figure'])
463
-
464
- return results
465
-
466
- class KaglDataset(Dataset):
467
- def __init__(self, dataframe):
468
- self.dataframe = dataframe
469
- # Use VALIDATION transforms (no augmentation)
470
- self.transform = transforms.Compose([
471
- transforms.Resize((224, 224)),
472
- transforms.ToTensor(),
473
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
474
- ])
475
-
476
- def __len__(self):
477
- return len(self.dataframe)
478
-
479
- def __getitem__(self, idx):
480
- row = self.dataframe.iloc[idx]
481
-
482
- # Handle image
483
- image_data = row['image_url']
484
- image = Image.open(BytesIO(image_data['bytes'])).convert("RGB")
485
- image = self.transform(image)
486
-
487
- # Get text and hierarchy
488
- description = row['text']
489
- hierarchy = row['hierarchy']
490
-
491
- return image, description, hierarchy
492
-
493
- def load_Kagl_marqo_dataset(evaluator):
494
- """Load and prepare Kagl KAGL dataset"""
495
- from datasets import load_dataset
496
- print("Loading Kagl KAGL dataset...")
497
-
498
- # Load the dataset
499
- dataset = load_dataset("Marqo/KAGL")
500
- df = dataset["data"].to_pandas()
501
- print(f"✅ Dataset Kagl loaded")
502
- print(f"📊 Before filtering: {len(df)} samples")
503
- print(f"📋 Available columns: {list(df.columns)}")
504
-
505
- # Check available categories and map them to our hierarchy
506
- print(f"🎨 Available categories: {sorted(df['category2'].unique())}")
507
- # Apply mapping
508
- df['hierarchy'] = df['category2'].str.lower()
509
- df['hierarchy'] = df['hierarchy'].replace('bags', 'bag').replace('topwear', 'top').replace('flip flops', 'shoes').replace('sandal', 'shoes')
510
-
511
- # Filter to only include valid hierarchies that exist in our model
512
- valid_hierarchies = df['hierarchy'].dropna().unique()
513
- print(f"🎯 Valid hierarchies found: {sorted(valid_hierarchies)}")
514
- print(f"🎯 Model hierarchies: {sorted(evaluator.hierarchy_classes)}")
515
-
516
- # Filter to only include hierarchies that exist in our model
517
- df = df[df['hierarchy'].isin(evaluator.hierarchy_classes)]
518
- print(f"📊 After filtering to model hierarchies: {len(df)} samples")
519
-
520
- if len(df) == 0:
521
- print("❌ No samples left after hierarchy filtering.")
522
- return pd.DataFrame()
523
-
524
- # Ensure we have text and image data
525
- df = df.dropna(subset=['text', 'image'])
526
- print(f"📊 After removing missing text/image: {len(df)} samples")
527
-
528
- # Show sample of text data to verify quality
529
- print(f"📝 Sample texts:")
530
- for i, (text, hierarchy) in enumerate(zip(df['text'].head(3), df['hierarchy'].head(3))):
531
- print(f" {i+1}. [{hierarchy}] {text[:100]}...")
532
-
533
- print(f"📊 After sampling: {len(df)} samples")
534
- print(f"📊 Samples per hierarchy:")
535
- for hierarchy in sorted(df['hierarchy'].unique()):
536
- count = len(df[df['hierarchy'] == hierarchy])
537
- print(f" {hierarchy}: {count} samples")
538
-
539
- # Create formatted dataset with proper column names
540
- Kagl_formatted = pd.DataFrame({
541
- 'image_url': df['image'],
542
- 'text': df['text'],
543
- 'hierarchy': df['hierarchy']
544
- })
545
-
546
- print(f"📊 Final dataset size: {len(Kagl_formatted)} samples")
547
- return Kagl_formatted
548
-
549
- if __name__ == "__main__":
550
- device = config.device
551
- model_path = config.hierarchy_model_path
552
- directory = config.evaluation_directory
553
-
554
- print(f"🚀 Starting evaluation with {model_path}")
555
-
556
- evaluator = EmbeddingEvaluator(model_path, directory)
557
-
558
- print(f"📊 Final hierarchy classes after initialization: {len(evaluator.vocab.hierarchy_classes)} classes")
559
-
560
- # Evaluate on validation dataset (same subset as during training)
561
- print("\n" + "="*60)
562
- print("EVALUATING VALIDATION DATASET")
563
- print("="*60)
564
- val_results = evaluator.evaluate_dataset(evaluator.val_df, "Validation Dataset")
565
-
566
- print("\n" + "="*60)
567
- print("EVALUATING Kagl MARQO DATASET")
568
- print("="*60)
569
- df_Kagl_marqo = load_Kagl_marqo_dataset(evaluator)
570
- Kagl_results = evaluator.evaluate_dataset(df_Kagl_marqo, "Kagl Marqo Dataset")
571
-
572
- # Compare results
573
- print(f"\n{'='*60}")
574
- print("FINAL EVALUATION SUMMARY")
575
- print(f"{'='*60}")
576
-
577
- print("\n🔍 VALIDATION DATASET RESULTS:")
578
- print(f"Text - Separation: {val_results['text']['separation_score']:.4f} | NN Acc: {val_results['text']['accuracy']*100:.1f}% | Centroid Acc: {val_results['text']['centroid_accuracy']*100:.1f}%")
579
- print(f"Image - Separation: {val_results['image']['separation_score']:.4f} | NN Acc: {val_results['image']['accuracy']*100:.1f}% | Centroid Acc: {val_results['image']['centroid_accuracy']*100:.1f}%")
580
- print(f"hierarchy - Separation: {val_results['hierarchy']['separation_score']:.4f} | NN Acc: {val_results['hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {val_results['hierarchy']['centroid_accuracy']*100:.1f}%")
581
-
582
- print("\n🌐 Kagl MARQO DATASET RESULTS:")
583
- print(f"Text - Separation: {Kagl_results['text']['separation_score']:.4f} | NN Acc: {Kagl_results['text']['accuracy']*100:.1f}% | Centroid Acc: {Kagl_results['text']['centroid_accuracy']*100:.1f}%")
584
- print(f"Image - Separation: {Kagl_results['image']['separation_score']:.4f} | NN Acc: {Kagl_results['image']['accuracy']*100:.1f}% | Centroid Acc: {Kagl_results['image']['centroid_accuracy']*100:.1f}%")
585
- print(f"Hierarchy - Separation: {Kagl_results['hierarchy']['separation_score']:.4f} | NN Acc: {Kagl_results['hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {Kagl_results['hierarchy']['centroid_accuracy']*100:.1f}%")
586
-
587
-
588
- print(f"\n✅ Evaluation completed! Check 'improved_model_evaluation/' for visualization files.")
589
- print(f"📊 Final hierarchy classes used: {len(evaluator.vocab.hierarchy_classes)} classes")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Evaluation/hierarchy_evaluation_with_clip_baseline.py DELETED
@@ -1,808 +0,0 @@
1
- """
2
- Hierarchy embedding evaluation with CLIP baseline comparison.
3
- This file evaluates the quality of hierarchy embeddings from the custom model and compares them
4
- with a CLIP baseline model (OpenAI CLIP). It calculates similarity metrics, classification accuracies,
5
- and generates confusion matrices for both models to measure relative performance. It also supports
6
- evaluation on Fashion-MNIST and kagl Marqo datasets.
7
- """
8
-
9
- import torch
10
- import pandas as pd
11
- import numpy as np
12
- import matplotlib.pyplot as plt
13
- import seaborn as sns
14
- from sklearn.metrics.pairwise import cosine_similarity
15
- from sklearn.metrics import confusion_matrix, classification_report, accuracy_score, f1_score
16
- from collections import defaultdict
17
- import os
18
- import requests
19
- from tqdm import tqdm
20
- from torch.utils.data import Dataset, DataLoader
21
- from torchvision import transforms
22
- from sklearn.model_selection import train_test_split
23
- from io import BytesIO
24
- from PIL import Image
25
- import warnings
26
- warnings.filterwarnings('ignore')
27
-
28
- # Import transformers CLIP
29
- from transformers import CLIPProcessor, CLIPModel as TransformersCLIPModel
30
-
31
- # Import your custom model
32
- from hierarchy_model import Model, HierarchyExtractor, HierarchyDataset, collate_fn
33
- import config
34
-
35
- def convert_fashion_mnist_to_image(pixel_values):
36
- """Convert Fashion-MNIST pixel values to PIL image"""
37
- # Reshape to 28x28 and convert to PIL Image
38
- image_array = np.array(pixel_values).reshape(28, 28).astype(np.uint8)
39
- # Convert to RGB by duplicating the grayscale channel
40
- image_array = np.stack([image_array] * 3, axis=-1)
41
- image = Image.fromarray(image_array)
42
- return image
43
-
44
- def get_fashion_mnist_labels():
45
- """Get Fashion-MNIST class labels"""
46
- return {
47
- 0: "T-shirt/top",
48
- 1: "Trouser",
49
- 2: "Pullover",
50
- 3: "Dress",
51
- 4: "Coat",
52
- 5: "Sandal",
53
- 6: "Shirt",
54
- 7: "Sneaker",
55
- 8: "Bag",
56
- 9: "Ankle boot"
57
- }
58
-
59
- class FashionMNISTDataset(Dataset):
60
- def __init__(self, dataframe, image_size=224):
61
- self.dataframe = dataframe
62
- self.image_size = image_size
63
- self.labels_map = get_fashion_mnist_labels()
64
-
65
- # Simple transforms for validation/inference
66
- self.transform = transforms.Compose([
67
- transforms.Resize((image_size, image_size)),
68
- transforms.ToTensor(),
69
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
70
- ])
71
-
72
- def __len__(self):
73
- return len(self.dataframe)
74
-
75
- def __getitem__(self, idx):
76
- row = self.dataframe.iloc[idx]
77
-
78
- # Get pixel values (columns 1-784)
79
- pixel_cols = [f'pixel{i}' for i in range(1, 785)]
80
- pixel_values = row[pixel_cols].values
81
-
82
- # Convert to image
83
- image = convert_fashion_mnist_to_image(pixel_values)
84
- image = self.transform(image)
85
-
86
- # Get text description
87
- text = row['text']
88
-
89
- # Get hierarchy label
90
- hierarchy = row['hierarchy']
91
-
92
- return image, text, hierarchy
93
-
94
- class CLIPBaselineEvaluator:
95
- def __init__(self, device='mps'):
96
- self.device = torch.device(device)
97
-
98
- # Load CLIP model and processor
99
- print("🤗 Loading CLIP baseline model from transformers...")
100
- self.clip_model = TransformersCLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device)
101
- self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
102
- self.clip_model.eval()
103
- print("✅ CLIP model loaded successfully")
104
-
105
- def extract_clip_embeddings(self, images, texts):
106
- """Extract CLIP embeddings for images and texts"""
107
- all_image_embeddings = []
108
- all_text_embeddings = []
109
-
110
- with torch.no_grad():
111
- for i in tqdm(range(len(images)), desc="Extracting CLIP embeddings"):
112
- # Process image
113
- if isinstance(images[i], torch.Tensor):
114
- # Convert tensor back to PIL Image
115
- image_tensor = images[i]
116
- if image_tensor.dim() == 3:
117
- # Denormalize
118
- mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
119
- std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
120
- image_tensor = image_tensor * std + mean
121
- image_tensor = torch.clamp(image_tensor, 0, 1)
122
-
123
- # Convert to PIL
124
- image_pil = transforms.ToPILImage()(image_tensor)
125
- elif isinstance(images[i], Image.Image):
126
- image_pil = images[i]
127
- else:
128
- raise ValueError(f"Unsupported image type: {type(images[i])}")
129
-
130
- # Process with CLIP
131
- inputs = self.clip_processor(
132
- text=texts[i],
133
- images=image_pil,
134
- return_tensors="pt",
135
- padding=True
136
- ).to(self.device)
137
-
138
- outputs = self.clip_model(**inputs)
139
-
140
- # Get normalized embeddings
141
- image_emb = outputs.image_embeds / outputs.image_embeds.norm(p=2, dim=-1, keepdim=True)
142
- text_emb = outputs.text_embeds / outputs.text_embeds.norm(p=2, dim=-1, keepdim=True)
143
-
144
- all_image_embeddings.append(image_emb.cpu().numpy())
145
- all_text_embeddings.append(text_emb.cpu().numpy())
146
-
147
- return np.vstack(all_image_embeddings), np.vstack(all_text_embeddings)
148
-
149
-
150
- class EmbeddingEvaluator:
151
- def __init__(self, model_path, directory):
152
- self.device = config.device
153
- self.directory = directory
154
-
155
- # 1. Load the dataset
156
- CSV = config.local_dataset_path
157
- print(f"📁 Using dataset with local images: {CSV}")
158
- df = pd.read_csv(CSV)
159
-
160
- print(f"📁 Loaded {len(df)} samples")
161
-
162
- # 2. Get unique hierarchy classes from the dataset
163
- hierarchy_classes = sorted(df[config.hierarchy_column].unique().tolist())
164
- print(f"📋 Found {len(hierarchy_classes)} hierarchy classes")
165
-
166
- _, self.val_df = train_test_split(df, test_size=0.2, random_state=42, stratify=df['hierarchy'])
167
-
168
- # 3. Load the model
169
- if os.path.exists(model_path):
170
- checkpoint = torch.load(model_path, map_location=self.device)
171
-
172
- config = checkpoint.get('config', {})
173
- saved_hierarchy_classes = checkpoint['hierarchy_classes']
174
-
175
- # Use the saved hierarchy classes
176
- self.hierarchy_classes = saved_hierarchy_classes
177
-
178
- # Create the hierarchy extractor
179
- self.vocab = HierarchyExtractor(saved_hierarchy_classes)
180
-
181
- # Create the model with the saved configuration
182
- self.model = Model(
183
- num_hierarchy_classes=len(saved_hierarchy_classes),
184
- embed_dim=config['embed_dim'],
185
- dropout=config['dropout']
186
- ).to(self.device)
187
-
188
- self.model.load_state_dict(checkpoint['model_state'])
189
-
190
- print(f"✅ Custom model loaded with:")
191
- print(f"📋 Hierarchy classes: {len(saved_hierarchy_classes)}")
192
- print(f"🎯 Embed dim: {config['embed_dim']}")
193
- print(f"💧 Dropout: {config['dropout']}")
194
- print(f"📅 Epoch: {checkpoint.get('epoch', 'unknown')}")
195
-
196
- else:
197
- raise FileNotFoundError(f"Model file {model_path} not found")
198
-
199
- self.model.eval()
200
-
201
- # Initialize CLIP baseline
202
- self.clip_evaluator = CLIPBaselineEvaluator(device)
203
-
204
- def create_dataloader(self, dataframe, batch_size=16):
205
- """Create a dataloader for custom model"""
206
- # Check if this is Fashion-MNIST data (has pixel1 column)
207
- if 'pixel1' in dataframe.columns:
208
- print("🔍 Detected Fashion-MNIST data, using FashionMNISTDataset")
209
- dataset = FashionMNISTDataset(dataframe, image_size=224)
210
- else:
211
- dataset = HierarchyDataset(dataframe, image_size=224)
212
-
213
- dataloader = DataLoader(
214
- dataset,
215
- batch_size=batch_size,
216
- shuffle=False,
217
- collate_fn=lambda batch: collate_fn(batch, self.vocab),
218
- num_workers=0
219
- )
220
-
221
- return dataloader
222
-
223
- def create_clip_dataloader(self, dataframe, batch_size=16):
224
- """Create a dataloader for CLIP baseline"""
225
- # Check if this is Fashion-MNIST data (has pixel1 column)
226
- if 'pixel1' in dataframe.columns:
227
- print("🔍 Detected Fashion-MNIST data for CLIP, using FashionMNISTDataset")
228
- dataset = FashionMNISTDataset(dataframe, image_size=224)
229
- else:
230
- dataset = CLIPDataset(dataframe)
231
-
232
- dataloader = DataLoader(
233
- dataset,
234
- batch_size=batch_size,
235
- shuffle=False,
236
- num_workers=0
237
- )
238
-
239
- return dataloader
240
-
241
- def extract_custom_embeddings(self, dataloader, embedding_type='text'):
242
- """Extract embeddings from custom model"""
243
- all_embeddings = []
244
- all_labels = []
245
- all_texts = []
246
-
247
- with torch.no_grad():
248
- for batch in tqdm(dataloader, desc=f"Extracting custom {embedding_type} embeddings"):
249
- images = batch['image'].to(self.device)
250
- hierarchy_indices = batch['hierarchy_indices'].to(self.device)
251
- hierarchy_labels = batch['hierarchy']
252
-
253
- # Forward pass
254
- out = self.model(image=images, hierarchy_indices=hierarchy_indices)
255
- embeddings = out['z_txt'] if embedding_type == 'text' else out['z_img'] if embedding_type == 'image' else out['z_txt']
256
-
257
- all_embeddings.append(embeddings.cpu().numpy())
258
- all_labels.extend(hierarchy_labels)
259
- all_texts.extend(hierarchy_labels)
260
-
261
- return np.vstack(all_embeddings), all_labels, all_texts
262
-
263
- def compute_similarity_metrics(self, embeddings, labels):
264
- """Compute intra-class and inter-class similarities"""
265
- similarities = cosine_similarity(embeddings)
266
-
267
- # Group embeddings by hierarchy
268
- hierarchy_groups = defaultdict(list)
269
- for i, hierarchy in enumerate(labels):
270
- hierarchy_groups[hierarchy].append(i)
271
-
272
- # Calculate intra-class similarities (same hierarchy)
273
- intra_class_similarities = []
274
- for hierarchy, indices in hierarchy_groups.items():
275
- if len(indices) > 1:
276
- for i in range(len(indices)):
277
- for j in range(i+1, len(indices)):
278
- sim = similarities[indices[i], indices[j]]
279
- intra_class_similarities.append(sim)
280
-
281
- # Calculate inter-class similarities (different hierarchies)
282
- inter_class_similarities = []
283
- hierarchies = list(hierarchy_groups.keys())
284
- for i in range(len(hierarchies)):
285
- for j in range(i+1, len(hierarchies)):
286
- hierarchy1_indices = hierarchy_groups[hierarchies[i]]
287
- hierarchy2_indices = hierarchy_groups[hierarchies[j]]
288
-
289
- for idx1 in hierarchy1_indices:
290
- for idx2 in hierarchy2_indices:
291
- sim = similarities[idx1, idx2]
292
- inter_class_similarities.append(sim)
293
-
294
- # Calculate classification accuracy using nearest neighbor in embedding space
295
- nn_accuracy = self.compute_embedding_accuracy(embeddings, labels, similarities)
296
-
297
- # Calculate classification accuracy using centroids
298
- centroid_accuracy = self.compute_centroid_accuracy(embeddings, labels)
299
-
300
- return {
301
- 'intra_class_similarities': intra_class_similarities,
302
- 'inter_class_similarities': inter_class_similarities,
303
- 'intra_class_mean': np.mean(intra_class_similarities) if intra_class_similarities else 0,
304
- 'inter_class_mean': np.mean(inter_class_similarities) if inter_class_similarities else 0,
305
- 'separation_score': np.mean(intra_class_similarities) - np.mean(inter_class_similarities) if intra_class_similarities and inter_class_similarities else 0,
306
- 'accuracy': nn_accuracy,
307
- 'centroid_accuracy': centroid_accuracy
308
- }
309
-
310
- def compute_embedding_accuracy(self, embeddings, labels, similarities):
311
- """Compute classification accuracy using nearest neighbor in embedding space"""
312
- correct_predictions = 0
313
- total_predictions = len(labels)
314
-
315
- for i in range(len(embeddings)):
316
- true_label = labels[i]
317
-
318
- # Find the most similar embedding (excluding itself)
319
- similarities_row = similarities[i].copy()
320
- similarities_row[i] = -1 # Exclude self-similarity
321
- nearest_neighbor_idx = np.argmax(similarities_row)
322
- predicted_label = labels[nearest_neighbor_idx]
323
-
324
- if predicted_label == true_label:
325
- correct_predictions += 1
326
-
327
- return correct_predictions / total_predictions if total_predictions > 0 else 0
328
-
329
- def compute_centroid_accuracy(self, embeddings, labels):
330
- """Compute classification accuracy using hierarchy centroids"""
331
- # Create centroids for each hierarchy
332
- unique_hierarchies = list(set(labels))
333
- centroids = {}
334
-
335
- for hierarchy in unique_hierarchies:
336
- hierarchy_indices = [i for i, label in enumerate(labels) if label == hierarchy]
337
- hierarchy_embeddings = embeddings[hierarchy_indices]
338
- centroids[hierarchy] = np.mean(hierarchy_embeddings, axis=0)
339
-
340
- # Classify each embedding to nearest centroid
341
- correct_predictions = 0
342
- total_predictions = len(labels)
343
-
344
- for i, embedding in enumerate(embeddings):
345
- true_label = labels[i]
346
-
347
- # Find closest centroid
348
- best_similarity = -1
349
- predicted_label = None
350
-
351
- for hierarchy, centroid in centroids.items():
352
- similarity = cosine_similarity([embedding], [centroid])[0][0]
353
- if similarity > best_similarity:
354
- best_similarity = similarity
355
- predicted_label = hierarchy
356
-
357
- if predicted_label == true_label:
358
- correct_predictions += 1
359
-
360
- return correct_predictions / total_predictions if total_predictions > 0 else 0
361
-
362
- def create_confusion_matrix(self, true_labels, predicted_labels, title="Confusion Matrix"):
363
- """Create and plot confusion matrix"""
364
- # Get unique labels
365
- unique_labels = sorted(list(set(true_labels + predicted_labels)))
366
-
367
- # Create confusion matrix
368
- cm = confusion_matrix(true_labels, predicted_labels, labels=unique_labels)
369
-
370
- # Calculate accuracy
371
- accuracy = accuracy_score(true_labels, predicted_labels)
372
-
373
- # Plot confusion matrix
374
- plt.figure(figsize=(12, 10))
375
- sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
376
- xticklabels=unique_labels, yticklabels=unique_labels)
377
- plt.title(f'{title}\nAccuracy: {accuracy:.3f} ({accuracy*100:.1f}%)')
378
- plt.ylabel('True Hierarchy')
379
- plt.xlabel('Predicted Hierarchy')
380
- plt.xticks(rotation=45)
381
- plt.yticks(rotation=0)
382
- plt.tight_layout()
383
-
384
- return plt.gcf(), accuracy, cm
385
-
386
- def predict_hierarchy_from_embeddings(self, embeddings, labels):
387
- """Predict hierarchy from embeddings using centroid-based classification"""
388
- # Create hierarchy centroids from training data
389
- unique_hierarchies = list(set(labels))
390
- centroids = {}
391
-
392
- for hierarchy in unique_hierarchies:
393
- hierarchy_indices = [i for i, label in enumerate(labels) if label == hierarchy]
394
- hierarchy_embeddings = embeddings[hierarchy_indices]
395
- centroids[hierarchy] = np.mean(hierarchy_embeddings, axis=0)
396
-
397
- # Predict hierarchy for all embeddings
398
- predictions = []
399
-
400
- for i, embedding in enumerate(embeddings):
401
- # Find closest centroid
402
- best_similarity = -1
403
- predicted_hierarchy = None
404
-
405
- for hierarchy, centroid in centroids.items():
406
- similarity = cosine_similarity([embedding], [centroid])[0][0]
407
- if similarity > best_similarity:
408
- best_similarity = similarity
409
- predicted_hierarchy = hierarchy
410
-
411
- predictions.append(predicted_hierarchy)
412
-
413
- return predictions
414
-
415
- def evaluate_classification_performance(self, embeddings, labels, embedding_type="Embeddings"):
416
- """Evaluate classification performance and create confusion matrix"""
417
- # Predict hierarchy
418
- predictions = self.predict_hierarchy_from_embeddings(embeddings, labels)
419
-
420
- # Calculate accuracy
421
- accuracy = accuracy_score(labels, predictions)
422
-
423
- # Calculate F1 scores
424
- unique_labels = sorted(list(set(labels)))
425
- f1_macro = f1_score(labels, predictions, labels=unique_labels, average='macro', zero_division=0)
426
- f1_weighted = f1_score(labels, predictions, labels=unique_labels, average='weighted', zero_division=0)
427
- f1_per_class = f1_score(labels, predictions, labels=unique_labels, average=None, zero_division=0)
428
-
429
- # Create confusion matrix
430
- fig, acc, cm = self.create_confusion_matrix(labels, predictions,
431
- f"{embedding_type} - Hierarchy Classification")
432
-
433
- # Generate classification report
434
- report = classification_report(labels, predictions, labels=unique_labels,
435
- target_names=unique_labels, output_dict=True)
436
-
437
- return {
438
- 'accuracy': accuracy,
439
- 'f1_macro': f1_macro,
440
- 'f1_weighted': f1_weighted,
441
- 'f1_per_class': f1_per_class,
442
- 'predictions': predictions,
443
- 'confusion_matrix': cm,
444
- 'classification_report': report,
445
- 'figure': fig
446
- }
447
-
448
- def evaluate_dataset_with_baselines(self, dataframe, dataset_name="Dataset"):
449
- """Evaluate embeddings on a given dataset with both custom model and CLIP baseline"""
450
- print(f"\n{'='*60}")
451
- print(f"Evaluating {dataset_name}")
452
- print(f"{'='*60}")
453
-
454
- results = {}
455
-
456
- # ===== CUSTOM MODEL EVALUATION =====
457
- print(f"\n🔧 Evaluating Custom Model on {dataset_name}")
458
- print("-" * 40)
459
-
460
- # Create dataloader for custom model
461
- custom_dataloader = self.create_dataloader(dataframe, batch_size=16)
462
-
463
- # Evaluate text embeddings
464
- text_embeddings, text_labels, texts = self.extract_custom_embeddings(custom_dataloader, 'text')
465
- text_metrics = self.compute_similarity_metrics(text_embeddings, text_labels)
466
- text_classification = self.evaluate_classification_performance(text_embeddings, text_labels, "Custom Text Embeddings")
467
- text_metrics.update(text_classification)
468
- results['custom_text'] = text_metrics
469
-
470
- # Evaluate image embeddings
471
- image_embeddings, image_labels, _ = self.extract_custom_embeddings(custom_dataloader, 'image')
472
- image_metrics = self.compute_similarity_metrics(image_embeddings, image_labels)
473
- image_classification = self.evaluate_classification_performance(image_embeddings, image_labels, "Custom Image Embeddings")
474
- image_metrics.update(image_classification)
475
- results['custom_image'] = image_metrics
476
-
477
- # ===== CLIP BASELINE EVALUATION =====
478
- print(f"\n🤗 Evaluating CLIP Baseline on {dataset_name}")
479
- print("-" * 40)
480
-
481
- # Create dataloader for CLIP
482
- clip_dataloader = self.create_clip_dataloader(dataframe, batch_size=8) # Smaller batch for CLIP
483
-
484
- # Extract data for CLIP
485
- all_images = []
486
- all_texts = []
487
- all_labels = []
488
-
489
- for batch in tqdm(clip_dataloader, desc="Preparing data for CLIP"):
490
- images, texts, labels = batch
491
- all_images.extend(images)
492
- all_texts.extend(texts)
493
- all_labels.extend(labels)
494
-
495
- # Get CLIP embeddings
496
- clip_image_embeddings, clip_text_embeddings = self.clip_evaluator.extract_clip_embeddings(all_images, all_texts)
497
-
498
- # Evaluate CLIP text embeddings
499
- clip_text_metrics = self.compute_similarity_metrics(clip_text_embeddings, all_labels)
500
- clip_text_classification = self.evaluate_classification_performance(clip_text_embeddings, all_labels, "CLIP Text Embeddings")
501
- clip_text_metrics.update(clip_text_classification)
502
- results['clip_text'] = clip_text_metrics
503
-
504
- # Evaluate CLIP image embeddings
505
- clip_image_metrics = self.compute_similarity_metrics(clip_image_embeddings, all_labels)
506
- clip_image_classification = self.evaluate_classification_performance(clip_image_embeddings, all_labels, "CLIP Image Embeddings")
507
- clip_image_metrics.update(clip_image_classification)
508
- results['clip_image'] = clip_image_metrics
509
-
510
- # ===== PRINT COMPARISON RESULTS =====
511
- print(f"\n{dataset_name} Results Comparison:")
512
- print(f"Dataset size: {len(dataframe)} samples")
513
- print("=" * 80)
514
- print(f"{'Model':<20} {'Embedding':<10} {'Sep Score':<10} {'NN Acc':<8} {'Centroid Acc':<12} {'F1 Macro':<10}")
515
- print("-" * 80)
516
-
517
- for model_type in ['custom', 'clip']:
518
- for emb_type in ['text', 'image']:
519
- key = f"{model_type}_{emb_type}"
520
- if key in results:
521
- metrics = results[key]
522
- model_name = "Custom Model" if model_type == 'custom' else "CLIP Baseline"
523
- print(f"{model_name:<20} {emb_type.capitalize():<10} {metrics['separation_score']:<10.4f} {metrics['accuracy']*100:<8.1f}% {metrics['centroid_accuracy']*100:<12.1f}% {metrics['f1_macro']*100:<10.1f}%")
524
-
525
- # ===== SAVE VISUALIZATIONS =====
526
- os.makedirs(f'{self.directory}', exist_ok=True)
527
-
528
- # Save confusion matrices
529
- for key, metrics in results.items():
530
- if 'figure' in metrics:
531
- metrics['figure'].savefig(f'{self.directory}/{dataset_name.lower()}_{key}_confusion_matrix.png', dpi=300, bbox_inches='tight')
532
- plt.close(metrics['figure'])
533
-
534
- return results
535
-
536
-
537
- class CLIPDataset(Dataset):
538
- def __init__(self, dataframe):
539
- self.dataframe = dataframe
540
- # Use VALIDATION transforms (no augmentation)
541
- self.transform = transforms.Compose([
542
- transforms.Resize((224, 224)),
543
- transforms.ToTensor(),
544
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
545
- ])
546
-
547
- def __len__(self):
548
- return len(self.dataframe)
549
-
550
- def __getitem__(self, idx):
551
- row = self.dataframe.iloc[idx]
552
-
553
- # Handle image loading (same as HierarchyDataset)
554
- if config.column_local_image_path in row.index and pd.notna(row[config.column_local_image_path]):
555
- local_path = row[config.column_local_image_path]
556
- try:
557
- if os.path.exists(local_path):
558
- image = Image.open(local_path).convert("RGB")
559
- else:
560
- print(f"⚠️ Local image not found: {local_path}")
561
- image = Image.new('RGB', (224, 224), color='gray')
562
- except Exception as e:
563
- print(f"⚠️ Failed to load local image {idx}: {e}")
564
- image = Image.new('RGB', (224, 224), color='gray')
565
- elif isinstance(row[config.column_url_image], dict):
566
- image = Image.open(BytesIO(row[config.column_url_image]['bytes'])).convert('RGB')
567
- elif isinstance(row['image_url'], (list, np.ndarray)):
568
- pixels = np.array(row[config.column_url_image]).reshape(28, 28)
569
- image = Image.fromarray(pixels.astype(np.uint8)).convert("RGB")
570
- elif isinstance(row[config.column_url_image], Image.Image):
571
- # Handle PIL Image objects directly (for Fashion-MNIST)
572
- image = row[config.column_url_image].convert("RGB")
573
- else:
574
- try:
575
- response = requests.get(row[config.column_url_image], timeout=10)
576
- response.raise_for_status()
577
- image = Image.open(BytesIO(response.content)).convert("RGB")
578
- except Exception as e:
579
- print(f"⚠️ Failed to load image {idx}: {e}")
580
- image = Image.new('RGB', (224, 224), color='gray')
581
-
582
- # Apply transforms
583
- image_tensor = self.transform(image)
584
-
585
- description = row[config.text_column]
586
- hierarchy = row[config.hierarchy_column]
587
-
588
- return image_tensor, description, hierarchy
589
-
590
-
591
- def load_fashion_mnist_dataset(evaluator):
592
- """Load and prepare Fashion-MNIST test dataset"""
593
- print("Loading Fashion-MNIST test dataset...")
594
-
595
- # Load the dataset
596
- df = pd.read_csv(config.fashion_mnist_test_path)
597
- print(f"✅ Fashion-MNIST dataset loaded")
598
- print(f"📊 Total samples: {len(df)}")
599
-
600
- # Fashion-MNIST class labels mapping
601
- fashion_mnist_labels = get_fashion_mnist_labels()
602
-
603
- # Map labels to hierarchy classes
604
- hierarchy_mapping = {
605
- 'T-shirt/top': 'top',
606
- 'Trouser': 'bottom',
607
- 'Pullover': 'top',
608
- 'Dress': 'dress',
609
- 'Coat': 'top',
610
- 'Sandal': 'shoes',
611
- 'Shirt': 'top',
612
- 'Sneaker': 'shoes',
613
- 'Bag': 'bag',
614
- 'Ankle boot': 'shoes'
615
- }
616
-
617
- # Apply label mapping
618
- df['hierarchy'] = df['label'].map(fashion_mnist_labels).map(hierarchy_mapping)
619
-
620
- # Filter to only include hierarchies that exist in our model
621
- valid_hierarchies = df['hierarchy'].dropna().unique()
622
- print(f"🎯 Valid hierarchies found: {sorted(valid_hierarchies)}")
623
- print(f"🎯 Model hierarchies: {sorted(evaluator.hierarchy_classes)}")
624
-
625
- # Filter to only include hierarchies that exist in our model
626
- df = df[df['hierarchy'].isin(evaluator.hierarchy_classes)]
627
- print(f"📊 After filtering to model hierarchies: {len(df)} samples")
628
-
629
- if len(df) == 0:
630
- print("❌ No samples left after hierarchy filtering.")
631
- return pd.DataFrame()
632
-
633
- # Keep pixel columns as they are (FashionMNISTDataset will handle them)
634
-
635
- # Create text descriptions based on hierarchy
636
- text_descriptions = {
637
- 'top': 'A top clothing item',
638
- 'bottom': 'A bottom clothing item',
639
- 'dress': 'A dress',
640
- 'shoes': 'A pair of shoes',
641
- 'bag': 'A bag'
642
- }
643
-
644
- df['text'] = df['hierarchy'].map(text_descriptions)
645
-
646
- # Show sample of data
647
- print(f"📝 Sample data:")
648
- for i, (hierarchy, text) in enumerate(zip(df['hierarchy'].head(3), df['text'].head(3))):
649
- print(f" {i+1}. [{hierarchy}] {text}")
650
-
651
- df_test = df.copy()
652
-
653
- print(f"📊 After sampling: {len(df_test)} samples")
654
- print(f"📊 Samples per hierarchy:")
655
- for hierarchy in sorted(df_test['hierarchy'].unique()):
656
- count = len(df_test[df_test['hierarchy'] == hierarchy])
657
- print(f" {hierarchy}: {count} samples")
658
-
659
- # Create formatted dataset with proper column names
660
- # Keep all pixel columns for FashionMNISTDataset
661
- pixel_cols = [f'pixel{i}' for i in range(1, 785)]
662
- fashion_mnist_formatted = df_test[['label'] + pixel_cols + ['text', 'hierarchy']].copy()
663
-
664
- print(f"📊 Final dataset size: {len(fashion_mnist_formatted)} samples")
665
- return fashion_mnist_formatted
666
-
667
-
668
- def load_kagl_marqo_dataset(evaluator):
669
- """Load and prepare kagl dataset"""
670
- from datasets import load_dataset
671
- print("Loading kagl dataset...")
672
-
673
- # Load the dataset
674
- dataset = load_dataset("Marqo/KAGL")
675
- df = dataset["data"].to_pandas()
676
- print(f"✅ Dataset kagl loaded")
677
- print(f"📊 Before filtering: {len(df)} samples")
678
- print(f"📋 Available columns: {list(df.columns)}")
679
-
680
- # Check available categories and map them to our hierarchy
681
- print(f"🎨 Available categories: {sorted(df['category2'].unique())}")
682
- # Apply mapping
683
- df['hierarchy'] = df['category2'].str.lower()
684
- df['hierarchy'] = df['hierarchy'].replace('bags', 'bag').replace('topwear', 'top').replace('flip flops', 'shoes').replace('sandal', 'shoes')
685
-
686
- # Filter to only include valid hierarchies that exist in our model
687
- valid_hierarchies = df['hierarchy'].dropna().unique()
688
- print(f"🎯 Valid hierarchies found: {sorted(valid_hierarchies)}")
689
- print(f"🎯 Model hierarchies: {sorted(evaluator.hierarchy_classes)}")
690
-
691
- # Filter to only include hierarchies that exist in our model
692
- df = df[df['hierarchy'].isin(evaluator.hierarchy_classes)]
693
- print(f"📊 After filtering to model hierarchies: {len(df)} samples")
694
-
695
- if len(df) == 0:
696
- print("❌ No samples left after hierarchy filtering.")
697
- return pd.DataFrame()
698
-
699
- # Ensure we have text and image data
700
- df = df.dropna(subset=['text', 'image'])
701
- print(f"📊 After removing missing text/image: {len(df)} samples")
702
-
703
- # Show sample of text data to verify quality
704
- print(f"📝 Sample texts:")
705
- for i, (text, hierarchy) in enumerate(zip(df['text'].head(3), df['hierarchy'].head(3))):
706
- print(f" {i+1}. [{hierarchy}] {text[:100]}...")
707
-
708
- print(f"📊 After sampling: {len(df_test)} samples")
709
- print(f"📊 Samples per hierarchy:")
710
- for hierarchy in sorted(df_test['hierarchy'].unique()):
711
- count = len(df_test[df_test['hierarchy'] == hierarchy])
712
- print(f" {hierarchy}: {count} samples")
713
-
714
- # Create formatted dataset with proper column names
715
- kagl_formatted = pd.DataFrame({
716
- 'image_url': df_test['image'],
717
- 'text': df_test['text'],
718
- 'hierarchy': df_test['hierarchy']
719
- })
720
-
721
- print(f"📊 Final dataset size: {len(kagl_formatted)} samples")
722
- return kagl_formatted
723
-
724
-
725
- if __name__ == "__main__":
726
- device = config.device
727
- directory = config.evaluation_directory
728
-
729
- print(f"🚀 Starting evaluation with custom model: {config.hierarchy_model_path}")
730
- print(f"🤗 Including CLIP baseline comparison")
731
-
732
- evaluator = EmbeddingEvaluator(config.hierarchy_model_path, directory, device=device)
733
-
734
- print(f"📊 Final hierarchy classes after initialization: {len(evaluator.vocab.hierarchy_classes)} classes")
735
-
736
- # Evaluate on validation dataset (same subset as during training)
737
- print("\n" + "="*60)
738
- print("EVALUATING VALIDATION DATASET - CUSTOM MODEL vs CLIP BASELINE")
739
- print("="*60)
740
- val_results = evaluator.evaluate_dataset_with_baselines(evaluator.val_df, "Validation Dataset")
741
-
742
- print("\n" + "="*60)
743
- print("EVALUATING FASHION-MNIST TEST DATASET - CUSTOM MODEL vs CLIP BASELINE")
744
- print("="*60)
745
- df_fashion_mnist = load_fashion_mnist_dataset(evaluator)
746
- if len(df_fashion_mnist) > 0:
747
- fashion_mnist_results = evaluator.evaluate_dataset_with_baselines(df_fashion_mnist, "Fashion-MNIST Test Dataset")
748
- else:
749
- fashion_mnist_results = {}
750
-
751
- print("\n" + "="*60)
752
- print("EVALUATING kagl MARQO DATASET - CUSTOM MODEL vs CLIP BASELINE")
753
- print("="*60)
754
- df_kagl_marqo = load_kagl_marqo_dataset(evaluator)
755
- if len(df_kagl_marqo) > 0:
756
- kagl_results = evaluator.evaluate_dataset_with_baselines(df_kagl_marqo, "kagl Marqo Dataset")
757
- else:
758
- kagl_results = {}
759
-
760
- # Compare results
761
- print(f"\n{'='*80}")
762
- print("FINAL EVALUATION SUMMARY - CUSTOM MODEL vs CLIP BASELINE")
763
- print(f"{'='*80}")
764
-
765
- print("\n🔍 VALIDATION DATASET RESULTS:")
766
- print(f"Dataset size: {len(evaluator.val_df)} samples")
767
- print(f"{'Model':<20} {'Embedding':<10} {'Sep Score':<12} {'NN Acc':<10} {'Centroid Acc':<12} {'F1 Macro':<10}")
768
- print("-" * 80)
769
-
770
- for model_type in ['custom', 'clip']:
771
- for emb_type in ['text', 'image']:
772
- key = f"{model_type}_{emb_type}"
773
- if key in val_results:
774
- metrics = val_results[key]
775
- model_name = "Custom Model" if model_type == 'custom' else "CLIP Baseline"
776
- print(f"{model_name:<20} {emb_type.capitalize():<10} {metrics['separation_score']:<12.4f} {metrics['accuracy']*100:<10.1f}% {metrics['centroid_accuracy']*100:<12.1f}% {metrics['f1_macro']*100:<10.1f}%")
777
-
778
- if fashion_mnist_results:
779
- print("\n👗 FASHION-MNIST TEST DATASET RESULTS:")
780
- print(f"Dataset size: {len(df_fashion_mnist)} samples")
781
- print(f"{'Model':<20} {'Embedding':<10} {'Sep Score':<12} {'NN Acc':<10} {'Centroid Acc':<12} {'F1 Macro':<10}")
782
- print("-" * 80)
783
-
784
- for model_type in ['custom', 'clip']:
785
- for emb_type in ['text', 'image']:
786
- key = f"{model_type}_{emb_type}"
787
- if key in fashion_mnist_results:
788
- metrics = fashion_mnist_results[key]
789
- model_name = "Custom Model" if model_type == 'custom' else "CLIP Baseline"
790
- print(f"{model_name:<20} {emb_type.capitalize():<10} {metrics['separation_score']:<12.4f} {metrics['accuracy']*100:<10.1f}% {metrics['centroid_accuracy']*100:<12.1f}% {metrics['f1_macro']*100:<10.1f}%")
791
-
792
- if kagl_results:
793
- print("\n🌐 kagl MARQO DATASET RESULTS:")
794
- print(f"Dataset size: {len(df_kagl_marqo)} samples")
795
- print(f"{'Model':<20} {'Embedding':<10} {'Sep Score':<12} {'NN Acc':<10} {'Centroid Acc':<12} {'F1 Macro':<10}")
796
- print("-" * 80)
797
-
798
- for model_type in ['custom', 'clip']:
799
- for emb_type in ['text', 'image']:
800
- key = f"{model_type}_{emb_type}"
801
- if key in kagl_results:
802
- metrics = kagl_results[key]
803
- model_name = "Custom Model" if model_type == 'custom' else "CLIP Baseline"
804
- print(f"{model_name:<20} {emb_type.capitalize():<10} {metrics['separation_score']:<12.4f} {metrics['accuracy']*100:<10.1f}% {metrics['centroid_accuracy']*100:<12.1f}% {metrics['f1_macro']*100:<10.1f}%")
805
-
806
- print(f"\n✅ Evaluation completed! Check '{directory}/' for visualization files.")
807
- print(f"📊 Custom model hierarchy classes: {len(evaluator.vocab.hierarchy_classes)} classes")
808
- print(f"🤗 CLIP baseline comparison included")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Evaluation/main_model_evaluation.py DELETED
The diff for this file is too large to render. See raw diff
 
Evaluation/tsne_images.py DELETED
@@ -1,569 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Outputs several t-SNE visualizations with color and hierarchy overlays to
4
- verify that the main model separates colors well inside each hierarchy group.
5
- """
6
-
7
- import math
8
-
9
- import matplotlib.pyplot as plt
10
- import numpy as np
11
- import pandas as pd
12
- import seaborn as sns
13
- import torch
14
- from matplotlib.patches import Polygon
15
- from PIL import Image
16
- from sklearn.manifold import TSNE
17
- from sklearn.metrics import (
18
- silhouette_score,
19
- davies_bouldin_score,
20
- calinski_harabasz_score,
21
- )
22
- from sklearn.preprocessing import normalize
23
- from sklearn.metrics.pairwise import cosine_similarity
24
- from torch.utils.data import DataLoader, Dataset
25
- from torchvision import transforms
26
- from tqdm import tqdm
27
- from transformers import CLIPModel as CLIPModel_transformers, CLIPProcessor
28
-
29
- try:
30
- from scipy.spatial import ConvexHull
31
- except ImportError:
32
- ConvexHull = None
33
-
34
- from config import (
35
- color_column,
36
- color_emb_dim,
37
- column_local_image_path,
38
- device,
39
- hierarchy_column,
40
- hierarchy_emb_dim,
41
- images_dir,
42
- local_dataset_path,
43
- main_model_path,
44
- )
45
-
46
-
47
- class ImageDataset(Dataset):
48
- """Lightweight dataset to load local images along with colors and hierarchies."""
49
-
50
- def __init__(self, dataframe: pd.DataFrame, root_dir: str):
51
- self.df = dataframe.reset_index(drop=True)
52
- self.root_dir = root_dir
53
- self.transform = transforms.Compose(
54
- [
55
- transforms.Resize((224, 224)),
56
- transforms.ToTensor(),
57
- transforms.Normalize(
58
- mean=[0.485, 0.456, 0.406],
59
- std=[0.229, 0.224, 0.225],
60
- ),
61
- ]
62
- )
63
-
64
- def __len__(self):
65
- return len(self.df)
66
-
67
- def __getitem__(self, idx):
68
- row = self.df.iloc[idx]
69
- img_path = row[column_local_image_path]
70
- image = Image.open(img_path).convert("RGB")
71
- image = self.transform(image)
72
- color = row[color_column]
73
- hierarchy = row[hierarchy_column]
74
- return image, color, hierarchy
75
-
76
-
77
-
78
- def load_main_model():
79
- """Load the main model with the trained weights."""
80
- checkpoint = torch.load(main_model_path, map_location=device)
81
- state_dict = checkpoint.get("model_state_dict", checkpoint)
82
- model = CLIPModel_transformers.from_pretrained(
83
- "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
84
- )
85
- model.load_state_dict(state_dict)
86
- model.to(device)
87
- model.eval()
88
- # Load processor for text tokenization
89
- processor = CLIPProcessor.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K")
90
- return model, processor
91
-
92
-
93
- def load_clip_baseline():
94
- """Load the CLIP baseline model from transformers."""
95
- print("🤗 Loading CLIP baseline model from transformers...")
96
- clip_model = CLIPModel_transformers.from_pretrained("openai/clip-vit-base-patch32").to(device)
97
- clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
98
- clip_model.eval()
99
- print("✅ CLIP baseline model loaded successfully")
100
- return clip_model, clip_processor
101
-
102
-
103
- def enforce_min_hierarchy_samples(df, min_per_hierarchy):
104
- """Filter out hierarchy groups with fewer than min_per_hierarchy rows."""
105
- if not min_per_hierarchy or min_per_hierarchy <= 0:
106
- return df
107
- counts = df[hierarchy_column].value_counts()
108
- keep_values = counts[counts >= min_per_hierarchy].index
109
- filtered = df[df[hierarchy_column].isin(keep_values)].reset_index(drop=True)
110
- return filtered
111
-
112
-
113
- def prepare_dataframe(df, sample_size, per_color_limit, min_per_hierarchy=None):
114
- """Subsample the dataframe to speed up the t-SNE."""
115
- if per_color_limit and per_color_limit > 0:
116
- df_limited = (
117
- df.groupby(color_column)
118
- .apply(lambda g: g.sample(min(len(g), per_color_limit), random_state=42))
119
- .reset_index(drop=True)
120
- )
121
- else:
122
- df_limited = df
123
-
124
- if sample_size and 0 < sample_size < len(df_limited):
125
- df_limited = df_limited.sample(sample_size, random_state=42).reset_index(
126
- drop=True
127
- )
128
- df_limited = enforce_min_hierarchy_samples(df_limited, min_per_hierarchy)
129
- return df_limited
130
-
131
-
132
- def compute_embeddings(model, dataloader):
133
- """Extract color, hierarchy, and combined embeddings."""
134
- color_embeddings = []
135
- hierarchy_embeddings = []
136
- color_labels = []
137
- hierarchy_labels = []
138
- with torch.no_grad():
139
- for images, colors, hierarchies in tqdm(
140
- dataloader, desc="Extracting embeddings"
141
- ):
142
- images = images.to(device)
143
- if images.shape[1] == 1: # safety in case
144
- images = images.expand(-1, 3, -1, -1)
145
- image_embeds = model.get_image_features(pixel_values=images)
146
- color_part = image_embeds[:, :color_emb_dim]
147
- hierarchy_part = image_embeds[
148
- :, color_emb_dim : color_emb_dim + hierarchy_emb_dim
149
- ]
150
- color_embeddings.append(color_part.cpu().numpy())
151
- hierarchy_embeddings.append(hierarchy_part.cpu().numpy())
152
- color_labels.extend(colors)
153
- hierarchy_labels.extend(hierarchies)
154
- return (
155
- np.concatenate(color_embeddings, axis=0),
156
- np.concatenate(hierarchy_embeddings, axis=0),
157
- color_labels,
158
- hierarchy_labels,
159
- )
160
-
161
-
162
- def compute_clip_embeddings(clip_model, clip_processor, dataloader):
163
- """Extract CLIP baseline embeddings (full image embeddings, not separated)."""
164
- all_embeddings = []
165
- color_labels = []
166
- hierarchy_labels = []
167
-
168
- with torch.no_grad():
169
- for images, colors, hierarchies in tqdm(
170
- dataloader, desc="Extracting CLIP embeddings"
171
- ):
172
- batch_embeddings = []
173
- for i in range(images.shape[0]):
174
- # Get single image from batch
175
- image_tensor = images[i] # Shape: (3, 224, 224)
176
-
177
- # Denormalize on CPU (safer for PIL conversion)
178
- mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
179
- std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
180
- image_tensor = image_tensor * std + mean
181
- image_tensor = torch.clamp(image_tensor, 0, 1)
182
-
183
- # Convert to PIL Image (must be on CPU)
184
- image_pil = transforms.ToPILImage()(image_tensor.cpu())
185
-
186
- # Process with CLIP (using empty text since we only need image embeddings)
187
- inputs = clip_processor(
188
- text="",
189
- images=image_pil,
190
- return_tensors="pt",
191
- padding=True
192
- ).to(device)
193
-
194
- outputs = clip_model(**inputs)
195
- # Get normalized image embeddings
196
- image_emb = outputs.image_embeds / outputs.image_embeds.norm(p=2, dim=-1, keepdim=True)
197
- batch_embeddings.append(image_emb.cpu().numpy())
198
-
199
- all_embeddings.append(np.vstack(batch_embeddings))
200
- color_labels.extend(colors)
201
- hierarchy_labels.extend(hierarchies)
202
-
203
- # For CLIP, we use the full embeddings for all visualizations
204
- # (no separation into color/hierarchy dimensions)
205
- full_embeddings = np.concatenate(all_embeddings, axis=0)
206
- return (
207
- full_embeddings, # color_embeddings (using full CLIP embeddings)
208
- full_embeddings, # hierarchy_embeddings (using full CLIP embeddings)
209
- full_embeddings, # color_hier_embeddings (using full CLIP embeddings)
210
- color_labels,
211
- hierarchy_labels,
212
- )
213
-
214
-
215
- def compute_dunn_index(embeddings, labels):
216
- """
217
- Compute the Dunn Index for clustering evaluation.
218
-
219
- The Dunn Index is the ratio of the minimum inter-cluster distance
220
- to the maximum intra-cluster distance. Higher values indicate better clustering.
221
-
222
- Args:
223
- embeddings: Array of embeddings [N, embed_dim]
224
- labels: Array of cluster labels [N]
225
-
226
- Returns:
227
- Dunn Index value (float) or None if calculation fails
228
- """
229
- try:
230
- unique_labels = np.unique(labels)
231
- if len(unique_labels) < 2:
232
- return None
233
-
234
- # Calculate intra-cluster distances (maximum within each cluster)
235
- max_intra_cluster_dist = 0
236
- for label in unique_labels:
237
- cluster_points = embeddings[labels == label]
238
- if len(cluster_points) > 1:
239
- # Calculate pairwise distances within cluster
240
- from scipy.spatial.distance import pdist
241
- intra_dists = pdist(cluster_points, metric='euclidean')
242
- if len(intra_dists) > 0:
243
- max_intra = np.max(intra_dists)
244
- max_intra_cluster_dist = max(max_intra_cluster_dist, max_intra)
245
-
246
- if max_intra_cluster_dist == 0:
247
- return None
248
-
249
- # Calculate inter-cluster distances (minimum between clusters)
250
- min_inter_cluster_dist = float('inf')
251
- for i, label1 in enumerate(unique_labels):
252
- for label2 in unique_labels[i+1:]:
253
- cluster1_points = embeddings[labels == label1]
254
- cluster2_points = embeddings[labels == label2]
255
-
256
- # Calculate distances between clusters
257
- from scipy.spatial.distance import cdist
258
- inter_dists = cdist(cluster1_points, cluster2_points, metric='euclidean')
259
- min_inter = np.min(inter_dists)
260
- min_inter_cluster_dist = min(min_inter_cluster_dist, min_inter)
261
-
262
- if min_inter_cluster_dist == float('inf'):
263
- return None
264
-
265
- # Dunn Index = minimum inter-cluster distance / maximum intra-cluster distance
266
- dunn_index = min_inter_cluster_dist / max_intra_cluster_dist
267
- return float(dunn_index)
268
- except Exception as e:
269
- print(f"⚠️ Error computing Dunn Index: {e}")
270
- return None
271
-
272
-
273
- def build_color_map(labels, prefer_true_colors=False):
274
- """Build a color mapping for labels."""
275
- unique_labels = sorted(set(labels))
276
- palette = sns.color_palette("husl", len(unique_labels))
277
- return {label: palette[idx] for idx, label in enumerate(unique_labels)}
278
-
279
-
280
- def compute_color_similarity_matrix(embeddings, colors, title="Color similarity (image embeddings)"):
281
- """Compute and visualize similarity matrix between color centroids."""
282
- # Use only the colors from the reference heatmap
283
- reference_colors = ['red', 'pink', 'blue', 'green', 'aqua', 'lime', 'yellow', 'orange',
284
- 'purple', 'brown', 'gray', 'black', 'white']
285
- # Map 'yelloworange' to 'yellow' or 'orange' if needed
286
- color_mapping = {
287
- 'yelloworange': 'yellow',
288
- 'grey': 'gray' # Handle grey/gray variation
289
- }
290
-
291
- # Filter to only include colors that are in the reference list
292
- filtered_colors = []
293
- filtered_embeddings = []
294
- for i, color in enumerate(colors):
295
- # Normalize color name
296
- normalized_color = color_mapping.get(color.lower(), color.lower())
297
- if normalized_color in reference_colors:
298
- filtered_colors.append(normalized_color)
299
- filtered_embeddings.append(embeddings[i])
300
-
301
- if len(filtered_colors) == 0:
302
- print("⚠️ No matching colors found in reference list")
303
- return None
304
-
305
- # Use only unique colors from reference that exist in data
306
- unique_colors = sorted([c for c in reference_colors if c in filtered_colors])
307
-
308
- # Convert to numpy arrays
309
- filtered_embeddings = np.array(filtered_embeddings)
310
- filtered_colors = np.array(filtered_colors)
311
-
312
- # Compute centroids for each color
313
- centroids = {}
314
- for color in unique_colors:
315
- color_mask = np.array([c == color for c in filtered_colors])
316
- if color_mask.sum() > 0:
317
- centroids[color] = np.mean(filtered_embeddings[color_mask], axis=0)
318
-
319
- # Compute similarity matrix
320
- similarity_matrix = np.zeros((len(unique_colors), len(unique_colors)))
321
- for i, color1 in enumerate(unique_colors):
322
- for j, color2 in enumerate(unique_colors):
323
- if i == j:
324
- similarity_matrix[i, j] = 1.0
325
- else:
326
- if color1 in centroids and color2 in centroids:
327
- similarity = cosine_similarity(
328
- [centroids[color1]],
329
- [centroids[color2]]
330
- )[0][0]
331
- similarity_matrix[i, j] = similarity
332
-
333
- # Create heatmap
334
- plt.figure(figsize=(12, 10))
335
- sns.heatmap(
336
- similarity_matrix,
337
- annot=True,
338
- fmt='.2f',
339
- cmap='RdYlBu_r',
340
- xticklabels=unique_colors,
341
- yticklabels=unique_colors,
342
- square=True,
343
- cbar_kws={'label': 'Cosine Similarity'},
344
- linewidths=0.5,
345
- vmin=-0.6,
346
- vmax=1.0
347
- )
348
-
349
- plt.title(title, fontsize=16, fontweight='bold', pad=20)
350
- plt.xlabel('Colors', fontsize=14, fontweight='bold')
351
- plt.ylabel('Colors', fontsize=14, fontweight='bold')
352
- plt.xticks(rotation=45, ha='right')
353
- plt.yticks(rotation=0)
354
- plt.tight_layout()
355
-
356
- output_path = "color_similarity_image_embeddings.png"
357
- plt.savefig(output_path, dpi=300, bbox_inches='tight')
358
- plt.close()
359
- print(f"✅ Color similarity heatmap saved: {output_path}")
360
-
361
- return similarity_matrix
362
-
363
-
364
- def run_tsne(embeddings,legend_labels,output_path,perplexity,title,scatter_color_labels=None,prefer_true_colors=False):
365
- """Calculate and plot a t-SNE projection."""
366
- tsne = TSNE(
367
- n_components=2,
368
- perplexity=perplexity,
369
- init="pca",
370
- learning_rate="auto",
371
- random_state=42,
372
- )
373
- reduced = tsne.fit_transform(embeddings)
374
-
375
- label_array = np.array(legend_labels)
376
- color_labels = (
377
- np.array(scatter_color_labels) if scatter_color_labels is not None else label_array
378
- )
379
-
380
- # Calculate silhouette scores
381
- unique_labels_list = sorted(set(label_array))
382
- if len(unique_labels_list) > 1 and len(label_array) > 1:
383
- # Convert labels to numeric indices for silhouette_score
384
- label_to_idx = {label: idx for idx, label in enumerate(unique_labels_list)}
385
- numeric_labels = np.array([label_to_idx[label] for label in label_array])
386
-
387
- # Calculate in original embedding space (ground truth - measures real separation)
388
- silhouette = silhouette_score(embeddings, numeric_labels, metric='euclidean')
389
- davies_bouldin = davies_bouldin_score(embeddings, numeric_labels)
390
- calinski_harabasz = calinski_harabasz_score(embeddings, numeric_labels)
391
- dunn = compute_dunn_index(embeddings, numeric_labels)
392
-
393
- else:
394
- silhouette = None
395
- davies_bouldin = None
396
- calinski_harabasz = None
397
- dunn = None
398
-
399
- # Helpful reference for the reported clustering indices:
400
- # • Silhouette Score ∈ [-1, 1] — closer to 1 means points fit their cluster well, 0 means overlap, < 0 suggests misassignment.
401
- # • Davies–Bouldin Index ∈ [0, +∞) — lower is better; quantifies average similarity between clusters relative to their size.
402
- # • Calinski–Harabasz Index ∈ [0, +∞) — higher is better; ratio of between-cluster dispersion to within-cluster dispersion.
403
- # • Dunn Index ∈ [0, +∞) — higher is better; compares the tightest cluster diameter to the closest distance between clusters.
404
-
405
- # Build color map for visualization
406
- color_map = build_color_map(color_labels, prefer_true_colors=prefer_true_colors)
407
- color_series = np.array([color_map[label] for label in color_labels])
408
-
409
- plt.figure(figsize=(10, 8))
410
- unique_labels = sorted(set(label_array))
411
- for label in unique_labels:
412
- mask = label_array == label
413
- if 'color' in title:
414
- c = label
415
- else:
416
- c = color_series[mask]
417
- plt.scatter(
418
- reduced[mask, 0],
419
- reduced[mask, 1],
420
- c=c,
421
- s=15,
422
- alpha=0.8,
423
- label=label,
424
- )
425
-
426
- # Add silhouette score to title
427
- if silhouette is not None:
428
- title_with_score = f"{title}\n(t-SNE Silhouette: {silhouette:.3f} | Davies-Bouldin: {davies_bouldin:.3f} | Calinski-Harabasz: {calinski_harabasz:.3f} | Dunn: {dunn:.3f})"
429
- else:
430
- title_with_score = title
431
-
432
- plt.title(title_with_score)
433
- plt.xlabel("t-SNE 1")
434
- plt.ylabel("t-SNE 2")
435
- plt.legend(
436
- bbox_to_anchor=(1.05, 1), loc="upper left", fontsize="small", frameon=False
437
- )
438
- plt.tight_layout()
439
- plt.savefig(output_path, dpi=300)
440
- plt.close()
441
- print(f"✅ Figure saved in {output_path}")
442
- print(f" 📊 t-SNE space: {silhouette:.3f} (matches visualization) | Davies-Bouldin: {davies_bouldin:.3f} | Calinski-Harabasz: {calinski_harabasz:.3f} | Dunn: {dunn:.3f}")
443
-
444
-
445
-
446
- def filter_valid_rows(dataframe: pd.DataFrame) -> pd.DataFrame:
447
- """Keep only rows with valid local image paths and colors."""
448
- dataframe = dataframe[dataframe['color'] != 'unknown'].copy()
449
- df = dataframe.dropna(
450
- subset=[column_local_image_path, color_column, hierarchy_column]
451
- ).copy()
452
- mask = df[column_local_image_path].apply(lambda x: isinstance(x, str) and len(x.strip()) > 0)
453
- return df[mask].reset_index(drop=True)
454
-
455
- if __name__ == "__main__":
456
- sample_size = None
457
- per_color_limit = 500
458
- min_per_hierarchy = 200
459
- batch_size = 32
460
- perplexity = 30
461
- output_color = "tsne_color_space.png"
462
- output_hierarchy = "tsne_hierarchy_space.png"
463
-
464
- print("📥 Loading the dataset...")
465
- df = pd.read_csv("data/data_with_local_paths.csv")
466
- df = filter_valid_rows(df)
467
- print(f"Total len if the dataset: {len(df)}")
468
- df = prepare_dataframe(df, sample_size, per_color_limit, min_per_hierarchy)
469
- print(f"✅ {len(df)} samples will be used for the t-SNE")
470
- print(f"Number of colors in the dataset: {len(df['color'].unique())}")
471
- print(f"Colors in the dataset: {df['color'].unique()}")
472
- dataset = ImageDataset(df, images_dir)
473
- dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4)
474
-
475
- # 2) Loading the models
476
- print("⚙️ Loading the main model...")
477
- model, processor = load_main_model()
478
-
479
- print("⚙️ Loading CLIP baseline model...")
480
- clip_model, clip_processor = load_clip_baseline()
481
-
482
- # 3) Extracting the embeddings
483
- print("🎯 Extracting the embeddings...")
484
-
485
- (
486
- color_embeddings,
487
- hierarchy_embeddings,
488
- colors,
489
- hierarchies,
490
- ) = compute_embeddings(model, dataloader)
491
-
492
- # 4) Calculating the t-SNE
493
- print("🌀 Calculating the color t-SNE...")
494
- run_tsne(
495
- color_embeddings,
496
- colors,
497
- output_color,
498
- perplexity,
499
- "t-SNE of the color embeddings of the main model",
500
- scatter_color_labels=colors,
501
- prefer_true_colors=True,
502
- )
503
-
504
- print("🎨 Computing color similarity matrix from image embeddings...")
505
- compute_color_similarity_matrix(
506
- color_embeddings,
507
- colors,
508
- title="Color similarity (image embeddings - main model)"
509
- )
510
-
511
- print("🌀 Calculating the hierarchy t-SNE...")
512
- run_tsne(
513
- hierarchy_embeddings,
514
- hierarchies,
515
- output_hierarchy,
516
- perplexity,
517
- "t-SNE of the hierarchy embeddings of the main model",
518
- scatter_color_labels=hierarchies,
519
- )
520
-
521
- # ========== CLIP BASELINE EVALUATION ==========
522
- print("\n" + "="*60)
523
- print("🔄 Starting CLIP Baseline Evaluation")
524
- print("="*60)
525
-
526
- print("🎯 Extracting CLIP embeddings...")
527
- (
528
- clip_color_embeddings,
529
- clip_hierarchy_embeddings,
530
- clip_color_hier_embeddings,
531
- clip_colors,
532
- clip_hierarchies,
533
- ) = compute_clip_embeddings(clip_model, clip_processor, dataloader)
534
-
535
- # Output paths for CLIP baseline
536
- clip_output_color = "clip_baseline_tsne_color_space.png"
537
- clip_output_hierarchy = "clip_baseline_tsne_hierarchy_space.png"
538
-
539
- print("🌀 Calculating CLIP baseline color t-SNE...")
540
- run_tsne(
541
- clip_color_embeddings,
542
- clip_colors,
543
- clip_output_color,
544
- perplexity,
545
- "t-SNE of the color embeddings (CLIP Baseline)",
546
- scatter_color_labels=clip_colors,
547
- prefer_true_colors=True,
548
- )
549
-
550
- print("🎨 Computing color similarity matrix from image embeddings...")
551
- compute_color_similarity_matrix(
552
- clip_color_embeddings,
553
- clip_colors,
554
- title="Color similarity (image embeddings - CLIP Baseline)"
555
- )
556
-
557
- print("🌀 Calculating CLIP baseline hierarchy t-SNE...")
558
- run_tsne(
559
- clip_hierarchy_embeddings,
560
- clip_hierarchies,
561
- clip_output_hierarchy,
562
- perplexity,
563
- "t-SNE of the hierarchy embeddings (CLIP Baseline)",
564
- scatter_color_labels=clip_hierarchies,
565
- )
566
-
567
- print("\n✅ All t-SNE visualizations completed!")
568
- print(" - Main model: tsne_*.png")
569
- print(" - CLIP baseline: clip_baseline_tsne_*.png")