Spaces:
Build error
Build error
| # Standard library imports first | |
| import os | |
| from typing import List, Dict, Optional, Any | |
| # Get logger before any other imports | |
| from logger import get_logger | |
| logger = get_logger(__name__) | |
| # Third-party imports | |
| import cv2 | |
| import numpy as np | |
| # Local imports | |
| from base import BaseDetectionPipeline | |
| from config import ( | |
| LineConfig, ImageConfig, PointConfig, JunctionConfig, | |
| SymbolConfig, TagConfig | |
| ) | |
| from utils import DebugHandler | |
| from storage import StorageInterface | |
| # Import detectors after logging is configured | |
| from detectors import ( | |
| LineDetector, PointDetector, JunctionDetector, | |
| SymbolDetector, TagDetector | |
| ) | |
| class DiagramDetectionPipeline(BaseDetectionPipeline): | |
| """Main pipeline for processing P&ID diagrams""" | |
| def __init__( | |
| self, | |
| storage: StorageInterface, | |
| debug_handler: Optional[DebugHandler] = None | |
| ): | |
| super().__init__(storage, debug_handler) | |
| # Initialize detectors when needed | |
| self._line_detector = None | |
| self._point_detector = None | |
| self._junction_detector = None | |
| self._symbol_detector = None | |
| self._tag_detector = None | |
| def _load_image(self, image_path: str) -> np.ndarray: | |
| """Load image with validation.""" | |
| image_data = self.storage.load_file(image_path) | |
| image = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR) | |
| if image is None: | |
| raise ValueError(f"Failed to load image from {image_path}") | |
| return image | |
| def _crop_to_roi(self, image: np.ndarray, roi: Optional[list]) -> Tuple[np.ndarray, Tuple[int, int]]: | |
| """Crop to ROI if provided, else return full image.""" | |
| if roi is not None and len(roi) == 4: | |
| x_min, y_min, x_max, y_max = roi | |
| return image[y_min:y_max, x_min:x_max], (x_min, y_min) | |
| return image, (0, 0) | |
| def _remove_symbol_tag_bboxes(self, image: np.ndarray, context: DetectionContext) -> np.ndarray: | |
| """Fill symbol & tag bounding boxes with white to avoid line detection picking them up.""" | |
| masked = image.copy() | |
| for sym in context.symbols.values(): | |
| cv2.rectangle(masked, | |
| (sym.bbox.xmin, sym.bbox.ymin), | |
| (sym.bbox.xmax, sym.bbox.ymax), | |
| (255, 255, 255), # White | |
| thickness=-1) | |
| for tg in context.tags.values(): | |
| cv2.rectangle(masked, | |
| (tg.bbox.xmin, tg.bbox.ymin), | |
| (tg.bbox.xmax, tg.bbox.ymax), | |
| (255, 255, 255), | |
| thickness=-1) | |
| return masked | |
| def run( | |
| self, | |
| image_path: str, | |
| output_dir: str, | |
| config | |
| ) -> DetectionResult: | |
| """ | |
| Main pipeline steps (in local coords): | |
| 1) Load + crop image | |
| 2) Detect symbols & tags | |
| 3) Make a copy for final debug images | |
| 4) White out symbol/tag bounding boxes | |
| 5) Detect lines, points, junctions | |
| 6) Save final JSON | |
| 7) Generate debug images with various combinations | |
| """ | |
| try: | |
| with self.debug_handler.track_performance("total_processing"): | |
| # 1) Load & crop | |
| image = self._load_image(image_path) | |
| cropped_image, roi_offset = self._crop_to_roi(image, config.roi) | |
| # 2) Create fresh context | |
| context = DetectionContext() | |
| # 3) Detect symbols | |
| with self.debug_handler.track_performance("symbol_detection"): | |
| self.symbol_detector.detect( | |
| cropped_image, | |
| context=context, | |
| roi_offset=roi_offset | |
| ) | |
| # 4) Detect tags | |
| with self.debug_handler.track_performance("tag_detection"): | |
| self.tag_detector.detect( | |
| cropped_image, | |
| context=context, | |
| roi_offset=roi_offset | |
| ) | |
| # Make a copy of the cropped image for final debug combos | |
| debug_cropped = cropped_image.copy() | |
| # 5) White-out symbol/tag bboxes in the original cropped image | |
| cropped_image = self._remove_symbol_tag_bboxes(cropped_image, context) | |
| # 6) Detect lines | |
| with self.debug_handler.track_performance("line_detection"): | |
| self.line_detector.detect(cropped_image, context=context) | |
| # 7) Detect points | |
| if self.point_detector: | |
| with self.debug_handler.track_performance("point_detection"): | |
| self.point_detector.detect(cropped_image, context=context) | |
| # 8) Detect junctions | |
| if self.junction_detector: | |
| with self.debug_handler.track_performance("junction_detection"): | |
| self.junction_detector.detect(cropped_image, context=context) | |
| # 9) Save final JSON & any final images | |
| output_paths = self._persist_results(output_dir, image_path, context) | |
| # 10) Save debug images in local coords using debug_cropped | |
| self._save_all_combinations(debug_cropped, context, output_dir, image_path) | |
| return DetectionResult( | |
| success=True, | |
| processing_time=self.debug_handler.metrics.get('total_processing', 0), | |
| json_path=output_paths.get('json_path'), | |
| image_path=output_paths.get('image_path') # Now returning the annotated image path | |
| ) | |
| except Exception as e: | |
| logger.error(f"Processing failed: {str(e)}") | |
| return DetectionResult( | |
| success=False, | |
| error=str(e) | |
| ) | |
| # ------------------------------------------------ | |
| # HELPER FUNCTIONS | |
| # ------------------------------------------------ | |
| def _persist_results(self, output_dir: str, image_path: str, context: DetectionContext) -> dict: | |
| """Saves final JSON and debug images to disk.""" | |
| self.storage.create_directory(output_dir) | |
| base_name = Path(image_path).stem | |
| # Save JSON | |
| json_path = Path(output_dir) / f"{base_name}_detected_lines.json" | |
| context_json_str = context.to_json(indent=2) | |
| self.storage.save_file(str(json_path), context_json_str.encode('utf-8')) | |
| # Save annotated image for pipeline display | |
| annotated_image = self._draw_objects( | |
| self._load_image(image_path), | |
| context, | |
| draw_lines=True, | |
| draw_points=True, | |
| draw_symbols=True, | |
| draw_junctions=True, | |
| draw_tags=True | |
| ) | |
| image_path = Path(output_dir) / f"{base_name}_annotated.jpg" | |
| _, encoded = cv2.imencode('.jpg', annotated_image) | |
| self.storage.save_file(str(image_path), encoded.tobytes()) | |
| return { | |
| "json_path": str(json_path), | |
| "image_path": str(image_path) | |
| } | |
| def _save_all_combinations(self, local_image: np.ndarray, context: DetectionContext, | |
| output_dir: str, image_path: str) -> None: | |
| """Produce debug images with different combinations.""" | |
| base_name = Path(image_path).stem | |
| base_name = base_name.split("_")[0] | |
| combos = [ | |
| ("text_detected_symbols", dict(draw_symbols=True, draw_tags=False, draw_lines=False, draw_points=False, draw_junctions=False)), | |
| ("text_detected_texts", dict(draw_symbols=False, draw_tags=True, draw_lines=False, draw_points=False, draw_junctions=False)), | |
| ("text_detected_lines", dict(draw_symbols=False, draw_tags=False, draw_lines=True, draw_points=False, draw_junctions=False)), | |
| ] | |
| self.storage.create_directory(output_dir) | |
| for combo_name, flags in combos: | |
| annotated = self._draw_objects(local_image, context, **flags) | |
| save_name = f"{base_name}_{combo_name}.jpg" | |
| save_path = Path(output_dir) / save_name | |
| _, encoded = cv2.imencode('.jpg', annotated) | |
| self.storage.save_file(str(save_path), encoded.tobytes()) | |
| logger.info(f"Saved debug image: {save_path}") | |
| def _draw_objects(self, base_image: np.ndarray, context: DetectionContext, | |
| draw_lines: bool = True, draw_points: bool = True, | |
| draw_symbols: bool = True, draw_junctions: bool = True, | |
| draw_tags: bool = True) -> np.ndarray: | |
| """Draw detection results on a copy of base_image in local coords.""" | |
| annotated = base_image.copy() | |
| # Lines | |
| if draw_lines: | |
| for ln in context.lines.values(): | |
| cv2.line(annotated, | |
| (ln.start.coords.x, ln.start.coords.y), | |
| (ln.end.coords.x, ln.end.coords.y), | |
| (0, 255, 0), # green | |
| 2) | |
| # Points | |
| if draw_points: | |
| for pt in context.points.values(): | |
| cv2.circle(annotated, | |
| (pt.coords.x, pt.coords.y), | |
| 3, | |
| (0, 0, 255), # red | |
| -1) | |
| # Symbols | |
| if draw_symbols: | |
| for sym in context.symbols.values(): | |
| cv2.rectangle(annotated, | |
| (sym.bbox.xmin, sym.bbox.ymin), | |
| (sym.bbox.xmax, sym.bbox.ymax), | |
| (255, 255, 0), # cyan | |
| 2) | |
| cv2.circle(annotated, | |
| (sym.center.x, sym.center.y), | |
| 4, | |
| (255, 0, 255), # magenta | |
| -1) | |
| # Junctions | |
| if draw_junctions: | |
| for jn in context.junctions.values(): | |
| if jn.junction_type == JunctionType.T: | |
| color = (0, 165, 255) # orange | |
| elif jn.junction_type == JunctionType.L: | |
| color = (255, 0, 255) # magenta | |
| else: # END | |
| color = (0, 0, 255) # red | |
| cv2.circle(annotated, | |
| (jn.center.x, jn.center.y), | |
| 5, | |
| color, | |
| -1) | |
| # Tags | |
| if draw_tags: | |
| for tg in context.tags.values(): | |
| cv2.rectangle(annotated, | |
| (tg.bbox.xmin, tg.bbox.ymin), | |
| (tg.bbox.xmax, tg.bbox.ymax), | |
| (128, 0, 128), # purple | |
| 2) | |
| cv2.putText(annotated, | |
| tg.text, | |
| (tg.bbox.xmin, tg.bbox.ymin - 5), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 0.5, | |
| (128, 0, 128), | |
| 1) | |
| return annotated | |
| def detect_lines(self, image_path: str, output_dir: str, config: Optional[Dict] = None) -> Dict: | |
| """Legacy interface for line detection""" | |
| storage = StorageFactory.get_storage() | |
| debug_handler = DebugHandler(enabled=True, storage=storage) | |
| line_detector = LineDetector( | |
| config=LineConfig(), | |
| model_path="models/deeplsd_md.tar", | |
| device=torch.device("cpu"), | |
| debug_handler=debug_handler | |
| ) | |
| pipeline = DiagramDetectionPipeline( | |
| tag_detector=None, | |
| symbol_detector=None, | |
| line_detector=line_detector, | |
| point_detector=None, | |
| junction_detector=None, | |
| storage=storage, | |
| debug_handler=debug_handler | |
| ) | |
| result = pipeline.run(image_path, output_dir, ImageConfig()) | |
| return result | |
| def _validate_and_normalize_coordinates(self, points): | |
| """Validate and normalize coordinates to image space""" | |
| valid_points = [] | |
| for point in points: | |
| x, y = point['x'], point['y'] | |
| # Validate coordinates are within image bounds | |
| if 0 <= x <= self.image_width and 0 <= y <= self.image_height: | |
| # Normalize coordinates if needed | |
| valid_points.append({ | |
| 'x': int(x), | |
| 'y': int(y), | |
| 'type': point.get('type', 'unknown'), | |
| 'confidence': point.get('confidence', 1.0) | |
| }) | |
| return valid_points | |
| if __name__ == "__main__": | |
| # 1) Initialize components | |
| storage = StorageFactory.get_storage() | |
| debug_handler = DebugHandler(enabled=True, storage=storage) | |
| # 2) Build detectors | |
| conf = { | |
| "detect_lines": True, | |
| "line_detection_params": { | |
| "merge": True, | |
| "filtering": True, | |
| "grad_thresh": 3, | |
| "grad_nfa": True | |
| } | |
| } | |
| # 3) Configure | |
| line_config = LineConfig() | |
| point_config = PointConfig() | |
| junction_config = JunctionConfig() | |
| symbol_config = SymbolConfig() | |
| tag_config = TagConfig() | |
| # ========================== Detectors ========================== # | |
| symbol_detector = SymbolDetector( | |
| config=symbol_config, | |
| debug_handler=debug_handler | |
| ) | |
| tag_detector = TagDetector( | |
| config=tag_config, | |
| debug_handler=debug_handler | |
| ) | |
| line_detector = LineDetector( | |
| config=line_config, | |
| model_path="models/deeplsd_md.tar", | |
| model_config=conf, | |
| device=torch.device("cpu"), # or "cuda" if available | |
| debug_handler=debug_handler | |
| ) | |
| point_detector = PointDetector( | |
| config=point_config, | |
| debug_handler=debug_handler) | |
| junction_detector = JunctionDetector( | |
| config=junction_config, | |
| debug_handler=debug_handler | |
| ) | |
| # 4) Create pipeline | |
| pipeline = DiagramDetectionPipeline( | |
| tag_detector=tag_detector, | |
| symbol_detector=symbol_detector, | |
| line_detector=line_detector, | |
| point_detector=point_detector, | |
| junction_detector=junction_detector, | |
| storage=storage, | |
| debug_handler=debug_handler | |
| ) | |
| # 5) Run pipeline | |
| result = pipeline.run( | |
| image_path="samples/images/0.jpg", | |
| output_dir="results/", | |
| config=ImageConfig() | |
| ) | |
| if result.success: | |
| logger.info(f"Pipeline succeeded! See JSON at {result.json_path}") | |
| else: | |
| logger.error(f"Pipeline failed: {result.error}") |