"""CLI entrypoint for the PID2Graph evaluation skeleton. Usage: python -m pid2graph_eval.cli \ --image-dir path/to/images \ --gt-dir path/to/graphml \ --output results.json \ --limit 10 The loader pairs each image with the graphml of the same stem (`A-001.png` ↔ `A-001.graphml`). Adjust `pair_samples` once you know the actual PID2Graph on-disk layout. """ from __future__ import annotations import argparse import json import sys from pathlib import Path import anthropic from tqdm import tqdm from .extractor import ( DEFAULT_MAX_TOKENS, DEFAULT_MODEL, extract_graph, extract_graph_tiled, ) from .gt_loader import ( SEMANTIC_EQUIPMENT_TYPES, collapse_through_primitives, filter_by_types, load_graphml, summarize, ) from .metrics import aggregate, evaluate IMAGE_EXTS = (".png", ".jpg", ".jpeg", ".webp") GT_EXTS = (".graphml", ".xml") def pair_samples(image_dir: Path, gt_dir: Path) -> list[tuple[Path, Path]]: """Match each image to a graphml file by filename stem.""" gt_by_stem: dict[str, Path] = {} for ext in GT_EXTS: for p in gt_dir.rglob(f"*{ext}"): gt_by_stem[p.stem] = p pairs: list[tuple[Path, Path]] = [] for ext in IMAGE_EXTS: for img in sorted(image_dir.rglob(f"*{ext}")): gt = gt_by_stem.get(img.stem) if gt is not None: pairs.append((img, gt)) return pairs def run(args: argparse.Namespace) -> int: image_dir = Path(args.image_dir) gt_dir = Path(args.gt_dir) output = Path(args.output) if not image_dir.exists(): print(f"error: --image-dir does not exist: {image_dir}", file=sys.stderr) return 2 if not gt_dir.exists(): print(f"error: --gt-dir does not exist: {gt_dir}", file=sys.stderr) return 2 pairs = pair_samples(image_dir, gt_dir) if args.limit: pairs = pairs[: args.limit] if not pairs: print("error: no (image, graphml) pairs found — check stems match", file=sys.stderr) return 1 print(f"found {len(pairs)} sample pair(s)") client = anthropic.Anthropic() per_sample: list[dict] = [] errors: list[dict] = [] for image_path, gt_path in tqdm(pairs, desc="eval"): try: gt_graph = load_graphml(gt_path) except Exception as e: # parsing a single bad file shouldn't kill the run errors.append({"sample": image_path.stem, "stage": "gt_load", "error": str(e)}) continue # Semantic-only mode: drop line-primitive nodes from the GT and # re-wire the remaining semantic nodes via the original pipe # connectivity (BFS through primitives). This matches the format # the VLM is instructed to emit — one direct edge per physical # pipeline, regardless of how many junctions it passes through. if args.semantic_only: gt_graph = collapse_through_primitives(gt_graph, SEMANTIC_EQUIPMENT_TYPES) if args.dry_run: per_sample.append( { "sample": image_path.stem, "gt_summary": summarize(gt_graph), } ) continue try: if args.tile_rows > 1 or args.tile_cols > 1: pred_dict = extract_graph_tiled( image_path, client=client, model=args.model, max_tokens=args.max_tokens, rows=args.tile_rows, cols=args.tile_cols, overlap=args.tile_overlap, dedup_px=args.dedup_px, seam_filter=not args.no_seam_filter, seam_threshold=args.seam_threshold_px, edge_threshold=args.edge_threshold_px, ) else: pred = extract_graph( image_path, client=client, model=args.model, max_tokens=args.max_tokens, ) pred_dict = pred.to_dict() except Exception as e: errors.append({"sample": image_path.stem, "stage": "vlm", "error": str(e)}) continue if args.semantic_only: pred_dict = filter_by_types( {**pred_dict, "directed": gt_graph.get("directed", False)}, SEMANTIC_EQUIPMENT_TYPES, ) # Default to whatever the GT file says; allow CLI override. if args.force_undirected: directed = False elif args.force_directed: directed = True else: directed = gt_graph.get("directed", True) metrics = evaluate( pred_dict, gt_graph, directed=directed, match_threshold=args.match_threshold, ) per_sample.append( { "sample": image_path.stem, "metrics": metrics, "pred": pred_dict, "gt_summary": summarize(gt_graph), } ) result: dict = { "config": { "model": args.model, "max_tokens": args.max_tokens, "directed": ( False if args.force_undirected else True if args.force_directed else "auto" ), "semantic_only": args.semantic_only, "match_threshold": args.match_threshold, "tile_rows": args.tile_rows, "tile_cols": args.tile_cols, "tile_overlap": args.tile_overlap, "dedup_px": args.dedup_px, "seam_filter": not args.no_seam_filter, "seam_threshold_px": args.seam_threshold_px, "edge_threshold_px": args.edge_threshold_px, "limit": args.limit, "dry_run": args.dry_run, }, "per_sample": per_sample, "errors": errors, } if not args.dry_run and per_sample: result["aggregate"] = aggregate([s["metrics"] for s in per_sample if "metrics" in s]) output.parent.mkdir(parents=True, exist_ok=True) output.write_text(json.dumps(result, indent=2, ensure_ascii=False)) print(f"wrote {output}") if "aggregate" in result: agg = result["aggregate"] print( f" nodes F1={agg['nodes_micro']['f1']:.3f} " f"P={agg['nodes_micro']['precision']:.3f} " f"R={agg['nodes_micro']['recall']:.3f}" ) print( f" edges F1={agg['edges_micro']['f1']:.3f} " f"P={agg['edges_micro']['precision']:.3f} " f"R={agg['edges_micro']['recall']:.3f}" ) if errors: print(f" {len(errors)} error(s) — see `errors` in output JSON") return 0 def main() -> int: p = argparse.ArgumentParser(description="PID2Graph VLM evaluation skeleton") p.add_argument("--image-dir", required=True, help="Directory containing P&ID images") p.add_argument("--gt-dir", required=True, help="Directory containing graphml ground truth") p.add_argument("--output", default="results.json", help="Where to write the JSON report") p.add_argument("--model", default=DEFAULT_MODEL, help=f"Claude model id (default: {DEFAULT_MODEL})") p.add_argument( "--max-tokens", type=int, default=DEFAULT_MAX_TOKENS, help=f"VLM max output tokens (default: {DEFAULT_MAX_TOKENS}, streamed)", ) p.add_argument("--limit", type=int, default=0, help="Only process the first N samples (0 = all)") p.add_argument("--match-threshold", type=float, default=0.5, help="Node similarity threshold") p.add_argument( "--semantic-only", action="store_true", help=( "Restrict both prediction and GT to semantic equipment " "(valve, pump, tank, instrumentation, inlet/outlet); " "drops pipe-primitive nodes like connector/crossing/arrow." ), ) p.add_argument( "--tile-rows", type=int, default=1, help="Tile the image into this many rows before VLM extraction (default 1 = off)", ) p.add_argument( "--tile-cols", type=int, default=1, help="Tile the image into this many columns before VLM extraction (default 1 = off)", ) p.add_argument( "--tile-overlap", type=float, default=0.1, help="Fractional overlap between adjacent tiles (default 0.1 = 10%%)", ) p.add_argument( "--dedup-px", type=float, default=40.0, help="Bbox-center distance (pixels) under which two same-type nodes are merged", ) p.add_argument( "--no-seam-filter", action="store_true", help=( "Disable the inlet/outlet tile-seam FP filter. By default, " "tiled extraction drops inlet/outlet nodes whose bbox center " "sits within 50px of an inner tile seam and is not within " "30px of the outer image border." ), ) p.add_argument( "--seam-threshold-px", type=float, default=50.0, help="Distance (px) from an inner tile seam that triggers FP filtering", ) p.add_argument( "--edge-threshold-px", type=float, default=30.0, help="Distance (px) from the outer image edge that exempts a node from filtering", ) p.add_argument( "--force-undirected", action="store_true", help="Force undirected edge matching (default: use whatever the GT file says)", ) p.add_argument( "--force-directed", action="store_true", help="Force directed edge matching (default: use whatever the GT file says)", ) p.add_argument( "--dry-run", action="store_true", help="Skip VLM calls; just load GT and print summaries (for loader debugging)", ) return run(p.parse_args()) if __name__ == "__main__": raise SystemExit(main())