import os import torch import torch.nn as nn import requests from io import BytesIO from PIL import Image import timm import gradio as gr from fastapi import FastAPI, Request from fastapi.responses import JSONResponse from transformers import AutoTokenizer, AutoModel from torchvision import transforms from huggingface_hub import hf_hub_download # NEW # --- Config --- TEXT_MODEL = "sentence-transformers/LaBSE" IMG_MODEL = "vit_base_patch16_224" IMG_SIZE = 224 MAX_LENGTH = 512 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # --- Model definition (unchanged) --- class MultimodalRegressor(nn.Module): def __init__(self, text_dim=768, img_dim=768, proj_dim=768): super().__init__() self.text_proj = nn.Linear(text_dim, proj_dim) self.img_proj = nn.Linear(img_dim, proj_dim) self.fusion_layer = nn.MultiheadAttention(embed_dim=proj_dim, num_heads=8, batch_first=True) self.dropout = nn.Dropout(0.1) self.regressor = nn.Sequential( nn.Linear(proj_dim, proj_dim // 2), nn.ReLU(), nn.Dropout(0.1), nn.Linear(proj_dim // 2, 1) ) def forward(self, text_emb, img_emb): t = self.text_proj(text_emb).unsqueeze(1) i = self.img_proj(img_emb).unsqueeze(1) attn_out, _ = self.fusion_layer(query=t, key=i, value=i) fused = attn_out.squeeze(1) fused = self.dropout(fused) return self.regressor(fused).squeeze(1) # --- Load backbone models + head --- tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL) text_model = AutoModel.from_pretrained(TEXT_MODEL).to(DEVICE) img_model = timm.create_model(IMG_MODEL, pretrained=False, num_classes=0).to(DEVICE) head = MultimodalRegressor().to(DEVICE) # NEW: Dynamic load with cache def load_model_if_needed(): try: model_path = hf_hub_download( repo_id="MeshMax/video_tower", filename="finetuned_multimodal.pt", local_dir=None, # CHANGED: Use default ~/.cache (persistent, no /tmp) local_dir_use_symlinks=False, cache_dir=None ) print(f"Model loaded from: {model_path}") return model_path except Exception as e: print(f"Download failed: {e}. Retrying with force_download...") # Fallback: Force re-download if cache is corrupted model_path = hf_hub_download( repo_id="MeshMax/video_tower", filename="finetuned_multimodal.pt", local_dir=None, local_dir_use_symlinks=False, cache_dir=None, force_download=True # Overwrite if needed ) return model_path model_path = load_model_if_needed() ckpt = torch.load(model_path, map_location=DEVICE, weights_only=False) if "text_model_state" in ckpt: text_model.load_state_dict(ckpt["text_model_state"]) if "img_model_state" in ckpt: img_model.load_state_dict(ckpt["img_model_state"]) if "head_state" in ckpt: head.load_state_dict(ckpt["head_state"]) text_model.eval() img_model.eval() head.eval() transform = transforms.Compose([ transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) def compute_embedding(title, description, tags, thumbnail_url): text = f"{title} {description} {tags}" toks = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=MAX_LENGTH).to(DEVICE) with torch.no_grad(): out = text_model(**toks) if hasattr(out, "pooler_output") and out.pooler_output is not None: text_emb = out.pooler_output else: text_emb = out.last_hidden_state.mean(dim=1) try: img_resp = requests.get(thumbnail_url, timeout=5) img_resp.raise_for_status() # IMPROVED: Raise on HTTP errors img = Image.open(BytesIO(img_resp.content)).convert("RGB") except Exception: img = Image.new("RGB", (IMG_SIZE, IMG_SIZE), color=(128, 128, 128)) img_tensor = transform(img).unsqueeze(0).to(DEVICE) with torch.no_grad(): img_emb = img_model(img_tensor) t_proj = head.text_proj(text_emb) i_proj = head.img_proj(img_emb) attn_out, _ = head.fusion_layer( query=t_proj.unsqueeze(1), key=i_proj.unsqueeze(1), value=i_proj.unsqueeze(1) ) fused = attn_out.squeeze(1) return fused.squeeze(0).cpu().numpy().tolist() # Note: This is proj_dim=768, not 1—adjust if regression output # --- Keep everything up to compute_embedding() unchanged --- # (Imports, config, model loading, transform, compute_embedding) # --- NEW: Pure Gradio Interface + Launch (remove FastAPI entirely) --- def gradio_fn(title, description, tags, thumbnail_url): emb = compute_embedding(title, description, tags, thumbnail_url) return emb # Returns the full list directly—Gradio stringifies it cleanly # Create the interface (same as before) demo = gr.Interface( # Renamed from gr_interface for clarity fn=gradio_fn, inputs=[ gr.Textbox(label="Title", placeholder="Enter video title..."), gr.Textbox(label="Description", placeholder="Enter video description..."), gr.Textbox(label="Tags", placeholder="Enter comma-separated tags..."), gr.Textbox(label="Thumbnail URL", placeholder="Enter image URL (e.g., https://example.com/thumb.jpg)...") ], outputs=gr.Textbox(label="Generated Embedding", lines=10), title="Video Embedding Generator", description="Generates fused multimodal embeddings from video metadata using LaBSE + ViT + Custom Fusion.", examples=[ # Optional: Add sample inputs for easy testing ["Test Video", "A sample description.", "ml, ai, video", "https://via.placeholder.com/224"] ], allow_flagging="never" # Disable if not needed ) # NEW: Launch directly (blocks and serves everything) if __name__ == "__main__": demo.queue(max_size=10) # Enable queuing for concurrency (HF handles 10+ users) demo.launch( server_name="0.0.0.0", server_port=7860, share=False, # HF provides public URL, no need for temp share show_api=True, # Exposes /api endpoint automatically (see below) debug=False, # Set True for more logs during testing favicon_path=None # Optional: Add custom favicon later )