deepkick's picture
Initial commit: PID2Graph × Claude VLM evaluation + Gradio demo
59fa244
"""Image tiling and multi-tile result merging.
Rationale:
Whole-page P&ID extraction hits a node-recall ceiling around 50-60%
on large diagrams because the VLM can only resolve ~50 symbols at
~1.15MP vision downsampling. Splitting the image into overlapping
tiles, extracting each, then merging via bbox-distance deduplication
lets the model zoom in on smaller regions at full pixel budget.
Coordinate conventions:
- The VLM sees an individual tile and reports bbox in *normalized*
[0, 1] coordinates relative to that tile (see `schema.BBox`).
- `merge_tile_graphs` converts each tile-local normalized bbox into
global image pixel coordinates and deduplicates across tiles using
the Euclidean distance between bbox centers.
Overlap:
Tiles are grown outward by `overlap` fraction of the un-overlapped
tile size on each internal seam so a symbol that happens to straddle
a split line appears fully in at least one tile. The dedup step then
collapses the two detections into one node.
"""
from __future__ import annotations
import base64
import io
from dataclasses import dataclass, field
from pathlib import Path
from typing import Iterable
from PIL import Image
from .schema import BBox, GraphOut
@dataclass
class Tile:
"""One cropped region of the source image, ready to send to the VLM."""
image: Image.Image
x0: int # top-left x in global (full-image) pixel coordinates
y0: int # top-left y in global pixel coordinates
w: int # tile width in global pixels
h: int # tile height in global pixels
parent_w: int
parent_h: int
name: str # short label, e.g. "r0c0"
def split_image(
image_path: Path | str,
rows: int = 2,
cols: int = 2,
overlap: float = 0.1,
) -> list[Tile]:
"""Split `image_path` into `rows * cols` overlapping tiles.
`overlap` is the fractional grow-outward applied to each seam. With
`overlap=0.1` and 2x2 tiling each tile is 60% wide and 60% tall
(50% + 10% overlap on the internal edge). Boundary tiles are clipped
to the image extent so they never exceed the parent dimensions.
"""
img = Image.open(Path(image_path)).convert("RGB")
W, H = img.size
base_tw = W / cols
base_th = H / rows
ow = int(round(base_tw * overlap))
oh = int(round(base_th * overlap))
tiles: list[Tile] = []
for r in range(rows):
for c in range(cols):
x0 = max(0, int(round(c * base_tw)) - ow)
y0 = max(0, int(round(r * base_th)) - oh)
x1 = min(W, int(round((c + 1) * base_tw)) + ow)
y1 = min(H, int(round((r + 1) * base_th)) + oh)
tile_img = img.crop((x0, y0, x1, y1))
tiles.append(
Tile(
image=tile_img,
x0=x0,
y0=y0,
w=x1 - x0,
h=y1 - y0,
parent_w=W,
parent_h=H,
name=f"r{r}c{c}",
)
)
return tiles
def tile_to_base64_png(tile: Tile) -> tuple[str, str]:
"""Encode a tile as base64 PNG for the Messages API `source.data` field."""
buf = io.BytesIO()
tile.image.save(buf, format="PNG")
return base64.standard_b64encode(buf.getvalue()).decode("utf-8"), "image/png"
def _tile_bbox_to_global_px(bbox: BBox, tile: Tile) -> tuple[float, float, float, float]:
"""Convert a tile-local normalized bbox to global pixel coordinates."""
return (
tile.x0 + bbox.xmin * tile.w,
tile.y0 + bbox.ymin * tile.h,
tile.x0 + bbox.xmax * tile.w,
tile.y0 + bbox.ymax * tile.h,
)
def _center(bbox_px: tuple[float, float, float, float]) -> tuple[float, float]:
xmin, ymin, xmax, ymax = bbox_px
return (xmin + xmax) / 2.0, (ymin + ymax) / 2.0
@dataclass
class _MergedNode:
id: str
type: str
label: str | None
bbox_px: tuple[float, float, float, float]
center: tuple[float, float]
source_tiles: list[str] = field(default_factory=list)
def merge_tile_graphs(
tile_results: Iterable[tuple[GraphOut, Tile]],
dedup_px: float = 40.0,
) -> dict:
"""Merge per-tile predictions into one global graph.
Two nodes are considered the same symbol iff they share the same
`type` AND their bbox centers are within `dedup_px` pixels in the
global (un-tiled) image coordinate space. The first occurrence
wins; later duplicates are absorbed and their tile name is recorded
on `source_tiles` for debugging.
Edges are re-wired through the tile-local → global id map. An edge
whose endpoints map to the same global node is dropped, and
undirected duplicates across tiles are collapsed.
Returns a plain dict matching the shape `metrics.evaluate()` expects,
plus a `directed: False` flag. Extra `bbox` and `source_tiles` fields
on each node are ignored by the evaluator but useful for visualizing
predictions later.
"""
global_nodes: list[_MergedNode] = []
id_map: dict[tuple[str, str], str] = {} # (tile.name, tile_local_id) -> global id
next_id = 1
tile_results = list(tile_results)
# --- nodes: dedup by (type, bbox center distance) -------------------
for graph, tile in tile_results:
for node in graph.nodes:
if node.bbox is None:
# No bbox → can't dedupe across tiles safely. Skip.
continue
bbox_px = _tile_bbox_to_global_px(node.bbox, tile)
cx, cy = _center(bbox_px)
matched_gid: str | None = None
for gn in global_nodes:
if gn.type != node.type:
continue
dx = gn.center[0] - cx
dy = gn.center[1] - cy
if (dx * dx + dy * dy) ** 0.5 <= dedup_px:
matched_gid = gn.id
break
if matched_gid is None:
gid = f"n{next_id}"
next_id += 1
global_nodes.append(
_MergedNode(
id=gid,
type=node.type,
label=node.label,
bbox_px=bbox_px,
center=(cx, cy),
source_tiles=[tile.name],
)
)
id_map[(tile.name, node.id)] = gid
else:
id_map[(tile.name, node.id)] = matched_gid
for gn in global_nodes:
if gn.id == matched_gid:
if tile.name not in gn.source_tiles:
gn.source_tiles.append(tile.name)
break
# --- edges: remap ids, drop self-loops and duplicates ---------------
edge_set: set[tuple[str, str]] = set()
final_edges: list[dict] = []
for graph, tile in tile_results:
for edge in graph.edges:
src = id_map.get((tile.name, edge.source))
tgt = id_map.get((tile.name, edge.target))
if src is None or tgt is None:
continue
if src == tgt:
continue
a, b = sorted((src, tgt)) # undirected canonicalization
key = (a, b)
if key in edge_set:
continue
edge_set.add(key)
final_edges.append(
{
"source": a,
"target": b,
"type": edge.type or "solid",
"label": edge.label,
}
)
out_nodes = [
{
"id": gn.id,
"type": gn.type,
"label": gn.label,
"bbox": list(gn.bbox_px),
"source_tiles": gn.source_tiles,
}
for gn in global_nodes
]
return {
"nodes": out_nodes,
"edges": final_edges,
"directed": False,
}
def _inner_seam_lines(tiles: list[Tile]) -> tuple[list[float], list[float]]:
"""Return (vertical_seam_xs, horizontal_seam_ys) for inner tile borders.
For 2x2 tiling with overlap, each internal seam corresponds to TWO
pixel x-coordinates: where one tile ends and where the adjacent tile
begins. Both are returned so the FP filter can match either edge of
the overlap band.
"""
if not tiles:
return [], []
parent_w = tiles[0].parent_w
parent_h = tiles[0].parent_h
vx: set[float] = set()
hy: set[float] = set()
for t in tiles:
if t.x0 > 0:
vx.add(float(t.x0))
if t.x0 + t.w < parent_w:
vx.add(float(t.x0 + t.w))
if t.y0 > 0:
hy.add(float(t.y0))
if t.y0 + t.h < parent_h:
hy.add(float(t.y0 + t.h))
return sorted(vx), sorted(hy)
def filter_seam_artifacts(
merged: dict,
tiles: list[Tile],
types: tuple[str, ...] = ("inlet/outlet",),
seam_threshold: float = 50.0,
edge_threshold: float = 30.0,
) -> dict:
"""Drop nodes that look like tile-boundary false positives.
A node is filtered iff ALL of the following hold:
* its `type` is in `types` (default: inlet/outlet only),
* its bbox center lies within `seam_threshold` pixels of any inner
tile seam (vertical OR horizontal), AND
* its bbox center is NOT within `edge_threshold` pixels of the
outer image border (real boundary inlet/outlets stay).
Edges referencing dropped nodes are also removed. The dropped nodes
are recorded under `seam_filtered` so the report can show what was
pruned.
Nodes whose `type` is not in `types`, or that have no `bbox`, are
passed through untouched.
"""
if not tiles:
return merged
parent_w = tiles[0].parent_w
parent_h = tiles[0].parent_h
vx, hy = _inner_seam_lines(tiles)
types_set = set(types)
keep_ids: set[str] = set()
new_nodes: list[dict] = []
dropped: list[dict] = []
for n in merged["nodes"]:
if n.get("type") not in types_set:
keep_ids.add(n["id"])
new_nodes.append(n)
continue
bbox = n.get("bbox")
if not bbox or len(bbox) != 4:
keep_ids.add(n["id"])
new_nodes.append(n)
continue
cx = (bbox[0] + bbox[2]) / 2.0
cy = (bbox[1] + bbox[3]) / 2.0
near_outer = (
cx < edge_threshold
or cx > parent_w - edge_threshold
or cy < edge_threshold
or cy > parent_h - edge_threshold
)
near_seam_x = any(abs(cx - x) <= seam_threshold for x in vx)
near_seam_y = any(abs(cy - y) <= seam_threshold for y in hy)
near_seam = near_seam_x or near_seam_y
if near_seam and not near_outer:
dropped.append({"id": n["id"], "type": n["type"], "bbox": bbox})
continue
keep_ids.add(n["id"])
new_nodes.append(n)
new_edges = [
e
for e in merged["edges"]
if e["source"] in keep_ids and e["target"] in keep_ids
]
out = dict(merged)
out["nodes"] = new_nodes
out["edges"] = new_edges
out["seam_filtered"] = dropped
return out