# Standard library imports first import os from typing import List, Dict, Optional, Any # Get logger before any other imports from logger import get_logger logger = get_logger(__name__) # Third-party imports import cv2 import numpy as np # Local imports from base import BaseDetectionPipeline from config import ( LineConfig, ImageConfig, PointConfig, JunctionConfig, SymbolConfig, TagConfig ) from utils import DebugHandler from storage import StorageInterface # Import detectors after logging is configured from detectors import ( LineDetector, PointDetector, JunctionDetector, SymbolDetector, TagDetector ) class DiagramDetectionPipeline(BaseDetectionPipeline): """Main pipeline for processing P&ID diagrams""" def __init__( self, storage: StorageInterface, debug_handler: Optional[DebugHandler] = None ): super().__init__(storage, debug_handler) # Initialize detectors when needed self._line_detector = None self._point_detector = None self._junction_detector = None self._symbol_detector = None self._tag_detector = None def _load_image(self, image_path: str) -> np.ndarray: """Load image with validation.""" image_data = self.storage.load_file(image_path) image = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR) if image is None: raise ValueError(f"Failed to load image from {image_path}") return image def _crop_to_roi(self, image: np.ndarray, roi: Optional[list]) -> Tuple[np.ndarray, Tuple[int, int]]: """Crop to ROI if provided, else return full image.""" if roi is not None and len(roi) == 4: x_min, y_min, x_max, y_max = roi return image[y_min:y_max, x_min:x_max], (x_min, y_min) return image, (0, 0) def _remove_symbol_tag_bboxes(self, image: np.ndarray, context: DetectionContext) -> np.ndarray: """Fill symbol & tag bounding boxes with white to avoid line detection picking them up.""" masked = image.copy() for sym in context.symbols.values(): cv2.rectangle(masked, (sym.bbox.xmin, sym.bbox.ymin), (sym.bbox.xmax, sym.bbox.ymax), (255, 255, 255), # White thickness=-1) for tg in context.tags.values(): cv2.rectangle(masked, (tg.bbox.xmin, tg.bbox.ymin), (tg.bbox.xmax, tg.bbox.ymax), (255, 255, 255), thickness=-1) return masked def run( self, image_path: str, output_dir: str, config ) -> DetectionResult: """ Main pipeline steps (in local coords): 1) Load + crop image 2) Detect symbols & tags 3) Make a copy for final debug images 4) White out symbol/tag bounding boxes 5) Detect lines, points, junctions 6) Save final JSON 7) Generate debug images with various combinations """ try: with self.debug_handler.track_performance("total_processing"): # 1) Load & crop image = self._load_image(image_path) cropped_image, roi_offset = self._crop_to_roi(image, config.roi) # 2) Create fresh context context = DetectionContext() # 3) Detect symbols with self.debug_handler.track_performance("symbol_detection"): self.symbol_detector.detect( cropped_image, context=context, roi_offset=roi_offset ) # 4) Detect tags with self.debug_handler.track_performance("tag_detection"): self.tag_detector.detect( cropped_image, context=context, roi_offset=roi_offset ) # Make a copy of the cropped image for final debug combos debug_cropped = cropped_image.copy() # 5) White-out symbol/tag bboxes in the original cropped image cropped_image = self._remove_symbol_tag_bboxes(cropped_image, context) # 6) Detect lines with self.debug_handler.track_performance("line_detection"): self.line_detector.detect(cropped_image, context=context) # 7) Detect points if self.point_detector: with self.debug_handler.track_performance("point_detection"): self.point_detector.detect(cropped_image, context=context) # 8) Detect junctions if self.junction_detector: with self.debug_handler.track_performance("junction_detection"): self.junction_detector.detect(cropped_image, context=context) # 9) Save final JSON & any final images output_paths = self._persist_results(output_dir, image_path, context) # 10) Save debug images in local coords using debug_cropped self._save_all_combinations(debug_cropped, context, output_dir, image_path) return DetectionResult( success=True, processing_time=self.debug_handler.metrics.get('total_processing', 0), json_path=output_paths.get('json_path'), image_path=output_paths.get('image_path') # Now returning the annotated image path ) except Exception as e: logger.error(f"Processing failed: {str(e)}") return DetectionResult( success=False, error=str(e) ) # ------------------------------------------------ # HELPER FUNCTIONS # ------------------------------------------------ def _persist_results(self, output_dir: str, image_path: str, context: DetectionContext) -> dict: """Saves final JSON and debug images to disk.""" self.storage.create_directory(output_dir) base_name = Path(image_path).stem # Save JSON json_path = Path(output_dir) / f"{base_name}_detected_lines.json" context_json_str = context.to_json(indent=2) self.storage.save_file(str(json_path), context_json_str.encode('utf-8')) # Save annotated image for pipeline display annotated_image = self._draw_objects( self._load_image(image_path), context, draw_lines=True, draw_points=True, draw_symbols=True, draw_junctions=True, draw_tags=True ) image_path = Path(output_dir) / f"{base_name}_annotated.jpg" _, encoded = cv2.imencode('.jpg', annotated_image) self.storage.save_file(str(image_path), encoded.tobytes()) return { "json_path": str(json_path), "image_path": str(image_path) } def _save_all_combinations(self, local_image: np.ndarray, context: DetectionContext, output_dir: str, image_path: str) -> None: """Produce debug images with different combinations.""" base_name = Path(image_path).stem base_name = base_name.split("_")[0] combos = [ ("text_detected_symbols", dict(draw_symbols=True, draw_tags=False, draw_lines=False, draw_points=False, draw_junctions=False)), ("text_detected_texts", dict(draw_symbols=False, draw_tags=True, draw_lines=False, draw_points=False, draw_junctions=False)), ("text_detected_lines", dict(draw_symbols=False, draw_tags=False, draw_lines=True, draw_points=False, draw_junctions=False)), ] self.storage.create_directory(output_dir) for combo_name, flags in combos: annotated = self._draw_objects(local_image, context, **flags) save_name = f"{base_name}_{combo_name}.jpg" save_path = Path(output_dir) / save_name _, encoded = cv2.imencode('.jpg', annotated) self.storage.save_file(str(save_path), encoded.tobytes()) logger.info(f"Saved debug image: {save_path}") def _draw_objects(self, base_image: np.ndarray, context: DetectionContext, draw_lines: bool = True, draw_points: bool = True, draw_symbols: bool = True, draw_junctions: bool = True, draw_tags: bool = True) -> np.ndarray: """Draw detection results on a copy of base_image in local coords.""" annotated = base_image.copy() # Lines if draw_lines: for ln in context.lines.values(): cv2.line(annotated, (ln.start.coords.x, ln.start.coords.y), (ln.end.coords.x, ln.end.coords.y), (0, 255, 0), # green 2) # Points if draw_points: for pt in context.points.values(): cv2.circle(annotated, (pt.coords.x, pt.coords.y), 3, (0, 0, 255), # red -1) # Symbols if draw_symbols: for sym in context.symbols.values(): cv2.rectangle(annotated, (sym.bbox.xmin, sym.bbox.ymin), (sym.bbox.xmax, sym.bbox.ymax), (255, 255, 0), # cyan 2) cv2.circle(annotated, (sym.center.x, sym.center.y), 4, (255, 0, 255), # magenta -1) # Junctions if draw_junctions: for jn in context.junctions.values(): if jn.junction_type == JunctionType.T: color = (0, 165, 255) # orange elif jn.junction_type == JunctionType.L: color = (255, 0, 255) # magenta else: # END color = (0, 0, 255) # red cv2.circle(annotated, (jn.center.x, jn.center.y), 5, color, -1) # Tags if draw_tags: for tg in context.tags.values(): cv2.rectangle(annotated, (tg.bbox.xmin, tg.bbox.ymin), (tg.bbox.xmax, tg.bbox.ymax), (128, 0, 128), # purple 2) cv2.putText(annotated, tg.text, (tg.bbox.xmin, tg.bbox.ymin - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (128, 0, 128), 1) return annotated def detect_lines(self, image_path: str, output_dir: str, config: Optional[Dict] = None) -> Dict: """Legacy interface for line detection""" storage = StorageFactory.get_storage() debug_handler = DebugHandler(enabled=True, storage=storage) line_detector = LineDetector( config=LineConfig(), model_path="models/deeplsd_md.tar", device=torch.device("cpu"), debug_handler=debug_handler ) pipeline = DiagramDetectionPipeline( tag_detector=None, symbol_detector=None, line_detector=line_detector, point_detector=None, junction_detector=None, storage=storage, debug_handler=debug_handler ) result = pipeline.run(image_path, output_dir, ImageConfig()) return result def _validate_and_normalize_coordinates(self, points): """Validate and normalize coordinates to image space""" valid_points = [] for point in points: x, y = point['x'], point['y'] # Validate coordinates are within image bounds if 0 <= x <= self.image_width and 0 <= y <= self.image_height: # Normalize coordinates if needed valid_points.append({ 'x': int(x), 'y': int(y), 'type': point.get('type', 'unknown'), 'confidence': point.get('confidence', 1.0) }) return valid_points if __name__ == "__main__": # 1) Initialize components storage = StorageFactory.get_storage() debug_handler = DebugHandler(enabled=True, storage=storage) # 2) Build detectors conf = { "detect_lines": True, "line_detection_params": { "merge": True, "filtering": True, "grad_thresh": 3, "grad_nfa": True } } # 3) Configure line_config = LineConfig() point_config = PointConfig() junction_config = JunctionConfig() symbol_config = SymbolConfig() tag_config = TagConfig() # ========================== Detectors ========================== # symbol_detector = SymbolDetector( config=symbol_config, debug_handler=debug_handler ) tag_detector = TagDetector( config=tag_config, debug_handler=debug_handler ) line_detector = LineDetector( config=line_config, model_path="models/deeplsd_md.tar", model_config=conf, device=torch.device("cpu"), # or "cuda" if available debug_handler=debug_handler ) point_detector = PointDetector( config=point_config, debug_handler=debug_handler) junction_detector = JunctionDetector( config=junction_config, debug_handler=debug_handler ) # 4) Create pipeline pipeline = DiagramDetectionPipeline( tag_detector=tag_detector, symbol_detector=symbol_detector, line_detector=line_detector, point_detector=point_detector, junction_detector=junction_detector, storage=storage, debug_handler=debug_handler ) # 5) Run pipeline result = pipeline.run( image_path="samples/images/0.jpg", output_dir="results/", config=ImageConfig() ) if result.success: logger.info(f"Pipeline succeeded! See JSON at {result.json_path}") else: logger.error(f"Pipeline failed: {result.error}")