"""Common graph schema shared by predictions and normalized ground truth. The VLM is asked to emit a `GraphOut` instance (enforced via structured outputs). Ground truth graphml is normalized into the same dict shape by `gt_loader.py`, so both sides can be fed directly into `metrics.evaluate()`. """ from __future__ import annotations from typing import List, Optional from pydantic import BaseModel, ConfigDict, Field class BBox(BaseModel): """Axis-aligned bounding box in normalized image coordinates. Each coordinate is in [0.0, 1.0] relative to the image the VLM was shown. When extraction runs on a cropped tile, these are tile-local coordinates; the tiled merger converts them back to global image pixel coordinates before deduplication. """ model_config = ConfigDict(extra="forbid") xmin: float = Field(description="Left edge, normalized [0, 1] within the visible image.") ymin: float = Field(description="Top edge, normalized [0, 1].") xmax: float = Field(description="Right edge, normalized [0, 1].") ymax: float = Field(description="Bottom edge, normalized [0, 1].") class NodeOut(BaseModel): """A single symbol in the P&ID (pump, valve, tank, sensor, ...).""" model_config = ConfigDict(extra="forbid") id: str = Field(description="Unique node id within this diagram (e.g. 'n1').") type: str = Field( description=( "Category of the symbol, lowercase, singular. " "Examples: 'pump', 'valve', 'tank', 'heat_exchanger', 'sensor', " "'controller', 'compressor', 'column'." ) ) label: Optional[str] = Field( default=None, description="Tag/identifier printed next to the symbol if visible (e.g. 'P-101').", ) bbox: Optional[BBox] = Field( default=None, description=( "Tight bounding box of the symbol, with each coordinate " "normalized to [0, 1] relative to the image the VLM sees. " "Required by the tiled extraction pipeline for deduplication." ), ) class EdgeOut(BaseModel): """A pipeline or signal connection between two symbols.""" model_config = ConfigDict(extra="forbid") source: str = Field(description="Source node id.") target: str = Field(description="Target node id.") type: Optional[str] = Field( default=None, description="Connection type: 'pipe', 'signal', 'electrical', ...", ) label: Optional[str] = Field( default=None, description="Line tag if visible (e.g. '2\"-PW-101').", ) class GraphOut(BaseModel): """Top-level graph extracted from a P&ID image.""" model_config = ConfigDict(extra="forbid") nodes: List[NodeOut] edges: List[EdgeOut] def to_dict(self) -> dict: """Shape that `metrics.evaluate()` consumes.""" return { "nodes": [n.model_dump() for n in self.nodes], "edges": [e.model_dump() for e in self.edges], }