deepkick's picture
Initial commit: PID2Graph Γ— Claude VLM evaluation + Gradio demo
59fa244
"""Claude VLM-based P&ID graph extraction.
Uses the Anthropic SDK with:
- Vision input (base64-encoded P&ID image)
- Streaming (required for high `max_tokens` to avoid HTTP timeouts)
- Structured outputs via `output_config.format={"type": "json_schema", ...}`
so the model is forced to return our Pydantic schema
- Prompt caching on the (stable) system prompt so repeated evaluation runs
only pay full price once per 5-minute window
Public API:
extract_graph(image_path, client=None, model="claude-opus-4-6") -> GraphOut
extract_graph_tiled(image_path, ..., rows=2, cols=2) -> dict
The tiled variant splits the image into overlapping tiles, runs the VLM
on each, and merges the results using bbox-distance deduplication. Used
to break the ~50-symbol recall ceiling on large diagrams.
Note on scope (semantic-only mode):
This prompt intentionally excludes PID2Graph's line-level primitives
(connector, crossing, arrow, background, general). It targets only
the five semantic equipment categories the VLM can recognize. The
matching CLI flag `--semantic-only` filters the ground truth to the
same five categories so P/R/F1 are comparable.
"""
from __future__ import annotations
import base64
import json
from pathlib import Path
from typing import Optional
import anthropic
from .schema import GraphOut
from .tile import (
Tile,
filter_seam_artifacts,
merge_tile_graphs,
split_image,
tile_to_base64_png,
)
DEFAULT_MODEL = "claude-opus-4-6"
DEFAULT_MAX_TOKENS = 64000 # streaming is required at this size
DEFAULT_TEMPERATURE = 0.0 # deterministic sampling β€” kills run-to-run variance
SYSTEM_PROMPT = """\
You are an expert annotator for the PID2Graph benchmark, which turns
Piping and Instrumentation Diagrams (P&IDs) into node/edge graphs.
Your task: given a P&ID image (which may be a full diagram or a cropped
region of one), emit every SEMANTIC EQUIPMENT symbol visible as a node,
and every pipeline connection between two such symbols as an edge.
Line-level primitives (pipe junctions, crossings, arrowheads, background
regions) must NOT be emitted.
NODE CATEGORIES β€” use EXACTLY one of these 5 lowercase labels for `type`:
- valve all valve bodies: gate, ball, check, control, globe,
butterfly, needle, etc. Each valve body is one node.
- pump pumps and compressors (rotating equipment)
- tank storage tanks, drums, vessels, columns, reactors,
heat exchangers, any large process equipment
- instrumentation circle/balloon instrument bubbles (PI, TI, FIC, TC,
LT, PDT, ...). Each bubble is one node.
- inlet/outlet diagram boundary terminals where a pipe enters or
leaves the drawing. Use the literal string with a
slash: "inlet/outlet".
Do not invent other type names. Do not use synonyms. Do not use title case.
Do not emit connector, crossing, arrow, background, or general β€” those
are excluded from this task. If a drawn glyph does not clearly fit one of
the five categories above, skip it.
PER-NODE FIELDS:
id unique id within THIS response, e.g. "n1", "n2", "n3", ...
type one of the 5 labels above
label leave null β€” PID2Graph ground truth does not store printed tags
bbox REQUIRED. An object {xmin, ymin, xmax, ymax} with each coordinate
normalized to [0.0, 1.0] relative to the image you are shown.
(0, 0) is the top-left corner, (1, 1) is the bottom-right. Give
your best estimate of the symbol's tight bounding box. If you
cannot estimate it, still emit a best-effort box β€” never omit
the field and never return null.
EDGES β€” edges in PID2Graph are UNDIRECTED. Emit an edge whenever two of
the semantic nodes above are joined by a continuous pipeline, even if
that pipeline passes through several pipe junctions, crossings, or
arrowheads on its way. Those intermediate points must not appear as
nodes; collapse the whole physical pipeline into a single direct edge
between its two semantic endpoints.
PER-EDGE FIELDS:
source / target node ids from the `nodes` list
type use "solid"
label null
Guidelines:
* Be exhaustive within the semantic scope.
* Ignore the title block, legend, border, and revision history.
* If the image is a cropped region (e.g. one tile of a larger diagram),
only emit symbols and edges visible within this crop. Do not infer
connections that continue off-crop β€” another tile will cover them.
* Return ONLY the JSON object matching the schema β€” no prose, no markdown.
"""
USER_INSTRUCTION = (
"Extract the semantic-equipment graph from this P&ID image as JSON. "
"Only the 5 categories listed in the instructions β€” no pipe primitives. "
"Include a bbox (normalized [0, 1]) for every node."
)
def _guess_media_type(path: Path) -> str:
ext = path.suffix.lower()
return {
".png": "image/png",
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".gif": "image/gif",
".webp": "image/webp",
}.get(ext, "image/png")
def _load_image_b64(path: Path) -> tuple[str, str]:
data = base64.standard_b64encode(path.read_bytes()).decode("utf-8")
return data, _guess_media_type(path)
def _graphout_schema() -> dict:
"""JSON schema for GraphOut, ready for `output_config.format`."""
return GraphOut.model_json_schema()
def _call_vlm(
image_data: str,
media_type: str,
client: anthropic.Anthropic,
model: str,
max_tokens: int,
tag: str,
temperature: float = DEFAULT_TEMPERATURE,
) -> GraphOut:
"""Shared VLM request path used by both whole-image and tile extraction.
`tag` is a short human-readable identifier (filename, tile name) used
only to make error messages point at the right sample.
`temperature=0.0` by default for reproducible evaluation runs.
"""
schema = _graphout_schema()
with client.messages.stream(
model=model,
max_tokens=max_tokens,
temperature=temperature,
system=[
{
"type": "text",
"text": SYSTEM_PROMPT,
"cache_control": {"type": "ephemeral"},
}
],
messages=[
{
"role": "user",
"content": [
{
"type": "image",
"source": {
"type": "base64",
"media_type": media_type,
"data": image_data,
},
},
{"type": "text", "text": USER_INSTRUCTION},
],
}
],
output_config={
"format": {
"type": "json_schema",
"schema": schema,
}
},
) as stream:
final = stream.get_final_message()
if final.stop_reason == "refusal":
raise RuntimeError(f"VLM refused to answer for {tag} (stop_reason=refusal)")
if final.stop_reason == "max_tokens":
raise RuntimeError(
f"VLM hit max_tokens={max_tokens} for {tag} β€” "
f"output truncated; raise max_tokens or simplify the prompt"
)
text_block = next((b for b in final.content if b.type == "text"), None)
if text_block is None:
raise RuntimeError(
f"VLM emitted no text block for {tag} (stop_reason={final.stop_reason})"
)
try:
data = json.loads(text_block.text)
except json.JSONDecodeError as e:
raise RuntimeError(
f"VLM output is not valid JSON for {tag}: {e}\n"
f"first 500 chars: {text_block.text[:500]}"
)
return GraphOut.model_validate(data)
def extract_graph(
image_path: Path,
client: Optional[anthropic.Anthropic] = None,
model: str = DEFAULT_MODEL,
max_tokens: int = DEFAULT_MAX_TOKENS,
temperature: float = DEFAULT_TEMPERATURE,
) -> GraphOut:
"""Run the VLM on the whole image and return the parsed graph."""
client = client or anthropic.Anthropic()
image_path = Path(image_path)
image_data, media_type = _load_image_b64(image_path)
return _call_vlm(
image_data, media_type, client, model, max_tokens,
image_path.name, temperature=temperature,
)
def extract_graph_tiled(
image_path: Path,
client: Optional[anthropic.Anthropic] = None,
model: str = DEFAULT_MODEL,
max_tokens: int = DEFAULT_MAX_TOKENS,
temperature: float = DEFAULT_TEMPERATURE,
rows: int = 2,
cols: int = 2,
overlap: float = 0.1,
dedup_px: float = 40.0,
seam_filter: bool = True,
seam_filter_types: tuple[str, ...] = ("inlet/outlet",),
seam_threshold: float = 50.0,
edge_threshold: float = 30.0,
) -> dict:
"""Tile the image, extract each tile, merge via bbox-distance dedup.
Returns a plain dict (not a GraphOut) because merged nodes carry
extra `bbox` (in global pixel coordinates) and `source_tiles` fields
that don't fit the response schema. The dict shape is still
compatible with `metrics.evaluate()`.
Deduplication rules:
* Two nodes from different tiles are merged iff they share the
same `type` AND their bbox centers are within `dedup_px` pixels
in the un-tiled global image.
* Nodes with a null bbox are dropped (they can't be deduped).
* Edges are remapped through the merge map; undirected duplicates
are collapsed; self-loops (both endpoints merged to the same
global node) are dropped.
"""
client = client or anthropic.Anthropic()
image_path = Path(image_path)
tiles: list[Tile] = split_image(image_path, rows=rows, cols=cols, overlap=overlap)
per_tile: list[tuple[GraphOut, Tile]] = []
for tile in tiles:
image_data, media_type = tile_to_base64_png(tile)
tag = f"{image_path.name}:{tile.name}"
graph = _call_vlm(
image_data, media_type, client, model, max_tokens,
tag, temperature=temperature,
)
per_tile.append((graph, tile))
merged = merge_tile_graphs(per_tile, dedup_px=dedup_px)
if seam_filter:
merged = filter_seam_artifacts(
merged,
tiles,
types=seam_filter_types,
seam_threshold=seam_threshold,
edge_threshold=edge_threshold,
)
# Attach provenance for downstream debugging / aggregation.
merged["tile_stats"] = [
{
"tile": tile.name,
"n_nodes": len(graph.nodes),
"n_edges": len(graph.edges),
"nodes_with_bbox": sum(1 for n in graph.nodes if n.bbox is not None),
}
for graph, tile in per_tile
]
return merged