VidTower / app.py
MeshMax's picture
Create app.py
ab0f28d verified
raw
history blame
8 kB
# app.py
import os
import io
import time
import json
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
from io import BytesIO
import requests
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
import gradio as gr
from transformers import AutoTokenizer, AutoModel
import timm
from torchvision import transforms
# -----------------------
# Config — mirror your notebook
# -----------------------
MODEL_FILENAME = "finetuned_multimodal.pt" # upload this to your Space
TEXT_MODEL = "sentence-transformers/LaBSE"
IMG_MODEL = "vit_base_patch16_224"
IMG_SIZE = 224
MAX_LENGTH = 512
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# -----------------------
# Model class (exact from your notebook)
# -----------------------
class MultimodalRegressor(nn.Module):
def __init__(self, text_dim=768, img_dim=768, proj_dim=768): # keep dims consistent with training
super().__init__()
self.text_proj = nn.Linear(text_dim, proj_dim)
self.img_proj = nn.Linear(img_dim, proj_dim)
# batch_first=True per your notebook
self.fusion_layer = nn.MultiheadAttention(embed_dim=proj_dim, num_heads=8, batch_first=True)
self.dropout = nn.Dropout(0.1)
self.regressor = nn.Sequential(
nn.Linear(proj_dim, proj_dim // 2),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(proj_dim // 2, 1)
)
def forward(self, text_emb, img_emb):
t = self.text_proj(text_emb).unsqueeze(1)
i = self.img_proj(img_emb).unsqueeze(1)
attn_out, _ = self.fusion_layer(query=t, key=i, value=i)
fused = attn_out.squeeze(1)
fused = self.dropout(fused)
return self.regressor(fused).squeeze(1)
# -----------------------
# Utilities: image transform & helpers
# -----------------------
img_transform = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
def load_image_from_url(url):
try:
resp = requests.get(url, timeout=6)
resp.raise_for_status()
img = Image.open(BytesIO(resp.content)).convert("RGB")
return img
except Exception:
# Return a gray image fallback if thumbnail fetch fails
return Image.new("RGB", (IMG_SIZE, IMG_SIZE), color=(128, 128, 128))
def text_to_embedding(tokenizer, text_model, texts):
# texts: list[str] (batch)
# Return tensor shape (batch, text_dim)
text_model.eval()
with torch.no_grad():
toks = tokenizer(texts, padding=True, truncation=True, max_length=MAX_LENGTH, return_tensors="pt")
toks = {k: v.to(DEVICE) for k, v in toks.items()}
out = text_model(**toks)
# prefer pooler_output if available, else mean of last_hidden_state
if hasattr(out, "pooler_output") and out.pooler_output is not None:
emb = out.pooler_output
else:
last = out.last_hidden_state # (batch, seq, dim)
emb = last.mean(dim=1)
return emb # already on DEVICE
# -----------------------
# Load pretrained backbone models + head; load checkpoint
# -----------------------
print("Device:", DEVICE)
print("Loading tokenizer and text model:", TEXT_MODEL)
tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL)
text_model = AutoModel.from_pretrained(TEXT_MODEL).to(DEVICE)
print("Creating image model:", IMG_MODEL)
# create_model(..., num_classes=0) returns features vector for many timm models
img_model = timm.create_model(IMG_MODEL, pretrained=False, num_classes=0).to(DEVICE)
multimodal_head = MultimodalRegressor().to(DEVICE)
# Load checkpoint (robust to different key names)
if not os.path.exists(MODEL_FILENAME):
print(f"WARNING: {MODEL_FILENAME} not found in the Space. Place your checkpoint at the repository root.")
else:
print("Loading checkpoint:", MODEL_FILENAME)
ckpt = torch.load(MODEL_FILENAME, map_location=DEVICE)
# expected keys from notebook: 'text_model_state', 'img_model_state', 'head_state'
if "text_model_state" in ckpt:
text_model.load_state_dict(ckpt["text_model_state"])
elif "text_state_dict" in ckpt:
text_model.load_state_dict(ckpt["text_state_dict"])
else:
print("No text_model_state found in checkpoint (skipping).")
if "img_model_state" in ckpt:
img_model.load_state_dict(ckpt["img_model_state"])
elif "img_state_dict" in ckpt:
img_model.load_state_dict(ckpt["img_state_dict"])
else:
print("No img_model_state found in checkpoint (skipping).")
if "head_state" in ckpt:
multimodal_head.load_state_dict(ckpt["head_state"])
elif "head_state_dict" in ckpt:
multimodal_head.load_state_dict(ckpt["head_state_dict"])
else:
print("No head_state found in checkpoint (skipping).")
text_model.eval()
img_model.eval()
multimodal_head.eval()
print("Models ready.")
# -----------------------
# Inference: create fused embedding (same pipeline notebook used)
# -----------------------
def compute_fused_embedding(title: str, description: str, tags: str, thumbnail_url: str):
# Build text and image inputs
text = " ".join([str(title or ""), str(description or ""), str(tags or "")]).strip()
texts = [text]
# Text embedding (batch of 1)
t_emb = text_to_embedding(tokenizer, text_model, texts) # shape (1, text_dim)
# Image embedding: preprocess and forward
img = load_image_from_url(thumbnail_url)
img_tensor = img_transform(img).unsqueeze(0).to(DEVICE) # (1,3,H,W)
with torch.no_grad():
i_emb = img_model(img_tensor) # expected shape (1, img_dim)
# Project, fuse via head's fusion layer (exactly as in notebook)
t_proj = multimodal_head.text_proj(t_emb) # (1, proj_dim)
i_proj = multimodal_head.img_proj(i_emb) # (1, proj_dim)
# MultiheadAttention expects (batch, seq, dim) because batch_first=True
attn_out, _ = multimodal_head.fusion_layer(
query=t_proj.unsqueeze(1), # (1, 1, proj_dim)
key=i_proj.unsqueeze(1), # (1, 1, proj_dim)
value=i_proj.unsqueeze(1) # (1, 1, proj_dim)
)
fused = attn_out.squeeze(1) # (1, proj_dim) -> (proj_dim,)
fused_np = fused.squeeze(0).cpu().numpy().tolist()
return fused_np
# -----------------------
# FastAPI + Gradio integration
# -----------------------
app = FastAPI()
@app.post("/api/get_embedding")
async def api_get_embedding(request: Request):
payload = await request.json()
title = payload.get("title", "")
description = payload.get("description", "")
tags = payload.get("tags", "")
thumbnail_url = payload.get("thumbnail_url", "")
try:
emb = compute_fused_embedding(title, description, tags, thumbnail_url)
except Exception as e:
return JSONResponse({"error": str(e)}, status_code=500)
return JSONResponse({"embedding": emb})
# Gradio UI for quick testing (truncated embedding shown)
def gradio_fn(title, description, tags, thumbnail_url):
try:
emb = compute_fused_embedding(title, description, tags, thumbnail_url)
return f"embedding (len={len(emb)}): {emb[:10]} ... (truncated)"
except Exception as e:
return f"Error: {e}"
gr_interface = gr.Interface(
fn=gradio_fn,
inputs=[
gr.Textbox(label="Title", lines=1),
gr.Textbox(label="Description", lines=3),
gr.Textbox(label="Tags", lines=1),
gr.Textbox(label="Thumbnail URL", lines=1),
],
outputs=gr.Textbox(label="Embedding (truncated)"),
title="Multimodal Embedding (Notebook -> Space)",
description="Provide title, description, tags and thumbnail URL. Returns fused multimodal embedding (vector).",
examples=[
["Cute cat", "A cat doing flips", "cat,funny", "https://example.com/sample.jpg"]
]
)
# Mount Gradio app at root
app = gr.mount_gradio_app(app, gr_interface, path="/")