ISR / inference.py
Zhen Ye
added async first frame/video detection
f0b6460
raw
history blame
7.84 kB
import logging
from threading import RLock
from typing import Any, Dict, List, Optional, Sequence, Tuple
import cv2
import numpy as np
from models.model_loader import load_detector
from models.segmenters.model_loader import load_segmenter
from utils.video import extract_frames, write_video
def draw_boxes(frame: np.ndarray, boxes: np.ndarray) -> np.ndarray:
output = frame.copy()
if boxes is None:
return output
for box in boxes:
x1, y1, x2, y2 = [int(coord) for coord in box]
cv2.rectangle(output, (x1, y1), (x2, y2), (0, 255, 0), thickness=2)
return output
def draw_masks(frame: np.ndarray, masks: np.ndarray, alpha: float = 0.45) -> np.ndarray:
output = frame.copy()
if masks is None or len(masks) == 0:
return output
colors = [
(255, 0, 0),
(0, 255, 0),
(0, 0, 255),
(255, 255, 0),
(0, 255, 255),
(255, 0, 255),
]
for idx, mask in enumerate(masks):
if mask is None:
continue
if mask.ndim == 3:
mask = mask[0]
if mask.shape[:2] != output.shape[:2]:
mask = cv2.resize(mask, (output.shape[1], output.shape[0]), interpolation=cv2.INTER_NEAREST)
mask_bool = mask.astype(bool)
overlay = np.zeros_like(output, dtype=np.uint8)
overlay[mask_bool] = colors[idx % len(colors)]
output = cv2.addWeighted(output, 1.0, overlay, alpha, 0)
return output
def _build_detection_records(
boxes: np.ndarray,
scores: Sequence[float],
labels: Sequence[int],
queries: Sequence[str],
label_names: Optional[Sequence[str]] = None,
) -> List[Dict[str, Any]]:
detections: List[Dict[str, Any]] = []
for idx, box in enumerate(boxes):
if label_names is not None and idx < len(label_names):
label = label_names[idx]
else:
label_idx = int(labels[idx]) if idx < len(labels) else -1
if 0 <= label_idx < len(queries):
label = queries[label_idx]
else:
label = f"label_{label_idx}"
detections.append(
{
"label": label,
"score": float(scores[idx]) if idx < len(scores) else 0.0,
"bbox": [int(coord) for coord in box.tolist()],
}
)
return detections
_MODEL_LOCKS: Dict[str, RLock] = {}
_MODEL_LOCKS_GUARD = RLock()
def _get_model_lock(kind: str, name: str) -> RLock:
key = f"{kind}:{name}"
with _MODEL_LOCKS_GUARD:
lock = _MODEL_LOCKS.get(key)
if lock is None:
lock = RLock()
_MODEL_LOCKS[key] = lock
return lock
def infer_frame(
frame: np.ndarray,
queries: Sequence[str],
detector_name: Optional[str] = None,
) -> tuple[np.ndarray, List[Dict[str, Any]]]:
detector = load_detector(detector_name)
text_queries = list(queries) or ["object"]
try:
lock = _get_model_lock("detector", detector.name)
with lock:
result = detector.predict(frame, text_queries)
detections = _build_detection_records(
result.boxes, result.scores, result.labels, text_queries, result.label_names
)
except Exception:
logging.exception("Inference failed for queries %s", text_queries)
raise
return draw_boxes(frame, result.boxes), detections
def infer_segmentation_frame(
frame: np.ndarray,
text_queries: Optional[List[str]] = None,
segmenter_name: Optional[str] = None,
) -> tuple[np.ndarray, Any]:
segmenter = load_segmenter(segmenter_name)
lock = _get_model_lock("segmenter", segmenter.name)
with lock:
result = segmenter.predict(frame, text_prompts=text_queries)
return draw_masks(frame, result.masks), result
def extract_first_frame(video_path: str) -> Tuple[np.ndarray, float, int, int]:
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise ValueError("Unable to open video.")
fps = cap.get(cv2.CAP_PROP_FPS) or 0.0
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
success, frame = cap.read()
cap.release()
if not success or frame is None:
raise ValueError("Video decode produced zero frames.")
return frame, fps, width, height
def process_first_frame(
video_path: str,
queries: List[str],
mode: str,
detector_name: Optional[str] = None,
segmenter_name: Optional[str] = None,
) -> Tuple[np.ndarray, List[Dict[str, Any]]]:
frame, _, _, _ = extract_first_frame(video_path)
if mode == "segmentation":
processed, _ = infer_segmentation_frame(
frame, text_queries=queries, segmenter_name=segmenter_name
)
return processed, []
processed, detections = infer_frame(
frame, queries, detector_name=detector_name
)
return processed, detections
def run_inference(
input_video_path: str,
output_video_path: str,
queries: List[str],
max_frames: Optional[int] = None,
detector_name: Optional[str] = None,
) -> str:
"""
Run object detection inference on a video.
Args:
input_video_path: Path to input video
output_video_path: Path to write processed video
queries: List of object classes to detect (e.g., ["person", "car"])
max_frames: Optional frame limit for testing
detector_name: Detector to use (default: hf_yolov8)
Returns:
Path to processed output video
"""
try:
frames, fps, width, height = extract_frames(input_video_path)
except ValueError as exc:
logging.exception("Failed to decode video at %s", input_video_path)
raise
# Use provided queries or default to common objects
if not queries:
queries = ["person", "car", "truck", "motorcycle", "bicycle", "bus", "train", "airplane"]
logging.info("No queries provided, using defaults: %s", queries)
logging.info("Detection queries: %s", queries)
# Select detector
active_detector = detector_name or "hf_yolov8"
logging.info("Using detector: %s", active_detector)
# Process frames
processed_frames: List[np.ndarray] = []
for idx, frame in enumerate(frames):
if max_frames is not None and idx >= max_frames:
break
logging.debug("Processing frame %d", idx)
processed_frame, _ = infer_frame(frame, queries, detector_name=active_detector)
processed_frames.append(processed_frame)
# Write output video
write_video(processed_frames, output_video_path, fps=fps, width=width, height=height)
logging.info("Processed video written to: %s", output_video_path)
return output_video_path
def run_segmentation(
input_video_path: str,
output_video_path: str,
queries: List[str],
max_frames: Optional[int] = None,
segmenter_name: Optional[str] = None,
) -> str:
try:
frames, fps, width, height = extract_frames(input_video_path)
except ValueError as exc:
logging.exception("Failed to decode video at %s", input_video_path)
raise
active_segmenter = segmenter_name or "sam3"
logging.info("Using segmenter: %s with queries: %s", active_segmenter, queries)
processed_frames: List[np.ndarray] = []
for idx, frame in enumerate(frames):
if max_frames is not None and idx >= max_frames:
break
logging.debug("Processing frame %d", idx)
processed_frame, _ = infer_segmentation_frame(frame, text_queries=queries, segmenter_name=active_segmenter)
processed_frames.append(processed_frame)
write_video(processed_frames, output_video_path, fps=fps, width=width, height=height)
logging.info("Segmented video written to: %s", output_video_path)
return output_video_path