File size: 25,949 Bytes
910e0d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
import os
import math
import torch
import cv2
import numpy as np
from typing import List, Optional, Tuple, Dict
from dataclasses import replace
from math import sqrt
import json
import uuid
from pathlib import Path

# Base classes and utilities
from base import BaseDetector
from detection_schema import DetectionContext
from utils import DebugHandler
from config import SymbolConfig, TagConfig, LineConfig, PointConfig, JunctionConfig

# DeepLSD model for line detection
from deeplsd.models.deeplsd_inference import DeepLSD
from ultralytics import YOLO

# Detection schema: dataclasses for different objects
from detection_schema import (
    BBox,
    Coordinates,
    Point,
    Line,
    Symbol,
    Tag,
    SymbolType,
    LineStyle,
    ConnectionType,
    JunctionType,
    Junction
)

# Skeletonization and label processing for junction detection
from skimage.morphology import skeletonize
from skimage.measure import label


import os
import cv2
import torch
import numpy as np
from dataclasses import replace
from typing import List, Optional
from detection_utils import robust_merge_lines


class LineDetector(BaseDetector):
    """
    DeepLSD-based line detection with patch-based tiling and global merging.
    """

    def __init__(self,
                 config: LineConfig,
                 model_path: str,
                 model_config: dict,
                 device: torch.device,
                 debug_handler: DebugHandler = None):
        super().__init__(config, debug_handler)
        
        # Fix device selection for Apple Silicon
        if torch.backends.mps.is_available():
            self.device = torch.device("mps")
        elif torch.cuda.is_available():
            self.device = torch.device("cuda")
        else:
            self.device = torch.device("cpu")
        
        self.model_path = model_path
        self.model_config = model_config
        self.model = self._load_model(model_path)

        # Patch parameters
        self.patch_size = 512
        self.overlap = 10

        # Merging thresholds
        self.angle_thresh = 5.0  # degrees
        self.dist_thresh = 5.0  # pixels

    def _preprocess(self, image: np.ndarray) -> np.ndarray:
        kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2))
        dilated = cv2.dilate(image, kernel, iterations=2)

        skeleton = cv2.bitwise_not(dilated)
        skeleton = skeletonize(skeleton // 255)
        skeleton = (skeleton * 255).astype(np.uint8)
        kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 1))
        clean_image = cv2.dilate(skeleton, kernel, iterations=5)

        self.debug_handler.save_artifact(name="skeleton", data=clean_image, extension="png")

        return clean_image

    def _postprocess(self, image: np.ndarray) -> np.ndarray:
        return None
    # -------------------------------------
    # 1) Load Model
    # -------------------------------------
    def _load_model(self, model_path: str) -> DeepLSD:
        if not os.path.exists(model_path):
            raise FileNotFoundError(f"Model file not found: {model_path}")
        ckpt = torch.load(model_path, map_location=self.device)
        model = DeepLSD(self.model_config)
        model.load_state_dict(ckpt["model"])
        return model.to(self.device).eval()

    # -------------------------------------
    # 2) Main Detection Pipeline
    # -------------------------------------
    def detect(self,
               image: np.ndarray,
               context: DetectionContext,
               mask_coords: Optional[List[BBox]] = None,
               *args,
               **kwargs) -> None:
        """
        Steps:
          - Optional mask + threshold
          - Tile into overlapping patches
          - For each patch => run DeepLSD => re-map lines to global coords
          - Merge lines robustly
          - Build final Line objects => add to context
        """
        mask_coords = mask_coords or []

        skeleton = self._preprocess(image)
        # (A) Optional mask + threshold if you want a binary
        #     If your model expects grayscale or binary, do it here:
        processed_img = self._apply_mask_and_threshold(skeleton, mask_coords)
        # (B) Patch-based inference => collect raw lines in global coords
        all_lines = self._detect_in_patches(processed_img)

        # (C) Merge the lines in the global coordinate system
        merged_line_segments = robust_merge_lines(
            all_lines,
            angle_thresh=self.angle_thresh,
            dist_thresh=self.dist_thresh
        )

        # (D) Convert merged segments => final Line objects, add to context
        for (x1, y1, x2, y2) in merged_line_segments:
            line_obj = self._create_line_object(x1, y1, x2, y2)
            context.add_line(line_obj)

    # -------------------------------------
    # 3) Optional Mask + Threshold
    # -------------------------------------
    def _apply_mask_and_threshold(self, image: np.ndarray, mask_coords: List[BBox]) -> np.ndarray:
        """White out rectangular areas, then threshold to binary (if needed)."""
        masked = image.copy()
        for bbox in mask_coords:
            x1, y1 = int(bbox.xmin), int(bbox.ymin)
            x2, y2 = int(bbox.xmax), int(bbox.ymax)
            cv2.rectangle(masked, (x1, y1), (x2, y2), (255, 255, 255), -1)

        # If image has 3 channels, convert to grayscale
        if len(masked.shape) == 3:
            masked_gray = cv2.cvtColor(masked, cv2.COLOR_BGR2GRAY)
        else:
            masked_gray = masked

        # Binary threshold (adjust threshold as needed)
        # If your model expects a plain grayscale, skip threshold
        binary_img = cv2.threshold(masked_gray, 127, 255, cv2.THRESH_BINARY)[1]
        return binary_img

    # -------------------------------------
    # 4) Patch-Based Inference
    # -------------------------------------
    def _detect_in_patches(self, processed_img: np.ndarray) -> List[tuple]:
        """
        Break the image into overlapping patches, run DeepLSD,
        map local lines => global coords, and return the global line list.
        """
        patch_size = self.patch_size
        overlap = self.overlap

        height, width = processed_img.shape[:2]
        step = patch_size - overlap

        all_lines = []

        for y in range(0, height, step):
            patch_ymax = min(y + patch_size, height)
            patch_ymin = patch_ymax - patch_size if (patch_ymax - y) < patch_size else y
            if patch_ymin < 0: patch_ymin = 0

            for x in range(0, width, step):
                patch_xmax = min(x + patch_size, width)
                patch_xmin = patch_xmax - patch_size if (patch_xmax - x) < patch_size else x
                if patch_xmin < 0: patch_xmin = 0

                patch = processed_img[patch_ymin:patch_ymax, patch_xmin:patch_xmax]

                # Run model
                local_lines = self._run_model_inference(patch)

                # Convert local lines => global coords
                for ln in local_lines:
                    (x1_local, y1_local), (x2_local, y2_local) = ln

                    # offset by patch_xmin, patch_ymin
                    gx1 = x1_local + patch_xmin
                    gy1 = y1_local + patch_ymin
                    gx2 = x2_local + patch_xmin
                    gy2 = y2_local + patch_ymin

                    # Optional: clamp or filter lines partially out-of-bounds
                    if 0 <= gx1 < width and 0 <= gx2 < width and 0 <= gy1 < height and 0 <= gy2 < height:
                        all_lines.append((gx1, gy1, gx2, gy2))

        return all_lines

    # -------------------------------------
    # 5) Model Inference (Single Patch)
    # -------------------------------------
    def _run_model_inference(self, patch_img: np.ndarray) -> np.ndarray:
        """
        Run DeepLSD on a single patch (already masked/thresholded).
        patch_img shape: [patchH, patchW].
        Returns lines shape: [N, 2, 2].
        """
        # Convert patch to float32 and scale
        inp = torch.tensor(patch_img, dtype=torch.float32, device=self.device)[None, None] / 255.0
        with torch.no_grad():
            output = self.model({"image": inp})
            lines = output["lines"][0]  # shape (N, 2, 2)
        return lines

    # -------------------------------------
    # 6) Convert Merged Segments => Line Objects
    # -------------------------------------
    def _create_line_object(self, x1: float, y1: float, x2: float, y2: float) -> Line:
        """
        Create a minimal `Line` object from the final merged coordinates.
        """
        margin = 2
        # Start point
        start_pt = Point(
            coords=Coordinates(int(x1), int(y1)),
            bbox=BBox(
                xmin=int(x1 - margin),
                ymin=int(y1 - margin),
                xmax=int(x1 + margin),
                ymax=int(y1 + margin)
            ),
            type=JunctionType.END,
            confidence=1.0
        )
        # End point
        end_pt = Point(
            coords=Coordinates(int(x2), int(y2)),
            bbox=BBox(
                xmin=int(x2 - margin),
                ymin=int(y2 - margin),
                xmax=int(x2 + margin),
                ymax=int(y2 + margin)
            ),
            type=JunctionType.END,
            confidence=1.0
        )

        # Overall bounding box
        x_min = int(min(x1, x2))
        x_max = int(max(x1, x2))
        y_min = int(min(y1, y2))
        y_max = int(max(y1, y2))

        line_obj = Line(
            start=start_pt,
            end=end_pt,
            bbox=BBox(xmin=x_min, ymin=y_min, xmax=x_max, ymax=y_max),
            style=LineStyle(
                connection_type=ConnectionType.SOLID,
                stroke_width=2,
                color="#000000"
            ),
            confidence=0.9,
            topological_links=[]
        )
        return line_obj

class PointDetector(BaseDetector):
    """
    A detector that:
      1) Reads lines from the context
      2) Clusters endpoints within 'threshold_distance'
      3) Updates lines so that shared endpoints reference the same Point object
    """

    def __init__(self,
                 config:PointConfig,
                 debug_handler: DebugHandler = None):
        super().__init__(config, debug_handler)  # No real model to load
        self.threshold_distance = config.threshold_distance

    def _load_model(self, model_path: str):
        """No model needed for simple point unification."""
        return None

    def detect(self, image: np.ndarray, context: DetectionContext, *args, **kwargs) -> None:
        """
        Main method called by the pipeline.
        1) Gather all line endpoints from context
        2) Cluster them within 'threshold_distance'
        3) Update the line endpoints so they reference the unified cluster point
        """
        # 1) Collect all endpoints
        endpoints = []
        for line in context.lines.values():
            endpoints.append(line.start)
            endpoints.append(line.end)

        # 2) Cluster endpoints
        clusters = self._cluster_points(endpoints, self.threshold_distance)

        # 3) Build a dictionary of "representative" points
        #    So that each cluster has one "canonical" point
        #    Then we link all the points in that cluster to the canonical reference
        unified_point_map = {}
        for cluster in clusters:
            # let's pick the first point in the cluster as the "representative"
            rep_point = cluster[0]
            for p in cluster[1:]:
                unified_point_map[p.id] = rep_point

        # 4) Update all lines to reference the canonical point
        for line in context.lines.values():
            # unify start
            if line.start.id in unified_point_map:
                line.start = unified_point_map[line.start.id]
            # unify end
            if line.end.id in unified_point_map:
                line.end = unified_point_map[line.end.id]

        # We could also store the final set of unique points back in context.points
        # (e.g. clearing old duplicates).
        # That step is optional: you might prefer to keep everything in lines only,
        # or you might want context.points as a separate reference.

        # If you want to keep unique points in context.points:
        new_points = {}
        for line in context.lines.values():
            new_points[line.start.id] = line.start
            new_points[line.end.id] = line.end
        context.points = new_points  # replace the dictionary of points

    def _preprocess(self, image: np.ndarray) -> np.ndarray:
        """No specific image preprocessing needed."""
        return image

    def _postprocess(self, image: np.ndarray) -> np.ndarray:
        """No specific image postprocessing needed."""
        return image

    # ----------------------
    # HELPER: clustering
    # ----------------------
    def _cluster_points(self, points: List[Point], threshold: float) -> List[List[Point]]:
        """
        Very naive clustering:
         1) Start from the first point
         2) If it's within threshold of an existing cluster's representative,
            put it in that cluster
         3) Otherwise start a new cluster
        Return: list of clusters, each is a list of Points
        """
        clusters = []

        for pt in points:
            placed = False
            for cluster in clusters:
                # pick the first point in the cluster as reference
                ref_pt = cluster[0]
                if self._distance(pt, ref_pt) < threshold:
                    cluster.append(pt)
                    placed = True
                    break

            if not placed:
                clusters.append([pt])

        return clusters

    def _distance(self, p1: Point, p2: Point) -> float:
        dx = p1.coords.x - p2.coords.x
        dy = p1.coords.y - p2.coords.y
        return sqrt(dx*dx + dy*dy)


class JunctionDetector(BaseDetector):
    """
    Classifies points as 'END', 'L', or 'T' by skeletonizing the binarized image
    and analyzing local connectivity. Also creates Junction objects in the context.
    """

    def __init__(self, config: JunctionConfig, debug_handler: DebugHandler = None):
        super().__init__(config, debug_handler)  # no real model path
        self.window_size = config.window_size
        self.radius = config.radius
        self.angle_threshold_lb = config.angle_threshold_lb
        self.angle_threshold_ub = config.angle_threshold_ub
        self.debug_handler = debug_handler or DebugHandler()

    def _load_model(self, model_path: str):
        """Not loading any actual model, just skeleton logic."""
        return None

    def detect(self,
               image: np.ndarray,
               context: DetectionContext,
               *args,
               **kwargs) -> None:
        """
        1) Convert to binary & skeletonize
        2) Classify each point in the context
        3) Create a Junction for each point and store it in context.junctions
           (with 'connected_lines' referencing lines that share this point).
        """
        # 1) Preprocess -> skeleton
        skeleton = self._create_skeleton(image)

        # 2) Classify each point
        for pt in context.points.values():
            pt.type = self._classify_point(skeleton, pt)

        # 3) Create a Junction object for each point
        #    If you prefer only T or L, you can filter out END points.
        self._record_junctions_in_context(context)

    def _preprocess(self, image: np.ndarray) -> np.ndarray:
        """We might do thresholding; let's do a simple binary threshold."""
        if image.ndim == 3:
            gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        else:
            gray = image
        _, bin_image = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY)
        return bin_image

    def _postprocess(self, image: np.ndarray) -> np.ndarray:
        return image

    def _create_skeleton(self, raw_image: np.ndarray) -> np.ndarray:
        """Skeletonize the binarized image."""
        bin_img = self._preprocess(raw_image)
        # For skeletonize, we need a boolean array
        inv = cv2.bitwise_not(bin_img)
        inv_bool = (inv > 127).astype(np.uint8)
        skel = skeletonize(inv_bool).astype(np.uint8) * 255
        return skel

    def _classify_point(self, skeleton: np.ndarray, pt: Point) -> JunctionType:
        """
        Given a skeleton image, look around 'pt' in a local window
        to determine if it's an END, L, or T.
        """
        classification = JunctionType.END  # default

        half_w = self.window_size // 2
        x, y = pt.coords.x, pt.coords.y

        top    = max(0, y - half_w)
        bottom = min(skeleton.shape[0], y + half_w + 1)
        left   = max(0, x - half_w)
        right  = min(skeleton.shape[1], x + half_w + 1)

        patch = (skeleton[top:bottom, left:right] > 127).astype(np.uint8)

        # create circular mask
        circle_mask = np.zeros_like(patch, dtype=np.uint8)
        local_cx = x - left
        local_cy = y - top
        cv2.circle(circle_mask, (local_cx, local_cy), self.radius, 1, -1)
        circle_skel = patch & circle_mask

        # label connected regions
        labeled = label(circle_skel, connectivity=2)
        num_exits = labeled.max()

        if num_exits == 1:
            classification = JunctionType.END
        elif num_exits == 2:
            # check angle for L
            classification = self._check_angle_for_L(labeled)
        elif num_exits == 3:
            classification = JunctionType.T

        return classification

    def _check_angle_for_L(self, labeled_region: np.ndarray) -> JunctionType:
        """
        If the angle between two branches is within
        [angle_threshold_lb, angle_threshold_ub], it's 'L'.
        Otherwise default to END.
        """
        coords = np.argwhere(labeled_region == 1)
        if len(coords) < 2:
            return JunctionType.END

        (y1, x1), (y2, x2) = coords[:2]
        dx = x2 - x1
        dy = y2 - y1
        angle = math.degrees(math.atan2(dy, dx))
        acute_angle = min(abs(angle), 180 - abs(angle))

        if self.angle_threshold_lb <= acute_angle <= self.angle_threshold_ub:
            return JunctionType.L
        return JunctionType.END

    # -----------------------------------------
    #  EXTRA STEP: Create Junction objects
    # -----------------------------------------
    def _record_junctions_in_context(self, context: DetectionContext):
        """
        Create a Junction object for each point in context.points.
        If you only want T/L points as junctions, filter them out.
        Also track any lines that connect to this point.
        """

        for pt in context.points.values():
            # If you prefer to store all points as junction, do it:
            # or if you want only T or L, do:
            # if pt.type in {JunctionType.T, JunctionType.L}: ...

            jn = Junction(
                center=pt.coords,
                junction_type=pt.type,
                # add more properties if needed
            )

            # find lines that connect to this point
            connected_lines = []
            for ln in context.lines.values():
                if ln.start.id == pt.id or ln.end.id == pt.id:
                    connected_lines.append(ln.id)

            jn.connected_lines = connected_lines

            # add to context
            context.add_junction(jn)

import json
import uuid

class SymbolDetector(BaseDetector):
    """
    A placeholder detector that reads precomputed symbol data
    from a JSON file and populates the context with Symbol objects.
    """

    def __init__(self,
                 config: SymbolConfig,
                 debug_handler: Optional[DebugHandler] = None,
                 symbol_json_path: str = "./symbols.json"):
        super().__init__(config=config, debug_handler=debug_handler)
        self.symbol_json_path = symbol_json_path

    def _load_model(self, model_path: str):
        """Not loading an actual model; symbol data is read from JSON."""
        return None

    def detect(self,
               image: np.ndarray,
               context: DetectionContext,
               # roi_offset: Tuple[int, int],
               *args,
               **kwargs) -> None:
        """
        Reads from a JSON file containing symbol info,
        adjusts coordinates using roi_offset, and updates context.
        """
        symbol_data = self._load_json_data(self.symbol_json_path)
        if not symbol_data:
            return

        # x_min, y_min = roi_offset  # Offset values from cropping

        for record in symbol_data.get("detections", []):  # Fix: Use "detections" key
            # sym_obj = self._parse_symbol_record(record, x_min, y_min)
            sym_obj = self._parse_symbol_record(record)
            context.add_symbol(sym_obj)

    def _preprocess(self, image: np.ndarray) -> np.ndarray:
        return image

    def _postprocess(self, image: np.ndarray) -> np.ndarray:
        return image

    # --------------
    # HELPER METHODS
    # --------------
    def _load_json_data(self, json_path: str) -> dict:
        if not os.path.exists(json_path):
            self.debug_handler.save_artifact(name="symbol_error",
                                             data=b"Missing symbol JSON file",
                                             extension="txt")
            return {}

        with open(json_path, "r", encoding="utf-8") as f:
            return json.load(f)

    def _parse_symbol_record(self, record: dict) -> Symbol:
        """
        Builds a Symbol object from a JSON record, adjusting coordinates for cropping.
        """
        bbox_list = record.get("bbox", [0, 0, 0, 0])
        # bbox_obj = BBox(
        #     xmin=bbox_list[0] - x_min,
        #     ymin=bbox_list[1] - y_min,
        #     xmax=bbox_list[2] - x_min,
        #     ymax=bbox_list[3] - y_min
        # )

        bbox_obj = BBox(
            xmin=bbox_list[0],
            ymin=bbox_list[1],
            xmax=bbox_list[2],
            ymax=bbox_list[3]
        )


        # Compute the center
        center_coords = Coordinates(
            x=(bbox_obj.xmin + bbox_obj.xmax) // 2,
            y=(bbox_obj.ymin + bbox_obj.ymax) // 2
        )

        return Symbol(
            id=record.get("symbol_id", ""),
            class_id=record.get("class_id", -1),
            original_label=record.get("original_label", ""),
            category=record.get("category", ""),
            type=record.get("type", ""),
            label=record.get("label", ""),
            bbox=bbox_obj,
            center=center_coords,
            confidence=record.get("confidence", 0.95),
            model_source=record.get("model_source", ""),
            connections=[]
        )

class TagDetector(BaseDetector):
    """
    A placeholder detector that reads precomputed tag data
    from a JSON file and populates the context with Tag objects.
    """

    def __init__(self,
                 config: TagConfig,
                 debug_handler: Optional[DebugHandler] = None,
                 tag_json_path: str = "./tags.json"):
        super().__init__(config=config, debug_handler=debug_handler)
        self.tag_json_path = tag_json_path

    def _load_model(self, model_path: str):
        """Not loading an actual model; tag data is read from JSON."""
        return None

    def detect(self,
               image: np.ndarray,
               context: DetectionContext,
               # roi_offset: Tuple[int, int],
               *args,
               **kwargs) -> None:
        """
        Reads from a JSON file containing tag info,
        adjusts coordinates using roi_offset, and updates context.
        """

        tag_data = self._load_json_data(self.tag_json_path)
        if not tag_data:
            return

        # x_min, y_min = roi_offset  # Offset values from cropping

        for record in tag_data.get("detections", []):  # Fix: Use "detections" key
            # tag_obj = self._parse_tag_record(record, x_min, y_min)
            tag_obj = self._parse_tag_record(record)
            context.add_tag(tag_obj)

    def _preprocess(self, image: np.ndarray) -> np.ndarray:
        return image

    def _postprocess(self, image: np.ndarray) -> np.ndarray:
        return image

    # --------------
    # HELPER METHODS
    # --------------
    def _load_json_data(self, json_path: str) -> dict:
        if not os.path.exists(json_path):
            self.debug_handler.save_artifact(name="tag_error",
                                             data=b"Missing tag JSON file",
                                             extension="txt")
            return {}

        with open(json_path, "r", encoding="utf-8") as f:
            return json.load(f)

    def _parse_tag_record(self, record: dict) -> Tag:
        """
        Builds a Tag object from a JSON record, adjusting coordinates for cropping.
        """
        bbox_list = record.get("bbox", [0, 0, 0, 0])
        # bbox_obj = BBox(
        #     xmin=bbox_list[0] - x_min,
        #     ymin=bbox_list[1] - y_min,
        #     xmax=bbox_list[2] - x_min,
        #     ymax=bbox_list[3] - y_min
        # )

        bbox_obj = BBox(
            xmin=bbox_list[0],
            ymin=bbox_list[1],
            xmax=bbox_list[2],
            ymax=bbox_list[3]
        )

        return Tag(
            text=record.get("text", ""),
            bbox=bbox_obj,
            confidence=record.get("confidence", 1.0),
            source=record.get("source", ""),
            text_type=record.get("text_type", "Unknown"),
            id=record.get("id", str(uuid.uuid4())),
            font_size=record.get("font_size", 12),
            rotation=record.get("rotation", 0.0)
        )