Spaces:
Running
Running
File size: 11,232 Bytes
59fa244 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 | """Image tiling and multi-tile result merging.
Rationale:
Whole-page P&ID extraction hits a node-recall ceiling around 50-60%
on large diagrams because the VLM can only resolve ~50 symbols at
~1.15MP vision downsampling. Splitting the image into overlapping
tiles, extracting each, then merging via bbox-distance deduplication
lets the model zoom in on smaller regions at full pixel budget.
Coordinate conventions:
- The VLM sees an individual tile and reports bbox in *normalized*
[0, 1] coordinates relative to that tile (see `schema.BBox`).
- `merge_tile_graphs` converts each tile-local normalized bbox into
global image pixel coordinates and deduplicates across tiles using
the Euclidean distance between bbox centers.
Overlap:
Tiles are grown outward by `overlap` fraction of the un-overlapped
tile size on each internal seam so a symbol that happens to straddle
a split line appears fully in at least one tile. The dedup step then
collapses the two detections into one node.
"""
from __future__ import annotations
import base64
import io
from dataclasses import dataclass, field
from pathlib import Path
from typing import Iterable
from PIL import Image
from .schema import BBox, GraphOut
@dataclass
class Tile:
"""One cropped region of the source image, ready to send to the VLM."""
image: Image.Image
x0: int # top-left x in global (full-image) pixel coordinates
y0: int # top-left y in global pixel coordinates
w: int # tile width in global pixels
h: int # tile height in global pixels
parent_w: int
parent_h: int
name: str # short label, e.g. "r0c0"
def split_image(
image_path: Path | str,
rows: int = 2,
cols: int = 2,
overlap: float = 0.1,
) -> list[Tile]:
"""Split `image_path` into `rows * cols` overlapping tiles.
`overlap` is the fractional grow-outward applied to each seam. With
`overlap=0.1` and 2x2 tiling each tile is 60% wide and 60% tall
(50% + 10% overlap on the internal edge). Boundary tiles are clipped
to the image extent so they never exceed the parent dimensions.
"""
img = Image.open(Path(image_path)).convert("RGB")
W, H = img.size
base_tw = W / cols
base_th = H / rows
ow = int(round(base_tw * overlap))
oh = int(round(base_th * overlap))
tiles: list[Tile] = []
for r in range(rows):
for c in range(cols):
x0 = max(0, int(round(c * base_tw)) - ow)
y0 = max(0, int(round(r * base_th)) - oh)
x1 = min(W, int(round((c + 1) * base_tw)) + ow)
y1 = min(H, int(round((r + 1) * base_th)) + oh)
tile_img = img.crop((x0, y0, x1, y1))
tiles.append(
Tile(
image=tile_img,
x0=x0,
y0=y0,
w=x1 - x0,
h=y1 - y0,
parent_w=W,
parent_h=H,
name=f"r{r}c{c}",
)
)
return tiles
def tile_to_base64_png(tile: Tile) -> tuple[str, str]:
"""Encode a tile as base64 PNG for the Messages API `source.data` field."""
buf = io.BytesIO()
tile.image.save(buf, format="PNG")
return base64.standard_b64encode(buf.getvalue()).decode("utf-8"), "image/png"
def _tile_bbox_to_global_px(bbox: BBox, tile: Tile) -> tuple[float, float, float, float]:
"""Convert a tile-local normalized bbox to global pixel coordinates."""
return (
tile.x0 + bbox.xmin * tile.w,
tile.y0 + bbox.ymin * tile.h,
tile.x0 + bbox.xmax * tile.w,
tile.y0 + bbox.ymax * tile.h,
)
def _center(bbox_px: tuple[float, float, float, float]) -> tuple[float, float]:
xmin, ymin, xmax, ymax = bbox_px
return (xmin + xmax) / 2.0, (ymin + ymax) / 2.0
@dataclass
class _MergedNode:
id: str
type: str
label: str | None
bbox_px: tuple[float, float, float, float]
center: tuple[float, float]
source_tiles: list[str] = field(default_factory=list)
def merge_tile_graphs(
tile_results: Iterable[tuple[GraphOut, Tile]],
dedup_px: float = 40.0,
) -> dict:
"""Merge per-tile predictions into one global graph.
Two nodes are considered the same symbol iff they share the same
`type` AND their bbox centers are within `dedup_px` pixels in the
global (un-tiled) image coordinate space. The first occurrence
wins; later duplicates are absorbed and their tile name is recorded
on `source_tiles` for debugging.
Edges are re-wired through the tile-local → global id map. An edge
whose endpoints map to the same global node is dropped, and
undirected duplicates across tiles are collapsed.
Returns a plain dict matching the shape `metrics.evaluate()` expects,
plus a `directed: False` flag. Extra `bbox` and `source_tiles` fields
on each node are ignored by the evaluator but useful for visualizing
predictions later.
"""
global_nodes: list[_MergedNode] = []
id_map: dict[tuple[str, str], str] = {} # (tile.name, tile_local_id) -> global id
next_id = 1
tile_results = list(tile_results)
# --- nodes: dedup by (type, bbox center distance) -------------------
for graph, tile in tile_results:
for node in graph.nodes:
if node.bbox is None:
# No bbox → can't dedupe across tiles safely. Skip.
continue
bbox_px = _tile_bbox_to_global_px(node.bbox, tile)
cx, cy = _center(bbox_px)
matched_gid: str | None = None
for gn in global_nodes:
if gn.type != node.type:
continue
dx = gn.center[0] - cx
dy = gn.center[1] - cy
if (dx * dx + dy * dy) ** 0.5 <= dedup_px:
matched_gid = gn.id
break
if matched_gid is None:
gid = f"n{next_id}"
next_id += 1
global_nodes.append(
_MergedNode(
id=gid,
type=node.type,
label=node.label,
bbox_px=bbox_px,
center=(cx, cy),
source_tiles=[tile.name],
)
)
id_map[(tile.name, node.id)] = gid
else:
id_map[(tile.name, node.id)] = matched_gid
for gn in global_nodes:
if gn.id == matched_gid:
if tile.name not in gn.source_tiles:
gn.source_tiles.append(tile.name)
break
# --- edges: remap ids, drop self-loops and duplicates ---------------
edge_set: set[tuple[str, str]] = set()
final_edges: list[dict] = []
for graph, tile in tile_results:
for edge in graph.edges:
src = id_map.get((tile.name, edge.source))
tgt = id_map.get((tile.name, edge.target))
if src is None or tgt is None:
continue
if src == tgt:
continue
a, b = sorted((src, tgt)) # undirected canonicalization
key = (a, b)
if key in edge_set:
continue
edge_set.add(key)
final_edges.append(
{
"source": a,
"target": b,
"type": edge.type or "solid",
"label": edge.label,
}
)
out_nodes = [
{
"id": gn.id,
"type": gn.type,
"label": gn.label,
"bbox": list(gn.bbox_px),
"source_tiles": gn.source_tiles,
}
for gn in global_nodes
]
return {
"nodes": out_nodes,
"edges": final_edges,
"directed": False,
}
def _inner_seam_lines(tiles: list[Tile]) -> tuple[list[float], list[float]]:
"""Return (vertical_seam_xs, horizontal_seam_ys) for inner tile borders.
For 2x2 tiling with overlap, each internal seam corresponds to TWO
pixel x-coordinates: where one tile ends and where the adjacent tile
begins. Both are returned so the FP filter can match either edge of
the overlap band.
"""
if not tiles:
return [], []
parent_w = tiles[0].parent_w
parent_h = tiles[0].parent_h
vx: set[float] = set()
hy: set[float] = set()
for t in tiles:
if t.x0 > 0:
vx.add(float(t.x0))
if t.x0 + t.w < parent_w:
vx.add(float(t.x0 + t.w))
if t.y0 > 0:
hy.add(float(t.y0))
if t.y0 + t.h < parent_h:
hy.add(float(t.y0 + t.h))
return sorted(vx), sorted(hy)
def filter_seam_artifacts(
merged: dict,
tiles: list[Tile],
types: tuple[str, ...] = ("inlet/outlet",),
seam_threshold: float = 50.0,
edge_threshold: float = 30.0,
) -> dict:
"""Drop nodes that look like tile-boundary false positives.
A node is filtered iff ALL of the following hold:
* its `type` is in `types` (default: inlet/outlet only),
* its bbox center lies within `seam_threshold` pixels of any inner
tile seam (vertical OR horizontal), AND
* its bbox center is NOT within `edge_threshold` pixels of the
outer image border (real boundary inlet/outlets stay).
Edges referencing dropped nodes are also removed. The dropped nodes
are recorded under `seam_filtered` so the report can show what was
pruned.
Nodes whose `type` is not in `types`, or that have no `bbox`, are
passed through untouched.
"""
if not tiles:
return merged
parent_w = tiles[0].parent_w
parent_h = tiles[0].parent_h
vx, hy = _inner_seam_lines(tiles)
types_set = set(types)
keep_ids: set[str] = set()
new_nodes: list[dict] = []
dropped: list[dict] = []
for n in merged["nodes"]:
if n.get("type") not in types_set:
keep_ids.add(n["id"])
new_nodes.append(n)
continue
bbox = n.get("bbox")
if not bbox or len(bbox) != 4:
keep_ids.add(n["id"])
new_nodes.append(n)
continue
cx = (bbox[0] + bbox[2]) / 2.0
cy = (bbox[1] + bbox[3]) / 2.0
near_outer = (
cx < edge_threshold
or cx > parent_w - edge_threshold
or cy < edge_threshold
or cy > parent_h - edge_threshold
)
near_seam_x = any(abs(cx - x) <= seam_threshold for x in vx)
near_seam_y = any(abs(cy - y) <= seam_threshold for y in hy)
near_seam = near_seam_x or near_seam_y
if near_seam and not near_outer:
dropped.append({"id": n["id"], "type": n["type"], "bbox": bbox})
continue
keep_ids.add(n["id"])
new_nodes.append(n)
new_edges = [
e
for e in merged["edges"]
if e["source"] in keep_ids and e["target"] in keep_ids
]
out = dict(merged)
out["nodes"] = new_nodes
out["edges"] = new_edges
out["seam_filtered"] = dropped
return out
|