viclip-zerogpu / app.py
magboola's picture
fix: expose api_embed_video, api_embed_text, api_embed_frames as Gradio API endpoints for EagleEye
9919202 verified
"""
ViCLIP ZeroGPU Space for Cadayn/EagleEye
Video-text embedding using ViCLIP for semantic video search.
Better than CLIP for temporal understanding.
Features:
- Video segment embeddings
- Text query embeddings
- Similarity search
- Multi-frame temporal pooling
API Endpoints:
- /api/embed_video - Get video embedding
- /api/embed_text - Get text embedding
- /api/similarity - Compute video-text similarity
"""
from __future__ import annotations
import base64
import io
import os
import tempfile
import traceback
from typing import Any
import cv2
import gradio as gr
import numpy as np
import spaces
import torch
from PIL import Image
from transformers import AutoModel, AutoProcessor, CLIPTokenizer
MODEL_ID = "OpenGVLab/ViCLIP-L-14-hf"
EMBEDDING_DIM = 768
model = None
processor = None
tokenizer = None
def load_model():
"""Load ViCLIP model."""
global model, processor, tokenizer
if model is None:
print(f"Loading {MODEL_ID}...")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModel.from_pretrained(
MODEL_ID,
trust_remote_code=True,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
)
model = model.to(device).eval()
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
print(f"ViCLIP loaded on {device}")
return model, processor, tokenizer
def extract_frames(video_path: str, num_frames: int = 8) -> list[Image.Image]:
"""Extract frames from video at uniform intervals."""
cap = cv2.VideoCapture(video_path)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if total_frames == 0:
cap.release()
return []
frame_indices = [int(i * total_frames / num_frames) for i in range(num_frames)]
frames = []
for idx in frame_indices:
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
ret, frame = cap.read()
if ret:
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(Image.fromarray(frame_rgb))
cap.release()
return frames
@spaces.GPU(duration=120)
def embed_video_file(video_file, num_frames: int = 8) -> str:
"""
Get embedding for a video file.
Args:
video_file: Video file from Gradio
num_frames: Number of frames to sample
Returns:
Embedding as comma-separated string
"""
try:
if video_file is None:
return "Please upload a video."
if isinstance(video_file, str):
video_path = video_file
elif hasattr(video_file, 'name'):
video_path = video_file.name
else:
return f"Error: Unexpected file type: {type(video_file)}"
model, processor, _ = load_model()
frames = extract_frames(video_path, num_frames)
if not frames:
return "Error: Could not extract frames from video."
inputs = processor(images=frames, return_tensors="pt")
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
if hasattr(model, "get_image_features"):
frame_embeddings = model.get_image_features(**inputs)
else:
outputs = model.vision_model(**inputs)
frame_embeddings = outputs.pooler_output
video_embedding = frame_embeddings.mean(dim=0)
video_embedding = video_embedding / video_embedding.norm()
embedding_list = video_embedding.cpu().numpy().tolist()
return ",".join(f"{v:.6f}" for v in embedding_list)
except Exception as e:
error_msg = f"Error embedding video: {str(e)}\n{traceback.format_exc()}"
print(error_msg)
return error_msg
@spaces.GPU(duration=60)
def embed_text_query(text: str) -> str:
"""
Get embedding for a text query.
Args:
text: Text query
Returns:
Embedding as comma-separated string
"""
try:
if not text or not text.strip():
return "Please enter a text query."
model, _, tokenizer = load_model()
tokens = tokenizer(
[text],
return_tensors="pt",
padding=True,
truncation=True,
max_length=77,
)
tokens = {k: v.to(model.device) for k, v in tokens.items()}
with torch.no_grad():
if hasattr(model, "get_text_features"):
text_embedding = model.get_text_features(**tokens)
else:
outputs = model.text_model(**tokens)
text_embedding = outputs.pooler_output
text_embedding = text_embedding / text_embedding.norm(dim=-1, keepdim=True)
embedding_list = text_embedding.squeeze(0).cpu().numpy().tolist()
return ",".join(f"{v:.6f}" for v in embedding_list)
except Exception as e:
error_msg = f"Error embedding text: {str(e)}\n{traceback.format_exc()}"
print(error_msg)
return error_msg
@spaces.GPU(duration=180)
def api_embed_video(
video_url: str | None = None,
video_base64: str | None = None,
num_frames: int = 8,
) -> dict[str, Any]:
"""
API endpoint for video embedding from EagleEye.
Args:
video_url: URL to video file
video_base64: Base64 encoded video
num_frames: Number of frames to sample
Returns:
JSON response with embedding vector
"""
try:
video_path = None
temp_file = None
if video_url:
import requests
response = requests.get(video_url, timeout=120, stream=True)
response.raise_for_status()
temp_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
for chunk in response.iter_content(chunk_size=8192):
temp_file.write(chunk)
temp_file.close()
video_path = temp_file.name
elif video_base64:
video_bytes = base64.b64decode(video_base64)
temp_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
temp_file.write(video_bytes)
temp_file.close()
video_path = temp_file.name
else:
return {"error": "No video provided", "success": False}
model, processor, _ = load_model()
frames = extract_frames(video_path, num_frames)
if not frames:
if temp_file:
os.unlink(temp_file.name)
return {"error": "Could not extract frames from video", "success": False}
inputs = processor(images=frames, return_tensors="pt")
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
if hasattr(model, "get_image_features"):
frame_embeddings = model.get_image_features(**inputs)
else:
outputs = model.vision_model(**inputs)
frame_embeddings = outputs.pooler_output
video_embedding = frame_embeddings.mean(dim=0)
video_embedding = video_embedding / video_embedding.norm()
if temp_file:
os.unlink(temp_file.name)
embedding_list = video_embedding.cpu().numpy().tolist()
return {
"success": True,
"embedding": embedding_list,
"dim": len(embedding_list),
"frames_sampled": len(frames),
"model": MODEL_ID,
}
except Exception as e:
if temp_file and os.path.exists(temp_file.name):
os.unlink(temp_file.name)
return {"error": str(e), "success": False, "traceback": traceback.format_exc()}
@spaces.GPU(duration=60)
def api_embed_text(text: str) -> dict[str, Any]:
"""
API endpoint for text embedding from EagleEye.
Args:
text: Text query to embed
Returns:
JSON response with embedding vector
"""
try:
if not text or not text.strip():
return {"error": "Text is required", "success": False}
model, _, tokenizer = load_model()
tokens = tokenizer(
[text],
return_tensors="pt",
padding=True,
truncation=True,
max_length=77,
)
tokens = {k: v.to(model.device) for k, v in tokens.items()}
with torch.no_grad():
if hasattr(model, "get_text_features"):
text_embedding = model.get_text_features(**tokens)
else:
outputs = model.text_model(**tokens)
text_embedding = outputs.pooler_output
text_embedding = text_embedding / text_embedding.norm(dim=-1, keepdim=True)
embedding_list = text_embedding.squeeze(0).cpu().numpy().tolist()
return {
"success": True,
"embedding": embedding_list,
"dim": len(embedding_list),
"text": text,
"model": MODEL_ID,
}
except Exception as e:
return {"error": str(e), "success": False, "traceback": traceback.format_exc()}
@spaces.GPU(duration=180)
def api_embed_frames(
frames_base64: list[str],
timestamps: list[float] | None = None,
) -> dict[str, Any]:
"""
API endpoint for embedding multiple frames from EagleEye.
Used for segment-level embeddings in video search.
Args:
frames_base64: List of base64 encoded frames
timestamps: Optional timestamps for each frame
Returns:
JSON response with per-frame embeddings
"""
try:
if not frames_base64:
return {"error": "No frames provided", "success": False}
model, processor, _ = load_model()
frames = []
for frame_b64 in frames_base64:
image_bytes = base64.b64decode(frame_b64)
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
frames.append(image)
inputs = processor(images=frames, return_tensors="pt")
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
if hasattr(model, "get_image_features"):
frame_embeddings = model.get_image_features(**inputs)
else:
outputs = model.vision_model(**inputs)
frame_embeddings = outputs.pooler_output
frame_embeddings = frame_embeddings / frame_embeddings.norm(dim=-1, keepdim=True)
embeddings_list = frame_embeddings.cpu().numpy().tolist()
pooled = frame_embeddings.mean(dim=0)
pooled = pooled / pooled.norm()
pooled_list = pooled.cpu().numpy().tolist()
result = {
"success": True,
"frame_embeddings": embeddings_list,
"pooled_embedding": pooled_list,
"dim": len(pooled_list),
"num_frames": len(frames),
"model": MODEL_ID,
}
if timestamps:
result["timestamps"] = timestamps
return result
except Exception as e:
return {"error": str(e), "success": False, "traceback": traceback.format_exc()}
with gr.Blocks(title="ViCLIP for Cadayn") as demo:
gr.Markdown("""
# ViCLIP - Video-Text Embeddings
Powered by [ViCLIP-L-14](https://huggingface.co/OpenGVLab/ViCLIP-L-14-hf) on ZeroGPU.
**Capabilities:**
- Video segment embeddings (768-dim)
- Text query embeddings
- Temporal-aware video understanding
- Semantic video search
**API Endpoints for EagleEye:**
- `POST /call/api_embed_video` - Video segment embedding
- `POST /call/api_embed_text` - Text query embedding
- `POST /call/api_embed_frames` - Multi-frame embeddings
""")
# Hidden API interfaces for EagleEye integration
with gr.Row(visible=False):
# Video embedding API
api_vid_url = gr.Textbox()
api_vid_b64 = gr.Textbox()
api_vid_frames = gr.Number(value=8)
api_vid_output = gr.JSON()
# Text embedding API
api_text_input = gr.Textbox()
api_text_output = gr.JSON()
# Frames embedding API
api_frames_input = gr.Textbox()
api_frames_ts = gr.Textbox()
api_frames_output = gr.JSON()
api_vid_url.change(
fn=api_embed_video,
inputs=[api_vid_url, api_vid_b64, api_vid_frames],
outputs=api_vid_output,
api_name="api_embed_video",
)
api_text_input.change(
fn=api_embed_text,
inputs=[api_text_input],
outputs=api_text_output,
api_name="api_embed_text",
)
api_frames_input.change(
fn=api_embed_frames,
inputs=[api_frames_input, api_frames_ts],
outputs=api_frames_output,
api_name="api_embed_frames",
)
with gr.Tab("Video Embedding"):
with gr.Row():
with gr.Column():
video_input = gr.File(
label="Upload Video",
file_types=[".mp4", ".avi", ".mov", ".mkv", ".webm"],
)
num_frames_slider = gr.Slider(
minimum=4,
maximum=32,
value=8,
step=4,
label="Number of Frames",
)
video_btn = gr.Button("Get Embedding", variant="primary")
with gr.Column():
video_output = gr.Textbox(label="Embedding (768-dim)", lines=10)
video_btn.click(
fn=embed_video_file,
inputs=[video_input, num_frames_slider],
outputs=video_output,
)
with gr.Tab("Text Embedding"):
with gr.Row():
with gr.Column():
text_input = gr.Textbox(
label="Text Query",
placeholder="e.g., a person scoring a goal",
lines=2,
)
text_btn = gr.Button("Get Embedding", variant="primary")
with gr.Column():
text_output = gr.Textbox(label="Embedding (768-dim)", lines=10)
text_btn.click(
fn=embed_text_query,
inputs=[text_input],
outputs=text_output,
)
with gr.Tab("API"):
gr.Markdown("""
## API Usage for EagleEye Integration
### Video Embedding
```python
from gradio_client import Client
client = Client("Cadayn/viclip-zerogpu")
result = client.predict(
video_url="https://example.com/clip.mp4",
num_frames=8,
api_name="/api_embed_video"
)
print(result)
# {"success": True, "embedding": [...], "dim": 768, ...}
```
### Text Embedding
```python
result = client.predict(
text="a soccer player scoring a goal",
api_name="/api_embed_text"
)
print(result)
# {"success": True, "embedding": [...], "dim": 768, ...}
```
### Multi-Frame Embeddings
```python
result = client.predict(
frames_base64=["frame1_b64", "frame2_b64", ...],
timestamps=[0.0, 1.0, 2.0, ...],
api_name="/api_embed_frames"
)
print(result)
# {"success": True, "frame_embeddings": [[...], [...]], "pooled_embedding": [...], ...}
```
""")
if __name__ == "__main__":
demo.launch()