perception / utils /tiling.py
Zhen Ye
feat: Implement SAHI Tiling for 4K video detection
af29397
import numpy as np
import torch
import logging
from typing import List, Tuple, Dict, Any, Optional
def get_slice_bboxes(
image_height: int,
image_width: int,
slice_height: int = 640,
slice_width: int = 640,
overlap_height_ratio: float = 0.2,
overlap_width_ratio: float = 0.2,
) -> List[List[int]]:
"""
Calculate bounding boxes for slices with overlap.
Returns: List of [x_min, y_min, x_max, y_max]
"""
slice_bboxes = []
y_max = y_min = 0
y_overlap = int(slice_height * overlap_height_ratio)
x_overlap = int(slice_width * overlap_width_ratio)
while y_max < image_height:
x_min = x_max = 0
y_max = y_min + slice_height
while x_max < image_width:
x_max = x_min + slice_width
# Adjustment for boundaries
if y_max > image_height:
y_max = image_height
y_min = max(0, image_height - slice_height)
if x_max > image_width:
x_max = image_width
x_min = max(0, image_width - slice_width)
slice_bboxes.append([x_min, y_min, x_max, y_max])
x_min = x_max - x_overlap
y_min = y_max - y_overlap
return slice_bboxes
def slice_image(
image: np.ndarray,
slice_bboxes: List[List[int]]
) -> List[np.ndarray]:
"""Crops the image based on provided bounding boxes."""
slices = []
for bbox in slice_bboxes:
xmin, ymin, xmax, ymax = bbox
slices.append(image[ymin:ymax, xmin:xmax])
return slices
def shift_bboxes(
bboxes: List[List[float]],
slice_coords: List[int]
) -> List[List[float]]:
"""
Shifts bounding boxes from slice coordinates to global image coordinates.
slice_coords: [xmin, ymin, xmax, ymax]
bboxes: List of [xmin, ymin, xmax, ymax]
"""
shift_x = slice_coords[0]
shift_y = slice_coords[1]
shifted = []
for box in bboxes:
# box = [x1, y1, x2, y2]
shifted.append([
box[0] + shift_x,
box[1] + shift_y,
box[2] + shift_x,
box[3] + shift_y
])
return shifted
def batched_nms(
boxes: torch.Tensor,
scores: torch.Tensor,
idxs: torch.Tensor,
iou_threshold: float = 0.5
) -> torch.Tensor:
"""
Performs non-maximum suppression in a batched fashion.
Fallback to simple NMS if torchvision/ultralytics unavailable.
"""
if boxes.numel() == 0:
return torch.empty((0,), dtype=torch.int64, device=boxes.device)
# Try importing efficient NMS implementations
try:
import torchvision
return torchvision.ops.batched_nms(boxes, scores, idxs, iou_threshold)
except ImportError:
pass
try:
from ultralytics.utils.ops import non_max_suppression
# Ultralytics NMS is usually complex/end-to-end. We need simple box NMS.
# Fallback to custom greedy NMS
except ImportError:
pass
# Custom Batched NMS Implementation (Slow but standard)
keep_indices = []
unique_labels = idxs.unique()
for label in unique_labels:
mask = (idxs == label)
cls_boxes = boxes[mask]
cls_scores = scores[mask]
original_indices = torch.where(mask)[0]
# Sort by score
sorted_indices = torch.argsort(cls_scores, descending=True)
cls_boxes = cls_boxes[sorted_indices]
original_indices = original_indices[sorted_indices]
cls_keep = []
while cls_boxes.size(0) > 0:
current_idx = 0
cls_keep.append(original_indices[current_idx])
if cls_boxes.size(0) == 1:
break
current_box = cls_boxes[current_idx].unsqueeze(0)
rest_boxes = cls_boxes[1:]
# IoU Calculation
x1 = torch.max(current_box[:, 0], rest_boxes[:, 0])
y1 = torch.max(current_box[:, 1], rest_boxes[:, 1])
x2 = torch.min(current_box[:, 2], rest_boxes[:, 2])
y2 = torch.min(current_box[:, 3], rest_boxes[:, 3])
inter_area = (x2 - x1).clamp(min=0) * (y2 - y1).clamp(min=0)
box_area = (current_box[:, 2] - current_box[:, 0]) * (current_box[:, 3] - current_box[:, 1])
rest_area = (rest_boxes[:, 2] - rest_boxes[:, 0]) * (rest_boxes[:, 3] - rest_boxes[:, 1])
union_area = box_area + rest_area - inter_area
iou = inter_area / (union_area + 1e-6)
# Keep boxes with low IoU
mask_iou = iou < iou_threshold
cls_boxes = rest_boxes[mask_iou]
original_indices = original_indices[1:][mask_iou]
keep_indices.extend(cls_keep)
return torch.tensor(keep_indices, dtype=torch.int64, device=boxes.device)