Spaces:
Paused
Paused
Zhen Ye
commited on
Commit
·
af29397
1
Parent(s):
284ce20
feat: Implement SAHI Tiling for 4K video detection
Browse files- Added utils/tiling.py for image slicing and NMS
- Updated DroneYolo and HFYoloV8 to auto-tile images > 3000px width
- Uses 1280x1280 tiles with 20% overlap for maximum small object recall
- models/detectors/drone_yolo.py +94 -0
- models/detectors/yolov8.py +80 -0
- utils/tiling.py +153 -0
models/detectors/drone_yolo.py
CHANGED
|
@@ -62,7 +62,94 @@ class DroneYoloDetector(ObjectDetector):
|
|
| 62 |
label_names=label_names,
|
| 63 |
)
|
| 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
def predict(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
device_arg = self.device
|
| 67 |
results = self.model.predict(
|
| 68 |
source=frame,
|
|
@@ -74,6 +161,13 @@ class DroneYoloDetector(ObjectDetector):
|
|
| 74 |
return self._parse_single_result(results[0], queries)
|
| 75 |
|
| 76 |
def predict_batch(self, frames: Sequence[np.ndarray], queries: Sequence[str]) -> Sequence[DetectionResult]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
results = self.model.predict(
|
| 78 |
source=frames,
|
| 79 |
device=self.device,
|
|
|
|
| 62 |
label_names=label_names,
|
| 63 |
)
|
| 64 |
|
| 65 |
+
from utils.tiling import get_slice_bboxes, slice_image, shift_bboxes, batched_nms
|
| 66 |
+
|
| 67 |
+
def _predict_tiled(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult:
|
| 68 |
+
"""Run tiled inference for high-resolution frames."""
|
| 69 |
+
# 1. Slice
|
| 70 |
+
h, w = frame.shape[:2]
|
| 71 |
+
# Heuristic: 1280x1280 tiles with 20% overlap
|
| 72 |
+
slice_boxes = get_slice_bboxes(h, w, 1280, 1280, 0.2, 0.2)
|
| 73 |
+
tiles = slice_image(frame, slice_boxes)
|
| 74 |
+
|
| 75 |
+
# 2. Batch Inference
|
| 76 |
+
# We can use our own model's batch prediction if we can trust it not to recurse strictly
|
| 77 |
+
# But we need raw results to merge.
|
| 78 |
+
# Actually proper way: run standard predict on tiles.
|
| 79 |
+
|
| 80 |
+
all_boxes = []
|
| 81 |
+
all_scores = []
|
| 82 |
+
all_labels = []
|
| 83 |
+
|
| 84 |
+
# Run in batches of max_batch_size to respect GPU memory
|
| 85 |
+
batch_size = self.max_batch_size
|
| 86 |
+
for i in range(0, len(tiles), batch_size):
|
| 87 |
+
batch_tiles = tiles[i : i + batch_size]
|
| 88 |
+
batch_slices = slice_boxes[i : i + batch_size]
|
| 89 |
+
|
| 90 |
+
results = self.model.predict(
|
| 91 |
+
source=batch_tiles,
|
| 92 |
+
device=self.device,
|
| 93 |
+
conf=self.score_threshold,
|
| 94 |
+
imgsz=1280, # Run tiles at full res
|
| 95 |
+
verbose=False,
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
for res, slice_coord in zip(results, batch_slices):
|
| 99 |
+
if res.boxes is None: continue
|
| 100 |
+
# Extract standard results
|
| 101 |
+
boxes = res.boxes.xyxy.cpu().numpy().tolist()
|
| 102 |
+
scores = res.boxes.conf.cpu().numpy().tolist()
|
| 103 |
+
clss = res.boxes.cls.cpu().numpy().tolist()
|
| 104 |
+
|
| 105 |
+
# Shift to global
|
| 106 |
+
shifted = shift_bboxes(boxes, slice_coord)
|
| 107 |
+
|
| 108 |
+
all_boxes.extend(shifted)
|
| 109 |
+
all_scores.extend(scores)
|
| 110 |
+
all_labels.extend(clss)
|
| 111 |
+
|
| 112 |
+
if not all_boxes:
|
| 113 |
+
empty = np.empty((0, 4), dtype=np.float32)
|
| 114 |
+
return DetectionResult(empty, [], [], [])
|
| 115 |
+
|
| 116 |
+
# 3. NMS Merge
|
| 117 |
+
boxes_t = torch.tensor(all_boxes, device=self.device)
|
| 118 |
+
scores_t = torch.tensor(all_scores, device=self.device)
|
| 119 |
+
labels_t = torch.tensor(all_labels, device=self.device)
|
| 120 |
+
|
| 121 |
+
keep = batched_nms(boxes_t, scores_t, labels_t, iou_threshold=0.4)
|
| 122 |
+
|
| 123 |
+
final_boxes = boxes_t[keep].cpu().numpy()
|
| 124 |
+
final_scores = scores_t[keep].cpu().tolist()
|
| 125 |
+
final_labels = labels_t[keep].cpu().int().tolist()
|
| 126 |
+
|
| 127 |
+
# 4. Filter & Format
|
| 128 |
+
label_names = [self.class_names.get(idx, f"class_{idx}") for idx in final_labels]
|
| 129 |
+
keep_indices = self._filter_indices(label_names, queries)
|
| 130 |
+
|
| 131 |
+
if not keep_indices:
|
| 132 |
+
empty = np.empty((0, 4), dtype=np.float32)
|
| 133 |
+
return DetectionResult(empty, [], [], [])
|
| 134 |
+
|
| 135 |
+
final_boxes = final_boxes[keep_indices]
|
| 136 |
+
final_scores = [final_scores[i] for i in keep_indices]
|
| 137 |
+
final_labels = [final_labels[i] for i in keep_indices]
|
| 138 |
+
final_names = [label_names[i] for i in keep_indices]
|
| 139 |
+
|
| 140 |
+
return DetectionResult(
|
| 141 |
+
boxes=final_boxes,
|
| 142 |
+
scores=final_scores,
|
| 143 |
+
labels=final_labels,
|
| 144 |
+
label_names=final_names
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
def predict(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult:
|
| 148 |
+
h, w = frame.shape[:2]
|
| 149 |
+
# Enable tiling for 4Kish images (width > 3000)
|
| 150 |
+
if w > 3000:
|
| 151 |
+
return self._predict_tiled(frame, queries)
|
| 152 |
+
|
| 153 |
device_arg = self.device
|
| 154 |
results = self.model.predict(
|
| 155 |
source=frame,
|
|
|
|
| 161 |
return self._parse_single_result(results[0], queries)
|
| 162 |
|
| 163 |
def predict_batch(self, frames: Sequence[np.ndarray], queries: Sequence[str]) -> Sequence[DetectionResult]:
|
| 164 |
+
# Mixed batch support is hard. Assume batch is uniform size.
|
| 165 |
+
if not frames: return []
|
| 166 |
+
h, w = frames[0].shape[:2]
|
| 167 |
+
|
| 168 |
+
if w > 3000:
|
| 169 |
+
return [self._predict_tiled(f, queries) for f in frames]
|
| 170 |
+
|
| 171 |
results = self.model.predict(
|
| 172 |
source=frames,
|
| 173 |
device=self.device,
|
models/detectors/yolov8.py
CHANGED
|
@@ -7,6 +7,7 @@ from huggingface_hub import hf_hub_download
|
|
| 7 |
from ultralytics import YOLO
|
| 8 |
|
| 9 |
from models.detectors.base import DetectionResult, ObjectDetector
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
class HuggingFaceYoloV8Detector(ObjectDetector):
|
|
@@ -64,7 +65,81 @@ class HuggingFaceYoloV8Detector(ObjectDetector):
|
|
| 64 |
label_names=label_names,
|
| 65 |
)
|
| 66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
def predict(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
results = self.model.predict(
|
| 69 |
source=frame,
|
| 70 |
device=self.device,
|
|
@@ -75,6 +150,11 @@ class HuggingFaceYoloV8Detector(ObjectDetector):
|
|
| 75 |
return self._parse_single_result(results[0], queries)
|
| 76 |
|
| 77 |
def predict_batch(self, frames: Sequence[np.ndarray], queries: Sequence[str]) -> Sequence[DetectionResult]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
results = self.model.predict(
|
| 79 |
source=frames,
|
| 80 |
device=self.device,
|
|
|
|
| 7 |
from ultralytics import YOLO
|
| 8 |
|
| 9 |
from models.detectors.base import DetectionResult, ObjectDetector
|
| 10 |
+
from utils.tiling import get_slice_bboxes, slice_image, shift_bboxes, batched_nms
|
| 11 |
|
| 12 |
|
| 13 |
class HuggingFaceYoloV8Detector(ObjectDetector):
|
|
|
|
| 65 |
label_names=label_names,
|
| 66 |
)
|
| 67 |
|
| 68 |
+
def _predict_tiled(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult:
|
| 69 |
+
"""Run tiled inference for high-resolution frames."""
|
| 70 |
+
h, w = frame.shape[:2]
|
| 71 |
+
# Heuristic: 1280x1280 tiles with 20% overlap
|
| 72 |
+
slice_boxes = get_slice_bboxes(h, w, 1280, 1280, 0.2, 0.2)
|
| 73 |
+
tiles = slice_image(frame, slice_boxes)
|
| 74 |
+
|
| 75 |
+
all_boxes = []
|
| 76 |
+
all_scores = []
|
| 77 |
+
all_labels = []
|
| 78 |
+
|
| 79 |
+
batch_size = self.max_batch_size
|
| 80 |
+
for i in range(0, len(tiles), batch_size):
|
| 81 |
+
batch_tiles = tiles[i : i + batch_size]
|
| 82 |
+
batch_slices = slice_boxes[i : i + batch_size]
|
| 83 |
+
|
| 84 |
+
# Using 1280px tiles
|
| 85 |
+
results = self.model.predict(
|
| 86 |
+
source=batch_tiles,
|
| 87 |
+
device=self.device,
|
| 88 |
+
conf=self.score_threshold,
|
| 89 |
+
imgsz=1280,
|
| 90 |
+
verbose=False,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
for res, slice_coord in zip(results, batch_slices):
|
| 94 |
+
if res.boxes is None: continue
|
| 95 |
+
boxes = res.boxes.xyxy.cpu().numpy().tolist()
|
| 96 |
+
scores = res.boxes.conf.cpu().numpy().tolist()
|
| 97 |
+
clss = res.boxes.cls.cpu().numpy().tolist()
|
| 98 |
+
|
| 99 |
+
shifted = shift_bboxes(boxes, slice_coord)
|
| 100 |
+
|
| 101 |
+
all_boxes.extend(shifted)
|
| 102 |
+
all_scores.extend(scores)
|
| 103 |
+
all_labels.extend(clss)
|
| 104 |
+
|
| 105 |
+
if not all_boxes:
|
| 106 |
+
empty = np.empty((0, 4), dtype=np.float32)
|
| 107 |
+
return DetectionResult(empty, [], [], [])
|
| 108 |
+
|
| 109 |
+
boxes_t = torch.tensor(all_boxes, device=self.device)
|
| 110 |
+
scores_t = torch.tensor(all_scores, device=self.device)
|
| 111 |
+
labels_t = torch.tensor(all_labels, device=self.device)
|
| 112 |
+
|
| 113 |
+
keep = batched_nms(boxes_t, scores_t, labels_t, iou_threshold=0.4)
|
| 114 |
+
|
| 115 |
+
final_boxes = boxes_t[keep].cpu().numpy()
|
| 116 |
+
final_scores = scores_t[keep].cpu().tolist()
|
| 117 |
+
final_labels = labels_t[keep].cpu().int().tolist()
|
| 118 |
+
|
| 119 |
+
label_names = [self.class_names.get(idx, f"class_{idx}") for idx in final_labels]
|
| 120 |
+
keep_indices = self._filter_indices(label_names, queries)
|
| 121 |
+
|
| 122 |
+
if not keep_indices:
|
| 123 |
+
empty = np.empty((0, 4), dtype=np.float32)
|
| 124 |
+
return DetectionResult(empty, [], [], [])
|
| 125 |
+
|
| 126 |
+
final_boxes = final_boxes[keep_indices]
|
| 127 |
+
final_scores = [final_scores[i] for i in keep_indices]
|
| 128 |
+
final_labels = [final_labels[i] for i in keep_indices]
|
| 129 |
+
final_names = [label_names[i] for i in keep_indices]
|
| 130 |
+
|
| 131 |
+
return DetectionResult(
|
| 132 |
+
boxes=final_boxes,
|
| 133 |
+
scores=final_scores,
|
| 134 |
+
labels=final_labels,
|
| 135 |
+
label_names=final_names
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
def predict(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult:
|
| 139 |
+
h, w = frame.shape[:2]
|
| 140 |
+
if w > 3000:
|
| 141 |
+
return self._predict_tiled(frame, queries)
|
| 142 |
+
|
| 143 |
results = self.model.predict(
|
| 144 |
source=frame,
|
| 145 |
device=self.device,
|
|
|
|
| 150 |
return self._parse_single_result(results[0], queries)
|
| 151 |
|
| 152 |
def predict_batch(self, frames: Sequence[np.ndarray], queries: Sequence[str]) -> Sequence[DetectionResult]:
|
| 153 |
+
if not frames: return []
|
| 154 |
+
h, w = frames[0].shape[:2]
|
| 155 |
+
if w > 3000:
|
| 156 |
+
return [self._predict_tiled(f, queries) for f in frames]
|
| 157 |
+
|
| 158 |
results = self.model.predict(
|
| 159 |
source=frames,
|
| 160 |
device=self.device,
|
utils/tiling.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import logging
|
| 4 |
+
from typing import List, Tuple, Dict, Any, Optional
|
| 5 |
+
|
| 6 |
+
def get_slice_bboxes(
|
| 7 |
+
image_height: int,
|
| 8 |
+
image_width: int,
|
| 9 |
+
slice_height: int = 640,
|
| 10 |
+
slice_width: int = 640,
|
| 11 |
+
overlap_height_ratio: float = 0.2,
|
| 12 |
+
overlap_width_ratio: float = 0.2,
|
| 13 |
+
) -> List[List[int]]:
|
| 14 |
+
"""
|
| 15 |
+
Calculate bounding boxes for slices with overlap.
|
| 16 |
+
Returns: List of [x_min, y_min, x_max, y_max]
|
| 17 |
+
"""
|
| 18 |
+
slice_bboxes = []
|
| 19 |
+
y_max = y_min = 0
|
| 20 |
+
y_overlap = int(slice_height * overlap_height_ratio)
|
| 21 |
+
x_overlap = int(slice_width * overlap_width_ratio)
|
| 22 |
+
|
| 23 |
+
while y_max < image_height:
|
| 24 |
+
x_min = x_max = 0
|
| 25 |
+
y_max = y_min + slice_height
|
| 26 |
+
|
| 27 |
+
while x_max < image_width:
|
| 28 |
+
x_max = x_min + slice_width
|
| 29 |
+
|
| 30 |
+
# Adjustment for boundaries
|
| 31 |
+
if y_max > image_height:
|
| 32 |
+
y_max = image_height
|
| 33 |
+
y_min = max(0, image_height - slice_height)
|
| 34 |
+
|
| 35 |
+
if x_max > image_width:
|
| 36 |
+
x_max = image_width
|
| 37 |
+
x_min = max(0, image_width - slice_width)
|
| 38 |
+
|
| 39 |
+
slice_bboxes.append([x_min, y_min, x_max, y_max])
|
| 40 |
+
|
| 41 |
+
x_min = x_max - x_overlap
|
| 42 |
+
y_min = y_max - y_overlap
|
| 43 |
+
|
| 44 |
+
return slice_bboxes
|
| 45 |
+
|
| 46 |
+
def slice_image(
|
| 47 |
+
image: np.ndarray,
|
| 48 |
+
slice_bboxes: List[List[int]]
|
| 49 |
+
) -> List[np.ndarray]:
|
| 50 |
+
"""Crops the image based on provided bounding boxes."""
|
| 51 |
+
slices = []
|
| 52 |
+
for bbox in slice_bboxes:
|
| 53 |
+
xmin, ymin, xmax, ymax = bbox
|
| 54 |
+
slices.append(image[ymin:ymax, xmin:xmax])
|
| 55 |
+
return slices
|
| 56 |
+
|
| 57 |
+
def shift_bboxes(
|
| 58 |
+
bboxes: List[List[float]],
|
| 59 |
+
slice_coords: List[int]
|
| 60 |
+
) -> List[List[float]]:
|
| 61 |
+
"""
|
| 62 |
+
Shifts bounding boxes from slice coordinates to global image coordinates.
|
| 63 |
+
slice_coords: [xmin, ymin, xmax, ymax]
|
| 64 |
+
bboxes: List of [xmin, ymin, xmax, ymax]
|
| 65 |
+
"""
|
| 66 |
+
shift_x = slice_coords[0]
|
| 67 |
+
shift_y = slice_coords[1]
|
| 68 |
+
|
| 69 |
+
shifted = []
|
| 70 |
+
for box in bboxes:
|
| 71 |
+
# box = [x1, y1, x2, y2]
|
| 72 |
+
shifted.append([
|
| 73 |
+
box[0] + shift_x,
|
| 74 |
+
box[1] + shift_y,
|
| 75 |
+
box[2] + shift_x,
|
| 76 |
+
box[3] + shift_y
|
| 77 |
+
])
|
| 78 |
+
return shifted
|
| 79 |
+
|
| 80 |
+
def batched_nms(
|
| 81 |
+
boxes: torch.Tensor,
|
| 82 |
+
scores: torch.Tensor,
|
| 83 |
+
idxs: torch.Tensor,
|
| 84 |
+
iou_threshold: float = 0.5
|
| 85 |
+
) -> torch.Tensor:
|
| 86 |
+
"""
|
| 87 |
+
Performs non-maximum suppression in a batched fashion.
|
| 88 |
+
Fallback to simple NMS if torchvision/ultralytics unavailable.
|
| 89 |
+
"""
|
| 90 |
+
if boxes.numel() == 0:
|
| 91 |
+
return torch.empty((0,), dtype=torch.int64, device=boxes.device)
|
| 92 |
+
|
| 93 |
+
# Try importing efficient NMS implementations
|
| 94 |
+
try:
|
| 95 |
+
import torchvision
|
| 96 |
+
return torchvision.ops.batched_nms(boxes, scores, idxs, iou_threshold)
|
| 97 |
+
except ImportError:
|
| 98 |
+
pass
|
| 99 |
+
|
| 100 |
+
try:
|
| 101 |
+
from ultralytics.utils.ops import non_max_suppression
|
| 102 |
+
# Ultralytics NMS is usually complex/end-to-end. We need simple box NMS.
|
| 103 |
+
# Fallback to custom greedy NMS
|
| 104 |
+
except ImportError:
|
| 105 |
+
pass
|
| 106 |
+
|
| 107 |
+
# Custom Batched NMS Implementation (Slow but standard)
|
| 108 |
+
keep_indices = []
|
| 109 |
+
unique_labels = idxs.unique()
|
| 110 |
+
|
| 111 |
+
for label in unique_labels:
|
| 112 |
+
mask = (idxs == label)
|
| 113 |
+
cls_boxes = boxes[mask]
|
| 114 |
+
cls_scores = scores[mask]
|
| 115 |
+
original_indices = torch.where(mask)[0]
|
| 116 |
+
|
| 117 |
+
# Sort by score
|
| 118 |
+
sorted_indices = torch.argsort(cls_scores, descending=True)
|
| 119 |
+
cls_boxes = cls_boxes[sorted_indices]
|
| 120 |
+
original_indices = original_indices[sorted_indices]
|
| 121 |
+
|
| 122 |
+
cls_keep = []
|
| 123 |
+
while cls_boxes.size(0) > 0:
|
| 124 |
+
current_idx = 0
|
| 125 |
+
cls_keep.append(original_indices[current_idx])
|
| 126 |
+
|
| 127 |
+
if cls_boxes.size(0) == 1:
|
| 128 |
+
break
|
| 129 |
+
|
| 130 |
+
current_box = cls_boxes[current_idx].unsqueeze(0)
|
| 131 |
+
rest_boxes = cls_boxes[1:]
|
| 132 |
+
|
| 133 |
+
# IoU Calculation
|
| 134 |
+
x1 = torch.max(current_box[:, 0], rest_boxes[:, 0])
|
| 135 |
+
y1 = torch.max(current_box[:, 1], rest_boxes[:, 1])
|
| 136 |
+
x2 = torch.min(current_box[:, 2], rest_boxes[:, 2])
|
| 137 |
+
y2 = torch.min(current_box[:, 3], rest_boxes[:, 3])
|
| 138 |
+
|
| 139 |
+
inter_area = (x2 - x1).clamp(min=0) * (y2 - y1).clamp(min=0)
|
| 140 |
+
box_area = (current_box[:, 2] - current_box[:, 0]) * (current_box[:, 3] - current_box[:, 1])
|
| 141 |
+
rest_area = (rest_boxes[:, 2] - rest_boxes[:, 0]) * (rest_boxes[:, 3] - rest_boxes[:, 1])
|
| 142 |
+
union_area = box_area + rest_area - inter_area
|
| 143 |
+
|
| 144 |
+
iou = inter_area / (union_area + 1e-6)
|
| 145 |
+
|
| 146 |
+
# Keep boxes with low IoU
|
| 147 |
+
mask_iou = iou < iou_threshold
|
| 148 |
+
cls_boxes = rest_boxes[mask_iou]
|
| 149 |
+
original_indices = original_indices[1:][mask_iou]
|
| 150 |
+
|
| 151 |
+
keep_indices.extend(cls_keep)
|
| 152 |
+
|
| 153 |
+
return torch.tensor(keep_indices, dtype=torch.int64, device=boxes.device)
|