intelligent-pid / detectors.py
msIntui
Add centralized logging configuration
d3ea93c
# Standard library imports first
import os
import math
import json
import uuid
from abc import ABC, abstractmethod
from dataclasses import replace
from math import sqrt
from pathlib import Path
from typing import List, Optional, Tuple, Dict, Any
# Get logger before any other imports
from logger import get_logger
logger = get_logger(__name__)
# Third-party imports
import cv2
import numpy as np
import torch
from PIL import Image
import matplotlib.pyplot as plt
from skimage.morphology import skeletonize
from skimage.measure import label
from ultralytics import YOLO
# Local imports
from storage import StorageInterface
from base import BaseDetector
from config import SymbolConfig, TagConfig, LineConfig, PointConfig, JunctionConfig
from line_detectors import OpenCVLineDetector, DEEPLSD_AVAILABLE
# Try to import DeepLSD, but don't fail if not available
try:
from line_detectors import DeepLSDDetector
logger.info("Successfully imported DeepLSD")
except ImportError as e:
logger.warning(f"DeepLSD import failed: {str(e)}. Will use OpenCV fallback.")
# Detection schema imports
from detection_schema import (
BBox, Coordinates, Point, Line, Symbol, Tag,
SymbolType, LineStyle, ConnectionType, JunctionType, Junction
)
# Rest of the classes...
class Detector(ABC):
"""Base class for all detectors"""
def __init__(self, config: Any, debug_handler=None):
self.config = config
self.debug_handler = debug_handler
@abstractmethod
def detect(self, image: np.ndarray) -> Dict:
"""Perform detection on the image"""
pass
def save_debug_image(self, image: np.ndarray, filename: str):
"""Save debug visualization if debug handler is available"""
if self.debug_handler:
self.debug_handler.save_image(image, filename)
class SymbolDetector(Detector):
"""Detector for symbols in P&ID diagrams"""
def __init__(self, config, debug_handler=None):
super().__init__(config, debug_handler)
self.models = {}
for name, path in config.model_paths.items():
if os.path.exists(path):
self.models[name] = YOLO(path)
else:
logger.warning(f"Model not found at {path}")
def detect(self, image: np.ndarray) -> Dict:
"""Detect symbols using multiple YOLO models"""
results = []
# Process with each model
for model_name, model in self.models.items():
model_results = model(image, conf=self.config.confidence_threshold)[0]
boxes = model_results.boxes
for box in boxes:
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
conf = box.conf[0].cpu().numpy()
cls = box.cls[0].cpu().numpy()
cls_name = model_results.names[int(cls)]
results.append({
'bbox': [float(x1), float(y1), float(x2), float(y2)],
'confidence': float(conf),
'class': cls_name,
'model': model_name
})
return {'detections': results}
class TagDetector(Detector):
"""Detector for text tags in P&ID diagrams"""
def __init__(self, config, debug_handler=None):
super().__init__(config, debug_handler)
self.ocr = None # Initialize OCR engine here
def detect(self, image: np.ndarray) -> Dict:
"""Detect and recognize text tags"""
# Implement text detection logic
return {'detections': []}
class LineDetector(Detector):
"""Detector for lines in P&ID diagrams"""
def __init__(self, config, model_path=None, model_config=None, device='cpu', debug_handler=None):
super().__init__(config, debug_handler)
# Try to use DeepLSD if available, otherwise fall back to OpenCV
if DEEPLSD_AVAILABLE and model_path:
self.detector = DeepLSDDetector(model_path)
logger.info("Using DeepLSD for line detection")
else:
self.detector = OpenCVLineDetector()
logger.info("Using OpenCV for line detection")
def detect(self, image: np.ndarray) -> Dict:
return self.detector.detect(image)
class PointDetector(Detector):
"""Detector for connection points in P&ID diagrams"""
def detect(self, image: np.ndarray) -> Dict:
"""Detect connection points"""
# Implement point detection logic
return {'detections': []}
class JunctionDetector(Detector):
"""Detector for line junctions in P&ID diagrams"""
def detect(self, image: np.ndarray) -> Dict:
"""Detect line junctions"""
# Implement junction detection logic
return {'detections': []}