| |
| """ |
| Section 5.3.6 — Embedding Structure Evaluation |
| =============================================== |
| |
| Verifies that the GAP-CLIP embedding subspaces encode the attributes they are |
| designed for, and tests zero-shot vision-language alignment. |
| |
| Test A — Different colors, same hierarchy: |
| The 64D hierarchy subspace should be MORE similar between two items that |
| share a category but differ in color, compared to the 16D color subspace. |
| Expected result: 1000/1000 pass. |
| Example: |
| In Test A, the code computes for each pair: |
| - sim_hier = cosine between the hierarchy slice (emb[16:80]) |
| - sim_full512 = cosine between the full 512-d embedding (emb) |
| The test check: |
| - pair_ok = (sim_hier > sim_color) and (sim_hier > sim_full512) |
| Test B — Same color, different hierarchies: |
| The 16D color subspace should be MORE similar than the full 512D embedding |
| for items sharing a color but differing in category. |
| Expected result: 1000/1000 pass. |
| |
| Test C — Subspace Decomposition Consistency: |
| Encode a full description (e.g. "red dress in cotton"), a standalone color |
| ("red"), and a standalone hierarchy ("dress"). Verify that: |
| - The color subspace (first 16D) of the full embedding is more similar |
| to the color-only embedding than to the hierarchy-only embedding. |
| - The hierarchy subspace (dims 16-80) of the full embedding is more |
| similar to the hierarchy-only embedding than to the color-only embedding. |
| Expected result: 1000/1000 pass. |
| |
| Test D — Zero-shot image-to-text classification: |
| Each image is used as a query; the highest-scoring text label (cosine in |
| shared latent space) is the predicted class. Accuracy is computed across |
| three datasets (Fashion-MNIST, KAGL Marqo, Internal). |
| |
| Paper reference: Section 5.3.6 and Table 4. |
| |
| Run directly: |
| python sec536_embedding_structure.py --tests AB # only tests A+B |
| python sec536_embedding_structure.py --tests ABCD # all tests |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import os |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
| from dataclasses import dataclass |
| from pathlib import Path |
| import random |
| import sys |
|
|
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
| from typing import Dict, List, Optional, Sequence, Tuple |
|
|
| import numpy as np |
| import pandas as pd |
| import requests |
| from sklearn.metrics import f1_score |
| import torch |
| import torch.nn.functional as F |
| from io import BytesIO |
| from PIL import Image, ImageOps |
| from torchvision import transforms |
| from torchvision import datasets |
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
| from transformers import CLIPModel as CLIPModelTransformers |
| from transformers import CLIPProcessor |
|
|
| from training.hierarchy_model import HierarchyExtractor |
| from evaluation.type_aware_scoring import ( |
| TypeAwareParams, |
| compute_type_aware_scores, |
| ) |
| from evaluation.ensemble_scoring import ( |
| AdaptiveEnsembleParams, |
| EnsembleParams, |
| compute_prob_ensemble, |
| compute_prob_ensemble_adaptive, |
| rerank_top_k, |
| ) |
| from evaluation.hybrid_scoring import compute_hybrid_metrics |
| from evaluation.pure_boost_scoring import ( |
| compute_pure_boost_metrics, |
| encode_images_with_specialist_tta, |
| encode_text_with_specialist_ensembled, |
| ) |
|
|
| try: |
| import config as project_config |
| except Exception: |
| project_config = None |
|
|
| DEFAULT_COLOR_EMB_DIM = getattr(project_config, "color_emb_dim", 16) |
| DEFAULT_HIERARCHY_EMB_DIM = getattr(project_config, "hierarchy_emb_dim", 64) |
| DEFAULT_MAIN_EMB_DIM = getattr(project_config, "main_emb_dim", 512) |
| DEFAULT_MAIN_MODEL_PATH = getattr(project_config, "main_model_path", "models/gap_clip.pth") |
| DEFAULT_DEVICE = getattr(project_config, "device", torch.device("cpu")) |
|
|
| _HIERARCHY_EXTRACTOR = HierarchyExtractor([ |
| "accessories", "bodysuits", "bras", "coat", "dress", "jacket", |
| "legging", "pant", "polo", "shirt", "shoes", "short", "skirt", |
| "socks", "sweater", "swimwear", "top", "underwear", |
| ], verbose=False) |
|
|
|
|
| @dataclass |
| class RuntimeConfig: |
| color_emb_dim: int = DEFAULT_COLOR_EMB_DIM |
| hierarchy_emb_dim: int = DEFAULT_HIERARCHY_EMB_DIM |
| main_emb_dim: int = DEFAULT_MAIN_EMB_DIM |
| main_model_path: str = DEFAULT_MAIN_MODEL_PATH |
| device: torch.device = DEFAULT_DEVICE |
|
|
| DEFAULT_NUM_EXAMPLES = 10000 |
| DEFAULT_NUM_PRINTED = 3 |
|
|
| COLORS = [ |
| "yellow", "blue", "red", "green", "black", "white", "pink", "purple", "brown", "orange", |
| ] |
| HIERARCHIES = [ |
| "dress", "shirt", "pants", "skirt", "jacket", "coat", "jeans", "sweater", "shorts", "top", |
| ] |
|
|
|
|
| LONG_TEXT_TEMPLATES = [ |
| "{color} {hierarchy}", |
| "{color} {hierarchy} with buttons", |
| "{color} {hierarchy} in cotton", |
| "casual {color} {hierarchy} for women", |
| "elegant {color} {hierarchy} with pockets", |
| ] |
|
|
|
|
| |
| |
| |
| |
| ZERO_SHOT_TEMPLATES = [ |
| "a photo of a {label}", |
| "a photo of the {label}", |
| "a picture of a {label}", |
| "an image of a {label}", |
| "a product photo of a {label}", |
| "a fashion photo of a {label}", |
| "a catalog image of a {label}", |
| "a close-up photo of a {label}", |
| "a {label}", |
| "clothing: {label}", |
| ] |
|
|
|
|
| |
| |
| |
| DATASET_FUSION_WEIGHTS: Dict[str, Tuple[float, float, float, float]] = { |
| "internal": (0.5, 0.8, 0.2, 0.0), |
| "modanet": (0.5, 0.7, 0.3, 0.0), |
| |
| |
| |
| |
| "kagl": (0.3, 1.0, 0.3, 0.0), |
| |
| |
| |
| |
| "fmnist": (0.2, 1.0, 0.2, 0.0), |
| } |
|
|
| |
| |
| ZERO_SHOT_SOFTMAX_TAU = 0.01 |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| DATASET_ENSEMBLE_PARAMS: Dict[str, EnsembleParams] = { |
| "internal": EnsembleParams(weights={ |
| "full": 0.20, "gen": 0.25, "hier": 0.30, |
| "nocolor": 0.20, "color": 0.05, |
| }), |
| "modanet": EnsembleParams(weights={ |
| "full": 0.20, "gen": 0.25, "hier": 0.30, |
| "nocolor": 0.20, "color": 0.05, |
| }), |
| "kagl": EnsembleParams(weights={ |
| "full": 0.30, "gen": 0.30, "hier": 0.05, |
| "nocolor": 0.30, "color": 0.05, |
| }), |
| "fmnist": EnsembleParams( |
| tau_full=0.01, tau_sub=0.5, |
| weights={ |
| "full": 0.20, "gen": 0.20, "hier": 0.40, |
| "nocolor": 0.20, "color": 0.0, |
| }, |
| ), |
| } |
|
|
| |
| |
| |
| |
| DATASET_RERANK_PARAMS: Dict[str, Tuple[int, float]] = { |
| "internal": (3, 0.4), |
| "modanet": (3, 0.4), |
| "kagl": (3, 0.5), |
| "fmnist": (3, 0.6), |
| } |
|
|
|
|
| DATASET_TYPE_AWARE_PARAMS: Dict[str, TypeAwareParams] = { |
| "internal": TypeAwareParams( |
| w_hier=0.7, w_color=0.0, |
| alpha=0.3, beta=0.6, gamma=0.1, delta=0.4, |
| lambda_match=0.5, tau_type=0.05, |
| ), |
| "modanet": TypeAwareParams( |
| w_hier=0.7, w_color=0.0, |
| alpha=0.3, beta=0.6, gamma=0.1, delta=0.4, |
| lambda_match=0.5, tau_type=0.05, |
| ), |
| "kagl": TypeAwareParams( |
| w_hier=0.2, w_color=0.0, |
| alpha=0.5, beta=0.6, gamma=0.2, delta=0.4, |
| lambda_match=0.8, tau_type=0.05, |
| ), |
| "fmnist": TypeAwareParams( |
| w_hier=1.0, w_color=0.0, |
| alpha=0.1, beta=0.4, gamma=0.1, delta=0.3, |
| lambda_match=1.0, tau_type=0.05, |
| ), |
| } |
|
|
|
|
| def build_text_query(color: str, hierarchy: str) -> str: |
| template = random.choice(LONG_TEXT_TEMPLATES) |
| return template.format(color=color, hierarchy=hierarchy) |
|
|
|
|
| def resolve_runtime_config() -> RuntimeConfig: |
| """Resolve config from local config.py if available, else use defaults.""" |
| cfg = RuntimeConfig() |
| try: |
| import config |
|
|
| cfg.color_emb_dim = getattr(config, "color_emb_dim", cfg.color_emb_dim) |
| cfg.hierarchy_emb_dim = getattr(config, "hierarchy_emb_dim", cfg.hierarchy_emb_dim) |
| cfg.main_emb_dim = getattr(config, "main_emb_dim", cfg.main_emb_dim) |
| cfg.main_model_path = getattr(config, "main_model_path", cfg.main_model_path) |
| cfg.device = getattr(config, "device", cfg.device) |
| except Exception: |
| if torch.cuda.is_available(): |
| cfg.device = torch.device("cuda") |
| elif torch.backends.mps.is_available(): |
| cfg.device = torch.device("mps") |
| else: |
| cfg.device = torch.device("cpu") |
|
|
| return cfg |
|
|
|
|
| def load_main_model(device: torch.device, main_model_path: str) -> Tuple[CLIPModelTransformers, CLIPProcessor]: |
| """Load GAP-CLIP from local checkpoint path only.""" |
| model_path = Path(main_model_path) |
| if not model_path.exists(): |
| raise FileNotFoundError(f"Main model checkpoint not found: {main_model_path}") |
|
|
| clip_name = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K" |
| model = CLIPModelTransformers.from_pretrained(clip_name) |
| checkpoint = torch.load(str(model_path), map_location=device) |
| if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint: |
| model.load_state_dict(checkpoint["model_state_dict"], strict=False) |
| else: |
| model.load_state_dict(checkpoint, strict=False) |
| model = model.to(device) |
| model.eval() |
| processor = CLIPProcessor.from_pretrained(clip_name) |
| return model, processor |
|
|
|
|
| def encode_text(model, processor, text_queries, device): |
| """Encode text queries into embeddings (unnormalized).""" |
| if isinstance(text_queries, str): |
| text_queries = [text_queries] |
| inputs = processor(text=text_queries, return_tensors="pt", padding=True, truncation=True) |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
| with torch.no_grad(): |
| text_features = model.get_text_features(**inputs) |
| return text_features |
|
|
|
|
| def encode_image(model, processor, images, device): |
| """Encode images into embeddings (unnormalized).""" |
| if not isinstance(images, list): |
| images = [images] |
| inputs = processor(images=images, return_tensors="pt") |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
| with torch.no_grad(): |
| image_features = model.get_image_features(**inputs) |
| return image_features |
|
|
|
|
| def get_text_embedding( |
| model: CLIPModelTransformers, processor: CLIPProcessor, device: torch.device, text: str) -> torch.Tensor: |
| """Normalized single text embedding (shape: [512]).""" |
| return F.normalize(encode_text(model, processor, text, device), dim=-1).squeeze(0) |
|
|
|
|
| def cosine(a: torch.Tensor, b: torch.Tensor) -> float: |
| return F.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0), dim=1).item() |
|
|
|
|
| def delta_percent(reference: float, value: float) -> float: |
| """Relative delta in percent: (value-reference)/|reference|*100.""" |
| denom = max(abs(reference), 1e-8) |
| return ((value - reference) / denom) * 100.0 |
|
|
|
|
| def format_bool(ok: bool) -> str: |
| return "PASS" if ok else "FAIL" |
|
|
|
|
| def print_table(title: str, headers: List[str], rows: List[List[str]]) -> None: |
| print("\n" + "=" * 120) |
| print(title) |
| print("=" * 120) |
| all_rows = [headers] + rows |
| col_widths = [max(len(str(r[i])) for r in all_rows) for i in range(len(headers))] |
|
|
| def fmt(row: List[str]) -> str: |
| return " | ".join(str(v).ljust(col_widths[i]) for i, v in enumerate(row)) |
|
|
| print(fmt(headers)) |
| print("-" * (sum(col_widths) + 3 * (len(headers) - 1))) |
| for row in rows: |
| print(fmt(row)) |
|
|
|
|
| def run_test_a( |
| model: CLIPModelTransformers, |
| processor: CLIPProcessor, |
| cfg: RuntimeConfig, |
| num_examples: int, |
| num_printed: int, |
| test_name: str = "Test A") -> Dict[str, bool]: |
| """ |
| A: different colors + same hierarchy. |
| Expect hierarchy subspace to be more similar than color subspace. |
| """ |
| positive_pairs: List[Tuple[str, str]] = [] |
| negative_pairs: List[Tuple[str, str]] = [] |
| for _ in range(num_examples): |
| hierarchy = random.choice(HIERARCHIES) |
| c1, c2 = random.sample(COLORS, 2) |
| negative_hierarchy = random.choice([h for h in HIERARCHIES if h != hierarchy]) |
| positive_pairs.append((build_text_query(c1, hierarchy), build_text_query(c2, hierarchy))) |
| negative_pairs.append((build_text_query(c1, hierarchy), build_text_query(c2, negative_hierarchy))) |
|
|
| rows: List[List[str]] = [] |
| pair_outcomes: List[bool] = [] |
| full512_outcomes: List[bool] = [] |
| hier_gt_full_outcomes: List[bool] = [] |
| hier_gt_color_outcomes: List[bool] = [] |
| delta_color_vs_full_values: List[float] = [] |
| delta_hier_vs_full_values: List[float] = [] |
|
|
| for (left, right), (_, negative_right) in zip(positive_pairs, negative_pairs): |
| emb_left = get_text_embedding(model, processor, cfg.device, left) |
| emb_right = get_text_embedding(model, processor, cfg.device, right) |
| emb_negative_right = get_text_embedding(model, processor, cfg.device, negative_right) |
|
|
| left_color = emb_left[: cfg.color_emb_dim] |
| right_color = emb_right[: cfg.color_emb_dim] |
| left_hier = emb_left[cfg.color_emb_dim : cfg.color_emb_dim + cfg.hierarchy_emb_dim] |
| right_hier = emb_right[cfg.color_emb_dim : cfg.color_emb_dim + cfg.hierarchy_emb_dim] |
|
|
| sim_color = cosine(left_color, right_color) |
| sim_hier = cosine(left_hier, right_hier) |
| sim_full512 = cosine(emb_left, emb_right) |
| sim_full512_negative = cosine(emb_left, emb_negative_right) |
| delta_color_vs_full_pct = delta_percent(sim_full512, sim_color) |
| delta_hier_vs_full_pct = delta_percent(sim_full512, sim_hier) |
| delta_color_vs_full_values.append(delta_color_vs_full_pct) |
| delta_hier_vs_full_values.append(delta_hier_vs_full_pct) |
|
|
| hierarchy_higher_than_full = sim_hier > sim_full512 |
| hierarchy_higher_than_color = sim_hier > sim_color |
| pair_ok = hierarchy_higher_than_full and hierarchy_higher_than_color |
| pair_outcomes.append(pair_ok) |
| hier_gt_full_outcomes.append(hierarchy_higher_than_full) |
| hier_gt_color_outcomes.append(hierarchy_higher_than_color) |
| full512_outcomes.append(sim_full512 > sim_full512_negative) |
|
|
| rows.append( |
| [ |
| f"{left} vs {right}", |
| f"{sim_color:.4f}", |
| f"{sim_hier:.4f}", |
| f"{sim_full512:.4f}", |
| f"{delta_color_vs_full_pct:+.2f}%", |
| f"{delta_hier_vs_full_pct:+.2f}%", |
| format_bool(pair_ok), |
| ] |
| ) |
|
|
| print_table( |
| f"{test_name}: Different colors, same hierarchy (showing {min(num_printed, len(rows))}/{len(rows)} examples)", |
| [ |
| "Pair", |
| "CosSim first16(color)", |
| "CosSim hier64", |
| "CosSim full512", |
| "Delta first16 vs full512 (%)", |
| "Delta hier64 vs full512 (%)", |
| "Result", |
| ], |
| rows[:num_printed], |
| ) |
|
|
| overall = all(pair_outcomes) |
| pass_rate = sum(pair_outcomes) / len(pair_outcomes) |
| full512_accuracy = sum(full512_outcomes) / len(full512_outcomes) |
| hier_gt_full_rate = sum(hier_gt_full_outcomes) / len(hier_gt_full_outcomes) |
| hier_gt_color_rate = sum(hier_gt_color_outcomes) / len(hier_gt_color_outcomes) |
| avg_delta_color_vs_full = sum(delta_color_vs_full_values) / len(delta_color_vs_full_values) |
| avg_delta_hier_vs_full = sum(delta_hier_vs_full_values) / len(delta_hier_vs_full_values) |
| print(f"{test_name} aggregate: {sum(pair_outcomes)}/{len(pair_outcomes)} passed ({pass_rate:.2%})") |
| print(f" sub-condition hier > full512: {sum(hier_gt_full_outcomes)}/{len(hier_gt_full_outcomes)} ({hier_gt_full_rate:.2%})") |
| print(f" sub-condition hier > color: {sum(hier_gt_color_outcomes)}/{len(hier_gt_color_outcomes)} ({hier_gt_color_rate:.2%})") |
| print( |
| f"{test_name} full512 pair-discrimination accuracy " |
| f"(same-hierarchy > different-hierarchy): {sum(full512_outcomes)}/{len(full512_outcomes)} " |
| f"({full512_accuracy:.2%})" |
| ) |
| print( |
| f"{test_name} avg deltas: " |
| f"first16 vs full512 = {avg_delta_color_vs_full:+.2f}%, " |
| f"hier64 vs full512 = {avg_delta_hier_vs_full:+.2f}%" |
| ) |
| return { |
| "overall": overall, |
| "accuracy_full512": full512_accuracy, |
| "pass_rate": pass_rate, |
| "hier_gt_full_rate": hier_gt_full_rate, |
| "hier_gt_color_rate": hier_gt_color_rate, |
| "avg_delta_color_vs_full": avg_delta_color_vs_full, |
| "avg_delta_hier_vs_full": avg_delta_hier_vs_full, |
| } |
|
|
|
|
| def run_test_b( |
| model: CLIPModelTransformers, |
| processor: CLIPProcessor, |
| cfg: RuntimeConfig, |
| num_examples: int, |
| num_printed: int, |
| test_name: str = "Test B",) -> Dict[str, bool]: |
| """ |
| B: same color + different hierarchies. |
| Expect similarity in first16 (color) to be higher than full512. |
| """ |
| positive_pairs: List[Tuple[str, str]] = [] |
| negative_pairs: List[Tuple[str, str]] = [] |
| for _ in range(num_examples): |
| color = random.choice(COLORS) |
| h1, h2 = random.sample(HIERARCHIES, 2) |
| negative_color = random.choice([c for c in COLORS if c != color]) |
| positive_pairs.append((build_text_query(color, h1), build_text_query(color, h2))) |
| negative_pairs.append((build_text_query(color, h1), build_text_query(negative_color, h2))) |
|
|
| rows: List[List[str]] = [] |
| pair_outcomes: List[bool] = [] |
| full512_outcomes: List[bool] = [] |
| color_gt_full_outcomes: List[bool] = [] |
| color_gt_hier_outcomes: List[bool] = [] |
| delta_color_vs_full_values: List[float] = [] |
| delta_hier_vs_full_values: List[float] = [] |
|
|
| for (left, right), (_, negative_right) in zip(positive_pairs, negative_pairs): |
| emb_left = get_text_embedding(model, processor, cfg.device, left) |
| emb_right = get_text_embedding(model, processor, cfg.device, right) |
| emb_negative_right = get_text_embedding(model, processor, cfg.device, negative_right) |
|
|
| sim_512 = cosine(emb_left, emb_right) |
| sim_16 = cosine(emb_left[: cfg.color_emb_dim], emb_right[: cfg.color_emb_dim]) |
| sim_hier = cosine( |
| emb_left[cfg.color_emb_dim : cfg.color_emb_dim + cfg.hierarchy_emb_dim], |
| emb_right[cfg.color_emb_dim : cfg.color_emb_dim + cfg.hierarchy_emb_dim], |
| ) |
| sim_512_negative = cosine(emb_left, emb_negative_right) |
| delta_color_vs_full_pct = delta_percent(sim_512, sim_16) |
| delta_hier_vs_full_pct = delta_percent(sim_512, sim_hier) |
| delta_color_vs_full_values.append(delta_color_vs_full_pct) |
| delta_hier_vs_full_values.append(delta_hier_vs_full_pct) |
|
|
| first16_higher_than_full = sim_16 > sim_512 |
| color_higher_than_hier = sim_16 > sim_hier |
| pair_ok = first16_higher_than_full and color_higher_than_hier |
| pair_outcomes.append(pair_ok) |
| color_gt_full_outcomes.append(first16_higher_than_full) |
| color_gt_hier_outcomes.append(color_higher_than_hier) |
| full512_outcomes.append(sim_512 > sim_512_negative) |
|
|
| rows.append( |
| [ |
| f"{left} vs {right}", |
| f"{sim_16:.4f}", |
| f"{sim_hier:.4f}", |
| f"{sim_512:.4f}", |
| f"{delta_color_vs_full_pct:+.2f}%", |
| f"{delta_hier_vs_full_pct:+.2f}%", |
| format_bool(pair_ok), |
| ] |
| ) |
|
|
| print_table( |
| f"{test_name}: Same color, different hierarchies (showing {min(num_printed, len(rows))}/{len(rows)} examples)", |
| [ |
| "Pair", |
| "CosSim first16(color)", |
| "CosSim hier64", |
| "CosSim full512", |
| "Delta first16 vs full512 (%)", |
| "Delta hier64 vs full512 (%)", |
| "Result", |
| ], |
| rows[:num_printed], |
| ) |
|
|
| overall = all(pair_outcomes) |
| pass_rate = sum(pair_outcomes) / len(pair_outcomes) |
| full512_accuracy = sum(full512_outcomes) / len(full512_outcomes) |
| color_gt_full_rate = sum(color_gt_full_outcomes) / len(color_gt_full_outcomes) |
| color_gt_hier_rate = sum(color_gt_hier_outcomes) / len(color_gt_hier_outcomes) |
| avg_delta_color_vs_full = sum(delta_color_vs_full_values) / len(delta_color_vs_full_values) |
| avg_delta_hier_vs_full = sum(delta_hier_vs_full_values) / len(delta_hier_vs_full_values) |
| print(f"{test_name} aggregate: {sum(pair_outcomes)}/{len(pair_outcomes)} passed ({pass_rate:.2%})") |
| print(f" sub-condition color > full512: {sum(color_gt_full_outcomes)}/{len(color_gt_full_outcomes)} ({color_gt_full_rate:.2%})") |
| print(f" sub-condition color > hier: {sum(color_gt_hier_outcomes)}/{len(color_gt_hier_outcomes)} ({color_gt_hier_rate:.2%})") |
| print( |
| f"{test_name} full512 pair-discrimination accuracy " |
| f"(same-color > different-color): {sum(full512_outcomes)}/{len(full512_outcomes)} " |
| f"({full512_accuracy:.2%})" |
| ) |
| print( |
| f"{test_name} avg deltas: " |
| f"first16 vs full512 = {avg_delta_color_vs_full:+.2f}%, " |
| f"hier64 vs full512 = {avg_delta_hier_vs_full:+.2f}%" |
| ) |
| return { |
| "overall": overall, |
| "accuracy_full512": full512_accuracy, |
| "pass_rate": pass_rate, |
| "color_gt_full_rate": color_gt_full_rate, |
| "color_gt_hier_rate": color_gt_hier_rate, |
| "avg_delta_color_vs_full": avg_delta_color_vs_full, |
| "avg_delta_hier_vs_full": avg_delta_hier_vs_full, |
| } |
|
|
|
|
|
|
| def run_test_c( |
| model: CLIPModelTransformers, |
| processor: CLIPProcessor, |
| cfg: RuntimeConfig, |
| num_examples: int, |
| num_printed: int, |
| test_name: str = "Test C",) -> Dict[str, object]: |
| """ |
| C: Subspace Decomposition Consistency. |
| Encode a full description (e.g. "red dress in cotton"), a standalone color |
| ("red"), and a standalone hierarchy ("dress"). Then verify: |
| - The color subspace (first 16D) of the full embedding aligns with the |
| color-only embedding more than with the hierarchy-only embedding. |
| - The hierarchy subspace (dims 16-80) of the full embedding aligns with |
| the hierarchy-only embedding more than with the color-only embedding. |
| """ |
| rows: List[List[str]] = [] |
| color_match_outcomes: List[bool] = [] |
| hier_match_outcomes: List[bool] = [] |
| pair_outcomes: List[bool] = [] |
| sim_color_match_values: List[float] = [] |
| sim_color_cross_values: List[float] = [] |
| sim_hier_match_values: List[float] = [] |
| sim_hier_cross_values: List[float] = [] |
|
|
| for _ in range(num_examples): |
| color = random.choice(COLORS) |
| hierarchy = random.choice(HIERARCHIES) |
| full_text = build_text_query(color, hierarchy) |
|
|
| emb_full = get_text_embedding(model, processor, cfg.device, full_text) |
| emb_color = get_text_embedding(model, processor, cfg.device, color) |
| emb_hier = get_text_embedding(model, processor, cfg.device, hierarchy) |
|
|
| |
| full_color = emb_full[: cfg.color_emb_dim] |
| color_color = emb_color[: cfg.color_emb_dim] |
| hier_color = emb_hier[: cfg.color_emb_dim] |
|
|
| |
| full_hier = emb_full[cfg.color_emb_dim : cfg.color_emb_dim + cfg.hierarchy_emb_dim] |
| color_hier = emb_color[cfg.color_emb_dim : cfg.color_emb_dim + cfg.hierarchy_emb_dim] |
| hier_hier = emb_hier[cfg.color_emb_dim : cfg.color_emb_dim + cfg.hierarchy_emb_dim] |
|
|
| |
| sim_color_match = cosine(full_color, color_color) |
| sim_hier_match = cosine(full_hier, hier_hier) |
|
|
| |
| sim_color_cross = cosine(full_color, hier_color) |
| sim_hier_cross = cosine(full_hier, color_hier) |
|
|
| sim_color_match_values.append(sim_color_match) |
| sim_color_cross_values.append(sim_color_cross) |
| sim_hier_match_values.append(sim_hier_match) |
| sim_hier_cross_values.append(sim_hier_cross) |
|
|
| color_ok = sim_color_match > sim_color_cross |
| hier_ok = sim_hier_match > sim_hier_cross |
| pair_ok = color_ok and hier_ok |
| color_match_outcomes.append(color_ok) |
| hier_match_outcomes.append(hier_ok) |
| pair_outcomes.append(pair_ok) |
|
|
| rows.append([ |
| full_text, |
| color, |
| hierarchy, |
| f"{sim_color_match:.4f}", |
| f"{sim_color_cross:.4f}", |
| f"{sim_hier_match:.4f}", |
| f"{sim_hier_cross:.4f}", |
| format_bool(pair_ok), |
| ]) |
|
|
| print_table( |
| f"{test_name}: Subspace Decomposition Consistency " |
| f"(showing {min(num_printed, len(rows))}/{len(rows)} examples)", |
| [ |
| "Full description", |
| "Color", |
| "Hierarchy", |
| "ColorSub match", |
| "ColorSub cross", |
| "HierSub match", |
| "HierSub cross", |
| "Result", |
| ], |
| rows[:num_printed], |
| ) |
|
|
| pass_rate = sum(pair_outcomes) / len(pair_outcomes) |
| color_rate = sum(color_match_outcomes) / len(color_match_outcomes) |
| hier_rate = sum(hier_match_outcomes) / len(hier_match_outcomes) |
| avg_color_match = sum(sim_color_match_values) / len(sim_color_match_values) |
| avg_color_cross = sum(sim_color_cross_values) / len(sim_color_cross_values) |
| avg_hier_match = sum(sim_hier_match_values) / len(sim_hier_match_values) |
| avg_hier_cross = sum(sim_hier_cross_values) / len(sim_hier_cross_values) |
|
|
| print(f"{test_name} aggregate: {sum(pair_outcomes)}/{len(pair_outcomes)} passed ({pass_rate:.2%})") |
| print(f" sub-condition color_match > color_cross: {sum(color_match_outcomes)}/{len(color_match_outcomes)} ({color_rate:.2%})") |
| print(f" sub-condition hier_match > hier_cross: {sum(hier_match_outcomes)}/{len(hier_match_outcomes)} ({hier_rate:.2%})") |
| print( |
| f"{test_name} avg similarities: " |
| f"color_match={avg_color_match:.4f}, color_cross={avg_color_cross:.4f}, " |
| f"hier_match={avg_hier_match:.4f}, hier_cross={avg_hier_cross:.4f}" |
| ) |
|
|
| return { |
| "overall": all(pair_outcomes), |
| "pass_rate": pass_rate, |
| "color_match_rate": color_rate, |
| "hier_match_rate": hier_rate, |
| "avg_color_match": avg_color_match, |
| "avg_color_cross": avg_color_cross, |
| "avg_hier_match": avg_hier_match, |
| "avg_hier_cross": avg_hier_cross, |
| } |
|
|
|
|
| FASHION_MNIST_LABELS = { |
| 0: "top", |
| 1: "pant", |
| 2: "sweater", |
| 3: "dress", |
| 4: "coat", |
| 5: "shoes", |
| 6: "shirt", |
| 7: "shoes", |
| 8: "accessories", |
| 9: "shoes", |
| } |
|
|
| |
| |
| |
| |
| FASHION_MNIST_ORIGINAL_LABELS = { |
| 0: "T-shirt", |
| 1: "Trouser", |
| 2: "Pullover", |
| 3: "Dress", |
| 4: "Coat", |
| 5: "Sandal", |
| 6: "Shirt", |
| 7: "Sneaker", |
| 8: "Bag", |
| 9: "Ankle boot", |
| } |
|
|
| FASHION_MNIST_CSV = "data/fashion-mnist_test.csv" |
| INTERNAL_DATASET_CSV = "data/data.csv" |
|
|
|
|
| def fashion_mnist_pixels_to_tensor(pixel_values: np.ndarray, image_size: int = 224) -> torch.Tensor: |
| img_array = pixel_values.reshape(28, 28).astype(np.uint8) |
| img_array = np.stack([img_array] * 3, axis=-1) |
| image = Image.fromarray(img_array) |
| transform = transforms.Compose([ |
| transforms.Resize((image_size, image_size)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| ]) |
| return transform(image) |
|
|
|
|
| def get_image_embedding( |
| model: CLIPModelTransformers, processor: CLIPProcessor, device: torch.device, image_tensor: torch.Tensor |
| ) -> torch.Tensor: |
| """Normalized image embedding from a preprocessed tensor (shape: [512]).""" |
| image_tensor = image_tensor.unsqueeze(0).to(device) |
| |
| from torchvision.transforms.functional import to_pil_image |
| pil_img = to_pil_image(image_tensor.squeeze(0).cpu()) |
| return F.normalize(encode_image(model, processor, pil_img, device), dim=-1).squeeze(0) |
|
|
|
|
| def get_image_embedding_from_pil( |
| model: CLIPModelTransformers, processor: CLIPProcessor, device: torch.device, image: Image.Image |
| ) -> torch.Tensor: |
| """Normalized image embedding from a PIL image (shape: [512]).""" |
| return F.normalize(encode_image(model, processor, image, device), dim=-1).squeeze(0) |
|
|
|
|
| def get_text_embeddings_batch( |
| model: CLIPModelTransformers, processor: CLIPProcessor, device: torch.device, texts: List[str] |
| ) -> torch.Tensor: |
| """Normalized text embeddings for a batch (shape: [N, 512]).""" |
| return F.normalize(encode_text(model, processor, texts, device), dim=-1) |
|
|
|
|
| def get_prompt_ensembled_text_embeddings( |
| model: CLIPModelTransformers, |
| processor: CLIPProcessor, |
| device: torch.device, |
| labels: List[str], |
| templates: List[str], |
| ) -> torch.Tensor: |
| """Encode labels with multiple prompt templates and average embeddings.""" |
| all_prompt_embs: List[torch.Tensor] = [] |
| for template in templates: |
| prompts = [template.format(label=label) for label in labels] |
| all_prompt_embs.append(get_text_embeddings_batch(model, processor, device, prompts)) |
| stacked = torch.stack(all_prompt_embs, dim=0) |
| ensembled = stacked.mean(dim=0) |
| ensembled = F.normalize(ensembled, dim=-1) |
| return ensembled |
|
|
|
|
| def get_descriptor_ensembled_text_embeddings( |
| model: CLIPModelTransformers, |
| processor: CLIPProcessor, |
| device: torch.device, |
| descriptors_per_label: Dict[str, List[str]], |
| labels: List[str], |
| templates: List[str], |
| ) -> torch.Tensor: |
| """Encode each label by averaging across (descriptor, template) pairs. |
| |
| For each canonical label, multiple synonym/leaf-level descriptors are |
| expanded with each prompt template, encoded, and averaged. This produces |
| a single text embedding per canonical label whose centroid covers the |
| full breadth of the coarse-parent category — used to evaluate models |
| against datasets whose ground-truth labels are coarser than the model's |
| training vocabulary (e.g. KAGL `category2`'s `Topwear` covers GAP-CLIP's |
| `top`/`shirt`/`polo`/`sweater`/`jacket`/`coat` leaves). |
| |
| Returns shape [len(labels), embedding_dim], L2-normalized. |
| """ |
| out: List[torch.Tensor] = [] |
| for label in labels: |
| descriptors = descriptors_per_label.get(label, [label]) |
| prompts: List[str] = [] |
| for descriptor in descriptors: |
| for template in templates: |
| prompts.append(template.format(label=descriptor)) |
| embs = get_text_embeddings_batch(model, processor, device, prompts) |
| centroid = embs.mean(dim=0, keepdim=True) |
| centroid = F.normalize(centroid, dim=-1) |
| out.append(centroid) |
| return torch.cat(out, dim=0) |
|
|
|
|
| |
| |
| |
| |
| KAGL_COARSE_DESCRIPTORS: Dict[str, List[str]] = { |
| "accessories": [ |
| "accessory", "fashion accessory", "bag", "handbag", "backpack", |
| "wallet", "watch", "belt", "scarf", "tie", "jewelry", "earrings", |
| "necklace", "bracelet", "cap", "hat", "sunglasses", "eyewear", |
| "headwear", "clutch", |
| ], |
| "dress": [ |
| "dress", "gown", "frock", "saree", "sari", "lehenga", "robe", |
| "kurta dress", "sundress", "evening dress", |
| ], |
| "pant": [ |
| "pants", "trousers", "jeans", "leggings", "tights", "shorts", |
| "skirt", "bottomwear", "joggers", "track pants", "capris", |
| "lounge pants", "salwar", "chinos", "lower garment", |
| ], |
| "shoes": [ |
| "shoes", "footwear", "sneakers", "boots", "sandals", "heels", |
| "flats", "loafers", "flip flops", "slippers", |
| ], |
| "socks": ["socks", "stockings", "hosiery"], |
| "top": [ |
| "top", "topwear", "shirt", "t-shirt", "tshirt", "blouse", "sweater", |
| "sweatshirt", "hoodie", "cardigan", "polo", "jacket", "coat", |
| "blazer", "kurta", "kurti", "tunic", "upper garment", |
| ], |
| "underwear": [ |
| "underwear", "innerwear", "bra", "boxers", "briefs", "trunks", |
| "camisole", "undershirt", "vest", "bodysuit", "sleepwear", |
| "nightwear", "lingerie", "swimwear", "loungewear", |
| ], |
| |
| |
| "shirt": ["shirt", "tshirt", "t-shirt", "blouse", "button-up", "button down"], |
| "polo": ["polo", "polo shirt", "polo tee"], |
| "sweater": ["sweater", "sweatshirt", "hoodie", "cardigan", "jumper", "pullover"], |
| "jacket": ["jacket", "blazer", "windbreaker", "bomber"], |
| "coat": ["coat", "overcoat", "trench coat", "parka"], |
| "legging": ["leggings", "tights", "stretch pants"], |
| "short": ["shorts", "boardshorts", "bermuda shorts"], |
| "skirt": ["skirt", "miniskirt", "midi skirt"], |
| "bras": ["bra", "brassiere"], |
| "bodysuits": ["bodysuit", "leotard", "onesie", "jumpsuit", "romper"], |
| "swimwear": ["swimsuit", "swimwear", "bikini", "trunks"], |
| } |
|
|
|
|
| def compute_subspace_accuracies( |
| img_embs: torch.Tensor, text_embs: torch.Tensor, cfg: RuntimeConfig, |
| ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: |
| """Return (preds_full, preds_color, preds_hier) from normalized embeddings.""" |
| |
| preds_full = (img_embs @ text_embs.T).argmax(dim=-1).cpu().numpy() |
| |
| img_c = F.normalize(img_embs[:, :cfg.color_emb_dim], dim=-1) |
| txt_c = F.normalize(text_embs[:, :cfg.color_emb_dim], dim=-1) |
| preds_color = (img_c @ txt_c.T).argmax(dim=-1).cpu().numpy() |
| |
| h_s = cfg.color_emb_dim |
| h_e = cfg.color_emb_dim + cfg.hierarchy_emb_dim |
| img_h = F.normalize(img_embs[:, h_s:h_e], dim=-1) |
| txt_h = F.normalize(text_embs[:, h_s:h_e], dim=-1) |
| preds_hier = (img_h @ txt_h.T).argmax(dim=-1).cpu().numpy() |
| return preds_full, preds_color, preds_hier |
|
|
|
|
| def _subspace_cosine( |
| img_embs: torch.Tensor, text_embs: torch.Tensor, start: int, end: int |
| ) -> torch.Tensor: |
| """Cosine similarity computed on a re-normalized slice [start:end].""" |
| img_s = F.normalize(img_embs[:, start:end], dim=-1) |
| txt_s = F.normalize(text_embs[:, start:end], dim=-1) |
| return img_s @ txt_s.T |
|
|
|
|
| def _zscore_rowwise(scores: torch.Tensor) -> torch.Tensor: |
| """Standardize each row across candidate labels.""" |
| mean = scores.mean(dim=-1, keepdim=True) |
| std = scores.std(dim=-1, keepdim=True) |
| return (scores - mean) / (std + 1e-6) |
|
|
|
|
| def compute_fused_scores( |
| img_embs: torch.Tensor, |
| text_embs: torch.Tensor, |
| cfg: RuntimeConfig, |
| weights: Tuple[float, float, float, float], |
| mask_color: bool = False, |
| ) -> Dict[str, torch.Tensor]: |
| """Subspace-aware fused scoring over the paper's decomposed subspaces. |
| |
| Computes four sub-scores (general / hierarchy / no-color / color), z-scores |
| each per query, then sums with `weights = (w_gen, w_hier, w_nocolor, w_color)`. |
| Returns a dict with both the fused logits and every component (useful for |
| ablation reporting). |
| |
| When `mask_color=True`, dims 0:color_emb_dim of `img_embs` are zeroed and the |
| embedding is re-normalized before any sub-score is computed. This is |
| appropriate for grayscale inputs (FMNIST) where the color subspace is |
| degenerate and leaks noise into `s_full` and `s_nocolor` is not enough. |
| """ |
| if mask_color: |
| img_embs = img_embs.clone() |
| img_embs[:, : cfg.color_emb_dim] = 0.0 |
| img_embs = F.normalize(img_embs, dim=-1) |
|
|
| h_s = cfg.color_emb_dim |
| h_e = cfg.color_emb_dim + cfg.hierarchy_emb_dim |
| d = text_embs.size(-1) |
|
|
| s_full = img_embs @ text_embs.T |
| s_gen = _subspace_cosine(img_embs, text_embs, h_e, d) |
| s_hier = _subspace_cosine(img_embs, text_embs, h_s, h_e) |
| s_nocolor = _subspace_cosine(img_embs, text_embs, h_s, d) |
| s_color = _subspace_cosine(img_embs, text_embs, 0, h_s) |
|
|
| w_gen, w_hier, w_nocolor, w_color = weights |
| fused = ( |
| w_gen * _zscore_rowwise(s_gen) |
| + w_hier * _zscore_rowwise(s_hier) |
| + w_nocolor * _zscore_rowwise(s_nocolor) |
| + w_color * _zscore_rowwise(s_color) |
| ) |
| return { |
| "full": s_full, |
| "gen": s_gen, |
| "hier": s_hier, |
| "nocolor": s_nocolor, |
| "color": s_color, |
| "fused": fused, |
| } |
|
|
|
|
| def apply_label_prior( |
| logits: torch.Tensor, |
| candidate_labels: List[str], |
| tau: float = ZERO_SHOT_SOFTMAX_TAU, |
| ) -> Tuple[torch.Tensor, float]: |
| """Softmax the logits at temperature `tau`, then mix with adaptive prior. |
| |
| Returns `(probs, prior_weight)`. `prior_weight` self-attenuates on OOD |
| datasets via `get_adaptive_label_prior`, so it is safe to call |
| unconditionally. |
| """ |
| probs = F.softmax(logits / tau, dim=-1) |
| prior, prior_w = get_adaptive_label_prior(candidate_labels) |
| if prior_w > 0.0: |
| prior = prior.to(probs.device) |
| probs = probs * (1.0 - prior_w) + prior * prior_w |
| return probs, prior_w |
|
|
|
|
| def get_internal_label_prior(labels: List[str]) -> torch.Tensor: |
| """ |
| Compute label prior from internal dataset hierarchy frequency. |
| Falls back to uniform when internal CSV is unavailable. |
| """ |
| csv_file = Path(INTERNAL_DATASET_CSV) |
| if not csv_file.exists(): |
| return torch.ones(len(labels), dtype=torch.float32) / max(len(labels), 1) |
| try: |
| df = pd.read_csv(INTERNAL_DATASET_CSV, usecols=["hierarchy"]).dropna() |
| except Exception: |
| return torch.ones(len(labels), dtype=torch.float32) / max(len(labels), 1) |
| if len(df) == 0: |
| return torch.ones(len(labels), dtype=torch.float32) / max(len(labels), 1) |
|
|
| norm_labels = [normalize_hierarchy_label(v) for v in df["hierarchy"].astype(str)] |
| counts = pd.Series(norm_labels).value_counts().to_dict() |
| smooth = 1e-3 |
| probs = torch.tensor([float(counts.get(label, 0.0)) + smooth for label in labels], dtype=torch.float32) |
| probs = probs / probs.sum() |
| return probs |
|
|
|
|
| def get_adaptive_label_prior(labels: List[str]) -> Tuple[torch.Tensor, float]: |
| """ |
| Compute label prior with adaptive strength based on overlap between |
| candidate labels and the training distribution. When most candidate |
| labels are out-of-domain, the recommended weight drops toward zero so |
| the prior does not penalise novel categories. |
| """ |
| csv_file = Path(INTERNAL_DATASET_CSV) |
| uniform = torch.ones(len(labels), dtype=torch.float32) / max(len(labels), 1) |
| if not csv_file.exists(): |
| return uniform, 0.0 |
| try: |
| df = pd.read_csv(INTERNAL_DATASET_CSV, usecols=["hierarchy"]).dropna() |
| except Exception: |
| return uniform, 0.0 |
| if len(df) == 0: |
| return uniform, 0.0 |
|
|
| norm_labels = [normalize_hierarchy_label(v) for v in df["hierarchy"].astype(str)] |
| counts = pd.Series(norm_labels).value_counts().to_dict() |
| known_labels = set(counts.keys()) |
| overlap = sum(1 for l in labels if l in known_labels) / max(len(labels), 1) |
| total_count = sum(counts.values()) |
| default_prob = 1.0 / max(len(labels), 1) |
|
|
| probs = torch.tensor( |
| [ |
| counts.get(label, 0.0) / total_count if label in known_labels else default_prob |
| for label in labels |
| ], |
| dtype=torch.float32, |
| ) |
| probs = probs / probs.sum() |
| recommended_weight = 0.15 * (overlap ** 2) |
| return probs, recommended_weight |
|
|
|
|
| def _encode_images_batched( |
| model, processor, device, pil_images: List[Image.Image], batch_size: int, desc: str, |
| tta: bool = False, |
| ) -> torch.Tensor: |
| """Encode a list of PIL images in batches and return a normalized [N, 512] tensor. |
| |
| With `tta=True`, also encodes each image's horizontal flip and averages |
| the L2-normalized embeddings (then re-normalizes). Doubles encoding time |
| but is the standard CLIP zero-shot test-time-augmentation trick. |
| """ |
| parts: List[torch.Tensor] = [] |
| for start in tqdm(range(0, len(pil_images), batch_size), desc=desc): |
| batch = pil_images[start : start + batch_size] |
| emb = encode_image(model, processor, batch, device).to(device).float() |
| emb = F.normalize(emb, dim=-1) |
| if tta: |
| flipped = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in batch] |
| emb_f = encode_image(model, processor, flipped, device).to(device).float() |
| emb_f = F.normalize(emb_f, dim=-1) |
| emb = F.normalize((emb + emb_f) / 2.0, dim=-1) |
| parts.append(emb) |
| if not parts: |
| return torch.empty(0, 512, device=device) |
| return torch.cat(parts, dim=0) |
|
|
|
|
| def run_zero_shot_scoring( |
| img_embs: torch.Tensor, |
| text_embs_single: torch.Tensor, |
| text_embs_ensembled: torch.Tensor, |
| candidate_labels: List[str], |
| all_labels: np.ndarray, |
| cfg: RuntimeConfig, |
| dataset_key: str, |
| mask_color: bool = False, |
| aux_img_embs: Optional[torch.Tensor] = None, |
| aux_text_embs_single: Optional[torch.Tensor] = None, |
| spec_img_embs: Optional[torch.Tensor] = None, |
| spec_text_embs: Optional[torch.Tensor] = None, |
| ) -> Dict[str, float]: |
| """Shared scoring pipeline for Test D. |
| |
| Returns a metrics dict with the paper's baseline protocol plus every |
| ablation step (prompt ensembling, per-subspace cosine, z-score fusion, |
| fusion + adaptive label prior). |
| |
| `dataset_key` selects weights from `DATASET_FUSION_WEIGHTS`. |
| `mask_color=True` is appropriate for grayscale datasets (FMNIST); it zeros |
| dims 0:color_emb_dim of image embeddings before fused scoring only (the |
| paper-protocol baseline is left untouched). |
| """ |
| if len(all_labels) == 0: |
| return {} |
|
|
| def _f1(preds: np.ndarray) -> float: |
| return float(f1_score(all_labels, preds, average="weighted")) |
|
|
| def _macro_f1(preds: np.ndarray) -> float: |
| return float(f1_score(all_labels, preds, average="macro", zero_division=0)) |
|
|
| def _acc(preds: np.ndarray) -> float: |
| return float((preds == all_labels).mean()) |
|
|
| |
| preds_paper = (img_embs @ text_embs_single.T).argmax(dim=-1).cpu().numpy() |
|
|
| |
| preds_full_ens = (img_embs @ text_embs_ensembled.T).argmax(dim=-1).cpu().numpy() |
|
|
| |
| weights = DATASET_FUSION_WEIGHTS.get(dataset_key, (0.5, 0.7, 0.3, 0.0)) |
| scores = compute_fused_scores( |
| img_embs, text_embs_ensembled, cfg, weights, mask_color=mask_color, |
| ) |
| preds_gen = scores["gen"].argmax(dim=-1).cpu().numpy() |
| preds_hier = scores["hier"].argmax(dim=-1).cpu().numpy() |
| preds_nocolor = scores["nocolor"].argmax(dim=-1).cpu().numpy() |
| preds_fused = scores["fused"].argmax(dim=-1).cpu().numpy() |
|
|
| probs, prior_w = apply_label_prior(scores["fused"], candidate_labels) |
| preds_fused_prior = probs.argmax(dim=-1).cpu().numpy() |
|
|
| |
| |
| |
| |
| sub_for_ens = { |
| "full": scores["full"], |
| "gen": _zscore_rowwise(scores["gen"]), |
| "hier": _zscore_rowwise(scores["hier"]), |
| "nocolor": _zscore_rowwise(scores["nocolor"]), |
| "color": _zscore_rowwise(scores["color"]), |
| } |
| ens_params = DATASET_ENSEMBLE_PARAMS.get(dataset_key, EnsembleParams()) |
| p_ens = compute_prob_ensemble(sub_for_ens, ens_params) |
| preds_prob_ens = p_ens.argmax(dim=-1).cpu().numpy() |
|
|
| |
| p_ens_adapt = compute_prob_ensemble_adaptive(sub_for_ens, AdaptiveEnsembleParams()) |
| preds_prob_ens_adapt = p_ens_adapt.argmax(dim=-1).cpu().numpy() |
|
|
| |
| |
| |
| |
| s_full_single = img_embs @ text_embs_single.T |
| rerank_k, rerank_w = DATASET_RERANK_PARAMS.get(dataset_key, (3, 0.5)) |
| preds_rerank = ( |
| rerank_top_k(scores["fused"], s_full_single, k=rerank_k, rerank_weight=rerank_w) |
| .cpu().numpy() |
| ) |
|
|
| |
| |
| |
| hybrid_results: Dict[str, float] = {} |
| if aux_img_embs is not None and aux_text_embs_single is not None: |
| aux_full_single = aux_img_embs @ aux_text_embs_single.T |
| hybrid_preds = compute_hybrid_metrics( |
| scores["fused"], aux_full_single, dataset_key=dataset_key, |
| ) |
| for name, preds_t in hybrid_preds.items(): |
| preds_np = preds_t.cpu().numpy() |
| hybrid_results[f"f1_{name}"] = _f1(preds_np) |
|
|
| |
| pure_boost_results: Dict[str, float] = {} |
| if spec_img_embs is not None and spec_text_embs is not None: |
| s_spec = spec_img_embs @ spec_text_embs.T |
| pb_preds = compute_pure_boost_metrics( |
| scores["fused"], s_spec, dataset_key=dataset_key, |
| ) |
| for name, preds_t in pb_preds.items(): |
| preds_np = preds_t.cpu().numpy() |
| pure_boost_results[f"f1_{name}"] = _f1(preds_np) |
|
|
| |
| ta_params = DATASET_TYPE_AWARE_PARAMS.get(dataset_key, TypeAwareParams()) |
| ta = compute_type_aware_scores( |
| img_embs, text_embs_ensembled, candidate_labels, cfg, ta_params, |
| extractor=_HIERARCHY_EXTRACTOR, normalize_fn=normalize_hierarchy_label, |
| mask_color=mask_color, |
| ) |
| preds_type_aware = ta["fused_ta"].argmax(dim=-1).cpu().numpy() |
| preds_ta_no_prior = ta["fused_ta_no_prior"].argmax(dim=-1).cpu().numpy() |
| preds_ta_no_gating = ta["fused_ta_no_gating"].argmax(dim=-1).cpu().numpy() |
|
|
| parse_rate = float(ta["parse_rate"].item()) |
| P_type = ta["P_type"] |
| p_log = torch.log(P_type.clamp_min(1e-12)) |
| type_entropy = float(-(P_type * p_log).sum(dim=-1).mean().item()) |
| mean_C = float(ta["C"].mean().item()) |
|
|
| |
| |
| per_class_paper = f1_score( |
| all_labels, preds_paper, labels=list(range(len(candidate_labels))), |
| average=None, zero_division=0, |
| ) |
| per_class_fused = f1_score( |
| all_labels, preds_fused, labels=list(range(len(candidate_labels))), |
| average=None, zero_division=0, |
| ) |
|
|
| return { |
| |
| "accuracy": _acc(preds_paper), |
| "weighted_f1": _f1(preds_paper), |
| "macro_f1": _macro_f1(preds_paper), |
| |
| "f1_full_ensembled": _f1(preds_full_ens), |
| "f1_gen": _f1(preds_gen), |
| "f1_hier": _f1(preds_hier), |
| "f1_nocolor": _f1(preds_nocolor), |
| "f1_fused": _f1(preds_fused), |
| "macro_f1_fused": _macro_f1(preds_fused), |
| "f1_fused_prior": _f1(preds_fused_prior), |
| |
| "f1_prob_ens": _f1(preds_prob_ens), |
| "f1_prob_ens_adaptive": _f1(preds_prob_ens_adapt), |
| "f1_rerank": _f1(preds_rerank), |
| |
| **hybrid_results, |
| |
| **pure_boost_results, |
| |
| "f1_type_aware": _f1(preds_type_aware), |
| "f1_type_aware_no_prior": _f1(preds_ta_no_prior), |
| "f1_type_aware_no_gating": _f1(preds_ta_no_gating), |
| "type_parse_rate": parse_rate, |
| "type_entropy": type_entropy, |
| "mean_C": mean_C, |
| "prior_weight": prior_w, |
| "num_samples": int(len(all_labels)), |
| "num_labels": len(candidate_labels), |
| "per_class_f1_paper": { |
| lbl: float(per_class_paper[i]) for i, lbl in enumerate(candidate_labels) |
| }, |
| "per_class_f1_fused": { |
| lbl: float(per_class_fused[i]) for i, lbl in enumerate(candidate_labels) |
| }, |
| } |
|
|
|
|
| def _maybe_specialist_embeddings( |
| spec_model, pil_images, candidate_labels, batch_size, device, desc, tta=True, |
| ): |
| """Return (spec_img_embs, spec_text_embs) or (None, None) when spec_model is None.""" |
| if spec_model is None: |
| return None, None |
| spec_img_embs = encode_images_with_specialist_tta( |
| spec_model, pil_images, batch_size, device, desc=desc, tta=tta, |
| ) |
| spec_text_embs = encode_text_with_specialist_ensembled( |
| spec_model, candidate_labels, ZERO_SHOT_TEMPLATES, device, |
| ) |
| return spec_img_embs, spec_text_embs |
|
|
|
|
| def zero_shot_fashion_mnist( |
| model, |
| processor, |
| device, |
| cfg: RuntimeConfig, |
| batch_size: int = 64, |
| data_root: str = "./data", |
| aux_model=None, |
| aux_processor=None, |
| spec_model=None, |
| image_tta: bool = False) -> Dict[str, float]: |
| """Notebook-equivalent zero-shot accuracy on all Fashion-MNIST test samples.""" |
| dataset = datasets.FashionMNIST( |
| root=data_root, train=False, download=True, |
| transform=transforms.Grayscale(num_output_channels=3), |
| ) |
| loader = DataLoader( |
| dataset, batch_size=batch_size, shuffle=False, |
| collate_fn=lambda batch: ( |
| [item[0] for item in batch], |
| torch.tensor([item[1] for item in batch]), |
| ), |
| ) |
|
|
| candidate_labels = list(dataset.classes) |
|
|
| single_prompts = [f"a photo of a {label}" for label in candidate_labels] |
| text_embs_single = get_text_embeddings_batch(model, processor, device, single_prompts).to(device).float() |
| text_embs_ens = get_prompt_ensembled_text_embeddings( |
| model, processor, device, candidate_labels, ZERO_SHOT_TEMPLATES, |
| ).to(device).float() |
|
|
| aux_text_embs_single = None |
| if aux_model is not None and aux_processor is not None: |
| aux_text_embs_single = get_text_embeddings_batch( |
| aux_model, aux_processor, device, single_prompts, |
| ).to(device).float() |
|
|
| |
| |
| all_img_embs: List[torch.Tensor] = [] |
| all_aux_img_embs: List[torch.Tensor] = [] |
| all_pil: List[Image.Image] = [] |
| all_gt: List[int] = [] |
| for pil_images, labels in tqdm(loader, desc="Zero-shot Fashion-MNIST"): |
| pil_images = [ImageOps.invert(img) for img in pil_images] |
| emb = encode_image(model, processor, pil_images, device).to(device).float() |
| emb = F.normalize(emb, dim=-1) |
| if image_tta: |
| flipped = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in pil_images] |
| emb_f = encode_image(model, processor, flipped, device).to(device).float() |
| emb_f = F.normalize(emb_f, dim=-1) |
| emb = F.normalize((emb + emb_f) / 2.0, dim=-1) |
| all_img_embs.append(emb) |
| if aux_model is not None and aux_processor is not None: |
| aux_emb = encode_image(aux_model, aux_processor, pil_images, device).to(device).float() |
| all_aux_img_embs.append(F.normalize(aux_emb, dim=-1)) |
| all_pil.extend(pil_images) |
| all_gt.extend(labels.tolist()) |
|
|
| img_embs = torch.cat(all_img_embs, dim=0) if all_img_embs else torch.empty(0, 512, device=device) |
| aux_img_embs = ( |
| torch.cat(all_aux_img_embs, dim=0) if all_aux_img_embs else None |
| ) |
| all_labels = np.asarray(all_gt, dtype=np.int64) |
|
|
| spec_img_embs, spec_text_embs = _maybe_specialist_embeddings( |
| spec_model, all_pil, candidate_labels, batch_size, device, |
| desc="FMNIST specialist", tta=image_tta, |
| ) |
|
|
| metrics = run_zero_shot_scoring( |
| img_embs, text_embs_single, text_embs_ens, candidate_labels, all_labels, |
| cfg, dataset_key="fmnist", mask_color=True, |
| aux_img_embs=aux_img_embs, aux_text_embs_single=aux_text_embs_single, |
| spec_img_embs=spec_img_embs, spec_text_embs=spec_text_embs, |
| ) |
| print( |
| "FMNIST zero-shot " |
| f"paper={metrics.get('weighted_f1', 0):.4f} " |
| f"ens_full={metrics.get('f1_full_ensembled', 0):.4f} " |
| f"gen={metrics.get('f1_gen', 0):.4f} " |
| f"hier={metrics.get('f1_hier', 0):.4f} " |
| f"nocolor={metrics.get('f1_nocolor', 0):.4f} " |
| f"fused={metrics.get('f1_fused', 0):.4f} " |
| f"fused+prior={metrics.get('f1_fused_prior', 0):.4f}" |
| ) |
| print( |
| "FMNIST ensemble " |
| f"prob_ens={metrics.get('f1_prob_ens', 0):.4f} " |
| f"prob_ens_adaptive={metrics.get('f1_prob_ens_adaptive', 0):.4f} " |
| f"rerank_topk={metrics.get('f1_rerank', 0):.4f}" |
| ) |
| if any(k.startswith('f1_hybrid_') for k in metrics): |
| print( |
| "FMNIST hybrid " |
| f"w30={metrics.get('f1_hybrid_w30', 0):.4f} " |
| f"w50={metrics.get('f1_hybrid_w50', 0):.4f} " |
| f"w70={metrics.get('f1_hybrid_w70', 0):.4f} " |
| f"rerank={metrics.get('f1_hybrid_rerank', 0):.4f}" |
| ) |
| if any(k.startswith('f1_pure_') for k in metrics): |
| print( |
| "FMNIST pure-boost " |
| f"spec_only={metrics.get('f1_pure_spec_only', 0):.4f} " |
| f"w50={metrics.get('f1_pure_boost_w50', 0):.4f} " |
| f"w60={metrics.get('f1_pure_boost_w60', 0):.4f} " |
| f"w70={metrics.get('f1_pure_boost_w70', 0):.4f}" |
| ) |
| print( |
| "FMNIST type-aware " |
| f"ta={metrics.get('f1_type_aware', 0):.4f} " |
| f"ta_no_prior={metrics.get('f1_type_aware_no_prior', 0):.4f} " |
| f"ta_no_gating={metrics.get('f1_type_aware_no_gating', 0):.4f} " |
| f"parse_rate={metrics.get('type_parse_rate', 0):.2f} " |
| f"H(P_type)={metrics.get('type_entropy', 0):.3f} " |
| f"mean_C={metrics.get('mean_C', 0):.3f}" |
| ) |
| return metrics |
|
|
|
|
|
|
| def zero_shot_kagl( |
| model, |
| processor, |
| device, |
| cfg: RuntimeConfig, |
| batch_size: int = 64, |
| num_examples: int = 10000, |
| aux_model=None, |
| aux_processor=None, |
| spec_model=None, |
| image_tta: bool = False, |
| ) -> Optional[Dict[str, float]]: |
| """Notebook-equivalent zero-shot accuracy/F1 on KAGL Marqo (category2).""" |
| try: |
| from datasets import load_dataset |
| except Exception: |
| print("Skipping zero_shot_kagl: datasets package not available") |
| return None |
|
|
| try: |
| dataset = load_dataset("Marqo/KAGL", split="data") |
| except Exception as exc: |
| print(f"Skipping zero_shot_kagl: failed to load dataset ({exc})") |
| return None |
|
|
| dataset = dataset.shuffle(seed=42).select(range(min(num_examples, len(dataset)))) |
|
|
| pil_images: List[Image.Image] = [] |
| labels_text: List[str] = [] |
| for item in dataset: |
| raw_label = item.get("category2") |
| image_obj = item.get("image") |
| if raw_label is None or image_obj is None: |
| continue |
|
|
| if hasattr(image_obj, "convert"): |
| image = image_obj.convert("RGB") |
| elif isinstance(image_obj, dict) and "bytes" in image_obj: |
| image = Image.open(BytesIO(image_obj["bytes"])).convert("RGB") |
| else: |
| continue |
| pil_images.append(image) |
| labels_text.append(str(raw_label).strip()) |
|
|
| if not pil_images: |
| print("Skipping zero_shot_kagl: no valid samples") |
| return None |
|
|
| |
| from collections import Counter |
| raw_counts = Counter(labels_text) |
| print(f" KAGL: raw samples loaded = {len(labels_text)}, unique raw labels = {len(raw_counts)}") |
| oov_raw = sorted({lbl for lbl in raw_counts if not is_clothing_label(lbl)}) |
| if oov_raw: |
| oov_total = sum(raw_counts[l] for l in oov_raw) |
| print(f" KAGL: {len(oov_raw)} OOV raw labels covering {oov_total} samples (dropped): " |
| f"{oov_raw[:15]}{'...' if len(oov_raw) > 15 else ''}") |
|
|
| |
| |
| |
| keep_idx = [i for i, lbl in enumerate(labels_text) if is_clothing_label(lbl)] |
| if len(keep_idx) < len(labels_text): |
| dropped = len(labels_text) - len(keep_idx) |
| print(f" KAGL: filtered out {dropped} non-clothing samples " |
| f"({dropped / len(labels_text):.1%})") |
| pil_images = [pil_images[i] for i in keep_idx] |
| labels_text = [labels_text[i] for i in keep_idx] |
|
|
| if not pil_images: |
| print("Skipping zero_shot_kagl: no clothing samples after filter") |
| return None |
|
|
| |
| |
| |
| |
| |
| |
| canonical_labels = [normalize_hierarchy_label(lbl) for lbl in labels_text] |
| raw_to_canonical: Dict[str, Counter] = {} |
| for raw, canon in zip(labels_text, canonical_labels): |
| raw_to_canonical.setdefault(raw, Counter())[canon] += 1 |
| print(f" KAGL: filtered samples = {len(canonical_labels)}, " |
| f"unique canonical labels = {len(set(canonical_labels))}") |
| print(f" KAGL: raw -> canonical mapping (sample counts):") |
| for raw in sorted(raw_to_canonical): |
| items = ", ".join(f"{c}={n}" for c, n in raw_to_canonical[raw].most_common()) |
| print(f" {raw!r:24s} -> {items}") |
|
|
| candidate_labels = sorted(set(canonical_labels)) |
| label_to_idx = {label: idx for idx, label in enumerate(candidate_labels)} |
| all_labels = np.array([label_to_idx[label] for label in canonical_labels], dtype=np.int64) |
| canonical_counts = Counter(canonical_labels) |
| print(f" KAGL: per-class sample counts: " |
| + ", ".join(f"{lbl}={canonical_counts[lbl]}" for lbl in candidate_labels)) |
|
|
| |
| |
| |
| |
| |
| single_prompts = [f"a photo of a {label}" for label in candidate_labels] |
| text_embs_single = get_text_embeddings_batch(model, processor, device, single_prompts).to(device).float() |
| text_embs_ens = get_descriptor_ensembled_text_embeddings( |
| model, processor, device, KAGL_COARSE_DESCRIPTORS, |
| candidate_labels, ZERO_SHOT_TEMPLATES, |
| ).to(device).float() |
|
|
| img_embs = _encode_images_batched( |
| model, processor, device, pil_images, batch_size, desc="Zero-shot KAGL", |
| tta=image_tta, |
| ) |
| aux_img_embs = None |
| aux_text_embs_single = None |
| if aux_model is not None and aux_processor is not None: |
| aux_text_embs_single = get_text_embeddings_batch( |
| aux_model, aux_processor, device, single_prompts, |
| ).to(device).float() |
| aux_img_embs = _encode_images_batched( |
| aux_model, aux_processor, device, pil_images, batch_size, |
| desc="Zero-shot KAGL (aux)", |
| ) |
| spec_img_embs, spec_text_embs = _maybe_specialist_embeddings( |
| spec_model, pil_images, candidate_labels, batch_size, device, |
| desc="KAGL specialist", tta=image_tta, |
| ) |
| metrics = run_zero_shot_scoring( |
| img_embs, text_embs_single, text_embs_ens, candidate_labels, all_labels, |
| cfg, dataset_key="kagl", mask_color=False, |
| aux_img_embs=aux_img_embs, aux_text_embs_single=aux_text_embs_single, |
| spec_img_embs=spec_img_embs, spec_text_embs=spec_text_embs, |
| ) |
| print( |
| "KAGL zero-shot " |
| f"paper={metrics.get('weighted_f1', 0):.4f} " |
| f"macro={metrics.get('macro_f1', 0):.4f} " |
| f"ens_full={metrics.get('f1_full_ensembled', 0):.4f} " |
| f"gen={metrics.get('f1_gen', 0):.4f} " |
| f"hier={metrics.get('f1_hier', 0):.4f} " |
| f"nocolor={metrics.get('f1_nocolor', 0):.4f} " |
| f"fused={metrics.get('f1_fused', 0):.4f} " |
| f"macro_fused={metrics.get('macro_f1_fused', 0):.4f} " |
| f"fused+prior={metrics.get('f1_fused_prior', 0):.4f}" |
| ) |
| pc_paper = metrics.get('per_class_f1_paper', {}) or {} |
| pc_fused = metrics.get('per_class_f1_fused', {}) or {} |
| if pc_paper: |
| print(" KAGL per-class F1 (paper / fused):") |
| for lbl in sorted(pc_paper): |
| print(f" {lbl:14s} paper={pc_paper.get(lbl, 0):.3f} " |
| f"fused={pc_fused.get(lbl, 0):.3f}") |
| print( |
| "KAGL ensemble " |
| f"prob_ens={metrics.get('f1_prob_ens', 0):.4f} " |
| f"prob_ens_adaptive={metrics.get('f1_prob_ens_adaptive', 0):.4f} " |
| f"rerank_topk={metrics.get('f1_rerank', 0):.4f}" |
| ) |
| if any(k.startswith('f1_hybrid_') for k in metrics): |
| print( |
| "KAGL hybrid " |
| f"w30={metrics.get('f1_hybrid_w30', 0):.4f} " |
| f"w50={metrics.get('f1_hybrid_w50', 0):.4f} " |
| f"w70={metrics.get('f1_hybrid_w70', 0):.4f} " |
| f"rerank={metrics.get('f1_hybrid_rerank', 0):.4f}" |
| ) |
| if any(k.startswith('f1_pure_') for k in metrics): |
| print( |
| "KAGL pure-boost " |
| f"spec_only={metrics.get('f1_pure_spec_only', 0):.4f} " |
| f"w30={metrics.get('f1_pure_boost_w30', 0):.4f} " |
| f"w40={metrics.get('f1_pure_boost_w40', 0):.4f} " |
| f"w50={metrics.get('f1_pure_boost_w50', 0):.4f}" |
| ) |
| print( |
| "KAGL type-aware " |
| f"ta={metrics.get('f1_type_aware', 0):.4f} " |
| f"ta_no_prior={metrics.get('f1_type_aware_no_prior', 0):.4f} " |
| f"ta_no_gating={metrics.get('f1_type_aware_no_gating', 0):.4f} " |
| f"parse_rate={metrics.get('type_parse_rate', 0):.2f} " |
| f"H(P_type)={metrics.get('type_entropy', 0):.3f} " |
| f"mean_C={metrics.get('mean_C', 0):.3f}" |
| ) |
| return metrics |
|
|
|
|
| def zero_shot_internal( |
| model, |
| processor, |
| device, |
| cfg: RuntimeConfig, |
| batch_size: int = 64, |
| num_examples: int = 10000, |
| csv_path: str = INTERNAL_DATASET_CSV, |
| aux_model=None, |
| aux_processor=None, |
| spec_model=None, |
| image_tta: bool = False) -> Optional[Dict[str, float]]: |
| """Notebook-equivalent zero-shot accuracy/F1 on internal dataset.""" |
| csv_file = Path(csv_path) |
| if not csv_file.exists(): |
| print(f"Skipping zero_shot_internal: {csv_path} not found") |
| return None |
|
|
| df = pd.read_csv(csv_file) |
| use_local = "local_image_path" in df.columns |
| required_cols = {"hierarchy", "local_image_path"} if use_local else {"hierarchy", "image_url"} |
| if not required_cols.issubset(df.columns): |
| print(f"Skipping zero_shot_internal: missing required columns {required_cols}") |
| return None |
|
|
| img_col = "local_image_path" if use_local else "image_url" |
| df = df.dropna(subset=["hierarchy", img_col]).sample(frac=1.0, random_state=42) |
| pil_images: List[Image.Image] = [] |
| labels_text: List[str] = [] |
| for _, row in df.iterrows(): |
| if len(pil_images) >= num_examples: |
| break |
| try: |
| if use_local: |
| img_path = Path(str(row["local_image_path"])) |
| if not img_path.exists(): |
| |
| img_path = Path("data/images") / img_path.name |
| if not img_path.exists(): |
| continue |
| image = Image.open(img_path).convert("RGB") |
| else: |
| response = requests.get(str(row["image_url"]), timeout=5) |
| response.raise_for_status() |
| image = Image.open(BytesIO(response.content)).convert("RGB") |
| except Exception: |
| continue |
| label = normalize_hierarchy_label(str(row["hierarchy"])) |
| pil_images.append(image) |
| labels_text.append(label) |
|
|
| if not pil_images: |
| print("Skipping zero_shot_internal: no valid samples") |
| return None |
|
|
| candidate_labels = sorted(set(labels_text)) |
| label_to_idx = {label: idx for idx, label in enumerate(candidate_labels)} |
| all_labels = np.array([label_to_idx[label] for label in labels_text], dtype=np.int64) |
|
|
| single_prompts = [f"a photo of a {label}" for label in candidate_labels] |
| text_embs_single = get_text_embeddings_batch(model, processor, device, single_prompts).to(device).float() |
| text_embs_ens = get_prompt_ensembled_text_embeddings( |
| model, processor, device, candidate_labels, ZERO_SHOT_TEMPLATES, |
| ).to(device).float() |
|
|
| img_embs = _encode_images_batched( |
| model, processor, device, pil_images, batch_size, desc="Zero-shot Internal", |
| tta=image_tta, |
| ) |
| aux_img_embs = None |
| aux_text_embs_single = None |
| if aux_model is not None and aux_processor is not None: |
| aux_text_embs_single = get_text_embeddings_batch( |
| aux_model, aux_processor, device, single_prompts, |
| ).to(device).float() |
| aux_img_embs = _encode_images_batched( |
| aux_model, aux_processor, device, pil_images, batch_size, |
| desc="Zero-shot Internal (aux)", |
| ) |
| spec_img_embs, spec_text_embs = _maybe_specialist_embeddings( |
| spec_model, pil_images, candidate_labels, batch_size, device, |
| desc="Internal specialist", tta=image_tta, |
| ) |
| metrics = run_zero_shot_scoring( |
| img_embs, text_embs_single, text_embs_ens, candidate_labels, all_labels, |
| cfg, dataset_key="internal", mask_color=False, |
| aux_img_embs=aux_img_embs, aux_text_embs_single=aux_text_embs_single, |
| spec_img_embs=spec_img_embs, spec_text_embs=spec_text_embs, |
| ) |
| print( |
| "Internal zero-shot " |
| f"paper={metrics.get('weighted_f1', 0):.4f} " |
| f"ens_full={metrics.get('f1_full_ensembled', 0):.4f} " |
| f"gen={metrics.get('f1_gen', 0):.4f} " |
| f"hier={metrics.get('f1_hier', 0):.4f} " |
| f"nocolor={metrics.get('f1_nocolor', 0):.4f} " |
| f"fused={metrics.get('f1_fused', 0):.4f} " |
| f"fused+prior={metrics.get('f1_fused_prior', 0):.4f}" |
| ) |
| print( |
| "Internal ensemble " |
| f"prob_ens={metrics.get('f1_prob_ens', 0):.4f} " |
| f"prob_ens_adaptive={metrics.get('f1_prob_ens_adaptive', 0):.4f} " |
| f"rerank_topk={metrics.get('f1_rerank', 0):.4f}" |
| ) |
| if any(k.startswith('f1_hybrid_') for k in metrics): |
| print( |
| "Internal hybrid " |
| f"w30={metrics.get('f1_hybrid_w30', 0):.4f} " |
| f"w50={metrics.get('f1_hybrid_w50', 0):.4f} " |
| f"w70={metrics.get('f1_hybrid_w70', 0):.4f} " |
| f"rerank={metrics.get('f1_hybrid_rerank', 0):.4f}" |
| ) |
| if any(k.startswith('f1_pure_') for k in metrics): |
| print( |
| "Internal pure-boost " |
| f"spec_only={metrics.get('f1_pure_spec_only', 0):.4f} " |
| f"w40={metrics.get('f1_pure_boost_w40', 0):.4f} " |
| f"w50={metrics.get('f1_pure_boost_w50', 0):.4f} " |
| f"w60={metrics.get('f1_pure_boost_w60', 0):.4f}" |
| ) |
| print( |
| "Internal type-aware " |
| f"ta={metrics.get('f1_type_aware', 0):.4f} " |
| f"ta_no_prior={metrics.get('f1_type_aware_no_prior', 0):.4f} " |
| f"ta_no_gating={metrics.get('f1_type_aware_no_gating', 0):.4f} " |
| f"parse_rate={metrics.get('type_parse_rate', 0):.2f} " |
| f"H(P_type)={metrics.get('type_entropy', 0):.3f} " |
| f"mean_C={metrics.get('mean_C', 0):.3f}" |
| ) |
| return metrics |
|
|
|
|
| def normalize_hierarchy_label(raw_label: str) -> str: |
| """Map dataset category strings to internal hierarchy labels.""" |
| label = str(raw_label).strip().lower() |
| synonyms = { |
| "t-shirt/top": "top", |
| "top": "top", |
| "tee": "top", |
| "t-shirt": "top", |
| "shirt": "shirt", |
| "shirts": "shirt", |
| "pullover": "sweater", |
| "sweater": "sweater", |
| "coat": "coat", |
| "jacket": "jacket", |
| "outerwear": "coat", |
| "trouser": "pant", |
| "trousers": "pant", |
| "pants": "pant", |
| "pant": "pant", |
| "jeans": "pant", |
| "dress": "dress", |
| "skirt": "skirt", |
| "shorts": "short", |
| "short": "short", |
| "sandal": "shoes", |
| "sneaker": "shoes", |
| "ankle boot": "shoes", |
| "shoe": "shoes", |
| "shoes": "shoes", |
| "flip flops": "shoes", |
| "footwear": "shoes", |
| "shoe accessories": "shoes", |
| "bag": "accessories", |
| "bags": "accessories", |
| "accessory": "accessories", |
| "accessories": "accessories", |
| "belts": "accessories", |
| "eyewear": "accessories", |
| "jewellery": "accessories", |
| "jewelry": "accessories", |
| "headwear": "accessories", |
| "wallets": "accessories", |
| "watches": "accessories", |
| "mufflers": "accessories", |
| "scarves": "accessories", |
| "stoles": "accessories", |
| "ties": "accessories", |
| "topwear": "top", |
| "bottomwear": "pant", |
| "innerwear": "underwear", |
| "loungewear and nightwear": "underwear", |
| "saree": "dress", |
| "boots": "shoes", |
| "outer": "coat", |
| "sunglasses": "accessories", |
| "scarf & tie": "accessories", |
| "scarf/tie": "accessories", |
| "belt": "accessories", |
| |
| "tshirts": "shirt", |
| "tshirt": "shirt", |
| "tunics": "top", |
| "tunic": "top", |
| "kurta": "top", |
| "kurtas": "top", |
| "kurti": "top", |
| "kurtis": "top", |
| "blouse": "shirt", |
| "blouses": "shirt", |
| "camisoles": "top", |
| "camisole": "top", |
| "sweatshirt": "sweater", |
| "sweatshirts": "sweater", |
| "sweaters": "sweater", |
| "jumper": "sweater", |
| "jumpers": "sweater", |
| "hoodie": "sweater", |
| "hoodies": "sweater", |
| "cardigan": "sweater", |
| "cardigans": "sweater", |
| "jackets": "jacket", |
| "blazers": "jacket", |
| "blazer": "jacket", |
| "coats": "coat", |
| "tracksuit": "jacket", |
| "tracksuits": "jacket", |
| "track pants": "pant", |
| "lounge pants": "pant", |
| "salwar": "pant", |
| "salwar and dupatta": "pant", |
| "patiala": "pant", |
| "churidar": "pant", |
| "churidars": "pant", |
| "capris": "pant", |
| "capri": "pant", |
| "leggings": "legging", |
| "tights": "legging", |
| "stockings": "legging", |
| "lounge shorts": "short", |
| "skirts": "skirt", |
| "skorts": "skirt", |
| "skort": "skirt", |
| "dresses": "dress", |
| "nightdress": "dress", |
| "nightdresses": "dress", |
| "night suits": "dress", |
| "night dress": "dress", |
| "lounge tshirts": "top", |
| "sarees": "dress", |
| "lehenga choli": "dress", |
| "lehenga": "dress", |
| "cholis": "top", |
| "choli": "top", |
| "innerwear vests": "underwear", |
| "innerwear": "underwear", |
| "boxers": "underwear", |
| "boxer": "underwear", |
| "briefs": "underwear", |
| "brief": "underwear", |
| "trunks": "underwear", |
| "trunk": "underwear", |
| "bra": "bras", |
| "swim": "swimwear", |
| "swimsuit": "swimwear", |
| "swimsuits": "swimwear", |
| "swim suit": "swimwear", |
| "swimwear and beach wear": "swimwear", |
| "rompers": "bodysuits", |
| "romper": "bodysuits", |
| "jumpsuits": "bodysuits", |
| "jumpsuit": "bodysuits", |
| "bodysuit": "bodysuits", |
| "playsuit": "bodysuits", |
| "playsuits": "bodysuits", |
| "polos": "polo", |
| "polo shirt": "polo", |
| "polo shirts": "polo", |
| "polo t-shirts": "polo", |
| "casual shoes": "shoes", |
| "formal shoes": "shoes", |
| "sports shoes": "shoes", |
| "sandals": "shoes", |
| "flats": "shoes", |
| "heels": "shoes", |
| "booties": "shoes", |
| "loafers": "shoes", |
| "slippers": "shoes", |
| "stocking": "socks", |
| "handbags": "accessories", |
| "handbag": "accessories", |
| "backpacks": "accessories", |
| "backpack": "accessories", |
| "clutches": "accessories", |
| "clutch": "accessories", |
| "earrings": "accessories", |
| "earring": "accessories", |
| "necklaces": "accessories", |
| "necklace": "accessories", |
| "necklace and chains": "accessories", |
| "rings": "accessories", |
| "ring": "accessories", |
| "bracelets": "accessories", |
| "bracelet": "accessories", |
| "anklets": "accessories", |
| "anklet": "accessories", |
| "bangles": "accessories", |
| "bangle": "accessories", |
| "cufflinks": "accessories", |
| "pendants": "accessories", |
| "pendant": "accessories", |
| "caps": "accessories", |
| "cap": "accessories", |
| "hat": "accessories", |
| "hats": "accessories", |
| "duppata": "accessories", |
| "dupatta": "accessories", |
| "dupatta and stoles": "accessories", |
| "scarf": "accessories", |
| "stole": "accessories", |
| "muffler": "accessories", |
| "wallet": "accessories", |
| "watch": "accessories", |
| "tie": "accessories", |
| "gloves": "accessories", |
| "glove": "accessories", |
| } |
| exact = synonyms.get(label, None) |
| if exact is not None: |
| return exact |
|
|
| |
| |
| |
| result = _HIERARCHY_EXTRACTOR.extract_hierarchy(label) |
| if result: |
| return result |
|
|
| |
| _EXTRA_KEYWORDS = [ |
| ("capri", "pant"), |
| ("denim", "pant"), |
| ("skinny", "pant"), |
| ("boyfriend", "pant"), |
| ("graphic", "top"), |
| ("longsleeve", "top"), |
| ("leather", "jacket"), |
| ] |
| for keyword, category in _EXTRA_KEYWORDS: |
| if keyword in label: |
| return category |
|
|
| return label |
|
|
|
|
| |
| |
| |
| |
| _CLOTHING_VOCAB = frozenset({ |
| "accessories", "bodysuits", "bras", "coat", "dress", "jacket", |
| "legging", "pant", "polo", "shirt", "shoes", "short", "skirt", |
| "socks", "sweater", "swimwear", "top", "underwear", |
| }) |
|
|
|
|
| def is_clothing_label(raw_label: str) -> bool: |
| """True when `raw_label` maps to a known training-time hierarchy.""" |
| return normalize_hierarchy_label(raw_label) in _CLOTHING_VOCAB |
|
|
|
|
| |
| MODANET_CATEGORIES = { |
| 1: "bag", 2: "belt", 3: "boots", 4: "footwear", 5: "outer", |
| 6: "dress", 7: "sunglasses", 8: "pants", 9: "top", 10: "shorts", |
| 11: "skirt", 12: "headwear", 13: "scarf/tie", |
| } |
|
|
| MODANET_ANNOTATIONS_JSON = "data/modanet_instances_train.json" |
| MODANET_IMAGES_DIR = "data/modanet_images/images" |
|
|
|
|
| def load_modanet_samples( |
| num_examples: int, |
| ) -> Tuple[List[Tuple[Image.Image, str]], List[Tuple[Image.Image, str]], List[Tuple[Image.Image, str]]]: |
| """Return (baseline_samples, gap_samples, color_samples) from ModaNet. |
| |
| Loads from local COCO JSON annotations + image directory. |
| Each image may have multiple annotations — we pick the largest bbox area. |
| """ |
| import json as _json |
|
|
| ann_path = Path(MODANET_ANNOTATIONS_JSON) |
| img_dir = Path(MODANET_IMAGES_DIR) |
|
|
| if not ann_path.exists(): |
| print(f" Skipping ModaNet: annotations not found at {MODANET_ANNOTATIONS_JSON}") |
| return [], [], [] |
| if not img_dir.exists(): |
| print(f" Skipping ModaNet: images directory not found at {MODANET_IMAGES_DIR}") |
| return [], [], [] |
|
|
| print(" Loading ModaNet annotations...") |
| with open(ann_path) as f: |
| coco = _json.load(f) |
|
|
| cat_map = {c["id"]: c["name"] for c in coco["categories"]} |
| img_map = {img["id"]: img["file_name"] for img in coco["images"]} |
|
|
| |
| best_per_image: Dict[int, Tuple[int, float]] = {} |
| for ann in coco["annotations"]: |
| img_id = ann["image_id"] |
| cat_id = ann["category_id"] |
| area = ann.get("area", 0) |
| if img_id not in best_per_image or area > best_per_image[img_id][1]: |
| best_per_image[img_id] = (cat_id, area) |
|
|
| |
| image_ids = list(best_per_image.keys()) |
| rng = random.Random(42) |
| rng.shuffle(image_ids) |
|
|
| baseline_samples: List[Tuple[Image.Image, str]] = [] |
| gap_samples: List[Tuple[Image.Image, str]] = [] |
|
|
| for img_id in image_ids: |
| if len(baseline_samples) >= num_examples: |
| break |
| file_name = img_map.get(img_id) |
| if file_name is None: |
| continue |
| img_path = img_dir / file_name |
| if not img_path.exists(): |
| continue |
| try: |
| image = Image.open(img_path).convert("RGB") |
| except Exception: |
| continue |
|
|
| cat_id, _ = best_per_image[img_id] |
| native_label = cat_map.get(cat_id, "unknown") |
| gap_label = normalize_hierarchy_label(native_label) |
| baseline_samples.append((image, native_label)) |
| gap_samples.append((image, gap_label)) |
|
|
| print(f" ModaNet: loaded {len(baseline_samples)} valid samples (from {len(best_per_image)} annotated images)") |
| return baseline_samples, gap_samples, [] |
|
|
|
|
| def zero_shot_modanet( |
| model, |
| processor, |
| device, |
| cfg: RuntimeConfig, |
| batch_size: int = 64, |
| num_examples: int = 10000, |
| use_gap_labels: bool = True, |
| aux_model=None, |
| aux_processor=None, |
| spec_model=None, |
| image_tta: bool = False, |
| ) -> Optional[Dict[str, float]]: |
| """Zero-shot accuracy/F1 on ModaNet dataset.""" |
| baseline_samples, gap_samples, _ = load_modanet_samples(num_examples) |
| samples = gap_samples if use_gap_labels else baseline_samples |
| if not samples: |
| print("Skipping zero_shot_modanet: no valid samples") |
| return None |
|
|
| pil_images = [img for img, _ in samples] |
| labels_text = [label for _, label in samples] |
|
|
| candidate_labels = sorted(set(labels_text)) |
| label_to_idx = {label: idx for idx, label in enumerate(candidate_labels)} |
| all_labels = np.array([label_to_idx[label] for label in labels_text], dtype=np.int64) |
|
|
| single_prompts = [f"a photo of a {label}" for label in candidate_labels] |
| text_embs_single = get_text_embeddings_batch(model, processor, device, single_prompts).to(device).float() |
| text_embs_ens = get_prompt_ensembled_text_embeddings( |
| model, processor, device, candidate_labels, ZERO_SHOT_TEMPLATES, |
| ).to(device).float() |
|
|
| img_embs = _encode_images_batched( |
| model, processor, device, pil_images, batch_size, desc="Zero-shot ModaNet", |
| tta=image_tta, |
| ) |
| aux_img_embs = None |
| aux_text_embs_single = None |
| if aux_model is not None and aux_processor is not None: |
| aux_text_embs_single = get_text_embeddings_batch( |
| aux_model, aux_processor, device, single_prompts, |
| ).to(device).float() |
| aux_img_embs = _encode_images_batched( |
| aux_model, aux_processor, device, pil_images, batch_size, |
| desc="Zero-shot ModaNet (aux)", |
| ) |
| spec_img_embs, spec_text_embs = _maybe_specialist_embeddings( |
| spec_model, pil_images, candidate_labels, batch_size, device, |
| desc="ModaNet specialist", tta=image_tta, |
| ) |
| metrics = run_zero_shot_scoring( |
| img_embs, text_embs_single, text_embs_ens, candidate_labels, all_labels, |
| cfg, dataset_key="modanet", mask_color=False, |
| aux_img_embs=aux_img_embs, aux_text_embs_single=aux_text_embs_single, |
| spec_img_embs=spec_img_embs, spec_text_embs=spec_text_embs, |
| ) |
| label_kind = "GAP" if use_gap_labels else "native" |
| print( |
| f"ModaNet ({label_kind}) zero-shot " |
| f"paper={metrics.get('weighted_f1', 0):.4f} " |
| f"ens_full={metrics.get('f1_full_ensembled', 0):.4f} " |
| f"gen={metrics.get('f1_gen', 0):.4f} " |
| f"hier={metrics.get('f1_hier', 0):.4f} " |
| f"nocolor={metrics.get('f1_nocolor', 0):.4f} " |
| f"fused={metrics.get('f1_fused', 0):.4f} " |
| f"fused+prior={metrics.get('f1_fused_prior', 0):.4f}" |
| ) |
| print( |
| f"ModaNet ({label_kind}) ensemble " |
| f"prob_ens={metrics.get('f1_prob_ens', 0):.4f} " |
| f"prob_ens_adaptive={metrics.get('f1_prob_ens_adaptive', 0):.4f} " |
| f"rerank_topk={metrics.get('f1_rerank', 0):.4f}" |
| ) |
| if any(k.startswith('f1_hybrid_') for k in metrics): |
| print( |
| f"ModaNet ({label_kind}) hybrid " |
| f"w30={metrics.get('f1_hybrid_w30', 0):.4f} " |
| f"w50={metrics.get('f1_hybrid_w50', 0):.4f} " |
| f"w70={metrics.get('f1_hybrid_w70', 0):.4f} " |
| f"rerank={metrics.get('f1_hybrid_rerank', 0):.4f}" |
| ) |
| if any(k.startswith('f1_pure_') for k in metrics): |
| print( |
| f"ModaNet ({label_kind}) pure-boost " |
| f"spec_only={metrics.get('f1_pure_spec_only', 0):.4f} " |
| f"w40={metrics.get('f1_pure_boost_w40', 0):.4f} " |
| f"w50={metrics.get('f1_pure_boost_w50', 0):.4f} " |
| f"w60={metrics.get('f1_pure_boost_w60', 0):.4f}" |
| ) |
| print( |
| f"ModaNet ({label_kind}) type-aware " |
| f"ta={metrics.get('f1_type_aware', 0):.4f} " |
| f"ta_no_prior={metrics.get('f1_type_aware_no_prior', 0):.4f} " |
| f"ta_no_gating={metrics.get('f1_type_aware_no_gating', 0):.4f} " |
| f"parse_rate={metrics.get('type_parse_rate', 0):.2f} " |
| f"H(P_type)={metrics.get('type_entropy', 0):.3f} " |
| f"mean_C={metrics.get('mean_C', 0):.3f}" |
| ) |
| return metrics |
|
|
|
|
| def main( |
| selected_tests: set[str], |
| model=None, |
| processor=None, |
| baseline_model=None, |
| baseline_processor=None, |
| ) -> None: |
| random.seed(42) |
| cfg = resolve_runtime_config() |
|
|
| if model is None or processor is None: |
| model_path = Path(cfg.main_model_path) |
| if not model_path.exists(): |
| raise FileNotFoundError(f"Main model checkpoint not found: {cfg.main_model_path}") |
| print("Loading model...") |
| print(f" device: {cfg.device}") |
| print(f" checkpoint: {cfg.main_model_path}") |
| print(f" dims: color={cfg.color_emb_dim}, hierarchy={cfg.hierarchy_emb_dim}, total={cfg.main_emb_dim}") |
| model, processor = load_main_model(cfg.device, cfg.main_model_path) |
| print("Model loaded.") |
| else: |
| print(f"Using pre-loaded GAP-CLIP model (dims: color={cfg.color_emb_dim}, hierarchy={cfg.hierarchy_emb_dim}, total={cfg.main_emb_dim})") |
|
|
| result_a: Optional[Dict[str, object]] = None |
| result_b: Optional[Dict[str, object]] = None |
| result_c: Optional[Dict[str, object]] = None |
| baseline_result_a: Optional[Dict[str, object]] = None |
| baseline_result_b: Optional[Dict[str, object]] = None |
| baseline_result_c: Optional[Dict[str, object]] = None |
|
|
| if baseline_model is None or baseline_processor is None: |
| if any(t in selected_tests for t in ("A", "B", "C", "D")): |
| print("\nLoading baseline model (patrickjohncyh/fashion-clip)...") |
| baseline_name = "patrickjohncyh/fashion-clip" |
| baseline_processor = CLIPProcessor.from_pretrained(baseline_name) |
| baseline_model = CLIPModelTransformers.from_pretrained(baseline_name).to(cfg.device) |
| baseline_model.eval() |
| print("Baseline model loaded.") |
|
|
| if "A" in selected_tests: |
| result_a = run_test_a( |
| model, |
| processor, |
| cfg, |
| num_examples=DEFAULT_NUM_EXAMPLES, |
| num_printed=DEFAULT_NUM_PRINTED, |
| ) |
| if baseline_model is not None and baseline_processor is not None: |
| baseline_result_a = run_test_a( |
| baseline_model, |
| baseline_processor, |
| cfg, |
| num_examples=DEFAULT_NUM_EXAMPLES, |
| num_printed=DEFAULT_NUM_PRINTED, |
| test_name="Baseline Test A", |
| ) |
| if "B" in selected_tests: |
| result_b = run_test_b( |
| model, |
| processor, |
| cfg, |
| num_examples=DEFAULT_NUM_EXAMPLES, |
| num_printed=DEFAULT_NUM_PRINTED, |
| ) |
| if baseline_model is not None and baseline_processor is not None: |
| baseline_result_b = run_test_b( |
| baseline_model, |
| baseline_processor, |
| cfg, |
| num_examples=DEFAULT_NUM_EXAMPLES, |
| num_printed=DEFAULT_NUM_PRINTED, |
| test_name="Baseline Test B", |
| ) |
| if "C" in selected_tests: |
| result_c = run_test_c( |
| model, |
| processor, |
| cfg, |
| num_examples=DEFAULT_NUM_EXAMPLES, |
| num_printed=DEFAULT_NUM_PRINTED, |
| ) |
| if baseline_model is not None and baseline_processor is not None: |
| baseline_result_c = run_test_c( |
| baseline_model, |
| baseline_processor, |
| cfg, |
| num_examples=DEFAULT_NUM_EXAMPLES, |
| num_printed=DEFAULT_NUM_PRINTED, |
| test_name="Baseline Test C", |
| ) |
|
|
| if "D" in selected_tests: |
| assert baseline_model is not None and baseline_processor is not None |
|
|
| print("\n" + "=" * 120) |
| print("Test D — Notebook-style zero-shot accuracy") |
| print("=" * 120) |
|
|
| |
| |
| spec_model = None |
| try: |
| from evaluation.utils.model_loader import load_hierarchy_model |
| try: |
| import config as _project_config |
| hier_path = getattr(_project_config, "hierarchy_model_path", "models/hierarchy_model.pth") |
| except Exception: |
| hier_path = "models/hierarchy_model.pth" |
| if Path(hier_path).exists(): |
| print(f"Loading specialist HierarchyModel from {hier_path} ...") |
| spec_model = load_hierarchy_model(hier_path, cfg.device) |
| else: |
| print(f" Specialist HierarchyModel not found at {hier_path}; pure-boost disabled") |
| except Exception as exc: |
| print(f" Skipping pure-boost: failed to load specialist ({exc})") |
| spec_model = None |
|
|
| |
| |
| |
| |
| d_results: Dict[str, Dict[str, Optional[Dict[str, float]]]] = { |
| "Fashion-MNIST": { |
| "gap": zero_shot_fashion_mnist(model=model, processor=processor, device=cfg.device, cfg=cfg, batch_size=64, |
| spec_model=spec_model, image_tta=True), |
| "base": zero_shot_fashion_mnist(model=baseline_model, processor=baseline_processor, device=cfg.device, cfg=cfg, batch_size=64), |
| }, |
| "KAGL Marqo": { |
| "gap": zero_shot_kagl(model=model, processor=processor, device=cfg.device, cfg=cfg, batch_size=64, num_examples=DEFAULT_NUM_EXAMPLES, |
| spec_model=spec_model, image_tta=True), |
| "base": zero_shot_kagl(model=baseline_model, processor=baseline_processor, device=cfg.device, cfg=cfg, batch_size=64, num_examples=DEFAULT_NUM_EXAMPLES), |
| }, |
| "Internal dataset": { |
| "gap": zero_shot_internal(model=model, processor=processor, device=cfg.device, cfg=cfg, batch_size=64, num_examples=DEFAULT_NUM_EXAMPLES, |
| spec_model=spec_model, image_tta=True), |
| "base": zero_shot_internal(model=baseline_model, processor=baseline_processor, device=cfg.device, cfg=cfg, batch_size=64, num_examples=DEFAULT_NUM_EXAMPLES), |
| }, |
| "ModaNet": { |
| "gap": zero_shot_modanet(model=model, processor=processor, device=cfg.device, cfg=cfg, batch_size=64, num_examples=DEFAULT_NUM_EXAMPLES, use_gap_labels=True, |
| spec_model=spec_model, image_tta=True), |
| "base": zero_shot_modanet(model=baseline_model, processor=baseline_processor, device=cfg.device, cfg=cfg, batch_size=64, num_examples=DEFAULT_NUM_EXAMPLES, use_gap_labels=True), |
| }, |
| } |
|
|
| print("\n" + "-" * 120) |
| print("Test D summary") |
| print("-" * 120) |
| summary_rows: List[List[str]] = [] |
| for ds in ["Fashion-MNIST", "KAGL Marqo", "ModaNet", "Internal dataset"]: |
| gap_result = d_results[ds]["gap"] |
| base_result = d_results[ds]["base"] |
|
|
| def _fmt(result, key): |
| if result is None: |
| return "N/A" |
| val = result.get(key) |
| return f"{val:.2%}" if val is not None else "N/A" |
|
|
| summary_rows.append([ |
| ds, |
| _fmt(gap_result, "accuracy"), |
| _fmt(gap_result, "accuracy_color"), |
| _fmt(gap_result, "accuracy_hier"), |
| _fmt(base_result, "accuracy"), |
| _fmt(base_result, "accuracy_color"), |
| _fmt(base_result, "accuracy_hier"), |
| ]) |
| print_table( |
| "Test D — zero-shot accuracy (notebook protocol)", |
| ["Dataset", "GAP full", "GAP color[0:16]", "GAP hier[16:80]", "Base full", "Base color[0:16]", "Base hier[16:80]"], |
| summary_rows, |
| ) |
| print("\n" + "=" * 120) |
| print("Final Summary") |
| print("=" * 120) |
| print(f"Tests selected: {''.join(sorted(selected_tests))}") |
| if result_a is not None: |
| print(f"Test A overall: {format_bool(bool(result_a['overall']))}") |
| print(f"Test A full512 accuracy: {float(result_a['accuracy_full512']):.2%}") |
| if baseline_result_a is not None: |
| print(f"Baseline Test A full512 accuracy: {float(baseline_result_a['accuracy_full512']):.2%}") |
| if result_b is not None: |
| print(f"Test B overall: {format_bool(bool(result_b['overall']))}") |
| print(f"Test B full512 accuracy: {float(result_b['accuracy_full512']):.2%}") |
| if baseline_result_b is not None: |
| print(f"Baseline Test B full512 accuracy: {float(baseline_result_b['accuracy_full512']):.2%}") |
| if result_c is not None: |
| print(f"Test C overall: {format_bool(bool(result_c['overall']))}") |
| print(f" pass rate: {float(result_c['pass_rate']):.2%}") |
| print(f" avg color_match={float(result_c['avg_color_match']):.4f} vs cross={float(result_c['avg_color_cross']):.4f}") |
| print(f" avg hier_match={float(result_c['avg_hier_match']):.4f} vs cross={float(result_c['avg_hier_cross']):.4f}") |
| if baseline_result_c is not None: |
| print(f"Baseline Test C overall: {format_bool(bool(baseline_result_c['overall']))}") |
| print(f" baseline pass rate: {float(baseline_result_c['pass_rate']):.2%}") |
|
|
| if result_a is not None: |
| assert float(result_a["pass_rate"]) >= 0.95, ( |
| f"Test A failed: pass rate {float(result_a['pass_rate']):.2%} < 95%." |
| ) |
| if result_b is not None: |
| assert float(result_b["pass_rate"]) >= 0.95, ( |
| f"Test B failed: pass rate {float(result_b['pass_rate']):.2%} < 95%." |
| ) |
| if result_c is not None: |
| assert float(result_c["pass_rate"]) >= 0.95, ( |
| f"Test C failed: subspace decomposition pass rate {float(result_c['pass_rate']):.2%} < 95%." |
| ) |
|
|
| print("\nAll embedding-structure tests passed.") |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Embedding structure evaluation") |
| parser.add_argument("--tests", default="ABCD", help="Which tests to run, e.g. 'C' or 'ABCD'") |
| parser.add_argument("--num-examples", type=int, default=None, help="Override DEFAULT_NUM_EXAMPLES") |
| args = parser.parse_args() |
| if args.num_examples is not None: |
| DEFAULT_NUM_EXAMPLES = args.num_examples |
| selected_tests = set(args.tests.upper()) |
| main(selected_tests) |
|
|