#!/usr/bin/env python3 """ fVLM benchmark evaluation (batched MCQ scoring). Usage: # 135M model python benchmark.py --llm /workspace/models/SmolLM2-135M-Instruct \ --checkpoint /workspace/checkpoints/final/stage3/best.pt # 1.7B model python benchmark.py --llm /workspace/models/SmolLM2-1.7B-Instruct \ --checkpoint /workspace/checkpoints/final_1.7B/stage3/latest.pt # Run specific benchmarks only python benchmark.py --checkpoint ... --only MVBench ScienceQA Key optimizations: 1. Batch all MCQ options into ONE forward pass per mode (not N sequential) 2. Compute per-option CE from logits (avoid model's averaged loss) 3. Cache DINO encoding across modes (same frames, reuse kv_cache) """ import sys, os, json, tarfile, io, time, re, glob, gc import torch import torch.nn.functional as F from PIL import Image from torchvision import transforms from transformers import AutoTokenizer from collections import defaultdict from model import FoveatedVLM # ─── Model / tokenizer ────────────────────────────────────────────── def load_model(checkpoint_path, llm_name, dino_name="/workspace/models/dinov2-small", device="cuda"): model = FoveatedVLM( llm_name=llm_name, dino_name=dino_name, query_dim=384, visual_scale=0.14, lambda_coarse=0.0, deep_query=True, ) ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False) state_dict = ckpt["model_state_dict"] # Strip torch.compile's _orig_mod prefix if present if any("._orig_mod." in k for k in state_dict): state_dict = {k.replace("._orig_mod", ""): v for k, v in state_dict.items()} model.load_state_dict(state_dict) model = model.to(device).to(torch.bfloat16).eval() print(f"Loaded: {checkpoint_path} (step {ckpt.get('step', '?')})") return model def load_tokenizer(llm_name): tok = AutoTokenizer.from_pretrained(llm_name) if tok.pad_token is None: tok.pad_token = tok.eos_token return tok FRAME_TRANSFORM = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # ─── Data loading ──────────────────────────────────────────────────── def load_all_samples_from_shards(shard_pattern): shard_files = sorted(glob.glob(shard_pattern)) print(f" Loading from {len(shard_files)} shards...") samples = [] for shard_path in shard_files: with tarfile.open(shard_path, "r") as tar: members = tar.getmembers() grouped = {} for m in members: parts = m.name.split(".") if m.name.endswith(".json"): key = parts[0] if key not in grouped: grouped[key] = {"frames": {}} grouped[key]["json"] = json.load(tar.extractfile(m)) elif m.name.endswith(".jpg") or m.name.endswith(".png"): key = parts[0] frame_idx = int(parts[1]) if len(parts) >= 3 else 0 if key not in grouped: grouped[key] = {"frames": {}} img_data = tar.extractfile(m).read() img = Image.open(io.BytesIO(img_data)).convert("RGB") grouped[key]["frames"][frame_idx] = img for key in sorted(grouped.keys()): entry = grouped[key] if entry.get("json") and entry.get("frames"): sorted_frames = [entry["frames"][i] for i in sorted(entry["frames"].keys())] samples.append({ "key": key, "json": entry["json"], "frames": sorted_frames, }) print(f" Loaded {len(samples)} samples") return samples def prepare_frames_tensor(pil_frames, device="cuda", replicate_to=8): """Transform PIL frames to tensor. Replicate single-frame images to match training.""" tensors = [FRAME_TRANSFORM(f) for f in pil_frames] frames = torch.stack(tensors) # [T, 3, H, W] # Replicate single images to N frames (matches training with replicate_image_frames=8) if frames.shape[0] == 1 and replicate_to > 1: frames = frames.repeat(replicate_to, 1, 1, 1) # [N, 3, H, W] return frames.unsqueeze(0).to(device, dtype=torch.bfloat16) # ─── MCQ helpers ───────────────────────────────────────────────────── def parse_mcq_options(user_text): options = {} for match in re.finditer(r'([A-Z])\.\s*(.+?)(?=\n[A-Z]\.|$)', user_text, re.DOTALL): options[match.group(1)] = match.group(1) + ". " + match.group(2).strip() return options def extract_answer_letter(assistant_text): m = re.match(r'([A-Z])\.', assistant_text.strip()) if m: return m.group(1) return assistant_text.strip()[0] if assistant_text.strip() else "?" # ─── Batched option scoring (KEY OPTIMIZATION) ────────────────────── @torch.no_grad() def score_options_batched(model, tokenizer, frames, question_text, options_dict, mode, device): """ Score all MCQ options in ONE batched forward pass. Returns dict {letter: loss} where lower loss = better match. """ letters = sorted(options_dict.keys()) if not letters: return {} # Tokenize prompt (shared across all options) prompt_messages = [ {"role": "system", "content": "You are a helpful AI assistant."}, {"role": "user", "content": question_text}, ] prompt_text = tokenizer.apply_chat_template(prompt_messages, tokenize=False, add_generation_prompt=True) prompt_ids = tokenizer.encode(prompt_text) S_prompt = len(prompt_ids) # Tokenize each full sequence (prompt + option) all_ids = [] for letter in letters: messages = [ {"role": "system", "content": "You are a helpful AI assistant."}, {"role": "user", "content": question_text}, {"role": "assistant", "content": options_dict[letter]}, ] full_text = tokenizer.apply_chat_template(messages, tokenize=False) ids = tokenizer.encode(full_text) all_ids.append(ids) # Pad to same length max_len = max(len(ids) for ids in all_ids) pad_id = tokenizer.pad_token_id N = len(letters) batch_ids = torch.full((N, max_len), pad_id, dtype=torch.long, device=device) batch_attn = torch.zeros(N, max_len, dtype=torch.long, device=device) batch_loss_mask = torch.zeros(N, max_len, dtype=torch.float32, device=device) for i, ids in enumerate(all_ids): L = len(ids) batch_ids[i, :L] = torch.tensor(ids, dtype=torch.long) batch_attn[i, :L] = 1 batch_loss_mask[i, S_prompt:L] = 1.0 # answer-only tokens # Expand frames: same image for all options frames_batch = frames.expand(N, -1, -1, -1, -1) # [N, T, 3, H, W] # Single batched forward with torch.amp.autocast("cuda", dtype=torch.bfloat16): result = model.forward( frames=frames_batch, input_ids=batch_ids, attention_mask=batch_attn, loss_mask=batch_loss_mask, mode=mode, ) # Compute per-option loss from logits logits = result["logits"] # [N, T_vis+S, V] (T_vis=T for coarse, 1 for autoregressive) T_visual = logits.shape[1] - batch_ids.shape[1] # adaptive to mode # Extract text portion of logits text_logits = logits[:, T_visual:, :] # [N, S, V] shift_logits = text_logits[:, :-1, :].contiguous() # [N, S-1, V] shift_labels = batch_ids[:, 1:].contiguous() # [N, S-1] shift_mask = batch_loss_mask[:, 1:].contiguous() # [N, S-1] # Per-token CE loss V = shift_logits.shape[-1] per_token_loss = F.cross_entropy( shift_logits.reshape(-1, V), shift_labels.reshape(-1), reduction="none", ignore_index=pad_id, ).reshape(N, -1) # [N, S-1] # Average loss over answer tokens only (per option) masked_loss = (per_token_loss * shift_mask).sum(dim=1) / shift_mask.sum(dim=1).clamp(min=1) return {letters[i]: masked_loss[i].item() for i in range(N)} # ─── MCQ benchmark evaluation ─────────────────────────────────────── @torch.no_grad() def evaluate_mcq_benchmark(model, tokenizer, samples, benchmark_name, modes, device): results = {mode: {"correct": 0, "total": 0, "per_category": defaultdict(lambda: {"correct": 0, "total": 0})} for mode in modes} t0 = time.time() for i, sample in enumerate(samples): meta = sample["json"] user_text = meta["user"] gt_answer = meta["assistant"] source = meta.get("source", "unknown") category = source.split("/")[-1] if "/" in source else source gt_letter = extract_answer_letter(gt_answer) options = parse_mcq_options(user_text) if not options: continue frames = prepare_frames_tensor(sample["frames"], device=device) for mode in modes: option_losses = score_options_batched( model, tokenizer, frames, user_text, options, mode, device ) if not option_losses: continue pred_letter = min(option_losses, key=option_losses.get) correct = (pred_letter == gt_letter) results[mode]["total"] += 1 if correct: results[mode]["correct"] += 1 results[mode]["per_category"][category]["total"] += 1 if correct: results[mode]["per_category"][category]["correct"] += 1 if (i + 1) % 100 == 0: elapsed = time.time() - t0 for mode in modes: r = results[mode] acc = r["correct"] / max(r["total"], 1) * 100 print(f" [{benchmark_name}] {i+1}/{len(samples)} | {mode}: {acc:.1f}% ({r['correct']}/{r['total']}) | {elapsed:.0f}s", flush=True) return results # ─── Val loss evaluation ───────────────────────────────────────────── @torch.no_grad() def evaluate_val_loss(model, tokenizer, shard_pattern, modes, device, max_samples=1000): samples = load_all_samples_from_shards(shard_pattern) if max_samples: samples = samples[:max_samples] results = {mode: {"total_loss": 0.0, "count": 0} for mode in modes} t0 = time.time() for i, sample in enumerate(samples): meta = sample["json"] frames = prepare_frames_tensor(sample["frames"], device=device) if "token_ids" in meta: input_ids = torch.tensor(meta["token_ids"], dtype=torch.long).unsqueeze(0).to(device) loss_mask_vals = meta.get("loss_mask", [1] * len(meta["token_ids"])) loss_mask = torch.tensor(loss_mask_vals, dtype=torch.float32).unsqueeze(0).to(device) else: caption = meta.get("caption", meta.get("assistant", "")) user = meta.get("user", "Describe this video.") messages = [ {"role": "system", "content": "You are a helpful AI assistant."}, {"role": "user", "content": user}, {"role": "assistant", "content": caption}, ] text = tokenizer.apply_chat_template(messages, tokenize=False) input_ids = tokenizer.encode(text, return_tensors="pt").to(device) loss_mask = torch.ones_like(input_ids, dtype=torch.float32) attention_mask = torch.ones_like(input_ids) for mode in modes: with torch.amp.autocast("cuda", dtype=torch.bfloat16): result = model.forward( frames=frames, input_ids=input_ids, attention_mask=attention_mask, loss_mask=loss_mask, mode=mode, ) results[mode]["total_loss"] += result["loss"].item() results[mode]["count"] += 1 if (i + 1) % 200 == 0: elapsed = time.time() - t0 for mode in modes: r = results[mode] avg = r["total_loss"] / max(r["count"], 1) print(f" [val_10k] {i+1}/{len(samples)} | {mode}: loss={avg:.4f} | {elapsed:.0f}s", flush=True) return results # ─── Main ──────────────────────────────────────────────────────────── def run_mcq_benchmark(model, tokenizer, name, shard_pattern, modes, device, all_results): """Load, evaluate, and free one MCQ benchmark.""" shards = glob.glob(shard_pattern) if not shards: print(f" Skipping {name} — shards not found") return samples = load_all_samples_from_shards(shard_pattern) results = evaluate_mcq_benchmark(model, tokenizer, samples, name, modes, device) del samples; gc.collect() # free PIL images immediately all_results[name.lower().replace("-", "_").replace(" ", "_")] = {} key = name.lower().replace("-", "_").replace(" ", "_") for mode in modes: r = results[mode] acc = r["correct"] / max(r["total"], 1) * 100 all_results[key][mode] = { "accuracy": acc, "correct": r["correct"], "total": r["total"], "per_category": {cat: {"accuracy": v["correct"]/max(v["total"],1)*100, "correct": v["correct"], "total": v["total"]} for cat, v in r["per_category"].items()}, } print(f" {mode}: {acc:.1f}% ({r['correct']}/{r['total']})") def main(): import argparse parser = argparse.ArgumentParser(description="fVLM benchmark evaluation") parser.add_argument("--llm", default="/workspace/models/SmolLM2-135M-Instruct", help="HuggingFace LLM path (e.g. SmolLM2-135M or 1.7B)") parser.add_argument("--checkpoint", default="/workspace/checkpoints/final/stage3/best.pt", help="Path to model checkpoint (.pt)") parser.add_argument("--dino", default="/workspace/models/dinov2-small", help="HuggingFace DINOv2 path") parser.add_argument("--only", nargs="+", default=None, help="Run only specified benchmarks (e.g. --only MVBench ScienceQA)") parser.add_argument("--output", default=None, help="Output JSON path (default: /workspace/benchmark_results_{model}.json)") parser.add_argument("--merge", action="store_true", help="Merge with existing results file instead of overwriting") args = parser.parse_args() # Auto-detect model name for output file model_name = os.path.basename(args.llm).replace("-Instruct", "").replace("SmolLM2-", "") if args.output is None: args.output = f"/workspace/benchmark_results_{model_name}.json" device = "cuda" modes = ["coarse_only", "coarse_fine", "autoregressive"] print("=" * 70) print(f"fVLM-{model_name} BENCHMARK EVALUATION") print(f" LLM: {args.llm}") print(f" Checkpoint: {args.checkpoint}") print(f" Output: {args.output}") print("=" * 70) print("\nLoading model (bf16)...") model = load_model(args.checkpoint, args.llm, args.dino, device) tokenizer = load_tokenizer(args.llm) # Load existing results if merging all_results = {} if args.merge and os.path.exists(args.output): with open(args.output) as f: all_results = json.load(f) print(f" Loaded existing results from {args.output}") t_global = time.time() # ─── MCQ benchmarks (load one at a time, free between) ─────── benchmarks = [ ("MVBench", "/workspace/data/eval/benchmarks/mvbench_shards/mvbench_*.tar"), ("Video-MME", "/workspace/data/eval/benchmarks/video_mme_shards/video_mme_*.tar"), ("ScienceQA", "/workspace/data/eval/benchmarks/scienceqa_shards/scienceqa_*.tar"), ("POPE", "/workspace/data/eval/benchmarks/pope_shards/pope_*.tar"), ("MLVU", "/workspace/data/eval/benchmarks/mlvu_shards/mlvu_*.tar"), ] if args.only: only_set = {n.lower() for n in args.only} benchmarks = [(n, p) for n, p in benchmarks if n.lower() in only_set] for i, (name, pattern) in enumerate(benchmarks): print(f"\n{'-' * 70}") print(f"BENCHMARK: {name}") print(f"{'-' * 70}") run_mcq_benchmark(model, tokenizer, name, pattern, modes, device, all_results) # ─── Save results ──────────────────────────────────────────── with open(args.output, "w") as f: json.dump(all_results, f, indent=2) print(f"\nResults saved: {args.output}") # ─── Summary ───────────────────────────────────────────────── total_time = time.time() - t_global print("\n" + "=" * 70) print(f"SUMMARY (total: {total_time:.0f}s = {total_time/60:.1f}min)") print("=" * 70) print(f"\n{'Benchmark':<15} {'Coarse-Only':>15} {'Coarse->Fine':>15} {'Autoregressive':>15}") print("-" * 62) for bench_name, bench_data in all_results.items(): vals = [] for mode in modes: if mode not in bench_data: vals.append("—") elif "accuracy" in bench_data[mode]: vals.append(f"{bench_data[mode]['accuracy']:.1f}%") else: vals.append(f"{bench_data[mode]['avg_loss']:.4f}") print(f"{bench_name:<15} {vals[0]:>15} {vals[1]:>15} {vals[2]:>15}") print("\n" + "=" * 70) if __name__ == "__main__": main()