Spaces:
Running
Running
Zhen Ye
commited on
Commit
·
537aca9
1
Parent(s):
6c02470
added drone detector
Browse files- app.py +7 -13
- demo.html +26 -114
- models/detectors/drone_yolo.py +70 -0
- models/model_loader.py +2 -0
app.py
CHANGED
|
@@ -82,11 +82,12 @@ async def detect_endpoint(
|
|
| 82 |
queries: Comma-separated object classes for object_detection mode
|
| 83 |
detector: Model to use (hf_yolov8, detr_resnet50, grounding_dino)
|
| 84 |
segmenter: Segmentation model to use (sam3)
|
|
|
|
| 85 |
|
| 86 |
Returns:
|
| 87 |
- For object_detection: Processed video with bounding boxes
|
| 88 |
- For segmentation: Processed video with masks rendered
|
| 89 |
-
- For drone_detection:
|
| 90 |
"""
|
| 91 |
# Validate mode
|
| 92 |
if mode not in VALID_MODES:
|
|
@@ -142,17 +143,7 @@ async def detect_endpoint(
|
|
| 142 |
filename="segmented.mp4",
|
| 143 |
)
|
| 144 |
|
| 145 |
-
|
| 146 |
-
return JSONResponse(
|
| 147 |
-
status_code=200,
|
| 148 |
-
content={
|
| 149 |
-
"status": "coming_soon",
|
| 150 |
-
"message": "Drone detection mode is under development. Stay tuned!",
|
| 151 |
-
"mode": "drone_detection"
|
| 152 |
-
}
|
| 153 |
-
)
|
| 154 |
-
|
| 155 |
-
# Handle object detection mode
|
| 156 |
if video is None:
|
| 157 |
raise HTTPException(status_code=400, detail="Video file is required.")
|
| 158 |
|
|
@@ -171,14 +162,17 @@ async def detect_endpoint(
|
|
| 171 |
|
| 172 |
# Parse queries
|
| 173 |
query_list = [q.strip() for q in queries.split(",") if q.strip()]
|
|
|
|
|
|
|
| 174 |
|
| 175 |
# Run inference
|
| 176 |
try:
|
|
|
|
| 177 |
output_path = run_inference(
|
| 178 |
input_path,
|
| 179 |
output_path,
|
| 180 |
query_list,
|
| 181 |
-
detector_name=
|
| 182 |
)
|
| 183 |
except ValueError as exc:
|
| 184 |
logging.exception("Video processing failed.")
|
|
|
|
| 82 |
queries: Comma-separated object classes for object_detection mode
|
| 83 |
detector: Model to use (hf_yolov8, detr_resnet50, grounding_dino)
|
| 84 |
segmenter: Segmentation model to use (sam3)
|
| 85 |
+
drone_detection uses the dedicated drone_yolo model.
|
| 86 |
|
| 87 |
Returns:
|
| 88 |
- For object_detection: Processed video with bounding boxes
|
| 89 |
- For segmentation: Processed video with masks rendered
|
| 90 |
+
- For drone_detection: Processed video with bounding boxes
|
| 91 |
"""
|
| 92 |
# Validate mode
|
| 93 |
if mode not in VALID_MODES:
|
|
|
|
| 143 |
filename="segmented.mp4",
|
| 144 |
)
|
| 145 |
|
| 146 |
+
# Handle object detection or drone detection mode
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
if video is None:
|
| 148 |
raise HTTPException(status_code=400, detail="Video file is required.")
|
| 149 |
|
|
|
|
| 162 |
|
| 163 |
# Parse queries
|
| 164 |
query_list = [q.strip() for q in queries.split(",") if q.strip()]
|
| 165 |
+
if mode == "drone_detection" and not query_list:
|
| 166 |
+
query_list = ["drone"]
|
| 167 |
|
| 168 |
# Run inference
|
| 169 |
try:
|
| 170 |
+
detector_name = "drone_yolo" if mode == "drone_detection" else detector
|
| 171 |
output_path = run_inference(
|
| 172 |
input_path,
|
| 173 |
output_path,
|
| 174 |
query_list,
|
| 175 |
+
detector_name=detector_name,
|
| 176 |
)
|
| 177 |
except ValueError as exc:
|
| 178 |
logging.exception("Video processing failed.")
|
demo.html
CHANGED
|
@@ -276,55 +276,6 @@
|
|
| 276 |
display: none;
|
| 277 |
}
|
| 278 |
|
| 279 |
-
/* Modal */
|
| 280 |
-
.modal {
|
| 281 |
-
display: none;
|
| 282 |
-
position: fixed;
|
| 283 |
-
z-index: 1000;
|
| 284 |
-
left: 0;
|
| 285 |
-
top: 0;
|
| 286 |
-
width: 100%;
|
| 287 |
-
height: 100%;
|
| 288 |
-
background: rgba(15, 23, 42, 0.5);
|
| 289 |
-
align-items: center;
|
| 290 |
-
justify-content: center;
|
| 291 |
-
}
|
| 292 |
-
|
| 293 |
-
.modal.show {
|
| 294 |
-
display: flex;
|
| 295 |
-
}
|
| 296 |
-
|
| 297 |
-
.modal-content {
|
| 298 |
-
background: white;
|
| 299 |
-
padding: 30px;
|
| 300 |
-
border-radius: 12px;
|
| 301 |
-
max-width: 500px;
|
| 302 |
-
text-align: center;
|
| 303 |
-
}
|
| 304 |
-
|
| 305 |
-
.modal-content h2 {
|
| 306 |
-
margin-bottom: 15px;
|
| 307 |
-
color: #333;
|
| 308 |
-
}
|
| 309 |
-
|
| 310 |
-
.modal-content p {
|
| 311 |
-
margin-bottom: 20px;
|
| 312 |
-
color: #666;
|
| 313 |
-
}
|
| 314 |
-
|
| 315 |
-
.modal-btn {
|
| 316 |
-
padding: 10px 24px;
|
| 317 |
-
background: #1f2933;
|
| 318 |
-
color: #f9fafb;
|
| 319 |
-
border: none;
|
| 320 |
-
border-radius: 6px;
|
| 321 |
-
cursor: pointer;
|
| 322 |
-
font-size: 1rem;
|
| 323 |
-
}
|
| 324 |
-
|
| 325 |
-
.modal-btn:hover {
|
| 326 |
-
background: #111827;
|
| 327 |
-
}
|
| 328 |
</style>
|
| 329 |
</head>
|
| 330 |
<body>
|
|
@@ -346,10 +297,9 @@
|
|
| 346 |
<div class="mode-title">Segmentation</div>
|
| 347 |
</label>
|
| 348 |
|
| 349 |
-
<label class="mode-card
|
| 350 |
<input type="radio" name="mode" value="drone_detection">
|
| 351 |
<div class="mode-title">Drone Detection</div>
|
| 352 |
-
<span class="mode-badge">COMING SOON</span>
|
| 353 |
</label>
|
| 354 |
</div>
|
| 355 |
</div>
|
|
@@ -391,6 +341,16 @@
|
|
| 391 |
</div>
|
| 392 |
</div>
|
| 393 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 394 |
<!-- Video Upload -->
|
| 395 |
<div class="section">
|
| 396 |
<div class="input-group">
|
|
@@ -441,15 +401,6 @@
|
|
| 441 |
</div>
|
| 442 |
</div>
|
| 443 |
|
| 444 |
-
<!-- Coming Soon Modal -->
|
| 445 |
-
<div class="modal" id="comingSoonModal">
|
| 446 |
-
<div class="modal-content">
|
| 447 |
-
<h2>Coming Soon!</h2>
|
| 448 |
-
<p id="modalMessage"></p>
|
| 449 |
-
<button class="modal-btn" id="modalClose">Got it</button>
|
| 450 |
-
</div>
|
| 451 |
-
</div>
|
| 452 |
-
|
| 453 |
<script>
|
| 454 |
// State
|
| 455 |
let selectedMode = 'object_detection';
|
|
@@ -462,6 +413,7 @@
|
|
| 462 |
const queriesHint = document.getElementById('queriesHint');
|
| 463 |
const detectorSection = document.getElementById('detectorSection');
|
| 464 |
const segmenterSection = document.getElementById('segmenterSection');
|
|
|
|
| 465 |
const fileInput = document.getElementById('videoFile');
|
| 466 |
const fileLabel = document.getElementById('fileLabel');
|
| 467 |
const processBtn = document.getElementById('processBtn');
|
|
@@ -470,23 +422,12 @@
|
|
| 470 |
const originalVideo = document.getElementById('originalVideo');
|
| 471 |
const processedVideo = document.getElementById('processedVideo');
|
| 472 |
const downloadBtn = document.getElementById('downloadBtn');
|
| 473 |
-
const modal = document.getElementById('comingSoonModal');
|
| 474 |
-
const modalMessage = document.getElementById('modalMessage');
|
| 475 |
-
const modalClose = document.getElementById('modalClose');
|
| 476 |
-
|
| 477 |
// Mode selection handler
|
| 478 |
modeCards.forEach(card => {
|
| 479 |
card.addEventListener('click', (e) => {
|
| 480 |
const input = card.querySelector('input[type="radio"]');
|
| 481 |
const mode = input.value;
|
| 482 |
|
| 483 |
-
// Check if disabled
|
| 484 |
-
if (card.classList.contains('disabled')) {
|
| 485 |
-
e.preventDefault();
|
| 486 |
-
showComingSoonModal(mode);
|
| 487 |
-
return;
|
| 488 |
-
}
|
| 489 |
-
|
| 490 |
// Update selected state
|
| 491 |
modeCards.forEach(c => c.classList.remove('selected'));
|
| 492 |
card.classList.add('selected');
|
|
@@ -498,16 +439,19 @@
|
|
| 498 |
queriesHint.textContent = 'Example: person, car, dog, bicycle';
|
| 499 |
detectorSection.classList.remove('hidden');
|
| 500 |
segmenterSection.classList.add('hidden');
|
|
|
|
| 501 |
} else if (mode === 'segmentation') {
|
| 502 |
queriesLabel.textContent = 'Objects to Segment (comma-separated)';
|
| 503 |
queriesHint.textContent = 'Example: person, car, building, tree';
|
| 504 |
detectorSection.classList.add('hidden');
|
| 505 |
segmenterSection.classList.remove('hidden');
|
|
|
|
| 506 |
} else if (mode === 'drone_detection') {
|
| 507 |
-
queriesLabel.textContent = '
|
| 508 |
-
queriesHint.textContent = 'Example:
|
| 509 |
detectorSection.classList.add('hidden');
|
| 510 |
segmenterSection.classList.add('hidden');
|
|
|
|
| 511 |
}
|
| 512 |
|
| 513 |
// Always show queries section
|
|
@@ -555,20 +499,17 @@
|
|
| 555 |
});
|
| 556 |
|
| 557 |
if (response.ok) {
|
| 558 |
-
const contentType = response.headers.get('content-type');
|
| 559 |
-
|
| 560 |
-
if (contentType && contentType.includes('application/json')) {
|
| 561 |
-
// Coming soon response
|
| 562 |
const data = await response.json();
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
// Video response
|
| 566 |
-
const blob = await response.blob();
|
| 567 |
-
const videoUrl = URL.createObjectURL(blob);
|
| 568 |
-
processedVideo.src = videoUrl;
|
| 569 |
-
downloadBtn.href = videoUrl;
|
| 570 |
-
resultsSection.classList.remove('hidden');
|
| 571 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 572 |
} else {
|
| 573 |
const error = await response.json();
|
| 574 |
alert(`Error: ${error.detail || error.error || 'Processing failed'}`);
|
|
@@ -582,35 +523,6 @@
|
|
| 582 |
}
|
| 583 |
});
|
| 584 |
|
| 585 |
-
// Coming soon modal
|
| 586 |
-
function showComingSoonModal(mode) {
|
| 587 |
-
const messages = {
|
| 588 |
-
'drone_detection': 'Drone detection mode is under development. Stay tuned for specialized UAV and aerial object detection!'
|
| 589 |
-
};
|
| 590 |
-
modalMessage.textContent = messages[mode] || 'This feature is coming soon!';
|
| 591 |
-
modal.classList.add('show');
|
| 592 |
-
}
|
| 593 |
-
|
| 594 |
-
modalClose.addEventListener('click', () => {
|
| 595 |
-
modal.classList.remove('show');
|
| 596 |
-
// Reset to object detection
|
| 597 |
-
document.querySelector('input[value="object_detection"]').checked = true;
|
| 598 |
-
modeCards.forEach(c => c.classList.remove('selected'));
|
| 599 |
-
document.querySelector('input[value="object_detection"]').closest('.mode-card').classList.add('selected');
|
| 600 |
-
selectedMode = 'object_detection';
|
| 601 |
-
// Update labels for object detection mode
|
| 602 |
-
queriesLabel.textContent = 'Objects to Detect (comma-separated)';
|
| 603 |
-
queriesHint.textContent = 'Example: person, car, dog, bicycle';
|
| 604 |
-
detectorSection.classList.remove('hidden');
|
| 605 |
-
segmenterSection.classList.add('hidden');
|
| 606 |
-
});
|
| 607 |
-
|
| 608 |
-
// Close modal on background click
|
| 609 |
-
modal.addEventListener('click', (e) => {
|
| 610 |
-
if (e.target === modal) {
|
| 611 |
-
modalClose.click();
|
| 612 |
-
}
|
| 613 |
-
});
|
| 614 |
</script>
|
| 615 |
</body>
|
| 616 |
</html>
|
|
|
|
| 276 |
display: none;
|
| 277 |
}
|
| 278 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
</style>
|
| 280 |
</head>
|
| 281 |
<body>
|
|
|
|
| 297 |
<div class="mode-title">Segmentation</div>
|
| 298 |
</label>
|
| 299 |
|
| 300 |
+
<label class="mode-card">
|
| 301 |
<input type="radio" name="mode" value="drone_detection">
|
| 302 |
<div class="mode-title">Drone Detection</div>
|
|
|
|
| 303 |
</label>
|
| 304 |
</div>
|
| 305 |
</div>
|
|
|
|
| 341 |
</div>
|
| 342 |
</div>
|
| 343 |
|
| 344 |
+
<!-- Drone Model Selection -->
|
| 345 |
+
<div class="section hidden" id="droneModelSection">
|
| 346 |
+
<div class="input-group">
|
| 347 |
+
<label for="droneModel">2. Select Drone Model</label>
|
| 348 |
+
<select id="droneModel" disabled>
|
| 349 |
+
<option value="drone_yolo">Drone YOLO (HF pretrained)</option>
|
| 350 |
+
</select>
|
| 351 |
+
</div>
|
| 352 |
+
</div>
|
| 353 |
+
|
| 354 |
<!-- Video Upload -->
|
| 355 |
<div class="section">
|
| 356 |
<div class="input-group">
|
|
|
|
| 401 |
</div>
|
| 402 |
</div>
|
| 403 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 404 |
<script>
|
| 405 |
// State
|
| 406 |
let selectedMode = 'object_detection';
|
|
|
|
| 413 |
const queriesHint = document.getElementById('queriesHint');
|
| 414 |
const detectorSection = document.getElementById('detectorSection');
|
| 415 |
const segmenterSection = document.getElementById('segmenterSection');
|
| 416 |
+
const droneModelSection = document.getElementById('droneModelSection');
|
| 417 |
const fileInput = document.getElementById('videoFile');
|
| 418 |
const fileLabel = document.getElementById('fileLabel');
|
| 419 |
const processBtn = document.getElementById('processBtn');
|
|
|
|
| 422 |
const originalVideo = document.getElementById('originalVideo');
|
| 423 |
const processedVideo = document.getElementById('processedVideo');
|
| 424 |
const downloadBtn = document.getElementById('downloadBtn');
|
|
|
|
|
|
|
|
|
|
|
|
|
| 425 |
// Mode selection handler
|
| 426 |
modeCards.forEach(card => {
|
| 427 |
card.addEventListener('click', (e) => {
|
| 428 |
const input = card.querySelector('input[type="radio"]');
|
| 429 |
const mode = input.value;
|
| 430 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 431 |
// Update selected state
|
| 432 |
modeCards.forEach(c => c.classList.remove('selected'));
|
| 433 |
card.classList.add('selected');
|
|
|
|
| 439 |
queriesHint.textContent = 'Example: person, car, dog, bicycle';
|
| 440 |
detectorSection.classList.remove('hidden');
|
| 441 |
segmenterSection.classList.add('hidden');
|
| 442 |
+
droneModelSection.classList.add('hidden');
|
| 443 |
} else if (mode === 'segmentation') {
|
| 444 |
queriesLabel.textContent = 'Objects to Segment (comma-separated)';
|
| 445 |
queriesHint.textContent = 'Example: person, car, building, tree';
|
| 446 |
detectorSection.classList.add('hidden');
|
| 447 |
segmenterSection.classList.remove('hidden');
|
| 448 |
+
droneModelSection.classList.add('hidden');
|
| 449 |
} else if (mode === 'drone_detection') {
|
| 450 |
+
queriesLabel.textContent = 'Optional Labels (comma-separated)';
|
| 451 |
+
queriesHint.textContent = 'Example: drone, quadcopter';
|
| 452 |
detectorSection.classList.add('hidden');
|
| 453 |
segmenterSection.classList.add('hidden');
|
| 454 |
+
droneModelSection.classList.remove('hidden');
|
| 455 |
}
|
| 456 |
|
| 457 |
// Always show queries section
|
|
|
|
| 499 |
});
|
| 500 |
|
| 501 |
if (response.ok) {
|
| 502 |
+
const contentType = response.headers.get('content-type') || '';
|
| 503 |
+
if (contentType.includes('application/json')) {
|
|
|
|
|
|
|
| 504 |
const data = await response.json();
|
| 505 |
+
alert(data.message || 'Request completed.');
|
| 506 |
+
return;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 507 |
}
|
| 508 |
+
const blob = await response.blob();
|
| 509 |
+
const videoUrl = URL.createObjectURL(blob);
|
| 510 |
+
processedVideo.src = videoUrl;
|
| 511 |
+
downloadBtn.href = videoUrl;
|
| 512 |
+
resultsSection.classList.remove('hidden');
|
| 513 |
} else {
|
| 514 |
const error = await response.json();
|
| 515 |
alert(`Error: ${error.detail || error.error || 'Processing failed'}`);
|
|
|
|
| 523 |
}
|
| 524 |
});
|
| 525 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 526 |
</script>
|
| 527 |
</body>
|
| 528 |
</html>
|
models/detectors/drone_yolo.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
from typing import List, Sequence
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from huggingface_hub import hf_hub_download
|
| 8 |
+
from ultralytics import YOLO
|
| 9 |
+
|
| 10 |
+
from models.detectors.base import DetectionResult, ObjectDetector
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class DroneYoloDetector(ObjectDetector):
|
| 14 |
+
"""Drone detector backed by a YOLO model on the Hugging Face Hub."""
|
| 15 |
+
|
| 16 |
+
REPO_ID = "rujutashashikanjoshi/yolo12-drone-detection-0205-100m"
|
| 17 |
+
DEFAULT_WEIGHT = "best.pt"
|
| 18 |
+
|
| 19 |
+
def __init__(self, score_threshold: float = 0.3) -> None:
|
| 20 |
+
self.name = "drone_yolo"
|
| 21 |
+
self.score_threshold = score_threshold
|
| 22 |
+
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 23 |
+
weight_file = os.getenv("DRONE_YOLO_WEIGHT", self.DEFAULT_WEIGHT)
|
| 24 |
+
logging.info(
|
| 25 |
+
"Loading drone YOLO weights %s/%s onto %s",
|
| 26 |
+
self.REPO_ID,
|
| 27 |
+
weight_file,
|
| 28 |
+
self.device,
|
| 29 |
+
)
|
| 30 |
+
weight_path = hf_hub_download(repo_id=self.REPO_ID, filename=weight_file)
|
| 31 |
+
self.model = YOLO(weight_path)
|
| 32 |
+
self.model.to(self.device)
|
| 33 |
+
self.class_names = self.model.names
|
| 34 |
+
|
| 35 |
+
def _filter_indices(self, label_names: Sequence[str], queries: Sequence[str]) -> List[int]:
|
| 36 |
+
if not queries:
|
| 37 |
+
return list(range(len(label_names)))
|
| 38 |
+
allowed = {query.lower().strip() for query in queries if query}
|
| 39 |
+
keep = [idx for idx, name in enumerate(label_names) if name.lower() in allowed]
|
| 40 |
+
return keep or list(range(len(label_names)))
|
| 41 |
+
|
| 42 |
+
def predict(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult:
|
| 43 |
+
device_arg = 0 if self.device.startswith("cuda") else "cpu"
|
| 44 |
+
results = self.model.predict(
|
| 45 |
+
source=frame,
|
| 46 |
+
device=device_arg,
|
| 47 |
+
conf=self.score_threshold,
|
| 48 |
+
verbose=False,
|
| 49 |
+
)
|
| 50 |
+
result = results[0]
|
| 51 |
+
boxes = result.boxes
|
| 52 |
+
if boxes is None or boxes.xyxy is None:
|
| 53 |
+
empty = np.empty((0, 4), dtype=np.float32)
|
| 54 |
+
return DetectionResult(empty, [], [], [])
|
| 55 |
+
|
| 56 |
+
xyxy = boxes.xyxy.cpu().numpy()
|
| 57 |
+
scores = boxes.conf.cpu().numpy().tolist()
|
| 58 |
+
label_ids = boxes.cls.cpu().numpy().astype(int).tolist()
|
| 59 |
+
label_names = [self.class_names.get(idx, f"class_{idx}") for idx in label_ids]
|
| 60 |
+
keep_indices = self._filter_indices(label_names, queries)
|
| 61 |
+
xyxy = xyxy[keep_indices] if len(xyxy) else xyxy
|
| 62 |
+
scores = [scores[i] for i in keep_indices]
|
| 63 |
+
label_ids = [label_ids[i] for i in keep_indices]
|
| 64 |
+
label_names = [label_names[i] for i in keep_indices]
|
| 65 |
+
return DetectionResult(
|
| 66 |
+
boxes=xyxy,
|
| 67 |
+
scores=scores,
|
| 68 |
+
labels=label_ids,
|
| 69 |
+
label_names=label_names,
|
| 70 |
+
)
|
models/model_loader.py
CHANGED
|
@@ -4,6 +4,7 @@ from typing import Callable, Dict, Optional
|
|
| 4 |
|
| 5 |
from models.detectors.base import ObjectDetector
|
| 6 |
from models.detectors.detr import DetrDetector
|
|
|
|
| 7 |
from models.detectors.grounding_dino import GroundingDinoDetector
|
| 8 |
from models.detectors.yolov8 import HuggingFaceYoloV8Detector
|
| 9 |
|
|
@@ -13,6 +14,7 @@ _REGISTRY: Dict[str, Callable[[], ObjectDetector]] = {
|
|
| 13 |
"hf_yolov8": HuggingFaceYoloV8Detector,
|
| 14 |
"detr_resnet50": DetrDetector,
|
| 15 |
"grounding_dino": GroundingDinoDetector,
|
|
|
|
| 16 |
}
|
| 17 |
|
| 18 |
|
|
|
|
| 4 |
|
| 5 |
from models.detectors.base import ObjectDetector
|
| 6 |
from models.detectors.detr import DetrDetector
|
| 7 |
+
from models.detectors.drone_yolo import DroneYoloDetector
|
| 8 |
from models.detectors.grounding_dino import GroundingDinoDetector
|
| 9 |
from models.detectors.yolov8 import HuggingFaceYoloV8Detector
|
| 10 |
|
|
|
|
| 14 |
"hf_yolov8": HuggingFaceYoloV8Detector,
|
| 15 |
"detr_resnet50": DetrDetector,
|
| 16 |
"grounding_dino": GroundingDinoDetector,
|
| 17 |
+
"drone_yolo": DroneYoloDetector,
|
| 18 |
}
|
| 19 |
|
| 20 |
|