zhkleciel's picture
Super-squash branch 'main' using huggingface_hub
c013d55
#!/usr/bin/env python
"""Eval pipeline driver: extraction β†’ structural metrics β†’ VLM judge β†’ JSON.
Extraction runs locally on your GPU (vLLM/HF); the VLM judge runs remotely
via the OpenRouter API. One process, sequential stages, one JSON file out.
"""
from __future__ import annotations
import argparse
import datetime as _dt
import json
import logging
import sys
import time
from pathlib import Path
from typing import Any
from extract import iter_eval_samples, run_extraction
from judge import initialize_per_key_evals, run_vlm_judge
# ─── metrics aggregation ───────────────────────────────────────────────────
def per_sample_structural(prediction_json: dict, ground_truth: dict, strict_valid: bool) -> dict[str, Any]:
pred_keys = set(prediction_json.keys())
gt_keys = set(ground_truth.keys())
overlap = pred_keys & gt_keys
p = len(overlap) / len(pred_keys) if pred_keys else 0.0
r = len(overlap) / len(gt_keys) if gt_keys else 0.0
f1 = 2 * p * r / (p + r) if (p + r) > 0 else 0.0
return {
"json_valid": strict_valid,
"total_keys": len(gt_keys),
"total_pred_keys": len(pred_keys),
"overlap_keys": len(overlap),
"key_precision": p,
"key_recall": r,
"key_f1": f1,
}
def aggregate(records: list[dict[str, Any]]) -> dict[str, Any]:
n = len(records)
if n == 0:
return {"samples_evaluated": 0}
def mean(xs: list[float]) -> float:
return sum(xs) / len(xs) if xs else 0.0
json_valid = sum(1 for r in records if r.get("json_valid"))
vlm_scores = [r["vlm_judge_avg"] for r in records if r.get("vlm_judge_avg") is not None]
return {
"json_validity_rate": json_valid / n,
"key_precision_macro": mean([r.get("key_precision", 0.0) for r in records]),
"key_recall_macro": mean([r.get("key_recall", 0.0) for r in records]),
"key_f1_macro": mean([r.get("key_f1", 0.0) for r in records]),
"vlm_judge_score_avg": mean(vlm_scores) if vlm_scores else None,
"samples_evaluated": n,
}
def _strip_sample(rec: dict[str, Any]) -> dict[str, Any]:
"""Drop heavy/internal fields before serialising to JSON."""
return {
"key": rec["key"],
"schema": rec["schema"],
"ground_truth": rec["ground_truth"],
"prediction_raw": rec["prediction_raw"],
"prediction_json": rec["prediction_json"],
"json_valid": rec.get("json_valid", False),
"total_keys": rec.get("total_keys", 0),
"total_pred_keys": rec.get("total_pred_keys", 0),
"key_precision": rec.get("key_precision", 0.0),
"key_recall": rec.get("key_recall", 0.0),
"key_f1": rec.get("key_f1", 0.0),
"vlm_judge_avg": rec.get("vlm_judge_avg"),
"vlm_judge_raw": rec.get("vlm_judge_raw"),
"per_key": rec.get("per_key", {}),
}
# ─── CLI ───────────────────────────────────────────────────────────────────
def main() -> int:
p = argparse.ArgumentParser(description="OpenRouter-judged structured-extraction eval.")
p.add_argument("--checkpoint-path", required=True, help="HF id or local merged/LoRA dir.")
p.add_argument("--data-path", default="./eval_data", help="WDS tar / dir / glob.")
p.add_argument("--output-path", default="./eval_result.json")
p.add_argument("--num-samples", type=int, default=0, help="Cap N samples (0 = all).")
p.add_argument("--skip-samples", type=int, default=0)
p.add_argument("--extraction-backend", choices=["auto", "vllm", "hf"], default="auto")
p.add_argument("--extraction-batch", type=int, default=8)
p.add_argument("--extraction-max-new-tokens", type=int, default=1024)
p.add_argument("--extraction-gpu-mem-util", type=float, default=0.85)
p.add_argument("--extraction-max-model-len", type=int, default=8192)
p.add_argument("--vlm-judge", action=argparse.BooleanOptionalAction, default=True)
p.add_argument("--vlm-judge-model", default="qwen/qwen3-vl-4b-instruct")
p.add_argument("--vlm-judge-max-tokens", type=int, default=1024)
p.add_argument("--judge-concurrency", type=int, default=16, help="Concurrent OpenRouter calls.")
p.add_argument("--openrouter-api-key", default=None, help="Override $OPENROUTER_API_KEY.")
p.add_argument("--log-level", default="INFO")
args = p.parse_args()
logging.basicConfig(
level=args.log_level.upper(),
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
t_start = time.perf_counter()
logger = logging.getLogger("run_eval")
logger.info("=== OpenRouter-judged eval starting ===")
# ── load samples ─────────────────────────────────────────────────────
samples = list(
iter_eval_samples(
args.data_path,
skip=args.skip_samples,
limit=args.num_samples,
)
)
if not samples:
raise RuntimeError(
f"No usable samples loaded from {args.data_path} β€” expected WDS tars "
"with .jpg, .key_explanations, .structured_text per sample."
)
logger.info("Loaded %d sample(s).", len(samples))
sample_images = {s.key: s.image_bytes for s in samples}
# ── extraction ───────────────────────────────────────────────────────
records = run_extraction(
samples,
model_path=args.checkpoint_path,
backend=args.extraction_backend,
max_new_tokens=args.extraction_max_new_tokens,
max_model_len=args.extraction_max_model_len,
gpu_mem_util=args.extraction_gpu_mem_util,
batch=args.extraction_batch,
)
# ── structural metrics ───────────────────────────────────────────────
for rec in records:
rec.update(
per_sample_structural(
rec["prediction_json"],
rec["ground_truth"],
rec.get("prediction_strict_valid", bool(rec["prediction_json"])),
)
)
initialize_per_key_evals(records)
judge_errors: dict[str, str] = {}
# ── VLM judge ────────────────────────────────────────────────────────
if args.vlm_judge:
try:
run_vlm_judge(
records,
sample_images=sample_images,
model=args.vlm_judge_model,
max_tokens=args.vlm_judge_max_tokens,
concurrency=args.judge_concurrency,
api_key=args.openrouter_api_key,
)
except Exception as e:
judge_errors["vlm_judge"] = repr(e)
logger.warning("VLM judge failed (%s); continuing without VLM scores.", e)
for rec in records:
rec.setdefault("vlm_judge_avg", None)
else:
for rec in records:
rec["vlm_judge_avg"] = None
# ── write output ─────────────────────────────────────────────────────
elapsed = time.perf_counter() - t_start
result = {
"metadata": {
"checkpoint_path": args.checkpoint_path,
"data_path": args.data_path,
"num_samples_evaluated": len(records),
"extraction_backend": args.extraction_backend,
"vlm_judge_model": args.vlm_judge_model if args.vlm_judge else None,
"judge_errors": judge_errors or None,
"elapsed_s": round(elapsed, 2),
"timestamp_utc": _dt.datetime.now(_dt.timezone.utc).isoformat(),
},
"metrics": aggregate(records),
"samples": [_strip_sample(rec) for rec in records],
}
out = Path(args.output_path)
out.parent.mkdir(parents=True, exist_ok=True)
out.write_text(json.dumps(result, ensure_ascii=False, indent=2), encoding="utf-8")
print()
print("=== JUDGING SUMMARY ===")
print(f"output={out}")
for k, v in result["metrics"].items():
print(f" {k}={v:.4f}" if isinstance(v, float) else f" {k}={v}")
print(f" elapsed_s={elapsed:.1f}")
print("=== JUDGING OK ===")
return 0
if __name__ == "__main__":
sys.exit(main())