| |
| import os |
| import io |
| import time |
| import json |
| import torch |
| import torch.nn as nn |
| import numpy as np |
| from PIL import Image |
| from io import BytesIO |
| import requests |
|
|
| from fastapi import FastAPI, Request |
| from fastapi.responses import JSONResponse |
| import gradio as gr |
|
|
| from transformers import AutoTokenizer, AutoModel |
| import timm |
| from torchvision import transforms |
|
|
| |
| |
| |
| MODEL_FILENAME = "finetuned_multimodal.pt" |
| TEXT_MODEL = "sentence-transformers/LaBSE" |
| IMG_MODEL = "vit_base_patch16_224" |
| IMG_SIZE = 224 |
| MAX_LENGTH = 512 |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| |
| |
| |
| class MultimodalRegressor(nn.Module): |
| def __init__(self, text_dim=768, img_dim=768, proj_dim=768): |
| super().__init__() |
| self.text_proj = nn.Linear(text_dim, proj_dim) |
| self.img_proj = nn.Linear(img_dim, proj_dim) |
| |
| self.fusion_layer = nn.MultiheadAttention(embed_dim=proj_dim, num_heads=8, batch_first=True) |
| self.dropout = nn.Dropout(0.1) |
| self.regressor = nn.Sequential( |
| nn.Linear(proj_dim, proj_dim // 2), |
| nn.ReLU(), |
| nn.Dropout(0.1), |
| nn.Linear(proj_dim // 2, 1) |
| ) |
|
|
| def forward(self, text_emb, img_emb): |
| t = self.text_proj(text_emb).unsqueeze(1) |
| i = self.img_proj(img_emb).unsqueeze(1) |
| attn_out, _ = self.fusion_layer(query=t, key=i, value=i) |
| fused = attn_out.squeeze(1) |
| fused = self.dropout(fused) |
| return self.regressor(fused).squeeze(1) |
|
|
| |
| |
| |
| img_transform = transforms.Compose([ |
| transforms.Resize((IMG_SIZE, IMG_SIZE)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) |
| ]) |
|
|
| def load_image_from_url(url): |
| try: |
| resp = requests.get(url, timeout=6) |
| resp.raise_for_status() |
| img = Image.open(BytesIO(resp.content)).convert("RGB") |
| return img |
| except Exception: |
| |
| return Image.new("RGB", (IMG_SIZE, IMG_SIZE), color=(128, 128, 128)) |
|
|
| def text_to_embedding(tokenizer, text_model, texts): |
| |
| |
| text_model.eval() |
| with torch.no_grad(): |
| toks = tokenizer(texts, padding=True, truncation=True, max_length=MAX_LENGTH, return_tensors="pt") |
| toks = {k: v.to(DEVICE) for k, v in toks.items()} |
| out = text_model(**toks) |
| |
| if hasattr(out, "pooler_output") and out.pooler_output is not None: |
| emb = out.pooler_output |
| else: |
| last = out.last_hidden_state |
| emb = last.mean(dim=1) |
| return emb |
|
|
| |
| |
| |
| print("Device:", DEVICE) |
| print("Loading tokenizer and text model:", TEXT_MODEL) |
| tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL) |
| text_model = AutoModel.from_pretrained(TEXT_MODEL).to(DEVICE) |
|
|
| print("Creating image model:", IMG_MODEL) |
| |
| img_model = timm.create_model(IMG_MODEL, pretrained=False, num_classes=0).to(DEVICE) |
|
|
| multimodal_head = MultimodalRegressor().to(DEVICE) |
|
|
| |
| if not os.path.exists(MODEL_FILENAME): |
| print(f"WARNING: {MODEL_FILENAME} not found in the Space. Place your checkpoint at the repository root.") |
| else: |
| print("Loading checkpoint:", MODEL_FILENAME) |
| ckpt = torch.load(MODEL_FILENAME, map_location=DEVICE) |
| |
| if "text_model_state" in ckpt: |
| text_model.load_state_dict(ckpt["text_model_state"]) |
| elif "text_state_dict" in ckpt: |
| text_model.load_state_dict(ckpt["text_state_dict"]) |
| else: |
| print("No text_model_state found in checkpoint (skipping).") |
|
|
| if "img_model_state" in ckpt: |
| img_model.load_state_dict(ckpt["img_model_state"]) |
| elif "img_state_dict" in ckpt: |
| img_model.load_state_dict(ckpt["img_state_dict"]) |
| else: |
| print("No img_model_state found in checkpoint (skipping).") |
|
|
| if "head_state" in ckpt: |
| multimodal_head.load_state_dict(ckpt["head_state"]) |
| elif "head_state_dict" in ckpt: |
| multimodal_head.load_state_dict(ckpt["head_state_dict"]) |
| else: |
| print("No head_state found in checkpoint (skipping).") |
|
|
| text_model.eval() |
| img_model.eval() |
| multimodal_head.eval() |
| print("Models ready.") |
|
|
| |
| |
| |
| def compute_fused_embedding(title: str, description: str, tags: str, thumbnail_url: str): |
| |
| text = " ".join([str(title or ""), str(description or ""), str(tags or "")]).strip() |
| texts = [text] |
|
|
| |
| t_emb = text_to_embedding(tokenizer, text_model, texts) |
|
|
| |
| img = load_image_from_url(thumbnail_url) |
| img_tensor = img_transform(img).unsqueeze(0).to(DEVICE) |
| with torch.no_grad(): |
| i_emb = img_model(img_tensor) |
|
|
| |
| t_proj = multimodal_head.text_proj(t_emb) |
| i_proj = multimodal_head.img_proj(i_emb) |
|
|
| |
| attn_out, _ = multimodal_head.fusion_layer( |
| query=t_proj.unsqueeze(1), |
| key=i_proj.unsqueeze(1), |
| value=i_proj.unsqueeze(1) |
| ) |
| fused = attn_out.squeeze(1) |
| fused_np = fused.squeeze(0).cpu().numpy().tolist() |
| return fused_np |
|
|
| |
| |
| |
| app = FastAPI() |
|
|
| @app.post("/api/get_embedding") |
| async def api_get_embedding(request: Request): |
| payload = await request.json() |
| title = payload.get("title", "") |
| description = payload.get("description", "") |
| tags = payload.get("tags", "") |
| thumbnail_url = payload.get("thumbnail_url", "") |
|
|
| try: |
| emb = compute_fused_embedding(title, description, tags, thumbnail_url) |
| except Exception as e: |
| return JSONResponse({"error": str(e)}, status_code=500) |
|
|
| return JSONResponse({"embedding": emb}) |
|
|
| |
| def gradio_fn(title, description, tags, thumbnail_url): |
| try: |
| emb = compute_fused_embedding(title, description, tags, thumbnail_url) |
| return f"embedding (len={len(emb)}): {emb[:10]} ... (truncated)" |
| except Exception as e: |
| return f"Error: {e}" |
|
|
| gr_interface = gr.Interface( |
| fn=gradio_fn, |
| inputs=[ |
| gr.Textbox(label="Title", lines=1), |
| gr.Textbox(label="Description", lines=3), |
| gr.Textbox(label="Tags", lines=1), |
| gr.Textbox(label="Thumbnail URL", lines=1), |
| ], |
| outputs=gr.Textbox(label="Embedding (truncated)"), |
| title="Multimodal Embedding (Notebook -> Space)", |
| description="Provide title, description, tags and thumbnail URL. Returns fused multimodal embedding (vector).", |
| examples=[ |
| ["Cute cat", "A cat doing flips", "cat,funny", "https://example.com/sample.jpg"] |
| ] |
| ) |
|
|
| |
| app = gr.mount_gradio_app(app, gr_interface, path="/") |
|
|