Intelligent_PID / detectors.py
msIntui
feat: initial clean deployment
910e0d4
import os
import math
import torch
import cv2
import numpy as np
from typing import List, Optional, Tuple, Dict
from dataclasses import replace
from math import sqrt
import json
import uuid
from pathlib import Path
# Base classes and utilities
from base import BaseDetector
from detection_schema import DetectionContext
from utils import DebugHandler
from config import SymbolConfig, TagConfig, LineConfig, PointConfig, JunctionConfig
# DeepLSD model for line detection
from deeplsd.models.deeplsd_inference import DeepLSD
from ultralytics import YOLO
# Detection schema: dataclasses for different objects
from detection_schema import (
BBox,
Coordinates,
Point,
Line,
Symbol,
Tag,
SymbolType,
LineStyle,
ConnectionType,
JunctionType,
Junction
)
# Skeletonization and label processing for junction detection
from skimage.morphology import skeletonize
from skimage.measure import label
import os
import cv2
import torch
import numpy as np
from dataclasses import replace
from typing import List, Optional
from detection_utils import robust_merge_lines
class LineDetector(BaseDetector):
"""
DeepLSD-based line detection with patch-based tiling and global merging.
"""
def __init__(self,
config: LineConfig,
model_path: str,
model_config: dict,
device: torch.device,
debug_handler: DebugHandler = None):
super().__init__(config, debug_handler)
# Fix device selection for Apple Silicon
if torch.backends.mps.is_available():
self.device = torch.device("mps")
elif torch.cuda.is_available():
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
self.model_path = model_path
self.model_config = model_config
self.model = self._load_model(model_path)
# Patch parameters
self.patch_size = 512
self.overlap = 10
# Merging thresholds
self.angle_thresh = 5.0 # degrees
self.dist_thresh = 5.0 # pixels
def _preprocess(self, image: np.ndarray) -> np.ndarray:
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2))
dilated = cv2.dilate(image, kernel, iterations=2)
skeleton = cv2.bitwise_not(dilated)
skeleton = skeletonize(skeleton // 255)
skeleton = (skeleton * 255).astype(np.uint8)
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 1))
clean_image = cv2.dilate(skeleton, kernel, iterations=5)
self.debug_handler.save_artifact(name="skeleton", data=clean_image, extension="png")
return clean_image
def _postprocess(self, image: np.ndarray) -> np.ndarray:
return None
# -------------------------------------
# 1) Load Model
# -------------------------------------
def _load_model(self, model_path: str) -> DeepLSD:
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model file not found: {model_path}")
ckpt = torch.load(model_path, map_location=self.device)
model = DeepLSD(self.model_config)
model.load_state_dict(ckpt["model"])
return model.to(self.device).eval()
# -------------------------------------
# 2) Main Detection Pipeline
# -------------------------------------
def detect(self,
image: np.ndarray,
context: DetectionContext,
mask_coords: Optional[List[BBox]] = None,
*args,
**kwargs) -> None:
"""
Steps:
- Optional mask + threshold
- Tile into overlapping patches
- For each patch => run DeepLSD => re-map lines to global coords
- Merge lines robustly
- Build final Line objects => add to context
"""
mask_coords = mask_coords or []
skeleton = self._preprocess(image)
# (A) Optional mask + threshold if you want a binary
# If your model expects grayscale or binary, do it here:
processed_img = self._apply_mask_and_threshold(skeleton, mask_coords)
# (B) Patch-based inference => collect raw lines in global coords
all_lines = self._detect_in_patches(processed_img)
# (C) Merge the lines in the global coordinate system
merged_line_segments = robust_merge_lines(
all_lines,
angle_thresh=self.angle_thresh,
dist_thresh=self.dist_thresh
)
# (D) Convert merged segments => final Line objects, add to context
for (x1, y1, x2, y2) in merged_line_segments:
line_obj = self._create_line_object(x1, y1, x2, y2)
context.add_line(line_obj)
# -------------------------------------
# 3) Optional Mask + Threshold
# -------------------------------------
def _apply_mask_and_threshold(self, image: np.ndarray, mask_coords: List[BBox]) -> np.ndarray:
"""White out rectangular areas, then threshold to binary (if needed)."""
masked = image.copy()
for bbox in mask_coords:
x1, y1 = int(bbox.xmin), int(bbox.ymin)
x2, y2 = int(bbox.xmax), int(bbox.ymax)
cv2.rectangle(masked, (x1, y1), (x2, y2), (255, 255, 255), -1)
# If image has 3 channels, convert to grayscale
if len(masked.shape) == 3:
masked_gray = cv2.cvtColor(masked, cv2.COLOR_BGR2GRAY)
else:
masked_gray = masked
# Binary threshold (adjust threshold as needed)
# If your model expects a plain grayscale, skip threshold
binary_img = cv2.threshold(masked_gray, 127, 255, cv2.THRESH_BINARY)[1]
return binary_img
# -------------------------------------
# 4) Patch-Based Inference
# -------------------------------------
def _detect_in_patches(self, processed_img: np.ndarray) -> List[tuple]:
"""
Break the image into overlapping patches, run DeepLSD,
map local lines => global coords, and return the global line list.
"""
patch_size = self.patch_size
overlap = self.overlap
height, width = processed_img.shape[:2]
step = patch_size - overlap
all_lines = []
for y in range(0, height, step):
patch_ymax = min(y + patch_size, height)
patch_ymin = patch_ymax - patch_size if (patch_ymax - y) < patch_size else y
if patch_ymin < 0: patch_ymin = 0
for x in range(0, width, step):
patch_xmax = min(x + patch_size, width)
patch_xmin = patch_xmax - patch_size if (patch_xmax - x) < patch_size else x
if patch_xmin < 0: patch_xmin = 0
patch = processed_img[patch_ymin:patch_ymax, patch_xmin:patch_xmax]
# Run model
local_lines = self._run_model_inference(patch)
# Convert local lines => global coords
for ln in local_lines:
(x1_local, y1_local), (x2_local, y2_local) = ln
# offset by patch_xmin, patch_ymin
gx1 = x1_local + patch_xmin
gy1 = y1_local + patch_ymin
gx2 = x2_local + patch_xmin
gy2 = y2_local + patch_ymin
# Optional: clamp or filter lines partially out-of-bounds
if 0 <= gx1 < width and 0 <= gx2 < width and 0 <= gy1 < height and 0 <= gy2 < height:
all_lines.append((gx1, gy1, gx2, gy2))
return all_lines
# -------------------------------------
# 5) Model Inference (Single Patch)
# -------------------------------------
def _run_model_inference(self, patch_img: np.ndarray) -> np.ndarray:
"""
Run DeepLSD on a single patch (already masked/thresholded).
patch_img shape: [patchH, patchW].
Returns lines shape: [N, 2, 2].
"""
# Convert patch to float32 and scale
inp = torch.tensor(patch_img, dtype=torch.float32, device=self.device)[None, None] / 255.0
with torch.no_grad():
output = self.model({"image": inp})
lines = output["lines"][0] # shape (N, 2, 2)
return lines
# -------------------------------------
# 6) Convert Merged Segments => Line Objects
# -------------------------------------
def _create_line_object(self, x1: float, y1: float, x2: float, y2: float) -> Line:
"""
Create a minimal `Line` object from the final merged coordinates.
"""
margin = 2
# Start point
start_pt = Point(
coords=Coordinates(int(x1), int(y1)),
bbox=BBox(
xmin=int(x1 - margin),
ymin=int(y1 - margin),
xmax=int(x1 + margin),
ymax=int(y1 + margin)
),
type=JunctionType.END,
confidence=1.0
)
# End point
end_pt = Point(
coords=Coordinates(int(x2), int(y2)),
bbox=BBox(
xmin=int(x2 - margin),
ymin=int(y2 - margin),
xmax=int(x2 + margin),
ymax=int(y2 + margin)
),
type=JunctionType.END,
confidence=1.0
)
# Overall bounding box
x_min = int(min(x1, x2))
x_max = int(max(x1, x2))
y_min = int(min(y1, y2))
y_max = int(max(y1, y2))
line_obj = Line(
start=start_pt,
end=end_pt,
bbox=BBox(xmin=x_min, ymin=y_min, xmax=x_max, ymax=y_max),
style=LineStyle(
connection_type=ConnectionType.SOLID,
stroke_width=2,
color="#000000"
),
confidence=0.9,
topological_links=[]
)
return line_obj
class PointDetector(BaseDetector):
"""
A detector that:
1) Reads lines from the context
2) Clusters endpoints within 'threshold_distance'
3) Updates lines so that shared endpoints reference the same Point object
"""
def __init__(self,
config:PointConfig,
debug_handler: DebugHandler = None):
super().__init__(config, debug_handler) # No real model to load
self.threshold_distance = config.threshold_distance
def _load_model(self, model_path: str):
"""No model needed for simple point unification."""
return None
def detect(self, image: np.ndarray, context: DetectionContext, *args, **kwargs) -> None:
"""
Main method called by the pipeline.
1) Gather all line endpoints from context
2) Cluster them within 'threshold_distance'
3) Update the line endpoints so they reference the unified cluster point
"""
# 1) Collect all endpoints
endpoints = []
for line in context.lines.values():
endpoints.append(line.start)
endpoints.append(line.end)
# 2) Cluster endpoints
clusters = self._cluster_points(endpoints, self.threshold_distance)
# 3) Build a dictionary of "representative" points
# So that each cluster has one "canonical" point
# Then we link all the points in that cluster to the canonical reference
unified_point_map = {}
for cluster in clusters:
# let's pick the first point in the cluster as the "representative"
rep_point = cluster[0]
for p in cluster[1:]:
unified_point_map[p.id] = rep_point
# 4) Update all lines to reference the canonical point
for line in context.lines.values():
# unify start
if line.start.id in unified_point_map:
line.start = unified_point_map[line.start.id]
# unify end
if line.end.id in unified_point_map:
line.end = unified_point_map[line.end.id]
# We could also store the final set of unique points back in context.points
# (e.g. clearing old duplicates).
# That step is optional: you might prefer to keep everything in lines only,
# or you might want context.points as a separate reference.
# If you want to keep unique points in context.points:
new_points = {}
for line in context.lines.values():
new_points[line.start.id] = line.start
new_points[line.end.id] = line.end
context.points = new_points # replace the dictionary of points
def _preprocess(self, image: np.ndarray) -> np.ndarray:
"""No specific image preprocessing needed."""
return image
def _postprocess(self, image: np.ndarray) -> np.ndarray:
"""No specific image postprocessing needed."""
return image
# ----------------------
# HELPER: clustering
# ----------------------
def _cluster_points(self, points: List[Point], threshold: float) -> List[List[Point]]:
"""
Very naive clustering:
1) Start from the first point
2) If it's within threshold of an existing cluster's representative,
put it in that cluster
3) Otherwise start a new cluster
Return: list of clusters, each is a list of Points
"""
clusters = []
for pt in points:
placed = False
for cluster in clusters:
# pick the first point in the cluster as reference
ref_pt = cluster[0]
if self._distance(pt, ref_pt) < threshold:
cluster.append(pt)
placed = True
break
if not placed:
clusters.append([pt])
return clusters
def _distance(self, p1: Point, p2: Point) -> float:
dx = p1.coords.x - p2.coords.x
dy = p1.coords.y - p2.coords.y
return sqrt(dx*dx + dy*dy)
class JunctionDetector(BaseDetector):
"""
Classifies points as 'END', 'L', or 'T' by skeletonizing the binarized image
and analyzing local connectivity. Also creates Junction objects in the context.
"""
def __init__(self, config: JunctionConfig, debug_handler: DebugHandler = None):
super().__init__(config, debug_handler) # no real model path
self.window_size = config.window_size
self.radius = config.radius
self.angle_threshold_lb = config.angle_threshold_lb
self.angle_threshold_ub = config.angle_threshold_ub
self.debug_handler = debug_handler or DebugHandler()
def _load_model(self, model_path: str):
"""Not loading any actual model, just skeleton logic."""
return None
def detect(self,
image: np.ndarray,
context: DetectionContext,
*args,
**kwargs) -> None:
"""
1) Convert to binary & skeletonize
2) Classify each point in the context
3) Create a Junction for each point and store it in context.junctions
(with 'connected_lines' referencing lines that share this point).
"""
# 1) Preprocess -> skeleton
skeleton = self._create_skeleton(image)
# 2) Classify each point
for pt in context.points.values():
pt.type = self._classify_point(skeleton, pt)
# 3) Create a Junction object for each point
# If you prefer only T or L, you can filter out END points.
self._record_junctions_in_context(context)
def _preprocess(self, image: np.ndarray) -> np.ndarray:
"""We might do thresholding; let's do a simple binary threshold."""
if image.ndim == 3:
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
else:
gray = image
_, bin_image = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY)
return bin_image
def _postprocess(self, image: np.ndarray) -> np.ndarray:
return image
def _create_skeleton(self, raw_image: np.ndarray) -> np.ndarray:
"""Skeletonize the binarized image."""
bin_img = self._preprocess(raw_image)
# For skeletonize, we need a boolean array
inv = cv2.bitwise_not(bin_img)
inv_bool = (inv > 127).astype(np.uint8)
skel = skeletonize(inv_bool).astype(np.uint8) * 255
return skel
def _classify_point(self, skeleton: np.ndarray, pt: Point) -> JunctionType:
"""
Given a skeleton image, look around 'pt' in a local window
to determine if it's an END, L, or T.
"""
classification = JunctionType.END # default
half_w = self.window_size // 2
x, y = pt.coords.x, pt.coords.y
top = max(0, y - half_w)
bottom = min(skeleton.shape[0], y + half_w + 1)
left = max(0, x - half_w)
right = min(skeleton.shape[1], x + half_w + 1)
patch = (skeleton[top:bottom, left:right] > 127).astype(np.uint8)
# create circular mask
circle_mask = np.zeros_like(patch, dtype=np.uint8)
local_cx = x - left
local_cy = y - top
cv2.circle(circle_mask, (local_cx, local_cy), self.radius, 1, -1)
circle_skel = patch & circle_mask
# label connected regions
labeled = label(circle_skel, connectivity=2)
num_exits = labeled.max()
if num_exits == 1:
classification = JunctionType.END
elif num_exits == 2:
# check angle for L
classification = self._check_angle_for_L(labeled)
elif num_exits == 3:
classification = JunctionType.T
return classification
def _check_angle_for_L(self, labeled_region: np.ndarray) -> JunctionType:
"""
If the angle between two branches is within
[angle_threshold_lb, angle_threshold_ub], it's 'L'.
Otherwise default to END.
"""
coords = np.argwhere(labeled_region == 1)
if len(coords) < 2:
return JunctionType.END
(y1, x1), (y2, x2) = coords[:2]
dx = x2 - x1
dy = y2 - y1
angle = math.degrees(math.atan2(dy, dx))
acute_angle = min(abs(angle), 180 - abs(angle))
if self.angle_threshold_lb <= acute_angle <= self.angle_threshold_ub:
return JunctionType.L
return JunctionType.END
# -----------------------------------------
# EXTRA STEP: Create Junction objects
# -----------------------------------------
def _record_junctions_in_context(self, context: DetectionContext):
"""
Create a Junction object for each point in context.points.
If you only want T/L points as junctions, filter them out.
Also track any lines that connect to this point.
"""
for pt in context.points.values():
# If you prefer to store all points as junction, do it:
# or if you want only T or L, do:
# if pt.type in {JunctionType.T, JunctionType.L}: ...
jn = Junction(
center=pt.coords,
junction_type=pt.type,
# add more properties if needed
)
# find lines that connect to this point
connected_lines = []
for ln in context.lines.values():
if ln.start.id == pt.id or ln.end.id == pt.id:
connected_lines.append(ln.id)
jn.connected_lines = connected_lines
# add to context
context.add_junction(jn)
import json
import uuid
class SymbolDetector(BaseDetector):
"""
A placeholder detector that reads precomputed symbol data
from a JSON file and populates the context with Symbol objects.
"""
def __init__(self,
config: SymbolConfig,
debug_handler: Optional[DebugHandler] = None,
symbol_json_path: str = "./symbols.json"):
super().__init__(config=config, debug_handler=debug_handler)
self.symbol_json_path = symbol_json_path
def _load_model(self, model_path: str):
"""Not loading an actual model; symbol data is read from JSON."""
return None
def detect(self,
image: np.ndarray,
context: DetectionContext,
# roi_offset: Tuple[int, int],
*args,
**kwargs) -> None:
"""
Reads from a JSON file containing symbol info,
adjusts coordinates using roi_offset, and updates context.
"""
symbol_data = self._load_json_data(self.symbol_json_path)
if not symbol_data:
return
# x_min, y_min = roi_offset # Offset values from cropping
for record in symbol_data.get("detections", []): # Fix: Use "detections" key
# sym_obj = self._parse_symbol_record(record, x_min, y_min)
sym_obj = self._parse_symbol_record(record)
context.add_symbol(sym_obj)
def _preprocess(self, image: np.ndarray) -> np.ndarray:
return image
def _postprocess(self, image: np.ndarray) -> np.ndarray:
return image
# --------------
# HELPER METHODS
# --------------
def _load_json_data(self, json_path: str) -> dict:
if not os.path.exists(json_path):
self.debug_handler.save_artifact(name="symbol_error",
data=b"Missing symbol JSON file",
extension="txt")
return {}
with open(json_path, "r", encoding="utf-8") as f:
return json.load(f)
def _parse_symbol_record(self, record: dict) -> Symbol:
"""
Builds a Symbol object from a JSON record, adjusting coordinates for cropping.
"""
bbox_list = record.get("bbox", [0, 0, 0, 0])
# bbox_obj = BBox(
# xmin=bbox_list[0] - x_min,
# ymin=bbox_list[1] - y_min,
# xmax=bbox_list[2] - x_min,
# ymax=bbox_list[3] - y_min
# )
bbox_obj = BBox(
xmin=bbox_list[0],
ymin=bbox_list[1],
xmax=bbox_list[2],
ymax=bbox_list[3]
)
# Compute the center
center_coords = Coordinates(
x=(bbox_obj.xmin + bbox_obj.xmax) // 2,
y=(bbox_obj.ymin + bbox_obj.ymax) // 2
)
return Symbol(
id=record.get("symbol_id", ""),
class_id=record.get("class_id", -1),
original_label=record.get("original_label", ""),
category=record.get("category", ""),
type=record.get("type", ""),
label=record.get("label", ""),
bbox=bbox_obj,
center=center_coords,
confidence=record.get("confidence", 0.95),
model_source=record.get("model_source", ""),
connections=[]
)
class TagDetector(BaseDetector):
"""
A placeholder detector that reads precomputed tag data
from a JSON file and populates the context with Tag objects.
"""
def __init__(self,
config: TagConfig,
debug_handler: Optional[DebugHandler] = None,
tag_json_path: str = "./tags.json"):
super().__init__(config=config, debug_handler=debug_handler)
self.tag_json_path = tag_json_path
def _load_model(self, model_path: str):
"""Not loading an actual model; tag data is read from JSON."""
return None
def detect(self,
image: np.ndarray,
context: DetectionContext,
# roi_offset: Tuple[int, int],
*args,
**kwargs) -> None:
"""
Reads from a JSON file containing tag info,
adjusts coordinates using roi_offset, and updates context.
"""
tag_data = self._load_json_data(self.tag_json_path)
if not tag_data:
return
# x_min, y_min = roi_offset # Offset values from cropping
for record in tag_data.get("detections", []): # Fix: Use "detections" key
# tag_obj = self._parse_tag_record(record, x_min, y_min)
tag_obj = self._parse_tag_record(record)
context.add_tag(tag_obj)
def _preprocess(self, image: np.ndarray) -> np.ndarray:
return image
def _postprocess(self, image: np.ndarray) -> np.ndarray:
return image
# --------------
# HELPER METHODS
# --------------
def _load_json_data(self, json_path: str) -> dict:
if not os.path.exists(json_path):
self.debug_handler.save_artifact(name="tag_error",
data=b"Missing tag JSON file",
extension="txt")
return {}
with open(json_path, "r", encoding="utf-8") as f:
return json.load(f)
def _parse_tag_record(self, record: dict) -> Tag:
"""
Builds a Tag object from a JSON record, adjusting coordinates for cropping.
"""
bbox_list = record.get("bbox", [0, 0, 0, 0])
# bbox_obj = BBox(
# xmin=bbox_list[0] - x_min,
# ymin=bbox_list[1] - y_min,
# xmax=bbox_list[2] - x_min,
# ymax=bbox_list[3] - y_min
# )
bbox_obj = BBox(
xmin=bbox_list[0],
ymin=bbox_list[1],
xmax=bbox_list[2],
ymax=bbox_list[3]
)
return Tag(
text=record.get("text", ""),
bbox=bbox_obj,
confidence=record.get("confidence", 1.0),
source=record.get("source", ""),
text_type=record.get("text_type", "Unknown"),
id=record.get("id", str(uuid.uuid4())),
font_size=record.get("font_size", 12),
rotation=record.get("rotation", 0.0)
)