venice-h1 / evaluate.py
OdaxAI's picture
Add evaluate.py
6598593 verified
Raw
History Blame Contribute Delete
2.7 kB
#!/usr/bin/env python3
"""
Evaluate Venice-H1 on pre-extracted feature caches.
Usage:
python evaluate.py --checkpoint checkpoints/best.pt \
--splits data/cached_testA_refcoco_unc_feats.pt \
data/cached_testB_refcoco_unc_feats.pt
"""
import argparse
import json
from pathlib import Path
import torch
from torch.utils.data import DataLoader
from venice_h1.model import VeniceH1Reranker
from train import FeatureCacheDataset, evaluate
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", required=True,
help="Path to checkpoint .pt file")
parser.add_argument("--splits", nargs="+", required=True,
help="One or more feature cache .pt files to evaluate")
parser.add_argument("--tau", type=float, default=0.5,
help="Failure Gate threshold")
parser.add_argument("--no_grid", action="store_true")
parser.add_argument("--batch_size", type=int, default=512)
parser.add_argument("--device", default="cuda")
args = parser.parse_args()
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
# Load checkpoint
ckpt = torch.load(args.checkpoint, map_location="cpu")
state = ckpt["model"] if "model" in ckpt else ckpt
# Infer model config from weights
use_grid = not args.no_grid
model = VeniceH1Reranker(use_grid=use_grid).to(device)
model.load_state_dict(state, strict=False)
model.eval()
print(f"Loaded: {args.checkpoint}")
print(f"Parameters: {model.num_parameters():,} | tau={args.tau}")
print()
results = {}
for split_path in args.splits:
if not Path(split_path).exists():
print(f" [skip] {split_path} not found")
continue
ds = FeatureCacheDataset(split_path, use_grid=use_grid)
loader = DataLoader(ds, batch_size=args.batch_size,
shuffle=False, num_workers=4)
split_name = Path(split_path).stem.replace("cached_", "").replace("_feats", "")
metrics = evaluate(model, loader, device, tau=args.tau)
results[split_name] = metrics
print(f" {split_name:<30s} "
f"AUC={metrics['gate_auc']:.3f} "
f"harmful={metrics['harmful_switch']*100:.3f}% "
f"correct_rerank={metrics['correct_rerank']*100:.1f}% "
f"failure_rate={metrics['failure_rate']*100:.1f}%")
print()
out_path = "eval_results.json"
with open(out_path, "w") as f:
json.dump(results, f, indent=2)
print(f"Saved results → {out_path}")
if __name__ == "__main__":
main()