|
|
import os |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import requests |
|
|
from io import BytesIO |
|
|
from PIL import Image |
|
|
import timm |
|
|
import gradio as gr |
|
|
from fastapi import FastAPI, Request |
|
|
from fastapi.responses import JSONResponse |
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
from torchvision import transforms |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
TEXT_MODEL = "sentence-transformers/LaBSE" |
|
|
IMG_MODEL = "vit_base_patch16_224" |
|
|
IMG_SIZE = 224 |
|
|
MAX_LENGTH = 512 |
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
class MultimodalRegressor(nn.Module): |
|
|
def __init__(self, text_dim=768, img_dim=768, proj_dim=768): |
|
|
super().__init__() |
|
|
self.text_proj = nn.Linear(text_dim, proj_dim) |
|
|
self.img_proj = nn.Linear(img_dim, proj_dim) |
|
|
self.fusion_layer = nn.MultiheadAttention(embed_dim=proj_dim, num_heads=8, batch_first=True) |
|
|
self.dropout = nn.Dropout(0.1) |
|
|
self.regressor = nn.Sequential( |
|
|
nn.Linear(proj_dim, proj_dim // 2), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(0.1), |
|
|
nn.Linear(proj_dim // 2, 1) |
|
|
) |
|
|
|
|
|
def forward(self, text_emb, img_emb): |
|
|
t = self.text_proj(text_emb).unsqueeze(1) |
|
|
i = self.img_proj(img_emb).unsqueeze(1) |
|
|
attn_out, _ = self.fusion_layer(query=t, key=i, value=i) |
|
|
fused = attn_out.squeeze(1) |
|
|
fused = self.dropout(fused) |
|
|
return self.regressor(fused).squeeze(1) |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL) |
|
|
text_model = AutoModel.from_pretrained(TEXT_MODEL).to(DEVICE) |
|
|
img_model = timm.create_model(IMG_MODEL, pretrained=False, num_classes=0).to(DEVICE) |
|
|
head = MultimodalRegressor().to(DEVICE) |
|
|
|
|
|
|
|
|
def load_model_if_needed(): |
|
|
try: |
|
|
model_path = hf_hub_download( |
|
|
repo_id="MeshMax/video_tower", |
|
|
filename="finetuned_multimodal.pt", |
|
|
local_dir=None, |
|
|
local_dir_use_symlinks=False, |
|
|
cache_dir=None |
|
|
) |
|
|
print(f"Model loaded from: {model_path}") |
|
|
return model_path |
|
|
except Exception as e: |
|
|
print(f"Download failed: {e}. Retrying with force_download...") |
|
|
|
|
|
model_path = hf_hub_download( |
|
|
repo_id="MeshMax/video_tower", |
|
|
filename="finetuned_multimodal.pt", |
|
|
local_dir=None, |
|
|
local_dir_use_symlinks=False, |
|
|
cache_dir=None, |
|
|
force_download=True |
|
|
) |
|
|
return model_path |
|
|
|
|
|
model_path = load_model_if_needed() |
|
|
ckpt = torch.load(model_path, map_location=DEVICE, weights_only=False) |
|
|
if "text_model_state" in ckpt: |
|
|
text_model.load_state_dict(ckpt["text_model_state"]) |
|
|
if "img_model_state" in ckpt: |
|
|
img_model.load_state_dict(ckpt["img_model_state"]) |
|
|
if "head_state" in ckpt: |
|
|
head.load_state_dict(ckpt["head_state"]) |
|
|
|
|
|
text_model.eval() |
|
|
img_model.eval() |
|
|
head.eval() |
|
|
|
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize((IMG_SIZE, IMG_SIZE)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
|
|
]) |
|
|
|
|
|
def compute_embedding(title, description, tags, thumbnail_url): |
|
|
text = f"{title} {description} {tags}" |
|
|
toks = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=MAX_LENGTH).to(DEVICE) |
|
|
with torch.no_grad(): |
|
|
out = text_model(**toks) |
|
|
if hasattr(out, "pooler_output") and out.pooler_output is not None: |
|
|
text_emb = out.pooler_output |
|
|
else: |
|
|
text_emb = out.last_hidden_state.mean(dim=1) |
|
|
|
|
|
try: |
|
|
img_resp = requests.get(thumbnail_url, timeout=5) |
|
|
img_resp.raise_for_status() |
|
|
img = Image.open(BytesIO(img_resp.content)).convert("RGB") |
|
|
except Exception: |
|
|
img = Image.new("RGB", (IMG_SIZE, IMG_SIZE), color=(128, 128, 128)) |
|
|
|
|
|
img_tensor = transform(img).unsqueeze(0).to(DEVICE) |
|
|
with torch.no_grad(): |
|
|
img_emb = img_model(img_tensor) |
|
|
t_proj = head.text_proj(text_emb) |
|
|
i_proj = head.img_proj(img_emb) |
|
|
attn_out, _ = head.fusion_layer( |
|
|
query=t_proj.unsqueeze(1), key=i_proj.unsqueeze(1), value=i_proj.unsqueeze(1) |
|
|
) |
|
|
fused = attn_out.squeeze(1) |
|
|
return fused.squeeze(0).cpu().numpy().tolist() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def gradio_fn(title, description, tags, thumbnail_url): |
|
|
emb = compute_embedding(title, description, tags, thumbnail_url) |
|
|
return emb |
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=gradio_fn, |
|
|
inputs=[ |
|
|
gr.Textbox(label="Title", placeholder="Enter video title..."), |
|
|
gr.Textbox(label="Description", placeholder="Enter video description..."), |
|
|
gr.Textbox(label="Tags", placeholder="Enter comma-separated tags..."), |
|
|
gr.Textbox(label="Thumbnail URL", placeholder="Enter image URL (e.g., https://example.com/thumb.jpg)...") |
|
|
], |
|
|
outputs=gr.Textbox(label="Generated Embedding", lines=10), |
|
|
title="Video Embedding Generator", |
|
|
description="Generates fused multimodal embeddings from video metadata using LaBSE + ViT + Custom Fusion.", |
|
|
examples=[ |
|
|
["Test Video", "A sample description.", "ml, ai, video", "https://via.placeholder.com/224"] |
|
|
], |
|
|
allow_flagging="never" |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.queue(max_size=10) |
|
|
demo.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
share=False, |
|
|
show_api=True, |
|
|
debug=False, |
|
|
favicon_path=None |
|
|
) |