gap-clip / evaluation /basic_test_generalized.py
Leacb4's picture
Upload evaluation/basic_test_generalized.py with huggingface_hub
8b62bf6 verified
"""
Generalized evaluation of the main model with sub-module comparison.
This file evaluates the main model's performance by comparing specialized parts
(color and hierarchy) with corresponding specialized models. It calculates similarity
matrices, linear projections between embedding spaces, and generates detailed statistics
on alignment between different representations.
"""
import os
import json
import argparse
import config
import torch
import torch.nn.functional as F
import pandas as pd
from PIL import Image
from torchvision import transforms
from transformers import CLIPProcessor, CLIPModel as CLIPModelTransformers
from tqdm.auto import tqdm
# Local imports
from color_model import ColorCLIP as ColorModel, ColorDataset, Tokenizer
from config import color_model_path, color_emb_dim, device, hierarchy_model_path, hierarchy_emb_dim
from hierarchy_model import Model as HierarchyModel, HierarchyExtractor
def load_color_model(color_model_path, color_emb_dim, device):
# Load color model
color_checkpoint = torch.load(color_model_path, map_location=device, weights_only=True)
color_model = ColorModel(vocab_size=39, embedding_dim=color_emb_dim).to(device)
color_model.load_state_dict(color_checkpoint)
# Load and set the tokenizer
tokenizer = Tokenizer()
with open(config.tokeniser_path, 'r') as f:
vocab_dict = json.load(f)
color_model.tokenizer = tokenizer
color_model.eval()
return color_model
def get_emb_color_model(color_model, image_path_to_encode, text_to_encode):
# Load and preprocess image
image = Image.open(image_path_to_encode).convert('RGB')
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
processed_image = transform(image)
# Get embeddings
processed_image_batch = processed_image.unsqueeze(0).to(device) # Shape: [1, 3, 224, 224]
with torch.no_grad():
image_emb = color_model.image_encoder(processed_image_batch)
# Text embedding via tokenizer + text_encoder
token_ids = torch.tensor([color_model.tokenizer(text_to_encode)], dtype=torch.long, device=device)
lengths = torch.tensor([token_ids.size(1) if token_ids.dim() > 1 else token_ids.size(0)], dtype=torch.long, device=device)
with torch.no_grad():
txt_emb = color_model.text_encoder(token_ids, lengths)
return image_emb, txt_emb
def load_main_model(main_model_path, device):
checkpoint = torch.load(main_model_path, map_location=device)
main_model = CLIPModel_transformers.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
state = checkpoint['model_state_dict'] if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint else checkpoint
try:
main_model.load_state_dict(state, strict=False)
except Exception:
# Fallback: filter matching keys
model_state = main_model.state_dict()
filtered = {k: v for k, v in state.items() if k in model_state and model_state[k].shape == v.shape}
main_model.load_state_dict(filtered, strict=False)
main_model.to(device)
main_model.eval()
processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
return main_model, processor
def load_hierarchy_model(hierarchy_model_path, device):
checkpoint = torch.load(hierarchy_model_path, map_location=device)
hierarchy_classes = checkpoint.get('hierarchy_classes', [])
model = HierarchyModel(num_hierarchy_classes=len(hierarchy_classes), embed_dim=config.hierarchy_emb_dim).to(device)
model.load_state_dict(checkpoint['model_state'])
extractor = HierarchyExtractor(hierarchy_classes, verbose=False)
model.set_hierarchy_extractor(extractor)
model.eval()
return model
def get_emb_hierarchy_model(hierarchy_model, image_path_to_encode, text_to_encode):
image = Image.open(image_path_to_encode).convert('RGB')
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
image_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
img_emb = hierarchy_model.get_image_embeddings(image_tensor)
txt_emb = hierarchy_model.get_text_embeddings(text_to_encode)
return img_emb, txt_emb
def get_emb_main_model(main_model, processor, image_path_to_encode, text_to_encode):
image = Image.open(image_path_to_encode).convert('RGB')
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image = transform(image)
image = image.unsqueeze(0).to(device)
# Prepare text inputs via processor
text_inputs = processor(text=[text_to_encode], return_tensors="pt", padding=True)
text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
outputs = main_model(**text_inputs, pixel_values=image)
text_emb = outputs.text_embeds
image_emb = outputs.image_embeds
return text_emb, image_emb
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Evaluate main model parts vs small models and build similarity matrices')
parser.add_argument('--main-checkpoint', type=str, default='models/laion_explicable_model.pth')
parser.add_argument('--color-checkpoint', type=str, default='models/color_model.pt')
parser.add_argument('--csv', type=str, default='data/data_with_local_paths.csv')
parser.add_argument('--color-emb-dim', type=int, default=16)
parser.add_argument('--num-samples', type=int, default=200)
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--primary-metric', type=str, default='sim_color_txt_img',
choices=['sim_txt_color_part', 'sim_img_color_part', 'sim_color_txt_img', 'sim_small_txt_img',
'sim_txt_hierarchy_part', 'sim_img_hierarchy_part'])
parser.add_argument('--top-k', type=int, default=30)
parser.add_argument('--heatmap', action='store_true')
parser.add_argument('--l2-grid', type=str, default='1e-5,1e-4,1e-3,1e-2,1e-1')
args = parser.parse_args()
main_checkpoint = args.main_checkpoint
color_checkpoint = args.color_checkpoint
csv = args.csv
color_emb_dim = args.color_emb_dim
num_samples = args.num_samples
seed = args.seed
primary_metric = args.primary_metric
top_k = args.top_k
l2_grid = [float(x) for x in args.l2_grid.split(',') if x]
device = torch.device("mps")
df = pd.read_csv(csv)
# Normalize colors (reduce aliasing and sparsity)
def normalize_color(c):
if pd.isna(c):
return c
s = str(c).strip().lower()
aliases = {
'grey': 'gray',
'navy blue': 'navy',
'light blue': 'blue',
'dark blue': 'blue',
'light grey': 'gray',
'dark grey': 'gray',
'light gray': 'gray',
'dark gray': 'gray',
}
return aliases.get(s, s)
if config.color_column in df.columns:
df[config.color_column] = df[config.color_column].apply(normalize_color)
color_model = load_color_model(color_checkpoint, color_emb_dim, device)
main_model, processor = load_main_model(main_checkpoint, device)
hierarchy_model = load_hierarchy_model(hierarchy_model_path, device)
# Results container
results = []
# Accumulators for projection (A: main part, B: small model)
color_txt_As, color_txt_Bs = [], []
color_img_As, color_img_Bs = [], []
hier_txt_As, hier_txt_Bs = [], []
hier_img_As, hier_img_Bs = [], []
# Ensure determinism for sampling
pd.options.mode.copy_on_write = True
rng = pd.Series(range(len(df)), dtype=int)
_ = rng # silence lint
torch.manual_seed(seed)
unique_hiers = sorted(df[config.hierarchy_column].dropna().unique())
unique_colors = sorted(df[config.color_column].dropna().unique())
# Progress bar across all (hierarchy, color) pairs
total_pairs = len(unique_hiers) * len(unique_colors)
pair_pbar = tqdm(total=total_pairs, desc="Evaluating pairs", leave=False)
for hierarchy in unique_hiers:
for color in unique_colors:
group = df[(df[config.hierarchy_column] == hierarchy) & (df[config.color_column] == color)]
# Sample up to num_samples per (hierarchy, color)
k = min(num_samples, len(group))
group_iter = group.sample(n=k, random_state=seed) if len(group) > k else group.iloc[:k]
# Progress bar for samples within the pair
inner_pbar = tqdm(total=len(group_iter), desc=f"{hierarchy}/{color}", leave=False)
for row_idx, (_, example) in enumerate(group_iter.iterrows()):
try:
image_emb, txt_emb = get_emb_color_model(color_model, example['local_image_path'], example['text'])
image_emb_hier, txt_emb_hier = get_emb_hierarchy_model(hierarchy_model, example['local_image_path'], example['text'])
text_emb_main_model, image_emb_main_model = get_emb_main_model(
main_model, processor, example['local_image_path'], example['text']
)
color_part_txt = text_emb_main_model[:, :color_emb_dim]
color_part_img = image_emb_main_model[:, :color_emb_dim]
hier_part_txt = text_emb_main_model[:, color_emb_dim:color_emb_dim + hierarchy_emb_dim]
hier_part_img = image_emb_main_model[:, color_emb_dim:color_emb_dim + hierarchy_emb_dim]
# L2-normalize parts and small-model embeddings for stable cosine
color_part_txt = F.normalize(color_part_txt, dim=1)
color_part_img = F.normalize(color_part_img, dim=1)
hier_part_txt = F.normalize(hier_part_txt, dim=1)
hier_part_img = F.normalize(hier_part_img, dim=1)
txt_emb = F.normalize(txt_emb, dim=1)
image_emb = F.normalize(image_emb, dim=1)
txt_emb_hier = F.normalize(txt_emb_hier, dim=1)
image_emb_hier = F.normalize(image_emb_hier, dim=1)
sim_txt_color_part = F.cosine_similarity(txt_emb, color_part_txt).item()
sim_img_color_part = F.cosine_similarity(image_emb, color_part_img).item()
sim_color_txt_img = F.cosine_similarity(color_part_txt, color_part_img).item()
sim_small_txt_img = F.cosine_similarity(txt_emb, image_emb).item()
sim_txt_hierarchy_part = F.cosine_similarity(txt_emb_hier, hier_part_txt).item()
sim_img_hierarchy_part = F.cosine_similarity(image_emb_hier, hier_part_img).item()
# Accumulate for projection fitting later
color_txt_As.append(color_part_txt.squeeze(0).detach().cpu())
color_txt_Bs.append(txt_emb.squeeze(0).detach().cpu())
color_img_As.append(color_part_img.squeeze(0).detach().cpu())
color_img_Bs.append(image_emb.squeeze(0).detach().cpu())
hier_txt_As.append(hier_part_txt.squeeze(0).detach().cpu())
hier_txt_Bs.append(txt_emb_hier.squeeze(0).detach().cpu())
hier_img_As.append(hier_part_img.squeeze(0).detach().cpu())
hier_img_Bs.append(image_emb_hier.squeeze(0).detach().cpu())
results.append({
'hierarchy': hierarchy,
'color': color,
'row_index': int(row_idx),
'sim_txt_color_part': float(sim_txt_color_part),
'sim_img_color_part': float(sim_img_color_part),
'sim_color_txt_img': float(sim_color_txt_img),
'sim_small_txt_img': float(sim_small_txt_img),
'sim_txt_hierarchy_part': float(sim_txt_hierarchy_part),
'sim_img_hierarchy_part': float(sim_img_hierarchy_part),
})
except Exception as e:
print(f"Skipping example due to error: {e}")
finally:
inner_pbar.update(1)
inner_pbar.close()
pair_pbar.update(1)
pair_pbar.close()
results_df = pd.DataFrame(results)
# Save raw results
os.makedirs('evaluation_outputs', exist_ok=True)
raw_path = os.path.join('evaluation_outputs', 'similarities_raw.csv')
results_df.to_csv(raw_path, index=False)
print(f"Saved raw similarities to {raw_path}")
# Intelligent averages
metrics = ['sim_txt_color_part', 'sim_img_color_part', 'sim_color_txt_img', 'sim_small_txt_img',
'sim_txt_hierarchy_part', 'sim_img_hierarchy_part']
# Overall means
overall_means = results_df[metrics].mean().to_frame(name='mean').T
overall_means.insert(0, 'level', 'overall')
# By hierarchy
by_hierarchy = results_df.groupby(config.hierarchy_column)[metrics].mean().reset_index()
by_hierarchy.insert(0, 'level', config.hierarchy_column)
# By color
by_color = results_df.groupby(config.color_column)[metrics].mean().reset_index()
by_color.insert(0, 'level', config.color_column)
# By hierarchy+color
by_pair = results_df.groupby([config.hierarchy_column, config.color_column])[metrics].mean().reset_index()
by_pair.insert(0, 'level', 'hierarchy_color')
summary_df = pd.concat([overall_means, by_hierarchy, by_color, by_pair], ignore_index=True)
summary_path = os.path.join('evaluation_outputs', 'similarities_summary.csv')
summary_df.to_csv(summary_path, index=False)
print(f"Saved summary statistics to {summary_path}")
# =====================
# Similarity matrices for best hierarchy-color combinations
# =====================
try:
by_pair_core = results_df.groupby([config.hierarchy_column, config.color_column])[metrics].mean().reset_index()
top_pairs = by_pair_core.nlargest(top_k, primary_metric)
matrix = top_pairs.pivot(index=config.hierarchy_column, columns=config.color_column, values=primary_metric)
os.makedirs('evaluation_outputs', exist_ok=True)
matrix_csv_path = os.path.join('evaluation_outputs', f'similarity_matrix_{primary_metric}_top{top_k}.csv')
matrix.to_csv(matrix_csv_path)
print(f"Saved similarity matrix to {matrix_csv_path}")
if args.heatmap:
try:
import seaborn as sns
import matplotlib.pyplot as plt
plt.figure(figsize=(max(6, 0.5 * len(matrix.columns)), max(4, 0.5 * len(matrix.index))))
sns.heatmap(matrix, annot=False, cmap='viridis')
plt.title(f'Similarity matrix (top {top_k}) - {primary_metric}')
heatmap_path = os.path.join('evaluation_outputs', f'similarity_matrix_{primary_metric}_top{top_k}.png')
plt.tight_layout()
plt.savefig(heatmap_path, dpi=200)
plt.close()
print(f"Saved similarity heatmap to {heatmap_path}")
except Exception as e:
print(f"Skipping heatmap generation: {e}")
except Exception as e:
print(f"Skipping matrix generation: {e}")
# =====================
# Learn projections A->B and report projected cosine means
# =====================
def fit_ridge_projection(A, B, l2_reg=1e-3):
# A: [N, D_in], B: [N, D_out]
A = torch.stack(A) # [N, D_in]
B = torch.stack(B) # [N, D_out]
# Closed-form ridge: W = (A^T A + λI)^-1 A^T B
AtA = A.T @ A
D_in = AtA.shape[0]
AtA_reg = AtA + l2_reg * torch.eye(D_in)
W = torch.linalg.solve(AtA_reg, A.T @ B)
return W # [D_in, D_out]
def fit_ridge_with_cv(A, B, l2_values):
# Simple holdout CV: 80/20 split
if len(A) < 10:
# Not enough data for split; fallback to middle lambda
best_l2 = l2_values[min(len(l2_values) // 2, len(l2_values)-1)]
W = fit_ridge_projection(A, B, best_l2)
return W, best_l2, None
N = len(A)
idx = torch.randperm(N)
split = int(0.8 * N)
train_idx = idx[:split]
val_idx = idx[split:]
A_tensor = torch.stack(A)
B_tensor = torch.stack(B)
A_train, B_train = A_tensor[train_idx], B_tensor[train_idx]
A_val, B_val = A_tensor[val_idx], B_tensor[val_idx]
def to_list(t):
return [row for row in t]
best_l2 = None
best_score = -1.0
for l2 in l2_values:
W = fit_ridge_projection(to_list(A_train), to_list(B_train), l2)
score = mean_projected_cosine(to_list(A_val), to_list(B_val), W)
if score > best_score:
best_score = score
best_l2 = l2
# Refit on all with best_l2
W_best = fit_ridge_projection(A, B, best_l2)
return W_best, best_l2, best_score
def mean_projected_cosine(A, B, W):
A = torch.stack(A)
B = torch.stack(B)
A_proj = A @ W
A_proj = F.normalize(A_proj, dim=1)
B = F.normalize(B, dim=1)
return torch.mean(torch.sum(A_proj * B, dim=1)).item()
projection_report = {}
if len(color_txt_As) >= 8:
W_ct, best_l2_ct, cv_ct = fit_ridge_with_cv(color_txt_As, color_txt_Bs, l2_grid)
projection_report['proj_sim_txt_color_part_mean'] = mean_projected_cosine(color_txt_As, color_txt_Bs, W_ct)
projection_report['proj_txt_color_part_best_l2'] = best_l2_ct
if cv_ct is not None:
projection_report['proj_txt_color_part_cv_val'] = cv_ct
if len(color_img_As) >= 8:
W_ci, best_l2_ci, cv_ci = fit_ridge_with_cv(color_img_As, color_img_Bs, l2_grid)
projection_report['proj_sim_img_color_part_mean'] = mean_projected_cosine(color_img_As, color_img_Bs, W_ci)
projection_report['proj_img_color_part_best_l2'] = best_l2_ci
if cv_ci is not None:
projection_report['proj_img_color_part_cv_val'] = cv_ci
if len(hier_txt_As) >= 8:
W_ht, best_l2_ht, cv_ht = fit_ridge_with_cv(hier_txt_As, hier_txt_Bs, l2_grid)
projection_report['proj_sim_txt_hierarchy_part_mean'] = mean_projected_cosine(hier_txt_As, hier_txt_Bs, W_ht)
projection_report['proj_txt_hierarchy_part_best_l2'] = best_l2_ht
if cv_ht is not None:
projection_report['proj_txt_hierarchy_part_cv_val'] = cv_ht
if len(hier_img_As) >= 8:
W_hi, best_l2_hi, cv_hi = fit_ridge_with_cv(hier_img_As, hier_img_Bs, l2_grid)
projection_report['proj_sim_img_hierarchy_part_mean'] = mean_projected_cosine(hier_img_As, hier_img_Bs, W_hi)
projection_report['proj_img_hierarchy_part_best_l2'] = best_l2_hi
if cv_hi is not None:
projection_report['proj_img_hierarchy_part_cv_val'] = cv_hi
proj_summary_path = os.path.join('evaluation_outputs', 'projection_summary.json')
with open(proj_summary_path, 'w') as f:
json.dump(projection_report, f, indent=2)
print(f"Saved projection summary to {proj_summary_path}")