Spaces:
Sleeping
Sleeping
File size: 4,958 Bytes
af29397 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import numpy as np
import torch
import logging
from typing import List, Tuple, Dict, Any, Optional
def get_slice_bboxes(
image_height: int,
image_width: int,
slice_height: int = 640,
slice_width: int = 640,
overlap_height_ratio: float = 0.2,
overlap_width_ratio: float = 0.2,
) -> List[List[int]]:
"""
Calculate bounding boxes for slices with overlap.
Returns: List of [x_min, y_min, x_max, y_max]
"""
slice_bboxes = []
y_max = y_min = 0
y_overlap = int(slice_height * overlap_height_ratio)
x_overlap = int(slice_width * overlap_width_ratio)
while y_max < image_height:
x_min = x_max = 0
y_max = y_min + slice_height
while x_max < image_width:
x_max = x_min + slice_width
# Adjustment for boundaries
if y_max > image_height:
y_max = image_height
y_min = max(0, image_height - slice_height)
if x_max > image_width:
x_max = image_width
x_min = max(0, image_width - slice_width)
slice_bboxes.append([x_min, y_min, x_max, y_max])
x_min = x_max - x_overlap
y_min = y_max - y_overlap
return slice_bboxes
def slice_image(
image: np.ndarray,
slice_bboxes: List[List[int]]
) -> List[np.ndarray]:
"""Crops the image based on provided bounding boxes."""
slices = []
for bbox in slice_bboxes:
xmin, ymin, xmax, ymax = bbox
slices.append(image[ymin:ymax, xmin:xmax])
return slices
def shift_bboxes(
bboxes: List[List[float]],
slice_coords: List[int]
) -> List[List[float]]:
"""
Shifts bounding boxes from slice coordinates to global image coordinates.
slice_coords: [xmin, ymin, xmax, ymax]
bboxes: List of [xmin, ymin, xmax, ymax]
"""
shift_x = slice_coords[0]
shift_y = slice_coords[1]
shifted = []
for box in bboxes:
# box = [x1, y1, x2, y2]
shifted.append([
box[0] + shift_x,
box[1] + shift_y,
box[2] + shift_x,
box[3] + shift_y
])
return shifted
def batched_nms(
boxes: torch.Tensor,
scores: torch.Tensor,
idxs: torch.Tensor,
iou_threshold: float = 0.5
) -> torch.Tensor:
"""
Performs non-maximum suppression in a batched fashion.
Fallback to simple NMS if torchvision/ultralytics unavailable.
"""
if boxes.numel() == 0:
return torch.empty((0,), dtype=torch.int64, device=boxes.device)
# Try importing efficient NMS implementations
try:
import torchvision
return torchvision.ops.batched_nms(boxes, scores, idxs, iou_threshold)
except ImportError:
pass
try:
from ultralytics.utils.ops import non_max_suppression
# Ultralytics NMS is usually complex/end-to-end. We need simple box NMS.
# Fallback to custom greedy NMS
except ImportError:
pass
# Custom Batched NMS Implementation (Slow but standard)
keep_indices = []
unique_labels = idxs.unique()
for label in unique_labels:
mask = (idxs == label)
cls_boxes = boxes[mask]
cls_scores = scores[mask]
original_indices = torch.where(mask)[0]
# Sort by score
sorted_indices = torch.argsort(cls_scores, descending=True)
cls_boxes = cls_boxes[sorted_indices]
original_indices = original_indices[sorted_indices]
cls_keep = []
while cls_boxes.size(0) > 0:
current_idx = 0
cls_keep.append(original_indices[current_idx])
if cls_boxes.size(0) == 1:
break
current_box = cls_boxes[current_idx].unsqueeze(0)
rest_boxes = cls_boxes[1:]
# IoU Calculation
x1 = torch.max(current_box[:, 0], rest_boxes[:, 0])
y1 = torch.max(current_box[:, 1], rest_boxes[:, 1])
x2 = torch.min(current_box[:, 2], rest_boxes[:, 2])
y2 = torch.min(current_box[:, 3], rest_boxes[:, 3])
inter_area = (x2 - x1).clamp(min=0) * (y2 - y1).clamp(min=0)
box_area = (current_box[:, 2] - current_box[:, 0]) * (current_box[:, 3] - current_box[:, 1])
rest_area = (rest_boxes[:, 2] - rest_boxes[:, 0]) * (rest_boxes[:, 3] - rest_boxes[:, 1])
union_area = box_area + rest_area - inter_area
iou = inter_area / (union_area + 1e-6)
# Keep boxes with low IoU
mask_iou = iou < iou_threshold
cls_boxes = rest_boxes[mask_iou]
original_indices = original_indices[1:][mask_iou]
keep_indices.extend(cls_keep)
return torch.tensor(keep_indices, dtype=torch.int64, device=boxes.device)
|