deepkick's picture
Fix HF Spaces launch: pin gradio==4.31.5 and bind 0.0.0.0
242f528
"""Gradio demo: P&ID graph extraction with Claude VLM + evaluation.
Usage (local):
ANTHROPIC_API_KEY=sk-ant-... python app.py
Or put the key in a `.env` next to this file. The app:
1. Takes a P&ID image (preset or upload)
2. Runs extraction (optionally tiled 2x2) via Claude Opus 4.6
3. If a ground-truth graphml is provided, collapses it to semantic-only
form and computes node/edge P/R/F1 via `pid2graph_eval.metrics`
4. Draws both the prediction and the ground truth as NetworkX graphs
using bbox-based layouts so the topology matches the source image
"""
from __future__ import annotations
import json
import os
import time
from pathlib import Path
from typing import Optional
# Matplotlib backend must be set before pyplot import for headless use;
# the `matplotlib.use()` call below taints every subsequent import with
# E402 ("module level import not at top of file"), which is expected here.
import matplotlib
matplotlib.use("Agg")
import matplotlib.patches as mpatches # noqa: E402
import matplotlib.pyplot as plt # noqa: E402
import networkx as nx # noqa: E402
import anthropic # noqa: E402
# ---------------------------------------------------------------------------
# Gradio 4.44 / gradio_client 1.3.0 bug workaround
# ---------------------------------------------------------------------------
# At `/info` boot, Gradio walks every component's JSON schema via
# `gradio_client.utils._json_schema_to_python_type`. That function does not
# handle bool schemas (`additionalProperties: false` or `true`, both of which
# are valid JSON Schema) — it recurses with the bool and then `if "const" in
# schema:` on line 863 raises `TypeError: argument of type 'bool' is not
# iterable`. Patch the function here before importing gradio so the crash is
# avoided regardless of which component triggers it. (Fixed upstream in later
# gradio_client releases; we stay on 4.44 because Python 3.9 can't run
# gradio 5.) Harmless on versions where the bug is already fixed.
import gradio_client.utils as _gc_utils # noqa: E402
_orig_json_schema_to_python_type = _gc_utils._json_schema_to_python_type
def _patched_json_schema_to_python_type(schema, defs=None): # type: ignore[override]
if isinstance(schema, bool):
# `True` means "any value is allowed"; `False` means "no value".
return "Any" if schema else "None"
return _orig_json_schema_to_python_type(schema, defs)
_gc_utils._json_schema_to_python_type = _patched_json_schema_to_python_type
import gradio as gr # noqa: E402 (must come after the monkey-patch)
from dotenv import load_dotenv # noqa: E402
from pid2graph_eval.extractor import ( # noqa: E402
DEFAULT_MODEL,
extract_graph,
extract_graph_tiled,
)
from pid2graph_eval.gt_loader import ( # noqa: E402
SEMANTIC_EQUIPMENT_TYPES,
collapse_through_primitives,
filter_by_types,
load_graphml,
)
from pid2graph_eval.metrics import evaluate # noqa: E402
# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
APP_ROOT = Path(__file__).parent
SAMPLES_DIR = APP_ROOT / "samples"
load_dotenv(APP_ROOT / ".env")
# Presets: (display name) -> (image path, graphml path)
PRESETS: dict[str, tuple[Path, Path]] = {
"OPEN100 #1 — small (27 semantic nodes)": (
SAMPLES_DIR / "open100_01_small.png",
SAMPLES_DIR / "open100_01_small.graphml",
),
"OPEN100 #3 — medium (53 semantic nodes)": (
SAMPLES_DIR / "open100_03_medium.png",
SAMPLES_DIR / "open100_03_medium.graphml",
),
"OPEN100 #0 — large (82 semantic nodes)": (
SAMPLES_DIR / "open100_00_large.png",
SAMPLES_DIR / "open100_00_large.graphml",
),
}
NONE_LABEL = "(none — upload your own)"
# Fixed palette so pred/GT visualizations use matching colors.
TYPE_COLORS: dict[str, str] = {
"valve": "#ff6b6b",
"pump": "#4ecdc4",
"tank": "#ffd93d",
"instrumentation": "#6bcfff",
"inlet/outlet": "#c47bff",
}
LEGEND_HANDLES = [
mpatches.Patch(color=c, label=t) for t, c in TYPE_COLORS.items()
]
# ---------------------------------------------------------------------------
# Visualization
# ---------------------------------------------------------------------------
def _bbox_to_xyxy(bbox) -> Optional[tuple[float, float, float, float]]:
"""Normalize a bbox to `(xmin, ymin, xmax, ymax)` floats.
Accepts both shapes that flow through the pipeline:
* **list / tuple** `[x1, y1, x2, y2]` — produced by `gt_loader._bbox`
and by `tile.merge_tile_graphs` for the tiled pred path.
* **dict** `{"xmin": ..., "ymin": ..., "xmax": ..., "ymax": ...}` —
produced by `GraphOut.to_dict()` in single-shot mode, because the
Pydantic `BBox` model round-trips through `model_dump()`.
Returns `None` if the bbox is missing or malformed.
"""
if bbox is None:
return None
if isinstance(bbox, dict):
try:
return (
float(bbox["xmin"]),
float(bbox["ymin"]),
float(bbox["xmax"]),
float(bbox["ymax"]),
)
except (KeyError, TypeError, ValueError):
return None
if isinstance(bbox, (list, tuple)) and len(bbox) == 4:
try:
return (
float(bbox[0]),
float(bbox[1]),
float(bbox[2]),
float(bbox[3]),
)
except (TypeError, ValueError):
return None
return None
def draw_graph(graph_dict: dict, title: str, figsize=(8, 6)) -> plt.Figure:
"""Render a graph as a matplotlib figure.
Node positions come from bbox centers when available (so the drawing
preserves the spatial layout of the original P&ID); nodes without a
bbox fall back to networkx spring layout.
"""
fig, ax = plt.subplots(figsize=figsize, dpi=110)
G = nx.Graph()
pos: dict[str, tuple[float, float]] = {}
colors: list[str] = []
node_list: list[str] = []
for n in graph_dict.get("nodes", []):
nid = n["id"]
G.add_node(nid)
node_list.append(nid)
colors.append(TYPE_COLORS.get(n.get("type", ""), "#cccccc"))
coords = _bbox_to_xyxy(n.get("bbox"))
if coords is not None:
x1, y1, x2, y2 = coords
cx = (x1 + x2) / 2.0
cy = (y1 + y2) / 2.0
pos[nid] = (cx, -cy) # flip y so the image is right-side up
for e in graph_dict.get("edges", []):
s, t = e.get("source"), e.get("target")
if s in G.nodes and t in G.nodes:
G.add_edge(s, t)
# Fall back to spring layout for any nodes that lack a bbox.
missing = [nid for nid in G.nodes if nid not in pos]
if missing:
if not pos:
pos = nx.spring_layout(G, seed=42)
else:
# Place missing nodes near the existing bbox cloud center.
xs = [p[0] for p in pos.values()]
ys = [p[1] for p in pos.values()]
cx0 = sum(xs) / len(xs)
cy0 = sum(ys) / len(ys)
for nid in missing:
pos[nid] = (cx0, cy0)
nx.draw_networkx_edges(G, pos, alpha=0.35, width=0.6, ax=ax)
nx.draw_networkx_nodes(
G, pos,
nodelist=node_list,
node_color=colors,
node_size=55,
linewidths=0.5,
edgecolors="#222",
ax=ax,
)
ax.set_title(title, fontsize=11)
ax.set_aspect("equal")
ax.axis("off")
ax.legend(handles=LEGEND_HANDLES, loc="lower right", fontsize=7, framealpha=0.9)
fig.tight_layout()
return fig
# ---------------------------------------------------------------------------
# Pipeline
# ---------------------------------------------------------------------------
def _preset_paths(preset_name: str) -> tuple[Optional[str], Optional[str]]:
"""Resolve a preset dropdown selection to (image_path, graphml_path)."""
if preset_name == NONE_LABEL or preset_name not in PRESETS:
return None, None
img, gt = PRESETS[preset_name]
return (str(img) if img.exists() else None,
str(gt) if gt.exists() else None)
def _format_metrics(metrics: dict, latency_s: float, mode: str) -> str:
nm = metrics["nodes"]
em = metrics["edges"]
return f"""
### Metrics
| | Precision | Recall | F1 | TP | FP | FN |
|---|---:|---:|---:|---:|---:|---:|
| **Nodes** | {nm['precision']:.3f} | {nm['recall']:.3f} | **{nm['f1']:.3f}** | {nm['tp']} | {nm['fp']} | {nm['fn']} |
| **Edges** | {em['precision']:.3f} | {em['recall']:.3f} | **{em['f1']:.3f}** | {em['tp']} | {em['fp']} | {em['fn']} |
- Pred: **{metrics['n_pred_nodes']}** ノード / **{metrics['n_pred_edges']}** エッジ
- GT (semantic-collapsed): **{metrics['n_gt_nodes']}** ノード / **{metrics['n_gt_edges']}** エッジ
- Mode: `{mode}` · Latency: **{latency_s:.1f}s**
"""
def _format_pred_only(pred_dict: dict, latency_s: float, mode: str) -> str:
return f"""
### Prediction
- **{len(pred_dict['nodes'])}** ノード / **{len(pred_dict['edges'])}** エッジ
- Mode: `{mode}` · Latency: **{latency_s:.1f}s**
- (正解 graphml 未指定のため評価スキップ)
"""
def run_extraction(
preset_name: str,
image_path: Optional[str],
gt_path: Optional[str],
use_tiling: bool,
progress: gr.Progress = gr.Progress(),
) -> tuple[str, Optional[plt.Figure], Optional[plt.Figure], str]:
"""Entry point wired to the Run button."""
# Preset overrides manual upload so the demo is reproducible.
preset_img, preset_gt = _preset_paths(preset_name)
if preset_img:
image_path = preset_img
if preset_gt:
gt_path = preset_gt
if not image_path:
return (
"⚠️ 画像をアップロードするか、プリセットを選択してください。",
None, None, "",
)
if not os.environ.get("ANTHROPIC_API_KEY"):
return (
"⚠️ `ANTHROPIC_API_KEY` が設定されていません。`.env` に追記して再起動してください。",
None, None, "",
)
client = anthropic.Anthropic()
mode = "tiled 2x2 + seam filter" if use_tiling else "single-shot"
try:
progress(0.05, desc=f"VLM 抽出開始 ({mode})…")
t0 = time.time()
if use_tiling:
pred_dict = extract_graph_tiled(
Path(image_path),
client=client,
rows=2,
cols=2,
overlap=0.1,
dedup_px=40.0,
)
else:
pred = extract_graph(Path(image_path), client=client)
pred_dict = pred.to_dict()
latency = time.time() - t0
progress(0.55, desc="予測を semantic types に絞り込み…")
# Defensive: drop anything non-semantic the VLM may have emitted.
pred_dict = filter_by_types(pred_dict, SEMANTIC_EQUIPMENT_TYPES)
except Exception as e:
return (f"❌ VLM 抽出中にエラー: `{e}`", None, None, "")
progress(0.65, desc="予測グラフを描画…")
pred_fig = draw_graph(
pred_dict,
title=f"Prediction — {len(pred_dict['nodes'])} nodes, {len(pred_dict['edges'])} edges",
)
gt_fig = None
metrics_md = _format_pred_only(pred_dict, latency, mode)
if gt_path and Path(gt_path).exists():
try:
progress(0.75, desc="GT graphml をロード & 縮約…")
gt_raw = load_graphml(Path(gt_path))
gt_dict = collapse_through_primitives(gt_raw, SEMANTIC_EQUIPMENT_TYPES)
progress(0.85, desc="P/R/F1 を評価…")
metrics = evaluate(
pred_dict,
gt_dict,
directed=False,
match_threshold=0.5,
)
metrics_md = _format_metrics(metrics, latency, mode)
progress(0.95, desc="GT グラフを描画…")
gt_fig = draw_graph(
gt_dict,
title=f"Ground Truth — {len(gt_dict['nodes'])} nodes, {len(gt_dict['edges'])} edges",
)
except Exception as e:
metrics_md += f"\n\n⚠️ GT 処理でエラー: `{e}`"
# Strip heavy-ish keys before JSON display.
display_dict = {
"nodes": pred_dict["nodes"],
"edges": pred_dict["edges"],
}
pred_json = json.dumps(display_dict, indent=2, ensure_ascii=False)
progress(1.0, desc="完了")
return metrics_md, pred_fig, gt_fig, pred_json
def on_preset_change(preset_name: str):
"""When a preset is picked, auto-fill the image and graphml fields."""
img, gt = _preset_paths(preset_name)
return img, gt
# ---------------------------------------------------------------------------
# UI
# ---------------------------------------------------------------------------
DESCRIPTION = """
# PID2Graph × Claude VLM Demo
P&ID (配管計装図) を Claude Opus 4.6 のビジョンで読み取り、シンボル(valve / pump /
tank / instrumentation / inlet・outlet)とその接続関係を JSON グラフに変換します。
正解 graphml を指定すると、ノード/エッジ単位の Precision / Recall / F1 を算出します。
- **タイル分割 (2x2)**: 大きな図面では 1 枚を 4 タイルに分割してから抽出し、マージ時に
bbox 距離で重複排除 + タイル境界の inlet/outlet FP を後処理で除去します。
- **評価ルール**: GT 側は semantic equipment のみを残し、配管プリミティブ (connector /
crossing / arrow / background) を経由する接続を 1 エッジに縮約します。
- **VLM 設定**: `temperature=0` で決定論的サンプリング、構造化出力で JSON スキーマを強制。
"""
def build_ui() -> gr.Blocks:
with gr.Blocks(title="PID2Graph × Claude VLM Demo") as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column(scale=1):
preset = gr.Dropdown(
choices=[NONE_LABEL] + list(PRESETS.keys()),
value=NONE_LABEL,
label="プリセット (OPEN100 より)",
)
image_in = gr.Image(
type="filepath",
label="P&ID 画像",
height=260,
)
gt_in = gr.File(
label="正解 graphml (任意)",
file_types=[".graphml", ".xml"],
type="filepath",
)
tiling = gr.Checkbox(
value=True,
label="タイル分割 (2x2) で抽出 — 高精度だがコスト・時間 4 倍",
)
run_btn = gr.Button("抽出実行", variant="primary")
gr.Markdown(
"モデル: `" + DEFAULT_MODEL + "` · 所要時間目安: single ~20s / tiled ~60-80s"
)
with gr.Column(scale=2):
metrics_md = gr.Markdown()
with gr.Row():
pred_plot = gr.Plot(label="Prediction")
gt_plot = gr.Plot(label="Ground Truth")
with gr.Accordion("予測 JSON (nodes / edges)", open=False):
# NOTE: using Textbox rather than `gr.Code(language="json")`
# because the latter's schema has tripped the gradio_client
# `additionalProperties: false` bug on 4.44.1 in the past.
# Textbox is a plain string component — zero schema surface.
pred_json = gr.Textbox(
label="",
lines=20,
max_lines=30,
show_copy_button=True,
interactive=False,
)
preset.change(on_preset_change, inputs=[preset], outputs=[image_in, gt_in])
run_btn.click(
run_extraction,
inputs=[preset, image_in, gt_in, tiling],
outputs=[metrics_md, pred_plot, gt_plot, pred_json],
)
return demo
if __name__ == "__main__":
# HF Spaces runs this inside a container where localhost is not
# reachable from outside — binding to 0.0.0.0 is required, otherwise
# Gradio raises "When localhost is not accessible, a shareable link
# must be created". Locally this is harmless.
#
# `show_api=False` hides the docs panel in the UI; the monkey-patch
# at the top of this file is what actually prevents the 4.44 schema
# crash (now kept as defensive dead-code since we pin to 4.31.5).
build_ui().launch(show_api=False, server_name="0.0.0.0")