deepkick's picture
Initial commit: PID2Graph × Claude VLM evaluation + Gradio demo
59fa244
"""Common graph schema shared by predictions and normalized ground truth.
The VLM is asked to emit a `GraphOut` instance (enforced via structured outputs).
Ground truth graphml is normalized into the same dict shape by `gt_loader.py`,
so both sides can be fed directly into `metrics.evaluate()`.
"""
from __future__ import annotations
from typing import List, Optional
from pydantic import BaseModel, ConfigDict, Field
class BBox(BaseModel):
"""Axis-aligned bounding box in normalized image coordinates.
Each coordinate is in [0.0, 1.0] relative to the image the VLM was
shown. When extraction runs on a cropped tile, these are tile-local
coordinates; the tiled merger converts them back to global image
pixel coordinates before deduplication.
"""
model_config = ConfigDict(extra="forbid")
xmin: float = Field(description="Left edge, normalized [0, 1] within the visible image.")
ymin: float = Field(description="Top edge, normalized [0, 1].")
xmax: float = Field(description="Right edge, normalized [0, 1].")
ymax: float = Field(description="Bottom edge, normalized [0, 1].")
class NodeOut(BaseModel):
"""A single symbol in the P&ID (pump, valve, tank, sensor, ...)."""
model_config = ConfigDict(extra="forbid")
id: str = Field(description="Unique node id within this diagram (e.g. 'n1').")
type: str = Field(
description=(
"Category of the symbol, lowercase, singular. "
"Examples: 'pump', 'valve', 'tank', 'heat_exchanger', 'sensor', "
"'controller', 'compressor', 'column'."
)
)
label: Optional[str] = Field(
default=None,
description="Tag/identifier printed next to the symbol if visible (e.g. 'P-101').",
)
bbox: Optional[BBox] = Field(
default=None,
description=(
"Tight bounding box of the symbol, with each coordinate "
"normalized to [0, 1] relative to the image the VLM sees. "
"Required by the tiled extraction pipeline for deduplication."
),
)
class EdgeOut(BaseModel):
"""A pipeline or signal connection between two symbols."""
model_config = ConfigDict(extra="forbid")
source: str = Field(description="Source node id.")
target: str = Field(description="Target node id.")
type: Optional[str] = Field(
default=None,
description="Connection type: 'pipe', 'signal', 'electrical', ...",
)
label: Optional[str] = Field(
default=None,
description="Line tag if visible (e.g. '2\"-PW-101').",
)
class GraphOut(BaseModel):
"""Top-level graph extracted from a P&ID image."""
model_config = ConfigDict(extra="forbid")
nodes: List[NodeOut]
edges: List[EdgeOut]
def to_dict(self) -> dict:
"""Shape that `metrics.evaluate()` consumes."""
return {
"nodes": [n.model_dump() for n in self.nodes],
"edges": [e.model_dump() for e in self.edges],
}