File size: 3,005 Bytes
59fa244
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""Common graph schema shared by predictions and normalized ground truth.

The VLM is asked to emit a `GraphOut` instance (enforced via structured outputs).
Ground truth graphml is normalized into the same dict shape by `gt_loader.py`,
so both sides can be fed directly into `metrics.evaluate()`.
"""

from __future__ import annotations

from typing import List, Optional

from pydantic import BaseModel, ConfigDict, Field


class BBox(BaseModel):
    """Axis-aligned bounding box in normalized image coordinates.

    Each coordinate is in [0.0, 1.0] relative to the image the VLM was
    shown. When extraction runs on a cropped tile, these are tile-local
    coordinates; the tiled merger converts them back to global image
    pixel coordinates before deduplication.
    """

    model_config = ConfigDict(extra="forbid")

    xmin: float = Field(description="Left edge, normalized [0, 1] within the visible image.")
    ymin: float = Field(description="Top edge, normalized [0, 1].")
    xmax: float = Field(description="Right edge, normalized [0, 1].")
    ymax: float = Field(description="Bottom edge, normalized [0, 1].")


class NodeOut(BaseModel):
    """A single symbol in the P&ID (pump, valve, tank, sensor, ...)."""

    model_config = ConfigDict(extra="forbid")

    id: str = Field(description="Unique node id within this diagram (e.g. 'n1').")
    type: str = Field(
        description=(
            "Category of the symbol, lowercase, singular. "
            "Examples: 'pump', 'valve', 'tank', 'heat_exchanger', 'sensor', "
            "'controller', 'compressor', 'column'."
        )
    )
    label: Optional[str] = Field(
        default=None,
        description="Tag/identifier printed next to the symbol if visible (e.g. 'P-101').",
    )
    bbox: Optional[BBox] = Field(
        default=None,
        description=(
            "Tight bounding box of the symbol, with each coordinate "
            "normalized to [0, 1] relative to the image the VLM sees. "
            "Required by the tiled extraction pipeline for deduplication."
        ),
    )


class EdgeOut(BaseModel):
    """A pipeline or signal connection between two symbols."""

    model_config = ConfigDict(extra="forbid")

    source: str = Field(description="Source node id.")
    target: str = Field(description="Target node id.")
    type: Optional[str] = Field(
        default=None,
        description="Connection type: 'pipe', 'signal', 'electrical', ...",
    )
    label: Optional[str] = Field(
        default=None,
        description="Line tag if visible (e.g. '2\"-PW-101').",
    )


class GraphOut(BaseModel):
    """Top-level graph extracted from a P&ID image."""

    model_config = ConfigDict(extra="forbid")

    nodes: List[NodeOut]
    edges: List[EdgeOut]

    def to_dict(self) -> dict:
        """Shape that `metrics.evaluate()` consumes."""
        return {
            "nodes": [n.model_dump() for n in self.nodes],
            "edges": [e.model_dump() for e in self.edges],
        }