|
|
""" |
|
|
Generalized evaluation of the main model with sub-module comparison. |
|
|
This file evaluates the main model's performance by comparing specialized parts |
|
|
(color and hierarchy) with corresponding specialized models. It calculates similarity |
|
|
matrices, linear projections between embedding spaces, and generates detailed statistics |
|
|
on alignment between different representations. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import json |
|
|
import argparse |
|
|
import config |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import pandas as pd |
|
|
from PIL import Image |
|
|
from torchvision import transforms |
|
|
from transformers import CLIPProcessor, CLIPModel as CLIPModelTransformers |
|
|
from tqdm.auto import tqdm |
|
|
|
|
|
|
|
|
from color_model import ColorCLIP as ColorModel, ColorDataset, Tokenizer |
|
|
from config import color_model_path, color_emb_dim, device, hierarchy_model_path, hierarchy_emb_dim |
|
|
from hierarchy_model import Model as HierarchyModel, HierarchyExtractor |
|
|
|
|
|
|
|
|
def load_color_model(color_model_path, color_emb_dim, device): |
|
|
|
|
|
color_checkpoint = torch.load(color_model_path, map_location=device, weights_only=True) |
|
|
color_model = ColorModel(vocab_size=39, embedding_dim=color_emb_dim).to(device) |
|
|
color_model.load_state_dict(color_checkpoint) |
|
|
|
|
|
|
|
|
tokenizer = Tokenizer() |
|
|
with open(config.tokeniser_path, 'r') as f: |
|
|
vocab_dict = json.load(f) |
|
|
color_model.tokenizer = tokenizer |
|
|
|
|
|
color_model.eval() |
|
|
return color_model |
|
|
|
|
|
|
|
|
def get_emb_color_model(color_model, image_path_to_encode, text_to_encode): |
|
|
|
|
|
image = Image.open(image_path_to_encode).convert('RGB') |
|
|
|
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
processed_image = transform(image) |
|
|
|
|
|
|
|
|
processed_image_batch = processed_image.unsqueeze(0).to(device) |
|
|
with torch.no_grad(): |
|
|
image_emb = color_model.image_encoder(processed_image_batch) |
|
|
|
|
|
|
|
|
token_ids = torch.tensor([color_model.tokenizer(text_to_encode)], dtype=torch.long, device=device) |
|
|
lengths = torch.tensor([token_ids.size(1) if token_ids.dim() > 1 else token_ids.size(0)], dtype=torch.long, device=device) |
|
|
with torch.no_grad(): |
|
|
txt_emb = color_model.text_encoder(token_ids, lengths) |
|
|
|
|
|
return image_emb, txt_emb |
|
|
|
|
|
def load_main_model(main_model_path, device): |
|
|
checkpoint = torch.load(main_model_path, map_location=device) |
|
|
main_model = CLIPModel_transformers.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K') |
|
|
state = checkpoint['model_state_dict'] if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint else checkpoint |
|
|
try: |
|
|
main_model.load_state_dict(state, strict=False) |
|
|
except Exception: |
|
|
|
|
|
model_state = main_model.state_dict() |
|
|
filtered = {k: v for k, v in state.items() if k in model_state and model_state[k].shape == v.shape} |
|
|
main_model.load_state_dict(filtered, strict=False) |
|
|
main_model.to(device) |
|
|
main_model.eval() |
|
|
processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K') |
|
|
return main_model, processor |
|
|
|
|
|
|
|
|
def load_hierarchy_model(hierarchy_model_path, device): |
|
|
checkpoint = torch.load(hierarchy_model_path, map_location=device) |
|
|
hierarchy_classes = checkpoint.get('hierarchy_classes', []) |
|
|
model = HierarchyModel(num_hierarchy_classes=len(hierarchy_classes), embed_dim=config.hierarchy_emb_dim).to(device) |
|
|
model.load_state_dict(checkpoint['model_state']) |
|
|
extractor = HierarchyExtractor(hierarchy_classes, verbose=False) |
|
|
model.set_hierarchy_extractor(extractor) |
|
|
model.eval() |
|
|
return model |
|
|
|
|
|
|
|
|
def get_emb_hierarchy_model(hierarchy_model, image_path_to_encode, text_to_encode): |
|
|
image = Image.open(image_path_to_encode).convert('RGB') |
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
]) |
|
|
image_tensor = transform(image).unsqueeze(0).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
img_emb = hierarchy_model.get_image_embeddings(image_tensor) |
|
|
txt_emb = hierarchy_model.get_text_embeddings(text_to_encode) |
|
|
|
|
|
return img_emb, txt_emb |
|
|
|
|
|
def get_emb_main_model(main_model, processor, image_path_to_encode, text_to_encode): |
|
|
image = Image.open(image_path_to_encode).convert('RGB') |
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
image = transform(image) |
|
|
image = image.unsqueeze(0).to(device) |
|
|
|
|
|
text_inputs = processor(text=[text_to_encode], return_tensors="pt", padding=True) |
|
|
text_inputs = {k: v.to(device) for k, v in text_inputs.items()} |
|
|
outputs = main_model(**text_inputs, pixel_values=image) |
|
|
text_emb = outputs.text_embeds |
|
|
image_emb = outputs.image_embeds |
|
|
|
|
|
return text_emb, image_emb |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
parser = argparse.ArgumentParser(description='Evaluate main model parts vs small models and build similarity matrices') |
|
|
parser.add_argument('--main-checkpoint', type=str, default='models/laion_explicable_model.pth') |
|
|
parser.add_argument('--color-checkpoint', type=str, default='models/color_model.pt') |
|
|
parser.add_argument('--csv', type=str, default='data/data_with_local_paths.csv') |
|
|
parser.add_argument('--color-emb-dim', type=int, default=16) |
|
|
parser.add_argument('--num-samples', type=int, default=200) |
|
|
parser.add_argument('--seed', type=int, default=42) |
|
|
parser.add_argument('--primary-metric', type=str, default='sim_color_txt_img', |
|
|
choices=['sim_txt_color_part', 'sim_img_color_part', 'sim_color_txt_img', 'sim_small_txt_img', |
|
|
'sim_txt_hierarchy_part', 'sim_img_hierarchy_part']) |
|
|
parser.add_argument('--top-k', type=int, default=30) |
|
|
parser.add_argument('--heatmap', action='store_true') |
|
|
parser.add_argument('--l2-grid', type=str, default='1e-5,1e-4,1e-3,1e-2,1e-1') |
|
|
args = parser.parse_args() |
|
|
|
|
|
main_checkpoint = args.main_checkpoint |
|
|
color_checkpoint = args.color_checkpoint |
|
|
csv = args.csv |
|
|
color_emb_dim = args.color_emb_dim |
|
|
num_samples = args.num_samples |
|
|
seed = args.seed |
|
|
primary_metric = args.primary_metric |
|
|
top_k = args.top_k |
|
|
l2_grid = [float(x) for x in args.l2_grid.split(',') if x] |
|
|
device = torch.device("mps") |
|
|
|
|
|
df = pd.read_csv(csv) |
|
|
|
|
|
|
|
|
def normalize_color(c): |
|
|
if pd.isna(c): |
|
|
return c |
|
|
s = str(c).strip().lower() |
|
|
aliases = { |
|
|
'grey': 'gray', |
|
|
'navy blue': 'navy', |
|
|
'light blue': 'blue', |
|
|
'dark blue': 'blue', |
|
|
'light grey': 'gray', |
|
|
'dark grey': 'gray', |
|
|
'light gray': 'gray', |
|
|
'dark gray': 'gray', |
|
|
} |
|
|
return aliases.get(s, s) |
|
|
|
|
|
if config.color_column in df.columns: |
|
|
df[config.color_column] = df[config.color_column].apply(normalize_color) |
|
|
|
|
|
color_model = load_color_model(color_checkpoint, color_emb_dim, device) |
|
|
main_model, processor = load_main_model(main_checkpoint, device) |
|
|
hierarchy_model = load_hierarchy_model(hierarchy_model_path, device) |
|
|
|
|
|
|
|
|
results = [] |
|
|
|
|
|
|
|
|
color_txt_As, color_txt_Bs = [], [] |
|
|
color_img_As, color_img_Bs = [], [] |
|
|
hier_txt_As, hier_txt_Bs = [], [] |
|
|
hier_img_As, hier_img_Bs = [], [] |
|
|
|
|
|
|
|
|
pd.options.mode.copy_on_write = True |
|
|
rng = pd.Series(range(len(df)), dtype=int) |
|
|
_ = rng |
|
|
torch.manual_seed(seed) |
|
|
|
|
|
unique_hiers = sorted(df[config.hierarchy_column].dropna().unique()) |
|
|
unique_colors = sorted(df[config.color_column].dropna().unique()) |
|
|
|
|
|
|
|
|
total_pairs = len(unique_hiers) * len(unique_colors) |
|
|
pair_pbar = tqdm(total=total_pairs, desc="Evaluating pairs", leave=False) |
|
|
for hierarchy in unique_hiers: |
|
|
for color in unique_colors: |
|
|
group = df[(df[config.hierarchy_column] == hierarchy) & (df[config.color_column] == color)] |
|
|
|
|
|
|
|
|
k = min(num_samples, len(group)) |
|
|
group_iter = group.sample(n=k, random_state=seed) if len(group) > k else group.iloc[:k] |
|
|
|
|
|
|
|
|
inner_pbar = tqdm(total=len(group_iter), desc=f"{hierarchy}/{color}", leave=False) |
|
|
for row_idx, (_, example) in enumerate(group_iter.iterrows()): |
|
|
try: |
|
|
image_emb, txt_emb = get_emb_color_model(color_model, example['local_image_path'], example['text']) |
|
|
image_emb_hier, txt_emb_hier = get_emb_hierarchy_model(hierarchy_model, example['local_image_path'], example['text']) |
|
|
text_emb_main_model, image_emb_main_model = get_emb_main_model( |
|
|
main_model, processor, example['local_image_path'], example['text'] |
|
|
) |
|
|
|
|
|
color_part_txt = text_emb_main_model[:, :color_emb_dim] |
|
|
color_part_img = image_emb_main_model[:, :color_emb_dim] |
|
|
hier_part_txt = text_emb_main_model[:, color_emb_dim:color_emb_dim + hierarchy_emb_dim] |
|
|
hier_part_img = image_emb_main_model[:, color_emb_dim:color_emb_dim + hierarchy_emb_dim] |
|
|
|
|
|
|
|
|
color_part_txt = F.normalize(color_part_txt, dim=1) |
|
|
color_part_img = F.normalize(color_part_img, dim=1) |
|
|
hier_part_txt = F.normalize(hier_part_txt, dim=1) |
|
|
hier_part_img = F.normalize(hier_part_img, dim=1) |
|
|
txt_emb = F.normalize(txt_emb, dim=1) |
|
|
image_emb = F.normalize(image_emb, dim=1) |
|
|
txt_emb_hier = F.normalize(txt_emb_hier, dim=1) |
|
|
image_emb_hier = F.normalize(image_emb_hier, dim=1) |
|
|
|
|
|
sim_txt_color_part = F.cosine_similarity(txt_emb, color_part_txt).item() |
|
|
sim_img_color_part = F.cosine_similarity(image_emb, color_part_img).item() |
|
|
sim_color_txt_img = F.cosine_similarity(color_part_txt, color_part_img).item() |
|
|
sim_small_txt_img = F.cosine_similarity(txt_emb, image_emb).item() |
|
|
|
|
|
sim_txt_hierarchy_part = F.cosine_similarity(txt_emb_hier, hier_part_txt).item() |
|
|
sim_img_hierarchy_part = F.cosine_similarity(image_emb_hier, hier_part_img).item() |
|
|
|
|
|
|
|
|
color_txt_As.append(color_part_txt.squeeze(0).detach().cpu()) |
|
|
color_txt_Bs.append(txt_emb.squeeze(0).detach().cpu()) |
|
|
color_img_As.append(color_part_img.squeeze(0).detach().cpu()) |
|
|
color_img_Bs.append(image_emb.squeeze(0).detach().cpu()) |
|
|
|
|
|
hier_txt_As.append(hier_part_txt.squeeze(0).detach().cpu()) |
|
|
hier_txt_Bs.append(txt_emb_hier.squeeze(0).detach().cpu()) |
|
|
hier_img_As.append(hier_part_img.squeeze(0).detach().cpu()) |
|
|
hier_img_Bs.append(image_emb_hier.squeeze(0).detach().cpu()) |
|
|
|
|
|
results.append({ |
|
|
'hierarchy': hierarchy, |
|
|
'color': color, |
|
|
'row_index': int(row_idx), |
|
|
'sim_txt_color_part': float(sim_txt_color_part), |
|
|
'sim_img_color_part': float(sim_img_color_part), |
|
|
'sim_color_txt_img': float(sim_color_txt_img), |
|
|
'sim_small_txt_img': float(sim_small_txt_img), |
|
|
'sim_txt_hierarchy_part': float(sim_txt_hierarchy_part), |
|
|
'sim_img_hierarchy_part': float(sim_img_hierarchy_part), |
|
|
}) |
|
|
except Exception as e: |
|
|
print(f"Skipping example due to error: {e}") |
|
|
finally: |
|
|
inner_pbar.update(1) |
|
|
inner_pbar.close() |
|
|
pair_pbar.update(1) |
|
|
pair_pbar.close() |
|
|
|
|
|
results_df = pd.DataFrame(results) |
|
|
|
|
|
|
|
|
os.makedirs('evaluation_outputs', exist_ok=True) |
|
|
raw_path = os.path.join('evaluation_outputs', 'similarities_raw.csv') |
|
|
results_df.to_csv(raw_path, index=False) |
|
|
print(f"Saved raw similarities to {raw_path}") |
|
|
|
|
|
|
|
|
metrics = ['sim_txt_color_part', 'sim_img_color_part', 'sim_color_txt_img', 'sim_small_txt_img', |
|
|
'sim_txt_hierarchy_part', 'sim_img_hierarchy_part'] |
|
|
|
|
|
|
|
|
overall_means = results_df[metrics].mean().to_frame(name='mean').T |
|
|
overall_means.insert(0, 'level', 'overall') |
|
|
|
|
|
|
|
|
by_hierarchy = results_df.groupby(config.hierarchy_column)[metrics].mean().reset_index() |
|
|
by_hierarchy.insert(0, 'level', config.hierarchy_column) |
|
|
|
|
|
|
|
|
by_color = results_df.groupby(config.color_column)[metrics].mean().reset_index() |
|
|
by_color.insert(0, 'level', config.color_column) |
|
|
|
|
|
|
|
|
by_pair = results_df.groupby([config.hierarchy_column, config.color_column])[metrics].mean().reset_index() |
|
|
by_pair.insert(0, 'level', 'hierarchy_color') |
|
|
|
|
|
summary_df = pd.concat([overall_means, by_hierarchy, by_color, by_pair], ignore_index=True) |
|
|
summary_path = os.path.join('evaluation_outputs', 'similarities_summary.csv') |
|
|
summary_df.to_csv(summary_path, index=False) |
|
|
print(f"Saved summary statistics to {summary_path}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
by_pair_core = results_df.groupby([config.hierarchy_column, config.color_column])[metrics].mean().reset_index() |
|
|
top_pairs = by_pair_core.nlargest(top_k, primary_metric) |
|
|
matrix = top_pairs.pivot(index=config.hierarchy_column, columns=config.color_column, values=primary_metric) |
|
|
os.makedirs('evaluation_outputs', exist_ok=True) |
|
|
matrix_csv_path = os.path.join('evaluation_outputs', f'similarity_matrix_{primary_metric}_top{top_k}.csv') |
|
|
matrix.to_csv(matrix_csv_path) |
|
|
print(f"Saved similarity matrix to {matrix_csv_path}") |
|
|
|
|
|
if args.heatmap: |
|
|
try: |
|
|
import seaborn as sns |
|
|
import matplotlib.pyplot as plt |
|
|
plt.figure(figsize=(max(6, 0.5 * len(matrix.columns)), max(4, 0.5 * len(matrix.index)))) |
|
|
sns.heatmap(matrix, annot=False, cmap='viridis') |
|
|
plt.title(f'Similarity matrix (top {top_k}) - {primary_metric}') |
|
|
heatmap_path = os.path.join('evaluation_outputs', f'similarity_matrix_{primary_metric}_top{top_k}.png') |
|
|
plt.tight_layout() |
|
|
plt.savefig(heatmap_path, dpi=200) |
|
|
plt.close() |
|
|
print(f"Saved similarity heatmap to {heatmap_path}") |
|
|
except Exception as e: |
|
|
print(f"Skipping heatmap generation: {e}") |
|
|
except Exception as e: |
|
|
print(f"Skipping matrix generation: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def fit_ridge_projection(A, B, l2_reg=1e-3): |
|
|
|
|
|
A = torch.stack(A) |
|
|
B = torch.stack(B) |
|
|
|
|
|
AtA = A.T @ A |
|
|
D_in = AtA.shape[0] |
|
|
AtA_reg = AtA + l2_reg * torch.eye(D_in) |
|
|
W = torch.linalg.solve(AtA_reg, A.T @ B) |
|
|
return W |
|
|
|
|
|
def fit_ridge_with_cv(A, B, l2_values): |
|
|
|
|
|
if len(A) < 10: |
|
|
|
|
|
best_l2 = l2_values[min(len(l2_values) // 2, len(l2_values)-1)] |
|
|
W = fit_ridge_projection(A, B, best_l2) |
|
|
return W, best_l2, None |
|
|
|
|
|
N = len(A) |
|
|
idx = torch.randperm(N) |
|
|
split = int(0.8 * N) |
|
|
train_idx = idx[:split] |
|
|
val_idx = idx[split:] |
|
|
|
|
|
A_tensor = torch.stack(A) |
|
|
B_tensor = torch.stack(B) |
|
|
|
|
|
A_train, B_train = A_tensor[train_idx], B_tensor[train_idx] |
|
|
A_val, B_val = A_tensor[val_idx], B_tensor[val_idx] |
|
|
|
|
|
def to_list(t): |
|
|
return [row for row in t] |
|
|
|
|
|
best_l2 = None |
|
|
best_score = -1.0 |
|
|
for l2 in l2_values: |
|
|
W = fit_ridge_projection(to_list(A_train), to_list(B_train), l2) |
|
|
score = mean_projected_cosine(to_list(A_val), to_list(B_val), W) |
|
|
if score > best_score: |
|
|
best_score = score |
|
|
best_l2 = l2 |
|
|
|
|
|
|
|
|
W_best = fit_ridge_projection(A, B, best_l2) |
|
|
return W_best, best_l2, best_score |
|
|
|
|
|
def mean_projected_cosine(A, B, W): |
|
|
A = torch.stack(A) |
|
|
B = torch.stack(B) |
|
|
A_proj = A @ W |
|
|
A_proj = F.normalize(A_proj, dim=1) |
|
|
B = F.normalize(B, dim=1) |
|
|
return torch.mean(torch.sum(A_proj * B, dim=1)).item() |
|
|
|
|
|
projection_report = {} |
|
|
|
|
|
if len(color_txt_As) >= 8: |
|
|
W_ct, best_l2_ct, cv_ct = fit_ridge_with_cv(color_txt_As, color_txt_Bs, l2_grid) |
|
|
projection_report['proj_sim_txt_color_part_mean'] = mean_projected_cosine(color_txt_As, color_txt_Bs, W_ct) |
|
|
projection_report['proj_txt_color_part_best_l2'] = best_l2_ct |
|
|
if cv_ct is not None: |
|
|
projection_report['proj_txt_color_part_cv_val'] = cv_ct |
|
|
if len(color_img_As) >= 8: |
|
|
W_ci, best_l2_ci, cv_ci = fit_ridge_with_cv(color_img_As, color_img_Bs, l2_grid) |
|
|
projection_report['proj_sim_img_color_part_mean'] = mean_projected_cosine(color_img_As, color_img_Bs, W_ci) |
|
|
projection_report['proj_img_color_part_best_l2'] = best_l2_ci |
|
|
if cv_ci is not None: |
|
|
projection_report['proj_img_color_part_cv_val'] = cv_ci |
|
|
if len(hier_txt_As) >= 8: |
|
|
W_ht, best_l2_ht, cv_ht = fit_ridge_with_cv(hier_txt_As, hier_txt_Bs, l2_grid) |
|
|
projection_report['proj_sim_txt_hierarchy_part_mean'] = mean_projected_cosine(hier_txt_As, hier_txt_Bs, W_ht) |
|
|
projection_report['proj_txt_hierarchy_part_best_l2'] = best_l2_ht |
|
|
if cv_ht is not None: |
|
|
projection_report['proj_txt_hierarchy_part_cv_val'] = cv_ht |
|
|
if len(hier_img_As) >= 8: |
|
|
W_hi, best_l2_hi, cv_hi = fit_ridge_with_cv(hier_img_As, hier_img_Bs, l2_grid) |
|
|
projection_report['proj_sim_img_hierarchy_part_mean'] = mean_projected_cosine(hier_img_As, hier_img_Bs, W_hi) |
|
|
projection_report['proj_img_hierarchy_part_best_l2'] = best_l2_hi |
|
|
if cv_hi is not None: |
|
|
projection_report['proj_img_hierarchy_part_cv_val'] = cv_hi |
|
|
|
|
|
proj_summary_path = os.path.join('evaluation_outputs', 'projection_summary.json') |
|
|
with open(proj_summary_path, 'w') as f: |
|
|
json.dump(projection_report, f, indent=2) |
|
|
print(f"Saved projection summary to {proj_summary_path}") |
|
|
|
|
|
|