Spaces:
Running
Running
| """Load PID2Graph ground-truth graphml files and normalize to the common schema. | |
| PID2Graph OPEN100 schema (observed): | |
| node attrs: label (category), xmin, xmax, ymin, ymax (bounding box) | |
| edge attrs: edge_label (only 'solid' observed) | |
| graph: undirected | |
| Ten node categories are used across the dataset: | |
| connector, crossing, arrow, instrumentation, valve, general, | |
| inlet/outlet, background, tank, pump | |
| There are no printed tags (like "P-101") in the graphml — PID2Graph is a | |
| pure symbol + connectivity benchmark, not an OCR benchmark. The `label` | |
| field in our normalized dict is therefore always None for this dataset | |
| and metrics fall back to type-only matching. | |
| """ | |
| from __future__ import annotations | |
| from pathlib import Path | |
| from typing import Iterable, Optional | |
| import networkx as nx | |
| # The PID2Graph graphml files store the category under `label`. Kept as a | |
| # tuple so other graphml-based datasets with different conventions can be | |
| # supported by extending the candidate list. | |
| NODE_TYPE_KEYS: tuple[str, ...] = ("label", "type", "category", "class") | |
| EDGE_TYPE_KEYS: tuple[str, ...] = ("edge_label", "type", "category") | |
| # The official PID2Graph OPEN100 categories. Exposed so extractor.py can | |
| # put them straight into the VLM prompt. | |
| PID2GRAPH_NODE_TYPES: tuple[str, ...] = ( | |
| "connector", | |
| "crossing", | |
| "arrow", | |
| "instrumentation", | |
| "valve", | |
| "general", | |
| "inlet/outlet", | |
| "background", | |
| "tank", | |
| "pump", | |
| ) | |
| # The subset used by the "semantic-only" evaluation mode: real equipment | |
| # and instrument symbols, excluding line-level primitives (connector / | |
| # crossing / arrow / background) AND the `general` catch-all, whose | |
| # shape definition is too vague for zero-shot VLM detection to handle. | |
| SEMANTIC_EQUIPMENT_TYPES: frozenset[str] = frozenset( | |
| { | |
| "valve", | |
| "pump", | |
| "tank", | |
| "instrumentation", | |
| "inlet/outlet", | |
| } | |
| ) | |
| def _norm_type(t: Optional[str]) -> str: | |
| return (t or "").strip().lower() | |
| def filter_by_types(graph: dict, allowed: frozenset[str]) -> dict: | |
| """Return a copy of `graph` keeping only nodes whose type is in `allowed`. | |
| Edges are kept only when BOTH endpoints survive the filter. All | |
| non-{nodes, edges} keys (e.g. `directed`, `tile_stats`, | |
| `seam_filtered`) are passed through unchanged so downstream code | |
| can still inspect provenance. | |
| """ | |
| keep_ids: set[str] = set() | |
| new_nodes: list[dict] = [] | |
| for n in graph["nodes"]: | |
| if _norm_type(n.get("type")) in allowed: | |
| keep_ids.add(n["id"]) | |
| new_nodes.append(n) | |
| new_edges = [ | |
| e | |
| for e in graph["edges"] | |
| if e["source"] in keep_ids and e["target"] in keep_ids | |
| ] | |
| out = dict(graph) | |
| out["nodes"] = new_nodes | |
| out["edges"] = new_edges | |
| return out | |
| def collapse_through_primitives(graph: dict, semantic_types: frozenset[str]) -> dict: | |
| """Keep only semantic nodes; re-wire edges by walking through primitives. | |
| Two semantic nodes are connected in the result iff there is a path | |
| between them in the original graph consisting of zero or more | |
| NON-semantic nodes (e.g. `connector`, `crossing`, `arrow`). This | |
| matches what the VLM is asked to produce: one direct semantic-to- | |
| semantic edge per physical pipeline, regardless of how many pipe | |
| junctions it passes through. | |
| The resulting graph is always treated as undirected — PID2Graph's | |
| underlying graphml is undirected and path-based equivalence has no | |
| natural orientation. | |
| """ | |
| sem_ids: set[str] = { | |
| n["id"] for n in graph["nodes"] if _norm_type(n.get("type")) in semantic_types | |
| } | |
| # Undirected adjacency for the full graph | |
| adj: dict[str, list[str]] = {n["id"]: [] for n in graph["nodes"]} | |
| for e in graph["edges"]: | |
| s, t = e["source"], e["target"] | |
| if s in adj and t in adj: | |
| adj[s].append(t) | |
| adj[t].append(s) | |
| new_edges: set[tuple[str, str]] = set() | |
| # BFS from each semantic node through primitive nodes; whenever we | |
| # land on another semantic node, record the edge and stop expanding | |
| # past it. Visiting primitives multiple times from different | |
| # starting points is fine; the edge-set deduplicates results. | |
| for start in sem_ids: | |
| visited = {start} | |
| stack: list[str] = [start] | |
| while stack: | |
| cur = stack.pop() | |
| for nb in adj.get(cur, ()): | |
| if nb in visited: | |
| continue | |
| visited.add(nb) | |
| if nb in sem_ids: | |
| a, b = sorted((start, nb)) | |
| new_edges.add((a, b)) | |
| # Don't recurse past a semantic boundary. | |
| else: | |
| stack.append(nb) | |
| new_nodes = [n for n in graph["nodes"] if n["id"] in sem_ids] | |
| new_edges_list = [ | |
| { | |
| "source": a, | |
| "target": b, | |
| "type": "solid", | |
| "label": None, | |
| "raw_attrs": {}, | |
| } | |
| for a, b in sorted(new_edges) | |
| ] | |
| return { | |
| "nodes": new_nodes, | |
| "edges": new_edges_list, | |
| "directed": False, | |
| } | |
| def _first_attr(attrs: dict, keys: Iterable[str]) -> Optional[str]: | |
| for k in keys: | |
| v = attrs.get(k) | |
| if v is None: | |
| continue | |
| s = str(v).strip() | |
| if s: | |
| return s | |
| return None | |
| def _bbox(attrs: dict) -> Optional[list[float]]: | |
| try: | |
| return [ | |
| float(attrs["xmin"]), | |
| float(attrs["ymin"]), | |
| float(attrs["xmax"]), | |
| float(attrs["ymax"]), | |
| ] | |
| except (KeyError, TypeError, ValueError): | |
| return None | |
| def load_graphml(path: Path) -> dict: | |
| """Parse a graphml file into `{nodes, edges, directed}`. | |
| Each node/edge keeps its original attributes under `raw_attrs` so | |
| experiments can try alternative fields without re-reading the file. | |
| """ | |
| G = nx.read_graphml(path) | |
| nodes: list[dict] = [] | |
| for node_id, attrs in G.nodes(data=True): | |
| nodes.append( | |
| { | |
| "id": str(node_id), | |
| "type": _first_attr(attrs, NODE_TYPE_KEYS) or "", | |
| "label": None, # PID2Graph has no printed tag in GT | |
| "bbox": _bbox(attrs), | |
| "raw_attrs": dict(attrs), | |
| } | |
| ) | |
| edges: list[dict] = [] | |
| for u, v, attrs in G.edges(data=True): | |
| edges.append( | |
| { | |
| "source": str(u), | |
| "target": str(v), | |
| "type": _first_attr(attrs, EDGE_TYPE_KEYS), | |
| "label": None, | |
| "raw_attrs": dict(attrs), | |
| } | |
| ) | |
| return { | |
| "nodes": nodes, | |
| "edges": edges, | |
| "directed": G.is_directed(), | |
| } | |
| def summarize(graph: dict) -> dict: | |
| """Quick stats for sanity-checking the loader on a new dataset.""" | |
| type_counts: dict[str, int] = {} | |
| for n in graph["nodes"]: | |
| t = n["type"] or "<empty>" | |
| type_counts[t] = type_counts.get(t, 0) + 1 | |
| edge_type_counts: dict[str, int] = {} | |
| for e in graph["edges"]: | |
| t = e["type"] or "<empty>" | |
| edge_type_counts[t] = edge_type_counts.get(t, 0) + 1 | |
| return { | |
| "n_nodes": len(graph["nodes"]), | |
| "n_edges": len(graph["edges"]), | |
| "directed": graph.get("directed", False), | |
| "node_types": type_counts, | |
| "edge_types": edge_type_counts, | |
| } | |