# Standard library imports first import os import math import json import uuid from abc import ABC, abstractmethod from dataclasses import replace from math import sqrt from pathlib import Path from typing import List, Optional, Tuple, Dict, Any # Get logger before any other imports from logger import get_logger logger = get_logger(__name__) # Third-party imports import cv2 import numpy as np import torch from PIL import Image import matplotlib.pyplot as plt from skimage.morphology import skeletonize from skimage.measure import label from ultralytics import YOLO # Local imports from storage import StorageInterface from base import BaseDetector from config import SymbolConfig, TagConfig, LineConfig, PointConfig, JunctionConfig from line_detectors import OpenCVLineDetector, DEEPLSD_AVAILABLE # Try to import DeepLSD, but don't fail if not available try: from line_detectors import DeepLSDDetector logger.info("Successfully imported DeepLSD") except ImportError as e: logger.warning(f"DeepLSD import failed: {str(e)}. Will use OpenCV fallback.") # Detection schema imports from detection_schema import ( BBox, Coordinates, Point, Line, Symbol, Tag, SymbolType, LineStyle, ConnectionType, JunctionType, Junction ) # Rest of the classes... class Detector(ABC): """Base class for all detectors""" def __init__(self, config: Any, debug_handler=None): self.config = config self.debug_handler = debug_handler @abstractmethod def detect(self, image: np.ndarray) -> Dict: """Perform detection on the image""" pass def save_debug_image(self, image: np.ndarray, filename: str): """Save debug visualization if debug handler is available""" if self.debug_handler: self.debug_handler.save_image(image, filename) class SymbolDetector(Detector): """Detector for symbols in P&ID diagrams""" def __init__(self, config, debug_handler=None): super().__init__(config, debug_handler) self.models = {} for name, path in config.model_paths.items(): if os.path.exists(path): self.models[name] = YOLO(path) else: logger.warning(f"Model not found at {path}") def detect(self, image: np.ndarray) -> Dict: """Detect symbols using multiple YOLO models""" results = [] # Process with each model for model_name, model in self.models.items(): model_results = model(image, conf=self.config.confidence_threshold)[0] boxes = model_results.boxes for box in boxes: x1, y1, x2, y2 = box.xyxy[0].cpu().numpy() conf = box.conf[0].cpu().numpy() cls = box.cls[0].cpu().numpy() cls_name = model_results.names[int(cls)] results.append({ 'bbox': [float(x1), float(y1), float(x2), float(y2)], 'confidence': float(conf), 'class': cls_name, 'model': model_name }) return {'detections': results} class TagDetector(Detector): """Detector for text tags in P&ID diagrams""" def __init__(self, config, debug_handler=None): super().__init__(config, debug_handler) self.ocr = None # Initialize OCR engine here def detect(self, image: np.ndarray) -> Dict: """Detect and recognize text tags""" # Implement text detection logic return {'detections': []} class LineDetector(Detector): """Detector for lines in P&ID diagrams""" def __init__(self, config, model_path=None, model_config=None, device='cpu', debug_handler=None): super().__init__(config, debug_handler) # Try to use DeepLSD if available, otherwise fall back to OpenCV if DEEPLSD_AVAILABLE and model_path: self.detector = DeepLSDDetector(model_path) logger.info("Using DeepLSD for line detection") else: self.detector = OpenCVLineDetector() logger.info("Using OpenCV for line detection") def detect(self, image: np.ndarray) -> Dict: return self.detector.detect(image) class PointDetector(Detector): """Detector for connection points in P&ID diagrams""" def detect(self, image: np.ndarray) -> Dict: """Detect connection points""" # Implement point detection logic return {'detections': []} class JunctionDetector(Detector): """Detector for line junctions in P&ID diagrams""" def detect(self, image: np.ndarray) -> Dict: """Detect line junctions""" # Implement junction detection logic return {'detections': []}