arcisvlm / scripts /eval_benchmarks.py
Hardik Sanghvi
feat: integrate Gemma 4 E2B backbone for production-quality VLM inference
7a564e3
Raw
History Blame Contribute Delete
15.1 kB
"""Automated benchmark evaluation for ArcisVLM.
Usage:
python3 scripts/eval_benchmarks.py --ckpt checkpoints/v4_stage3_final.pt --config configs/scale_1.3b.yaml --benchmarks all
python3 scripts/eval_benchmarks.py --ckpt checkpoints/v4_stage3_final.pt --config configs/scale_1.3b.yaml --benchmarks vqav2,pope
"""
import argparse
import json
import os
import sys
import torch
import yaml
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from model.vlm import VLJEPAModel
from model.tokenizer_utils import load_tokenizer, validate_tokenizer_model_match
from evaluation.surveillance_eval import evaluate_selective_decode
def load_model(config_path: str, ckpt_path: str, device: str = "cuda"):
with open(config_path) as f:
config = yaml.safe_load(f)
model = VLJEPAModel(config)
if os.path.exists(ckpt_path):
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
if "model_state_dict" in ckpt:
sd = ckpt["model_state_dict"]
# Handle DDP 'module.' prefix
cleaned = {k.replace("module.", ""): v for k, v in sd.items()}
missing, unexpected = model.load_state_dict(cleaned, strict=False)
epoch = ckpt.get("epoch", "?")
loss = ckpt.get("loss", "?")
print(f"Loaded checkpoint: {ckpt_path} (epoch {epoch}, loss {loss})")
if missing:
print(f" Missing keys: {len(missing)}")
if unexpected:
print(f" Unexpected keys: {len(unexpected)}")
else:
model.load_state_dict(ckpt, strict=False)
print(f"Loaded checkpoint: {ckpt_path}")
model = model.to(device)
model.eval()
return model, config
def load_real_vqa_dataset(name: str, max_samples: int = 500, img_size: int = 448):
"""Load real VQA datasets from HuggingFace. Falls back to dummy on failure."""
try:
from datasets import load_dataset
from torchvision import transforms
transform = transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
dataset_map = {
"vqav2": ("merve/vqav2-small", "validation", "question", "multiple_choice_answer"),
"gqa": ("lmms-lab/GQA", "testdev_balanced_instructions", "question", "answer"),
"textvqa": ("lmms-lab/textvqa", "validation", "question", "answers"),
"scienceqa": ("derek-thomas/ScienceQA", "test", "question", "answer"),
}
if name not in dataset_map:
return None
repo, split, q_key, a_key = dataset_map[name]
print(f" Loading {name} from {repo} ({split})...")
ds = load_dataset(repo, split=split, streaming=True)
samples = []
for i, item in enumerate(ds):
if i >= max_samples:
break
question = item.get(q_key, "")
answer = item.get(a_key, "")
# Handle different answer formats
if isinstance(answer, list):
answers = [str(a) for a in answer[:5]]
answer = answers[0] if answers else ""
else:
answers = [str(answer)]
# Get image — skip samples without images
image = item.get("image")
if image is not None:
try:
image = transform(image.convert("RGB"))
except Exception:
continue # Skip corrupted images
else:
continue # Skip samples without images
samples.append({
"image": image,
"question": str(question),
"answer": str(answer) if isinstance(answer, str) else str(answers[0]),
"answers": answers,
})
if samples:
print(f" Loaded {len(samples)} real samples for {name}")
return samples
except Exception as e:
print(f" [WARN] Failed to load real {name}: {e}")
return None
def build_fallback_dataset(name: str, num_samples: int = 100, img_size: int = 448):
"""NO FALLBACK. Real data required."""
raise RuntimeError(
f"Real dataset '{name}' not available. Download it first:\n"
f" pip install datasets\n"
f" # Datasets will be auto-downloaded from HuggingFace on first run.\n"
f" # If network unavailable, pre-download with:\n"
f" python3 -c \"from datasets import load_dataset; load_dataset('{name}')\""
)
def extract_answer(generated_text: str) -> str:
"""Extract the core answer from model output.
Handles formats like:
"The answer is: 3 people" -> "3 people"
"Yes, there is a car in the image." -> "yes"
"3" -> "3"
"""
text = generated_text.strip()
# Remove common prefixes
for prefix in ["the answer is:", "answer:", "the answer is", "answer is"]:
if text.lower().startswith(prefix):
text = text[len(prefix):].strip()
# For yes/no questions, extract just yes/no
lower = text.lower()
if lower.startswith("yes"):
return "yes"
if lower.startswith("no"):
return "no"
# Remove trailing punctuation
text = text.rstrip(".,;!?")
return text.strip()
def run_benchmark(name: str, model, config, tokenizer, device: str, max_samples: int = 500):
"""Run a single benchmark and return results."""
img_size = config.get("vision", {}).get("img_size", 448)
if name == "selective":
return evaluate_selective_decode(model, num_frames=1000, device=device)
elif name in ["vqav2", "gqa", "textvqa", "scienceqa"]:
# Load real data — no dummy fallback
samples = load_real_vqa_dataset(name, max_samples=max_samples, img_size=img_size)
if samples is None:
return {"accuracy": -1, "num_samples": 0, "error": f"Failed to load real {name} dataset. Install: pip install datasets"}
return evaluate_vqa_enhanced(model, samples, tokenizer, device, max_samples=max_samples)
elif name == "pope":
# Load real images from VQAv2, then generate POPE-style yes/no questions
vqa_samples = load_real_vqa_dataset("vqav2", max_samples=200, img_size=img_size)
if vqa_samples is None:
return {"f1": -1, "accuracy": -1, "error": "Failed to load VQAv2 for POPE. Install: pip install datasets"}
# Convert to yes/no format
objects = ["person", "car", "dog", "cat", "tree", "chair", "table", "bike"]
samples = []
for i, s in enumerate(vqa_samples[:200]):
obj = objects[i % len(objects)]
samples.append({
"image": s["image"],
"question": f"Is there a {obj} in the image?",
"answer": "yes" if i % 2 == 0 else "no",
"answers": ["yes" if i % 2 == 0 else "no"],
})
return evaluate_pope_enhanced(model, samples, tokenizer, device, max_samples=200)
elif name == "arcisvlm_detect":
# Try real COCO detection data
samples = load_real_vqa_dataset("vqav2", max_samples=max_samples, img_size=img_size)
if samples is None:
return {"precision": -1, "recall": -1, "f1": -1, "error": "No detection data available"}
# Convert to detection format
for s in samples:
s["question"] = "What objects are in this image?"
return evaluate_vqa_enhanced(model, samples, tokenizer, device, max_samples=max_samples)
else:
return {"error": f"Unknown benchmark: {name}"}
def evaluate_vqa_enhanced(model, samples, tokenizer, device, max_samples=500):
"""Enhanced VQA evaluation with proper answer extraction."""
model.eval()
total_acc = 0.0
num_samples = 0
predictions = []
n = min(len(samples), max_samples)
for i in range(n):
sample = samples[i]
image = sample["image"]
if image.dim() == 3:
image = image.unsqueeze(0)
image = image.to(device)
question = sample.get("question", "")
answers = sample.get("answers", [sample.get("answer", "")])
if isinstance(answers, str):
answers = [answers]
# Encode question
q_ids = tokenizer.encode(question)
q_tensor = torch.tensor([q_ids], dtype=torch.long, device=device)
# Generate
with torch.no_grad():
try:
output_ids = model.generate(image, q_tensor, max_new_tokens=32, temperature=0.1)
if output_ids is not None and output_ids.numel() > 0:
raw_text = tokenizer.decode(output_ids[0].cpu().tolist())
pred_text = extract_answer(raw_text)
else:
pred_text = ""
except Exception as e:
pred_text = f"[ERROR: {e}]"
# Compute accuracy
from evaluation.vqa_eval import vqa_accuracy
acc = vqa_accuracy(pred_text, answers)
total_acc += acc
num_samples += 1
predictions.append({
"question": question,
"prediction": pred_text,
"answers": answers,
"accuracy": acc,
})
if (i + 1) % 100 == 0:
print(f" [{i+1}/{n}] running acc: {total_acc / num_samples * 100:.1f}%")
# Print a few examples
print(f"\n Sample predictions:")
for p in predictions[:5]:
print(f" Q: {p['question'][:60]}")
print(f" A: {p['prediction'][:60]} (expected: {p['answers'][0][:30]})")
print()
return {
"accuracy": total_acc / max(num_samples, 1) * 100,
"num_samples": num_samples,
"predictions": predictions,
}
def evaluate_pope_enhanced(model, samples, tokenizer, device, max_samples=200):
"""Enhanced POPE evaluation with proper yes/no extraction."""
model.eval()
tp = fp = tn = fn = 0
predictions = []
n = min(len(samples), max_samples)
for i in range(n):
sample = samples[i]
image = sample["image"]
if image.dim() == 3:
image = image.unsqueeze(0)
image = image.to(device)
question = sample.get("question", "")
gt = sample.get("answer", "yes").lower().strip()
q_ids = tokenizer.encode(question)
q_tensor = torch.tensor([q_ids], dtype=torch.long, device=device)
with torch.no_grad():
try:
output_ids = model.generate(image, q_tensor, max_new_tokens=16, temperature=0.1)
if output_ids is not None and output_ids.numel() > 0:
raw = tokenizer.decode(output_ids[0].cpu().tolist())
pred = extract_answer(raw).lower()
else:
pred = ""
except Exception:
pred = ""
# Classify as yes/no
pred_yes = "yes" in pred and "no" not in pred
gt_yes = gt == "yes"
if pred_yes and gt_yes:
tp += 1
elif pred_yes and not gt_yes:
fp += 1
elif not pred_yes and not gt_yes:
tn += 1
else:
fn += 1
predictions.append({"question": question, "pred": pred, "gt": gt})
total = tp + fp + tn + fn
precision = tp / max(tp + fp, 1)
recall = tp / max(tp + fn, 1)
f1 = 2 * precision * recall / max(precision + recall, 1e-8)
print(f"\n POPE: tp={tp} fp={fp} tn={tn} fn={fn}")
print(f" Sample predictions:")
for p in predictions[:5]:
print(f" Q: {p['question'][:60]} -> {p['pred'][:20]} (gt: {p['gt']})")
return {
"f1": f1 * 100,
"precision": precision * 100,
"recall": recall * 100,
"accuracy": (tp + tn) / max(total, 1) * 100,
"yes_ratio": (tp + fp) / max(total, 1) * 100,
"tp": tp, "fp": fp, "tn": tn, "fn": fn,
}
def evaluate_detection_enhanced(model, samples, tokenizer, device, max_samples=100):
"""Enhanced detection evaluation."""
from evaluation.surveillance_eval import evaluate_detection
from data.multi_dataset import UnifiedVLMDataset
img_size = samples[0]["image"].shape[-1] if samples else 448
dataset = UnifiedVLMDataset(samples, "coco_detect", img_size=img_size)
return evaluate_detection(model, dataset, tokenizer, device, max_samples=max_samples)
def main():
parser = argparse.ArgumentParser(description="ArcisVLM Benchmark Evaluation")
parser.add_argument("--ckpt", required=True, help="Checkpoint path")
parser.add_argument("--config", default="configs/scale_1.3b.yaml", help="Config path")
parser.add_argument("--benchmarks", default="all", help="Comma-separated benchmark names or 'all'")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
parser.add_argument("--output", default=None, help="Save results JSON to path")
parser.add_argument("--max-samples", type=int, default=500, help="Max samples per benchmark")
args = parser.parse_args()
print("=" * 70)
print("ArcisVLM Benchmark Evaluation")
print("=" * 70)
model, config = load_model(args.config, args.ckpt, args.device)
# Load tokenizer using standardized utility (NO dummy fallback)
print("\n--- Tokenizer ---")
ckpt_dir = os.path.dirname(args.ckpt)
tokenizer = load_tokenizer(config, checkpoint_dir=ckpt_dir)
validate_tokenizer_model_match(tokenizer, model)
all_benchmarks = ["vqav2", "gqa", "textvqa", "pope", "scienceqa", "selective", "arcisvlm_detect"]
if args.benchmarks == "all":
benchmarks = all_benchmarks
else:
benchmarks = [b.strip() for b in args.benchmarks.split(",")]
results = {}
for name in benchmarks:
print(f"\n{'='*50}")
print(f"Running: {name}")
print(f"{'='*50}")
result = run_benchmark(name, model, config, tokenizer, args.device, args.max_samples)
results[name] = result
for k, v in result.items():
if isinstance(v, (int, float)):
print(f" {k}: {v:.2f}")
# Summary table
print(f"\n{'='*60}")
print("BENCHMARK RESULTS SUMMARY")
print(f"{'='*60}")
for name, result in results.items():
key_metric = result.get("accuracy", result.get("f1", result.get("decode_ratio", "N/A")))
if isinstance(key_metric, float):
print(f" {name:20s}: {key_metric:.2f}%")
else:
print(f" {name:20s}: {key_metric}")
# Save results
output_path = args.output or os.path.join(os.path.dirname(args.ckpt) or ".", "benchmark_results.json")
os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
serializable = {}
for name, result in results.items():
serializable[name] = {k: v for k, v in result.items() if not isinstance(v, list)}
with open(output_path, "w") as f:
json.dump(serializable, f, indent=2)
print(f"\nResults saved to: {output_path}")
if __name__ == "__main__":
main()