"""Automated benchmark evaluation for ArcisVLM. Usage: python3 scripts/eval_benchmarks.py --ckpt checkpoints/v4_stage3_final.pt --config configs/scale_1.3b.yaml --benchmarks all python3 scripts/eval_benchmarks.py --ckpt checkpoints/v4_stage3_final.pt --config configs/scale_1.3b.yaml --benchmarks vqav2,pope """ import argparse import json import os import sys import torch import yaml sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from model.vlm import VLJEPAModel from model.tokenizer_utils import load_tokenizer, validate_tokenizer_model_match from evaluation.surveillance_eval import evaluate_selective_decode def load_model(config_path: str, ckpt_path: str, device: str = "cuda"): with open(config_path) as f: config = yaml.safe_load(f) model = VLJEPAModel(config) if os.path.exists(ckpt_path): ckpt = torch.load(ckpt_path, map_location=device, weights_only=False) if "model_state_dict" in ckpt: sd = ckpt["model_state_dict"] # Handle DDP 'module.' prefix cleaned = {k.replace("module.", ""): v for k, v in sd.items()} missing, unexpected = model.load_state_dict(cleaned, strict=False) epoch = ckpt.get("epoch", "?") loss = ckpt.get("loss", "?") print(f"Loaded checkpoint: {ckpt_path} (epoch {epoch}, loss {loss})") if missing: print(f" Missing keys: {len(missing)}") if unexpected: print(f" Unexpected keys: {len(unexpected)}") else: model.load_state_dict(ckpt, strict=False) print(f"Loaded checkpoint: {ckpt_path}") model = model.to(device) model.eval() return model, config def load_real_vqa_dataset(name: str, max_samples: int = 500, img_size: int = 448): """Load real VQA datasets from HuggingFace. Falls back to dummy on failure.""" try: from datasets import load_dataset from torchvision import transforms transform = transforms.Compose([ transforms.Resize((img_size, img_size)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) dataset_map = { "vqav2": ("merve/vqav2-small", "validation", "question", "multiple_choice_answer"), "gqa": ("lmms-lab/GQA", "testdev_balanced_instructions", "question", "answer"), "textvqa": ("lmms-lab/textvqa", "validation", "question", "answers"), "scienceqa": ("derek-thomas/ScienceQA", "test", "question", "answer"), } if name not in dataset_map: return None repo, split, q_key, a_key = dataset_map[name] print(f" Loading {name} from {repo} ({split})...") ds = load_dataset(repo, split=split, streaming=True) samples = [] for i, item in enumerate(ds): if i >= max_samples: break question = item.get(q_key, "") answer = item.get(a_key, "") # Handle different answer formats if isinstance(answer, list): answers = [str(a) for a in answer[:5]] answer = answers[0] if answers else "" else: answers = [str(answer)] # Get image — skip samples without images image = item.get("image") if image is not None: try: image = transform(image.convert("RGB")) except Exception: continue # Skip corrupted images else: continue # Skip samples without images samples.append({ "image": image, "question": str(question), "answer": str(answer) if isinstance(answer, str) else str(answers[0]), "answers": answers, }) if samples: print(f" Loaded {len(samples)} real samples for {name}") return samples except Exception as e: print(f" [WARN] Failed to load real {name}: {e}") return None def build_fallback_dataset(name: str, num_samples: int = 100, img_size: int = 448): """NO FALLBACK. Real data required.""" raise RuntimeError( f"Real dataset '{name}' not available. Download it first:\n" f" pip install datasets\n" f" # Datasets will be auto-downloaded from HuggingFace on first run.\n" f" # If network unavailable, pre-download with:\n" f" python3 -c \"from datasets import load_dataset; load_dataset('{name}')\"" ) def extract_answer(generated_text: str) -> str: """Extract the core answer from model output. Handles formats like: "The answer is: 3 people" -> "3 people" "Yes, there is a car in the image." -> "yes" "3" -> "3" """ text = generated_text.strip() # Remove common prefixes for prefix in ["the answer is:", "answer:", "the answer is", "answer is"]: if text.lower().startswith(prefix): text = text[len(prefix):].strip() # For yes/no questions, extract just yes/no lower = text.lower() if lower.startswith("yes"): return "yes" if lower.startswith("no"): return "no" # Remove trailing punctuation text = text.rstrip(".,;!?") return text.strip() def run_benchmark(name: str, model, config, tokenizer, device: str, max_samples: int = 500): """Run a single benchmark and return results.""" img_size = config.get("vision", {}).get("img_size", 448) if name == "selective": return evaluate_selective_decode(model, num_frames=1000, device=device) elif name in ["vqav2", "gqa", "textvqa", "scienceqa"]: # Load real data — no dummy fallback samples = load_real_vqa_dataset(name, max_samples=max_samples, img_size=img_size) if samples is None: return {"accuracy": -1, "num_samples": 0, "error": f"Failed to load real {name} dataset. Install: pip install datasets"} return evaluate_vqa_enhanced(model, samples, tokenizer, device, max_samples=max_samples) elif name == "pope": # Load real images from VQAv2, then generate POPE-style yes/no questions vqa_samples = load_real_vqa_dataset("vqav2", max_samples=200, img_size=img_size) if vqa_samples is None: return {"f1": -1, "accuracy": -1, "error": "Failed to load VQAv2 for POPE. Install: pip install datasets"} # Convert to yes/no format objects = ["person", "car", "dog", "cat", "tree", "chair", "table", "bike"] samples = [] for i, s in enumerate(vqa_samples[:200]): obj = objects[i % len(objects)] samples.append({ "image": s["image"], "question": f"Is there a {obj} in the image?", "answer": "yes" if i % 2 == 0 else "no", "answers": ["yes" if i % 2 == 0 else "no"], }) return evaluate_pope_enhanced(model, samples, tokenizer, device, max_samples=200) elif name == "arcisvlm_detect": # Try real COCO detection data samples = load_real_vqa_dataset("vqav2", max_samples=max_samples, img_size=img_size) if samples is None: return {"precision": -1, "recall": -1, "f1": -1, "error": "No detection data available"} # Convert to detection format for s in samples: s["question"] = "What objects are in this image?" return evaluate_vqa_enhanced(model, samples, tokenizer, device, max_samples=max_samples) else: return {"error": f"Unknown benchmark: {name}"} def evaluate_vqa_enhanced(model, samples, tokenizer, device, max_samples=500): """Enhanced VQA evaluation with proper answer extraction.""" model.eval() total_acc = 0.0 num_samples = 0 predictions = [] n = min(len(samples), max_samples) for i in range(n): sample = samples[i] image = sample["image"] if image.dim() == 3: image = image.unsqueeze(0) image = image.to(device) question = sample.get("question", "") answers = sample.get("answers", [sample.get("answer", "")]) if isinstance(answers, str): answers = [answers] # Encode question q_ids = tokenizer.encode(question) q_tensor = torch.tensor([q_ids], dtype=torch.long, device=device) # Generate with torch.no_grad(): try: output_ids = model.generate(image, q_tensor, max_new_tokens=32, temperature=0.1) if output_ids is not None and output_ids.numel() > 0: raw_text = tokenizer.decode(output_ids[0].cpu().tolist()) pred_text = extract_answer(raw_text) else: pred_text = "" except Exception as e: pred_text = f"[ERROR: {e}]" # Compute accuracy from evaluation.vqa_eval import vqa_accuracy acc = vqa_accuracy(pred_text, answers) total_acc += acc num_samples += 1 predictions.append({ "question": question, "prediction": pred_text, "answers": answers, "accuracy": acc, }) if (i + 1) % 100 == 0: print(f" [{i+1}/{n}] running acc: {total_acc / num_samples * 100:.1f}%") # Print a few examples print(f"\n Sample predictions:") for p in predictions[:5]: print(f" Q: {p['question'][:60]}") print(f" A: {p['prediction'][:60]} (expected: {p['answers'][0][:30]})") print() return { "accuracy": total_acc / max(num_samples, 1) * 100, "num_samples": num_samples, "predictions": predictions, } def evaluate_pope_enhanced(model, samples, tokenizer, device, max_samples=200): """Enhanced POPE evaluation with proper yes/no extraction.""" model.eval() tp = fp = tn = fn = 0 predictions = [] n = min(len(samples), max_samples) for i in range(n): sample = samples[i] image = sample["image"] if image.dim() == 3: image = image.unsqueeze(0) image = image.to(device) question = sample.get("question", "") gt = sample.get("answer", "yes").lower().strip() q_ids = tokenizer.encode(question) q_tensor = torch.tensor([q_ids], dtype=torch.long, device=device) with torch.no_grad(): try: output_ids = model.generate(image, q_tensor, max_new_tokens=16, temperature=0.1) if output_ids is not None and output_ids.numel() > 0: raw = tokenizer.decode(output_ids[0].cpu().tolist()) pred = extract_answer(raw).lower() else: pred = "" except Exception: pred = "" # Classify as yes/no pred_yes = "yes" in pred and "no" not in pred gt_yes = gt == "yes" if pred_yes and gt_yes: tp += 1 elif pred_yes and not gt_yes: fp += 1 elif not pred_yes and not gt_yes: tn += 1 else: fn += 1 predictions.append({"question": question, "pred": pred, "gt": gt}) total = tp + fp + tn + fn precision = tp / max(tp + fp, 1) recall = tp / max(tp + fn, 1) f1 = 2 * precision * recall / max(precision + recall, 1e-8) print(f"\n POPE: tp={tp} fp={fp} tn={tn} fn={fn}") print(f" Sample predictions:") for p in predictions[:5]: print(f" Q: {p['question'][:60]} -> {p['pred'][:20]} (gt: {p['gt']})") return { "f1": f1 * 100, "precision": precision * 100, "recall": recall * 100, "accuracy": (tp + tn) / max(total, 1) * 100, "yes_ratio": (tp + fp) / max(total, 1) * 100, "tp": tp, "fp": fp, "tn": tn, "fn": fn, } def evaluate_detection_enhanced(model, samples, tokenizer, device, max_samples=100): """Enhanced detection evaluation.""" from evaluation.surveillance_eval import evaluate_detection from data.multi_dataset import UnifiedVLMDataset img_size = samples[0]["image"].shape[-1] if samples else 448 dataset = UnifiedVLMDataset(samples, "coco_detect", img_size=img_size) return evaluate_detection(model, dataset, tokenizer, device, max_samples=max_samples) def main(): parser = argparse.ArgumentParser(description="ArcisVLM Benchmark Evaluation") parser.add_argument("--ckpt", required=True, help="Checkpoint path") parser.add_argument("--config", default="configs/scale_1.3b.yaml", help="Config path") parser.add_argument("--benchmarks", default="all", help="Comma-separated benchmark names or 'all'") parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") parser.add_argument("--output", default=None, help="Save results JSON to path") parser.add_argument("--max-samples", type=int, default=500, help="Max samples per benchmark") args = parser.parse_args() print("=" * 70) print("ArcisVLM Benchmark Evaluation") print("=" * 70) model, config = load_model(args.config, args.ckpt, args.device) # Load tokenizer using standardized utility (NO dummy fallback) print("\n--- Tokenizer ---") ckpt_dir = os.path.dirname(args.ckpt) tokenizer = load_tokenizer(config, checkpoint_dir=ckpt_dir) validate_tokenizer_model_match(tokenizer, model) all_benchmarks = ["vqav2", "gqa", "textvqa", "pope", "scienceqa", "selective", "arcisvlm_detect"] if args.benchmarks == "all": benchmarks = all_benchmarks else: benchmarks = [b.strip() for b in args.benchmarks.split(",")] results = {} for name in benchmarks: print(f"\n{'='*50}") print(f"Running: {name}") print(f"{'='*50}") result = run_benchmark(name, model, config, tokenizer, args.device, args.max_samples) results[name] = result for k, v in result.items(): if isinstance(v, (int, float)): print(f" {k}: {v:.2f}") # Summary table print(f"\n{'='*60}") print("BENCHMARK RESULTS SUMMARY") print(f"{'='*60}") for name, result in results.items(): key_metric = result.get("accuracy", result.get("f1", result.get("decode_ratio", "N/A"))) if isinstance(key_metric, float): print(f" {name:20s}: {key_metric:.2f}%") else: print(f" {name:20s}: {key_metric}") # Save results output_path = args.output or os.path.join(os.path.dirname(args.ckpt) or ".", "benchmark_results.json") os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) serializable = {} for name, result in results.items(): serializable[name] = {k: v for k, v in result.items() if not isinstance(v, list)} with open(output_path, "w") as f: json.dump(serializable, f, indent=2) print(f"\nResults saved to: {output_path}") if __name__ == "__main__": main()