Leacb4 commited on
Commit
4c2095f
Β·
verified Β·
1 Parent(s): f16c04f

Upload evaluation/hierarchy_evaluation.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. evaluation/hierarchy_evaluation.py +589 -0
evaluation/hierarchy_evaluation.py ADDED
@@ -0,0 +1,589 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hierarchy embedding evaluation for clothing category classification.
3
+ This file evaluates the quality of hierarchy embeddings generated by the hierarchy model
4
+ by calculating intra-class and inter-class similarity metrics, nearest neighbor and centroid-based
5
+ classification accuracies, and generating confusion matrices. It can be used on different datasets
6
+ (local validation, Kagl Marqo) to measure model generalization.
7
+ """
8
+
9
+ import torch
10
+ import pandas as pd
11
+ import numpy as np
12
+ import matplotlib.pyplot as plt
13
+ import seaborn as sns
14
+ from sklearn.metrics.pairwise import cosine_similarity
15
+ from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
16
+ from collections import defaultdict
17
+ import os
18
+ from tqdm import tqdm
19
+ from torch.utils.data import Dataset, DataLoader
20
+ from torchvision import transforms
21
+ from sklearn.model_selection import train_test_split
22
+ from io import BytesIO
23
+ from PIL import Image
24
+ import config
25
+ import warnings
26
+ warnings.filterwarnings('ignore')
27
+ from hierarchy_model import Model, HierarchyExtractor, HierarchyDataset, collate_fn
28
+
29
+
30
+ class EmbeddingEvaluator:
31
+ """
32
+ Evaluator for hierarchy embeddings generated by the hierarchy model.
33
+
34
+ This class provides methods to evaluate the quality of hierarchy embeddings by computing
35
+ similarity metrics, classification accuracies, and generating visualizations.
36
+ """
37
+
38
+ def __init__(self, model_path, directory):
39
+ """
40
+ Initialize the embedding evaluator.
41
+
42
+ Args:
43
+ model_path: Path to the trained hierarchy model checkpoint
44
+ directory: Directory to save evaluation results and visualizations
45
+ """
46
+ self.device = config.device
47
+ self.directory = directory
48
+
49
+ # 1. Load the dataset
50
+ CSV = config.local_dataset_path
51
+ print(f"πŸ“ Using dataset with local images: {CSV}")
52
+ df = pd.read_csv(CSV)
53
+
54
+ print(f"πŸ“ Loaded {len(df)} samples")
55
+
56
+ # 2. Get unique hierarchy classes from the dataset
57
+ hierarchy_classes = sorted(df[config.hierarchy_column].unique().tolist())
58
+ print(f"πŸ“‹ Found {len(hierarchy_classes)} hierarchy classes")
59
+
60
+ _, self.val_df = train_test_split(df, test_size=0.2, random_state=42, stratify=df[config.hierarchy_column])
61
+
62
+ # 3. Load the model
63
+ if os.path.exists(model_path):
64
+ checkpoint = torch.load(model_path, map_location=self.device)
65
+
66
+ # Use model_config to avoid shadowing the imported config module
67
+ model_config = checkpoint.get('config', {})
68
+ saved_hierarchy_classes = checkpoint['hierarchy_classes']
69
+
70
+ # Use the saved hierarchy classes
71
+ self.hierarchy_classes = saved_hierarchy_classes
72
+
73
+ # Create the hierarchy extractor
74
+ self.vocab = HierarchyExtractor(saved_hierarchy_classes)
75
+
76
+ # Create the model with the saved configuration
77
+ self.model = Model(
78
+ num_hierarchy_classes=len(saved_hierarchy_classes),
79
+ embed_dim=model_config['embed_dim'],
80
+ dropout=model_config['dropout']
81
+ ).to(self.device)
82
+
83
+ self.model.load_state_dict(checkpoint['model_state'])
84
+
85
+ print(f"βœ… Model loaded with:")
86
+ print(f"πŸ“‹ Hierarchy classes: {len(saved_hierarchy_classes)}")
87
+ print(f"🎯 Embed dim: {model_config['embed_dim']}")
88
+ print(f"πŸ’§ Dropout: {model_config['dropout']}")
89
+ print(f"πŸ“… Epoch: {checkpoint.get('epoch', 'unknown')}")
90
+
91
+ else:
92
+ raise FileNotFoundError(f"Model file {model_path} not found")
93
+
94
+ self.model.eval()
95
+
96
+ def create_dataloader(self, dataframe, batch_size=16):
97
+ """
98
+ Create a DataLoader for the hierarchy dataset.
99
+
100
+ Args:
101
+ dataframe: DataFrame containing the dataset
102
+ batch_size: Batch size for the DataLoader
103
+
104
+ Returns:
105
+ DataLoader instance
106
+ """
107
+ dataset = HierarchyDataset(dataframe, image_size=224)
108
+
109
+ dataloader = DataLoader(
110
+ dataset,
111
+ batch_size=batch_size,
112
+ shuffle=False,
113
+ collate_fn=lambda batch: collate_fn(batch, self.vocab),
114
+ num_workers=0
115
+ )
116
+
117
+ return dataloader
118
+
119
+ def extract_embeddings(self, dataloader, embedding_type='text'):
120
+ """
121
+ Extract embeddings from the model for a given dataloader.
122
+
123
+ Args:
124
+ dataloader: DataLoader containing images, texts, and hierarchy labels
125
+ embedding_type: Type of embeddings to extract ('text' or 'image')
126
+
127
+ Returns:
128
+ Tuple of (embeddings array, labels list, texts list)
129
+ """
130
+ all_embeddings = []
131
+ all_labels = []
132
+ all_texts = []
133
+
134
+ with torch.no_grad():
135
+ for batch in tqdm(dataloader, desc=f"Extracting {embedding_type} embeddings"):
136
+ images = batch['image'].to(self.device)
137
+ hierarchy_indices = batch['hierarchy_indices'].to(self.device)
138
+ hierarchy_labels = batch['hierarchy']
139
+
140
+ # Forward pass
141
+ out = self.model(image=images, hierarchy_indices=hierarchy_indices)
142
+ embeddings = out['z_txt'] if embedding_type == 'text' else out['z_img'] if embedding_type == 'image' else out['z_txt']
143
+
144
+ all_embeddings.append(embeddings.cpu().numpy())
145
+ all_labels.extend(hierarchy_labels)
146
+ all_texts.extend(hierarchy_labels)
147
+
148
+ return np.vstack(all_embeddings), all_labels, all_texts
149
+
150
+ def compute_similarity_metrics(self, embeddings, labels):
151
+ """
152
+ Compute intra-class and inter-class similarity metrics.
153
+
154
+ Args:
155
+ embeddings: Array of embeddings [N, embed_dim]
156
+ labels: List of labels for each embedding
157
+
158
+ Returns:
159
+ Dictionary containing similarity metrics, accuracies, and separation scores
160
+ """
161
+ similarities = cosine_similarity(embeddings)
162
+
163
+ # Group embeddings by hierarchy
164
+ hierarchy_groups = defaultdict(list)
165
+ for i, hierarchy in enumerate(labels):
166
+ hierarchy_groups[hierarchy].append(i)
167
+
168
+ # Calculate intra-class similarities (same hierarchy)
169
+ intra_class_similarities = []
170
+ for hierarchy, indices in hierarchy_groups.items():
171
+ if len(indices) > 1:
172
+ for i in range(len(indices)):
173
+ for j in range(i+1, len(indices)):
174
+ sim = similarities[indices[i], indices[j]]
175
+ intra_class_similarities.append(sim)
176
+
177
+
178
+ # Calculate inter-class similarities (different hierarchies)
179
+ inter_class_similarities = []
180
+ hierarchies = list(hierarchy_groups.keys())
181
+ for i in range(len(hierarchies)):
182
+ for j in range(i+1, len(hierarchies)):
183
+ hierarchy1_indices = hierarchy_groups[hierarchies[i]]
184
+ hierarchy2_indices = hierarchy_groups[hierarchies[j]]
185
+
186
+ for idx1 in hierarchy1_indices:
187
+ for idx2 in hierarchy2_indices:
188
+ sim = similarities[idx1, idx2]
189
+ inter_class_similarities.append(sim)
190
+
191
+ # Calculate classification accuracy using nearest neighbor in embedding space
192
+ nn_accuracy = self.compute_embedding_accuracy(embeddings, labels, similarities)
193
+
194
+ # Calculate classification accuracy using centroids
195
+ centroid_accuracy = self.compute_centroid_accuracy(embeddings, labels)
196
+
197
+ return {
198
+ 'intra_class_similarities': intra_class_similarities,
199
+ 'inter_class_similarities': inter_class_similarities,
200
+ 'intra_class_mean': np.mean(intra_class_similarities) if intra_class_similarities else 0,
201
+ 'inter_class_mean': np.mean(inter_class_similarities) if inter_class_similarities else 0,
202
+ 'separation_score': np.mean(intra_class_similarities) - np.mean(inter_class_similarities) if intra_class_similarities and inter_class_similarities else 0,
203
+ 'accuracy': nn_accuracy,
204
+ 'centroid_accuracy': centroid_accuracy
205
+ }
206
+
207
+ def compute_embedding_accuracy(self, embeddings, labels, similarities):
208
+ """
209
+ Compute classification accuracy using nearest neighbor in embedding space.
210
+
211
+ Args:
212
+ embeddings: Array of embeddings [N, embed_dim]
213
+ labels: List of true labels
214
+ similarities: Pre-computed similarity matrix [N, N]
215
+
216
+ Returns:
217
+ Accuracy score (float between 0 and 1)
218
+ """
219
+ correct_predictions = 0
220
+ total_predictions = len(labels)
221
+
222
+ for i in range(len(embeddings)):
223
+ true_label = labels[i]
224
+
225
+ # Find the most similar embedding (excluding itself)
226
+ similarities_row = similarities[i].copy()
227
+ similarities_row[i] = -1 # Exclude self-similarity
228
+ nearest_neighbor_idx = np.argmax(similarities_row)
229
+ predicted_label = labels[nearest_neighbor_idx]
230
+
231
+ if predicted_label == true_label:
232
+ correct_predictions += 1
233
+
234
+ return correct_predictions / total_predictions if total_predictions > 0 else 0
235
+
236
+ def compute_centroid_accuracy(self, embeddings, labels):
237
+ """
238
+ Compute classification accuracy using hierarchy centroids.
239
+
240
+ Each hierarchy class is represented by its centroid (mean embedding), and each
241
+ embedding is classified to the nearest centroid.
242
+
243
+ Args:
244
+ embeddings: Array of embeddings [N, embed_dim]
245
+ labels: List of true labels
246
+
247
+ Returns:
248
+ Accuracy score (float between 0 and 1)
249
+ """
250
+ # Create centroids for each hierarchy
251
+ unique_hierarchies = list(set(labels))
252
+ centroids = {}
253
+
254
+ for hierarchy in unique_hierarchies:
255
+ hierarchy_indices = [i for i, label in enumerate(labels) if label == hierarchy]
256
+ hierarchy_embeddings = embeddings[hierarchy_indices]
257
+ centroids[hierarchy] = np.mean(hierarchy_embeddings, axis=0)
258
+
259
+ # Classify each embedding to nearest centroid
260
+ correct_predictions = 0
261
+ total_predictions = len(labels)
262
+
263
+ for i, embedding in enumerate(embeddings):
264
+ true_label = labels[i]
265
+
266
+ # Find closest centroid
267
+ best_similarity = -1
268
+ predicted_label = None
269
+
270
+ for hierarchy, centroid in centroids.items():
271
+ similarity = cosine_similarity([embedding], [centroid])[0][0]
272
+ if similarity > best_similarity:
273
+ best_similarity = similarity
274
+ predicted_label = hierarchy
275
+
276
+ if predicted_label == true_label:
277
+ correct_predictions += 1
278
+
279
+ return correct_predictions / total_predictions if total_predictions > 0 else 0
280
+
281
+ def predict_hierarchy_from_embeddings(self, embeddings, labels):
282
+ """
283
+ Predict hierarchy from embeddings using centroid-based classification.
284
+
285
+ Args:
286
+ embeddings: Array of embeddings [N, embed_dim]
287
+ labels: List of labels used to compute centroids
288
+
289
+ Returns:
290
+ List of predicted hierarchy labels
291
+ """
292
+ # Create hierarchy centroids from training data
293
+ unique_hierarchies = list(set(labels))
294
+ centroids = {}
295
+
296
+ for hierarchy in unique_hierarchies:
297
+ hierarchy_indices = [i for i, label in enumerate(labels) if label == hierarchy]
298
+ hierarchy_embeddings = embeddings[hierarchy_indices]
299
+ centroids[hierarchy] = np.mean(hierarchy_embeddings, axis=0)
300
+
301
+ # Predict hierarchy for all embeddings
302
+ predictions = []
303
+
304
+ for i, embedding in enumerate(embeddings):
305
+ # Find closest centroid
306
+ best_similarity = -1
307
+ predicted_hierarchy = None
308
+
309
+ for hierarchy, centroid in centroids.items():
310
+ similarity = cosine_similarity([embedding], [centroid])[0][0]
311
+ if similarity > best_similarity:
312
+ best_similarity = similarity
313
+ predicted_hierarchy = hierarchy
314
+
315
+ predictions.append(predicted_hierarchy)
316
+
317
+ return predictions
318
+
319
+ def create_confusion_matrix(self, true_labels, predicted_labels, title="Confusion Matrix"):
320
+ """
321
+ Create and plot a confusion matrix.
322
+
323
+ Args:
324
+ true_labels: List of true labels
325
+ predicted_labels: List of predicted labels
326
+ title: Title for the confusion matrix plot
327
+
328
+ Returns:
329
+ Tuple of (figure, accuracy, confusion_matrix)
330
+ """
331
+ # Get unique labels
332
+ unique_labels = sorted(list(set(true_labels + predicted_labels)))
333
+
334
+ # Create confusion matrix
335
+ cm = confusion_matrix(true_labels, predicted_labels, labels=unique_labels)
336
+
337
+ # Calculate accuracy
338
+ accuracy = accuracy_score(true_labels, predicted_labels)
339
+
340
+ # Plot confusion matrix
341
+ plt.figure(figsize=(12, 10))
342
+ sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
343
+ xticklabels=unique_labels, yticklabels=unique_labels)
344
+ plt.title(f'{title}\nAccuracy: {accuracy:.3f} ({accuracy*100:.1f}%)')
345
+ plt.ylabel('True Hierarchy')
346
+ plt.xlabel('Predicted Hierarchy')
347
+ plt.xticks(rotation=45)
348
+ plt.yticks(rotation=0)
349
+ plt.tight_layout()
350
+
351
+ return plt.gcf(), accuracy, cm
352
+
353
+ def evaluate_classification_performance(self, embeddings, labels, embedding_type="Embeddings"):
354
+ """
355
+ Evaluate classification performance and create confusion matrix.
356
+
357
+ Args:
358
+ embeddings: Array of embeddings [N, embed_dim]
359
+ labels: List of true labels
360
+ embedding_type: Type of embeddings for display purposes
361
+
362
+ Returns:
363
+ Dictionary containing accuracy, predictions, confusion matrix, and classification report
364
+ """
365
+ # Predict hierarchy
366
+ predictions = self.predict_hierarchy_from_embeddings(embeddings, labels)
367
+
368
+ # Calculate accuracy
369
+ accuracy = accuracy_score(labels, predictions)
370
+
371
+ # Create confusion matrix
372
+ fig, acc, cm = self.create_confusion_matrix(labels, predictions,
373
+ f"{embedding_type} - Hierarchy Classification")
374
+
375
+ # Generate classification report
376
+ unique_labels = sorted(list(set(labels)))
377
+ report = classification_report(labels, predictions, labels=unique_labels,
378
+ target_names=unique_labels, output_dict=True)
379
+
380
+ return {
381
+ 'accuracy': accuracy,
382
+ 'predictions': predictions,
383
+ 'confusion_matrix': cm,
384
+ 'classification_report': report,
385
+ 'figure': fig
386
+ }
387
+
388
+ def evaluate_dataset(self, dataframe, dataset_name="Dataset"):
389
+ """
390
+ Evaluate embeddings on a given dataset.
391
+
392
+ This method extracts embeddings for text and image, computes similarity metrics,
393
+ evaluates classification performance, and saves confusion matrices.
394
+
395
+ Args:
396
+ dataframe: DataFrame containing the dataset
397
+ dataset_name: Name of the dataset for display purposes
398
+
399
+ Returns:
400
+ Dictionary containing evaluation results for text and image embeddings
401
+ """
402
+ print(f"\n{'='*60}")
403
+ print(f"Evaluating {dataset_name}")
404
+ print(f"{'='*60}")
405
+
406
+ # Create dataloader exactly as during training
407
+ dataloader = self.create_dataloader(dataframe, batch_size=16)
408
+
409
+ results = {}
410
+
411
+ # Evaluate text embeddings
412
+ text_embeddings, text_labels, texts = self.extract_embeddings(dataloader, 'text')
413
+ text_metrics = self.compute_similarity_metrics(text_embeddings, text_labels)
414
+ text_classification = self.evaluate_classification_performance(text_embeddings, text_labels, "Text Embeddings")
415
+ text_metrics.update(text_classification)
416
+ results['text'] = text_metrics
417
+
418
+ # Evaluate image embeddings
419
+ image_embeddings, image_labels, _ = self.extract_embeddings(dataloader, 'image')
420
+ image_metrics = self.compute_similarity_metrics(image_embeddings, image_labels)
421
+ image_classification = self.evaluate_classification_performance(image_embeddings, image_labels, "Image Embeddings")
422
+ image_metrics.update(image_classification)
423
+ results['image'] = image_metrics
424
+
425
+ # Evaluate hierarchy embeddings
426
+ hierarchy_embeddings, hierarchy_labels, _ = self.extract_embeddings(dataloader, 'category2')
427
+ hierarchy_metrics = self.compute_similarity_metrics(hierarchy_embeddings, hierarchy_labels)
428
+ hierarchy_classification = self.evaluate_classification_performance(hierarchy_embeddings, hierarchy_labels, "hierarchy Embeddings")
429
+ hierarchy_metrics.update(hierarchy_classification)
430
+ results['hierarchy'] = hierarchy_metrics
431
+
432
+ # Print results
433
+ print(f"\n{dataset_name} Results:")
434
+ print("-" * 40)
435
+ for emb_type, metrics in results.items():
436
+ print(f"{emb_type.capitalize()} Embeddings:")
437
+ print(f" Intra-class similarity (same hierarchy): {metrics['intra_class_mean']:.4f}")
438
+ print(f" Inter-class similarity (diff hierarchy): {metrics['inter_class_mean']:.4f}")
439
+ print(f" Separation score: {metrics['separation_score']:.4f}")
440
+ print(f" Nearest Neighbor Accuracy: {metrics['accuracy']:.4f} ({metrics['accuracy']*100:.1f}%)")
441
+ print(f" Centroid Accuracy: {metrics['centroid_accuracy']:.4f} ({metrics['centroid_accuracy']*100:.1f}%)")
442
+
443
+ # Classification report summary
444
+ report = metrics['classification_report']
445
+ print(f" πŸ“Š Classification Performance:")
446
+ print(f" β€’ Macro Avg F1-Score: {report['macro avg']['f1-score']:.4f}")
447
+ print(f" β€’ Weighted Avg F1-Score: {report['weighted avg']['f1-score']:.4f}")
448
+ print(f" β€’ Support: {report['macro avg']['support']:.0f} samples")
449
+ print()
450
+
451
+ # Create visualizations
452
+ os.makedirs(f'{self.directory}', exist_ok=True)
453
+
454
+ # Confusion matrices
455
+ results['text']['figure'].savefig(f'{self.directory}/{dataset_name.lower()}_text_confusion_matrix.png', dpi=300, bbox_inches='tight')
456
+ plt.close(results['text']['figure'])
457
+
458
+ results['image']['figure'].savefig(f'{self.directory}/{dataset_name.lower()}_image_confusion_matrix.png', dpi=300, bbox_inches='tight')
459
+ plt.close(results['image']['figure'])
460
+
461
+ results['hierarchy']['figure'].savefig(f'{self.directory}/{dataset_name.lower()}_hierarchy_confusion_matrix.png', dpi=300, bbox_inches='tight')
462
+ plt.close(results['hierarchy']['figure'])
463
+
464
+ return results
465
+
466
+ class KaglDataset(Dataset):
467
+ def __init__(self, dataframe):
468
+ self.dataframe = dataframe
469
+ # Use VALIDATION transforms (no augmentation)
470
+ self.transform = transforms.Compose([
471
+ transforms.Resize((224, 224)),
472
+ transforms.ToTensor(),
473
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
474
+ ])
475
+
476
+ def __len__(self):
477
+ return len(self.dataframe)
478
+
479
+ def __getitem__(self, idx):
480
+ row = self.dataframe.iloc[idx]
481
+
482
+ # Handle image
483
+ image_data = row['image_url']
484
+ image = Image.open(BytesIO(image_data['bytes'])).convert("RGB")
485
+ image = self.transform(image)
486
+
487
+ # Get text and hierarchy
488
+ description = row['text']
489
+ hierarchy = row['hierarchy']
490
+
491
+ return image, description, hierarchy
492
+
493
+ def load_Kagl_marqo_dataset(evaluator):
494
+ """Load and prepare Kagl KAGL dataset"""
495
+ from datasets import load_dataset
496
+ print("Loading Kagl KAGL dataset...")
497
+
498
+ # Load the dataset
499
+ dataset = load_dataset("Marqo/KAGL")
500
+ df = dataset["data"].to_pandas()
501
+ print(f"βœ… Dataset Kagl loaded")
502
+ print(f"πŸ“Š Before filtering: {len(df)} samples")
503
+ print(f"πŸ“‹ Available columns: {list(df.columns)}")
504
+
505
+ # Check available categories and map them to our hierarchy
506
+ print(f"🎨 Available categories: {sorted(df['category2'].unique())}")
507
+ # Apply mapping
508
+ df['hierarchy'] = df['category2'].str.lower()
509
+ df['hierarchy'] = df['hierarchy'].replace('bags', 'bag').replace('topwear', 'top').replace('flip flops', 'shoes').replace('sandal', 'shoes')
510
+
511
+ # Filter to only include valid hierarchies that exist in our model
512
+ valid_hierarchies = df['hierarchy'].dropna().unique()
513
+ print(f"🎯 Valid hierarchies found: {sorted(valid_hierarchies)}")
514
+ print(f"🎯 Model hierarchies: {sorted(evaluator.hierarchy_classes)}")
515
+
516
+ # Filter to only include hierarchies that exist in our model
517
+ df = df[df['hierarchy'].isin(evaluator.hierarchy_classes)]
518
+ print(f"πŸ“Š After filtering to model hierarchies: {len(df)} samples")
519
+
520
+ if len(df) == 0:
521
+ print("❌ No samples left after hierarchy filtering.")
522
+ return pd.DataFrame()
523
+
524
+ # Ensure we have text and image data
525
+ df = df.dropna(subset=['text', 'image'])
526
+ print(f"πŸ“Š After removing missing text/image: {len(df)} samples")
527
+
528
+ # Show sample of text data to verify quality
529
+ print(f"πŸ“ Sample texts:")
530
+ for i, (text, hierarchy) in enumerate(zip(df['text'].head(3), df['hierarchy'].head(3))):
531
+ print(f" {i+1}. [{hierarchy}] {text[:100]}...")
532
+
533
+ print(f"πŸ“Š After sampling: {len(df)} samples")
534
+ print(f"πŸ“Š Samples per hierarchy:")
535
+ for hierarchy in sorted(df['hierarchy'].unique()):
536
+ count = len(df[df['hierarchy'] == hierarchy])
537
+ print(f" {hierarchy}: {count} samples")
538
+
539
+ # Create formatted dataset with proper column names
540
+ Kagl_formatted = pd.DataFrame({
541
+ 'image_url': df['image'],
542
+ 'text': df['text'],
543
+ 'hierarchy': df['hierarchy']
544
+ })
545
+
546
+ print(f"πŸ“Š Final dataset size: {len(Kagl_formatted)} samples")
547
+ return Kagl_formatted
548
+
549
+ if __name__ == "__main__":
550
+ device = config.device
551
+ model_path = config.hierarchy_model_path
552
+ directory = config.evaluation_directory
553
+
554
+ print(f"πŸš€ Starting evaluation with {model_path}")
555
+
556
+ evaluator = EmbeddingEvaluator(model_path, directory)
557
+
558
+ print(f"πŸ“Š Final hierarchy classes after initialization: {len(evaluator.vocab.hierarchy_classes)} classes")
559
+
560
+ # Evaluate on validation dataset (same subset as during training)
561
+ print("\n" + "="*60)
562
+ print("EVALUATING VALIDATION DATASET")
563
+ print("="*60)
564
+ val_results = evaluator.evaluate_dataset(evaluator.val_df, "Validation Dataset")
565
+
566
+ print("\n" + "="*60)
567
+ print("EVALUATING Kagl MARQO DATASET")
568
+ print("="*60)
569
+ df_Kagl_marqo = load_Kagl_marqo_dataset(evaluator)
570
+ Kagl_results = evaluator.evaluate_dataset(df_Kagl_marqo, "Kagl Marqo Dataset")
571
+
572
+ # Compare results
573
+ print(f"\n{'='*60}")
574
+ print("FINAL EVALUATION SUMMARY")
575
+ print(f"{'='*60}")
576
+
577
+ print("\nπŸ” VALIDATION DATASET RESULTS:")
578
+ print(f"Text - Separation: {val_results['text']['separation_score']:.4f} | NN Acc: {val_results['text']['accuracy']*100:.1f}% | Centroid Acc: {val_results['text']['centroid_accuracy']*100:.1f}%")
579
+ print(f"Image - Separation: {val_results['image']['separation_score']:.4f} | NN Acc: {val_results['image']['accuracy']*100:.1f}% | Centroid Acc: {val_results['image']['centroid_accuracy']*100:.1f}%")
580
+ print(f"hierarchy - Separation: {val_results['hierarchy']['separation_score']:.4f} | NN Acc: {val_results['hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {val_results['hierarchy']['centroid_accuracy']*100:.1f}%")
581
+
582
+ print("\n🌐 Kagl MARQO DATASET RESULTS:")
583
+ print(f"Text - Separation: {Kagl_results['text']['separation_score']:.4f} | NN Acc: {Kagl_results['text']['accuracy']*100:.1f}% | Centroid Acc: {Kagl_results['text']['centroid_accuracy']*100:.1f}%")
584
+ print(f"Image - Separation: {Kagl_results['image']['separation_score']:.4f} | NN Acc: {Kagl_results['image']['accuracy']*100:.1f}% | Centroid Acc: {Kagl_results['image']['centroid_accuracy']*100:.1f}%")
585
+ print(f"Hierarchy - Separation: {Kagl_results['hierarchy']['separation_score']:.4f} | NN Acc: {Kagl_results['hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {Kagl_results['hierarchy']['centroid_accuracy']*100:.1f}%")
586
+
587
+
588
+ print(f"\nβœ… Evaluation completed! Check 'improved_model_evaluation/' for visualization files.")
589
+ print(f"πŸ“Š Final hierarchy classes used: {len(evaluator.vocab.hierarchy_classes)} classes")