"""Load PID2Graph ground-truth graphml files and normalize to the common schema. PID2Graph OPEN100 schema (observed): node attrs: label (category), xmin, xmax, ymin, ymax (bounding box) edge attrs: edge_label (only 'solid' observed) graph: undirected Ten node categories are used across the dataset: connector, crossing, arrow, instrumentation, valve, general, inlet/outlet, background, tank, pump There are no printed tags (like "P-101") in the graphml — PID2Graph is a pure symbol + connectivity benchmark, not an OCR benchmark. The `label` field in our normalized dict is therefore always None for this dataset and metrics fall back to type-only matching. """ from __future__ import annotations from pathlib import Path from typing import Iterable, Optional import networkx as nx # The PID2Graph graphml files store the category under `label`. Kept as a # tuple so other graphml-based datasets with different conventions can be # supported by extending the candidate list. NODE_TYPE_KEYS: tuple[str, ...] = ("label", "type", "category", "class") EDGE_TYPE_KEYS: tuple[str, ...] = ("edge_label", "type", "category") # The official PID2Graph OPEN100 categories. Exposed so extractor.py can # put them straight into the VLM prompt. PID2GRAPH_NODE_TYPES: tuple[str, ...] = ( "connector", "crossing", "arrow", "instrumentation", "valve", "general", "inlet/outlet", "background", "tank", "pump", ) # The subset used by the "semantic-only" evaluation mode: real equipment # and instrument symbols, excluding line-level primitives (connector / # crossing / arrow / background) AND the `general` catch-all, whose # shape definition is too vague for zero-shot VLM detection to handle. SEMANTIC_EQUIPMENT_TYPES: frozenset[str] = frozenset( { "valve", "pump", "tank", "instrumentation", "inlet/outlet", } ) def _norm_type(t: Optional[str]) -> str: return (t or "").strip().lower() def filter_by_types(graph: dict, allowed: frozenset[str]) -> dict: """Return a copy of `graph` keeping only nodes whose type is in `allowed`. Edges are kept only when BOTH endpoints survive the filter. All non-{nodes, edges} keys (e.g. `directed`, `tile_stats`, `seam_filtered`) are passed through unchanged so downstream code can still inspect provenance. """ keep_ids: set[str] = set() new_nodes: list[dict] = [] for n in graph["nodes"]: if _norm_type(n.get("type")) in allowed: keep_ids.add(n["id"]) new_nodes.append(n) new_edges = [ e for e in graph["edges"] if e["source"] in keep_ids and e["target"] in keep_ids ] out = dict(graph) out["nodes"] = new_nodes out["edges"] = new_edges return out def collapse_through_primitives(graph: dict, semantic_types: frozenset[str]) -> dict: """Keep only semantic nodes; re-wire edges by walking through primitives. Two semantic nodes are connected in the result iff there is a path between them in the original graph consisting of zero or more NON-semantic nodes (e.g. `connector`, `crossing`, `arrow`). This matches what the VLM is asked to produce: one direct semantic-to- semantic edge per physical pipeline, regardless of how many pipe junctions it passes through. The resulting graph is always treated as undirected — PID2Graph's underlying graphml is undirected and path-based equivalence has no natural orientation. """ sem_ids: set[str] = { n["id"] for n in graph["nodes"] if _norm_type(n.get("type")) in semantic_types } # Undirected adjacency for the full graph adj: dict[str, list[str]] = {n["id"]: [] for n in graph["nodes"]} for e in graph["edges"]: s, t = e["source"], e["target"] if s in adj and t in adj: adj[s].append(t) adj[t].append(s) new_edges: set[tuple[str, str]] = set() # BFS from each semantic node through primitive nodes; whenever we # land on another semantic node, record the edge and stop expanding # past it. Visiting primitives multiple times from different # starting points is fine; the edge-set deduplicates results. for start in sem_ids: visited = {start} stack: list[str] = [start] while stack: cur = stack.pop() for nb in adj.get(cur, ()): if nb in visited: continue visited.add(nb) if nb in sem_ids: a, b = sorted((start, nb)) new_edges.add((a, b)) # Don't recurse past a semantic boundary. else: stack.append(nb) new_nodes = [n for n in graph["nodes"] if n["id"] in sem_ids] new_edges_list = [ { "source": a, "target": b, "type": "solid", "label": None, "raw_attrs": {}, } for a, b in sorted(new_edges) ] return { "nodes": new_nodes, "edges": new_edges_list, "directed": False, } def _first_attr(attrs: dict, keys: Iterable[str]) -> Optional[str]: for k in keys: v = attrs.get(k) if v is None: continue s = str(v).strip() if s: return s return None def _bbox(attrs: dict) -> Optional[list[float]]: try: return [ float(attrs["xmin"]), float(attrs["ymin"]), float(attrs["xmax"]), float(attrs["ymax"]), ] except (KeyError, TypeError, ValueError): return None def load_graphml(path: Path) -> dict: """Parse a graphml file into `{nodes, edges, directed}`. Each node/edge keeps its original attributes under `raw_attrs` so experiments can try alternative fields without re-reading the file. """ G = nx.read_graphml(path) nodes: list[dict] = [] for node_id, attrs in G.nodes(data=True): nodes.append( { "id": str(node_id), "type": _first_attr(attrs, NODE_TYPE_KEYS) or "", "label": None, # PID2Graph has no printed tag in GT "bbox": _bbox(attrs), "raw_attrs": dict(attrs), } ) edges: list[dict] = [] for u, v, attrs in G.edges(data=True): edges.append( { "source": str(u), "target": str(v), "type": _first_attr(attrs, EDGE_TYPE_KEYS), "label": None, "raw_attrs": dict(attrs), } ) return { "nodes": nodes, "edges": edges, "directed": G.is_directed(), } def summarize(graph: dict) -> dict: """Quick stats for sanity-checking the loader on a new dataset.""" type_counts: dict[str, int] = {} for n in graph["nodes"]: t = n["type"] or "" type_counts[t] = type_counts.get(t, 0) + 1 edge_type_counts: dict[str, int] = {} for e in graph["edges"]: t = e["type"] or "" edge_type_counts[t] = edge_type_counts.get(t, 0) + 1 return { "n_nodes": len(graph["nodes"]), "n_edges": len(graph["edges"]), "directed": graph.get("directed", False), "node_types": type_counts, "edge_types": edge_type_counts, }