Leacb4 commited on
Commit
8c71174
·
verified ·
1 Parent(s): 9c2cc41

Upload evaluation/main_model_evaluation.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. evaluation/main_model_evaluation.py +175 -53
evaluation/main_model_evaluation.py CHANGED
@@ -19,7 +19,7 @@ import warnings
19
  warnings.filterwarnings('ignore')
20
  from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers
21
 
22
- from config import main_model_path, hierarchy_model_path, color_emb_dim, hierarchy_emb_dim, local_dataset_path, column_local_image_path, fashion_mnist_test_path
23
 
24
 
25
  def create_fashion_mnist_to_hierarchy_mapping(hierarchy_classes):
@@ -176,7 +176,7 @@ class FashionMNISTDataset(Dataset):
176
 
177
  def load_fashion_mnist_dataset(max_samples=1000, hierarchy_classes=None):
178
  print("📊 Loading Fashion-MNIST test dataset...")
179
- df = pd.read_csv(fashion_mnist_test_path)
180
  print(f"✅ Fashion-MNIST dataset loaded: {len(df)} samples")
181
 
182
  # Create mapping if hierarchy classes are provided
@@ -600,14 +600,14 @@ class ColorHierarchyEvaluator:
600
  return sorted(set(hierarchies))
601
 
602
  def extract_color_embeddings(self, dataloader, embedding_type='text', max_samples=10000):
603
- """Extract color embeddings from dims 0-16"""
604
  all_embeddings = []
605
  all_colors = []
606
  all_hierarchies = []
607
 
608
  sample_count = 0
609
  with torch.no_grad():
610
- for batch in tqdm(dataloader, desc=f"Extracting {embedding_type} color embeddings (dims 0-16)"):
611
  if sample_count >= max_samples:
612
  break
613
 
@@ -627,9 +627,10 @@ class ColorHierarchyEvaluator:
627
  else:
628
  embeddings = outputs.text_embeds
629
 
630
- # Extract only color embeddings (dims 0-16)
631
- color_embeddings = embeddings[:, :self.color_emb_dim]
632
 
 
633
  all_embeddings.append(color_embeddings.cpu().numpy())
634
  all_colors.extend(colors)
635
  all_hierarchies.extend(hierarchies)
@@ -670,8 +671,9 @@ class ColorHierarchyEvaluator:
670
  embeddings = outputs.text_embeds
671
 
672
  # Extract hierarchy embeddings (dims 17-79 -> indices 16:79)
673
- hierarchy_embeddings = embeddings[:, 16:79]
674
-
 
675
  all_embeddings.append(hierarchy_embeddings.cpu().numpy())
676
  all_colors.extend(colors)
677
  all_hierarchies.extend(hierarchies)
@@ -683,6 +685,46 @@ class ColorHierarchyEvaluator:
683
 
684
  return np.vstack(all_embeddings), all_colors, all_hierarchies
685
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
686
  def extract_baseline_embeddings_batch(self, dataloader, embedding_type='text', max_samples=10000):
687
  """Extract embeddings from baseline Fashion CLIP model"""
688
  all_embeddings = []
@@ -883,6 +925,52 @@ class ColorHierarchyEvaluator:
883
  predictions.append(predicted_label)
884
  return predictions
885
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
886
  def create_confusion_matrix(self, true_labels, predicted_labels, title="Confusion Matrix", label_type="Label"):
887
  """Create and plot confusion matrix"""
888
  unique_labels = sorted(list(set(true_labels + predicted_labels)))
@@ -898,11 +986,34 @@ class ColorHierarchyEvaluator:
898
  plt.tight_layout()
899
  return plt.gcf(), accuracy, cm
900
 
901
- def evaluate_classification_performance(self, embeddings, labels, embedding_type="Embeddings", label_type="Label"):
902
- """Evaluate classification performance and create confusion matrix"""
903
- predictions = self.predict_labels_from_embeddings(embeddings, labels)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
904
  accuracy = accuracy_score(labels, predictions)
905
- fig, acc, cm = self.create_confusion_matrix(labels, predictions, f"{embedding_type} - {label_type} Classification", label_type)
 
 
 
 
906
  unique_labels = sorted(list(set(labels)))
907
  report = classification_report(labels, predictions, labels=unique_labels, target_names=unique_labels, output_dict=True)
908
  return {
@@ -1046,68 +1157,79 @@ class ColorHierarchyEvaluator:
1046
 
1047
  results = {}
1048
 
1049
- # ========== COLOR EVALUATION (DIMS 0-16) ==========
1050
- print("\n🎨 COLOR EVALUATION (dims 0-16)")
 
 
 
 
 
 
 
1051
  print("=" * 50)
1052
 
1053
- # Text color embeddings
1054
- print("\n📝 Extracting text color embeddings...")
1055
- text_color_embeddings, text_colors, _ = self.extract_color_embeddings(dataloader, 'text', max_samples)
1056
- print(f" Text color embeddings shape: {text_color_embeddings.shape}")
1057
- text_color_metrics = self.compute_similarity_metrics(text_color_embeddings, text_colors)
 
1058
  text_color_class = self.evaluate_classification_performance(
1059
- text_color_embeddings, text_colors, "Text Color Embeddings (16D)", "Color"
 
 
1060
  )
1061
  text_color_metrics.update(text_color_class)
1062
  results['text_color'] = text_color_metrics
1063
 
1064
- del text_color_embeddings
1065
- torch.cuda.empty_cache() if torch.cuda.is_available() else None
1066
-
1067
- # Image color embeddings
1068
- print("\n🖼️ Extracting image color embeddings...")
1069
- image_color_embeddings, image_colors, _ = self.extract_color_embeddings(dataloader, 'image', max_samples)
1070
- print(f" Image color embeddings shape: {image_color_embeddings.shape}")
1071
- image_color_metrics = self.compute_similarity_metrics(image_color_embeddings, image_colors)
1072
  image_color_class = self.evaluate_classification_performance(
1073
- image_color_embeddings, image_colors, "Image Color Embeddings (16D)", "Color"
 
 
1074
  )
1075
  image_color_metrics.update(image_color_class)
1076
  results['image_color'] = image_color_metrics
1077
 
1078
- del image_color_embeddings
1079
- torch.cuda.empty_cache() if torch.cuda.is_available() else None
1080
-
1081
- # ========== HIERARCHY EVALUATION (DIMS 16-79) ==========
1082
- print("\n📋 HIERARCHY EVALUATION (dims 16-79)")
1083
  print("=" * 50)
1084
 
1085
- # Text hierarchy embeddings
1086
- print("\n📝 Extracting text hierarchy embeddings...")
1087
- text_hierarchy_embeddings, _, text_hierarchies = self.extract_hierarchy_embeddings(dataloader, 'text', max_samples)
1088
- print(f" Text hierarchy embeddings shape: {text_hierarchy_embeddings.shape}")
1089
- text_hierarchy_metrics = self.compute_similarity_metrics(text_hierarchy_embeddings, text_hierarchies)
 
1090
  text_hierarchy_class = self.evaluate_classification_performance(
1091
- text_hierarchy_embeddings, text_hierarchies, "Text Hierarchy Embeddings (64D)", "Hierarchy"
 
 
1092
  )
1093
  text_hierarchy_metrics.update(text_hierarchy_class)
1094
  results['text_hierarchy'] = text_hierarchy_metrics
1095
 
1096
- del text_hierarchy_embeddings
1097
- torch.cuda.empty_cache() if torch.cuda.is_available() else None
1098
-
1099
- # Image hierarchy embeddings
1100
- print("\n🖼️ Extracting image hierarchy embeddings...")
1101
- image_hierarchy_embeddings, _, image_hierarchies = self.extract_hierarchy_embeddings(dataloader, 'image', max_samples)
1102
- print(f" Image hierarchy embeddings shape: {image_hierarchy_embeddings.shape}")
1103
- image_hierarchy_metrics = self.compute_similarity_metrics(image_hierarchy_embeddings, image_hierarchies)
1104
  image_hierarchy_class = self.evaluate_classification_performance(
1105
- image_hierarchy_embeddings, image_hierarchies, "Image Hierarchy Embeddings (64D)", "Hierarchy"
 
 
1106
  )
1107
  image_hierarchy_metrics.update(image_hierarchy_class)
1108
  results['image_hierarchy'] = image_hierarchy_metrics
1109
 
1110
- del image_hierarchy_embeddings
 
 
 
1111
  torch.cuda.empty_cache() if torch.cuda.is_available() else None
1112
 
1113
  # ========== SAVE VISUALIZATIONS ==========
@@ -1724,7 +1846,7 @@ class ColorHierarchyEvaluator:
1724
  'trained': trained_color_img_acc,
1725
  'baseline': baseline_color_img_acc,
1726
  'diff': diff,
1727
- 'trained_dims': '0-16 (17 dims)',
1728
  'baseline_dims': 'All dimensions (512 dims)'
1729
  })
1730
 
@@ -1779,7 +1901,7 @@ class ColorHierarchyEvaluator:
1779
  print("\nRaisons probables:")
1780
  print("\n1. 📐 CAPACITÉ DIMENSIONNELLE:")
1781
  print(" • Baseline: Utilise TOUTES les 512 dimensions des embeddings")
1782
- print(" • Modèle entraîné: Utilise seulement 17 dims (couleur) ou 64 dims (hiérarchie)")
1783
  print(" • Impact: La baseline a accès à plus d'information pour la classification")
1784
 
1785
  print("\n2. 🎯 SUR-SPÉCIALISATION:")
@@ -1829,7 +1951,7 @@ if __name__ == "__main__":
1829
  device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
1830
  print(f"Using device: {device}")
1831
 
1832
- directory = 'main_model_analysi'
1833
  max_samples = 10000
1834
 
1835
  evaluator = ColorHierarchyEvaluator(device=device, directory=directory)
 
19
  warnings.filterwarnings('ignore')
20
  from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers
21
 
22
+ from config import main_model_path, hierarchy_model_path, color_emb_dim, hierarchy_emb_dim, local_dataset_path, column_local_image_path
23
 
24
 
25
  def create_fashion_mnist_to_hierarchy_mapping(hierarchy_classes):
 
176
 
177
  def load_fashion_mnist_dataset(max_samples=1000, hierarchy_classes=None):
178
  print("📊 Loading Fashion-MNIST test dataset...")
179
+ df = pd.read_csv("/Users/leaattiasarfati/Desktop/docs/search/old/MainModel/data/fashion-mnist_test.csv")
180
  print(f"✅ Fashion-MNIST dataset loaded: {len(df)} samples")
181
 
182
  # Create mapping if hierarchy classes are provided
 
600
  return sorted(set(hierarchies))
601
 
602
  def extract_color_embeddings(self, dataloader, embedding_type='text', max_samples=10000):
603
+ """Extract color embeddings from dims 0-15 (16 dimensions)"""
604
  all_embeddings = []
605
  all_colors = []
606
  all_hierarchies = []
607
 
608
  sample_count = 0
609
  with torch.no_grad():
610
+ for batch in tqdm(dataloader, desc=f"Extracting {embedding_type} color embeddings (dims 0-15)"):
611
  if sample_count >= max_samples:
612
  break
613
 
 
627
  else:
628
  embeddings = outputs.text_embeds
629
 
630
+ # Extract only color embeddings (dims 0-15, i.e., first 16 dimensions)
631
+ # color_embeddings = embeddings[:, :self.color_emb_dim]
632
 
633
+ color_embeddings = embeddings
634
  all_embeddings.append(color_embeddings.cpu().numpy())
635
  all_colors.extend(colors)
636
  all_hierarchies.extend(hierarchies)
 
671
  embeddings = outputs.text_embeds
672
 
673
  # Extract hierarchy embeddings (dims 17-79 -> indices 16:79)
674
+ # hierarchy_embeddings = embeddings[:, 16:79]
675
+
676
+ hierarchy_embeddings = embeddings
677
  all_embeddings.append(hierarchy_embeddings.cpu().numpy())
678
  all_colors.extend(colors)
679
  all_hierarchies.extend(hierarchies)
 
685
 
686
  return np.vstack(all_embeddings), all_colors, all_hierarchies
687
 
688
+ def extract_full_embeddings(self, dataloader, embedding_type='text', max_samples=10000):
689
+ """Extract full 512-dimensional embeddings (all dimensions)"""
690
+ all_embeddings = []
691
+ all_colors = []
692
+ all_hierarchies = []
693
+
694
+ sample_count = 0
695
+ with torch.no_grad():
696
+ for batch in tqdm(dataloader, desc=f"Extracting {embedding_type} FULL embeddings (all dims)"):
697
+ if sample_count >= max_samples:
698
+ break
699
+
700
+ images, texts, colors, hierarchies = batch
701
+ images = images.to(self.device)
702
+ images = images.expand(-1, 3, -1, -1)
703
+
704
+ text_inputs = self.processor(text=texts, padding=True, return_tensors="pt")
705
+ text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
706
+
707
+ outputs = self.model(**text_inputs, pixel_values=images)
708
+
709
+ if embedding_type == 'text':
710
+ embeddings = outputs.text_embeds
711
+ elif embedding_type == 'image':
712
+ embeddings = outputs.image_embeds
713
+ else:
714
+ embeddings = outputs.text_embeds
715
+
716
+ # Use all 512 dimensions
717
+ all_embeddings.append(embeddings.cpu().numpy())
718
+ all_colors.extend(colors)
719
+ all_hierarchies.extend(hierarchies)
720
+
721
+ sample_count += len(images)
722
+
723
+ del images, text_inputs, outputs, embeddings
724
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
725
+
726
+ return np.vstack(all_embeddings), all_colors, all_hierarchies
727
+
728
  def extract_baseline_embeddings_batch(self, dataloader, embedding_type='text', max_samples=10000):
729
  """Extract embeddings from baseline Fashion CLIP model"""
730
  all_embeddings = []
 
925
  predictions.append(predicted_label)
926
  return predictions
927
 
928
+ def predict_labels_ensemble(self, specialized_embeddings, full_embeddings, labels,
929
+ specialized_weight=0.5):
930
+ """
931
+ Ensemble prediction combining specialized (16/64 dims) and full (512 dims) embeddings.
932
+
933
+ Args:
934
+ specialized_embeddings: Embeddings from specialized dimensions (e.g., dims 0-15 for color)
935
+ full_embeddings: Full 512-dimensional embeddings
936
+ labels: True labels for computing centroids
937
+ specialized_weight: Weight for specialized embeddings (0.0 = only full, 1.0 = only specialized)
938
+
939
+ Returns:
940
+ List of predicted labels using weighted ensemble
941
+ """
942
+ unique_labels = list(set(labels))
943
+
944
+ # Compute centroids for both specialized and full embeddings
945
+ specialized_centroids = {}
946
+ full_centroids = {}
947
+
948
+ for label in unique_labels:
949
+ label_indices = [i for i, l in enumerate(labels) if l == label]
950
+ specialized_centroids[label] = np.mean(specialized_embeddings[label_indices], axis=0)
951
+ full_centroids[label] = np.mean(full_embeddings[label_indices], axis=0)
952
+
953
+ predictions = []
954
+ for i in range(len(specialized_embeddings)):
955
+ best_combined_score = -np.inf
956
+ predicted_label = None
957
+
958
+ for label in unique_labels:
959
+ # Compute similarity scores for both specialized and full
960
+ spec_sim = cosine_similarity([specialized_embeddings[i]], [specialized_centroids[label]])[0][0]
961
+ full_sim = cosine_similarity([full_embeddings[i]], [full_centroids[label]])[0][0]
962
+
963
+ # Weighted combination
964
+ combined_score = specialized_weight * spec_sim + (1 - specialized_weight) * full_sim
965
+
966
+ if combined_score > best_combined_score:
967
+ best_combined_score = combined_score
968
+ predicted_label = label
969
+
970
+ predictions.append(predicted_label)
971
+
972
+ return predictions
973
+
974
  def create_confusion_matrix(self, true_labels, predicted_labels, title="Confusion Matrix", label_type="Label"):
975
  """Create and plot confusion matrix"""
976
  unique_labels = sorted(list(set(true_labels + predicted_labels)))
 
986
  plt.tight_layout()
987
  return plt.gcf(), accuracy, cm
988
 
989
+ def evaluate_classification_performance(self, embeddings, labels, embedding_type="Embeddings", label_type="Label",
990
+ full_embeddings=None, ensemble_weight=0.5):
991
+ """
992
+ Evaluate classification performance and create confusion matrix.
993
+
994
+ Args:
995
+ embeddings: Specialized embeddings (e.g., dims 0-15 for color or dims 16-79 for hierarchy)
996
+ labels: True labels
997
+ embedding_type: Type of embeddings for display
998
+ label_type: Type of labels (Color/Hierarchy)
999
+ full_embeddings: Optional full 512-dim embeddings for ensemble (if None, uses only specialized)
1000
+ ensemble_weight: Weight for specialized embeddings in ensemble (0.0 = only full, 1.0 = only specialized)
1001
+ """
1002
+ if full_embeddings is not None:
1003
+ # Use ensemble prediction
1004
+ predictions = self.predict_labels_ensemble(embeddings, full_embeddings, labels, ensemble_weight)
1005
+ title_suffix = f" (Ensemble: {ensemble_weight:.1f} specialized + {1-ensemble_weight:.1f} full)"
1006
+ else:
1007
+ # Use only specialized embeddings
1008
+ predictions = self.predict_labels_from_embeddings(embeddings, labels)
1009
+ title_suffix = ""
1010
+
1011
  accuracy = accuracy_score(labels, predictions)
1012
+ fig, acc, cm = self.create_confusion_matrix(
1013
+ labels, predictions,
1014
+ f"{embedding_type} - {label_type} Classification{title_suffix}",
1015
+ label_type
1016
+ )
1017
  unique_labels = sorted(list(set(labels)))
1018
  report = classification_report(labels, predictions, labels=unique_labels, target_names=unique_labels, output_dict=True)
1019
  return {
 
1157
 
1158
  results = {}
1159
 
1160
+ # ========== EXTRACT FULL EMBEDDINGS FOR ENSEMBLE ==========
1161
+ print("\n📦 Extracting full 512-dimensional embeddings for ensemble...")
1162
+ text_full_embeddings, text_colors_full, text_hierarchies_full = self.extract_full_embeddings(dataloader, 'text', max_samples)
1163
+ image_full_embeddings, image_colors_full, image_hierarchies_full = self.extract_full_embeddings(dataloader, 'image', max_samples)
1164
+ print(f" Text full embeddings shape: {text_full_embeddings.shape}")
1165
+ print(f" Image full embeddings shape: {image_full_embeddings.shape}")
1166
+
1167
+ # ========== COLOR EVALUATION (DIMS 0-15) WITH ENSEMBLE ==========
1168
+ print("\n🎨 COLOR EVALUATION (dims 0-15) - Using Ensemble")
1169
  print("=" * 50)
1170
 
1171
+ # Extract specialized color embeddings (dims 0-15)
1172
+ print("\n📝 Extracting specialized text color embeddings (dims 0-15)...")
1173
+ text_color_embeddings_spec = text_full_embeddings[:, :self.color_emb_dim] # First 16 dims
1174
+ print(f" Specialized text color embeddings shape: {text_color_embeddings_spec.shape}")
1175
+ text_color_metrics = self.compute_similarity_metrics(text_color_embeddings_spec, text_colors_full)
1176
+ # Use ensemble: combine specialized (16D) + full (512D)
1177
  text_color_class = self.evaluate_classification_performance(
1178
+ text_color_embeddings_spec, text_colors_full,
1179
+ "Text Color Embeddings (Ensemble)", "Color",
1180
+ full_embeddings=text_full_embeddings, ensemble_weight=0.4 # 40% specialized, 60% full
1181
  )
1182
  text_color_metrics.update(text_color_class)
1183
  results['text_color'] = text_color_metrics
1184
 
1185
+ # Image color embeddings with ensemble
1186
+ print("\n🖼️ Extracting specialized image color embeddings (dims 0-15)...")
1187
+ image_color_embeddings_spec = image_full_embeddings[:, :self.color_emb_dim] # First 16 dims
1188
+ print(f" Specialized image color embeddings shape: {image_color_embeddings_spec.shape}")
1189
+ image_color_metrics = self.compute_similarity_metrics(image_color_embeddings_spec, image_colors_full)
 
 
 
1190
  image_color_class = self.evaluate_classification_performance(
1191
+ image_color_embeddings_spec, image_colors_full,
1192
+ "Image Color Embeddings (Ensemble)", "Color",
1193
+ full_embeddings=image_full_embeddings, ensemble_weight=0.4
1194
  )
1195
  image_color_metrics.update(image_color_class)
1196
  results['image_color'] = image_color_metrics
1197
 
1198
+ # ========== HIERARCHY EVALUATION (DIMS 16-79) WITH ENSEMBLE ==========
1199
+ print("\n📋 HIERARCHY EVALUATION (dims 16-79) - Using Ensemble")
 
 
 
1200
  print("=" * 50)
1201
 
1202
+ # Extract specialized hierarchy embeddings (dims 16-79)
1203
+ print("\n📝 Extracting specialized text hierarchy embeddings (dims 16-79)...")
1204
+ text_hierarchy_embeddings_spec = text_full_embeddings[:, self.color_emb_dim:self.color_emb_dim+self.hierarchy_emb_dim] # dims 16-79
1205
+ print(f" Specialized text hierarchy embeddings shape: {text_hierarchy_embeddings_spec.shape}")
1206
+ text_hierarchy_metrics = self.compute_similarity_metrics(text_hierarchy_embeddings_spec, text_hierarchies_full)
1207
+ # Use ensemble: combine specialized (64D) + full (512D)
1208
  text_hierarchy_class = self.evaluate_classification_performance(
1209
+ text_hierarchy_embeddings_spec, text_hierarchies_full,
1210
+ "Text Hierarchy Embeddings (Ensemble)", "Hierarchy",
1211
+ full_embeddings=text_full_embeddings, ensemble_weight=0.4
1212
  )
1213
  text_hierarchy_metrics.update(text_hierarchy_class)
1214
  results['text_hierarchy'] = text_hierarchy_metrics
1215
 
1216
+ # Image hierarchy embeddings with ensemble
1217
+ print("\n🖼️ Extracting specialized image hierarchy embeddings (dims 16-79)...")
1218
+ image_hierarchy_embeddings_spec = image_full_embeddings[:, self.color_emb_dim:self.color_emb_dim+self.hierarchy_emb_dim] # dims 16-79
1219
+ print(f" Specialized image hierarchy embeddings shape: {image_hierarchy_embeddings_spec.shape}")
1220
+ image_hierarchy_metrics = self.compute_similarity_metrics(image_hierarchy_embeddings_spec, image_hierarchies_full)
 
 
 
1221
  image_hierarchy_class = self.evaluate_classification_performance(
1222
+ image_hierarchy_embeddings_spec, image_hierarchies_full,
1223
+ "Image Hierarchy Embeddings (Ensemble)", "Hierarchy",
1224
+ full_embeddings=image_full_embeddings, ensemble_weight=0.4
1225
  )
1226
  image_hierarchy_metrics.update(image_hierarchy_class)
1227
  results['image_hierarchy'] = image_hierarchy_metrics
1228
 
1229
+ # Cleanup
1230
+ del text_full_embeddings, image_full_embeddings
1231
+ del text_color_embeddings_spec, image_color_embeddings_spec
1232
+ del text_hierarchy_embeddings_spec, image_hierarchy_embeddings_spec
1233
  torch.cuda.empty_cache() if torch.cuda.is_available() else None
1234
 
1235
  # ========== SAVE VISUALIZATIONS ==========
 
1846
  'trained': trained_color_img_acc,
1847
  'baseline': baseline_color_img_acc,
1848
  'diff': diff,
1849
+ 'trained_dims': '0-15 (16 dims)',
1850
  'baseline_dims': 'All dimensions (512 dims)'
1851
  })
1852
 
 
1901
  print("\nRaisons probables:")
1902
  print("\n1. 📐 CAPACITÉ DIMENSIONNELLE:")
1903
  print(" • Baseline: Utilise TOUTES les 512 dimensions des embeddings")
1904
+ print(" • Modèle entraîné: Utilise seulement 16 dims (couleur) ou 64 dims (hiérarchie)")
1905
  print(" • Impact: La baseline a accès à plus d'information pour la classification")
1906
 
1907
  print("\n2. 🎯 SUR-SPÉCIALISATION:")
 
1951
  device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
1952
  print(f"Using device: {device}")
1953
 
1954
+ directory = 'main_model_analysis_model'
1955
  max_samples = 10000
1956
 
1957
  evaluator = ColorHierarchyEvaluator(device=device, directory=directory)