intelligent-pid / line_detection_ai.py
msIntui
Add centralized logging configuration
d3ea93c
# Standard library imports first
import os
from typing import List, Dict, Optional, Any
# Get logger before any other imports
from logger import get_logger
logger = get_logger(__name__)
# Third-party imports
import cv2
import numpy as np
# Local imports
from base import BaseDetectionPipeline
from config import (
LineConfig, ImageConfig, PointConfig, JunctionConfig,
SymbolConfig, TagConfig
)
from utils import DebugHandler
from storage import StorageInterface
# Import detectors after logging is configured
from detectors import (
LineDetector, PointDetector, JunctionDetector,
SymbolDetector, TagDetector
)
class DiagramDetectionPipeline(BaseDetectionPipeline):
"""Main pipeline for processing P&ID diagrams"""
def __init__(
self,
storage: StorageInterface,
debug_handler: Optional[DebugHandler] = None
):
super().__init__(storage, debug_handler)
# Initialize detectors when needed
self._line_detector = None
self._point_detector = None
self._junction_detector = None
self._symbol_detector = None
self._tag_detector = None
def _load_image(self, image_path: str) -> np.ndarray:
"""Load image with validation."""
image_data = self.storage.load_file(image_path)
image = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR)
if image is None:
raise ValueError(f"Failed to load image from {image_path}")
return image
def _crop_to_roi(self, image: np.ndarray, roi: Optional[list]) -> Tuple[np.ndarray, Tuple[int, int]]:
"""Crop to ROI if provided, else return full image."""
if roi is not None and len(roi) == 4:
x_min, y_min, x_max, y_max = roi
return image[y_min:y_max, x_min:x_max], (x_min, y_min)
return image, (0, 0)
def _remove_symbol_tag_bboxes(self, image: np.ndarray, context: DetectionContext) -> np.ndarray:
"""Fill symbol & tag bounding boxes with white to avoid line detection picking them up."""
masked = image.copy()
for sym in context.symbols.values():
cv2.rectangle(masked,
(sym.bbox.xmin, sym.bbox.ymin),
(sym.bbox.xmax, sym.bbox.ymax),
(255, 255, 255), # White
thickness=-1)
for tg in context.tags.values():
cv2.rectangle(masked,
(tg.bbox.xmin, tg.bbox.ymin),
(tg.bbox.xmax, tg.bbox.ymax),
(255, 255, 255),
thickness=-1)
return masked
def run(
self,
image_path: str,
output_dir: str,
config
) -> DetectionResult:
"""
Main pipeline steps (in local coords):
1) Load + crop image
2) Detect symbols & tags
3) Make a copy for final debug images
4) White out symbol/tag bounding boxes
5) Detect lines, points, junctions
6) Save final JSON
7) Generate debug images with various combinations
"""
try:
with self.debug_handler.track_performance("total_processing"):
# 1) Load & crop
image = self._load_image(image_path)
cropped_image, roi_offset = self._crop_to_roi(image, config.roi)
# 2) Create fresh context
context = DetectionContext()
# 3) Detect symbols
with self.debug_handler.track_performance("symbol_detection"):
self.symbol_detector.detect(
cropped_image,
context=context,
roi_offset=roi_offset
)
# 4) Detect tags
with self.debug_handler.track_performance("tag_detection"):
self.tag_detector.detect(
cropped_image,
context=context,
roi_offset=roi_offset
)
# Make a copy of the cropped image for final debug combos
debug_cropped = cropped_image.copy()
# 5) White-out symbol/tag bboxes in the original cropped image
cropped_image = self._remove_symbol_tag_bboxes(cropped_image, context)
# 6) Detect lines
with self.debug_handler.track_performance("line_detection"):
self.line_detector.detect(cropped_image, context=context)
# 7) Detect points
if self.point_detector:
with self.debug_handler.track_performance("point_detection"):
self.point_detector.detect(cropped_image, context=context)
# 8) Detect junctions
if self.junction_detector:
with self.debug_handler.track_performance("junction_detection"):
self.junction_detector.detect(cropped_image, context=context)
# 9) Save final JSON & any final images
output_paths = self._persist_results(output_dir, image_path, context)
# 10) Save debug images in local coords using debug_cropped
self._save_all_combinations(debug_cropped, context, output_dir, image_path)
return DetectionResult(
success=True,
processing_time=self.debug_handler.metrics.get('total_processing', 0),
json_path=output_paths.get('json_path'),
image_path=output_paths.get('image_path') # Now returning the annotated image path
)
except Exception as e:
logger.error(f"Processing failed: {str(e)}")
return DetectionResult(
success=False,
error=str(e)
)
# ------------------------------------------------
# HELPER FUNCTIONS
# ------------------------------------------------
def _persist_results(self, output_dir: str, image_path: str, context: DetectionContext) -> dict:
"""Saves final JSON and debug images to disk."""
self.storage.create_directory(output_dir)
base_name = Path(image_path).stem
# Save JSON
json_path = Path(output_dir) / f"{base_name}_detected_lines.json"
context_json_str = context.to_json(indent=2)
self.storage.save_file(str(json_path), context_json_str.encode('utf-8'))
# Save annotated image for pipeline display
annotated_image = self._draw_objects(
self._load_image(image_path),
context,
draw_lines=True,
draw_points=True,
draw_symbols=True,
draw_junctions=True,
draw_tags=True
)
image_path = Path(output_dir) / f"{base_name}_annotated.jpg"
_, encoded = cv2.imencode('.jpg', annotated_image)
self.storage.save_file(str(image_path), encoded.tobytes())
return {
"json_path": str(json_path),
"image_path": str(image_path)
}
def _save_all_combinations(self, local_image: np.ndarray, context: DetectionContext,
output_dir: str, image_path: str) -> None:
"""Produce debug images with different combinations."""
base_name = Path(image_path).stem
base_name = base_name.split("_")[0]
combos = [
("text_detected_symbols", dict(draw_symbols=True, draw_tags=False, draw_lines=False, draw_points=False, draw_junctions=False)),
("text_detected_texts", dict(draw_symbols=False, draw_tags=True, draw_lines=False, draw_points=False, draw_junctions=False)),
("text_detected_lines", dict(draw_symbols=False, draw_tags=False, draw_lines=True, draw_points=False, draw_junctions=False)),
]
self.storage.create_directory(output_dir)
for combo_name, flags in combos:
annotated = self._draw_objects(local_image, context, **flags)
save_name = f"{base_name}_{combo_name}.jpg"
save_path = Path(output_dir) / save_name
_, encoded = cv2.imencode('.jpg', annotated)
self.storage.save_file(str(save_path), encoded.tobytes())
logger.info(f"Saved debug image: {save_path}")
def _draw_objects(self, base_image: np.ndarray, context: DetectionContext,
draw_lines: bool = True, draw_points: bool = True,
draw_symbols: bool = True, draw_junctions: bool = True,
draw_tags: bool = True) -> np.ndarray:
"""Draw detection results on a copy of base_image in local coords."""
annotated = base_image.copy()
# Lines
if draw_lines:
for ln in context.lines.values():
cv2.line(annotated,
(ln.start.coords.x, ln.start.coords.y),
(ln.end.coords.x, ln.end.coords.y),
(0, 255, 0), # green
2)
# Points
if draw_points:
for pt in context.points.values():
cv2.circle(annotated,
(pt.coords.x, pt.coords.y),
3,
(0, 0, 255), # red
-1)
# Symbols
if draw_symbols:
for sym in context.symbols.values():
cv2.rectangle(annotated,
(sym.bbox.xmin, sym.bbox.ymin),
(sym.bbox.xmax, sym.bbox.ymax),
(255, 255, 0), # cyan
2)
cv2.circle(annotated,
(sym.center.x, sym.center.y),
4,
(255, 0, 255), # magenta
-1)
# Junctions
if draw_junctions:
for jn in context.junctions.values():
if jn.junction_type == JunctionType.T:
color = (0, 165, 255) # orange
elif jn.junction_type == JunctionType.L:
color = (255, 0, 255) # magenta
else: # END
color = (0, 0, 255) # red
cv2.circle(annotated,
(jn.center.x, jn.center.y),
5,
color,
-1)
# Tags
if draw_tags:
for tg in context.tags.values():
cv2.rectangle(annotated,
(tg.bbox.xmin, tg.bbox.ymin),
(tg.bbox.xmax, tg.bbox.ymax),
(128, 0, 128), # purple
2)
cv2.putText(annotated,
tg.text,
(tg.bbox.xmin, tg.bbox.ymin - 5),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
(128, 0, 128),
1)
return annotated
def detect_lines(self, image_path: str, output_dir: str, config: Optional[Dict] = None) -> Dict:
"""Legacy interface for line detection"""
storage = StorageFactory.get_storage()
debug_handler = DebugHandler(enabled=True, storage=storage)
line_detector = LineDetector(
config=LineConfig(),
model_path="models/deeplsd_md.tar",
device=torch.device("cpu"),
debug_handler=debug_handler
)
pipeline = DiagramDetectionPipeline(
tag_detector=None,
symbol_detector=None,
line_detector=line_detector,
point_detector=None,
junction_detector=None,
storage=storage,
debug_handler=debug_handler
)
result = pipeline.run(image_path, output_dir, ImageConfig())
return result
def _validate_and_normalize_coordinates(self, points):
"""Validate and normalize coordinates to image space"""
valid_points = []
for point in points:
x, y = point['x'], point['y']
# Validate coordinates are within image bounds
if 0 <= x <= self.image_width and 0 <= y <= self.image_height:
# Normalize coordinates if needed
valid_points.append({
'x': int(x),
'y': int(y),
'type': point.get('type', 'unknown'),
'confidence': point.get('confidence', 1.0)
})
return valid_points
if __name__ == "__main__":
# 1) Initialize components
storage = StorageFactory.get_storage()
debug_handler = DebugHandler(enabled=True, storage=storage)
# 2) Build detectors
conf = {
"detect_lines": True,
"line_detection_params": {
"merge": True,
"filtering": True,
"grad_thresh": 3,
"grad_nfa": True
}
}
# 3) Configure
line_config = LineConfig()
point_config = PointConfig()
junction_config = JunctionConfig()
symbol_config = SymbolConfig()
tag_config = TagConfig()
# ========================== Detectors ========================== #
symbol_detector = SymbolDetector(
config=symbol_config,
debug_handler=debug_handler
)
tag_detector = TagDetector(
config=tag_config,
debug_handler=debug_handler
)
line_detector = LineDetector(
config=line_config,
model_path="models/deeplsd_md.tar",
model_config=conf,
device=torch.device("cpu"), # or "cuda" if available
debug_handler=debug_handler
)
point_detector = PointDetector(
config=point_config,
debug_handler=debug_handler)
junction_detector = JunctionDetector(
config=junction_config,
debug_handler=debug_handler
)
# 4) Create pipeline
pipeline = DiagramDetectionPipeline(
tag_detector=tag_detector,
symbol_detector=symbol_detector,
line_detector=line_detector,
point_detector=point_detector,
junction_detector=junction_detector,
storage=storage,
debug_handler=debug_handler
)
# 5) Run pipeline
result = pipeline.run(
image_path="samples/images/0.jpg",
output_dir="results/",
config=ImageConfig()
)
if result.success:
logger.info(f"Pipeline succeeded! See JSON at {result.json_path}")
else:
logger.error(f"Pipeline failed: {result.error}")