Spaces:
Running
Running
File size: 10,882 Bytes
59fa244 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 | """Claude VLM-based P&ID graph extraction.
Uses the Anthropic SDK with:
- Vision input (base64-encoded P&ID image)
- Streaming (required for high `max_tokens` to avoid HTTP timeouts)
- Structured outputs via `output_config.format={"type": "json_schema", ...}`
so the model is forced to return our Pydantic schema
- Prompt caching on the (stable) system prompt so repeated evaluation runs
only pay full price once per 5-minute window
Public API:
extract_graph(image_path, client=None, model="claude-opus-4-6") -> GraphOut
extract_graph_tiled(image_path, ..., rows=2, cols=2) -> dict
The tiled variant splits the image into overlapping tiles, runs the VLM
on each, and merges the results using bbox-distance deduplication. Used
to break the ~50-symbol recall ceiling on large diagrams.
Note on scope (semantic-only mode):
This prompt intentionally excludes PID2Graph's line-level primitives
(connector, crossing, arrow, background, general). It targets only
the five semantic equipment categories the VLM can recognize. The
matching CLI flag `--semantic-only` filters the ground truth to the
same five categories so P/R/F1 are comparable.
"""
from __future__ import annotations
import base64
import json
from pathlib import Path
from typing import Optional
import anthropic
from .schema import GraphOut
from .tile import (
Tile,
filter_seam_artifacts,
merge_tile_graphs,
split_image,
tile_to_base64_png,
)
DEFAULT_MODEL = "claude-opus-4-6"
DEFAULT_MAX_TOKENS = 64000 # streaming is required at this size
DEFAULT_TEMPERATURE = 0.0 # deterministic sampling β kills run-to-run variance
SYSTEM_PROMPT = """\
You are an expert annotator for the PID2Graph benchmark, which turns
Piping and Instrumentation Diagrams (P&IDs) into node/edge graphs.
Your task: given a P&ID image (which may be a full diagram or a cropped
region of one), emit every SEMANTIC EQUIPMENT symbol visible as a node,
and every pipeline connection between two such symbols as an edge.
Line-level primitives (pipe junctions, crossings, arrowheads, background
regions) must NOT be emitted.
NODE CATEGORIES β use EXACTLY one of these 5 lowercase labels for `type`:
- valve all valve bodies: gate, ball, check, control, globe,
butterfly, needle, etc. Each valve body is one node.
- pump pumps and compressors (rotating equipment)
- tank storage tanks, drums, vessels, columns, reactors,
heat exchangers, any large process equipment
- instrumentation circle/balloon instrument bubbles (PI, TI, FIC, TC,
LT, PDT, ...). Each bubble is one node.
- inlet/outlet diagram boundary terminals where a pipe enters or
leaves the drawing. Use the literal string with a
slash: "inlet/outlet".
Do not invent other type names. Do not use synonyms. Do not use title case.
Do not emit connector, crossing, arrow, background, or general β those
are excluded from this task. If a drawn glyph does not clearly fit one of
the five categories above, skip it.
PER-NODE FIELDS:
id unique id within THIS response, e.g. "n1", "n2", "n3", ...
type one of the 5 labels above
label leave null β PID2Graph ground truth does not store printed tags
bbox REQUIRED. An object {xmin, ymin, xmax, ymax} with each coordinate
normalized to [0.0, 1.0] relative to the image you are shown.
(0, 0) is the top-left corner, (1, 1) is the bottom-right. Give
your best estimate of the symbol's tight bounding box. If you
cannot estimate it, still emit a best-effort box β never omit
the field and never return null.
EDGES β edges in PID2Graph are UNDIRECTED. Emit an edge whenever two of
the semantic nodes above are joined by a continuous pipeline, even if
that pipeline passes through several pipe junctions, crossings, or
arrowheads on its way. Those intermediate points must not appear as
nodes; collapse the whole physical pipeline into a single direct edge
between its two semantic endpoints.
PER-EDGE FIELDS:
source / target node ids from the `nodes` list
type use "solid"
label null
Guidelines:
* Be exhaustive within the semantic scope.
* Ignore the title block, legend, border, and revision history.
* If the image is a cropped region (e.g. one tile of a larger diagram),
only emit symbols and edges visible within this crop. Do not infer
connections that continue off-crop β another tile will cover them.
* Return ONLY the JSON object matching the schema β no prose, no markdown.
"""
USER_INSTRUCTION = (
"Extract the semantic-equipment graph from this P&ID image as JSON. "
"Only the 5 categories listed in the instructions β no pipe primitives. "
"Include a bbox (normalized [0, 1]) for every node."
)
def _guess_media_type(path: Path) -> str:
ext = path.suffix.lower()
return {
".png": "image/png",
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".gif": "image/gif",
".webp": "image/webp",
}.get(ext, "image/png")
def _load_image_b64(path: Path) -> tuple[str, str]:
data = base64.standard_b64encode(path.read_bytes()).decode("utf-8")
return data, _guess_media_type(path)
def _graphout_schema() -> dict:
"""JSON schema for GraphOut, ready for `output_config.format`."""
return GraphOut.model_json_schema()
def _call_vlm(
image_data: str,
media_type: str,
client: anthropic.Anthropic,
model: str,
max_tokens: int,
tag: str,
temperature: float = DEFAULT_TEMPERATURE,
) -> GraphOut:
"""Shared VLM request path used by both whole-image and tile extraction.
`tag` is a short human-readable identifier (filename, tile name) used
only to make error messages point at the right sample.
`temperature=0.0` by default for reproducible evaluation runs.
"""
schema = _graphout_schema()
with client.messages.stream(
model=model,
max_tokens=max_tokens,
temperature=temperature,
system=[
{
"type": "text",
"text": SYSTEM_PROMPT,
"cache_control": {"type": "ephemeral"},
}
],
messages=[
{
"role": "user",
"content": [
{
"type": "image",
"source": {
"type": "base64",
"media_type": media_type,
"data": image_data,
},
},
{"type": "text", "text": USER_INSTRUCTION},
],
}
],
output_config={
"format": {
"type": "json_schema",
"schema": schema,
}
},
) as stream:
final = stream.get_final_message()
if final.stop_reason == "refusal":
raise RuntimeError(f"VLM refused to answer for {tag} (stop_reason=refusal)")
if final.stop_reason == "max_tokens":
raise RuntimeError(
f"VLM hit max_tokens={max_tokens} for {tag} β "
f"output truncated; raise max_tokens or simplify the prompt"
)
text_block = next((b for b in final.content if b.type == "text"), None)
if text_block is None:
raise RuntimeError(
f"VLM emitted no text block for {tag} (stop_reason={final.stop_reason})"
)
try:
data = json.loads(text_block.text)
except json.JSONDecodeError as e:
raise RuntimeError(
f"VLM output is not valid JSON for {tag}: {e}\n"
f"first 500 chars: {text_block.text[:500]}"
)
return GraphOut.model_validate(data)
def extract_graph(
image_path: Path,
client: Optional[anthropic.Anthropic] = None,
model: str = DEFAULT_MODEL,
max_tokens: int = DEFAULT_MAX_TOKENS,
temperature: float = DEFAULT_TEMPERATURE,
) -> GraphOut:
"""Run the VLM on the whole image and return the parsed graph."""
client = client or anthropic.Anthropic()
image_path = Path(image_path)
image_data, media_type = _load_image_b64(image_path)
return _call_vlm(
image_data, media_type, client, model, max_tokens,
image_path.name, temperature=temperature,
)
def extract_graph_tiled(
image_path: Path,
client: Optional[anthropic.Anthropic] = None,
model: str = DEFAULT_MODEL,
max_tokens: int = DEFAULT_MAX_TOKENS,
temperature: float = DEFAULT_TEMPERATURE,
rows: int = 2,
cols: int = 2,
overlap: float = 0.1,
dedup_px: float = 40.0,
seam_filter: bool = True,
seam_filter_types: tuple[str, ...] = ("inlet/outlet",),
seam_threshold: float = 50.0,
edge_threshold: float = 30.0,
) -> dict:
"""Tile the image, extract each tile, merge via bbox-distance dedup.
Returns a plain dict (not a GraphOut) because merged nodes carry
extra `bbox` (in global pixel coordinates) and `source_tiles` fields
that don't fit the response schema. The dict shape is still
compatible with `metrics.evaluate()`.
Deduplication rules:
* Two nodes from different tiles are merged iff they share the
same `type` AND their bbox centers are within `dedup_px` pixels
in the un-tiled global image.
* Nodes with a null bbox are dropped (they can't be deduped).
* Edges are remapped through the merge map; undirected duplicates
are collapsed; self-loops (both endpoints merged to the same
global node) are dropped.
"""
client = client or anthropic.Anthropic()
image_path = Path(image_path)
tiles: list[Tile] = split_image(image_path, rows=rows, cols=cols, overlap=overlap)
per_tile: list[tuple[GraphOut, Tile]] = []
for tile in tiles:
image_data, media_type = tile_to_base64_png(tile)
tag = f"{image_path.name}:{tile.name}"
graph = _call_vlm(
image_data, media_type, client, model, max_tokens,
tag, temperature=temperature,
)
per_tile.append((graph, tile))
merged = merge_tile_graphs(per_tile, dedup_px=dedup_px)
if seam_filter:
merged = filter_seam_artifacts(
merged,
tiles,
types=seam_filter_types,
seam_threshold=seam_threshold,
edge_threshold=edge_threshold,
)
# Attach provenance for downstream debugging / aggregation.
merged["tile_stats"] = [
{
"tile": tile.name,
"n_nodes": len(graph.nodes),
"n_edges": len(graph.edges),
"nodes_with_bbox": sum(1 for n in graph.nodes if n.bbox is not None),
}
for graph, tile in per_tile
]
return merged
|