#!/usr/bin/env python3 """Run exact-unit grounded-CBU verification requests against vLLM.""" from __future__ import annotations import argparse import asyncio import base64 import json import time from io import BytesIO from pathlib import Path from typing import Any import aiohttp from PIL import Image, ImageFile ImageFile.LOAD_TRUNCATED_IMAGES = True STATUSES = [ "grounded", "unsupported", "uncertain", "invalid_text_unit", "not_a_visual_claim", "image_unavailable", ] def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Run exact-unit grounded-CBU verification requests") parser.add_argument("--input", required=True) parser.add_argument("--output", required=True) parser.add_argument("--urls", default="http://localhost:8000") parser.add_argument("--model", default="Qwen/Qwen3.5-122B-A10B-FP8") parser.add_argument("--max-requests", type=int, default=None) parser.add_argument("--concurrency", type=int, default=32) parser.add_argument("--max-tokens", type=int, default=2048) parser.add_argument("--temperature", type=float, default=0.0) parser.add_argument("--timeout-sec", type=int, default=600) parser.add_argument("--image-mode", choices=["auto", "file", "data", "url"], default="auto") parser.add_argument("--structured-json", action="store_true") parser.add_argument("--resume", action="store_true", help="Append to output and skip request_ids already present.") parser.add_argument( "--resume-ok-only", action="store_true", help="With --resume, skip only previously successful request_ids so timeout/schema failures are retried.", ) parser.add_argument( "--skip-ok-from", default=None, help="JSONL response log whose successful request_ids should be skipped while writing a separate output.", ) return parser.parse_args() def iter_requests(path: Path, max_requests: int | None) -> list[dict[str, Any]]: rows = [] with path.open("r", encoding="utf-8") as handle: for line in handle: if max_requests is not None and len(rows) >= max_requests: break if line.strip(): rows.append(json.loads(line)) return rows def image_url_for(row: dict[str, Any], mode: str) -> str: if mode in {"auto", "data"} and row.get("image_path"): path = Path(row["image_path"]) with Image.open(path) as image: if image.mode != "RGB": image = image.convert("RGB") buffer = BytesIO() image.save(buffer, format="JPEG", quality=88) return f"data:image/jpeg;base64,{base64.b64encode(buffer.getvalue()).decode('ascii')}" if mode in {"auto", "file"} and row.get("image_path"): return Path(row["image_path"]).resolve().as_uri() if mode == "file": raise ValueError(f"request {row.get('request_id')} has no image_path") return row["image_url"] def response_schema(unit_ids: list[str]) -> dict[str, Any]: return { "type": "object", "properties": { "caption_id": {"type": "string"}, "unit_results": { "type": "array", "minItems": len(unit_ids), "maxItems": len(unit_ids), "items": { "type": "object", "properties": { "unit_id": {"type": "string", "enum": unit_ids}, "status": {"type": "string", "enum": STATUSES}, "confidence": {"type": "number", "minimum": 0.0, "maximum": 1.0}, "evidence": {"type": "string", "maxLength": 180}, }, "required": ["unit_id", "status", "confidence", "evidence"], "additionalProperties": False, }, }, }, "required": ["caption_id", "unit_results"], "additionalProperties": False, } def validate(parsed: Any, row: dict[str, Any]) -> str | None: if not isinstance(parsed, dict): return "top-level response is not an object" if not isinstance(parsed.get("caption_id"), str): return "caption_id is not a string" results = parsed.get("unit_results") if not isinstance(results, list): return "unit_results is not an array" expected = [unit["unit_id"] for unit in row.get("claimed_units", [])] seen = [] for index, result in enumerate(results): if not isinstance(result, dict): return f"unit_results[{index}] is not an object" unit_id = result.get("unit_id") if not isinstance(unit_id, str): return f"unit_results[{index}].unit_id is not a string" seen.append(unit_id) if result.get("status") not in set(STATUSES): return f"unit_results[{index}].status has invalid value" if not isinstance(result.get("confidence"), int | float): return f"unit_results[{index}].confidence is not numeric" if not isinstance(result.get("evidence"), str): return f"unit_results[{index}].evidence is not a string" if sorted(seen) != sorted(expected): return f"unit_id set mismatch: expected={len(expected)} seen={len(seen)}" if len(seen) != len(set(seen)): return "duplicate unit_id in response" return None def payload_for(row: dict[str, Any], args: argparse.Namespace) -> dict[str, Any]: unit_ids = [unit["unit_id"] for unit in row.get("claimed_units", [])] payload: dict[str, Any] = { "model": args.model, "max_tokens": args.max_tokens, "temperature": args.temperature, "messages": [ {"role": "system", "content": row["system_prompt"]}, { "role": "user", "content": [ {"type": "text", "text": row["user_prompt"]}, {"type": "image_url", "image_url": {"url": image_url_for(row, args.image_mode)}}, ], }, ], "chat_template_kwargs": {"enable_thinking": False}, } if args.structured_json: payload["structured_outputs"] = {"json": response_schema(unit_ids)} return payload async def post_one(session: aiohttp.ClientSession, url: str, row: dict[str, Any], args: argparse.Namespace) -> dict[str, Any]: endpoint = f"{url.rstrip('/')}/v1/chat/completions" start = time.perf_counter() try: async with session.post(endpoint, json=payload_for(row, args), headers={"Authorization": "Bearer sk-fake"}) as response: text = await response.text() elapsed = time.perf_counter() - start if response.status >= 400: return { "request_id": row["request_id"], "ok": False, "status": response.status, "elapsed_sec": round(elapsed, 4), "error": text[:4000], "request": row, } body = json.loads(text) content = body["choices"][0]["message"]["content"] parsed = None parse_error = None schema_error = None try: parsed = json.loads(content) schema_error = validate(parsed, row) except Exception as exc: # noqa: BLE001 parse_error = repr(exc) return { "request_id": row["request_id"], "ok": parse_error is None and schema_error is None, "status": response.status, "elapsed_sec": round(elapsed, 4), "model": args.model, "usage": body.get("usage", {}), "response_text": content, "parsed": parsed, "parse_error": parse_error, "schema_error": schema_error, "request": row, } except Exception as exc: # noqa: BLE001 return { "request_id": row["request_id"], "ok": False, "status": None, "elapsed_sec": round(time.perf_counter() - start, 4), "error": repr(exc), "request": row, } async def run(args: argparse.Namespace) -> int: rows = iter_requests(Path(args.input), args.max_requests) urls = [item.strip() for item in args.urls.split(",") if item.strip()] output = Path(args.output) output.parent.mkdir(parents=True, exist_ok=True) seen_request_ids: set[str] = set() if args.skip_ok_from: with Path(args.skip_ok_from).open("r", encoding="utf-8") as handle: for line in handle: if not line.strip(): continue try: row = json.loads(line) except json.JSONDecodeError: continue if not row.get("ok"): continue request_id = row.get("request_id") if isinstance(request_id, str): seen_request_ids.add(request_id) if args.resume and output.exists(): with output.open("r", encoding="utf-8") as handle: for line in handle: if not line.strip(): continue try: row = json.loads(line) except json.JSONDecodeError: continue if args.resume_ok_only and not row.get("ok"): continue request_id = row.get("request_id") if isinstance(request_id, str): seen_request_ids.add(request_id) rows = [row for row in rows if row.get("request_id") not in seen_request_ids] timeout = aiohttp.ClientTimeout(total=args.timeout_sec) connector = aiohttp.TCPConnector(limit=args.concurrency) sem = asyncio.Semaphore(args.concurrency) ok = 0 total = 0 mode = "a" if args.resume else "w" with output.open(mode, encoding="utf-8") as handle: async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session: async def guarded(index: int, row: dict[str, Any]) -> dict[str, Any]: async with sem: return await post_one(session, urls[index % len(urls)], row, args) tasks = [asyncio.create_task(guarded(index, row)) for index, row in enumerate(rows)] for task in asyncio.as_completed(tasks): result = await task handle.write(json.dumps(result, ensure_ascii=False) + "\n") handle.flush() total += 1 ok += int(bool(result.get("ok"))) if total % 10 == 0 or total == len(rows): print( json.dumps( { "completed": total, "ok": ok, "total": len(rows), "skipped_existing": len(seen_request_ids), }, ensure_ascii=False, ) ) print(json.dumps({"output": str(output), "completed": total, "ok": ok, "skipped_existing": len(seen_request_ids)}, indent=2)) return 0 def main() -> int: return asyncio.run(run(parse_args())) if __name__ == "__main__": raise SystemExit(main())