"""Claude VLM-based P&ID graph extraction. Uses the Anthropic SDK with: - Vision input (base64-encoded P&ID image) - Streaming (required for high `max_tokens` to avoid HTTP timeouts) - Structured outputs via `output_config.format={"type": "json_schema", ...}` so the model is forced to return our Pydantic schema - Prompt caching on the (stable) system prompt so repeated evaluation runs only pay full price once per 5-minute window Public API: extract_graph(image_path, client=None, model="claude-opus-4-6") -> GraphOut extract_graph_tiled(image_path, ..., rows=2, cols=2) -> dict The tiled variant splits the image into overlapping tiles, runs the VLM on each, and merges the results using bbox-distance deduplication. Used to break the ~50-symbol recall ceiling on large diagrams. Note on scope (semantic-only mode): This prompt intentionally excludes PID2Graph's line-level primitives (connector, crossing, arrow, background, general). It targets only the five semantic equipment categories the VLM can recognize. The matching CLI flag `--semantic-only` filters the ground truth to the same five categories so P/R/F1 are comparable. """ from __future__ import annotations import base64 import json from pathlib import Path from typing import Optional import anthropic from .schema import GraphOut from .tile import ( Tile, filter_seam_artifacts, merge_tile_graphs, split_image, tile_to_base64_png, ) DEFAULT_MODEL = "claude-opus-4-6" DEFAULT_MAX_TOKENS = 64000 # streaming is required at this size DEFAULT_TEMPERATURE = 0.0 # deterministic sampling — kills run-to-run variance SYSTEM_PROMPT = """\ You are an expert annotator for the PID2Graph benchmark, which turns Piping and Instrumentation Diagrams (P&IDs) into node/edge graphs. Your task: given a P&ID image (which may be a full diagram or a cropped region of one), emit every SEMANTIC EQUIPMENT symbol visible as a node, and every pipeline connection between two such symbols as an edge. Line-level primitives (pipe junctions, crossings, arrowheads, background regions) must NOT be emitted. NODE CATEGORIES — use EXACTLY one of these 5 lowercase labels for `type`: - valve all valve bodies: gate, ball, check, control, globe, butterfly, needle, etc. Each valve body is one node. - pump pumps and compressors (rotating equipment) - tank storage tanks, drums, vessels, columns, reactors, heat exchangers, any large process equipment - instrumentation circle/balloon instrument bubbles (PI, TI, FIC, TC, LT, PDT, ...). Each bubble is one node. - inlet/outlet diagram boundary terminals where a pipe enters or leaves the drawing. Use the literal string with a slash: "inlet/outlet". Do not invent other type names. Do not use synonyms. Do not use title case. Do not emit connector, crossing, arrow, background, or general — those are excluded from this task. If a drawn glyph does not clearly fit one of the five categories above, skip it. PER-NODE FIELDS: id unique id within THIS response, e.g. "n1", "n2", "n3", ... type one of the 5 labels above label leave null — PID2Graph ground truth does not store printed tags bbox REQUIRED. An object {xmin, ymin, xmax, ymax} with each coordinate normalized to [0.0, 1.0] relative to the image you are shown. (0, 0) is the top-left corner, (1, 1) is the bottom-right. Give your best estimate of the symbol's tight bounding box. If you cannot estimate it, still emit a best-effort box — never omit the field and never return null. EDGES — edges in PID2Graph are UNDIRECTED. Emit an edge whenever two of the semantic nodes above are joined by a continuous pipeline, even if that pipeline passes through several pipe junctions, crossings, or arrowheads on its way. Those intermediate points must not appear as nodes; collapse the whole physical pipeline into a single direct edge between its two semantic endpoints. PER-EDGE FIELDS: source / target node ids from the `nodes` list type use "solid" label null Guidelines: * Be exhaustive within the semantic scope. * Ignore the title block, legend, border, and revision history. * If the image is a cropped region (e.g. one tile of a larger diagram), only emit symbols and edges visible within this crop. Do not infer connections that continue off-crop — another tile will cover them. * Return ONLY the JSON object matching the schema — no prose, no markdown. """ USER_INSTRUCTION = ( "Extract the semantic-equipment graph from this P&ID image as JSON. " "Only the 5 categories listed in the instructions — no pipe primitives. " "Include a bbox (normalized [0, 1]) for every node." ) def _guess_media_type(path: Path) -> str: ext = path.suffix.lower() return { ".png": "image/png", ".jpg": "image/jpeg", ".jpeg": "image/jpeg", ".gif": "image/gif", ".webp": "image/webp", }.get(ext, "image/png") def _load_image_b64(path: Path) -> tuple[str, str]: data = base64.standard_b64encode(path.read_bytes()).decode("utf-8") return data, _guess_media_type(path) def _graphout_schema() -> dict: """JSON schema for GraphOut, ready for `output_config.format`.""" return GraphOut.model_json_schema() def _call_vlm( image_data: str, media_type: str, client: anthropic.Anthropic, model: str, max_tokens: int, tag: str, temperature: float = DEFAULT_TEMPERATURE, ) -> GraphOut: """Shared VLM request path used by both whole-image and tile extraction. `tag` is a short human-readable identifier (filename, tile name) used only to make error messages point at the right sample. `temperature=0.0` by default for reproducible evaluation runs. """ schema = _graphout_schema() with client.messages.stream( model=model, max_tokens=max_tokens, temperature=temperature, system=[ { "type": "text", "text": SYSTEM_PROMPT, "cache_control": {"type": "ephemeral"}, } ], messages=[ { "role": "user", "content": [ { "type": "image", "source": { "type": "base64", "media_type": media_type, "data": image_data, }, }, {"type": "text", "text": USER_INSTRUCTION}, ], } ], output_config={ "format": { "type": "json_schema", "schema": schema, } }, ) as stream: final = stream.get_final_message() if final.stop_reason == "refusal": raise RuntimeError(f"VLM refused to answer for {tag} (stop_reason=refusal)") if final.stop_reason == "max_tokens": raise RuntimeError( f"VLM hit max_tokens={max_tokens} for {tag} — " f"output truncated; raise max_tokens or simplify the prompt" ) text_block = next((b for b in final.content if b.type == "text"), None) if text_block is None: raise RuntimeError( f"VLM emitted no text block for {tag} (stop_reason={final.stop_reason})" ) try: data = json.loads(text_block.text) except json.JSONDecodeError as e: raise RuntimeError( f"VLM output is not valid JSON for {tag}: {e}\n" f"first 500 chars: {text_block.text[:500]}" ) return GraphOut.model_validate(data) def extract_graph( image_path: Path, client: Optional[anthropic.Anthropic] = None, model: str = DEFAULT_MODEL, max_tokens: int = DEFAULT_MAX_TOKENS, temperature: float = DEFAULT_TEMPERATURE, ) -> GraphOut: """Run the VLM on the whole image and return the parsed graph.""" client = client or anthropic.Anthropic() image_path = Path(image_path) image_data, media_type = _load_image_b64(image_path) return _call_vlm( image_data, media_type, client, model, max_tokens, image_path.name, temperature=temperature, ) def extract_graph_tiled( image_path: Path, client: Optional[anthropic.Anthropic] = None, model: str = DEFAULT_MODEL, max_tokens: int = DEFAULT_MAX_TOKENS, temperature: float = DEFAULT_TEMPERATURE, rows: int = 2, cols: int = 2, overlap: float = 0.1, dedup_px: float = 40.0, seam_filter: bool = True, seam_filter_types: tuple[str, ...] = ("inlet/outlet",), seam_threshold: float = 50.0, edge_threshold: float = 30.0, ) -> dict: """Tile the image, extract each tile, merge via bbox-distance dedup. Returns a plain dict (not a GraphOut) because merged nodes carry extra `bbox` (in global pixel coordinates) and `source_tiles` fields that don't fit the response schema. The dict shape is still compatible with `metrics.evaluate()`. Deduplication rules: * Two nodes from different tiles are merged iff they share the same `type` AND their bbox centers are within `dedup_px` pixels in the un-tiled global image. * Nodes with a null bbox are dropped (they can't be deduped). * Edges are remapped through the merge map; undirected duplicates are collapsed; self-loops (both endpoints merged to the same global node) are dropped. """ client = client or anthropic.Anthropic() image_path = Path(image_path) tiles: list[Tile] = split_image(image_path, rows=rows, cols=cols, overlap=overlap) per_tile: list[tuple[GraphOut, Tile]] = [] for tile in tiles: image_data, media_type = tile_to_base64_png(tile) tag = f"{image_path.name}:{tile.name}" graph = _call_vlm( image_data, media_type, client, model, max_tokens, tag, temperature=temperature, ) per_tile.append((graph, tile)) merged = merge_tile_graphs(per_tile, dedup_px=dedup_px) if seam_filter: merged = filter_seam_artifacts( merged, tiles, types=seam_filter_types, seam_threshold=seam_threshold, edge_threshold=edge_threshold, ) # Attach provenance for downstream debugging / aggregation. merged["tile_stats"] = [ { "tile": tile.name, "n_nodes": len(graph.nodes), "n_edges": len(graph.edges), "nodes_with_bbox": sum(1 for n in graph.nodes if n.bbox is not None), } for graph, tile in per_tile ] return merged