Spaces:
Running
Running
| """CLI entrypoint for the PID2Graph evaluation skeleton. | |
| Usage: | |
| python -m pid2graph_eval.cli \ | |
| --image-dir path/to/images \ | |
| --gt-dir path/to/graphml \ | |
| --output results.json \ | |
| --limit 10 | |
| The loader pairs each image with the graphml of the same stem | |
| (`A-001.png` ↔ `A-001.graphml`). Adjust `pair_samples` once you know the | |
| actual PID2Graph on-disk layout. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import sys | |
| from pathlib import Path | |
| import anthropic | |
| from tqdm import tqdm | |
| from .extractor import ( | |
| DEFAULT_MAX_TOKENS, | |
| DEFAULT_MODEL, | |
| extract_graph, | |
| extract_graph_tiled, | |
| ) | |
| from .gt_loader import ( | |
| SEMANTIC_EQUIPMENT_TYPES, | |
| collapse_through_primitives, | |
| filter_by_types, | |
| load_graphml, | |
| summarize, | |
| ) | |
| from .metrics import aggregate, evaluate | |
| IMAGE_EXTS = (".png", ".jpg", ".jpeg", ".webp") | |
| GT_EXTS = (".graphml", ".xml") | |
| def pair_samples(image_dir: Path, gt_dir: Path) -> list[tuple[Path, Path]]: | |
| """Match each image to a graphml file by filename stem.""" | |
| gt_by_stem: dict[str, Path] = {} | |
| for ext in GT_EXTS: | |
| for p in gt_dir.rglob(f"*{ext}"): | |
| gt_by_stem[p.stem] = p | |
| pairs: list[tuple[Path, Path]] = [] | |
| for ext in IMAGE_EXTS: | |
| for img in sorted(image_dir.rglob(f"*{ext}")): | |
| gt = gt_by_stem.get(img.stem) | |
| if gt is not None: | |
| pairs.append((img, gt)) | |
| return pairs | |
| def run(args: argparse.Namespace) -> int: | |
| image_dir = Path(args.image_dir) | |
| gt_dir = Path(args.gt_dir) | |
| output = Path(args.output) | |
| if not image_dir.exists(): | |
| print(f"error: --image-dir does not exist: {image_dir}", file=sys.stderr) | |
| return 2 | |
| if not gt_dir.exists(): | |
| print(f"error: --gt-dir does not exist: {gt_dir}", file=sys.stderr) | |
| return 2 | |
| pairs = pair_samples(image_dir, gt_dir) | |
| if args.limit: | |
| pairs = pairs[: args.limit] | |
| if not pairs: | |
| print("error: no (image, graphml) pairs found — check stems match", file=sys.stderr) | |
| return 1 | |
| print(f"found {len(pairs)} sample pair(s)") | |
| client = anthropic.Anthropic() | |
| per_sample: list[dict] = [] | |
| errors: list[dict] = [] | |
| for image_path, gt_path in tqdm(pairs, desc="eval"): | |
| try: | |
| gt_graph = load_graphml(gt_path) | |
| except Exception as e: # parsing a single bad file shouldn't kill the run | |
| errors.append({"sample": image_path.stem, "stage": "gt_load", "error": str(e)}) | |
| continue | |
| # Semantic-only mode: drop line-primitive nodes from the GT and | |
| # re-wire the remaining semantic nodes via the original pipe | |
| # connectivity (BFS through primitives). This matches the format | |
| # the VLM is instructed to emit — one direct edge per physical | |
| # pipeline, regardless of how many junctions it passes through. | |
| if args.semantic_only: | |
| gt_graph = collapse_through_primitives(gt_graph, SEMANTIC_EQUIPMENT_TYPES) | |
| if args.dry_run: | |
| per_sample.append( | |
| { | |
| "sample": image_path.stem, | |
| "gt_summary": summarize(gt_graph), | |
| } | |
| ) | |
| continue | |
| try: | |
| if args.tile_rows > 1 or args.tile_cols > 1: | |
| pred_dict = extract_graph_tiled( | |
| image_path, | |
| client=client, | |
| model=args.model, | |
| max_tokens=args.max_tokens, | |
| rows=args.tile_rows, | |
| cols=args.tile_cols, | |
| overlap=args.tile_overlap, | |
| dedup_px=args.dedup_px, | |
| seam_filter=not args.no_seam_filter, | |
| seam_threshold=args.seam_threshold_px, | |
| edge_threshold=args.edge_threshold_px, | |
| ) | |
| else: | |
| pred = extract_graph( | |
| image_path, | |
| client=client, | |
| model=args.model, | |
| max_tokens=args.max_tokens, | |
| ) | |
| pred_dict = pred.to_dict() | |
| except Exception as e: | |
| errors.append({"sample": image_path.stem, "stage": "vlm", "error": str(e)}) | |
| continue | |
| if args.semantic_only: | |
| pred_dict = filter_by_types( | |
| {**pred_dict, "directed": gt_graph.get("directed", False)}, | |
| SEMANTIC_EQUIPMENT_TYPES, | |
| ) | |
| # Default to whatever the GT file says; allow CLI override. | |
| if args.force_undirected: | |
| directed = False | |
| elif args.force_directed: | |
| directed = True | |
| else: | |
| directed = gt_graph.get("directed", True) | |
| metrics = evaluate( | |
| pred_dict, | |
| gt_graph, | |
| directed=directed, | |
| match_threshold=args.match_threshold, | |
| ) | |
| per_sample.append( | |
| { | |
| "sample": image_path.stem, | |
| "metrics": metrics, | |
| "pred": pred_dict, | |
| "gt_summary": summarize(gt_graph), | |
| } | |
| ) | |
| result: dict = { | |
| "config": { | |
| "model": args.model, | |
| "max_tokens": args.max_tokens, | |
| "directed": ( | |
| False if args.force_undirected | |
| else True if args.force_directed | |
| else "auto" | |
| ), | |
| "semantic_only": args.semantic_only, | |
| "match_threshold": args.match_threshold, | |
| "tile_rows": args.tile_rows, | |
| "tile_cols": args.tile_cols, | |
| "tile_overlap": args.tile_overlap, | |
| "dedup_px": args.dedup_px, | |
| "seam_filter": not args.no_seam_filter, | |
| "seam_threshold_px": args.seam_threshold_px, | |
| "edge_threshold_px": args.edge_threshold_px, | |
| "limit": args.limit, | |
| "dry_run": args.dry_run, | |
| }, | |
| "per_sample": per_sample, | |
| "errors": errors, | |
| } | |
| if not args.dry_run and per_sample: | |
| result["aggregate"] = aggregate([s["metrics"] for s in per_sample if "metrics" in s]) | |
| output.parent.mkdir(parents=True, exist_ok=True) | |
| output.write_text(json.dumps(result, indent=2, ensure_ascii=False)) | |
| print(f"wrote {output}") | |
| if "aggregate" in result: | |
| agg = result["aggregate"] | |
| print( | |
| f" nodes F1={agg['nodes_micro']['f1']:.3f} " | |
| f"P={agg['nodes_micro']['precision']:.3f} " | |
| f"R={agg['nodes_micro']['recall']:.3f}" | |
| ) | |
| print( | |
| f" edges F1={agg['edges_micro']['f1']:.3f} " | |
| f"P={agg['edges_micro']['precision']:.3f} " | |
| f"R={agg['edges_micro']['recall']:.3f}" | |
| ) | |
| if errors: | |
| print(f" {len(errors)} error(s) — see `errors` in output JSON") | |
| return 0 | |
| def main() -> int: | |
| p = argparse.ArgumentParser(description="PID2Graph VLM evaluation skeleton") | |
| p.add_argument("--image-dir", required=True, help="Directory containing P&ID images") | |
| p.add_argument("--gt-dir", required=True, help="Directory containing graphml ground truth") | |
| p.add_argument("--output", default="results.json", help="Where to write the JSON report") | |
| p.add_argument("--model", default=DEFAULT_MODEL, help=f"Claude model id (default: {DEFAULT_MODEL})") | |
| p.add_argument( | |
| "--max-tokens", | |
| type=int, | |
| default=DEFAULT_MAX_TOKENS, | |
| help=f"VLM max output tokens (default: {DEFAULT_MAX_TOKENS}, streamed)", | |
| ) | |
| p.add_argument("--limit", type=int, default=0, help="Only process the first N samples (0 = all)") | |
| p.add_argument("--match-threshold", type=float, default=0.5, help="Node similarity threshold") | |
| p.add_argument( | |
| "--semantic-only", | |
| action="store_true", | |
| help=( | |
| "Restrict both prediction and GT to semantic equipment " | |
| "(valve, pump, tank, instrumentation, inlet/outlet); " | |
| "drops pipe-primitive nodes like connector/crossing/arrow." | |
| ), | |
| ) | |
| p.add_argument( | |
| "--tile-rows", | |
| type=int, | |
| default=1, | |
| help="Tile the image into this many rows before VLM extraction (default 1 = off)", | |
| ) | |
| p.add_argument( | |
| "--tile-cols", | |
| type=int, | |
| default=1, | |
| help="Tile the image into this many columns before VLM extraction (default 1 = off)", | |
| ) | |
| p.add_argument( | |
| "--tile-overlap", | |
| type=float, | |
| default=0.1, | |
| help="Fractional overlap between adjacent tiles (default 0.1 = 10%%)", | |
| ) | |
| p.add_argument( | |
| "--dedup-px", | |
| type=float, | |
| default=40.0, | |
| help="Bbox-center distance (pixels) under which two same-type nodes are merged", | |
| ) | |
| p.add_argument( | |
| "--no-seam-filter", | |
| action="store_true", | |
| help=( | |
| "Disable the inlet/outlet tile-seam FP filter. By default, " | |
| "tiled extraction drops inlet/outlet nodes whose bbox center " | |
| "sits within 50px of an inner tile seam and is not within " | |
| "30px of the outer image border." | |
| ), | |
| ) | |
| p.add_argument( | |
| "--seam-threshold-px", | |
| type=float, | |
| default=50.0, | |
| help="Distance (px) from an inner tile seam that triggers FP filtering", | |
| ) | |
| p.add_argument( | |
| "--edge-threshold-px", | |
| type=float, | |
| default=30.0, | |
| help="Distance (px) from the outer image edge that exempts a node from filtering", | |
| ) | |
| p.add_argument( | |
| "--force-undirected", | |
| action="store_true", | |
| help="Force undirected edge matching (default: use whatever the GT file says)", | |
| ) | |
| p.add_argument( | |
| "--force-directed", | |
| action="store_true", | |
| help="Force directed edge matching (default: use whatever the GT file says)", | |
| ) | |
| p.add_argument( | |
| "--dry-run", | |
| action="store_true", | |
| help="Skip VLM calls; just load GT and print summaries (for loader debugging)", | |
| ) | |
| return run(p.parse_args()) | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |