DepthLens / src /models /detector.py
Rishabh Jain
Initial upload — depth-aware scene description system
5412d82
"""
Object detection using YOLOv8n.
Wraps the ultralytics YOLO interface and returns detections in the format
expected by build_depth_context: (boxes, classes, confidences).
"""
import numpy as np
import torch
from ultralytics import YOLO
from ..config import CONF_THRESHOLD, YOLO_MODEL
class ObjectDetector:
"""YOLOv8n object detector.
Downloads ``yolov8n.pt`` on first use (cached by ultralytics in
``~/.cache/ultralytics/``). Subsequent loads use the cached weights.
"""
def __init__(self) -> None:
"""Load YOLOv8n onto the available device."""
print("Loading YOLOv8n...")
self.model = YOLO(YOLO_MODEL)
# Move weights to GPU when available. YOLO's constructor always
# loads to CPU; .to() moves the underlying PyTorch model in-place.
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
if torch.cuda.is_available():
print(
f" GPU memory allocated: "
f"{torch.cuda.memory_allocated() / 1024**2:.0f} MB"
)
def detect(
self, image: np.ndarray
) -> tuple[np.ndarray, list[str], list[float]]:
"""Run detection on an RGB image.
Args:
image: uint8 RGB numpy array of shape (H, W, 3).
Returns:
boxes: float32 array of shape (N, 4) as [x1, y1, x2, y2]
in pixel coordinates.
classes: List of N class-name strings.
confidences: List of N confidence floats in [0, 1].
"""
# ultralytics assumes BGR numpy input and does its own BGR→RGB flip
# internally. Convert so colours are correct for a model trained on
# standard BGR/OpenCV images.
bgr = image[..., ::-1]
with torch.inference_mode():
results = self.model(
bgr,
conf=CONF_THRESHOLD,
verbose=False,
device=self.device,
)
result = results[0]
det = result.boxes
if det is None or len(det) == 0:
empty = np.empty((0, 4), dtype=np.float32)
return empty, [], []
boxes = det.xyxy.cpu().numpy().astype(np.float32) # (N, 4)
confidences = det.conf.cpu().numpy().tolist() # (N,)
class_ids = det.cls.cpu().numpy().astype(int).tolist() # (N,)
classes = [result.names[cid] for cid in class_ids]
return boxes, classes, confidences