Spaces:
Running
Running
| """Image tiling and multi-tile result merging. | |
| Rationale: | |
| Whole-page P&ID extraction hits a node-recall ceiling around 50-60% | |
| on large diagrams because the VLM can only resolve ~50 symbols at | |
| ~1.15MP vision downsampling. Splitting the image into overlapping | |
| tiles, extracting each, then merging via bbox-distance deduplication | |
| lets the model zoom in on smaller regions at full pixel budget. | |
| Coordinate conventions: | |
| - The VLM sees an individual tile and reports bbox in *normalized* | |
| [0, 1] coordinates relative to that tile (see `schema.BBox`). | |
| - `merge_tile_graphs` converts each tile-local normalized bbox into | |
| global image pixel coordinates and deduplicates across tiles using | |
| the Euclidean distance between bbox centers. | |
| Overlap: | |
| Tiles are grown outward by `overlap` fraction of the un-overlapped | |
| tile size on each internal seam so a symbol that happens to straddle | |
| a split line appears fully in at least one tile. The dedup step then | |
| collapses the two detections into one node. | |
| """ | |
| from __future__ import annotations | |
| import base64 | |
| import io | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import Iterable | |
| from PIL import Image | |
| from .schema import BBox, GraphOut | |
| class Tile: | |
| """One cropped region of the source image, ready to send to the VLM.""" | |
| image: Image.Image | |
| x0: int # top-left x in global (full-image) pixel coordinates | |
| y0: int # top-left y in global pixel coordinates | |
| w: int # tile width in global pixels | |
| h: int # tile height in global pixels | |
| parent_w: int | |
| parent_h: int | |
| name: str # short label, e.g. "r0c0" | |
| def split_image( | |
| image_path: Path | str, | |
| rows: int = 2, | |
| cols: int = 2, | |
| overlap: float = 0.1, | |
| ) -> list[Tile]: | |
| """Split `image_path` into `rows * cols` overlapping tiles. | |
| `overlap` is the fractional grow-outward applied to each seam. With | |
| `overlap=0.1` and 2x2 tiling each tile is 60% wide and 60% tall | |
| (50% + 10% overlap on the internal edge). Boundary tiles are clipped | |
| to the image extent so they never exceed the parent dimensions. | |
| """ | |
| img = Image.open(Path(image_path)).convert("RGB") | |
| W, H = img.size | |
| base_tw = W / cols | |
| base_th = H / rows | |
| ow = int(round(base_tw * overlap)) | |
| oh = int(round(base_th * overlap)) | |
| tiles: list[Tile] = [] | |
| for r in range(rows): | |
| for c in range(cols): | |
| x0 = max(0, int(round(c * base_tw)) - ow) | |
| y0 = max(0, int(round(r * base_th)) - oh) | |
| x1 = min(W, int(round((c + 1) * base_tw)) + ow) | |
| y1 = min(H, int(round((r + 1) * base_th)) + oh) | |
| tile_img = img.crop((x0, y0, x1, y1)) | |
| tiles.append( | |
| Tile( | |
| image=tile_img, | |
| x0=x0, | |
| y0=y0, | |
| w=x1 - x0, | |
| h=y1 - y0, | |
| parent_w=W, | |
| parent_h=H, | |
| name=f"r{r}c{c}", | |
| ) | |
| ) | |
| return tiles | |
| def tile_to_base64_png(tile: Tile) -> tuple[str, str]: | |
| """Encode a tile as base64 PNG for the Messages API `source.data` field.""" | |
| buf = io.BytesIO() | |
| tile.image.save(buf, format="PNG") | |
| return base64.standard_b64encode(buf.getvalue()).decode("utf-8"), "image/png" | |
| def _tile_bbox_to_global_px(bbox: BBox, tile: Tile) -> tuple[float, float, float, float]: | |
| """Convert a tile-local normalized bbox to global pixel coordinates.""" | |
| return ( | |
| tile.x0 + bbox.xmin * tile.w, | |
| tile.y0 + bbox.ymin * tile.h, | |
| tile.x0 + bbox.xmax * tile.w, | |
| tile.y0 + bbox.ymax * tile.h, | |
| ) | |
| def _center(bbox_px: tuple[float, float, float, float]) -> tuple[float, float]: | |
| xmin, ymin, xmax, ymax = bbox_px | |
| return (xmin + xmax) / 2.0, (ymin + ymax) / 2.0 | |
| class _MergedNode: | |
| id: str | |
| type: str | |
| label: str | None | |
| bbox_px: tuple[float, float, float, float] | |
| center: tuple[float, float] | |
| source_tiles: list[str] = field(default_factory=list) | |
| def merge_tile_graphs( | |
| tile_results: Iterable[tuple[GraphOut, Tile]], | |
| dedup_px: float = 40.0, | |
| ) -> dict: | |
| """Merge per-tile predictions into one global graph. | |
| Two nodes are considered the same symbol iff they share the same | |
| `type` AND their bbox centers are within `dedup_px` pixels in the | |
| global (un-tiled) image coordinate space. The first occurrence | |
| wins; later duplicates are absorbed and their tile name is recorded | |
| on `source_tiles` for debugging. | |
| Edges are re-wired through the tile-local → global id map. An edge | |
| whose endpoints map to the same global node is dropped, and | |
| undirected duplicates across tiles are collapsed. | |
| Returns a plain dict matching the shape `metrics.evaluate()` expects, | |
| plus a `directed: False` flag. Extra `bbox` and `source_tiles` fields | |
| on each node are ignored by the evaluator but useful for visualizing | |
| predictions later. | |
| """ | |
| global_nodes: list[_MergedNode] = [] | |
| id_map: dict[tuple[str, str], str] = {} # (tile.name, tile_local_id) -> global id | |
| next_id = 1 | |
| tile_results = list(tile_results) | |
| # --- nodes: dedup by (type, bbox center distance) ------------------- | |
| for graph, tile in tile_results: | |
| for node in graph.nodes: | |
| if node.bbox is None: | |
| # No bbox → can't dedupe across tiles safely. Skip. | |
| continue | |
| bbox_px = _tile_bbox_to_global_px(node.bbox, tile) | |
| cx, cy = _center(bbox_px) | |
| matched_gid: str | None = None | |
| for gn in global_nodes: | |
| if gn.type != node.type: | |
| continue | |
| dx = gn.center[0] - cx | |
| dy = gn.center[1] - cy | |
| if (dx * dx + dy * dy) ** 0.5 <= dedup_px: | |
| matched_gid = gn.id | |
| break | |
| if matched_gid is None: | |
| gid = f"n{next_id}" | |
| next_id += 1 | |
| global_nodes.append( | |
| _MergedNode( | |
| id=gid, | |
| type=node.type, | |
| label=node.label, | |
| bbox_px=bbox_px, | |
| center=(cx, cy), | |
| source_tiles=[tile.name], | |
| ) | |
| ) | |
| id_map[(tile.name, node.id)] = gid | |
| else: | |
| id_map[(tile.name, node.id)] = matched_gid | |
| for gn in global_nodes: | |
| if gn.id == matched_gid: | |
| if tile.name not in gn.source_tiles: | |
| gn.source_tiles.append(tile.name) | |
| break | |
| # --- edges: remap ids, drop self-loops and duplicates --------------- | |
| edge_set: set[tuple[str, str]] = set() | |
| final_edges: list[dict] = [] | |
| for graph, tile in tile_results: | |
| for edge in graph.edges: | |
| src = id_map.get((tile.name, edge.source)) | |
| tgt = id_map.get((tile.name, edge.target)) | |
| if src is None or tgt is None: | |
| continue | |
| if src == tgt: | |
| continue | |
| a, b = sorted((src, tgt)) # undirected canonicalization | |
| key = (a, b) | |
| if key in edge_set: | |
| continue | |
| edge_set.add(key) | |
| final_edges.append( | |
| { | |
| "source": a, | |
| "target": b, | |
| "type": edge.type or "solid", | |
| "label": edge.label, | |
| } | |
| ) | |
| out_nodes = [ | |
| { | |
| "id": gn.id, | |
| "type": gn.type, | |
| "label": gn.label, | |
| "bbox": list(gn.bbox_px), | |
| "source_tiles": gn.source_tiles, | |
| } | |
| for gn in global_nodes | |
| ] | |
| return { | |
| "nodes": out_nodes, | |
| "edges": final_edges, | |
| "directed": False, | |
| } | |
| def _inner_seam_lines(tiles: list[Tile]) -> tuple[list[float], list[float]]: | |
| """Return (vertical_seam_xs, horizontal_seam_ys) for inner tile borders. | |
| For 2x2 tiling with overlap, each internal seam corresponds to TWO | |
| pixel x-coordinates: where one tile ends and where the adjacent tile | |
| begins. Both are returned so the FP filter can match either edge of | |
| the overlap band. | |
| """ | |
| if not tiles: | |
| return [], [] | |
| parent_w = tiles[0].parent_w | |
| parent_h = tiles[0].parent_h | |
| vx: set[float] = set() | |
| hy: set[float] = set() | |
| for t in tiles: | |
| if t.x0 > 0: | |
| vx.add(float(t.x0)) | |
| if t.x0 + t.w < parent_w: | |
| vx.add(float(t.x0 + t.w)) | |
| if t.y0 > 0: | |
| hy.add(float(t.y0)) | |
| if t.y0 + t.h < parent_h: | |
| hy.add(float(t.y0 + t.h)) | |
| return sorted(vx), sorted(hy) | |
| def filter_seam_artifacts( | |
| merged: dict, | |
| tiles: list[Tile], | |
| types: tuple[str, ...] = ("inlet/outlet",), | |
| seam_threshold: float = 50.0, | |
| edge_threshold: float = 30.0, | |
| ) -> dict: | |
| """Drop nodes that look like tile-boundary false positives. | |
| A node is filtered iff ALL of the following hold: | |
| * its `type` is in `types` (default: inlet/outlet only), | |
| * its bbox center lies within `seam_threshold` pixels of any inner | |
| tile seam (vertical OR horizontal), AND | |
| * its bbox center is NOT within `edge_threshold` pixels of the | |
| outer image border (real boundary inlet/outlets stay). | |
| Edges referencing dropped nodes are also removed. The dropped nodes | |
| are recorded under `seam_filtered` so the report can show what was | |
| pruned. | |
| Nodes whose `type` is not in `types`, or that have no `bbox`, are | |
| passed through untouched. | |
| """ | |
| if not tiles: | |
| return merged | |
| parent_w = tiles[0].parent_w | |
| parent_h = tiles[0].parent_h | |
| vx, hy = _inner_seam_lines(tiles) | |
| types_set = set(types) | |
| keep_ids: set[str] = set() | |
| new_nodes: list[dict] = [] | |
| dropped: list[dict] = [] | |
| for n in merged["nodes"]: | |
| if n.get("type") not in types_set: | |
| keep_ids.add(n["id"]) | |
| new_nodes.append(n) | |
| continue | |
| bbox = n.get("bbox") | |
| if not bbox or len(bbox) != 4: | |
| keep_ids.add(n["id"]) | |
| new_nodes.append(n) | |
| continue | |
| cx = (bbox[0] + bbox[2]) / 2.0 | |
| cy = (bbox[1] + bbox[3]) / 2.0 | |
| near_outer = ( | |
| cx < edge_threshold | |
| or cx > parent_w - edge_threshold | |
| or cy < edge_threshold | |
| or cy > parent_h - edge_threshold | |
| ) | |
| near_seam_x = any(abs(cx - x) <= seam_threshold for x in vx) | |
| near_seam_y = any(abs(cy - y) <= seam_threshold for y in hy) | |
| near_seam = near_seam_x or near_seam_y | |
| if near_seam and not near_outer: | |
| dropped.append({"id": n["id"], "type": n["type"], "bbox": bbox}) | |
| continue | |
| keep_ids.add(n["id"]) | |
| new_nodes.append(n) | |
| new_edges = [ | |
| e | |
| for e in merged["edges"] | |
| if e["source"] in keep_ids and e["target"] in keep_ids | |
| ] | |
| out = dict(merged) | |
| out["nodes"] = new_nodes | |
| out["edges"] = new_edges | |
| out["seam_filtered"] = dropped | |
| return out | |