Spaces:
Running
Running
File size: 3,005 Bytes
59fa244 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 | """Common graph schema shared by predictions and normalized ground truth.
The VLM is asked to emit a `GraphOut` instance (enforced via structured outputs).
Ground truth graphml is normalized into the same dict shape by `gt_loader.py`,
so both sides can be fed directly into `metrics.evaluate()`.
"""
from __future__ import annotations
from typing import List, Optional
from pydantic import BaseModel, ConfigDict, Field
class BBox(BaseModel):
"""Axis-aligned bounding box in normalized image coordinates.
Each coordinate is in [0.0, 1.0] relative to the image the VLM was
shown. When extraction runs on a cropped tile, these are tile-local
coordinates; the tiled merger converts them back to global image
pixel coordinates before deduplication.
"""
model_config = ConfigDict(extra="forbid")
xmin: float = Field(description="Left edge, normalized [0, 1] within the visible image.")
ymin: float = Field(description="Top edge, normalized [0, 1].")
xmax: float = Field(description="Right edge, normalized [0, 1].")
ymax: float = Field(description="Bottom edge, normalized [0, 1].")
class NodeOut(BaseModel):
"""A single symbol in the P&ID (pump, valve, tank, sensor, ...)."""
model_config = ConfigDict(extra="forbid")
id: str = Field(description="Unique node id within this diagram (e.g. 'n1').")
type: str = Field(
description=(
"Category of the symbol, lowercase, singular. "
"Examples: 'pump', 'valve', 'tank', 'heat_exchanger', 'sensor', "
"'controller', 'compressor', 'column'."
)
)
label: Optional[str] = Field(
default=None,
description="Tag/identifier printed next to the symbol if visible (e.g. 'P-101').",
)
bbox: Optional[BBox] = Field(
default=None,
description=(
"Tight bounding box of the symbol, with each coordinate "
"normalized to [0, 1] relative to the image the VLM sees. "
"Required by the tiled extraction pipeline for deduplication."
),
)
class EdgeOut(BaseModel):
"""A pipeline or signal connection between two symbols."""
model_config = ConfigDict(extra="forbid")
source: str = Field(description="Source node id.")
target: str = Field(description="Target node id.")
type: Optional[str] = Field(
default=None,
description="Connection type: 'pipe', 'signal', 'electrical', ...",
)
label: Optional[str] = Field(
default=None,
description="Line tag if visible (e.g. '2\"-PW-101').",
)
class GraphOut(BaseModel):
"""Top-level graph extracted from a P&ID image."""
model_config = ConfigDict(extra="forbid")
nodes: List[NodeOut]
edges: List[EdgeOut]
def to_dict(self) -> dict:
"""Shape that `metrics.evaluate()` consumes."""
return {
"nodes": [n.model_dump() for n in self.nodes],
"edges": [e.model_dump() for e in self.edges],
}
|