| | |
| | """ |
| | fVLM benchmark evaluation (batched MCQ scoring). |
| | |
| | Usage: |
| | # 135M model |
| | python benchmark.py --llm /workspace/models/SmolLM2-135M-Instruct \ |
| | --checkpoint /workspace/checkpoints/final/stage3/best.pt |
| | |
| | # 1.7B model |
| | python benchmark.py --llm /workspace/models/SmolLM2-1.7B-Instruct \ |
| | --checkpoint /workspace/checkpoints/final_1.7B/stage3/latest.pt |
| | |
| | # Run specific benchmarks only |
| | python benchmark.py --checkpoint ... --only MVBench ScienceQA |
| | |
| | Key optimizations: |
| | 1. Batch all MCQ options into ONE forward pass per mode (not N sequential) |
| | 2. Compute per-option CE from logits (avoid model's averaged loss) |
| | 3. Cache DINO encoding across modes (same frames, reuse kv_cache) |
| | """ |
| |
|
| | import sys, os, json, tarfile, io, time, re, glob, gc |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from PIL import Image |
| | from torchvision import transforms |
| | from transformers import AutoTokenizer |
| | from collections import defaultdict |
| |
|
| | from model import FoveatedVLM |
| |
|
| |
|
| | |
| |
|
| | def load_model(checkpoint_path, llm_name, dino_name="/workspace/models/dinov2-small", |
| | device="cuda"): |
| | model = FoveatedVLM( |
| | llm_name=llm_name, dino_name=dino_name, |
| | query_dim=384, visual_scale=0.14, lambda_coarse=0.0, deep_query=True, |
| | ) |
| | ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False) |
| | state_dict = ckpt["model_state_dict"] |
| | |
| | if any("._orig_mod." in k for k in state_dict): |
| | state_dict = {k.replace("._orig_mod", ""): v for k, v in state_dict.items()} |
| | model.load_state_dict(state_dict) |
| | model = model.to(device).to(torch.bfloat16).eval() |
| | print(f"Loaded: {checkpoint_path} (step {ckpt.get('step', '?')})") |
| | return model |
| |
|
| |
|
| | def load_tokenizer(llm_name): |
| | tok = AutoTokenizer.from_pretrained(llm_name) |
| | if tok.pad_token is None: |
| | tok.pad_token = tok.eos_token |
| | return tok |
| |
|
| |
|
| | FRAME_TRANSFORM = transforms.Compose([ |
| | transforms.Resize((224, 224)), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| | ]) |
| |
|
| |
|
| | |
| |
|
| | def load_all_samples_from_shards(shard_pattern): |
| | shard_files = sorted(glob.glob(shard_pattern)) |
| | print(f" Loading from {len(shard_files)} shards...") |
| | samples = [] |
| | for shard_path in shard_files: |
| | with tarfile.open(shard_path, "r") as tar: |
| | members = tar.getmembers() |
| | grouped = {} |
| | for m in members: |
| | parts = m.name.split(".") |
| | if m.name.endswith(".json"): |
| | key = parts[0] |
| | if key not in grouped: |
| | grouped[key] = {"frames": {}} |
| | grouped[key]["json"] = json.load(tar.extractfile(m)) |
| | elif m.name.endswith(".jpg") or m.name.endswith(".png"): |
| | key = parts[0] |
| | frame_idx = int(parts[1]) if len(parts) >= 3 else 0 |
| | if key not in grouped: |
| | grouped[key] = {"frames": {}} |
| | img_data = tar.extractfile(m).read() |
| | img = Image.open(io.BytesIO(img_data)).convert("RGB") |
| | grouped[key]["frames"][frame_idx] = img |
| | for key in sorted(grouped.keys()): |
| | entry = grouped[key] |
| | if entry.get("json") and entry.get("frames"): |
| | sorted_frames = [entry["frames"][i] for i in sorted(entry["frames"].keys())] |
| | samples.append({ |
| | "key": key, |
| | "json": entry["json"], |
| | "frames": sorted_frames, |
| | }) |
| | print(f" Loaded {len(samples)} samples") |
| | return samples |
| |
|
| |
|
| | def prepare_frames_tensor(pil_frames, device="cuda", replicate_to=8): |
| | """Transform PIL frames to tensor. Replicate single-frame images to match training.""" |
| | tensors = [FRAME_TRANSFORM(f) for f in pil_frames] |
| | frames = torch.stack(tensors) |
| | |
| | if frames.shape[0] == 1 and replicate_to > 1: |
| | frames = frames.repeat(replicate_to, 1, 1, 1) |
| | return frames.unsqueeze(0).to(device, dtype=torch.bfloat16) |
| |
|
| |
|
| | |
| |
|
| | def parse_mcq_options(user_text): |
| | options = {} |
| | for match in re.finditer(r'([A-Z])\.\s*(.+?)(?=\n[A-Z]\.|$)', user_text, re.DOTALL): |
| | options[match.group(1)] = match.group(1) + ". " + match.group(2).strip() |
| | return options |
| |
|
| |
|
| | def extract_answer_letter(assistant_text): |
| | m = re.match(r'([A-Z])\.', assistant_text.strip()) |
| | if m: |
| | return m.group(1) |
| | return assistant_text.strip()[0] if assistant_text.strip() else "?" |
| |
|
| |
|
| | |
| |
|
| | @torch.no_grad() |
| | def score_options_batched(model, tokenizer, frames, question_text, options_dict, mode, device): |
| | """ |
| | Score all MCQ options in ONE batched forward pass. |
| | Returns dict {letter: loss} where lower loss = better match. |
| | """ |
| | letters = sorted(options_dict.keys()) |
| | if not letters: |
| | return {} |
| |
|
| | |
| | prompt_messages = [ |
| | {"role": "system", "content": "You are a helpful AI assistant."}, |
| | {"role": "user", "content": question_text}, |
| | ] |
| | prompt_text = tokenizer.apply_chat_template(prompt_messages, tokenize=False, add_generation_prompt=True) |
| | prompt_ids = tokenizer.encode(prompt_text) |
| | S_prompt = len(prompt_ids) |
| |
|
| | |
| | all_ids = [] |
| | for letter in letters: |
| | messages = [ |
| | {"role": "system", "content": "You are a helpful AI assistant."}, |
| | {"role": "user", "content": question_text}, |
| | {"role": "assistant", "content": options_dict[letter]}, |
| | ] |
| | full_text = tokenizer.apply_chat_template(messages, tokenize=False) |
| | ids = tokenizer.encode(full_text) |
| | all_ids.append(ids) |
| |
|
| | |
| | max_len = max(len(ids) for ids in all_ids) |
| | pad_id = tokenizer.pad_token_id |
| | N = len(letters) |
| |
|
| | batch_ids = torch.full((N, max_len), pad_id, dtype=torch.long, device=device) |
| | batch_attn = torch.zeros(N, max_len, dtype=torch.long, device=device) |
| | batch_loss_mask = torch.zeros(N, max_len, dtype=torch.float32, device=device) |
| |
|
| | for i, ids in enumerate(all_ids): |
| | L = len(ids) |
| | batch_ids[i, :L] = torch.tensor(ids, dtype=torch.long) |
| | batch_attn[i, :L] = 1 |
| | batch_loss_mask[i, S_prompt:L] = 1.0 |
| |
|
| | |
| | frames_batch = frames.expand(N, -1, -1, -1, -1) |
| |
|
| | |
| | with torch.amp.autocast("cuda", dtype=torch.bfloat16): |
| | result = model.forward( |
| | frames=frames_batch, |
| | input_ids=batch_ids, |
| | attention_mask=batch_attn, |
| | loss_mask=batch_loss_mask, |
| | mode=mode, |
| | ) |
| |
|
| | |
| | logits = result["logits"] |
| | T_visual = logits.shape[1] - batch_ids.shape[1] |
| |
|
| | |
| | text_logits = logits[:, T_visual:, :] |
| | shift_logits = text_logits[:, :-1, :].contiguous() |
| | shift_labels = batch_ids[:, 1:].contiguous() |
| | shift_mask = batch_loss_mask[:, 1:].contiguous() |
| |
|
| | |
| | V = shift_logits.shape[-1] |
| | per_token_loss = F.cross_entropy( |
| | shift_logits.reshape(-1, V), |
| | shift_labels.reshape(-1), |
| | reduction="none", |
| | ignore_index=pad_id, |
| | ).reshape(N, -1) |
| |
|
| | |
| | masked_loss = (per_token_loss * shift_mask).sum(dim=1) / shift_mask.sum(dim=1).clamp(min=1) |
| |
|
| | return {letters[i]: masked_loss[i].item() for i in range(N)} |
| |
|
| |
|
| | |
| |
|
| | @torch.no_grad() |
| | def evaluate_mcq_benchmark(model, tokenizer, samples, benchmark_name, modes, device): |
| | results = {mode: {"correct": 0, "total": 0, "per_category": defaultdict(lambda: {"correct": 0, "total": 0})} |
| | for mode in modes} |
| | t0 = time.time() |
| |
|
| | for i, sample in enumerate(samples): |
| | meta = sample["json"] |
| | user_text = meta["user"] |
| | gt_answer = meta["assistant"] |
| | source = meta.get("source", "unknown") |
| | category = source.split("/")[-1] if "/" in source else source |
| | gt_letter = extract_answer_letter(gt_answer) |
| | options = parse_mcq_options(user_text) |
| | if not options: |
| | continue |
| |
|
| | frames = prepare_frames_tensor(sample["frames"], device=device) |
| |
|
| | for mode in modes: |
| | option_losses = score_options_batched( |
| | model, tokenizer, frames, user_text, options, mode, device |
| | ) |
| | if not option_losses: |
| | continue |
| | pred_letter = min(option_losses, key=option_losses.get) |
| | correct = (pred_letter == gt_letter) |
| | results[mode]["total"] += 1 |
| | if correct: |
| | results[mode]["correct"] += 1 |
| | results[mode]["per_category"][category]["total"] += 1 |
| | if correct: |
| | results[mode]["per_category"][category]["correct"] += 1 |
| |
|
| | if (i + 1) % 100 == 0: |
| | elapsed = time.time() - t0 |
| | for mode in modes: |
| | r = results[mode] |
| | acc = r["correct"] / max(r["total"], 1) * 100 |
| | print(f" [{benchmark_name}] {i+1}/{len(samples)} | {mode}: {acc:.1f}% ({r['correct']}/{r['total']}) | {elapsed:.0f}s", flush=True) |
| |
|
| | return results |
| |
|
| |
|
| | |
| |
|
| | @torch.no_grad() |
| | def evaluate_val_loss(model, tokenizer, shard_pattern, modes, device, max_samples=1000): |
| | samples = load_all_samples_from_shards(shard_pattern) |
| | if max_samples: |
| | samples = samples[:max_samples] |
| |
|
| | results = {mode: {"total_loss": 0.0, "count": 0} for mode in modes} |
| | t0 = time.time() |
| |
|
| | for i, sample in enumerate(samples): |
| | meta = sample["json"] |
| | frames = prepare_frames_tensor(sample["frames"], device=device) |
| |
|
| | if "token_ids" in meta: |
| | input_ids = torch.tensor(meta["token_ids"], dtype=torch.long).unsqueeze(0).to(device) |
| | loss_mask_vals = meta.get("loss_mask", [1] * len(meta["token_ids"])) |
| | loss_mask = torch.tensor(loss_mask_vals, dtype=torch.float32).unsqueeze(0).to(device) |
| | else: |
| | caption = meta.get("caption", meta.get("assistant", "")) |
| | user = meta.get("user", "Describe this video.") |
| | messages = [ |
| | {"role": "system", "content": "You are a helpful AI assistant."}, |
| | {"role": "user", "content": user}, |
| | {"role": "assistant", "content": caption}, |
| | ] |
| | text = tokenizer.apply_chat_template(messages, tokenize=False) |
| | input_ids = tokenizer.encode(text, return_tensors="pt").to(device) |
| | loss_mask = torch.ones_like(input_ids, dtype=torch.float32) |
| |
|
| | attention_mask = torch.ones_like(input_ids) |
| |
|
| | for mode in modes: |
| | with torch.amp.autocast("cuda", dtype=torch.bfloat16): |
| | result = model.forward( |
| | frames=frames, input_ids=input_ids, |
| | attention_mask=attention_mask, loss_mask=loss_mask, |
| | mode=mode, |
| | ) |
| | results[mode]["total_loss"] += result["loss"].item() |
| | results[mode]["count"] += 1 |
| |
|
| | if (i + 1) % 200 == 0: |
| | elapsed = time.time() - t0 |
| | for mode in modes: |
| | r = results[mode] |
| | avg = r["total_loss"] / max(r["count"], 1) |
| | print(f" [val_10k] {i+1}/{len(samples)} | {mode}: loss={avg:.4f} | {elapsed:.0f}s", flush=True) |
| |
|
| | return results |
| |
|
| |
|
| | |
| |
|
| | def run_mcq_benchmark(model, tokenizer, name, shard_pattern, modes, device, all_results): |
| | """Load, evaluate, and free one MCQ benchmark.""" |
| | shards = glob.glob(shard_pattern) |
| | if not shards: |
| | print(f" Skipping {name} β shards not found") |
| | return |
| | samples = load_all_samples_from_shards(shard_pattern) |
| | results = evaluate_mcq_benchmark(model, tokenizer, samples, name, modes, device) |
| | del samples; gc.collect() |
| |
|
| | all_results[name.lower().replace("-", "_").replace(" ", "_")] = {} |
| | key = name.lower().replace("-", "_").replace(" ", "_") |
| | for mode in modes: |
| | r = results[mode] |
| | acc = r["correct"] / max(r["total"], 1) * 100 |
| | all_results[key][mode] = { |
| | "accuracy": acc, "correct": r["correct"], "total": r["total"], |
| | "per_category": {cat: {"accuracy": v["correct"]/max(v["total"],1)*100, |
| | "correct": v["correct"], "total": v["total"]} |
| | for cat, v in r["per_category"].items()}, |
| | } |
| | print(f" {mode}: {acc:.1f}% ({r['correct']}/{r['total']})") |
| |
|
| |
|
| | def main(): |
| | import argparse |
| | parser = argparse.ArgumentParser(description="fVLM benchmark evaluation") |
| | parser.add_argument("--llm", default="/workspace/models/SmolLM2-135M-Instruct", |
| | help="HuggingFace LLM path (e.g. SmolLM2-135M or 1.7B)") |
| | parser.add_argument("--checkpoint", default="/workspace/checkpoints/final/stage3/best.pt", |
| | help="Path to model checkpoint (.pt)") |
| | parser.add_argument("--dino", default="/workspace/models/dinov2-small", |
| | help="HuggingFace DINOv2 path") |
| | parser.add_argument("--only", nargs="+", default=None, |
| | help="Run only specified benchmarks (e.g. --only MVBench ScienceQA)") |
| | parser.add_argument("--output", default=None, |
| | help="Output JSON path (default: /workspace/benchmark_results_{model}.json)") |
| | parser.add_argument("--merge", action="store_true", |
| | help="Merge with existing results file instead of overwriting") |
| | args = parser.parse_args() |
| |
|
| | |
| | model_name = os.path.basename(args.llm).replace("-Instruct", "").replace("SmolLM2-", "") |
| | if args.output is None: |
| | args.output = f"/workspace/benchmark_results_{model_name}.json" |
| |
|
| | device = "cuda" |
| | modes = ["coarse_only", "coarse_fine", "autoregressive"] |
| |
|
| | print("=" * 70) |
| | print(f"fVLM-{model_name} BENCHMARK EVALUATION") |
| | print(f" LLM: {args.llm}") |
| | print(f" Checkpoint: {args.checkpoint}") |
| | print(f" Output: {args.output}") |
| | print("=" * 70) |
| |
|
| | print("\nLoading model (bf16)...") |
| | model = load_model(args.checkpoint, args.llm, args.dino, device) |
| | tokenizer = load_tokenizer(args.llm) |
| |
|
| | |
| | all_results = {} |
| | if args.merge and os.path.exists(args.output): |
| | with open(args.output) as f: |
| | all_results = json.load(f) |
| | print(f" Loaded existing results from {args.output}") |
| |
|
| | t_global = time.time() |
| |
|
| | |
| | benchmarks = [ |
| | ("MVBench", "/workspace/data/eval/benchmarks/mvbench_shards/mvbench_*.tar"), |
| | ("Video-MME", "/workspace/data/eval/benchmarks/video_mme_shards/video_mme_*.tar"), |
| | ("ScienceQA", "/workspace/data/eval/benchmarks/scienceqa_shards/scienceqa_*.tar"), |
| | ("POPE", "/workspace/data/eval/benchmarks/pope_shards/pope_*.tar"), |
| | ("MLVU", "/workspace/data/eval/benchmarks/mlvu_shards/mlvu_*.tar"), |
| | ] |
| |
|
| | if args.only: |
| | only_set = {n.lower() for n in args.only} |
| | benchmarks = [(n, p) for n, p in benchmarks if n.lower() in only_set] |
| |
|
| | for i, (name, pattern) in enumerate(benchmarks): |
| | print(f"\n{'-' * 70}") |
| | print(f"BENCHMARK: {name}") |
| | print(f"{'-' * 70}") |
| | run_mcq_benchmark(model, tokenizer, name, pattern, modes, device, all_results) |
| |
|
| | |
| | with open(args.output, "w") as f: |
| | json.dump(all_results, f, indent=2) |
| | print(f"\nResults saved: {args.output}") |
| |
|
| | |
| | total_time = time.time() - t_global |
| | print("\n" + "=" * 70) |
| | print(f"SUMMARY (total: {total_time:.0f}s = {total_time/60:.1f}min)") |
| | print("=" * 70) |
| | print(f"\n{'Benchmark':<15} {'Coarse-Only':>15} {'Coarse->Fine':>15} {'Autoregressive':>15}") |
| | print("-" * 62) |
| | for bench_name, bench_data in all_results.items(): |
| | vals = [] |
| | for mode in modes: |
| | if mode not in bench_data: |
| | vals.append("β") |
| | elif "accuracy" in bench_data[mode]: |
| | vals.append(f"{bench_data[mode]['accuracy']:.1f}%") |
| | else: |
| | vals.append(f"{bench_data[mode]['avg_loss']:.4f}") |
| | print(f"{bench_name:<15} {vals[0]:>15} {vals[1]:>15} {vals[2]:>15}") |
| | print("\n" + "=" * 70) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|