Leacb4 commited on
Commit
70f9f13
Β·
verified Β·
1 Parent(s): 43824ca

Upload evaluation/hierarchy_evaluation.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. evaluation/hierarchy_evaluation.py +471 -252
evaluation/hierarchy_evaluation.py CHANGED
@@ -1,9 +1,9 @@
1
  """
2
- Hierarchy embedding evaluation for clothing category classification.
3
- This file evaluates the quality of hierarchy embeddings generated by the hierarchy model
4
- by calculating intra-class and inter-class similarity metrics, nearest neighbor and centroid-based
5
- classification accuracies, and generating confusion matrices. It can be used on different datasets
6
- (local validation, Kagl Marqo) to measure model generalization.
7
  """
8
 
9
  import torch
@@ -12,59 +12,165 @@ import numpy as np
12
  import matplotlib.pyplot as plt
13
  import seaborn as sns
14
  from sklearn.metrics.pairwise import cosine_similarity
15
- from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
16
  from collections import defaultdict
17
  import os
 
18
  from tqdm import tqdm
19
  from torch.utils.data import Dataset, DataLoader
20
  from torchvision import transforms
21
  from sklearn.model_selection import train_test_split
22
  from io import BytesIO
23
  from PIL import Image
24
- import config
25
  import warnings
26
  warnings.filterwarnings('ignore')
 
 
 
 
 
27
  from hierarchy_model import Model, HierarchyExtractor, HierarchyDataset, collate_fn
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
 
30
  class EmbeddingEvaluator:
31
- """
32
- Evaluator for hierarchy embeddings generated by the hierarchy model.
33
-
34
- This class provides methods to evaluate the quality of hierarchy embeddings by computing
35
- similarity metrics, classification accuracies, and generating visualizations.
36
- """
37
-
38
  def __init__(self, model_path, directory):
39
- """
40
- Initialize the embedding evaluator.
41
-
42
- Args:
43
- model_path: Path to the trained hierarchy model checkpoint
44
- directory: Directory to save evaluation results and visualizations
45
- """
46
- self.device = config.device
47
  self.directory = directory
48
 
49
  # 1. Load the dataset
50
- CSV = config.local_dataset_path
51
- print(f"πŸ“ Using dataset with local images: {CSV}")
52
- df = pd.read_csv(CSV)
53
 
54
  print(f"πŸ“ Loaded {len(df)} samples")
55
 
56
  # 2. Get unique hierarchy classes from the dataset
57
- hierarchy_classes = sorted(df[config.hierarchy_column].unique().tolist())
58
  print(f"πŸ“‹ Found {len(hierarchy_classes)} hierarchy classes")
59
 
60
- _, self.val_df = train_test_split(df, test_size=0.2, random_state=42, stratify=df[config.hierarchy_column])
61
 
62
  # 3. Load the model
63
  if os.path.exists(model_path):
64
  checkpoint = torch.load(model_path, map_location=self.device)
65
 
66
- # Use model_config to avoid shadowing the imported config module
67
- model_config = checkpoint.get('config', {})
68
  saved_hierarchy_classes = checkpoint['hierarchy_classes']
69
 
70
  # Use the saved hierarchy classes
@@ -76,35 +182,34 @@ class EmbeddingEvaluator:
76
  # Create the model with the saved configuration
77
  self.model = Model(
78
  num_hierarchy_classes=len(saved_hierarchy_classes),
79
- embed_dim=model_config['embed_dim'],
80
- dropout=model_config['dropout']
81
  ).to(self.device)
82
 
83
  self.model.load_state_dict(checkpoint['model_state'])
84
 
85
- print(f"βœ… Model loaded with:")
86
  print(f"πŸ“‹ Hierarchy classes: {len(saved_hierarchy_classes)}")
87
- print(f"🎯 Embed dim: {model_config['embed_dim']}")
88
- print(f"πŸ’§ Dropout: {model_config['dropout']}")
89
  print(f"πŸ“… Epoch: {checkpoint.get('epoch', 'unknown')}")
90
 
91
  else:
92
  raise FileNotFoundError(f"Model file {model_path} not found")
93
 
94
  self.model.eval()
 
 
 
95
 
96
  def create_dataloader(self, dataframe, batch_size=16):
97
- """
98
- Create a DataLoader for the hierarchy dataset.
99
-
100
- Args:
101
- dataframe: DataFrame containing the dataset
102
- batch_size: Batch size for the DataLoader
103
-
104
- Returns:
105
- DataLoader instance
106
- """
107
- dataset = HierarchyDataset(dataframe, image_size=224)
108
 
109
  dataloader = DataLoader(
110
  dataset,
@@ -116,23 +221,32 @@ class EmbeddingEvaluator:
116
 
117
  return dataloader
118
 
119
- def extract_embeddings(self, dataloader, embedding_type='text'):
120
- """
121
- Extract embeddings from the model for a given dataloader.
 
 
 
 
 
122
 
123
- Args:
124
- dataloader: DataLoader containing images, texts, and hierarchy labels
125
- embedding_type: Type of embeddings to extract ('text' or 'image')
126
-
127
- Returns:
128
- Tuple of (embeddings array, labels list, texts list)
129
- """
 
 
 
 
130
  all_embeddings = []
131
  all_labels = []
132
  all_texts = []
133
 
134
  with torch.no_grad():
135
- for batch in tqdm(dataloader, desc=f"Extracting {embedding_type} embeddings"):
136
  images = batch['image'].to(self.device)
137
  hierarchy_indices = batch['hierarchy_indices'].to(self.device)
138
  hierarchy_labels = batch['hierarchy']
@@ -146,18 +260,9 @@ class EmbeddingEvaluator:
146
  all_texts.extend(hierarchy_labels)
147
 
148
  return np.vstack(all_embeddings), all_labels, all_texts
149
-
150
  def compute_similarity_metrics(self, embeddings, labels):
151
- """
152
- Compute intra-class and inter-class similarity metrics.
153
-
154
- Args:
155
- embeddings: Array of embeddings [N, embed_dim]
156
- labels: List of labels for each embedding
157
-
158
- Returns:
159
- Dictionary containing similarity metrics, accuracies, and separation scores
160
- """
161
  similarities = cosine_similarity(embeddings)
162
 
163
  # Group embeddings by hierarchy
@@ -174,7 +279,6 @@ class EmbeddingEvaluator:
174
  sim = similarities[indices[i], indices[j]]
175
  intra_class_similarities.append(sim)
176
 
177
-
178
  # Calculate inter-class similarities (different hierarchies)
179
  inter_class_similarities = []
180
  hierarchies = list(hierarchy_groups.keys())
@@ -205,17 +309,7 @@ class EmbeddingEvaluator:
205
  }
206
 
207
  def compute_embedding_accuracy(self, embeddings, labels, similarities):
208
- """
209
- Compute classification accuracy using nearest neighbor in embedding space.
210
-
211
- Args:
212
- embeddings: Array of embeddings [N, embed_dim]
213
- labels: List of true labels
214
- similarities: Pre-computed similarity matrix [N, N]
215
-
216
- Returns:
217
- Accuracy score (float between 0 and 1)
218
- """
219
  correct_predictions = 0
220
  total_predictions = len(labels)
221
 
@@ -234,19 +328,7 @@ class EmbeddingEvaluator:
234
  return correct_predictions / total_predictions if total_predictions > 0 else 0
235
 
236
  def compute_centroid_accuracy(self, embeddings, labels):
237
- """
238
- Compute classification accuracy using hierarchy centroids.
239
-
240
- Each hierarchy class is represented by its centroid (mean embedding), and each
241
- embedding is classified to the nearest centroid.
242
-
243
- Args:
244
- embeddings: Array of embeddings [N, embed_dim]
245
- labels: List of true labels
246
-
247
- Returns:
248
- Accuracy score (float between 0 and 1)
249
- """
250
  # Create centroids for each hierarchy
251
  unique_hierarchies = list(set(labels))
252
  centroids = {}
@@ -277,18 +359,33 @@ class EmbeddingEvaluator:
277
  correct_predictions += 1
278
 
279
  return correct_predictions / total_predictions if total_predictions > 0 else 0
280
-
281
- def predict_hierarchy_from_embeddings(self, embeddings, labels):
282
- """
283
- Predict hierarchy from embeddings using centroid-based classification.
 
284
 
285
- Args:
286
- embeddings: Array of embeddings [N, embed_dim]
287
- labels: List of labels used to compute centroids
288
-
289
- Returns:
290
- List of predicted hierarchy labels
291
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  # Create hierarchy centroids from training data
293
  unique_hierarchies = list(set(labels))
294
  centroids = {}
@@ -315,155 +412,130 @@ class EmbeddingEvaluator:
315
  predictions.append(predicted_hierarchy)
316
 
317
  return predictions
318
-
319
- def create_confusion_matrix(self, true_labels, predicted_labels, title="Confusion Matrix"):
320
- """
321
- Create and plot a confusion matrix.
322
-
323
- Args:
324
- true_labels: List of true labels
325
- predicted_labels: List of predicted labels
326
- title: Title for the confusion matrix plot
327
-
328
- Returns:
329
- Tuple of (figure, accuracy, confusion_matrix)
330
- """
331
- # Get unique labels
332
- unique_labels = sorted(list(set(true_labels + predicted_labels)))
333
-
334
- # Create confusion matrix
335
- cm = confusion_matrix(true_labels, predicted_labels, labels=unique_labels)
336
-
337
- # Calculate accuracy
338
- accuracy = accuracy_score(true_labels, predicted_labels)
339
-
340
- # Plot confusion matrix
341
- plt.figure(figsize=(12, 10))
342
- sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
343
- xticklabels=unique_labels, yticklabels=unique_labels)
344
- plt.title(f'{title}\nAccuracy: {accuracy:.3f} ({accuracy*100:.1f}%)')
345
- plt.ylabel('True Hierarchy')
346
- plt.xlabel('Predicted Hierarchy')
347
- plt.xticks(rotation=45)
348
- plt.yticks(rotation=0)
349
- plt.tight_layout()
350
-
351
- return plt.gcf(), accuracy, cm
352
-
353
  def evaluate_classification_performance(self, embeddings, labels, embedding_type="Embeddings"):
354
- """
355
- Evaluate classification performance and create confusion matrix.
356
-
357
- Args:
358
- embeddings: Array of embeddings [N, embed_dim]
359
- labels: List of true labels
360
- embedding_type: Type of embeddings for display purposes
361
-
362
- Returns:
363
- Dictionary containing accuracy, predictions, confusion matrix, and classification report
364
- """
365
  # Predict hierarchy
366
  predictions = self.predict_hierarchy_from_embeddings(embeddings, labels)
367
 
368
  # Calculate accuracy
369
  accuracy = accuracy_score(labels, predictions)
370
 
 
 
 
 
 
 
371
  # Create confusion matrix
372
  fig, acc, cm = self.create_confusion_matrix(labels, predictions,
373
  f"{embedding_type} - Hierarchy Classification")
374
 
375
  # Generate classification report
376
- unique_labels = sorted(list(set(labels)))
377
  report = classification_report(labels, predictions, labels=unique_labels,
378
  target_names=unique_labels, output_dict=True)
379
 
380
  return {
381
  'accuracy': accuracy,
 
 
 
382
  'predictions': predictions,
383
  'confusion_matrix': cm,
384
  'classification_report': report,
385
  'figure': fig
386
  }
387
-
388
- def evaluate_dataset(self, dataframe, dataset_name="Dataset"):
389
- """
390
- Evaluate embeddings on a given dataset.
391
-
392
- This method extracts embeddings for text and image, computes similarity metrics,
393
- evaluates classification performance, and saves confusion matrices.
394
-
395
- Args:
396
- dataframe: DataFrame containing the dataset
397
- dataset_name: Name of the dataset for display purposes
398
-
399
- Returns:
400
- Dictionary containing evaluation results for text and image embeddings
401
- """
402
  print(f"\n{'='*60}")
403
  print(f"Evaluating {dataset_name}")
404
  print(f"{'='*60}")
405
 
406
- # Create dataloader exactly as during training
407
- dataloader = self.create_dataloader(dataframe, batch_size=16)
408
-
409
  results = {}
410
 
 
 
 
 
 
 
 
411
  # Evaluate text embeddings
412
- text_embeddings, text_labels, texts = self.extract_embeddings(dataloader, 'text')
413
  text_metrics = self.compute_similarity_metrics(text_embeddings, text_labels)
414
- text_classification = self.evaluate_classification_performance(text_embeddings, text_labels, "Text Embeddings")
415
  text_metrics.update(text_classification)
416
- results['text'] = text_metrics
417
 
418
  # Evaluate image embeddings
419
- image_embeddings, image_labels, _ = self.extract_embeddings(dataloader, 'image')
420
  image_metrics = self.compute_similarity_metrics(image_embeddings, image_labels)
421
- image_classification = self.evaluate_classification_performance(image_embeddings, image_labels, "Image Embeddings")
422
  image_metrics.update(image_classification)
423
- results['image'] = image_metrics
424
-
425
- # Evaluate hierarchy embeddings
426
- hierarchy_embeddings, hierarchy_labels, _ = self.extract_embeddings(dataloader, 'category2')
427
- hierarchy_metrics = self.compute_similarity_metrics(hierarchy_embeddings, hierarchy_labels)
428
- hierarchy_classification = self.evaluate_classification_performance(hierarchy_embeddings, hierarchy_labels, "hierarchy Embeddings")
429
- hierarchy_metrics.update(hierarchy_classification)
430
- results['hierarchy'] = hierarchy_metrics
431
 
432
- # Print results
433
- print(f"\n{dataset_name} Results:")
434
  print("-" * 40)
435
- for emb_type, metrics in results.items():
436
- print(f"{emb_type.capitalize()} Embeddings:")
437
- print(f" Intra-class similarity (same hierarchy): {metrics['intra_class_mean']:.4f}")
438
- print(f" Inter-class similarity (diff hierarchy): {metrics['inter_class_mean']:.4f}")
439
- print(f" Separation score: {metrics['separation_score']:.4f}")
440
- print(f" Nearest Neighbor Accuracy: {metrics['accuracy']:.4f} ({metrics['accuracy']*100:.1f}%)")
441
- print(f" Centroid Accuracy: {metrics['centroid_accuracy']:.4f} ({metrics['centroid_accuracy']*100:.1f}%)")
442
-
443
- # Classification report summary
444
- report = metrics['classification_report']
445
- print(f" πŸ“Š Classification Performance:")
446
- print(f" β€’ Macro Avg F1-Score: {report['macro avg']['f1-score']:.4f}")
447
- print(f" β€’ Weighted Avg F1-Score: {report['weighted avg']['f1-score']:.4f}")
448
- print(f" β€’ Support: {report['macro avg']['support']:.0f} samples")
449
- print()
450
-
451
- # Create visualizations
452
- os.makedirs(f'{self.directory}', exist_ok=True)
453
 
454
- # Confusion matrices
455
- results['text']['figure'].savefig(f'{self.directory}/{dataset_name.lower()}_text_confusion_matrix.png', dpi=300, bbox_inches='tight')
456
- plt.close(results['text']['figure'])
457
 
458
- results['image']['figure'].savefig(f'{self.directory}/{dataset_name.lower()}_image_confusion_matrix.png', dpi=300, bbox_inches='tight')
459
- plt.close(results['image']['figure'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
460
 
461
- results['hierarchy']['figure'].savefig(f'{self.directory}/{dataset_name.lower()}_hierarchy_confusion_matrix.png', dpi=300, bbox_inches='tight')
462
- plt.close(results['hierarchy']['figure'])
 
 
 
463
 
464
  return results
465
 
466
- class KaglDataset(Dataset):
 
467
  def __init__(self, dataframe):
468
  self.dataframe = dataframe
469
  # Use VALIDATION transforms (no augmentation)
@@ -479,26 +551,130 @@ class KaglDataset(Dataset):
479
  def __getitem__(self, idx):
480
  row = self.dataframe.iloc[idx]
481
 
482
- # Handle image
483
- image_data = row['image_url']
484
- image = Image.open(BytesIO(image_data['bytes'])).convert("RGB")
485
- image = self.transform(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
486
 
487
- # Get text and hierarchy
488
- description = row['text']
489
- hierarchy = row['hierarchy']
 
490
 
491
- return image, description, hierarchy
492
 
493
- def load_Kagl_marqo_dataset(evaluator):
494
- """Load and prepare Kagl KAGL dataset"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
495
  from datasets import load_dataset
496
- print("Loading Kagl KAGL dataset...")
497
 
498
  # Load the dataset
499
  dataset = load_dataset("Marqo/KAGL")
500
  df = dataset["data"].to_pandas()
501
- print(f"βœ… Dataset Kagl loaded")
502
  print(f"πŸ“Š Before filtering: {len(df)} samples")
503
  print(f"πŸ“‹ Available columns: {list(df.columns)}")
504
 
@@ -530,60 +706,103 @@ def load_Kagl_marqo_dataset(evaluator):
530
  for i, (text, hierarchy) in enumerate(zip(df['text'].head(3), df['hierarchy'].head(3))):
531
  print(f" {i+1}. [{hierarchy}] {text[:100]}...")
532
 
533
- print(f"πŸ“Š After sampling: {len(df)} samples")
534
  print(f"πŸ“Š Samples per hierarchy:")
535
- for hierarchy in sorted(df['hierarchy'].unique()):
536
- count = len(df[df['hierarchy'] == hierarchy])
537
  print(f" {hierarchy}: {count} samples")
538
 
539
  # Create formatted dataset with proper column names
540
- Kagl_formatted = pd.DataFrame({
541
- 'image_url': df['image'],
542
- 'text': df['text'],
543
- 'hierarchy': df['hierarchy']
544
  })
545
 
546
- print(f"πŸ“Š Final dataset size: {len(Kagl_formatted)} samples")
547
- return Kagl_formatted
 
548
 
549
  if __name__ == "__main__":
550
- device = config.device
551
- model_path = config.hierarchy_model_path
552
- directory = config.evaluation_directory
553
 
554
- print(f"πŸš€ Starting evaluation with {model_path}")
 
555
 
556
- evaluator = EmbeddingEvaluator(model_path, directory)
557
 
558
  print(f"πŸ“Š Final hierarchy classes after initialization: {len(evaluator.vocab.hierarchy_classes)} classes")
559
 
560
  # Evaluate on validation dataset (same subset as during training)
561
  print("\n" + "="*60)
562
- print("EVALUATING VALIDATION DATASET")
563
  print("="*60)
564
- val_results = evaluator.evaluate_dataset(evaluator.val_df, "Validation Dataset")
565
 
566
  print("\n" + "="*60)
567
- print("EVALUATING Kagl MARQO DATASET")
568
  print("="*60)
569
- df_Kagl_marqo = load_Kagl_marqo_dataset(evaluator)
570
- Kagl_results = evaluator.evaluate_dataset(df_Kagl_marqo, "Kagl Marqo Dataset")
 
 
 
 
 
 
 
 
 
 
 
 
571
 
572
  # Compare results
573
- print(f"\n{'='*60}")
574
- print("FINAL EVALUATION SUMMARY")
575
- print(f"{'='*60}")
576
 
577
  print("\nπŸ” VALIDATION DATASET RESULTS:")
578
- print(f"Text - Separation: {val_results['text']['separation_score']:.4f} | NN Acc: {val_results['text']['accuracy']*100:.1f}% | Centroid Acc: {val_results['text']['centroid_accuracy']*100:.1f}%")
579
- print(f"Image - Separation: {val_results['image']['separation_score']:.4f} | NN Acc: {val_results['image']['accuracy']*100:.1f}% | Centroid Acc: {val_results['image']['centroid_accuracy']*100:.1f}%")
580
- print(f"hierarchy - Separation: {val_results['hierarchy']['separation_score']:.4f} | NN Acc: {val_results['hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {val_results['hierarchy']['centroid_accuracy']*100:.1f}%")
581
 
582
- print("\n🌐 Kagl MARQO DATASET RESULTS:")
583
- print(f"Text - Separation: {Kagl_results['text']['separation_score']:.4f} | NN Acc: {Kagl_results['text']['accuracy']*100:.1f}% | Centroid Acc: {Kagl_results['text']['centroid_accuracy']*100:.1f}%")
584
- print(f"Image - Separation: {Kagl_results['image']['separation_score']:.4f} | NN Acc: {Kagl_results['image']['accuracy']*100:.1f}% | Centroid Acc: {Kagl_results['image']['centroid_accuracy']*100:.1f}%")
585
- print(f"Hierarchy - Separation: {Kagl_results['hierarchy']['separation_score']:.4f} | NN Acc: {Kagl_results['hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {Kagl_results['hierarchy']['centroid_accuracy']*100:.1f}%")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
586
 
 
 
 
 
 
 
 
 
 
 
 
 
 
587
 
588
- print(f"\nβœ… Evaluation completed! Check 'improved_model_evaluation/' for visualization files.")
589
- print(f"πŸ“Š Final hierarchy classes used: {len(evaluator.vocab.hierarchy_classes)} classes")
 
 
1
  """
2
+ Hierarchy embedding evaluation with CLIP baseline comparison.
3
+ This file evaluates the quality of hierarchy embeddings from the custom model and compares them
4
+ with a CLIP baseline model (Fashion-CLIP by patrickjohncyh). It calculates similarity metrics,
5
+ classification accuracies, and generates confusion matrices for both models to measure relative
6
+ performance. It also supports evaluation on Fashion-MNIST and kagl Marqo datasets.
7
  """
8
 
9
  import torch
 
12
  import matplotlib.pyplot as plt
13
  import seaborn as sns
14
  from sklearn.metrics.pairwise import cosine_similarity
15
+ from sklearn.metrics import confusion_matrix, classification_report, accuracy_score, f1_score
16
  from collections import defaultdict
17
  import os
18
+ import requests
19
  from tqdm import tqdm
20
  from torch.utils.data import Dataset, DataLoader
21
  from torchvision import transforms
22
  from sklearn.model_selection import train_test_split
23
  from io import BytesIO
24
  from PIL import Image
25
+ from config import device, hierarchy_model_path, hierarchy_column, local_dataset_path
26
  import warnings
27
  warnings.filterwarnings('ignore')
28
+
29
+ # Import transformers CLIP
30
+ from transformers import CLIPProcessor, CLIPModel as TransformersCLIPModel
31
+
32
+ # Import your custom model
33
  from hierarchy_model import Model, HierarchyExtractor, HierarchyDataset, collate_fn
34
+ import config
35
+
36
+ def convert_fashion_mnist_to_image(pixel_values):
37
+ """Convert Fashion-MNIST pixel values to PIL image"""
38
+ # Reshape to 28x28 and convert to PIL Image
39
+ image_array = np.array(pixel_values).reshape(28, 28).astype(np.uint8)
40
+ # Convert to RGB by duplicating the grayscale channel
41
+ image_array = np.stack([image_array] * 3, axis=-1)
42
+ image = Image.fromarray(image_array)
43
+ return image
44
+
45
+ def get_fashion_mnist_labels():
46
+ """Get Fashion-MNIST class labels"""
47
+ return {
48
+ 0: "T-shirt/top",
49
+ 1: "Trouser",
50
+ 2: "Pullover",
51
+ 3: "Dress",
52
+ 4: "Coat",
53
+ 5: "Sandal",
54
+ 6: "Shirt",
55
+ 7: "Sneaker",
56
+ 8: "Bag",
57
+ 9: "Ankle boot"
58
+ }
59
+
60
+ class FashionMNISTDataset(Dataset):
61
+ def __init__(self, dataframe, image_size=224):
62
+ self.dataframe = dataframe
63
+ self.image_size = image_size
64
+ self.labels_map = get_fashion_mnist_labels()
65
+
66
+ # Simple transforms for validation/inference
67
+ self.transform = transforms.Compose([
68
+ transforms.Resize((image_size, image_size)),
69
+ transforms.ToTensor(),
70
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
71
+ ])
72
+
73
+ def __len__(self):
74
+ return len(self.dataframe)
75
+
76
+ def __getitem__(self, idx):
77
+ row = self.dataframe.iloc[idx]
78
+
79
+ # Get pixel values (columns 1-784)
80
+ pixel_cols = [f'pixel{i}' for i in range(1, 785)]
81
+ pixel_values = row[pixel_cols].values
82
+
83
+ # Convert to image
84
+ image = convert_fashion_mnist_to_image(pixel_values)
85
+ image = self.transform(image)
86
+
87
+ # Get text description
88
+ text = row['text']
89
+
90
+ # Get hierarchy label
91
+ hierarchy = row['hierarchy']
92
+
93
+ return image, text, hierarchy
94
+
95
+ class CLIPBaselineEvaluator:
96
+ def __init__(self, device='mps'):
97
+ self.device = torch.device(device)
98
+
99
+ # Load Fashion-CLIP model and processor
100
+ print("πŸ€— Loading Fashion-CLIP baseline model from transformers...")
101
+ patrick_model_name = "patrickjohncyh/fashion-clip"
102
+ self.clip_model = TransformersCLIPModel.from_pretrained(patrick_model_name).to(self.device)
103
+ self.clip_processor = CLIPProcessor.from_pretrained(patrick_model_name)
104
+
105
+ self.clip_model.eval()
106
+ print("βœ… Fashion-CLIP model loaded successfully")
107
+
108
+ def extract_clip_embeddings(self, images, texts):
109
+ """Extract Fashion-CLIP embeddings for images and texts"""
110
+ all_image_embeddings = []
111
+ all_text_embeddings = []
112
+
113
+ with torch.no_grad():
114
+ for i in tqdm(range(len(images)), desc="Extracting CLIP embeddings"):
115
+ # Process image
116
+ if isinstance(images[i], torch.Tensor):
117
+ # Convert tensor back to PIL Image
118
+ image_tensor = images[i]
119
+ if image_tensor.dim() == 3:
120
+ # Denormalize
121
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
122
+ std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
123
+ image_tensor = image_tensor * std + mean
124
+ image_tensor = torch.clamp(image_tensor, 0, 1)
125
+
126
+ # Convert to PIL
127
+ image_pil = transforms.ToPILImage()(image_tensor)
128
+ elif isinstance(images[i], Image.Image):
129
+ image_pil = images[i]
130
+ else:
131
+ raise ValueError(f"Unsupported image type: {type(images[i])}")
132
+
133
+ # Process with Fashion-CLIP
134
+ inputs = self.clip_processor(
135
+ text=texts[i],
136
+ images=image_pil,
137
+ return_tensors="pt",
138
+ padding=True
139
+ ).to(self.device)
140
+
141
+ outputs = self.clip_model(**inputs)
142
+
143
+ # Get normalized embeddings
144
+ image_emb = outputs.image_embeds / outputs.image_embeds.norm(p=2, dim=-1, keepdim=True)
145
+ text_emb = outputs.text_embeds / outputs.text_embeds.norm(p=2, dim=-1, keepdim=True)
146
+
147
+ all_image_embeddings.append(image_emb.cpu().numpy())
148
+ all_text_embeddings.append(text_emb.cpu().numpy())
149
+
150
+ return np.vstack(all_image_embeddings), np.vstack(all_text_embeddings)
151
 
152
 
153
  class EmbeddingEvaluator:
 
 
 
 
 
 
 
154
  def __init__(self, model_path, directory):
 
 
 
 
 
 
 
 
155
  self.directory = directory
156
 
157
  # 1. Load the dataset
158
+ print(f"πŸ“ Using dataset with local images: {local_dataset_path}")
159
+ df = pd.read_csv(local_dataset_path)
 
160
 
161
  print(f"πŸ“ Loaded {len(df)} samples")
162
 
163
  # 2. Get unique hierarchy classes from the dataset
164
+ hierarchy_classes = sorted(df[hierarchy_column].unique().tolist())
165
  print(f"πŸ“‹ Found {len(hierarchy_classes)} hierarchy classes")
166
 
167
+ _, self.val_df = train_test_split(df, test_size=0.2, random_state=42, stratify=df['hierarchy'])
168
 
169
  # 3. Load the model
170
  if os.path.exists(model_path):
171
  checkpoint = torch.load(model_path, map_location=self.device)
172
 
173
+ config = checkpoint.get('config', {})
 
174
  saved_hierarchy_classes = checkpoint['hierarchy_classes']
175
 
176
  # Use the saved hierarchy classes
 
182
  # Create the model with the saved configuration
183
  self.model = Model(
184
  num_hierarchy_classes=len(saved_hierarchy_classes),
185
+ embed_dim=config['embed_dim'],
186
+ dropout=config['dropout']
187
  ).to(self.device)
188
 
189
  self.model.load_state_dict(checkpoint['model_state'])
190
 
191
+ print(f"βœ… Custom model loaded with:")
192
  print(f"πŸ“‹ Hierarchy classes: {len(saved_hierarchy_classes)}")
193
+ print(f"🎯 Embed dim: {config['embed_dim']}")
194
+ print(f"πŸ’§ Dropout: {config['dropout']}")
195
  print(f"πŸ“… Epoch: {checkpoint.get('epoch', 'unknown')}")
196
 
197
  else:
198
  raise FileNotFoundError(f"Model file {model_path} not found")
199
 
200
  self.model.eval()
201
+
202
+ # Initialize Fashion-CLIP baseline
203
+ self.clip_evaluator = CLIPBaselineEvaluator(device)
204
 
205
  def create_dataloader(self, dataframe, batch_size=16):
206
+ """Create a dataloader for custom model"""
207
+ # Check if this is Fashion-MNIST data (has pixel1 column)
208
+ if 'pixel1' in dataframe.columns:
209
+ print("πŸ” Detected Fashion-MNIST data, using FashionMNISTDataset")
210
+ dataset = FashionMNISTDataset(dataframe, image_size=224)
211
+ else:
212
+ dataset = HierarchyDataset(dataframe, image_size=224)
 
 
 
 
213
 
214
  dataloader = DataLoader(
215
  dataset,
 
221
 
222
  return dataloader
223
 
224
+ def create_clip_dataloader(self, dataframe, batch_size=16):
225
+ """Create a dataloader for Fashion-CLIP baseline"""
226
+ # Check if this is Fashion-MNIST data (has pixel1 column)
227
+ if 'pixel1' in dataframe.columns:
228
+ print("πŸ” Detected Fashion-MNIST data for Fashion-CLIP, using FashionMNISTDataset")
229
+ dataset = FashionMNISTDataset(dataframe, image_size=224)
230
+ else:
231
+ dataset = CLIPDataset(dataframe)
232
 
233
+ dataloader = DataLoader(
234
+ dataset,
235
+ batch_size=batch_size,
236
+ shuffle=False,
237
+ num_workers=0
238
+ )
239
+
240
+ return dataloader
241
+
242
+ def extract_custom_embeddings(self, dataloader, embedding_type='text'):
243
+ """Extract embeddings from custom model"""
244
  all_embeddings = []
245
  all_labels = []
246
  all_texts = []
247
 
248
  with torch.no_grad():
249
+ for batch in tqdm(dataloader, desc=f"Extracting custom {embedding_type} embeddings"):
250
  images = batch['image'].to(self.device)
251
  hierarchy_indices = batch['hierarchy_indices'].to(self.device)
252
  hierarchy_labels = batch['hierarchy']
 
260
  all_texts.extend(hierarchy_labels)
261
 
262
  return np.vstack(all_embeddings), all_labels, all_texts
263
+
264
  def compute_similarity_metrics(self, embeddings, labels):
265
+ """Compute intra-class and inter-class similarities"""
 
 
 
 
 
 
 
 
 
266
  similarities = cosine_similarity(embeddings)
267
 
268
  # Group embeddings by hierarchy
 
279
  sim = similarities[indices[i], indices[j]]
280
  intra_class_similarities.append(sim)
281
 
 
282
  # Calculate inter-class similarities (different hierarchies)
283
  inter_class_similarities = []
284
  hierarchies = list(hierarchy_groups.keys())
 
309
  }
310
 
311
  def compute_embedding_accuracy(self, embeddings, labels, similarities):
312
+ """Compute classification accuracy using nearest neighbor in embedding space"""
 
 
 
 
 
 
 
 
 
 
313
  correct_predictions = 0
314
  total_predictions = len(labels)
315
 
 
328
  return correct_predictions / total_predictions if total_predictions > 0 else 0
329
 
330
  def compute_centroid_accuracy(self, embeddings, labels):
331
+ """Compute classification accuracy using hierarchy centroids"""
 
 
 
 
 
 
 
 
 
 
 
 
332
  # Create centroids for each hierarchy
333
  unique_hierarchies = list(set(labels))
334
  centroids = {}
 
359
  correct_predictions += 1
360
 
361
  return correct_predictions / total_predictions if total_predictions > 0 else 0
362
+
363
+ def create_confusion_matrix(self, true_labels, predicted_labels, title="Confusion Matrix"):
364
+ """Create and plot confusion matrix"""
365
+ # Get unique labels
366
+ unique_labels = sorted(list(set(true_labels + predicted_labels)))
367
 
368
+ # Create confusion matrix
369
+ cm = confusion_matrix(true_labels, predicted_labels, labels=unique_labels)
370
+
371
+ # Calculate accuracy
372
+ accuracy = accuracy_score(true_labels, predicted_labels)
373
+
374
+ # Plot confusion matrix
375
+ plt.figure(figsize=(12, 10))
376
+ sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
377
+ xticklabels=unique_labels, yticklabels=unique_labels)
378
+ plt.title(f'{title}\nAccuracy: {accuracy:.3f} ({accuracy*100:.1f}%)')
379
+ plt.ylabel('True Hierarchy')
380
+ plt.xlabel('Predicted Hierarchy')
381
+ plt.xticks(rotation=45)
382
+ plt.yticks(rotation=0)
383
+ plt.tight_layout()
384
+
385
+ return plt.gcf(), accuracy, cm
386
+
387
+ def predict_hierarchy_from_embeddings(self, embeddings, labels):
388
+ """Predict hierarchy from embeddings using centroid-based classification"""
389
  # Create hierarchy centroids from training data
390
  unique_hierarchies = list(set(labels))
391
  centroids = {}
 
412
  predictions.append(predicted_hierarchy)
413
 
414
  return predictions
415
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
416
  def evaluate_classification_performance(self, embeddings, labels, embedding_type="Embeddings"):
417
+ """Evaluate classification performance and create confusion matrix"""
 
 
 
 
 
 
 
 
 
 
418
  # Predict hierarchy
419
  predictions = self.predict_hierarchy_from_embeddings(embeddings, labels)
420
 
421
  # Calculate accuracy
422
  accuracy = accuracy_score(labels, predictions)
423
 
424
+ # Calculate F1 scores
425
+ unique_labels = sorted(list(set(labels)))
426
+ f1_macro = f1_score(labels, predictions, labels=unique_labels, average='macro', zero_division=0)
427
+ f1_weighted = f1_score(labels, predictions, labels=unique_labels, average='weighted', zero_division=0)
428
+ f1_per_class = f1_score(labels, predictions, labels=unique_labels, average=None, zero_division=0)
429
+
430
  # Create confusion matrix
431
  fig, acc, cm = self.create_confusion_matrix(labels, predictions,
432
  f"{embedding_type} - Hierarchy Classification")
433
 
434
  # Generate classification report
 
435
  report = classification_report(labels, predictions, labels=unique_labels,
436
  target_names=unique_labels, output_dict=True)
437
 
438
  return {
439
  'accuracy': accuracy,
440
+ 'f1_macro': f1_macro,
441
+ 'f1_weighted': f1_weighted,
442
+ 'f1_per_class': f1_per_class,
443
  'predictions': predictions,
444
  'confusion_matrix': cm,
445
  'classification_report': report,
446
  'figure': fig
447
  }
448
+
449
+ def evaluate_dataset_with_baselines(self, dataframe, dataset_name="Dataset"):
450
+ """Evaluate embeddings on a given dataset with both custom model and CLIP baseline"""
 
 
 
 
 
 
 
 
 
 
 
 
451
  print(f"\n{'='*60}")
452
  print(f"Evaluating {dataset_name}")
453
  print(f"{'='*60}")
454
 
 
 
 
455
  results = {}
456
 
457
+ # ===== CUSTOM MODEL EVALUATION =====
458
+ print(f"\nπŸ”§ Evaluating Custom Model on {dataset_name}")
459
+ print("-" * 40)
460
+
461
+ # Create dataloader for custom model
462
+ custom_dataloader = self.create_dataloader(dataframe, batch_size=16)
463
+
464
  # Evaluate text embeddings
465
+ text_embeddings, text_labels, texts = self.extract_custom_embeddings(custom_dataloader, 'text')
466
  text_metrics = self.compute_similarity_metrics(text_embeddings, text_labels)
467
+ text_classification = self.evaluate_classification_performance(text_embeddings, text_labels, "Custom Text Embeddings")
468
  text_metrics.update(text_classification)
469
+ results['custom_text'] = text_metrics
470
 
471
  # Evaluate image embeddings
472
+ image_embeddings, image_labels, _ = self.extract_custom_embeddings(custom_dataloader, 'image')
473
  image_metrics = self.compute_similarity_metrics(image_embeddings, image_labels)
474
+ image_classification = self.evaluate_classification_performance(image_embeddings, image_labels, "Custom Image Embeddings")
475
  image_metrics.update(image_classification)
476
+ results['custom_image'] = image_metrics
 
 
 
 
 
 
 
477
 
478
+ # ===== FASHION-CLIP BASELINE EVALUATION =====
479
+ print(f"\nπŸ€— Evaluating Fashion-CLIP Baseline on {dataset_name}")
480
  print("-" * 40)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
481
 
482
+ # Create dataloader for Fashion-CLIP
483
+ clip_dataloader = self.create_clip_dataloader(dataframe, batch_size=8) # Smaller batch for Fashion-CLIP
 
484
 
485
+ # Extract data for Fashion-CLIP
486
+ all_images = []
487
+ all_texts = []
488
+ all_labels = []
489
+
490
+ for batch in tqdm(clip_dataloader, desc="Preparing data for Fashion-CLIP"):
491
+ images, texts, labels = batch
492
+ all_images.extend(images)
493
+ all_texts.extend(texts)
494
+ all_labels.extend(labels)
495
+
496
+ # Get Fashion-CLIP embeddings
497
+ clip_image_embeddings, clip_text_embeddings = self.clip_evaluator.extract_clip_embeddings(all_images, all_texts)
498
+
499
+ # Evaluate Fashion-CLIP text embeddings
500
+ clip_text_metrics = self.compute_similarity_metrics(clip_text_embeddings, all_labels)
501
+ clip_text_classification = self.evaluate_classification_performance(clip_text_embeddings, all_labels, "Fashion-CLIP Text Embeddings")
502
+ clip_text_metrics.update(clip_text_classification)
503
+ results['clip_text'] = clip_text_metrics
504
+
505
+ # Evaluate Fashion-CLIP image embeddings
506
+ clip_image_metrics = self.compute_similarity_metrics(clip_image_embeddings, all_labels)
507
+ clip_image_classification = self.evaluate_classification_performance(clip_image_embeddings, all_labels, "Fashion-CLIP Image Embeddings")
508
+ clip_image_metrics.update(clip_image_classification)
509
+ results['clip_image'] = clip_image_metrics
510
+
511
+ # ===== PRINT COMPARISON RESULTS =====
512
+ print(f"\n{dataset_name} Results Comparison:")
513
+ print(f"Dataset size: {len(dataframe)} samples")
514
+ print("=" * 80)
515
+ print(f"{'Model':<20} {'Embedding':<10} {'Sep Score':<10} {'NN Acc':<8} {'Centroid Acc':<12} {'F1 Macro':<10}")
516
+ print("-" * 80)
517
+
518
+ for model_type in ['custom', 'clip']:
519
+ for emb_type in ['text', 'image']:
520
+ key = f"{model_type}_{emb_type}"
521
+ if key in results:
522
+ metrics = results[key]
523
+ model_name = "Custom Model" if model_type == 'custom' else "Fashion-CLIP Baseline"
524
+ print(f"{model_name:<20} {emb_type.capitalize():<10} {metrics['separation_score']:<10.4f} {metrics['accuracy']*100:<8.1f}% {metrics['centroid_accuracy']*100:<12.1f}% {metrics['f1_macro']*100:<10.1f}%")
525
+
526
+ # ===== SAVE VISUALIZATIONS =====
527
+ os.makedirs(f'{self.directory}', exist_ok=True)
528
 
529
+ # Save confusion matrices
530
+ for key, metrics in results.items():
531
+ if 'figure' in metrics:
532
+ metrics['figure'].savefig(f'{self.directory}/{dataset_name.lower()}_{key}_confusion_matrix.png', dpi=300, bbox_inches='tight')
533
+ plt.close(metrics['figure'])
534
 
535
  return results
536
 
537
+
538
+ class CLIPDataset(Dataset):
539
  def __init__(self, dataframe):
540
  self.dataframe = dataframe
541
  # Use VALIDATION transforms (no augmentation)
 
551
  def __getitem__(self, idx):
552
  row = self.dataframe.iloc[idx]
553
 
554
+ # Handle image loading (same as HierarchyDataset)
555
+ if config.column_local_image_path in row.index and pd.notna(row[config.column_local_image_path]):
556
+ local_path = row[config.column_local_image_path]
557
+ try:
558
+ if os.path.exists(local_path):
559
+ image = Image.open(local_path).convert("RGB")
560
+ else:
561
+ print(f"⚠️ Local image not found: {local_path}")
562
+ image = Image.new('RGB', (224, 224), color='gray')
563
+ except Exception as e:
564
+ print(f"⚠️ Failed to load local image {idx}: {e}")
565
+ image = Image.new('RGB', (224, 224), color='gray')
566
+ elif isinstance(row[config.column_url_image], dict):
567
+ image = Image.open(BytesIO(row[config.column_url_image]['bytes'])).convert('RGB')
568
+ elif isinstance(row['image_url'], (list, np.ndarray)):
569
+ pixels = np.array(row[config.column_url_image]).reshape(28, 28)
570
+ image = Image.fromarray(pixels.astype(np.uint8)).convert("RGB")
571
+ elif isinstance(row[config.column_url_image], Image.Image):
572
+ # Handle PIL Image objects directly (for Fashion-MNIST)
573
+ image = row[config.column_url_image].convert("RGB")
574
+ else:
575
+ try:
576
+ response = requests.get(row[config.column_url_image], timeout=10)
577
+ response.raise_for_status()
578
+ image = Image.open(BytesIO(response.content)).convert("RGB")
579
+ except Exception as e:
580
+ print(f"⚠️ Failed to load image {idx}: {e}")
581
+ image = Image.new('RGB', (224, 224), color='gray')
582
+
583
+ # Apply transforms
584
+ image_tensor = self.transform(image)
585
 
586
+ description = row[config.text_column]
587
+ hierarchy = row[config.hierarchy_column]
588
+
589
+ return image_tensor, description, hierarchy
590
 
 
591
 
592
+ def load_fashion_mnist_dataset(evaluator):
593
+ """Load and prepare Fashion-MNIST test dataset"""
594
+ print("Loading Fashion-MNIST test dataset...")
595
+
596
+ # Load the dataset
597
+ df = pd.read_csv(config.fashion_mnist_test_path)
598
+ print(f"βœ… Fashion-MNIST dataset loaded")
599
+ print(f"πŸ“Š Total samples: {len(df)}")
600
+
601
+ # Fashion-MNIST class labels mapping
602
+ fashion_mnist_labels = get_fashion_mnist_labels()
603
+
604
+ # Map labels to hierarchy classes
605
+ hierarchy_mapping = {
606
+ 'T-shirt/top': 'top',
607
+ 'Trouser': 'bottom',
608
+ 'Pullover': 'top',
609
+ 'Dress': 'dress',
610
+ 'Coat': 'top',
611
+ 'Sandal': 'shoes',
612
+ 'Shirt': 'top',
613
+ 'Sneaker': 'shoes',
614
+ 'Bag': 'bag',
615
+ 'Ankle boot': 'shoes'
616
+ }
617
+
618
+ # Apply label mapping
619
+ df['hierarchy'] = df['label'].map(fashion_mnist_labels).map(hierarchy_mapping)
620
+
621
+ # Filter to only include hierarchies that exist in our model
622
+ valid_hierarchies = df['hierarchy'].dropna().unique()
623
+ print(f"🎯 Valid hierarchies found: {sorted(valid_hierarchies)}")
624
+ print(f"🎯 Model hierarchies: {sorted(evaluator.hierarchy_classes)}")
625
+
626
+ # Filter to only include hierarchies that exist in our model
627
+ df = df[df['hierarchy'].isin(evaluator.hierarchy_classes)]
628
+ print(f"πŸ“Š After filtering to model hierarchies: {len(df)} samples")
629
+
630
+ if len(df) == 0:
631
+ print("❌ No samples left after hierarchy filtering.")
632
+ return pd.DataFrame()
633
+
634
+ # Keep pixel columns as they are (FashionMNISTDataset will handle them)
635
+
636
+ # Create text descriptions based on hierarchy
637
+ text_descriptions = {
638
+ 'top': 'A top clothing item',
639
+ 'bottom': 'A bottom clothing item',
640
+ 'dress': 'A dress',
641
+ 'shoes': 'A pair of shoes',
642
+ 'bag': 'A bag'
643
+ }
644
+
645
+ df['text'] = df['hierarchy'].map(text_descriptions)
646
+
647
+ # Show sample of data
648
+ print(f"πŸ“ Sample data:")
649
+ for i, (hierarchy, text) in enumerate(zip(df['hierarchy'].head(3), df['text'].head(3))):
650
+ print(f" {i+1}. [{hierarchy}] {text}")
651
+
652
+ df_test = df.copy()
653
+
654
+ print(f"πŸ“Š After sampling: {len(df_test)} samples")
655
+ print(f"πŸ“Š Samples per hierarchy:")
656
+ for hierarchy in sorted(df_test['hierarchy'].unique()):
657
+ count = len(df_test[df_test['hierarchy'] == hierarchy])
658
+ print(f" {hierarchy}: {count} samples")
659
+
660
+ # Create formatted dataset with proper column names
661
+ # Keep all pixel columns for FashionMNISTDataset
662
+ pixel_cols = [f'pixel{i}' for i in range(1, 785)]
663
+ fashion_mnist_formatted = df_test[['label'] + pixel_cols + ['text', 'hierarchy']].copy()
664
+
665
+ print(f"πŸ“Š Final dataset size: {len(fashion_mnist_formatted)} samples")
666
+ return fashion_mnist_formatted
667
+
668
+
669
+ def load_kagl_marqo_dataset(evaluator):
670
+ """Load and prepare kagl dataset"""
671
  from datasets import load_dataset
672
+ print("Loading kagl dataset...")
673
 
674
  # Load the dataset
675
  dataset = load_dataset("Marqo/KAGL")
676
  df = dataset["data"].to_pandas()
677
+ print(f"βœ… Dataset kagl loaded")
678
  print(f"πŸ“Š Before filtering: {len(df)} samples")
679
  print(f"πŸ“‹ Available columns: {list(df.columns)}")
680
 
 
706
  for i, (text, hierarchy) in enumerate(zip(df['text'].head(3), df['hierarchy'].head(3))):
707
  print(f" {i+1}. [{hierarchy}] {text[:100]}...")
708
 
709
+ print(f"πŸ“Š After sampling: {len(df_test)} samples")
710
  print(f"πŸ“Š Samples per hierarchy:")
711
+ for hierarchy in sorted(df_test['hierarchy'].unique()):
712
+ count = len(df_test[df_test['hierarchy'] == hierarchy])
713
  print(f" {hierarchy}: {count} samples")
714
 
715
  # Create formatted dataset with proper column names
716
+ kagl_formatted = pd.DataFrame({
717
+ 'image_url': df_test['image'],
718
+ 'text': df_test['text'],
719
+ 'hierarchy': df_test['hierarchy']
720
  })
721
 
722
+ print(f"πŸ“Š Final dataset size: {len(kagl_formatted)} samples")
723
+ return kagl_formatted
724
+
725
 
726
  if __name__ == "__main__":
727
+ directory = "hierarchy_model_analysis"
 
 
728
 
729
+ print(f"πŸš€ Starting evaluation with custom model: {hierarchy_model_path}")
730
+ print(f"πŸ€— Including Fashion-CLIP baseline comparison")
731
 
732
+ evaluator = EmbeddingEvaluator(hierarchy_model_path, directory)
733
 
734
  print(f"πŸ“Š Final hierarchy classes after initialization: {len(evaluator.vocab.hierarchy_classes)} classes")
735
 
736
  # Evaluate on validation dataset (same subset as during training)
737
  print("\n" + "="*60)
738
+ print("EVALUATING VALIDATION DATASET - CUSTOM MODEL vs FASHION-CLIP BASELINE")
739
  print("="*60)
740
+ val_results = evaluator.evaluate_dataset_with_baselines(evaluator.val_df, "Validation Dataset")
741
 
742
  print("\n" + "="*60)
743
+ print("EVALUATING FASHION-MNIST TEST DATASET - CUSTOM MODEL vs FASHION-CLIP BASELINE")
744
  print("="*60)
745
+ df_fashion_mnist = load_fashion_mnist_dataset(evaluator)
746
+ if len(df_fashion_mnist) > 0:
747
+ fashion_mnist_results = evaluator.evaluate_dataset_with_baselines(df_fashion_mnist, "Fashion-MNIST Test Dataset")
748
+ else:
749
+ fashion_mnist_results = {}
750
+
751
+ print("\n" + "="*60)
752
+ print("EVALUATING kagl MARQO DATASET - CUSTOM MODEL vs FASHION-CLIP BASELINE")
753
+ print("="*60)
754
+ df_kagl_marqo = load_kagl_marqo_dataset(evaluator)
755
+ if len(df_kagl_marqo) > 0:
756
+ kagl_results = evaluator.evaluate_dataset_with_baselines(df_kagl_marqo, "kagl Marqo Dataset")
757
+ else:
758
+ kagl_results = {}
759
 
760
  # Compare results
761
+ print(f"\n{'='*80}")
762
+ print("FINAL EVALUATION SUMMARY - CUSTOM MODEL vs FASHION-CLIP BASELINE")
763
+ print(f"{'='*80}")
764
 
765
  print("\nπŸ” VALIDATION DATASET RESULTS:")
766
+ print(f"Dataset size: {len(evaluator.val_df)} samples")
767
+ print(f"{'Model':<20} {'Embedding':<10} {'Sep Score':<12} {'NN Acc':<10} {'Centroid Acc':<12} {'F1 Macro':<10}")
768
+ print("-" * 80)
769
 
770
+ for model_type in ['custom', 'clip']:
771
+ for emb_type in ['text', 'image']:
772
+ key = f"{model_type}_{emb_type}"
773
+ if key in val_results:
774
+ metrics = val_results[key]
775
+ model_name = "Custom Model" if model_type == 'custom' else "CLIP Baseline"
776
+ print(f"{model_name:<20} {emb_type.capitalize():<10} {metrics['separation_score']:<12.4f} {metrics['accuracy']*100:<10.1f}% {metrics['centroid_accuracy']*100:<12.1f}% {metrics['f1_macro']*100:<10.1f}%")
777
+
778
+ if fashion_mnist_results:
779
+ print("\nπŸ‘— FASHION-MNIST TEST DATASET RESULTS:")
780
+ print(f"Dataset size: {len(df_fashion_mnist)} samples")
781
+ print(f"{'Model':<20} {'Embedding':<10} {'Sep Score':<12} {'NN Acc':<10} {'Centroid Acc':<12} {'F1 Macro':<10}")
782
+ print("-" * 80)
783
+
784
+ for model_type in ['custom', 'clip']:
785
+ for emb_type in ['text', 'image']:
786
+ key = f"{model_type}_{emb_type}"
787
+ if key in fashion_mnist_results:
788
+ metrics = fashion_mnist_results[key]
789
+ model_name = "Custom Model" if model_type == 'custom' else "Fashion-CLIP Baseline"
790
+ print(f"{model_name:<20} {emb_type.capitalize():<10} {metrics['separation_score']:<12.4f} {metrics['accuracy']*100:<10.1f}% {metrics['centroid_accuracy']*100:<12.1f}% {metrics['f1_macro']*100:<10.1f}%")
791
 
792
+ if kagl_results:
793
+ print("\n🌐 kagl MARQO DATASET RESULTS:")
794
+ print(f"Dataset size: {len(df_kagl_marqo)} samples")
795
+ print(f"{'Model':<20} {'Embedding':<10} {'Sep Score':<12} {'NN Acc':<10} {'Centroid Acc':<12} {'F1 Macro':<10}")
796
+ print("-" * 80)
797
+
798
+ for model_type in ['custom', 'clip']:
799
+ for emb_type in ['text', 'image']:
800
+ key = f"{model_type}_{emb_type}"
801
+ if key in kagl_results:
802
+ metrics = kagl_results[key]
803
+ model_name = "Custom Model" if model_type == 'custom' else "Fashion-CLIP Baseline"
804
+ print(f"{model_name:<20} {emb_type.capitalize():<10} {metrics['separation_score']:<12.4f} {metrics['accuracy']*100:<10.1f}% {metrics['centroid_accuracy']*100:<12.1f}% {metrics['f1_macro']*100:<10.1f}%")
805
 
806
+ print(f"\nβœ… Evaluation completed! Check '{directory}/' for visualization files.")
807
+ print(f"πŸ“Š Custom model hierarchy classes: {len(evaluator.vocab.hierarchy_classes)} classes")
808
+ print(f"πŸ€— Fashion-CLIP baseline comparison included")