File size: 10,882 Bytes
59fa244
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
"""Claude VLM-based P&ID graph extraction.

Uses the Anthropic SDK with:
- Vision input (base64-encoded P&ID image)
- Streaming (required for high `max_tokens` to avoid HTTP timeouts)
- Structured outputs via `output_config.format={"type": "json_schema", ...}`
  so the model is forced to return our Pydantic schema
- Prompt caching on the (stable) system prompt so repeated evaluation runs
  only pay full price once per 5-minute window

Public API:
    extract_graph(image_path, client=None, model="claude-opus-4-6") -> GraphOut
    extract_graph_tiled(image_path, ..., rows=2, cols=2) -> dict

The tiled variant splits the image into overlapping tiles, runs the VLM
on each, and merges the results using bbox-distance deduplication. Used
to break the ~50-symbol recall ceiling on large diagrams.

Note on scope (semantic-only mode):
    This prompt intentionally excludes PID2Graph's line-level primitives
    (connector, crossing, arrow, background, general). It targets only
    the five semantic equipment categories the VLM can recognize. The
    matching CLI flag `--semantic-only` filters the ground truth to the
    same five categories so P/R/F1 are comparable.
"""

from __future__ import annotations

import base64
import json
from pathlib import Path
from typing import Optional

import anthropic

from .schema import GraphOut
from .tile import (
    Tile,
    filter_seam_artifacts,
    merge_tile_graphs,
    split_image,
    tile_to_base64_png,
)

DEFAULT_MODEL = "claude-opus-4-6"
DEFAULT_MAX_TOKENS = 64000  # streaming is required at this size
DEFAULT_TEMPERATURE = 0.0  # deterministic sampling β€” kills run-to-run variance

SYSTEM_PROMPT = """\
You are an expert annotator for the PID2Graph benchmark, which turns
Piping and Instrumentation Diagrams (P&IDs) into node/edge graphs.

Your task: given a P&ID image (which may be a full diagram or a cropped
region of one), emit every SEMANTIC EQUIPMENT symbol visible as a node,
and every pipeline connection between two such symbols as an edge.
Line-level primitives (pipe junctions, crossings, arrowheads, background
regions) must NOT be emitted.

NODE CATEGORIES β€” use EXACTLY one of these 5 lowercase labels for `type`:

  - valve             all valve bodies: gate, ball, check, control, globe,
                      butterfly, needle, etc. Each valve body is one node.
  - pump              pumps and compressors (rotating equipment)
  - tank              storage tanks, drums, vessels, columns, reactors,
                      heat exchangers, any large process equipment
  - instrumentation   circle/balloon instrument bubbles (PI, TI, FIC, TC,
                      LT, PDT, ...). Each bubble is one node.
  - inlet/outlet      diagram boundary terminals where a pipe enters or
                      leaves the drawing. Use the literal string with a
                      slash: "inlet/outlet".

Do not invent other type names. Do not use synonyms. Do not use title case.
Do not emit connector, crossing, arrow, background, or general β€” those
are excluded from this task. If a drawn glyph does not clearly fit one of
the five categories above, skip it.

PER-NODE FIELDS:
  id     unique id within THIS response, e.g. "n1", "n2", "n3", ...
  type   one of the 5 labels above
  label  leave null β€” PID2Graph ground truth does not store printed tags
  bbox   REQUIRED. An object {xmin, ymin, xmax, ymax} with each coordinate
         normalized to [0.0, 1.0] relative to the image you are shown.
         (0, 0) is the top-left corner, (1, 1) is the bottom-right. Give
         your best estimate of the symbol's tight bounding box. If you
         cannot estimate it, still emit a best-effort box β€” never omit
         the field and never return null.

EDGES β€” edges in PID2Graph are UNDIRECTED. Emit an edge whenever two of
the semantic nodes above are joined by a continuous pipeline, even if
that pipeline passes through several pipe junctions, crossings, or
arrowheads on its way. Those intermediate points must not appear as
nodes; collapse the whole physical pipeline into a single direct edge
between its two semantic endpoints.

PER-EDGE FIELDS:
  source / target  node ids from the `nodes` list
  type             use "solid"
  label            null

Guidelines:
  * Be exhaustive within the semantic scope.
  * Ignore the title block, legend, border, and revision history.
  * If the image is a cropped region (e.g. one tile of a larger diagram),
    only emit symbols and edges visible within this crop. Do not infer
    connections that continue off-crop β€” another tile will cover them.
  * Return ONLY the JSON object matching the schema β€” no prose, no markdown.
"""

USER_INSTRUCTION = (
    "Extract the semantic-equipment graph from this P&ID image as JSON. "
    "Only the 5 categories listed in the instructions β€” no pipe primitives. "
    "Include a bbox (normalized [0, 1]) for every node."
)


def _guess_media_type(path: Path) -> str:
    ext = path.suffix.lower()
    return {
        ".png": "image/png",
        ".jpg": "image/jpeg",
        ".jpeg": "image/jpeg",
        ".gif": "image/gif",
        ".webp": "image/webp",
    }.get(ext, "image/png")


def _load_image_b64(path: Path) -> tuple[str, str]:
    data = base64.standard_b64encode(path.read_bytes()).decode("utf-8")
    return data, _guess_media_type(path)


def _graphout_schema() -> dict:
    """JSON schema for GraphOut, ready for `output_config.format`."""
    return GraphOut.model_json_schema()


def _call_vlm(
    image_data: str,
    media_type: str,
    client: anthropic.Anthropic,
    model: str,
    max_tokens: int,
    tag: str,
    temperature: float = DEFAULT_TEMPERATURE,
) -> GraphOut:
    """Shared VLM request path used by both whole-image and tile extraction.

    `tag` is a short human-readable identifier (filename, tile name) used
    only to make error messages point at the right sample.

    `temperature=0.0` by default for reproducible evaluation runs.
    """
    schema = _graphout_schema()

    with client.messages.stream(
        model=model,
        max_tokens=max_tokens,
        temperature=temperature,
        system=[
            {
                "type": "text",
                "text": SYSTEM_PROMPT,
                "cache_control": {"type": "ephemeral"},
            }
        ],
        messages=[
            {
                "role": "user",
                "content": [
                    {
                        "type": "image",
                        "source": {
                            "type": "base64",
                            "media_type": media_type,
                            "data": image_data,
                        },
                    },
                    {"type": "text", "text": USER_INSTRUCTION},
                ],
            }
        ],
        output_config={
            "format": {
                "type": "json_schema",
                "schema": schema,
            }
        },
    ) as stream:
        final = stream.get_final_message()

    if final.stop_reason == "refusal":
        raise RuntimeError(f"VLM refused to answer for {tag} (stop_reason=refusal)")
    if final.stop_reason == "max_tokens":
        raise RuntimeError(
            f"VLM hit max_tokens={max_tokens} for {tag} β€” "
            f"output truncated; raise max_tokens or simplify the prompt"
        )

    text_block = next((b for b in final.content if b.type == "text"), None)
    if text_block is None:
        raise RuntimeError(
            f"VLM emitted no text block for {tag} (stop_reason={final.stop_reason})"
        )

    try:
        data = json.loads(text_block.text)
    except json.JSONDecodeError as e:
        raise RuntimeError(
            f"VLM output is not valid JSON for {tag}: {e}\n"
            f"first 500 chars: {text_block.text[:500]}"
        )

    return GraphOut.model_validate(data)


def extract_graph(
    image_path: Path,
    client: Optional[anthropic.Anthropic] = None,
    model: str = DEFAULT_MODEL,
    max_tokens: int = DEFAULT_MAX_TOKENS,
    temperature: float = DEFAULT_TEMPERATURE,
) -> GraphOut:
    """Run the VLM on the whole image and return the parsed graph."""
    client = client or anthropic.Anthropic()
    image_path = Path(image_path)
    image_data, media_type = _load_image_b64(image_path)
    return _call_vlm(
        image_data, media_type, client, model, max_tokens,
        image_path.name, temperature=temperature,
    )


def extract_graph_tiled(
    image_path: Path,
    client: Optional[anthropic.Anthropic] = None,
    model: str = DEFAULT_MODEL,
    max_tokens: int = DEFAULT_MAX_TOKENS,
    temperature: float = DEFAULT_TEMPERATURE,
    rows: int = 2,
    cols: int = 2,
    overlap: float = 0.1,
    dedup_px: float = 40.0,
    seam_filter: bool = True,
    seam_filter_types: tuple[str, ...] = ("inlet/outlet",),
    seam_threshold: float = 50.0,
    edge_threshold: float = 30.0,
) -> dict:
    """Tile the image, extract each tile, merge via bbox-distance dedup.

    Returns a plain dict (not a GraphOut) because merged nodes carry
    extra `bbox` (in global pixel coordinates) and `source_tiles` fields
    that don't fit the response schema. The dict shape is still
    compatible with `metrics.evaluate()`.

    Deduplication rules:
      * Two nodes from different tiles are merged iff they share the
        same `type` AND their bbox centers are within `dedup_px` pixels
        in the un-tiled global image.
      * Nodes with a null bbox are dropped (they can't be deduped).
      * Edges are remapped through the merge map; undirected duplicates
        are collapsed; self-loops (both endpoints merged to the same
        global node) are dropped.
    """
    client = client or anthropic.Anthropic()
    image_path = Path(image_path)
    tiles: list[Tile] = split_image(image_path, rows=rows, cols=cols, overlap=overlap)

    per_tile: list[tuple[GraphOut, Tile]] = []
    for tile in tiles:
        image_data, media_type = tile_to_base64_png(tile)
        tag = f"{image_path.name}:{tile.name}"
        graph = _call_vlm(
            image_data, media_type, client, model, max_tokens,
            tag, temperature=temperature,
        )
        per_tile.append((graph, tile))

    merged = merge_tile_graphs(per_tile, dedup_px=dedup_px)

    if seam_filter:
        merged = filter_seam_artifacts(
            merged,
            tiles,
            types=seam_filter_types,
            seam_threshold=seam_threshold,
            edge_threshold=edge_threshold,
        )

    # Attach provenance for downstream debugging / aggregation.
    merged["tile_stats"] = [
        {
            "tile": tile.name,
            "n_nodes": len(graph.nodes),
            "n_edges": len(graph.edges),
            "nodes_with_bbox": sum(1 for n in graph.nodes if n.bbox is not None),
        }
        for graph, tile in per_tile
    ]
    return merged