Spaces:
Running on Zero
Running on Zero
| """ | |
| ViCLIP ZeroGPU Space for Cadayn/EagleEye | |
| Video-text embedding using ViCLIP for semantic video search. | |
| Better than CLIP for temporal understanding. | |
| Features: | |
| - Video segment embeddings | |
| - Text query embeddings | |
| - Similarity search | |
| - Multi-frame temporal pooling | |
| API Endpoints: | |
| - /api/embed_video - Get video embedding | |
| - /api/embed_text - Get text embedding | |
| - /api/similarity - Compute video-text similarity | |
| """ | |
| from __future__ import annotations | |
| import base64 | |
| import io | |
| import os | |
| import tempfile | |
| import traceback | |
| from typing import Any | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| import spaces | |
| import torch | |
| from PIL import Image | |
| from transformers import AutoModel, AutoProcessor, CLIPTokenizer | |
| MODEL_ID = "OpenGVLab/ViCLIP-L-14-hf" | |
| EMBEDDING_DIM = 768 | |
| model = None | |
| processor = None | |
| tokenizer = None | |
| def load_model(): | |
| """Load ViCLIP model.""" | |
| global model, processor, tokenizer | |
| if model is None: | |
| print(f"Loading {MODEL_ID}...") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = AutoModel.from_pretrained( | |
| MODEL_ID, | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16 if device == "cuda" else torch.float32, | |
| ) | |
| model = model.to(device).eval() | |
| processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) | |
| tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") | |
| print(f"ViCLIP loaded on {device}") | |
| return model, processor, tokenizer | |
| def extract_frames(video_path: str, num_frames: int = 8) -> list[Image.Image]: | |
| """Extract frames from video at uniform intervals.""" | |
| cap = cv2.VideoCapture(video_path) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| if total_frames == 0: | |
| cap.release() | |
| return [] | |
| frame_indices = [int(i * total_frames / num_frames) for i in range(num_frames)] | |
| frames = [] | |
| for idx in frame_indices: | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, idx) | |
| ret, frame = cap.read() | |
| if ret: | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frames.append(Image.fromarray(frame_rgb)) | |
| cap.release() | |
| return frames | |
| def embed_video_file(video_file, num_frames: int = 8) -> str: | |
| """ | |
| Get embedding for a video file. | |
| Args: | |
| video_file: Video file from Gradio | |
| num_frames: Number of frames to sample | |
| Returns: | |
| Embedding as comma-separated string | |
| """ | |
| try: | |
| if video_file is None: | |
| return "Please upload a video." | |
| if isinstance(video_file, str): | |
| video_path = video_file | |
| elif hasattr(video_file, 'name'): | |
| video_path = video_file.name | |
| else: | |
| return f"Error: Unexpected file type: {type(video_file)}" | |
| model, processor, _ = load_model() | |
| frames = extract_frames(video_path, num_frames) | |
| if not frames: | |
| return "Error: Could not extract frames from video." | |
| inputs = processor(images=frames, return_tensors="pt") | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| if hasattr(model, "get_image_features"): | |
| frame_embeddings = model.get_image_features(**inputs) | |
| else: | |
| outputs = model.vision_model(**inputs) | |
| frame_embeddings = outputs.pooler_output | |
| video_embedding = frame_embeddings.mean(dim=0) | |
| video_embedding = video_embedding / video_embedding.norm() | |
| embedding_list = video_embedding.cpu().numpy().tolist() | |
| return ",".join(f"{v:.6f}" for v in embedding_list) | |
| except Exception as e: | |
| error_msg = f"Error embedding video: {str(e)}\n{traceback.format_exc()}" | |
| print(error_msg) | |
| return error_msg | |
| def embed_text_query(text: str) -> str: | |
| """ | |
| Get embedding for a text query. | |
| Args: | |
| text: Text query | |
| Returns: | |
| Embedding as comma-separated string | |
| """ | |
| try: | |
| if not text or not text.strip(): | |
| return "Please enter a text query." | |
| model, _, tokenizer = load_model() | |
| tokens = tokenizer( | |
| [text], | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=77, | |
| ) | |
| tokens = {k: v.to(model.device) for k, v in tokens.items()} | |
| with torch.no_grad(): | |
| if hasattr(model, "get_text_features"): | |
| text_embedding = model.get_text_features(**tokens) | |
| else: | |
| outputs = model.text_model(**tokens) | |
| text_embedding = outputs.pooler_output | |
| text_embedding = text_embedding / text_embedding.norm(dim=-1, keepdim=True) | |
| embedding_list = text_embedding.squeeze(0).cpu().numpy().tolist() | |
| return ",".join(f"{v:.6f}" for v in embedding_list) | |
| except Exception as e: | |
| error_msg = f"Error embedding text: {str(e)}\n{traceback.format_exc()}" | |
| print(error_msg) | |
| return error_msg | |
| def api_embed_video( | |
| video_url: str | None = None, | |
| video_base64: str | None = None, | |
| num_frames: int = 8, | |
| ) -> dict[str, Any]: | |
| """ | |
| API endpoint for video embedding from EagleEye. | |
| Args: | |
| video_url: URL to video file | |
| video_base64: Base64 encoded video | |
| num_frames: Number of frames to sample | |
| Returns: | |
| JSON response with embedding vector | |
| """ | |
| try: | |
| video_path = None | |
| temp_file = None | |
| if video_url: | |
| import requests | |
| response = requests.get(video_url, timeout=120, stream=True) | |
| response.raise_for_status() | |
| temp_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) | |
| for chunk in response.iter_content(chunk_size=8192): | |
| temp_file.write(chunk) | |
| temp_file.close() | |
| video_path = temp_file.name | |
| elif video_base64: | |
| video_bytes = base64.b64decode(video_base64) | |
| temp_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) | |
| temp_file.write(video_bytes) | |
| temp_file.close() | |
| video_path = temp_file.name | |
| else: | |
| return {"error": "No video provided", "success": False} | |
| model, processor, _ = load_model() | |
| frames = extract_frames(video_path, num_frames) | |
| if not frames: | |
| if temp_file: | |
| os.unlink(temp_file.name) | |
| return {"error": "Could not extract frames from video", "success": False} | |
| inputs = processor(images=frames, return_tensors="pt") | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| if hasattr(model, "get_image_features"): | |
| frame_embeddings = model.get_image_features(**inputs) | |
| else: | |
| outputs = model.vision_model(**inputs) | |
| frame_embeddings = outputs.pooler_output | |
| video_embedding = frame_embeddings.mean(dim=0) | |
| video_embedding = video_embedding / video_embedding.norm() | |
| if temp_file: | |
| os.unlink(temp_file.name) | |
| embedding_list = video_embedding.cpu().numpy().tolist() | |
| return { | |
| "success": True, | |
| "embedding": embedding_list, | |
| "dim": len(embedding_list), | |
| "frames_sampled": len(frames), | |
| "model": MODEL_ID, | |
| } | |
| except Exception as e: | |
| if temp_file and os.path.exists(temp_file.name): | |
| os.unlink(temp_file.name) | |
| return {"error": str(e), "success": False, "traceback": traceback.format_exc()} | |
| def api_embed_text(text: str) -> dict[str, Any]: | |
| """ | |
| API endpoint for text embedding from EagleEye. | |
| Args: | |
| text: Text query to embed | |
| Returns: | |
| JSON response with embedding vector | |
| """ | |
| try: | |
| if not text or not text.strip(): | |
| return {"error": "Text is required", "success": False} | |
| model, _, tokenizer = load_model() | |
| tokens = tokenizer( | |
| [text], | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=77, | |
| ) | |
| tokens = {k: v.to(model.device) for k, v in tokens.items()} | |
| with torch.no_grad(): | |
| if hasattr(model, "get_text_features"): | |
| text_embedding = model.get_text_features(**tokens) | |
| else: | |
| outputs = model.text_model(**tokens) | |
| text_embedding = outputs.pooler_output | |
| text_embedding = text_embedding / text_embedding.norm(dim=-1, keepdim=True) | |
| embedding_list = text_embedding.squeeze(0).cpu().numpy().tolist() | |
| return { | |
| "success": True, | |
| "embedding": embedding_list, | |
| "dim": len(embedding_list), | |
| "text": text, | |
| "model": MODEL_ID, | |
| } | |
| except Exception as e: | |
| return {"error": str(e), "success": False, "traceback": traceback.format_exc()} | |
| def api_embed_frames( | |
| frames_base64: list[str], | |
| timestamps: list[float] | None = None, | |
| ) -> dict[str, Any]: | |
| """ | |
| API endpoint for embedding multiple frames from EagleEye. | |
| Used for segment-level embeddings in video search. | |
| Args: | |
| frames_base64: List of base64 encoded frames | |
| timestamps: Optional timestamps for each frame | |
| Returns: | |
| JSON response with per-frame embeddings | |
| """ | |
| try: | |
| if not frames_base64: | |
| return {"error": "No frames provided", "success": False} | |
| model, processor, _ = load_model() | |
| frames = [] | |
| for frame_b64 in frames_base64: | |
| image_bytes = base64.b64decode(frame_b64) | |
| image = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
| frames.append(image) | |
| inputs = processor(images=frames, return_tensors="pt") | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| if hasattr(model, "get_image_features"): | |
| frame_embeddings = model.get_image_features(**inputs) | |
| else: | |
| outputs = model.vision_model(**inputs) | |
| frame_embeddings = outputs.pooler_output | |
| frame_embeddings = frame_embeddings / frame_embeddings.norm(dim=-1, keepdim=True) | |
| embeddings_list = frame_embeddings.cpu().numpy().tolist() | |
| pooled = frame_embeddings.mean(dim=0) | |
| pooled = pooled / pooled.norm() | |
| pooled_list = pooled.cpu().numpy().tolist() | |
| result = { | |
| "success": True, | |
| "frame_embeddings": embeddings_list, | |
| "pooled_embedding": pooled_list, | |
| "dim": len(pooled_list), | |
| "num_frames": len(frames), | |
| "model": MODEL_ID, | |
| } | |
| if timestamps: | |
| result["timestamps"] = timestamps | |
| return result | |
| except Exception as e: | |
| return {"error": str(e), "success": False, "traceback": traceback.format_exc()} | |
| with gr.Blocks(title="ViCLIP for Cadayn") as demo: | |
| gr.Markdown(""" | |
| # ViCLIP - Video-Text Embeddings | |
| Powered by [ViCLIP-L-14](https://huggingface.co/OpenGVLab/ViCLIP-L-14-hf) on ZeroGPU. | |
| **Capabilities:** | |
| - Video segment embeddings (768-dim) | |
| - Text query embeddings | |
| - Temporal-aware video understanding | |
| - Semantic video search | |
| **API Endpoints for EagleEye:** | |
| - `POST /call/api_embed_video` - Video segment embedding | |
| - `POST /call/api_embed_text` - Text query embedding | |
| - `POST /call/api_embed_frames` - Multi-frame embeddings | |
| """) | |
| # Hidden API interfaces for EagleEye integration | |
| with gr.Row(visible=False): | |
| # Video embedding API | |
| api_vid_url = gr.Textbox() | |
| api_vid_b64 = gr.Textbox() | |
| api_vid_frames = gr.Number(value=8) | |
| api_vid_output = gr.JSON() | |
| # Text embedding API | |
| api_text_input = gr.Textbox() | |
| api_text_output = gr.JSON() | |
| # Frames embedding API | |
| api_frames_input = gr.Textbox() | |
| api_frames_ts = gr.Textbox() | |
| api_frames_output = gr.JSON() | |
| api_vid_url.change( | |
| fn=api_embed_video, | |
| inputs=[api_vid_url, api_vid_b64, api_vid_frames], | |
| outputs=api_vid_output, | |
| api_name="api_embed_video", | |
| ) | |
| api_text_input.change( | |
| fn=api_embed_text, | |
| inputs=[api_text_input], | |
| outputs=api_text_output, | |
| api_name="api_embed_text", | |
| ) | |
| api_frames_input.change( | |
| fn=api_embed_frames, | |
| inputs=[api_frames_input, api_frames_ts], | |
| outputs=api_frames_output, | |
| api_name="api_embed_frames", | |
| ) | |
| with gr.Tab("Video Embedding"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| video_input = gr.File( | |
| label="Upload Video", | |
| file_types=[".mp4", ".avi", ".mov", ".mkv", ".webm"], | |
| ) | |
| num_frames_slider = gr.Slider( | |
| minimum=4, | |
| maximum=32, | |
| value=8, | |
| step=4, | |
| label="Number of Frames", | |
| ) | |
| video_btn = gr.Button("Get Embedding", variant="primary") | |
| with gr.Column(): | |
| video_output = gr.Textbox(label="Embedding (768-dim)", lines=10) | |
| video_btn.click( | |
| fn=embed_video_file, | |
| inputs=[video_input, num_frames_slider], | |
| outputs=video_output, | |
| ) | |
| with gr.Tab("Text Embedding"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| text_input = gr.Textbox( | |
| label="Text Query", | |
| placeholder="e.g., a person scoring a goal", | |
| lines=2, | |
| ) | |
| text_btn = gr.Button("Get Embedding", variant="primary") | |
| with gr.Column(): | |
| text_output = gr.Textbox(label="Embedding (768-dim)", lines=10) | |
| text_btn.click( | |
| fn=embed_text_query, | |
| inputs=[text_input], | |
| outputs=text_output, | |
| ) | |
| with gr.Tab("API"): | |
| gr.Markdown(""" | |
| ## API Usage for EagleEye Integration | |
| ### Video Embedding | |
| ```python | |
| from gradio_client import Client | |
| client = Client("Cadayn/viclip-zerogpu") | |
| result = client.predict( | |
| video_url="https://example.com/clip.mp4", | |
| num_frames=8, | |
| api_name="/api_embed_video" | |
| ) | |
| print(result) | |
| # {"success": True, "embedding": [...], "dim": 768, ...} | |
| ``` | |
| ### Text Embedding | |
| ```python | |
| result = client.predict( | |
| text="a soccer player scoring a goal", | |
| api_name="/api_embed_text" | |
| ) | |
| print(result) | |
| # {"success": True, "embedding": [...], "dim": 768, ...} | |
| ``` | |
| ### Multi-Frame Embeddings | |
| ```python | |
| result = client.predict( | |
| frames_base64=["frame1_b64", "frame2_b64", ...], | |
| timestamps=[0.0, 1.0, 2.0, ...], | |
| api_name="/api_embed_frames" | |
| ) | |
| print(result) | |
| # {"success": True, "frame_embeddings": [[...], [...]], "pooled_embedding": [...], ...} | |
| ``` | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch() | |