VidTower / app.py
MeshMax's picture
Update app.py
9b639e2 verified
import os
import torch
import torch.nn as nn
import requests
from io import BytesIO
from PIL import Image
import timm
import gradio as gr
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from transformers import AutoTokenizer, AutoModel
from torchvision import transforms
from huggingface_hub import hf_hub_download # NEW
# --- Config ---
TEXT_MODEL = "sentence-transformers/LaBSE"
IMG_MODEL = "vit_base_patch16_224"
IMG_SIZE = 224
MAX_LENGTH = 512
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# --- Model definition (unchanged) ---
class MultimodalRegressor(nn.Module):
def __init__(self, text_dim=768, img_dim=768, proj_dim=768):
super().__init__()
self.text_proj = nn.Linear(text_dim, proj_dim)
self.img_proj = nn.Linear(img_dim, proj_dim)
self.fusion_layer = nn.MultiheadAttention(embed_dim=proj_dim, num_heads=8, batch_first=True)
self.dropout = nn.Dropout(0.1)
self.regressor = nn.Sequential(
nn.Linear(proj_dim, proj_dim // 2),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(proj_dim // 2, 1)
)
def forward(self, text_emb, img_emb):
t = self.text_proj(text_emb).unsqueeze(1)
i = self.img_proj(img_emb).unsqueeze(1)
attn_out, _ = self.fusion_layer(query=t, key=i, value=i)
fused = attn_out.squeeze(1)
fused = self.dropout(fused)
return self.regressor(fused).squeeze(1)
# --- Load backbone models + head ---
tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL)
text_model = AutoModel.from_pretrained(TEXT_MODEL).to(DEVICE)
img_model = timm.create_model(IMG_MODEL, pretrained=False, num_classes=0).to(DEVICE)
head = MultimodalRegressor().to(DEVICE)
# NEW: Dynamic load with cache
def load_model_if_needed():
try:
model_path = hf_hub_download(
repo_id="MeshMax/video_tower",
filename="finetuned_multimodal.pt",
local_dir=None, # CHANGED: Use default ~/.cache (persistent, no /tmp)
local_dir_use_symlinks=False,
cache_dir=None
)
print(f"Model loaded from: {model_path}")
return model_path
except Exception as e:
print(f"Download failed: {e}. Retrying with force_download...")
# Fallback: Force re-download if cache is corrupted
model_path = hf_hub_download(
repo_id="MeshMax/video_tower",
filename="finetuned_multimodal.pt",
local_dir=None,
local_dir_use_symlinks=False,
cache_dir=None,
force_download=True # Overwrite if needed
)
return model_path
model_path = load_model_if_needed()
ckpt = torch.load(model_path, map_location=DEVICE, weights_only=False)
if "text_model_state" in ckpt:
text_model.load_state_dict(ckpt["text_model_state"])
if "img_model_state" in ckpt:
img_model.load_state_dict(ckpt["img_model_state"])
if "head_state" in ckpt:
head.load_state_dict(ckpt["head_state"])
text_model.eval()
img_model.eval()
head.eval()
transform = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
def compute_embedding(title, description, tags, thumbnail_url):
text = f"{title} {description} {tags}"
toks = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=MAX_LENGTH).to(DEVICE)
with torch.no_grad():
out = text_model(**toks)
if hasattr(out, "pooler_output") and out.pooler_output is not None:
text_emb = out.pooler_output
else:
text_emb = out.last_hidden_state.mean(dim=1)
try:
img_resp = requests.get(thumbnail_url, timeout=5)
img_resp.raise_for_status() # IMPROVED: Raise on HTTP errors
img = Image.open(BytesIO(img_resp.content)).convert("RGB")
except Exception:
img = Image.new("RGB", (IMG_SIZE, IMG_SIZE), color=(128, 128, 128))
img_tensor = transform(img).unsqueeze(0).to(DEVICE)
with torch.no_grad():
img_emb = img_model(img_tensor)
t_proj = head.text_proj(text_emb)
i_proj = head.img_proj(img_emb)
attn_out, _ = head.fusion_layer(
query=t_proj.unsqueeze(1), key=i_proj.unsqueeze(1), value=i_proj.unsqueeze(1)
)
fused = attn_out.squeeze(1)
return fused.squeeze(0).cpu().numpy().tolist() # Note: This is proj_dim=768, not 1—adjust if regression output
# --- Keep everything up to compute_embedding() unchanged ---
# (Imports, config, model loading, transform, compute_embedding)
# --- NEW: Pure Gradio Interface + Launch (remove FastAPI entirely) ---
def gradio_fn(title, description, tags, thumbnail_url):
emb = compute_embedding(title, description, tags, thumbnail_url)
return emb # Returns the full list directly—Gradio stringifies it cleanly
# Create the interface (same as before)
demo = gr.Interface( # Renamed from gr_interface for clarity
fn=gradio_fn,
inputs=[
gr.Textbox(label="Title", placeholder="Enter video title..."),
gr.Textbox(label="Description", placeholder="Enter video description..."),
gr.Textbox(label="Tags", placeholder="Enter comma-separated tags..."),
gr.Textbox(label="Thumbnail URL", placeholder="Enter image URL (e.g., https://example.com/thumb.jpg)...")
],
outputs=gr.Textbox(label="Generated Embedding", lines=10),
title="Video Embedding Generator",
description="Generates fused multimodal embeddings from video metadata using LaBSE + ViT + Custom Fusion.",
examples=[ # Optional: Add sample inputs for easy testing
["Test Video", "A sample description.", "ml, ai, video", "https://via.placeholder.com/224"]
],
allow_flagging="never" # Disable if not needed
)
# NEW: Launch directly (blocks and serves everything)
if __name__ == "__main__":
demo.queue(max_size=10) # Enable queuing for concurrency (HF handles 10+ users)
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False, # HF provides public URL, no need for temp share
show_api=True, # Exposes /api endpoint automatically (see below)
debug=False, # Set True for more logs during testing
favicon_path=None # Optional: Add custom favicon later
)