| |
| """ |
| Evaluate Venice-H1 on pre-extracted feature caches. |
| |
| Usage: |
| python evaluate.py --checkpoint checkpoints/best.pt \ |
| --splits data/cached_testA_refcoco_unc_feats.pt \ |
| data/cached_testB_refcoco_unc_feats.pt |
| """ |
|
|
| import argparse |
| import json |
| from pathlib import Path |
|
|
| import torch |
| from torch.utils.data import DataLoader |
|
|
| from venice_h1.model import VeniceH1Reranker |
| from train import FeatureCacheDataset, evaluate |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--checkpoint", required=True, |
| help="Path to checkpoint .pt file") |
| parser.add_argument("--splits", nargs="+", required=True, |
| help="One or more feature cache .pt files to evaluate") |
| parser.add_argument("--tau", type=float, default=0.5, |
| help="Failure Gate threshold") |
| parser.add_argument("--no_grid", action="store_true") |
| parser.add_argument("--batch_size", type=int, default=512) |
| parser.add_argument("--device", default="cuda") |
| args = parser.parse_args() |
|
|
| device = torch.device(args.device if torch.cuda.is_available() else "cpu") |
|
|
| |
| ckpt = torch.load(args.checkpoint, map_location="cpu") |
| state = ckpt["model"] if "model" in ckpt else ckpt |
|
|
| |
| use_grid = not args.no_grid |
| model = VeniceH1Reranker(use_grid=use_grid).to(device) |
| model.load_state_dict(state, strict=False) |
| model.eval() |
|
|
| print(f"Loaded: {args.checkpoint}") |
| print(f"Parameters: {model.num_parameters():,} | tau={args.tau}") |
| print() |
|
|
| results = {} |
| for split_path in args.splits: |
| if not Path(split_path).exists(): |
| print(f" [skip] {split_path} not found") |
| continue |
|
|
| ds = FeatureCacheDataset(split_path, use_grid=use_grid) |
| loader = DataLoader(ds, batch_size=args.batch_size, |
| shuffle=False, num_workers=4) |
| split_name = Path(split_path).stem.replace("cached_", "").replace("_feats", "") |
|
|
| metrics = evaluate(model, loader, device, tau=args.tau) |
| results[split_name] = metrics |
|
|
| print(f" {split_name:<30s} " |
| f"AUC={metrics['gate_auc']:.3f} " |
| f"harmful={metrics['harmful_switch']*100:.3f}% " |
| f"correct_rerank={metrics['correct_rerank']*100:.1f}% " |
| f"failure_rate={metrics['failure_rate']*100:.1f}%") |
|
|
| print() |
| out_path = "eval_results.json" |
| with open(out_path, "w") as f: |
| json.dump(results, f, indent=2) |
| print(f"Saved results → {out_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|