""" Generalized evaluation of the main model with sub-module comparison. This file evaluates the main model's performance by comparing specialized parts (color and hierarchy) with corresponding specialized models. It calculates similarity matrices, linear projections between embedding spaces, and generates detailed statistics on alignment between different representations. """ import os import json import argparse import config import torch import torch.nn.functional as F import pandas as pd from PIL import Image from torchvision import transforms from transformers import CLIPProcessor, CLIPModel as CLIPModelTransformers from tqdm.auto import tqdm # Local imports from color_model import ColorCLIP as ColorModel, ColorDataset, Tokenizer from config import color_model_path, color_emb_dim, device, hierarchy_model_path, hierarchy_emb_dim from hierarchy_model import Model as HierarchyModel, HierarchyExtractor def load_color_model(color_model_path, color_emb_dim, device): # Load color model color_checkpoint = torch.load(color_model_path, map_location=device, weights_only=True) color_model = ColorModel(vocab_size=39, embedding_dim=color_emb_dim).to(device) color_model.load_state_dict(color_checkpoint) # Load and set the tokenizer tokenizer = Tokenizer() with open(config.tokeniser_path, 'r') as f: vocab_dict = json.load(f) color_model.tokenizer = tokenizer color_model.eval() return color_model def get_emb_color_model(color_model, image_path_to_encode, text_to_encode): # Load and preprocess image image = Image.open(image_path_to_encode).convert('RGB') transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) processed_image = transform(image) # Get embeddings processed_image_batch = processed_image.unsqueeze(0).to(device) # Shape: [1, 3, 224, 224] with torch.no_grad(): image_emb = color_model.image_encoder(processed_image_batch) # Text embedding via tokenizer + text_encoder token_ids = torch.tensor([color_model.tokenizer(text_to_encode)], dtype=torch.long, device=device) lengths = torch.tensor([token_ids.size(1) if token_ids.dim() > 1 else token_ids.size(0)], dtype=torch.long, device=device) with torch.no_grad(): txt_emb = color_model.text_encoder(token_ids, lengths) return image_emb, txt_emb def load_main_model(main_model_path, device): checkpoint = torch.load(main_model_path, map_location=device) main_model = CLIPModel_transformers.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K') state = checkpoint['model_state_dict'] if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint else checkpoint try: main_model.load_state_dict(state, strict=False) except Exception: # Fallback: filter matching keys model_state = main_model.state_dict() filtered = {k: v for k, v in state.items() if k in model_state and model_state[k].shape == v.shape} main_model.load_state_dict(filtered, strict=False) main_model.to(device) main_model.eval() processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K') return main_model, processor def load_hierarchy_model(hierarchy_model_path, device): checkpoint = torch.load(hierarchy_model_path, map_location=device) hierarchy_classes = checkpoint.get('hierarchy_classes', []) model = HierarchyModel(num_hierarchy_classes=len(hierarchy_classes), embed_dim=config.hierarchy_emb_dim).to(device) model.load_state_dict(checkpoint['model_state']) extractor = HierarchyExtractor(hierarchy_classes, verbose=False) model.set_hierarchy_extractor(extractor) model.eval() return model def get_emb_hierarchy_model(hierarchy_model, image_path_to_encode, text_to_encode): image = Image.open(image_path_to_encode).convert('RGB') transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), ]) image_tensor = transform(image).unsqueeze(0).to(device) with torch.no_grad(): img_emb = hierarchy_model.get_image_embeddings(image_tensor) txt_emb = hierarchy_model.get_text_embeddings(text_to_encode) return img_emb, txt_emb def get_emb_main_model(main_model, processor, image_path_to_encode, text_to_encode): image = Image.open(image_path_to_encode).convert('RGB') transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) image = transform(image) image = image.unsqueeze(0).to(device) # Prepare text inputs via processor text_inputs = processor(text=[text_to_encode], return_tensors="pt", padding=True) text_inputs = {k: v.to(device) for k, v in text_inputs.items()} outputs = main_model(**text_inputs, pixel_values=image) text_emb = outputs.text_embeds image_emb = outputs.image_embeds return text_emb, image_emb if __name__ == '__main__': parser = argparse.ArgumentParser(description='Evaluate main model parts vs small models and build similarity matrices') parser.add_argument('--main-checkpoint', type=str, default='models/laion_explicable_model.pth') parser.add_argument('--color-checkpoint', type=str, default='models/color_model.pt') parser.add_argument('--csv', type=str, default='data/data_with_local_paths.csv') parser.add_argument('--color-emb-dim', type=int, default=16) parser.add_argument('--num-samples', type=int, default=200) parser.add_argument('--seed', type=int, default=42) parser.add_argument('--primary-metric', type=str, default='sim_color_txt_img', choices=['sim_txt_color_part', 'sim_img_color_part', 'sim_color_txt_img', 'sim_small_txt_img', 'sim_txt_hierarchy_part', 'sim_img_hierarchy_part']) parser.add_argument('--top-k', type=int, default=30) parser.add_argument('--heatmap', action='store_true') parser.add_argument('--l2-grid', type=str, default='1e-5,1e-4,1e-3,1e-2,1e-1') args = parser.parse_args() main_checkpoint = args.main_checkpoint color_checkpoint = args.color_checkpoint csv = args.csv color_emb_dim = args.color_emb_dim num_samples = args.num_samples seed = args.seed primary_metric = args.primary_metric top_k = args.top_k l2_grid = [float(x) for x in args.l2_grid.split(',') if x] device = torch.device("mps") df = pd.read_csv(csv) # Normalize colors (reduce aliasing and sparsity) def normalize_color(c): if pd.isna(c): return c s = str(c).strip().lower() aliases = { 'grey': 'gray', 'navy blue': 'navy', 'light blue': 'blue', 'dark blue': 'blue', 'light grey': 'gray', 'dark grey': 'gray', 'light gray': 'gray', 'dark gray': 'gray', } return aliases.get(s, s) if config.color_column in df.columns: df[config.color_column] = df[config.color_column].apply(normalize_color) color_model = load_color_model(color_checkpoint, color_emb_dim, device) main_model, processor = load_main_model(main_checkpoint, device) hierarchy_model = load_hierarchy_model(hierarchy_model_path, device) # Results container results = [] # Accumulators for projection (A: main part, B: small model) color_txt_As, color_txt_Bs = [], [] color_img_As, color_img_Bs = [], [] hier_txt_As, hier_txt_Bs = [], [] hier_img_As, hier_img_Bs = [], [] # Ensure determinism for sampling pd.options.mode.copy_on_write = True rng = pd.Series(range(len(df)), dtype=int) _ = rng # silence lint torch.manual_seed(seed) unique_hiers = sorted(df[config.hierarchy_column].dropna().unique()) unique_colors = sorted(df[config.color_column].dropna().unique()) # Progress bar across all (hierarchy, color) pairs total_pairs = len(unique_hiers) * len(unique_colors) pair_pbar = tqdm(total=total_pairs, desc="Evaluating pairs", leave=False) for hierarchy in unique_hiers: for color in unique_colors: group = df[(df[config.hierarchy_column] == hierarchy) & (df[config.color_column] == color)] # Sample up to num_samples per (hierarchy, color) k = min(num_samples, len(group)) group_iter = group.sample(n=k, random_state=seed) if len(group) > k else group.iloc[:k] # Progress bar for samples within the pair inner_pbar = tqdm(total=len(group_iter), desc=f"{hierarchy}/{color}", leave=False) for row_idx, (_, example) in enumerate(group_iter.iterrows()): try: image_emb, txt_emb = get_emb_color_model(color_model, example['local_image_path'], example['text']) image_emb_hier, txt_emb_hier = get_emb_hierarchy_model(hierarchy_model, example['local_image_path'], example['text']) text_emb_main_model, image_emb_main_model = get_emb_main_model( main_model, processor, example['local_image_path'], example['text'] ) color_part_txt = text_emb_main_model[:, :color_emb_dim] color_part_img = image_emb_main_model[:, :color_emb_dim] hier_part_txt = text_emb_main_model[:, color_emb_dim:color_emb_dim + hierarchy_emb_dim] hier_part_img = image_emb_main_model[:, color_emb_dim:color_emb_dim + hierarchy_emb_dim] # L2-normalize parts and small-model embeddings for stable cosine color_part_txt = F.normalize(color_part_txt, dim=1) color_part_img = F.normalize(color_part_img, dim=1) hier_part_txt = F.normalize(hier_part_txt, dim=1) hier_part_img = F.normalize(hier_part_img, dim=1) txt_emb = F.normalize(txt_emb, dim=1) image_emb = F.normalize(image_emb, dim=1) txt_emb_hier = F.normalize(txt_emb_hier, dim=1) image_emb_hier = F.normalize(image_emb_hier, dim=1) sim_txt_color_part = F.cosine_similarity(txt_emb, color_part_txt).item() sim_img_color_part = F.cosine_similarity(image_emb, color_part_img).item() sim_color_txt_img = F.cosine_similarity(color_part_txt, color_part_img).item() sim_small_txt_img = F.cosine_similarity(txt_emb, image_emb).item() sim_txt_hierarchy_part = F.cosine_similarity(txt_emb_hier, hier_part_txt).item() sim_img_hierarchy_part = F.cosine_similarity(image_emb_hier, hier_part_img).item() # Accumulate for projection fitting later color_txt_As.append(color_part_txt.squeeze(0).detach().cpu()) color_txt_Bs.append(txt_emb.squeeze(0).detach().cpu()) color_img_As.append(color_part_img.squeeze(0).detach().cpu()) color_img_Bs.append(image_emb.squeeze(0).detach().cpu()) hier_txt_As.append(hier_part_txt.squeeze(0).detach().cpu()) hier_txt_Bs.append(txt_emb_hier.squeeze(0).detach().cpu()) hier_img_As.append(hier_part_img.squeeze(0).detach().cpu()) hier_img_Bs.append(image_emb_hier.squeeze(0).detach().cpu()) results.append({ 'hierarchy': hierarchy, 'color': color, 'row_index': int(row_idx), 'sim_txt_color_part': float(sim_txt_color_part), 'sim_img_color_part': float(sim_img_color_part), 'sim_color_txt_img': float(sim_color_txt_img), 'sim_small_txt_img': float(sim_small_txt_img), 'sim_txt_hierarchy_part': float(sim_txt_hierarchy_part), 'sim_img_hierarchy_part': float(sim_img_hierarchy_part), }) except Exception as e: print(f"Skipping example due to error: {e}") finally: inner_pbar.update(1) inner_pbar.close() pair_pbar.update(1) pair_pbar.close() results_df = pd.DataFrame(results) # Save raw results os.makedirs('evaluation_outputs', exist_ok=True) raw_path = os.path.join('evaluation_outputs', 'similarities_raw.csv') results_df.to_csv(raw_path, index=False) print(f"Saved raw similarities to {raw_path}") # Intelligent averages metrics = ['sim_txt_color_part', 'sim_img_color_part', 'sim_color_txt_img', 'sim_small_txt_img', 'sim_txt_hierarchy_part', 'sim_img_hierarchy_part'] # Overall means overall_means = results_df[metrics].mean().to_frame(name='mean').T overall_means.insert(0, 'level', 'overall') # By hierarchy by_hierarchy = results_df.groupby(config.hierarchy_column)[metrics].mean().reset_index() by_hierarchy.insert(0, 'level', config.hierarchy_column) # By color by_color = results_df.groupby(config.color_column)[metrics].mean().reset_index() by_color.insert(0, 'level', config.color_column) # By hierarchy+color by_pair = results_df.groupby([config.hierarchy_column, config.color_column])[metrics].mean().reset_index() by_pair.insert(0, 'level', 'hierarchy_color') summary_df = pd.concat([overall_means, by_hierarchy, by_color, by_pair], ignore_index=True) summary_path = os.path.join('evaluation_outputs', 'similarities_summary.csv') summary_df.to_csv(summary_path, index=False) print(f"Saved summary statistics to {summary_path}") # ===================== # Similarity matrices for best hierarchy-color combinations # ===================== try: by_pair_core = results_df.groupby([config.hierarchy_column, config.color_column])[metrics].mean().reset_index() top_pairs = by_pair_core.nlargest(top_k, primary_metric) matrix = top_pairs.pivot(index=config.hierarchy_column, columns=config.color_column, values=primary_metric) os.makedirs('evaluation_outputs', exist_ok=True) matrix_csv_path = os.path.join('evaluation_outputs', f'similarity_matrix_{primary_metric}_top{top_k}.csv') matrix.to_csv(matrix_csv_path) print(f"Saved similarity matrix to {matrix_csv_path}") if args.heatmap: try: import seaborn as sns import matplotlib.pyplot as plt plt.figure(figsize=(max(6, 0.5 * len(matrix.columns)), max(4, 0.5 * len(matrix.index)))) sns.heatmap(matrix, annot=False, cmap='viridis') plt.title(f'Similarity matrix (top {top_k}) - {primary_metric}') heatmap_path = os.path.join('evaluation_outputs', f'similarity_matrix_{primary_metric}_top{top_k}.png') plt.tight_layout() plt.savefig(heatmap_path, dpi=200) plt.close() print(f"Saved similarity heatmap to {heatmap_path}") except Exception as e: print(f"Skipping heatmap generation: {e}") except Exception as e: print(f"Skipping matrix generation: {e}") # ===================== # Learn projections A->B and report projected cosine means # ===================== def fit_ridge_projection(A, B, l2_reg=1e-3): # A: [N, D_in], B: [N, D_out] A = torch.stack(A) # [N, D_in] B = torch.stack(B) # [N, D_out] # Closed-form ridge: W = (A^T A + λI)^-1 A^T B AtA = A.T @ A D_in = AtA.shape[0] AtA_reg = AtA + l2_reg * torch.eye(D_in) W = torch.linalg.solve(AtA_reg, A.T @ B) return W # [D_in, D_out] def fit_ridge_with_cv(A, B, l2_values): # Simple holdout CV: 80/20 split if len(A) < 10: # Not enough data for split; fallback to middle lambda best_l2 = l2_values[min(len(l2_values) // 2, len(l2_values)-1)] W = fit_ridge_projection(A, B, best_l2) return W, best_l2, None N = len(A) idx = torch.randperm(N) split = int(0.8 * N) train_idx = idx[:split] val_idx = idx[split:] A_tensor = torch.stack(A) B_tensor = torch.stack(B) A_train, B_train = A_tensor[train_idx], B_tensor[train_idx] A_val, B_val = A_tensor[val_idx], B_tensor[val_idx] def to_list(t): return [row for row in t] best_l2 = None best_score = -1.0 for l2 in l2_values: W = fit_ridge_projection(to_list(A_train), to_list(B_train), l2) score = mean_projected_cosine(to_list(A_val), to_list(B_val), W) if score > best_score: best_score = score best_l2 = l2 # Refit on all with best_l2 W_best = fit_ridge_projection(A, B, best_l2) return W_best, best_l2, best_score def mean_projected_cosine(A, B, W): A = torch.stack(A) B = torch.stack(B) A_proj = A @ W A_proj = F.normalize(A_proj, dim=1) B = F.normalize(B, dim=1) return torch.mean(torch.sum(A_proj * B, dim=1)).item() projection_report = {} if len(color_txt_As) >= 8: W_ct, best_l2_ct, cv_ct = fit_ridge_with_cv(color_txt_As, color_txt_Bs, l2_grid) projection_report['proj_sim_txt_color_part_mean'] = mean_projected_cosine(color_txt_As, color_txt_Bs, W_ct) projection_report['proj_txt_color_part_best_l2'] = best_l2_ct if cv_ct is not None: projection_report['proj_txt_color_part_cv_val'] = cv_ct if len(color_img_As) >= 8: W_ci, best_l2_ci, cv_ci = fit_ridge_with_cv(color_img_As, color_img_Bs, l2_grid) projection_report['proj_sim_img_color_part_mean'] = mean_projected_cosine(color_img_As, color_img_Bs, W_ci) projection_report['proj_img_color_part_best_l2'] = best_l2_ci if cv_ci is not None: projection_report['proj_img_color_part_cv_val'] = cv_ci if len(hier_txt_As) >= 8: W_ht, best_l2_ht, cv_ht = fit_ridge_with_cv(hier_txt_As, hier_txt_Bs, l2_grid) projection_report['proj_sim_txt_hierarchy_part_mean'] = mean_projected_cosine(hier_txt_As, hier_txt_Bs, W_ht) projection_report['proj_txt_hierarchy_part_best_l2'] = best_l2_ht if cv_ht is not None: projection_report['proj_txt_hierarchy_part_cv_val'] = cv_ht if len(hier_img_As) >= 8: W_hi, best_l2_hi, cv_hi = fit_ridge_with_cv(hier_img_As, hier_img_Bs, l2_grid) projection_report['proj_sim_img_hierarchy_part_mean'] = mean_projected_cosine(hier_img_As, hier_img_Bs, W_hi) projection_report['proj_img_hierarchy_part_best_l2'] = best_l2_hi if cv_hi is not None: projection_report['proj_img_hierarchy_part_cv_val'] = cv_hi proj_summary_path = os.path.join('evaluation_outputs', 'projection_summary.json') with open(proj_summary_path, 'w') as f: json.dump(projection_report, f, indent=2) print(f"Saved projection summary to {proj_summary_path}")