Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import cv2 | |
| import tempfile | |
| from PIL import Image | |
| from transformers import AutoProcessor, AutoModel | |
| # Load model on startup | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| processor = AutoProcessor.from_pretrained("google/siglip-so400m-patch14-384") | |
| model = AutoModel.from_pretrained("google/siglip-so400m-patch14-384").to(device).eval() | |
| def encode_text(text): | |
| """Encode text query to embedding.""" | |
| if not text: | |
| return None | |
| with torch.no_grad(): | |
| inputs = processor(text=[text], return_tensors="pt", padding=True) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| emb = model.get_text_features(**inputs) | |
| emb = emb / emb.norm(dim=-1, keepdim=True) | |
| return emb[0].cpu().numpy().tolist() | |
| def encode_video(video_path): | |
| """Extract frames and encode to embeddings.""" | |
| if not video_path: | |
| return None | |
| cap = cv2.VideoCapture(video_path) | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| if fps <= 0: | |
| return None | |
| interval = max(1, int(fps)) # 1 frame per second | |
| frames, timestamps = [], [] | |
| idx = 0 | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| if idx % interval == 0: | |
| frames.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))) | |
| timestamps.append(idx / fps) | |
| idx += 1 | |
| cap.release() | |
| if not frames: | |
| return None | |
| # Encode in batches | |
| embeddings = [] | |
| with torch.no_grad(): | |
| for i in range(0, len(frames), 8): | |
| batch = frames[i:i+8] | |
| inputs = processor(images=batch, return_tensors="pt", padding=True) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| emb = model.get_image_features(**inputs) | |
| emb = emb / emb.norm(dim=-1, keepdim=True) | |
| embeddings.extend(emb.cpu().numpy().tolist()) | |
| return {"embeddings": embeddings, "timestamps": timestamps} | |
| # Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Video Search API") | |
| with gr.Tab("Encode Text"): | |
| text_input = gr.Textbox(label="Query") | |
| text_output = gr.JSON(label="Embedding") | |
| gr.Button("Encode").click(encode_text, text_input, text_output) | |
| with gr.Tab("Encode Video"): | |
| video_input = gr.Video(label="Video") | |
| video_output = gr.JSON(label="Embeddings + Timestamps") | |
| gr.Button("Encode").click(encode_video, video_input, video_output) | |
| demo.launch() |