Zhen Ye commited on
Commit
af29397
·
1 Parent(s): 284ce20

feat: Implement SAHI Tiling for 4K video detection

Browse files

- Added utils/tiling.py for image slicing and NMS
- Updated DroneYolo and HFYoloV8 to auto-tile images > 3000px width
- Uses 1280x1280 tiles with 20% overlap for maximum small object recall

models/detectors/drone_yolo.py CHANGED
@@ -62,7 +62,94 @@ class DroneYoloDetector(ObjectDetector):
62
  label_names=label_names,
63
  )
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  def predict(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult:
 
 
 
 
 
66
  device_arg = self.device
67
  results = self.model.predict(
68
  source=frame,
@@ -74,6 +161,13 @@ class DroneYoloDetector(ObjectDetector):
74
  return self._parse_single_result(results[0], queries)
75
 
76
  def predict_batch(self, frames: Sequence[np.ndarray], queries: Sequence[str]) -> Sequence[DetectionResult]:
 
 
 
 
 
 
 
77
  results = self.model.predict(
78
  source=frames,
79
  device=self.device,
 
62
  label_names=label_names,
63
  )
64
 
65
+ from utils.tiling import get_slice_bboxes, slice_image, shift_bboxes, batched_nms
66
+
67
+ def _predict_tiled(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult:
68
+ """Run tiled inference for high-resolution frames."""
69
+ # 1. Slice
70
+ h, w = frame.shape[:2]
71
+ # Heuristic: 1280x1280 tiles with 20% overlap
72
+ slice_boxes = get_slice_bboxes(h, w, 1280, 1280, 0.2, 0.2)
73
+ tiles = slice_image(frame, slice_boxes)
74
+
75
+ # 2. Batch Inference
76
+ # We can use our own model's batch prediction if we can trust it not to recurse strictly
77
+ # But we need raw results to merge.
78
+ # Actually proper way: run standard predict on tiles.
79
+
80
+ all_boxes = []
81
+ all_scores = []
82
+ all_labels = []
83
+
84
+ # Run in batches of max_batch_size to respect GPU memory
85
+ batch_size = self.max_batch_size
86
+ for i in range(0, len(tiles), batch_size):
87
+ batch_tiles = tiles[i : i + batch_size]
88
+ batch_slices = slice_boxes[i : i + batch_size]
89
+
90
+ results = self.model.predict(
91
+ source=batch_tiles,
92
+ device=self.device,
93
+ conf=self.score_threshold,
94
+ imgsz=1280, # Run tiles at full res
95
+ verbose=False,
96
+ )
97
+
98
+ for res, slice_coord in zip(results, batch_slices):
99
+ if res.boxes is None: continue
100
+ # Extract standard results
101
+ boxes = res.boxes.xyxy.cpu().numpy().tolist()
102
+ scores = res.boxes.conf.cpu().numpy().tolist()
103
+ clss = res.boxes.cls.cpu().numpy().tolist()
104
+
105
+ # Shift to global
106
+ shifted = shift_bboxes(boxes, slice_coord)
107
+
108
+ all_boxes.extend(shifted)
109
+ all_scores.extend(scores)
110
+ all_labels.extend(clss)
111
+
112
+ if not all_boxes:
113
+ empty = np.empty((0, 4), dtype=np.float32)
114
+ return DetectionResult(empty, [], [], [])
115
+
116
+ # 3. NMS Merge
117
+ boxes_t = torch.tensor(all_boxes, device=self.device)
118
+ scores_t = torch.tensor(all_scores, device=self.device)
119
+ labels_t = torch.tensor(all_labels, device=self.device)
120
+
121
+ keep = batched_nms(boxes_t, scores_t, labels_t, iou_threshold=0.4)
122
+
123
+ final_boxes = boxes_t[keep].cpu().numpy()
124
+ final_scores = scores_t[keep].cpu().tolist()
125
+ final_labels = labels_t[keep].cpu().int().tolist()
126
+
127
+ # 4. Filter & Format
128
+ label_names = [self.class_names.get(idx, f"class_{idx}") for idx in final_labels]
129
+ keep_indices = self._filter_indices(label_names, queries)
130
+
131
+ if not keep_indices:
132
+ empty = np.empty((0, 4), dtype=np.float32)
133
+ return DetectionResult(empty, [], [], [])
134
+
135
+ final_boxes = final_boxes[keep_indices]
136
+ final_scores = [final_scores[i] for i in keep_indices]
137
+ final_labels = [final_labels[i] for i in keep_indices]
138
+ final_names = [label_names[i] for i in keep_indices]
139
+
140
+ return DetectionResult(
141
+ boxes=final_boxes,
142
+ scores=final_scores,
143
+ labels=final_labels,
144
+ label_names=final_names
145
+ )
146
+
147
  def predict(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult:
148
+ h, w = frame.shape[:2]
149
+ # Enable tiling for 4Kish images (width > 3000)
150
+ if w > 3000:
151
+ return self._predict_tiled(frame, queries)
152
+
153
  device_arg = self.device
154
  results = self.model.predict(
155
  source=frame,
 
161
  return self._parse_single_result(results[0], queries)
162
 
163
  def predict_batch(self, frames: Sequence[np.ndarray], queries: Sequence[str]) -> Sequence[DetectionResult]:
164
+ # Mixed batch support is hard. Assume batch is uniform size.
165
+ if not frames: return []
166
+ h, w = frames[0].shape[:2]
167
+
168
+ if w > 3000:
169
+ return [self._predict_tiled(f, queries) for f in frames]
170
+
171
  results = self.model.predict(
172
  source=frames,
173
  device=self.device,
models/detectors/yolov8.py CHANGED
@@ -7,6 +7,7 @@ from huggingface_hub import hf_hub_download
7
  from ultralytics import YOLO
8
 
9
  from models.detectors.base import DetectionResult, ObjectDetector
 
10
 
11
 
12
  class HuggingFaceYoloV8Detector(ObjectDetector):
@@ -64,7 +65,81 @@ class HuggingFaceYoloV8Detector(ObjectDetector):
64
  label_names=label_names,
65
  )
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  def predict(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult:
 
 
 
 
68
  results = self.model.predict(
69
  source=frame,
70
  device=self.device,
@@ -75,6 +150,11 @@ class HuggingFaceYoloV8Detector(ObjectDetector):
75
  return self._parse_single_result(results[0], queries)
76
 
77
  def predict_batch(self, frames: Sequence[np.ndarray], queries: Sequence[str]) -> Sequence[DetectionResult]:
 
 
 
 
 
78
  results = self.model.predict(
79
  source=frames,
80
  device=self.device,
 
7
  from ultralytics import YOLO
8
 
9
  from models.detectors.base import DetectionResult, ObjectDetector
10
+ from utils.tiling import get_slice_bboxes, slice_image, shift_bboxes, batched_nms
11
 
12
 
13
  class HuggingFaceYoloV8Detector(ObjectDetector):
 
65
  label_names=label_names,
66
  )
67
 
68
+ def _predict_tiled(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult:
69
+ """Run tiled inference for high-resolution frames."""
70
+ h, w = frame.shape[:2]
71
+ # Heuristic: 1280x1280 tiles with 20% overlap
72
+ slice_boxes = get_slice_bboxes(h, w, 1280, 1280, 0.2, 0.2)
73
+ tiles = slice_image(frame, slice_boxes)
74
+
75
+ all_boxes = []
76
+ all_scores = []
77
+ all_labels = []
78
+
79
+ batch_size = self.max_batch_size
80
+ for i in range(0, len(tiles), batch_size):
81
+ batch_tiles = tiles[i : i + batch_size]
82
+ batch_slices = slice_boxes[i : i + batch_size]
83
+
84
+ # Using 1280px tiles
85
+ results = self.model.predict(
86
+ source=batch_tiles,
87
+ device=self.device,
88
+ conf=self.score_threshold,
89
+ imgsz=1280,
90
+ verbose=False,
91
+ )
92
+
93
+ for res, slice_coord in zip(results, batch_slices):
94
+ if res.boxes is None: continue
95
+ boxes = res.boxes.xyxy.cpu().numpy().tolist()
96
+ scores = res.boxes.conf.cpu().numpy().tolist()
97
+ clss = res.boxes.cls.cpu().numpy().tolist()
98
+
99
+ shifted = shift_bboxes(boxes, slice_coord)
100
+
101
+ all_boxes.extend(shifted)
102
+ all_scores.extend(scores)
103
+ all_labels.extend(clss)
104
+
105
+ if not all_boxes:
106
+ empty = np.empty((0, 4), dtype=np.float32)
107
+ return DetectionResult(empty, [], [], [])
108
+
109
+ boxes_t = torch.tensor(all_boxes, device=self.device)
110
+ scores_t = torch.tensor(all_scores, device=self.device)
111
+ labels_t = torch.tensor(all_labels, device=self.device)
112
+
113
+ keep = batched_nms(boxes_t, scores_t, labels_t, iou_threshold=0.4)
114
+
115
+ final_boxes = boxes_t[keep].cpu().numpy()
116
+ final_scores = scores_t[keep].cpu().tolist()
117
+ final_labels = labels_t[keep].cpu().int().tolist()
118
+
119
+ label_names = [self.class_names.get(idx, f"class_{idx}") for idx in final_labels]
120
+ keep_indices = self._filter_indices(label_names, queries)
121
+
122
+ if not keep_indices:
123
+ empty = np.empty((0, 4), dtype=np.float32)
124
+ return DetectionResult(empty, [], [], [])
125
+
126
+ final_boxes = final_boxes[keep_indices]
127
+ final_scores = [final_scores[i] for i in keep_indices]
128
+ final_labels = [final_labels[i] for i in keep_indices]
129
+ final_names = [label_names[i] for i in keep_indices]
130
+
131
+ return DetectionResult(
132
+ boxes=final_boxes,
133
+ scores=final_scores,
134
+ labels=final_labels,
135
+ label_names=final_names
136
+ )
137
+
138
  def predict(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult:
139
+ h, w = frame.shape[:2]
140
+ if w > 3000:
141
+ return self._predict_tiled(frame, queries)
142
+
143
  results = self.model.predict(
144
  source=frame,
145
  device=self.device,
 
150
  return self._parse_single_result(results[0], queries)
151
 
152
  def predict_batch(self, frames: Sequence[np.ndarray], queries: Sequence[str]) -> Sequence[DetectionResult]:
153
+ if not frames: return []
154
+ h, w = frames[0].shape[:2]
155
+ if w > 3000:
156
+ return [self._predict_tiled(f, queries) for f in frames]
157
+
158
  results = self.model.predict(
159
  source=frames,
160
  device=self.device,
utils/tiling.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import logging
4
+ from typing import List, Tuple, Dict, Any, Optional
5
+
6
+ def get_slice_bboxes(
7
+ image_height: int,
8
+ image_width: int,
9
+ slice_height: int = 640,
10
+ slice_width: int = 640,
11
+ overlap_height_ratio: float = 0.2,
12
+ overlap_width_ratio: float = 0.2,
13
+ ) -> List[List[int]]:
14
+ """
15
+ Calculate bounding boxes for slices with overlap.
16
+ Returns: List of [x_min, y_min, x_max, y_max]
17
+ """
18
+ slice_bboxes = []
19
+ y_max = y_min = 0
20
+ y_overlap = int(slice_height * overlap_height_ratio)
21
+ x_overlap = int(slice_width * overlap_width_ratio)
22
+
23
+ while y_max < image_height:
24
+ x_min = x_max = 0
25
+ y_max = y_min + slice_height
26
+
27
+ while x_max < image_width:
28
+ x_max = x_min + slice_width
29
+
30
+ # Adjustment for boundaries
31
+ if y_max > image_height:
32
+ y_max = image_height
33
+ y_min = max(0, image_height - slice_height)
34
+
35
+ if x_max > image_width:
36
+ x_max = image_width
37
+ x_min = max(0, image_width - slice_width)
38
+
39
+ slice_bboxes.append([x_min, y_min, x_max, y_max])
40
+
41
+ x_min = x_max - x_overlap
42
+ y_min = y_max - y_overlap
43
+
44
+ return slice_bboxes
45
+
46
+ def slice_image(
47
+ image: np.ndarray,
48
+ slice_bboxes: List[List[int]]
49
+ ) -> List[np.ndarray]:
50
+ """Crops the image based on provided bounding boxes."""
51
+ slices = []
52
+ for bbox in slice_bboxes:
53
+ xmin, ymin, xmax, ymax = bbox
54
+ slices.append(image[ymin:ymax, xmin:xmax])
55
+ return slices
56
+
57
+ def shift_bboxes(
58
+ bboxes: List[List[float]],
59
+ slice_coords: List[int]
60
+ ) -> List[List[float]]:
61
+ """
62
+ Shifts bounding boxes from slice coordinates to global image coordinates.
63
+ slice_coords: [xmin, ymin, xmax, ymax]
64
+ bboxes: List of [xmin, ymin, xmax, ymax]
65
+ """
66
+ shift_x = slice_coords[0]
67
+ shift_y = slice_coords[1]
68
+
69
+ shifted = []
70
+ for box in bboxes:
71
+ # box = [x1, y1, x2, y2]
72
+ shifted.append([
73
+ box[0] + shift_x,
74
+ box[1] + shift_y,
75
+ box[2] + shift_x,
76
+ box[3] + shift_y
77
+ ])
78
+ return shifted
79
+
80
+ def batched_nms(
81
+ boxes: torch.Tensor,
82
+ scores: torch.Tensor,
83
+ idxs: torch.Tensor,
84
+ iou_threshold: float = 0.5
85
+ ) -> torch.Tensor:
86
+ """
87
+ Performs non-maximum suppression in a batched fashion.
88
+ Fallback to simple NMS if torchvision/ultralytics unavailable.
89
+ """
90
+ if boxes.numel() == 0:
91
+ return torch.empty((0,), dtype=torch.int64, device=boxes.device)
92
+
93
+ # Try importing efficient NMS implementations
94
+ try:
95
+ import torchvision
96
+ return torchvision.ops.batched_nms(boxes, scores, idxs, iou_threshold)
97
+ except ImportError:
98
+ pass
99
+
100
+ try:
101
+ from ultralytics.utils.ops import non_max_suppression
102
+ # Ultralytics NMS is usually complex/end-to-end. We need simple box NMS.
103
+ # Fallback to custom greedy NMS
104
+ except ImportError:
105
+ pass
106
+
107
+ # Custom Batched NMS Implementation (Slow but standard)
108
+ keep_indices = []
109
+ unique_labels = idxs.unique()
110
+
111
+ for label in unique_labels:
112
+ mask = (idxs == label)
113
+ cls_boxes = boxes[mask]
114
+ cls_scores = scores[mask]
115
+ original_indices = torch.where(mask)[0]
116
+
117
+ # Sort by score
118
+ sorted_indices = torch.argsort(cls_scores, descending=True)
119
+ cls_boxes = cls_boxes[sorted_indices]
120
+ original_indices = original_indices[sorted_indices]
121
+
122
+ cls_keep = []
123
+ while cls_boxes.size(0) > 0:
124
+ current_idx = 0
125
+ cls_keep.append(original_indices[current_idx])
126
+
127
+ if cls_boxes.size(0) == 1:
128
+ break
129
+
130
+ current_box = cls_boxes[current_idx].unsqueeze(0)
131
+ rest_boxes = cls_boxes[1:]
132
+
133
+ # IoU Calculation
134
+ x1 = torch.max(current_box[:, 0], rest_boxes[:, 0])
135
+ y1 = torch.max(current_box[:, 1], rest_boxes[:, 1])
136
+ x2 = torch.min(current_box[:, 2], rest_boxes[:, 2])
137
+ y2 = torch.min(current_box[:, 3], rest_boxes[:, 3])
138
+
139
+ inter_area = (x2 - x1).clamp(min=0) * (y2 - y1).clamp(min=0)
140
+ box_area = (current_box[:, 2] - current_box[:, 0]) * (current_box[:, 3] - current_box[:, 1])
141
+ rest_area = (rest_boxes[:, 2] - rest_boxes[:, 0]) * (rest_boxes[:, 3] - rest_boxes[:, 1])
142
+ union_area = box_area + rest_area - inter_area
143
+
144
+ iou = inter_area / (union_area + 1e-6)
145
+
146
+ # Keep boxes with low IoU
147
+ mask_iou = iou < iou_threshold
148
+ cls_boxes = rest_boxes[mask_iou]
149
+ original_indices = original_indices[1:][mask_iou]
150
+
151
+ keep_indices.extend(cls_keep)
152
+
153
+ return torch.tensor(keep_indices, dtype=torch.int64, device=boxes.device)