import numpy as np import torch import logging from typing import List, Tuple, Dict, Any, Optional def get_slice_bboxes( image_height: int, image_width: int, slice_height: int = 640, slice_width: int = 640, overlap_height_ratio: float = 0.2, overlap_width_ratio: float = 0.2, ) -> List[List[int]]: """ Calculate bounding boxes for slices with overlap. Returns: List of [x_min, y_min, x_max, y_max] """ slice_bboxes = [] y_max = y_min = 0 y_overlap = int(slice_height * overlap_height_ratio) x_overlap = int(slice_width * overlap_width_ratio) while y_max < image_height: x_min = x_max = 0 y_max = y_min + slice_height while x_max < image_width: x_max = x_min + slice_width # Adjustment for boundaries if y_max > image_height: y_max = image_height y_min = max(0, image_height - slice_height) if x_max > image_width: x_max = image_width x_min = max(0, image_width - slice_width) slice_bboxes.append([x_min, y_min, x_max, y_max]) x_min = x_max - x_overlap y_min = y_max - y_overlap return slice_bboxes def slice_image( image: np.ndarray, slice_bboxes: List[List[int]] ) -> List[np.ndarray]: """Crops the image based on provided bounding boxes.""" slices = [] for bbox in slice_bboxes: xmin, ymin, xmax, ymax = bbox slices.append(image[ymin:ymax, xmin:xmax]) return slices def shift_bboxes( bboxes: List[List[float]], slice_coords: List[int] ) -> List[List[float]]: """ Shifts bounding boxes from slice coordinates to global image coordinates. slice_coords: [xmin, ymin, xmax, ymax] bboxes: List of [xmin, ymin, xmax, ymax] """ shift_x = slice_coords[0] shift_y = slice_coords[1] shifted = [] for box in bboxes: # box = [x1, y1, x2, y2] shifted.append([ box[0] + shift_x, box[1] + shift_y, box[2] + shift_x, box[3] + shift_y ]) return shifted def batched_nms( boxes: torch.Tensor, scores: torch.Tensor, idxs: torch.Tensor, iou_threshold: float = 0.5 ) -> torch.Tensor: """ Performs non-maximum suppression in a batched fashion. Fallback to simple NMS if torchvision/ultralytics unavailable. """ if boxes.numel() == 0: return torch.empty((0,), dtype=torch.int64, device=boxes.device) # Try importing efficient NMS implementations try: import torchvision return torchvision.ops.batched_nms(boxes, scores, idxs, iou_threshold) except ImportError: pass try: from ultralytics.utils.ops import non_max_suppression # Ultralytics NMS is usually complex/end-to-end. We need simple box NMS. # Fallback to custom greedy NMS except ImportError: pass # Custom Batched NMS Implementation (Slow but standard) keep_indices = [] unique_labels = idxs.unique() for label in unique_labels: mask = (idxs == label) cls_boxes = boxes[mask] cls_scores = scores[mask] original_indices = torch.where(mask)[0] # Sort by score sorted_indices = torch.argsort(cls_scores, descending=True) cls_boxes = cls_boxes[sorted_indices] original_indices = original_indices[sorted_indices] cls_keep = [] while cls_boxes.size(0) > 0: current_idx = 0 cls_keep.append(original_indices[current_idx]) if cls_boxes.size(0) == 1: break current_box = cls_boxes[current_idx].unsqueeze(0) rest_boxes = cls_boxes[1:] # IoU Calculation x1 = torch.max(current_box[:, 0], rest_boxes[:, 0]) y1 = torch.max(current_box[:, 1], rest_boxes[:, 1]) x2 = torch.min(current_box[:, 2], rest_boxes[:, 2]) y2 = torch.min(current_box[:, 3], rest_boxes[:, 3]) inter_area = (x2 - x1).clamp(min=0) * (y2 - y1).clamp(min=0) box_area = (current_box[:, 2] - current_box[:, 0]) * (current_box[:, 3] - current_box[:, 1]) rest_area = (rest_boxes[:, 2] - rest_boxes[:, 0]) * (rest_boxes[:, 3] - rest_boxes[:, 1]) union_area = box_area + rest_area - inter_area iou = inter_area / (union_area + 1e-6) # Keep boxes with low IoU mask_iou = iou < iou_threshold cls_boxes = rest_boxes[mask_iou] original_indices = original_indices[1:][mask_iou] keep_indices.extend(cls_keep) return torch.tensor(keep_indices, dtype=torch.int64, device=boxes.device)