raheebhassan's picture
Fix more API issues
a5c3445
"""Hugging Face Spaces Gradio demo for ROI-VAE compression.
This application provides both a web UI and programmatic API for:
- ROI-based image and video compression using TIC VAE
- Segmentation (SAM3, YOLO, SegFormer, Mask2Former, MaskRCNN)
- Object detection (YOLO, DETR, Grounding DINO, YOLO-World, etc.)
Web UI Tabs:
Image Tab:
- Upload image → Choose ROI method → Set mission/prompt
- Find ROI: extracts regions of interest via segmentation
- Transmit: compresses with ROI preservation (σ controls background)
- Optional: ROI highlight overlay, detection overlays
Video Tab:
- Upload video → Choose ROI method → Set mission/prompt
- Find ROI: segments all frames, saves masks for reuse
- Transmit: compresses with cached masks (3-5x faster on repeat)
- Optional: detection overlays on compressed output
API Endpoints (see API.md for full documentation):
Image:
/segment - Segment image → mask or overlay
/compress - Compress image with optional ROI mask
/detect - Run object detection → JSON or overlay
/process - Full pipeline: segment → compress → detect
Video (Buffered):
/segment_video - Segment video → mask file or overlay video
/compress_video - Compress video with optional cached masks
/detect_video - Run detection on video → JSON or overlay video
/process_video - Full pipeline with static/dynamic modes
"""
from __future__ import annotations
import asyncio
import json
import os
import re
import sys
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import List, Optional, Sequence
import numpy as np
import gradio as gr
# Fix Python 3.13 asyncio event loop cleanup issues
if sys.version_info >= (3, 13):
# Set environment variable to suppress asyncio warnings
os.environ.setdefault('PYTHONWARNINGS', 'ignore::ResourceWarning')
try:
# Suppress asyncio file descriptor warnings during cleanup
import warnings
warnings.filterwarnings("ignore", category=ResourceWarning)
warnings.filterwarnings("ignore", message=".*file descriptor.*")
except Exception:
pass
try:
import spaces # Hugging Face Spaces (ZeroGPU / GPU scheduling)
except Exception: # pragma: no cover
class _SpacesFallback:
@staticmethod
def GPU(*_args, **_kwargs):
def _decorator(fn):
return fn
# Handle both @spaces.GPU and @spaces.GPU() usage patterns
if len(_args) == 1 and callable(_args[0]) and not _kwargs:
# Called as @spaces.GPU (without parentheses)
return _args[0]
# Called as @spaces.GPU() (with parentheses)
return _decorator
spaces = _SpacesFallback()
import vae
from model_cache import ensure_default_checkpoint_dirs
def _default_device() -> str:
try:
import torch
return "cuda" if torch.cuda.is_available() else "cpu"
except Exception:
return "cpu"
CHECKPOINTS = [
# (label, lambda, relative_path, N, M)
("Smallest file", 0.0035, "checkpoints/tic_lambda_0.0035.pth.tar", 128, 192),
("Smaller file", 0.013, "checkpoints/tic_lambda_0.013.pth.tar", 128, 192),
("Balanced", 0.025, "checkpoints/tic_lambda_0.025.pth.tar", 192, 192),
("Higher quality", 0.0483, "checkpoints/tic_lambda_0.0483.pth.tar", 192, 192),
("Best quality", 0.0932, "checkpoints/tic_lambda_0.0932.pth.tar", 192, 192),
]
# Open-vocabulary detectors that require class prompts
OPEN_VOCAB_DETECTORS = {"yolo_world", "grounding_dino"}
# =============================================================================
# Public API Functions (exposed via Gradio API)
# =============================================================================
# These functions are designed for external use via the Gradio Client.
# They use @spaces.GPU for ZeroGPU support on Hugging Face Spaces.
#
# Usage from Python:
# from gradio_client import Client, handle_file
# client = Client("your-space-name")
# result = client.predict(handle_file("img.jpg"), "car", "sam3", False, api_name="/segment")
#
# IMAGE API ENDPOINTS:
# /segment - Segment image → (mask_or_overlay, coverage, classes_json)
# /compress - Compress image → (compressed, bpp, ratio)
# /detect - Object detection → (overlay_or_none, detections_json)
# /process - Full pipeline → (compressed, mask, bpp, ratio, coverage, dets_json)
#
# VIDEO API ENDPOINTS:
# /segment_video - Segment video → (mask_file_or_overlay_video, stats_json)
# /compress_video - Compress video → (compressed_video, stats_json)
# /detect_video - Detection on video → (overlay_video_or_none, detections_json)
# /process_video - Full pipeline → (compressed_video, stats_json)
#
# See API.md for complete documentation with examples.
# =============================================================================
def _split_classes(classes_str: str) -> List[str]:
"""Split comma-separated class string into list."""
if not classes_str:
return []
parts = [c.strip() for c in classes_str.split(",")]
return [c for c in parts if c]
@spaces.GPU
def api_segment(
image,
prompt: str = "object",
method: str = "sam3",
return_overlay: bool = False,
) -> tuple:
"""
Segment an image to find Regions of Interest (ROI).
Args:
image: Input image (PIL Image)
prompt: Text prompt or comma-separated classes (e.g., "person, car")
method: Segmentation method (sam3, yolo, segformer, mask2former, maskrcnn)
return_overlay: If True, returns image with ROI highlighted instead of mask
Returns:
tuple: (result_image, roi_coverage, classes_used)
- result_image: Grayscale mask OR image with ROI overlay (if return_overlay=True)
- roi_coverage: Fraction of image covered by ROI (0.0-1.0)
- classes_used: JSON string of classes/prompts used
"""
from PIL import Image
import json
if image is None:
return None, 0.0, "[]"
ensure_default_checkpoint_dirs()
device = _default_device()
# Get segmenter
segmenter = _get_segmenter(method, device)
# Parse prompt into classes
if method == "sam3":
targets = _split_classes(prompt) or ["object"]
else:
targets = _split_classes(prompt)
if not targets:
targets = segmenter.get_default_classes()
# Run segmentation
mask = segmenter(image, target_classes=targets)
mask = mask.astype(np.float32)
roi_coverage = float(mask.mean())
# Return overlay or mask based on parameter
if return_overlay:
# Return image with ROI highlighted
result_image = vae.highlight_roi(image, mask, alpha=0.35, color=(0, 255, 0))
else:
# Convert mask to image for Gradio (default behavior)
mask_uint8 = (mask * 255).astype(np.uint8)
result_image = Image.fromarray(mask_uint8, mode="L")
return result_image, roi_coverage, json.dumps(targets)
@spaces.GPU
def api_compress(
image,
mask_image = None,
quality: int = 4,
sigma: float = 0.3,
) -> tuple:
"""
Compress an image with optional ROI preservation.
Args:
image: Input image (PIL Image)
mask_image: Optional ROI mask (grayscale image, white=ROI)
quality: Quality level 1-5 (1=smallest file, 5=best quality)
sigma: Background preservation 0.01-1.0 (lower=more compression)
Returns:
tuple: (compressed_image, bpp, compression_ratio)
- compressed_image: Reconstructed image after compression
- bpp: Bits per pixel
- compression_ratio: Compression ratio (24/bpp)
"""
if image is None:
return None, 0.0, 0.0
ensure_default_checkpoint_dirs()
device = _default_device()
# Get quality config
idx = max(0, min(int(quality) - 1, len(CHECKPOINTS) - 1))
_name, _lambda, ckpt_rel, N, M = CHECKPOINTS[idx]
# Load model
model = _get_compression_model(ckpt_rel, device, N, M)
# Prepare mask
if mask_image is not None:
mask = np.array(mask_image.convert("L")).astype(np.float32) / 255.0
# Resize if needed
if mask.shape != (image.height, image.width):
from PIL import Image as PILImage
mask_resized = PILImage.fromarray((mask * 255).astype(np.uint8), mode="L")
mask_resized = mask_resized.resize((image.width, image.height), PILImage.NEAREST)
mask = np.array(mask_resized).astype(np.float32) / 255.0
else:
mask = np.zeros((image.height, image.width), dtype=np.float32)
# Compress
result = vae.compress_image(
image=image,
mask=mask,
model=model,
sigma=float(sigma),
device=device,
)
bpp = float(result["bpp"])
compression_ratio = 24.0 / bpp if bpp > 0 else 0.0
return result["compressed"], bpp, compression_ratio
@spaces.GPU
def api_detect(
image,
method: str = "yolo",
classes: str = "",
confidence: float = 0.25,
return_overlay: bool = False,
):
"""
Run object detection on an image.
Args:
image: Input image (PIL Image)
method: Detection method (yolo, yolo_world, grounding_dino, detr, etc.)
classes: Comma-separated classes (required for yolo_world, grounding_dino)
confidence: Confidence threshold 0.0-1.0
return_overlay: If True, returns tuple (image_with_boxes, json). If False, returns json only.
Returns:
If return_overlay=False (default): str - JSON string with list of detections
If return_overlay=True: tuple(Image, str) - Image with boxes and JSON string
"""
import json
if image is None:
return (None, "[]") if return_overlay else "[]"
ensure_default_checkpoint_dirs()
device = _default_device()
detector = _get_detector(method, device)
det_kwargs = {"conf_threshold": float(confidence)}
# Handle open-vocab detectors
class_list = _split_classes(classes)
if method in OPEN_VOCAB_DETECTORS:
if not class_list:
return (image, "[]") if return_overlay else "[]"
det_kwargs["classes"] = class_list
dets = detector(image, **det_kwargs)
results = [
{
"label": d.label,
"score": float(d.score),
"bbox_xyxy": [float(x) for x in d.bbox_xyxy],
}
for d in dets
]
detections_json = json.dumps(results)
if return_overlay:
from detection.utils import draw_detections
image_with_dets = draw_detections(image.copy(), dets, color=(0, 255, 0))
return image_with_dets, detections_json
# For backwards compatibility - original /detect returned just JSON string
return detections_json
@spaces.GPU
def api_detect_overlay(
image,
method: str = "yolo",
classes: str = "",
confidence: float = 0.25,
):
"""
Run object detection on an image and return image with bounding boxes.
This is a separate endpoint for getting detection overlays.
For JSON-only results, use /detect.
Args:
image: Input image (PIL Image)
method: Detection method (yolo, yolo_world, grounding_dino, detr, etc.)
classes: Comma-separated classes (required for yolo_world, grounding_dino)
confidence: Confidence threshold 0.0-1.0
Returns:
tuple(Image, str) - Image with bounding boxes and JSON string
"""
return api_detect(image, method, classes, confidence, return_overlay=True)
@spaces.GPU
def api_process(
image,
prompt: str = "object",
segmentation_method: str = "sam3",
quality: int = 4,
sigma: float = 0.3,
run_detection: bool = False,
detection_method: str = "yolo",
detection_classes: str = "",
) -> tuple:
"""
Full processing pipeline: segment → compress → (optionally) detect.
Args:
image: Input image (PIL Image)
prompt: Text prompt or comma-separated classes for segmentation
segmentation_method: ROI extraction method (sam3, yolo, segformer, etc.)
quality: Compression quality 1-5
sigma: Background preservation 0.01-1.0
run_detection: Whether to run detection on output
detection_method: Object detector to use
detection_classes: Classes for open-vocab detectors
Returns:
tuple: (compressed_image, mask_image, bpp, compression_ratio, roi_coverage, detections_json)
"""
import json
if image is None:
return None, None, 0.0, 0.0, 0.0, "[]"
# Step 1: Segment
mask_image, roi_coverage, classes_json = api_segment(image, prompt, segmentation_method)
# Step 2: Compress
compressed, bpp, compression_ratio = api_compress(image, mask_image, quality, sigma)
# Step 3: Optional detection
detections_json = "[]"
if run_detection and compressed is not None:
det_classes = detection_classes or prompt
detections_json = api_detect(compressed, detection_method, det_classes)
return compressed, mask_image, bpp, compression_ratio, roi_coverage, detections_json
@spaces.GPU
def api_process_video(
video_path: str,
prompt: str = "object",
segmentation_method: str = "sam3",
mode: str = "static",
quality: int = 4,
sigma: float = 0.3,
output_fps: float = 15.0,
bandwidth_kbps: float = 500.0,
min_fps: float = 5.0,
max_fps: float = 30.0,
aggressiveness: float = 0.5,
run_detection: bool = False,
detection_method: str = "yolo",
mask_file_path: Optional[str] = None,
) -> tuple:
"""
Process video with ROI-based compression.
Args:
video_path: Path to input video file
prompt: Text prompt or comma-separated classes for segmentation
segmentation_method: ROI extraction method (sam3, yolo, segformer, etc.)
mode: "static" (fixed settings) or "dynamic" (bandwidth-adaptive)
quality: Compression quality 1-5 (static mode)
sigma: Background preservation 0.01-1.0 (static mode)
output_fps: Target output framerate (static mode)
bandwidth_kbps: Target bandwidth in kbps (dynamic mode)
min_fps: Minimum framerate (dynamic mode)
max_fps: Maximum framerate (dynamic mode)
aggressiveness: Bandwidth savings strategy 0.0-1.0 (dynamic mode)
- 0.0: Use full bandwidth, maintain high FPS always
- 0.5: Moderate savings (default)
- 1.0: Maximum savings, aggressive FPS reduction for low motion
run_detection: Whether to run detection/tracking
detection_method: Object detector to use
mask_file_path: Optional path to pre-computed segmentation masks (skips segmentation)
Returns:
tuple: (output_video_path, stats_json)
- output_video_path: Path to compressed video
- stats_json: JSON string with compression statistics
"""
import json
import tempfile
if video_path is None:
return None, "{}"
# Sanitize parameters - Gradio may pass None which overrides defaults
if segmentation_method is None or not segmentation_method.strip():
segmentation_method = "sam3"
if detection_method is None or not detection_method.strip():
detection_method = "yolo"
if prompt is None:
prompt = "object"
if mode is None or not mode.strip():
mode = "static"
ensure_default_checkpoint_dirs()
device = _default_device()
target_classes = _split_classes(prompt) if prompt else []
try:
from video import VideoProcessor, CompressionSettings, load_video_masks
# Load saved masks if provided
saved_masks = None
if mask_file_path is not None:
try:
# Handle both string paths and Gradio FileData dicts
actual_path = mask_file_path
if isinstance(mask_file_path, dict):
actual_path = mask_file_path.get("path", mask_file_path.get("name"))
saved_masks = load_video_masks(actual_path)
print(f"API: Loaded {len(saved_masks)} cached masks")
except Exception as e:
print(f"API: Failed to load masks: {e}, will re-segment")
processor = VideoProcessor(device=device)
processor.load_models(
quality_level=quality if mode == "static" else 3,
segmentation_method=segmentation_method,
detection_method=detection_method if run_detection else "yolo",
enable_tracking=run_detection,
)
settings = CompressionSettings(
mode=mode,
quality_level=quality,
sigma=sigma,
output_fps=output_fps,
target_bandwidth_kbps=bandwidth_kbps,
min_fps=min_fps,
max_fps=max_fps,
aggressiveness=aggressiveness,
segmentation_method=segmentation_method,
target_classes=target_classes,
enable_tracking=run_detection,
)
# Process video – fully offline two-phase (segment-all → compress-all)
all_frames = []
all_stats = []
if mode == "static":
all_chunks = processor.process_static_offline(
video_path, settings, saved_masks=saved_masks,
)
else: # dynamic – offline batched pipeline
all_chunks = processor.process_dynamic_offline(
video_path, settings, saved_masks=saved_masks,
)
for chunk in all_chunks:
all_frames.extend(chunk.frames)
all_stats.append({
"chunk_index": chunk.chunk_index,
"frames": len(chunk.frames),
"fps": round(chunk.fps, 1),
"quality": chunk.quality_level,
"sigma": round(chunk.sigma, 2),
"avg_bpp": round(chunk.avg_bpp, 3),
})
# Create output video
if all_frames:
from video.video_processor import frames_to_video_bytes
avg_fps = sum(s["fps"] for s in all_stats) / len(all_stats) if all_stats else 15
video_bytes = frames_to_video_bytes(all_frames, avg_fps, format="mp4")
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
tmp.write(video_bytes)
output_path = tmp.name
stats = {
"total_frames": len(all_frames),
"avg_fps": round(avg_fps, 1),
"chunks": all_stats,
}
# Clean up to free memory
del all_frames, all_chunks
import gc
import torch
gc.collect()
torch.cuda.empty_cache()
return output_path, json.dumps(stats)
return None, "{}"
except Exception as e:
return None, json.dumps({"error": str(e)})
finally:
# Always cleanup processor to release models
if 'processor' in locals():
processor.cleanup()
@spaces.GPU
def api_segment_video(
video_path: str,
prompt: str = "object",
method: str = "sam3",
return_overlay: bool = False,
output_fps: float = 15.0,
) -> tuple:
"""
Segment a video to find Regions of Interest (ROI).
Args:
video_path: Path to input video file
prompt: Text prompt or comma-separated classes (e.g., "person, car")
method: Segmentation method (sam3, yolo, segformer, mask2former, maskrcnn)
return_overlay: If True, returns video with ROI highlighted instead of mask file
output_fps: Output framerate (max 30 FPS)
Returns:
tuple: (result_path, stats_json)
- result_path: Path to mask file (NPZ) OR video with ROI overlay (if return_overlay=True)
- stats_json: JSON string with frame count, coverage stats, and classes used
"""
import json
import tempfile
import cv2
from PIL import Image
if video_path is None:
return None, "{}"
ensure_default_checkpoint_dirs()
device = _default_device()
target_classes = _split_classes(prompt) if prompt else ["object"]
try:
from video.video_processor import frames_to_video_bytes, MAX_PROCESSING_FPS, MAX_PROCESSING_HEIGHT
from video.mask_cache import save_video_masks
from video.chunk_compressor import smooth_masks_sdf
from vae.visualization import highlight_roi
from segmentation import create_segmenter
# Load segmenter
segmenter = create_segmenter(method, device=device)
# Extract frames from video with preprocessing
cap = cv2.VideoCapture(video_path)
original_fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# Calculate effective FPS (capped)
effective_fps = min(output_fps, MAX_PROCESSING_FPS)
frame_step = max(1, int(original_fps / effective_fps))
# Calculate resize dimensions
if original_height > MAX_PROCESSING_HEIGHT:
scale_factor = MAX_PROCESSING_HEIGHT / original_height
new_width = int(original_width * scale_factor)
new_height = MAX_PROCESSING_HEIGHT
else:
new_width = original_width
new_height = original_height
# Extract all frames first
pil_frames = []
frame_idx = 0
while True:
ret, frame = cap.read()
if not ret:
break
# Skip frames for FPS limiting
if frame_idx % frame_step != 0:
frame_idx += 1
continue
# Resize if needed
if new_width != original_width or new_height != original_height:
frame = cv2.resize(frame, (new_width, new_height), interpolation=cv2.INTER_AREA)
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil_frames.append(Image.fromarray(frame_rgb))
frame_idx += 1
cap.release()
if not pil_frames:
return None, json.dumps({"error": "No frames extracted from video"})
# Auto-detect batch size from GPU memory
from video.gpu_memory import estimate_batch_sizes
batch_est = estimate_batch_sizes(
frame_height=new_height,
frame_width=new_width,
seg_method=method,
device=str(device),
total_frames=len(pil_frames),
)
seg_batch = batch_est.seg_batch_size
print(f"API segment_video: {len(pil_frames)} frames, seg batch={seg_batch} ({batch_est.notes})")
# Batch segment all frames with OOM retry
import torch
max_retries = 7
all_masks = None
for attempt in range(max_retries + 1):
try:
all_masks = []
if hasattr(segmenter, 'segment_batch') and getattr(segmenter, 'supports_batch', False):
for i in range(0, len(pil_frames), seg_batch):
batch = pil_frames[i:i + seg_batch]
batch_masks = segmenter.segment_batch(batch, target_classes=target_classes)
all_masks.extend([m.astype('float32') for m in batch_masks])
else:
for pil_frame in pil_frames:
mask = segmenter(pil_frame, target_classes=target_classes)
all_masks.append(mask.astype('float32'))
break # success
except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
if 'out of memory' in str(e).lower() and attempt < max_retries:
seg_batch = max(1, seg_batch // 2)
# Aggressive memory cleanup to prevent leaks
all_masks = None
import gc
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
print(f"API segment_video: OOM, retrying with batch={seg_batch} (attempt {attempt+1}/{max_retries})")
continue
raise
if all_masks is None:
return None, json.dumps({"error": "Segmentation failed after OOM retries"})
processed_frames = len(all_masks)
coverage_sum = sum(float(m.mean()) for m in all_masks)
# Create overlay frames if requested
overlay_frames = None
if return_overlay:
overlay_frames = []
for pil_frame, mask in zip(pil_frames, all_masks):
if mask.sum() > 0:
highlighted = highlight_roi(pil_frame, mask, alpha=0.35, color=(0, 255, 0))
else:
highlighted = pil_frame
overlay_frames.append(highlighted)
# Apply SDF temporal smoothing to masks
all_masks = smooth_masks_sdf(all_masks, alpha=0.5, empty_thresh=10, patience=5)
avg_coverage = coverage_sum / processed_frames if processed_frames > 0 else 0.0
stats = {
"total_frames": processed_frames,
"fps": round(effective_fps, 1),
"dimensions": [new_width, new_height],
"avg_roi_coverage": round(avg_coverage, 4),
"classes_used": target_classes,
}
if return_overlay:
# Return video with ROI overlays
video_bytes = frames_to_video_bytes(overlay_frames, effective_fps, format="mp4")
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
tmp.write(video_bytes)
output_path = tmp.name
else:
# Save masks to file for later use
output_path = save_video_masks(all_masks)
stats["mask_file"] = output_path
return output_path, json.dumps(stats)
except Exception as e:
return None, json.dumps({"error": str(e)})
@spaces.GPU
def api_compress_video(
video_path: str,
mask_file_path: Optional[str] = None,
quality: int = 4,
sigma: float = 0.3,
output_fps: float = 15.0,
) -> tuple:
"""
Compress a video with optional ROI preservation.
Args:
video_path: Path to input video file
mask_file_path: Optional path to pre-computed masks (from api_segment_video)
quality: Quality level 1-5 (1=smallest file, 5=best quality)
sigma: Background preservation 0.01-1.0 (lower=more compression)
output_fps: Target output framerate
Returns:
tuple: (compressed_video_path, stats_json)
- compressed_video_path: Path to compressed video
- stats_json: JSON string with compression statistics (bpp, ratio, etc.)
"""
import json
import tempfile
if video_path is None:
return None, "{}"
ensure_default_checkpoint_dirs()
device = _default_device()
try:
from video import VideoProcessor, CompressionSettings, load_video_masks
from video.video_processor import frames_to_video_bytes
# Load saved masks if provided
saved_masks = None
if mask_file_path is not None:
try:
# Handle both string paths and Gradio FileData dicts
actual_path = mask_file_path
if isinstance(mask_file_path, dict):
actual_path = mask_file_path.get("path", mask_file_path.get("name"))
saved_masks = load_video_masks(actual_path)
print(f"API: Loaded {len(saved_masks)} cached masks")
except Exception as e:
print(f"API: Failed to load masks: {e}, using empty masks")
# Create processor
processor = VideoProcessor(device=device)
processor.load_models(
quality_level=quality,
segmentation_method="sam3", # Won't be used if masks provided
detection_method="yolo",
enable_tracking=False,
)
# Create settings - use static mode for API simplicity
settings = CompressionSettings(
mode="static",
quality_level=quality,
sigma=sigma,
output_fps=output_fps,
segmentation_method="sam3",
target_classes=[],
enable_tracking=False,
)
# Process video using offline batch processing
all_chunks = processor.process_static_offline(video_path, settings, saved_masks=saved_masks)
all_frames = []
all_stats = []
total_bytes = 0
total_original_frames = 0
for chunk in all_chunks:
all_frames.extend(chunk.frames)
total_bytes += chunk.estimated_bytes
total_original_frames += chunk.original_frame_count
all_stats.append({
"chunk_index": chunk.chunk_index,
"frames": len(chunk.frames),
"fps": round(chunk.fps, 1),
"quality": chunk.quality_level,
"sigma": round(chunk.sigma, 2),
"avg_bpp": round(chunk.avg_bpp, 3),
})
if not all_frames:
return None, json.dumps({"error": "No frames produced"})
# Calculate final FPS
avg_fps = sum(s["fps"] for s in all_stats) / len(all_stats) if all_stats else output_fps
# Create output video
video_bytes = frames_to_video_bytes(all_frames, avg_fps, format="mp4")
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
tmp.write(video_bytes)
output_path = tmp.name
# Calculate compression stats
compression_ratio = 0.0
if processor.video_dimensions and total_bytes > 0:
w, h = processor.video_dimensions
compression_ratio = (24 * total_original_frames * w * h) / (8 * total_bytes)
stats = {
"total_frames": len(all_frames),
"original_frames": total_original_frames,
"avg_fps": round(avg_fps, 1),
"total_size_kb": round(total_bytes / 1024, 1),
"avg_bpp": round(sum(s["avg_bpp"] for s in all_stats) / max(1, len(all_stats)), 3),
"compression_ratio": round(compression_ratio, 2),
"used_cached_masks": mask_file_path is not None and saved_masks is not None,
"chunks": all_stats,
}
return output_path, json.dumps(stats)
except Exception as e:
return None, json.dumps({"error": str(e)})
finally:
# Always cleanup processor to release models
if 'processor' in locals():
processor.cleanup()
@spaces.GPU
def api_detect_video(
video_path: str,
method: str = "yolo",
classes: str = "",
confidence: float = 0.25,
return_overlay: bool = False,
output_fps: float = 15.0,
) -> tuple:
"""
Run object detection on a video.
Args:
video_path: Path to input video file
method: Detection method (yolo, yolo_world, grounding_dino, detr, etc.)
classes: Comma-separated classes (required for yolo_world, grounding_dino)
confidence: Confidence threshold 0.0-1.0
return_overlay: If True, returns video with detection boxes instead of JSON
output_fps: Output framerate (max 30 FPS)
Returns:
tuple: (result_path, detections_json)
- result_path: Video with detection boxes (if return_overlay=True), None otherwise
- detections_json: JSON string with per-frame detections list
"""
import json
import tempfile
import cv2
from PIL import Image
if video_path is None:
return None, "[]"
ensure_default_checkpoint_dirs()
device = _default_device()
try:
from video.video_processor import frames_to_video_bytes, MAX_PROCESSING_FPS, MAX_PROCESSING_HEIGHT
from detection import create_detector
from detection.utils import draw_detections
# Load detector
detector = create_detector(method, device=device)
# Parse classes for open-vocab detectors
det_kwargs = {"conf_threshold": float(confidence)}
class_list = _split_classes(classes)
if method in OPEN_VOCAB_DETECTORS:
if not class_list:
return None, json.dumps({"error": f"{method} requires class prompts"})
det_kwargs["classes"] = class_list
# Extract frames from video
cap = cv2.VideoCapture(video_path)
original_fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# Calculate effective FPS (capped)
effective_fps = min(output_fps, MAX_PROCESSING_FPS)
frame_step = max(1, int(original_fps / effective_fps))
# Calculate resize dimensions
if original_height > MAX_PROCESSING_HEIGHT:
scale_factor = MAX_PROCESSING_HEIGHT / original_height
new_width = int(original_width * scale_factor)
new_height = MAX_PROCESSING_HEIGHT
else:
new_width = original_width
new_height = original_height
all_detections = []
overlay_frames = [] if return_overlay else None
frame_idx = 0
processed_frames = 0
total_detections = 0
while True:
ret, frame = cap.read()
if not ret:
break
# Skip frames for FPS limiting
if frame_idx % frame_step != 0:
frame_idx += 1
continue
# Resize if needed
if new_width != original_width or new_height != original_height:
frame = cv2.resize(frame, (new_width, new_height), interpolation=cv2.INTER_AREA)
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil_frame = Image.fromarray(frame_rgb)
# Run detection
dets = detector(pil_frame, **det_kwargs)
frame_detections = [
{
"label": d.label,
"score": float(d.score),
"bbox_xyxy": [float(x) for x in d.bbox_xyxy],
}
for d in dets
]
all_detections.append({
"frame_index": processed_frames,
"detections": frame_detections,
})
total_detections += len(dets)
# Create overlay frame if requested
if return_overlay:
frame_with_dets = draw_detections(pil_frame, dets, color=(0, 255, 0))
overlay_frames.append(frame_with_dets)
processed_frames += 1
frame_idx += 1
cap.release()
if processed_frames == 0:
return None, json.dumps({"error": "No frames extracted from video"})
result_data = {
"total_frames": processed_frames,
"total_detections": total_detections,
"avg_detections_per_frame": round(total_detections / processed_frames, 2),
"fps": round(effective_fps, 1),
"dimensions": [new_width, new_height],
"frames": all_detections,
}
if return_overlay:
# Return video with detection overlays
video_bytes = frames_to_video_bytes(overlay_frames, effective_fps, format="mp4")
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
tmp.write(video_bytes)
output_path = tmp.name
return output_path, json.dumps(result_data)
else:
return None, json.dumps(result_data)
except Exception as e:
return None, json.dumps({"error": str(e)})
# =============================================================================
# Internal data classes and helpers
# =============================================================================
@dataclass(frozen=True)
class _ExtractResult:
classes: List[str]
note: str
def _dets_to_dicts(dets) -> List[dict]:
out = []
for d in dets or []:
out.append(
{
"label": d.label,
"score": float(d.score),
"bbox_xyxy": [float(x) for x in d.bbox_xyxy],
}
)
return out
def _dicts_to_dets(dets_dicts):
from detection.base import Detection
out = []
for dd in dets_dicts or []:
try:
out.append(
Detection(
label=str(dd["label"]),
score=float(dd["score"]),
bbox_xyxy=[float(x) for x in dd["bbox_xyxy"]],
)
)
except Exception:
continue
return out
def _render_input_preview(original_image, mask, highlight_roi: bool, show_detection: bool, dets_dicts):
if original_image is None:
return None
img = original_image.copy()
if highlight_roi and mask is not None:
img = vae.highlight_roi(img, mask)
if show_detection and dets_dicts:
from detection.utils import draw_detections
img = draw_detections(img, _dicts_to_dets(dets_dicts))
return img
def _render_output_preview(compressed_image, mask, highlight_roi: bool, show_detection: bool, dets_dicts):
if compressed_image is None:
return None
img = compressed_image.copy()
if highlight_roi and mask is not None:
img = vae.highlight_roi(img, mask)
if show_detection and dets_dicts:
from detection.utils import draw_detections
img = draw_detections(img, _dicts_to_dets(dets_dicts), color=(255, 0, 0))
return img
def _split_comma_list(text: str) -> List[str]:
parts = [p.strip() for p in (text or "").split(",")]
return [p for p in parts if p]
def _heuristic_extract(mission: str, allowed: Sequence[str], max_classes: int = 6) -> List[str]:
mission_l = (mission or "").lower()
if not mission_l.strip():
return []
allowed_l = {c.lower(): c for c in allowed}
# Simple keyword containment: pick allowed labels mentioned in mission.
hits: List[str] = []
for low, orig in allowed_l.items():
if re.search(rf"\b{re.escape(low)}\b", mission_l):
hits.append(orig)
# Cap to a small, stable list.
hits = sorted(set(hits), key=lambda x: x.lower())
return hits[:max_classes]
def _extract_classes_via_openai(mission: str, allowed: Sequence[str], max_classes: int = 6) -> _ExtractResult:
"""Use OpenAI (if configured) to map mission → allowed class labels.
If `OPENAI_API_KEY` is missing or the request fails, falls back to a heuristic.
"""
allowed = list(dict.fromkeys([str(a).strip() for a in allowed if str(a).strip()]))
if not (mission or "").strip():
return _ExtractResult(classes=[], note="")
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
classes = _heuristic_extract(mission, allowed, max_classes=max_classes)
note = "OpenAI not configured (set OPENAI_API_KEY); using heuristic extraction."
return _ExtractResult(classes=classes, note=note)
model = os.getenv("OPENAI_MODEL", "gpt-4o-mini")
# Keep the payload tight; allowed lists can be large.
allowed_json = json.dumps(allowed, ensure_ascii=False)
system = (
"You extract segmentation class labels from a mission description. "
"Return ONLY JSON and only choose labels from the provided allowed list."
)
user = (
"Mission:\n"
f"{mission}\n\n"
"Allowed labels (JSON array):\n"
f"{allowed_json}\n\n"
f"Return JSON exactly like: {{\"classes\": [..]}} with at most {max_classes} labels."
)
try:
from openai import OpenAI
client = OpenAI(api_key=api_key)
resp = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": system},
{"role": "user", "content": user},
],
temperature=0,
)
content = (resp.choices[0].message.content or "").strip()
# Be tolerant to extra text; extract first JSON object.
m = re.search(r"\{.*\}", content, flags=re.DOTALL)
if not m:
raise ValueError("OpenAI did not return JSON")
data = json.loads(m.group(0))
raw = data.get("classes")
if not isinstance(raw, list):
raise ValueError("OpenAI JSON missing 'classes' list")
allowed_l = {a.lower(): a for a in allowed}
out: List[str] = []
for item in raw:
s = str(item).strip()
if not s:
continue
key = s.lower()
if key in allowed_l:
out.append(allowed_l[key])
out = list(dict.fromkeys(out))[:max_classes]
return _ExtractResult(classes=out, note="")
except Exception as e:
classes = _heuristic_extract(mission, allowed, max_classes=max_classes)
note = f"OpenAI extraction failed; using heuristic. ({type(e).__name__})"
return _ExtractResult(classes=classes, note=note)
@lru_cache(maxsize=16)
def _get_segmenter(method: str, device: str):
from segmentation import create_segmenter
return create_segmenter(method, device=device)
@lru_cache(maxsize=8)
def _get_detector(method: str, device: str):
from detection import create_detector
return create_detector(method, device=device)
@lru_cache(maxsize=8)
def _get_compression_model(checkpoint_rel: str, device: str, N: int, M: int):
ckpt_path = str(Path(__file__).parent / checkpoint_rel)
return vae.load_checkpoint(ckpt_path, N=int(N), M=int(M), device=device)
def _format_classes_md(
segmenter: str,
mission: str,
extracted_classes: Sequence[str],
used_classes: Sequence[str],
note: str,
sam3_prompts: Optional[Sequence[str]] = None,
) -> str:
if segmenter == "sam3":
prompts = list(sam3_prompts or [])
if not prompts:
prompt = (mission or "").strip() or "object"
prompts = [prompt]
joined = ", ".join([f"`{p}`" for p in prompts])
return f"**Prompts (sam3):** {joined}"
extracted = list(extracted_classes or [])
used = list(used_classes or [])
extracted_txt = ", ".join([f"`{c}`" for c in extracted]) if extracted else "(none)"
used_txt = ", ".join([f"`{c}`" for c in used]) if used else "(none)"
suffix = f"\n\n*{note}*" if note else ""
# return f"**GPT extracted:** {extracted_txt}\n\n**Used for ROI:** {used_txt}{suffix}"
return f"**Extracted classes:** {extracted_txt}"
def _compute_roi_mask(
image,
segmenter_method: str,
mission: str,
device: str,
):
segmenter = _get_segmenter(segmenter_method, device)
extracted = _ExtractResult(classes=[], note="")
if segmenter_method == "sam3":
# SAM3 is prompt-based. Use mission text, but split into multiple prompts
# so phrases like "person and car" work better.
raw = (mission or "").strip()
if not raw:
target = ["object"]
else:
parts = re.split(r",|;|\n|\band\b", raw, flags=re.IGNORECASE)
parts = [p.strip() for p in parts if p.strip()]
base_prompts = parts[:6] if parts else [raw]
# Expand simple single-word prompts to help OWL-ViT latch on.
expanded: List[str] = []
for p in base_prompts:
expanded.append(p)
if " " not in p and len(p) >= 3:
expanded.append(f"a {p}")
expanded.append(f"the {p}")
# keep unique, preserve order
seen = set()
target = []
for p in expanded:
k = p.lower()
if k in seen:
continue
seen.add(k)
target.append(p)
target = target[:8] if target else [raw]
else:
avail = segmenter.get_available_classes()
if isinstance(avail, dict):
allowed = list(avail.keys())
elif isinstance(avail, (list, tuple, set)):
allowed = list(avail)
else:
allowed = []
extracted = _extract_classes_via_openai(mission or "", allowed)
target = extracted.classes or segmenter.get_default_classes()
mask = segmenter(image, target_classes=target)
mask = mask.astype(np.float32)
used = target if segmenter_method != "sam3" else []
extracted_classes = extracted.classes if segmenter_method != "sam3" else []
classes_md = _format_classes_md(
segmenter_method,
mission or "",
extracted_classes=extracted_classes,
used_classes=used,
note=extracted.note,
sam3_prompts=target if segmenter_method == "sam3" else None,
)
open_vocab_classes = extracted.classes if segmenter_method != "sam3" else list(target)
return mask, target, extracted, classes_md, open_vocab_classes
def _on_upload(image):
if image is None:
return None, None, None, None, None, "", "", None
# Only set pristine state and clear derived outputs.
return image.copy(), None, None, None, None, "", "", None
def _compute_detections_dicts(
image,
detector_method: str,
device: str,
open_vocab_classes,
segmenter_method: str,
mission: str,
custom_det_classes: str = "",
):
detector = _get_detector(detector_method, device)
det_kwargs = {}
if detector_method in {"yolo_world", "grounding_dino"}:
# Prefer custom detection classes from UI textbox if provided
if custom_det_classes and custom_det_classes.strip():
classes = _split_comma_list(custom_det_classes)
else:
# Fall back to open_vocab_classes from segmentation or mission text
classes = list(open_vocab_classes or [])
if segmenter_method == "sam3" and not classes:
classes = _split_comma_list(mission)
if not classes:
# No classes specified for open-vocab detector - return empty instead of error
return []
det_kwargs["classes"] = classes
dets = detector(image, conf_threshold=0.25, **det_kwargs)
return _dets_to_dicts(dets)
@spaces.GPU
def _find_roi(
original_image,
segmenter_method: str,
mission: str,
input_highlight_roi: bool,
input_enable_detection: bool,
input_detector_method: str,
input_det_classes: str = "",
):
if original_image is None:
return None, None, None, None, "Please upload an image."
ensure_default_checkpoint_dirs()
device = _default_device()
base = original_image.copy()
try:
mask, _target, _extracted, classes_md, open_vocab_classes = _compute_roi_mask(
image=base,
segmenter_method=segmenter_method,
mission=mission,
device=device,
)
roi_frac = float(mask.mean()) if mask is not None else 0.0
if roi_frac <= 1e-6:
classes_md = classes_md + "\n\n**Warning:** ROI mask is empty. Try a simpler prompt (e.g., `person`, `car`) or switch ROI method."
dets_dicts = None
if input_enable_detection:
dets_dicts = _compute_detections_dicts(
image=base,
detector_method=input_detector_method,
device=device,
open_vocab_classes=open_vocab_classes,
segmenter_method=segmenter_method,
mission=mission,
custom_det_classes=input_det_classes,
)
viz = _render_input_preview(base, mask, input_highlight_roi, input_enable_detection, dets_dicts)
status = f"**ROI:** computed. (ROI coverage: {roi_frac*100:.2f}%)"
return viz, mask, open_vocab_classes, dets_dicts, classes_md, status
except Exception as e:
return base, None, None, None, "", f"Error: {type(e).__name__}: {e}"
@spaces.GPU
def _compress(
original_image,
mask,
open_vocab_classes,
segmenter_method: str,
mission: str,
quality_level: int,
background_preservation: float,
output_highlight_roi: bool,
output_enable_detection: bool,
output_detector_method: str,
output_det_classes: str = "",
):
if original_image is None:
return None, None, None, None, None, "", "Please upload an image."
ensure_default_checkpoint_dirs()
device = _default_device()
base = original_image.copy()
idx = int(np.clip(int(quality_level) - 1, 0, len(CHECKPOINTS) - 1))
quality_name, lambda_val, ckpt_rel, N, M = CHECKPOINTS[idx]
try:
classes_md = ""
if mask is None:
mask, _target, _extracted, classes_md, open_vocab_classes = _compute_roi_mask(
image=base,
segmenter_method=segmenter_method,
mission=mission,
device=device,
)
elif not classes_md:
# If we already have a mask, keep class UI meaningful.
if segmenter_method == "sam3":
classes_md = _format_classes_md(
"sam3",
mission or "",
extracted_classes=[],
used_classes=[],
note="",
sam3_prompts=list(open_vocab_classes or []),
)
else:
classes_md = _format_classes_md(
segmenter_method,
mission or "",
extracted_classes=list(open_vocab_classes or []),
used_classes=list(open_vocab_classes or []),
note="",
)
model = _get_compression_model(ckpt_rel, device, int(N), int(M))
sigma = float(background_preservation)
result = vae.compress_image(image=base, mask=mask, model=model, sigma=sigma, device=device)
compressed = result["compressed"]
bpp = float(result["bpp"])
compression_ratio = (24.0 / bpp) if bpp > 0 else float("inf")
roi_frac = float(mask.mean()) if mask is not None else 0.0
if roi_frac <= 1e-6:
classes_md = classes_md + "\n\n**Warning:** ROI mask is empty; compression will behave like background-only."
dets_dicts = None
det_note = ""
if output_enable_detection:
dets_dicts = _compute_detections_dicts(
image=compressed,
detector_method=output_detector_method,
device=device,
open_vocab_classes=open_vocab_classes,
segmenter_method=segmenter_method,
mission=mission,
custom_det_classes=output_det_classes,
)
det_note = f" | detections: {len(dets_dicts)}"
status = (
f"**Compression:** {quality_name} (λ={lambda_val}) | "
f"**Background preservation:** σ={sigma:.2f} | "
f"**Compression ratio:** {compression_ratio:.2f}×{det_note}"
)
output_viz = _render_output_preview(compressed, mask, output_highlight_roi, output_enable_detection, dets_dicts)
return output_viz, compressed, mask, dets_dicts, open_vocab_classes, classes_md, status
except Exception as e:
return None, None, mask, None, open_vocab_classes, "", f"Error: {type(e).__name__}: {e}"
def _refresh_input_view(original_image, mask, dets_dicts, highlight_roi: bool, enable_detection: bool):
return _render_input_preview(original_image, mask, highlight_roi, enable_detection, dets_dicts)
def _refresh_output_view(compressed_image, mask, dets_dicts, highlight_roi: bool, enable_detection: bool):
return _render_output_preview(compressed_image, mask, highlight_roi, enable_detection, dets_dicts)
@spaces.GPU
def _on_input_detection_toggle(
original_image,
mask,
open_vocab_classes,
dets_dicts,
segmenter_method: str,
mission: str,
enable_detection: bool,
detector_method: str,
highlight_roi: bool,
custom_det_classes: str = "",
):
if original_image is None:
return None, dets_dicts
if enable_detection and dets_dicts is None:
ensure_default_checkpoint_dirs()
device = _default_device()
dets_dicts = _compute_detections_dicts(
image=original_image,
detector_method=detector_method,
device=device,
open_vocab_classes=open_vocab_classes,
segmenter_method=segmenter_method,
mission=mission,
custom_det_classes=custom_det_classes,
)
viz = _render_input_preview(original_image, mask, highlight_roi, enable_detection, dets_dicts)
return viz, dets_dicts
@spaces.GPU
def _on_output_detection_toggle(
compressed_image,
mask,
open_vocab_classes,
dets_dicts,
segmenter_method: str,
mission: str,
enable_detection: bool,
detector_method: str,
highlight_roi: bool,
custom_det_classes: str = "",
):
if compressed_image is None:
return None, dets_dicts
if enable_detection and dets_dicts is None:
ensure_default_checkpoint_dirs()
device = _default_device()
dets_dicts = _compute_detections_dicts(
image=compressed_image,
detector_method=detector_method,
device=device,
open_vocab_classes=open_vocab_classes,
segmenter_method=segmenter_method,
mission=mission,
custom_det_classes=custom_det_classes,
)
viz = _render_output_preview(compressed_image, mask, highlight_roi, enable_detection, dets_dicts)
return viz, dets_dicts
@spaces.GPU
def _on_input_detector_change(
original_image,
mask,
open_vocab_classes,
segmenter_method: str,
mission: str,
enable_detection: bool,
detector_method: str,
highlight_roi: bool,
custom_det_classes: str = "",
):
# Changing the detector invalidates cached detections.
dets_dicts = None
return _on_input_detection_toggle(
original_image,
mask,
open_vocab_classes,
dets_dicts,
segmenter_method,
mission,
enable_detection,
detector_method,
highlight_roi,
custom_det_classes,
)
@spaces.GPU
def _on_output_detector_change(
compressed_image,
mask,
open_vocab_classes,
segmenter_method: str,
mission: str,
enable_detection: bool,
detector_method: str,
highlight_roi: bool,
custom_det_classes: str = "",
):
dets_dicts = None
return _on_output_detection_toggle(
compressed_image,
mask,
open_vocab_classes,
dets_dicts,
segmenter_method,
mission,
enable_detection,
detector_method,
highlight_roi,
custom_det_classes,
)
# =============================================================================
# Video Processing Functions
# =============================================================================
def _process_video_streaming(
video_path,
mode: str,
segmenter_method: str,
mission: str,
# Static settings
quality_level: int,
sigma: float,
output_fps: float,
# Dynamic settings
bandwidth_kbps: float,
min_fps: float,
max_fps: float,
chunk_duration: float,
# ROI highlight toggles
show_input_roi: bool,
show_output_roi: bool,
# Detection toggles
show_input_detection: bool,
show_output_detection: bool,
):
"""Process video with streaming output.
This is a generator function that yields updates for real-time streaming.
Returns both input (with optional detections) and output (with optional detections) videos.
"""
import tempfile
import os
if video_path is None:
yield None, None, "No video uploaded", "", {}, None, []
return
ensure_default_checkpoint_dirs()
device = _default_device()
# Parse target classes
target_classes = _split_classes(mission) if mission else []
# Determine if we need detection (for input or output)
need_detection = show_input_detection or show_output_detection
need_roi_highlight = show_input_roi or show_output_roi
try:
from video import VideoProcessor, CompressionSettings
from video.video_processor import frames_to_video_bytes
from detection import SimpleTracker, draw_tracks, create_detector
from detection.utils import draw_detections
from detection.base import Detection
from vae.visualization import highlight_roi
from segmentation import create_segmenter
from PIL import Image
import cv2
# Create processor
processor = VideoProcessor(device=device)
# Initial status update
yield None, None, "**Loading models...**", "", {}, None, []
# Load models - always use YOLO for detection since it's fast
processor.load_models(
quality_level=quality_level if mode == "Static" else 3,
segmentation_method=segmenter_method,
detection_method="yolo",
enable_tracking=need_detection,
)
# Create settings - disable internal tracking, we'll handle it ourselves
settings = CompressionSettings(
mode="static" if mode == "Static" else "dynamic",
quality_level=quality_level,
sigma=sigma,
output_fps=output_fps,
target_bandwidth_kbps=bandwidth_kbps,
chunk_duration_sec=chunk_duration,
min_fps=min_fps,
max_fps=max_fps,
segmentation_method=segmenter_method,
target_classes=target_classes,
enable_tracking=False, # We'll handle detection separately
)
# Collect all frames
all_input_frames = [] # Original input frames
all_output_frames = [] # Compressed output frames
all_stats = []
all_tracks = []
total_bytes = 0
total_original_frames = 0
last_detection_frame = None
# Extract input frames for detection visualization
input_frames_extracted = []
cap = cv2.VideoCapture(video_path)
input_fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
while True:
ret, frame = cap.read()
if not ret:
break
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
input_frames_extracted.append(Image.fromarray(frame_rgb))
cap.release()
# Progress tracking
current_progress = "**Processing...**"
def progress_callback(current, total, message):
nonlocal current_progress
pct = int(100 * current / total) if total > 0 else 0
current_progress = f"**Progress:** {pct}% - {message}"
# Process video using offline batch processing
if mode == "Static":
all_chunks = processor.process_static_offline(
video_path, settings, progress_callback=progress_callback
)
else:
all_chunks = processor.process_dynamic_offline(
video_path, settings, progress_callback=progress_callback
)
# Load detector for visualization if needed
detector = None
tracker = None
input_tracker = None
if need_detection:
detector = create_detector("yolo", device=device)
tracker = SimpleTracker(iou_threshold=0.3, max_age=30)
input_tracker = SimpleTracker(iou_threshold=0.3, max_age=30)
# Load segmenter for ROI highlighting if needed
segmenter = None
if need_roi_highlight:
segmenter = create_segmenter(segmenter_method, device=device)
# Helper to apply ROI highlight to a frame
def apply_roi_highlight(frame_img, seg):
if seg is None:
return frame_img
# Get mask for frame
mask = seg(frame_img, target_classes=target_classes if target_classes else ["object"])
if mask is not None and mask.sum() > 0:
return highlight_roi(frame_img, mask, alpha=0.35, color=(0, 255, 0))
return frame_img
# Track frames for both input and output (with optional detection overlays)
all_input_frames = [] # Input frames (with detection if enabled)
all_tracks = []
last_detection_frame = None
# Process chunks (already complete from offline processing)
for c_idx, chunk in enumerate(all_chunks):
# Process output frames for this chunk
chunk_output_frames = []
for frame in chunk.frames:
processed_frame = frame
# Apply ROI highlight if enabled
if show_output_roi and segmenter is not None:
processed_frame = apply_roi_highlight(processed_frame, segmenter)
# Apply detection if enabled
if show_output_detection and detector is not None:
# Run detection on output frame (use original for detection, draw on processed)
dets = detector(frame, conf_threshold=0.25)
det_dicts = [{"label": d.label, "score": d.score, "bbox_xyxy": list(d.bbox_xyxy)} for d in dets]
tracks = tracker.update(det_dicts)
# Draw detections and tracks on the processed frame
frame_with_dets = draw_detections(processed_frame.copy(), dets, color=(0, 255, 0))
frame_with_dets = draw_tracks(frame_with_dets, tracks, show_id=True, show_trail=True)
chunk_output_frames.append(frame_with_dets)
last_detection_frame = frame_with_dets
all_tracks = tracks
else:
chunk_output_frames.append(processed_frame)
all_output_frames.extend(chunk_output_frames)
total_bytes += chunk.estimated_bytes
total_original_frames += chunk.original_frame_count
# Process corresponding input frames for this chunk (sample from extracted)
need_input_processing = (show_input_detection or show_input_roi) and input_frames_extracted
if need_input_processing:
# Calculate which input frames correspond to this chunk
input_start = int(len(input_frames_extracted) * (chunk.chunk_index * chunk.original_frame_count) / max(1, total_original_frames))
input_end = min(len(input_frames_extracted), input_start + len(chunk.frames))
for frame in input_frames_extracted[input_start:input_end]:
processed_input = frame
# Apply ROI highlight if enabled
if show_input_roi and segmenter is not None:
processed_input = apply_roi_highlight(processed_input, segmenter)
# Apply detection if enabled
if show_input_detection and detector is not None and input_tracker is not None:
dets = detector(frame, conf_threshold=0.25)
det_dicts = [{"label": d.label, "score": d.score, "bbox_xyxy": list(d.bbox_xyxy)} for d in dets]
tracks = input_tracker.update(det_dicts)
frame_with_dets = draw_detections(processed_input.copy(), dets, color=(0, 255, 0))
frame_with_dets = draw_tracks(frame_with_dets, tracks, show_id=True, show_trail=True)
all_input_frames.append(frame_with_dets)
else:
all_input_frames.append(processed_input)
# Collect stats
chunk_stats = {
"chunk_index": chunk.chunk_index,
"frames": len(chunk.frames),
"fps": round(chunk.fps, 1),
"quality": chunk.quality_level,
"sigma": round(chunk.sigma, 2),
"avg_bpp": round(chunk.avg_bpp, 3),
"size_kb": round(chunk.estimated_bytes / 1024, 1),
}
# Add motion metrics for dynamic mode
if chunk.motion_metrics is not None:
chunk_stats["motion"] = round(chunk.motion_metrics.motion_magnitude, 3)
chunk_stats["complexity"] = round(chunk.motion_metrics.complexity, 3)
all_stats.append(chunk_stats)
# Create intermediate video for streaming update
if all_output_frames:
try:
avg_fps = sum(s["fps"] for s in all_stats) / len(all_stats) if all_stats else 15
# Create output video
video_bytes = frames_to_video_bytes(all_output_frames, avg_fps, format="mp4")
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
tmp.write(video_bytes)
output_video_path = tmp.name
# Create input video with overlays if enabled
input_video_path = video_path # Default to original
if (show_input_detection or show_input_roi) and all_input_frames:
input_video_bytes = frames_to_video_bytes(all_input_frames, avg_fps, format="mp4")
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
tmp.write(input_video_bytes)
input_video_path = tmp.name
# Build summary stats
summary_stats = {
"chunks_processed": len(all_stats),
"total_frames": len(all_output_frames),
"original_frames": total_original_frames,
"frame_reduction": f"{100 * (1 - len(all_output_frames) / max(1, total_original_frames)):.1f}%",
"total_size_kb": round(total_bytes / 1024, 1),
"avg_bpp": round(sum(s["avg_bpp"] for s in all_stats) / max(1, len(all_stats)), 3),
"per_chunk": all_stats,
}
# Yield streaming update
yield (
input_video_path,
output_video_path,
current_progress,
f"**Chunk {chunk.chunk_index + 1}** | FPS: {chunk.fps:.1f} | Quality: {chunk.quality_level} | Size: {chunk.estimated_bytes/1024:.1f}KB",
summary_stats,
last_detection_frame,
all_tracks,
)
except Exception as e:
# Continue processing even if video encoding fails
yield (
video_path,
None,
current_progress,
f"Encoding error: {e}",
{"chunks_processed": len(all_stats), "per_chunk": all_stats},
last_detection_frame,
all_tracks,
)
# Final output
if not all_output_frames:
yield None, None, "**No frames processed**", "Error: No frames produced", {}, None, []
return
avg_fps = sum(s["fps"] for s in all_stats) / len(all_stats) if all_stats else 15
# Create final output video
video_bytes = frames_to_video_bytes(all_output_frames, avg_fps, format="mp4")
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
tmp.write(video_bytes)
final_output_path = tmp.name
# Create final input video with overlays if enabled
final_input_path = video_path
if (show_input_detection or show_input_roi) and all_input_frames:
input_video_bytes = frames_to_video_bytes(all_input_frames, avg_fps, format="mp4")
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
tmp.write(input_video_bytes)
final_input_path = tmp.name
# Final summary
compression_ratio = (24 * total_original_frames * processor.video_dimensions[0] * processor.video_dimensions[1]) / (8 * max(1, total_bytes))
final_stats = {
"status": "complete",
"total_chunks": len(all_stats),
"total_frames": len(all_output_frames),
"original_frames": total_original_frames,
"frame_reduction": f"{100 * (1 - len(all_output_frames) / max(1, total_original_frames)):.1f}%",
"total_size_kb": round(total_bytes / 1024, 1),
"compression_ratio": f"{compression_ratio:.1f}x",
"avg_bpp": round(sum(s["avg_bpp"] for s in all_stats) / max(1, len(all_stats)), 3),
"video_fps": processor.video_fps,
"video_dimensions": processor.video_dimensions,
"per_chunk": all_stats,
}
final_status = (
f"**Complete!** | {len(all_output_frames)} frames | "
f"{total_bytes/1024:.1f}KB | {compression_ratio:.1f}x compression"
)
yield (
final_input_path,
final_output_path,
"**Processing complete!**",
final_status,
final_stats,
last_detection_frame,
all_tracks,
)
except Exception as e:
import traceback
error_msg = f"Error: {type(e).__name__}: {e}\n{traceback.format_exc()}"
yield None, None, f"**Error:** {type(e).__name__}", error_msg, {}, None, []
finally:
# Always cleanup processor to release models
if 'processor' in locals():
processor.cleanup()
@spaces.GPU
def _find_video_roi(
video_path,
segmenter_method: str,
mission: str,
show_detection: bool,
):
"""
Find ROI in video and create ROI-highlighted video.
Preprocesses video to 480p max height and caps FPS at 30.
Saves segmentation masks to file for reuse in compression.
Returns:
- roi_video_path: Path to video with ROI highlights
- status: Status message
- roi_frames: List of ROI-highlighted PIL frames (for caching)
- mask_file_path: Path to saved segmentation masks file
"""
import tempfile
import cv2
from PIL import Image
if video_path is None:
return None, "**Error:** No video uploaded", None
ensure_default_checkpoint_dirs()
device = _default_device()
target_classes = _split_classes(mission) if mission else []
try:
from video.video_processor import frames_to_video_bytes, MAX_PROCESSING_FPS, MAX_PROCESSING_HEIGHT
from vae.visualization import highlight_roi
from segmentation import create_segmenter
from detection import create_detector
from detection.utils import draw_detections
from video.mask_cache import save_video_masks
from video.chunk_compressor import smooth_masks_sdf
# Load segmenter
segmenter = create_segmenter(segmenter_method, device=device)
# Load detector if needed
detector = None
if show_detection:
detector = create_detector("yolo", device=device)
# Extract frames from video with preprocessing
cap = cv2.VideoCapture(video_path)
original_fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# Calculate effective FPS (capped)
effective_fps = min(original_fps, MAX_PROCESSING_FPS)
frame_step = max(1, int(original_fps / effective_fps))
# Calculate resize dimensions
if original_height > MAX_PROCESSING_HEIGHT:
scale_factor = MAX_PROCESSING_HEIGHT / original_height
new_width = int(original_width * scale_factor)
new_height = MAX_PROCESSING_HEIGHT
else:
new_width = original_width
new_height = original_height
# Extract all frames first
pil_frames = []
frame_idx = 0
while True:
ret, frame = cap.read()
if not ret:
break
# Skip frames for FPS limiting
if frame_idx % frame_step != 0:
frame_idx += 1
continue
# Resize if needed
if new_width != original_width or new_height != original_height:
frame = cv2.resize(frame, (new_width, new_height), interpolation=cv2.INTER_AREA)
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil_frames.append(Image.fromarray(frame_rgb))
frame_idx += 1
cap.release()
if not pil_frames:
return None, "**Error:** No frames extracted from video", None, None
# Auto-detect batch size from GPU memory
from video.gpu_memory import estimate_batch_sizes
batch_est = estimate_batch_sizes(
frame_height=new_height,
frame_width=new_width,
seg_method=segmenter_method,
device=str(device),
total_frames=len(pil_frames),
)
seg_batch = batch_est.seg_batch_size
print(f"UI _find_video_roi: {len(pil_frames)} frames, seg batch={seg_batch} ({batch_est.notes})")
# Batch segment all frames with OOM retry
import torch
max_retries = 7
all_masks = None
prompts = target_classes if target_classes else ["object"]
for attempt in range(max_retries + 1):
try:
all_masks = []
if hasattr(segmenter, 'segment_batch') and getattr(segmenter, 'supports_batch', False):
for i in range(0, len(pil_frames), seg_batch):
batch = pil_frames[i:i + seg_batch]
batch_masks = segmenter.segment_batch(batch, target_classes=prompts)
all_masks.extend([m.astype('float32') for m in batch_masks])
else:
for pil_frame in pil_frames:
mask = segmenter(pil_frame, target_classes=prompts)
all_masks.append(mask.astype('float32'))
break # success
except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
if 'out of memory' in str(e).lower() and attempt < max_retries:
seg_batch = max(1, seg_batch // 2)
# Aggressive memory cleanup to prevent leaks
all_masks = None
import gc
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
print(f"UI _find_video_roi: OOM, retrying with batch={seg_batch} (attempt {attempt+1}/{max_retries})")
continue
raise
if all_masks is None:
return None, "**Error:** Segmentation failed after OOM retries", None, None
# Apply SDF temporal smoothing to masks
all_masks = smooth_masks_sdf(all_masks, alpha=0.5, empty_thresh=10, patience=5)
# Create ROI-highlighted frames
roi_frames = []
for pil_frame, mask in zip(pil_frames, all_masks):
if mask is not None and mask.sum() > 0:
highlighted = highlight_roi(pil_frame, mask, alpha=0.35, color=(0, 255, 0))
else:
highlighted = pil_frame
# Apply detection overlay if enabled
if show_detection and detector is not None:
dets = detector(pil_frame, conf_threshold=0.25)
highlighted = draw_detections(highlighted, dets, color=(255, 0, 0))
roi_frames.append(highlighted)
# Save masks to file for reuse
mask_file_path = save_video_masks(all_masks)
# Create output video
video_bytes = frames_to_video_bytes(roi_frames, effective_fps, format="mp4")
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
tmp.write(video_bytes)
output_path = tmp.name
status = f"**ROI Found** | {len(roi_frames)} frames @ {effective_fps:.0f}fps ({new_width}x{new_height}) | Masks saved"
if show_detection:
status += " | Detection enabled"
return output_path, status, roi_frames, mask_file_path
except Exception as e:
import traceback
return None, f"**Error:** {type(e).__name__}: {e}", None, None
@spaces.GPU
def _compress_video(
video_path,
segmenter_method: str,
mission: str,
mode: str,
# Static settings
quality_level: int,
sigma: float,
output_fps: float,
# Dynamic settings
bandwidth_kbps: float,
min_fps: float,
max_fps: float,
chunk_duration: float,
# State
roi_frames_cache,
mask_file_path, # NEW: Path to saved segmentation masks
):
"""
Compress video using ROI-based compression.
If mask_file_path is provided, reuses saved segmentation masks.
Otherwise, performs segmentation during compression.
Returns:
- compressed_video_path: Path to compressed video
- status: Status message
- stats: Compression statistics dict
- compressed_state: Path for detection step
"""
import tempfile
if video_path is None:
return None, "**Error:** No video uploaded", {}, None
ensure_default_checkpoint_dirs()
device = _default_device()
target_classes = _split_classes(mission) if mission else []
try:
from video import VideoProcessor, CompressionSettings, load_video_masks
from video.video_processor import frames_to_video_bytes
# Load saved masks if available
saved_masks = None
if mask_file_path is not None:
try:
saved_masks = load_video_masks(mask_file_path)
print(f"Loaded {len(saved_masks)} cached masks")
except Exception as e:
print(f"Failed to load masks: {e}, will re-segment")
# Create processor
processor = VideoProcessor(device=device)
# Load models
processor.load_models(
quality_level=quality_level if mode == "Static" else 3,
segmentation_method=segmenter_method,
detection_method="yolo",
enable_tracking=False,
)
# Create settings
settings = CompressionSettings(
mode="static" if mode == "Static" else "dynamic",
quality_level=quality_level,
sigma=sigma,
output_fps=output_fps,
target_bandwidth_kbps=bandwidth_kbps,
chunk_duration_sec=chunk_duration,
min_fps=min_fps,
max_fps=max_fps,
segmentation_method=segmenter_method,
target_classes=target_classes,
enable_tracking=False,
)
# Process video using offline batch processing
if mode == "Static":
all_chunks = processor.process_static_offline(video_path, settings, saved_masks=saved_masks)
else:
all_chunks = processor.process_dynamic_offline(video_path, settings, saved_masks=saved_masks)
all_frames = []
all_stats = []
total_bytes = 0
total_original_frames = 0
for chunk in all_chunks:
all_frames.extend(chunk.frames)
total_bytes += chunk.estimated_bytes
total_original_frames += chunk.original_frame_count
chunk_stat = {
"chunk": chunk.chunk_index,
"frames": len(chunk.frames),
"fps": round(chunk.fps, 1),
"quality": chunk.quality_level,
"sigma": round(chunk.sigma, 2),
"bpp": round(chunk.avg_bpp, 3),
"size_kb": round(chunk.estimated_bytes / 1024, 1),
}
if chunk.motion_metrics is not None:
chunk_stat["motion"] = round(chunk.motion_metrics.motion_magnitude, 3)
all_stats.append(chunk_stat)
if not all_frames:
return None, "**Error:** No frames produced", {}, None
# Calculate final FPS
avg_fps = sum(s["fps"] for s in all_stats) / len(all_stats) if all_stats else output_fps
# Create output video
video_bytes = frames_to_video_bytes(all_frames, avg_fps, format="mp4")
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
tmp.write(video_bytes)
output_path = tmp.name
# Calculate compression stats
if processor.video_dimensions:
w, h = processor.video_dimensions
compression_ratio = (24 * total_original_frames * w * h) / (8 * max(1, total_bytes))
else:
compression_ratio = 0
stats = {
"mode": mode,
"total_frames": len(all_frames),
"original_frames": total_original_frames,
"frame_reduction": f"{100 * (1 - len(all_frames) / max(1, total_original_frames)):.1f}%",
"total_size_kb": round(total_bytes / 1024, 1),
"compression_ratio": f"{compression_ratio:.1f}x",
"avg_bpp": round(sum(s["bpp"] for s in all_stats) / max(1, len(all_stats)), 3),
"chunks": all_stats,
}
status = (
f"**Compressed!** | {len(all_frames)} frames | "
f"{total_bytes/1024:.1f}KB | {compression_ratio:.1f}x compression"
)
if saved_masks is not None:
status += " | Using cached masks"
return output_path, status, stats, output_path
except Exception as e:
import traceback
return None, f"**Error:** {type(e).__name__}: {e}", {}, None
finally:
# Always cleanup processor to release models
if 'processor' in locals():
processor.cleanup()
@spaces.GPU
def _run_video_detection(
compressed_video_path,
det_method: str,
det_classes: str,
):
"""
Run object detection on compressed video.
Returns:
- detection_video_path: Path to video with detection overlays
- status: Status message
"""
import tempfile
import cv2
from PIL import Image
if compressed_video_path is None:
return None, "**Error:** No compressed video available. Run compression first."
ensure_default_checkpoint_dirs()
device = _default_device()
try:
from video.video_processor import frames_to_video_bytes
from detection import create_detector
from detection.utils import draw_detections
# Load detector
detector = create_detector(det_method, device=device)
# Parse classes for open-vocab detectors
det_kwargs = {"conf_threshold": 0.25}
if det_method in OPEN_VOCAB_DETECTORS:
classes = _split_classes(det_classes)
if not classes:
return None, f"**Error:** {det_method} requires class prompts"
det_kwargs["classes"] = classes
# Extract frames
cap = cv2.VideoCapture(compressed_video_path)
fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
det_frames = []
total_detections = 0
while True:
ret, frame = cap.read()
if not ret:
break
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil_frame = Image.fromarray(frame_rgb)
# Run detection
dets = detector(pil_frame, **det_kwargs)
total_detections += len(dets)
# Draw detections
frame_with_dets = draw_detections(pil_frame, dets, color=(0, 255, 0))
det_frames.append(frame_with_dets)
cap.release()
if not det_frames:
return None, "**Error:** No frames extracted"
# Create output video
video_bytes = frames_to_video_bytes(det_frames, fps, format="mp4")
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
tmp.write(video_bytes)
output_path = tmp.name
avg_dets = total_detections / len(det_frames)
status = f"**Detection complete** | {len(det_frames)} frames | {total_detections} total detections ({avg_dets:.1f}/frame)"
return output_path, status
except Exception as e:
import traceback
return None, f"**Error:** {type(e).__name__}: {e}"
def build_app() -> gr.Blocks:
from segmentation.factory import get_available_methods
from detection import get_available_detectors
seg_methods = get_available_methods()
det_methods = get_available_detectors()
with gr.Blocks(title="Contextual Communication Demo") as demo:
gr.Markdown(
"# Contextual Communication Demo\n"
"Mission-driven contextual compression for bandwidth-limited ISR: preserve decision-critical regions while compactly transmitting the rest."
)
with gr.Tabs():
with gr.TabItem("Image", id=0):
with gr.Row():
with gr.Column(scale=1):
inp = gr.Image(type="pil", label="Input image", height=480)
original_state = gr.State(None)
mask_state = gr.State(None)
open_vocab_classes_state = gr.State(None)
input_dets_state = gr.State(None)
compressed_state = gr.State(None)
output_dets_state = gr.State(None)
segmenter = gr.Dropdown(
choices=seg_methods,
value="sam3" if "sam3" in seg_methods else (seg_methods[0] if seg_methods else None),
label="Context Extraction Method (ROI)",
)
mission = gr.Textbox(
label="Mission",
placeholder="e.g., Collect intelligence on potential air threats (drones, aircraft)",
lines=1,
autoscroll=False,
)
chosen_classes = gr.Markdown("")
with gr.Group():
input_roi_highlight = gr.Checkbox(value=True, label="Highlight ROI")
input_det_toggle = gr.Checkbox(value=False, label="Enable object detection")
input_detector = gr.Dropdown(
choices=det_methods,
value="yolo" if "yolo" in det_methods else (det_methods[0] if det_methods else None),
label="Detector",
interactive=True,
visible=False,
)
with gr.Row(visible=False) as input_det_classes_row:
input_det_classes = gr.Textbox(
placeholder="person, car, drone, aircraft",
scale=4,
show_label=False,
container=False,
)
input_det_apply_btn = gr.Button("Apply", scale=2, variant="secondary", min_width=60)
find_roi_btn = gr.Button("Find Context (ROI)")
with gr.Column(scale=1):
out_out = gr.Image(type="pil", label="Decoded image (after transmission)", height=480)
quality = gr.Slider(
minimum=1,
maximum=len(CHECKPOINTS),
value=4 if len(CHECKPOINTS) >= 4 else 1,
step=1,
label="Transmission quality (higher = larger payload)",
)
_blank1 = gr.Markdown("")
bg = gr.Slider(
minimum=0.01,
maximum=1.0,
value=0.3,
step=0.01,
label="Background preservation (higher = keep more context)",
)
_blank2 = gr.Markdown("")
with gr.Group():
output_roi_highlight = gr.Checkbox(value=False, label="Highlight ROI")
output_det_toggle = gr.Checkbox(value=False, label="Enable object detection")
output_detector = gr.Dropdown(
choices=det_methods,
value="yolo" if "yolo" in det_methods else (det_methods[0] if det_methods else None),
label="Detector",
interactive=True,
visible=False,
)
with gr.Row(visible=False) as output_det_classes_row:
output_det_classes = gr.Textbox(
placeholder="person, car, drone, aircraft",
scale=4,
show_label=False,
container=False,
)
output_det_apply_btn = gr.Button("Apply", scale=1, size="sm", variant="secondary", min_width=60)
compress_btn = gr.Button("Transmit (Compress)", variant="primary")
status = gr.Markdown("")
# Use global OPEN_VOCAB_DETECTORS defined at module level
def _toggle_detector_ui(enabled: bool, detector: str):
is_open_vocab = detector in OPEN_VOCAB_DETECTORS
return gr.update(visible=enabled), gr.update(visible=enabled and is_open_vocab)
def _on_detector_type_change(enabled: bool, detector: str):
is_open_vocab = detector in OPEN_VOCAB_DETECTORS
return gr.update(visible=enabled and is_open_vocab)
input_det_toggle.change(
_toggle_detector_ui,
inputs=[input_det_toggle, input_detector],
outputs=[input_detector, input_det_classes_row],
)
output_det_toggle.change(
_toggle_detector_ui,
inputs=[output_det_toggle, output_detector],
outputs=[output_detector, output_det_classes_row],
)
input_detector.change(
_on_detector_type_change,
inputs=[input_det_toggle, input_detector],
outputs=[input_det_classes_row],
)
output_detector.change(
_on_detector_type_change,
inputs=[output_det_toggle, output_detector],
outputs=[output_det_classes_row],
)
# IMPORTANT: use upload event so programmatic updates to the image
# (ROI overlays) don't trigger state resets and cause stacking.
inp.upload(
_on_upload,
inputs=[inp],
outputs=[
original_state,
mask_state,
open_vocab_classes_state,
input_dets_state,
output_dets_state,
chosen_classes,
status,
compressed_state,
],
)
find_roi_btn.click(
_find_roi,
inputs=[original_state, segmenter, mission, input_roi_highlight, input_det_toggle, input_detector, input_det_classes],
outputs=[inp, mask_state, open_vocab_classes_state, input_dets_state, chosen_classes, status],
)
compress_btn.click(
_compress,
inputs=[
original_state,
mask_state,
open_vocab_classes_state,
segmenter,
mission,
quality,
bg,
output_roi_highlight,
output_det_toggle,
output_detector,
output_det_classes,
],
outputs=[out_out, compressed_state, mask_state, output_dets_state, open_vocab_classes_state, chosen_classes, status],
)
# Reactive input toggles: purely re-render from cached state.
input_roi_highlight.change(
_refresh_input_view,
inputs=[original_state, mask_state, input_dets_state, input_roi_highlight, input_det_toggle],
outputs=[inp],
)
input_det_toggle.change(
_on_input_detection_toggle,
inputs=[
original_state,
mask_state,
open_vocab_classes_state,
input_dets_state,
segmenter,
mission,
input_det_toggle,
input_detector,
input_roi_highlight,
input_det_classes,
],
outputs=[inp, input_dets_state],
)
input_detector.change(
_on_input_detector_change,
inputs=[
original_state,
mask_state,
open_vocab_classes_state,
segmenter,
mission,
input_det_toggle,
input_detector,
input_roi_highlight,
input_det_classes,
],
outputs=[inp, input_dets_state],
)
# Reactive output toggles: purely re-render from cached state.
output_roi_highlight.change(
_refresh_output_view,
inputs=[compressed_state, mask_state, output_dets_state, output_roi_highlight, output_det_toggle],
outputs=[out_out],
)
output_det_toggle.change(
_on_output_detection_toggle,
inputs=[
compressed_state,
mask_state,
open_vocab_classes_state,
output_dets_state,
segmenter,
mission,
output_det_toggle,
output_detector,
output_roi_highlight,
output_det_classes,
],
outputs=[out_out, output_dets_state],
)
output_detector.change(
_on_output_detector_change,
inputs=[
compressed_state,
mask_state,
open_vocab_classes_state,
segmenter,
mission,
output_det_toggle,
output_detector,
output_roi_highlight,
output_det_classes,
],
outputs=[out_out, output_dets_state],
)
# Re-run detection when custom classes textbox changes (submit/blur)
input_det_classes.submit(
_on_input_detector_change,
inputs=[
original_state,
mask_state,
open_vocab_classes_state,
segmenter,
mission,
input_det_toggle,
input_detector,
input_roi_highlight,
input_det_classes,
],
outputs=[inp, input_dets_state],
)
output_det_classes.submit(
_on_output_detector_change,
inputs=[
compressed_state,
mask_state,
open_vocab_classes_state,
segmenter,
mission,
output_det_toggle,
output_detector,
output_roi_highlight,
output_det_classes,
],
outputs=[out_out, output_dets_state],
)
# Apply buttons for open-vocab detection classes
input_det_apply_btn.click(
_on_input_detector_change,
inputs=[
original_state,
mask_state,
open_vocab_classes_state,
segmenter,
mission,
input_det_toggle,
input_detector,
input_roi_highlight,
input_det_classes,
],
outputs=[inp, input_dets_state],
)
output_det_apply_btn.click(
_on_output_detector_change,
inputs=[
compressed_state,
mask_state,
open_vocab_classes_state,
segmenter,
mission,
output_det_toggle,
output_detector,
output_roi_highlight,
output_det_classes,
],
outputs=[out_out, output_dets_state],
)
with gr.TabItem("Video", id=1):
# ===============================================================
# VIDEO TAB - ROI-based video compression
# Layout:
# Row 1: [Input Video] [ROI Highlighted Video] [Compressed Output]
# Row 2: [ROI Controls (2/3 width)] [Compression Controls (1/3 width)]
# ===============================================================
gr.Markdown(
"## Video Compression\n\n"
"**Workflow:** Upload video → Select ROI method & mission → Find ROI → Tune compression → Transmit"
)
# State variables for video processing
video_roi_frames_state = gr.State(None) # ROI-highlighted frames
video_mask_file_state = gr.State(None) # Saved segmentation masks file
video_compressed_state = gr.State(None) # Compressed output path
# ================== ROW 1: Three Video Panes ==================
with gr.Row(height=360):
video_input = gr.Video(
label="Input Video",
height=360,
)
video_roi_output = gr.Video(
label="ROI Highlighted Video",
height=360,
autoplay=True,
)
video_compressed_output = gr.Video(
label="Compressed Output",
height=360,
autoplay=True,
)
# ================== ROW 2: Controls ==================
with gr.Row():
# Left column: ROI Controls (2/3 width - under Input + ROI videos)
with gr.Column(scale=2):
video_segmenter = gr.Dropdown(
choices=seg_methods,
value="sam3" if "sam3" in seg_methods else (seg_methods[0] if seg_methods else None),
label="ROI Method",
)
video_mission = gr.Textbox(
label="Mission / Target Classes",
placeholder="e.g., drone, person, vehicle, aircraft",
lines=1,
)
video_show_detection = gr.Checkbox(
value=False,
label="Show Object Detection on ROI Video",
)
video_find_roi_btn = gr.Button("Find ROI", variant="secondary")
video_roi_status = gr.Markdown("")
# Right column: Compression Controls (1/3 width - under Compressed Output)
with gr.Column(scale=1):
# Mode selection
video_mode = gr.Radio(
choices=["Static", "Dynamic"],
value="Static",
label="Compression Mode",
info="Static: fixed settings | Dynamic: bandwidth-adaptive",
)
# Static mode settings
with gr.Group(visible=True) as static_settings_group:
video_quality = gr.Slider(
minimum=1,
maximum=5,
value=4,
step=1,
label="Quality (1=smallest, 5=best)",
)
video_sigma = gr.Slider(
minimum=0.01,
maximum=1.0,
value=0.3,
step=0.01,
label="Background Preservation (σ)",
)
video_output_fps = gr.Slider(
minimum=1,
maximum=60,
value=10,
step=1,
label="Output FPS",
)
# Dynamic mode settings
with gr.Group(visible=False) as dynamic_settings_group:
video_bandwidth = gr.Slider(
minimum=50,
maximum=5000,
value=500,
step=50,
label="Target Bandwidth (kbps)",
)
video_min_fps = gr.Slider(
minimum=1,
maximum=30,
value=5,
step=1,
label="Min FPS",
)
video_max_fps = gr.Slider(
minimum=15,
maximum=60,
value=24,
step=1,
label="Max FPS",
)
video_chunk_duration = gr.Slider(
minimum=0.5,
maximum=5.0,
value=1.0,
step=0.5,
label="Chunk Duration (sec)",
)
video_transmit_btn = gr.Button("Transmit (Compress)", variant="primary")
video_compress_status = gr.Markdown("")
# Stats accordion
with gr.Accordion("Compression Statistics", open=False):
video_stats = gr.JSON(label="Stats")
# Object Detection accordion (hidden by default)
with gr.Accordion("Try Object Detection", open=False):
video_det_method = gr.Dropdown(
choices=det_methods,
value="yolo" if "yolo" in det_methods else (det_methods[0] if det_methods else None),
label="Detection Method",
)
video_det_prompt = gr.Textbox(
label="Classes (open-vocab)",
placeholder="person, car, drone",
lines=1,
visible=False,
)
video_run_detection_btn = gr.Button("Run Detection", variant="secondary")
video_det_status = gr.Markdown("")
video_detection_output = gr.Video(
label="Detection Output",
height=200,
autoplay=True,
)
# ================== Event Handlers ==================
# Mode toggle handler
def _toggle_video_mode(mode: str):
is_static = mode == "Static"
return (
gr.update(visible=is_static),
gr.update(visible=not is_static),
)
video_mode.change(
_toggle_video_mode,
inputs=[video_mode],
outputs=[static_settings_group, dynamic_settings_group],
)
# Show/hide detection prompt for open-vocab detectors
def _toggle_det_prompt(method: str):
return gr.update(visible=method in OPEN_VOCAB_DETECTORS)
video_det_method.change(
_toggle_det_prompt,
inputs=[video_det_method],
outputs=[video_det_prompt],
)
# Find ROI handler
video_find_roi_btn.click(
_find_video_roi,
inputs=[
video_input,
video_segmenter,
video_mission,
video_show_detection,
],
outputs=[
video_roi_output,
video_roi_status,
video_roi_frames_state,
video_mask_file_state,
],
)
# Transmit (Compress) handler
video_transmit_btn.click(
_compress_video,
inputs=[
video_input,
video_segmenter,
video_mission,
video_mode,
# Static settings
video_quality,
video_sigma,
video_output_fps,
# Dynamic settings
video_bandwidth,
video_min_fps,
video_max_fps,
video_chunk_duration,
# State
video_roi_frames_state,
video_mask_file_state,
],
outputs=[
video_compressed_output,
video_compress_status,
video_stats,
video_compressed_state,
],
)
# Run Detection handler
video_run_detection_btn.click(
_run_video_detection,
inputs=[
video_compressed_state,
video_det_method,
video_det_prompt,
],
outputs=[
video_detection_output,
video_det_status,
],
)
# =================================================================
# Register API endpoints with clean names
# These are the primary endpoints for external use
# =================================================================
# Hidden interface components for API endpoints
with gr.Row(visible=False):
# Segment API
_api_seg_image = gr.Image(type="pil")
_api_seg_prompt = gr.Textbox()
_api_seg_method = gr.Textbox()
_api_seg_return_overlay = gr.Checkbox()
_api_seg_mask = gr.Image(type="pil")
_api_seg_coverage = gr.Number()
_api_seg_classes = gr.Textbox()
# Compress API
_api_comp_image = gr.Image(type="pil")
_api_comp_mask = gr.Image(type="pil")
_api_comp_quality = gr.Number()
_api_comp_sigma = gr.Number()
_api_comp_output = gr.Image(type="pil")
_api_comp_bpp = gr.Number()
_api_comp_ratio = gr.Number()
# Detect API (original - returns JSON only)
_api_det_image = gr.Image(type="pil")
_api_det_method = gr.Textbox()
_api_det_classes = gr.Textbox()
_api_det_conf = gr.Number()
_api_det_result = gr.Textbox() # JSON result only
# Detect Overlay API (returns image + JSON)
_api_det_ov_image = gr.Image(type="pil")
_api_det_ov_method = gr.Textbox()
_api_det_ov_classes = gr.Textbox()
_api_det_ov_conf = gr.Number()
_api_det_ov_image_out = gr.Image(type="pil")
_api_det_ov_result = gr.Textbox()
# Process API
_api_proc_image = gr.Image(type="pil")
_api_proc_prompt = gr.Textbox()
_api_proc_seg_method = gr.Textbox()
_api_proc_quality = gr.Number()
_api_proc_sigma = gr.Number()
_api_proc_run_det = gr.Checkbox()
_api_proc_det_method = gr.Textbox()
_api_proc_det_classes = gr.Textbox()
_api_proc_out_image = gr.Image(type="pil")
_api_proc_out_mask = gr.Image(type="pil")
_api_proc_out_bpp = gr.Number()
_api_proc_out_ratio = gr.Number()
_api_proc_out_coverage = gr.Number()
_api_proc_out_dets = gr.Textbox()
# Register segment endpoint
gr.Button(visible=False).click(
api_segment,
inputs=[_api_seg_image, _api_seg_prompt, _api_seg_method, _api_seg_return_overlay],
outputs=[_api_seg_mask, _api_seg_coverage, _api_seg_classes],
api_name="segment",
)
# Register compress endpoint
gr.Button(visible=False).click(
api_compress,
inputs=[_api_comp_image, _api_comp_mask, _api_comp_quality, _api_comp_sigma],
outputs=[_api_comp_output, _api_comp_bpp, _api_comp_ratio],
api_name="compress",
)
# Register detect endpoint (original - returns JSON only for backwards compatibility)
gr.Button(visible=False).click(
api_detect,
inputs=[_api_det_image, _api_det_method, _api_det_classes, _api_det_conf],
outputs=[_api_det_result],
api_name="detect",
)
# Register detect_overlay endpoint (returns image with boxes + JSON)
gr.Button(visible=False).click(
api_detect_overlay,
inputs=[_api_det_ov_image, _api_det_ov_method, _api_det_ov_classes, _api_det_ov_conf],
outputs=[_api_det_ov_image_out, _api_det_ov_result],
api_name="detect_overlay",
)
# Register process endpoint
gr.Button(visible=False).click(
api_process,
inputs=[
_api_proc_image, _api_proc_prompt, _api_proc_seg_method,
_api_proc_quality, _api_proc_sigma,
_api_proc_run_det, _api_proc_det_method, _api_proc_det_classes,
],
outputs=[
_api_proc_out_image, _api_proc_out_mask,
_api_proc_out_bpp, _api_proc_out_ratio, _api_proc_out_coverage,
_api_proc_out_dets,
],
api_name="process",
)
# Video API components (hidden)
with gr.Row(visible=False):
# Video process (full pipeline)
_api_vid_input = gr.Video()
_api_vid_prompt = gr.Textbox()
_api_vid_seg_method = gr.Textbox()
_api_vid_mode = gr.Textbox()
_api_vid_quality = gr.Number()
_api_vid_sigma = gr.Number()
_api_vid_fps = gr.Number()
_api_vid_bandwidth = gr.Number()
_api_vid_min_fps = gr.Number()
_api_vid_max_fps = gr.Number()
_api_vid_aggressiveness = gr.Number()
_api_vid_run_det = gr.Checkbox()
_api_vid_det_method = gr.Textbox()
_api_vid_mask_file = gr.Textbox()
_api_vid_output = gr.Video()
_api_vid_stats = gr.Textbox()
# Video segment
_api_vid_seg_input = gr.Video()
_api_vid_seg_prompt = gr.Textbox()
_api_vid_seg_method = gr.Textbox()
_api_vid_seg_return_overlay = gr.Checkbox()
_api_vid_seg_output_fps = gr.Number()
_api_vid_seg_output = gr.File()
_api_vid_seg_output_video = gr.Video()
_api_vid_seg_stats = gr.Textbox()
# Video compress
_api_vid_comp_input = gr.Video()
_api_vid_comp_mask_file = gr.Textbox()
_api_vid_comp_quality = gr.Number()
_api_vid_comp_sigma = gr.Number()
_api_vid_comp_fps = gr.Number()
_api_vid_comp_output = gr.Video()
_api_vid_comp_stats = gr.Textbox()
# Video detect
_api_vid_det_input = gr.Video()
_api_vid_det_method = gr.Textbox()
_api_vid_det_classes = gr.Textbox()
_api_vid_det_conf = gr.Number()
_api_vid_det_return_overlay = gr.Checkbox()
_api_vid_det_fps = gr.Number()
_api_vid_det_output = gr.Video()
_api_vid_det_result = gr.Textbox()
# Register video process endpoint (full pipeline)
gr.Button(visible=False).click(
api_process_video,
inputs=[
_api_vid_input, _api_vid_prompt, _api_vid_seg_method,
_api_vid_mode, _api_vid_quality, _api_vid_sigma, _api_vid_fps,
_api_vid_bandwidth, _api_vid_min_fps, _api_vid_max_fps,
_api_vid_aggressiveness, _api_vid_run_det, _api_vid_det_method, _api_vid_mask_file,
],
outputs=[_api_vid_output, _api_vid_stats],
api_name="process_video",
)
# Register video segment endpoint
gr.Button(visible=False).click(
api_segment_video,
inputs=[
_api_vid_seg_input, _api_vid_seg_prompt, _api_vid_seg_method,
_api_vid_seg_return_overlay, _api_vid_seg_output_fps,
],
outputs=[_api_vid_seg_output_video, _api_vid_seg_stats],
api_name="segment_video",
)
# Register video compress endpoint
gr.Button(visible=False).click(
api_compress_video,
inputs=[
_api_vid_comp_input, _api_vid_comp_mask_file,
_api_vid_comp_quality, _api_vid_comp_sigma, _api_vid_comp_fps,
],
outputs=[_api_vid_comp_output, _api_vid_comp_stats],
api_name="compress_video",
)
# Register video detect endpoint
gr.Button(visible=False).click(
api_detect_video,
inputs=[
_api_vid_det_input, _api_vid_det_method, _api_vid_det_classes,
_api_vid_det_conf, _api_vid_det_return_overlay, _api_vid_det_fps,
],
outputs=[_api_vid_det_output, _api_vid_det_result],
api_name="detect_video",
)
return demo
if __name__ == "__main__":
# Ensure proper event loop handling for Python 3.13+
if sys.version_info >= (3, 13):
try:
# Set up a new event loop to avoid cleanup issues
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except Exception:
pass
app = build_app()
app.launch(theme=gr.themes.Soft(), show_error=True)