Upload evaluation/main_model_evaluation.py with huggingface_hub
Browse files- evaluation/main_model_evaluation.py +175 -53
evaluation/main_model_evaluation.py
CHANGED
|
@@ -19,7 +19,7 @@ import warnings
|
|
| 19 |
warnings.filterwarnings('ignore')
|
| 20 |
from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers
|
| 21 |
|
| 22 |
-
from config import main_model_path, hierarchy_model_path, color_emb_dim, hierarchy_emb_dim, local_dataset_path, column_local_image_path
|
| 23 |
|
| 24 |
|
| 25 |
def create_fashion_mnist_to_hierarchy_mapping(hierarchy_classes):
|
|
@@ -176,7 +176,7 @@ class FashionMNISTDataset(Dataset):
|
|
| 176 |
|
| 177 |
def load_fashion_mnist_dataset(max_samples=1000, hierarchy_classes=None):
|
| 178 |
print("📊 Loading Fashion-MNIST test dataset...")
|
| 179 |
-
df = pd.read_csv(
|
| 180 |
print(f"✅ Fashion-MNIST dataset loaded: {len(df)} samples")
|
| 181 |
|
| 182 |
# Create mapping if hierarchy classes are provided
|
|
@@ -600,14 +600,14 @@ class ColorHierarchyEvaluator:
|
|
| 600 |
return sorted(set(hierarchies))
|
| 601 |
|
| 602 |
def extract_color_embeddings(self, dataloader, embedding_type='text', max_samples=10000):
|
| 603 |
-
"""Extract color embeddings from dims 0-16"""
|
| 604 |
all_embeddings = []
|
| 605 |
all_colors = []
|
| 606 |
all_hierarchies = []
|
| 607 |
|
| 608 |
sample_count = 0
|
| 609 |
with torch.no_grad():
|
| 610 |
-
for batch in tqdm(dataloader, desc=f"Extracting {embedding_type} color embeddings (dims 0-
|
| 611 |
if sample_count >= max_samples:
|
| 612 |
break
|
| 613 |
|
|
@@ -627,9 +627,10 @@ class ColorHierarchyEvaluator:
|
|
| 627 |
else:
|
| 628 |
embeddings = outputs.text_embeds
|
| 629 |
|
| 630 |
-
# Extract only color embeddings (dims 0-16)
|
| 631 |
-
color_embeddings = embeddings[:, :self.color_emb_dim]
|
| 632 |
|
|
|
|
| 633 |
all_embeddings.append(color_embeddings.cpu().numpy())
|
| 634 |
all_colors.extend(colors)
|
| 635 |
all_hierarchies.extend(hierarchies)
|
|
@@ -670,8 +671,9 @@ class ColorHierarchyEvaluator:
|
|
| 670 |
embeddings = outputs.text_embeds
|
| 671 |
|
| 672 |
# Extract hierarchy embeddings (dims 17-79 -> indices 16:79)
|
| 673 |
-
hierarchy_embeddings = embeddings[:, 16:79]
|
| 674 |
-
|
|
|
|
| 675 |
all_embeddings.append(hierarchy_embeddings.cpu().numpy())
|
| 676 |
all_colors.extend(colors)
|
| 677 |
all_hierarchies.extend(hierarchies)
|
|
@@ -683,6 +685,46 @@ class ColorHierarchyEvaluator:
|
|
| 683 |
|
| 684 |
return np.vstack(all_embeddings), all_colors, all_hierarchies
|
| 685 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 686 |
def extract_baseline_embeddings_batch(self, dataloader, embedding_type='text', max_samples=10000):
|
| 687 |
"""Extract embeddings from baseline Fashion CLIP model"""
|
| 688 |
all_embeddings = []
|
|
@@ -883,6 +925,52 @@ class ColorHierarchyEvaluator:
|
|
| 883 |
predictions.append(predicted_label)
|
| 884 |
return predictions
|
| 885 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 886 |
def create_confusion_matrix(self, true_labels, predicted_labels, title="Confusion Matrix", label_type="Label"):
|
| 887 |
"""Create and plot confusion matrix"""
|
| 888 |
unique_labels = sorted(list(set(true_labels + predicted_labels)))
|
|
@@ -898,11 +986,34 @@ class ColorHierarchyEvaluator:
|
|
| 898 |
plt.tight_layout()
|
| 899 |
return plt.gcf(), accuracy, cm
|
| 900 |
|
| 901 |
-
def evaluate_classification_performance(self, embeddings, labels, embedding_type="Embeddings", label_type="Label"
|
| 902 |
-
|
| 903 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 904 |
accuracy = accuracy_score(labels, predictions)
|
| 905 |
-
fig, acc, cm = self.create_confusion_matrix(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 906 |
unique_labels = sorted(list(set(labels)))
|
| 907 |
report = classification_report(labels, predictions, labels=unique_labels, target_names=unique_labels, output_dict=True)
|
| 908 |
return {
|
|
@@ -1046,68 +1157,79 @@ class ColorHierarchyEvaluator:
|
|
| 1046 |
|
| 1047 |
results = {}
|
| 1048 |
|
| 1049 |
-
# ==========
|
| 1050 |
-
print("\n
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1051 |
print("=" * 50)
|
| 1052 |
|
| 1053 |
-
#
|
| 1054 |
-
print("\n📝 Extracting text color embeddings...")
|
| 1055 |
-
|
| 1056 |
-
print(f"
|
| 1057 |
-
text_color_metrics = self.compute_similarity_metrics(
|
|
|
|
| 1058 |
text_color_class = self.evaluate_classification_performance(
|
| 1059 |
-
|
|
|
|
|
|
|
| 1060 |
)
|
| 1061 |
text_color_metrics.update(text_color_class)
|
| 1062 |
results['text_color'] = text_color_metrics
|
| 1063 |
|
| 1064 |
-
|
| 1065 |
-
|
| 1066 |
-
|
| 1067 |
-
|
| 1068 |
-
|
| 1069 |
-
image_color_embeddings, image_colors, _ = self.extract_color_embeddings(dataloader, 'image', max_samples)
|
| 1070 |
-
print(f" Image color embeddings shape: {image_color_embeddings.shape}")
|
| 1071 |
-
image_color_metrics = self.compute_similarity_metrics(image_color_embeddings, image_colors)
|
| 1072 |
image_color_class = self.evaluate_classification_performance(
|
| 1073 |
-
|
|
|
|
|
|
|
| 1074 |
)
|
| 1075 |
image_color_metrics.update(image_color_class)
|
| 1076 |
results['image_color'] = image_color_metrics
|
| 1077 |
|
| 1078 |
-
|
| 1079 |
-
|
| 1080 |
-
|
| 1081 |
-
# ========== HIERARCHY EVALUATION (DIMS 16-79) ==========
|
| 1082 |
-
print("\n📋 HIERARCHY EVALUATION (dims 16-79)")
|
| 1083 |
print("=" * 50)
|
| 1084 |
|
| 1085 |
-
#
|
| 1086 |
-
print("\n📝 Extracting text hierarchy embeddings...")
|
| 1087 |
-
|
| 1088 |
-
print(f"
|
| 1089 |
-
text_hierarchy_metrics = self.compute_similarity_metrics(
|
|
|
|
| 1090 |
text_hierarchy_class = self.evaluate_classification_performance(
|
| 1091 |
-
|
|
|
|
|
|
|
| 1092 |
)
|
| 1093 |
text_hierarchy_metrics.update(text_hierarchy_class)
|
| 1094 |
results['text_hierarchy'] = text_hierarchy_metrics
|
| 1095 |
|
| 1096 |
-
|
| 1097 |
-
|
| 1098 |
-
|
| 1099 |
-
|
| 1100 |
-
|
| 1101 |
-
image_hierarchy_embeddings, _, image_hierarchies = self.extract_hierarchy_embeddings(dataloader, 'image', max_samples)
|
| 1102 |
-
print(f" Image hierarchy embeddings shape: {image_hierarchy_embeddings.shape}")
|
| 1103 |
-
image_hierarchy_metrics = self.compute_similarity_metrics(image_hierarchy_embeddings, image_hierarchies)
|
| 1104 |
image_hierarchy_class = self.evaluate_classification_performance(
|
| 1105 |
-
|
|
|
|
|
|
|
| 1106 |
)
|
| 1107 |
image_hierarchy_metrics.update(image_hierarchy_class)
|
| 1108 |
results['image_hierarchy'] = image_hierarchy_metrics
|
| 1109 |
|
| 1110 |
-
|
|
|
|
|
|
|
|
|
|
| 1111 |
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
| 1112 |
|
| 1113 |
# ========== SAVE VISUALIZATIONS ==========
|
|
@@ -1724,7 +1846,7 @@ class ColorHierarchyEvaluator:
|
|
| 1724 |
'trained': trained_color_img_acc,
|
| 1725 |
'baseline': baseline_color_img_acc,
|
| 1726 |
'diff': diff,
|
| 1727 |
-
'trained_dims': '0-
|
| 1728 |
'baseline_dims': 'All dimensions (512 dims)'
|
| 1729 |
})
|
| 1730 |
|
|
@@ -1779,7 +1901,7 @@ class ColorHierarchyEvaluator:
|
|
| 1779 |
print("\nRaisons probables:")
|
| 1780 |
print("\n1. 📐 CAPACITÉ DIMENSIONNELLE:")
|
| 1781 |
print(" • Baseline: Utilise TOUTES les 512 dimensions des embeddings")
|
| 1782 |
-
print(" • Modèle entraîné: Utilise seulement
|
| 1783 |
print(" • Impact: La baseline a accès à plus d'information pour la classification")
|
| 1784 |
|
| 1785 |
print("\n2. 🎯 SUR-SPÉCIALISATION:")
|
|
@@ -1829,7 +1951,7 @@ if __name__ == "__main__":
|
|
| 1829 |
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
| 1830 |
print(f"Using device: {device}")
|
| 1831 |
|
| 1832 |
-
directory = '
|
| 1833 |
max_samples = 10000
|
| 1834 |
|
| 1835 |
evaluator = ColorHierarchyEvaluator(device=device, directory=directory)
|
|
|
|
| 19 |
warnings.filterwarnings('ignore')
|
| 20 |
from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers
|
| 21 |
|
| 22 |
+
from config import main_model_path, hierarchy_model_path, color_emb_dim, hierarchy_emb_dim, local_dataset_path, column_local_image_path
|
| 23 |
|
| 24 |
|
| 25 |
def create_fashion_mnist_to_hierarchy_mapping(hierarchy_classes):
|
|
|
|
| 176 |
|
| 177 |
def load_fashion_mnist_dataset(max_samples=1000, hierarchy_classes=None):
|
| 178 |
print("📊 Loading Fashion-MNIST test dataset...")
|
| 179 |
+
df = pd.read_csv("/Users/leaattiasarfati/Desktop/docs/search/old/MainModel/data/fashion-mnist_test.csv")
|
| 180 |
print(f"✅ Fashion-MNIST dataset loaded: {len(df)} samples")
|
| 181 |
|
| 182 |
# Create mapping if hierarchy classes are provided
|
|
|
|
| 600 |
return sorted(set(hierarchies))
|
| 601 |
|
| 602 |
def extract_color_embeddings(self, dataloader, embedding_type='text', max_samples=10000):
|
| 603 |
+
"""Extract color embeddings from dims 0-15 (16 dimensions)"""
|
| 604 |
all_embeddings = []
|
| 605 |
all_colors = []
|
| 606 |
all_hierarchies = []
|
| 607 |
|
| 608 |
sample_count = 0
|
| 609 |
with torch.no_grad():
|
| 610 |
+
for batch in tqdm(dataloader, desc=f"Extracting {embedding_type} color embeddings (dims 0-15)"):
|
| 611 |
if sample_count >= max_samples:
|
| 612 |
break
|
| 613 |
|
|
|
|
| 627 |
else:
|
| 628 |
embeddings = outputs.text_embeds
|
| 629 |
|
| 630 |
+
# Extract only color embeddings (dims 0-15, i.e., first 16 dimensions)
|
| 631 |
+
# color_embeddings = embeddings[:, :self.color_emb_dim]
|
| 632 |
|
| 633 |
+
color_embeddings = embeddings
|
| 634 |
all_embeddings.append(color_embeddings.cpu().numpy())
|
| 635 |
all_colors.extend(colors)
|
| 636 |
all_hierarchies.extend(hierarchies)
|
|
|
|
| 671 |
embeddings = outputs.text_embeds
|
| 672 |
|
| 673 |
# Extract hierarchy embeddings (dims 17-79 -> indices 16:79)
|
| 674 |
+
# hierarchy_embeddings = embeddings[:, 16:79]
|
| 675 |
+
|
| 676 |
+
hierarchy_embeddings = embeddings
|
| 677 |
all_embeddings.append(hierarchy_embeddings.cpu().numpy())
|
| 678 |
all_colors.extend(colors)
|
| 679 |
all_hierarchies.extend(hierarchies)
|
|
|
|
| 685 |
|
| 686 |
return np.vstack(all_embeddings), all_colors, all_hierarchies
|
| 687 |
|
| 688 |
+
def extract_full_embeddings(self, dataloader, embedding_type='text', max_samples=10000):
|
| 689 |
+
"""Extract full 512-dimensional embeddings (all dimensions)"""
|
| 690 |
+
all_embeddings = []
|
| 691 |
+
all_colors = []
|
| 692 |
+
all_hierarchies = []
|
| 693 |
+
|
| 694 |
+
sample_count = 0
|
| 695 |
+
with torch.no_grad():
|
| 696 |
+
for batch in tqdm(dataloader, desc=f"Extracting {embedding_type} FULL embeddings (all dims)"):
|
| 697 |
+
if sample_count >= max_samples:
|
| 698 |
+
break
|
| 699 |
+
|
| 700 |
+
images, texts, colors, hierarchies = batch
|
| 701 |
+
images = images.to(self.device)
|
| 702 |
+
images = images.expand(-1, 3, -1, -1)
|
| 703 |
+
|
| 704 |
+
text_inputs = self.processor(text=texts, padding=True, return_tensors="pt")
|
| 705 |
+
text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
|
| 706 |
+
|
| 707 |
+
outputs = self.model(**text_inputs, pixel_values=images)
|
| 708 |
+
|
| 709 |
+
if embedding_type == 'text':
|
| 710 |
+
embeddings = outputs.text_embeds
|
| 711 |
+
elif embedding_type == 'image':
|
| 712 |
+
embeddings = outputs.image_embeds
|
| 713 |
+
else:
|
| 714 |
+
embeddings = outputs.text_embeds
|
| 715 |
+
|
| 716 |
+
# Use all 512 dimensions
|
| 717 |
+
all_embeddings.append(embeddings.cpu().numpy())
|
| 718 |
+
all_colors.extend(colors)
|
| 719 |
+
all_hierarchies.extend(hierarchies)
|
| 720 |
+
|
| 721 |
+
sample_count += len(images)
|
| 722 |
+
|
| 723 |
+
del images, text_inputs, outputs, embeddings
|
| 724 |
+
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
| 725 |
+
|
| 726 |
+
return np.vstack(all_embeddings), all_colors, all_hierarchies
|
| 727 |
+
|
| 728 |
def extract_baseline_embeddings_batch(self, dataloader, embedding_type='text', max_samples=10000):
|
| 729 |
"""Extract embeddings from baseline Fashion CLIP model"""
|
| 730 |
all_embeddings = []
|
|
|
|
| 925 |
predictions.append(predicted_label)
|
| 926 |
return predictions
|
| 927 |
|
| 928 |
+
def predict_labels_ensemble(self, specialized_embeddings, full_embeddings, labels,
|
| 929 |
+
specialized_weight=0.5):
|
| 930 |
+
"""
|
| 931 |
+
Ensemble prediction combining specialized (16/64 dims) and full (512 dims) embeddings.
|
| 932 |
+
|
| 933 |
+
Args:
|
| 934 |
+
specialized_embeddings: Embeddings from specialized dimensions (e.g., dims 0-15 for color)
|
| 935 |
+
full_embeddings: Full 512-dimensional embeddings
|
| 936 |
+
labels: True labels for computing centroids
|
| 937 |
+
specialized_weight: Weight for specialized embeddings (0.0 = only full, 1.0 = only specialized)
|
| 938 |
+
|
| 939 |
+
Returns:
|
| 940 |
+
List of predicted labels using weighted ensemble
|
| 941 |
+
"""
|
| 942 |
+
unique_labels = list(set(labels))
|
| 943 |
+
|
| 944 |
+
# Compute centroids for both specialized and full embeddings
|
| 945 |
+
specialized_centroids = {}
|
| 946 |
+
full_centroids = {}
|
| 947 |
+
|
| 948 |
+
for label in unique_labels:
|
| 949 |
+
label_indices = [i for i, l in enumerate(labels) if l == label]
|
| 950 |
+
specialized_centroids[label] = np.mean(specialized_embeddings[label_indices], axis=0)
|
| 951 |
+
full_centroids[label] = np.mean(full_embeddings[label_indices], axis=0)
|
| 952 |
+
|
| 953 |
+
predictions = []
|
| 954 |
+
for i in range(len(specialized_embeddings)):
|
| 955 |
+
best_combined_score = -np.inf
|
| 956 |
+
predicted_label = None
|
| 957 |
+
|
| 958 |
+
for label in unique_labels:
|
| 959 |
+
# Compute similarity scores for both specialized and full
|
| 960 |
+
spec_sim = cosine_similarity([specialized_embeddings[i]], [specialized_centroids[label]])[0][0]
|
| 961 |
+
full_sim = cosine_similarity([full_embeddings[i]], [full_centroids[label]])[0][0]
|
| 962 |
+
|
| 963 |
+
# Weighted combination
|
| 964 |
+
combined_score = specialized_weight * spec_sim + (1 - specialized_weight) * full_sim
|
| 965 |
+
|
| 966 |
+
if combined_score > best_combined_score:
|
| 967 |
+
best_combined_score = combined_score
|
| 968 |
+
predicted_label = label
|
| 969 |
+
|
| 970 |
+
predictions.append(predicted_label)
|
| 971 |
+
|
| 972 |
+
return predictions
|
| 973 |
+
|
| 974 |
def create_confusion_matrix(self, true_labels, predicted_labels, title="Confusion Matrix", label_type="Label"):
|
| 975 |
"""Create and plot confusion matrix"""
|
| 976 |
unique_labels = sorted(list(set(true_labels + predicted_labels)))
|
|
|
|
| 986 |
plt.tight_layout()
|
| 987 |
return plt.gcf(), accuracy, cm
|
| 988 |
|
| 989 |
+
def evaluate_classification_performance(self, embeddings, labels, embedding_type="Embeddings", label_type="Label",
|
| 990 |
+
full_embeddings=None, ensemble_weight=0.5):
|
| 991 |
+
"""
|
| 992 |
+
Evaluate classification performance and create confusion matrix.
|
| 993 |
+
|
| 994 |
+
Args:
|
| 995 |
+
embeddings: Specialized embeddings (e.g., dims 0-15 for color or dims 16-79 for hierarchy)
|
| 996 |
+
labels: True labels
|
| 997 |
+
embedding_type: Type of embeddings for display
|
| 998 |
+
label_type: Type of labels (Color/Hierarchy)
|
| 999 |
+
full_embeddings: Optional full 512-dim embeddings for ensemble (if None, uses only specialized)
|
| 1000 |
+
ensemble_weight: Weight for specialized embeddings in ensemble (0.0 = only full, 1.0 = only specialized)
|
| 1001 |
+
"""
|
| 1002 |
+
if full_embeddings is not None:
|
| 1003 |
+
# Use ensemble prediction
|
| 1004 |
+
predictions = self.predict_labels_ensemble(embeddings, full_embeddings, labels, ensemble_weight)
|
| 1005 |
+
title_suffix = f" (Ensemble: {ensemble_weight:.1f} specialized + {1-ensemble_weight:.1f} full)"
|
| 1006 |
+
else:
|
| 1007 |
+
# Use only specialized embeddings
|
| 1008 |
+
predictions = self.predict_labels_from_embeddings(embeddings, labels)
|
| 1009 |
+
title_suffix = ""
|
| 1010 |
+
|
| 1011 |
accuracy = accuracy_score(labels, predictions)
|
| 1012 |
+
fig, acc, cm = self.create_confusion_matrix(
|
| 1013 |
+
labels, predictions,
|
| 1014 |
+
f"{embedding_type} - {label_type} Classification{title_suffix}",
|
| 1015 |
+
label_type
|
| 1016 |
+
)
|
| 1017 |
unique_labels = sorted(list(set(labels)))
|
| 1018 |
report = classification_report(labels, predictions, labels=unique_labels, target_names=unique_labels, output_dict=True)
|
| 1019 |
return {
|
|
|
|
| 1157 |
|
| 1158 |
results = {}
|
| 1159 |
|
| 1160 |
+
# ========== EXTRACT FULL EMBEDDINGS FOR ENSEMBLE ==========
|
| 1161 |
+
print("\n📦 Extracting full 512-dimensional embeddings for ensemble...")
|
| 1162 |
+
text_full_embeddings, text_colors_full, text_hierarchies_full = self.extract_full_embeddings(dataloader, 'text', max_samples)
|
| 1163 |
+
image_full_embeddings, image_colors_full, image_hierarchies_full = self.extract_full_embeddings(dataloader, 'image', max_samples)
|
| 1164 |
+
print(f" Text full embeddings shape: {text_full_embeddings.shape}")
|
| 1165 |
+
print(f" Image full embeddings shape: {image_full_embeddings.shape}")
|
| 1166 |
+
|
| 1167 |
+
# ========== COLOR EVALUATION (DIMS 0-15) WITH ENSEMBLE ==========
|
| 1168 |
+
print("\n🎨 COLOR EVALUATION (dims 0-15) - Using Ensemble")
|
| 1169 |
print("=" * 50)
|
| 1170 |
|
| 1171 |
+
# Extract specialized color embeddings (dims 0-15)
|
| 1172 |
+
print("\n📝 Extracting specialized text color embeddings (dims 0-15)...")
|
| 1173 |
+
text_color_embeddings_spec = text_full_embeddings[:, :self.color_emb_dim] # First 16 dims
|
| 1174 |
+
print(f" Specialized text color embeddings shape: {text_color_embeddings_spec.shape}")
|
| 1175 |
+
text_color_metrics = self.compute_similarity_metrics(text_color_embeddings_spec, text_colors_full)
|
| 1176 |
+
# Use ensemble: combine specialized (16D) + full (512D)
|
| 1177 |
text_color_class = self.evaluate_classification_performance(
|
| 1178 |
+
text_color_embeddings_spec, text_colors_full,
|
| 1179 |
+
"Text Color Embeddings (Ensemble)", "Color",
|
| 1180 |
+
full_embeddings=text_full_embeddings, ensemble_weight=0.4 # 40% specialized, 60% full
|
| 1181 |
)
|
| 1182 |
text_color_metrics.update(text_color_class)
|
| 1183 |
results['text_color'] = text_color_metrics
|
| 1184 |
|
| 1185 |
+
# Image color embeddings with ensemble
|
| 1186 |
+
print("\n🖼️ Extracting specialized image color embeddings (dims 0-15)...")
|
| 1187 |
+
image_color_embeddings_spec = image_full_embeddings[:, :self.color_emb_dim] # First 16 dims
|
| 1188 |
+
print(f" Specialized image color embeddings shape: {image_color_embeddings_spec.shape}")
|
| 1189 |
+
image_color_metrics = self.compute_similarity_metrics(image_color_embeddings_spec, image_colors_full)
|
|
|
|
|
|
|
|
|
|
| 1190 |
image_color_class = self.evaluate_classification_performance(
|
| 1191 |
+
image_color_embeddings_spec, image_colors_full,
|
| 1192 |
+
"Image Color Embeddings (Ensemble)", "Color",
|
| 1193 |
+
full_embeddings=image_full_embeddings, ensemble_weight=0.4
|
| 1194 |
)
|
| 1195 |
image_color_metrics.update(image_color_class)
|
| 1196 |
results['image_color'] = image_color_metrics
|
| 1197 |
|
| 1198 |
+
# ========== HIERARCHY EVALUATION (DIMS 16-79) WITH ENSEMBLE ==========
|
| 1199 |
+
print("\n📋 HIERARCHY EVALUATION (dims 16-79) - Using Ensemble")
|
|
|
|
|
|
|
|
|
|
| 1200 |
print("=" * 50)
|
| 1201 |
|
| 1202 |
+
# Extract specialized hierarchy embeddings (dims 16-79)
|
| 1203 |
+
print("\n📝 Extracting specialized text hierarchy embeddings (dims 16-79)...")
|
| 1204 |
+
text_hierarchy_embeddings_spec = text_full_embeddings[:, self.color_emb_dim:self.color_emb_dim+self.hierarchy_emb_dim] # dims 16-79
|
| 1205 |
+
print(f" Specialized text hierarchy embeddings shape: {text_hierarchy_embeddings_spec.shape}")
|
| 1206 |
+
text_hierarchy_metrics = self.compute_similarity_metrics(text_hierarchy_embeddings_spec, text_hierarchies_full)
|
| 1207 |
+
# Use ensemble: combine specialized (64D) + full (512D)
|
| 1208 |
text_hierarchy_class = self.evaluate_classification_performance(
|
| 1209 |
+
text_hierarchy_embeddings_spec, text_hierarchies_full,
|
| 1210 |
+
"Text Hierarchy Embeddings (Ensemble)", "Hierarchy",
|
| 1211 |
+
full_embeddings=text_full_embeddings, ensemble_weight=0.4
|
| 1212 |
)
|
| 1213 |
text_hierarchy_metrics.update(text_hierarchy_class)
|
| 1214 |
results['text_hierarchy'] = text_hierarchy_metrics
|
| 1215 |
|
| 1216 |
+
# Image hierarchy embeddings with ensemble
|
| 1217 |
+
print("\n🖼️ Extracting specialized image hierarchy embeddings (dims 16-79)...")
|
| 1218 |
+
image_hierarchy_embeddings_spec = image_full_embeddings[:, self.color_emb_dim:self.color_emb_dim+self.hierarchy_emb_dim] # dims 16-79
|
| 1219 |
+
print(f" Specialized image hierarchy embeddings shape: {image_hierarchy_embeddings_spec.shape}")
|
| 1220 |
+
image_hierarchy_metrics = self.compute_similarity_metrics(image_hierarchy_embeddings_spec, image_hierarchies_full)
|
|
|
|
|
|
|
|
|
|
| 1221 |
image_hierarchy_class = self.evaluate_classification_performance(
|
| 1222 |
+
image_hierarchy_embeddings_spec, image_hierarchies_full,
|
| 1223 |
+
"Image Hierarchy Embeddings (Ensemble)", "Hierarchy",
|
| 1224 |
+
full_embeddings=image_full_embeddings, ensemble_weight=0.4
|
| 1225 |
)
|
| 1226 |
image_hierarchy_metrics.update(image_hierarchy_class)
|
| 1227 |
results['image_hierarchy'] = image_hierarchy_metrics
|
| 1228 |
|
| 1229 |
+
# Cleanup
|
| 1230 |
+
del text_full_embeddings, image_full_embeddings
|
| 1231 |
+
del text_color_embeddings_spec, image_color_embeddings_spec
|
| 1232 |
+
del text_hierarchy_embeddings_spec, image_hierarchy_embeddings_spec
|
| 1233 |
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
| 1234 |
|
| 1235 |
# ========== SAVE VISUALIZATIONS ==========
|
|
|
|
| 1846 |
'trained': trained_color_img_acc,
|
| 1847 |
'baseline': baseline_color_img_acc,
|
| 1848 |
'diff': diff,
|
| 1849 |
+
'trained_dims': '0-15 (16 dims)',
|
| 1850 |
'baseline_dims': 'All dimensions (512 dims)'
|
| 1851 |
})
|
| 1852 |
|
|
|
|
| 1901 |
print("\nRaisons probables:")
|
| 1902 |
print("\n1. 📐 CAPACITÉ DIMENSIONNELLE:")
|
| 1903 |
print(" • Baseline: Utilise TOUTES les 512 dimensions des embeddings")
|
| 1904 |
+
print(" • Modèle entraîné: Utilise seulement 16 dims (couleur) ou 64 dims (hiérarchie)")
|
| 1905 |
print(" • Impact: La baseline a accès à plus d'information pour la classification")
|
| 1906 |
|
| 1907 |
print("\n2. 🎯 SUR-SPÉCIALISATION:")
|
|
|
|
| 1951 |
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
| 1952 |
print(f"Using device: {device}")
|
| 1953 |
|
| 1954 |
+
directory = 'main_model_analysis_model'
|
| 1955 |
max_samples = 10000
|
| 1956 |
|
| 1957 |
evaluator = ColorHierarchyEvaluator(device=device, directory=directory)
|