import os import math import torch import cv2 import numpy as np from typing import List, Optional, Tuple, Dict from dataclasses import replace from math import sqrt import json import uuid from pathlib import Path # Base classes and utilities from base import BaseDetector from detection_schema import DetectionContext from utils import DebugHandler from config import SymbolConfig, TagConfig, LineConfig, PointConfig, JunctionConfig # DeepLSD model for line detection from deeplsd.models.deeplsd_inference import DeepLSD from ultralytics import YOLO # Detection schema: dataclasses for different objects from detection_schema import ( BBox, Coordinates, Point, Line, Symbol, Tag, SymbolType, LineStyle, ConnectionType, JunctionType, Junction ) # Skeletonization and label processing for junction detection from skimage.morphology import skeletonize from skimage.measure import label import os import cv2 import torch import numpy as np from dataclasses import replace from typing import List, Optional from detection_utils import robust_merge_lines class LineDetector(BaseDetector): """ DeepLSD-based line detection with patch-based tiling and global merging. """ def __init__(self, config: LineConfig, model_path: str, model_config: dict, device: torch.device, debug_handler: DebugHandler = None): super().__init__(config, debug_handler) # Fix device selection for Apple Silicon if torch.backends.mps.is_available(): self.device = torch.device("mps") elif torch.cuda.is_available(): self.device = torch.device("cuda") else: self.device = torch.device("cpu") self.model_path = model_path self.model_config = model_config self.model = self._load_model(model_path) # Patch parameters self.patch_size = 512 self.overlap = 10 # Merging thresholds self.angle_thresh = 5.0 # degrees self.dist_thresh = 5.0 # pixels def _preprocess(self, image: np.ndarray) -> np.ndarray: kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2)) dilated = cv2.dilate(image, kernel, iterations=2) skeleton = cv2.bitwise_not(dilated) skeleton = skeletonize(skeleton // 255) skeleton = (skeleton * 255).astype(np.uint8) kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 1)) clean_image = cv2.dilate(skeleton, kernel, iterations=5) self.debug_handler.save_artifact(name="skeleton", data=clean_image, extension="png") return clean_image def _postprocess(self, image: np.ndarray) -> np.ndarray: return None # ------------------------------------- # 1) Load Model # ------------------------------------- def _load_model(self, model_path: str) -> DeepLSD: if not os.path.exists(model_path): raise FileNotFoundError(f"Model file not found: {model_path}") ckpt = torch.load(model_path, map_location=self.device) model = DeepLSD(self.model_config) model.load_state_dict(ckpt["model"]) return model.to(self.device).eval() # ------------------------------------- # 2) Main Detection Pipeline # ------------------------------------- def detect(self, image: np.ndarray, context: DetectionContext, mask_coords: Optional[List[BBox]] = None, *args, **kwargs) -> None: """ Steps: - Optional mask + threshold - Tile into overlapping patches - For each patch => run DeepLSD => re-map lines to global coords - Merge lines robustly - Build final Line objects => add to context """ mask_coords = mask_coords or [] skeleton = self._preprocess(image) # (A) Optional mask + threshold if you want a binary # If your model expects grayscale or binary, do it here: processed_img = self._apply_mask_and_threshold(skeleton, mask_coords) # (B) Patch-based inference => collect raw lines in global coords all_lines = self._detect_in_patches(processed_img) # (C) Merge the lines in the global coordinate system merged_line_segments = robust_merge_lines( all_lines, angle_thresh=self.angle_thresh, dist_thresh=self.dist_thresh ) # (D) Convert merged segments => final Line objects, add to context for (x1, y1, x2, y2) in merged_line_segments: line_obj = self._create_line_object(x1, y1, x2, y2) context.add_line(line_obj) # ------------------------------------- # 3) Optional Mask + Threshold # ------------------------------------- def _apply_mask_and_threshold(self, image: np.ndarray, mask_coords: List[BBox]) -> np.ndarray: """White out rectangular areas, then threshold to binary (if needed).""" masked = image.copy() for bbox in mask_coords: x1, y1 = int(bbox.xmin), int(bbox.ymin) x2, y2 = int(bbox.xmax), int(bbox.ymax) cv2.rectangle(masked, (x1, y1), (x2, y2), (255, 255, 255), -1) # If image has 3 channels, convert to grayscale if len(masked.shape) == 3: masked_gray = cv2.cvtColor(masked, cv2.COLOR_BGR2GRAY) else: masked_gray = masked # Binary threshold (adjust threshold as needed) # If your model expects a plain grayscale, skip threshold binary_img = cv2.threshold(masked_gray, 127, 255, cv2.THRESH_BINARY)[1] return binary_img # ------------------------------------- # 4) Patch-Based Inference # ------------------------------------- def _detect_in_patches(self, processed_img: np.ndarray) -> List[tuple]: """ Break the image into overlapping patches, run DeepLSD, map local lines => global coords, and return the global line list. """ patch_size = self.patch_size overlap = self.overlap height, width = processed_img.shape[:2] step = patch_size - overlap all_lines = [] for y in range(0, height, step): patch_ymax = min(y + patch_size, height) patch_ymin = patch_ymax - patch_size if (patch_ymax - y) < patch_size else y if patch_ymin < 0: patch_ymin = 0 for x in range(0, width, step): patch_xmax = min(x + patch_size, width) patch_xmin = patch_xmax - patch_size if (patch_xmax - x) < patch_size else x if patch_xmin < 0: patch_xmin = 0 patch = processed_img[patch_ymin:patch_ymax, patch_xmin:patch_xmax] # Run model local_lines = self._run_model_inference(patch) # Convert local lines => global coords for ln in local_lines: (x1_local, y1_local), (x2_local, y2_local) = ln # offset by patch_xmin, patch_ymin gx1 = x1_local + patch_xmin gy1 = y1_local + patch_ymin gx2 = x2_local + patch_xmin gy2 = y2_local + patch_ymin # Optional: clamp or filter lines partially out-of-bounds if 0 <= gx1 < width and 0 <= gx2 < width and 0 <= gy1 < height and 0 <= gy2 < height: all_lines.append((gx1, gy1, gx2, gy2)) return all_lines # ------------------------------------- # 5) Model Inference (Single Patch) # ------------------------------------- def _run_model_inference(self, patch_img: np.ndarray) -> np.ndarray: """ Run DeepLSD on a single patch (already masked/thresholded). patch_img shape: [patchH, patchW]. Returns lines shape: [N, 2, 2]. """ # Convert patch to float32 and scale inp = torch.tensor(patch_img, dtype=torch.float32, device=self.device)[None, None] / 255.0 with torch.no_grad(): output = self.model({"image": inp}) lines = output["lines"][0] # shape (N, 2, 2) return lines # ------------------------------------- # 6) Convert Merged Segments => Line Objects # ------------------------------------- def _create_line_object(self, x1: float, y1: float, x2: float, y2: float) -> Line: """ Create a minimal `Line` object from the final merged coordinates. """ margin = 2 # Start point start_pt = Point( coords=Coordinates(int(x1), int(y1)), bbox=BBox( xmin=int(x1 - margin), ymin=int(y1 - margin), xmax=int(x1 + margin), ymax=int(y1 + margin) ), type=JunctionType.END, confidence=1.0 ) # End point end_pt = Point( coords=Coordinates(int(x2), int(y2)), bbox=BBox( xmin=int(x2 - margin), ymin=int(y2 - margin), xmax=int(x2 + margin), ymax=int(y2 + margin) ), type=JunctionType.END, confidence=1.0 ) # Overall bounding box x_min = int(min(x1, x2)) x_max = int(max(x1, x2)) y_min = int(min(y1, y2)) y_max = int(max(y1, y2)) line_obj = Line( start=start_pt, end=end_pt, bbox=BBox(xmin=x_min, ymin=y_min, xmax=x_max, ymax=y_max), style=LineStyle( connection_type=ConnectionType.SOLID, stroke_width=2, color="#000000" ), confidence=0.9, topological_links=[] ) return line_obj class PointDetector(BaseDetector): """ A detector that: 1) Reads lines from the context 2) Clusters endpoints within 'threshold_distance' 3) Updates lines so that shared endpoints reference the same Point object """ def __init__(self, config:PointConfig, debug_handler: DebugHandler = None): super().__init__(config, debug_handler) # No real model to load self.threshold_distance = config.threshold_distance def _load_model(self, model_path: str): """No model needed for simple point unification.""" return None def detect(self, image: np.ndarray, context: DetectionContext, *args, **kwargs) -> None: """ Main method called by the pipeline. 1) Gather all line endpoints from context 2) Cluster them within 'threshold_distance' 3) Update the line endpoints so they reference the unified cluster point """ # 1) Collect all endpoints endpoints = [] for line in context.lines.values(): endpoints.append(line.start) endpoints.append(line.end) # 2) Cluster endpoints clusters = self._cluster_points(endpoints, self.threshold_distance) # 3) Build a dictionary of "representative" points # So that each cluster has one "canonical" point # Then we link all the points in that cluster to the canonical reference unified_point_map = {} for cluster in clusters: # let's pick the first point in the cluster as the "representative" rep_point = cluster[0] for p in cluster[1:]: unified_point_map[p.id] = rep_point # 4) Update all lines to reference the canonical point for line in context.lines.values(): # unify start if line.start.id in unified_point_map: line.start = unified_point_map[line.start.id] # unify end if line.end.id in unified_point_map: line.end = unified_point_map[line.end.id] # We could also store the final set of unique points back in context.points # (e.g. clearing old duplicates). # That step is optional: you might prefer to keep everything in lines only, # or you might want context.points as a separate reference. # If you want to keep unique points in context.points: new_points = {} for line in context.lines.values(): new_points[line.start.id] = line.start new_points[line.end.id] = line.end context.points = new_points # replace the dictionary of points def _preprocess(self, image: np.ndarray) -> np.ndarray: """No specific image preprocessing needed.""" return image def _postprocess(self, image: np.ndarray) -> np.ndarray: """No specific image postprocessing needed.""" return image # ---------------------- # HELPER: clustering # ---------------------- def _cluster_points(self, points: List[Point], threshold: float) -> List[List[Point]]: """ Very naive clustering: 1) Start from the first point 2) If it's within threshold of an existing cluster's representative, put it in that cluster 3) Otherwise start a new cluster Return: list of clusters, each is a list of Points """ clusters = [] for pt in points: placed = False for cluster in clusters: # pick the first point in the cluster as reference ref_pt = cluster[0] if self._distance(pt, ref_pt) < threshold: cluster.append(pt) placed = True break if not placed: clusters.append([pt]) return clusters def _distance(self, p1: Point, p2: Point) -> float: dx = p1.coords.x - p2.coords.x dy = p1.coords.y - p2.coords.y return sqrt(dx*dx + dy*dy) class JunctionDetector(BaseDetector): """ Classifies points as 'END', 'L', or 'T' by skeletonizing the binarized image and analyzing local connectivity. Also creates Junction objects in the context. """ def __init__(self, config: JunctionConfig, debug_handler: DebugHandler = None): super().__init__(config, debug_handler) # no real model path self.window_size = config.window_size self.radius = config.radius self.angle_threshold_lb = config.angle_threshold_lb self.angle_threshold_ub = config.angle_threshold_ub self.debug_handler = debug_handler or DebugHandler() def _load_model(self, model_path: str): """Not loading any actual model, just skeleton logic.""" return None def detect(self, image: np.ndarray, context: DetectionContext, *args, **kwargs) -> None: """ 1) Convert to binary & skeletonize 2) Classify each point in the context 3) Create a Junction for each point and store it in context.junctions (with 'connected_lines' referencing lines that share this point). """ # 1) Preprocess -> skeleton skeleton = self._create_skeleton(image) # 2) Classify each point for pt in context.points.values(): pt.type = self._classify_point(skeleton, pt) # 3) Create a Junction object for each point # If you prefer only T or L, you can filter out END points. self._record_junctions_in_context(context) def _preprocess(self, image: np.ndarray) -> np.ndarray: """We might do thresholding; let's do a simple binary threshold.""" if image.ndim == 3: gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) else: gray = image _, bin_image = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY) return bin_image def _postprocess(self, image: np.ndarray) -> np.ndarray: return image def _create_skeleton(self, raw_image: np.ndarray) -> np.ndarray: """Skeletonize the binarized image.""" bin_img = self._preprocess(raw_image) # For skeletonize, we need a boolean array inv = cv2.bitwise_not(bin_img) inv_bool = (inv > 127).astype(np.uint8) skel = skeletonize(inv_bool).astype(np.uint8) * 255 return skel def _classify_point(self, skeleton: np.ndarray, pt: Point) -> JunctionType: """ Given a skeleton image, look around 'pt' in a local window to determine if it's an END, L, or T. """ classification = JunctionType.END # default half_w = self.window_size // 2 x, y = pt.coords.x, pt.coords.y top = max(0, y - half_w) bottom = min(skeleton.shape[0], y + half_w + 1) left = max(0, x - half_w) right = min(skeleton.shape[1], x + half_w + 1) patch = (skeleton[top:bottom, left:right] > 127).astype(np.uint8) # create circular mask circle_mask = np.zeros_like(patch, dtype=np.uint8) local_cx = x - left local_cy = y - top cv2.circle(circle_mask, (local_cx, local_cy), self.radius, 1, -1) circle_skel = patch & circle_mask # label connected regions labeled = label(circle_skel, connectivity=2) num_exits = labeled.max() if num_exits == 1: classification = JunctionType.END elif num_exits == 2: # check angle for L classification = self._check_angle_for_L(labeled) elif num_exits == 3: classification = JunctionType.T return classification def _check_angle_for_L(self, labeled_region: np.ndarray) -> JunctionType: """ If the angle between two branches is within [angle_threshold_lb, angle_threshold_ub], it's 'L'. Otherwise default to END. """ coords = np.argwhere(labeled_region == 1) if len(coords) < 2: return JunctionType.END (y1, x1), (y2, x2) = coords[:2] dx = x2 - x1 dy = y2 - y1 angle = math.degrees(math.atan2(dy, dx)) acute_angle = min(abs(angle), 180 - abs(angle)) if self.angle_threshold_lb <= acute_angle <= self.angle_threshold_ub: return JunctionType.L return JunctionType.END # ----------------------------------------- # EXTRA STEP: Create Junction objects # ----------------------------------------- def _record_junctions_in_context(self, context: DetectionContext): """ Create a Junction object for each point in context.points. If you only want T/L points as junctions, filter them out. Also track any lines that connect to this point. """ for pt in context.points.values(): # If you prefer to store all points as junction, do it: # or if you want only T or L, do: # if pt.type in {JunctionType.T, JunctionType.L}: ... jn = Junction( center=pt.coords, junction_type=pt.type, # add more properties if needed ) # find lines that connect to this point connected_lines = [] for ln in context.lines.values(): if ln.start.id == pt.id or ln.end.id == pt.id: connected_lines.append(ln.id) jn.connected_lines = connected_lines # add to context context.add_junction(jn) import json import uuid class SymbolDetector(BaseDetector): """ A placeholder detector that reads precomputed symbol data from a JSON file and populates the context with Symbol objects. """ def __init__(self, config: SymbolConfig, debug_handler: Optional[DebugHandler] = None, symbol_json_path: str = "./symbols.json"): super().__init__(config=config, debug_handler=debug_handler) self.symbol_json_path = symbol_json_path def _load_model(self, model_path: str): """Not loading an actual model; symbol data is read from JSON.""" return None def detect(self, image: np.ndarray, context: DetectionContext, # roi_offset: Tuple[int, int], *args, **kwargs) -> None: """ Reads from a JSON file containing symbol info, adjusts coordinates using roi_offset, and updates context. """ symbol_data = self._load_json_data(self.symbol_json_path) if not symbol_data: return # x_min, y_min = roi_offset # Offset values from cropping for record in symbol_data.get("detections", []): # Fix: Use "detections" key # sym_obj = self._parse_symbol_record(record, x_min, y_min) sym_obj = self._parse_symbol_record(record) context.add_symbol(sym_obj) def _preprocess(self, image: np.ndarray) -> np.ndarray: return image def _postprocess(self, image: np.ndarray) -> np.ndarray: return image # -------------- # HELPER METHODS # -------------- def _load_json_data(self, json_path: str) -> dict: if not os.path.exists(json_path): self.debug_handler.save_artifact(name="symbol_error", data=b"Missing symbol JSON file", extension="txt") return {} with open(json_path, "r", encoding="utf-8") as f: return json.load(f) def _parse_symbol_record(self, record: dict) -> Symbol: """ Builds a Symbol object from a JSON record, adjusting coordinates for cropping. """ bbox_list = record.get("bbox", [0, 0, 0, 0]) # bbox_obj = BBox( # xmin=bbox_list[0] - x_min, # ymin=bbox_list[1] - y_min, # xmax=bbox_list[2] - x_min, # ymax=bbox_list[3] - y_min # ) bbox_obj = BBox( xmin=bbox_list[0], ymin=bbox_list[1], xmax=bbox_list[2], ymax=bbox_list[3] ) # Compute the center center_coords = Coordinates( x=(bbox_obj.xmin + bbox_obj.xmax) // 2, y=(bbox_obj.ymin + bbox_obj.ymax) // 2 ) return Symbol( id=record.get("symbol_id", ""), class_id=record.get("class_id", -1), original_label=record.get("original_label", ""), category=record.get("category", ""), type=record.get("type", ""), label=record.get("label", ""), bbox=bbox_obj, center=center_coords, confidence=record.get("confidence", 0.95), model_source=record.get("model_source", ""), connections=[] ) class TagDetector(BaseDetector): """ A placeholder detector that reads precomputed tag data from a JSON file and populates the context with Tag objects. """ def __init__(self, config: TagConfig, debug_handler: Optional[DebugHandler] = None, tag_json_path: str = "./tags.json"): super().__init__(config=config, debug_handler=debug_handler) self.tag_json_path = tag_json_path def _load_model(self, model_path: str): """Not loading an actual model; tag data is read from JSON.""" return None def detect(self, image: np.ndarray, context: DetectionContext, # roi_offset: Tuple[int, int], *args, **kwargs) -> None: """ Reads from a JSON file containing tag info, adjusts coordinates using roi_offset, and updates context. """ tag_data = self._load_json_data(self.tag_json_path) if not tag_data: return # x_min, y_min = roi_offset # Offset values from cropping for record in tag_data.get("detections", []): # Fix: Use "detections" key # tag_obj = self._parse_tag_record(record, x_min, y_min) tag_obj = self._parse_tag_record(record) context.add_tag(tag_obj) def _preprocess(self, image: np.ndarray) -> np.ndarray: return image def _postprocess(self, image: np.ndarray) -> np.ndarray: return image # -------------- # HELPER METHODS # -------------- def _load_json_data(self, json_path: str) -> dict: if not os.path.exists(json_path): self.debug_handler.save_artifact(name="tag_error", data=b"Missing tag JSON file", extension="txt") return {} with open(json_path, "r", encoding="utf-8") as f: return json.load(f) def _parse_tag_record(self, record: dict) -> Tag: """ Builds a Tag object from a JSON record, adjusting coordinates for cropping. """ bbox_list = record.get("bbox", [0, 0, 0, 0]) # bbox_obj = BBox( # xmin=bbox_list[0] - x_min, # ymin=bbox_list[1] - y_min, # xmax=bbox_list[2] - x_min, # ymax=bbox_list[3] - y_min # ) bbox_obj = BBox( xmin=bbox_list[0], ymin=bbox_list[1], xmax=bbox_list[2], ymax=bbox_list[3] ) return Tag( text=record.get("text", ""), bbox=bbox_obj, confidence=record.get("confidence", 1.0), source=record.get("source", ""), text_type=record.get("text_type", "Unknown"), id=record.get("id", str(uuid.uuid4())), font_size=record.get("font_size", 12), rotation=record.get("rotation", 0.0) )