File size: 11,232 Bytes
59fa244
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
"""Image tiling and multi-tile result merging.

Rationale:
    Whole-page P&ID extraction hits a node-recall ceiling around 50-60%
    on large diagrams because the VLM can only resolve ~50 symbols at
    ~1.15MP vision downsampling. Splitting the image into overlapping
    tiles, extracting each, then merging via bbox-distance deduplication
    lets the model zoom in on smaller regions at full pixel budget.

Coordinate conventions:
    - The VLM sees an individual tile and reports bbox in *normalized*
      [0, 1] coordinates relative to that tile (see `schema.BBox`).
    - `merge_tile_graphs` converts each tile-local normalized bbox into
      global image pixel coordinates and deduplicates across tiles using
      the Euclidean distance between bbox centers.

Overlap:
    Tiles are grown outward by `overlap` fraction of the un-overlapped
    tile size on each internal seam so a symbol that happens to straddle
    a split line appears fully in at least one tile. The dedup step then
    collapses the two detections into one node.
"""

from __future__ import annotations

import base64
import io
from dataclasses import dataclass, field
from pathlib import Path
from typing import Iterable

from PIL import Image

from .schema import BBox, GraphOut


@dataclass
class Tile:
    """One cropped region of the source image, ready to send to the VLM."""

    image: Image.Image
    x0: int  # top-left x in global (full-image) pixel coordinates
    y0: int  # top-left y in global pixel coordinates
    w: int   # tile width in global pixels
    h: int   # tile height in global pixels
    parent_w: int
    parent_h: int
    name: str  # short label, e.g. "r0c0"


def split_image(
    image_path: Path | str,
    rows: int = 2,
    cols: int = 2,
    overlap: float = 0.1,
) -> list[Tile]:
    """Split `image_path` into `rows * cols` overlapping tiles.

    `overlap` is the fractional grow-outward applied to each seam. With
    `overlap=0.1` and 2x2 tiling each tile is 60% wide and 60% tall
    (50% + 10% overlap on the internal edge). Boundary tiles are clipped
    to the image extent so they never exceed the parent dimensions.
    """
    img = Image.open(Path(image_path)).convert("RGB")
    W, H = img.size

    base_tw = W / cols
    base_th = H / rows
    ow = int(round(base_tw * overlap))
    oh = int(round(base_th * overlap))

    tiles: list[Tile] = []
    for r in range(rows):
        for c in range(cols):
            x0 = max(0, int(round(c * base_tw)) - ow)
            y0 = max(0, int(round(r * base_th)) - oh)
            x1 = min(W, int(round((c + 1) * base_tw)) + ow)
            y1 = min(H, int(round((r + 1) * base_th)) + oh)
            tile_img = img.crop((x0, y0, x1, y1))
            tiles.append(
                Tile(
                    image=tile_img,
                    x0=x0,
                    y0=y0,
                    w=x1 - x0,
                    h=y1 - y0,
                    parent_w=W,
                    parent_h=H,
                    name=f"r{r}c{c}",
                )
            )
    return tiles


def tile_to_base64_png(tile: Tile) -> tuple[str, str]:
    """Encode a tile as base64 PNG for the Messages API `source.data` field."""
    buf = io.BytesIO()
    tile.image.save(buf, format="PNG")
    return base64.standard_b64encode(buf.getvalue()).decode("utf-8"), "image/png"


def _tile_bbox_to_global_px(bbox: BBox, tile: Tile) -> tuple[float, float, float, float]:
    """Convert a tile-local normalized bbox to global pixel coordinates."""
    return (
        tile.x0 + bbox.xmin * tile.w,
        tile.y0 + bbox.ymin * tile.h,
        tile.x0 + bbox.xmax * tile.w,
        tile.y0 + bbox.ymax * tile.h,
    )


def _center(bbox_px: tuple[float, float, float, float]) -> tuple[float, float]:
    xmin, ymin, xmax, ymax = bbox_px
    return (xmin + xmax) / 2.0, (ymin + ymax) / 2.0


@dataclass
class _MergedNode:
    id: str
    type: str
    label: str | None
    bbox_px: tuple[float, float, float, float]
    center: tuple[float, float]
    source_tiles: list[str] = field(default_factory=list)


def merge_tile_graphs(
    tile_results: Iterable[tuple[GraphOut, Tile]],
    dedup_px: float = 40.0,
) -> dict:
    """Merge per-tile predictions into one global graph.

    Two nodes are considered the same symbol iff they share the same
    `type` AND their bbox centers are within `dedup_px` pixels in the
    global (un-tiled) image coordinate space. The first occurrence
    wins; later duplicates are absorbed and their tile name is recorded
    on `source_tiles` for debugging.

    Edges are re-wired through the tile-local → global id map. An edge
    whose endpoints map to the same global node is dropped, and
    undirected duplicates across tiles are collapsed.

    Returns a plain dict matching the shape `metrics.evaluate()` expects,
    plus a `directed: False` flag. Extra `bbox` and `source_tiles` fields
    on each node are ignored by the evaluator but useful for visualizing
    predictions later.
    """
    global_nodes: list[_MergedNode] = []
    id_map: dict[tuple[str, str], str] = {}  # (tile.name, tile_local_id) -> global id
    next_id = 1

    tile_results = list(tile_results)

    # --- nodes: dedup by (type, bbox center distance) -------------------
    for graph, tile in tile_results:
        for node in graph.nodes:
            if node.bbox is None:
                # No bbox → can't dedupe across tiles safely. Skip.
                continue
            bbox_px = _tile_bbox_to_global_px(node.bbox, tile)
            cx, cy = _center(bbox_px)

            matched_gid: str | None = None
            for gn in global_nodes:
                if gn.type != node.type:
                    continue
                dx = gn.center[0] - cx
                dy = gn.center[1] - cy
                if (dx * dx + dy * dy) ** 0.5 <= dedup_px:
                    matched_gid = gn.id
                    break

            if matched_gid is None:
                gid = f"n{next_id}"
                next_id += 1
                global_nodes.append(
                    _MergedNode(
                        id=gid,
                        type=node.type,
                        label=node.label,
                        bbox_px=bbox_px,
                        center=(cx, cy),
                        source_tiles=[tile.name],
                    )
                )
                id_map[(tile.name, node.id)] = gid
            else:
                id_map[(tile.name, node.id)] = matched_gid
                for gn in global_nodes:
                    if gn.id == matched_gid:
                        if tile.name not in gn.source_tiles:
                            gn.source_tiles.append(tile.name)
                        break

    # --- edges: remap ids, drop self-loops and duplicates ---------------
    edge_set: set[tuple[str, str]] = set()
    final_edges: list[dict] = []
    for graph, tile in tile_results:
        for edge in graph.edges:
            src = id_map.get((tile.name, edge.source))
            tgt = id_map.get((tile.name, edge.target))
            if src is None or tgt is None:
                continue
            if src == tgt:
                continue
            a, b = sorted((src, tgt))  # undirected canonicalization
            key = (a, b)
            if key in edge_set:
                continue
            edge_set.add(key)
            final_edges.append(
                {
                    "source": a,
                    "target": b,
                    "type": edge.type or "solid",
                    "label": edge.label,
                }
            )

    out_nodes = [
        {
            "id": gn.id,
            "type": gn.type,
            "label": gn.label,
            "bbox": list(gn.bbox_px),
            "source_tiles": gn.source_tiles,
        }
        for gn in global_nodes
    ]

    return {
        "nodes": out_nodes,
        "edges": final_edges,
        "directed": False,
    }


def _inner_seam_lines(tiles: list[Tile]) -> tuple[list[float], list[float]]:
    """Return (vertical_seam_xs, horizontal_seam_ys) for inner tile borders.

    For 2x2 tiling with overlap, each internal seam corresponds to TWO
    pixel x-coordinates: where one tile ends and where the adjacent tile
    begins. Both are returned so the FP filter can match either edge of
    the overlap band.
    """
    if not tiles:
        return [], []
    parent_w = tiles[0].parent_w
    parent_h = tiles[0].parent_h
    vx: set[float] = set()
    hy: set[float] = set()
    for t in tiles:
        if t.x0 > 0:
            vx.add(float(t.x0))
        if t.x0 + t.w < parent_w:
            vx.add(float(t.x0 + t.w))
        if t.y0 > 0:
            hy.add(float(t.y0))
        if t.y0 + t.h < parent_h:
            hy.add(float(t.y0 + t.h))
    return sorted(vx), sorted(hy)


def filter_seam_artifacts(
    merged: dict,
    tiles: list[Tile],
    types: tuple[str, ...] = ("inlet/outlet",),
    seam_threshold: float = 50.0,
    edge_threshold: float = 30.0,
) -> dict:
    """Drop nodes that look like tile-boundary false positives.

    A node is filtered iff ALL of the following hold:
      * its `type` is in `types` (default: inlet/outlet only),
      * its bbox center lies within `seam_threshold` pixels of any inner
        tile seam (vertical OR horizontal), AND
      * its bbox center is NOT within `edge_threshold` pixels of the
        outer image border (real boundary inlet/outlets stay).

    Edges referencing dropped nodes are also removed. The dropped nodes
    are recorded under `seam_filtered` so the report can show what was
    pruned.

    Nodes whose `type` is not in `types`, or that have no `bbox`, are
    passed through untouched.
    """
    if not tiles:
        return merged

    parent_w = tiles[0].parent_w
    parent_h = tiles[0].parent_h
    vx, hy = _inner_seam_lines(tiles)
    types_set = set(types)

    keep_ids: set[str] = set()
    new_nodes: list[dict] = []
    dropped: list[dict] = []

    for n in merged["nodes"]:
        if n.get("type") not in types_set:
            keep_ids.add(n["id"])
            new_nodes.append(n)
            continue

        bbox = n.get("bbox")
        if not bbox or len(bbox) != 4:
            keep_ids.add(n["id"])
            new_nodes.append(n)
            continue

        cx = (bbox[0] + bbox[2]) / 2.0
        cy = (bbox[1] + bbox[3]) / 2.0

        near_outer = (
            cx < edge_threshold
            or cx > parent_w - edge_threshold
            or cy < edge_threshold
            or cy > parent_h - edge_threshold
        )
        near_seam_x = any(abs(cx - x) <= seam_threshold for x in vx)
        near_seam_y = any(abs(cy - y) <= seam_threshold for y in hy)
        near_seam = near_seam_x or near_seam_y

        if near_seam and not near_outer:
            dropped.append({"id": n["id"], "type": n["type"], "bbox": bbox})
            continue

        keep_ids.add(n["id"])
        new_nodes.append(n)

    new_edges = [
        e
        for e in merged["edges"]
        if e["source"] in keep_ids and e["target"] in keep_ids
    ]

    out = dict(merged)
    out["nodes"] = new_nodes
    out["edges"] = new_edges
    out["seam_filtered"] = dropped
    return out