"""Gradio demo: P&ID graph extraction with Claude VLM + evaluation. Usage (local): ANTHROPIC_API_KEY=sk-ant-... python app.py Or put the key in a `.env` next to this file. The app: 1. Takes a P&ID image (preset or upload) 2. Runs extraction (optionally tiled 2x2) via Claude Opus 4.6 3. If a ground-truth graphml is provided, collapses it to semantic-only form and computes node/edge P/R/F1 via `pid2graph_eval.metrics` 4. Draws both the prediction and the ground truth as NetworkX graphs using bbox-based layouts so the topology matches the source image """ from __future__ import annotations import json import os import time from pathlib import Path from typing import Optional # Matplotlib backend must be set before pyplot import for headless use; # the `matplotlib.use()` call below taints every subsequent import with # E402 ("module level import not at top of file"), which is expected here. import matplotlib matplotlib.use("Agg") import matplotlib.patches as mpatches # noqa: E402 import matplotlib.pyplot as plt # noqa: E402 import networkx as nx # noqa: E402 import anthropic # noqa: E402 # --------------------------------------------------------------------------- # Gradio 4.44 / gradio_client 1.3.0 bug workaround # --------------------------------------------------------------------------- # At `/info` boot, Gradio walks every component's JSON schema via # `gradio_client.utils._json_schema_to_python_type`. That function does not # handle bool schemas (`additionalProperties: false` or `true`, both of which # are valid JSON Schema) — it recurses with the bool and then `if "const" in # schema:` on line 863 raises `TypeError: argument of type 'bool' is not # iterable`. Patch the function here before importing gradio so the crash is # avoided regardless of which component triggers it. (Fixed upstream in later # gradio_client releases; we stay on 4.44 because Python 3.9 can't run # gradio 5.) Harmless on versions where the bug is already fixed. import gradio_client.utils as _gc_utils # noqa: E402 _orig_json_schema_to_python_type = _gc_utils._json_schema_to_python_type def _patched_json_schema_to_python_type(schema, defs=None): # type: ignore[override] if isinstance(schema, bool): # `True` means "any value is allowed"; `False` means "no value". return "Any" if schema else "None" return _orig_json_schema_to_python_type(schema, defs) _gc_utils._json_schema_to_python_type = _patched_json_schema_to_python_type import gradio as gr # noqa: E402 (must come after the monkey-patch) from dotenv import load_dotenv # noqa: E402 from pid2graph_eval.extractor import ( # noqa: E402 DEFAULT_MODEL, extract_graph, extract_graph_tiled, ) from pid2graph_eval.gt_loader import ( # noqa: E402 SEMANTIC_EQUIPMENT_TYPES, collapse_through_primitives, filter_by_types, load_graphml, ) from pid2graph_eval.metrics import evaluate # noqa: E402 # --------------------------------------------------------------------------- # Config # --------------------------------------------------------------------------- APP_ROOT = Path(__file__).parent SAMPLES_DIR = APP_ROOT / "samples" load_dotenv(APP_ROOT / ".env") # Presets: (display name) -> (image path, graphml path) PRESETS: dict[str, tuple[Path, Path]] = { "OPEN100 #1 — small (27 semantic nodes)": ( SAMPLES_DIR / "open100_01_small.png", SAMPLES_DIR / "open100_01_small.graphml", ), "OPEN100 #3 — medium (53 semantic nodes)": ( SAMPLES_DIR / "open100_03_medium.png", SAMPLES_DIR / "open100_03_medium.graphml", ), "OPEN100 #0 — large (82 semantic nodes)": ( SAMPLES_DIR / "open100_00_large.png", SAMPLES_DIR / "open100_00_large.graphml", ), } NONE_LABEL = "(none — upload your own)" # Fixed palette so pred/GT visualizations use matching colors. TYPE_COLORS: dict[str, str] = { "valve": "#ff6b6b", "pump": "#4ecdc4", "tank": "#ffd93d", "instrumentation": "#6bcfff", "inlet/outlet": "#c47bff", } LEGEND_HANDLES = [ mpatches.Patch(color=c, label=t) for t, c in TYPE_COLORS.items() ] # --------------------------------------------------------------------------- # Visualization # --------------------------------------------------------------------------- def _bbox_to_xyxy(bbox) -> Optional[tuple[float, float, float, float]]: """Normalize a bbox to `(xmin, ymin, xmax, ymax)` floats. Accepts both shapes that flow through the pipeline: * **list / tuple** `[x1, y1, x2, y2]` — produced by `gt_loader._bbox` and by `tile.merge_tile_graphs` for the tiled pred path. * **dict** `{"xmin": ..., "ymin": ..., "xmax": ..., "ymax": ...}` — produced by `GraphOut.to_dict()` in single-shot mode, because the Pydantic `BBox` model round-trips through `model_dump()`. Returns `None` if the bbox is missing or malformed. """ if bbox is None: return None if isinstance(bbox, dict): try: return ( float(bbox["xmin"]), float(bbox["ymin"]), float(bbox["xmax"]), float(bbox["ymax"]), ) except (KeyError, TypeError, ValueError): return None if isinstance(bbox, (list, tuple)) and len(bbox) == 4: try: return ( float(bbox[0]), float(bbox[1]), float(bbox[2]), float(bbox[3]), ) except (TypeError, ValueError): return None return None def draw_graph(graph_dict: dict, title: str, figsize=(8, 6)) -> plt.Figure: """Render a graph as a matplotlib figure. Node positions come from bbox centers when available (so the drawing preserves the spatial layout of the original P&ID); nodes without a bbox fall back to networkx spring layout. """ fig, ax = plt.subplots(figsize=figsize, dpi=110) G = nx.Graph() pos: dict[str, tuple[float, float]] = {} colors: list[str] = [] node_list: list[str] = [] for n in graph_dict.get("nodes", []): nid = n["id"] G.add_node(nid) node_list.append(nid) colors.append(TYPE_COLORS.get(n.get("type", ""), "#cccccc")) coords = _bbox_to_xyxy(n.get("bbox")) if coords is not None: x1, y1, x2, y2 = coords cx = (x1 + x2) / 2.0 cy = (y1 + y2) / 2.0 pos[nid] = (cx, -cy) # flip y so the image is right-side up for e in graph_dict.get("edges", []): s, t = e.get("source"), e.get("target") if s in G.nodes and t in G.nodes: G.add_edge(s, t) # Fall back to spring layout for any nodes that lack a bbox. missing = [nid for nid in G.nodes if nid not in pos] if missing: if not pos: pos = nx.spring_layout(G, seed=42) else: # Place missing nodes near the existing bbox cloud center. xs = [p[0] for p in pos.values()] ys = [p[1] for p in pos.values()] cx0 = sum(xs) / len(xs) cy0 = sum(ys) / len(ys) for nid in missing: pos[nid] = (cx0, cy0) nx.draw_networkx_edges(G, pos, alpha=0.35, width=0.6, ax=ax) nx.draw_networkx_nodes( G, pos, nodelist=node_list, node_color=colors, node_size=55, linewidths=0.5, edgecolors="#222", ax=ax, ) ax.set_title(title, fontsize=11) ax.set_aspect("equal") ax.axis("off") ax.legend(handles=LEGEND_HANDLES, loc="lower right", fontsize=7, framealpha=0.9) fig.tight_layout() return fig # --------------------------------------------------------------------------- # Pipeline # --------------------------------------------------------------------------- def _preset_paths(preset_name: str) -> tuple[Optional[str], Optional[str]]: """Resolve a preset dropdown selection to (image_path, graphml_path).""" if preset_name == NONE_LABEL or preset_name not in PRESETS: return None, None img, gt = PRESETS[preset_name] return (str(img) if img.exists() else None, str(gt) if gt.exists() else None) def _format_metrics(metrics: dict, latency_s: float, mode: str) -> str: nm = metrics["nodes"] em = metrics["edges"] return f""" ### Metrics | | Precision | Recall | F1 | TP | FP | FN | |---|---:|---:|---:|---:|---:|---:| | **Nodes** | {nm['precision']:.3f} | {nm['recall']:.3f} | **{nm['f1']:.3f}** | {nm['tp']} | {nm['fp']} | {nm['fn']} | | **Edges** | {em['precision']:.3f} | {em['recall']:.3f} | **{em['f1']:.3f}** | {em['tp']} | {em['fp']} | {em['fn']} | - Pred: **{metrics['n_pred_nodes']}** ノード / **{metrics['n_pred_edges']}** エッジ - GT (semantic-collapsed): **{metrics['n_gt_nodes']}** ノード / **{metrics['n_gt_edges']}** エッジ - Mode: `{mode}` · Latency: **{latency_s:.1f}s** """ def _format_pred_only(pred_dict: dict, latency_s: float, mode: str) -> str: return f""" ### Prediction - **{len(pred_dict['nodes'])}** ノード / **{len(pred_dict['edges'])}** エッジ - Mode: `{mode}` · Latency: **{latency_s:.1f}s** - (正解 graphml 未指定のため評価スキップ) """ def run_extraction( preset_name: str, image_path: Optional[str], gt_path: Optional[str], use_tiling: bool, progress: gr.Progress = gr.Progress(), ) -> tuple[str, Optional[plt.Figure], Optional[plt.Figure], str]: """Entry point wired to the Run button.""" # Preset overrides manual upload so the demo is reproducible. preset_img, preset_gt = _preset_paths(preset_name) if preset_img: image_path = preset_img if preset_gt: gt_path = preset_gt if not image_path: return ( "⚠️ 画像をアップロードするか、プリセットを選択してください。", None, None, "", ) if not os.environ.get("ANTHROPIC_API_KEY"): return ( "⚠️ `ANTHROPIC_API_KEY` が設定されていません。`.env` に追記して再起動してください。", None, None, "", ) client = anthropic.Anthropic() mode = "tiled 2x2 + seam filter" if use_tiling else "single-shot" try: progress(0.05, desc=f"VLM 抽出開始 ({mode})…") t0 = time.time() if use_tiling: pred_dict = extract_graph_tiled( Path(image_path), client=client, rows=2, cols=2, overlap=0.1, dedup_px=40.0, ) else: pred = extract_graph(Path(image_path), client=client) pred_dict = pred.to_dict() latency = time.time() - t0 progress(0.55, desc="予測を semantic types に絞り込み…") # Defensive: drop anything non-semantic the VLM may have emitted. pred_dict = filter_by_types(pred_dict, SEMANTIC_EQUIPMENT_TYPES) except Exception as e: return (f"❌ VLM 抽出中にエラー: `{e}`", None, None, "") progress(0.65, desc="予測グラフを描画…") pred_fig = draw_graph( pred_dict, title=f"Prediction — {len(pred_dict['nodes'])} nodes, {len(pred_dict['edges'])} edges", ) gt_fig = None metrics_md = _format_pred_only(pred_dict, latency, mode) if gt_path and Path(gt_path).exists(): try: progress(0.75, desc="GT graphml をロード & 縮約…") gt_raw = load_graphml(Path(gt_path)) gt_dict = collapse_through_primitives(gt_raw, SEMANTIC_EQUIPMENT_TYPES) progress(0.85, desc="P/R/F1 を評価…") metrics = evaluate( pred_dict, gt_dict, directed=False, match_threshold=0.5, ) metrics_md = _format_metrics(metrics, latency, mode) progress(0.95, desc="GT グラフを描画…") gt_fig = draw_graph( gt_dict, title=f"Ground Truth — {len(gt_dict['nodes'])} nodes, {len(gt_dict['edges'])} edges", ) except Exception as e: metrics_md += f"\n\n⚠️ GT 処理でエラー: `{e}`" # Strip heavy-ish keys before JSON display. display_dict = { "nodes": pred_dict["nodes"], "edges": pred_dict["edges"], } pred_json = json.dumps(display_dict, indent=2, ensure_ascii=False) progress(1.0, desc="完了") return metrics_md, pred_fig, gt_fig, pred_json def on_preset_change(preset_name: str): """When a preset is picked, auto-fill the image and graphml fields.""" img, gt = _preset_paths(preset_name) return img, gt # --------------------------------------------------------------------------- # UI # --------------------------------------------------------------------------- DESCRIPTION = """ # PID2Graph × Claude VLM Demo P&ID (配管計装図) を Claude Opus 4.6 のビジョンで読み取り、シンボル(valve / pump / tank / instrumentation / inlet・outlet)とその接続関係を JSON グラフに変換します。 正解 graphml を指定すると、ノード/エッジ単位の Precision / Recall / F1 を算出します。 - **タイル分割 (2x2)**: 大きな図面では 1 枚を 4 タイルに分割してから抽出し、マージ時に bbox 距離で重複排除 + タイル境界の inlet/outlet FP を後処理で除去します。 - **評価ルール**: GT 側は semantic equipment のみを残し、配管プリミティブ (connector / crossing / arrow / background) を経由する接続を 1 エッジに縮約します。 - **VLM 設定**: `temperature=0` で決定論的サンプリング、構造化出力で JSON スキーマを強制。 """ def build_ui() -> gr.Blocks: with gr.Blocks(title="PID2Graph × Claude VLM Demo") as demo: gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Column(scale=1): preset = gr.Dropdown( choices=[NONE_LABEL] + list(PRESETS.keys()), value=NONE_LABEL, label="プリセット (OPEN100 より)", ) image_in = gr.Image( type="filepath", label="P&ID 画像", height=260, ) gt_in = gr.File( label="正解 graphml (任意)", file_types=[".graphml", ".xml"], type="filepath", ) tiling = gr.Checkbox( value=True, label="タイル分割 (2x2) で抽出 — 高精度だがコスト・時間 4 倍", ) run_btn = gr.Button("抽出実行", variant="primary") gr.Markdown( "モデル: `" + DEFAULT_MODEL + "` · 所要時間目安: single ~20s / tiled ~60-80s" ) with gr.Column(scale=2): metrics_md = gr.Markdown() with gr.Row(): pred_plot = gr.Plot(label="Prediction") gt_plot = gr.Plot(label="Ground Truth") with gr.Accordion("予測 JSON (nodes / edges)", open=False): # NOTE: using Textbox rather than `gr.Code(language="json")` # because the latter's schema has tripped the gradio_client # `additionalProperties: false` bug on 4.44.1 in the past. # Textbox is a plain string component — zero schema surface. pred_json = gr.Textbox( label="", lines=20, max_lines=30, show_copy_button=True, interactive=False, ) preset.change(on_preset_change, inputs=[preset], outputs=[image_in, gt_in]) run_btn.click( run_extraction, inputs=[preset, image_in, gt_in, tiling], outputs=[metrics_md, pred_plot, gt_plot, pred_json], ) return demo if __name__ == "__main__": # HF Spaces runs this inside a container where localhost is not # reachable from outside — binding to 0.0.0.0 is required, otherwise # Gradio raises "When localhost is not accessible, a shareable link # must be created". Locally this is harmless. # # `show_api=False` hides the docs panel in the UI; the monkey-patch # at the top of this file is what actually prevents the 4.44 schema # crash (now kept as defensive dead-code since we pin to 4.31.5). build_ui().launch(show_api=False, server_name="0.0.0.0")