msIntui commited on
Commit
7f22c74
·
1 Parent(s): 690b5e4

Fix merge conflict in detectors.py

Browse files
Files changed (1) hide show
  1. detectors.py +103 -1052
detectors.py CHANGED
@@ -3,12 +3,17 @@ import math
3
  import torch
4
  import cv2
5
  import numpy as np
6
- from typing import List, Optional, Tuple, Dict
7
  from dataclasses import replace
8
  from math import sqrt
9
  import json
10
  import uuid
11
  from pathlib import Path
 
 
 
 
 
12
 
13
  # Base classes and utilities
14
  from base import BaseDetector
@@ -18,7 +23,6 @@ from config import SymbolConfig, TagConfig, LineConfig, PointConfig, JunctionCon
18
 
19
  # DeepLSD model for line detection
20
  from deeplsd.models.deeplsd_inference import DeepLSD
21
- from ultralytics import YOLO
22
 
23
  # Detection schema: dataclasses for different objects
24
  from detection_schema import (
@@ -39,1058 +43,105 @@ from detection_schema import (
39
  from skimage.morphology import skeletonize
40
  from skimage.measure import label
41
 
 
 
42
 
43
- class LineDetector(BaseDetector):
44
- """
45
- DeepLSD-based line detection that populates newly detected lines (and naive endpoints)
46
- directly into a DetectionContext.
47
- """
48
-
49
- <<<<<<< HEAD
50
- def __init__(self, model_path=None, model=None, model_config=None, device=None, debug_handler=None):
51
- self.device = device or torch.device('cpu')
52
  self.debug_handler = debug_handler
53
- if model is not None:
54
- self.model = model
55
- else:
56
- super().__init__(model_path)
57
- self.config = model_config or {}
58
- self.scale_factor = 8.0 # Inverse of 0.5 scaling
59
- self.margin = 10 # BBox expansion margin
60
- =======
61
- def __init__(self,
62
- config: LineConfig,
63
- model_path: str,
64
- model_config: dict,
65
- device: torch.device,
66
- debug_handler: DebugHandler = None):
67
- self.device = device
68
- self.model_path = model_path
69
- self.model_config = model_config
70
- super().__init__(config, debug_handler)
71
- self._load_params()
72
- self.model = self._load_model(model_path)
73
- self.scale_factor = 0.75 # For downscaling input to model
74
- self.margin = 10
75
- >>>>>>> temp/test-integration
76
-
77
- # -------------------------------------
78
- # BaseDetector requirements
79
- # -------------------------------------
80
- def _load_model(self, model_path: str) -> DeepLSD:
81
- """Load and configure the DeepLSD model."""
82
- if not os.path.exists(model_path):
83
- raise FileNotFoundError(f"Model file not found: {model_path}")
84
- ckpt = torch.load(model_path, map_location=self.device)
85
- <<<<<<< HEAD
86
- model = DeepLSD(self.config)
87
- model.load_state_dict(ckpt['model'])
88
- =======
89
- model = DeepLSD(self.model_config)
90
- model.load_state_dict(ckpt["model"])
91
- >>>>>>> temp/test-integration
92
- return model.to(self.device).eval()
93
-
94
- def _preprocess(self, image: np.ndarray) -> np.ndarray:
95
- """
96
- Not used directly here. We'll handle our own
97
- masking + threshold steps in the detect() method.
98
- """
99
- return image
100
-
101
- def _postprocess(self, image: np.ndarray) -> np.ndarray:
102
- """
103
- Not used directly. Postprocessing is integrated
104
- into detect() after we create lines.
105
- """
106
- return image
107
-
108
- # -------------------------------------
109
- # Our main detection method
110
- # -------------------------------------
111
- def detect(self,
112
- image: np.ndarray,
113
- context: DetectionContext,
114
- mask_coords: Optional[List[BBox]] = None,
115
- *args,
116
- **kwargs) -> None:
117
- """
118
- Main detection pipeline:
119
- 1) Apply mask
120
- 2) Convert to binary & downscale
121
- 3) Run DeepLSD
122
- 4) Build minimal Line objects (with naive endpoints)
123
- 5) Scale lines to original resolution
124
- 6) Store the lines into the context
125
-
126
- We do NOT unify endpoints here or classify them as T/L/etc.
127
- """
128
- mask_coords = mask_coords or []
129
-
130
- # (A) Preprocess
131
- processed_img = self._apply_mask_and_downscale(image, mask_coords)
132
-
133
- # (B) Inference
134
- raw_output = self._run_model_inference(processed_img)
135
-
136
- # (C) Create lines in downscaled space
137
- downscaled_lines = self._create_lines_from_output(raw_output)
138
-
139
- # (D) Scale them to original resolution
140
- lines_scaled = [self._scale_line(ln) for ln in downscaled_lines]
141
-
142
- # (E) Add them to context
143
- for line in lines_scaled:
144
- context.add_line(line)
145
-
146
- # -------------------------------------
147
- # Internal helpers
148
- # -------------------------------------
149
- def _load_params(self):
150
- """Load any model_config parameters if needed."""
151
  pass
 
 
 
 
 
152
 
153
- def _apply_mask_and_downscale(self, image: np.ndarray, mask_coords: List[BBox]) -> np.ndarray:
154
- """Apply rectangular mask, then threshold, then downscale."""
155
- masked = self._apply_masking(image, mask_coords)
156
- gray = cv2.cvtColor(masked, cv2.COLOR_RGB2GRAY)
157
- <<<<<<< HEAD
158
- binary = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY)[1]
159
- return cv2.resize(binary, None, fx=1/self.scale_factor, fy=1/self.scale_factor)
160
- =======
161
- binary_full = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY)[1]
162
- >>>>>>> temp/test-integration
163
-
164
- kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 1))
165
- dilated = cv2.dilate(binary_full, kernel, iterations=2)
166
-
167
- # Downscale
168
- binary_downscaled = cv2.resize(
169
- dilated,
170
- None,
171
- fx=self.scale_factor,
172
- fy=self.scale_factor
173
- )
174
- return binary_downscaled
175
-
176
- def _apply_masking(self, image: np.ndarray, mask_coords: List[BBox]) -> np.ndarray:
177
- """White-out rectangular areas to ignore them."""
178
- masked = image.copy()
179
- for bbox in mask_coords:
180
- x1, y1 = int(bbox.xmin), int(bbox.ymin)
181
- x2, y2 = int(bbox.xmax), int(bbox.ymax)
182
- cv2.rectangle(masked, (x1, y1), (x2, y2), (255, 255, 255), -1)
183
- return masked
184
-
185
- def _run_model_inference(self, downscaled_binary: np.ndarray) -> np.ndarray:
186
- """Run DeepLSD on the downscaled binary image, returning raw lines [N, 2, 2]."""
187
- tensor = torch.tensor(downscaled_binary, dtype=torch.float32, device=self.device)[None, None] / 255.0
188
- # tensor = torch.tensor(downscaled_binary, dtype=torch.float32, device=self.device)[None, None] / 255.0
189
- with torch.no_grad():
190
- output = self.model({"image": tensor})
191
- # shape: [batch, num_lines, 2, 2]
192
- return output["lines"][0]
193
-
194
- def _create_lines_from_output(self, model_output: np.ndarray) -> List[Line]:
195
- """
196
- Convert each [2,2] line segment into a minimal Line with naive endpoints (type=END).
197
- Coordinates are in downscaled space.
198
- """
199
- lines = []
200
- for endpoints in model_output:
201
- (x1, y1), (x2, y2) = endpoints # shape (2,) each
202
-
203
- p_start = self._create_point(x1, y1)
204
- p_end = self._create_point(x2, y2)
205
-
206
- # minimal bounding box in downscaled coords
207
- x_min = min(x1, x2)
208
- x_max = max(x1, x2)
209
- y_min = min(y1, y2)
210
- y_max = max(y1, y2)
211
-
212
- line_obj = Line(
213
- start=p_start,
214
- end=p_end,
215
- bbox=BBox(
216
- xmin=int(x_min),
217
- ymin=int(y_min),
218
- xmax=int(x_max),
219
- ymax=int(y_max)
220
- ),
221
- # style / confidence / ID assigned by default
222
- style=LineStyle(
223
- connection_type=ConnectionType.SOLID,
224
- stroke_width=2,
225
- color="#000000"
226
- ),
227
- confidence=0.9,
228
- topological_links=[]
229
- )
230
- lines.append(line_obj)
231
-
232
- return lines
233
-
234
- def _create_point(self, x: float, y: float) -> Point:
235
- """
236
- Creates a naive 'END'-type Point at downscaled coords.
237
- We'll scale it later.
238
- """
239
- margin = 2
240
- return Point(
241
- coords=Coordinates(x=int(x), y=int(y)),
242
- bbox=BBox(
243
- xmin=int(x - margin),
244
- ymin=int(y - margin),
245
- xmax=int(x + margin),
246
- ymax=int(y + margin)
247
- ),
248
- type=JunctionType.END, # no classification here
249
- confidence=1.0
250
- )
251
-
252
- def _scale_line(self, line: Line) -> Line:
253
- """
254
- Scale line's start/end points + bounding box to original resolution.
255
- """
256
- scaled_start = self._scale_point(line.start)
257
- scaled_end = self._scale_point(line.end)
258
-
259
- # recalc bounding box in original scale
260
- new_bbox = BBox(
261
- xmin=min(scaled_start.bbox.xmin, scaled_end.bbox.xmin),
262
- ymin=min(scaled_start.bbox.ymin, scaled_end.bbox.ymin),
263
- xmax=max(scaled_start.bbox.xmax, scaled_end.bbox.xmax),
264
- ymax=max(scaled_start.bbox.ymax, scaled_end.bbox.ymax)
265
- )
266
-
267
- return replace(line, start=scaled_start, end=scaled_end, bbox=new_bbox)
268
-
269
- def _scale_point(self, point: Point) -> Point:
270
- sx = int(point.coords.x * 1/self.scale_factor)
271
- sy = int(point.coords.y * 1/self.scale_factor)
272
-
273
- bb = point.bbox
274
- scaled_bbox = BBox(
275
- xmin=int(bb.xmin * 1/self.scale_factor),
276
- ymin=int(bb.ymin * 1/self.scale_factor),
277
- xmax=int(bb.xmax * 1/self.scale_factor),
278
- ymax=int(bb.ymax * 1/self.scale_factor)
279
- )
280
- return replace(point, coords=Coordinates(sx, sy), bbox=scaled_bbox)
281
-
282
-
283
- class PointDetector(BaseDetector):
284
- """
285
- A detector that:
286
- 1) Reads lines from the context
287
- 2) Clusters endpoints within 'threshold_distance'
288
- 3) Updates lines so that shared endpoints reference the same Point object
289
- """
290
-
291
- def __init__(self,
292
- config:PointConfig,
293
- debug_handler: DebugHandler = None):
294
- super().__init__(config, debug_handler) # No real model to load
295
- self.threshold_distance = config.threshold_distance
296
-
297
- def _load_model(self, model_path: str):
298
- """No model needed for simple point unification."""
299
- return None
300
-
301
- def detect(self, image: np.ndarray, context: DetectionContext, *args, **kwargs) -> None:
302
- """
303
- Main method called by the pipeline.
304
- 1) Gather all line endpoints from context
305
- 2) Cluster them within 'threshold_distance'
306
- 3) Update the line endpoints so they reference the unified cluster point
307
- """
308
- # 1) Collect all endpoints
309
- endpoints = []
310
- for line in context.lines.values():
311
- endpoints.append(line.start)
312
- endpoints.append(line.end)
313
-
314
- # 2) Cluster endpoints
315
- clusters = self._cluster_points(endpoints, self.threshold_distance)
316
-
317
- # 3) Build a dictionary of "representative" points
318
- # So that each cluster has one "canonical" point
319
- # Then we link all the points in that cluster to the canonical reference
320
- unified_point_map = {}
321
- for cluster in clusters:
322
- # let's pick the first point in the cluster as the "representative"
323
- rep_point = cluster[0]
324
- for p in cluster[1:]:
325
- unified_point_map[p.id] = rep_point
326
-
327
- # 4) Update all lines to reference the canonical point
328
- for line in context.lines.values():
329
- # unify start
330
- if line.start.id in unified_point_map:
331
- line.start = unified_point_map[line.start.id]
332
- # unify end
333
- if line.end.id in unified_point_map:
334
- line.end = unified_point_map[line.end.id]
335
-
336
- # We could also store the final set of unique points back in context.points
337
- # (e.g. clearing old duplicates).
338
- # That step is optional: you might prefer to keep everything in lines only,
339
- # or you might want context.points as a separate reference.
340
-
341
- # If you want to keep unique points in context.points:
342
- new_points = {}
343
- for line in context.lines.values():
344
- new_points[line.start.id] = line.start
345
- new_points[line.end.id] = line.end
346
- context.points = new_points # replace the dictionary of points
347
-
348
- def _preprocess(self, image: np.ndarray) -> np.ndarray:
349
- """No specific image preprocessing needed."""
350
- return image
351
-
352
- def _postprocess(self, image: np.ndarray) -> np.ndarray:
353
- """No specific image postprocessing needed."""
354
- return image
355
-
356
- # ----------------------
357
- # HELPER: clustering
358
- # ----------------------
359
- def _cluster_points(self, points: List[Point], threshold: float) -> List[List[Point]]:
360
- """
361
- Very naive clustering:
362
- 1) Start from the first point
363
- 2) If it's within threshold of an existing cluster's representative,
364
- put it in that cluster
365
- 3) Otherwise start a new cluster
366
- Return: list of clusters, each is a list of Points
367
- """
368
- clusters = []
369
-
370
- for pt in points:
371
- placed = False
372
- for cluster in clusters:
373
- # pick the first point in the cluster as reference
374
- ref_pt = cluster[0]
375
- if self._distance(pt, ref_pt) < threshold:
376
- cluster.append(pt)
377
- placed = True
378
- break
379
-
380
- if not placed:
381
- clusters.append([pt])
382
-
383
- return clusters
384
-
385
- def _distance(self, p1: Point, p2: Point) -> float:
386
- dx = p1.coords.x - p2.coords.x
387
- dy = p1.coords.y - p2.coords.y
388
- return sqrt(dx*dx + dy*dy)
389
-
390
-
391
- class JunctionDetector(BaseDetector):
392
- """
393
- Classifies points as 'END', 'L', or 'T' by skeletonizing the binarized image
394
- and analyzing local connectivity. Also creates Junction objects in the context.
395
- """
396
-
397
- def __init__(self, config: JunctionConfig, debug_handler: DebugHandler = None):
398
- super().__init__(config, debug_handler) # no real model path
399
- self.window_size = config.window_size
400
- self.radius = config.radius
401
- self.angle_threshold_lb = config.angle_threshold_lb
402
- self.angle_threshold_ub = config.angle_threshold_ub
403
- self.debug_handler = debug_handler or DebugHandler()
404
-
405
- def _load_model(self, model_path: str):
406
- """Not loading any actual model, just skeleton logic."""
407
- return None
408
-
409
- def detect(self,
410
- image: np.ndarray,
411
- context: DetectionContext,
412
- *args,
413
- **kwargs) -> None:
414
- """
415
- 1) Convert to binary & skeletonize
416
- 2) Classify each point in the context
417
- 3) Create a Junction for each point and store it in context.junctions
418
- (with 'connected_lines' referencing lines that share this point).
419
- """
420
- # 1) Preprocess -> skeleton
421
- skeleton = self._create_skeleton(image)
422
-
423
- # 2) Classify each point
424
- for pt in context.points.values():
425
- pt.type = self._classify_point(skeleton, pt)
426
-
427
- # 3) Create a Junction object for each point
428
- # If you prefer only T or L, you can filter out END points.
429
- self._record_junctions_in_context(context)
430
 
431
- def _preprocess(self, image: np.ndarray) -> np.ndarray:
432
- """We might do thresholding; let's do a simple binary threshold."""
433
- if image.ndim == 3:
434
- gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
435
- else:
436
- gray = image
437
- _, bin_image = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY)
438
- return bin_image
439
-
440
- def _postprocess(self, image: np.ndarray) -> np.ndarray:
441
- return image
442
-
443
- def _create_skeleton(self, raw_image: np.ndarray) -> np.ndarray:
444
- """Skeletonize the binarized image."""
445
- bin_img = self._preprocess(raw_image)
446
- # For skeletonize, we need a boolean array
447
- inv = cv2.bitwise_not(bin_img)
448
- inv_bool = (inv > 127).astype(np.uint8)
449
- skel = skeletonize(inv_bool).astype(np.uint8) * 255
450
- return skel
451
-
452
- def _classify_point(self, skeleton: np.ndarray, pt: Point) -> JunctionType:
453
- """
454
- Given a skeleton image, look around 'pt' in a local window
455
- to determine if it's an END, L, or T.
456
- """
457
- classification = JunctionType.END # default
458
-
459
- half_w = self.window_size // 2
460
- x, y = pt.coords.x, pt.coords.y
461
-
462
- top = max(0, y - half_w)
463
- bottom = min(skeleton.shape[0], y + half_w + 1)
464
- left = max(0, x - half_w)
465
- right = min(skeleton.shape[1], x + half_w + 1)
466
-
467
- patch = (skeleton[top:bottom, left:right] > 127).astype(np.uint8)
468
-
469
- # create circular mask
470
- circle_mask = np.zeros_like(patch, dtype=np.uint8)
471
- local_cx = x - left
472
- local_cy = y - top
473
- cv2.circle(circle_mask, (local_cx, local_cy), self.radius, 1, -1)
474
- circle_skel = patch & circle_mask
475
-
476
- # label connected regions
477
- labeled = label(circle_skel, connectivity=2)
478
- num_exits = labeled.max()
479
-
480
- if num_exits == 1:
481
- classification = JunctionType.END
482
- elif num_exits == 2:
483
- # check angle for L
484
- classification = self._check_angle_for_L(labeled)
485
- elif num_exits == 3:
486
- classification = JunctionType.T
487
-
488
- return classification
489
-
490
- def _check_angle_for_L(self, labeled_region: np.ndarray) -> JunctionType:
491
- """
492
- If the angle between two branches is within
493
- [angle_threshold_lb, angle_threshold_ub], it's 'L'.
494
- Otherwise default to END.
495
- """
496
- coords = np.argwhere(labeled_region == 1)
497
- if len(coords) < 2:
498
- return JunctionType.END
499
-
500
- (y1, x1), (y2, x2) = coords[:2]
501
- dx = x2 - x1
502
- dy = y2 - y1
503
- angle = math.degrees(math.atan2(dy, dx))
504
- acute_angle = min(abs(angle), 180 - abs(angle))
505
-
506
- if self.angle_threshold_lb <= acute_angle <= self.angle_threshold_ub:
507
- return JunctionType.L
508
- return JunctionType.END
509
-
510
- # -----------------------------------------
511
- # EXTRA STEP: Create Junction objects
512
- # -----------------------------------------
513
- def _record_junctions_in_context(self, context: DetectionContext):
514
- """
515
- Create a Junction object for each point in context.points.
516
- If you only want T/L points as junctions, filter them out.
517
- Also track any lines that connect to this point.
518
- """
519
-
520
- for pt in context.points.values():
521
- # If you prefer to store all points as junction, do it:
522
- # or if you want only T or L, do:
523
- # if pt.type in {JunctionType.T, JunctionType.L}: ...
524
-
525
- jn = Junction(
526
- center=pt.coords,
527
- junction_type=pt.type,
528
- # add more properties if needed
529
- )
530
-
531
- # find lines that connect to this point
532
- connected_lines = []
533
- for ln in context.lines.values():
534
- if ln.start.id == pt.id or ln.end.id == pt.id:
535
- connected_lines.append(ln.id)
536
-
537
- jn.connected_lines = connected_lines
538
-
539
- # add to context
540
- context.add_junction(jn)
541
-
542
- # from loguru import logger
543
- #
544
- #
545
- # class SymbolDetector(BaseDetector):
546
- # """
547
- # YOLO-based symbol detector using multiple confidence thresholds,
548
- # merges final detections, and stores them in the context.
549
- # """
550
- #
551
- # def __init__(self, config: SymbolConfig, debug_handler: Optional[DebugHandler] = None):
552
- # super().__init__(config, debug_handler)
553
- # self.config = config
554
- # self.debug_handler = debug_handler or DebugHandler()
555
- # self.models = self._load_models()
556
- # self.class_map = self._build_class_map()
557
- #
558
- # logger.info("Symbol detector initialized with config: %s", self.config)
559
- #
560
- # # -----------------------------
561
- # # BaseDetector Implementation
562
- # # -----------------------------
563
- # def _load_model(self, model_path: str):
564
- # """We won't use this single-model loader; see _load_models()."""
565
- # pass
566
- #
567
- # def detect(self,
568
- # image: np.ndarray,
569
- # context: DetectionContext,
570
- # roi_offset: Tuple[int, int],
571
- # *args,
572
- # **kwargs) -> None:
573
- # """
574
- # Run multi-threshold YOLO detection for each model, pick best threshold,
575
- # merge detections, and store Symbol objects in context.
576
- # """
577
- # try:
578
- # with self.debug_handler.track_performance("symbol_detection"):
579
- # # 1) Possibly preprocess & resize
580
- # processed_img = self._preprocess(image)
581
- # resized_img, scale_factor = self._resize_image(processed_img)
582
- #
583
- # # 2) Detect with all models, each using multiple thresholds
584
- # all_detections = []
585
- # for model_name, model in self.models.items():
586
- # best_detections = self._detect_best_threshold(
587
- # model, resized_img, image.shape, scale_factor, model_name
588
- # )
589
- # all_detections.extend(best_detections)
590
- #
591
- # # 3) Merge detections using NMS logic
592
- # merged_detections = self._merge_detections(all_detections)
593
- #
594
- # # 4) Update context with final symbols
595
- # self._update_context(merged_detections, context)
596
- #
597
- # # 5) Create optional debug image artifact
598
- # debug_image = self._create_debug_image(processed_img, merged_detections)
599
- # _, debug_img_encoded = cv2.imencode('.jpg', debug_image)
600
- # self.debug_handler.save_artifact(
601
- # name="symbol_detection_debug",
602
- # data=debug_img_encoded.tobytes(),
603
- # extension="jpg"
604
- # )
605
- #
606
- # except Exception as e:
607
- # logger.error("Symbol detection failed: %s", str(e), exc_info=True)
608
- # self.debug_handler.save_artifact(
609
- # name="symbol_detection_error",
610
- # data=f"Detection error: {str(e)}".encode('utf-8'),
611
- # extension="txt"
612
- # )
613
- #
614
- # def _preprocess(self, image: np.ndarray) -> np.ndarray:
615
- # """Preprocess if needed (e.g., histogram equalization)."""
616
- # if self.config.apply_preprocessing:
617
- # gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
618
- # equalized = cv2.equalizeHist(gray)
619
- # # convert back to BGR for YOLO
620
- # return cv2.cvtColor(equalized, cv2.COLOR_GRAY2BGR)
621
- # return image.copy()
622
- #
623
- # def _postprocess(self, image: np.ndarray) -> np.ndarray:
624
- # return None
625
- #
626
- # # -----------------------------
627
- # # Internal Helpers
628
- # # -----------------------------
629
- # def _load_models(self) -> Dict[str, YOLO]:
630
- # """Load multiple YOLO models from config."""
631
- # models = {}
632
- # for model_name, path_str in self.config.model_paths.items():
633
- # path = Path(path_str)
634
- # if not path.exists():
635
- # raise FileNotFoundError(f"Model file not found: {path_str}")
636
- # models[model_name] = YOLO(str(path))
637
- # logger.info(f"Loaded model '{model_name}' from {path_str}")
638
- # return models
639
- #
640
- # def _build_class_map(self) -> Dict[int, SymbolType]:
641
- # """
642
- # Convert config symbol_type_mapping (like {"pump": "PUMP"})
643
- # into a dictionary from YOLO class_id to SymbolType.
644
- # If you have a fixed list of YOLO classes, you can map them here.
645
- # """
646
- # # For example, if YOLO has classes like ["valve", "pump", ...],
647
- # # you might want to do something more dynamic.
648
- # # For now, let's just return an empty dict or handle it in detection.
649
- # return {}
650
- #
651
- # def _resize_image(self, image: np.ndarray) -> Tuple[np.ndarray, float]:
652
- # """Resize while maintaining aspect ratio if needed."""
653
- # h, w = image.shape[:2]
654
- # if not self.config.resize_image:
655
- # return image, 1.0
656
- #
657
- # if max(w, h) > self.config.max_dimension:
658
- # scale = self.config.max_dimension / max(w, h)
659
- # new_w, new_h = int(w * scale), int(h * scale)
660
- # resized = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
661
- # return resized, scale
662
- # return image, 1.0
663
- #
664
- # def _detect_best_threshold(self,
665
- # model: YOLO,
666
- # resized_img: np.ndarray,
667
- # orig_shape: Tuple[int, int, int],
668
- # scale_factor: float,
669
- # model_name: str) -> List[Dict]:
670
- # """
671
- # Run detection across multiple confidence thresholds.
672
- # Use the threshold that yields the 'best metric' (currently # of detections).
673
- # """
674
- # best_metric = -1
675
- # best_threshold = 0.5
676
- # best_detections_list = []
677
- #
678
- # # Evaluate each threshold
679
- # for thresh in self.config.confidence_thresholds:
680
- # # Run YOLO detection
681
- # # Setting conf=thresh or conf=0.0 + we do filtering ourselves.
682
- # results = model.predict(
683
- # source=resized_img,
684
- # imgsz=self.config.max_dimension,
685
- # conf=0.0, # We'll filter manually below
686
- # verbose=False
687
- # )
688
- #
689
- # # Convert to detection dict
690
- # detections_list = []
691
- # for result in results:
692
- # for box in result.boxes:
693
- # conf_val = float(box.conf[0])
694
- # if conf_val >= thresh:
695
- # # Convert bounding box coords to original (local) coords
696
- # x1, y1, x2, y2 = self._scale_coordinates(
697
- # box.xyxy[0].cpu().numpy(),
698
- # resized_img.shape, # shape after resizing
699
- # scale_factor
700
- # )
701
- # class_id = int(box.cls[0])
702
- # label = result.names[class_id] if result.names else "unknown_label"
703
- #
704
- # # parse label (category, type, new_label)
705
- # category, type_str, new_label = self._parse_label(label)
706
- #
707
- # detection_info = {
708
- # "symbol_id": str(uuid.uuid4()),
709
- # "class_id": class_id,
710
- # "original_label": label,
711
- # "category": category,
712
- # "type": type_str,
713
- # "label": new_label,
714
- # "confidence": conf_val,
715
- # "bbox": [x1, y1, x2, y2],
716
- # "model_source": model_name
717
- # }
718
- # detections_list.append(detection_info)
719
- #
720
- # # Evaluate
721
- # metric = self._evaluate_detections(detections_list)
722
- # if metric > best_metric:
723
- # best_metric = metric
724
- # best_threshold = thresh
725
- # best_detections_list = detections_list
726
- #
727
- # logger.info(f"For model {model_name}, best threshold={best_threshold:.2f} with {best_metric} detections.")
728
- # return best_detections_list
729
- #
730
- # def _evaluate_detections(self, detections_list: List[Dict]) -> int:
731
- # """A simple metric: # of detections."""
732
- # return len(detections_list)
733
- #
734
- # def _parse_label(self, label: str) -> Tuple[str, str, str]:
735
- # """
736
- # Attempt to parse the YOLO label into (category, type, new_label).
737
- # Example label: "inst_ind_Solenoid_actuator"
738
- # -> category=inst, type=ind, new_label="Solenoid_actuator"
739
- # If no underscores, we fallback to "Unknown" for type.
740
- # """
741
- # split_label = label.split('_')
742
- # if len(split_label) >= 3:
743
- # category = split_label[0]
744
- # type_ = split_label[1]
745
- # new_label = '_'.join(split_label[2:])
746
- # elif len(split_label) == 2:
747
- # category = split_label[0]
748
- # type_ = split_label[1]
749
- # new_label = split_label[1]
750
- # elif len(split_label) == 1:
751
- # category = split_label[0]
752
- # type_ = "Unknown"
753
- # new_label = split_label[0]
754
- # else:
755
- # logger.warning(f"Unexpected label format: {label}")
756
- # return ("Unknown", "Unknown", label)
757
- #
758
- # return (category, type_, new_label)
759
- #
760
- # def _scale_coordinates(self,
761
- # coords: np.ndarray,
762
- # resized_shape: Tuple[int, int, int],
763
- # scale_factor: float) -> Tuple[int, int, int, int]:
764
- # """
765
- # Scale YOLO's [x1,y1,x2,y2] from the resized image back to the original local coords.
766
- # """
767
- # x1, y1, x2, y2 = coords
768
- # # Because we resized by scale_factor
769
- # # so original coordinate = coords / scale_factor
770
- # return (
771
- # int(x1 / scale_factor),
772
- # int(y1 / scale_factor),
773
- # int(x2 / scale_factor),
774
- # int(y2 / scale_factor),
775
- # )
776
- #
777
- # def _merge_detections(self, all_detections: List[Dict]) -> List[Dict]:
778
- # """Merge using NMS-like approach (IoU-based) across all models."""
779
- # if not all_detections:
780
- # return []
781
- #
782
- # # Sort by confidence (descending)
783
- # all_detections.sort(key=lambda x: x['confidence'], reverse=True)
784
- # keep = [True] * len(all_detections)
785
- #
786
- # for i in range(len(all_detections)):
787
- # if not keep[i]:
788
- # continue
789
- # for j in range(i + 1, len(all_detections)):
790
- # if not keep[j]:
791
- # continue
792
- # # Merge if same class_id & high IoU
793
- # if (all_detections[i]['class_id'] == all_detections[j]['class_id'] and
794
- # self._calculate_iou(all_detections[i]['bbox'], all_detections[j]['bbox']) > 0.5):
795
- # keep[j] = False
796
- #
797
- # return [det for idx, det in enumerate(all_detections) if keep[idx]]
798
- #
799
- # def _calculate_iou(self, box1: List[int], box2: List[int]) -> float:
800
- # """Intersection over Union"""
801
- # x_left = max(box1[0], box2[0])
802
- # y_top = max(box1[1], box2[1])
803
- # x_right = min(box1[2], box2[2])
804
- # y_bottom = min(box1[3], box2[3])
805
- #
806
- # inter_w = max(0, x_right - x_left)
807
- # inter_h = max(0, y_bottom - y_top)
808
- # intersection = inter_w * inter_h
809
- #
810
- # area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
811
- # area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
812
- # union = float(area1 + area2 - intersection)
813
- # return intersection / union if union > 0 else 0.0
814
- #
815
- # def _update_context(self, detections: List[Dict], context: DetectionContext) -> None:
816
- # """Convert final detections into Symbol objects & add to context."""
817
- # for det in detections:
818
- # x1, y1, x2, y2 = det['bbox']
819
- # # Use your Symbol dataclass from detection_schema
820
- # symbol_obj = Symbol(
821
- # bbox=BBox(xmin=x1, ymin=y1, xmax=x2, ymax=y2),
822
- # center=Coordinates(x=(x1 + x2) // 2, y=(y1 + y2) // 2),
823
- # symbol_type=SymbolType.OTHER, # default
824
- # confidence=det['confidence'],
825
- # model_source=det['model_source'],
826
- # class_id=det['class_id'],
827
- # original_label=det['original_label'],
828
- # category=det['category'],
829
- # type=det['type'],
830
- # label=det['label']
831
- # )
832
- # context.add_symbol(symbol_obj)
833
- #
834
- # def _create_debug_image(self, image: np.ndarray, detections: List[Dict]) -> np.ndarray:
835
- # """Optional: draw bounding boxes & labels on a copy of 'image'."""
836
- # debug_img = image.copy()
837
- # for det in detections:
838
- # x1, y1, x2, y2 = det['bbox']
839
- # cv2.rectangle(debug_img, (x1, y1), (x2, y2), (0, 255, 0), 2)
840
- # txt = f"{det['label']} {det['confidence']:.2f}"
841
- # cv2.putText(debug_img, txt, (x1, max(0, y1 - 10)),
842
- # cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)
843
- # return debug_img
844
- #
845
- #
846
- # class TagDetector(BaseDetector):
847
- # """
848
- # A placeholder detector that reads precomputed tag data
849
- # from a JSON file and populates the context with Tag objects.
850
- # """
851
- #
852
- # def __init__(self,
853
- # config: TagConfig,
854
- # debug_handler: Optional[DebugHandler] = None,
855
- # tag_json_path: str = "./tags.json"):
856
- # super().__init__(config=config, debug_handler=debug_handler)
857
- # self.tag_json_path = tag_json_path
858
- #
859
- # def _load_model(self, model_path: str):
860
- # """Not loading an actual model; tag data is read from JSON."""
861
- # return None
862
- #
863
- # def detect(self,
864
- # image: np.ndarray,
865
- # context: DetectionContext,
866
- # roi_offset: Tuple[int, int],
867
- # *args,
868
- # **kwargs) -> None:
869
- # """
870
- # Reads from a JSON file containing tag info,
871
- # adjusts coordinates using roi_offset, and updates context.
872
- # """
873
- #
874
- # tag_data = self._load_json_data(self.tag_json_path)
875
- # if not tag_data:
876
- # return
877
- #
878
- # x_min, y_min = roi_offset # Offset values from cropping
879
- #
880
- # for record in tag_data.get("detections", []): # Fix: Use "detections" key
881
- # tag_obj = self._parse_tag_record(record, x_min, y_min)
882
- # context.add_tag(tag_obj)
883
- #
884
- # def _preprocess(self, image: np.ndarray) -> np.ndarray:
885
- # return image
886
- #
887
- # def _postprocess(self, image: np.ndarray) -> np.ndarray:
888
- # return image
889
- #
890
- # # --------------
891
- # # HELPER METHODS
892
- # # --------------
893
- # def _load_json_data(self, json_path: str) -> dict:
894
- # if not os.path.exists(json_path):
895
- # self.debug_handler.save_artifact(name="tag_error",
896
- # data=b"Missing tag JSON file",
897
- # extension="txt")
898
- # return {}
899
- #
900
- # with open(json_path, "r", encoding="utf-8") as f:
901
- # return json.load(f)
902
- #
903
- # def _parse_tag_record(self, record: dict, x_min: int, y_min: int) -> Tag:
904
- # """
905
- # Builds a Tag object from a JSON record, adjusting coordinates for cropping.
906
- # """
907
- # bbox_list = record.get("bbox", [0, 0, 0, 0])
908
- # bbox_obj = BBox(
909
- # xmin=bbox_list[0] - x_min,
910
- # ymin=bbox_list[1] - y_min,
911
- # xmax=bbox_list[2] - x_min,
912
- # ymax=bbox_list[3] - y_min
913
- # )
914
- #
915
- # return Tag(
916
- # text=record.get("text", ""),
917
- # bbox=bbox_obj,
918
- # confidence=record.get("confidence", 1.0),
919
- # source=record.get("source", ""),
920
- # text_type=record.get("text_type", "Unknown"),
921
- # id=record.get("id", str(uuid.uuid4())),
922
- # font_size=record.get("font_size", 12),
923
- # rotation=record.get("rotation", 0.0)
924
- # )
925
-
926
-
927
- import json
928
- import uuid
929
-
930
- class SymbolDetector(BaseDetector):
931
- """
932
- A placeholder detector that reads precomputed symbol data
933
- from a JSON file and populates the context with Symbol objects.
934
- """
935
-
936
- def __init__(self,
937
- config: SymbolConfig,
938
- debug_handler: Optional[DebugHandler] = None,
939
- symbol_json_path: str = "./symbols.json"):
940
- super().__init__(config=config, debug_handler=debug_handler)
941
- self.symbol_json_path = symbol_json_path
942
-
943
- def _load_model(self, model_path: str):
944
- """Not loading an actual model; symbol data is read from JSON."""
945
- return None
946
-
947
- def detect(self,
948
- image: np.ndarray,
949
- context: DetectionContext,
950
- roi_offset: Tuple[int, int],
951
- *args,
952
- **kwargs) -> None:
953
- """
954
- Reads from a JSON file containing symbol info,
955
- adjusts coordinates using roi_offset, and updates context.
956
- """
957
- symbol_data = self._load_json_data(self.symbol_json_path)
958
- if not symbol_data:
959
- return
960
-
961
- x_min, y_min = roi_offset # Offset values from cropping
962
-
963
- for record in symbol_data.get("detections", []): # Fix: Use "detections" key
964
- sym_obj = self._parse_symbol_record(record, x_min, y_min)
965
- context.add_symbol(sym_obj)
966
-
967
- def _preprocess(self, image: np.ndarray) -> np.ndarray:
968
- return image
969
-
970
- def _postprocess(self, image: np.ndarray) -> np.ndarray:
971
- return image
972
-
973
- # --------------
974
- # HELPER METHODS
975
- # --------------
976
- def _load_json_data(self, json_path: str) -> dict:
977
- if not os.path.exists(json_path):
978
- self.debug_handler.save_artifact(name="symbol_error",
979
- data=b"Missing symbol JSON file",
980
- extension="txt")
981
- return {}
982
-
983
- with open(json_path, "r", encoding="utf-8") as f:
984
- return json.load(f)
985
-
986
- def _parse_symbol_record(self, record: dict, x_min: int, y_min: int) -> Symbol:
987
- """
988
- Builds a Symbol object from a JSON record, adjusting coordinates for cropping.
989
- """
990
- bbox_list = record.get("bbox", [0, 0, 0, 0])
991
- bbox_obj = BBox(
992
- xmin=bbox_list[0] - x_min,
993
- ymin=bbox_list[1] - y_min,
994
- xmax=bbox_list[2] - x_min,
995
- ymax=bbox_list[3] - y_min
996
- )
997
-
998
- # Compute the center
999
- center_coords = Coordinates(
1000
- x=(bbox_obj.xmin + bbox_obj.xmax) // 2,
1001
- y=(bbox_obj.ymin + bbox_obj.ymax) // 2
1002
- )
1003
-
1004
- return Symbol(
1005
- id=record.get("symbol_id", ""),
1006
- class_id=record.get("class_id", -1),
1007
- original_label=record.get("original_label", ""),
1008
- category=record.get("category", ""),
1009
- type=record.get("type", ""),
1010
- label=record.get("label", ""),
1011
- bbox=bbox_obj,
1012
- center=center_coords,
1013
- confidence=record.get("confidence", 0.95),
1014
- model_source=record.get("model_source", ""),
1015
- connections=[]
1016
- )
1017
-
1018
- class TagDetector(BaseDetector):
1019
- """
1020
- A placeholder detector that reads precomputed tag data
1021
- from a JSON file and populates the context with Tag objects.
1022
- """
1023
-
1024
- def __init__(self,
1025
- config: TagConfig,
1026
- debug_handler: Optional[DebugHandler] = None,
1027
- tag_json_path: str = "./tags.json"):
1028
- super().__init__(config=config, debug_handler=debug_handler)
1029
- self.tag_json_path = tag_json_path
1030
-
1031
- def _load_model(self, model_path: str):
1032
- """Not loading an actual model; tag data is read from JSON."""
1033
- return None
1034
-
1035
- def detect(self,
1036
- image: np.ndarray,
1037
- context: DetectionContext,
1038
- roi_offset: Tuple[int, int],
1039
- *args,
1040
- **kwargs) -> None:
1041
- """
1042
- Reads from a JSON file containing tag info,
1043
- adjusts coordinates using roi_offset, and updates context.
1044
- """
1045
-
1046
- tag_data = self._load_json_data(self.tag_json_path)
1047
- if not tag_data:
1048
- return
1049
-
1050
- x_min, y_min = roi_offset # Offset values from cropping
1051
-
1052
- for record in tag_data.get("detections", []): # Fix: Use "detections" key
1053
- tag_obj = self._parse_tag_record(record, x_min, y_min)
1054
- context.add_tag(tag_obj)
1055
-
1056
- def _preprocess(self, image: np.ndarray) -> np.ndarray:
1057
- return image
1058
-
1059
- def _postprocess(self, image: np.ndarray) -> np.ndarray:
1060
- return image
1061
-
1062
- # --------------
1063
- # HELPER METHODS
1064
- # --------------
1065
- def _load_json_data(self, json_path: str) -> dict:
1066
- if not os.path.exists(json_path):
1067
- self.debug_handler.save_artifact(name="tag_error",
1068
- data=b"Missing tag JSON file",
1069
- extension="txt")
1070
- return {}
1071
-
1072
- with open(json_path, "r", encoding="utf-8") as f:
1073
- return json.load(f)
1074
-
1075
- def _parse_tag_record(self, record: dict, x_min: int, y_min: int) -> Tag:
1076
- """
1077
- Builds a Tag object from a JSON record, adjusting coordinates for cropping.
1078
- """
1079
- bbox_list = record.get("bbox", [0, 0, 0, 0])
1080
- bbox_obj = BBox(
1081
- xmin=bbox_list[0] - x_min,
1082
- ymin=bbox_list[1] - y_min,
1083
- xmax=bbox_list[2] - x_min,
1084
- ymax=bbox_list[3] - y_min
1085
- )
1086
-
1087
- return Tag(
1088
- text=record.get("text", ""),
1089
- bbox=bbox_obj,
1090
- confidence=record.get("confidence", 1.0),
1091
- source=record.get("source", ""),
1092
- text_type=record.get("text_type", "Unknown"),
1093
- id=record.get("id", str(uuid.uuid4())),
1094
- font_size=record.get("font_size", 12),
1095
- rotation=record.get("rotation", 0.0)
1096
- )
 
3
  import torch
4
  import cv2
5
  import numpy as np
6
+ from typing import List, Optional, Tuple, Dict, Any
7
  from dataclasses import replace
8
  from math import sqrt
9
  import json
10
  import uuid
11
  from pathlib import Path
12
+ from abc import ABC, abstractmethod
13
+ from ultralytics import YOLO
14
+ from PIL import Image
15
+ import matplotlib.pyplot as plt
16
+ from storage import StorageInterface
17
 
18
  # Base classes and utilities
19
  from base import BaseDetector
 
23
 
24
  # DeepLSD model for line detection
25
  from deeplsd.models.deeplsd_inference import DeepLSD
 
26
 
27
  # Detection schema: dataclasses for different objects
28
  from detection_schema import (
 
43
  from skimage.morphology import skeletonize
44
  from skimage.measure import label
45
 
46
+ # Configure logging
47
+ logger = logging.getLogger(__name__)
48
 
49
+ class Detector(ABC):
50
+ """Base class for all detectors"""
51
+
52
+ def __init__(self, config: Any, debug_handler=None):
53
+ self.config = config
 
 
 
 
54
  self.debug_handler = debug_handler
55
+
56
+ @abstractmethod
57
+ def detect(self, image: np.ndarray) -> Dict:
58
+ """Perform detection on the image"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  pass
60
+
61
+ def save_debug_image(self, image: np.ndarray, filename: str):
62
+ """Save debug visualization if debug handler is available"""
63
+ if self.debug_handler:
64
+ self.debug_handler.save_image(image, filename)
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
+ class SymbolDetector(Detector):
68
+ """Detector for symbols in P&ID diagrams"""
69
+
70
+ def __init__(self, config, debug_handler=None):
71
+ super().__init__(config, debug_handler)
72
+ self.models = {}
73
+ for name, path in config.model_paths.items():
74
+ if os.path.exists(path):
75
+ self.models[name] = YOLO(path)
76
+ else:
77
+ logger.warning(f"Model not found at {path}")
78
+
79
+ def detect(self, image: np.ndarray) -> Dict:
80
+ """Detect symbols using multiple YOLO models"""
81
+ results = []
82
+
83
+ # Process with each model
84
+ for model_name, model in self.models.items():
85
+ model_results = model(image, conf=self.config.confidence_threshold)[0]
86
+ boxes = model_results.boxes
87
+
88
+ for box in boxes:
89
+ x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
90
+ conf = box.conf[0].cpu().numpy()
91
+ cls = box.cls[0].cpu().numpy()
92
+ cls_name = model_results.names[int(cls)]
93
+
94
+ results.append({
95
+ 'bbox': [float(x1), float(y1), float(x2), float(y2)],
96
+ 'confidence': float(conf),
97
+ 'class': cls_name,
98
+ 'model': model_name
99
+ })
100
+
101
+ return {'detections': results}
102
+
103
+
104
+ class TagDetector(Detector):
105
+ """Detector for text tags in P&ID diagrams"""
106
+
107
+ def __init__(self, config, debug_handler=None):
108
+ super().__init__(config, debug_handler)
109
+ self.ocr = None # Initialize OCR engine here
110
+
111
+ def detect(self, image: np.ndarray) -> Dict:
112
+ """Detect and recognize text tags"""
113
+ # Implement text detection logic
114
+ return {'detections': []}
115
+
116
+
117
+ class LineDetector(Detector):
118
+ """Detector for lines in P&ID diagrams"""
119
+
120
+ def __init__(self, config, model_path=None, model_config=None, device='cpu', debug_handler=None):
121
+ super().__init__(config, debug_handler)
122
+ self.model_path = model_path
123
+ self.model_config = model_config or {}
124
+ self.device = device
125
+
126
+ def detect(self, image: np.ndarray) -> Dict:
127
+ """Detect lines using DeepLSD or other methods"""
128
+ # Implement line detection logic
129
+ return {'detections': []}
130
+
131
+
132
+ class PointDetector(Detector):
133
+ """Detector for connection points in P&ID diagrams"""
134
+
135
+ def detect(self, image: np.ndarray) -> Dict:
136
+ """Detect connection points"""
137
+ # Implement point detection logic
138
+ return {'detections': []}
139
+
140
+
141
+ class JunctionDetector(Detector):
142
+ """Detector for line junctions in P&ID diagrams"""
143
+
144
+ def detect(self, image: np.ndarray) -> Dict:
145
+ """Detect line junctions"""
146
+ # Implement junction detection logic
147
+ return {'detections': []}