from dataclasses import dataclass, field from typing import List, Optional, Tuple, Dict import uuid from enum import Enum import json import numpy as np # ======================== Point ======================== # class ConnectionType(Enum): SOLID = "solid" DASHED = "dashed" PHANTOM = "phantom" @dataclass class Coordinates: x: int y: int @dataclass class BBox: xmin: int ymin: int xmax: int ymax: int def width(self) -> int: return self.xmax - self.xmin def height(self) -> int: return self.ymax - self.ymin class JunctionType(str, Enum): T = "T" L = "L" END = "END" @dataclass class Point: coords: Coordinates bbox: BBox type: JunctionType confidence: float = 1.0 id: str = field(default_factory=lambda: str(uuid.uuid4())) # # ======================== Symbol ======================== # # class SymbolType(Enum): # VALVE = "valve" # PUMP = "pump" # SENSOR = "sensor" # # Add others as needed # class ValveSubtype(Enum): GATE = "gate" GLOBE = "globe" BUTTERFLY = "butterfly" # # @dataclass # class Symbol: # symbol_type: SymbolType # bbox: BBox # center: Coordinates # connections: List[Point] = field(default_factory=list) # subtype: Optional[ValveSubtype] = None # id: str = field(default_factory=lambda: str(uuid.uuid4())) # confidence: float = 0.95 # model_metadata: dict = field(default_factory=dict) # ======================== Symbol ======================== # class SymbolType(Enum): VALVE = "valve" PUMP = "pump" SENSOR = "sensor" OTHER = "other" # Added to handle unknown categories @dataclass class Symbol: center: Coordinates symbol_type: SymbolType = field(default=SymbolType.OTHER) id: str = field(default_factory=lambda: str(uuid.uuid4())) class_id: int = -1 original_label: str = "" category: str = "" # e.g., "inst" type: str = "" # e.g., "ind" label: str = "" # e.g., "Solenoid_actuator" bbox: BBox = None confidence: float = 0.95 model_source: str = "" # e.g., "model2" connections: List[Point] = field(default_factory=list) subtype: Optional[ValveSubtype] = None model_metadata: dict = field(default_factory=dict) def __post_init__(self): """ Handle any additional post-processing after initialization. """ # Ensure bbox is a BBox object if isinstance(self.bbox, list) and len(self.bbox) == 4: self.bbox = BBox(*self.bbox) # ======================== Line ======================== # @dataclass class LineStyle: connection_type: ConnectionType stroke_width: int = 2 color: str = "#000000" # CSS-style colors @dataclass class Line: start: Point end: Point bbox: BBox id: str = field(default_factory=lambda: str(uuid.uuid4())) style: LineStyle = field(default_factory=lambda: LineStyle(ConnectionType.SOLID)) confidence: float = 0.90 topological_links: List[str] = field(default_factory=list) # Linked symbols/junctions # ======================== Junction ======================== # class JunctionType(str, Enum): T = "T" L = "L" END = "END" @dataclass class JunctionProperties: flow_direction: Optional[str] = None # "in", "out" pressure: Optional[float] = None # kPa @dataclass class Junction: center: Coordinates junction_type: JunctionType id: str = field(default_factory=lambda: str(uuid.uuid4())) properties: JunctionProperties = field(default_factory=JunctionProperties) connected_lines: List[str] = field(default_factory=list) # Line IDs # # ======================== Tag ======================== # # @dataclass # class Tag: # text: str # bbox: BBox # associated_element: str # ID of linked symbol/line # id: str = field(default_factory=lambda: str(uuid.uuid4())) # font_size: int = 12 # rotation: float = 0.0 # Degrees @dataclass class Tag: text: str bbox: BBox confidence: float = 1.0 source: str = "" # e.g., "easyocr" text_type: str = "Unknown" # e.g., "Unknown", could be something else later id: str = field(default_factory=lambda: str(uuid.uuid4())) associated_element: Optional[str] = None # ID of linked symbol/line (can be None) font_size: int = 12 rotation: float = 0.0 # Degrees def __post_init__(self): """ Ensure bbox is properly converted. """ if isinstance(self.bbox, list) and len(self.bbox) == 4: self.bbox = BBox(*self.bbox) # ---------------------------- # DETECTION CONTEXT # ---------------------------- @dataclass class DetectionContext: """ In-memory container for all detected elements (lines, points, symbols, junctions, tags). Each element is stored in a dict keyed by 'id' for quick lookup and update. """ lines: Dict[str, Line] = field(default_factory=dict) points: Dict[str, Point] = field(default_factory=dict) symbols: Dict[str, Symbol] = field(default_factory=dict) junctions: Dict[str, Junction] = field(default_factory=dict) tags: Dict[str, Tag] = field(default_factory=dict) # ------------------------- # 1) ADD / GET / REMOVE # ------------------------- def add_line(self, line: Line) -> None: self.lines[line.id] = line def get_line(self, line_id: str) -> Optional[Line]: return self.lines.get(line_id) def remove_line(self, line_id: str) -> None: self.lines.pop(line_id, None) def add_point(self, point: Point) -> None: self.points[point.id] = point def get_point(self, point_id: str) -> Optional[Point]: return self.points.get(point_id) def remove_point(self, point_id: str) -> None: self.points.pop(point_id, None) def add_symbol(self, symbol: Symbol) -> None: self.symbols[symbol.id] = symbol def get_symbol(self, symbol_id: str) -> Optional[Symbol]: return self.symbols.get(symbol_id) def remove_symbol(self, symbol_id: str) -> None: self.symbols.pop(symbol_id, None) def add_junction(self, junction: Junction) -> None: self.junctions[junction.id] = junction def get_junction(self, junction_id: str) -> Optional[Junction]: return self.junctions.get(junction_id) def remove_junction(self, junction_id: str) -> None: self.junctions.pop(junction_id, None) def add_tag(self, tag: Tag) -> None: self.tags[tag.id] = tag def get_tag(self, tag_id: str) -> Optional[Tag]: return self.tags.get(tag_id) def remove_tag(self, tag_id: str) -> None: self.tags.pop(tag_id, None) # ------------------------- # 2) SERIALIZATION: to_dict / from_dict # ------------------------- def to_dict(self) -> dict: """Convert all stored objects into a JSON-serializable dictionary.""" return { "lines": [self._line_to_dict(line) for line in self.lines.values()], "points": [self._point_to_dict(pt) for pt in self.points.values()], "symbols": [self._symbol_to_dict(sym) for sym in self.symbols.values()], "junctions": [self._junction_to_dict(jn) for jn in self.junctions.values()], "tags": [self._tag_to_dict(tg) for tg in self.tags.values()] } @classmethod def from_dict(cls, data: dict) -> "DetectionContext": """ Create a new DetectionContext from a dictionary structure (e.g. loaded from JSON). """ context = cls() # Points for pt_dict in data.get("points", []): pt_obj = cls._point_from_dict(pt_dict) context.add_point(pt_obj) # Lines for ln_dict in data.get("lines", []): ln_obj = cls._line_from_dict(ln_dict) context.add_line(ln_obj) # Symbols for sym_dict in data.get("symbols", []): sym_obj = cls._symbol_from_dict(sym_dict) context.add_symbol(sym_obj) # Junctions for jn_dict in data.get("junctions", []): jn_obj = cls._junction_from_dict(jn_dict) context.add_junction(jn_obj) # Tags for tg_dict in data.get("tags", []): tg_obj = cls._tag_from_dict(tg_dict) context.add_tag(tg_obj) return context # ------------------------- # 3) HELPER METHODS FOR (DE)SERIALIZATION # ------------------------- @staticmethod def _bbox_to_dict(bbox: BBox) -> dict: return { "xmin": bbox.xmin, "ymin": bbox.ymin, "xmax": bbox.xmax, "ymax": bbox.ymax } @staticmethod def _bbox_from_dict(d: dict) -> BBox: return BBox( xmin=d["xmin"], ymin=d["ymin"], xmax=d["xmax"], ymax=d["ymax"] ) @staticmethod def _coords_to_dict(coords: Coordinates) -> dict: return { "x": coords.x, "y": coords.y } @staticmethod def _coords_from_dict(d: dict) -> Coordinates: return Coordinates(x=d["x"], y=d["y"]) @staticmethod def _line_style_to_dict(style: LineStyle) -> dict: return { "connection_type": style.connection_type.value, "stroke_width": style.stroke_width, "color": style.color } @staticmethod def _line_style_from_dict(d: dict) -> LineStyle: return LineStyle( connection_type=ConnectionType(d["connection_type"]), stroke_width=d.get("stroke_width", 2), color=d.get("color", "#000000") ) @staticmethod def _point_to_dict(pt: Point) -> dict: return { "id": pt.id, "coords": DetectionContext._coords_to_dict(pt.coords), "bbox": DetectionContext._bbox_to_dict(pt.bbox), "type": pt.type.value, "confidence": pt.confidence } @staticmethod def _point_from_dict(d: dict) -> Point: return Point( id=d["id"], coords=DetectionContext._coords_from_dict(d["coords"]), bbox=DetectionContext._bbox_from_dict(d["bbox"]), type=JunctionType(d["type"]), confidence=d.get("confidence", 1.0) ) @staticmethod def _line_to_dict(ln: Line) -> dict: return { "id": ln.id, "start": DetectionContext._point_to_dict(ln.start), "end": DetectionContext._point_to_dict(ln.end), "bbox": DetectionContext._bbox_to_dict(ln.bbox), "style": DetectionContext._line_style_to_dict(ln.style), "confidence": ln.confidence, "topological_links": ln.topological_links } @staticmethod def _line_from_dict(d: dict) -> Line: return Line( id=d["id"], start=DetectionContext._point_from_dict(d["start"]), end=DetectionContext._point_from_dict(d["end"]), bbox=DetectionContext._bbox_from_dict(d["bbox"]), style=DetectionContext._line_style_from_dict(d["style"]), confidence=d.get("confidence", 0.90), topological_links=d.get("topological_links", []) ) @staticmethod def _symbol_to_dict(sym: Symbol) -> dict: return { "id": sym.id, "symbol_type": sym.symbol_type.value, "bbox": DetectionContext._bbox_to_dict(sym.bbox), "center": DetectionContext._coords_to_dict(sym.center), "connections": [DetectionContext._point_to_dict(p) for p in sym.connections], "subtype": sym.subtype.value if sym.subtype else None, "confidence": sym.confidence, "model_metadata": sym.model_metadata } @staticmethod def _symbol_from_dict(d: dict) -> Symbol: return Symbol( id=d["id"], symbol_type=SymbolType(d["symbol_type"]), bbox=DetectionContext._bbox_from_dict(d["bbox"]), center=DetectionContext._coords_from_dict(d["center"]), connections=[DetectionContext._point_from_dict(p) for p in d.get("connections", [])], subtype=ValveSubtype(d["subtype"]) if d.get("subtype") else None, confidence=d.get("confidence", 0.95), model_metadata=d.get("model_metadata", {}) ) @staticmethod def _junction_props_to_dict(props: JunctionProperties) -> dict: return { "flow_direction": props.flow_direction, "pressure": props.pressure } @staticmethod def _junction_props_from_dict(d: dict) -> JunctionProperties: return JunctionProperties( flow_direction=d.get("flow_direction"), pressure=d.get("pressure") ) @staticmethod def _junction_to_dict(jn: Junction) -> dict: return { "id": jn.id, "center": DetectionContext._coords_to_dict(jn.center), "junction_type": jn.junction_type.value, "properties": DetectionContext._junction_props_to_dict(jn.properties), "connected_lines": jn.connected_lines } @staticmethod def _junction_from_dict(d: dict) -> Junction: return Junction( id=d["id"], center=DetectionContext._coords_from_dict(d["center"]), junction_type=JunctionType(d["junction_type"]), properties=DetectionContext._junction_props_from_dict(d["properties"]), connected_lines=d.get("connected_lines", []) ) @staticmethod def _tag_to_dict(tg: Tag) -> dict: return { "id": tg.id, "text": tg.text, "bbox": DetectionContext._bbox_to_dict(tg.bbox), "associated_element": tg.associated_element, "font_size": tg.font_size, "rotation": tg.rotation } @staticmethod def _tag_from_dict(d: dict) -> Tag: return Tag( id=d["id"], text=d["text"], bbox=DetectionContext._bbox_from_dict(d["bbox"]), associated_element=d["associated_element"], font_size=d.get("font_size", 12), rotation=d.get("rotation", 0.0) ) # ------------------------- # 4) OPTIONAL UTILS # ------------------------- def to_json(self, indent: int = 2) -> str: """Convert context to JSON, ensuring dataclasses and numpy types are handled correctly.""" return json.dumps(self.to_dict(), default=self._json_serializer, indent=indent) @staticmethod def _json_serializer(obj): """Handles numpy types and unknown objects for JSON serialization.""" if isinstance(obj, np.integer): return int(obj) if isinstance(obj, np.floating): return float(obj) if isinstance(obj, np.ndarray): return obj.tolist() # Convert arrays to lists if isinstance(obj, Enum): return obj.value # Convert Enums to string values if hasattr(obj, "__dict__"): return obj.__dict__ # Convert dataclass objects to dict raise TypeError(f"Object of type {type(obj)} is not JSON serializable") @classmethod def from_json(cls, json_str: str) -> "DetectionContext": """Load DetectionContext from a JSON string.""" data = json.loads(json_str) return cls.from_dict(data)