|
|
""" |
|
|
Zero-shot classification evaluation on a new dataset. |
|
|
This file evaluates the main model's performance on unseen data by performing |
|
|
zero-shot classification. It compares three methods: color-to-color classification, |
|
|
text-to-text, and image-to-text. It generates confusion matrices and classification reports |
|
|
for each method to analyze the model's generalization capability. |
|
|
""" |
|
|
|
|
|
import os |
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
from torch.utils.data import Dataset |
|
|
import matplotlib.pyplot as plt |
|
|
from PIL import Image |
|
|
from torchvision import transforms |
|
|
from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers |
|
|
import warnings |
|
|
import config |
|
|
from tqdm import tqdm |
|
|
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report |
|
|
import seaborn as sns |
|
|
from color_model import CLIPModel as ColorModel |
|
|
from hierarchy_model import Model, HierarchyExtractor |
|
|
|
|
|
|
|
|
warnings.filterwarnings("ignore", category=FutureWarning) |
|
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
|
|
|
|
def load_trained_model(model_path, device): |
|
|
""" |
|
|
Load the trained CLIP model from checkpoint |
|
|
""" |
|
|
print(f"Loading trained model from: {model_path}") |
|
|
|
|
|
|
|
|
checkpoint = torch.load(model_path, map_location=device) |
|
|
|
|
|
|
|
|
model = CLIPModel_transformers.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K') |
|
|
|
|
|
|
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
model = model.to(device) |
|
|
model.eval() |
|
|
|
|
|
print(f"β
Model loaded successfully!") |
|
|
print(f"π Training epoch: {checkpoint['epoch']}") |
|
|
print(f"π Best validation loss: {checkpoint['best_val_loss']:.4f}") |
|
|
|
|
|
return model, checkpoint |
|
|
|
|
|
def load_feature_models(device): |
|
|
"""Load feature models (color and hierarchy)""" |
|
|
|
|
|
|
|
|
color_checkpoint = torch.load(config.color_model_path, map_location=device, weights_only=True) |
|
|
color_model = ColorModel(embed_dim=config.color_emb_dim).to(device) |
|
|
color_model.load_state_dict(color_checkpoint) |
|
|
color_model.eval() |
|
|
color_model.name = 'color' |
|
|
|
|
|
|
|
|
hierarchy_checkpoint = torch.load(config.hierarchy_model_path, map_location=device) |
|
|
hierarchy_classes = hierarchy_checkpoint.get('hierarchy_classes', []) |
|
|
hierarchy_model = Model( |
|
|
num_hierarchy_classes=len(hierarchy_classes), |
|
|
embed_dim=config.hierarchy_emb_dim |
|
|
).to(device) |
|
|
hierarchy_model.load_state_dict(hierarchy_checkpoint['model_state']) |
|
|
|
|
|
|
|
|
hierarchy_extractor = HierarchyExtractor(hierarchy_classes, verbose=False) |
|
|
hierarchy_model.set_hierarchy_extractor(hierarchy_extractor) |
|
|
hierarchy_model.eval() |
|
|
hierarchy_model.name = 'hierarchy' |
|
|
|
|
|
feature_models = {model.name: model for model in [color_model, hierarchy_model]} |
|
|
return feature_models |
|
|
|
|
|
def get_image_embedding(model, image, device): |
|
|
"""Get image embedding from the trained model""" |
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
|
|
|
if image.dim() == 3 and image.size(0) == 1: |
|
|
image = image.expand(3, -1, -1) |
|
|
elif image.dim() == 4 and image.size(1) == 1: |
|
|
image = image.expand(-1, 3, -1, -1) |
|
|
|
|
|
|
|
|
if image.dim() == 3: |
|
|
image = image.unsqueeze(0) |
|
|
|
|
|
image = image.to(device) |
|
|
|
|
|
|
|
|
vision_outputs = model.vision_model(pixel_values=image) |
|
|
image_features = model.visual_projection(vision_outputs.pooler_output) |
|
|
|
|
|
return F.normalize(image_features, dim=-1) |
|
|
|
|
|
def get_text_embedding(model, text, processor, device): |
|
|
"""Get text embedding from the trained model""" |
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
text_inputs = processor(text=text, padding=True, return_tensors="pt") |
|
|
text_inputs = {k: v.to(device) for k, v in text_inputs.items()} |
|
|
|
|
|
|
|
|
text_outputs = model.text_model(**text_inputs) |
|
|
text_features = model.text_projection(text_outputs.pooler_output) |
|
|
|
|
|
return F.normalize(text_features, dim=-1) |
|
|
|
|
|
def evaluate_custom_csv_accuracy(model, dataset, processor, method='similarity'): |
|
|
""" |
|
|
Evaluate the accuracy of the model on your custom CSV using text-to-text similarity |
|
|
|
|
|
Args: |
|
|
model: The trained CLIP model |
|
|
dataset: CustomCSVDataset |
|
|
processor: CLIPProcessor |
|
|
method: 'similarity' or 'classification' |
|
|
""" |
|
|
print(f"\nπ === Evaluation of the accuracy on custom CSV (TEXT-TO-TEXT method) ===") |
|
|
|
|
|
model.eval() |
|
|
|
|
|
|
|
|
all_colors = set() |
|
|
for i in range(len(dataset)): |
|
|
_, _, color = dataset[i] |
|
|
all_colors.add(color) |
|
|
|
|
|
color_list = sorted(list(all_colors)) |
|
|
print(f"π¨ Colors found: {color_list}") |
|
|
|
|
|
true_labels = [] |
|
|
predicted_labels = [] |
|
|
|
|
|
|
|
|
print("π Pre-calculating the embeddings of the colors...") |
|
|
color_embeddings = {} |
|
|
for color in color_list: |
|
|
color_emb = get_text_embedding(model, color, processor) |
|
|
color_embeddings[color] = color_emb |
|
|
|
|
|
print("π Evaluation in progress...") |
|
|
correct_predictions = 0 |
|
|
|
|
|
for idx in tqdm(range(len(dataset)), desc="Evaluation"): |
|
|
image, text, true_color = dataset[idx] |
|
|
|
|
|
|
|
|
text_emb = get_text_embedding(model, text, processor) |
|
|
|
|
|
|
|
|
best_similarity = -1 |
|
|
predicted_color = color_list[0] |
|
|
|
|
|
for color, color_emb in color_embeddings.items(): |
|
|
similarity = F.cosine_similarity(text_emb, color_emb, dim=1).item() |
|
|
if similarity > best_similarity: |
|
|
best_similarity = similarity |
|
|
predicted_color = color |
|
|
|
|
|
true_labels.append(true_color) |
|
|
predicted_labels.append(predicted_color) |
|
|
|
|
|
if true_color == predicted_color: |
|
|
correct_predictions += 1 |
|
|
|
|
|
|
|
|
accuracy = accuracy_score(true_labels, predicted_labels) |
|
|
|
|
|
print(f"\nβ
Results of evaluation:") |
|
|
print(f"π― Global accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)") |
|
|
print(f"π Correct predictions: {correct_predictions}/{len(true_labels)}") |
|
|
|
|
|
return true_labels, predicted_labels, accuracy |
|
|
|
|
|
def evaluate_custom_csv_accuracy_image(model, dataset, processor, method='similarity'): |
|
|
""" |
|
|
Evaluate the accuracy of the model on your custom CSV using image-to-text similarity |
|
|
|
|
|
Args: |
|
|
model: The trained CLIP model |
|
|
dataset: CustomCSVDataset with images loaded |
|
|
processor: CLIPProcessor |
|
|
method: 'similarity' or 'classification' |
|
|
""" |
|
|
print(f"\nπ === Evaluation of the accuracy on custom CSV (IMAGE-TO-TEXT method) ===") |
|
|
|
|
|
model.eval() |
|
|
|
|
|
|
|
|
all_colors = set() |
|
|
for i in range(len(dataset)): |
|
|
_, _, color = dataset[i] |
|
|
all_colors.add(color) |
|
|
|
|
|
color_list = sorted(list(all_colors)) |
|
|
print(f"π¨ Colors found: {color_list}") |
|
|
|
|
|
true_labels = [] |
|
|
predicted_labels = [] |
|
|
|
|
|
|
|
|
print("π Pre-calculating the embeddings of the colors...") |
|
|
color_embeddings = {} |
|
|
for color in color_list: |
|
|
color_emb = get_text_embedding(model, color, processor) |
|
|
color_embeddings[color] = color_emb |
|
|
|
|
|
print("π Evaluation in progress...") |
|
|
correct_predictions = 0 |
|
|
|
|
|
for idx in tqdm(range(len(dataset)), desc="Evaluation"): |
|
|
image, text, true_color = dataset[idx] |
|
|
|
|
|
|
|
|
image_emb = get_image_embedding(model, image, processor) |
|
|
|
|
|
|
|
|
best_similarity = -1 |
|
|
predicted_color = color_list[0] |
|
|
|
|
|
for color, color_emb in color_embeddings.items(): |
|
|
similarity = F.cosine_similarity(image_emb, color_emb, dim=1).item() |
|
|
if similarity > best_similarity: |
|
|
best_similarity = similarity |
|
|
predicted_color = color |
|
|
|
|
|
true_labels.append(true_color) |
|
|
predicted_labels.append(predicted_color) |
|
|
|
|
|
if true_color == predicted_color: |
|
|
correct_predictions += 1 |
|
|
|
|
|
|
|
|
accuracy = accuracy_score(true_labels, predicted_labels) |
|
|
|
|
|
print(f"\nβ
Results of evaluation:") |
|
|
print(f"π― Global accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)") |
|
|
print(f"π Correct predictions: {correct_predictions}/{len(true_labels)}") |
|
|
|
|
|
return true_labels, predicted_labels, accuracy |
|
|
|
|
|
def evaluate_custom_csv_accuracy_color_only(model, dataset, processor): |
|
|
""" |
|
|
Evaluate the accuracy by encoding ONLY the color (not the full text) |
|
|
This tests if the embedding space is consistent for colors |
|
|
|
|
|
Args: |
|
|
model: The trained CLIP model |
|
|
dataset: CustomCSVDataset |
|
|
processor: CLIPProcessor |
|
|
""" |
|
|
print(f"\nπ === Evaluation of the accuracy on custom CSV (COLOR-TO-COLOR method) ===") |
|
|
print("π¬ This test encodes ONLY the color name, not the full text") |
|
|
|
|
|
model.eval() |
|
|
|
|
|
|
|
|
all_colors = set() |
|
|
for i in range(len(dataset)): |
|
|
_, _, color = dataset[i] |
|
|
all_colors.add(color) |
|
|
|
|
|
color_list = sorted(list(all_colors)) |
|
|
print(f"π¨ Colors found: {color_list}") |
|
|
|
|
|
true_labels = [] |
|
|
predicted_labels = [] |
|
|
|
|
|
|
|
|
print("π Pre-calculating the embeddings of the colors...") |
|
|
color_embeddings = {} |
|
|
for color in color_list: |
|
|
color_emb = get_text_embedding(model, color, processor) |
|
|
color_embeddings[color] = color_emb |
|
|
|
|
|
print("π Evaluation in progress...") |
|
|
correct_predictions = 0 |
|
|
|
|
|
for idx in tqdm(range(len(dataset)), desc="Evaluation"): |
|
|
image, text, true_color = dataset[idx] |
|
|
|
|
|
|
|
|
true_color_emb = get_text_embedding(model, true_color, processor) |
|
|
|
|
|
|
|
|
best_similarity = -1 |
|
|
predicted_color = color_list[0] |
|
|
|
|
|
for color, color_emb in color_embeddings.items(): |
|
|
similarity = F.cosine_similarity(true_color_emb, color_emb, dim=1).item() |
|
|
if similarity > best_similarity: |
|
|
best_similarity = similarity |
|
|
predicted_color = color |
|
|
|
|
|
true_labels.append(true_color) |
|
|
predicted_labels.append(predicted_color) |
|
|
|
|
|
if true_color == predicted_color: |
|
|
correct_predictions += 1 |
|
|
|
|
|
|
|
|
accuracy = accuracy_score(true_labels, predicted_labels) |
|
|
|
|
|
print(f"\nβ
Results of evaluation:") |
|
|
print(f"π― Global accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)") |
|
|
print(f"π Correct predictions: {correct_predictions}/{len(true_labels)}") |
|
|
|
|
|
return true_labels, predicted_labels, accuracy |
|
|
|
|
|
def search_custom_csv_by_text(model, dataset, query, processor, top_k=5): |
|
|
"""Search in your CSV by text query""" |
|
|
print(f"\nπ Search in custom CSV: '{query}'") |
|
|
|
|
|
|
|
|
query_emb = get_text_embedding(model, query, processor) |
|
|
|
|
|
similarities = [] |
|
|
|
|
|
print("π Calculating similarities...") |
|
|
for idx in tqdm(range(len(dataset)), desc="Processing"): |
|
|
image, text, color, _, image_path = dataset[idx] |
|
|
|
|
|
|
|
|
image_emb = get_image_embedding(model, image, processor) |
|
|
|
|
|
|
|
|
similarity = F.cosine_similarity(query_emb, image_emb, dim=1).item() |
|
|
|
|
|
similarities.append((idx, similarity, text, color, color, image_path)) |
|
|
|
|
|
|
|
|
similarities.sort(key=lambda x: x[1], reverse=True) |
|
|
|
|
|
return similarities[:top_k] |
|
|
|
|
|
def plot_confusion_matrix(true_labels, predicted_labels, save_path=None, title_suffix="text"): |
|
|
""" |
|
|
Display and save the confusion matrix |
|
|
""" |
|
|
print("\nπ === Generation of the confusion matrix ===") |
|
|
|
|
|
|
|
|
cm = confusion_matrix(true_labels, predicted_labels) |
|
|
|
|
|
|
|
|
unique_labels = sorted(set(true_labels + predicted_labels)) |
|
|
|
|
|
|
|
|
accuracy = accuracy_score(true_labels, predicted_labels) |
|
|
|
|
|
|
|
|
cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100 |
|
|
cm_percent = np.around(cm_percent).astype(int) |
|
|
|
|
|
|
|
|
plt.figure(figsize=(12, 10)) |
|
|
|
|
|
|
|
|
sns.heatmap(cm_percent, |
|
|
annot=True, |
|
|
fmt='d', |
|
|
cmap='Blues', |
|
|
cbar_kws={'label': 'Percentage (%)'}, |
|
|
xticklabels=unique_labels, |
|
|
yticklabels=unique_labels) |
|
|
|
|
|
plt.title(f"Confusion Matrix for {title_suffix} - new data - accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)", fontsize=16) |
|
|
plt.xlabel('Predictions', fontsize=12) |
|
|
plt.ylabel('True colors', fontsize=12) |
|
|
plt.xticks(rotation=45, ha='right') |
|
|
plt.yticks(rotation=0) |
|
|
plt.tight_layout() |
|
|
|
|
|
if save_path: |
|
|
plt.savefig(save_path, dpi=300, bbox_inches='tight') |
|
|
print(f"πΎ Confusion matrix saved: {save_path}") |
|
|
|
|
|
plt.show() |
|
|
|
|
|
return cm |
|
|
|
|
|
class CustomCSVDataset(Dataset): |
|
|
def __init__(self, dataframe, image_size=224, load_images=True): |
|
|
self.dataframe = dataframe |
|
|
self.image_size = image_size |
|
|
self.load_images = load_images |
|
|
|
|
|
|
|
|
self.transform = transforms.Compose([ |
|
|
transforms.Resize((image_size, image_size)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], |
|
|
std=[0.26862954, 0.26130258, 0.27577711]) |
|
|
]) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.dataframe) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
row = self.dataframe.iloc[idx] |
|
|
text = row[config.text_column] |
|
|
colors = row[config.color_column] |
|
|
|
|
|
if self.load_images and config.column_local_image_path in row: |
|
|
|
|
|
try: |
|
|
image = Image.open(row[config.column_local_image_path]).convert('RGB') |
|
|
image = self.transform(image) |
|
|
except Exception as e: |
|
|
print(f"Warning: Could not load image {row.get(config.column_local_image_path, 'unknown')}: {e}") |
|
|
image = torch.zeros(3, self.image_size, self.image_size) |
|
|
else: |
|
|
|
|
|
image = torch.zeros(3, self.image_size, self.image_size) |
|
|
|
|
|
return image, text, colors |
|
|
|
|
|
if __name__ == "__main__": |
|
|
"""Main function with evaluation""" |
|
|
print("π === Test and Evaluation of the model on new dataset ===") |
|
|
|
|
|
|
|
|
print("π§ Loading the model...") |
|
|
model, checkpoint = load_trained_model(config.main_model_path, config.device) |
|
|
|
|
|
|
|
|
processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K') |
|
|
|
|
|
|
|
|
print("π Loading the new dataset...") |
|
|
df = pd.read_csv(config.local_dataset_path) |
|
|
|
|
|
print("\n" + "="*80) |
|
|
print("π¨ COLOR-TO-COLOR CLASSIFICATION (Control Test)") |
|
|
print("="*80) |
|
|
|
|
|
|
|
|
dataset_color = CustomCSVDataset(df, load_images=False) |
|
|
|
|
|
|
|
|
true_labels_color, predicted_labels_color, accuracy_color = evaluate_custom_csv_accuracy_color_only( |
|
|
model, dataset_color, processor |
|
|
) |
|
|
|
|
|
|
|
|
confusion_matrix_color = plot_confusion_matrix( |
|
|
true_labels_color, predicted_labels_color, |
|
|
save_path="confusion_matrix_color_only.png", |
|
|
title_suffix="color-only" |
|
|
) |
|
|
|
|
|
print("\n" + "="*80) |
|
|
print("π TEXT-TO-TEXT CLASSIFICATION") |
|
|
print("="*80) |
|
|
|
|
|
|
|
|
dataset_text = CustomCSVDataset(df, load_images=False) |
|
|
|
|
|
|
|
|
true_labels_text, predicted_labels_text, accuracy_text = evaluate_custom_csv_accuracy( |
|
|
model, dataset_text, processor, method='similarity' |
|
|
) |
|
|
|
|
|
|
|
|
confusion_matrix_text = plot_confusion_matrix( |
|
|
true_labels_text, predicted_labels_text, |
|
|
save_path="confusion_matrix_text.png", |
|
|
title_suffix="text" |
|
|
) |
|
|
|
|
|
print("\n" + "="*80) |
|
|
print("πΌοΈ IMAGE-TO-TEXT CLASSIFICATION") |
|
|
print("="*80) |
|
|
|
|
|
|
|
|
dataset_image = CustomCSVDataset(df, load_images=True) |
|
|
|
|
|
|
|
|
true_labels_image, predicted_labels_image, accuracy_image = evaluate_custom_csv_accuracy_image( |
|
|
model, dataset_image, processor, method='similarity' |
|
|
) |
|
|
|
|
|
|
|
|
confusion_matrix_image = plot_confusion_matrix( |
|
|
true_labels_image, predicted_labels_image, |
|
|
save_path="confusion_matrix_image.png", |
|
|
title_suffix="image" |
|
|
) |
|
|
|
|
|
|
|
|
print("\n" + "="*80) |
|
|
print("π SUMMARY") |
|
|
print("="*80) |
|
|
print(f"π¨ Color-to-Color Accuracy (Control): {accuracy_color:.4f} ({accuracy_color*100:.2f}%)") |
|
|
print(f"π Text-to-Text Accuracy: {accuracy_text:.4f} ({accuracy_text*100:.2f}%)") |
|
|
print(f"πΌοΈ Image-to-Text Accuracy: {accuracy_image:.4f} ({accuracy_image*100:.2f}%)") |
|
|
print(f"\nπ Analysis:") |
|
|
print(f" β’ Loss from full text vs color-only: {abs(accuracy_color - accuracy_text):.4f} ({abs(accuracy_color - accuracy_text)*100:.2f}%)") |
|
|
print(f" β’ Difference text vs image: {abs(accuracy_text - accuracy_image):.4f} ({abs(accuracy_text - accuracy_image)*100:.2f}%)") |
|
|
|