fVLM-1.7B / benchmark.py
sanps's picture
Upload benchmark.py
8d699a5 verified
#!/usr/bin/env python3
"""
fVLM benchmark evaluation (batched MCQ scoring).
Usage:
# 135M model
python benchmark.py --llm /workspace/models/SmolLM2-135M-Instruct \
--checkpoint /workspace/checkpoints/final/stage3/best.pt
# 1.7B model
python benchmark.py --llm /workspace/models/SmolLM2-1.7B-Instruct \
--checkpoint /workspace/checkpoints/final_1.7B/stage3/latest.pt
# Run specific benchmarks only
python benchmark.py --checkpoint ... --only MVBench ScienceQA
Key optimizations:
1. Batch all MCQ options into ONE forward pass per mode (not N sequential)
2. Compute per-option CE from logits (avoid model's averaged loss)
3. Cache DINO encoding across modes (same frames, reuse kv_cache)
"""
import sys, os, json, tarfile, io, time, re, glob, gc
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from transformers import AutoTokenizer
from collections import defaultdict
from model import FoveatedVLM
# ─── Model / tokenizer ──────────────────────────────────────────────
def load_model(checkpoint_path, llm_name, dino_name="/workspace/models/dinov2-small",
device="cuda"):
model = FoveatedVLM(
llm_name=llm_name, dino_name=dino_name,
query_dim=384, visual_scale=0.14, lambda_coarse=0.0, deep_query=True,
)
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
state_dict = ckpt["model_state_dict"]
# Strip torch.compile's _orig_mod prefix if present
if any("._orig_mod." in k for k in state_dict):
state_dict = {k.replace("._orig_mod", ""): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model = model.to(device).to(torch.bfloat16).eval()
print(f"Loaded: {checkpoint_path} (step {ckpt.get('step', '?')})")
return model
def load_tokenizer(llm_name):
tok = AutoTokenizer.from_pretrained(llm_name)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
return tok
FRAME_TRANSFORM = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# ─── Data loading ────────────────────────────────────────────────────
def load_all_samples_from_shards(shard_pattern):
shard_files = sorted(glob.glob(shard_pattern))
print(f" Loading from {len(shard_files)} shards...")
samples = []
for shard_path in shard_files:
with tarfile.open(shard_path, "r") as tar:
members = tar.getmembers()
grouped = {}
for m in members:
parts = m.name.split(".")
if m.name.endswith(".json"):
key = parts[0]
if key not in grouped:
grouped[key] = {"frames": {}}
grouped[key]["json"] = json.load(tar.extractfile(m))
elif m.name.endswith(".jpg") or m.name.endswith(".png"):
key = parts[0]
frame_idx = int(parts[1]) if len(parts) >= 3 else 0
if key not in grouped:
grouped[key] = {"frames": {}}
img_data = tar.extractfile(m).read()
img = Image.open(io.BytesIO(img_data)).convert("RGB")
grouped[key]["frames"][frame_idx] = img
for key in sorted(grouped.keys()):
entry = grouped[key]
if entry.get("json") and entry.get("frames"):
sorted_frames = [entry["frames"][i] for i in sorted(entry["frames"].keys())]
samples.append({
"key": key,
"json": entry["json"],
"frames": sorted_frames,
})
print(f" Loaded {len(samples)} samples")
return samples
def prepare_frames_tensor(pil_frames, device="cuda", replicate_to=8):
"""Transform PIL frames to tensor. Replicate single-frame images to match training."""
tensors = [FRAME_TRANSFORM(f) for f in pil_frames]
frames = torch.stack(tensors) # [T, 3, H, W]
# Replicate single images to N frames (matches training with replicate_image_frames=8)
if frames.shape[0] == 1 and replicate_to > 1:
frames = frames.repeat(replicate_to, 1, 1, 1) # [N, 3, H, W]
return frames.unsqueeze(0).to(device, dtype=torch.bfloat16)
# ─── MCQ helpers ─────────────────────────────────────────────────────
def parse_mcq_options(user_text):
options = {}
for match in re.finditer(r'([A-Z])\.\s*(.+?)(?=\n[A-Z]\.|$)', user_text, re.DOTALL):
options[match.group(1)] = match.group(1) + ". " + match.group(2).strip()
return options
def extract_answer_letter(assistant_text):
m = re.match(r'([A-Z])\.', assistant_text.strip())
if m:
return m.group(1)
return assistant_text.strip()[0] if assistant_text.strip() else "?"
# ─── Batched option scoring (KEY OPTIMIZATION) ──────────────────────
@torch.no_grad()
def score_options_batched(model, tokenizer, frames, question_text, options_dict, mode, device):
"""
Score all MCQ options in ONE batched forward pass.
Returns dict {letter: loss} where lower loss = better match.
"""
letters = sorted(options_dict.keys())
if not letters:
return {}
# Tokenize prompt (shared across all options)
prompt_messages = [
{"role": "system", "content": "You are a helpful AI assistant."},
{"role": "user", "content": question_text},
]
prompt_text = tokenizer.apply_chat_template(prompt_messages, tokenize=False, add_generation_prompt=True)
prompt_ids = tokenizer.encode(prompt_text)
S_prompt = len(prompt_ids)
# Tokenize each full sequence (prompt + option)
all_ids = []
for letter in letters:
messages = [
{"role": "system", "content": "You are a helpful AI assistant."},
{"role": "user", "content": question_text},
{"role": "assistant", "content": options_dict[letter]},
]
full_text = tokenizer.apply_chat_template(messages, tokenize=False)
ids = tokenizer.encode(full_text)
all_ids.append(ids)
# Pad to same length
max_len = max(len(ids) for ids in all_ids)
pad_id = tokenizer.pad_token_id
N = len(letters)
batch_ids = torch.full((N, max_len), pad_id, dtype=torch.long, device=device)
batch_attn = torch.zeros(N, max_len, dtype=torch.long, device=device)
batch_loss_mask = torch.zeros(N, max_len, dtype=torch.float32, device=device)
for i, ids in enumerate(all_ids):
L = len(ids)
batch_ids[i, :L] = torch.tensor(ids, dtype=torch.long)
batch_attn[i, :L] = 1
batch_loss_mask[i, S_prompt:L] = 1.0 # answer-only tokens
# Expand frames: same image for all options
frames_batch = frames.expand(N, -1, -1, -1, -1) # [N, T, 3, H, W]
# Single batched forward
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
result = model.forward(
frames=frames_batch,
input_ids=batch_ids,
attention_mask=batch_attn,
loss_mask=batch_loss_mask,
mode=mode,
)
# Compute per-option loss from logits
logits = result["logits"] # [N, T_vis+S, V] (T_vis=T for coarse, 1 for autoregressive)
T_visual = logits.shape[1] - batch_ids.shape[1] # adaptive to mode
# Extract text portion of logits
text_logits = logits[:, T_visual:, :] # [N, S, V]
shift_logits = text_logits[:, :-1, :].contiguous() # [N, S-1, V]
shift_labels = batch_ids[:, 1:].contiguous() # [N, S-1]
shift_mask = batch_loss_mask[:, 1:].contiguous() # [N, S-1]
# Per-token CE loss
V = shift_logits.shape[-1]
per_token_loss = F.cross_entropy(
shift_logits.reshape(-1, V),
shift_labels.reshape(-1),
reduction="none",
ignore_index=pad_id,
).reshape(N, -1) # [N, S-1]
# Average loss over answer tokens only (per option)
masked_loss = (per_token_loss * shift_mask).sum(dim=1) / shift_mask.sum(dim=1).clamp(min=1)
return {letters[i]: masked_loss[i].item() for i in range(N)}
# ─── MCQ benchmark evaluation ───────────────────────────────────────
@torch.no_grad()
def evaluate_mcq_benchmark(model, tokenizer, samples, benchmark_name, modes, device):
results = {mode: {"correct": 0, "total": 0, "per_category": defaultdict(lambda: {"correct": 0, "total": 0})}
for mode in modes}
t0 = time.time()
for i, sample in enumerate(samples):
meta = sample["json"]
user_text = meta["user"]
gt_answer = meta["assistant"]
source = meta.get("source", "unknown")
category = source.split("/")[-1] if "/" in source else source
gt_letter = extract_answer_letter(gt_answer)
options = parse_mcq_options(user_text)
if not options:
continue
frames = prepare_frames_tensor(sample["frames"], device=device)
for mode in modes:
option_losses = score_options_batched(
model, tokenizer, frames, user_text, options, mode, device
)
if not option_losses:
continue
pred_letter = min(option_losses, key=option_losses.get)
correct = (pred_letter == gt_letter)
results[mode]["total"] += 1
if correct:
results[mode]["correct"] += 1
results[mode]["per_category"][category]["total"] += 1
if correct:
results[mode]["per_category"][category]["correct"] += 1
if (i + 1) % 100 == 0:
elapsed = time.time() - t0
for mode in modes:
r = results[mode]
acc = r["correct"] / max(r["total"], 1) * 100
print(f" [{benchmark_name}] {i+1}/{len(samples)} | {mode}: {acc:.1f}% ({r['correct']}/{r['total']}) | {elapsed:.0f}s", flush=True)
return results
# ─── Val loss evaluation ─────────────────────────────────────────────
@torch.no_grad()
def evaluate_val_loss(model, tokenizer, shard_pattern, modes, device, max_samples=1000):
samples = load_all_samples_from_shards(shard_pattern)
if max_samples:
samples = samples[:max_samples]
results = {mode: {"total_loss": 0.0, "count": 0} for mode in modes}
t0 = time.time()
for i, sample in enumerate(samples):
meta = sample["json"]
frames = prepare_frames_tensor(sample["frames"], device=device)
if "token_ids" in meta:
input_ids = torch.tensor(meta["token_ids"], dtype=torch.long).unsqueeze(0).to(device)
loss_mask_vals = meta.get("loss_mask", [1] * len(meta["token_ids"]))
loss_mask = torch.tensor(loss_mask_vals, dtype=torch.float32).unsqueeze(0).to(device)
else:
caption = meta.get("caption", meta.get("assistant", ""))
user = meta.get("user", "Describe this video.")
messages = [
{"role": "system", "content": "You are a helpful AI assistant."},
{"role": "user", "content": user},
{"role": "assistant", "content": caption},
]
text = tokenizer.apply_chat_template(messages, tokenize=False)
input_ids = tokenizer.encode(text, return_tensors="pt").to(device)
loss_mask = torch.ones_like(input_ids, dtype=torch.float32)
attention_mask = torch.ones_like(input_ids)
for mode in modes:
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
result = model.forward(
frames=frames, input_ids=input_ids,
attention_mask=attention_mask, loss_mask=loss_mask,
mode=mode,
)
results[mode]["total_loss"] += result["loss"].item()
results[mode]["count"] += 1
if (i + 1) % 200 == 0:
elapsed = time.time() - t0
for mode in modes:
r = results[mode]
avg = r["total_loss"] / max(r["count"], 1)
print(f" [val_10k] {i+1}/{len(samples)} | {mode}: loss={avg:.4f} | {elapsed:.0f}s", flush=True)
return results
# ─── Main ────────────────────────────────────────────────────────────
def run_mcq_benchmark(model, tokenizer, name, shard_pattern, modes, device, all_results):
"""Load, evaluate, and free one MCQ benchmark."""
shards = glob.glob(shard_pattern)
if not shards:
print(f" Skipping {name} β€” shards not found")
return
samples = load_all_samples_from_shards(shard_pattern)
results = evaluate_mcq_benchmark(model, tokenizer, samples, name, modes, device)
del samples; gc.collect() # free PIL images immediately
all_results[name.lower().replace("-", "_").replace(" ", "_")] = {}
key = name.lower().replace("-", "_").replace(" ", "_")
for mode in modes:
r = results[mode]
acc = r["correct"] / max(r["total"], 1) * 100
all_results[key][mode] = {
"accuracy": acc, "correct": r["correct"], "total": r["total"],
"per_category": {cat: {"accuracy": v["correct"]/max(v["total"],1)*100,
"correct": v["correct"], "total": v["total"]}
for cat, v in r["per_category"].items()},
}
print(f" {mode}: {acc:.1f}% ({r['correct']}/{r['total']})")
def main():
import argparse
parser = argparse.ArgumentParser(description="fVLM benchmark evaluation")
parser.add_argument("--llm", default="/workspace/models/SmolLM2-135M-Instruct",
help="HuggingFace LLM path (e.g. SmolLM2-135M or 1.7B)")
parser.add_argument("--checkpoint", default="/workspace/checkpoints/final/stage3/best.pt",
help="Path to model checkpoint (.pt)")
parser.add_argument("--dino", default="/workspace/models/dinov2-small",
help="HuggingFace DINOv2 path")
parser.add_argument("--only", nargs="+", default=None,
help="Run only specified benchmarks (e.g. --only MVBench ScienceQA)")
parser.add_argument("--output", default=None,
help="Output JSON path (default: /workspace/benchmark_results_{model}.json)")
parser.add_argument("--merge", action="store_true",
help="Merge with existing results file instead of overwriting")
args = parser.parse_args()
# Auto-detect model name for output file
model_name = os.path.basename(args.llm).replace("-Instruct", "").replace("SmolLM2-", "")
if args.output is None:
args.output = f"/workspace/benchmark_results_{model_name}.json"
device = "cuda"
modes = ["coarse_only", "coarse_fine", "autoregressive"]
print("=" * 70)
print(f"fVLM-{model_name} BENCHMARK EVALUATION")
print(f" LLM: {args.llm}")
print(f" Checkpoint: {args.checkpoint}")
print(f" Output: {args.output}")
print("=" * 70)
print("\nLoading model (bf16)...")
model = load_model(args.checkpoint, args.llm, args.dino, device)
tokenizer = load_tokenizer(args.llm)
# Load existing results if merging
all_results = {}
if args.merge and os.path.exists(args.output):
with open(args.output) as f:
all_results = json.load(f)
print(f" Loaded existing results from {args.output}")
t_global = time.time()
# ─── MCQ benchmarks (load one at a time, free between) ───────
benchmarks = [
("MVBench", "/workspace/data/eval/benchmarks/mvbench_shards/mvbench_*.tar"),
("Video-MME", "/workspace/data/eval/benchmarks/video_mme_shards/video_mme_*.tar"),
("ScienceQA", "/workspace/data/eval/benchmarks/scienceqa_shards/scienceqa_*.tar"),
("POPE", "/workspace/data/eval/benchmarks/pope_shards/pope_*.tar"),
("MLVU", "/workspace/data/eval/benchmarks/mlvu_shards/mlvu_*.tar"),
]
if args.only:
only_set = {n.lower() for n in args.only}
benchmarks = [(n, p) for n, p in benchmarks if n.lower() in only_set]
for i, (name, pattern) in enumerate(benchmarks):
print(f"\n{'-' * 70}")
print(f"BENCHMARK: {name}")
print(f"{'-' * 70}")
run_mcq_benchmark(model, tokenizer, name, pattern, modes, device, all_results)
# ─── Save results ────────────────────────────────────────────
with open(args.output, "w") as f:
json.dump(all_results, f, indent=2)
print(f"\nResults saved: {args.output}")
# ─── Summary ─────────────────────────────────────────────────
total_time = time.time() - t_global
print("\n" + "=" * 70)
print(f"SUMMARY (total: {total_time:.0f}s = {total_time/60:.1f}min)")
print("=" * 70)
print(f"\n{'Benchmark':<15} {'Coarse-Only':>15} {'Coarse->Fine':>15} {'Autoregressive':>15}")
print("-" * 62)
for bench_name, bench_data in all_results.items():
vals = []
for mode in modes:
if mode not in bench_data:
vals.append("β€”")
elif "accuracy" in bench_data[mode]:
vals.append(f"{bench_data[mode]['accuracy']:.1f}%")
else:
vals.append(f"{bench_data[mode]['avg_loss']:.4f}")
print(f"{bench_name:<15} {vals[0]:>15} {vals[1]:>15} {vals[2]:>15}")
print("\n" + "=" * 70)
if __name__ == "__main__":
main()