Spaces:
Paused
Paused
| import gradio as gr | |
| import os | |
| import cv2 | |
| import shutil | |
| import tempfile | |
| import numpy as np | |
| import subprocess | |
| import time | |
| import threading | |
| import torch | |
| import sys | |
| import logging | |
| from PIL import Image | |
| # =========================================== | |
| # LOGGING CONFIGURATION | |
| # =========================================== | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| handlers=[logging.StreamHandler(sys.stdout)] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Ensure Python sees the local 'eneas' folder | |
| sys.path.append(os.path.dirname(os.path.abspath(__file__))) | |
| import spaces | |
| try: | |
| from eneas.segmentation import UniqueInstanceSegmenter, GenericCategorySegmenter | |
| from eneas.segmentation.model_manager import ModelManager | |
| except ImportError as e: | |
| logger.error(f"Error importing ENEAS: {e}") | |
| raise e | |
| # =========================================== | |
| # CONSTANTS | |
| # =========================================== | |
| MAX_FRAMES = 150 # Limit frames to avoid ZeroGPU Timeout (~1s/frame processing) | |
| OLLAMA_HOST = "127.0.0.1:11434" | |
| OLLAMA_URL = f"http://{OLLAMA_HOST}" | |
| OLLAMA_BIN = "./bin/ollama" | |
| VLM_MODELS = [ | |
| "qwen3-vl:4b-instruct-q8_0", | |
| "qwen3-vl:2b-instruct-q8_0" | |
| ] | |
| OUTPUT_BASE_DIR = "gradio_outputs" | |
| os.makedirs(OUTPUT_BASE_DIR, exist_ok=True) | |
| # =========================================== | |
| # OLLAMA FUNCTIONS (FOR USE INSIDE @spaces.GPU) | |
| # =========================================== | |
| def get_ollama_env(): | |
| """Get environment variables for Ollama process with GPU support.""" | |
| env = os.environ.copy() | |
| env["OLLAMA_HOST"] = OLLAMA_HOST | |
| env["OLLAMA_ORIGINS"] = "*" | |
| env["HOME"] = os.getcwd() | |
| # Add local lib path for the extracted binary | |
| cwd = os.getcwd() | |
| lib_path = f"{cwd}/lib" | |
| if "LD_LIBRARY_PATH" in env: | |
| env["LD_LIBRARY_PATH"] += f":{lib_path}" | |
| else: | |
| env["LD_LIBRARY_PATH"] = lib_path | |
| return env | |
| def is_ollama_server_running() -> bool: | |
| """Check if Ollama server is responding.""" | |
| try: | |
| result = subprocess.run( | |
| ["curl", "-s", "-o", "/dev/null", "-w", "%{http_code}", OLLAMA_URL], | |
| capture_output=True, | |
| text=True, | |
| timeout=5 | |
| ) | |
| return result.stdout.strip() == "200" | |
| except Exception: | |
| return False | |
| def start_ollama_server_gpu(): | |
| """ | |
| Start Ollama server INSIDE @spaces.GPU context. | |
| This ensures Ollama detects and uses the GPU. | |
| Returns: | |
| bool: True if server started successfully | |
| """ | |
| if is_ollama_server_running(): | |
| logger.info("Ollama server is already running.") | |
| return True | |
| logger.info("Starting Ollama server inside GPU context...") | |
| try: | |
| env = get_ollama_env() | |
| # Start server as background process | |
| process = subprocess.Popen( | |
| [OLLAMA_BIN, "serve"], | |
| env=env, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.PIPE | |
| ) | |
| # Wait for server to be ready (max 30 seconds) | |
| max_retries = 30 | |
| for i in range(max_retries): | |
| if is_ollama_server_running(): | |
| logger.info(f"Ollama server started successfully in {i+1} seconds.") | |
| return True | |
| time.sleep(1) | |
| logger.error("Ollama server failed to start within 30 seconds.") | |
| return False | |
| except Exception as e: | |
| logger.error(f"Failed to start Ollama server: {e}") | |
| return False | |
| def load_model_into_vram(model_name: str) -> bool: | |
| """ | |
| Load model into VRAM for faster inference. | |
| Uses keep_alive=-1 to keep model loaded. | |
| Args: | |
| model_name: Name of the Ollama model to load | |
| Returns: | |
| bool: True if model loaded successfully | |
| """ | |
| logger.info(f"Loading model {model_name} into VRAM...") | |
| try: | |
| # Send a minimal request to trigger model loading | |
| result = subprocess.run( | |
| [ | |
| "curl", "-s", f"{OLLAMA_URL}/api/generate", | |
| "-d", f'{{"model": "{model_name}", "prompt": "hi", "stream": false}}' | |
| ], | |
| capture_output=True, | |
| text=True, | |
| timeout=120 # Model loading can take time | |
| ) | |
| if "error" in result.stdout.lower(): | |
| logger.error(f"Error loading model: {result.stdout}") | |
| return False | |
| # Set keep_alive to -1 to keep model in VRAM | |
| subprocess.run( | |
| [ | |
| "curl", "-s", f"{OLLAMA_URL}/api/generate", | |
| "-d", f'{{"model": "{model_name}", "keep_alive": -1}}' | |
| ], | |
| capture_output=True, | |
| timeout=10 | |
| ) | |
| logger.info(f"Model {model_name} loaded into VRAM successfully.") | |
| return True | |
| except subprocess.TimeoutExpired: | |
| logger.error("Timeout while loading model into VRAM.") | |
| return False | |
| except Exception as e: | |
| logger.error(f"Error loading model into VRAM: {e}") | |
| return False | |
| def log_active_models(): | |
| """Log which models are currently loaded in VRAM (not just on disk).""" | |
| try: | |
| result = subprocess.run( | |
| ["curl", "-s", f"{OLLAMA_URL}/api/ps"], | |
| capture_output=True, | |
| text=True, | |
| timeout=5 | |
| ) | |
| logger.info(f"Active models in VRAM: {result.stdout}") | |
| except Exception as e: | |
| logger.warning(f"Could not get active models: {e}") | |
| def ensure_ollama_ready_gpu(model_name: str) -> bool: | |
| """ | |
| Main function to ensure Ollama is fully ready with GPU support. | |
| MUST be called inside @spaces.GPU decorated function. | |
| This function: | |
| 1. Starts Ollama server (which will detect GPU) | |
| 2. Loads the specified model into VRAM | |
| 3. Logs which model is active | |
| Args: | |
| model_name: Name of the Ollama model to use | |
| Returns: | |
| bool: True if ready | |
| Raises: | |
| RuntimeError: If setup fails | |
| """ | |
| logger.info(f"Ensuring Ollama is ready with GPU for model: {model_name}") | |
| # Step 1: Start server (will detect GPU since we're inside @spaces.GPU) | |
| if not start_ollama_server_gpu(): | |
| raise RuntimeError("Failed to start Ollama server with GPU") | |
| # Step 2: Load model into VRAM | |
| if not load_model_into_vram(model_name): | |
| raise RuntimeError(f"Failed to load model {model_name} into VRAM") | |
| # Step 3: Log which model is actually active in VRAM | |
| log_active_models() | |
| logger.info("Ollama is ready with GPU support!") | |
| return True | |
| # =========================================== | |
| # STARTUP: DOWNLOAD BINARY AND MODELS (CPU) | |
| # =========================================== | |
| def download_ollama_binary(): | |
| """Download Ollama binary if not present.""" | |
| if os.path.exists(OLLAMA_BIN): | |
| logger.info("Ollama binary already exists.") | |
| return True | |
| logger.info("Downloading Ollama binary (ZST)...") | |
| try: | |
| subprocess.run( | |
| ["curl", "-L", "https://ollama.com/download/ollama-linux-amd64.tar.zst", "-o", "ollama.tar.zst"], | |
| check=True, | |
| timeout=300 | |
| ) | |
| subprocess.run(["tar", "--zstd", "-xf", "ollama.tar.zst"], check=True) | |
| subprocess.run(["chmod", "+x", OLLAMA_BIN], check=True) | |
| os.remove("ollama.tar.zst") # Cleanup | |
| logger.info("Ollama binary downloaded and extracted successfully.") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to download Ollama binary: {e}") | |
| return False | |
| def pull_ollama_models(): | |
| """ | |
| Pull Ollama models at startup (runs on CPU). | |
| This pre-downloads the models so they're ready when GPU is available. | |
| """ | |
| logger.info("Pre-downloading Ollama models...") | |
| # Need to temporarily start server to pull models | |
| env = get_ollama_env() | |
| # Start server temporarily | |
| server_process = subprocess.Popen( | |
| [OLLAMA_BIN, "serve"], | |
| env=env, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.PIPE | |
| ) | |
| # Wait for server | |
| time.sleep(5) | |
| for _ in range(20): | |
| if is_ollama_server_running(): | |
| break | |
| time.sleep(1) | |
| # Pull each model | |
| for model in VLM_MODELS: | |
| logger.info(f"Pulling model: {model}") | |
| try: | |
| subprocess.run( | |
| [OLLAMA_BIN, "pull", model], | |
| env=env, | |
| timeout=600, | |
| capture_output=True | |
| ) | |
| logger.info(f"Model {model} pulled successfully.") | |
| except Exception as e: | |
| logger.warning(f"Failed to pull model {model}: {e}") | |
| # Stop server (we'll restart it inside GPU context later) | |
| server_process.terminate() | |
| try: | |
| server_process.wait(timeout=5) | |
| except subprocess.TimeoutExpired: | |
| server_process.kill() | |
| logger.info("Ollama models pre-download complete.") | |
| def setup_ollama_startup(): | |
| """Setup Ollama at startup: download binary and pull models.""" | |
| download_ollama_binary() | |
| pull_ollama_models() | |
| def setup_hf_models(): | |
| """ | |
| Downloads heavy HuggingFace models to disk at startup. | |
| This prevents ZeroGPU timeouts during the first inference. | |
| """ | |
| logger.info("Starting HuggingFace models download (Warm-up)...") | |
| try: | |
| manager = ModelManager() | |
| # 1. SeC-4B (Heavy, ~15GB) | |
| logger.info("Downloading SeC-4B...") | |
| manager.download("OpenIXCLab/SeC-4B") | |
| # 2. Florence-2 (Grounding) | |
| logger.info("Downloading Florence-2...") | |
| manager.download("microsoft/Florence-2-large") | |
| # 3. SigLIP (For Generic Category) | |
| logger.info("Downloading SigLIP...") | |
| manager.download("google/siglip2-base-patch16-naflex") | |
| # 4. SAM2 Checkpoint (Direct URL) | |
| logger.info("Downloading SAM2 checkpoint...") | |
| manager.download_url( | |
| "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt", | |
| "sam2.1_hiera_large.pt" | |
| ) | |
| logger.info("All HuggingFace models downloaded successfully.") | |
| except Exception as e: | |
| logger.error(f"Error during HF model download: {e}") | |
| # =========================================== | |
| # STARTUP: PARALLEL MODEL DOWNLOADS | |
| # =========================================== | |
| logger.info("Starting parallel model downloads at startup...") | |
| t_hf = threading.Thread(target=setup_hf_models, daemon=True) | |
| t_ollama = threading.Thread(target=setup_ollama_startup, daemon=True) | |
| t_hf.start() | |
| t_ollama.start() | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| logger.info(f"Main process device detection: {DEVICE}") | |
| # =========================================== | |
| # UTILITY FUNCTIONS | |
| # =========================================== | |
| def process_inputs_to_frames(input_data, output_folder: str) -> tuple: | |
| """ | |
| Extracts frames from video (1 FPS) or copies images to output folder. | |
| Enforces MAX_FRAMES limit to prevent ZeroGPU timeouts. | |
| Args: | |
| input_data: Video file or list of image files | |
| output_folder: Directory to save extracted frames | |
| Returns: | |
| tuple: (output_folder path, list of frame file paths) | |
| """ | |
| if os.path.exists(output_folder): | |
| shutil.rmtree(output_folder) | |
| os.makedirs(output_folder) | |
| frame_paths = [] | |
| video_extensions = {'.mp4', '.avi', '.mov', '.mkv', '.webm'} | |
| input_list = input_data if isinstance(input_data, list) else [input_data] | |
| if not input_list: | |
| return output_folder, [] | |
| first_file = input_list[0].name if hasattr(input_list[0], 'name') else str(input_list[0]) | |
| ext = os.path.splitext(first_file)[1].lower() | |
| if ext in video_extensions: | |
| # Process video file | |
| logger.info(f"Processing video: {first_file}...") | |
| cap = cv2.VideoCapture(first_file) | |
| video_fps = cap.get(cv2.CAP_PROP_FPS) | |
| total_frames_original = cap.get(cv2.CAP_PROP_FRAME_COUNT) | |
| if video_fps == 0 or np.isnan(video_fps): | |
| video_fps = 30 | |
| duration_sec = total_frames_original / video_fps | |
| # Validate video duration | |
| if duration_sec > MAX_FRAMES: | |
| cap.release() | |
| msg = f"Video is too long ({int(duration_sec)}s). Max allowed is {MAX_FRAMES}s to avoid ZeroGPU timeout." | |
| logger.error(msg) | |
| raise gr.Error(msg) | |
| # Sample at 1 FPS | |
| frame_interval = max(1, int(video_fps)) | |
| count = 0 | |
| saved_count = 0 | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| if count % frame_interval == 0: | |
| filename = f"frame_{saved_count:05d}.jpg" | |
| filepath = os.path.join(output_folder, filename) | |
| cv2.imwrite(filepath, frame) | |
| frame_paths.append(filepath) | |
| saved_count += 1 | |
| if saved_count > MAX_FRAMES: | |
| cap.release() | |
| raise gr.Error(f"Limit reached: > {MAX_FRAMES} frames extracted.") | |
| count += 1 | |
| cap.release() | |
| logger.info(f"Video sampled at 1 FPS. Total frames: {saved_count}") | |
| else: | |
| # Process image files | |
| if len(input_list) > MAX_FRAMES: | |
| raise gr.Error(f"Too many images! You uploaded {len(input_list)}. Max allowed is {MAX_FRAMES}.") | |
| logger.info(f"Processing {len(input_list)} images...") | |
| input_list.sort(key=lambda x: x.name if hasattr(x, 'name') else str(x)) | |
| for i, f in enumerate(input_list): | |
| path = f.name if hasattr(f, 'name') else str(f) | |
| try: | |
| img = Image.open(path).convert("RGB") | |
| filename = f"frame_{i:05d}.jpg" | |
| filepath = os.path.join(output_folder, filename) | |
| img.save(filepath) | |
| frame_paths.append(filepath) | |
| except Exception as e: | |
| logger.warning(f"Skipping file {path}: {e}") | |
| return output_folder, frame_paths | |
| def create_video_overlay(frames_folder: str, masks_dict: dict, output_path: str, fps: int = 5) -> str: | |
| """ | |
| Creates a video from frames with segmentation masks overlaid. | |
| Args: | |
| frames_folder: Directory containing frame images | |
| masks_dict: Dictionary mapping frame index to mask arrays | |
| output_path: Output video file path | |
| fps: Frames per second for output video | |
| Returns: | |
| Output video path or None if failed | |
| """ | |
| logger.info("Generating result video...") | |
| frame_files = sorted([f for f in os.listdir(frames_folder) if f.endswith(".jpg")]) | |
| if not frame_files: | |
| return None | |
| first_frame = cv2.imread(os.path.join(frames_folder, frame_files[0])) | |
| height, width, _ = first_frame.shape | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) | |
| # Orange/Gold color for mask overlay | |
| mask_color = np.array([255, 100, 0], dtype=np.uint8) | |
| for i, filename in enumerate(frame_files): | |
| frame = cv2.imread(os.path.join(frames_folder, filename)) | |
| mask_overlay = np.zeros_like(frame) | |
| if i in masks_dict: | |
| masks_data = masks_dict[i] | |
| masks_list = [masks_data] if isinstance(masks_data, np.ndarray) else ( | |
| masks_data if isinstance(masks_data, list) else [] | |
| ) | |
| for mask in masks_list: | |
| mask_overlay[mask > 0] = mask_color | |
| if np.any(mask_overlay): | |
| frame = cv2.addWeighted(frame, 1, mask_overlay, 0.5, 0) | |
| out.write(frame) | |
| out.release() | |
| return output_path | |
| # =========================================== | |
| # UNIQUE INSTANCE SEGMENTATION | |
| # =========================================== | |
| def process_unique_upload(input_files): | |
| """ | |
| Process uploaded files for Unique Instance segmentation. | |
| Extracts frames and prepares the UI for annotation. | |
| """ | |
| if not input_files: | |
| return None, None, [], "Please upload files first.", gr.Slider(value=0, maximum=0, visible=False) | |
| temp_dir = tempfile.mkdtemp() | |
| frames_dir, frame_paths = process_inputs_to_frames(input_files, temp_dir) | |
| num_frames = len(frame_paths) | |
| if num_frames == 0: | |
| return None, None, [], "No frames extracted.", gr.Slider(value=0, maximum=0, visible=False) | |
| new_slider = gr.Slider( | |
| value=0, | |
| minimum=0, | |
| maximum=num_frames - 1, | |
| step=1, | |
| visible=True, | |
| interactive=True, | |
| label=f"Select Reference Frame (0 - {num_frames - 1})" | |
| ) | |
| return frame_paths[0], frames_dir, [], f"Processed {num_frames} frames (1 FPS). Select target.", new_slider | |
| def update_canvas_from_slider(frame_idx, frames_dir): | |
| """Update the displayed frame when slider changes.""" | |
| if not frames_dir or not os.path.exists(frames_dir): | |
| return None, [] | |
| filename = f"frame_{int(frame_idx):05d}.jpg" | |
| path = os.path.join(frames_dir, filename) | |
| if os.path.exists(path): | |
| img = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB) | |
| return img, [] | |
| return None, [] | |
| def add_point(img, evt: gr.SelectData, points_state): | |
| """Add a point annotation to the image.""" | |
| x, y = evt.index | |
| points_state.append((x, y)) | |
| img_pil = Image.fromarray(img) | |
| img_cv = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR) | |
| # Draw markers for all points | |
| for px, py in points_state: | |
| cv2.drawMarker( | |
| img_cv, (px, py), (0, 255, 0), | |
| markerType=cv2.MARKER_TILTED_CROSS, | |
| markerSize=20, | |
| thickness=3 | |
| ) | |
| return cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB), points_state | |
| def run_unique_segmentation(input_files, points_state, text_prompt, sam_encoder, offload_gpu, cleanup_interval, frame_idx_slider): | |
| """ | |
| Run Unique Instance segmentation on the uploaded frames. | |
| Tracks a specific object identified by points or text description. | |
| """ | |
| if not input_files: | |
| return None, "Error: Process input first." | |
| # Wait for HF models to be downloaded | |
| if t_hf.is_alive(): | |
| logger.info("Waiting for HF models download to finish...") | |
| t_hf.join() | |
| try: | |
| logger.info("Processing inputs on GPU node...") | |
| temp_dir = tempfile.mkdtemp() | |
| # Re-extract frames to ensure they exist on GPU ephemeral storage | |
| frames_dir, _ = process_inputs_to_frames(input_files, temp_dir) | |
| logger.info("Initializing UniqueInstanceSegmenter...") | |
| segmenter = UniqueInstanceSegmenter( | |
| sam_encoder=sam_encoder, | |
| memory_cleanup_interval=int(cleanup_interval), | |
| device="cuda" | |
| ) | |
| if offload_gpu: | |
| segmenter.optimize_cuda_memory() | |
| annotation_frame = f"frame_{int(frame_idx_slider):05d}.jpg" | |
| if not os.path.exists(os.path.join(frames_dir, annotation_frame)): | |
| return None, f"Error: Frame {annotation_frame} not found." | |
| # Run segmentation based on input type | |
| if text_prompt.strip(): | |
| logger.info(f"Mode: Text -> {text_prompt}") | |
| result = segmenter.segment( | |
| frames_path=frames_dir, | |
| text=text_prompt, | |
| annotation_frame=annotation_frame, | |
| offload_frames_to_gpu=offload_gpu | |
| ) | |
| else: | |
| if not points_state: | |
| return None, "Please add points or text." | |
| logger.info(f"Mode: Points -> {points_state}") | |
| result = segmenter.segment( | |
| frames_path=frames_dir, | |
| points=points_state, | |
| annotation_frame=annotation_frame, | |
| offload_frames_to_gpu=offload_gpu | |
| ) | |
| output_vid = os.path.join(OUTPUT_BASE_DIR, "unique_output.mp4") | |
| return create_video_overlay(frames_dir, result.masks, output_vid), f"Completed. {result.num_frames} frames processed." | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| logger.error(str(e)) | |
| if isinstance(e, gr.Error): | |
| raise e | |
| return None, f"Error: {str(e)}" | |
| # =========================================== | |
| # GENERIC CATEGORY SEGMENTATION | |
| # =========================================== | |
| def run_generic_segmentation(input_files, category, accept_thresh, reject_thresh, vlm_model_name): | |
| """ | |
| Run Generic Category segmentation on the uploaded frames. | |
| Detects all instances of a specified category using VLM + segmentation. | |
| IMPORTANT: This function starts Ollama server INSIDE the GPU context, | |
| ensuring that Ollama can detect and use the GPU for inference. | |
| """ | |
| if not input_files: | |
| return None, "Error: Upload input." | |
| if not category.strip(): | |
| return None, "Error: Please specify text." | |
| # Wait for model downloads to complete | |
| if t_hf.is_alive(): | |
| logger.info("Waiting for HF models download...") | |
| t_hf.join() | |
| if t_ollama.is_alive(): | |
| logger.info("Waiting for Ollama models download...") | |
| t_ollama.join() | |
| try: | |
| # ========================================================= | |
| # CRITICAL: Start Ollama INSIDE @spaces.GPU context | |
| # This ensures Ollama detects and uses the GPU! | |
| # ========================================================= | |
| logger.info("=" * 50) | |
| logger.info("Starting Ollama server with GPU support...") | |
| logger.info("=" * 50) | |
| ensure_ollama_ready_gpu(vlm_model_name) | |
| logger.info("Ollama is running with GPU. Processing inputs...") | |
| # Process input frames | |
| temp_dir = tempfile.mkdtemp() | |
| frames_dir, _ = process_inputs_to_frames(input_files, temp_dir) | |
| logger.info(f"Initializing GenericCategorySegmenter with VLM: {vlm_model_name}") | |
| segmenter = GenericCategorySegmenter( | |
| device="cuda", | |
| vlm_model=vlm_model_name | |
| ) | |
| logger.info(f"Detecting category: {category}") | |
| result = segmenter.segment( | |
| frames_path=frames_dir, | |
| category=category, | |
| accept_threshold=accept_thresh, | |
| reject_threshold=reject_thresh, | |
| save_debug=False | |
| ) | |
| output_vid = os.path.join(OUTPUT_BASE_DIR, "generic_output.mp4") | |
| total_detections = sum(len(d) for d in result.metadata['detections'].values()) | |
| return create_video_overlay(frames_dir, result.masks, output_vid), f"Completed! Total detections: {total_detections}" | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| logger.error(f"Generic segmentation error: {e}") | |
| if isinstance(e, gr.Error): | |
| raise e | |
| return None, f"Error: {e}" | |
| # =========================================== | |
| # GRADIO UI | |
| # =========================================== | |
| with gr.Blocks(title="ENEAS: Embedding-guided Neural Ensemble for Adaptive Segmentation") as demo: | |
| gr.Markdown( | |
| f""" | |
| # ⚡ ENEAS: Embedding-guided Neural Ensemble for Adaptive Segmentation | |
| **⚠️ IMPORTANT LIMITS:** | |
| - Maximum **{MAX_FRAMES} FRAMES** to prevent ZeroGPU timeouts | |
| - Videos are sampled at **1 FPS** → Max **{MAX_FRAMES} seconds** of video | |
| - Exceeding these limits will stop execution | |
| """ | |
| ) | |
| with gr.Tabs(): | |
| # =========================================== | |
| # TAB 1: UNIQUE INSTANCE SEGMENTATION | |
| # =========================================== | |
| with gr.Tab("🎯 Unique Instance"): | |
| gr.Markdown("Track a specific object. Upload Video (1 FPS extraction) OR Images.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| u_file = gr.File( | |
| label="Input (Video or Images)", | |
| file_count="multiple", | |
| file_types=["video", "image"] | |
| ) | |
| u_btn_proc = gr.Button("▶️ 1. Process Input (Extract 1 FPS)", variant="secondary") | |
| u_slider = gr.Slider(label="Frame Selector", visible=False) | |
| with gr.Accordion("Advanced Options", open=False): | |
| u_enc = gr.Dropdown( | |
| ["long-large", "long-small"], | |
| value="long-large", | |
| label="SAM2 Encoder" | |
| ) | |
| u_offload = gr.Checkbox(label="GPU Memory Offload", value=False) | |
| with gr.Column(scale=2): | |
| u_path_frames_cpu = gr.Textbox(visible=False) | |
| points_state = gr.State([]) | |
| u_img = gr.Image( | |
| label="Reference Frame (Click to add points)", | |
| interactive=True | |
| ) | |
| u_txt = gr.Textbox( | |
| label="Text Description (Grounding)", | |
| placeholder="Points are ignored if text is provided." | |
| ) | |
| u_btn_run = gr.Button("🚀 2. Run Segmentation", variant="primary") | |
| u_out = gr.Video(label="Result") | |
| u_status = gr.Textbox(label="Status", interactive=False) | |
| # Event handlers | |
| u_btn_proc.click( | |
| process_unique_upload, | |
| [u_file], | |
| [u_img, u_path_frames_cpu, points_state, u_status, u_slider] | |
| ) | |
| u_slider.change( | |
| update_canvas_from_slider, | |
| inputs=[u_slider, u_path_frames_cpu], | |
| outputs=[u_img, points_state] | |
| ) | |
| u_img.select(add_point, [u_img, points_state], [u_img, points_state]) | |
| u_btn_run.click( | |
| run_unique_segmentation, | |
| [u_file, points_state, u_txt, u_enc, u_offload, gr.Number(10, visible=False), u_slider], | |
| [u_out, u_status] | |
| ) | |
| # Example for Unique Instance | |
| gr.Examples( | |
| examples=[ | |
| [["examples/reporter.mp4"], "blonde woman with microphone"] | |
| ], | |
| inputs=[u_file, u_txt], | |
| label="Example" | |
| ) | |
| # =========================================== | |
| # TAB 2: GENERIC CATEGORY SEGMENTATION | |
| # =========================================== | |
| with gr.Tab("🔮 Generic Text"): | |
| gr.Markdown( | |
| f""" | |
| Detect all instances of a text prompt in every frame (Max {MAX_FRAMES} frames). | |
| **🚀 GPU-Accelerated:** Ollama VLM runs on GPU for fast inference. | |
| First request includes ~15-20s server startup time. | |
| """ | |
| ) | |
| with gr.Row(): | |
| g_file = gr.File( | |
| label="Input (Video or Images)", | |
| file_count="multiple", | |
| file_types=["video", "image"] | |
| ) | |
| g_cat = gr.Textbox( | |
| label="Text prompt", | |
| placeholder="e.g., person, chair, car, dog" | |
| ) | |
| g_btn = gr.Button("🔍 Run Detection", variant="primary") | |
| with gr.Accordion("Detection Settings", open=True): | |
| g_accept = gr.Slider( | |
| 0.0, 1.0, | |
| value=0.30, | |
| label="Accept Threshold", | |
| info="Higher = more confident detections only" | |
| ) | |
| g_reject = gr.Slider( | |
| 0.0, 1.0, | |
| value=0.1, | |
| label="Reject Threshold", | |
| info="Lower = filter out more false positives" | |
| ) | |
| g_vlm = gr.Dropdown( | |
| choices=VLM_MODELS, | |
| value=VLM_MODELS[0], | |
| label="VLM Model", | |
| info="Larger models are more accurate but slower" | |
| ) | |
| g_out = gr.Video(label="Result") | |
| g_stat = gr.Textbox(label="Detection Statistics", interactive=False) | |
| g_btn.click( | |
| run_generic_segmentation, | |
| [g_file, g_cat, g_accept, g_reject, g_vlm], | |
| [g_out, g_stat] | |
| ) | |
| # Example for Generic Category | |
| gr.Examples( | |
| examples=[ | |
| [["examples/moving.mp4"], "person", 0.3, 0.1, "qwen3-vl:4b-instruct-q8_0"] | |
| ], | |
| inputs=[g_file, g_cat, g_accept, g_reject, g_vlm], | |
| label="Example" | |
| ) | |
| # =========================================== | |
| # MAIN ENTRY POINT | |
| # =========================================== | |
| if __name__ == "__main__": | |
| demo.launch() |