intelligent-pid / detection_schema.py
msIntui
Initial commit: Add core files for P&ID processing
9847531
from dataclasses import dataclass, field
from typing import List, Optional, Tuple, Dict
import uuid
from enum import Enum
import json
import numpy as np
# ======================== Point ======================== #
class ConnectionType(Enum):
SOLID = "solid"
DASHED = "dashed"
PHANTOM = "phantom"
@dataclass
class Coordinates:
x: int
y: int
@dataclass
class BBox:
xmin: int
ymin: int
xmax: int
ymax: int
def width(self) -> int:
return self.xmax - self.xmin
def height(self) -> int:
return self.ymax - self.ymin
class JunctionType(str, Enum):
T = "T"
L = "L"
END = "END"
@dataclass
class Point:
coords: Coordinates
bbox: BBox
type: JunctionType
confidence: float = 1.0
id: str = field(default_factory=lambda: str(uuid.uuid4()))
# # ======================== Symbol ======================== #
# class SymbolType(Enum):
# VALVE = "valve"
# PUMP = "pump"
# SENSOR = "sensor"
# # Add others as needed
#
class ValveSubtype(Enum):
GATE = "gate"
GLOBE = "globe"
BUTTERFLY = "butterfly"
#
# @dataclass
# class Symbol:
# symbol_type: SymbolType
# bbox: BBox
# center: Coordinates
# connections: List[Point] = field(default_factory=list)
# subtype: Optional[ValveSubtype] = None
# id: str = field(default_factory=lambda: str(uuid.uuid4()))
# confidence: float = 0.95
# model_metadata: dict = field(default_factory=dict)
# ======================== Symbol ======================== #
class SymbolType(Enum):
VALVE = "valve"
PUMP = "pump"
SENSOR = "sensor"
OTHER = "other" # Added to handle unknown categories
@dataclass
class Symbol:
center: Coordinates
symbol_type: SymbolType = field(default=SymbolType.OTHER)
id: str = field(default_factory=lambda: str(uuid.uuid4()))
class_id: int = -1
original_label: str = ""
category: str = "" # e.g., "inst"
type: str = "" # e.g., "ind"
label: str = "" # e.g., "Solenoid_actuator"
bbox: BBox = None
confidence: float = 0.95
model_source: str = "" # e.g., "model2"
connections: List[Point] = field(default_factory=list)
subtype: Optional[ValveSubtype] = None
model_metadata: dict = field(default_factory=dict)
def __post_init__(self):
"""
Handle any additional post-processing after initialization.
"""
# Ensure bbox is a BBox object
if isinstance(self.bbox, list) and len(self.bbox) == 4:
self.bbox = BBox(*self.bbox)
# ======================== Line ======================== #
@dataclass
class LineStyle:
connection_type: ConnectionType
stroke_width: int = 2
color: str = "#000000" # CSS-style colors
@dataclass
class Line:
start: Point
end: Point
bbox: BBox
id: str = field(default_factory=lambda: str(uuid.uuid4()))
style: LineStyle = field(default_factory=lambda: LineStyle(ConnectionType.SOLID))
confidence: float = 0.90
topological_links: List[str] = field(default_factory=list) # Linked symbols/junctions
# ======================== Junction ======================== #
class JunctionType(str, Enum):
T = "T"
L = "L"
END = "END"
@dataclass
class JunctionProperties:
flow_direction: Optional[str] = None # "in", "out"
pressure: Optional[float] = None # kPa
@dataclass
class Junction:
center: Coordinates
junction_type: JunctionType
id: str = field(default_factory=lambda: str(uuid.uuid4()))
properties: JunctionProperties = field(default_factory=JunctionProperties)
connected_lines: List[str] = field(default_factory=list) # Line IDs
# # ======================== Tag ======================== #
# @dataclass
# class Tag:
# text: str
# bbox: BBox
# associated_element: str # ID of linked symbol/line
# id: str = field(default_factory=lambda: str(uuid.uuid4()))
# font_size: int = 12
# rotation: float = 0.0 # Degrees
@dataclass
class Tag:
text: str
bbox: BBox
confidence: float = 1.0
source: str = "" # e.g., "easyocr"
text_type: str = "Unknown" # e.g., "Unknown", could be something else later
id: str = field(default_factory=lambda: str(uuid.uuid4()))
associated_element: Optional[str] = None # ID of linked symbol/line (can be None)
font_size: int = 12
rotation: float = 0.0 # Degrees
def __post_init__(self):
"""
Ensure bbox is properly converted.
"""
if isinstance(self.bbox, list) and len(self.bbox) == 4:
self.bbox = BBox(*self.bbox)
# ----------------------------
# DETECTION CONTEXT
# ----------------------------
@dataclass
class DetectionContext:
"""
In-memory container for all detected elements (lines, points, symbols, junctions, tags).
Each element is stored in a dict keyed by 'id' for quick lookup and update.
"""
lines: Dict[str, Line] = field(default_factory=dict)
points: Dict[str, Point] = field(default_factory=dict)
symbols: Dict[str, Symbol] = field(default_factory=dict)
junctions: Dict[str, Junction] = field(default_factory=dict)
tags: Dict[str, Tag] = field(default_factory=dict)
# -------------------------
# 1) ADD / GET / REMOVE
# -------------------------
def add_line(self, line: Line) -> None:
self.lines[line.id] = line
def get_line(self, line_id: str) -> Optional[Line]:
return self.lines.get(line_id)
def remove_line(self, line_id: str) -> None:
self.lines.pop(line_id, None)
def add_point(self, point: Point) -> None:
self.points[point.id] = point
def get_point(self, point_id: str) -> Optional[Point]:
return self.points.get(point_id)
def remove_point(self, point_id: str) -> None:
self.points.pop(point_id, None)
def add_symbol(self, symbol: Symbol) -> None:
self.symbols[symbol.id] = symbol
def get_symbol(self, symbol_id: str) -> Optional[Symbol]:
return self.symbols.get(symbol_id)
def remove_symbol(self, symbol_id: str) -> None:
self.symbols.pop(symbol_id, None)
def add_junction(self, junction: Junction) -> None:
self.junctions[junction.id] = junction
def get_junction(self, junction_id: str) -> Optional[Junction]:
return self.junctions.get(junction_id)
def remove_junction(self, junction_id: str) -> None:
self.junctions.pop(junction_id, None)
def add_tag(self, tag: Tag) -> None:
self.tags[tag.id] = tag
def get_tag(self, tag_id: str) -> Optional[Tag]:
return self.tags.get(tag_id)
def remove_tag(self, tag_id: str) -> None:
self.tags.pop(tag_id, None)
# -------------------------
# 2) SERIALIZATION: to_dict / from_dict
# -------------------------
def to_dict(self) -> dict:
"""Convert all stored objects into a JSON-serializable dictionary."""
return {
"lines": [self._line_to_dict(line) for line in self.lines.values()],
"points": [self._point_to_dict(pt) for pt in self.points.values()],
"symbols": [self._symbol_to_dict(sym) for sym in self.symbols.values()],
"junctions": [self._junction_to_dict(jn) for jn in self.junctions.values()],
"tags": [self._tag_to_dict(tg) for tg in self.tags.values()]
}
@classmethod
def from_dict(cls, data: dict) -> "DetectionContext":
"""
Create a new DetectionContext from a dictionary structure (e.g. loaded from JSON).
"""
context = cls()
# Points
for pt_dict in data.get("points", []):
pt_obj = cls._point_from_dict(pt_dict)
context.add_point(pt_obj)
# Lines
for ln_dict in data.get("lines", []):
ln_obj = cls._line_from_dict(ln_dict)
context.add_line(ln_obj)
# Symbols
for sym_dict in data.get("symbols", []):
sym_obj = cls._symbol_from_dict(sym_dict)
context.add_symbol(sym_obj)
# Junctions
for jn_dict in data.get("junctions", []):
jn_obj = cls._junction_from_dict(jn_dict)
context.add_junction(jn_obj)
# Tags
for tg_dict in data.get("tags", []):
tg_obj = cls._tag_from_dict(tg_dict)
context.add_tag(tg_obj)
return context
# -------------------------
# 3) HELPER METHODS FOR (DE)SERIALIZATION
# -------------------------
@staticmethod
def _bbox_to_dict(bbox: BBox) -> dict:
return {
"xmin": bbox.xmin,
"ymin": bbox.ymin,
"xmax": bbox.xmax,
"ymax": bbox.ymax
}
@staticmethod
def _bbox_from_dict(d: dict) -> BBox:
return BBox(
xmin=d["xmin"],
ymin=d["ymin"],
xmax=d["xmax"],
ymax=d["ymax"]
)
@staticmethod
def _coords_to_dict(coords: Coordinates) -> dict:
return {
"x": coords.x,
"y": coords.y
}
@staticmethod
def _coords_from_dict(d: dict) -> Coordinates:
return Coordinates(x=d["x"], y=d["y"])
@staticmethod
def _line_style_to_dict(style: LineStyle) -> dict:
return {
"connection_type": style.connection_type.value,
"stroke_width": style.stroke_width,
"color": style.color
}
@staticmethod
def _line_style_from_dict(d: dict) -> LineStyle:
return LineStyle(
connection_type=ConnectionType(d["connection_type"]),
stroke_width=d.get("stroke_width", 2),
color=d.get("color", "#000000")
)
@staticmethod
def _point_to_dict(pt: Point) -> dict:
return {
"id": pt.id,
"coords": DetectionContext._coords_to_dict(pt.coords),
"bbox": DetectionContext._bbox_to_dict(pt.bbox),
"type": pt.type.value,
"confidence": pt.confidence
}
@staticmethod
def _point_from_dict(d: dict) -> Point:
return Point(
id=d["id"],
coords=DetectionContext._coords_from_dict(d["coords"]),
bbox=DetectionContext._bbox_from_dict(d["bbox"]),
type=JunctionType(d["type"]),
confidence=d.get("confidence", 1.0)
)
@staticmethod
def _line_to_dict(ln: Line) -> dict:
return {
"id": ln.id,
"start": DetectionContext._point_to_dict(ln.start),
"end": DetectionContext._point_to_dict(ln.end),
"bbox": DetectionContext._bbox_to_dict(ln.bbox),
"style": DetectionContext._line_style_to_dict(ln.style),
"confidence": ln.confidence,
"topological_links": ln.topological_links
}
@staticmethod
def _line_from_dict(d: dict) -> Line:
return Line(
id=d["id"],
start=DetectionContext._point_from_dict(d["start"]),
end=DetectionContext._point_from_dict(d["end"]),
bbox=DetectionContext._bbox_from_dict(d["bbox"]),
style=DetectionContext._line_style_from_dict(d["style"]),
confidence=d.get("confidence", 0.90),
topological_links=d.get("topological_links", [])
)
@staticmethod
def _symbol_to_dict(sym: Symbol) -> dict:
return {
"id": sym.id,
"symbol_type": sym.symbol_type.value,
"bbox": DetectionContext._bbox_to_dict(sym.bbox),
"center": DetectionContext._coords_to_dict(sym.center),
"connections": [DetectionContext._point_to_dict(p) for p in sym.connections],
"subtype": sym.subtype.value if sym.subtype else None,
"confidence": sym.confidence,
"model_metadata": sym.model_metadata
}
@staticmethod
def _symbol_from_dict(d: dict) -> Symbol:
return Symbol(
id=d["id"],
symbol_type=SymbolType(d["symbol_type"]),
bbox=DetectionContext._bbox_from_dict(d["bbox"]),
center=DetectionContext._coords_from_dict(d["center"]),
connections=[DetectionContext._point_from_dict(p) for p in d.get("connections", [])],
subtype=ValveSubtype(d["subtype"]) if d.get("subtype") else None,
confidence=d.get("confidence", 0.95),
model_metadata=d.get("model_metadata", {})
)
@staticmethod
def _junction_props_to_dict(props: JunctionProperties) -> dict:
return {
"flow_direction": props.flow_direction,
"pressure": props.pressure
}
@staticmethod
def _junction_props_from_dict(d: dict) -> JunctionProperties:
return JunctionProperties(
flow_direction=d.get("flow_direction"),
pressure=d.get("pressure")
)
@staticmethod
def _junction_to_dict(jn: Junction) -> dict:
return {
"id": jn.id,
"center": DetectionContext._coords_to_dict(jn.center),
"junction_type": jn.junction_type.value,
"properties": DetectionContext._junction_props_to_dict(jn.properties),
"connected_lines": jn.connected_lines
}
@staticmethod
def _junction_from_dict(d: dict) -> Junction:
return Junction(
id=d["id"],
center=DetectionContext._coords_from_dict(d["center"]),
junction_type=JunctionType(d["junction_type"]),
properties=DetectionContext._junction_props_from_dict(d["properties"]),
connected_lines=d.get("connected_lines", [])
)
@staticmethod
def _tag_to_dict(tg: Tag) -> dict:
return {
"id": tg.id,
"text": tg.text,
"bbox": DetectionContext._bbox_to_dict(tg.bbox),
"associated_element": tg.associated_element,
"font_size": tg.font_size,
"rotation": tg.rotation
}
@staticmethod
def _tag_from_dict(d: dict) -> Tag:
return Tag(
id=d["id"],
text=d["text"],
bbox=DetectionContext._bbox_from_dict(d["bbox"]),
associated_element=d["associated_element"],
font_size=d.get("font_size", 12),
rotation=d.get("rotation", 0.0)
)
# -------------------------
# 4) OPTIONAL UTILS
# -------------------------
def to_json(self, indent: int = 2) -> str:
"""Convert context to JSON, ensuring dataclasses and numpy types are handled correctly."""
return json.dumps(self.to_dict(), default=self._json_serializer, indent=indent)
@staticmethod
def _json_serializer(obj):
"""Handles numpy types and unknown objects for JSON serialization."""
if isinstance(obj, np.integer):
return int(obj)
if isinstance(obj, np.floating):
return float(obj)
if isinstance(obj, np.ndarray):
return obj.tolist() # Convert arrays to lists
if isinstance(obj, Enum):
return obj.value # Convert Enums to string values
if hasattr(obj, "__dict__"):
return obj.__dict__ # Convert dataclass objects to dict
raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
@classmethod
def from_json(cls, json_str: str) -> "DetectionContext":
"""Load DetectionContext from a JSON string."""
data = json.loads(json_str)
return cls.from_dict(data)