Spaces:
Sleeping
Sleeping
| import logging | |
| from threading import RLock | |
| from typing import Any, Dict, List, Optional, Sequence, Tuple | |
| import cv2 | |
| import numpy as np | |
| from models.model_loader import load_detector | |
| from models.segmenters.model_loader import load_segmenter | |
| from utils.video import extract_frames, write_video | |
| def draw_boxes(frame: np.ndarray, boxes: np.ndarray) -> np.ndarray: | |
| output = frame.copy() | |
| if boxes is None: | |
| return output | |
| for box in boxes: | |
| x1, y1, x2, y2 = [int(coord) for coord in box] | |
| cv2.rectangle(output, (x1, y1), (x2, y2), (0, 255, 0), thickness=2) | |
| return output | |
| def draw_masks(frame: np.ndarray, masks: np.ndarray, alpha: float = 0.45) -> np.ndarray: | |
| output = frame.copy() | |
| if masks is None or len(masks) == 0: | |
| return output | |
| colors = [ | |
| (255, 0, 0), | |
| (0, 255, 0), | |
| (0, 0, 255), | |
| (255, 255, 0), | |
| (0, 255, 255), | |
| (255, 0, 255), | |
| ] | |
| for idx, mask in enumerate(masks): | |
| if mask is None: | |
| continue | |
| if mask.ndim == 3: | |
| mask = mask[0] | |
| if mask.shape[:2] != output.shape[:2]: | |
| mask = cv2.resize(mask, (output.shape[1], output.shape[0]), interpolation=cv2.INTER_NEAREST) | |
| mask_bool = mask.astype(bool) | |
| overlay = np.zeros_like(output, dtype=np.uint8) | |
| overlay[mask_bool] = colors[idx % len(colors)] | |
| output = cv2.addWeighted(output, 1.0, overlay, alpha, 0) | |
| return output | |
| def _build_detection_records( | |
| boxes: np.ndarray, | |
| scores: Sequence[float], | |
| labels: Sequence[int], | |
| queries: Sequence[str], | |
| label_names: Optional[Sequence[str]] = None, | |
| ) -> List[Dict[str, Any]]: | |
| detections: List[Dict[str, Any]] = [] | |
| for idx, box in enumerate(boxes): | |
| if label_names is not None and idx < len(label_names): | |
| label = label_names[idx] | |
| else: | |
| label_idx = int(labels[idx]) if idx < len(labels) else -1 | |
| if 0 <= label_idx < len(queries): | |
| label = queries[label_idx] | |
| else: | |
| label = f"label_{label_idx}" | |
| detections.append( | |
| { | |
| "label": label, | |
| "score": float(scores[idx]) if idx < len(scores) else 0.0, | |
| "bbox": [int(coord) for coord in box.tolist()], | |
| } | |
| ) | |
| return detections | |
| _MODEL_LOCKS: Dict[str, RLock] = {} | |
| _MODEL_LOCKS_GUARD = RLock() | |
| def _get_model_lock(kind: str, name: str) -> RLock: | |
| key = f"{kind}:{name}" | |
| with _MODEL_LOCKS_GUARD: | |
| lock = _MODEL_LOCKS.get(key) | |
| if lock is None: | |
| lock = RLock() | |
| _MODEL_LOCKS[key] = lock | |
| return lock | |
| def infer_frame( | |
| frame: np.ndarray, | |
| queries: Sequence[str], | |
| detector_name: Optional[str] = None, | |
| ) -> tuple[np.ndarray, List[Dict[str, Any]]]: | |
| detector = load_detector(detector_name) | |
| text_queries = list(queries) or ["object"] | |
| try: | |
| lock = _get_model_lock("detector", detector.name) | |
| with lock: | |
| result = detector.predict(frame, text_queries) | |
| detections = _build_detection_records( | |
| result.boxes, result.scores, result.labels, text_queries, result.label_names | |
| ) | |
| except Exception: | |
| logging.exception("Inference failed for queries %s", text_queries) | |
| raise | |
| return draw_boxes(frame, result.boxes), detections | |
| def infer_segmentation_frame( | |
| frame: np.ndarray, | |
| text_queries: Optional[List[str]] = None, | |
| segmenter_name: Optional[str] = None, | |
| ) -> tuple[np.ndarray, Any]: | |
| segmenter = load_segmenter(segmenter_name) | |
| lock = _get_model_lock("segmenter", segmenter.name) | |
| with lock: | |
| result = segmenter.predict(frame, text_prompts=text_queries) | |
| return draw_masks(frame, result.masks), result | |
| def extract_first_frame(video_path: str) -> Tuple[np.ndarray, float, int, int]: | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| raise ValueError("Unable to open video.") | |
| fps = cap.get(cv2.CAP_PROP_FPS) or 0.0 | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| success, frame = cap.read() | |
| cap.release() | |
| if not success or frame is None: | |
| raise ValueError("Video decode produced zero frames.") | |
| return frame, fps, width, height | |
| def process_first_frame( | |
| video_path: str, | |
| queries: List[str], | |
| mode: str, | |
| detector_name: Optional[str] = None, | |
| segmenter_name: Optional[str] = None, | |
| ) -> Tuple[np.ndarray, List[Dict[str, Any]]]: | |
| frame, _, _, _ = extract_first_frame(video_path) | |
| if mode == "segmentation": | |
| processed, _ = infer_segmentation_frame( | |
| frame, text_queries=queries, segmenter_name=segmenter_name | |
| ) | |
| return processed, [] | |
| processed, detections = infer_frame( | |
| frame, queries, detector_name=detector_name | |
| ) | |
| return processed, detections | |
| def run_inference( | |
| input_video_path: str, | |
| output_video_path: str, | |
| queries: List[str], | |
| max_frames: Optional[int] = None, | |
| detector_name: Optional[str] = None, | |
| ) -> str: | |
| """ | |
| Run object detection inference on a video. | |
| Args: | |
| input_video_path: Path to input video | |
| output_video_path: Path to write processed video | |
| queries: List of object classes to detect (e.g., ["person", "car"]) | |
| max_frames: Optional frame limit for testing | |
| detector_name: Detector to use (default: hf_yolov8) | |
| Returns: | |
| Path to processed output video | |
| """ | |
| try: | |
| frames, fps, width, height = extract_frames(input_video_path) | |
| except ValueError as exc: | |
| logging.exception("Failed to decode video at %s", input_video_path) | |
| raise | |
| # Use provided queries or default to common objects | |
| if not queries: | |
| queries = ["person", "car", "truck", "motorcycle", "bicycle", "bus", "train", "airplane"] | |
| logging.info("No queries provided, using defaults: %s", queries) | |
| logging.info("Detection queries: %s", queries) | |
| # Select detector | |
| active_detector = detector_name or "hf_yolov8" | |
| logging.info("Using detector: %s", active_detector) | |
| # Process frames | |
| processed_frames: List[np.ndarray] = [] | |
| for idx, frame in enumerate(frames): | |
| if max_frames is not None and idx >= max_frames: | |
| break | |
| logging.debug("Processing frame %d", idx) | |
| processed_frame, _ = infer_frame(frame, queries, detector_name=active_detector) | |
| processed_frames.append(processed_frame) | |
| # Write output video | |
| write_video(processed_frames, output_video_path, fps=fps, width=width, height=height) | |
| logging.info("Processed video written to: %s", output_video_path) | |
| return output_video_path | |
| def run_segmentation( | |
| input_video_path: str, | |
| output_video_path: str, | |
| queries: List[str], | |
| max_frames: Optional[int] = None, | |
| segmenter_name: Optional[str] = None, | |
| ) -> str: | |
| try: | |
| frames, fps, width, height = extract_frames(input_video_path) | |
| except ValueError as exc: | |
| logging.exception("Failed to decode video at %s", input_video_path) | |
| raise | |
| active_segmenter = segmenter_name or "sam3" | |
| logging.info("Using segmenter: %s with queries: %s", active_segmenter, queries) | |
| processed_frames: List[np.ndarray] = [] | |
| for idx, frame in enumerate(frames): | |
| if max_frames is not None and idx >= max_frames: | |
| break | |
| logging.debug("Processing frame %d", idx) | |
| processed_frame, _ = infer_segmentation_frame(frame, text_queries=queries, segmenter_name=active_segmenter) | |
| processed_frames.append(processed_frame) | |
| write_video(processed_frames, output_video_path, fps=fps, width=width, height=height) | |
| logging.info("Segmented video written to: %s", output_video_path) | |
| return output_video_path | |