Leacb4 commited on
Commit
d1b4c8f
Β·
verified Β·
1 Parent(s): 4c2095f

Upload evaluation/hierarchy_evaluation_with_clip_baseline.py with huggingface_hub

Browse files
evaluation/hierarchy_evaluation_with_clip_baseline.py ADDED
@@ -0,0 +1,808 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hierarchy embedding evaluation with CLIP baseline comparison.
3
+ This file evaluates the quality of hierarchy embeddings from the custom model and compares them
4
+ with a CLIP baseline model (OpenAI CLIP). It calculates similarity metrics, classification accuracies,
5
+ and generates confusion matrices for both models to measure relative performance. It also supports
6
+ evaluation on Fashion-MNIST and kagl Marqo datasets.
7
+ """
8
+
9
+ import torch
10
+ import pandas as pd
11
+ import numpy as np
12
+ import matplotlib.pyplot as plt
13
+ import seaborn as sns
14
+ from sklearn.metrics.pairwise import cosine_similarity
15
+ from sklearn.metrics import confusion_matrix, classification_report, accuracy_score, f1_score
16
+ from collections import defaultdict
17
+ import os
18
+ import requests
19
+ from tqdm import tqdm
20
+ from torch.utils.data import Dataset, DataLoader
21
+ from torchvision import transforms
22
+ from sklearn.model_selection import train_test_split
23
+ from io import BytesIO
24
+ from PIL import Image
25
+ import warnings
26
+ warnings.filterwarnings('ignore')
27
+
28
+ # Import transformers CLIP
29
+ from transformers import CLIPProcessor, CLIPModel as TransformersCLIPModel
30
+
31
+ # Import your custom model
32
+ from hierarchy_model import Model, HierarchyExtractor, HierarchyDataset, collate_fn
33
+ import config
34
+
35
+ def convert_fashion_mnist_to_image(pixel_values):
36
+ """Convert Fashion-MNIST pixel values to PIL image"""
37
+ # Reshape to 28x28 and convert to PIL Image
38
+ image_array = np.array(pixel_values).reshape(28, 28).astype(np.uint8)
39
+ # Convert to RGB by duplicating the grayscale channel
40
+ image_array = np.stack([image_array] * 3, axis=-1)
41
+ image = Image.fromarray(image_array)
42
+ return image
43
+
44
+ def get_fashion_mnist_labels():
45
+ """Get Fashion-MNIST class labels"""
46
+ return {
47
+ 0: "T-shirt/top",
48
+ 1: "Trouser",
49
+ 2: "Pullover",
50
+ 3: "Dress",
51
+ 4: "Coat",
52
+ 5: "Sandal",
53
+ 6: "Shirt",
54
+ 7: "Sneaker",
55
+ 8: "Bag",
56
+ 9: "Ankle boot"
57
+ }
58
+
59
+ class FashionMNISTDataset(Dataset):
60
+ def __init__(self, dataframe, image_size=224):
61
+ self.dataframe = dataframe
62
+ self.image_size = image_size
63
+ self.labels_map = get_fashion_mnist_labels()
64
+
65
+ # Simple transforms for validation/inference
66
+ self.transform = transforms.Compose([
67
+ transforms.Resize((image_size, image_size)),
68
+ transforms.ToTensor(),
69
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
70
+ ])
71
+
72
+ def __len__(self):
73
+ return len(self.dataframe)
74
+
75
+ def __getitem__(self, idx):
76
+ row = self.dataframe.iloc[idx]
77
+
78
+ # Get pixel values (columns 1-784)
79
+ pixel_cols = [f'pixel{i}' for i in range(1, 785)]
80
+ pixel_values = row[pixel_cols].values
81
+
82
+ # Convert to image
83
+ image = convert_fashion_mnist_to_image(pixel_values)
84
+ image = self.transform(image)
85
+
86
+ # Get text description
87
+ text = row['text']
88
+
89
+ # Get hierarchy label
90
+ hierarchy = row['hierarchy']
91
+
92
+ return image, text, hierarchy
93
+
94
+ class CLIPBaselineEvaluator:
95
+ def __init__(self, device='mps'):
96
+ self.device = torch.device(device)
97
+
98
+ # Load CLIP model and processor
99
+ print("πŸ€— Loading CLIP baseline model from transformers...")
100
+ self.clip_model = TransformersCLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device)
101
+ self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
102
+ self.clip_model.eval()
103
+ print("βœ… CLIP model loaded successfully")
104
+
105
+ def extract_clip_embeddings(self, images, texts):
106
+ """Extract CLIP embeddings for images and texts"""
107
+ all_image_embeddings = []
108
+ all_text_embeddings = []
109
+
110
+ with torch.no_grad():
111
+ for i in tqdm(range(len(images)), desc="Extracting CLIP embeddings"):
112
+ # Process image
113
+ if isinstance(images[i], torch.Tensor):
114
+ # Convert tensor back to PIL Image
115
+ image_tensor = images[i]
116
+ if image_tensor.dim() == 3:
117
+ # Denormalize
118
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
119
+ std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
120
+ image_tensor = image_tensor * std + mean
121
+ image_tensor = torch.clamp(image_tensor, 0, 1)
122
+
123
+ # Convert to PIL
124
+ image_pil = transforms.ToPILImage()(image_tensor)
125
+ elif isinstance(images[i], Image.Image):
126
+ image_pil = images[i]
127
+ else:
128
+ raise ValueError(f"Unsupported image type: {type(images[i])}")
129
+
130
+ # Process with CLIP
131
+ inputs = self.clip_processor(
132
+ text=texts[i],
133
+ images=image_pil,
134
+ return_tensors="pt",
135
+ padding=True
136
+ ).to(self.device)
137
+
138
+ outputs = self.clip_model(**inputs)
139
+
140
+ # Get normalized embeddings
141
+ image_emb = outputs.image_embeds / outputs.image_embeds.norm(p=2, dim=-1, keepdim=True)
142
+ text_emb = outputs.text_embeds / outputs.text_embeds.norm(p=2, dim=-1, keepdim=True)
143
+
144
+ all_image_embeddings.append(image_emb.cpu().numpy())
145
+ all_text_embeddings.append(text_emb.cpu().numpy())
146
+
147
+ return np.vstack(all_image_embeddings), np.vstack(all_text_embeddings)
148
+
149
+
150
+ class EmbeddingEvaluator:
151
+ def __init__(self, model_path, directory):
152
+ self.device = config.device
153
+ self.directory = directory
154
+
155
+ # 1. Load the dataset
156
+ CSV = config.local_dataset_path
157
+ print(f"πŸ“ Using dataset with local images: {CSV}")
158
+ df = pd.read_csv(CSV)
159
+
160
+ print(f"πŸ“ Loaded {len(df)} samples")
161
+
162
+ # 2. Get unique hierarchy classes from the dataset
163
+ hierarchy_classes = sorted(df[config.hierarchy_column].unique().tolist())
164
+ print(f"πŸ“‹ Found {len(hierarchy_classes)} hierarchy classes")
165
+
166
+ _, self.val_df = train_test_split(df, test_size=0.2, random_state=42, stratify=df['hierarchy'])
167
+
168
+ # 3. Load the model
169
+ if os.path.exists(model_path):
170
+ checkpoint = torch.load(model_path, map_location=self.device)
171
+
172
+ config = checkpoint.get('config', {})
173
+ saved_hierarchy_classes = checkpoint['hierarchy_classes']
174
+
175
+ # Use the saved hierarchy classes
176
+ self.hierarchy_classes = saved_hierarchy_classes
177
+
178
+ # Create the hierarchy extractor
179
+ self.vocab = HierarchyExtractor(saved_hierarchy_classes)
180
+
181
+ # Create the model with the saved configuration
182
+ self.model = Model(
183
+ num_hierarchy_classes=len(saved_hierarchy_classes),
184
+ embed_dim=config['embed_dim'],
185
+ dropout=config['dropout']
186
+ ).to(self.device)
187
+
188
+ self.model.load_state_dict(checkpoint['model_state'])
189
+
190
+ print(f"βœ… Custom model loaded with:")
191
+ print(f"πŸ“‹ Hierarchy classes: {len(saved_hierarchy_classes)}")
192
+ print(f"🎯 Embed dim: {config['embed_dim']}")
193
+ print(f"πŸ’§ Dropout: {config['dropout']}")
194
+ print(f"πŸ“… Epoch: {checkpoint.get('epoch', 'unknown')}")
195
+
196
+ else:
197
+ raise FileNotFoundError(f"Model file {model_path} not found")
198
+
199
+ self.model.eval()
200
+
201
+ # Initialize CLIP baseline
202
+ self.clip_evaluator = CLIPBaselineEvaluator(device)
203
+
204
+ def create_dataloader(self, dataframe, batch_size=16):
205
+ """Create a dataloader for custom model"""
206
+ # Check if this is Fashion-MNIST data (has pixel1 column)
207
+ if 'pixel1' in dataframe.columns:
208
+ print("πŸ” Detected Fashion-MNIST data, using FashionMNISTDataset")
209
+ dataset = FashionMNISTDataset(dataframe, image_size=224)
210
+ else:
211
+ dataset = HierarchyDataset(dataframe, image_size=224)
212
+
213
+ dataloader = DataLoader(
214
+ dataset,
215
+ batch_size=batch_size,
216
+ shuffle=False,
217
+ collate_fn=lambda batch: collate_fn(batch, self.vocab),
218
+ num_workers=0
219
+ )
220
+
221
+ return dataloader
222
+
223
+ def create_clip_dataloader(self, dataframe, batch_size=16):
224
+ """Create a dataloader for CLIP baseline"""
225
+ # Check if this is Fashion-MNIST data (has pixel1 column)
226
+ if 'pixel1' in dataframe.columns:
227
+ print("πŸ” Detected Fashion-MNIST data for CLIP, using FashionMNISTDataset")
228
+ dataset = FashionMNISTDataset(dataframe, image_size=224)
229
+ else:
230
+ dataset = CLIPDataset(dataframe)
231
+
232
+ dataloader = DataLoader(
233
+ dataset,
234
+ batch_size=batch_size,
235
+ shuffle=False,
236
+ num_workers=0
237
+ )
238
+
239
+ return dataloader
240
+
241
+ def extract_custom_embeddings(self, dataloader, embedding_type='text'):
242
+ """Extract embeddings from custom model"""
243
+ all_embeddings = []
244
+ all_labels = []
245
+ all_texts = []
246
+
247
+ with torch.no_grad():
248
+ for batch in tqdm(dataloader, desc=f"Extracting custom {embedding_type} embeddings"):
249
+ images = batch['image'].to(self.device)
250
+ hierarchy_indices = batch['hierarchy_indices'].to(self.device)
251
+ hierarchy_labels = batch['hierarchy']
252
+
253
+ # Forward pass
254
+ out = self.model(image=images, hierarchy_indices=hierarchy_indices)
255
+ embeddings = out['z_txt'] if embedding_type == 'text' else out['z_img'] if embedding_type == 'image' else out['z_txt']
256
+
257
+ all_embeddings.append(embeddings.cpu().numpy())
258
+ all_labels.extend(hierarchy_labels)
259
+ all_texts.extend(hierarchy_labels)
260
+
261
+ return np.vstack(all_embeddings), all_labels, all_texts
262
+
263
+ def compute_similarity_metrics(self, embeddings, labels):
264
+ """Compute intra-class and inter-class similarities"""
265
+ similarities = cosine_similarity(embeddings)
266
+
267
+ # Group embeddings by hierarchy
268
+ hierarchy_groups = defaultdict(list)
269
+ for i, hierarchy in enumerate(labels):
270
+ hierarchy_groups[hierarchy].append(i)
271
+
272
+ # Calculate intra-class similarities (same hierarchy)
273
+ intra_class_similarities = []
274
+ for hierarchy, indices in hierarchy_groups.items():
275
+ if len(indices) > 1:
276
+ for i in range(len(indices)):
277
+ for j in range(i+1, len(indices)):
278
+ sim = similarities[indices[i], indices[j]]
279
+ intra_class_similarities.append(sim)
280
+
281
+ # Calculate inter-class similarities (different hierarchies)
282
+ inter_class_similarities = []
283
+ hierarchies = list(hierarchy_groups.keys())
284
+ for i in range(len(hierarchies)):
285
+ for j in range(i+1, len(hierarchies)):
286
+ hierarchy1_indices = hierarchy_groups[hierarchies[i]]
287
+ hierarchy2_indices = hierarchy_groups[hierarchies[j]]
288
+
289
+ for idx1 in hierarchy1_indices:
290
+ for idx2 in hierarchy2_indices:
291
+ sim = similarities[idx1, idx2]
292
+ inter_class_similarities.append(sim)
293
+
294
+ # Calculate classification accuracy using nearest neighbor in embedding space
295
+ nn_accuracy = self.compute_embedding_accuracy(embeddings, labels, similarities)
296
+
297
+ # Calculate classification accuracy using centroids
298
+ centroid_accuracy = self.compute_centroid_accuracy(embeddings, labels)
299
+
300
+ return {
301
+ 'intra_class_similarities': intra_class_similarities,
302
+ 'inter_class_similarities': inter_class_similarities,
303
+ 'intra_class_mean': np.mean(intra_class_similarities) if intra_class_similarities else 0,
304
+ 'inter_class_mean': np.mean(inter_class_similarities) if inter_class_similarities else 0,
305
+ 'separation_score': np.mean(intra_class_similarities) - np.mean(inter_class_similarities) if intra_class_similarities and inter_class_similarities else 0,
306
+ 'accuracy': nn_accuracy,
307
+ 'centroid_accuracy': centroid_accuracy
308
+ }
309
+
310
+ def compute_embedding_accuracy(self, embeddings, labels, similarities):
311
+ """Compute classification accuracy using nearest neighbor in embedding space"""
312
+ correct_predictions = 0
313
+ total_predictions = len(labels)
314
+
315
+ for i in range(len(embeddings)):
316
+ true_label = labels[i]
317
+
318
+ # Find the most similar embedding (excluding itself)
319
+ similarities_row = similarities[i].copy()
320
+ similarities_row[i] = -1 # Exclude self-similarity
321
+ nearest_neighbor_idx = np.argmax(similarities_row)
322
+ predicted_label = labels[nearest_neighbor_idx]
323
+
324
+ if predicted_label == true_label:
325
+ correct_predictions += 1
326
+
327
+ return correct_predictions / total_predictions if total_predictions > 0 else 0
328
+
329
+ def compute_centroid_accuracy(self, embeddings, labels):
330
+ """Compute classification accuracy using hierarchy centroids"""
331
+ # Create centroids for each hierarchy
332
+ unique_hierarchies = list(set(labels))
333
+ centroids = {}
334
+
335
+ for hierarchy in unique_hierarchies:
336
+ hierarchy_indices = [i for i, label in enumerate(labels) if label == hierarchy]
337
+ hierarchy_embeddings = embeddings[hierarchy_indices]
338
+ centroids[hierarchy] = np.mean(hierarchy_embeddings, axis=0)
339
+
340
+ # Classify each embedding to nearest centroid
341
+ correct_predictions = 0
342
+ total_predictions = len(labels)
343
+
344
+ for i, embedding in enumerate(embeddings):
345
+ true_label = labels[i]
346
+
347
+ # Find closest centroid
348
+ best_similarity = -1
349
+ predicted_label = None
350
+
351
+ for hierarchy, centroid in centroids.items():
352
+ similarity = cosine_similarity([embedding], [centroid])[0][0]
353
+ if similarity > best_similarity:
354
+ best_similarity = similarity
355
+ predicted_label = hierarchy
356
+
357
+ if predicted_label == true_label:
358
+ correct_predictions += 1
359
+
360
+ return correct_predictions / total_predictions if total_predictions > 0 else 0
361
+
362
+ def create_confusion_matrix(self, true_labels, predicted_labels, title="Confusion Matrix"):
363
+ """Create and plot confusion matrix"""
364
+ # Get unique labels
365
+ unique_labels = sorted(list(set(true_labels + predicted_labels)))
366
+
367
+ # Create confusion matrix
368
+ cm = confusion_matrix(true_labels, predicted_labels, labels=unique_labels)
369
+
370
+ # Calculate accuracy
371
+ accuracy = accuracy_score(true_labels, predicted_labels)
372
+
373
+ # Plot confusion matrix
374
+ plt.figure(figsize=(12, 10))
375
+ sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
376
+ xticklabels=unique_labels, yticklabels=unique_labels)
377
+ plt.title(f'{title}\nAccuracy: {accuracy:.3f} ({accuracy*100:.1f}%)')
378
+ plt.ylabel('True Hierarchy')
379
+ plt.xlabel('Predicted Hierarchy')
380
+ plt.xticks(rotation=45)
381
+ plt.yticks(rotation=0)
382
+ plt.tight_layout()
383
+
384
+ return plt.gcf(), accuracy, cm
385
+
386
+ def predict_hierarchy_from_embeddings(self, embeddings, labels):
387
+ """Predict hierarchy from embeddings using centroid-based classification"""
388
+ # Create hierarchy centroids from training data
389
+ unique_hierarchies = list(set(labels))
390
+ centroids = {}
391
+
392
+ for hierarchy in unique_hierarchies:
393
+ hierarchy_indices = [i for i, label in enumerate(labels) if label == hierarchy]
394
+ hierarchy_embeddings = embeddings[hierarchy_indices]
395
+ centroids[hierarchy] = np.mean(hierarchy_embeddings, axis=0)
396
+
397
+ # Predict hierarchy for all embeddings
398
+ predictions = []
399
+
400
+ for i, embedding in enumerate(embeddings):
401
+ # Find closest centroid
402
+ best_similarity = -1
403
+ predicted_hierarchy = None
404
+
405
+ for hierarchy, centroid in centroids.items():
406
+ similarity = cosine_similarity([embedding], [centroid])[0][0]
407
+ if similarity > best_similarity:
408
+ best_similarity = similarity
409
+ predicted_hierarchy = hierarchy
410
+
411
+ predictions.append(predicted_hierarchy)
412
+
413
+ return predictions
414
+
415
+ def evaluate_classification_performance(self, embeddings, labels, embedding_type="Embeddings"):
416
+ """Evaluate classification performance and create confusion matrix"""
417
+ # Predict hierarchy
418
+ predictions = self.predict_hierarchy_from_embeddings(embeddings, labels)
419
+
420
+ # Calculate accuracy
421
+ accuracy = accuracy_score(labels, predictions)
422
+
423
+ # Calculate F1 scores
424
+ unique_labels = sorted(list(set(labels)))
425
+ f1_macro = f1_score(labels, predictions, labels=unique_labels, average='macro', zero_division=0)
426
+ f1_weighted = f1_score(labels, predictions, labels=unique_labels, average='weighted', zero_division=0)
427
+ f1_per_class = f1_score(labels, predictions, labels=unique_labels, average=None, zero_division=0)
428
+
429
+ # Create confusion matrix
430
+ fig, acc, cm = self.create_confusion_matrix(labels, predictions,
431
+ f"{embedding_type} - Hierarchy Classification")
432
+
433
+ # Generate classification report
434
+ report = classification_report(labels, predictions, labels=unique_labels,
435
+ target_names=unique_labels, output_dict=True)
436
+
437
+ return {
438
+ 'accuracy': accuracy,
439
+ 'f1_macro': f1_macro,
440
+ 'f1_weighted': f1_weighted,
441
+ 'f1_per_class': f1_per_class,
442
+ 'predictions': predictions,
443
+ 'confusion_matrix': cm,
444
+ 'classification_report': report,
445
+ 'figure': fig
446
+ }
447
+
448
+ def evaluate_dataset_with_baselines(self, dataframe, dataset_name="Dataset"):
449
+ """Evaluate embeddings on a given dataset with both custom model and CLIP baseline"""
450
+ print(f"\n{'='*60}")
451
+ print(f"Evaluating {dataset_name}")
452
+ print(f"{'='*60}")
453
+
454
+ results = {}
455
+
456
+ # ===== CUSTOM MODEL EVALUATION =====
457
+ print(f"\nπŸ”§ Evaluating Custom Model on {dataset_name}")
458
+ print("-" * 40)
459
+
460
+ # Create dataloader for custom model
461
+ custom_dataloader = self.create_dataloader(dataframe, batch_size=16)
462
+
463
+ # Evaluate text embeddings
464
+ text_embeddings, text_labels, texts = self.extract_custom_embeddings(custom_dataloader, 'text')
465
+ text_metrics = self.compute_similarity_metrics(text_embeddings, text_labels)
466
+ text_classification = self.evaluate_classification_performance(text_embeddings, text_labels, "Custom Text Embeddings")
467
+ text_metrics.update(text_classification)
468
+ results['custom_text'] = text_metrics
469
+
470
+ # Evaluate image embeddings
471
+ image_embeddings, image_labels, _ = self.extract_custom_embeddings(custom_dataloader, 'image')
472
+ image_metrics = self.compute_similarity_metrics(image_embeddings, image_labels)
473
+ image_classification = self.evaluate_classification_performance(image_embeddings, image_labels, "Custom Image Embeddings")
474
+ image_metrics.update(image_classification)
475
+ results['custom_image'] = image_metrics
476
+
477
+ # ===== CLIP BASELINE EVALUATION =====
478
+ print(f"\nπŸ€— Evaluating CLIP Baseline on {dataset_name}")
479
+ print("-" * 40)
480
+
481
+ # Create dataloader for CLIP
482
+ clip_dataloader = self.create_clip_dataloader(dataframe, batch_size=8) # Smaller batch for CLIP
483
+
484
+ # Extract data for CLIP
485
+ all_images = []
486
+ all_texts = []
487
+ all_labels = []
488
+
489
+ for batch in tqdm(clip_dataloader, desc="Preparing data for CLIP"):
490
+ images, texts, labels = batch
491
+ all_images.extend(images)
492
+ all_texts.extend(texts)
493
+ all_labels.extend(labels)
494
+
495
+ # Get CLIP embeddings
496
+ clip_image_embeddings, clip_text_embeddings = self.clip_evaluator.extract_clip_embeddings(all_images, all_texts)
497
+
498
+ # Evaluate CLIP text embeddings
499
+ clip_text_metrics = self.compute_similarity_metrics(clip_text_embeddings, all_labels)
500
+ clip_text_classification = self.evaluate_classification_performance(clip_text_embeddings, all_labels, "CLIP Text Embeddings")
501
+ clip_text_metrics.update(clip_text_classification)
502
+ results['clip_text'] = clip_text_metrics
503
+
504
+ # Evaluate CLIP image embeddings
505
+ clip_image_metrics = self.compute_similarity_metrics(clip_image_embeddings, all_labels)
506
+ clip_image_classification = self.evaluate_classification_performance(clip_image_embeddings, all_labels, "CLIP Image Embeddings")
507
+ clip_image_metrics.update(clip_image_classification)
508
+ results['clip_image'] = clip_image_metrics
509
+
510
+ # ===== PRINT COMPARISON RESULTS =====
511
+ print(f"\n{dataset_name} Results Comparison:")
512
+ print(f"Dataset size: {len(dataframe)} samples")
513
+ print("=" * 80)
514
+ print(f"{'Model':<20} {'Embedding':<10} {'Sep Score':<10} {'NN Acc':<8} {'Centroid Acc':<12} {'F1 Macro':<10}")
515
+ print("-" * 80)
516
+
517
+ for model_type in ['custom', 'clip']:
518
+ for emb_type in ['text', 'image']:
519
+ key = f"{model_type}_{emb_type}"
520
+ if key in results:
521
+ metrics = results[key]
522
+ model_name = "Custom Model" if model_type == 'custom' else "CLIP Baseline"
523
+ print(f"{model_name:<20} {emb_type.capitalize():<10} {metrics['separation_score']:<10.4f} {metrics['accuracy']*100:<8.1f}% {metrics['centroid_accuracy']*100:<12.1f}% {metrics['f1_macro']*100:<10.1f}%")
524
+
525
+ # ===== SAVE VISUALIZATIONS =====
526
+ os.makedirs(f'{self.directory}', exist_ok=True)
527
+
528
+ # Save confusion matrices
529
+ for key, metrics in results.items():
530
+ if 'figure' in metrics:
531
+ metrics['figure'].savefig(f'{self.directory}/{dataset_name.lower()}_{key}_confusion_matrix.png', dpi=300, bbox_inches='tight')
532
+ plt.close(metrics['figure'])
533
+
534
+ return results
535
+
536
+
537
+ class CLIPDataset(Dataset):
538
+ def __init__(self, dataframe):
539
+ self.dataframe = dataframe
540
+ # Use VALIDATION transforms (no augmentation)
541
+ self.transform = transforms.Compose([
542
+ transforms.Resize((224, 224)),
543
+ transforms.ToTensor(),
544
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
545
+ ])
546
+
547
+ def __len__(self):
548
+ return len(self.dataframe)
549
+
550
+ def __getitem__(self, idx):
551
+ row = self.dataframe.iloc[idx]
552
+
553
+ # Handle image loading (same as HierarchyDataset)
554
+ if config.column_local_image_path in row.index and pd.notna(row[config.column_local_image_path]):
555
+ local_path = row[config.column_local_image_path]
556
+ try:
557
+ if os.path.exists(local_path):
558
+ image = Image.open(local_path).convert("RGB")
559
+ else:
560
+ print(f"⚠️ Local image not found: {local_path}")
561
+ image = Image.new('RGB', (224, 224), color='gray')
562
+ except Exception as e:
563
+ print(f"⚠️ Failed to load local image {idx}: {e}")
564
+ image = Image.new('RGB', (224, 224), color='gray')
565
+ elif isinstance(row[config.column_url_image], dict):
566
+ image = Image.open(BytesIO(row[config.column_url_image]['bytes'])).convert('RGB')
567
+ elif isinstance(row['image_url'], (list, np.ndarray)):
568
+ pixels = np.array(row[config.column_url_image]).reshape(28, 28)
569
+ image = Image.fromarray(pixels.astype(np.uint8)).convert("RGB")
570
+ elif isinstance(row[config.column_url_image], Image.Image):
571
+ # Handle PIL Image objects directly (for Fashion-MNIST)
572
+ image = row[config.column_url_image].convert("RGB")
573
+ else:
574
+ try:
575
+ response = requests.get(row[config.column_url_image], timeout=10)
576
+ response.raise_for_status()
577
+ image = Image.open(BytesIO(response.content)).convert("RGB")
578
+ except Exception as e:
579
+ print(f"⚠️ Failed to load image {idx}: {e}")
580
+ image = Image.new('RGB', (224, 224), color='gray')
581
+
582
+ # Apply transforms
583
+ image_tensor = self.transform(image)
584
+
585
+ description = row[config.text_column]
586
+ hierarchy = row[config.hierarchy_column]
587
+
588
+ return image_tensor, description, hierarchy
589
+
590
+
591
+ def load_fashion_mnist_dataset(evaluator):
592
+ """Load and prepare Fashion-MNIST test dataset"""
593
+ print("Loading Fashion-MNIST test dataset...")
594
+
595
+ # Load the dataset
596
+ df = pd.read_csv(config.fashion_mnist_test_path)
597
+ print(f"βœ… Fashion-MNIST dataset loaded")
598
+ print(f"πŸ“Š Total samples: {len(df)}")
599
+
600
+ # Fashion-MNIST class labels mapping
601
+ fashion_mnist_labels = get_fashion_mnist_labels()
602
+
603
+ # Map labels to hierarchy classes
604
+ hierarchy_mapping = {
605
+ 'T-shirt/top': 'top',
606
+ 'Trouser': 'bottom',
607
+ 'Pullover': 'top',
608
+ 'Dress': 'dress',
609
+ 'Coat': 'top',
610
+ 'Sandal': 'shoes',
611
+ 'Shirt': 'top',
612
+ 'Sneaker': 'shoes',
613
+ 'Bag': 'bag',
614
+ 'Ankle boot': 'shoes'
615
+ }
616
+
617
+ # Apply label mapping
618
+ df['hierarchy'] = df['label'].map(fashion_mnist_labels).map(hierarchy_mapping)
619
+
620
+ # Filter to only include hierarchies that exist in our model
621
+ valid_hierarchies = df['hierarchy'].dropna().unique()
622
+ print(f"🎯 Valid hierarchies found: {sorted(valid_hierarchies)}")
623
+ print(f"🎯 Model hierarchies: {sorted(evaluator.hierarchy_classes)}")
624
+
625
+ # Filter to only include hierarchies that exist in our model
626
+ df = df[df['hierarchy'].isin(evaluator.hierarchy_classes)]
627
+ print(f"πŸ“Š After filtering to model hierarchies: {len(df)} samples")
628
+
629
+ if len(df) == 0:
630
+ print("❌ No samples left after hierarchy filtering.")
631
+ return pd.DataFrame()
632
+
633
+ # Keep pixel columns as they are (FashionMNISTDataset will handle them)
634
+
635
+ # Create text descriptions based on hierarchy
636
+ text_descriptions = {
637
+ 'top': 'A top clothing item',
638
+ 'bottom': 'A bottom clothing item',
639
+ 'dress': 'A dress',
640
+ 'shoes': 'A pair of shoes',
641
+ 'bag': 'A bag'
642
+ }
643
+
644
+ df['text'] = df['hierarchy'].map(text_descriptions)
645
+
646
+ # Show sample of data
647
+ print(f"πŸ“ Sample data:")
648
+ for i, (hierarchy, text) in enumerate(zip(df['hierarchy'].head(3), df['text'].head(3))):
649
+ print(f" {i+1}. [{hierarchy}] {text}")
650
+
651
+ df_test = df.copy()
652
+
653
+ print(f"πŸ“Š After sampling: {len(df_test)} samples")
654
+ print(f"πŸ“Š Samples per hierarchy:")
655
+ for hierarchy in sorted(df_test['hierarchy'].unique()):
656
+ count = len(df_test[df_test['hierarchy'] == hierarchy])
657
+ print(f" {hierarchy}: {count} samples")
658
+
659
+ # Create formatted dataset with proper column names
660
+ # Keep all pixel columns for FashionMNISTDataset
661
+ pixel_cols = [f'pixel{i}' for i in range(1, 785)]
662
+ fashion_mnist_formatted = df_test[['label'] + pixel_cols + ['text', 'hierarchy']].copy()
663
+
664
+ print(f"πŸ“Š Final dataset size: {len(fashion_mnist_formatted)} samples")
665
+ return fashion_mnist_formatted
666
+
667
+
668
+ def load_kagl_marqo_dataset(evaluator):
669
+ """Load and prepare kagl dataset"""
670
+ from datasets import load_dataset
671
+ print("Loading kagl dataset...")
672
+
673
+ # Load the dataset
674
+ dataset = load_dataset("Marqo/KAGL")
675
+ df = dataset["data"].to_pandas()
676
+ print(f"βœ… Dataset kagl loaded")
677
+ print(f"πŸ“Š Before filtering: {len(df)} samples")
678
+ print(f"πŸ“‹ Available columns: {list(df.columns)}")
679
+
680
+ # Check available categories and map them to our hierarchy
681
+ print(f"🎨 Available categories: {sorted(df['category2'].unique())}")
682
+ # Apply mapping
683
+ df['hierarchy'] = df['category2'].str.lower()
684
+ df['hierarchy'] = df['hierarchy'].replace('bags', 'bag').replace('topwear', 'top').replace('flip flops', 'shoes').replace('sandal', 'shoes')
685
+
686
+ # Filter to only include valid hierarchies that exist in our model
687
+ valid_hierarchies = df['hierarchy'].dropna().unique()
688
+ print(f"🎯 Valid hierarchies found: {sorted(valid_hierarchies)}")
689
+ print(f"🎯 Model hierarchies: {sorted(evaluator.hierarchy_classes)}")
690
+
691
+ # Filter to only include hierarchies that exist in our model
692
+ df = df[df['hierarchy'].isin(evaluator.hierarchy_classes)]
693
+ print(f"πŸ“Š After filtering to model hierarchies: {len(df)} samples")
694
+
695
+ if len(df) == 0:
696
+ print("❌ No samples left after hierarchy filtering.")
697
+ return pd.DataFrame()
698
+
699
+ # Ensure we have text and image data
700
+ df = df.dropna(subset=['text', 'image'])
701
+ print(f"πŸ“Š After removing missing text/image: {len(df)} samples")
702
+
703
+ # Show sample of text data to verify quality
704
+ print(f"πŸ“ Sample texts:")
705
+ for i, (text, hierarchy) in enumerate(zip(df['text'].head(3), df['hierarchy'].head(3))):
706
+ print(f" {i+1}. [{hierarchy}] {text[:100]}...")
707
+
708
+ print(f"πŸ“Š After sampling: {len(df_test)} samples")
709
+ print(f"πŸ“Š Samples per hierarchy:")
710
+ for hierarchy in sorted(df_test['hierarchy'].unique()):
711
+ count = len(df_test[df_test['hierarchy'] == hierarchy])
712
+ print(f" {hierarchy}: {count} samples")
713
+
714
+ # Create formatted dataset with proper column names
715
+ kagl_formatted = pd.DataFrame({
716
+ 'image_url': df_test['image'],
717
+ 'text': df_test['text'],
718
+ 'hierarchy': df_test['hierarchy']
719
+ })
720
+
721
+ print(f"πŸ“Š Final dataset size: {len(kagl_formatted)} samples")
722
+ return kagl_formatted
723
+
724
+
725
+ if __name__ == "__main__":
726
+ device = config.device
727
+ directory = config.evaluation_directory
728
+
729
+ print(f"πŸš€ Starting evaluation with custom model: {config.hierarchy_model_path}")
730
+ print(f"πŸ€— Including CLIP baseline comparison")
731
+
732
+ evaluator = EmbeddingEvaluator(config.hierarchy_model_path, directory, device=device)
733
+
734
+ print(f"πŸ“Š Final hierarchy classes after initialization: {len(evaluator.vocab.hierarchy_classes)} classes")
735
+
736
+ # Evaluate on validation dataset (same subset as during training)
737
+ print("\n" + "="*60)
738
+ print("EVALUATING VALIDATION DATASET - CUSTOM MODEL vs CLIP BASELINE")
739
+ print("="*60)
740
+ val_results = evaluator.evaluate_dataset_with_baselines(evaluator.val_df, "Validation Dataset")
741
+
742
+ print("\n" + "="*60)
743
+ print("EVALUATING FASHION-MNIST TEST DATASET - CUSTOM MODEL vs CLIP BASELINE")
744
+ print("="*60)
745
+ df_fashion_mnist = load_fashion_mnist_dataset(evaluator)
746
+ if len(df_fashion_mnist) > 0:
747
+ fashion_mnist_results = evaluator.evaluate_dataset_with_baselines(df_fashion_mnist, "Fashion-MNIST Test Dataset")
748
+ else:
749
+ fashion_mnist_results = {}
750
+
751
+ print("\n" + "="*60)
752
+ print("EVALUATING kagl MARQO DATASET - CUSTOM MODEL vs CLIP BASELINE")
753
+ print("="*60)
754
+ df_kagl_marqo = load_kagl_marqo_dataset(evaluator)
755
+ if len(df_kagl_marqo) > 0:
756
+ kagl_results = evaluator.evaluate_dataset_with_baselines(df_kagl_marqo, "kagl Marqo Dataset")
757
+ else:
758
+ kagl_results = {}
759
+
760
+ # Compare results
761
+ print(f"\n{'='*80}")
762
+ print("FINAL EVALUATION SUMMARY - CUSTOM MODEL vs CLIP BASELINE")
763
+ print(f"{'='*80}")
764
+
765
+ print("\nπŸ” VALIDATION DATASET RESULTS:")
766
+ print(f"Dataset size: {len(evaluator.val_df)} samples")
767
+ print(f"{'Model':<20} {'Embedding':<10} {'Sep Score':<12} {'NN Acc':<10} {'Centroid Acc':<12} {'F1 Macro':<10}")
768
+ print("-" * 80)
769
+
770
+ for model_type in ['custom', 'clip']:
771
+ for emb_type in ['text', 'image']:
772
+ key = f"{model_type}_{emb_type}"
773
+ if key in val_results:
774
+ metrics = val_results[key]
775
+ model_name = "Custom Model" if model_type == 'custom' else "CLIP Baseline"
776
+ print(f"{model_name:<20} {emb_type.capitalize():<10} {metrics['separation_score']:<12.4f} {metrics['accuracy']*100:<10.1f}% {metrics['centroid_accuracy']*100:<12.1f}% {metrics['f1_macro']*100:<10.1f}%")
777
+
778
+ if fashion_mnist_results:
779
+ print("\nπŸ‘— FASHION-MNIST TEST DATASET RESULTS:")
780
+ print(f"Dataset size: {len(df_fashion_mnist)} samples")
781
+ print(f"{'Model':<20} {'Embedding':<10} {'Sep Score':<12} {'NN Acc':<10} {'Centroid Acc':<12} {'F1 Macro':<10}")
782
+ print("-" * 80)
783
+
784
+ for model_type in ['custom', 'clip']:
785
+ for emb_type in ['text', 'image']:
786
+ key = f"{model_type}_{emb_type}"
787
+ if key in fashion_mnist_results:
788
+ metrics = fashion_mnist_results[key]
789
+ model_name = "Custom Model" if model_type == 'custom' else "CLIP Baseline"
790
+ print(f"{model_name:<20} {emb_type.capitalize():<10} {metrics['separation_score']:<12.4f} {metrics['accuracy']*100:<10.1f}% {metrics['centroid_accuracy']*100:<12.1f}% {metrics['f1_macro']*100:<10.1f}%")
791
+
792
+ if kagl_results:
793
+ print("\n🌐 kagl MARQO DATASET RESULTS:")
794
+ print(f"Dataset size: {len(df_kagl_marqo)} samples")
795
+ print(f"{'Model':<20} {'Embedding':<10} {'Sep Score':<12} {'NN Acc':<10} {'Centroid Acc':<12} {'F1 Macro':<10}")
796
+ print("-" * 80)
797
+
798
+ for model_type in ['custom', 'clip']:
799
+ for emb_type in ['text', 'image']:
800
+ key = f"{model_type}_{emb_type}"
801
+ if key in kagl_results:
802
+ metrics = kagl_results[key]
803
+ model_name = "Custom Model" if model_type == 'custom' else "CLIP Baseline"
804
+ print(f"{model_name:<20} {emb_type.capitalize():<10} {metrics['separation_score']:<12.4f} {metrics['accuracy']*100:<10.1f}% {metrics['centroid_accuracy']*100:<12.1f}% {metrics['f1_macro']*100:<10.1f}%")
805
+
806
+ print(f"\nβœ… Evaluation completed! Check '{directory}/' for visualization files.")
807
+ print(f"πŸ“Š Custom model hierarchy classes: {len(evaluator.vocab.hierarchy_classes)} classes")
808
+ print(f"πŸ€— CLIP baseline comparison included")