MeshMax commited on
Commit
a5705f1
·
verified ·
1 Parent(s): fa94305

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -154
app.py CHANGED
@@ -1,42 +1,45 @@
1
  # app.py
2
  import os
3
- import io
4
- import time
5
- import json
6
  import torch
7
  import torch.nn as nn
8
- import numpy as np
9
- from PIL import Image
10
- from io import BytesIO
11
  import requests
12
-
 
 
 
13
  from fastapi import FastAPI, Request
14
  from fastapi.responses import JSONResponse
15
- import gradio as gr
16
-
17
  from transformers import AutoTokenizer, AutoModel
18
- import timm
19
  from torchvision import transforms
20
 
21
- # -----------------------
22
- # Config — mirror your notebook
23
- # -----------------------
24
- MODEL_FILENAME = "finetuned_multimodal.pt" # upload this to your Space
25
  TEXT_MODEL = "sentence-transformers/LaBSE"
26
  IMG_MODEL = "vit_base_patch16_224"
27
  IMG_SIZE = 224
28
  MAX_LENGTH = 512
29
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
30
 
31
- # -----------------------
32
- # Model class (exact from your notebook)
33
- # -----------------------
 
 
 
 
 
 
 
 
 
 
 
34
  class MultimodalRegressor(nn.Module):
35
- def __init__(self, text_dim=768, img_dim=768, proj_dim=768): # keep dims consistent with training
36
  super().__init__()
37
  self.text_proj = nn.Linear(text_dim, proj_dim)
38
  self.img_proj = nn.Linear(img_dim, proj_dim)
39
- # batch_first=True per your notebook
40
  self.fusion_layer = nn.MultiheadAttention(embed_dim=proj_dim, num_heads=8, batch_first=True)
41
  self.dropout = nn.Dropout(0.1)
42
  self.regressor = nn.Sequential(
@@ -54,162 +57,69 @@ class MultimodalRegressor(nn.Module):
54
  fused = self.dropout(fused)
55
  return self.regressor(fused).squeeze(1)
56
 
57
- # -----------------------
58
- # Utilities: image transform & helpers
59
- # -----------------------
60
- img_transform = transforms.Compose([
61
- transforms.Resize((IMG_SIZE, IMG_SIZE)),
62
- transforms.ToTensor(),
63
- transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
64
- ])
65
-
66
- def load_image_from_url(url):
67
- try:
68
- resp = requests.get(url, timeout=6)
69
- resp.raise_for_status()
70
- img = Image.open(BytesIO(resp.content)).convert("RGB")
71
- return img
72
- except Exception:
73
- # Return a gray image fallback if thumbnail fetch fails
74
- return Image.new("RGB", (IMG_SIZE, IMG_SIZE), color=(128, 128, 128))
75
-
76
- def text_to_embedding(tokenizer, text_model, texts):
77
- # texts: list[str] (batch)
78
- # Return tensor shape (batch, text_dim)
79
- text_model.eval()
80
- with torch.no_grad():
81
- toks = tokenizer(texts, padding=True, truncation=True, max_length=MAX_LENGTH, return_tensors="pt")
82
- toks = {k: v.to(DEVICE) for k, v in toks.items()}
83
- out = text_model(**toks)
84
- # prefer pooler_output if available, else mean of last_hidden_state
85
- if hasattr(out, "pooler_output") and out.pooler_output is not None:
86
- emb = out.pooler_output
87
- else:
88
- last = out.last_hidden_state # (batch, seq, dim)
89
- emb = last.mean(dim=1)
90
- return emb # already on DEVICE
91
-
92
- # -----------------------
93
- # Load pretrained backbone models + head; load checkpoint
94
- # -----------------------
95
- print("Device:", DEVICE)
96
- print("Loading tokenizer and text model:", TEXT_MODEL)
97
  tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL)
98
  text_model = AutoModel.from_pretrained(TEXT_MODEL).to(DEVICE)
99
-
100
- print("Creating image model:", IMG_MODEL)
101
- # create_model(..., num_classes=0) returns features vector for many timm models
102
  img_model = timm.create_model(IMG_MODEL, pretrained=False, num_classes=0).to(DEVICE)
 
103
 
104
- multimodal_head = MultimodalRegressor().to(DEVICE)
105
-
106
- # Load checkpoint (robust to different key names)
107
- if not os.path.exists(MODEL_FILENAME):
108
- print(f"WARNING: {MODEL_FILENAME} not found in the Space. Place your checkpoint at the repository root.")
109
- else:
110
- print("Loading checkpoint:", MODEL_FILENAME)
111
- ckpt = torch.load(MODEL_FILENAME, map_location=DEVICE)
112
- # expected keys from notebook: 'text_model_state', 'img_model_state', 'head_state'
113
- if "text_model_state" in ckpt:
114
- text_model.load_state_dict(ckpt["text_model_state"])
115
- elif "text_state_dict" in ckpt:
116
- text_model.load_state_dict(ckpt["text_state_dict"])
117
- else:
118
- print("No text_model_state found in checkpoint (skipping).")
119
-
120
- if "img_model_state" in ckpt:
121
- img_model.load_state_dict(ckpt["img_model_state"])
122
- elif "img_state_dict" in ckpt:
123
- img_model.load_state_dict(ckpt["img_state_dict"])
124
- else:
125
- print("No img_model_state found in checkpoint (skipping).")
126
-
127
- if "head_state" in ckpt:
128
- multimodal_head.load_state_dict(ckpt["head_state"])
129
- elif "head_state_dict" in ckpt:
130
- multimodal_head.load_state_dict(ckpt["head_state_dict"])
131
- else:
132
- print("No head_state found in checkpoint (skipping).")
133
 
134
  text_model.eval()
135
  img_model.eval()
136
- multimodal_head.eval()
137
- print("Models ready.")
138
 
139
- # -----------------------
140
- # Inference: create fused embedding (same pipeline notebook used)
141
- # -----------------------
142
- def compute_fused_embedding(title: str, description: str, tags: str, thumbnail_url: str):
143
- # Build text and image inputs
144
- text = " ".join([str(title or ""), str(description or ""), str(tags or "")]).strip()
145
- texts = [text]
146
-
147
- # Text embedding (batch of 1)
148
- t_emb = text_to_embedding(tokenizer, text_model, texts) # shape (1, text_dim)
149
 
150
- # Image embedding: preprocess and forward
151
- img = load_image_from_url(thumbnail_url)
152
- img_tensor = img_transform(img).unsqueeze(0).to(DEVICE) # (1,3,H,W)
153
  with torch.no_grad():
154
- i_emb = img_model(img_tensor) # expected shape (1, img_dim)
155
-
156
- # Project, fuse via head's fusion layer (exactly as in notebook)
157
- t_proj = multimodal_head.text_proj(t_emb) # (1, proj_dim)
158
- i_proj = multimodal_head.img_proj(i_emb) # (1, proj_dim)
159
-
160
- # MultiheadAttention expects (batch, seq, dim) because batch_first=True
161
- attn_out, _ = multimodal_head.fusion_layer(
162
- query=t_proj.unsqueeze(1), # (1, 1, proj_dim)
163
- key=i_proj.unsqueeze(1), # (1, 1, proj_dim)
164
- value=i_proj.unsqueeze(1) # (1, 1, proj_dim)
165
- )
166
- fused = attn_out.squeeze(1) # (1, proj_dim) -> (proj_dim,)
167
- fused_np = fused.squeeze(0).cpu().numpy().tolist()
168
- return fused_np
169
 
170
- # -----------------------
171
- # FastAPI + Gradio integration
172
- # -----------------------
173
  app = FastAPI()
174
 
175
  @app.post("/api/get_embedding")
176
  async def api_get_embedding(request: Request):
177
- payload = await request.json()
178
- title = payload.get("title", "")
179
- description = payload.get("description", "")
180
- tags = payload.get("tags", "")
181
- thumbnail_url = payload.get("thumbnail_url", "")
182
-
183
- try:
184
- emb = compute_fused_embedding(title, description, tags, thumbnail_url)
185
- except Exception as e:
186
- return JSONResponse({"error": str(e)}, status_code=500)
187
-
188
  return JSONResponse({"embedding": emb})
189
 
190
- # Gradio UI for quick testing (truncated embedding shown)
191
  def gradio_fn(title, description, tags, thumbnail_url):
192
- try:
193
- emb = compute_fused_embedding(title, description, tags, thumbnail_url)
194
- return f"embedding (len={len(emb)}): {emb[:10]} ... (truncated)"
195
- except Exception as e:
196
- return f"Error: {e}"
197
 
198
  gr_interface = gr.Interface(
199
  fn=gradio_fn,
200
- inputs=[
201
- gr.Textbox(label="Title", lines=1),
202
- gr.Textbox(label="Description", lines=3),
203
- gr.Textbox(label="Tags", lines=1),
204
- gr.Textbox(label="Thumbnail URL", lines=1),
205
- ],
206
- outputs=gr.Textbox(label="Embedding (truncated)"),
207
- title="Multimodal Embedding (Notebook -> Space)",
208
- description="Provide title, description, tags and thumbnail URL. Returns fused multimodal embedding (vector).",
209
- examples=[
210
- ["Cute cat", "A cat doing flips", "cat,funny", "https://example.com/sample.jpg"]
211
- ]
212
  )
213
 
214
- # Mount Gradio app at root
215
  app = gr.mount_gradio_app(app, gr_interface, path="/")
 
1
  # app.py
2
  import os
 
 
 
3
  import torch
4
  import torch.nn as nn
 
 
 
5
  import requests
6
+ from io import BytesIO
7
+ from PIL import Image
8
+ import timm
9
+ import gradio as gr
10
  from fastapi import FastAPI, Request
11
  from fastapi.responses import JSONResponse
 
 
12
  from transformers import AutoTokenizer, AutoModel
 
13
  from torchvision import transforms
14
 
15
+ # --- Config ---
16
+ MODEL_URL = "https://drive.google.com/uc?export=download&id=10Y_HLjflL54H7iwP1oz1ZG1SV4SsK6Qw"
17
+ MODEL_FILENAME = "finetuned_multimodal.pt"
 
18
  TEXT_MODEL = "sentence-transformers/LaBSE"
19
  IMG_MODEL = "vit_base_patch16_224"
20
  IMG_SIZE = 224
21
  MAX_LENGTH = 512
22
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
23
 
24
+ # --- Download model from Google Drive ---
25
+ if not os.path.exists(MODEL_FILENAME):
26
+ print(f"Downloading checkpoint from {MODEL_URL} ...")
27
+ r = requests.get(MODEL_URL, stream=True)
28
+ r.raise_for_status()
29
+ with open(MODEL_FILENAME, "wb") as f:
30
+ for chunk in r.iter_content(chunk_size=8192):
31
+ if chunk:
32
+ f.write(chunk)
33
+ print("Download complete.")
34
+ else:
35
+ print("Checkpoint already exists locally.")
36
+
37
+ # --- Define model ---
38
  class MultimodalRegressor(nn.Module):
39
+ def __init__(self, text_dim=768, img_dim=768, proj_dim=768):
40
  super().__init__()
41
  self.text_proj = nn.Linear(text_dim, proj_dim)
42
  self.img_proj = nn.Linear(img_dim, proj_dim)
 
43
  self.fusion_layer = nn.MultiheadAttention(embed_dim=proj_dim, num_heads=8, batch_first=True)
44
  self.dropout = nn.Dropout(0.1)
45
  self.regressor = nn.Sequential(
 
57
  fused = self.dropout(fused)
58
  return self.regressor(fused).squeeze(1)
59
 
60
+ # --- Load models ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL)
62
  text_model = AutoModel.from_pretrained(TEXT_MODEL).to(DEVICE)
 
 
 
63
  img_model = timm.create_model(IMG_MODEL, pretrained=False, num_classes=0).to(DEVICE)
64
+ head = MultimodalRegressor().to(DEVICE)
65
 
66
+ ckpt = torch.load(MODEL_FILENAME, map_location=DEVICE)
67
+ if "text_model_state" in ckpt:
68
+ text_model.load_state_dict(ckpt["text_model_state"])
69
+ if "img_model_state" in ckpt:
70
+ img_model.load_state_dict(ckpt["img_model_state"])
71
+ if "head_state" in ckpt:
72
+ head.load_state_dict(ckpt["head_state"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  text_model.eval()
75
  img_model.eval()
76
+ head.eval()
 
77
 
78
+ transform = transforms.Compose([
79
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
80
+ transforms.ToTensor(),
81
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
82
+ ])
 
 
 
 
 
83
 
84
+ def compute_embedding(title, description, tags, thumbnail_url):
85
+ text = f"{title} {description} {tags}"
86
+ toks = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=MAX_LENGTH).to(DEVICE)
87
  with torch.no_grad():
88
+ text_emb = text_model(**toks).pooler_output
89
+ try:
90
+ img = Image.open(BytesIO(requests.get(thumbnail_url).content)).convert("RGB")
91
+ except Exception:
92
+ img = Image.new("RGB", (IMG_SIZE, IMG_SIZE), color=(128, 128, 128))
93
+ img_tensor = transform(img).unsqueeze(0).to(DEVICE)
94
+ with torch.no_grad():
95
+ img_emb = img_model(img_tensor)
96
+ t_proj = head.text_proj(text_emb)
97
+ i_proj = head.img_proj(img_emb)
98
+ attn_out, _ = head.fusion_layer(query=t_proj.unsqueeze(1), key=i_proj.unsqueeze(1), value=i_proj.unsqueeze(1))
99
+ fused = attn_out.squeeze(1)
100
+ return fused.squeeze(0).cpu().numpy().tolist()
 
 
101
 
102
+ # --- FastAPI + Gradio ---
 
 
103
  app = FastAPI()
104
 
105
  @app.post("/api/get_embedding")
106
  async def api_get_embedding(request: Request):
107
+ data = await request.json()
108
+ emb = compute_embedding(data.get("title", ""), data.get("description", ""),
109
+ data.get("tags", ""), data.get("thumbnail_url", ""))
 
 
 
 
 
 
 
 
110
  return JSONResponse({"embedding": emb})
111
 
 
112
  def gradio_fn(title, description, tags, thumbnail_url):
113
+ emb = compute_embedding(title, description, tags, thumbnail_url)
114
+ return f"Embedding length {len(emb)}; first 10: {emb[:10]}"
 
 
 
115
 
116
  gr_interface = gr.Interface(
117
  fn=gradio_fn,
118
+ inputs=[gr.Textbox(label="Title"), gr.Textbox(label="Description"),
119
+ gr.Textbox(label="Tags"), gr.Textbox(label="Thumbnail URL")],
120
+ outputs="text",
121
+ title="Video Embedding Generator",
122
+ description="Generates fused multimodal embeddings from video metadata and thumbnail."
 
 
 
 
 
 
 
123
  )
124
 
 
125
  app = gr.mount_gradio_app(app, gr_interface, path="/")