| """Hugging Face Spaces Gradio demo for ROI-VAE compression. |
| |
| This application provides both a web UI and programmatic API for: |
| - ROI-based image and video compression using TIC VAE |
| - Segmentation (SAM3, YOLO, SegFormer, Mask2Former, MaskRCNN) |
| - Object detection (YOLO, DETR, Grounding DINO, YOLO-World, etc.) |
| |
| Web UI Tabs: |
| Image Tab: |
| - Upload image → Choose ROI method → Set mission/prompt |
| - Find ROI: extracts regions of interest via segmentation |
| - Transmit: compresses with ROI preservation (σ controls background) |
| - Optional: ROI highlight overlay, detection overlays |
| |
| Video Tab: |
| - Upload video → Choose ROI method → Set mission/prompt |
| - Find ROI: segments all frames, saves masks for reuse |
| - Transmit: compresses with cached masks (3-5x faster on repeat) |
| - Optional: detection overlays on compressed output |
| |
| API Endpoints (see API.md for full documentation): |
| Image: |
| /segment - Segment image → mask or overlay |
| /compress - Compress image with optional ROI mask |
| /detect - Run object detection → JSON or overlay |
| /process - Full pipeline: segment → compress → detect |
| |
| Video (Buffered): |
| /segment_video - Segment video → mask file or overlay video |
| /compress_video - Compress video with optional cached masks |
| /detect_video - Run detection on video → JSON or overlay video |
| /process_video - Full pipeline with static/dynamic modes |
| |
| """ |
|
|
| from __future__ import annotations |
|
|
| import asyncio |
| import json |
| import os |
| import re |
| import sys |
| from dataclasses import dataclass |
| from functools import lru_cache |
| from pathlib import Path |
| from typing import List, Optional, Sequence |
|
|
| import numpy as np |
| import gradio as gr |
|
|
| |
| if sys.version_info >= (3, 13): |
| |
| os.environ.setdefault('PYTHONWARNINGS', 'ignore::ResourceWarning') |
| try: |
| |
| import warnings |
| warnings.filterwarnings("ignore", category=ResourceWarning) |
| warnings.filterwarnings("ignore", message=".*file descriptor.*") |
| except Exception: |
| pass |
|
|
| try: |
| import spaces |
| except Exception: |
| class _SpacesFallback: |
| @staticmethod |
| def GPU(*_args, **_kwargs): |
| def _decorator(fn): |
| return fn |
|
|
| |
| if len(_args) == 1 and callable(_args[0]) and not _kwargs: |
| |
| return _args[0] |
| |
| return _decorator |
|
|
| spaces = _SpacesFallback() |
|
|
| import vae |
| from model_cache import ensure_default_checkpoint_dirs |
|
|
|
|
| def _default_device() -> str: |
| try: |
| import torch |
|
|
| return "cuda" if torch.cuda.is_available() else "cpu" |
| except Exception: |
| return "cpu" |
|
|
|
|
| CHECKPOINTS = [ |
| |
| ("Smallest file", 0.0035, "checkpoints/tic_lambda_0.0035.pth.tar", 128, 192), |
| ("Smaller file", 0.013, "checkpoints/tic_lambda_0.013.pth.tar", 128, 192), |
| ("Balanced", 0.025, "checkpoints/tic_lambda_0.025.pth.tar", 192, 192), |
| ("Higher quality", 0.0483, "checkpoints/tic_lambda_0.0483.pth.tar", 192, 192), |
| ("Best quality", 0.0932, "checkpoints/tic_lambda_0.0932.pth.tar", 192, 192), |
| ] |
|
|
| |
| OPEN_VOCAB_DETECTORS = {"yolo_world", "grounding_dino"} |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| def _split_classes(classes_str: str) -> List[str]: |
| """Split comma-separated class string into list.""" |
| if not classes_str: |
| return [] |
| parts = [c.strip() for c in classes_str.split(",")] |
| return [c for c in parts if c] |
|
|
|
|
| @spaces.GPU |
| def api_segment( |
| image, |
| prompt: str = "object", |
| method: str = "sam3", |
| return_overlay: bool = False, |
| ) -> tuple: |
| """ |
| Segment an image to find Regions of Interest (ROI). |
| |
| Args: |
| image: Input image (PIL Image) |
| prompt: Text prompt or comma-separated classes (e.g., "person, car") |
| method: Segmentation method (sam3, yolo, segformer, mask2former, maskrcnn) |
| return_overlay: If True, returns image with ROI highlighted instead of mask |
| |
| Returns: |
| tuple: (result_image, roi_coverage, classes_used) |
| - result_image: Grayscale mask OR image with ROI overlay (if return_overlay=True) |
| - roi_coverage: Fraction of image covered by ROI (0.0-1.0) |
| - classes_used: JSON string of classes/prompts used |
| """ |
| from PIL import Image |
| import json |
| |
| if image is None: |
| return None, 0.0, "[]" |
| |
| ensure_default_checkpoint_dirs() |
| device = _default_device() |
| |
| |
| segmenter = _get_segmenter(method, device) |
| |
| |
| if method == "sam3": |
| targets = _split_classes(prompt) or ["object"] |
| else: |
| targets = _split_classes(prompt) |
| if not targets: |
| targets = segmenter.get_default_classes() |
| |
| |
| mask = segmenter(image, target_classes=targets) |
| mask = mask.astype(np.float32) |
| |
| roi_coverage = float(mask.mean()) |
| |
| |
| if return_overlay: |
| |
| result_image = vae.highlight_roi(image, mask, alpha=0.35, color=(0, 255, 0)) |
| else: |
| |
| mask_uint8 = (mask * 255).astype(np.uint8) |
| result_image = Image.fromarray(mask_uint8, mode="L") |
| |
| return result_image, roi_coverage, json.dumps(targets) |
|
|
|
|
| @spaces.GPU |
| def api_compress( |
| image, |
| mask_image = None, |
| quality: int = 4, |
| sigma: float = 0.3, |
| ) -> tuple: |
| """ |
| Compress an image with optional ROI preservation. |
| |
| Args: |
| image: Input image (PIL Image) |
| mask_image: Optional ROI mask (grayscale image, white=ROI) |
| quality: Quality level 1-5 (1=smallest file, 5=best quality) |
| sigma: Background preservation 0.01-1.0 (lower=more compression) |
| |
| Returns: |
| tuple: (compressed_image, bpp, compression_ratio) |
| - compressed_image: Reconstructed image after compression |
| - bpp: Bits per pixel |
| - compression_ratio: Compression ratio (24/bpp) |
| """ |
| if image is None: |
| return None, 0.0, 0.0 |
| |
| ensure_default_checkpoint_dirs() |
| device = _default_device() |
| |
| |
| idx = max(0, min(int(quality) - 1, len(CHECKPOINTS) - 1)) |
| _name, _lambda, ckpt_rel, N, M = CHECKPOINTS[idx] |
| |
| |
| model = _get_compression_model(ckpt_rel, device, N, M) |
| |
| |
| if mask_image is not None: |
| mask = np.array(mask_image.convert("L")).astype(np.float32) / 255.0 |
| |
| if mask.shape != (image.height, image.width): |
| from PIL import Image as PILImage |
| mask_resized = PILImage.fromarray((mask * 255).astype(np.uint8), mode="L") |
| mask_resized = mask_resized.resize((image.width, image.height), PILImage.NEAREST) |
| mask = np.array(mask_resized).astype(np.float32) / 255.0 |
| else: |
| mask = np.zeros((image.height, image.width), dtype=np.float32) |
| |
| |
| result = vae.compress_image( |
| image=image, |
| mask=mask, |
| model=model, |
| sigma=float(sigma), |
| device=device, |
| ) |
| |
| bpp = float(result["bpp"]) |
| compression_ratio = 24.0 / bpp if bpp > 0 else 0.0 |
| |
| return result["compressed"], bpp, compression_ratio |
|
|
|
|
| @spaces.GPU |
| def api_detect( |
| image, |
| method: str = "yolo", |
| classes: str = "", |
| confidence: float = 0.25, |
| return_overlay: bool = False, |
| ): |
| """ |
| Run object detection on an image. |
| |
| Args: |
| image: Input image (PIL Image) |
| method: Detection method (yolo, yolo_world, grounding_dino, detr, etc.) |
| classes: Comma-separated classes (required for yolo_world, grounding_dino) |
| confidence: Confidence threshold 0.0-1.0 |
| return_overlay: If True, returns tuple (image_with_boxes, json). If False, returns json only. |
| |
| Returns: |
| If return_overlay=False (default): str - JSON string with list of detections |
| If return_overlay=True: tuple(Image, str) - Image with boxes and JSON string |
| """ |
| import json |
| |
| if image is None: |
| return (None, "[]") if return_overlay else "[]" |
| |
| ensure_default_checkpoint_dirs() |
| device = _default_device() |
| |
| detector = _get_detector(method, device) |
| det_kwargs = {"conf_threshold": float(confidence)} |
| |
| |
| class_list = _split_classes(classes) |
| if method in OPEN_VOCAB_DETECTORS: |
| if not class_list: |
| return (image, "[]") if return_overlay else "[]" |
| det_kwargs["classes"] = class_list |
| |
| dets = detector(image, **det_kwargs) |
| |
| results = [ |
| { |
| "label": d.label, |
| "score": float(d.score), |
| "bbox_xyxy": [float(x) for x in d.bbox_xyxy], |
| } |
| for d in dets |
| ] |
| |
| detections_json = json.dumps(results) |
| |
| if return_overlay: |
| from detection.utils import draw_detections |
| image_with_dets = draw_detections(image.copy(), dets, color=(0, 255, 0)) |
| return image_with_dets, detections_json |
| |
| |
| return detections_json |
|
|
|
|
| @spaces.GPU |
| def api_detect_overlay( |
| image, |
| method: str = "yolo", |
| classes: str = "", |
| confidence: float = 0.25, |
| ): |
| """ |
| Run object detection on an image and return image with bounding boxes. |
| |
| This is a separate endpoint for getting detection overlays. |
| For JSON-only results, use /detect. |
| |
| Args: |
| image: Input image (PIL Image) |
| method: Detection method (yolo, yolo_world, grounding_dino, detr, etc.) |
| classes: Comma-separated classes (required for yolo_world, grounding_dino) |
| confidence: Confidence threshold 0.0-1.0 |
| |
| Returns: |
| tuple(Image, str) - Image with bounding boxes and JSON string |
| """ |
| return api_detect(image, method, classes, confidence, return_overlay=True) |
|
|
|
|
| @spaces.GPU |
| def api_process( |
| image, |
| prompt: str = "object", |
| segmentation_method: str = "sam3", |
| quality: int = 4, |
| sigma: float = 0.3, |
| run_detection: bool = False, |
| detection_method: str = "yolo", |
| detection_classes: str = "", |
| ) -> tuple: |
| """ |
| Full processing pipeline: segment → compress → (optionally) detect. |
| |
| Args: |
| image: Input image (PIL Image) |
| prompt: Text prompt or comma-separated classes for segmentation |
| segmentation_method: ROI extraction method (sam3, yolo, segformer, etc.) |
| quality: Compression quality 1-5 |
| sigma: Background preservation 0.01-1.0 |
| run_detection: Whether to run detection on output |
| detection_method: Object detector to use |
| detection_classes: Classes for open-vocab detectors |
| |
| Returns: |
| tuple: (compressed_image, mask_image, bpp, compression_ratio, roi_coverage, detections_json) |
| """ |
| import json |
| |
| if image is None: |
| return None, None, 0.0, 0.0, 0.0, "[]" |
| |
| |
| mask_image, roi_coverage, classes_json = api_segment(image, prompt, segmentation_method) |
| |
| |
| compressed, bpp, compression_ratio = api_compress(image, mask_image, quality, sigma) |
| |
| |
| detections_json = "[]" |
| if run_detection and compressed is not None: |
| det_classes = detection_classes or prompt |
| detections_json = api_detect(compressed, detection_method, det_classes) |
| |
| return compressed, mask_image, bpp, compression_ratio, roi_coverage, detections_json |
|
|
|
|
| @spaces.GPU |
| def api_process_video( |
| video_path: str, |
| prompt: str = "object", |
| segmentation_method: str = "sam3", |
| mode: str = "static", |
| quality: int = 4, |
| sigma: float = 0.3, |
| output_fps: float = 15.0, |
| bandwidth_kbps: float = 500.0, |
| min_fps: float = 5.0, |
| max_fps: float = 30.0, |
| aggressiveness: float = 0.5, |
| run_detection: bool = False, |
| detection_method: str = "yolo", |
| mask_file_path: Optional[str] = None, |
| ) -> tuple: |
| """ |
| Process video with ROI-based compression. |
| |
| Args: |
| video_path: Path to input video file |
| prompt: Text prompt or comma-separated classes for segmentation |
| segmentation_method: ROI extraction method (sam3, yolo, segformer, etc.) |
| mode: "static" (fixed settings) or "dynamic" (bandwidth-adaptive) |
| quality: Compression quality 1-5 (static mode) |
| sigma: Background preservation 0.01-1.0 (static mode) |
| output_fps: Target output framerate (static mode) |
| bandwidth_kbps: Target bandwidth in kbps (dynamic mode) |
| min_fps: Minimum framerate (dynamic mode) |
| max_fps: Maximum framerate (dynamic mode) |
| aggressiveness: Bandwidth savings strategy 0.0-1.0 (dynamic mode) |
| - 0.0: Use full bandwidth, maintain high FPS always |
| - 0.5: Moderate savings (default) |
| - 1.0: Maximum savings, aggressive FPS reduction for low motion |
| run_detection: Whether to run detection/tracking |
| detection_method: Object detector to use |
| mask_file_path: Optional path to pre-computed segmentation masks (skips segmentation) |
| |
| Returns: |
| tuple: (output_video_path, stats_json) |
| - output_video_path: Path to compressed video |
| - stats_json: JSON string with compression statistics |
| """ |
| import json |
| import tempfile |
| |
| if video_path is None: |
| return None, "{}" |
| |
| |
| if segmentation_method is None or not segmentation_method.strip(): |
| segmentation_method = "sam3" |
| if detection_method is None or not detection_method.strip(): |
| detection_method = "yolo" |
| if prompt is None: |
| prompt = "object" |
| if mode is None or not mode.strip(): |
| mode = "static" |
| |
| ensure_default_checkpoint_dirs() |
| device = _default_device() |
| |
| target_classes = _split_classes(prompt) if prompt else [] |
| |
| try: |
| from video import VideoProcessor, CompressionSettings, load_video_masks |
| |
| |
| saved_masks = None |
| if mask_file_path is not None: |
| try: |
| |
| actual_path = mask_file_path |
| if isinstance(mask_file_path, dict): |
| actual_path = mask_file_path.get("path", mask_file_path.get("name")) |
| saved_masks = load_video_masks(actual_path) |
| print(f"API: Loaded {len(saved_masks)} cached masks") |
| except Exception as e: |
| print(f"API: Failed to load masks: {e}, will re-segment") |
| |
| processor = VideoProcessor(device=device) |
| processor.load_models( |
| quality_level=quality if mode == "static" else 3, |
| segmentation_method=segmentation_method, |
| detection_method=detection_method if run_detection else "yolo", |
| enable_tracking=run_detection, |
| ) |
| |
| settings = CompressionSettings( |
| mode=mode, |
| quality_level=quality, |
| sigma=sigma, |
| output_fps=output_fps, |
| target_bandwidth_kbps=bandwidth_kbps, |
| min_fps=min_fps, |
| max_fps=max_fps, |
| aggressiveness=aggressiveness, |
| segmentation_method=segmentation_method, |
| target_classes=target_classes, |
| enable_tracking=run_detection, |
| ) |
| |
| |
| all_frames = [] |
| all_stats = [] |
| |
| if mode == "static": |
| all_chunks = processor.process_static_offline( |
| video_path, settings, saved_masks=saved_masks, |
| ) |
| else: |
| all_chunks = processor.process_dynamic_offline( |
| video_path, settings, saved_masks=saved_masks, |
| ) |
| |
| for chunk in all_chunks: |
| all_frames.extend(chunk.frames) |
| all_stats.append({ |
| "chunk_index": chunk.chunk_index, |
| "frames": len(chunk.frames), |
| "fps": round(chunk.fps, 1), |
| "quality": chunk.quality_level, |
| "sigma": round(chunk.sigma, 2), |
| "avg_bpp": round(chunk.avg_bpp, 3), |
| }) |
| |
| |
| if all_frames: |
| from video.video_processor import frames_to_video_bytes |
| |
| avg_fps = sum(s["fps"] for s in all_stats) / len(all_stats) if all_stats else 15 |
| video_bytes = frames_to_video_bytes(all_frames, avg_fps, format="mp4") |
| |
| with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp: |
| tmp.write(video_bytes) |
| output_path = tmp.name |
| |
| stats = { |
| "total_frames": len(all_frames), |
| "avg_fps": round(avg_fps, 1), |
| "chunks": all_stats, |
| } |
| |
| |
| del all_frames, all_chunks |
| import gc |
| import torch |
| gc.collect() |
| torch.cuda.empty_cache() |
| |
| return output_path, json.dumps(stats) |
| |
| return None, "{}" |
| |
| except Exception as e: |
| return None, json.dumps({"error": str(e)}) |
| finally: |
| |
| if 'processor' in locals(): |
| processor.cleanup() |
|
|
|
|
| @spaces.GPU |
| def api_segment_video( |
| video_path: str, |
| prompt: str = "object", |
| method: str = "sam3", |
| return_overlay: bool = False, |
| output_fps: float = 15.0, |
| ) -> tuple: |
| """ |
| Segment a video to find Regions of Interest (ROI). |
| |
| Args: |
| video_path: Path to input video file |
| prompt: Text prompt or comma-separated classes (e.g., "person, car") |
| method: Segmentation method (sam3, yolo, segformer, mask2former, maskrcnn) |
| return_overlay: If True, returns video with ROI highlighted instead of mask file |
| output_fps: Output framerate (max 30 FPS) |
| |
| Returns: |
| tuple: (result_path, stats_json) |
| - result_path: Path to mask file (NPZ) OR video with ROI overlay (if return_overlay=True) |
| - stats_json: JSON string with frame count, coverage stats, and classes used |
| """ |
| import json |
| import tempfile |
| import cv2 |
| from PIL import Image |
| |
| if video_path is None: |
| return None, "{}" |
| |
| ensure_default_checkpoint_dirs() |
| device = _default_device() |
| |
| target_classes = _split_classes(prompt) if prompt else ["object"] |
| |
| try: |
| from video.video_processor import frames_to_video_bytes, MAX_PROCESSING_FPS, MAX_PROCESSING_HEIGHT |
| from video.mask_cache import save_video_masks |
| from video.chunk_compressor import smooth_masks_sdf |
| from vae.visualization import highlight_roi |
| from segmentation import create_segmenter |
| |
| |
| segmenter = create_segmenter(method, device=device) |
| |
| |
| cap = cv2.VideoCapture(video_path) |
| original_fps = cap.get(cv2.CAP_PROP_FPS) or 30.0 |
| original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
| original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
| |
| |
| effective_fps = min(output_fps, MAX_PROCESSING_FPS) |
| frame_step = max(1, int(original_fps / effective_fps)) |
| |
| |
| if original_height > MAX_PROCESSING_HEIGHT: |
| scale_factor = MAX_PROCESSING_HEIGHT / original_height |
| new_width = int(original_width * scale_factor) |
| new_height = MAX_PROCESSING_HEIGHT |
| else: |
| new_width = original_width |
| new_height = original_height |
| |
| |
| pil_frames = [] |
| frame_idx = 0 |
| |
| while True: |
| ret, frame = cap.read() |
| if not ret: |
| break |
| |
| |
| if frame_idx % frame_step != 0: |
| frame_idx += 1 |
| continue |
| |
| |
| if new_width != original_width or new_height != original_height: |
| frame = cv2.resize(frame, (new_width, new_height), interpolation=cv2.INTER_AREA) |
| |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| pil_frames.append(Image.fromarray(frame_rgb)) |
| frame_idx += 1 |
| |
| cap.release() |
| |
| if not pil_frames: |
| return None, json.dumps({"error": "No frames extracted from video"}) |
| |
| |
| from video.gpu_memory import estimate_batch_sizes |
| batch_est = estimate_batch_sizes( |
| frame_height=new_height, |
| frame_width=new_width, |
| seg_method=method, |
| device=str(device), |
| total_frames=len(pil_frames), |
| ) |
| seg_batch = batch_est.seg_batch_size |
| print(f"API segment_video: {len(pil_frames)} frames, seg batch={seg_batch} ({batch_est.notes})") |
| |
| |
| import torch |
| max_retries = 7 |
| all_masks = None |
| for attempt in range(max_retries + 1): |
| try: |
| all_masks = [] |
| if hasattr(segmenter, 'segment_batch') and getattr(segmenter, 'supports_batch', False): |
| for i in range(0, len(pil_frames), seg_batch): |
| batch = pil_frames[i:i + seg_batch] |
| batch_masks = segmenter.segment_batch(batch, target_classes=target_classes) |
| all_masks.extend([m.astype('float32') for m in batch_masks]) |
| else: |
| for pil_frame in pil_frames: |
| mask = segmenter(pil_frame, target_classes=target_classes) |
| all_masks.append(mask.astype('float32')) |
| break |
| except (torch.cuda.OutOfMemoryError, RuntimeError) as e: |
| if 'out of memory' in str(e).lower() and attempt < max_retries: |
| seg_batch = max(1, seg_batch // 2) |
| |
| all_masks = None |
| import gc |
| gc.collect() |
| torch.cuda.empty_cache() |
| torch.cuda.synchronize() |
| print(f"API segment_video: OOM, retrying with batch={seg_batch} (attempt {attempt+1}/{max_retries})") |
| continue |
| raise |
| |
| if all_masks is None: |
| return None, json.dumps({"error": "Segmentation failed after OOM retries"}) |
| |
| processed_frames = len(all_masks) |
| coverage_sum = sum(float(m.mean()) for m in all_masks) |
| |
| |
| overlay_frames = None |
| if return_overlay: |
| overlay_frames = [] |
| for pil_frame, mask in zip(pil_frames, all_masks): |
| if mask.sum() > 0: |
| highlighted = highlight_roi(pil_frame, mask, alpha=0.35, color=(0, 255, 0)) |
| else: |
| highlighted = pil_frame |
| overlay_frames.append(highlighted) |
| |
| |
| all_masks = smooth_masks_sdf(all_masks, alpha=0.5, empty_thresh=10, patience=5) |
| |
| avg_coverage = coverage_sum / processed_frames if processed_frames > 0 else 0.0 |
| |
| stats = { |
| "total_frames": processed_frames, |
| "fps": round(effective_fps, 1), |
| "dimensions": [new_width, new_height], |
| "avg_roi_coverage": round(avg_coverage, 4), |
| "classes_used": target_classes, |
| } |
| |
| if return_overlay: |
| |
| video_bytes = frames_to_video_bytes(overlay_frames, effective_fps, format="mp4") |
| with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp: |
| tmp.write(video_bytes) |
| output_path = tmp.name |
| else: |
| |
| output_path = save_video_masks(all_masks) |
| stats["mask_file"] = output_path |
| |
| return output_path, json.dumps(stats) |
| |
| except Exception as e: |
| return None, json.dumps({"error": str(e)}) |
|
|
|
|
| @spaces.GPU |
| def api_compress_video( |
| video_path: str, |
| mask_file_path: Optional[str] = None, |
| quality: int = 4, |
| sigma: float = 0.3, |
| output_fps: float = 15.0, |
| ) -> tuple: |
| """ |
| Compress a video with optional ROI preservation. |
| |
| Args: |
| video_path: Path to input video file |
| mask_file_path: Optional path to pre-computed masks (from api_segment_video) |
| quality: Quality level 1-5 (1=smallest file, 5=best quality) |
| sigma: Background preservation 0.01-1.0 (lower=more compression) |
| output_fps: Target output framerate |
| |
| Returns: |
| tuple: (compressed_video_path, stats_json) |
| - compressed_video_path: Path to compressed video |
| - stats_json: JSON string with compression statistics (bpp, ratio, etc.) |
| """ |
| import json |
| import tempfile |
| |
| if video_path is None: |
| return None, "{}" |
| |
| ensure_default_checkpoint_dirs() |
| device = _default_device() |
| |
| try: |
| from video import VideoProcessor, CompressionSettings, load_video_masks |
| from video.video_processor import frames_to_video_bytes |
| |
| |
| saved_masks = None |
| if mask_file_path is not None: |
| try: |
| |
| actual_path = mask_file_path |
| if isinstance(mask_file_path, dict): |
| actual_path = mask_file_path.get("path", mask_file_path.get("name")) |
| saved_masks = load_video_masks(actual_path) |
| print(f"API: Loaded {len(saved_masks)} cached masks") |
| except Exception as e: |
| print(f"API: Failed to load masks: {e}, using empty masks") |
| |
| |
| processor = VideoProcessor(device=device) |
| processor.load_models( |
| quality_level=quality, |
| segmentation_method="sam3", |
| detection_method="yolo", |
| enable_tracking=False, |
| ) |
| |
| |
| settings = CompressionSettings( |
| mode="static", |
| quality_level=quality, |
| sigma=sigma, |
| output_fps=output_fps, |
| segmentation_method="sam3", |
| target_classes=[], |
| enable_tracking=False, |
| ) |
| |
| |
| all_chunks = processor.process_static_offline(video_path, settings, saved_masks=saved_masks) |
| |
| all_frames = [] |
| all_stats = [] |
| total_bytes = 0 |
| total_original_frames = 0 |
| |
| for chunk in all_chunks: |
| all_frames.extend(chunk.frames) |
| total_bytes += chunk.estimated_bytes |
| total_original_frames += chunk.original_frame_count |
| all_stats.append({ |
| "chunk_index": chunk.chunk_index, |
| "frames": len(chunk.frames), |
| "fps": round(chunk.fps, 1), |
| "quality": chunk.quality_level, |
| "sigma": round(chunk.sigma, 2), |
| "avg_bpp": round(chunk.avg_bpp, 3), |
| }) |
| |
| if not all_frames: |
| return None, json.dumps({"error": "No frames produced"}) |
| |
| |
| avg_fps = sum(s["fps"] for s in all_stats) / len(all_stats) if all_stats else output_fps |
| |
| |
| video_bytes = frames_to_video_bytes(all_frames, avg_fps, format="mp4") |
| with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp: |
| tmp.write(video_bytes) |
| output_path = tmp.name |
| |
| |
| compression_ratio = 0.0 |
| if processor.video_dimensions and total_bytes > 0: |
| w, h = processor.video_dimensions |
| compression_ratio = (24 * total_original_frames * w * h) / (8 * total_bytes) |
| |
| stats = { |
| "total_frames": len(all_frames), |
| "original_frames": total_original_frames, |
| "avg_fps": round(avg_fps, 1), |
| "total_size_kb": round(total_bytes / 1024, 1), |
| "avg_bpp": round(sum(s["avg_bpp"] for s in all_stats) / max(1, len(all_stats)), 3), |
| "compression_ratio": round(compression_ratio, 2), |
| "used_cached_masks": mask_file_path is not None and saved_masks is not None, |
| "chunks": all_stats, |
| } |
| |
| return output_path, json.dumps(stats) |
| |
| except Exception as e: |
| return None, json.dumps({"error": str(e)}) |
| finally: |
| |
| if 'processor' in locals(): |
| processor.cleanup() |
|
|
|
|
| @spaces.GPU |
| def api_detect_video( |
| video_path: str, |
| method: str = "yolo", |
| classes: str = "", |
| confidence: float = 0.25, |
| return_overlay: bool = False, |
| output_fps: float = 15.0, |
| ) -> tuple: |
| """ |
| Run object detection on a video. |
| |
| Args: |
| video_path: Path to input video file |
| method: Detection method (yolo, yolo_world, grounding_dino, detr, etc.) |
| classes: Comma-separated classes (required for yolo_world, grounding_dino) |
| confidence: Confidence threshold 0.0-1.0 |
| return_overlay: If True, returns video with detection boxes instead of JSON |
| output_fps: Output framerate (max 30 FPS) |
| |
| Returns: |
| tuple: (result_path, detections_json) |
| - result_path: Video with detection boxes (if return_overlay=True), None otherwise |
| - detections_json: JSON string with per-frame detections list |
| """ |
| import json |
| import tempfile |
| import cv2 |
| from PIL import Image |
| |
| if video_path is None: |
| return None, "[]" |
| |
| ensure_default_checkpoint_dirs() |
| device = _default_device() |
| |
| try: |
| from video.video_processor import frames_to_video_bytes, MAX_PROCESSING_FPS, MAX_PROCESSING_HEIGHT |
| from detection import create_detector |
| from detection.utils import draw_detections |
| |
| |
| detector = create_detector(method, device=device) |
| |
| |
| det_kwargs = {"conf_threshold": float(confidence)} |
| class_list = _split_classes(classes) |
| if method in OPEN_VOCAB_DETECTORS: |
| if not class_list: |
| return None, json.dumps({"error": f"{method} requires class prompts"}) |
| det_kwargs["classes"] = class_list |
| |
| |
| cap = cv2.VideoCapture(video_path) |
| original_fps = cap.get(cv2.CAP_PROP_FPS) or 30.0 |
| original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
| original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
| |
| |
| effective_fps = min(output_fps, MAX_PROCESSING_FPS) |
| frame_step = max(1, int(original_fps / effective_fps)) |
| |
| |
| if original_height > MAX_PROCESSING_HEIGHT: |
| scale_factor = MAX_PROCESSING_HEIGHT / original_height |
| new_width = int(original_width * scale_factor) |
| new_height = MAX_PROCESSING_HEIGHT |
| else: |
| new_width = original_width |
| new_height = original_height |
| |
| all_detections = [] |
| overlay_frames = [] if return_overlay else None |
| frame_idx = 0 |
| processed_frames = 0 |
| total_detections = 0 |
| |
| while True: |
| ret, frame = cap.read() |
| if not ret: |
| break |
| |
| |
| if frame_idx % frame_step != 0: |
| frame_idx += 1 |
| continue |
| |
| |
| if new_width != original_width or new_height != original_height: |
| frame = cv2.resize(frame, (new_width, new_height), interpolation=cv2.INTER_AREA) |
| |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| pil_frame = Image.fromarray(frame_rgb) |
| |
| |
| dets = detector(pil_frame, **det_kwargs) |
| |
| frame_detections = [ |
| { |
| "label": d.label, |
| "score": float(d.score), |
| "bbox_xyxy": [float(x) for x in d.bbox_xyxy], |
| } |
| for d in dets |
| ] |
| all_detections.append({ |
| "frame_index": processed_frames, |
| "detections": frame_detections, |
| }) |
| total_detections += len(dets) |
| |
| |
| if return_overlay: |
| frame_with_dets = draw_detections(pil_frame, dets, color=(0, 255, 0)) |
| overlay_frames.append(frame_with_dets) |
| |
| processed_frames += 1 |
| frame_idx += 1 |
| |
| cap.release() |
| |
| if processed_frames == 0: |
| return None, json.dumps({"error": "No frames extracted from video"}) |
| |
| result_data = { |
| "total_frames": processed_frames, |
| "total_detections": total_detections, |
| "avg_detections_per_frame": round(total_detections / processed_frames, 2), |
| "fps": round(effective_fps, 1), |
| "dimensions": [new_width, new_height], |
| "frames": all_detections, |
| } |
| |
| if return_overlay: |
| |
| video_bytes = frames_to_video_bytes(overlay_frames, effective_fps, format="mp4") |
| with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp: |
| tmp.write(video_bytes) |
| output_path = tmp.name |
| return output_path, json.dumps(result_data) |
| else: |
| return None, json.dumps(result_data) |
| |
| except Exception as e: |
| return None, json.dumps({"error": str(e)}) |
|
|
|
|
| |
| |
| |
|
|
| @dataclass(frozen=True) |
| class _ExtractResult: |
| classes: List[str] |
| note: str |
|
|
|
|
| def _dets_to_dicts(dets) -> List[dict]: |
| out = [] |
| for d in dets or []: |
| out.append( |
| { |
| "label": d.label, |
| "score": float(d.score), |
| "bbox_xyxy": [float(x) for x in d.bbox_xyxy], |
| } |
| ) |
| return out |
|
|
|
|
| def _dicts_to_dets(dets_dicts): |
| from detection.base import Detection |
|
|
| out = [] |
| for dd in dets_dicts or []: |
| try: |
| out.append( |
| Detection( |
| label=str(dd["label"]), |
| score=float(dd["score"]), |
| bbox_xyxy=[float(x) for x in dd["bbox_xyxy"]], |
| ) |
| ) |
| except Exception: |
| continue |
| return out |
|
|
|
|
| def _render_input_preview(original_image, mask, highlight_roi: bool, show_detection: bool, dets_dicts): |
| if original_image is None: |
| return None |
| img = original_image.copy() |
| if highlight_roi and mask is not None: |
| img = vae.highlight_roi(img, mask) |
| if show_detection and dets_dicts: |
| from detection.utils import draw_detections |
|
|
| img = draw_detections(img, _dicts_to_dets(dets_dicts)) |
| return img |
|
|
|
|
| def _render_output_preview(compressed_image, mask, highlight_roi: bool, show_detection: bool, dets_dicts): |
| if compressed_image is None: |
| return None |
| img = compressed_image.copy() |
| if highlight_roi and mask is not None: |
| img = vae.highlight_roi(img, mask) |
| if show_detection and dets_dicts: |
| from detection.utils import draw_detections |
|
|
| img = draw_detections(img, _dicts_to_dets(dets_dicts), color=(255, 0, 0)) |
| return img |
|
|
|
|
| def _split_comma_list(text: str) -> List[str]: |
| parts = [p.strip() for p in (text or "").split(",")] |
| return [p for p in parts if p] |
|
|
|
|
| def _heuristic_extract(mission: str, allowed: Sequence[str], max_classes: int = 6) -> List[str]: |
| mission_l = (mission or "").lower() |
| if not mission_l.strip(): |
| return [] |
|
|
| allowed_l = {c.lower(): c for c in allowed} |
|
|
| |
| hits: List[str] = [] |
| for low, orig in allowed_l.items(): |
| if re.search(rf"\b{re.escape(low)}\b", mission_l): |
| hits.append(orig) |
|
|
| |
| hits = sorted(set(hits), key=lambda x: x.lower()) |
| return hits[:max_classes] |
|
|
|
|
| def _extract_classes_via_openai(mission: str, allowed: Sequence[str], max_classes: int = 6) -> _ExtractResult: |
| """Use OpenAI (if configured) to map mission → allowed class labels. |
| |
| If `OPENAI_API_KEY` is missing or the request fails, falls back to a heuristic. |
| """ |
|
|
| allowed = list(dict.fromkeys([str(a).strip() for a in allowed if str(a).strip()])) |
| if not (mission or "").strip(): |
| return _ExtractResult(classes=[], note="") |
|
|
| api_key = os.getenv("OPENAI_API_KEY") |
| if not api_key: |
| classes = _heuristic_extract(mission, allowed, max_classes=max_classes) |
| note = "OpenAI not configured (set OPENAI_API_KEY); using heuristic extraction." |
| return _ExtractResult(classes=classes, note=note) |
|
|
| model = os.getenv("OPENAI_MODEL", "gpt-4o-mini") |
|
|
| |
| allowed_json = json.dumps(allowed, ensure_ascii=False) |
|
|
| system = ( |
| "You extract segmentation class labels from a mission description. " |
| "Return ONLY JSON and only choose labels from the provided allowed list." |
| ) |
| user = ( |
| "Mission:\n" |
| f"{mission}\n\n" |
| "Allowed labels (JSON array):\n" |
| f"{allowed_json}\n\n" |
| f"Return JSON exactly like: {{\"classes\": [..]}} with at most {max_classes} labels." |
| ) |
|
|
| try: |
| from openai import OpenAI |
|
|
| client = OpenAI(api_key=api_key) |
| resp = client.chat.completions.create( |
| model=model, |
| messages=[ |
| {"role": "system", "content": system}, |
| {"role": "user", "content": user}, |
| ], |
| temperature=0, |
| ) |
| content = (resp.choices[0].message.content or "").strip() |
|
|
| |
| m = re.search(r"\{.*\}", content, flags=re.DOTALL) |
| if not m: |
| raise ValueError("OpenAI did not return JSON") |
| data = json.loads(m.group(0)) |
| raw = data.get("classes") |
| if not isinstance(raw, list): |
| raise ValueError("OpenAI JSON missing 'classes' list") |
|
|
| allowed_l = {a.lower(): a for a in allowed} |
| out: List[str] = [] |
| for item in raw: |
| s = str(item).strip() |
| if not s: |
| continue |
| key = s.lower() |
| if key in allowed_l: |
| out.append(allowed_l[key]) |
|
|
| out = list(dict.fromkeys(out))[:max_classes] |
| return _ExtractResult(classes=out, note="") |
| except Exception as e: |
| classes = _heuristic_extract(mission, allowed, max_classes=max_classes) |
| note = f"OpenAI extraction failed; using heuristic. ({type(e).__name__})" |
| return _ExtractResult(classes=classes, note=note) |
|
|
|
|
| @lru_cache(maxsize=16) |
| def _get_segmenter(method: str, device: str): |
| from segmentation import create_segmenter |
|
|
| return create_segmenter(method, device=device) |
|
|
|
|
| @lru_cache(maxsize=8) |
| def _get_detector(method: str, device: str): |
| from detection import create_detector |
|
|
| return create_detector(method, device=device) |
|
|
|
|
| @lru_cache(maxsize=8) |
| def _get_compression_model(checkpoint_rel: str, device: str, N: int, M: int): |
| ckpt_path = str(Path(__file__).parent / checkpoint_rel) |
| return vae.load_checkpoint(ckpt_path, N=int(N), M=int(M), device=device) |
|
|
|
|
| def _format_classes_md( |
| segmenter: str, |
| mission: str, |
| extracted_classes: Sequence[str], |
| used_classes: Sequence[str], |
| note: str, |
| sam3_prompts: Optional[Sequence[str]] = None, |
| ) -> str: |
| if segmenter == "sam3": |
| prompts = list(sam3_prompts or []) |
| if not prompts: |
| prompt = (mission or "").strip() or "object" |
| prompts = [prompt] |
| joined = ", ".join([f"`{p}`" for p in prompts]) |
| return f"**Prompts (sam3):** {joined}" |
|
|
| extracted = list(extracted_classes or []) |
| used = list(used_classes or []) |
|
|
| extracted_txt = ", ".join([f"`{c}`" for c in extracted]) if extracted else "(none)" |
| used_txt = ", ".join([f"`{c}`" for c in used]) if used else "(none)" |
| suffix = f"\n\n*{note}*" if note else "" |
| |
| return f"**Extracted classes:** {extracted_txt}" |
|
|
|
|
| def _compute_roi_mask( |
| image, |
| segmenter_method: str, |
| mission: str, |
| device: str, |
| ): |
| segmenter = _get_segmenter(segmenter_method, device) |
|
|
| extracted = _ExtractResult(classes=[], note="") |
| if segmenter_method == "sam3": |
| |
| |
| raw = (mission or "").strip() |
| if not raw: |
| target = ["object"] |
| else: |
| parts = re.split(r",|;|\n|\band\b", raw, flags=re.IGNORECASE) |
| parts = [p.strip() for p in parts if p.strip()] |
| base_prompts = parts[:6] if parts else [raw] |
|
|
| |
| expanded: List[str] = [] |
| for p in base_prompts: |
| expanded.append(p) |
| if " " not in p and len(p) >= 3: |
| expanded.append(f"a {p}") |
| expanded.append(f"the {p}") |
| |
| seen = set() |
| target = [] |
| for p in expanded: |
| k = p.lower() |
| if k in seen: |
| continue |
| seen.add(k) |
| target.append(p) |
| target = target[:8] if target else [raw] |
| else: |
| avail = segmenter.get_available_classes() |
| if isinstance(avail, dict): |
| allowed = list(avail.keys()) |
| elif isinstance(avail, (list, tuple, set)): |
| allowed = list(avail) |
| else: |
| allowed = [] |
|
|
| extracted = _extract_classes_via_openai(mission or "", allowed) |
| target = extracted.classes or segmenter.get_default_classes() |
|
|
| mask = segmenter(image, target_classes=target) |
| mask = mask.astype(np.float32) |
|
|
| used = target if segmenter_method != "sam3" else [] |
| extracted_classes = extracted.classes if segmenter_method != "sam3" else [] |
| classes_md = _format_classes_md( |
| segmenter_method, |
| mission or "", |
| extracted_classes=extracted_classes, |
| used_classes=used, |
| note=extracted.note, |
| sam3_prompts=target if segmenter_method == "sam3" else None, |
| ) |
|
|
| open_vocab_classes = extracted.classes if segmenter_method != "sam3" else list(target) |
| return mask, target, extracted, classes_md, open_vocab_classes |
|
|
|
|
| def _on_upload(image): |
| if image is None: |
| return None, None, None, None, None, "", "", None |
| |
| return image.copy(), None, None, None, None, "", "", None |
|
|
|
|
| def _compute_detections_dicts( |
| image, |
| detector_method: str, |
| device: str, |
| open_vocab_classes, |
| segmenter_method: str, |
| mission: str, |
| custom_det_classes: str = "", |
| ): |
| detector = _get_detector(detector_method, device) |
| det_kwargs = {} |
| if detector_method in {"yolo_world", "grounding_dino"}: |
| |
| if custom_det_classes and custom_det_classes.strip(): |
| classes = _split_comma_list(custom_det_classes) |
| else: |
| |
| classes = list(open_vocab_classes or []) |
| if segmenter_method == "sam3" and not classes: |
| classes = _split_comma_list(mission) |
| if not classes: |
| |
| return [] |
| det_kwargs["classes"] = classes |
| dets = detector(image, conf_threshold=0.25, **det_kwargs) |
| return _dets_to_dicts(dets) |
|
|
|
|
| @spaces.GPU |
| def _find_roi( |
| original_image, |
| segmenter_method: str, |
| mission: str, |
| input_highlight_roi: bool, |
| input_enable_detection: bool, |
| input_detector_method: str, |
| input_det_classes: str = "", |
| ): |
| if original_image is None: |
| return None, None, None, None, "Please upload an image." |
|
|
| ensure_default_checkpoint_dirs() |
| device = _default_device() |
|
|
| base = original_image.copy() |
| try: |
| mask, _target, _extracted, classes_md, open_vocab_classes = _compute_roi_mask( |
| image=base, |
| segmenter_method=segmenter_method, |
| mission=mission, |
| device=device, |
| ) |
|
|
| roi_frac = float(mask.mean()) if mask is not None else 0.0 |
| if roi_frac <= 1e-6: |
| classes_md = classes_md + "\n\n**Warning:** ROI mask is empty. Try a simpler prompt (e.g., `person`, `car`) or switch ROI method." |
|
|
| dets_dicts = None |
| if input_enable_detection: |
| dets_dicts = _compute_detections_dicts( |
| image=base, |
| detector_method=input_detector_method, |
| device=device, |
| open_vocab_classes=open_vocab_classes, |
| segmenter_method=segmenter_method, |
| mission=mission, |
| custom_det_classes=input_det_classes, |
| ) |
|
|
| viz = _render_input_preview(base, mask, input_highlight_roi, input_enable_detection, dets_dicts) |
| status = f"**ROI:** computed. (ROI coverage: {roi_frac*100:.2f}%)" |
| return viz, mask, open_vocab_classes, dets_dicts, classes_md, status |
| except Exception as e: |
| return base, None, None, None, "", f"Error: {type(e).__name__}: {e}" |
|
|
|
|
| @spaces.GPU |
| def _compress( |
| original_image, |
| mask, |
| open_vocab_classes, |
| segmenter_method: str, |
| mission: str, |
| quality_level: int, |
| background_preservation: float, |
| output_highlight_roi: bool, |
| output_enable_detection: bool, |
| output_detector_method: str, |
| output_det_classes: str = "", |
| ): |
| if original_image is None: |
| return None, None, None, None, None, "", "Please upload an image." |
|
|
| ensure_default_checkpoint_dirs() |
| device = _default_device() |
| base = original_image.copy() |
|
|
| idx = int(np.clip(int(quality_level) - 1, 0, len(CHECKPOINTS) - 1)) |
| quality_name, lambda_val, ckpt_rel, N, M = CHECKPOINTS[idx] |
|
|
| try: |
| classes_md = "" |
|
|
| if mask is None: |
| mask, _target, _extracted, classes_md, open_vocab_classes = _compute_roi_mask( |
| image=base, |
| segmenter_method=segmenter_method, |
| mission=mission, |
| device=device, |
| ) |
| elif not classes_md: |
| |
| if segmenter_method == "sam3": |
| classes_md = _format_classes_md( |
| "sam3", |
| mission or "", |
| extracted_classes=[], |
| used_classes=[], |
| note="", |
| sam3_prompts=list(open_vocab_classes or []), |
| ) |
| else: |
| classes_md = _format_classes_md( |
| segmenter_method, |
| mission or "", |
| extracted_classes=list(open_vocab_classes or []), |
| used_classes=list(open_vocab_classes or []), |
| note="", |
| ) |
|
|
| model = _get_compression_model(ckpt_rel, device, int(N), int(M)) |
|
|
| sigma = float(background_preservation) |
| result = vae.compress_image(image=base, mask=mask, model=model, sigma=sigma, device=device) |
| compressed = result["compressed"] |
| bpp = float(result["bpp"]) |
|
|
| compression_ratio = (24.0 / bpp) if bpp > 0 else float("inf") |
|
|
| roi_frac = float(mask.mean()) if mask is not None else 0.0 |
| if roi_frac <= 1e-6: |
| classes_md = classes_md + "\n\n**Warning:** ROI mask is empty; compression will behave like background-only." |
|
|
| dets_dicts = None |
| det_note = "" |
| if output_enable_detection: |
| dets_dicts = _compute_detections_dicts( |
| image=compressed, |
| detector_method=output_detector_method, |
| device=device, |
| open_vocab_classes=open_vocab_classes, |
| segmenter_method=segmenter_method, |
| mission=mission, |
| custom_det_classes=output_det_classes, |
| ) |
| det_note = f" | detections: {len(dets_dicts)}" |
|
|
| status = ( |
| f"**Compression:** {quality_name} (λ={lambda_val}) | " |
| f"**Background preservation:** σ={sigma:.2f} | " |
| f"**Compression ratio:** {compression_ratio:.2f}×{det_note}" |
| ) |
|
|
| output_viz = _render_output_preview(compressed, mask, output_highlight_roi, output_enable_detection, dets_dicts) |
| return output_viz, compressed, mask, dets_dicts, open_vocab_classes, classes_md, status |
| except Exception as e: |
| return None, None, mask, None, open_vocab_classes, "", f"Error: {type(e).__name__}: {e}" |
|
|
|
|
| def _refresh_input_view(original_image, mask, dets_dicts, highlight_roi: bool, enable_detection: bool): |
| return _render_input_preview(original_image, mask, highlight_roi, enable_detection, dets_dicts) |
|
|
|
|
| def _refresh_output_view(compressed_image, mask, dets_dicts, highlight_roi: bool, enable_detection: bool): |
| return _render_output_preview(compressed_image, mask, highlight_roi, enable_detection, dets_dicts) |
|
|
|
|
| @spaces.GPU |
| def _on_input_detection_toggle( |
| original_image, |
| mask, |
| open_vocab_classes, |
| dets_dicts, |
| segmenter_method: str, |
| mission: str, |
| enable_detection: bool, |
| detector_method: str, |
| highlight_roi: bool, |
| custom_det_classes: str = "", |
| ): |
| if original_image is None: |
| return None, dets_dicts |
|
|
| if enable_detection and dets_dicts is None: |
| ensure_default_checkpoint_dirs() |
| device = _default_device() |
| dets_dicts = _compute_detections_dicts( |
| image=original_image, |
| detector_method=detector_method, |
| device=device, |
| open_vocab_classes=open_vocab_classes, |
| segmenter_method=segmenter_method, |
| mission=mission, |
| custom_det_classes=custom_det_classes, |
| ) |
|
|
| viz = _render_input_preview(original_image, mask, highlight_roi, enable_detection, dets_dicts) |
| return viz, dets_dicts |
|
|
|
|
| @spaces.GPU |
| def _on_output_detection_toggle( |
| compressed_image, |
| mask, |
| open_vocab_classes, |
| dets_dicts, |
| segmenter_method: str, |
| mission: str, |
| enable_detection: bool, |
| detector_method: str, |
| highlight_roi: bool, |
| custom_det_classes: str = "", |
| ): |
| if compressed_image is None: |
| return None, dets_dicts |
|
|
| if enable_detection and dets_dicts is None: |
| ensure_default_checkpoint_dirs() |
| device = _default_device() |
| dets_dicts = _compute_detections_dicts( |
| image=compressed_image, |
| detector_method=detector_method, |
| device=device, |
| open_vocab_classes=open_vocab_classes, |
| segmenter_method=segmenter_method, |
| mission=mission, |
| custom_det_classes=custom_det_classes, |
| ) |
|
|
| viz = _render_output_preview(compressed_image, mask, highlight_roi, enable_detection, dets_dicts) |
| return viz, dets_dicts |
|
|
|
|
| @spaces.GPU |
| def _on_input_detector_change( |
| original_image, |
| mask, |
| open_vocab_classes, |
| segmenter_method: str, |
| mission: str, |
| enable_detection: bool, |
| detector_method: str, |
| highlight_roi: bool, |
| custom_det_classes: str = "", |
| ): |
| |
| dets_dicts = None |
| return _on_input_detection_toggle( |
| original_image, |
| mask, |
| open_vocab_classes, |
| dets_dicts, |
| segmenter_method, |
| mission, |
| enable_detection, |
| detector_method, |
| highlight_roi, |
| custom_det_classes, |
| ) |
|
|
|
|
| @spaces.GPU |
| def _on_output_detector_change( |
| compressed_image, |
| mask, |
| open_vocab_classes, |
| segmenter_method: str, |
| mission: str, |
| enable_detection: bool, |
| detector_method: str, |
| highlight_roi: bool, |
| custom_det_classes: str = "", |
| ): |
| dets_dicts = None |
| return _on_output_detection_toggle( |
| compressed_image, |
| mask, |
| open_vocab_classes, |
| dets_dicts, |
| segmenter_method, |
| mission, |
| enable_detection, |
| detector_method, |
| highlight_roi, |
| custom_det_classes, |
| ) |
|
|
|
|
| |
| |
| |
|
|
|
|
| def _process_video_streaming( |
| video_path, |
| mode: str, |
| segmenter_method: str, |
| mission: str, |
| |
| quality_level: int, |
| sigma: float, |
| output_fps: float, |
| |
| bandwidth_kbps: float, |
| min_fps: float, |
| max_fps: float, |
| chunk_duration: float, |
| |
| show_input_roi: bool, |
| show_output_roi: bool, |
| |
| show_input_detection: bool, |
| show_output_detection: bool, |
| ): |
| """Process video with streaming output. |
| |
| This is a generator function that yields updates for real-time streaming. |
| Returns both input (with optional detections) and output (with optional detections) videos. |
| """ |
| import tempfile |
| import os |
| |
| if video_path is None: |
| yield None, None, "No video uploaded", "", {}, None, [] |
| return |
| |
| ensure_default_checkpoint_dirs() |
| device = _default_device() |
| |
| |
| target_classes = _split_classes(mission) if mission else [] |
| |
| |
| need_detection = show_input_detection or show_output_detection |
| need_roi_highlight = show_input_roi or show_output_roi |
| |
| try: |
| from video import VideoProcessor, CompressionSettings |
| from video.video_processor import frames_to_video_bytes |
| from detection import SimpleTracker, draw_tracks, create_detector |
| from detection.utils import draw_detections |
| from detection.base import Detection |
| from vae.visualization import highlight_roi |
| from segmentation import create_segmenter |
| from PIL import Image |
| import cv2 |
| |
| |
| processor = VideoProcessor(device=device) |
| |
| |
| yield None, None, "**Loading models...**", "", {}, None, [] |
| |
| |
| processor.load_models( |
| quality_level=quality_level if mode == "Static" else 3, |
| segmentation_method=segmenter_method, |
| detection_method="yolo", |
| enable_tracking=need_detection, |
| ) |
| |
| |
| settings = CompressionSettings( |
| mode="static" if mode == "Static" else "dynamic", |
| quality_level=quality_level, |
| sigma=sigma, |
| output_fps=output_fps, |
| target_bandwidth_kbps=bandwidth_kbps, |
| chunk_duration_sec=chunk_duration, |
| min_fps=min_fps, |
| max_fps=max_fps, |
| segmentation_method=segmenter_method, |
| target_classes=target_classes, |
| enable_tracking=False, |
| ) |
| |
| |
| all_input_frames = [] |
| all_output_frames = [] |
| all_stats = [] |
| all_tracks = [] |
| total_bytes = 0 |
| total_original_frames = 0 |
| last_detection_frame = None |
| |
| |
| input_frames_extracted = [] |
| cap = cv2.VideoCapture(video_path) |
| input_fps = cap.get(cv2.CAP_PROP_FPS) or 30.0 |
| while True: |
| ret, frame = cap.read() |
| if not ret: |
| break |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| input_frames_extracted.append(Image.fromarray(frame_rgb)) |
| cap.release() |
| |
| |
| current_progress = "**Processing...**" |
| |
| def progress_callback(current, total, message): |
| nonlocal current_progress |
| pct = int(100 * current / total) if total > 0 else 0 |
| current_progress = f"**Progress:** {pct}% - {message}" |
| |
| |
| if mode == "Static": |
| all_chunks = processor.process_static_offline( |
| video_path, settings, progress_callback=progress_callback |
| ) |
| else: |
| all_chunks = processor.process_dynamic_offline( |
| video_path, settings, progress_callback=progress_callback |
| ) |
| |
| |
| detector = None |
| tracker = None |
| input_tracker = None |
| if need_detection: |
| detector = create_detector("yolo", device=device) |
| tracker = SimpleTracker(iou_threshold=0.3, max_age=30) |
| input_tracker = SimpleTracker(iou_threshold=0.3, max_age=30) |
| |
| |
| segmenter = None |
| if need_roi_highlight: |
| segmenter = create_segmenter(segmenter_method, device=device) |
| |
| |
| def apply_roi_highlight(frame_img, seg): |
| if seg is None: |
| return frame_img |
| |
| mask = seg(frame_img, target_classes=target_classes if target_classes else ["object"]) |
| if mask is not None and mask.sum() > 0: |
| return highlight_roi(frame_img, mask, alpha=0.35, color=(0, 255, 0)) |
| return frame_img |
| |
| |
| all_input_frames = [] |
| all_tracks = [] |
| last_detection_frame = None |
| |
| |
| for c_idx, chunk in enumerate(all_chunks): |
| |
| chunk_output_frames = [] |
| for frame in chunk.frames: |
| processed_frame = frame |
| |
| |
| if show_output_roi and segmenter is not None: |
| processed_frame = apply_roi_highlight(processed_frame, segmenter) |
| |
| |
| if show_output_detection and detector is not None: |
| |
| dets = detector(frame, conf_threshold=0.25) |
| det_dicts = [{"label": d.label, "score": d.score, "bbox_xyxy": list(d.bbox_xyxy)} for d in dets] |
| tracks = tracker.update(det_dicts) |
| |
| |
| frame_with_dets = draw_detections(processed_frame.copy(), dets, color=(0, 255, 0)) |
| frame_with_dets = draw_tracks(frame_with_dets, tracks, show_id=True, show_trail=True) |
| chunk_output_frames.append(frame_with_dets) |
| |
| last_detection_frame = frame_with_dets |
| all_tracks = tracks |
| else: |
| chunk_output_frames.append(processed_frame) |
| |
| all_output_frames.extend(chunk_output_frames) |
| total_bytes += chunk.estimated_bytes |
| total_original_frames += chunk.original_frame_count |
| |
| |
| need_input_processing = (show_input_detection or show_input_roi) and input_frames_extracted |
| if need_input_processing: |
| |
| input_start = int(len(input_frames_extracted) * (chunk.chunk_index * chunk.original_frame_count) / max(1, total_original_frames)) |
| input_end = min(len(input_frames_extracted), input_start + len(chunk.frames)) |
| |
| for frame in input_frames_extracted[input_start:input_end]: |
| processed_input = frame |
| |
| |
| if show_input_roi and segmenter is not None: |
| processed_input = apply_roi_highlight(processed_input, segmenter) |
| |
| |
| if show_input_detection and detector is not None and input_tracker is not None: |
| dets = detector(frame, conf_threshold=0.25) |
| det_dicts = [{"label": d.label, "score": d.score, "bbox_xyxy": list(d.bbox_xyxy)} for d in dets] |
| tracks = input_tracker.update(det_dicts) |
| |
| frame_with_dets = draw_detections(processed_input.copy(), dets, color=(0, 255, 0)) |
| frame_with_dets = draw_tracks(frame_with_dets, tracks, show_id=True, show_trail=True) |
| all_input_frames.append(frame_with_dets) |
| else: |
| all_input_frames.append(processed_input) |
| |
| |
| chunk_stats = { |
| "chunk_index": chunk.chunk_index, |
| "frames": len(chunk.frames), |
| "fps": round(chunk.fps, 1), |
| "quality": chunk.quality_level, |
| "sigma": round(chunk.sigma, 2), |
| "avg_bpp": round(chunk.avg_bpp, 3), |
| "size_kb": round(chunk.estimated_bytes / 1024, 1), |
| } |
| |
| |
| if chunk.motion_metrics is not None: |
| chunk_stats["motion"] = round(chunk.motion_metrics.motion_magnitude, 3) |
| chunk_stats["complexity"] = round(chunk.motion_metrics.complexity, 3) |
| |
| all_stats.append(chunk_stats) |
| |
| |
| if all_output_frames: |
| try: |
| avg_fps = sum(s["fps"] for s in all_stats) / len(all_stats) if all_stats else 15 |
| |
| |
| video_bytes = frames_to_video_bytes(all_output_frames, avg_fps, format="mp4") |
| with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp: |
| tmp.write(video_bytes) |
| output_video_path = tmp.name |
| |
| |
| input_video_path = video_path |
| if (show_input_detection or show_input_roi) and all_input_frames: |
| input_video_bytes = frames_to_video_bytes(all_input_frames, avg_fps, format="mp4") |
| with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp: |
| tmp.write(input_video_bytes) |
| input_video_path = tmp.name |
| |
| |
| summary_stats = { |
| "chunks_processed": len(all_stats), |
| "total_frames": len(all_output_frames), |
| "original_frames": total_original_frames, |
| "frame_reduction": f"{100 * (1 - len(all_output_frames) / max(1, total_original_frames)):.1f}%", |
| "total_size_kb": round(total_bytes / 1024, 1), |
| "avg_bpp": round(sum(s["avg_bpp"] for s in all_stats) / max(1, len(all_stats)), 3), |
| "per_chunk": all_stats, |
| } |
| |
| |
| yield ( |
| input_video_path, |
| output_video_path, |
| current_progress, |
| f"**Chunk {chunk.chunk_index + 1}** | FPS: {chunk.fps:.1f} | Quality: {chunk.quality_level} | Size: {chunk.estimated_bytes/1024:.1f}KB", |
| summary_stats, |
| last_detection_frame, |
| all_tracks, |
| ) |
| except Exception as e: |
| |
| yield ( |
| video_path, |
| None, |
| current_progress, |
| f"Encoding error: {e}", |
| {"chunks_processed": len(all_stats), "per_chunk": all_stats}, |
| last_detection_frame, |
| all_tracks, |
| ) |
| |
| |
| if not all_output_frames: |
| yield None, None, "**No frames processed**", "Error: No frames produced", {}, None, [] |
| return |
| |
| avg_fps = sum(s["fps"] for s in all_stats) / len(all_stats) if all_stats else 15 |
| |
| |
| video_bytes = frames_to_video_bytes(all_output_frames, avg_fps, format="mp4") |
| with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp: |
| tmp.write(video_bytes) |
| final_output_path = tmp.name |
| |
| |
| final_input_path = video_path |
| if (show_input_detection or show_input_roi) and all_input_frames: |
| input_video_bytes = frames_to_video_bytes(all_input_frames, avg_fps, format="mp4") |
| with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp: |
| tmp.write(input_video_bytes) |
| final_input_path = tmp.name |
| |
| |
| compression_ratio = (24 * total_original_frames * processor.video_dimensions[0] * processor.video_dimensions[1]) / (8 * max(1, total_bytes)) |
| |
| final_stats = { |
| "status": "complete", |
| "total_chunks": len(all_stats), |
| "total_frames": len(all_output_frames), |
| "original_frames": total_original_frames, |
| "frame_reduction": f"{100 * (1 - len(all_output_frames) / max(1, total_original_frames)):.1f}%", |
| "total_size_kb": round(total_bytes / 1024, 1), |
| "compression_ratio": f"{compression_ratio:.1f}x", |
| "avg_bpp": round(sum(s["avg_bpp"] for s in all_stats) / max(1, len(all_stats)), 3), |
| "video_fps": processor.video_fps, |
| "video_dimensions": processor.video_dimensions, |
| "per_chunk": all_stats, |
| } |
| |
| final_status = ( |
| f"**Complete!** | {len(all_output_frames)} frames | " |
| f"{total_bytes/1024:.1f}KB | {compression_ratio:.1f}x compression" |
| ) |
| |
| yield ( |
| final_input_path, |
| final_output_path, |
| "**Processing complete!**", |
| final_status, |
| final_stats, |
| last_detection_frame, |
| all_tracks, |
| ) |
| |
| except Exception as e: |
| import traceback |
| error_msg = f"Error: {type(e).__name__}: {e}\n{traceback.format_exc()}" |
| yield None, None, f"**Error:** {type(e).__name__}", error_msg, {}, None, [] |
| finally: |
| |
| if 'processor' in locals(): |
| processor.cleanup() |
|
|
|
|
| @spaces.GPU |
| def _find_video_roi( |
| video_path, |
| segmenter_method: str, |
| mission: str, |
| show_detection: bool, |
| ): |
| """ |
| Find ROI in video and create ROI-highlighted video. |
| |
| Preprocesses video to 480p max height and caps FPS at 30. |
| Saves segmentation masks to file for reuse in compression. |
| |
| Returns: |
| - roi_video_path: Path to video with ROI highlights |
| - status: Status message |
| - roi_frames: List of ROI-highlighted PIL frames (for caching) |
| - mask_file_path: Path to saved segmentation masks file |
| """ |
| import tempfile |
| import cv2 |
| from PIL import Image |
| |
| if video_path is None: |
| return None, "**Error:** No video uploaded", None |
| |
| ensure_default_checkpoint_dirs() |
| device = _default_device() |
| |
| target_classes = _split_classes(mission) if mission else [] |
| |
| try: |
| from video.video_processor import frames_to_video_bytes, MAX_PROCESSING_FPS, MAX_PROCESSING_HEIGHT |
| from vae.visualization import highlight_roi |
| from segmentation import create_segmenter |
| from detection import create_detector |
| from detection.utils import draw_detections |
| from video.mask_cache import save_video_masks |
| from video.chunk_compressor import smooth_masks_sdf |
| |
| |
| segmenter = create_segmenter(segmenter_method, device=device) |
| |
| |
| detector = None |
| if show_detection: |
| detector = create_detector("yolo", device=device) |
| |
| |
| cap = cv2.VideoCapture(video_path) |
| original_fps = cap.get(cv2.CAP_PROP_FPS) or 30.0 |
| original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
| original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
| |
| |
| effective_fps = min(original_fps, MAX_PROCESSING_FPS) |
| frame_step = max(1, int(original_fps / effective_fps)) |
| |
| |
| if original_height > MAX_PROCESSING_HEIGHT: |
| scale_factor = MAX_PROCESSING_HEIGHT / original_height |
| new_width = int(original_width * scale_factor) |
| new_height = MAX_PROCESSING_HEIGHT |
| else: |
| new_width = original_width |
| new_height = original_height |
| |
| |
| pil_frames = [] |
| frame_idx = 0 |
| |
| while True: |
| ret, frame = cap.read() |
| if not ret: |
| break |
| |
| |
| if frame_idx % frame_step != 0: |
| frame_idx += 1 |
| continue |
| |
| |
| if new_width != original_width or new_height != original_height: |
| frame = cv2.resize(frame, (new_width, new_height), interpolation=cv2.INTER_AREA) |
| |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| pil_frames.append(Image.fromarray(frame_rgb)) |
| frame_idx += 1 |
| |
| cap.release() |
| |
| if not pil_frames: |
| return None, "**Error:** No frames extracted from video", None, None |
| |
| |
| from video.gpu_memory import estimate_batch_sizes |
| batch_est = estimate_batch_sizes( |
| frame_height=new_height, |
| frame_width=new_width, |
| seg_method=segmenter_method, |
| device=str(device), |
| total_frames=len(pil_frames), |
| ) |
| seg_batch = batch_est.seg_batch_size |
| print(f"UI _find_video_roi: {len(pil_frames)} frames, seg batch={seg_batch} ({batch_est.notes})") |
| |
| |
| import torch |
| max_retries = 7 |
| all_masks = None |
| prompts = target_classes if target_classes else ["object"] |
| for attempt in range(max_retries + 1): |
| try: |
| all_masks = [] |
| if hasattr(segmenter, 'segment_batch') and getattr(segmenter, 'supports_batch', False): |
| for i in range(0, len(pil_frames), seg_batch): |
| batch = pil_frames[i:i + seg_batch] |
| batch_masks = segmenter.segment_batch(batch, target_classes=prompts) |
| all_masks.extend([m.astype('float32') for m in batch_masks]) |
| else: |
| for pil_frame in pil_frames: |
| mask = segmenter(pil_frame, target_classes=prompts) |
| all_masks.append(mask.astype('float32')) |
| break |
| except (torch.cuda.OutOfMemoryError, RuntimeError) as e: |
| if 'out of memory' in str(e).lower() and attempt < max_retries: |
| seg_batch = max(1, seg_batch // 2) |
| |
| all_masks = None |
| import gc |
| gc.collect() |
| torch.cuda.empty_cache() |
| torch.cuda.synchronize() |
| print(f"UI _find_video_roi: OOM, retrying with batch={seg_batch} (attempt {attempt+1}/{max_retries})") |
| continue |
| raise |
| |
| if all_masks is None: |
| return None, "**Error:** Segmentation failed after OOM retries", None, None |
| |
| |
| all_masks = smooth_masks_sdf(all_masks, alpha=0.5, empty_thresh=10, patience=5) |
| |
| |
| roi_frames = [] |
| for pil_frame, mask in zip(pil_frames, all_masks): |
| if mask is not None and mask.sum() > 0: |
| highlighted = highlight_roi(pil_frame, mask, alpha=0.35, color=(0, 255, 0)) |
| else: |
| highlighted = pil_frame |
| |
| |
| if show_detection and detector is not None: |
| dets = detector(pil_frame, conf_threshold=0.25) |
| highlighted = draw_detections(highlighted, dets, color=(255, 0, 0)) |
| |
| roi_frames.append(highlighted) |
| |
| |
| mask_file_path = save_video_masks(all_masks) |
| |
| |
| video_bytes = frames_to_video_bytes(roi_frames, effective_fps, format="mp4") |
| with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp: |
| tmp.write(video_bytes) |
| output_path = tmp.name |
| |
| status = f"**ROI Found** | {len(roi_frames)} frames @ {effective_fps:.0f}fps ({new_width}x{new_height}) | Masks saved" |
| if show_detection: |
| status += " | Detection enabled" |
| |
| return output_path, status, roi_frames, mask_file_path |
| |
| except Exception as e: |
| import traceback |
| return None, f"**Error:** {type(e).__name__}: {e}", None, None |
|
|
|
|
| @spaces.GPU |
| def _compress_video( |
| video_path, |
| segmenter_method: str, |
| mission: str, |
| mode: str, |
| |
| quality_level: int, |
| sigma: float, |
| output_fps: float, |
| |
| bandwidth_kbps: float, |
| min_fps: float, |
| max_fps: float, |
| chunk_duration: float, |
| |
| roi_frames_cache, |
| mask_file_path, |
| ): |
| """ |
| Compress video using ROI-based compression. |
| |
| If mask_file_path is provided, reuses saved segmentation masks. |
| Otherwise, performs segmentation during compression. |
| |
| Returns: |
| - compressed_video_path: Path to compressed video |
| - status: Status message |
| - stats: Compression statistics dict |
| - compressed_state: Path for detection step |
| """ |
| import tempfile |
| |
| if video_path is None: |
| return None, "**Error:** No video uploaded", {}, None |
| |
| ensure_default_checkpoint_dirs() |
| device = _default_device() |
| |
| target_classes = _split_classes(mission) if mission else [] |
| |
| try: |
| from video import VideoProcessor, CompressionSettings, load_video_masks |
| from video.video_processor import frames_to_video_bytes |
| |
| |
| saved_masks = None |
| if mask_file_path is not None: |
| try: |
| saved_masks = load_video_masks(mask_file_path) |
| print(f"Loaded {len(saved_masks)} cached masks") |
| except Exception as e: |
| print(f"Failed to load masks: {e}, will re-segment") |
| |
| |
| processor = VideoProcessor(device=device) |
| |
| |
| processor.load_models( |
| quality_level=quality_level if mode == "Static" else 3, |
| segmentation_method=segmenter_method, |
| detection_method="yolo", |
| enable_tracking=False, |
| ) |
| |
| |
| settings = CompressionSettings( |
| mode="static" if mode == "Static" else "dynamic", |
| quality_level=quality_level, |
| sigma=sigma, |
| output_fps=output_fps, |
| target_bandwidth_kbps=bandwidth_kbps, |
| chunk_duration_sec=chunk_duration, |
| min_fps=min_fps, |
| max_fps=max_fps, |
| segmentation_method=segmenter_method, |
| target_classes=target_classes, |
| enable_tracking=False, |
| ) |
| |
| |
| if mode == "Static": |
| all_chunks = processor.process_static_offline(video_path, settings, saved_masks=saved_masks) |
| else: |
| all_chunks = processor.process_dynamic_offline(video_path, settings, saved_masks=saved_masks) |
| |
| all_frames = [] |
| all_stats = [] |
| total_bytes = 0 |
| total_original_frames = 0 |
| |
| for chunk in all_chunks: |
| all_frames.extend(chunk.frames) |
| total_bytes += chunk.estimated_bytes |
| total_original_frames += chunk.original_frame_count |
| |
| chunk_stat = { |
| "chunk": chunk.chunk_index, |
| "frames": len(chunk.frames), |
| "fps": round(chunk.fps, 1), |
| "quality": chunk.quality_level, |
| "sigma": round(chunk.sigma, 2), |
| "bpp": round(chunk.avg_bpp, 3), |
| "size_kb": round(chunk.estimated_bytes / 1024, 1), |
| } |
| if chunk.motion_metrics is not None: |
| chunk_stat["motion"] = round(chunk.motion_metrics.motion_magnitude, 3) |
| all_stats.append(chunk_stat) |
| |
| if not all_frames: |
| return None, "**Error:** No frames produced", {}, None |
| |
| |
| avg_fps = sum(s["fps"] for s in all_stats) / len(all_stats) if all_stats else output_fps |
| |
| |
| video_bytes = frames_to_video_bytes(all_frames, avg_fps, format="mp4") |
| with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp: |
| tmp.write(video_bytes) |
| output_path = tmp.name |
| |
| |
| if processor.video_dimensions: |
| w, h = processor.video_dimensions |
| compression_ratio = (24 * total_original_frames * w * h) / (8 * max(1, total_bytes)) |
| else: |
| compression_ratio = 0 |
| |
| stats = { |
| "mode": mode, |
| "total_frames": len(all_frames), |
| "original_frames": total_original_frames, |
| "frame_reduction": f"{100 * (1 - len(all_frames) / max(1, total_original_frames)):.1f}%", |
| "total_size_kb": round(total_bytes / 1024, 1), |
| "compression_ratio": f"{compression_ratio:.1f}x", |
| "avg_bpp": round(sum(s["bpp"] for s in all_stats) / max(1, len(all_stats)), 3), |
| "chunks": all_stats, |
| } |
| |
| status = ( |
| f"**Compressed!** | {len(all_frames)} frames | " |
| f"{total_bytes/1024:.1f}KB | {compression_ratio:.1f}x compression" |
| ) |
| if saved_masks is not None: |
| status += " | Using cached masks" |
| |
| return output_path, status, stats, output_path |
| |
| except Exception as e: |
| import traceback |
| return None, f"**Error:** {type(e).__name__}: {e}", {}, None |
| finally: |
| |
| if 'processor' in locals(): |
| processor.cleanup() |
|
|
|
|
| @spaces.GPU |
| def _run_video_detection( |
| compressed_video_path, |
| det_method: str, |
| det_classes: str, |
| ): |
| """ |
| Run object detection on compressed video. |
| |
| Returns: |
| - detection_video_path: Path to video with detection overlays |
| - status: Status message |
| """ |
| import tempfile |
| import cv2 |
| from PIL import Image |
| |
| if compressed_video_path is None: |
| return None, "**Error:** No compressed video available. Run compression first." |
| |
| ensure_default_checkpoint_dirs() |
| device = _default_device() |
| |
| try: |
| from video.video_processor import frames_to_video_bytes |
| from detection import create_detector |
| from detection.utils import draw_detections |
| |
| |
| detector = create_detector(det_method, device=device) |
| |
| |
| det_kwargs = {"conf_threshold": 0.25} |
| if det_method in OPEN_VOCAB_DETECTORS: |
| classes = _split_classes(det_classes) |
| if not classes: |
| return None, f"**Error:** {det_method} requires class prompts" |
| det_kwargs["classes"] = classes |
| |
| |
| cap = cv2.VideoCapture(compressed_video_path) |
| fps = cap.get(cv2.CAP_PROP_FPS) or 30.0 |
| |
| det_frames = [] |
| total_detections = 0 |
| |
| while True: |
| ret, frame = cap.read() |
| if not ret: |
| break |
| |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| pil_frame = Image.fromarray(frame_rgb) |
| |
| |
| dets = detector(pil_frame, **det_kwargs) |
| total_detections += len(dets) |
| |
| |
| frame_with_dets = draw_detections(pil_frame, dets, color=(0, 255, 0)) |
| det_frames.append(frame_with_dets) |
| |
| cap.release() |
| |
| if not det_frames: |
| return None, "**Error:** No frames extracted" |
| |
| |
| video_bytes = frames_to_video_bytes(det_frames, fps, format="mp4") |
| with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp: |
| tmp.write(video_bytes) |
| output_path = tmp.name |
| |
| avg_dets = total_detections / len(det_frames) |
| status = f"**Detection complete** | {len(det_frames)} frames | {total_detections} total detections ({avg_dets:.1f}/frame)" |
| |
| return output_path, status |
| |
| except Exception as e: |
| import traceback |
| return None, f"**Error:** {type(e).__name__}: {e}" |
|
|
|
|
| def build_app() -> gr.Blocks: |
| from segmentation.factory import get_available_methods |
| from detection import get_available_detectors |
|
|
| seg_methods = get_available_methods() |
| det_methods = get_available_detectors() |
|
|
| with gr.Blocks(title="Contextual Communication Demo") as demo: |
| gr.Markdown( |
| "# Contextual Communication Demo\n" |
| "Mission-driven contextual compression for bandwidth-limited ISR: preserve decision-critical regions while compactly transmitting the rest." |
| ) |
|
|
| with gr.Tabs(): |
| with gr.TabItem("Image", id=0): |
| with gr.Row(): |
| with gr.Column(scale=1): |
| inp = gr.Image(type="pil", label="Input image", height=480) |
|
|
| original_state = gr.State(None) |
| mask_state = gr.State(None) |
| open_vocab_classes_state = gr.State(None) |
| input_dets_state = gr.State(None) |
| compressed_state = gr.State(None) |
| output_dets_state = gr.State(None) |
| segmenter = gr.Dropdown( |
| choices=seg_methods, |
| value="sam3" if "sam3" in seg_methods else (seg_methods[0] if seg_methods else None), |
| label="Context Extraction Method (ROI)", |
| ) |
| mission = gr.Textbox( |
| label="Mission", |
| placeholder="e.g., Collect intelligence on potential air threats (drones, aircraft)", |
| lines=1, |
| autoscroll=False, |
| ) |
| chosen_classes = gr.Markdown("") |
|
|
| with gr.Group(): |
| input_roi_highlight = gr.Checkbox(value=True, label="Highlight ROI") |
| input_det_toggle = gr.Checkbox(value=False, label="Enable object detection") |
| input_detector = gr.Dropdown( |
| choices=det_methods, |
| value="yolo" if "yolo" in det_methods else (det_methods[0] if det_methods else None), |
| label="Detector", |
| interactive=True, |
| visible=False, |
| ) |
| with gr.Row(visible=False) as input_det_classes_row: |
| input_det_classes = gr.Textbox( |
| placeholder="person, car, drone, aircraft", |
| scale=4, |
| show_label=False, |
| container=False, |
| ) |
| input_det_apply_btn = gr.Button("Apply", scale=2, variant="secondary", min_width=60) |
|
|
| find_roi_btn = gr.Button("Find Context (ROI)") |
|
|
| with gr.Column(scale=1): |
| out_out = gr.Image(type="pil", label="Decoded image (after transmission)", height=480) |
|
|
| quality = gr.Slider( |
| minimum=1, |
| maximum=len(CHECKPOINTS), |
| value=4 if len(CHECKPOINTS) >= 4 else 1, |
| step=1, |
| label="Transmission quality (higher = larger payload)", |
| ) |
| _blank1 = gr.Markdown("") |
| bg = gr.Slider( |
| minimum=0.01, |
| maximum=1.0, |
| value=0.3, |
| step=0.01, |
| label="Background preservation (higher = keep more context)", |
| ) |
| _blank2 = gr.Markdown("") |
|
|
| with gr.Group(): |
| output_roi_highlight = gr.Checkbox(value=False, label="Highlight ROI") |
| output_det_toggle = gr.Checkbox(value=False, label="Enable object detection") |
| output_detector = gr.Dropdown( |
| choices=det_methods, |
| value="yolo" if "yolo" in det_methods else (det_methods[0] if det_methods else None), |
| label="Detector", |
| interactive=True, |
| visible=False, |
| ) |
| with gr.Row(visible=False) as output_det_classes_row: |
| output_det_classes = gr.Textbox( |
| placeholder="person, car, drone, aircraft", |
| scale=4, |
| show_label=False, |
| container=False, |
| ) |
| output_det_apply_btn = gr.Button("Apply", scale=1, size="sm", variant="secondary", min_width=60) |
|
|
| compress_btn = gr.Button("Transmit (Compress)", variant="primary") |
| status = gr.Markdown("") |
|
|
| |
|
|
| def _toggle_detector_ui(enabled: bool, detector: str): |
| is_open_vocab = detector in OPEN_VOCAB_DETECTORS |
| return gr.update(visible=enabled), gr.update(visible=enabled and is_open_vocab) |
|
|
| def _on_detector_type_change(enabled: bool, detector: str): |
| is_open_vocab = detector in OPEN_VOCAB_DETECTORS |
| return gr.update(visible=enabled and is_open_vocab) |
|
|
| input_det_toggle.change( |
| _toggle_detector_ui, |
| inputs=[input_det_toggle, input_detector], |
| outputs=[input_detector, input_det_classes_row], |
| ) |
| output_det_toggle.change( |
| _toggle_detector_ui, |
| inputs=[output_det_toggle, output_detector], |
| outputs=[output_detector, output_det_classes_row], |
| ) |
| input_detector.change( |
| _on_detector_type_change, |
| inputs=[input_det_toggle, input_detector], |
| outputs=[input_det_classes_row], |
| ) |
| output_detector.change( |
| _on_detector_type_change, |
| inputs=[output_det_toggle, output_detector], |
| outputs=[output_det_classes_row], |
| ) |
|
|
| |
| |
| inp.upload( |
| _on_upload, |
| inputs=[inp], |
| outputs=[ |
| original_state, |
| mask_state, |
| open_vocab_classes_state, |
| input_dets_state, |
| output_dets_state, |
| chosen_classes, |
| status, |
| compressed_state, |
| ], |
| ) |
|
|
| find_roi_btn.click( |
| _find_roi, |
| inputs=[original_state, segmenter, mission, input_roi_highlight, input_det_toggle, input_detector, input_det_classes], |
| outputs=[inp, mask_state, open_vocab_classes_state, input_dets_state, chosen_classes, status], |
| ) |
|
|
| compress_btn.click( |
| _compress, |
| inputs=[ |
| original_state, |
| mask_state, |
| open_vocab_classes_state, |
| segmenter, |
| mission, |
| quality, |
| bg, |
| output_roi_highlight, |
| output_det_toggle, |
| output_detector, |
| output_det_classes, |
| ], |
| outputs=[out_out, compressed_state, mask_state, output_dets_state, open_vocab_classes_state, chosen_classes, status], |
| ) |
|
|
| |
| input_roi_highlight.change( |
| _refresh_input_view, |
| inputs=[original_state, mask_state, input_dets_state, input_roi_highlight, input_det_toggle], |
| outputs=[inp], |
| ) |
| input_det_toggle.change( |
| _on_input_detection_toggle, |
| inputs=[ |
| original_state, |
| mask_state, |
| open_vocab_classes_state, |
| input_dets_state, |
| segmenter, |
| mission, |
| input_det_toggle, |
| input_detector, |
| input_roi_highlight, |
| input_det_classes, |
| ], |
| outputs=[inp, input_dets_state], |
| ) |
| input_detector.change( |
| _on_input_detector_change, |
| inputs=[ |
| original_state, |
| mask_state, |
| open_vocab_classes_state, |
| segmenter, |
| mission, |
| input_det_toggle, |
| input_detector, |
| input_roi_highlight, |
| input_det_classes, |
| ], |
| outputs=[inp, input_dets_state], |
| ) |
|
|
| |
| output_roi_highlight.change( |
| _refresh_output_view, |
| inputs=[compressed_state, mask_state, output_dets_state, output_roi_highlight, output_det_toggle], |
| outputs=[out_out], |
| ) |
| output_det_toggle.change( |
| _on_output_detection_toggle, |
| inputs=[ |
| compressed_state, |
| mask_state, |
| open_vocab_classes_state, |
| output_dets_state, |
| segmenter, |
| mission, |
| output_det_toggle, |
| output_detector, |
| output_roi_highlight, |
| output_det_classes, |
| ], |
| outputs=[out_out, output_dets_state], |
| ) |
| output_detector.change( |
| _on_output_detector_change, |
| inputs=[ |
| compressed_state, |
| mask_state, |
| open_vocab_classes_state, |
| segmenter, |
| mission, |
| output_det_toggle, |
| output_detector, |
| output_roi_highlight, |
| output_det_classes, |
| ], |
| outputs=[out_out, output_dets_state], |
| ) |
|
|
| |
| input_det_classes.submit( |
| _on_input_detector_change, |
| inputs=[ |
| original_state, |
| mask_state, |
| open_vocab_classes_state, |
| segmenter, |
| mission, |
| input_det_toggle, |
| input_detector, |
| input_roi_highlight, |
| input_det_classes, |
| ], |
| outputs=[inp, input_dets_state], |
| ) |
| output_det_classes.submit( |
| _on_output_detector_change, |
| inputs=[ |
| compressed_state, |
| mask_state, |
| open_vocab_classes_state, |
| segmenter, |
| mission, |
| output_det_toggle, |
| output_detector, |
| output_roi_highlight, |
| output_det_classes, |
| ], |
| outputs=[out_out, output_dets_state], |
| ) |
|
|
| |
| input_det_apply_btn.click( |
| _on_input_detector_change, |
| inputs=[ |
| original_state, |
| mask_state, |
| open_vocab_classes_state, |
| segmenter, |
| mission, |
| input_det_toggle, |
| input_detector, |
| input_roi_highlight, |
| input_det_classes, |
| ], |
| outputs=[inp, input_dets_state], |
| ) |
| output_det_apply_btn.click( |
| _on_output_detector_change, |
| inputs=[ |
| compressed_state, |
| mask_state, |
| open_vocab_classes_state, |
| segmenter, |
| mission, |
| output_det_toggle, |
| output_detector, |
| output_roi_highlight, |
| output_det_classes, |
| ], |
| outputs=[out_out, output_dets_state], |
| ) |
|
|
| with gr.TabItem("Video", id=1): |
| |
| |
| |
| |
| |
| |
| |
| gr.Markdown( |
| "## Video Compression\n\n" |
| "**Workflow:** Upload video → Select ROI method & mission → Find ROI → Tune compression → Transmit" |
| ) |
| |
| |
| video_roi_frames_state = gr.State(None) |
| video_mask_file_state = gr.State(None) |
| video_compressed_state = gr.State(None) |
| |
| |
| with gr.Row(height=360): |
| video_input = gr.Video( |
| label="Input Video", |
| height=360, |
| ) |
| video_roi_output = gr.Video( |
| label="ROI Highlighted Video", |
| height=360, |
| autoplay=True, |
| ) |
| video_compressed_output = gr.Video( |
| label="Compressed Output", |
| height=360, |
| autoplay=True, |
| ) |
| |
| |
| with gr.Row(): |
| |
| with gr.Column(scale=2): |
| video_segmenter = gr.Dropdown( |
| choices=seg_methods, |
| value="sam3" if "sam3" in seg_methods else (seg_methods[0] if seg_methods else None), |
| label="ROI Method", |
| ) |
| video_mission = gr.Textbox( |
| label="Mission / Target Classes", |
| placeholder="e.g., drone, person, vehicle, aircraft", |
| lines=1, |
| ) |
| video_show_detection = gr.Checkbox( |
| value=False, |
| label="Show Object Detection on ROI Video", |
| ) |
| video_find_roi_btn = gr.Button("Find ROI", variant="secondary") |
| video_roi_status = gr.Markdown("") |
| |
| |
| with gr.Column(scale=1): |
| |
| video_mode = gr.Radio( |
| choices=["Static", "Dynamic"], |
| value="Static", |
| label="Compression Mode", |
| info="Static: fixed settings | Dynamic: bandwidth-adaptive", |
| ) |
| |
| |
| with gr.Group(visible=True) as static_settings_group: |
| video_quality = gr.Slider( |
| minimum=1, |
| maximum=5, |
| value=4, |
| step=1, |
| label="Quality (1=smallest, 5=best)", |
| ) |
| video_sigma = gr.Slider( |
| minimum=0.01, |
| maximum=1.0, |
| value=0.3, |
| step=0.01, |
| label="Background Preservation (σ)", |
| ) |
| video_output_fps = gr.Slider( |
| minimum=1, |
| maximum=60, |
| value=10, |
| step=1, |
| label="Output FPS", |
| ) |
| |
| |
| with gr.Group(visible=False) as dynamic_settings_group: |
| video_bandwidth = gr.Slider( |
| minimum=50, |
| maximum=5000, |
| value=500, |
| step=50, |
| label="Target Bandwidth (kbps)", |
| ) |
| video_min_fps = gr.Slider( |
| minimum=1, |
| maximum=30, |
| value=5, |
| step=1, |
| label="Min FPS", |
| ) |
| video_max_fps = gr.Slider( |
| minimum=15, |
| maximum=60, |
| value=24, |
| step=1, |
| label="Max FPS", |
| ) |
| video_chunk_duration = gr.Slider( |
| minimum=0.5, |
| maximum=5.0, |
| value=1.0, |
| step=0.5, |
| label="Chunk Duration (sec)", |
| ) |
| |
| video_transmit_btn = gr.Button("Transmit (Compress)", variant="primary") |
| video_compress_status = gr.Markdown("") |
| |
| |
| with gr.Accordion("Compression Statistics", open=False): |
| video_stats = gr.JSON(label="Stats") |
| |
| |
| with gr.Accordion("Try Object Detection", open=False): |
| video_det_method = gr.Dropdown( |
| choices=det_methods, |
| value="yolo" if "yolo" in det_methods else (det_methods[0] if det_methods else None), |
| label="Detection Method", |
| ) |
| video_det_prompt = gr.Textbox( |
| label="Classes (open-vocab)", |
| placeholder="person, car, drone", |
| lines=1, |
| visible=False, |
| ) |
| video_run_detection_btn = gr.Button("Run Detection", variant="secondary") |
| video_det_status = gr.Markdown("") |
| video_detection_output = gr.Video( |
| label="Detection Output", |
| height=200, |
| autoplay=True, |
| ) |
| |
| |
| |
| |
| def _toggle_video_mode(mode: str): |
| is_static = mode == "Static" |
| return ( |
| gr.update(visible=is_static), |
| gr.update(visible=not is_static), |
| ) |
| |
| video_mode.change( |
| _toggle_video_mode, |
| inputs=[video_mode], |
| outputs=[static_settings_group, dynamic_settings_group], |
| ) |
| |
| |
| def _toggle_det_prompt(method: str): |
| return gr.update(visible=method in OPEN_VOCAB_DETECTORS) |
| |
| video_det_method.change( |
| _toggle_det_prompt, |
| inputs=[video_det_method], |
| outputs=[video_det_prompt], |
| ) |
| |
| |
| video_find_roi_btn.click( |
| _find_video_roi, |
| inputs=[ |
| video_input, |
| video_segmenter, |
| video_mission, |
| video_show_detection, |
| ], |
| outputs=[ |
| video_roi_output, |
| video_roi_status, |
| video_roi_frames_state, |
| video_mask_file_state, |
| ], |
| ) |
| |
| |
| video_transmit_btn.click( |
| _compress_video, |
| inputs=[ |
| video_input, |
| video_segmenter, |
| video_mission, |
| video_mode, |
| |
| video_quality, |
| video_sigma, |
| video_output_fps, |
| |
| video_bandwidth, |
| video_min_fps, |
| video_max_fps, |
| video_chunk_duration, |
| |
| video_roi_frames_state, |
| video_mask_file_state, |
| ], |
| outputs=[ |
| video_compressed_output, |
| video_compress_status, |
| video_stats, |
| video_compressed_state, |
| ], |
| ) |
| |
| |
| video_run_detection_btn.click( |
| _run_video_detection, |
| inputs=[ |
| video_compressed_state, |
| video_det_method, |
| video_det_prompt, |
| ], |
| outputs=[ |
| video_detection_output, |
| video_det_status, |
| ], |
| ) |
| |
| |
| |
| |
| |
| |
| |
| with gr.Row(visible=False): |
| |
| _api_seg_image = gr.Image(type="pil") |
| _api_seg_prompt = gr.Textbox() |
| _api_seg_method = gr.Textbox() |
| _api_seg_return_overlay = gr.Checkbox() |
| _api_seg_mask = gr.Image(type="pil") |
| _api_seg_coverage = gr.Number() |
| _api_seg_classes = gr.Textbox() |
| |
| |
| _api_comp_image = gr.Image(type="pil") |
| _api_comp_mask = gr.Image(type="pil") |
| _api_comp_quality = gr.Number() |
| _api_comp_sigma = gr.Number() |
| _api_comp_output = gr.Image(type="pil") |
| _api_comp_bpp = gr.Number() |
| _api_comp_ratio = gr.Number() |
| |
| |
| _api_det_image = gr.Image(type="pil") |
| _api_det_method = gr.Textbox() |
| _api_det_classes = gr.Textbox() |
| _api_det_conf = gr.Number() |
| _api_det_result = gr.Textbox() |
| |
| |
| _api_det_ov_image = gr.Image(type="pil") |
| _api_det_ov_method = gr.Textbox() |
| _api_det_ov_classes = gr.Textbox() |
| _api_det_ov_conf = gr.Number() |
| _api_det_ov_image_out = gr.Image(type="pil") |
| _api_det_ov_result = gr.Textbox() |
| |
| |
| _api_proc_image = gr.Image(type="pil") |
| _api_proc_prompt = gr.Textbox() |
| _api_proc_seg_method = gr.Textbox() |
| _api_proc_quality = gr.Number() |
| _api_proc_sigma = gr.Number() |
| _api_proc_run_det = gr.Checkbox() |
| _api_proc_det_method = gr.Textbox() |
| _api_proc_det_classes = gr.Textbox() |
| _api_proc_out_image = gr.Image(type="pil") |
| _api_proc_out_mask = gr.Image(type="pil") |
| _api_proc_out_bpp = gr.Number() |
| _api_proc_out_ratio = gr.Number() |
| _api_proc_out_coverage = gr.Number() |
| _api_proc_out_dets = gr.Textbox() |
| |
| |
| gr.Button(visible=False).click( |
| api_segment, |
| inputs=[_api_seg_image, _api_seg_prompt, _api_seg_method, _api_seg_return_overlay], |
| outputs=[_api_seg_mask, _api_seg_coverage, _api_seg_classes], |
| api_name="segment", |
| ) |
| |
| |
| gr.Button(visible=False).click( |
| api_compress, |
| inputs=[_api_comp_image, _api_comp_mask, _api_comp_quality, _api_comp_sigma], |
| outputs=[_api_comp_output, _api_comp_bpp, _api_comp_ratio], |
| api_name="compress", |
| ) |
| |
| |
| gr.Button(visible=False).click( |
| api_detect, |
| inputs=[_api_det_image, _api_det_method, _api_det_classes, _api_det_conf], |
| outputs=[_api_det_result], |
| api_name="detect", |
| ) |
| |
| |
| gr.Button(visible=False).click( |
| api_detect_overlay, |
| inputs=[_api_det_ov_image, _api_det_ov_method, _api_det_ov_classes, _api_det_ov_conf], |
| outputs=[_api_det_ov_image_out, _api_det_ov_result], |
| api_name="detect_overlay", |
| ) |
| |
| |
| gr.Button(visible=False).click( |
| api_process, |
| inputs=[ |
| _api_proc_image, _api_proc_prompt, _api_proc_seg_method, |
| _api_proc_quality, _api_proc_sigma, |
| _api_proc_run_det, _api_proc_det_method, _api_proc_det_classes, |
| ], |
| outputs=[ |
| _api_proc_out_image, _api_proc_out_mask, |
| _api_proc_out_bpp, _api_proc_out_ratio, _api_proc_out_coverage, |
| _api_proc_out_dets, |
| ], |
| api_name="process", |
| ) |
| |
| |
| with gr.Row(visible=False): |
| |
| _api_vid_input = gr.Video() |
| _api_vid_prompt = gr.Textbox() |
| _api_vid_seg_method = gr.Textbox() |
| _api_vid_mode = gr.Textbox() |
| _api_vid_quality = gr.Number() |
| _api_vid_sigma = gr.Number() |
| _api_vid_fps = gr.Number() |
| _api_vid_bandwidth = gr.Number() |
| _api_vid_min_fps = gr.Number() |
| _api_vid_max_fps = gr.Number() |
| _api_vid_aggressiveness = gr.Number() |
| _api_vid_run_det = gr.Checkbox() |
| _api_vid_det_method = gr.Textbox() |
| _api_vid_mask_file = gr.Textbox() |
| _api_vid_output = gr.Video() |
| _api_vid_stats = gr.Textbox() |
| |
| |
| _api_vid_seg_input = gr.Video() |
| _api_vid_seg_prompt = gr.Textbox() |
| _api_vid_seg_method = gr.Textbox() |
| _api_vid_seg_return_overlay = gr.Checkbox() |
| _api_vid_seg_output_fps = gr.Number() |
| _api_vid_seg_output = gr.File() |
| _api_vid_seg_output_video = gr.Video() |
| _api_vid_seg_stats = gr.Textbox() |
| |
| |
| _api_vid_comp_input = gr.Video() |
| _api_vid_comp_mask_file = gr.Textbox() |
| _api_vid_comp_quality = gr.Number() |
| _api_vid_comp_sigma = gr.Number() |
| _api_vid_comp_fps = gr.Number() |
| _api_vid_comp_output = gr.Video() |
| _api_vid_comp_stats = gr.Textbox() |
| |
| |
| _api_vid_det_input = gr.Video() |
| _api_vid_det_method = gr.Textbox() |
| _api_vid_det_classes = gr.Textbox() |
| _api_vid_det_conf = gr.Number() |
| _api_vid_det_return_overlay = gr.Checkbox() |
| _api_vid_det_fps = gr.Number() |
| _api_vid_det_output = gr.Video() |
| _api_vid_det_result = gr.Textbox() |
| |
| |
| gr.Button(visible=False).click( |
| api_process_video, |
| inputs=[ |
| _api_vid_input, _api_vid_prompt, _api_vid_seg_method, |
| _api_vid_mode, _api_vid_quality, _api_vid_sigma, _api_vid_fps, |
| _api_vid_bandwidth, _api_vid_min_fps, _api_vid_max_fps, |
| _api_vid_aggressiveness, _api_vid_run_det, _api_vid_det_method, _api_vid_mask_file, |
| ], |
| outputs=[_api_vid_output, _api_vid_stats], |
| api_name="process_video", |
| ) |
| |
| |
| gr.Button(visible=False).click( |
| api_segment_video, |
| inputs=[ |
| _api_vid_seg_input, _api_vid_seg_prompt, _api_vid_seg_method, |
| _api_vid_seg_return_overlay, _api_vid_seg_output_fps, |
| ], |
| outputs=[_api_vid_seg_output_video, _api_vid_seg_stats], |
| api_name="segment_video", |
| ) |
| |
| |
| gr.Button(visible=False).click( |
| api_compress_video, |
| inputs=[ |
| _api_vid_comp_input, _api_vid_comp_mask_file, |
| _api_vid_comp_quality, _api_vid_comp_sigma, _api_vid_comp_fps, |
| ], |
| outputs=[_api_vid_comp_output, _api_vid_comp_stats], |
| api_name="compress_video", |
| ) |
| |
| |
| gr.Button(visible=False).click( |
| api_detect_video, |
| inputs=[ |
| _api_vid_det_input, _api_vid_det_method, _api_vid_det_classes, |
| _api_vid_det_conf, _api_vid_det_return_overlay, _api_vid_det_fps, |
| ], |
| outputs=[_api_vid_det_output, _api_vid_det_result], |
| api_name="detect_video", |
| ) |
| |
|
|
| return demo |
|
|
|
|
| if __name__ == "__main__": |
| |
| if sys.version_info >= (3, 13): |
| try: |
| |
| loop = asyncio.new_event_loop() |
| asyncio.set_event_loop(loop) |
| except Exception: |
| pass |
| |
| app = build_app() |
| app.launch(theme=gr.themes.Soft(), show_error=True) |
|
|