Spaces:
Build error
Build error
| from dataclasses import dataclass, field | |
| from typing import List, Optional, Tuple, Dict | |
| import uuid | |
| from enum import Enum | |
| import json | |
| import numpy as np | |
| # ======================== Point ======================== # | |
| class ConnectionType(Enum): | |
| SOLID = "solid" | |
| DASHED = "dashed" | |
| PHANTOM = "phantom" | |
| class Coordinates: | |
| x: int | |
| y: int | |
| class BBox: | |
| xmin: int | |
| ymin: int | |
| xmax: int | |
| ymax: int | |
| def width(self) -> int: | |
| return self.xmax - self.xmin | |
| def height(self) -> int: | |
| return self.ymax - self.ymin | |
| class JunctionType(str, Enum): | |
| T = "T" | |
| L = "L" | |
| END = "END" | |
| class Point: | |
| coords: Coordinates | |
| bbox: BBox | |
| type: JunctionType | |
| confidence: float = 1.0 | |
| id: str = field(default_factory=lambda: str(uuid.uuid4())) | |
| # # ======================== Symbol ======================== # | |
| # class SymbolType(Enum): | |
| # VALVE = "valve" | |
| # PUMP = "pump" | |
| # SENSOR = "sensor" | |
| # # Add others as needed | |
| # | |
| class ValveSubtype(Enum): | |
| GATE = "gate" | |
| GLOBE = "globe" | |
| BUTTERFLY = "butterfly" | |
| # | |
| # @dataclass | |
| # class Symbol: | |
| # symbol_type: SymbolType | |
| # bbox: BBox | |
| # center: Coordinates | |
| # connections: List[Point] = field(default_factory=list) | |
| # subtype: Optional[ValveSubtype] = None | |
| # id: str = field(default_factory=lambda: str(uuid.uuid4())) | |
| # confidence: float = 0.95 | |
| # model_metadata: dict = field(default_factory=dict) | |
| # ======================== Symbol ======================== # | |
| class SymbolType(Enum): | |
| VALVE = "valve" | |
| PUMP = "pump" | |
| SENSOR = "sensor" | |
| OTHER = "other" # Added to handle unknown categories | |
| class Symbol: | |
| center: Coordinates | |
| symbol_type: SymbolType = field(default=SymbolType.OTHER) | |
| id: str = field(default_factory=lambda: str(uuid.uuid4())) | |
| class_id: int = -1 | |
| original_label: str = "" | |
| category: str = "" # e.g., "inst" | |
| type: str = "" # e.g., "ind" | |
| label: str = "" # e.g., "Solenoid_actuator" | |
| bbox: BBox = None | |
| confidence: float = 0.95 | |
| model_source: str = "" # e.g., "model2" | |
| connections: List[Point] = field(default_factory=list) | |
| subtype: Optional[ValveSubtype] = None | |
| model_metadata: dict = field(default_factory=dict) | |
| def __post_init__(self): | |
| """ | |
| Handle any additional post-processing after initialization. | |
| """ | |
| # Ensure bbox is a BBox object | |
| if isinstance(self.bbox, list) and len(self.bbox) == 4: | |
| self.bbox = BBox(*self.bbox) | |
| # ======================== Line ======================== # | |
| class LineStyle: | |
| connection_type: ConnectionType | |
| stroke_width: int = 2 | |
| color: str = "#000000" # CSS-style colors | |
| class Line: | |
| start: Point | |
| end: Point | |
| bbox: BBox | |
| id: str = field(default_factory=lambda: str(uuid.uuid4())) | |
| style: LineStyle = field(default_factory=lambda: LineStyle(ConnectionType.SOLID)) | |
| confidence: float = 0.90 | |
| topological_links: List[str] = field(default_factory=list) # Linked symbols/junctions | |
| # ======================== Junction ======================== # | |
| class JunctionType(str, Enum): | |
| T = "T" | |
| L = "L" | |
| END = "END" | |
| class JunctionProperties: | |
| flow_direction: Optional[str] = None # "in", "out" | |
| pressure: Optional[float] = None # kPa | |
| class Junction: | |
| center: Coordinates | |
| junction_type: JunctionType | |
| id: str = field(default_factory=lambda: str(uuid.uuid4())) | |
| properties: JunctionProperties = field(default_factory=JunctionProperties) | |
| connected_lines: List[str] = field(default_factory=list) # Line IDs | |
| # # ======================== Tag ======================== # | |
| # @dataclass | |
| # class Tag: | |
| # text: str | |
| # bbox: BBox | |
| # associated_element: str # ID of linked symbol/line | |
| # id: str = field(default_factory=lambda: str(uuid.uuid4())) | |
| # font_size: int = 12 | |
| # rotation: float = 0.0 # Degrees | |
| class Tag: | |
| text: str | |
| bbox: BBox | |
| confidence: float = 1.0 | |
| source: str = "" # e.g., "easyocr" | |
| text_type: str = "Unknown" # e.g., "Unknown", could be something else later | |
| id: str = field(default_factory=lambda: str(uuid.uuid4())) | |
| associated_element: Optional[str] = None # ID of linked symbol/line (can be None) | |
| font_size: int = 12 | |
| rotation: float = 0.0 # Degrees | |
| def __post_init__(self): | |
| """ | |
| Ensure bbox is properly converted. | |
| """ | |
| if isinstance(self.bbox, list) and len(self.bbox) == 4: | |
| self.bbox = BBox(*self.bbox) | |
| # ---------------------------- | |
| # DETECTION CONTEXT | |
| # ---------------------------- | |
| class DetectionContext: | |
| """ | |
| In-memory container for all detected elements (lines, points, symbols, junctions, tags). | |
| Each element is stored in a dict keyed by 'id' for quick lookup and update. | |
| """ | |
| lines: Dict[str, Line] = field(default_factory=dict) | |
| points: Dict[str, Point] = field(default_factory=dict) | |
| symbols: Dict[str, Symbol] = field(default_factory=dict) | |
| junctions: Dict[str, Junction] = field(default_factory=dict) | |
| tags: Dict[str, Tag] = field(default_factory=dict) | |
| # ------------------------- | |
| # 1) ADD / GET / REMOVE | |
| # ------------------------- | |
| def add_line(self, line: Line) -> None: | |
| self.lines[line.id] = line | |
| def get_line(self, line_id: str) -> Optional[Line]: | |
| return self.lines.get(line_id) | |
| def remove_line(self, line_id: str) -> None: | |
| self.lines.pop(line_id, None) | |
| def add_point(self, point: Point) -> None: | |
| self.points[point.id] = point | |
| def get_point(self, point_id: str) -> Optional[Point]: | |
| return self.points.get(point_id) | |
| def remove_point(self, point_id: str) -> None: | |
| self.points.pop(point_id, None) | |
| def add_symbol(self, symbol: Symbol) -> None: | |
| self.symbols[symbol.id] = symbol | |
| def get_symbol(self, symbol_id: str) -> Optional[Symbol]: | |
| return self.symbols.get(symbol_id) | |
| def remove_symbol(self, symbol_id: str) -> None: | |
| self.symbols.pop(symbol_id, None) | |
| def add_junction(self, junction: Junction) -> None: | |
| self.junctions[junction.id] = junction | |
| def get_junction(self, junction_id: str) -> Optional[Junction]: | |
| return self.junctions.get(junction_id) | |
| def remove_junction(self, junction_id: str) -> None: | |
| self.junctions.pop(junction_id, None) | |
| def add_tag(self, tag: Tag) -> None: | |
| self.tags[tag.id] = tag | |
| def get_tag(self, tag_id: str) -> Optional[Tag]: | |
| return self.tags.get(tag_id) | |
| def remove_tag(self, tag_id: str) -> None: | |
| self.tags.pop(tag_id, None) | |
| # ------------------------- | |
| # 2) SERIALIZATION: to_dict / from_dict | |
| # ------------------------- | |
| def to_dict(self) -> dict: | |
| """Convert all stored objects into a JSON-serializable dictionary.""" | |
| return { | |
| "lines": [self._line_to_dict(line) for line in self.lines.values()], | |
| "points": [self._point_to_dict(pt) for pt in self.points.values()], | |
| "symbols": [self._symbol_to_dict(sym) for sym in self.symbols.values()], | |
| "junctions": [self._junction_to_dict(jn) for jn in self.junctions.values()], | |
| "tags": [self._tag_to_dict(tg) for tg in self.tags.values()] | |
| } | |
| def from_dict(cls, data: dict) -> "DetectionContext": | |
| """ | |
| Create a new DetectionContext from a dictionary structure (e.g. loaded from JSON). | |
| """ | |
| context = cls() | |
| # Points | |
| for pt_dict in data.get("points", []): | |
| pt_obj = cls._point_from_dict(pt_dict) | |
| context.add_point(pt_obj) | |
| # Lines | |
| for ln_dict in data.get("lines", []): | |
| ln_obj = cls._line_from_dict(ln_dict) | |
| context.add_line(ln_obj) | |
| # Symbols | |
| for sym_dict in data.get("symbols", []): | |
| sym_obj = cls._symbol_from_dict(sym_dict) | |
| context.add_symbol(sym_obj) | |
| # Junctions | |
| for jn_dict in data.get("junctions", []): | |
| jn_obj = cls._junction_from_dict(jn_dict) | |
| context.add_junction(jn_obj) | |
| # Tags | |
| for tg_dict in data.get("tags", []): | |
| tg_obj = cls._tag_from_dict(tg_dict) | |
| context.add_tag(tg_obj) | |
| return context | |
| # ------------------------- | |
| # 3) HELPER METHODS FOR (DE)SERIALIZATION | |
| # ------------------------- | |
| def _bbox_to_dict(bbox: BBox) -> dict: | |
| return { | |
| "xmin": bbox.xmin, | |
| "ymin": bbox.ymin, | |
| "xmax": bbox.xmax, | |
| "ymax": bbox.ymax | |
| } | |
| def _bbox_from_dict(d: dict) -> BBox: | |
| return BBox( | |
| xmin=d["xmin"], | |
| ymin=d["ymin"], | |
| xmax=d["xmax"], | |
| ymax=d["ymax"] | |
| ) | |
| def _coords_to_dict(coords: Coordinates) -> dict: | |
| return { | |
| "x": coords.x, | |
| "y": coords.y | |
| } | |
| def _coords_from_dict(d: dict) -> Coordinates: | |
| return Coordinates(x=d["x"], y=d["y"]) | |
| def _line_style_to_dict(style: LineStyle) -> dict: | |
| return { | |
| "connection_type": style.connection_type.value, | |
| "stroke_width": style.stroke_width, | |
| "color": style.color | |
| } | |
| def _line_style_from_dict(d: dict) -> LineStyle: | |
| return LineStyle( | |
| connection_type=ConnectionType(d["connection_type"]), | |
| stroke_width=d.get("stroke_width", 2), | |
| color=d.get("color", "#000000") | |
| ) | |
| def _point_to_dict(pt: Point) -> dict: | |
| return { | |
| "id": pt.id, | |
| "coords": DetectionContext._coords_to_dict(pt.coords), | |
| "bbox": DetectionContext._bbox_to_dict(pt.bbox), | |
| "type": pt.type.value, | |
| "confidence": pt.confidence | |
| } | |
| def _point_from_dict(d: dict) -> Point: | |
| return Point( | |
| id=d["id"], | |
| coords=DetectionContext._coords_from_dict(d["coords"]), | |
| bbox=DetectionContext._bbox_from_dict(d["bbox"]), | |
| type=JunctionType(d["type"]), | |
| confidence=d.get("confidence", 1.0) | |
| ) | |
| def _line_to_dict(ln: Line) -> dict: | |
| return { | |
| "id": ln.id, | |
| "start": DetectionContext._point_to_dict(ln.start), | |
| "end": DetectionContext._point_to_dict(ln.end), | |
| "bbox": DetectionContext._bbox_to_dict(ln.bbox), | |
| "style": DetectionContext._line_style_to_dict(ln.style), | |
| "confidence": ln.confidence, | |
| "topological_links": ln.topological_links | |
| } | |
| def _line_from_dict(d: dict) -> Line: | |
| return Line( | |
| id=d["id"], | |
| start=DetectionContext._point_from_dict(d["start"]), | |
| end=DetectionContext._point_from_dict(d["end"]), | |
| bbox=DetectionContext._bbox_from_dict(d["bbox"]), | |
| style=DetectionContext._line_style_from_dict(d["style"]), | |
| confidence=d.get("confidence", 0.90), | |
| topological_links=d.get("topological_links", []) | |
| ) | |
| def _symbol_to_dict(sym: Symbol) -> dict: | |
| return { | |
| "id": sym.id, | |
| "symbol_type": sym.symbol_type.value, | |
| "bbox": DetectionContext._bbox_to_dict(sym.bbox), | |
| "center": DetectionContext._coords_to_dict(sym.center), | |
| "connections": [DetectionContext._point_to_dict(p) for p in sym.connections], | |
| "subtype": sym.subtype.value if sym.subtype else None, | |
| "confidence": sym.confidence, | |
| "model_metadata": sym.model_metadata | |
| } | |
| def _symbol_from_dict(d: dict) -> Symbol: | |
| return Symbol( | |
| id=d["id"], | |
| symbol_type=SymbolType(d["symbol_type"]), | |
| bbox=DetectionContext._bbox_from_dict(d["bbox"]), | |
| center=DetectionContext._coords_from_dict(d["center"]), | |
| connections=[DetectionContext._point_from_dict(p) for p in d.get("connections", [])], | |
| subtype=ValveSubtype(d["subtype"]) if d.get("subtype") else None, | |
| confidence=d.get("confidence", 0.95), | |
| model_metadata=d.get("model_metadata", {}) | |
| ) | |
| def _junction_props_to_dict(props: JunctionProperties) -> dict: | |
| return { | |
| "flow_direction": props.flow_direction, | |
| "pressure": props.pressure | |
| } | |
| def _junction_props_from_dict(d: dict) -> JunctionProperties: | |
| return JunctionProperties( | |
| flow_direction=d.get("flow_direction"), | |
| pressure=d.get("pressure") | |
| ) | |
| def _junction_to_dict(jn: Junction) -> dict: | |
| return { | |
| "id": jn.id, | |
| "center": DetectionContext._coords_to_dict(jn.center), | |
| "junction_type": jn.junction_type.value, | |
| "properties": DetectionContext._junction_props_to_dict(jn.properties), | |
| "connected_lines": jn.connected_lines | |
| } | |
| def _junction_from_dict(d: dict) -> Junction: | |
| return Junction( | |
| id=d["id"], | |
| center=DetectionContext._coords_from_dict(d["center"]), | |
| junction_type=JunctionType(d["junction_type"]), | |
| properties=DetectionContext._junction_props_from_dict(d["properties"]), | |
| connected_lines=d.get("connected_lines", []) | |
| ) | |
| def _tag_to_dict(tg: Tag) -> dict: | |
| return { | |
| "id": tg.id, | |
| "text": tg.text, | |
| "bbox": DetectionContext._bbox_to_dict(tg.bbox), | |
| "associated_element": tg.associated_element, | |
| "font_size": tg.font_size, | |
| "rotation": tg.rotation | |
| } | |
| def _tag_from_dict(d: dict) -> Tag: | |
| return Tag( | |
| id=d["id"], | |
| text=d["text"], | |
| bbox=DetectionContext._bbox_from_dict(d["bbox"]), | |
| associated_element=d["associated_element"], | |
| font_size=d.get("font_size", 12), | |
| rotation=d.get("rotation", 0.0) | |
| ) | |
| # ------------------------- | |
| # 4) OPTIONAL UTILS | |
| # ------------------------- | |
| def to_json(self, indent: int = 2) -> str: | |
| """Convert context to JSON, ensuring dataclasses and numpy types are handled correctly.""" | |
| return json.dumps(self.to_dict(), default=self._json_serializer, indent=indent) | |
| def _json_serializer(obj): | |
| """Handles numpy types and unknown objects for JSON serialization.""" | |
| if isinstance(obj, np.integer): | |
| return int(obj) | |
| if isinstance(obj, np.floating): | |
| return float(obj) | |
| if isinstance(obj, np.ndarray): | |
| return obj.tolist() # Convert arrays to lists | |
| if isinstance(obj, Enum): | |
| return obj.value # Convert Enums to string values | |
| if hasattr(obj, "__dict__"): | |
| return obj.__dict__ # Convert dataclass objects to dict | |
| raise TypeError(f"Object of type {type(obj)} is not JSON serializable") | |
| def from_json(cls, json_str: str) -> "DetectionContext": | |
| """Load DetectionContext from a JSON string.""" | |
| data = json.loads(json_str) | |
| return cls.from_dict(data) |