bhuvan-2005's picture
Create app.py
b494970 verified
import gradio as gr
import torch
import numpy as np
import cv2
import tempfile
from PIL import Image
from transformers import AutoProcessor, AutoModel
# Load model on startup
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = AutoProcessor.from_pretrained("google/siglip-so400m-patch14-384")
model = AutoModel.from_pretrained("google/siglip-so400m-patch14-384").to(device).eval()
def encode_text(text):
"""Encode text query to embedding."""
if not text:
return None
with torch.no_grad():
inputs = processor(text=[text], return_tensors="pt", padding=True)
inputs = {k: v.to(device) for k, v in inputs.items()}
emb = model.get_text_features(**inputs)
emb = emb / emb.norm(dim=-1, keepdim=True)
return emb[0].cpu().numpy().tolist()
def encode_video(video_path):
"""Extract frames and encode to embeddings."""
if not video_path:
return None
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
if fps <= 0:
return None
interval = max(1, int(fps)) # 1 frame per second
frames, timestamps = [], []
idx = 0
while True:
ret, frame = cap.read()
if not ret:
break
if idx % interval == 0:
frames.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
timestamps.append(idx / fps)
idx += 1
cap.release()
if not frames:
return None
# Encode in batches
embeddings = []
with torch.no_grad():
for i in range(0, len(frames), 8):
batch = frames[i:i+8]
inputs = processor(images=batch, return_tensors="pt", padding=True)
inputs = {k: v.to(device) for k, v in inputs.items()}
emb = model.get_image_features(**inputs)
emb = emb / emb.norm(dim=-1, keepdim=True)
embeddings.extend(emb.cpu().numpy().tolist())
return {"embeddings": embeddings, "timestamps": timestamps}
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Video Search API")
with gr.Tab("Encode Text"):
text_input = gr.Textbox(label="Query")
text_output = gr.JSON(label="Embedding")
gr.Button("Encode").click(encode_text, text_input, text_output)
with gr.Tab("Encode Video"):
video_input = gr.Video(label="Video")
video_output = gr.JSON(label="Embeddings + Timestamps")
gr.Button("Encode").click(encode_video, video_input, video_output)
demo.launch()