perception2 / models /detectors /drone_yolo.py
Zhen Ye
fix: Resolve IndentationError in DroneYoloDetector due to displaced import
4e93b33
import logging
import os
from typing import List, Sequence
import numpy as np
import torch
from ultralytics import YOLO
from models.detectors.base import DetectionResult, ObjectDetector
from utils.tiling import get_slice_bboxes, slice_image, shift_bboxes, batched_nms
class DroneYoloDetector(ObjectDetector):
"""Drone detector backed by a YOLO model on the Hugging Face Hub."""
REPO_ID = "rujutashashikanjoshi/yolo12-drone-detection-0205-100m"
supports_batch = True
max_batch_size = 32
def __init__(self, score_threshold: float = 0.3, device: str = None) -> None:
self.name = "drone_yolo"
self.score_threshold = score_threshold
if device:
self.device = device
else:
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
logging.info(
"Loading drone YOLO from HuggingFace Hub: %s onto %s",
self.REPO_ID,
self.device,
)
# Load directly from HuggingFace Hub using ultralytics native support
self.model = YOLO(f"hf://{self.REPO_ID}")
self.model.to(self.device)
self.class_names = self.model.names
def _filter_indices(self, label_names: Sequence[str], queries: Sequence[str]) -> List[int]:
if not queries:
return list(range(len(label_names)))
allowed = {query.lower().strip() for query in queries if query}
keep = [idx for idx, name in enumerate(label_names) if name.lower() in allowed]
return keep or list(range(len(label_names)))
def _parse_single_result(self, result, queries: Sequence[str]) -> DetectionResult:
boxes = result.boxes
if boxes is None or boxes.xyxy is None:
empty = np.empty((0, 4), dtype=np.float32)
return DetectionResult(empty, [], [], [])
xyxy = boxes.xyxy.cpu().numpy()
scores = boxes.conf.cpu().numpy().tolist()
label_ids = boxes.cls.cpu().numpy().astype(int).tolist()
label_names = [self.class_names.get(idx, f"class_{idx}") for idx in label_ids]
keep_indices = self._filter_indices(label_names, queries)
xyxy = xyxy[keep_indices] if len(xyxy) else xyxy
scores = [scores[i] for i in keep_indices]
label_ids = [label_ids[i] for i in keep_indices]
label_names = [label_names[i] for i in keep_indices]
return DetectionResult(
boxes=xyxy,
scores=scores,
labels=label_ids,
label_names=label_names,
)
def _predict_tiled(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult:
"""Run tiled inference for high-resolution frames."""
# 1. Slice
h, w = frame.shape[:2]
# Heuristic: 1280x1280 tiles with 20% overlap
slice_boxes = get_slice_bboxes(h, w, 1280, 1280, 0.2, 0.2)
tiles = slice_image(frame, slice_boxes)
# 2. Batch Inference
# We can use our own model's batch prediction if we can trust it not to recurse strictly
# But we need raw results to merge.
# Actually proper way: run standard predict on tiles.
all_boxes = []
all_scores = []
all_labels = []
# Run in batches of max_batch_size to respect GPU memory
batch_size = self.max_batch_size
for i in range(0, len(tiles), batch_size):
batch_tiles = tiles[i : i + batch_size]
batch_slices = slice_boxes[i : i + batch_size]
results = self.model.predict(
source=batch_tiles,
device=self.device,
conf=self.score_threshold,
imgsz=1280, # Run tiles at full res
verbose=False,
)
for res, slice_coord in zip(results, batch_slices):
if res.boxes is None: continue
# Extract standard results
boxes = res.boxes.xyxy.cpu().numpy().tolist()
scores = res.boxes.conf.cpu().numpy().tolist()
clss = res.boxes.cls.cpu().numpy().tolist()
# Shift to global
shifted = shift_bboxes(boxes, slice_coord)
all_boxes.extend(shifted)
all_scores.extend(scores)
all_labels.extend(clss)
if not all_boxes:
empty = np.empty((0, 4), dtype=np.float32)
return DetectionResult(empty, [], [], [])
# 3. NMS Merge
boxes_t = torch.tensor(all_boxes, device=self.device)
scores_t = torch.tensor(all_scores, device=self.device)
labels_t = torch.tensor(all_labels, device=self.device)
keep = batched_nms(boxes_t, scores_t, labels_t, iou_threshold=0.4)
final_boxes = boxes_t[keep].cpu().numpy()
final_scores = scores_t[keep].cpu().tolist()
final_labels = labels_t[keep].cpu().int().tolist()
# 4. Filter & Format
label_names = [self.class_names.get(idx, f"class_{idx}") for idx in final_labels]
keep_indices = self._filter_indices(label_names, queries)
if not keep_indices:
empty = np.empty((0, 4), dtype=np.float32)
return DetectionResult(empty, [], [], [])
final_boxes = final_boxes[keep_indices]
final_scores = [final_scores[i] for i in keep_indices]
final_labels = [final_labels[i] for i in keep_indices]
final_names = [label_names[i] for i in keep_indices]
return DetectionResult(
boxes=final_boxes,
scores=final_scores,
labels=final_labels,
label_names=final_names
)
def predict(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult:
h, w = frame.shape[:2]
# Enable tiling for 4Kish images (width > 3000)
if w > 3000:
return self._predict_tiled(frame, queries)
device_arg = self.device
results = self.model.predict(
source=frame,
device=device_arg,
conf=self.score_threshold,
imgsz=1280,
verbose=False,
)
return self._parse_single_result(results[0], queries)
def predict_batch(self, frames: Sequence[np.ndarray], queries: Sequence[str]) -> Sequence[DetectionResult]:
# Mixed batch support is hard. Assume batch is uniform size.
if not frames: return []
h, w = frames[0].shape[:2]
if w > 3000:
return [self._predict_tiled(f, queries) for f in frames]
results = self.model.predict(
source=frames,
device=self.device,
conf=self.score_threshold,
imgsz=1280,
verbose=False,
)
return [self._parse_single_result(r, queries) for r in results]