Zhen Ye Claude Opus 4.6 commited on
Commit
e1fbf50
·
1 Parent(s): 3c61b44

feat: replace drone_yolo with YOLOv8-VisDrone detector

Browse files

Remove broken drone_yolo detector (rujutashashikanjoshi repo not found)
and replace with Mahadih534/YoloV8-VisDrone. Use hf_hub_download for
reliable weight fetching. Allow drone_detection mode to use user-selected
detector instead of hardcoding.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

app.py CHANGED
@@ -264,7 +264,7 @@ async def detect_endpoint(
264
  detector: Model to use (yolo11, detr_resnet50, grounding_dino)
265
  segmenter: Segmentation model to use (GSAM2-S/B/L, YSAM2-S/B/L)
266
  enable_depth: Whether to run legacy depth estimation (default: False)
267
- drone_detection uses the dedicated drone_yolo model.
268
 
269
  Returns:
270
  - For object_detection: Processed video with bounding boxes
@@ -344,7 +344,7 @@ async def detect_endpoint(
344
  os.close(fd)
345
 
346
  # Parse queries with mission awareness
347
- detector_name = "drone_yolo" if mode == "drone_detection" else detector
348
  mission_spec = None
349
 
350
  if queries.strip():
@@ -447,8 +447,8 @@ async def detect_async_endpoint(
447
  detector_name = detector
448
  mission_detector = detector # detector key used for mission query parsing
449
  if mode == "drone_detection":
450
- detector_name = "drone_yolo"
451
- mission_detector = "drone_yolo"
452
  elif mode == "segmentation":
453
  # Segmenter registry owns detector selection (GSAM2→GDINO, YSAM2→YOLO).
454
  # detector_name=None so the job doesn't forward it (avoids duplicate kwarg).
 
264
  detector: Model to use (yolo11, detr_resnet50, grounding_dino)
265
  segmenter: Segmentation model to use (GSAM2-S/B/L, YSAM2-S/B/L)
266
  enable_depth: Whether to run legacy depth estimation (default: False)
267
+ drone_detection uses the dedicated yolov8_visdrone model.
268
 
269
  Returns:
270
  - For object_detection: Processed video with bounding boxes
 
344
  os.close(fd)
345
 
346
  # Parse queries with mission awareness
347
+ detector_name = (detector or "yolov8_visdrone") if mode == "drone_detection" else detector
348
  mission_spec = None
349
 
350
  if queries.strip():
 
447
  detector_name = detector
448
  mission_detector = detector # detector key used for mission query parsing
449
  if mode == "drone_detection":
450
+ detector_name = detector or "yolov8_visdrone"
451
+ mission_detector = detector_name
452
  elif mode == "segmentation":
453
  # Segmenter registry owns detector selection (GSAM2→GDINO, YSAM2→YOLO).
454
  # detector_name=None so the job doesn't forward it (avoids duplicate kwarg).
frontend/index.html CHANGED
@@ -83,7 +83,6 @@
83
  <option value="YSAM2-S" data-kind="segmentation">YSAM2-S (Fast)</option>
84
  </optgroup>
85
  <optgroup label="Drone Detection Models">
86
- <option value="drone_yolo" data-kind="drone">Drone</option>
87
  <option value="yolov8_visdrone" data-kind="drone">VisDrone (YOLOv8)</option>
88
  </optgroup>
89
 
 
83
  <option value="YSAM2-S" data-kind="segmentation">YSAM2-S (Fast)</option>
84
  </optgroup>
85
  <optgroup label="Drone Detection Models">
 
86
  <option value="yolov8_visdrone" data-kind="drone">VisDrone (YOLOv8)</option>
87
  </optgroup>
88
 
models/detectors/drone_yolo.py DELETED
@@ -1,182 +0,0 @@
1
- import logging
2
- from typing import List, Sequence
3
-
4
- import numpy as np
5
- import torch
6
- from ultralytics import YOLO
7
-
8
- from models.detectors.base import DetectionResult, ObjectDetector
9
- from utils.tiling import get_slice_bboxes, slice_image, shift_bboxes, batched_nms
10
-
11
-
12
- class DroneYoloDetector(ObjectDetector):
13
- """Drone detector backed by a YOLO model on the Hugging Face Hub."""
14
-
15
- REPO_ID = "rujutashashikanjoshi/yolo12-drone-detection-0205-100m"
16
- supports_batch = True
17
- max_batch_size = 32
18
-
19
- def __init__(self, score_threshold: float = 0.3, device: str = None) -> None:
20
- self.name = "drone_yolo"
21
- self.score_threshold = score_threshold
22
- # CRITICAL: Store device as torch.device, NOT a string.
23
- # Ultralytics' select_device() sets CUDA_VISIBLE_DEVICES when it
24
- # receives a string like "cuda:0", restricting the entire process to
25
- # one GPU. Passing a torch.device object causes select_device() to
26
- # return immediately without touching the environment.
27
- if device:
28
- self.device = torch.device(device)
29
- else:
30
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
31
- logging.info(
32
- "Loading drone YOLO from HuggingFace Hub: %s onto %s",
33
- self.REPO_ID,
34
- self.device,
35
- )
36
- # Load directly from HuggingFace Hub using ultralytics native support
37
- self.model = YOLO(f"hf://{self.REPO_ID}")
38
- self.model.to(self.device)
39
- self.class_names = self.model.names
40
-
41
- def _filter_indices(self, label_names: Sequence[str], queries: Sequence[str]) -> List[int]:
42
- if not queries:
43
- return list(range(len(label_names)))
44
- allowed = {query.lower().strip() for query in queries if query}
45
- keep = [idx for idx, name in enumerate(label_names) if name.lower() in allowed]
46
- return keep or list(range(len(label_names)))
47
-
48
- def _parse_single_result(self, result, queries: Sequence[str]) -> DetectionResult:
49
- boxes = result.boxes
50
- if boxes is None or boxes.xyxy is None:
51
- empty = np.empty((0, 4), dtype=np.float32)
52
- return DetectionResult(empty, [], [], [])
53
-
54
- xyxy = boxes.xyxy.cpu().numpy()
55
- scores = boxes.conf.cpu().numpy().tolist()
56
- label_ids = boxes.cls.cpu().numpy().astype(int).tolist()
57
- label_names = [self.class_names.get(idx, f"class_{idx}") for idx in label_ids]
58
- keep_indices = self._filter_indices(label_names, queries)
59
- xyxy = xyxy[keep_indices] if len(xyxy) else xyxy
60
- scores = [scores[i] for i in keep_indices]
61
- label_ids = [label_ids[i] for i in keep_indices]
62
- label_names = [label_names[i] for i in keep_indices]
63
- return DetectionResult(
64
- boxes=xyxy,
65
- scores=scores,
66
- labels=label_ids,
67
- label_names=label_names,
68
- )
69
-
70
-
71
- def _predict_tiled(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult:
72
- """Run tiled inference for high-resolution frames."""
73
- # 1. Slice
74
- h, w = frame.shape[:2]
75
- # Heuristic: 1280x1280 tiles with 20% overlap
76
- slice_boxes = get_slice_bboxes(h, w, 1280, 1280, 0.2, 0.2)
77
- tiles = slice_image(frame, slice_boxes)
78
-
79
- # 2. Batch Inference
80
- # We can use our own model's batch prediction if we can trust it not to recurse strictly
81
- # But we need raw results to merge.
82
- # Actually proper way: run standard predict on tiles.
83
-
84
- all_boxes = []
85
- all_scores = []
86
- all_labels = []
87
-
88
- # Run in batches of max_batch_size to respect GPU memory
89
- batch_size = self.max_batch_size
90
- for i in range(0, len(tiles), batch_size):
91
- batch_tiles = tiles[i : i + batch_size]
92
- batch_slices = slice_boxes[i : i + batch_size]
93
-
94
- results = self.model.predict(
95
- source=batch_tiles,
96
- device=self.device,
97
- conf=self.score_threshold,
98
- imgsz=1280, # Run tiles at full res
99
- verbose=False,
100
- )
101
-
102
- for res, slice_coord in zip(results, batch_slices):
103
- if res.boxes is None: continue
104
- # Extract standard results
105
- boxes = res.boxes.xyxy.cpu().numpy().tolist()
106
- scores = res.boxes.conf.cpu().numpy().tolist()
107
- clss = res.boxes.cls.cpu().numpy().tolist()
108
-
109
- # Shift to global
110
- shifted = shift_bboxes(boxes, slice_coord)
111
-
112
- all_boxes.extend(shifted)
113
- all_scores.extend(scores)
114
- all_labels.extend(clss)
115
-
116
- if not all_boxes:
117
- empty = np.empty((0, 4), dtype=np.float32)
118
- return DetectionResult(empty, [], [], [])
119
-
120
- # 3. NMS Merge
121
- boxes_t = torch.tensor(all_boxes, device=self.device)
122
- scores_t = torch.tensor(all_scores, device=self.device)
123
- labels_t = torch.tensor(all_labels, device=self.device)
124
-
125
- keep = batched_nms(boxes_t, scores_t, labels_t, iou_threshold=0.4)
126
-
127
- final_boxes = boxes_t[keep].cpu().numpy()
128
- final_scores = scores_t[keep].cpu().tolist()
129
- final_labels = labels_t[keep].cpu().int().tolist()
130
-
131
- # 4. Filter & Format
132
- label_names = [self.class_names.get(idx, f"class_{idx}") for idx in final_labels]
133
- keep_indices = self._filter_indices(label_names, queries)
134
-
135
- if not keep_indices:
136
- empty = np.empty((0, 4), dtype=np.float32)
137
- return DetectionResult(empty, [], [], [])
138
-
139
- final_boxes = final_boxes[keep_indices]
140
- final_scores = [final_scores[i] for i in keep_indices]
141
- final_labels = [final_labels[i] for i in keep_indices]
142
- final_names = [label_names[i] for i in keep_indices]
143
-
144
- return DetectionResult(
145
- boxes=final_boxes,
146
- scores=final_scores,
147
- labels=final_labels,
148
- label_names=final_names
149
- )
150
-
151
- def predict(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult:
152
- h, w = frame.shape[:2]
153
- # Enable tiling for 4Kish images (width > 3000)
154
- if w > 3000:
155
- return self._predict_tiled(frame, queries)
156
-
157
- device_arg = self.device
158
- results = self.model.predict(
159
- source=frame,
160
- device=device_arg,
161
- conf=self.score_threshold,
162
- imgsz=1280,
163
- verbose=False,
164
- )
165
- return self._parse_single_result(results[0], queries)
166
-
167
- def predict_batch(self, frames: Sequence[np.ndarray], queries: Sequence[str]) -> Sequence[DetectionResult]:
168
- # Mixed batch support is hard. Assume batch is uniform size.
169
- if not frames: return []
170
- h, w = frames[0].shape[:2]
171
-
172
- if w > 3000:
173
- return [self._predict_tiled(f, queries) for f in frames]
174
-
175
- results = self.model.predict(
176
- source=frames,
177
- device=self.device,
178
- conf=self.score_threshold,
179
- imgsz=1280,
180
- verbose=False,
181
- )
182
- return [self._parse_single_result(r, queries) for r in results]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/detectors/yolov8_visdrone.py CHANGED
@@ -1,13 +1,20 @@
1
  import logging
 
 
2
  from typing import List, Sequence
3
 
4
  import numpy as np
5
  import torch
 
6
  from ultralytics import YOLO
7
 
8
  from models.detectors.base import DetectionResult, ObjectDetector
9
  from utils.tiling import get_slice_bboxes, slice_image, shift_bboxes, batched_nms
10
 
 
 
 
 
11
 
12
  class YoloV8VisDroneDetector(ObjectDetector):
13
  """YOLOv8 detector fine-tuned on VisDrone dataset for aerial imagery."""
@@ -28,7 +35,14 @@ class YoloV8VisDroneDetector(ObjectDetector):
28
  self.REPO_ID,
29
  self.device,
30
  )
31
- self.model = YOLO(f"hf://{self.REPO_ID}")
 
 
 
 
 
 
 
32
  self.model.to(self.device)
33
  self.class_names = self.model.names
34
 
 
1
  import logging
2
+ import os
3
+ from pathlib import Path
4
  from typing import List, Sequence
5
 
6
  import numpy as np
7
  import torch
8
+ from huggingface_hub import hf_hub_download
9
  from ultralytics import YOLO
10
 
11
  from models.detectors.base import DetectionResult, ObjectDetector
12
  from utils.tiling import get_slice_bboxes, slice_image, shift_bboxes, batched_nms
13
 
14
+ _WEIGHTS_CACHE = Path(os.environ.get("YOLO_CACHE", "/tmp/yolo_weights"))
15
+ _WEIGHTS_CACHE.mkdir(parents=True, exist_ok=True)
16
+ _VISDRONE_PATH = _WEIGHTS_CACHE / "visDrone.pt"
17
+
18
 
19
  class YoloV8VisDroneDetector(ObjectDetector):
20
  """YOLOv8 detector fine-tuned on VisDrone dataset for aerial imagery."""
 
35
  self.REPO_ID,
36
  self.device,
37
  )
38
+ if not _VISDRONE_PATH.exists():
39
+ logging.info("Downloading visDrone.pt to %s ...", _VISDRONE_PATH)
40
+ hf_hub_download(
41
+ repo_id=self.REPO_ID,
42
+ filename="visDrone.pt",
43
+ local_dir=str(_WEIGHTS_CACHE),
44
+ )
45
+ self.model = YOLO(str(_VISDRONE_PATH))
46
  self.model.to(self.device)
47
  self.class_names = self.model.names
48
 
models/model_loader.py CHANGED
@@ -4,7 +4,6 @@ from typing import Callable, Dict, Optional
4
 
5
  from models.detectors.base import ObjectDetector
6
  from models.detectors.detr import DetrDetector
7
- from models.detectors.drone_yolo import DroneYoloDetector
8
  from models.detectors.grounding_dino import GroundingDinoDetector
9
  from models.detectors.yolov11 import Yolo11Detector
10
  from models.detectors.yolov8_visdrone import YoloV8VisDroneDetector
@@ -16,7 +15,6 @@ _REGISTRY: Dict[str, Callable[[], ObjectDetector]] = {
16
  "yolo11": Yolo11Detector,
17
  "detr_resnet50": DetrDetector,
18
  "grounding_dino": GroundingDinoDetector,
19
- "drone_yolo": DroneYoloDetector,
20
  "yolov8_visdrone": YoloV8VisDroneDetector,
21
  }
22
 
 
4
 
5
  from models.detectors.base import ObjectDetector
6
  from models.detectors.detr import DetrDetector
 
7
  from models.detectors.grounding_dino import GroundingDinoDetector
8
  from models.detectors.yolov11 import Yolo11Detector
9
  from models.detectors.yolov8_visdrone import YoloV8VisDroneDetector
 
15
  "yolo11": Yolo11Detector,
16
  "detr_resnet50": DetrDetector,
17
  "grounding_dino": GroundingDinoDetector,
 
18
  "yolov8_visdrone": YoloV8VisDroneDetector,
19
  }
20
 
utils/profiler.py CHANGED
@@ -20,7 +20,7 @@ logger = logging.getLogger(__name__)
20
  # Detectors whose predict() can be decomposed into processor -> model -> post_process
21
  _DECOMPOSABLE_DETECTORS = {"detr_resnet50", "grounding_dino"}
22
  # Detectors with opaque predict() calls (YOLO-based)
23
- _OPAQUE_DETECTORS = {"yolo11", "drone_yolo"}
24
 
25
 
26
  @dataclass
 
20
  # Detectors whose predict() can be decomposed into processor -> model -> post_process
21
  _DECOMPOSABLE_DETECTORS = {"detr_resnet50", "grounding_dino"}
22
  # Detectors with opaque predict() calls (YOLO-based)
23
+ _OPAQUE_DETECTORS = {"yolo11", "yolov8_visdrone"}
24
 
25
 
26
  @dataclass
utils/roofline.py CHANGED
@@ -18,7 +18,7 @@ _MODEL_FLOPS: Dict[str, float] = {
18
  "yolo11": 78.9, # YOLO11m ~79 GFLOPs at 640px
19
  "detr_resnet50": 86.0, # DETR-R50 ~86 GFLOPs at 800px
20
  "grounding_dino": 172.0, # Grounding DINO-B ~172 GFLOPs
21
- "drone_yolo": 78.9, # Same arch as YOLO11m-class model
22
 
23
  # Segmentation models (GFLOPs per keyframe)
24
  "GSAM2-S": 48.0, # SAM2 small encoder
@@ -37,7 +37,7 @@ _MODEL_BYTES: Dict[str, float] = {
37
  "yolo11": 52.0,
38
  "detr_resnet50": 166.0,
39
  "grounding_dino": 340.0,
40
- "drone_yolo": 52.0,
41
  "GSAM2-S": 92.0,
42
  "GSAM2-B": 180.0,
43
  "GSAM2-L": 400.0,
 
18
  "yolo11": 78.9, # YOLO11m ~79 GFLOPs at 640px
19
  "detr_resnet50": 86.0, # DETR-R50 ~86 GFLOPs at 800px
20
  "grounding_dino": 172.0, # Grounding DINO-B ~172 GFLOPs
21
+ "yolov8_visdrone": 78.9, # YOLOv8 VisDrone model
22
 
23
  # Segmentation models (GFLOPs per keyframe)
24
  "GSAM2-S": 48.0, # SAM2 small encoder
 
37
  "yolo11": 52.0,
38
  "detr_resnet50": 166.0,
39
  "grounding_dino": 340.0,
40
+ "yolov8_visdrone": 52.0,
41
  "GSAM2-S": 92.0,
42
  "GSAM2-B": 180.0,
43
  "GSAM2-L": 400.0,