"""Image-text retrieval evaluation for Stage 1 Gate 2 check. Computes Recall@1, R@5, R@10 on a set of image-text pairs. """ import argparse import os import sys import torch import yaml sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from model.vlm import VLJEPAModel from model.tokenizer import BPETokenizer def evaluate_retrieval(model, images, texts, tokenizer, device="cuda"): """Compute image-to-text retrieval metrics. Args: model: VLJEPAModel images: [N, 3, H, W] tensor texts: list of N strings tokenizer: BPE tokenizer device: compute device Returns: dict with r1, r5, r10 """ model.eval() N = len(texts) # Get image embeddings with torch.no_grad(): img_embeds = [] for i in range(N): emb = model.get_embedding(images[i:i+1].to(device)) img_embeds.append(emb) img_embeds = torch.cat(img_embeds, dim=0) # [N, embed_dim] # Get text embeddings via Y-encoder txt_embeds = [] for text in texts: ids = torch.tensor([tokenizer.encode(text)], dtype=torch.long, device=device) emb = model.y_encoder(ids) txt_embeds.append(emb) txt_embeds = torch.cat(txt_embeds, dim=0) # [N, embed_dim] # Compute similarity matrix img_embeds = torch.nn.functional.normalize(img_embeds, dim=-1) txt_embeds = torch.nn.functional.normalize(txt_embeds, dim=-1) sims = img_embeds @ txt_embeds.T # [N, N] # Compute recall r1 = r5 = r10 = 0 for i in range(N): sorted_indices = sims[i].argsort(descending=True) rank = (sorted_indices == i).nonzero(as_tuple=True)[0].item() if rank < 1: r1 += 1 if rank < 5: r5 += 1 if rank < 10: r10 += 1 return { "R@1": r1 / N * 100, "R@5": r5 / N * 100, "R@10": r10 / N * 100, "num_samples": N, } def main(): parser = argparse.ArgumentParser() parser.add_argument("--ckpt", required=True) parser.add_argument("--config", default="configs/scale_1.3b.yaml") parser.add_argument("--num_samples", type=int, default=100) parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") args = parser.parse_args() with open(args.config) as f: config = yaml.safe_load(f) model = VLJEPAModel(config) if os.path.exists(args.ckpt): ckpt = torch.load(args.ckpt, map_location=args.device) if "model_state_dict" in ckpt: model.load_state_dict(ckpt["model_state_dict"], strict=False) else: model.load_state_dict(ckpt, strict=False) model = model.to(args.device).eval() tokenizer = BPETokenizer(vocab_size=config.get("tokenizer", {}).get("vocab_size", 32768)) tokenizer.train(["dummy retrieval text"] * 50) img_size = config["vision"]["img_size"] images = torch.randn(args.num_samples, 3, img_size, img_size) texts = [f"A test caption number {i}" for i in range(args.num_samples)] results = evaluate_retrieval(model, images, texts, tokenizer, args.device) print("\nRetrieval Results:") for k, v in results.items(): if isinstance(v, float): print(f" {k}: {v:.2f}%") else: print(f" {k}: {v}") # Gate 2 check if results["R@1"] > 25: print("\n✅ GATE 2 PASSED: R@1 > 25%") else: print(f"\n❌ GATE 2 FAILED: R@1 = {results['R@1']:.2f}% (need > 25%)") if __name__ == "__main__": main()