File size: 6,358 Bytes
ab0f28d a5705f1 ab0f28d d36dff5 ab0f28d a5705f1 ab0f28d d36dff5 ab0f28d a5705f1 ab0f28d 408b528 ab0f28d a5705f1 ab0f28d d36dff5 751e04c d36dff5 a5705f1 ab0f28d a5705f1 ab0f28d a5705f1 ab0f28d a5705f1 ab0f28d 408b528 a5705f1 408b528 d36dff5 408b528 a5705f1 408b528 a5705f1 408b528 a5705f1 d36dff5 ab0f28d 49705d6 ab0f28d 49705d6 ab0f28d a5705f1 9b639e2 ab0f28d 49705d6 ab0f28d 49705d6 a5705f1 49705d6 ab0f28d 49705d6 b2321f6 49705d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import os
import torch
import torch.nn as nn
import requests
from io import BytesIO
from PIL import Image
import timm
import gradio as gr
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from transformers import AutoTokenizer, AutoModel
from torchvision import transforms
from huggingface_hub import hf_hub_download # NEW
# --- Config ---
TEXT_MODEL = "sentence-transformers/LaBSE"
IMG_MODEL = "vit_base_patch16_224"
IMG_SIZE = 224
MAX_LENGTH = 512
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# --- Model definition (unchanged) ---
class MultimodalRegressor(nn.Module):
def __init__(self, text_dim=768, img_dim=768, proj_dim=768):
super().__init__()
self.text_proj = nn.Linear(text_dim, proj_dim)
self.img_proj = nn.Linear(img_dim, proj_dim)
self.fusion_layer = nn.MultiheadAttention(embed_dim=proj_dim, num_heads=8, batch_first=True)
self.dropout = nn.Dropout(0.1)
self.regressor = nn.Sequential(
nn.Linear(proj_dim, proj_dim // 2),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(proj_dim // 2, 1)
)
def forward(self, text_emb, img_emb):
t = self.text_proj(text_emb).unsqueeze(1)
i = self.img_proj(img_emb).unsqueeze(1)
attn_out, _ = self.fusion_layer(query=t, key=i, value=i)
fused = attn_out.squeeze(1)
fused = self.dropout(fused)
return self.regressor(fused).squeeze(1)
# --- Load backbone models + head ---
tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL)
text_model = AutoModel.from_pretrained(TEXT_MODEL).to(DEVICE)
img_model = timm.create_model(IMG_MODEL, pretrained=False, num_classes=0).to(DEVICE)
head = MultimodalRegressor().to(DEVICE)
# NEW: Dynamic load with cache
def load_model_if_needed():
try:
model_path = hf_hub_download(
repo_id="MeshMax/video_tower",
filename="finetuned_multimodal.pt",
local_dir=None, # CHANGED: Use default ~/.cache (persistent, no /tmp)
local_dir_use_symlinks=False,
cache_dir=None
)
print(f"Model loaded from: {model_path}")
return model_path
except Exception as e:
print(f"Download failed: {e}. Retrying with force_download...")
# Fallback: Force re-download if cache is corrupted
model_path = hf_hub_download(
repo_id="MeshMax/video_tower",
filename="finetuned_multimodal.pt",
local_dir=None,
local_dir_use_symlinks=False,
cache_dir=None,
force_download=True # Overwrite if needed
)
return model_path
model_path = load_model_if_needed()
ckpt = torch.load(model_path, map_location=DEVICE, weights_only=False)
if "text_model_state" in ckpt:
text_model.load_state_dict(ckpt["text_model_state"])
if "img_model_state" in ckpt:
img_model.load_state_dict(ckpt["img_model_state"])
if "head_state" in ckpt:
head.load_state_dict(ckpt["head_state"])
text_model.eval()
img_model.eval()
head.eval()
transform = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
def compute_embedding(title, description, tags, thumbnail_url):
text = f"{title} {description} {tags}"
toks = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=MAX_LENGTH).to(DEVICE)
with torch.no_grad():
out = text_model(**toks)
if hasattr(out, "pooler_output") and out.pooler_output is not None:
text_emb = out.pooler_output
else:
text_emb = out.last_hidden_state.mean(dim=1)
try:
img_resp = requests.get(thumbnail_url, timeout=5)
img_resp.raise_for_status() # IMPROVED: Raise on HTTP errors
img = Image.open(BytesIO(img_resp.content)).convert("RGB")
except Exception:
img = Image.new("RGB", (IMG_SIZE, IMG_SIZE), color=(128, 128, 128))
img_tensor = transform(img).unsqueeze(0).to(DEVICE)
with torch.no_grad():
img_emb = img_model(img_tensor)
t_proj = head.text_proj(text_emb)
i_proj = head.img_proj(img_emb)
attn_out, _ = head.fusion_layer(
query=t_proj.unsqueeze(1), key=i_proj.unsqueeze(1), value=i_proj.unsqueeze(1)
)
fused = attn_out.squeeze(1)
return fused.squeeze(0).cpu().numpy().tolist() # Note: This is proj_dim=768, not 1—adjust if regression output
# --- Keep everything up to compute_embedding() unchanged ---
# (Imports, config, model loading, transform, compute_embedding)
# --- NEW: Pure Gradio Interface + Launch (remove FastAPI entirely) ---
def gradio_fn(title, description, tags, thumbnail_url):
emb = compute_embedding(title, description, tags, thumbnail_url)
return emb # Returns the full list directly—Gradio stringifies it cleanly
# Create the interface (same as before)
demo = gr.Interface( # Renamed from gr_interface for clarity
fn=gradio_fn,
inputs=[
gr.Textbox(label="Title", placeholder="Enter video title..."),
gr.Textbox(label="Description", placeholder="Enter video description..."),
gr.Textbox(label="Tags", placeholder="Enter comma-separated tags..."),
gr.Textbox(label="Thumbnail URL", placeholder="Enter image URL (e.g., https://example.com/thumb.jpg)...")
],
outputs=gr.Textbox(label="Generated Embedding", lines=10),
title="Video Embedding Generator",
description="Generates fused multimodal embeddings from video metadata using LaBSE + ViT + Custom Fusion.",
examples=[ # Optional: Add sample inputs for easy testing
["Test Video", "A sample description.", "ml, ai, video", "https://via.placeholder.com/224"]
],
allow_flagging="never" # Disable if not needed
)
# NEW: Launch directly (blocks and serves everything)
if __name__ == "__main__":
demo.queue(max_size=10) # Enable queuing for concurrency (HF handles 10+ users)
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False, # HF provides public URL, no need for temp share
show_api=True, # Exposes /api endpoint automatically (see below)
debug=False, # Set True for more logs during testing
favicon_path=None # Optional: Add custom favicon later
) |