File size: 4,817 Bytes
3f1e63e
9847531
 
 
 
3f1e63e
 
 
 
 
9847531
d3ea93c
 
 
9847531
3f1e63e
 
 
 
 
 
 
 
 
 
 
 
 
 
589b914
c2516d1
589b914
391393a
589b914
 
391393a
589b914
9847531
589b914
c2516d1
 
 
 
66da89f
589b914
b76ce91
7f22c74
 
 
 
 
9847531
7f22c74
 
 
 
9847531
7f22c74
 
 
 
 
9847531
 
7f22c74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b76ce91
 
 
 
391393a
b76ce91
 
391393a
b76ce91
 
7f22c74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# Standard library imports first
import os
import math
import json
import uuid
from abc import ABC, abstractmethod
from dataclasses import replace
from math import sqrt
from pathlib import Path
from typing import List, Optional, Tuple, Dict, Any

# Get logger before any other imports
from logger import get_logger
logger = get_logger(__name__)

# Third-party imports
import cv2
import numpy as np
import torch
from PIL import Image
import matplotlib.pyplot as plt
from skimage.morphology import skeletonize
from skimage.measure import label
from ultralytics import YOLO

# Local imports
from storage import StorageInterface
from base import BaseDetector
from config import SymbolConfig, TagConfig, LineConfig, PointConfig, JunctionConfig
from line_detectors import OpenCVLineDetector, DEEPLSD_AVAILABLE

# Try to import DeepLSD, but don't fail if not available
try:
    from line_detectors import DeepLSDDetector
    logger.info("Successfully imported DeepLSD")
except ImportError as e:
    logger.warning(f"DeepLSD import failed: {str(e)}. Will use OpenCV fallback.")

# Detection schema imports
from detection_schema import (
    BBox, Coordinates, Point, Line, Symbol, Tag,
    SymbolType, LineStyle, ConnectionType, JunctionType, Junction
)

# Rest of the classes...

class Detector(ABC):
    """Base class for all detectors"""
    
    def __init__(self, config: Any, debug_handler=None):
        self.config = config
        self.debug_handler = debug_handler
        
    @abstractmethod
    def detect(self, image: np.ndarray) -> Dict:
        """Perform detection on the image"""
        pass
        
    def save_debug_image(self, image: np.ndarray, filename: str):
        """Save debug visualization if debug handler is available"""
        if self.debug_handler:
            self.debug_handler.save_image(image, filename)


class SymbolDetector(Detector):
    """Detector for symbols in P&ID diagrams"""
    
    def __init__(self, config, debug_handler=None):
        super().__init__(config, debug_handler)
        self.models = {}
        for name, path in config.model_paths.items():
            if os.path.exists(path):
                self.models[name] = YOLO(path)
            else:
                logger.warning(f"Model not found at {path}")

    def detect(self, image: np.ndarray) -> Dict:
        """Detect symbols using multiple YOLO models"""
        results = []
        
        # Process with each model
        for model_name, model in self.models.items():
            model_results = model(image, conf=self.config.confidence_threshold)[0]
            boxes = model_results.boxes
            
            for box in boxes:
                x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
                conf = box.conf[0].cpu().numpy()
                cls = box.cls[0].cpu().numpy()
                cls_name = model_results.names[int(cls)]
                
                results.append({
                    'bbox': [float(x1), float(y1), float(x2), float(y2)],
                    'confidence': float(conf),
                    'class': cls_name,
                    'model': model_name
                })
        
        return {'detections': results}


class TagDetector(Detector):
    """Detector for text tags in P&ID diagrams"""
    
    def __init__(self, config, debug_handler=None):
        super().__init__(config, debug_handler)
        self.ocr = None  # Initialize OCR engine here
        
    def detect(self, image: np.ndarray) -> Dict:
        """Detect and recognize text tags"""
        # Implement text detection logic
        return {'detections': []}


class LineDetector(Detector):
    """Detector for lines in P&ID diagrams"""
    
    def __init__(self, config, model_path=None, model_config=None, device='cpu', debug_handler=None):
        super().__init__(config, debug_handler)
        # Try to use DeepLSD if available, otherwise fall back to OpenCV
        if DEEPLSD_AVAILABLE and model_path:
            self.detector = DeepLSDDetector(model_path)
            logger.info("Using DeepLSD for line detection")
        else:
            self.detector = OpenCVLineDetector()
            logger.info("Using OpenCV for line detection")
            
    def detect(self, image: np.ndarray) -> Dict:
        return self.detector.detect(image)


class PointDetector(Detector):
    """Detector for connection points in P&ID diagrams"""
    
    def detect(self, image: np.ndarray) -> Dict:
        """Detect connection points"""
        # Implement point detection logic
        return {'detections': []}


class JunctionDetector(Detector):
    """Detector for line junctions in P&ID diagrams"""
    
    def detect(self, image: np.ndarray) -> Dict:
        """Detect line junctions"""
        # Implement junction detection logic
        return {'detections': []}