deepkick's picture
Initial commit: PID2Graph × Claude VLM evaluation + Gradio demo
59fa244
"""Load PID2Graph ground-truth graphml files and normalize to the common schema.
PID2Graph OPEN100 schema (observed):
node attrs: label (category), xmin, xmax, ymin, ymax (bounding box)
edge attrs: edge_label (only 'solid' observed)
graph: undirected
Ten node categories are used across the dataset:
connector, crossing, arrow, instrumentation, valve, general,
inlet/outlet, background, tank, pump
There are no printed tags (like "P-101") in the graphml — PID2Graph is a
pure symbol + connectivity benchmark, not an OCR benchmark. The `label`
field in our normalized dict is therefore always None for this dataset
and metrics fall back to type-only matching.
"""
from __future__ import annotations
from pathlib import Path
from typing import Iterable, Optional
import networkx as nx
# The PID2Graph graphml files store the category under `label`. Kept as a
# tuple so other graphml-based datasets with different conventions can be
# supported by extending the candidate list.
NODE_TYPE_KEYS: tuple[str, ...] = ("label", "type", "category", "class")
EDGE_TYPE_KEYS: tuple[str, ...] = ("edge_label", "type", "category")
# The official PID2Graph OPEN100 categories. Exposed so extractor.py can
# put them straight into the VLM prompt.
PID2GRAPH_NODE_TYPES: tuple[str, ...] = (
"connector",
"crossing",
"arrow",
"instrumentation",
"valve",
"general",
"inlet/outlet",
"background",
"tank",
"pump",
)
# The subset used by the "semantic-only" evaluation mode: real equipment
# and instrument symbols, excluding line-level primitives (connector /
# crossing / arrow / background) AND the `general` catch-all, whose
# shape definition is too vague for zero-shot VLM detection to handle.
SEMANTIC_EQUIPMENT_TYPES: frozenset[str] = frozenset(
{
"valve",
"pump",
"tank",
"instrumentation",
"inlet/outlet",
}
)
def _norm_type(t: Optional[str]) -> str:
return (t or "").strip().lower()
def filter_by_types(graph: dict, allowed: frozenset[str]) -> dict:
"""Return a copy of `graph` keeping only nodes whose type is in `allowed`.
Edges are kept only when BOTH endpoints survive the filter. All
non-{nodes, edges} keys (e.g. `directed`, `tile_stats`,
`seam_filtered`) are passed through unchanged so downstream code
can still inspect provenance.
"""
keep_ids: set[str] = set()
new_nodes: list[dict] = []
for n in graph["nodes"]:
if _norm_type(n.get("type")) in allowed:
keep_ids.add(n["id"])
new_nodes.append(n)
new_edges = [
e
for e in graph["edges"]
if e["source"] in keep_ids and e["target"] in keep_ids
]
out = dict(graph)
out["nodes"] = new_nodes
out["edges"] = new_edges
return out
def collapse_through_primitives(graph: dict, semantic_types: frozenset[str]) -> dict:
"""Keep only semantic nodes; re-wire edges by walking through primitives.
Two semantic nodes are connected in the result iff there is a path
between them in the original graph consisting of zero or more
NON-semantic nodes (e.g. `connector`, `crossing`, `arrow`). This
matches what the VLM is asked to produce: one direct semantic-to-
semantic edge per physical pipeline, regardless of how many pipe
junctions it passes through.
The resulting graph is always treated as undirected — PID2Graph's
underlying graphml is undirected and path-based equivalence has no
natural orientation.
"""
sem_ids: set[str] = {
n["id"] for n in graph["nodes"] if _norm_type(n.get("type")) in semantic_types
}
# Undirected adjacency for the full graph
adj: dict[str, list[str]] = {n["id"]: [] for n in graph["nodes"]}
for e in graph["edges"]:
s, t = e["source"], e["target"]
if s in adj and t in adj:
adj[s].append(t)
adj[t].append(s)
new_edges: set[tuple[str, str]] = set()
# BFS from each semantic node through primitive nodes; whenever we
# land on another semantic node, record the edge and stop expanding
# past it. Visiting primitives multiple times from different
# starting points is fine; the edge-set deduplicates results.
for start in sem_ids:
visited = {start}
stack: list[str] = [start]
while stack:
cur = stack.pop()
for nb in adj.get(cur, ()):
if nb in visited:
continue
visited.add(nb)
if nb in sem_ids:
a, b = sorted((start, nb))
new_edges.add((a, b))
# Don't recurse past a semantic boundary.
else:
stack.append(nb)
new_nodes = [n for n in graph["nodes"] if n["id"] in sem_ids]
new_edges_list = [
{
"source": a,
"target": b,
"type": "solid",
"label": None,
"raw_attrs": {},
}
for a, b in sorted(new_edges)
]
return {
"nodes": new_nodes,
"edges": new_edges_list,
"directed": False,
}
def _first_attr(attrs: dict, keys: Iterable[str]) -> Optional[str]:
for k in keys:
v = attrs.get(k)
if v is None:
continue
s = str(v).strip()
if s:
return s
return None
def _bbox(attrs: dict) -> Optional[list[float]]:
try:
return [
float(attrs["xmin"]),
float(attrs["ymin"]),
float(attrs["xmax"]),
float(attrs["ymax"]),
]
except (KeyError, TypeError, ValueError):
return None
def load_graphml(path: Path) -> dict:
"""Parse a graphml file into `{nodes, edges, directed}`.
Each node/edge keeps its original attributes under `raw_attrs` so
experiments can try alternative fields without re-reading the file.
"""
G = nx.read_graphml(path)
nodes: list[dict] = []
for node_id, attrs in G.nodes(data=True):
nodes.append(
{
"id": str(node_id),
"type": _first_attr(attrs, NODE_TYPE_KEYS) or "",
"label": None, # PID2Graph has no printed tag in GT
"bbox": _bbox(attrs),
"raw_attrs": dict(attrs),
}
)
edges: list[dict] = []
for u, v, attrs in G.edges(data=True):
edges.append(
{
"source": str(u),
"target": str(v),
"type": _first_attr(attrs, EDGE_TYPE_KEYS),
"label": None,
"raw_attrs": dict(attrs),
}
)
return {
"nodes": nodes,
"edges": edges,
"directed": G.is_directed(),
}
def summarize(graph: dict) -> dict:
"""Quick stats for sanity-checking the loader on a new dataset."""
type_counts: dict[str, int] = {}
for n in graph["nodes"]:
t = n["type"] or "<empty>"
type_counts[t] = type_counts.get(t, 0) + 1
edge_type_counts: dict[str, int] = {}
for e in graph["edges"]:
t = e["type"] or "<empty>"
edge_type_counts[t] = edge_type_counts.get(t, 0) + 1
return {
"n_nodes": len(graph["nodes"]),
"n_edges": len(graph["edges"]),
"directed": graph.get("directed", False),
"node_types": type_counts,
"edge_types": edge_type_counts,
}