Spaces:
Runtime error
Runtime error
| import os | |
| import math | |
| import torch | |
| import cv2 | |
| import numpy as np | |
| from typing import List, Optional, Tuple, Dict | |
| from dataclasses import replace | |
| from math import sqrt | |
| import json | |
| import uuid | |
| from pathlib import Path | |
| # Base classes and utilities | |
| from base import BaseDetector | |
| from detection_schema import DetectionContext | |
| from utils import DebugHandler | |
| from config import SymbolConfig, TagConfig, LineConfig, PointConfig, JunctionConfig | |
| # DeepLSD model for line detection | |
| from deeplsd.models.deeplsd_inference import DeepLSD | |
| from ultralytics import YOLO | |
| # Detection schema: dataclasses for different objects | |
| from detection_schema import ( | |
| BBox, | |
| Coordinates, | |
| Point, | |
| Line, | |
| Symbol, | |
| Tag, | |
| SymbolType, | |
| LineStyle, | |
| ConnectionType, | |
| JunctionType, | |
| Junction | |
| ) | |
| # Skeletonization and label processing for junction detection | |
| from skimage.morphology import skeletonize | |
| from skimage.measure import label | |
| import os | |
| import cv2 | |
| import torch | |
| import numpy as np | |
| from dataclasses import replace | |
| from typing import List, Optional | |
| from detection_utils import robust_merge_lines | |
| class LineDetector(BaseDetector): | |
| """ | |
| DeepLSD-based line detection with patch-based tiling and global merging. | |
| """ | |
| def __init__(self, | |
| config: LineConfig, | |
| model_path: str, | |
| model_config: dict, | |
| device: torch.device, | |
| debug_handler: DebugHandler = None): | |
| super().__init__(config, debug_handler) | |
| # Fix device selection for Apple Silicon | |
| if torch.backends.mps.is_available(): | |
| self.device = torch.device("mps") | |
| elif torch.cuda.is_available(): | |
| self.device = torch.device("cuda") | |
| else: | |
| self.device = torch.device("cpu") | |
| self.model_path = model_path | |
| self.model_config = model_config | |
| self.model = self._load_model(model_path) | |
| # Patch parameters | |
| self.patch_size = 512 | |
| self.overlap = 10 | |
| # Merging thresholds | |
| self.angle_thresh = 5.0 # degrees | |
| self.dist_thresh = 5.0 # pixels | |
| def _preprocess(self, image: np.ndarray) -> np.ndarray: | |
| kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2)) | |
| dilated = cv2.dilate(image, kernel, iterations=2) | |
| skeleton = cv2.bitwise_not(dilated) | |
| skeleton = skeletonize(skeleton // 255) | |
| skeleton = (skeleton * 255).astype(np.uint8) | |
| kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 1)) | |
| clean_image = cv2.dilate(skeleton, kernel, iterations=5) | |
| self.debug_handler.save_artifact(name="skeleton", data=clean_image, extension="png") | |
| return clean_image | |
| def _postprocess(self, image: np.ndarray) -> np.ndarray: | |
| return None | |
| # ------------------------------------- | |
| # 1) Load Model | |
| # ------------------------------------- | |
| def _load_model(self, model_path: str) -> DeepLSD: | |
| if not os.path.exists(model_path): | |
| raise FileNotFoundError(f"Model file not found: {model_path}") | |
| ckpt = torch.load(model_path, map_location=self.device) | |
| model = DeepLSD(self.model_config) | |
| model.load_state_dict(ckpt["model"]) | |
| return model.to(self.device).eval() | |
| # ------------------------------------- | |
| # 2) Main Detection Pipeline | |
| # ------------------------------------- | |
| def detect(self, | |
| image: np.ndarray, | |
| context: DetectionContext, | |
| mask_coords: Optional[List[BBox]] = None, | |
| *args, | |
| **kwargs) -> None: | |
| """ | |
| Steps: | |
| - Optional mask + threshold | |
| - Tile into overlapping patches | |
| - For each patch => run DeepLSD => re-map lines to global coords | |
| - Merge lines robustly | |
| - Build final Line objects => add to context | |
| """ | |
| mask_coords = mask_coords or [] | |
| skeleton = self._preprocess(image) | |
| # (A) Optional mask + threshold if you want a binary | |
| # If your model expects grayscale or binary, do it here: | |
| processed_img = self._apply_mask_and_threshold(skeleton, mask_coords) | |
| # (B) Patch-based inference => collect raw lines in global coords | |
| all_lines = self._detect_in_patches(processed_img) | |
| # (C) Merge the lines in the global coordinate system | |
| merged_line_segments = robust_merge_lines( | |
| all_lines, | |
| angle_thresh=self.angle_thresh, | |
| dist_thresh=self.dist_thresh | |
| ) | |
| # (D) Convert merged segments => final Line objects, add to context | |
| for (x1, y1, x2, y2) in merged_line_segments: | |
| line_obj = self._create_line_object(x1, y1, x2, y2) | |
| context.add_line(line_obj) | |
| # ------------------------------------- | |
| # 3) Optional Mask + Threshold | |
| # ------------------------------------- | |
| def _apply_mask_and_threshold(self, image: np.ndarray, mask_coords: List[BBox]) -> np.ndarray: | |
| """White out rectangular areas, then threshold to binary (if needed).""" | |
| masked = image.copy() | |
| for bbox in mask_coords: | |
| x1, y1 = int(bbox.xmin), int(bbox.ymin) | |
| x2, y2 = int(bbox.xmax), int(bbox.ymax) | |
| cv2.rectangle(masked, (x1, y1), (x2, y2), (255, 255, 255), -1) | |
| # If image has 3 channels, convert to grayscale | |
| if len(masked.shape) == 3: | |
| masked_gray = cv2.cvtColor(masked, cv2.COLOR_BGR2GRAY) | |
| else: | |
| masked_gray = masked | |
| # Binary threshold (adjust threshold as needed) | |
| # If your model expects a plain grayscale, skip threshold | |
| binary_img = cv2.threshold(masked_gray, 127, 255, cv2.THRESH_BINARY)[1] | |
| return binary_img | |
| # ------------------------------------- | |
| # 4) Patch-Based Inference | |
| # ------------------------------------- | |
| def _detect_in_patches(self, processed_img: np.ndarray) -> List[tuple]: | |
| """ | |
| Break the image into overlapping patches, run DeepLSD, | |
| map local lines => global coords, and return the global line list. | |
| """ | |
| patch_size = self.patch_size | |
| overlap = self.overlap | |
| height, width = processed_img.shape[:2] | |
| step = patch_size - overlap | |
| all_lines = [] | |
| for y in range(0, height, step): | |
| patch_ymax = min(y + patch_size, height) | |
| patch_ymin = patch_ymax - patch_size if (patch_ymax - y) < patch_size else y | |
| if patch_ymin < 0: patch_ymin = 0 | |
| for x in range(0, width, step): | |
| patch_xmax = min(x + patch_size, width) | |
| patch_xmin = patch_xmax - patch_size if (patch_xmax - x) < patch_size else x | |
| if patch_xmin < 0: patch_xmin = 0 | |
| patch = processed_img[patch_ymin:patch_ymax, patch_xmin:patch_xmax] | |
| # Run model | |
| local_lines = self._run_model_inference(patch) | |
| # Convert local lines => global coords | |
| for ln in local_lines: | |
| (x1_local, y1_local), (x2_local, y2_local) = ln | |
| # offset by patch_xmin, patch_ymin | |
| gx1 = x1_local + patch_xmin | |
| gy1 = y1_local + patch_ymin | |
| gx2 = x2_local + patch_xmin | |
| gy2 = y2_local + patch_ymin | |
| # Optional: clamp or filter lines partially out-of-bounds | |
| if 0 <= gx1 < width and 0 <= gx2 < width and 0 <= gy1 < height and 0 <= gy2 < height: | |
| all_lines.append((gx1, gy1, gx2, gy2)) | |
| return all_lines | |
| # ------------------------------------- | |
| # 5) Model Inference (Single Patch) | |
| # ------------------------------------- | |
| def _run_model_inference(self, patch_img: np.ndarray) -> np.ndarray: | |
| """ | |
| Run DeepLSD on a single patch (already masked/thresholded). | |
| patch_img shape: [patchH, patchW]. | |
| Returns lines shape: [N, 2, 2]. | |
| """ | |
| # Convert patch to float32 and scale | |
| inp = torch.tensor(patch_img, dtype=torch.float32, device=self.device)[None, None] / 255.0 | |
| with torch.no_grad(): | |
| output = self.model({"image": inp}) | |
| lines = output["lines"][0] # shape (N, 2, 2) | |
| return lines | |
| # ------------------------------------- | |
| # 6) Convert Merged Segments => Line Objects | |
| # ------------------------------------- | |
| def _create_line_object(self, x1: float, y1: float, x2: float, y2: float) -> Line: | |
| """ | |
| Create a minimal `Line` object from the final merged coordinates. | |
| """ | |
| margin = 2 | |
| # Start point | |
| start_pt = Point( | |
| coords=Coordinates(int(x1), int(y1)), | |
| bbox=BBox( | |
| xmin=int(x1 - margin), | |
| ymin=int(y1 - margin), | |
| xmax=int(x1 + margin), | |
| ymax=int(y1 + margin) | |
| ), | |
| type=JunctionType.END, | |
| confidence=1.0 | |
| ) | |
| # End point | |
| end_pt = Point( | |
| coords=Coordinates(int(x2), int(y2)), | |
| bbox=BBox( | |
| xmin=int(x2 - margin), | |
| ymin=int(y2 - margin), | |
| xmax=int(x2 + margin), | |
| ymax=int(y2 + margin) | |
| ), | |
| type=JunctionType.END, | |
| confidence=1.0 | |
| ) | |
| # Overall bounding box | |
| x_min = int(min(x1, x2)) | |
| x_max = int(max(x1, x2)) | |
| y_min = int(min(y1, y2)) | |
| y_max = int(max(y1, y2)) | |
| line_obj = Line( | |
| start=start_pt, | |
| end=end_pt, | |
| bbox=BBox(xmin=x_min, ymin=y_min, xmax=x_max, ymax=y_max), | |
| style=LineStyle( | |
| connection_type=ConnectionType.SOLID, | |
| stroke_width=2, | |
| color="#000000" | |
| ), | |
| confidence=0.9, | |
| topological_links=[] | |
| ) | |
| return line_obj | |
| class PointDetector(BaseDetector): | |
| """ | |
| A detector that: | |
| 1) Reads lines from the context | |
| 2) Clusters endpoints within 'threshold_distance' | |
| 3) Updates lines so that shared endpoints reference the same Point object | |
| """ | |
| def __init__(self, | |
| config:PointConfig, | |
| debug_handler: DebugHandler = None): | |
| super().__init__(config, debug_handler) # No real model to load | |
| self.threshold_distance = config.threshold_distance | |
| def _load_model(self, model_path: str): | |
| """No model needed for simple point unification.""" | |
| return None | |
| def detect(self, image: np.ndarray, context: DetectionContext, *args, **kwargs) -> None: | |
| """ | |
| Main method called by the pipeline. | |
| 1) Gather all line endpoints from context | |
| 2) Cluster them within 'threshold_distance' | |
| 3) Update the line endpoints so they reference the unified cluster point | |
| """ | |
| # 1) Collect all endpoints | |
| endpoints = [] | |
| for line in context.lines.values(): | |
| endpoints.append(line.start) | |
| endpoints.append(line.end) | |
| # 2) Cluster endpoints | |
| clusters = self._cluster_points(endpoints, self.threshold_distance) | |
| # 3) Build a dictionary of "representative" points | |
| # So that each cluster has one "canonical" point | |
| # Then we link all the points in that cluster to the canonical reference | |
| unified_point_map = {} | |
| for cluster in clusters: | |
| # let's pick the first point in the cluster as the "representative" | |
| rep_point = cluster[0] | |
| for p in cluster[1:]: | |
| unified_point_map[p.id] = rep_point | |
| # 4) Update all lines to reference the canonical point | |
| for line in context.lines.values(): | |
| # unify start | |
| if line.start.id in unified_point_map: | |
| line.start = unified_point_map[line.start.id] | |
| # unify end | |
| if line.end.id in unified_point_map: | |
| line.end = unified_point_map[line.end.id] | |
| # We could also store the final set of unique points back in context.points | |
| # (e.g. clearing old duplicates). | |
| # That step is optional: you might prefer to keep everything in lines only, | |
| # or you might want context.points as a separate reference. | |
| # If you want to keep unique points in context.points: | |
| new_points = {} | |
| for line in context.lines.values(): | |
| new_points[line.start.id] = line.start | |
| new_points[line.end.id] = line.end | |
| context.points = new_points # replace the dictionary of points | |
| def _preprocess(self, image: np.ndarray) -> np.ndarray: | |
| """No specific image preprocessing needed.""" | |
| return image | |
| def _postprocess(self, image: np.ndarray) -> np.ndarray: | |
| """No specific image postprocessing needed.""" | |
| return image | |
| # ---------------------- | |
| # HELPER: clustering | |
| # ---------------------- | |
| def _cluster_points(self, points: List[Point], threshold: float) -> List[List[Point]]: | |
| """ | |
| Very naive clustering: | |
| 1) Start from the first point | |
| 2) If it's within threshold of an existing cluster's representative, | |
| put it in that cluster | |
| 3) Otherwise start a new cluster | |
| Return: list of clusters, each is a list of Points | |
| """ | |
| clusters = [] | |
| for pt in points: | |
| placed = False | |
| for cluster in clusters: | |
| # pick the first point in the cluster as reference | |
| ref_pt = cluster[0] | |
| if self._distance(pt, ref_pt) < threshold: | |
| cluster.append(pt) | |
| placed = True | |
| break | |
| if not placed: | |
| clusters.append([pt]) | |
| return clusters | |
| def _distance(self, p1: Point, p2: Point) -> float: | |
| dx = p1.coords.x - p2.coords.x | |
| dy = p1.coords.y - p2.coords.y | |
| return sqrt(dx*dx + dy*dy) | |
| class JunctionDetector(BaseDetector): | |
| """ | |
| Classifies points as 'END', 'L', or 'T' by skeletonizing the binarized image | |
| and analyzing local connectivity. Also creates Junction objects in the context. | |
| """ | |
| def __init__(self, config: JunctionConfig, debug_handler: DebugHandler = None): | |
| super().__init__(config, debug_handler) # no real model path | |
| self.window_size = config.window_size | |
| self.radius = config.radius | |
| self.angle_threshold_lb = config.angle_threshold_lb | |
| self.angle_threshold_ub = config.angle_threshold_ub | |
| self.debug_handler = debug_handler or DebugHandler() | |
| def _load_model(self, model_path: str): | |
| """Not loading any actual model, just skeleton logic.""" | |
| return None | |
| def detect(self, | |
| image: np.ndarray, | |
| context: DetectionContext, | |
| *args, | |
| **kwargs) -> None: | |
| """ | |
| 1) Convert to binary & skeletonize | |
| 2) Classify each point in the context | |
| 3) Create a Junction for each point and store it in context.junctions | |
| (with 'connected_lines' referencing lines that share this point). | |
| """ | |
| # 1) Preprocess -> skeleton | |
| skeleton = self._create_skeleton(image) | |
| # 2) Classify each point | |
| for pt in context.points.values(): | |
| pt.type = self._classify_point(skeleton, pt) | |
| # 3) Create a Junction object for each point | |
| # If you prefer only T or L, you can filter out END points. | |
| self._record_junctions_in_context(context) | |
| def _preprocess(self, image: np.ndarray) -> np.ndarray: | |
| """We might do thresholding; let's do a simple binary threshold.""" | |
| if image.ndim == 3: | |
| gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | |
| else: | |
| gray = image | |
| _, bin_image = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY) | |
| return bin_image | |
| def _postprocess(self, image: np.ndarray) -> np.ndarray: | |
| return image | |
| def _create_skeleton(self, raw_image: np.ndarray) -> np.ndarray: | |
| """Skeletonize the binarized image.""" | |
| bin_img = self._preprocess(raw_image) | |
| # For skeletonize, we need a boolean array | |
| inv = cv2.bitwise_not(bin_img) | |
| inv_bool = (inv > 127).astype(np.uint8) | |
| skel = skeletonize(inv_bool).astype(np.uint8) * 255 | |
| return skel | |
| def _classify_point(self, skeleton: np.ndarray, pt: Point) -> JunctionType: | |
| """ | |
| Given a skeleton image, look around 'pt' in a local window | |
| to determine if it's an END, L, or T. | |
| """ | |
| classification = JunctionType.END # default | |
| half_w = self.window_size // 2 | |
| x, y = pt.coords.x, pt.coords.y | |
| top = max(0, y - half_w) | |
| bottom = min(skeleton.shape[0], y + half_w + 1) | |
| left = max(0, x - half_w) | |
| right = min(skeleton.shape[1], x + half_w + 1) | |
| patch = (skeleton[top:bottom, left:right] > 127).astype(np.uint8) | |
| # create circular mask | |
| circle_mask = np.zeros_like(patch, dtype=np.uint8) | |
| local_cx = x - left | |
| local_cy = y - top | |
| cv2.circle(circle_mask, (local_cx, local_cy), self.radius, 1, -1) | |
| circle_skel = patch & circle_mask | |
| # label connected regions | |
| labeled = label(circle_skel, connectivity=2) | |
| num_exits = labeled.max() | |
| if num_exits == 1: | |
| classification = JunctionType.END | |
| elif num_exits == 2: | |
| # check angle for L | |
| classification = self._check_angle_for_L(labeled) | |
| elif num_exits == 3: | |
| classification = JunctionType.T | |
| return classification | |
| def _check_angle_for_L(self, labeled_region: np.ndarray) -> JunctionType: | |
| """ | |
| If the angle between two branches is within | |
| [angle_threshold_lb, angle_threshold_ub], it's 'L'. | |
| Otherwise default to END. | |
| """ | |
| coords = np.argwhere(labeled_region == 1) | |
| if len(coords) < 2: | |
| return JunctionType.END | |
| (y1, x1), (y2, x2) = coords[:2] | |
| dx = x2 - x1 | |
| dy = y2 - y1 | |
| angle = math.degrees(math.atan2(dy, dx)) | |
| acute_angle = min(abs(angle), 180 - abs(angle)) | |
| if self.angle_threshold_lb <= acute_angle <= self.angle_threshold_ub: | |
| return JunctionType.L | |
| return JunctionType.END | |
| # ----------------------------------------- | |
| # EXTRA STEP: Create Junction objects | |
| # ----------------------------------------- | |
| def _record_junctions_in_context(self, context: DetectionContext): | |
| """ | |
| Create a Junction object for each point in context.points. | |
| If you only want T/L points as junctions, filter them out. | |
| Also track any lines that connect to this point. | |
| """ | |
| for pt in context.points.values(): | |
| # If you prefer to store all points as junction, do it: | |
| # or if you want only T or L, do: | |
| # if pt.type in {JunctionType.T, JunctionType.L}: ... | |
| jn = Junction( | |
| center=pt.coords, | |
| junction_type=pt.type, | |
| # add more properties if needed | |
| ) | |
| # find lines that connect to this point | |
| connected_lines = [] | |
| for ln in context.lines.values(): | |
| if ln.start.id == pt.id or ln.end.id == pt.id: | |
| connected_lines.append(ln.id) | |
| jn.connected_lines = connected_lines | |
| # add to context | |
| context.add_junction(jn) | |
| import json | |
| import uuid | |
| class SymbolDetector(BaseDetector): | |
| """ | |
| A placeholder detector that reads precomputed symbol data | |
| from a JSON file and populates the context with Symbol objects. | |
| """ | |
| def __init__(self, | |
| config: SymbolConfig, | |
| debug_handler: Optional[DebugHandler] = None, | |
| symbol_json_path: str = "./symbols.json"): | |
| super().__init__(config=config, debug_handler=debug_handler) | |
| self.symbol_json_path = symbol_json_path | |
| def _load_model(self, model_path: str): | |
| """Not loading an actual model; symbol data is read from JSON.""" | |
| return None | |
| def detect(self, | |
| image: np.ndarray, | |
| context: DetectionContext, | |
| # roi_offset: Tuple[int, int], | |
| *args, | |
| **kwargs) -> None: | |
| """ | |
| Reads from a JSON file containing symbol info, | |
| adjusts coordinates using roi_offset, and updates context. | |
| """ | |
| symbol_data = self._load_json_data(self.symbol_json_path) | |
| if not symbol_data: | |
| return | |
| # x_min, y_min = roi_offset # Offset values from cropping | |
| for record in symbol_data.get("detections", []): # Fix: Use "detections" key | |
| # sym_obj = self._parse_symbol_record(record, x_min, y_min) | |
| sym_obj = self._parse_symbol_record(record) | |
| context.add_symbol(sym_obj) | |
| def _preprocess(self, image: np.ndarray) -> np.ndarray: | |
| return image | |
| def _postprocess(self, image: np.ndarray) -> np.ndarray: | |
| return image | |
| # -------------- | |
| # HELPER METHODS | |
| # -------------- | |
| def _load_json_data(self, json_path: str) -> dict: | |
| if not os.path.exists(json_path): | |
| self.debug_handler.save_artifact(name="symbol_error", | |
| data=b"Missing symbol JSON file", | |
| extension="txt") | |
| return {} | |
| with open(json_path, "r", encoding="utf-8") as f: | |
| return json.load(f) | |
| def _parse_symbol_record(self, record: dict) -> Symbol: | |
| """ | |
| Builds a Symbol object from a JSON record, adjusting coordinates for cropping. | |
| """ | |
| bbox_list = record.get("bbox", [0, 0, 0, 0]) | |
| # bbox_obj = BBox( | |
| # xmin=bbox_list[0] - x_min, | |
| # ymin=bbox_list[1] - y_min, | |
| # xmax=bbox_list[2] - x_min, | |
| # ymax=bbox_list[3] - y_min | |
| # ) | |
| bbox_obj = BBox( | |
| xmin=bbox_list[0], | |
| ymin=bbox_list[1], | |
| xmax=bbox_list[2], | |
| ymax=bbox_list[3] | |
| ) | |
| # Compute the center | |
| center_coords = Coordinates( | |
| x=(bbox_obj.xmin + bbox_obj.xmax) // 2, | |
| y=(bbox_obj.ymin + bbox_obj.ymax) // 2 | |
| ) | |
| return Symbol( | |
| id=record.get("symbol_id", ""), | |
| class_id=record.get("class_id", -1), | |
| original_label=record.get("original_label", ""), | |
| category=record.get("category", ""), | |
| type=record.get("type", ""), | |
| label=record.get("label", ""), | |
| bbox=bbox_obj, | |
| center=center_coords, | |
| confidence=record.get("confidence", 0.95), | |
| model_source=record.get("model_source", ""), | |
| connections=[] | |
| ) | |
| class TagDetector(BaseDetector): | |
| """ | |
| A placeholder detector that reads precomputed tag data | |
| from a JSON file and populates the context with Tag objects. | |
| """ | |
| def __init__(self, | |
| config: TagConfig, | |
| debug_handler: Optional[DebugHandler] = None, | |
| tag_json_path: str = "./tags.json"): | |
| super().__init__(config=config, debug_handler=debug_handler) | |
| self.tag_json_path = tag_json_path | |
| def _load_model(self, model_path: str): | |
| """Not loading an actual model; tag data is read from JSON.""" | |
| return None | |
| def detect(self, | |
| image: np.ndarray, | |
| context: DetectionContext, | |
| # roi_offset: Tuple[int, int], | |
| *args, | |
| **kwargs) -> None: | |
| """ | |
| Reads from a JSON file containing tag info, | |
| adjusts coordinates using roi_offset, and updates context. | |
| """ | |
| tag_data = self._load_json_data(self.tag_json_path) | |
| if not tag_data: | |
| return | |
| # x_min, y_min = roi_offset # Offset values from cropping | |
| for record in tag_data.get("detections", []): # Fix: Use "detections" key | |
| # tag_obj = self._parse_tag_record(record, x_min, y_min) | |
| tag_obj = self._parse_tag_record(record) | |
| context.add_tag(tag_obj) | |
| def _preprocess(self, image: np.ndarray) -> np.ndarray: | |
| return image | |
| def _postprocess(self, image: np.ndarray) -> np.ndarray: | |
| return image | |
| # -------------- | |
| # HELPER METHODS | |
| # -------------- | |
| def _load_json_data(self, json_path: str) -> dict: | |
| if not os.path.exists(json_path): | |
| self.debug_handler.save_artifact(name="tag_error", | |
| data=b"Missing tag JSON file", | |
| extension="txt") | |
| return {} | |
| with open(json_path, "r", encoding="utf-8") as f: | |
| return json.load(f) | |
| def _parse_tag_record(self, record: dict) -> Tag: | |
| """ | |
| Builds a Tag object from a JSON record, adjusting coordinates for cropping. | |
| """ | |
| bbox_list = record.get("bbox", [0, 0, 0, 0]) | |
| # bbox_obj = BBox( | |
| # xmin=bbox_list[0] - x_min, | |
| # ymin=bbox_list[1] - y_min, | |
| # xmax=bbox_list[2] - x_min, | |
| # ymax=bbox_list[3] - y_min | |
| # ) | |
| bbox_obj = BBox( | |
| xmin=bbox_list[0], | |
| ymin=bbox_list[1], | |
| xmax=bbox_list[2], | |
| ymax=bbox_list[3] | |
| ) | |
| return Tag( | |
| text=record.get("text", ""), | |
| bbox=bbox_obj, | |
| confidence=record.get("confidence", 1.0), | |
| source=record.get("source", ""), | |
| text_type=record.get("text_type", "Unknown"), | |
| id=record.get("id", str(uuid.uuid4())), | |
| font_size=record.get("font_size", 12), | |
| rotation=record.get("rotation", 0.0) | |
| ) |