Spaces:
Sleeping
Sleeping
| import logging | |
| import os | |
| from typing import List, Sequence | |
| import numpy as np | |
| import torch | |
| from ultralytics import YOLO | |
| from models.detectors.base import DetectionResult, ObjectDetector | |
| from utils.tiling import get_slice_bboxes, slice_image, shift_bboxes, batched_nms | |
| class DroneYoloDetector(ObjectDetector): | |
| """Drone detector backed by a YOLO model on the Hugging Face Hub.""" | |
| REPO_ID = "rujutashashikanjoshi/yolo12-drone-detection-0205-100m" | |
| supports_batch = True | |
| max_batch_size = 32 | |
| def __init__(self, score_threshold: float = 0.3, device: str = None) -> None: | |
| self.name = "drone_yolo" | |
| self.score_threshold = score_threshold | |
| if device: | |
| self.device = device | |
| else: | |
| self.device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| logging.info( | |
| "Loading drone YOLO from HuggingFace Hub: %s onto %s", | |
| self.REPO_ID, | |
| self.device, | |
| ) | |
| # Load directly from HuggingFace Hub using ultralytics native support | |
| self.model = YOLO(f"hf://{self.REPO_ID}") | |
| self.model.to(self.device) | |
| self.class_names = self.model.names | |
| def _filter_indices(self, label_names: Sequence[str], queries: Sequence[str]) -> List[int]: | |
| if not queries: | |
| return list(range(len(label_names))) | |
| allowed = {query.lower().strip() for query in queries if query} | |
| keep = [idx for idx, name in enumerate(label_names) if name.lower() in allowed] | |
| return keep or list(range(len(label_names))) | |
| def _parse_single_result(self, result, queries: Sequence[str]) -> DetectionResult: | |
| boxes = result.boxes | |
| if boxes is None or boxes.xyxy is None: | |
| empty = np.empty((0, 4), dtype=np.float32) | |
| return DetectionResult(empty, [], [], []) | |
| xyxy = boxes.xyxy.cpu().numpy() | |
| scores = boxes.conf.cpu().numpy().tolist() | |
| label_ids = boxes.cls.cpu().numpy().astype(int).tolist() | |
| label_names = [self.class_names.get(idx, f"class_{idx}") for idx in label_ids] | |
| keep_indices = self._filter_indices(label_names, queries) | |
| xyxy = xyxy[keep_indices] if len(xyxy) else xyxy | |
| scores = [scores[i] for i in keep_indices] | |
| label_ids = [label_ids[i] for i in keep_indices] | |
| label_names = [label_names[i] for i in keep_indices] | |
| return DetectionResult( | |
| boxes=xyxy, | |
| scores=scores, | |
| labels=label_ids, | |
| label_names=label_names, | |
| ) | |
| def _predict_tiled(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult: | |
| """Run tiled inference for high-resolution frames.""" | |
| # 1. Slice | |
| h, w = frame.shape[:2] | |
| # Heuristic: 1280x1280 tiles with 20% overlap | |
| slice_boxes = get_slice_bboxes(h, w, 1280, 1280, 0.2, 0.2) | |
| tiles = slice_image(frame, slice_boxes) | |
| # 2. Batch Inference | |
| # We can use our own model's batch prediction if we can trust it not to recurse strictly | |
| # But we need raw results to merge. | |
| # Actually proper way: run standard predict on tiles. | |
| all_boxes = [] | |
| all_scores = [] | |
| all_labels = [] | |
| # Run in batches of max_batch_size to respect GPU memory | |
| batch_size = self.max_batch_size | |
| for i in range(0, len(tiles), batch_size): | |
| batch_tiles = tiles[i : i + batch_size] | |
| batch_slices = slice_boxes[i : i + batch_size] | |
| results = self.model.predict( | |
| source=batch_tiles, | |
| device=self.device, | |
| conf=self.score_threshold, | |
| imgsz=1280, # Run tiles at full res | |
| verbose=False, | |
| ) | |
| for res, slice_coord in zip(results, batch_slices): | |
| if res.boxes is None: continue | |
| # Extract standard results | |
| boxes = res.boxes.xyxy.cpu().numpy().tolist() | |
| scores = res.boxes.conf.cpu().numpy().tolist() | |
| clss = res.boxes.cls.cpu().numpy().tolist() | |
| # Shift to global | |
| shifted = shift_bboxes(boxes, slice_coord) | |
| all_boxes.extend(shifted) | |
| all_scores.extend(scores) | |
| all_labels.extend(clss) | |
| if not all_boxes: | |
| empty = np.empty((0, 4), dtype=np.float32) | |
| return DetectionResult(empty, [], [], []) | |
| # 3. NMS Merge | |
| boxes_t = torch.tensor(all_boxes, device=self.device) | |
| scores_t = torch.tensor(all_scores, device=self.device) | |
| labels_t = torch.tensor(all_labels, device=self.device) | |
| keep = batched_nms(boxes_t, scores_t, labels_t, iou_threshold=0.4) | |
| final_boxes = boxes_t[keep].cpu().numpy() | |
| final_scores = scores_t[keep].cpu().tolist() | |
| final_labels = labels_t[keep].cpu().int().tolist() | |
| # 4. Filter & Format | |
| label_names = [self.class_names.get(idx, f"class_{idx}") for idx in final_labels] | |
| keep_indices = self._filter_indices(label_names, queries) | |
| if not keep_indices: | |
| empty = np.empty((0, 4), dtype=np.float32) | |
| return DetectionResult(empty, [], [], []) | |
| final_boxes = final_boxes[keep_indices] | |
| final_scores = [final_scores[i] for i in keep_indices] | |
| final_labels = [final_labels[i] for i in keep_indices] | |
| final_names = [label_names[i] for i in keep_indices] | |
| return DetectionResult( | |
| boxes=final_boxes, | |
| scores=final_scores, | |
| labels=final_labels, | |
| label_names=final_names | |
| ) | |
| def predict(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult: | |
| h, w = frame.shape[:2] | |
| # Enable tiling for 4Kish images (width > 3000) | |
| if w > 3000: | |
| return self._predict_tiled(frame, queries) | |
| device_arg = self.device | |
| results = self.model.predict( | |
| source=frame, | |
| device=device_arg, | |
| conf=self.score_threshold, | |
| imgsz=1280, | |
| verbose=False, | |
| ) | |
| return self._parse_single_result(results[0], queries) | |
| def predict_batch(self, frames: Sequence[np.ndarray], queries: Sequence[str]) -> Sequence[DetectionResult]: | |
| # Mixed batch support is hard. Assume batch is uniform size. | |
| if not frames: return [] | |
| h, w = frames[0].shape[:2] | |
| if w > 3000: | |
| return [self._predict_tiled(f, queries) for f in frames] | |
| results = self.model.predict( | |
| source=frames, | |
| device=self.device, | |
| conf=self.score_threshold, | |
| imgsz=1280, | |
| verbose=False, | |
| ) | |
| return [self._parse_single_result(r, queries) for r in results] | |