eneas / app.py
javipd99's picture
fix
dcb39c7
import gradio as gr
import os
import cv2
import shutil
import tempfile
import numpy as np
import subprocess
import time
import threading
import torch
import sys
import logging
from PIL import Image
# ===========================================
# LOGGING CONFIGURATION
# ===========================================
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)
# Ensure Python sees the local 'eneas' folder
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
import spaces
try:
from eneas.segmentation import UniqueInstanceSegmenter, GenericCategorySegmenter
from eneas.segmentation.model_manager import ModelManager
except ImportError as e:
logger.error(f"Error importing ENEAS: {e}")
raise e
# ===========================================
# CONSTANTS
# ===========================================
MAX_FRAMES = 150 # Limit frames to avoid ZeroGPU Timeout (~1s/frame processing)
OLLAMA_HOST = "127.0.0.1:11434"
OLLAMA_URL = f"http://{OLLAMA_HOST}"
OLLAMA_BIN = "./bin/ollama"
VLM_MODELS = [
"qwen3-vl:4b-instruct-q8_0",
"qwen3-vl:2b-instruct-q8_0"
]
OUTPUT_BASE_DIR = "gradio_outputs"
os.makedirs(OUTPUT_BASE_DIR, exist_ok=True)
# ===========================================
# OLLAMA FUNCTIONS (FOR USE INSIDE @spaces.GPU)
# ===========================================
def get_ollama_env():
"""Get environment variables for Ollama process with GPU support."""
env = os.environ.copy()
env["OLLAMA_HOST"] = OLLAMA_HOST
env["OLLAMA_ORIGINS"] = "*"
env["HOME"] = os.getcwd()
# Add local lib path for the extracted binary
cwd = os.getcwd()
lib_path = f"{cwd}/lib"
if "LD_LIBRARY_PATH" in env:
env["LD_LIBRARY_PATH"] += f":{lib_path}"
else:
env["LD_LIBRARY_PATH"] = lib_path
return env
def is_ollama_server_running() -> bool:
"""Check if Ollama server is responding."""
try:
result = subprocess.run(
["curl", "-s", "-o", "/dev/null", "-w", "%{http_code}", OLLAMA_URL],
capture_output=True,
text=True,
timeout=5
)
return result.stdout.strip() == "200"
except Exception:
return False
def start_ollama_server_gpu():
"""
Start Ollama server INSIDE @spaces.GPU context.
This ensures Ollama detects and uses the GPU.
Returns:
bool: True if server started successfully
"""
if is_ollama_server_running():
logger.info("Ollama server is already running.")
return True
logger.info("Starting Ollama server inside GPU context...")
try:
env = get_ollama_env()
# Start server as background process
process = subprocess.Popen(
[OLLAMA_BIN, "serve"],
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE
)
# Wait for server to be ready (max 30 seconds)
max_retries = 30
for i in range(max_retries):
if is_ollama_server_running():
logger.info(f"Ollama server started successfully in {i+1} seconds.")
return True
time.sleep(1)
logger.error("Ollama server failed to start within 30 seconds.")
return False
except Exception as e:
logger.error(f"Failed to start Ollama server: {e}")
return False
def load_model_into_vram(model_name: str) -> bool:
"""
Load model into VRAM for faster inference.
Uses keep_alive=-1 to keep model loaded.
Args:
model_name: Name of the Ollama model to load
Returns:
bool: True if model loaded successfully
"""
logger.info(f"Loading model {model_name} into VRAM...")
try:
# Send a minimal request to trigger model loading
result = subprocess.run(
[
"curl", "-s", f"{OLLAMA_URL}/api/generate",
"-d", f'{{"model": "{model_name}", "prompt": "hi", "stream": false}}'
],
capture_output=True,
text=True,
timeout=120 # Model loading can take time
)
if "error" in result.stdout.lower():
logger.error(f"Error loading model: {result.stdout}")
return False
# Set keep_alive to -1 to keep model in VRAM
subprocess.run(
[
"curl", "-s", f"{OLLAMA_URL}/api/generate",
"-d", f'{{"model": "{model_name}", "keep_alive": -1}}'
],
capture_output=True,
timeout=10
)
logger.info(f"Model {model_name} loaded into VRAM successfully.")
return True
except subprocess.TimeoutExpired:
logger.error("Timeout while loading model into VRAM.")
return False
except Exception as e:
logger.error(f"Error loading model into VRAM: {e}")
return False
def log_active_models():
"""Log which models are currently loaded in VRAM (not just on disk)."""
try:
result = subprocess.run(
["curl", "-s", f"{OLLAMA_URL}/api/ps"],
capture_output=True,
text=True,
timeout=5
)
logger.info(f"Active models in VRAM: {result.stdout}")
except Exception as e:
logger.warning(f"Could not get active models: {e}")
def ensure_ollama_ready_gpu(model_name: str) -> bool:
"""
Main function to ensure Ollama is fully ready with GPU support.
MUST be called inside @spaces.GPU decorated function.
This function:
1. Starts Ollama server (which will detect GPU)
2. Loads the specified model into VRAM
3. Logs which model is active
Args:
model_name: Name of the Ollama model to use
Returns:
bool: True if ready
Raises:
RuntimeError: If setup fails
"""
logger.info(f"Ensuring Ollama is ready with GPU for model: {model_name}")
# Step 1: Start server (will detect GPU since we're inside @spaces.GPU)
if not start_ollama_server_gpu():
raise RuntimeError("Failed to start Ollama server with GPU")
# Step 2: Load model into VRAM
if not load_model_into_vram(model_name):
raise RuntimeError(f"Failed to load model {model_name} into VRAM")
# Step 3: Log which model is actually active in VRAM
log_active_models()
logger.info("Ollama is ready with GPU support!")
return True
# ===========================================
# STARTUP: DOWNLOAD BINARY AND MODELS (CPU)
# ===========================================
def download_ollama_binary():
"""Download Ollama binary if not present."""
if os.path.exists(OLLAMA_BIN):
logger.info("Ollama binary already exists.")
return True
logger.info("Downloading Ollama binary (ZST)...")
try:
subprocess.run(
["curl", "-L", "https://ollama.com/download/ollama-linux-amd64.tar.zst", "-o", "ollama.tar.zst"],
check=True,
timeout=300
)
subprocess.run(["tar", "--zstd", "-xf", "ollama.tar.zst"], check=True)
subprocess.run(["chmod", "+x", OLLAMA_BIN], check=True)
os.remove("ollama.tar.zst") # Cleanup
logger.info("Ollama binary downloaded and extracted successfully.")
return True
except Exception as e:
logger.error(f"Failed to download Ollama binary: {e}")
return False
def pull_ollama_models():
"""
Pull Ollama models at startup (runs on CPU).
This pre-downloads the models so they're ready when GPU is available.
"""
logger.info("Pre-downloading Ollama models...")
# Need to temporarily start server to pull models
env = get_ollama_env()
# Start server temporarily
server_process = subprocess.Popen(
[OLLAMA_BIN, "serve"],
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE
)
# Wait for server
time.sleep(5)
for _ in range(20):
if is_ollama_server_running():
break
time.sleep(1)
# Pull each model
for model in VLM_MODELS:
logger.info(f"Pulling model: {model}")
try:
subprocess.run(
[OLLAMA_BIN, "pull", model],
env=env,
timeout=600,
capture_output=True
)
logger.info(f"Model {model} pulled successfully.")
except Exception as e:
logger.warning(f"Failed to pull model {model}: {e}")
# Stop server (we'll restart it inside GPU context later)
server_process.terminate()
try:
server_process.wait(timeout=5)
except subprocess.TimeoutExpired:
server_process.kill()
logger.info("Ollama models pre-download complete.")
def setup_ollama_startup():
"""Setup Ollama at startup: download binary and pull models."""
download_ollama_binary()
pull_ollama_models()
def setup_hf_models():
"""
Downloads heavy HuggingFace models to disk at startup.
This prevents ZeroGPU timeouts during the first inference.
"""
logger.info("Starting HuggingFace models download (Warm-up)...")
try:
manager = ModelManager()
# 1. SeC-4B (Heavy, ~15GB)
logger.info("Downloading SeC-4B...")
manager.download("OpenIXCLab/SeC-4B")
# 2. Florence-2 (Grounding)
logger.info("Downloading Florence-2...")
manager.download("microsoft/Florence-2-large")
# 3. SigLIP (For Generic Category)
logger.info("Downloading SigLIP...")
manager.download("google/siglip2-base-patch16-naflex")
# 4. SAM2 Checkpoint (Direct URL)
logger.info("Downloading SAM2 checkpoint...")
manager.download_url(
"https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt",
"sam2.1_hiera_large.pt"
)
logger.info("All HuggingFace models downloaded successfully.")
except Exception as e:
logger.error(f"Error during HF model download: {e}")
# ===========================================
# STARTUP: PARALLEL MODEL DOWNLOADS
# ===========================================
logger.info("Starting parallel model downloads at startup...")
t_hf = threading.Thread(target=setup_hf_models, daemon=True)
t_ollama = threading.Thread(target=setup_ollama_startup, daemon=True)
t_hf.start()
t_ollama.start()
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Main process device detection: {DEVICE}")
# ===========================================
# UTILITY FUNCTIONS
# ===========================================
def process_inputs_to_frames(input_data, output_folder: str) -> tuple:
"""
Extracts frames from video (1 FPS) or copies images to output folder.
Enforces MAX_FRAMES limit to prevent ZeroGPU timeouts.
Args:
input_data: Video file or list of image files
output_folder: Directory to save extracted frames
Returns:
tuple: (output_folder path, list of frame file paths)
"""
if os.path.exists(output_folder):
shutil.rmtree(output_folder)
os.makedirs(output_folder)
frame_paths = []
video_extensions = {'.mp4', '.avi', '.mov', '.mkv', '.webm'}
input_list = input_data if isinstance(input_data, list) else [input_data]
if not input_list:
return output_folder, []
first_file = input_list[0].name if hasattr(input_list[0], 'name') else str(input_list[0])
ext = os.path.splitext(first_file)[1].lower()
if ext in video_extensions:
# Process video file
logger.info(f"Processing video: {first_file}...")
cap = cv2.VideoCapture(first_file)
video_fps = cap.get(cv2.CAP_PROP_FPS)
total_frames_original = cap.get(cv2.CAP_PROP_FRAME_COUNT)
if video_fps == 0 or np.isnan(video_fps):
video_fps = 30
duration_sec = total_frames_original / video_fps
# Validate video duration
if duration_sec > MAX_FRAMES:
cap.release()
msg = f"Video is too long ({int(duration_sec)}s). Max allowed is {MAX_FRAMES}s to avoid ZeroGPU timeout."
logger.error(msg)
raise gr.Error(msg)
# Sample at 1 FPS
frame_interval = max(1, int(video_fps))
count = 0
saved_count = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
if count % frame_interval == 0:
filename = f"frame_{saved_count:05d}.jpg"
filepath = os.path.join(output_folder, filename)
cv2.imwrite(filepath, frame)
frame_paths.append(filepath)
saved_count += 1
if saved_count > MAX_FRAMES:
cap.release()
raise gr.Error(f"Limit reached: > {MAX_FRAMES} frames extracted.")
count += 1
cap.release()
logger.info(f"Video sampled at 1 FPS. Total frames: {saved_count}")
else:
# Process image files
if len(input_list) > MAX_FRAMES:
raise gr.Error(f"Too many images! You uploaded {len(input_list)}. Max allowed is {MAX_FRAMES}.")
logger.info(f"Processing {len(input_list)} images...")
input_list.sort(key=lambda x: x.name if hasattr(x, 'name') else str(x))
for i, f in enumerate(input_list):
path = f.name if hasattr(f, 'name') else str(f)
try:
img = Image.open(path).convert("RGB")
filename = f"frame_{i:05d}.jpg"
filepath = os.path.join(output_folder, filename)
img.save(filepath)
frame_paths.append(filepath)
except Exception as e:
logger.warning(f"Skipping file {path}: {e}")
return output_folder, frame_paths
def create_video_overlay(frames_folder: str, masks_dict: dict, output_path: str, fps: int = 5) -> str:
"""
Creates a video from frames with segmentation masks overlaid.
Args:
frames_folder: Directory containing frame images
masks_dict: Dictionary mapping frame index to mask arrays
output_path: Output video file path
fps: Frames per second for output video
Returns:
Output video path or None if failed
"""
logger.info("Generating result video...")
frame_files = sorted([f for f in os.listdir(frames_folder) if f.endswith(".jpg")])
if not frame_files:
return None
first_frame = cv2.imread(os.path.join(frames_folder, frame_files[0]))
height, width, _ = first_frame.shape
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
# Orange/Gold color for mask overlay
mask_color = np.array([255, 100, 0], dtype=np.uint8)
for i, filename in enumerate(frame_files):
frame = cv2.imread(os.path.join(frames_folder, filename))
mask_overlay = np.zeros_like(frame)
if i in masks_dict:
masks_data = masks_dict[i]
masks_list = [masks_data] if isinstance(masks_data, np.ndarray) else (
masks_data if isinstance(masks_data, list) else []
)
for mask in masks_list:
mask_overlay[mask > 0] = mask_color
if np.any(mask_overlay):
frame = cv2.addWeighted(frame, 1, mask_overlay, 0.5, 0)
out.write(frame)
out.release()
return output_path
# ===========================================
# UNIQUE INSTANCE SEGMENTATION
# ===========================================
def process_unique_upload(input_files):
"""
Process uploaded files for Unique Instance segmentation.
Extracts frames and prepares the UI for annotation.
"""
if not input_files:
return None, None, [], "Please upload files first.", gr.Slider(value=0, maximum=0, visible=False)
temp_dir = tempfile.mkdtemp()
frames_dir, frame_paths = process_inputs_to_frames(input_files, temp_dir)
num_frames = len(frame_paths)
if num_frames == 0:
return None, None, [], "No frames extracted.", gr.Slider(value=0, maximum=0, visible=False)
new_slider = gr.Slider(
value=0,
minimum=0,
maximum=num_frames - 1,
step=1,
visible=True,
interactive=True,
label=f"Select Reference Frame (0 - {num_frames - 1})"
)
return frame_paths[0], frames_dir, [], f"Processed {num_frames} frames (1 FPS). Select target.", new_slider
def update_canvas_from_slider(frame_idx, frames_dir):
"""Update the displayed frame when slider changes."""
if not frames_dir or not os.path.exists(frames_dir):
return None, []
filename = f"frame_{int(frame_idx):05d}.jpg"
path = os.path.join(frames_dir, filename)
if os.path.exists(path):
img = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB)
return img, []
return None, []
def add_point(img, evt: gr.SelectData, points_state):
"""Add a point annotation to the image."""
x, y = evt.index
points_state.append((x, y))
img_pil = Image.fromarray(img)
img_cv = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
# Draw markers for all points
for px, py in points_state:
cv2.drawMarker(
img_cv, (px, py), (0, 255, 0),
markerType=cv2.MARKER_TILTED_CROSS,
markerSize=20,
thickness=3
)
return cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB), points_state
@spaces.GPU(duration=180)
def run_unique_segmentation(input_files, points_state, text_prompt, sam_encoder, offload_gpu, cleanup_interval, frame_idx_slider):
"""
Run Unique Instance segmentation on the uploaded frames.
Tracks a specific object identified by points or text description.
"""
if not input_files:
return None, "Error: Process input first."
# Wait for HF models to be downloaded
if t_hf.is_alive():
logger.info("Waiting for HF models download to finish...")
t_hf.join()
try:
logger.info("Processing inputs on GPU node...")
temp_dir = tempfile.mkdtemp()
# Re-extract frames to ensure they exist on GPU ephemeral storage
frames_dir, _ = process_inputs_to_frames(input_files, temp_dir)
logger.info("Initializing UniqueInstanceSegmenter...")
segmenter = UniqueInstanceSegmenter(
sam_encoder=sam_encoder,
memory_cleanup_interval=int(cleanup_interval),
device="cuda"
)
if offload_gpu:
segmenter.optimize_cuda_memory()
annotation_frame = f"frame_{int(frame_idx_slider):05d}.jpg"
if not os.path.exists(os.path.join(frames_dir, annotation_frame)):
return None, f"Error: Frame {annotation_frame} not found."
# Run segmentation based on input type
if text_prompt.strip():
logger.info(f"Mode: Text -> {text_prompt}")
result = segmenter.segment(
frames_path=frames_dir,
text=text_prompt,
annotation_frame=annotation_frame,
offload_frames_to_gpu=offload_gpu
)
else:
if not points_state:
return None, "Please add points or text."
logger.info(f"Mode: Points -> {points_state}")
result = segmenter.segment(
frames_path=frames_dir,
points=points_state,
annotation_frame=annotation_frame,
offload_frames_to_gpu=offload_gpu
)
output_vid = os.path.join(OUTPUT_BASE_DIR, "unique_output.mp4")
return create_video_overlay(frames_dir, result.masks, output_vid), f"Completed. {result.num_frames} frames processed."
except Exception as e:
import traceback
traceback.print_exc()
logger.error(str(e))
if isinstance(e, gr.Error):
raise e
return None, f"Error: {str(e)}"
# ===========================================
# GENERIC CATEGORY SEGMENTATION
# ===========================================
@spaces.GPU(duration=180)
def run_generic_segmentation(input_files, category, accept_thresh, reject_thresh, vlm_model_name):
"""
Run Generic Category segmentation on the uploaded frames.
Detects all instances of a specified category using VLM + segmentation.
IMPORTANT: This function starts Ollama server INSIDE the GPU context,
ensuring that Ollama can detect and use the GPU for inference.
"""
if not input_files:
return None, "Error: Upload input."
if not category.strip():
return None, "Error: Please specify text."
# Wait for model downloads to complete
if t_hf.is_alive():
logger.info("Waiting for HF models download...")
t_hf.join()
if t_ollama.is_alive():
logger.info("Waiting for Ollama models download...")
t_ollama.join()
try:
# =========================================================
# CRITICAL: Start Ollama INSIDE @spaces.GPU context
# This ensures Ollama detects and uses the GPU!
# =========================================================
logger.info("=" * 50)
logger.info("Starting Ollama server with GPU support...")
logger.info("=" * 50)
ensure_ollama_ready_gpu(vlm_model_name)
logger.info("Ollama is running with GPU. Processing inputs...")
# Process input frames
temp_dir = tempfile.mkdtemp()
frames_dir, _ = process_inputs_to_frames(input_files, temp_dir)
logger.info(f"Initializing GenericCategorySegmenter with VLM: {vlm_model_name}")
segmenter = GenericCategorySegmenter(
device="cuda",
vlm_model=vlm_model_name
)
logger.info(f"Detecting category: {category}")
result = segmenter.segment(
frames_path=frames_dir,
category=category,
accept_threshold=accept_thresh,
reject_threshold=reject_thresh,
save_debug=False
)
output_vid = os.path.join(OUTPUT_BASE_DIR, "generic_output.mp4")
total_detections = sum(len(d) for d in result.metadata['detections'].values())
return create_video_overlay(frames_dir, result.masks, output_vid), f"Completed! Total detections: {total_detections}"
except Exception as e:
import traceback
traceback.print_exc()
logger.error(f"Generic segmentation error: {e}")
if isinstance(e, gr.Error):
raise e
return None, f"Error: {e}"
# ===========================================
# GRADIO UI
# ===========================================
with gr.Blocks(title="ENEAS: Embedding-guided Neural Ensemble for Adaptive Segmentation") as demo:
gr.Markdown(
f"""
# ⚡ ENEAS: Embedding-guided Neural Ensemble for Adaptive Segmentation
**⚠️ IMPORTANT LIMITS:**
- Maximum **{MAX_FRAMES} FRAMES** to prevent ZeroGPU timeouts
- Videos are sampled at **1 FPS** → Max **{MAX_FRAMES} seconds** of video
- Exceeding these limits will stop execution
"""
)
with gr.Tabs():
# ===========================================
# TAB 1: UNIQUE INSTANCE SEGMENTATION
# ===========================================
with gr.Tab("🎯 Unique Instance"):
gr.Markdown("Track a specific object. Upload Video (1 FPS extraction) OR Images.")
with gr.Row():
with gr.Column(scale=1):
u_file = gr.File(
label="Input (Video or Images)",
file_count="multiple",
file_types=["video", "image"]
)
u_btn_proc = gr.Button("▶️ 1. Process Input (Extract 1 FPS)", variant="secondary")
u_slider = gr.Slider(label="Frame Selector", visible=False)
with gr.Accordion("Advanced Options", open=False):
u_enc = gr.Dropdown(
["long-large", "long-small"],
value="long-large",
label="SAM2 Encoder"
)
u_offload = gr.Checkbox(label="GPU Memory Offload", value=False)
with gr.Column(scale=2):
u_path_frames_cpu = gr.Textbox(visible=False)
points_state = gr.State([])
u_img = gr.Image(
label="Reference Frame (Click to add points)",
interactive=True
)
u_txt = gr.Textbox(
label="Text Description (Grounding)",
placeholder="Points are ignored if text is provided."
)
u_btn_run = gr.Button("🚀 2. Run Segmentation", variant="primary")
u_out = gr.Video(label="Result")
u_status = gr.Textbox(label="Status", interactive=False)
# Event handlers
u_btn_proc.click(
process_unique_upload,
[u_file],
[u_img, u_path_frames_cpu, points_state, u_status, u_slider]
)
u_slider.change(
update_canvas_from_slider,
inputs=[u_slider, u_path_frames_cpu],
outputs=[u_img, points_state]
)
u_img.select(add_point, [u_img, points_state], [u_img, points_state])
u_btn_run.click(
run_unique_segmentation,
[u_file, points_state, u_txt, u_enc, u_offload, gr.Number(10, visible=False), u_slider],
[u_out, u_status]
)
# Example for Unique Instance
gr.Examples(
examples=[
[["examples/reporter.mp4"], "blonde woman with microphone"]
],
inputs=[u_file, u_txt],
label="Example"
)
# ===========================================
# TAB 2: GENERIC CATEGORY SEGMENTATION
# ===========================================
with gr.Tab("🔮 Generic Text"):
gr.Markdown(
f"""
Detect all instances of a text prompt in every frame (Max {MAX_FRAMES} frames).
**🚀 GPU-Accelerated:** Ollama VLM runs on GPU for fast inference.
First request includes ~15-20s server startup time.
"""
)
with gr.Row():
g_file = gr.File(
label="Input (Video or Images)",
file_count="multiple",
file_types=["video", "image"]
)
g_cat = gr.Textbox(
label="Text prompt",
placeholder="e.g., person, chair, car, dog"
)
g_btn = gr.Button("🔍 Run Detection", variant="primary")
with gr.Accordion("Detection Settings", open=True):
g_accept = gr.Slider(
0.0, 1.0,
value=0.30,
label="Accept Threshold",
info="Higher = more confident detections only"
)
g_reject = gr.Slider(
0.0, 1.0,
value=0.1,
label="Reject Threshold",
info="Lower = filter out more false positives"
)
g_vlm = gr.Dropdown(
choices=VLM_MODELS,
value=VLM_MODELS[0],
label="VLM Model",
info="Larger models are more accurate but slower"
)
g_out = gr.Video(label="Result")
g_stat = gr.Textbox(label="Detection Statistics", interactive=False)
g_btn.click(
run_generic_segmentation,
[g_file, g_cat, g_accept, g_reject, g_vlm],
[g_out, g_stat]
)
# Example for Generic Category
gr.Examples(
examples=[
[["examples/moving.mp4"], "person", 0.3, 0.1, "qwen3-vl:4b-instruct-q8_0"]
],
inputs=[g_file, g_cat, g_accept, g_reject, g_vlm],
label="Example"
)
# ===========================================
# MAIN ENTRY POINT
# ===========================================
if __name__ == "__main__":
demo.launch()