deepkick's picture
Initial commit: PID2Graph × Claude VLM evaluation + Gradio demo
59fa244
"""CLI entrypoint for the PID2Graph evaluation skeleton.
Usage:
python -m pid2graph_eval.cli \
--image-dir path/to/images \
--gt-dir path/to/graphml \
--output results.json \
--limit 10
The loader pairs each image with the graphml of the same stem
(`A-001.png` ↔ `A-001.graphml`). Adjust `pair_samples` once you know the
actual PID2Graph on-disk layout.
"""
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
import anthropic
from tqdm import tqdm
from .extractor import (
DEFAULT_MAX_TOKENS,
DEFAULT_MODEL,
extract_graph,
extract_graph_tiled,
)
from .gt_loader import (
SEMANTIC_EQUIPMENT_TYPES,
collapse_through_primitives,
filter_by_types,
load_graphml,
summarize,
)
from .metrics import aggregate, evaluate
IMAGE_EXTS = (".png", ".jpg", ".jpeg", ".webp")
GT_EXTS = (".graphml", ".xml")
def pair_samples(image_dir: Path, gt_dir: Path) -> list[tuple[Path, Path]]:
"""Match each image to a graphml file by filename stem."""
gt_by_stem: dict[str, Path] = {}
for ext in GT_EXTS:
for p in gt_dir.rglob(f"*{ext}"):
gt_by_stem[p.stem] = p
pairs: list[tuple[Path, Path]] = []
for ext in IMAGE_EXTS:
for img in sorted(image_dir.rglob(f"*{ext}")):
gt = gt_by_stem.get(img.stem)
if gt is not None:
pairs.append((img, gt))
return pairs
def run(args: argparse.Namespace) -> int:
image_dir = Path(args.image_dir)
gt_dir = Path(args.gt_dir)
output = Path(args.output)
if not image_dir.exists():
print(f"error: --image-dir does not exist: {image_dir}", file=sys.stderr)
return 2
if not gt_dir.exists():
print(f"error: --gt-dir does not exist: {gt_dir}", file=sys.stderr)
return 2
pairs = pair_samples(image_dir, gt_dir)
if args.limit:
pairs = pairs[: args.limit]
if not pairs:
print("error: no (image, graphml) pairs found — check stems match", file=sys.stderr)
return 1
print(f"found {len(pairs)} sample pair(s)")
client = anthropic.Anthropic()
per_sample: list[dict] = []
errors: list[dict] = []
for image_path, gt_path in tqdm(pairs, desc="eval"):
try:
gt_graph = load_graphml(gt_path)
except Exception as e: # parsing a single bad file shouldn't kill the run
errors.append({"sample": image_path.stem, "stage": "gt_load", "error": str(e)})
continue
# Semantic-only mode: drop line-primitive nodes from the GT and
# re-wire the remaining semantic nodes via the original pipe
# connectivity (BFS through primitives). This matches the format
# the VLM is instructed to emit — one direct edge per physical
# pipeline, regardless of how many junctions it passes through.
if args.semantic_only:
gt_graph = collapse_through_primitives(gt_graph, SEMANTIC_EQUIPMENT_TYPES)
if args.dry_run:
per_sample.append(
{
"sample": image_path.stem,
"gt_summary": summarize(gt_graph),
}
)
continue
try:
if args.tile_rows > 1 or args.tile_cols > 1:
pred_dict = extract_graph_tiled(
image_path,
client=client,
model=args.model,
max_tokens=args.max_tokens,
rows=args.tile_rows,
cols=args.tile_cols,
overlap=args.tile_overlap,
dedup_px=args.dedup_px,
seam_filter=not args.no_seam_filter,
seam_threshold=args.seam_threshold_px,
edge_threshold=args.edge_threshold_px,
)
else:
pred = extract_graph(
image_path,
client=client,
model=args.model,
max_tokens=args.max_tokens,
)
pred_dict = pred.to_dict()
except Exception as e:
errors.append({"sample": image_path.stem, "stage": "vlm", "error": str(e)})
continue
if args.semantic_only:
pred_dict = filter_by_types(
{**pred_dict, "directed": gt_graph.get("directed", False)},
SEMANTIC_EQUIPMENT_TYPES,
)
# Default to whatever the GT file says; allow CLI override.
if args.force_undirected:
directed = False
elif args.force_directed:
directed = True
else:
directed = gt_graph.get("directed", True)
metrics = evaluate(
pred_dict,
gt_graph,
directed=directed,
match_threshold=args.match_threshold,
)
per_sample.append(
{
"sample": image_path.stem,
"metrics": metrics,
"pred": pred_dict,
"gt_summary": summarize(gt_graph),
}
)
result: dict = {
"config": {
"model": args.model,
"max_tokens": args.max_tokens,
"directed": (
False if args.force_undirected
else True if args.force_directed
else "auto"
),
"semantic_only": args.semantic_only,
"match_threshold": args.match_threshold,
"tile_rows": args.tile_rows,
"tile_cols": args.tile_cols,
"tile_overlap": args.tile_overlap,
"dedup_px": args.dedup_px,
"seam_filter": not args.no_seam_filter,
"seam_threshold_px": args.seam_threshold_px,
"edge_threshold_px": args.edge_threshold_px,
"limit": args.limit,
"dry_run": args.dry_run,
},
"per_sample": per_sample,
"errors": errors,
}
if not args.dry_run and per_sample:
result["aggregate"] = aggregate([s["metrics"] for s in per_sample if "metrics" in s])
output.parent.mkdir(parents=True, exist_ok=True)
output.write_text(json.dumps(result, indent=2, ensure_ascii=False))
print(f"wrote {output}")
if "aggregate" in result:
agg = result["aggregate"]
print(
f" nodes F1={agg['nodes_micro']['f1']:.3f} "
f"P={agg['nodes_micro']['precision']:.3f} "
f"R={agg['nodes_micro']['recall']:.3f}"
)
print(
f" edges F1={agg['edges_micro']['f1']:.3f} "
f"P={agg['edges_micro']['precision']:.3f} "
f"R={agg['edges_micro']['recall']:.3f}"
)
if errors:
print(f" {len(errors)} error(s) — see `errors` in output JSON")
return 0
def main() -> int:
p = argparse.ArgumentParser(description="PID2Graph VLM evaluation skeleton")
p.add_argument("--image-dir", required=True, help="Directory containing P&ID images")
p.add_argument("--gt-dir", required=True, help="Directory containing graphml ground truth")
p.add_argument("--output", default="results.json", help="Where to write the JSON report")
p.add_argument("--model", default=DEFAULT_MODEL, help=f"Claude model id (default: {DEFAULT_MODEL})")
p.add_argument(
"--max-tokens",
type=int,
default=DEFAULT_MAX_TOKENS,
help=f"VLM max output tokens (default: {DEFAULT_MAX_TOKENS}, streamed)",
)
p.add_argument("--limit", type=int, default=0, help="Only process the first N samples (0 = all)")
p.add_argument("--match-threshold", type=float, default=0.5, help="Node similarity threshold")
p.add_argument(
"--semantic-only",
action="store_true",
help=(
"Restrict both prediction and GT to semantic equipment "
"(valve, pump, tank, instrumentation, inlet/outlet); "
"drops pipe-primitive nodes like connector/crossing/arrow."
),
)
p.add_argument(
"--tile-rows",
type=int,
default=1,
help="Tile the image into this many rows before VLM extraction (default 1 = off)",
)
p.add_argument(
"--tile-cols",
type=int,
default=1,
help="Tile the image into this many columns before VLM extraction (default 1 = off)",
)
p.add_argument(
"--tile-overlap",
type=float,
default=0.1,
help="Fractional overlap between adjacent tiles (default 0.1 = 10%%)",
)
p.add_argument(
"--dedup-px",
type=float,
default=40.0,
help="Bbox-center distance (pixels) under which two same-type nodes are merged",
)
p.add_argument(
"--no-seam-filter",
action="store_true",
help=(
"Disable the inlet/outlet tile-seam FP filter. By default, "
"tiled extraction drops inlet/outlet nodes whose bbox center "
"sits within 50px of an inner tile seam and is not within "
"30px of the outer image border."
),
)
p.add_argument(
"--seam-threshold-px",
type=float,
default=50.0,
help="Distance (px) from an inner tile seam that triggers FP filtering",
)
p.add_argument(
"--edge-threshold-px",
type=float,
default=30.0,
help="Distance (px) from the outer image edge that exempts a node from filtering",
)
p.add_argument(
"--force-undirected",
action="store_true",
help="Force undirected edge matching (default: use whatever the GT file says)",
)
p.add_argument(
"--force-directed",
action="store_true",
help="Force directed edge matching (default: use whatever the GT file says)",
)
p.add_argument(
"--dry-run",
action="store_true",
help="Skip VLM calls; just load GT and print summaries (for loader debugging)",
)
return run(p.parse_args())
if __name__ == "__main__":
raise SystemExit(main())