arcisvlm / scripts /eval_retrieval.py
Hardik Sanghvi
feat: integrate Gemma 4 E2B backbone for production-quality VLM inference
7a564e3
Raw
History Blame Contribute Delete
3.6 kB
"""Image-text retrieval evaluation for Stage 1 Gate 2 check.
Computes Recall@1, R@5, R@10 on a set of image-text pairs.
"""
import argparse
import os
import sys
import torch
import yaml
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from model.vlm import VLJEPAModel
from model.tokenizer import BPETokenizer
def evaluate_retrieval(model, images, texts, tokenizer, device="cuda"):
"""Compute image-to-text retrieval metrics.
Args:
model: VLJEPAModel
images: [N, 3, H, W] tensor
texts: list of N strings
tokenizer: BPE tokenizer
device: compute device
Returns:
dict with r1, r5, r10
"""
model.eval()
N = len(texts)
# Get image embeddings
with torch.no_grad():
img_embeds = []
for i in range(N):
emb = model.get_embedding(images[i:i+1].to(device))
img_embeds.append(emb)
img_embeds = torch.cat(img_embeds, dim=0) # [N, embed_dim]
# Get text embeddings via Y-encoder
txt_embeds = []
for text in texts:
ids = torch.tensor([tokenizer.encode(text)], dtype=torch.long, device=device)
emb = model.y_encoder(ids)
txt_embeds.append(emb)
txt_embeds = torch.cat(txt_embeds, dim=0) # [N, embed_dim]
# Compute similarity matrix
img_embeds = torch.nn.functional.normalize(img_embeds, dim=-1)
txt_embeds = torch.nn.functional.normalize(txt_embeds, dim=-1)
sims = img_embeds @ txt_embeds.T # [N, N]
# Compute recall
r1 = r5 = r10 = 0
for i in range(N):
sorted_indices = sims[i].argsort(descending=True)
rank = (sorted_indices == i).nonzero(as_tuple=True)[0].item()
if rank < 1:
r1 += 1
if rank < 5:
r5 += 1
if rank < 10:
r10 += 1
return {
"R@1": r1 / N * 100,
"R@5": r5 / N * 100,
"R@10": r10 / N * 100,
"num_samples": N,
}
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt", required=True)
parser.add_argument("--config", default="configs/scale_1.3b.yaml")
parser.add_argument("--num_samples", type=int, default=100)
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
args = parser.parse_args()
with open(args.config) as f:
config = yaml.safe_load(f)
model = VLJEPAModel(config)
if os.path.exists(args.ckpt):
ckpt = torch.load(args.ckpt, map_location=args.device)
if "model_state_dict" in ckpt:
model.load_state_dict(ckpt["model_state_dict"], strict=False)
else:
model.load_state_dict(ckpt, strict=False)
model = model.to(args.device).eval()
tokenizer = BPETokenizer(vocab_size=config.get("tokenizer", {}).get("vocab_size", 32768))
tokenizer.train(["dummy retrieval text"] * 50)
img_size = config["vision"]["img_size"]
images = torch.randn(args.num_samples, 3, img_size, img_size)
texts = [f"A test caption number {i}" for i in range(args.num_samples)]
results = evaluate_retrieval(model, images, texts, tokenizer, args.device)
print("\nRetrieval Results:")
for k, v in results.items():
if isinstance(v, float):
print(f" {k}: {v:.2f}%")
else:
print(f" {k}: {v}")
# Gate 2 check
if results["R@1"] > 25:
print("\n✅ GATE 2 PASSED: R@1 > 25%")
else:
print(f"\n❌ GATE 2 FAILED: R@1 = {results['R@1']:.2f}% (need > 25%)")
if __name__ == "__main__":
main()