import gradio as gr import os import cv2 import shutil import tempfile import numpy as np import subprocess import time import threading import torch import sys import logging from PIL import Image # =========================================== # LOGGING CONFIGURATION # =========================================== logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[logging.StreamHandler(sys.stdout)] ) logger = logging.getLogger(__name__) # Ensure Python sees the local 'eneas' folder sys.path.append(os.path.dirname(os.path.abspath(__file__))) import spaces try: from eneas.segmentation import UniqueInstanceSegmenter, GenericCategorySegmenter from eneas.segmentation.model_manager import ModelManager except ImportError as e: logger.error(f"Error importing ENEAS: {e}") raise e # =========================================== # CONSTANTS # =========================================== MAX_FRAMES = 150 # Limit frames to avoid ZeroGPU Timeout (~1s/frame processing) OLLAMA_HOST = "127.0.0.1:11434" OLLAMA_URL = f"http://{OLLAMA_HOST}" OLLAMA_BIN = "./bin/ollama" VLM_MODELS = [ "qwen3-vl:4b-instruct-q8_0", "qwen3-vl:2b-instruct-q8_0" ] OUTPUT_BASE_DIR = "gradio_outputs" os.makedirs(OUTPUT_BASE_DIR, exist_ok=True) # =========================================== # OLLAMA FUNCTIONS (FOR USE INSIDE @spaces.GPU) # =========================================== def get_ollama_env(): """Get environment variables for Ollama process with GPU support.""" env = os.environ.copy() env["OLLAMA_HOST"] = OLLAMA_HOST env["OLLAMA_ORIGINS"] = "*" env["HOME"] = os.getcwd() # Add local lib path for the extracted binary cwd = os.getcwd() lib_path = f"{cwd}/lib" if "LD_LIBRARY_PATH" in env: env["LD_LIBRARY_PATH"] += f":{lib_path}" else: env["LD_LIBRARY_PATH"] = lib_path return env def is_ollama_server_running() -> bool: """Check if Ollama server is responding.""" try: result = subprocess.run( ["curl", "-s", "-o", "/dev/null", "-w", "%{http_code}", OLLAMA_URL], capture_output=True, text=True, timeout=5 ) return result.stdout.strip() == "200" except Exception: return False def start_ollama_server_gpu(): """ Start Ollama server INSIDE @spaces.GPU context. This ensures Ollama detects and uses the GPU. Returns: bool: True if server started successfully """ if is_ollama_server_running(): logger.info("Ollama server is already running.") return True logger.info("Starting Ollama server inside GPU context...") try: env = get_ollama_env() # Start server as background process process = subprocess.Popen( [OLLAMA_BIN, "serve"], env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE ) # Wait for server to be ready (max 30 seconds) max_retries = 30 for i in range(max_retries): if is_ollama_server_running(): logger.info(f"Ollama server started successfully in {i+1} seconds.") return True time.sleep(1) logger.error("Ollama server failed to start within 30 seconds.") return False except Exception as e: logger.error(f"Failed to start Ollama server: {e}") return False def load_model_into_vram(model_name: str) -> bool: """ Load model into VRAM for faster inference. Uses keep_alive=-1 to keep model loaded. Args: model_name: Name of the Ollama model to load Returns: bool: True if model loaded successfully """ logger.info(f"Loading model {model_name} into VRAM...") try: # Send a minimal request to trigger model loading result = subprocess.run( [ "curl", "-s", f"{OLLAMA_URL}/api/generate", "-d", f'{{"model": "{model_name}", "prompt": "hi", "stream": false}}' ], capture_output=True, text=True, timeout=120 # Model loading can take time ) if "error" in result.stdout.lower(): logger.error(f"Error loading model: {result.stdout}") return False # Set keep_alive to -1 to keep model in VRAM subprocess.run( [ "curl", "-s", f"{OLLAMA_URL}/api/generate", "-d", f'{{"model": "{model_name}", "keep_alive": -1}}' ], capture_output=True, timeout=10 ) logger.info(f"Model {model_name} loaded into VRAM successfully.") return True except subprocess.TimeoutExpired: logger.error("Timeout while loading model into VRAM.") return False except Exception as e: logger.error(f"Error loading model into VRAM: {e}") return False def log_active_models(): """Log which models are currently loaded in VRAM (not just on disk).""" try: result = subprocess.run( ["curl", "-s", f"{OLLAMA_URL}/api/ps"], capture_output=True, text=True, timeout=5 ) logger.info(f"Active models in VRAM: {result.stdout}") except Exception as e: logger.warning(f"Could not get active models: {e}") def ensure_ollama_ready_gpu(model_name: str) -> bool: """ Main function to ensure Ollama is fully ready with GPU support. MUST be called inside @spaces.GPU decorated function. This function: 1. Starts Ollama server (which will detect GPU) 2. Loads the specified model into VRAM 3. Logs which model is active Args: model_name: Name of the Ollama model to use Returns: bool: True if ready Raises: RuntimeError: If setup fails """ logger.info(f"Ensuring Ollama is ready with GPU for model: {model_name}") # Step 1: Start server (will detect GPU since we're inside @spaces.GPU) if not start_ollama_server_gpu(): raise RuntimeError("Failed to start Ollama server with GPU") # Step 2: Load model into VRAM if not load_model_into_vram(model_name): raise RuntimeError(f"Failed to load model {model_name} into VRAM") # Step 3: Log which model is actually active in VRAM log_active_models() logger.info("Ollama is ready with GPU support!") return True # =========================================== # STARTUP: DOWNLOAD BINARY AND MODELS (CPU) # =========================================== def download_ollama_binary(): """Download Ollama binary if not present.""" if os.path.exists(OLLAMA_BIN): logger.info("Ollama binary already exists.") return True logger.info("Downloading Ollama binary (ZST)...") try: subprocess.run( ["curl", "-L", "https://ollama.com/download/ollama-linux-amd64.tar.zst", "-o", "ollama.tar.zst"], check=True, timeout=300 ) subprocess.run(["tar", "--zstd", "-xf", "ollama.tar.zst"], check=True) subprocess.run(["chmod", "+x", OLLAMA_BIN], check=True) os.remove("ollama.tar.zst") # Cleanup logger.info("Ollama binary downloaded and extracted successfully.") return True except Exception as e: logger.error(f"Failed to download Ollama binary: {e}") return False def pull_ollama_models(): """ Pull Ollama models at startup (runs on CPU). This pre-downloads the models so they're ready when GPU is available. """ logger.info("Pre-downloading Ollama models...") # Need to temporarily start server to pull models env = get_ollama_env() # Start server temporarily server_process = subprocess.Popen( [OLLAMA_BIN, "serve"], env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE ) # Wait for server time.sleep(5) for _ in range(20): if is_ollama_server_running(): break time.sleep(1) # Pull each model for model in VLM_MODELS: logger.info(f"Pulling model: {model}") try: subprocess.run( [OLLAMA_BIN, "pull", model], env=env, timeout=600, capture_output=True ) logger.info(f"Model {model} pulled successfully.") except Exception as e: logger.warning(f"Failed to pull model {model}: {e}") # Stop server (we'll restart it inside GPU context later) server_process.terminate() try: server_process.wait(timeout=5) except subprocess.TimeoutExpired: server_process.kill() logger.info("Ollama models pre-download complete.") def setup_ollama_startup(): """Setup Ollama at startup: download binary and pull models.""" download_ollama_binary() pull_ollama_models() def setup_hf_models(): """ Downloads heavy HuggingFace models to disk at startup. This prevents ZeroGPU timeouts during the first inference. """ logger.info("Starting HuggingFace models download (Warm-up)...") try: manager = ModelManager() # 1. SeC-4B (Heavy, ~15GB) logger.info("Downloading SeC-4B...") manager.download("OpenIXCLab/SeC-4B") # 2. Florence-2 (Grounding) logger.info("Downloading Florence-2...") manager.download("microsoft/Florence-2-large") # 3. SigLIP (For Generic Category) logger.info("Downloading SigLIP...") manager.download("google/siglip2-base-patch16-naflex") # 4. SAM2 Checkpoint (Direct URL) logger.info("Downloading SAM2 checkpoint...") manager.download_url( "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt", "sam2.1_hiera_large.pt" ) logger.info("All HuggingFace models downloaded successfully.") except Exception as e: logger.error(f"Error during HF model download: {e}") # =========================================== # STARTUP: PARALLEL MODEL DOWNLOADS # =========================================== logger.info("Starting parallel model downloads at startup...") t_hf = threading.Thread(target=setup_hf_models, daemon=True) t_ollama = threading.Thread(target=setup_ollama_startup, daemon=True) t_hf.start() t_ollama.start() DEVICE = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Main process device detection: {DEVICE}") # =========================================== # UTILITY FUNCTIONS # =========================================== def process_inputs_to_frames(input_data, output_folder: str) -> tuple: """ Extracts frames from video (1 FPS) or copies images to output folder. Enforces MAX_FRAMES limit to prevent ZeroGPU timeouts. Args: input_data: Video file or list of image files output_folder: Directory to save extracted frames Returns: tuple: (output_folder path, list of frame file paths) """ if os.path.exists(output_folder): shutil.rmtree(output_folder) os.makedirs(output_folder) frame_paths = [] video_extensions = {'.mp4', '.avi', '.mov', '.mkv', '.webm'} input_list = input_data if isinstance(input_data, list) else [input_data] if not input_list: return output_folder, [] first_file = input_list[0].name if hasattr(input_list[0], 'name') else str(input_list[0]) ext = os.path.splitext(first_file)[1].lower() if ext in video_extensions: # Process video file logger.info(f"Processing video: {first_file}...") cap = cv2.VideoCapture(first_file) video_fps = cap.get(cv2.CAP_PROP_FPS) total_frames_original = cap.get(cv2.CAP_PROP_FRAME_COUNT) if video_fps == 0 or np.isnan(video_fps): video_fps = 30 duration_sec = total_frames_original / video_fps # Validate video duration if duration_sec > MAX_FRAMES: cap.release() msg = f"Video is too long ({int(duration_sec)}s). Max allowed is {MAX_FRAMES}s to avoid ZeroGPU timeout." logger.error(msg) raise gr.Error(msg) # Sample at 1 FPS frame_interval = max(1, int(video_fps)) count = 0 saved_count = 0 while cap.isOpened(): ret, frame = cap.read() if not ret: break if count % frame_interval == 0: filename = f"frame_{saved_count:05d}.jpg" filepath = os.path.join(output_folder, filename) cv2.imwrite(filepath, frame) frame_paths.append(filepath) saved_count += 1 if saved_count > MAX_FRAMES: cap.release() raise gr.Error(f"Limit reached: > {MAX_FRAMES} frames extracted.") count += 1 cap.release() logger.info(f"Video sampled at 1 FPS. Total frames: {saved_count}") else: # Process image files if len(input_list) > MAX_FRAMES: raise gr.Error(f"Too many images! You uploaded {len(input_list)}. Max allowed is {MAX_FRAMES}.") logger.info(f"Processing {len(input_list)} images...") input_list.sort(key=lambda x: x.name if hasattr(x, 'name') else str(x)) for i, f in enumerate(input_list): path = f.name if hasattr(f, 'name') else str(f) try: img = Image.open(path).convert("RGB") filename = f"frame_{i:05d}.jpg" filepath = os.path.join(output_folder, filename) img.save(filepath) frame_paths.append(filepath) except Exception as e: logger.warning(f"Skipping file {path}: {e}") return output_folder, frame_paths def create_video_overlay(frames_folder: str, masks_dict: dict, output_path: str, fps: int = 5) -> str: """ Creates a video from frames with segmentation masks overlaid. Args: frames_folder: Directory containing frame images masks_dict: Dictionary mapping frame index to mask arrays output_path: Output video file path fps: Frames per second for output video Returns: Output video path or None if failed """ logger.info("Generating result video...") frame_files = sorted([f for f in os.listdir(frames_folder) if f.endswith(".jpg")]) if not frame_files: return None first_frame = cv2.imread(os.path.join(frames_folder, frame_files[0])) height, width, _ = first_frame.shape fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) # Orange/Gold color for mask overlay mask_color = np.array([255, 100, 0], dtype=np.uint8) for i, filename in enumerate(frame_files): frame = cv2.imread(os.path.join(frames_folder, filename)) mask_overlay = np.zeros_like(frame) if i in masks_dict: masks_data = masks_dict[i] masks_list = [masks_data] if isinstance(masks_data, np.ndarray) else ( masks_data if isinstance(masks_data, list) else [] ) for mask in masks_list: mask_overlay[mask > 0] = mask_color if np.any(mask_overlay): frame = cv2.addWeighted(frame, 1, mask_overlay, 0.5, 0) out.write(frame) out.release() return output_path # =========================================== # UNIQUE INSTANCE SEGMENTATION # =========================================== def process_unique_upload(input_files): """ Process uploaded files for Unique Instance segmentation. Extracts frames and prepares the UI for annotation. """ if not input_files: return None, None, [], "Please upload files first.", gr.Slider(value=0, maximum=0, visible=False) temp_dir = tempfile.mkdtemp() frames_dir, frame_paths = process_inputs_to_frames(input_files, temp_dir) num_frames = len(frame_paths) if num_frames == 0: return None, None, [], "No frames extracted.", gr.Slider(value=0, maximum=0, visible=False) new_slider = gr.Slider( value=0, minimum=0, maximum=num_frames - 1, step=1, visible=True, interactive=True, label=f"Select Reference Frame (0 - {num_frames - 1})" ) return frame_paths[0], frames_dir, [], f"Processed {num_frames} frames (1 FPS). Select target.", new_slider def update_canvas_from_slider(frame_idx, frames_dir): """Update the displayed frame when slider changes.""" if not frames_dir or not os.path.exists(frames_dir): return None, [] filename = f"frame_{int(frame_idx):05d}.jpg" path = os.path.join(frames_dir, filename) if os.path.exists(path): img = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB) return img, [] return None, [] def add_point(img, evt: gr.SelectData, points_state): """Add a point annotation to the image.""" x, y = evt.index points_state.append((x, y)) img_pil = Image.fromarray(img) img_cv = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR) # Draw markers for all points for px, py in points_state: cv2.drawMarker( img_cv, (px, py), (0, 255, 0), markerType=cv2.MARKER_TILTED_CROSS, markerSize=20, thickness=3 ) return cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB), points_state @spaces.GPU(duration=180) def run_unique_segmentation(input_files, points_state, text_prompt, sam_encoder, offload_gpu, cleanup_interval, frame_idx_slider): """ Run Unique Instance segmentation on the uploaded frames. Tracks a specific object identified by points or text description. """ if not input_files: return None, "Error: Process input first." # Wait for HF models to be downloaded if t_hf.is_alive(): logger.info("Waiting for HF models download to finish...") t_hf.join() try: logger.info("Processing inputs on GPU node...") temp_dir = tempfile.mkdtemp() # Re-extract frames to ensure they exist on GPU ephemeral storage frames_dir, _ = process_inputs_to_frames(input_files, temp_dir) logger.info("Initializing UniqueInstanceSegmenter...") segmenter = UniqueInstanceSegmenter( sam_encoder=sam_encoder, memory_cleanup_interval=int(cleanup_interval), device="cuda" ) if offload_gpu: segmenter.optimize_cuda_memory() annotation_frame = f"frame_{int(frame_idx_slider):05d}.jpg" if not os.path.exists(os.path.join(frames_dir, annotation_frame)): return None, f"Error: Frame {annotation_frame} not found." # Run segmentation based on input type if text_prompt.strip(): logger.info(f"Mode: Text -> {text_prompt}") result = segmenter.segment( frames_path=frames_dir, text=text_prompt, annotation_frame=annotation_frame, offload_frames_to_gpu=offload_gpu ) else: if not points_state: return None, "Please add points or text." logger.info(f"Mode: Points -> {points_state}") result = segmenter.segment( frames_path=frames_dir, points=points_state, annotation_frame=annotation_frame, offload_frames_to_gpu=offload_gpu ) output_vid = os.path.join(OUTPUT_BASE_DIR, "unique_output.mp4") return create_video_overlay(frames_dir, result.masks, output_vid), f"Completed. {result.num_frames} frames processed." except Exception as e: import traceback traceback.print_exc() logger.error(str(e)) if isinstance(e, gr.Error): raise e return None, f"Error: {str(e)}" # =========================================== # GENERIC CATEGORY SEGMENTATION # =========================================== @spaces.GPU(duration=180) def run_generic_segmentation(input_files, category, accept_thresh, reject_thresh, vlm_model_name): """ Run Generic Category segmentation on the uploaded frames. Detects all instances of a specified category using VLM + segmentation. IMPORTANT: This function starts Ollama server INSIDE the GPU context, ensuring that Ollama can detect and use the GPU for inference. """ if not input_files: return None, "Error: Upload input." if not category.strip(): return None, "Error: Please specify text." # Wait for model downloads to complete if t_hf.is_alive(): logger.info("Waiting for HF models download...") t_hf.join() if t_ollama.is_alive(): logger.info("Waiting for Ollama models download...") t_ollama.join() try: # ========================================================= # CRITICAL: Start Ollama INSIDE @spaces.GPU context # This ensures Ollama detects and uses the GPU! # ========================================================= logger.info("=" * 50) logger.info("Starting Ollama server with GPU support...") logger.info("=" * 50) ensure_ollama_ready_gpu(vlm_model_name) logger.info("Ollama is running with GPU. Processing inputs...") # Process input frames temp_dir = tempfile.mkdtemp() frames_dir, _ = process_inputs_to_frames(input_files, temp_dir) logger.info(f"Initializing GenericCategorySegmenter with VLM: {vlm_model_name}") segmenter = GenericCategorySegmenter( device="cuda", vlm_model=vlm_model_name ) logger.info(f"Detecting category: {category}") result = segmenter.segment( frames_path=frames_dir, category=category, accept_threshold=accept_thresh, reject_threshold=reject_thresh, save_debug=False ) output_vid = os.path.join(OUTPUT_BASE_DIR, "generic_output.mp4") total_detections = sum(len(d) for d in result.metadata['detections'].values()) return create_video_overlay(frames_dir, result.masks, output_vid), f"Completed! Total detections: {total_detections}" except Exception as e: import traceback traceback.print_exc() logger.error(f"Generic segmentation error: {e}") if isinstance(e, gr.Error): raise e return None, f"Error: {e}" # =========================================== # GRADIO UI # =========================================== with gr.Blocks(title="ENEAS: Embedding-guided Neural Ensemble for Adaptive Segmentation") as demo: gr.Markdown( f""" # ⚡ ENEAS: Embedding-guided Neural Ensemble for Adaptive Segmentation **⚠️ IMPORTANT LIMITS:** - Maximum **{MAX_FRAMES} FRAMES** to prevent ZeroGPU timeouts - Videos are sampled at **1 FPS** → Max **{MAX_FRAMES} seconds** of video - Exceeding these limits will stop execution """ ) with gr.Tabs(): # =========================================== # TAB 1: UNIQUE INSTANCE SEGMENTATION # =========================================== with gr.Tab("🎯 Unique Instance"): gr.Markdown("Track a specific object. Upload Video (1 FPS extraction) OR Images.") with gr.Row(): with gr.Column(scale=1): u_file = gr.File( label="Input (Video or Images)", file_count="multiple", file_types=["video", "image"] ) u_btn_proc = gr.Button("▶️ 1. Process Input (Extract 1 FPS)", variant="secondary") u_slider = gr.Slider(label="Frame Selector", visible=False) with gr.Accordion("Advanced Options", open=False): u_enc = gr.Dropdown( ["long-large", "long-small"], value="long-large", label="SAM2 Encoder" ) u_offload = gr.Checkbox(label="GPU Memory Offload", value=False) with gr.Column(scale=2): u_path_frames_cpu = gr.Textbox(visible=False) points_state = gr.State([]) u_img = gr.Image( label="Reference Frame (Click to add points)", interactive=True ) u_txt = gr.Textbox( label="Text Description (Grounding)", placeholder="Points are ignored if text is provided." ) u_btn_run = gr.Button("🚀 2. Run Segmentation", variant="primary") u_out = gr.Video(label="Result") u_status = gr.Textbox(label="Status", interactive=False) # Event handlers u_btn_proc.click( process_unique_upload, [u_file], [u_img, u_path_frames_cpu, points_state, u_status, u_slider] ) u_slider.change( update_canvas_from_slider, inputs=[u_slider, u_path_frames_cpu], outputs=[u_img, points_state] ) u_img.select(add_point, [u_img, points_state], [u_img, points_state]) u_btn_run.click( run_unique_segmentation, [u_file, points_state, u_txt, u_enc, u_offload, gr.Number(10, visible=False), u_slider], [u_out, u_status] ) # Example for Unique Instance gr.Examples( examples=[ [["examples/reporter.mp4"], "blonde woman with microphone"] ], inputs=[u_file, u_txt], label="Example" ) # =========================================== # TAB 2: GENERIC CATEGORY SEGMENTATION # =========================================== with gr.Tab("🔮 Generic Text"): gr.Markdown( f""" Detect all instances of a text prompt in every frame (Max {MAX_FRAMES} frames). **🚀 GPU-Accelerated:** Ollama VLM runs on GPU for fast inference. First request includes ~15-20s server startup time. """ ) with gr.Row(): g_file = gr.File( label="Input (Video or Images)", file_count="multiple", file_types=["video", "image"] ) g_cat = gr.Textbox( label="Text prompt", placeholder="e.g., person, chair, car, dog" ) g_btn = gr.Button("🔍 Run Detection", variant="primary") with gr.Accordion("Detection Settings", open=True): g_accept = gr.Slider( 0.0, 1.0, value=0.30, label="Accept Threshold", info="Higher = more confident detections only" ) g_reject = gr.Slider( 0.0, 1.0, value=0.1, label="Reject Threshold", info="Lower = filter out more false positives" ) g_vlm = gr.Dropdown( choices=VLM_MODELS, value=VLM_MODELS[0], label="VLM Model", info="Larger models are more accurate but slower" ) g_out = gr.Video(label="Result") g_stat = gr.Textbox(label="Detection Statistics", interactive=False) g_btn.click( run_generic_segmentation, [g_file, g_cat, g_accept, g_reject, g_vlm], [g_out, g_stat] ) # Example for Generic Category gr.Examples( examples=[ [["examples/moving.mp4"], "person", 0.3, 0.1, "qwen3-vl:4b-instruct-q8_0"] ], inputs=[g_file, g_cat, g_accept, g_reject, g_vlm], label="Example" ) # =========================================== # MAIN ENTRY POINT # =========================================== if __name__ == "__main__": demo.launch()