Spaces:
Running
Running
| """Common graph schema shared by predictions and normalized ground truth. | |
| The VLM is asked to emit a `GraphOut` instance (enforced via structured outputs). | |
| Ground truth graphml is normalized into the same dict shape by `gt_loader.py`, | |
| so both sides can be fed directly into `metrics.evaluate()`. | |
| """ | |
| from __future__ import annotations | |
| from typing import List, Optional | |
| from pydantic import BaseModel, ConfigDict, Field | |
| class BBox(BaseModel): | |
| """Axis-aligned bounding box in normalized image coordinates. | |
| Each coordinate is in [0.0, 1.0] relative to the image the VLM was | |
| shown. When extraction runs on a cropped tile, these are tile-local | |
| coordinates; the tiled merger converts them back to global image | |
| pixel coordinates before deduplication. | |
| """ | |
| model_config = ConfigDict(extra="forbid") | |
| xmin: float = Field(description="Left edge, normalized [0, 1] within the visible image.") | |
| ymin: float = Field(description="Top edge, normalized [0, 1].") | |
| xmax: float = Field(description="Right edge, normalized [0, 1].") | |
| ymax: float = Field(description="Bottom edge, normalized [0, 1].") | |
| class NodeOut(BaseModel): | |
| """A single symbol in the P&ID (pump, valve, tank, sensor, ...).""" | |
| model_config = ConfigDict(extra="forbid") | |
| id: str = Field(description="Unique node id within this diagram (e.g. 'n1').") | |
| type: str = Field( | |
| description=( | |
| "Category of the symbol, lowercase, singular. " | |
| "Examples: 'pump', 'valve', 'tank', 'heat_exchanger', 'sensor', " | |
| "'controller', 'compressor', 'column'." | |
| ) | |
| ) | |
| label: Optional[str] = Field( | |
| default=None, | |
| description="Tag/identifier printed next to the symbol if visible (e.g. 'P-101').", | |
| ) | |
| bbox: Optional[BBox] = Field( | |
| default=None, | |
| description=( | |
| "Tight bounding box of the symbol, with each coordinate " | |
| "normalized to [0, 1] relative to the image the VLM sees. " | |
| "Required by the tiled extraction pipeline for deduplication." | |
| ), | |
| ) | |
| class EdgeOut(BaseModel): | |
| """A pipeline or signal connection between two symbols.""" | |
| model_config = ConfigDict(extra="forbid") | |
| source: str = Field(description="Source node id.") | |
| target: str = Field(description="Target node id.") | |
| type: Optional[str] = Field( | |
| default=None, | |
| description="Connection type: 'pipe', 'signal', 'electrical', ...", | |
| ) | |
| label: Optional[str] = Field( | |
| default=None, | |
| description="Line tag if visible (e.g. '2\"-PW-101').", | |
| ) | |
| class GraphOut(BaseModel): | |
| """Top-level graph extracted from a P&ID image.""" | |
| model_config = ConfigDict(extra="forbid") | |
| nodes: List[NodeOut] | |
| edges: List[EdgeOut] | |
| def to_dict(self) -> dict: | |
| """Shape that `metrics.evaluate()` consumes.""" | |
| return { | |
| "nodes": [n.model_dump() for n in self.nodes], | |
| "edges": [e.model_dump() for e in self.edges], | |
| } | |