"""Image tiling and multi-tile result merging. Rationale: Whole-page P&ID extraction hits a node-recall ceiling around 50-60% on large diagrams because the VLM can only resolve ~50 symbols at ~1.15MP vision downsampling. Splitting the image into overlapping tiles, extracting each, then merging via bbox-distance deduplication lets the model zoom in on smaller regions at full pixel budget. Coordinate conventions: - The VLM sees an individual tile and reports bbox in *normalized* [0, 1] coordinates relative to that tile (see `schema.BBox`). - `merge_tile_graphs` converts each tile-local normalized bbox into global image pixel coordinates and deduplicates across tiles using the Euclidean distance between bbox centers. Overlap: Tiles are grown outward by `overlap` fraction of the un-overlapped tile size on each internal seam so a symbol that happens to straddle a split line appears fully in at least one tile. The dedup step then collapses the two detections into one node. """ from __future__ import annotations import base64 import io from dataclasses import dataclass, field from pathlib import Path from typing import Iterable from PIL import Image from .schema import BBox, GraphOut @dataclass class Tile: """One cropped region of the source image, ready to send to the VLM.""" image: Image.Image x0: int # top-left x in global (full-image) pixel coordinates y0: int # top-left y in global pixel coordinates w: int # tile width in global pixels h: int # tile height in global pixels parent_w: int parent_h: int name: str # short label, e.g. "r0c0" def split_image( image_path: Path | str, rows: int = 2, cols: int = 2, overlap: float = 0.1, ) -> list[Tile]: """Split `image_path` into `rows * cols` overlapping tiles. `overlap` is the fractional grow-outward applied to each seam. With `overlap=0.1` and 2x2 tiling each tile is 60% wide and 60% tall (50% + 10% overlap on the internal edge). Boundary tiles are clipped to the image extent so they never exceed the parent dimensions. """ img = Image.open(Path(image_path)).convert("RGB") W, H = img.size base_tw = W / cols base_th = H / rows ow = int(round(base_tw * overlap)) oh = int(round(base_th * overlap)) tiles: list[Tile] = [] for r in range(rows): for c in range(cols): x0 = max(0, int(round(c * base_tw)) - ow) y0 = max(0, int(round(r * base_th)) - oh) x1 = min(W, int(round((c + 1) * base_tw)) + ow) y1 = min(H, int(round((r + 1) * base_th)) + oh) tile_img = img.crop((x0, y0, x1, y1)) tiles.append( Tile( image=tile_img, x0=x0, y0=y0, w=x1 - x0, h=y1 - y0, parent_w=W, parent_h=H, name=f"r{r}c{c}", ) ) return tiles def tile_to_base64_png(tile: Tile) -> tuple[str, str]: """Encode a tile as base64 PNG for the Messages API `source.data` field.""" buf = io.BytesIO() tile.image.save(buf, format="PNG") return base64.standard_b64encode(buf.getvalue()).decode("utf-8"), "image/png" def _tile_bbox_to_global_px(bbox: BBox, tile: Tile) -> tuple[float, float, float, float]: """Convert a tile-local normalized bbox to global pixel coordinates.""" return ( tile.x0 + bbox.xmin * tile.w, tile.y0 + bbox.ymin * tile.h, tile.x0 + bbox.xmax * tile.w, tile.y0 + bbox.ymax * tile.h, ) def _center(bbox_px: tuple[float, float, float, float]) -> tuple[float, float]: xmin, ymin, xmax, ymax = bbox_px return (xmin + xmax) / 2.0, (ymin + ymax) / 2.0 @dataclass class _MergedNode: id: str type: str label: str | None bbox_px: tuple[float, float, float, float] center: tuple[float, float] source_tiles: list[str] = field(default_factory=list) def merge_tile_graphs( tile_results: Iterable[tuple[GraphOut, Tile]], dedup_px: float = 40.0, ) -> dict: """Merge per-tile predictions into one global graph. Two nodes are considered the same symbol iff they share the same `type` AND their bbox centers are within `dedup_px` pixels in the global (un-tiled) image coordinate space. The first occurrence wins; later duplicates are absorbed and their tile name is recorded on `source_tiles` for debugging. Edges are re-wired through the tile-local → global id map. An edge whose endpoints map to the same global node is dropped, and undirected duplicates across tiles are collapsed. Returns a plain dict matching the shape `metrics.evaluate()` expects, plus a `directed: False` flag. Extra `bbox` and `source_tiles` fields on each node are ignored by the evaluator but useful for visualizing predictions later. """ global_nodes: list[_MergedNode] = [] id_map: dict[tuple[str, str], str] = {} # (tile.name, tile_local_id) -> global id next_id = 1 tile_results = list(tile_results) # --- nodes: dedup by (type, bbox center distance) ------------------- for graph, tile in tile_results: for node in graph.nodes: if node.bbox is None: # No bbox → can't dedupe across tiles safely. Skip. continue bbox_px = _tile_bbox_to_global_px(node.bbox, tile) cx, cy = _center(bbox_px) matched_gid: str | None = None for gn in global_nodes: if gn.type != node.type: continue dx = gn.center[0] - cx dy = gn.center[1] - cy if (dx * dx + dy * dy) ** 0.5 <= dedup_px: matched_gid = gn.id break if matched_gid is None: gid = f"n{next_id}" next_id += 1 global_nodes.append( _MergedNode( id=gid, type=node.type, label=node.label, bbox_px=bbox_px, center=(cx, cy), source_tiles=[tile.name], ) ) id_map[(tile.name, node.id)] = gid else: id_map[(tile.name, node.id)] = matched_gid for gn in global_nodes: if gn.id == matched_gid: if tile.name not in gn.source_tiles: gn.source_tiles.append(tile.name) break # --- edges: remap ids, drop self-loops and duplicates --------------- edge_set: set[tuple[str, str]] = set() final_edges: list[dict] = [] for graph, tile in tile_results: for edge in graph.edges: src = id_map.get((tile.name, edge.source)) tgt = id_map.get((tile.name, edge.target)) if src is None or tgt is None: continue if src == tgt: continue a, b = sorted((src, tgt)) # undirected canonicalization key = (a, b) if key in edge_set: continue edge_set.add(key) final_edges.append( { "source": a, "target": b, "type": edge.type or "solid", "label": edge.label, } ) out_nodes = [ { "id": gn.id, "type": gn.type, "label": gn.label, "bbox": list(gn.bbox_px), "source_tiles": gn.source_tiles, } for gn in global_nodes ] return { "nodes": out_nodes, "edges": final_edges, "directed": False, } def _inner_seam_lines(tiles: list[Tile]) -> tuple[list[float], list[float]]: """Return (vertical_seam_xs, horizontal_seam_ys) for inner tile borders. For 2x2 tiling with overlap, each internal seam corresponds to TWO pixel x-coordinates: where one tile ends and where the adjacent tile begins. Both are returned so the FP filter can match either edge of the overlap band. """ if not tiles: return [], [] parent_w = tiles[0].parent_w parent_h = tiles[0].parent_h vx: set[float] = set() hy: set[float] = set() for t in tiles: if t.x0 > 0: vx.add(float(t.x0)) if t.x0 + t.w < parent_w: vx.add(float(t.x0 + t.w)) if t.y0 > 0: hy.add(float(t.y0)) if t.y0 + t.h < parent_h: hy.add(float(t.y0 + t.h)) return sorted(vx), sorted(hy) def filter_seam_artifacts( merged: dict, tiles: list[Tile], types: tuple[str, ...] = ("inlet/outlet",), seam_threshold: float = 50.0, edge_threshold: float = 30.0, ) -> dict: """Drop nodes that look like tile-boundary false positives. A node is filtered iff ALL of the following hold: * its `type` is in `types` (default: inlet/outlet only), * its bbox center lies within `seam_threshold` pixels of any inner tile seam (vertical OR horizontal), AND * its bbox center is NOT within `edge_threshold` pixels of the outer image border (real boundary inlet/outlets stay). Edges referencing dropped nodes are also removed. The dropped nodes are recorded under `seam_filtered` so the report can show what was pruned. Nodes whose `type` is not in `types`, or that have no `bbox`, are passed through untouched. """ if not tiles: return merged parent_w = tiles[0].parent_w parent_h = tiles[0].parent_h vx, hy = _inner_seam_lines(tiles) types_set = set(types) keep_ids: set[str] = set() new_nodes: list[dict] = [] dropped: list[dict] = [] for n in merged["nodes"]: if n.get("type") not in types_set: keep_ids.add(n["id"]) new_nodes.append(n) continue bbox = n.get("bbox") if not bbox or len(bbox) != 4: keep_ids.add(n["id"]) new_nodes.append(n) continue cx = (bbox[0] + bbox[2]) / 2.0 cy = (bbox[1] + bbox[3]) / 2.0 near_outer = ( cx < edge_threshold or cx > parent_w - edge_threshold or cy < edge_threshold or cy > parent_h - edge_threshold ) near_seam_x = any(abs(cx - x) <= seam_threshold for x in vx) near_seam_y = any(abs(cy - y) <= seam_threshold for y in hy) near_seam = near_seam_x or near_seam_y if near_seam and not near_outer: dropped.append({"id": n["id"], "type": n["type"], "bbox": bbox}) continue keep_ids.add(n["id"]) new_nodes.append(n) new_edges = [ e for e in merged["edges"] if e["source"] in keep_ids and e["target"] in keep_ids ] out = dict(merged) out["nodes"] = new_nodes out["edges"] = new_edges out["seam_filtered"] = dropped return out