File size: 6,358 Bytes
ab0f28d
 
 
 
a5705f1
 
 
 
ab0f28d
 
 
 
d36dff5
ab0f28d
a5705f1
ab0f28d
 
 
 
 
 
d36dff5
ab0f28d
a5705f1
ab0f28d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
408b528
ab0f28d
 
 
a5705f1
ab0f28d
d36dff5
 
751e04c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d36dff5
 
 
a5705f1
 
 
 
 
 
ab0f28d
 
 
a5705f1
ab0f28d
a5705f1
 
 
 
 
ab0f28d
a5705f1
 
 
ab0f28d
408b528
 
 
 
 
 
a5705f1
408b528
d36dff5
408b528
a5705f1
 
408b528
a5705f1
 
 
 
 
408b528
 
 
a5705f1
d36dff5
ab0f28d
49705d6
 
ab0f28d
49705d6
ab0f28d
a5705f1
9b639e2
ab0f28d
49705d6
 
ab0f28d
49705d6
 
 
 
 
 
 
a5705f1
49705d6
 
 
 
 
ab0f28d
 
49705d6
b2321f6
49705d6
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import os
import torch
import torch.nn as nn
import requests
from io import BytesIO
from PIL import Image
import timm
import gradio as gr
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from transformers import AutoTokenizer, AutoModel
from torchvision import transforms
from huggingface_hub import hf_hub_download  # NEW

# --- Config ---
TEXT_MODEL = "sentence-transformers/LaBSE"
IMG_MODEL = "vit_base_patch16_224"
IMG_SIZE = 224
MAX_LENGTH = 512
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# --- Model definition (unchanged) ---
class MultimodalRegressor(nn.Module):
    def __init__(self, text_dim=768, img_dim=768, proj_dim=768):
        super().__init__()
        self.text_proj = nn.Linear(text_dim, proj_dim)
        self.img_proj = nn.Linear(img_dim, proj_dim)
        self.fusion_layer = nn.MultiheadAttention(embed_dim=proj_dim, num_heads=8, batch_first=True)
        self.dropout = nn.Dropout(0.1)
        self.regressor = nn.Sequential(
            nn.Linear(proj_dim, proj_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(proj_dim // 2, 1)
        )

    def forward(self, text_emb, img_emb):
        t = self.text_proj(text_emb).unsqueeze(1)
        i = self.img_proj(img_emb).unsqueeze(1)
        attn_out, _ = self.fusion_layer(query=t, key=i, value=i)
        fused = attn_out.squeeze(1)
        fused = self.dropout(fused)
        return self.regressor(fused).squeeze(1)

# --- Load backbone models + head ---
tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL)
text_model = AutoModel.from_pretrained(TEXT_MODEL).to(DEVICE)
img_model = timm.create_model(IMG_MODEL, pretrained=False, num_classes=0).to(DEVICE)
head = MultimodalRegressor().to(DEVICE)

# NEW: Dynamic load with cache
def load_model_if_needed():
    try:
        model_path = hf_hub_download(
            repo_id="MeshMax/video_tower",
            filename="finetuned_multimodal.pt",
            local_dir=None,  # CHANGED: Use default ~/.cache (persistent, no /tmp)
            local_dir_use_symlinks=False,
            cache_dir=None
        )
        print(f"Model loaded from: {model_path}")
        return model_path
    except Exception as e:
        print(f"Download failed: {e}. Retrying with force_download...")
        # Fallback: Force re-download if cache is corrupted
        model_path = hf_hub_download(
            repo_id="MeshMax/video_tower",
            filename="finetuned_multimodal.pt",
            local_dir=None,
            local_dir_use_symlinks=False,
            cache_dir=None,
            force_download=True  # Overwrite if needed
        )
        return model_path

model_path = load_model_if_needed()
ckpt = torch.load(model_path, map_location=DEVICE, weights_only=False)
if "text_model_state" in ckpt:
    text_model.load_state_dict(ckpt["text_model_state"])
if "img_model_state" in ckpt:
    img_model.load_state_dict(ckpt["img_model_state"])
if "head_state" in ckpt:
    head.load_state_dict(ckpt["head_state"])

text_model.eval()
img_model.eval()
head.eval()

transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

def compute_embedding(title, description, tags, thumbnail_url):
    text = f"{title} {description} {tags}"
    toks = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=MAX_LENGTH).to(DEVICE)
    with torch.no_grad():
        out = text_model(**toks)
        if hasattr(out, "pooler_output") and out.pooler_output is not None:
            text_emb = out.pooler_output
        else:
            text_emb = out.last_hidden_state.mean(dim=1)

    try:
        img_resp = requests.get(thumbnail_url, timeout=5)
        img_resp.raise_for_status()  # IMPROVED: Raise on HTTP errors
        img = Image.open(BytesIO(img_resp.content)).convert("RGB")
    except Exception:
        img = Image.new("RGB", (IMG_SIZE, IMG_SIZE), color=(128, 128, 128))

    img_tensor = transform(img).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        img_emb = img_model(img_tensor)
        t_proj = head.text_proj(text_emb)
        i_proj = head.img_proj(img_emb)
        attn_out, _ = head.fusion_layer(
            query=t_proj.unsqueeze(1), key=i_proj.unsqueeze(1), value=i_proj.unsqueeze(1)
        )
        fused = attn_out.squeeze(1)
        return fused.squeeze(0).cpu().numpy().tolist()  # Note: This is proj_dim=768, not 1—adjust if regression output

# --- Keep everything up to compute_embedding() unchanged ---
# (Imports, config, model loading, transform, compute_embedding)

# --- NEW: Pure Gradio Interface + Launch (remove FastAPI entirely) ---
def gradio_fn(title, description, tags, thumbnail_url):
    emb = compute_embedding(title, description, tags, thumbnail_url)
    return emb  # Returns the full list directly—Gradio stringifies it cleanly

# Create the interface (same as before)
demo = gr.Interface(  # Renamed from gr_interface for clarity
    fn=gradio_fn,
    inputs=[
        gr.Textbox(label="Title", placeholder="Enter video title..."),
        gr.Textbox(label="Description", placeholder="Enter video description..."),
        gr.Textbox(label="Tags", placeholder="Enter comma-separated tags..."),
        gr.Textbox(label="Thumbnail URL", placeholder="Enter image URL (e.g., https://example.com/thumb.jpg)...")
    ],
    outputs=gr.Textbox(label="Generated Embedding", lines=10),
    title="Video Embedding Generator",
    description="Generates fused multimodal embeddings from video metadata using LaBSE + ViT + Custom Fusion.",
    examples=[  # Optional: Add sample inputs for easy testing
        ["Test Video", "A sample description.", "ml, ai, video", "https://via.placeholder.com/224"]
    ],
    allow_flagging="never"  # Disable if not needed
)

# NEW: Launch directly (blocks and serves everything)
if __name__ == "__main__":
    demo.queue(max_size=10)  # Enable queuing for concurrency (HF handles 10+ users)
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=False,  # HF provides public URL, no need for temp share
        show_api=True,  # Exposes /api endpoint automatically (see below)
        debug=False,   # Set True for more logs during testing
        favicon_path=None  # Optional: Add custom favicon later
    )