Zhen Ye commited on
Commit
537aca9
·
1 Parent(s): 6c02470

added drone detector

Browse files
Files changed (4) hide show
  1. app.py +7 -13
  2. demo.html +26 -114
  3. models/detectors/drone_yolo.py +70 -0
  4. models/model_loader.py +2 -0
app.py CHANGED
@@ -82,11 +82,12 @@ async def detect_endpoint(
82
  queries: Comma-separated object classes for object_detection mode
83
  detector: Model to use (hf_yolov8, detr_resnet50, grounding_dino)
84
  segmenter: Segmentation model to use (sam3)
 
85
 
86
  Returns:
87
  - For object_detection: Processed video with bounding boxes
88
  - For segmentation: Processed video with masks rendered
89
- - For drone_detection: JSON with "coming_soon" status
90
  """
91
  # Validate mode
92
  if mode not in VALID_MODES:
@@ -142,17 +143,7 @@ async def detect_endpoint(
142
  filename="segmented.mp4",
143
  )
144
 
145
- if mode == "drone_detection":
146
- return JSONResponse(
147
- status_code=200,
148
- content={
149
- "status": "coming_soon",
150
- "message": "Drone detection mode is under development. Stay tuned!",
151
- "mode": "drone_detection"
152
- }
153
- )
154
-
155
- # Handle object detection mode
156
  if video is None:
157
  raise HTTPException(status_code=400, detail="Video file is required.")
158
 
@@ -171,14 +162,17 @@ async def detect_endpoint(
171
 
172
  # Parse queries
173
  query_list = [q.strip() for q in queries.split(",") if q.strip()]
 
 
174
 
175
  # Run inference
176
  try:
 
177
  output_path = run_inference(
178
  input_path,
179
  output_path,
180
  query_list,
181
- detector_name=detector,
182
  )
183
  except ValueError as exc:
184
  logging.exception("Video processing failed.")
 
82
  queries: Comma-separated object classes for object_detection mode
83
  detector: Model to use (hf_yolov8, detr_resnet50, grounding_dino)
84
  segmenter: Segmentation model to use (sam3)
85
+ drone_detection uses the dedicated drone_yolo model.
86
 
87
  Returns:
88
  - For object_detection: Processed video with bounding boxes
89
  - For segmentation: Processed video with masks rendered
90
+ - For drone_detection: Processed video with bounding boxes
91
  """
92
  # Validate mode
93
  if mode not in VALID_MODES:
 
143
  filename="segmented.mp4",
144
  )
145
 
146
+ # Handle object detection or drone detection mode
 
 
 
 
 
 
 
 
 
 
147
  if video is None:
148
  raise HTTPException(status_code=400, detail="Video file is required.")
149
 
 
162
 
163
  # Parse queries
164
  query_list = [q.strip() for q in queries.split(",") if q.strip()]
165
+ if mode == "drone_detection" and not query_list:
166
+ query_list = ["drone"]
167
 
168
  # Run inference
169
  try:
170
+ detector_name = "drone_yolo" if mode == "drone_detection" else detector
171
  output_path = run_inference(
172
  input_path,
173
  output_path,
174
  query_list,
175
+ detector_name=detector_name,
176
  )
177
  except ValueError as exc:
178
  logging.exception("Video processing failed.")
demo.html CHANGED
@@ -276,55 +276,6 @@
276
  display: none;
277
  }
278
 
279
- /* Modal */
280
- .modal {
281
- display: none;
282
- position: fixed;
283
- z-index: 1000;
284
- left: 0;
285
- top: 0;
286
- width: 100%;
287
- height: 100%;
288
- background: rgba(15, 23, 42, 0.5);
289
- align-items: center;
290
- justify-content: center;
291
- }
292
-
293
- .modal.show {
294
- display: flex;
295
- }
296
-
297
- .modal-content {
298
- background: white;
299
- padding: 30px;
300
- border-radius: 12px;
301
- max-width: 500px;
302
- text-align: center;
303
- }
304
-
305
- .modal-content h2 {
306
- margin-bottom: 15px;
307
- color: #333;
308
- }
309
-
310
- .modal-content p {
311
- margin-bottom: 20px;
312
- color: #666;
313
- }
314
-
315
- .modal-btn {
316
- padding: 10px 24px;
317
- background: #1f2933;
318
- color: #f9fafb;
319
- border: none;
320
- border-radius: 6px;
321
- cursor: pointer;
322
- font-size: 1rem;
323
- }
324
-
325
- .modal-btn:hover {
326
- background: #111827;
327
- }
328
  </style>
329
  </head>
330
  <body>
@@ -346,10 +297,9 @@
346
  <div class="mode-title">Segmentation</div>
347
  </label>
348
 
349
- <label class="mode-card disabled">
350
  <input type="radio" name="mode" value="drone_detection">
351
  <div class="mode-title">Drone Detection</div>
352
- <span class="mode-badge">COMING SOON</span>
353
  </label>
354
  </div>
355
  </div>
@@ -391,6 +341,16 @@
391
  </div>
392
  </div>
393
 
 
 
 
 
 
 
 
 
 
 
394
  <!-- Video Upload -->
395
  <div class="section">
396
  <div class="input-group">
@@ -441,15 +401,6 @@
441
  </div>
442
  </div>
443
 
444
- <!-- Coming Soon Modal -->
445
- <div class="modal" id="comingSoonModal">
446
- <div class="modal-content">
447
- <h2>Coming Soon!</h2>
448
- <p id="modalMessage"></p>
449
- <button class="modal-btn" id="modalClose">Got it</button>
450
- </div>
451
- </div>
452
-
453
  <script>
454
  // State
455
  let selectedMode = 'object_detection';
@@ -462,6 +413,7 @@
462
  const queriesHint = document.getElementById('queriesHint');
463
  const detectorSection = document.getElementById('detectorSection');
464
  const segmenterSection = document.getElementById('segmenterSection');
 
465
  const fileInput = document.getElementById('videoFile');
466
  const fileLabel = document.getElementById('fileLabel');
467
  const processBtn = document.getElementById('processBtn');
@@ -470,23 +422,12 @@
470
  const originalVideo = document.getElementById('originalVideo');
471
  const processedVideo = document.getElementById('processedVideo');
472
  const downloadBtn = document.getElementById('downloadBtn');
473
- const modal = document.getElementById('comingSoonModal');
474
- const modalMessage = document.getElementById('modalMessage');
475
- const modalClose = document.getElementById('modalClose');
476
-
477
  // Mode selection handler
478
  modeCards.forEach(card => {
479
  card.addEventListener('click', (e) => {
480
  const input = card.querySelector('input[type="radio"]');
481
  const mode = input.value;
482
 
483
- // Check if disabled
484
- if (card.classList.contains('disabled')) {
485
- e.preventDefault();
486
- showComingSoonModal(mode);
487
- return;
488
- }
489
-
490
  // Update selected state
491
  modeCards.forEach(c => c.classList.remove('selected'));
492
  card.classList.add('selected');
@@ -498,16 +439,19 @@
498
  queriesHint.textContent = 'Example: person, car, dog, bicycle';
499
  detectorSection.classList.remove('hidden');
500
  segmenterSection.classList.add('hidden');
 
501
  } else if (mode === 'segmentation') {
502
  queriesLabel.textContent = 'Objects to Segment (comma-separated)';
503
  queriesHint.textContent = 'Example: person, car, building, tree';
504
  detectorSection.classList.add('hidden');
505
  segmenterSection.classList.remove('hidden');
 
506
  } else if (mode === 'drone_detection') {
507
- queriesLabel.textContent = 'Drone Types to Detect (comma-separated)';
508
- queriesHint.textContent = 'Example: quadcopter, fixed-wing, drone';
509
  detectorSection.classList.add('hidden');
510
  segmenterSection.classList.add('hidden');
 
511
  }
512
 
513
  // Always show queries section
@@ -555,20 +499,17 @@
555
  });
556
 
557
  if (response.ok) {
558
- const contentType = response.headers.get('content-type');
559
-
560
- if (contentType && contentType.includes('application/json')) {
561
- // Coming soon response
562
  const data = await response.json();
563
- showComingSoonModal(data.mode);
564
- } else {
565
- // Video response
566
- const blob = await response.blob();
567
- const videoUrl = URL.createObjectURL(blob);
568
- processedVideo.src = videoUrl;
569
- downloadBtn.href = videoUrl;
570
- resultsSection.classList.remove('hidden');
571
  }
 
 
 
 
 
572
  } else {
573
  const error = await response.json();
574
  alert(`Error: ${error.detail || error.error || 'Processing failed'}`);
@@ -582,35 +523,6 @@
582
  }
583
  });
584
 
585
- // Coming soon modal
586
- function showComingSoonModal(mode) {
587
- const messages = {
588
- 'drone_detection': 'Drone detection mode is under development. Stay tuned for specialized UAV and aerial object detection!'
589
- };
590
- modalMessage.textContent = messages[mode] || 'This feature is coming soon!';
591
- modal.classList.add('show');
592
- }
593
-
594
- modalClose.addEventListener('click', () => {
595
- modal.classList.remove('show');
596
- // Reset to object detection
597
- document.querySelector('input[value="object_detection"]').checked = true;
598
- modeCards.forEach(c => c.classList.remove('selected'));
599
- document.querySelector('input[value="object_detection"]').closest('.mode-card').classList.add('selected');
600
- selectedMode = 'object_detection';
601
- // Update labels for object detection mode
602
- queriesLabel.textContent = 'Objects to Detect (comma-separated)';
603
- queriesHint.textContent = 'Example: person, car, dog, bicycle';
604
- detectorSection.classList.remove('hidden');
605
- segmenterSection.classList.add('hidden');
606
- });
607
-
608
- // Close modal on background click
609
- modal.addEventListener('click', (e) => {
610
- if (e.target === modal) {
611
- modalClose.click();
612
- }
613
- });
614
  </script>
615
  </body>
616
  </html>
 
276
  display: none;
277
  }
278
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
  </style>
280
  </head>
281
  <body>
 
297
  <div class="mode-title">Segmentation</div>
298
  </label>
299
 
300
+ <label class="mode-card">
301
  <input type="radio" name="mode" value="drone_detection">
302
  <div class="mode-title">Drone Detection</div>
 
303
  </label>
304
  </div>
305
  </div>
 
341
  </div>
342
  </div>
343
 
344
+ <!-- Drone Model Selection -->
345
+ <div class="section hidden" id="droneModelSection">
346
+ <div class="input-group">
347
+ <label for="droneModel">2. Select Drone Model</label>
348
+ <select id="droneModel" disabled>
349
+ <option value="drone_yolo">Drone YOLO (HF pretrained)</option>
350
+ </select>
351
+ </div>
352
+ </div>
353
+
354
  <!-- Video Upload -->
355
  <div class="section">
356
  <div class="input-group">
 
401
  </div>
402
  </div>
403
 
 
 
 
 
 
 
 
 
 
404
  <script>
405
  // State
406
  let selectedMode = 'object_detection';
 
413
  const queriesHint = document.getElementById('queriesHint');
414
  const detectorSection = document.getElementById('detectorSection');
415
  const segmenterSection = document.getElementById('segmenterSection');
416
+ const droneModelSection = document.getElementById('droneModelSection');
417
  const fileInput = document.getElementById('videoFile');
418
  const fileLabel = document.getElementById('fileLabel');
419
  const processBtn = document.getElementById('processBtn');
 
422
  const originalVideo = document.getElementById('originalVideo');
423
  const processedVideo = document.getElementById('processedVideo');
424
  const downloadBtn = document.getElementById('downloadBtn');
 
 
 
 
425
  // Mode selection handler
426
  modeCards.forEach(card => {
427
  card.addEventListener('click', (e) => {
428
  const input = card.querySelector('input[type="radio"]');
429
  const mode = input.value;
430
 
 
 
 
 
 
 
 
431
  // Update selected state
432
  modeCards.forEach(c => c.classList.remove('selected'));
433
  card.classList.add('selected');
 
439
  queriesHint.textContent = 'Example: person, car, dog, bicycle';
440
  detectorSection.classList.remove('hidden');
441
  segmenterSection.classList.add('hidden');
442
+ droneModelSection.classList.add('hidden');
443
  } else if (mode === 'segmentation') {
444
  queriesLabel.textContent = 'Objects to Segment (comma-separated)';
445
  queriesHint.textContent = 'Example: person, car, building, tree';
446
  detectorSection.classList.add('hidden');
447
  segmenterSection.classList.remove('hidden');
448
+ droneModelSection.classList.add('hidden');
449
  } else if (mode === 'drone_detection') {
450
+ queriesLabel.textContent = 'Optional Labels (comma-separated)';
451
+ queriesHint.textContent = 'Example: drone, quadcopter';
452
  detectorSection.classList.add('hidden');
453
  segmenterSection.classList.add('hidden');
454
+ droneModelSection.classList.remove('hidden');
455
  }
456
 
457
  // Always show queries section
 
499
  });
500
 
501
  if (response.ok) {
502
+ const contentType = response.headers.get('content-type') || '';
503
+ if (contentType.includes('application/json')) {
 
 
504
  const data = await response.json();
505
+ alert(data.message || 'Request completed.');
506
+ return;
 
 
 
 
 
 
507
  }
508
+ const blob = await response.blob();
509
+ const videoUrl = URL.createObjectURL(blob);
510
+ processedVideo.src = videoUrl;
511
+ downloadBtn.href = videoUrl;
512
+ resultsSection.classList.remove('hidden');
513
  } else {
514
  const error = await response.json();
515
  alert(`Error: ${error.detail || error.error || 'Processing failed'}`);
 
523
  }
524
  });
525
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
526
  </script>
527
  </body>
528
  </html>
models/detectors/drone_yolo.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from typing import List, Sequence
4
+
5
+ import numpy as np
6
+ import torch
7
+ from huggingface_hub import hf_hub_download
8
+ from ultralytics import YOLO
9
+
10
+ from models.detectors.base import DetectionResult, ObjectDetector
11
+
12
+
13
+ class DroneYoloDetector(ObjectDetector):
14
+ """Drone detector backed by a YOLO model on the Hugging Face Hub."""
15
+
16
+ REPO_ID = "rujutashashikanjoshi/yolo12-drone-detection-0205-100m"
17
+ DEFAULT_WEIGHT = "best.pt"
18
+
19
+ def __init__(self, score_threshold: float = 0.3) -> None:
20
+ self.name = "drone_yolo"
21
+ self.score_threshold = score_threshold
22
+ self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
23
+ weight_file = os.getenv("DRONE_YOLO_WEIGHT", self.DEFAULT_WEIGHT)
24
+ logging.info(
25
+ "Loading drone YOLO weights %s/%s onto %s",
26
+ self.REPO_ID,
27
+ weight_file,
28
+ self.device,
29
+ )
30
+ weight_path = hf_hub_download(repo_id=self.REPO_ID, filename=weight_file)
31
+ self.model = YOLO(weight_path)
32
+ self.model.to(self.device)
33
+ self.class_names = self.model.names
34
+
35
+ def _filter_indices(self, label_names: Sequence[str], queries: Sequence[str]) -> List[int]:
36
+ if not queries:
37
+ return list(range(len(label_names)))
38
+ allowed = {query.lower().strip() for query in queries if query}
39
+ keep = [idx for idx, name in enumerate(label_names) if name.lower() in allowed]
40
+ return keep or list(range(len(label_names)))
41
+
42
+ def predict(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult:
43
+ device_arg = 0 if self.device.startswith("cuda") else "cpu"
44
+ results = self.model.predict(
45
+ source=frame,
46
+ device=device_arg,
47
+ conf=self.score_threshold,
48
+ verbose=False,
49
+ )
50
+ result = results[0]
51
+ boxes = result.boxes
52
+ if boxes is None or boxes.xyxy is None:
53
+ empty = np.empty((0, 4), dtype=np.float32)
54
+ return DetectionResult(empty, [], [], [])
55
+
56
+ xyxy = boxes.xyxy.cpu().numpy()
57
+ scores = boxes.conf.cpu().numpy().tolist()
58
+ label_ids = boxes.cls.cpu().numpy().astype(int).tolist()
59
+ label_names = [self.class_names.get(idx, f"class_{idx}") for idx in label_ids]
60
+ keep_indices = self._filter_indices(label_names, queries)
61
+ xyxy = xyxy[keep_indices] if len(xyxy) else xyxy
62
+ scores = [scores[i] for i in keep_indices]
63
+ label_ids = [label_ids[i] for i in keep_indices]
64
+ label_names = [label_names[i] for i in keep_indices]
65
+ return DetectionResult(
66
+ boxes=xyxy,
67
+ scores=scores,
68
+ labels=label_ids,
69
+ label_names=label_names,
70
+ )
models/model_loader.py CHANGED
@@ -4,6 +4,7 @@ from typing import Callable, Dict, Optional
4
 
5
  from models.detectors.base import ObjectDetector
6
  from models.detectors.detr import DetrDetector
 
7
  from models.detectors.grounding_dino import GroundingDinoDetector
8
  from models.detectors.yolov8 import HuggingFaceYoloV8Detector
9
 
@@ -13,6 +14,7 @@ _REGISTRY: Dict[str, Callable[[], ObjectDetector]] = {
13
  "hf_yolov8": HuggingFaceYoloV8Detector,
14
  "detr_resnet50": DetrDetector,
15
  "grounding_dino": GroundingDinoDetector,
 
16
  }
17
 
18
 
 
4
 
5
  from models.detectors.base import ObjectDetector
6
  from models.detectors.detr import DetrDetector
7
+ from models.detectors.drone_yolo import DroneYoloDetector
8
  from models.detectors.grounding_dino import GroundingDinoDetector
9
  from models.detectors.yolov8 import HuggingFaceYoloV8Detector
10
 
 
14
  "hf_yolov8": HuggingFaceYoloV8Detector,
15
  "detr_resnet50": DetrDetector,
16
  "grounding_dino": GroundingDinoDetector,
17
+ "drone_yolo": DroneYoloDetector,
18
  }
19
 
20