Spaces:
Running
Running
| """Claude VLM-based P&ID graph extraction. | |
| Uses the Anthropic SDK with: | |
| - Vision input (base64-encoded P&ID image) | |
| - Streaming (required for high `max_tokens` to avoid HTTP timeouts) | |
| - Structured outputs via `output_config.format={"type": "json_schema", ...}` | |
| so the model is forced to return our Pydantic schema | |
| - Prompt caching on the (stable) system prompt so repeated evaluation runs | |
| only pay full price once per 5-minute window | |
| Public API: | |
| extract_graph(image_path, client=None, model="claude-opus-4-6") -> GraphOut | |
| extract_graph_tiled(image_path, ..., rows=2, cols=2) -> dict | |
| The tiled variant splits the image into overlapping tiles, runs the VLM | |
| on each, and merges the results using bbox-distance deduplication. Used | |
| to break the ~50-symbol recall ceiling on large diagrams. | |
| Note on scope (semantic-only mode): | |
| This prompt intentionally excludes PID2Graph's line-level primitives | |
| (connector, crossing, arrow, background, general). It targets only | |
| the five semantic equipment categories the VLM can recognize. The | |
| matching CLI flag `--semantic-only` filters the ground truth to the | |
| same five categories so P/R/F1 are comparable. | |
| """ | |
| from __future__ import annotations | |
| import base64 | |
| import json | |
| from pathlib import Path | |
| from typing import Optional | |
| import anthropic | |
| from .schema import GraphOut | |
| from .tile import ( | |
| Tile, | |
| filter_seam_artifacts, | |
| merge_tile_graphs, | |
| split_image, | |
| tile_to_base64_png, | |
| ) | |
| DEFAULT_MODEL = "claude-opus-4-6" | |
| DEFAULT_MAX_TOKENS = 64000 # streaming is required at this size | |
| DEFAULT_TEMPERATURE = 0.0 # deterministic sampling β kills run-to-run variance | |
| SYSTEM_PROMPT = """\ | |
| You are an expert annotator for the PID2Graph benchmark, which turns | |
| Piping and Instrumentation Diagrams (P&IDs) into node/edge graphs. | |
| Your task: given a P&ID image (which may be a full diagram or a cropped | |
| region of one), emit every SEMANTIC EQUIPMENT symbol visible as a node, | |
| and every pipeline connection between two such symbols as an edge. | |
| Line-level primitives (pipe junctions, crossings, arrowheads, background | |
| regions) must NOT be emitted. | |
| NODE CATEGORIES β use EXACTLY one of these 5 lowercase labels for `type`: | |
| - valve all valve bodies: gate, ball, check, control, globe, | |
| butterfly, needle, etc. Each valve body is one node. | |
| - pump pumps and compressors (rotating equipment) | |
| - tank storage tanks, drums, vessels, columns, reactors, | |
| heat exchangers, any large process equipment | |
| - instrumentation circle/balloon instrument bubbles (PI, TI, FIC, TC, | |
| LT, PDT, ...). Each bubble is one node. | |
| - inlet/outlet diagram boundary terminals where a pipe enters or | |
| leaves the drawing. Use the literal string with a | |
| slash: "inlet/outlet". | |
| Do not invent other type names. Do not use synonyms. Do not use title case. | |
| Do not emit connector, crossing, arrow, background, or general β those | |
| are excluded from this task. If a drawn glyph does not clearly fit one of | |
| the five categories above, skip it. | |
| PER-NODE FIELDS: | |
| id unique id within THIS response, e.g. "n1", "n2", "n3", ... | |
| type one of the 5 labels above | |
| label leave null β PID2Graph ground truth does not store printed tags | |
| bbox REQUIRED. An object {xmin, ymin, xmax, ymax} with each coordinate | |
| normalized to [0.0, 1.0] relative to the image you are shown. | |
| (0, 0) is the top-left corner, (1, 1) is the bottom-right. Give | |
| your best estimate of the symbol's tight bounding box. If you | |
| cannot estimate it, still emit a best-effort box β never omit | |
| the field and never return null. | |
| EDGES β edges in PID2Graph are UNDIRECTED. Emit an edge whenever two of | |
| the semantic nodes above are joined by a continuous pipeline, even if | |
| that pipeline passes through several pipe junctions, crossings, or | |
| arrowheads on its way. Those intermediate points must not appear as | |
| nodes; collapse the whole physical pipeline into a single direct edge | |
| between its two semantic endpoints. | |
| PER-EDGE FIELDS: | |
| source / target node ids from the `nodes` list | |
| type use "solid" | |
| label null | |
| Guidelines: | |
| * Be exhaustive within the semantic scope. | |
| * Ignore the title block, legend, border, and revision history. | |
| * If the image is a cropped region (e.g. one tile of a larger diagram), | |
| only emit symbols and edges visible within this crop. Do not infer | |
| connections that continue off-crop β another tile will cover them. | |
| * Return ONLY the JSON object matching the schema β no prose, no markdown. | |
| """ | |
| USER_INSTRUCTION = ( | |
| "Extract the semantic-equipment graph from this P&ID image as JSON. " | |
| "Only the 5 categories listed in the instructions β no pipe primitives. " | |
| "Include a bbox (normalized [0, 1]) for every node." | |
| ) | |
| def _guess_media_type(path: Path) -> str: | |
| ext = path.suffix.lower() | |
| return { | |
| ".png": "image/png", | |
| ".jpg": "image/jpeg", | |
| ".jpeg": "image/jpeg", | |
| ".gif": "image/gif", | |
| ".webp": "image/webp", | |
| }.get(ext, "image/png") | |
| def _load_image_b64(path: Path) -> tuple[str, str]: | |
| data = base64.standard_b64encode(path.read_bytes()).decode("utf-8") | |
| return data, _guess_media_type(path) | |
| def _graphout_schema() -> dict: | |
| """JSON schema for GraphOut, ready for `output_config.format`.""" | |
| return GraphOut.model_json_schema() | |
| def _call_vlm( | |
| image_data: str, | |
| media_type: str, | |
| client: anthropic.Anthropic, | |
| model: str, | |
| max_tokens: int, | |
| tag: str, | |
| temperature: float = DEFAULT_TEMPERATURE, | |
| ) -> GraphOut: | |
| """Shared VLM request path used by both whole-image and tile extraction. | |
| `tag` is a short human-readable identifier (filename, tile name) used | |
| only to make error messages point at the right sample. | |
| `temperature=0.0` by default for reproducible evaluation runs. | |
| """ | |
| schema = _graphout_schema() | |
| with client.messages.stream( | |
| model=model, | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| system=[ | |
| { | |
| "type": "text", | |
| "text": SYSTEM_PROMPT, | |
| "cache_control": {"type": "ephemeral"}, | |
| } | |
| ], | |
| messages=[ | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "image", | |
| "source": { | |
| "type": "base64", | |
| "media_type": media_type, | |
| "data": image_data, | |
| }, | |
| }, | |
| {"type": "text", "text": USER_INSTRUCTION}, | |
| ], | |
| } | |
| ], | |
| output_config={ | |
| "format": { | |
| "type": "json_schema", | |
| "schema": schema, | |
| } | |
| }, | |
| ) as stream: | |
| final = stream.get_final_message() | |
| if final.stop_reason == "refusal": | |
| raise RuntimeError(f"VLM refused to answer for {tag} (stop_reason=refusal)") | |
| if final.stop_reason == "max_tokens": | |
| raise RuntimeError( | |
| f"VLM hit max_tokens={max_tokens} for {tag} β " | |
| f"output truncated; raise max_tokens or simplify the prompt" | |
| ) | |
| text_block = next((b for b in final.content if b.type == "text"), None) | |
| if text_block is None: | |
| raise RuntimeError( | |
| f"VLM emitted no text block for {tag} (stop_reason={final.stop_reason})" | |
| ) | |
| try: | |
| data = json.loads(text_block.text) | |
| except json.JSONDecodeError as e: | |
| raise RuntimeError( | |
| f"VLM output is not valid JSON for {tag}: {e}\n" | |
| f"first 500 chars: {text_block.text[:500]}" | |
| ) | |
| return GraphOut.model_validate(data) | |
| def extract_graph( | |
| image_path: Path, | |
| client: Optional[anthropic.Anthropic] = None, | |
| model: str = DEFAULT_MODEL, | |
| max_tokens: int = DEFAULT_MAX_TOKENS, | |
| temperature: float = DEFAULT_TEMPERATURE, | |
| ) -> GraphOut: | |
| """Run the VLM on the whole image and return the parsed graph.""" | |
| client = client or anthropic.Anthropic() | |
| image_path = Path(image_path) | |
| image_data, media_type = _load_image_b64(image_path) | |
| return _call_vlm( | |
| image_data, media_type, client, model, max_tokens, | |
| image_path.name, temperature=temperature, | |
| ) | |
| def extract_graph_tiled( | |
| image_path: Path, | |
| client: Optional[anthropic.Anthropic] = None, | |
| model: str = DEFAULT_MODEL, | |
| max_tokens: int = DEFAULT_MAX_TOKENS, | |
| temperature: float = DEFAULT_TEMPERATURE, | |
| rows: int = 2, | |
| cols: int = 2, | |
| overlap: float = 0.1, | |
| dedup_px: float = 40.0, | |
| seam_filter: bool = True, | |
| seam_filter_types: tuple[str, ...] = ("inlet/outlet",), | |
| seam_threshold: float = 50.0, | |
| edge_threshold: float = 30.0, | |
| ) -> dict: | |
| """Tile the image, extract each tile, merge via bbox-distance dedup. | |
| Returns a plain dict (not a GraphOut) because merged nodes carry | |
| extra `bbox` (in global pixel coordinates) and `source_tiles` fields | |
| that don't fit the response schema. The dict shape is still | |
| compatible with `metrics.evaluate()`. | |
| Deduplication rules: | |
| * Two nodes from different tiles are merged iff they share the | |
| same `type` AND their bbox centers are within `dedup_px` pixels | |
| in the un-tiled global image. | |
| * Nodes with a null bbox are dropped (they can't be deduped). | |
| * Edges are remapped through the merge map; undirected duplicates | |
| are collapsed; self-loops (both endpoints merged to the same | |
| global node) are dropped. | |
| """ | |
| client = client or anthropic.Anthropic() | |
| image_path = Path(image_path) | |
| tiles: list[Tile] = split_image(image_path, rows=rows, cols=cols, overlap=overlap) | |
| per_tile: list[tuple[GraphOut, Tile]] = [] | |
| for tile in tiles: | |
| image_data, media_type = tile_to_base64_png(tile) | |
| tag = f"{image_path.name}:{tile.name}" | |
| graph = _call_vlm( | |
| image_data, media_type, client, model, max_tokens, | |
| tag, temperature=temperature, | |
| ) | |
| per_tile.append((graph, tile)) | |
| merged = merge_tile_graphs(per_tile, dedup_px=dedup_px) | |
| if seam_filter: | |
| merged = filter_seam_artifacts( | |
| merged, | |
| tiles, | |
| types=seam_filter_types, | |
| seam_threshold=seam_threshold, | |
| edge_threshold=edge_threshold, | |
| ) | |
| # Attach provenance for downstream debugging / aggregation. | |
| merged["tile_stats"] = [ | |
| { | |
| "tile": tile.name, | |
| "n_nodes": len(graph.nodes), | |
| "n_edges": len(graph.edges), | |
| "nodes_with_bbox": sum(1 for n in graph.nodes if n.bbox is not None), | |
| } | |
| for graph, tile in per_tile | |
| ] | |
| return merged | |