Leacb4 commited on
Commit
9f6ec81
ยท
verified ยท
1 Parent(s): 70f9f13

Upload evaluation/main_model_evaluation.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. evaluation/main_model_evaluation.py +123 -592
evaluation/main_model_evaluation.py CHANGED
@@ -19,7 +19,7 @@ import warnings
19
  warnings.filterwarnings('ignore')
20
  from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers
21
 
22
- from config import main_model_path, hierarchy_model_path, color_emb_dim, hierarchy_emb_dim, local_dataset_path, column_local_image_path
23
 
24
 
25
  def create_fashion_mnist_to_hierarchy_mapping(hierarchy_classes):
@@ -508,6 +508,15 @@ def load_local_validation_dataset(max_samples=5000):
508
  print("โŒ No valid samples after filtering.")
509
  return None
510
 
 
 
 
 
 
 
 
 
 
511
  # Ensure we have required columns
512
  required_cols = ['text', 'hierarchy']
513
  missing_cols = [col for col in required_cols if col not in df_clean.columns]
@@ -515,9 +524,10 @@ def load_local_validation_dataset(max_samples=5000):
515
  print(f"โŒ Missing required columns: {missing_cols}")
516
  return None
517
 
518
- # Limit to max_samples
519
  if len(df_clean) > max_samples:
520
- df_clean = df_clean.head(max_samples)
 
521
 
522
  print(f"๐Ÿ“Š Using {len(df_clean)} samples for evaluation")
523
  print(f" Samples per hierarchy:")
@@ -525,6 +535,14 @@ def load_local_validation_dataset(max_samples=5000):
525
  count = len(df_clean[df_clean['hierarchy'] == hierarchy])
526
  print(f" {hierarchy}: {count} samples")
527
 
 
 
 
 
 
 
 
 
528
  return LocalDataset(df_clean)
529
 
530
 
@@ -726,7 +744,12 @@ class ColorHierarchyEvaluator:
726
  return np.vstack(all_embeddings), all_colors, all_hierarchies
727
 
728
  def extract_baseline_embeddings_batch(self, dataloader, embedding_type='text', max_samples=10000):
729
- """Extract embeddings from baseline Fashion CLIP model"""
 
 
 
 
 
730
  all_embeddings = []
731
  all_colors = []
732
  all_hierarchies = []
@@ -739,23 +762,57 @@ class ColorHierarchyEvaluator:
739
  break
740
 
741
  images, texts, colors, hierarchies = batch
742
- images = images.to(self.device)
743
- images = images.expand(-1, 3, -1, -1) # Ensure 3 channels
744
-
745
- # Process text inputs with baseline processor
746
- text_inputs = self.baseline_processor(text=texts, padding=True, return_tensors="pt")
747
- text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
748
-
749
- # Forward pass through baseline model
750
- outputs = self.baseline_model(**text_inputs, pixel_values=images)
751
 
752
  # Extract embeddings based on type
753
  if embedding_type == 'text':
754
- embeddings = outputs.text_embeds
 
 
 
 
 
 
 
 
 
 
755
  elif embedding_type == 'image':
756
- embeddings = outputs.image_embeds
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
757
  else:
758
- embeddings = outputs.text_embeds
 
 
 
 
 
759
 
760
  all_embeddings.append(embeddings.cpu().numpy())
761
  all_colors.extend(colors)
@@ -764,62 +821,13 @@ class ColorHierarchyEvaluator:
764
  sample_count += len(images)
765
 
766
  # Clear GPU memory
767
- del images, text_inputs, outputs, embeddings
768
- torch.cuda.empty_cache() if torch.cuda.is_available() else None
769
-
770
- return np.vstack(all_embeddings), all_colors, all_hierarchies
771
-
772
- def extract_full_embeddings(self, dataloader, embedding_type='text', max_samples=10000):
773
- """
774
- Extrait TOUTES les dimensions des embeddings du modรจle entraรฎnรฉ (pas seulement les sous-espaces spรฉcialisรฉs)
775
-
776
- Cette mรฉthode permet de comparer les performances en utilisant toutes les dimensions disponibles,
777
- similaire ร  la baseline qui utilise toutes ses dimensions.
778
-
779
- Diffรฉrence avec extract_color_embeddings et extract_hierarchy_embeddings:
780
- - extract_color_embeddings: utilise seulement dims 0-15 (16 dimensions)
781
- - extract_hierarchy_embeddings: utilise seulement dims 16-79 (64 dimensions)
782
- - extract_full_embeddings: utilise toutes les dimensions (ex: 512 dimensions)
783
-
784
- Cela peut amรฉliorer les performances car toutes les informations sont disponibles.
785
- """
786
- all_embeddings = []
787
- all_colors = []
788
- all_hierarchies = []
789
-
790
- sample_count = 0
791
- with torch.no_grad():
792
- for batch in tqdm(dataloader, desc=f"Extracting {embedding_type} FULL embeddings (all dims)"):
793
- if sample_count >= max_samples:
794
- break
795
-
796
- images, texts, colors, hierarchies = batch
797
- images = images.to(self.device)
798
- images = images.expand(-1, 3, -1, -1)
799
-
800
- text_inputs = self.processor(text=texts, padding=True, return_tensors="pt")
801
- text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
802
-
803
- outputs = self.model(**text_inputs, pixel_values=images)
804
-
805
- if embedding_type == 'text':
806
- embeddings = outputs.text_embeds
807
- elif embedding_type == 'image':
808
- embeddings = outputs.image_embeds
809
  else:
810
- embeddings = outputs.text_embeds
811
-
812
- # Utiliser TOUTES les dimensions (pas seulement un sous-espace)
813
- # Cela permet d'avoir accรจs ร  toute l'information disponible dans l'embedding
814
- all_embeddings.append(embeddings.cpu().numpy())
815
- all_colors.extend(colors)
816
- all_hierarchies.extend(hierarchies)
817
-
818
- sample_count += len(images)
819
-
820
- del images, text_inputs, outputs, embeddings
821
  torch.cuda.empty_cache() if torch.cuda.is_available() else None
822
-
823
  return np.vstack(all_embeddings), all_colors, all_hierarchies
824
 
825
  def compute_similarity_metrics(self, embeddings, labels):
@@ -1052,75 +1060,55 @@ class ColorHierarchyEvaluator:
1052
 
1053
  results = {}
1054
 
1055
- # ========== COLOR EVALUATION (DIMS 0-15) ==========
1056
- print("\n๐ŸŽจ COLOR EVALUATION (dims 0-15)")
1057
- print("=" * 50)
1058
-
1059
- # Text color embeddings
1060
- print("\n๐Ÿ“ Extracting text color embeddings...")
1061
- text_color_embeddings, text_colors, _ = self.extract_color_embeddings(dataloader, 'text', max_samples)
1062
- print(f" Text color embeddings shape: {text_color_embeddings.shape}")
1063
- text_color_metrics = self.compute_similarity_metrics(text_color_embeddings, text_colors)
1064
- text_color_class = self.evaluate_classification_performance(
1065
- text_color_embeddings, text_colors, "Text Color Embeddings (16D)", "Color"
1066
- )
1067
- text_color_metrics.update(text_color_class)
1068
- results['text_color'] = text_color_metrics
1069
-
1070
- del text_color_embeddings
1071
- torch.cuda.empty_cache() if torch.cuda.is_available() else None
1072
-
1073
- # Image color embeddings
1074
- print("\n๐Ÿ–ผ๏ธ Extracting image color embeddings...")
1075
- image_color_embeddings, image_colors, _ = self.extract_color_embeddings(dataloader, 'image', max_samples)
1076
- print(f" Image color embeddings shape: {image_color_embeddings.shape}")
1077
- image_color_metrics = self.compute_similarity_metrics(image_color_embeddings, image_colors)
1078
- image_color_class = self.evaluate_classification_performance(
1079
- image_color_embeddings, image_colors, "Image Color Embeddings (16D)", "Color"
1080
- )
1081
- image_color_metrics.update(image_color_class)
1082
- results['image_color'] = image_color_metrics
1083
-
1084
- del image_color_embeddings
1085
- torch.cuda.empty_cache() if torch.cuda.is_available() else None
1086
 
1087
- # ========== HIERARCHY EVALUATION (DIMS 16-79) ==========
1088
- print("\n๐Ÿ“‹ HIERARCHY EVALUATION (dims 16-79)")
1089
  print("=" * 50)
1090
 
1091
- # Text hierarchy embeddings
1092
- print("\n๐Ÿ“ Extracting text hierarchy embeddings...")
1093
- text_hierarchy_embeddings, _, text_hierarchies = self.extract_hierarchy_embeddings(dataloader, 'text', max_samples)
1094
- print(f" Text hierarchy embeddings shape: {text_hierarchy_embeddings.shape}")
1095
- text_hierarchy_metrics = self.compute_similarity_metrics(text_hierarchy_embeddings, text_hierarchies)
 
1096
  text_hierarchy_class = self.evaluate_classification_performance(
1097
- text_hierarchy_embeddings, text_hierarchies, "Text Hierarchy Embeddings (64D)", "Hierarchy"
 
 
1098
  )
1099
  text_hierarchy_metrics.update(text_hierarchy_class)
1100
  results['text_hierarchy'] = text_hierarchy_metrics
1101
 
1102
- del text_hierarchy_embeddings
1103
- torch.cuda.empty_cache() if torch.cuda.is_available() else None
1104
-
1105
- # Image hierarchy embeddings
1106
- print("\n๐Ÿ–ผ๏ธ Extracting image hierarchy embeddings...")
1107
- image_hierarchy_embeddings, _, image_hierarchies = self.extract_hierarchy_embeddings(dataloader, 'image', max_samples)
1108
- print(f" Image hierarchy embeddings shape: {image_hierarchy_embeddings.shape}")
1109
- image_hierarchy_metrics = self.compute_similarity_metrics(image_hierarchy_embeddings, image_hierarchies)
1110
  image_hierarchy_class = self.evaluate_classification_performance(
1111
- image_hierarchy_embeddings, image_hierarchies, "Image Hierarchy Embeddings (64D)", "Hierarchy"
 
 
1112
  )
1113
  image_hierarchy_metrics.update(image_hierarchy_class)
1114
  results['image_hierarchy'] = image_hierarchy_metrics
1115
 
1116
- del image_hierarchy_embeddings
 
 
 
1117
  torch.cuda.empty_cache() if torch.cuda.is_available() else None
1118
 
1119
  # ========== SAVE VISUALIZATIONS ==========
1120
  os.makedirs(self.directory, exist_ok=True)
1121
- for key in ['text_color', 'image_color', 'text_hierarchy', 'image_hierarchy']:
1122
  results[key]['figure'].savefig(
1123
- f"{self.directory}/{key.replace('_', '_')}_confusion_matrix.png",
1124
  dpi=300,
1125
  bbox_inches='tight',
1126
  )
@@ -1245,11 +1233,11 @@ class ColorHierarchyEvaluator:
1245
  return results
1246
 
1247
  def evaluate_local_validation(self, max_samples):
1248
- """Evaluate both color and hierarchy embeddings on local validation dataset"""
1249
  print(f"\n{'='*60}")
1250
  print("Evaluating Local Validation Dataset")
1251
- print(" Color embeddings: dims 0-15")
1252
- print(" Hierarchy embeddings: dims 16-79")
1253
  print(f"Max samples: {max_samples}")
1254
  print(f"{'='*60}")
1255
 
@@ -1283,8 +1271,8 @@ class ColorHierarchyEvaluator:
1283
 
1284
  results = {}
1285
 
1286
- # ========== COLOR EVALUATION (DIMS 0-15) ==========
1287
- print("\n๐ŸŽจ COLOR EVALUATION (dims 0-15)")
1288
  print("=" * 50)
1289
 
1290
  # Text color embeddings
@@ -1315,8 +1303,8 @@ class ColorHierarchyEvaluator:
1315
  del image_color_embeddings
1316
  torch.cuda.empty_cache() if torch.cuda.is_available() else None
1317
 
1318
- # ========== HIERARCHY EVALUATION (DIMS 16-79) ==========
1319
- print("\n๐Ÿ“‹ HIERARCHY EVALUATION (dims 16-79)")
1320
  print("=" * 50)
1321
 
1322
  # Text hierarchy embeddings
@@ -1359,192 +1347,6 @@ class ColorHierarchyEvaluator:
1359
 
1360
  return results
1361
 
1362
- def evaluate_full_embeddings(self, dataloader, dataset_name, max_samples=10000):
1363
- """
1364
- Evaluate using ALL 512 dimensions from our trained model (not just specialized subspaces)
1365
- This allows fair comparison with baseline which uses all 512 dimensions.
1366
- """
1367
- print(f"\n{'='*60}")
1368
- print(f"Evaluating {dataset_name} with FULL 512-dimensional embeddings (Our Model)")
1369
- print(f"Max samples: {max_samples}")
1370
- print(f"{'='*60}")
1371
-
1372
- results = {}
1373
-
1374
- # ========== COLOR EVALUATION WITH FULL EMBEDDINGS ==========
1375
- print("\n๐ŸŽจ COLOR EVALUATION (512 dims - Full Embeddings)")
1376
- print("=" * 50)
1377
-
1378
- # Text color embeddings
1379
- print("\n๐Ÿ“ Extracting text FULL embeddings for color classification...")
1380
- text_full_embeddings, text_colors, _ = self.extract_full_embeddings(dataloader, 'text', max_samples)
1381
- print(f" Text full embeddings shape: {text_full_embeddings.shape} (using all {text_full_embeddings.shape[1]} dimensions)")
1382
- text_color_metrics = self.compute_similarity_metrics(text_full_embeddings, text_colors)
1383
- text_color_class = self.evaluate_classification_performance(
1384
- text_full_embeddings, text_colors, "Text Full Embeddings (512D) - Color", "Color"
1385
- )
1386
- text_color_metrics.update(text_color_class)
1387
- results['text_color'] = text_color_metrics
1388
-
1389
- del text_full_embeddings
1390
- torch.cuda.empty_cache() if torch.cuda.is_available() else None
1391
-
1392
- # Image color embeddings
1393
- print("\n๐Ÿ–ผ๏ธ Extracting image FULL embeddings for color classification...")
1394
- image_full_embeddings, image_colors, _ = self.extract_full_embeddings(dataloader, 'image', max_samples)
1395
- print(f" Image full embeddings shape: {image_full_embeddings.shape} (using all {image_full_embeddings.shape[1]} dimensions)")
1396
- image_color_metrics = self.compute_similarity_metrics(image_full_embeddings, image_colors)
1397
- image_color_class = self.evaluate_classification_performance(
1398
- image_full_embeddings, image_colors, "Image Full Embeddings (512D) - Color", "Color"
1399
- )
1400
- image_color_metrics.update(image_color_class)
1401
- results['image_color'] = image_color_metrics
1402
-
1403
- del image_full_embeddings
1404
- torch.cuda.empty_cache() if torch.cuda.is_available() else None
1405
-
1406
- # ========== HIERARCHY EVALUATION WITH FULL EMBEDDINGS ==========
1407
- print("\n๐Ÿ“‹ HIERARCHY EVALUATION (512 dims - Full Embeddings)")
1408
- print("=" * 50)
1409
-
1410
- # Text hierarchy embeddings
1411
- print("\n๐Ÿ“ Extracting text FULL embeddings for hierarchy classification...")
1412
- text_full_embeddings, _, text_hierarchies = self.extract_full_embeddings(dataloader, 'text', max_samples)
1413
- print(f" Text full embeddings shape: {text_full_embeddings.shape} (using all {text_full_embeddings.shape[1]} dimensions)")
1414
- text_hierarchy_metrics = self.compute_similarity_metrics(text_full_embeddings, text_hierarchies)
1415
- text_hierarchy_class = self.evaluate_classification_performance(
1416
- text_full_embeddings, text_hierarchies, "Text Full Embeddings (512D) - Hierarchy", "Hierarchy"
1417
- )
1418
- text_hierarchy_metrics.update(text_hierarchy_class)
1419
- results['text_hierarchy'] = text_hierarchy_metrics
1420
-
1421
- del text_full_embeddings
1422
- torch.cuda.empty_cache() if torch.cuda.is_available() else None
1423
-
1424
- # Image hierarchy embeddings
1425
- print("\n๐Ÿ–ผ๏ธ Extracting image FULL embeddings for hierarchy classification...")
1426
- image_full_embeddings, _, image_hierarchies = self.extract_full_embeddings(dataloader, 'image', max_samples)
1427
- print(f" Image full embeddings shape: {image_full_embeddings.shape} (using all {image_full_embeddings.shape[1]} dimensions)")
1428
- image_hierarchy_metrics = self.compute_similarity_metrics(image_full_embeddings, image_hierarchies)
1429
- image_hierarchy_class = self.evaluate_classification_performance(
1430
- image_full_embeddings, image_hierarchies, "Image Full Embeddings (512D) - Hierarchy", "Hierarchy"
1431
- )
1432
- image_hierarchy_metrics.update(image_hierarchy_class)
1433
- results['image_hierarchy'] = image_hierarchy_metrics
1434
-
1435
- del image_full_embeddings
1436
- torch.cuda.empty_cache() if torch.cuda.is_available() else None
1437
-
1438
- # ========== SAVE VISUALIZATIONS ==========
1439
- os.makedirs(self.directory, exist_ok=True)
1440
- dataset_prefix = dataset_name.lower().replace(' ', '_').replace('-', '_')
1441
- for key in ['text_color', 'image_color', 'text_hierarchy', 'image_hierarchy']:
1442
- results[key]['figure'].savefig(
1443
- f"{self.directory}/{dataset_prefix}_full_{key.replace('_', '_')}_confusion_matrix.png",
1444
- dpi=300,
1445
- bbox_inches='tight',
1446
- )
1447
- plt.close(results[key]['figure'])
1448
-
1449
- return results
1450
-
1451
- def compare_subspace_vs_full_embeddings(self, results_subspace, results_full, dataset_name):
1452
- """
1453
- Compare performance between specialized subspaces (16/64 dims) vs full embeddings (512 dims)
1454
- """
1455
- print(f"\n{'='*60}")
1456
- print(f"๐Ÿ“Š COMPARISON: Subspace vs Full Embeddings - {dataset_name}")
1457
- print(f"{'='*60}")
1458
-
1459
- comparisons = []
1460
-
1461
- # Text Color
1462
- subspace_color_text_acc = results_subspace.get('text_color', {}).get('accuracy', 0)
1463
- full_color_text_acc = results_full.get('text_color', {}).get('accuracy', 0)
1464
- if subspace_color_text_acc > 0 and full_color_text_acc > 0:
1465
- diff = full_color_text_acc - subspace_color_text_acc
1466
- comparisons.append({
1467
- 'type': 'Text Color',
1468
- 'subspace': subspace_color_text_acc,
1469
- 'full': full_color_text_acc,
1470
- 'diff': diff,
1471
- 'subspace_dims': '0-15 (16 dims)',
1472
- 'full_dims': 'All 512 dims'
1473
- })
1474
-
1475
- # Image Color
1476
- subspace_color_img_acc = results_subspace.get('image_color', {}).get('accuracy', 0)
1477
- full_color_img_acc = results_full.get('image_color', {}).get('accuracy', 0)
1478
- if subspace_color_img_acc > 0 and full_color_img_acc > 0:
1479
- diff = full_color_img_acc - subspace_color_img_acc
1480
- comparisons.append({
1481
- 'type': 'Image Color',
1482
- 'subspace': subspace_color_img_acc,
1483
- 'full': full_color_img_acc,
1484
- 'diff': diff,
1485
- 'subspace_dims': '0-15 (16 dims)',
1486
- 'full_dims': 'All 512 dims'
1487
- })
1488
-
1489
- # Text Hierarchy
1490
- subspace_hier_text_acc = results_subspace.get('text_hierarchy', {}).get('accuracy', 0)
1491
- full_hier_text_acc = results_full.get('text_hierarchy', {}).get('accuracy', 0)
1492
- if subspace_hier_text_acc > 0 and full_hier_text_acc > 0:
1493
- diff = full_hier_text_acc - subspace_hier_text_acc
1494
- comparisons.append({
1495
- 'type': 'Text Hierarchy',
1496
- 'subspace': subspace_hier_text_acc,
1497
- 'full': full_hier_text_acc,
1498
- 'diff': diff,
1499
- 'subspace_dims': '16-79 (64 dims)',
1500
- 'full_dims': 'All 512 dims'
1501
- })
1502
-
1503
- # Image Hierarchy
1504
- subspace_hier_img_acc = results_subspace.get('image_hierarchy', {}).get('accuracy', 0)
1505
- full_hier_img_acc = results_full.get('image_hierarchy', {}).get('accuracy', 0)
1506
- if subspace_hier_img_acc > 0 and full_hier_img_acc > 0:
1507
- diff = full_hier_img_acc - subspace_hier_img_acc
1508
- comparisons.append({
1509
- 'type': 'Image Hierarchy',
1510
- 'subspace': subspace_hier_img_acc,
1511
- 'full': full_hier_img_acc,
1512
- 'diff': diff,
1513
- 'subspace_dims': '16-79 (64 dims)',
1514
- 'full_dims': 'All 512 dims'
1515
- })
1516
-
1517
- # Display comparisons
1518
- print("\n๐Ÿ“ˆ PERFORMANCE COMPARISON:")
1519
- print("-" * 60)
1520
- for comp in comparisons:
1521
- better = "โœ… Full (512D)" if comp['diff'] > 0 else "โœ… Subspace"
1522
- print(f"\n{comp['type']}:")
1523
- print(f" Subspace ({comp['subspace_dims']}): {comp['subspace']*100:.2f}%")
1524
- print(f" Full ({comp['full_dims']}): {comp['full']*100:.2f}%")
1525
- print(f" Difference: {comp['diff']*100:+.2f}% โ†’ {better}")
1526
-
1527
- print(f"\n{'='*60}")
1528
- print("๐Ÿ’ก INTERPRETATION:")
1529
- print(f"{'='*60}")
1530
- full_better_count = sum(1 for c in comparisons if c['diff'] > 0)
1531
-
1532
- if full_better_count > len(comparisons) / 2:
1533
- print("\nโœ… Full embeddings (512D) perform better on most metrics.")
1534
- print(" This suggests that using all dimensions provides more information")
1535
- print(" for classification, even though specialized subspaces offer interpretability.")
1536
- else:
1537
- print("\nโœ… Specialized subspaces perform competitively or better.")
1538
- print(" This validates the effectiveness of dimensional specialization")
1539
- print(" while maintaining interpretability advantages.")
1540
-
1541
- print("\n๐Ÿ“Š Trade-off summary:")
1542
- print(" โ€ข Subspace (16/64 dims): Better interpretability, task-specific")
1543
- print(" โ€ข Full (512 dims): More information, potentially better accuracy")
1544
- print(" โ€ข Use case: Subspace for explainability, Full for maximum performance")
1545
-
1546
- return comparisons
1547
-
1548
  def evaluate_baseline_fashion_mnist(self, max_samples=1000):
1549
  """Evaluate baseline Fashion CLIP model on Fashion-MNIST"""
1550
  print(f"\n{'='*60}")
@@ -1568,22 +1370,15 @@ class ColorHierarchyEvaluator:
1568
 
1569
  # Evaluate text embeddings
1570
  print("\n๐Ÿ“ Extracting baseline text embeddings from Fashion-MNIST...")
1571
- text_embeddings, text_colors, text_hierarchies = self.extract_baseline_embeddings_batch(dataloader, 'text', max_samples)
1572
  print(f" Baseline text embeddings shape: {text_embeddings.shape} (using all {text_embeddings.shape[1]} dimensions)")
1573
- text_color_metrics = self.compute_similarity_metrics(text_embeddings, text_colors)
1574
  text_hierarchy_metrics = self.compute_similarity_metrics(text_embeddings, text_hierarchies)
1575
-
1576
- text_color_classification = self.evaluate_classification_performance(
1577
- text_embeddings, text_colors, "Baseline Fashion-MNIST Text Embeddings - Color", "Color"
1578
- )
1579
  text_hierarchy_classification = self.evaluate_classification_performance(
1580
  text_embeddings, text_hierarchies, "Baseline Fashion-MNIST Text Embeddings - Hierarchy", "Hierarchy"
1581
  )
1582
 
1583
- text_color_metrics.update(text_color_classification)
1584
  text_hierarchy_metrics.update(text_hierarchy_classification)
1585
  results['text'] = {
1586
- 'color': text_color_metrics,
1587
  'hierarchy': text_hierarchy_metrics
1588
  }
1589
 
@@ -1595,20 +1390,14 @@ class ColorHierarchyEvaluator:
1595
  print("\n๐Ÿ–ผ๏ธ Extracting baseline image embeddings from Fashion-MNIST...")
1596
  image_embeddings, image_colors, image_hierarchies = self.extract_baseline_embeddings_batch(dataloader, 'image', max_samples)
1597
  print(f" Baseline image embeddings shape: {image_embeddings.shape} (using all {image_embeddings.shape[1]} dimensions)")
1598
- image_color_metrics = self.compute_similarity_metrics(image_embeddings, image_colors)
1599
  image_hierarchy_metrics = self.compute_similarity_metrics(image_embeddings, image_hierarchies)
1600
 
1601
- image_color_classification = self.evaluate_classification_performance(
1602
- image_embeddings, image_colors, "Baseline Fashion-MNIST Image Embeddings - Color", "Color"
1603
- )
1604
  image_hierarchy_classification = self.evaluate_classification_performance(
1605
  image_embeddings, image_hierarchies, "Baseline Fashion-MNIST Image Embeddings - Hierarchy", "Hierarchy"
1606
  )
1607
 
1608
- image_color_metrics.update(image_color_classification)
1609
  image_hierarchy_metrics.update(image_hierarchy_classification)
1610
  results['image'] = {
1611
- 'color': image_color_metrics,
1612
  'hierarchy': image_hierarchy_metrics
1613
  }
1614
 
@@ -1619,7 +1408,7 @@ class ColorHierarchyEvaluator:
1619
  # ========== SAVE VISUALIZATIONS ==========
1620
  os.makedirs(self.directory, exist_ok=True)
1621
  for key in ['text', 'image']:
1622
- for subkey in ['color', 'hierarchy']:
1623
  figure = results[key][subkey]['figure']
1624
  figure.savefig(
1625
  f"{self.directory}/fashion_baseline_{key}_{subkey}_confusion_matrix.png",
@@ -1804,172 +1593,27 @@ class ColorHierarchyEvaluator:
1804
 
1805
  return results
1806
 
1807
- def analyze_baseline_vs_trained_performance(self, results_trained, results_baseline, dataset_name):
1808
- """
1809
- Analyse et explique pourquoi la baseline peut performer mieux que le modรจle entraรฎnรฉ
1810
-
1811
- Raisons possibles:
1812
- 1. Capacitรฉ dimensionnelle: Baseline utilise toutes les dimensions (512), modรจle entraรฎnรฉ utilise seulement des sous-espaces (17 ou 64 dims)
1813
- 2. Distribution shift: Dataset de validation diffรฉrent de celui d'entraรฎnement
1814
- 3. Overfitting: Modรจle trop spรฉcialisรฉ sur le dataset d'entraรฎnement
1815
- 4. Gรฉnรฉralisation: Baseline prรฉ-entraรฎnรฉe sur un dataset plus large et diversifiรฉ
1816
- 5. Perte d'information: Spรฉcialisation excessive peut causer perte d'information gรฉnรฉrale
1817
- """
1818
- print(f"\n{'='*60}")
1819
- print(f"๐Ÿ“Š ANALYSE: Baseline vs Modรจle Entraรฎnรฉ - {dataset_name}")
1820
- print(f"{'='*60}")
1821
-
1822
- # Comparer les mรฉtriques pour chaque type d'embedding
1823
- comparisons = []
1824
-
1825
- # Text Color
1826
- trained_color_text_acc = results_trained.get('text_color', {}).get('accuracy', 0)
1827
- baseline_color_text_acc = results_baseline.get('text', {}).get('color', {}).get('accuracy', 0)
1828
- if trained_color_text_acc > 0 and baseline_color_text_acc > 0:
1829
- diff = baseline_color_text_acc - trained_color_text_acc
1830
- comparisons.append({
1831
- 'type': 'Text Color',
1832
- 'trained': trained_color_text_acc,
1833
- 'baseline': baseline_color_text_acc,
1834
- 'diff': diff,
1835
- 'trained_dims': '0-15 (16 dims)',
1836
- 'baseline_dims': 'All dimensions (512 dims)'
1837
- })
1838
-
1839
- # Image Color
1840
- trained_color_img_acc = results_trained.get('image_color', {}).get('accuracy', 0)
1841
- baseline_color_img_acc = results_baseline.get('image', {}).get('color', {}).get('accuracy', 0)
1842
- if trained_color_img_acc > 0 and baseline_color_img_acc > 0:
1843
- diff = baseline_color_img_acc - trained_color_img_acc
1844
- comparisons.append({
1845
- 'type': 'Image Color',
1846
- 'trained': trained_color_img_acc,
1847
- 'baseline': baseline_color_img_acc,
1848
- 'diff': diff,
1849
- 'trained_dims': '0-15 (16 dims)',
1850
- 'baseline_dims': 'All dimensions (512 dims)'
1851
- })
1852
-
1853
- # Text Hierarchy
1854
- trained_hier_text_acc = results_trained.get('text_hierarchy', {}).get('accuracy', 0)
1855
- baseline_hier_text_acc = results_baseline.get('text', {}).get('hierarchy', {}).get('accuracy', 0)
1856
- if trained_hier_text_acc > 0 and baseline_hier_text_acc > 0:
1857
- diff = baseline_hier_text_acc - trained_hier_text_acc
1858
- comparisons.append({
1859
- 'type': 'Text Hierarchy',
1860
- 'trained': trained_hier_text_acc,
1861
- 'baseline': baseline_hier_text_acc,
1862
- 'diff': diff,
1863
- 'trained_dims': '16-79 (64 dims)',
1864
- 'baseline_dims': 'All dimensions (512 dims)'
1865
- })
1866
-
1867
- # Image Hierarchy
1868
- trained_hier_img_acc = results_trained.get('image_hierarchy', {}).get('accuracy', 0)
1869
- baseline_hier_img_acc = results_baseline.get('image', {}).get('hierarchy', {}).get('accuracy', 0)
1870
- if trained_hier_img_acc > 0 and baseline_hier_img_acc > 0:
1871
- diff = baseline_hier_img_acc - trained_hier_img_acc
1872
- comparisons.append({
1873
- 'type': 'Image Hierarchy',
1874
- 'trained': trained_hier_img_acc,
1875
- 'baseline': baseline_hier_img_acc,
1876
- 'diff': diff,
1877
- 'trained_dims': '16-79 (64 dims)',
1878
- 'baseline_dims': 'All dimensions (512 dims)'
1879
- })
1880
-
1881
- # Afficher les comparaisons
1882
- print("\n๐Ÿ“ˆ COMPARAISON DES PERFORMANCES:")
1883
- print("-" * 60)
1884
- for comp in comparisons:
1885
- better = "โœ… Baseline" if comp['diff'] > 0 else "โœ… Modรจle Entraรฎnรฉ"
1886
- print(f"\n{comp['type']}:")
1887
- print(f" Modรจle Entraรฎnรฉ ({comp['trained_dims']}): {comp['trained']*100:.2f}%")
1888
- print(f" Baseline ({comp['baseline_dims']}): {comp['baseline']*100:.2f}%")
1889
- print(f" Diffรฉrence: {comp['diff']*100:+.2f}% โ†’ {better}")
1890
-
1891
- # Analyse des raisons
1892
- print(f"\n{'='*60}")
1893
- print("๐Ÿ” EXPLICATIONS POSSIBLES:")
1894
- print(f"{'='*60}")
1895
-
1896
- avg_diff = np.mean([abs(c['diff']) for c in comparisons]) if comparisons else 0
1897
- baseline_better_count = sum(1 for c in comparisons if c['diff'] > 0)
1898
-
1899
- if baseline_better_count > len(comparisons) / 2:
1900
- print("\nโš ๏ธ La baseline performe mieux sur la majoritรฉ des mรฉtriques.")
1901
- print("\nRaisons probables:")
1902
- print("\n1. ๐Ÿ“ CAPACITร‰ DIMENSIONNELLE:")
1903
- print(" โ€ข Baseline: Utilise TOUTES les 512 dimensions des embeddings")
1904
- print(" โ€ข Modรจle entraรฎnรฉ: Utilise seulement 16 dims (couleur) ou 64 dims (hiรฉrarchie)")
1905
- print(" โ€ข Impact: La baseline a accรจs ร  plus d'information pour la classification")
1906
-
1907
- print("\n2. ๐ŸŽฏ SUR-SPร‰CIALISATION:")
1908
- print(" โ€ข Le modรจle entraรฎnรฉ a รฉtรฉ spรฉcialisรฉ pour sรฉparer couleur et hiรฉrarchie")
1909
- print(" โ€ข Cette spรฉcialisation peut causer une perte d'information gรฉnรฉrale")
1910
- print(" โ€ข Les dimensions non utilisรฉes peuvent contenir de l'information utile")
1911
-
1912
- print("\n3. ๐Ÿ“Š DISTRIBUTION SHIFT:")
1913
- print(" โ€ข Le dataset de validation peut avoir une distribution diffรฉrente")
1914
- print(" โ€ข Le modรจle entraรฎnรฉ peut avoir overfittรฉ sur le dataset d'entraรฎnement")
1915
- print(" โ€ข La baseline prรฉ-entraรฎnรฉe est plus robuste car entraรฎnรฉe sur plus de donnรฉes")
1916
-
1917
- print("\n4. ๐ŸŒ Gร‰Nร‰RALISATION:")
1918
- print(" โ€ข Baseline Fashion CLIP: Entraรฎnรฉe sur un large dataset diversifiรฉ")
1919
- print(" โ€ข Modรจle entraรฎnรฉ: Entraรฎnรฉ sur un dataset plus spรฉcifique")
1920
- print(" โ€ข La baseline peut mieux gรฉnรฉraliser ร  des distributions nouvelles")
1921
-
1922
- print("\n5. ๐Ÿ”„ TRADE-OFF SPร‰CIALISATION vs CAPACITร‰:")
1923
- print(" โ€ข Spรฉcialisation (modรจle entraรฎnรฉ): Meilleure sรฉparation explicable")
1924
- print(" โ€ข Capacitรฉ (baseline): Plus d'information pour meilleure performance brute")
1925
- print(" โ€ข C'est un compromis entre interprรฉtabilitรฉ et performance")
1926
-
1927
- print(f"\n{'='*60}")
1928
- print("๐Ÿ’ก RECOMMANDATIONS:")
1929
- print(f"{'='*60}")
1930
- print("\n1. Analyser les matrices de confusion pour voir les types d'erreurs")
1931
- print("2. Vรฉrifier si le modรจle entraรฎnรฉ performe mieux sur le dataset d'entraรฎnement")
1932
- print("\n3. ๐Ÿ”ง CONSIDร‰RER UTILISER TOUTES LES DIMENSIONS POUR LA CLASSIFICATION FINALE:")
1933
- print(" Actuellement:")
1934
- print(" โ€ข Modรจle entraรฎnรฉ: utilise seulement dims 0-15 (couleur) ou dims 16-79 (hiรฉrarchie)")
1935
- print(" โ€ข Baseline: utilise toutes les 512 dimensions")
1936
- print(" ")
1937
- print(" Solution proposรฉe:")
1938
- print(" โ€ข Utiliser TOUTES les dimensions du modรจle entraรฎnรฉ (ex: 512 dims) pour la classification")
1939
- print(" โ€ข Cela permet d'avoir accรจs ร  toute l'information disponible")
1940
- print(" โ€ข Mรฉthode disponible: extract_full_embeddings() pour extraire toutes les dimensions")
1941
- print(" โ€ข Vous pouvez alors comparer:")
1942
- print(" - Spรฉcialisรฉ (16 ou 64 dims) โ†’ meilleur pour interprรฉtabilitรฉ")
1943
- print(" - Complet (512 dims) โ†’ meilleur pour performance brute")
1944
- print("\n4. Utiliser les embeddings spรฉcialisรฉs pour l'interprรฉtabilitรฉ, pas pour la classification brute")
1945
- print("5. Si la performance est critique, combiner spรฉcialisรฉ + gรฉnรฉral (ensemble)")
1946
-
1947
- return comparisons
1948
 
1949
 
1950
  if __name__ == "__main__":
1951
  device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
1952
  print(f"Using device: {device}")
1953
 
1954
- directory = 'main_model_analysis_model'
1955
  max_samples = 10000
1956
 
1957
  evaluator = ColorHierarchyEvaluator(device=device, directory=directory)
1958
 
1959
  # Evaluate Fashion-MNIST
1960
  print("\n" + "="*60)
1961
- print("๐Ÿš€ Starting evaluation of Fashion-MNIST with Color & Hierarchy embeddings")
1962
  print("="*60)
1963
  results_fashion = evaluator.evaluate_fashion_mnist(max_samples=max_samples)
1964
 
1965
  print(f"\n{'='*60}")
1966
  print("FASHION-MNIST EVALUATION SUMMARY")
1967
  print(f"{'='*60}")
1968
-
1969
- print("\n๐ŸŽจ COLOR CLASSIFICATION RESULTS (dims 0-15):")
1970
- print(f" Text - NN Acc: {results_fashion['text_color']['accuracy']*100:.1f}% | Centroid Acc: {results_fashion['text_color']['centroid_accuracy']*100:.1f}% | Separation: {results_fashion['text_color']['separation_score']:.4f}")
1971
- print(f" Image - NN Acc: {results_fashion['image_color']['accuracy']*100:.1f}% | Centroid Acc: {results_fashion['image_color']['centroid_accuracy']*100:.1f}% | Separation: {results_fashion['image_color']['separation_score']:.4f}")
1972
-
1973
  print("\n๐Ÿ“‹ HIERARCHY CLASSIFICATION RESULTS (dims 16-79):")
1974
  print(f" Text - NN Acc: {results_fashion['text_hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_fashion['text_hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_fashion['text_hierarchy']['separation_score']:.4f}")
1975
  print(f" Image - NN Acc: {results_fashion['image_hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_fashion['image_hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_fashion['image_hierarchy']['separation_score']:.4f}")
@@ -1983,48 +1627,12 @@ if __name__ == "__main__":
1983
  print(f"\n{'='*60}")
1984
  print("BASELINE FASHION-MNIST EVALUATION SUMMARY")
1985
  print(f"{'='*60}")
1986
-
1987
- print("\n๐ŸŽจ COLOR CLASSIFICATION RESULTS (Baseline):")
1988
- print(f" Text - NN Acc: {results_baseline['text']['color']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline['text']['color']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline['text']['color']['separation_score']:.4f}")
1989
- print(f" Image - NN Acc: {results_baseline['image']['color']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline['image']['color']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline['image']['color']['separation_score']:.4f}")
1990
-
1991
  print("\n๐Ÿ“‹ HIERARCHY CLASSIFICATION RESULTS (Baseline):")
1992
  print(f" Text - NN Acc: {results_baseline['text']['hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline['text']['hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline['text']['hierarchy']['separation_score']:.4f}")
1993
  print(f" Image - NN Acc: {results_baseline['image']['hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline['image']['hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline['image']['hierarchy']['separation_score']:.4f}")
1994
 
1995
- # Analyse comparative pour Fashion-MNIST
1996
- evaluator.analyze_baseline_vs_trained_performance(
1997
- results_fashion,
1998
- results_baseline,
1999
- "Fashion-MNIST"
2000
- )
2001
-
2002
- # Evaluate Fashion-MNIST with FULL 512-dimensional embeddings
2003
- print("\n" + "="*60)
2004
- print("๐Ÿš€ Starting evaluation of Fashion-MNIST with FULL 512-dimensional embeddings")
2005
- print("="*60)
2006
- target_hierarchy_classes = evaluator.validation_hierarchy_classes or evaluator.hierarchy_classes
2007
- fashion_dataset = load_fashion_mnist_dataset(max_samples, hierarchy_classes=target_hierarchy_classes)
2008
- fashion_dataloader = DataLoader(fashion_dataset, batch_size=8, shuffle=False, num_workers=0)
2009
- results_fashion_full = evaluator.evaluate_full_embeddings(fashion_dataloader, "Fashion-MNIST", max_samples=max_samples)
2010
-
2011
- print(f"\n{'='*60}")
2012
- print("FASHION-MNIST FULL EMBEDDINGS (512D) EVALUATION SUMMARY")
2013
- print(f"{'='*60}")
2014
- print("\n๐ŸŽจ COLOR CLASSIFICATION RESULTS (512 dims):")
2015
- print(f" Text - NN Acc: {results_fashion_full['text_color']['accuracy']*100:.1f}% | Centroid Acc: {results_fashion_full['text_color']['centroid_accuracy']*100:.1f}% | Separation: {results_fashion_full['text_color']['separation_score']:.4f}")
2016
- print(f" Image - NN Acc: {results_fashion_full['image_color']['accuracy']*100:.1f}% | Centroid Acc: {results_fashion_full['image_color']['centroid_accuracy']*100:.1f}% | Separation: {results_fashion_full['image_color']['separation_score']:.4f}")
2017
- print("\n๐Ÿ“‹ HIERARCHY CLASSIFICATION RESULTS (512 dims):")
2018
- print(f" Text - NN Acc: {results_fashion_full['text_hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_fashion_full['text_hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_fashion_full['text_hierarchy']['separation_score']:.4f}")
2019
- print(f" Image - NN Acc: {results_fashion_full['image_hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_fashion_full['image_hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_fashion_full['image_hierarchy']['separation_score']:.4f}")
2020
-
2021
- # Compare subspace vs full embeddings for Fashion-MNIST
2022
- evaluator.compare_subspace_vs_full_embeddings(
2023
- results_fashion,
2024
- results_fashion_full,
2025
- "Fashion-MNIST"
2026
- )
2027
-
2028
  # Evaluate KAGL Marqo
2029
  print("\n" + "="*60)
2030
  print("๐Ÿš€ Starting evaluation of KAGL Marqo with Color & Hierarchy embeddings")
@@ -2062,41 +1670,7 @@ if __name__ == "__main__":
2062
  print("\n๐Ÿ“‹ HIERARCHY CLASSIFICATION RESULTS (Baseline):")
2063
  print(f" Text - NN Acc: {results_baseline_kaggle['text']['hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_kaggle['text']['hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_kaggle['text']['hierarchy']['separation_score']:.4f}")
2064
  print(f" Image - NN Acc: {results_baseline_kaggle['image']['hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_kaggle['image']['hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_kaggle['image']['hierarchy']['separation_score']:.4f}")
2065
-
2066
- # Analyse comparative pour KAGL Marqo
2067
- if results_kaggle is not None:
2068
- evaluator.analyze_baseline_vs_trained_performance(
2069
- results_kaggle,
2070
- results_baseline_kaggle,
2071
- "KAGL Marqo Dataset"
2072
- )
2073
 
2074
- # Evaluate KAGL Marqo with FULL 512-dimensional embeddings
2075
- print("\n" + "="*60)
2076
- print("๐Ÿš€ Starting evaluation of KAGL Marqo with FULL 512-dimensional embeddings")
2077
- print("="*60)
2078
- kaggle_dataset = load_kaggle_marqo_dataset(evaluator, max_samples)
2079
- if kaggle_dataset is not None:
2080
- kaggle_dataloader = DataLoader(kaggle_dataset, batch_size=8, shuffle=False, num_workers=0)
2081
- results_kaggle_full = evaluator.evaluate_full_embeddings(kaggle_dataloader, "KAGL Marqo", max_samples=max_samples)
2082
-
2083
- print(f"\n{'='*60}")
2084
- print("KAGL MARQO FULL EMBEDDINGS (512D) EVALUATION SUMMARY")
2085
- print(f"{'='*60}")
2086
- print("\n๐ŸŽจ COLOR CLASSIFICATION RESULTS (512 dims):")
2087
- print(f" Text - NN Acc: {results_kaggle_full['text_color']['accuracy']*100:.1f}% | Centroid Acc: {results_kaggle_full['text_color']['centroid_accuracy']*100:.1f}% | Separation: {results_kaggle_full['text_color']['separation_score']:.4f}")
2088
- print(f" Image - NN Acc: {results_kaggle_full['image_color']['accuracy']*100:.1f}% | Centroid Acc: {results_kaggle_full['image_color']['centroid_accuracy']*100:.1f}% | Separation: {results_kaggle_full['image_color']['separation_score']:.4f}")
2089
- print("\n๐Ÿ“‹ HIERARCHY CLASSIFICATION RESULTS (512 dims):")
2090
- print(f" Text - NN Acc: {results_kaggle_full['text_hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_kaggle_full['text_hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_kaggle_full['text_hierarchy']['separation_score']:.4f}")
2091
- print(f" Image - NN Acc: {results_kaggle_full['image_hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_kaggle_full['image_hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_kaggle_full['image_hierarchy']['separation_score']:.4f}")
2092
-
2093
- # Compare subspace vs full embeddings for KAGL Marqo
2094
- evaluator.compare_subspace_vs_full_embeddings(
2095
- results_kaggle,
2096
- results_kaggle_full,
2097
- "KAGL Marqo"
2098
- )
2099
-
2100
  # Evaluate Local Validation Dataset
2101
  print("\n" + "="*60)
2102
  print("๐Ÿš€ Starting evaluation of Local Validation Dataset with Color & Hierarchy embeddings")
@@ -2134,46 +1708,3 @@ if __name__ == "__main__":
2134
  print("\n๐Ÿ“‹ HIERARCHY CLASSIFICATION RESULTS (Baseline):")
2135
  print(f" Text - NN Acc: {results_baseline_local['text']['hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_local['text']['hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_local['text']['hierarchy']['separation_score']:.4f}")
2136
  print(f" Image - NN Acc: {results_baseline_local['image']['hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_local['image']['hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_local['image']['hierarchy']['separation_score']:.4f}")
2137
-
2138
- # Analyse comparative pour le dataset de validation local
2139
- if results_local is not None:
2140
- evaluator.analyze_baseline_vs_trained_performance(
2141
- results_local,
2142
- results_baseline_local,
2143
- "Local Validation Dataset"
2144
- )
2145
-
2146
- # Evaluate Local Validation with FULL 512-dimensional embeddings
2147
- print("\n" + "="*60)
2148
- print("๐Ÿš€ Starting evaluation of Local Validation with FULL 512-dimensional embeddings")
2149
- print("="*60)
2150
- local_dataset = load_local_validation_dataset(max_samples)
2151
- if local_dataset is not None:
2152
- # Filter to only include hierarchies that exist in our model
2153
- if len(local_dataset.dataframe) > 0:
2154
- valid_df = local_dataset.dataframe[local_dataset.dataframe['hierarchy'].isin(evaluator.hierarchy_classes)]
2155
- if len(valid_df) > 0:
2156
- if len(valid_df) < len(local_dataset.dataframe):
2157
- local_dataset = LocalDataset(valid_df)
2158
-
2159
- local_dataloader = DataLoader(local_dataset, batch_size=8, shuffle=False, num_workers=0)
2160
- results_local_full = evaluator.evaluate_full_embeddings(local_dataloader, "Local Validation", max_samples=max_samples)
2161
-
2162
- print(f"\n{'='*60}")
2163
- print("LOCAL VALIDATION FULL EMBEDDINGS (512D) EVALUATION SUMMARY")
2164
- print(f"{'='*60}")
2165
- print("\n๐ŸŽจ COLOR CLASSIFICATION RESULTS (512 dims):")
2166
- print(f" Text - NN Acc: {results_local_full['text_color']['accuracy']*100:.1f}% | Centroid Acc: {results_local_full['text_color']['centroid_accuracy']*100:.1f}% | Separation: {results_local_full['text_color']['separation_score']:.4f}")
2167
- print(f" Image - NN Acc: {results_local_full['image_color']['accuracy']*100:.1f}% | Centroid Acc: {results_local_full['image_color']['centroid_accuracy']*100:.1f}% | Separation: {results_local_full['image_color']['separation_score']:.4f}")
2168
- print("\n๐Ÿ“‹ HIERARCHY CLASSIFICATION RESULTS (512 dims):")
2169
- print(f" Text - NN Acc: {results_local_full['text_hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_local_full['text_hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_local_full['text_hierarchy']['separation_score']:.4f}")
2170
- print(f" Image - NN Acc: {results_local_full['image_hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_local_full['image_hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_local_full['image_hierarchy']['separation_score']:.4f}")
2171
-
2172
- # Compare subspace vs full embeddings for Local Validation
2173
- evaluator.compare_subspace_vs_full_embeddings(
2174
- results_local,
2175
- results_local_full,
2176
- "Local Validation"
2177
- )
2178
-
2179
- print(f"\nโœ… Evaluation completed! Check '{directory}/' for visualization files.")
 
19
  warnings.filterwarnings('ignore')
20
  from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers
21
 
22
+ from config import main_model_path, hierarchy_model_path, color_model_path, color_emb_dim, hierarchy_emb_dim, local_dataset_path, column_local_image_path
23
 
24
 
25
  def create_fashion_mnist_to_hierarchy_mapping(hierarchy_classes):
 
508
  print("โŒ No valid samples after filtering.")
509
  return None
510
 
511
+ # NO COLOR FILTERING for local dataset - keep all colors for comprehensive evaluation
512
+ if 'color' in df_clean.columns:
513
+ print(f"๐ŸŽจ Total unique colors in dataset: {len(df_clean['color'].unique())}")
514
+ print(f"๐ŸŽจ Colors found: {sorted(df_clean['color'].unique())}")
515
+ print(f"๐ŸŽจ Color distribution (top 15):")
516
+ color_counts = df_clean['color'].value_counts()
517
+ for color in color_counts.index[:15]: # Show top 15 colors
518
+ print(f" {color}: {color_counts[color]} samples")
519
+
520
  # Ensure we have required columns
521
  required_cols = ['text', 'hierarchy']
522
  missing_cols = [col for col in required_cols if col not in df_clean.columns]
 
524
  print(f"โŒ Missing required columns: {missing_cols}")
525
  return None
526
 
527
+ # Limit to max_samples with RANDOM SAMPLING to get diverse colors
528
  if len(df_clean) > max_samples:
529
+ df_clean = df_clean.sample(n=max_samples, random_state=42)
530
+ print(f"๐Ÿ“Š Randomly sampled {max_samples} samples")
531
 
532
  print(f"๐Ÿ“Š Using {len(df_clean)} samples for evaluation")
533
  print(f" Samples per hierarchy:")
 
535
  count = len(df_clean[df_clean['hierarchy'] == hierarchy])
536
  print(f" {hierarchy}: {count} samples")
537
 
538
+ # Show color distribution after sampling
539
+ if 'color' in df_clean.columns:
540
+ print(f"\n๐ŸŽจ Color distribution in sampled data:")
541
+ color_counts = df_clean['color'].value_counts()
542
+ print(f" Total unique colors: {len(color_counts)}")
543
+ for color in color_counts.index[:15]: # Show top 15
544
+ print(f" {color}: {color_counts[color]} samples")
545
+
546
  return LocalDataset(df_clean)
547
 
548
 
 
744
  return np.vstack(all_embeddings), all_colors, all_hierarchies
745
 
746
  def extract_baseline_embeddings_batch(self, dataloader, embedding_type='text', max_samples=10000):
747
+ """
748
+ Extract embeddings from baseline Fashion CLIP model.
749
+
750
+ This method properly processes images and text through the Fashion-CLIP processor
751
+ and applies L2 normalization to embeddings, matching the evaluation in evaluate_color_embeddings.py
752
+ """
753
  all_embeddings = []
754
  all_colors = []
755
  all_hierarchies = []
 
762
  break
763
 
764
  images, texts, colors, hierarchies = batch
 
 
 
 
 
 
 
 
 
765
 
766
  # Extract embeddings based on type
767
  if embedding_type == 'text':
768
+ # Process text through Fashion-CLIP processor
769
+ text_inputs = self.baseline_processor(text=texts, return_tensors="pt", padding=True, truncation=True, max_length=77)
770
+ text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
771
+
772
+ # Get text features using the dedicated method
773
+ text_features = self.baseline_model.get_text_features(**text_inputs)
774
+
775
+ # Apply L2 normalization (critical for CLIP!)
776
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
777
+ embeddings = text_features
778
+
779
  elif embedding_type == 'image':
780
+ # Convert tensor images back to PIL Images for proper processing
781
+ pil_images = []
782
+ for i in range(images.shape[0]):
783
+ img_tensor = images[i]
784
+
785
+ # Denormalize if the images were normalized (undo ImageNet normalization)
786
+ # Check if images are normalized (values outside [0,1])
787
+ if img_tensor.min() < 0 or img_tensor.max() > 1:
788
+ # Undo ImageNet normalization
789
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
790
+ std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
791
+ img_tensor = img_tensor * std + mean
792
+ img_tensor = torch.clamp(img_tensor, 0, 1)
793
+
794
+ # Convert to PIL Image
795
+ img_pil = transforms.ToPILImage()(img_tensor)
796
+ pil_images.append(img_pil)
797
+
798
+ # Process images through Fashion-CLIP processor (will apply its own normalization)
799
+ image_inputs = self.baseline_processor(images=pil_images, return_tensors="pt")
800
+ image_inputs = {k: v.to(self.device) for k, v in image_inputs.items()}
801
+
802
+ # Get image features using the dedicated method
803
+ image_features = self.baseline_model.get_image_features(**image_inputs)
804
+
805
+ # Apply L2 normalization (critical for CLIP!)
806
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
807
+ embeddings = image_features
808
+
809
  else:
810
+ # Default to text
811
+ text_inputs = self.baseline_processor(text=texts, return_tensors="pt", padding=True, truncation=True, max_length=77)
812
+ text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
813
+ text_features = self.baseline_model.get_text_features(**text_inputs)
814
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
815
+ embeddings = text_features
816
 
817
  all_embeddings.append(embeddings.cpu().numpy())
818
  all_colors.extend(colors)
 
821
  sample_count += len(images)
822
 
823
  # Clear GPU memory
824
+ del embeddings
825
+ if embedding_type == 'image':
826
+ del pil_images, image_inputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
827
  else:
828
+ del text_inputs
 
 
 
 
 
 
 
 
 
 
829
  torch.cuda.empty_cache() if torch.cuda.is_available() else None
830
+
831
  return np.vstack(all_embeddings), all_colors, all_hierarchies
832
 
833
  def compute_similarity_metrics(self, embeddings, labels):
 
1060
 
1061
  results = {}
1062
 
1063
+ # ========== EXTRACT FULL EMBEDDINGS FOR ENSEMBLE ==========
1064
+ print("\n๐Ÿ“ฆ Extracting full 512-dimensional embeddings for ensemble...")
1065
+ text_full_embeddings, text_colors_full, text_hierarchies_full = self.extract_full_embeddings(dataloader, 'text', max_samples)
1066
+ image_full_embeddings, image_colors_full, image_hierarchies_full = self.extract_full_embeddings(dataloader, 'image', max_samples)
1067
+ print(f" Text full embeddings shape: {text_full_embeddings.shape}")
1068
+ print(f" Image full embeddings shape: {image_full_embeddings.shape}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1069
 
1070
+ # ========== HIERARCHY EVALUATION (DIMS 16-79) WITH ENSEMBLE ==========
1071
+ print("\n๐Ÿ“‹ HIERARCHY EVALUATION (dims 16-79) - Using Ensemble")
1072
  print("=" * 50)
1073
 
1074
+ # Extract specialized hierarchy embeddings (dims 16-79)
1075
+ print("\n๐Ÿ“ Extracting specialized text hierarchy embeddings (dims 16-79)...")
1076
+ text_hierarchy_embeddings_spec = text_full_embeddings[:, self.color_emb_dim:self.color_emb_dim+self.hierarchy_emb_dim] # dims 16-79
1077
+ print(f" Specialized text hierarchy embeddings shape: {text_hierarchy_embeddings_spec.shape}")
1078
+ text_hierarchy_metrics = self.compute_similarity_metrics(text_hierarchy_embeddings_spec, text_hierarchies_full)
1079
+ # Use ensemble: combine specialized (64D) + full (512D)
1080
  text_hierarchy_class = self.evaluate_classification_performance(
1081
+ text_hierarchy_embeddings_spec, text_hierarchies_full,
1082
+ "Text Hierarchy Embeddings (Ensemble)", "Hierarchy",
1083
+ full_embeddings=text_full_embeddings, ensemble_weight=0.4
1084
  )
1085
  text_hierarchy_metrics.update(text_hierarchy_class)
1086
  results['text_hierarchy'] = text_hierarchy_metrics
1087
 
1088
+ # Image hierarchy embeddings with ensemble
1089
+ print("\n๐Ÿ–ผ๏ธ Extracting specialized image hierarchy embeddings (dims 16-79)...")
1090
+ image_hierarchy_embeddings_spec = image_full_embeddings[:, self.color_emb_dim:self.color_emb_dim+self.hierarchy_emb_dim] # dims 16-79
1091
+ print(f" Specialized image hierarchy embeddings shape: {image_hierarchy_embeddings_spec.shape}")
1092
+ image_hierarchy_metrics = self.compute_similarity_metrics(image_hierarchy_embeddings_spec, image_hierarchies_full)
 
 
 
1093
  image_hierarchy_class = self.evaluate_classification_performance(
1094
+ image_hierarchy_embeddings_spec, image_hierarchies_full,
1095
+ "Image Hierarchy Embeddings (Ensemble)", "Hierarchy",
1096
+ full_embeddings=image_full_embeddings, ensemble_weight=0.4
1097
  )
1098
  image_hierarchy_metrics.update(image_hierarchy_class)
1099
  results['image_hierarchy'] = image_hierarchy_metrics
1100
 
1101
+ # Cleanup
1102
+ del text_full_embeddings, image_full_embeddings
1103
+ del text_color_embeddings_spec, image_color_embeddings_spec
1104
+ del text_hierarchy_embeddings_spec, image_hierarchy_embeddings_spec
1105
  torch.cuda.empty_cache() if torch.cuda.is_available() else None
1106
 
1107
  # ========== SAVE VISUALIZATIONS ==========
1108
  os.makedirs(self.directory, exist_ok=True)
1109
+ for key in ['text_hierarchy', 'image_hierarchy']:
1110
  results[key]['figure'].savefig(
1111
+ f"{self.directory}/fashion_{key.replace('_', '_')}_confusion_matrix.png",
1112
  dpi=300,
1113
  bbox_inches='tight',
1114
  )
 
1233
  return results
1234
 
1235
  def evaluate_local_validation(self, max_samples):
1236
+ """Evaluate both color and hierarchy embeddings on local validation dataset (NO ENSEMBLE - only specialized embeddings)"""
1237
  print(f"\n{'='*60}")
1238
  print("Evaluating Local Validation Dataset")
1239
+ print(" Color embeddings: dims 0-15 (specialized only, no ensemble)")
1240
+ print(" Hierarchy embeddings: dims 16-79 (specialized only, no ensemble)")
1241
  print(f"Max samples: {max_samples}")
1242
  print(f"{'='*60}")
1243
 
 
1271
 
1272
  results = {}
1273
 
1274
+ # ========== COLOR EVALUATION (DIMS 0-15) - SPECIALIZED ONLY ==========
1275
+ print("\n๐ŸŽจ COLOR EVALUATION (dims 0-15) - Specialized embeddings only")
1276
  print("=" * 50)
1277
 
1278
  # Text color embeddings
 
1303
  del image_color_embeddings
1304
  torch.cuda.empty_cache() if torch.cuda.is_available() else None
1305
 
1306
+ # ========== HIERARCHY EVALUATION (DIMS 16-79) - SPECIALIZED ONLY ==========
1307
+ print("\n๐Ÿ“‹ HIERARCHY EVALUATION (dims 16-79) - Specialized embeddings only")
1308
  print("=" * 50)
1309
 
1310
  # Text hierarchy embeddings
 
1347
 
1348
  return results
1349
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1350
  def evaluate_baseline_fashion_mnist(self, max_samples=1000):
1351
  """Evaluate baseline Fashion CLIP model on Fashion-MNIST"""
1352
  print(f"\n{'='*60}")
 
1370
 
1371
  # Evaluate text embeddings
1372
  print("\n๐Ÿ“ Extracting baseline text embeddings from Fashion-MNIST...")
1373
+ text_embeddings, _, text_hierarchies = self.extract_baseline_embeddings_batch(dataloader, 'text', max_samples)
1374
  print(f" Baseline text embeddings shape: {text_embeddings.shape} (using all {text_embeddings.shape[1]} dimensions)")
 
1375
  text_hierarchy_metrics = self.compute_similarity_metrics(text_embeddings, text_hierarchies)
 
 
 
 
1376
  text_hierarchy_classification = self.evaluate_classification_performance(
1377
  text_embeddings, text_hierarchies, "Baseline Fashion-MNIST Text Embeddings - Hierarchy", "Hierarchy"
1378
  )
1379
 
 
1380
  text_hierarchy_metrics.update(text_hierarchy_classification)
1381
  results['text'] = {
 
1382
  'hierarchy': text_hierarchy_metrics
1383
  }
1384
 
 
1390
  print("\n๐Ÿ–ผ๏ธ Extracting baseline image embeddings from Fashion-MNIST...")
1391
  image_embeddings, image_colors, image_hierarchies = self.extract_baseline_embeddings_batch(dataloader, 'image', max_samples)
1392
  print(f" Baseline image embeddings shape: {image_embeddings.shape} (using all {image_embeddings.shape[1]} dimensions)")
 
1393
  image_hierarchy_metrics = self.compute_similarity_metrics(image_embeddings, image_hierarchies)
1394
 
 
 
 
1395
  image_hierarchy_classification = self.evaluate_classification_performance(
1396
  image_embeddings, image_hierarchies, "Baseline Fashion-MNIST Image Embeddings - Hierarchy", "Hierarchy"
1397
  )
1398
 
 
1399
  image_hierarchy_metrics.update(image_hierarchy_classification)
1400
  results['image'] = {
 
1401
  'hierarchy': image_hierarchy_metrics
1402
  }
1403
 
 
1408
  # ========== SAVE VISUALIZATIONS ==========
1409
  os.makedirs(self.directory, exist_ok=True)
1410
  for key in ['text', 'image']:
1411
+ for subkey in ['hierarchy']:
1412
  figure = results[key][subkey]['figure']
1413
  figure.savefig(
1414
  f"{self.directory}/fashion_baseline_{key}_{subkey}_confusion_matrix.png",
 
1593
 
1594
  return results
1595
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1596
 
1597
 
1598
  if __name__ == "__main__":
1599
  device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
1600
  print(f"Using device: {device}")
1601
 
1602
+ directory = 'main_model_analysis'
1603
  max_samples = 10000
1604
 
1605
  evaluator = ColorHierarchyEvaluator(device=device, directory=directory)
1606
 
1607
  # Evaluate Fashion-MNIST
1608
  print("\n" + "="*60)
1609
+ print("๐Ÿš€ Starting evaluation of Fashion-MNIST Hierarchy embeddings")
1610
  print("="*60)
1611
  results_fashion = evaluator.evaluate_fashion_mnist(max_samples=max_samples)
1612
 
1613
  print(f"\n{'='*60}")
1614
  print("FASHION-MNIST EVALUATION SUMMARY")
1615
  print(f"{'='*60}")
1616
+
 
 
 
 
1617
  print("\n๐Ÿ“‹ HIERARCHY CLASSIFICATION RESULTS (dims 16-79):")
1618
  print(f" Text - NN Acc: {results_fashion['text_hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_fashion['text_hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_fashion['text_hierarchy']['separation_score']:.4f}")
1619
  print(f" Image - NN Acc: {results_fashion['image_hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_fashion['image_hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_fashion['image_hierarchy']['separation_score']:.4f}")
 
1627
  print(f"\n{'='*60}")
1628
  print("BASELINE FASHION-MNIST EVALUATION SUMMARY")
1629
  print(f"{'='*60}")
1630
+
 
 
 
 
1631
  print("\n๐Ÿ“‹ HIERARCHY CLASSIFICATION RESULTS (Baseline):")
1632
  print(f" Text - NN Acc: {results_baseline['text']['hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline['text']['hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline['text']['hierarchy']['separation_score']:.4f}")
1633
  print(f" Image - NN Acc: {results_baseline['image']['hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline['image']['hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline['image']['hierarchy']['separation_score']:.4f}")
1634
 
1635
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1636
  # Evaluate KAGL Marqo
1637
  print("\n" + "="*60)
1638
  print("๐Ÿš€ Starting evaluation of KAGL Marqo with Color & Hierarchy embeddings")
 
1670
  print("\n๐Ÿ“‹ HIERARCHY CLASSIFICATION RESULTS (Baseline):")
1671
  print(f" Text - NN Acc: {results_baseline_kaggle['text']['hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_kaggle['text']['hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_kaggle['text']['hierarchy']['separation_score']:.4f}")
1672
  print(f" Image - NN Acc: {results_baseline_kaggle['image']['hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_kaggle['image']['hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_kaggle['image']['hierarchy']['separation_score']:.4f}")
 
 
 
 
 
 
 
 
1673
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1674
  # Evaluate Local Validation Dataset
1675
  print("\n" + "="*60)
1676
  print("๐Ÿš€ Starting evaluation of Local Validation Dataset with Color & Hierarchy embeddings")
 
1708
  print("\n๐Ÿ“‹ HIERARCHY CLASSIFICATION RESULTS (Baseline):")
1709
  print(f" Text - NN Acc: {results_baseline_local['text']['hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_local['text']['hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_local['text']['hierarchy']['separation_score']:.4f}")
1710
  print(f" Image - NN Acc: {results_baseline_local['image']['hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_local['image']['hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_local['image']['hierarchy']['separation_score']:.4f}")