MeshMax commited on
Commit
ab0f28d
·
verified ·
1 Parent(s): d35c7a5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +215 -0
app.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os
3
+ import io
4
+ import time
5
+ import json
6
+ import torch
7
+ import torch.nn as nn
8
+ import numpy as np
9
+ from PIL import Image
10
+ from io import BytesIO
11
+ import requests
12
+
13
+ from fastapi import FastAPI, Request
14
+ from fastapi.responses import JSONResponse
15
+ import gradio as gr
16
+
17
+ from transformers import AutoTokenizer, AutoModel
18
+ import timm
19
+ from torchvision import transforms
20
+
21
+ # -----------------------
22
+ # Config — mirror your notebook
23
+ # -----------------------
24
+ MODEL_FILENAME = "finetuned_multimodal.pt" # upload this to your Space
25
+ TEXT_MODEL = "sentence-transformers/LaBSE"
26
+ IMG_MODEL = "vit_base_patch16_224"
27
+ IMG_SIZE = 224
28
+ MAX_LENGTH = 512
29
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
30
+
31
+ # -----------------------
32
+ # Model class (exact from your notebook)
33
+ # -----------------------
34
+ class MultimodalRegressor(nn.Module):
35
+ def __init__(self, text_dim=768, img_dim=768, proj_dim=768): # keep dims consistent with training
36
+ super().__init__()
37
+ self.text_proj = nn.Linear(text_dim, proj_dim)
38
+ self.img_proj = nn.Linear(img_dim, proj_dim)
39
+ # batch_first=True per your notebook
40
+ self.fusion_layer = nn.MultiheadAttention(embed_dim=proj_dim, num_heads=8, batch_first=True)
41
+ self.dropout = nn.Dropout(0.1)
42
+ self.regressor = nn.Sequential(
43
+ nn.Linear(proj_dim, proj_dim // 2),
44
+ nn.ReLU(),
45
+ nn.Dropout(0.1),
46
+ nn.Linear(proj_dim // 2, 1)
47
+ )
48
+
49
+ def forward(self, text_emb, img_emb):
50
+ t = self.text_proj(text_emb).unsqueeze(1)
51
+ i = self.img_proj(img_emb).unsqueeze(1)
52
+ attn_out, _ = self.fusion_layer(query=t, key=i, value=i)
53
+ fused = attn_out.squeeze(1)
54
+ fused = self.dropout(fused)
55
+ return self.regressor(fused).squeeze(1)
56
+
57
+ # -----------------------
58
+ # Utilities: image transform & helpers
59
+ # -----------------------
60
+ img_transform = transforms.Compose([
61
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
62
+ transforms.ToTensor(),
63
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
64
+ ])
65
+
66
+ def load_image_from_url(url):
67
+ try:
68
+ resp = requests.get(url, timeout=6)
69
+ resp.raise_for_status()
70
+ img = Image.open(BytesIO(resp.content)).convert("RGB")
71
+ return img
72
+ except Exception:
73
+ # Return a gray image fallback if thumbnail fetch fails
74
+ return Image.new("RGB", (IMG_SIZE, IMG_SIZE), color=(128, 128, 128))
75
+
76
+ def text_to_embedding(tokenizer, text_model, texts):
77
+ # texts: list[str] (batch)
78
+ # Return tensor shape (batch, text_dim)
79
+ text_model.eval()
80
+ with torch.no_grad():
81
+ toks = tokenizer(texts, padding=True, truncation=True, max_length=MAX_LENGTH, return_tensors="pt")
82
+ toks = {k: v.to(DEVICE) for k, v in toks.items()}
83
+ out = text_model(**toks)
84
+ # prefer pooler_output if available, else mean of last_hidden_state
85
+ if hasattr(out, "pooler_output") and out.pooler_output is not None:
86
+ emb = out.pooler_output
87
+ else:
88
+ last = out.last_hidden_state # (batch, seq, dim)
89
+ emb = last.mean(dim=1)
90
+ return emb # already on DEVICE
91
+
92
+ # -----------------------
93
+ # Load pretrained backbone models + head; load checkpoint
94
+ # -----------------------
95
+ print("Device:", DEVICE)
96
+ print("Loading tokenizer and text model:", TEXT_MODEL)
97
+ tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL)
98
+ text_model = AutoModel.from_pretrained(TEXT_MODEL).to(DEVICE)
99
+
100
+ print("Creating image model:", IMG_MODEL)
101
+ # create_model(..., num_classes=0) returns features vector for many timm models
102
+ img_model = timm.create_model(IMG_MODEL, pretrained=False, num_classes=0).to(DEVICE)
103
+
104
+ multimodal_head = MultimodalRegressor().to(DEVICE)
105
+
106
+ # Load checkpoint (robust to different key names)
107
+ if not os.path.exists(MODEL_FILENAME):
108
+ print(f"WARNING: {MODEL_FILENAME} not found in the Space. Place your checkpoint at the repository root.")
109
+ else:
110
+ print("Loading checkpoint:", MODEL_FILENAME)
111
+ ckpt = torch.load(MODEL_FILENAME, map_location=DEVICE)
112
+ # expected keys from notebook: 'text_model_state', 'img_model_state', 'head_state'
113
+ if "text_model_state" in ckpt:
114
+ text_model.load_state_dict(ckpt["text_model_state"])
115
+ elif "text_state_dict" in ckpt:
116
+ text_model.load_state_dict(ckpt["text_state_dict"])
117
+ else:
118
+ print("No text_model_state found in checkpoint (skipping).")
119
+
120
+ if "img_model_state" in ckpt:
121
+ img_model.load_state_dict(ckpt["img_model_state"])
122
+ elif "img_state_dict" in ckpt:
123
+ img_model.load_state_dict(ckpt["img_state_dict"])
124
+ else:
125
+ print("No img_model_state found in checkpoint (skipping).")
126
+
127
+ if "head_state" in ckpt:
128
+ multimodal_head.load_state_dict(ckpt["head_state"])
129
+ elif "head_state_dict" in ckpt:
130
+ multimodal_head.load_state_dict(ckpt["head_state_dict"])
131
+ else:
132
+ print("No head_state found in checkpoint (skipping).")
133
+
134
+ text_model.eval()
135
+ img_model.eval()
136
+ multimodal_head.eval()
137
+ print("Models ready.")
138
+
139
+ # -----------------------
140
+ # Inference: create fused embedding (same pipeline notebook used)
141
+ # -----------------------
142
+ def compute_fused_embedding(title: str, description: str, tags: str, thumbnail_url: str):
143
+ # Build text and image inputs
144
+ text = " ".join([str(title or ""), str(description or ""), str(tags or "")]).strip()
145
+ texts = [text]
146
+
147
+ # Text embedding (batch of 1)
148
+ t_emb = text_to_embedding(tokenizer, text_model, texts) # shape (1, text_dim)
149
+
150
+ # Image embedding: preprocess and forward
151
+ img = load_image_from_url(thumbnail_url)
152
+ img_tensor = img_transform(img).unsqueeze(0).to(DEVICE) # (1,3,H,W)
153
+ with torch.no_grad():
154
+ i_emb = img_model(img_tensor) # expected shape (1, img_dim)
155
+
156
+ # Project, fuse via head's fusion layer (exactly as in notebook)
157
+ t_proj = multimodal_head.text_proj(t_emb) # (1, proj_dim)
158
+ i_proj = multimodal_head.img_proj(i_emb) # (1, proj_dim)
159
+
160
+ # MultiheadAttention expects (batch, seq, dim) because batch_first=True
161
+ attn_out, _ = multimodal_head.fusion_layer(
162
+ query=t_proj.unsqueeze(1), # (1, 1, proj_dim)
163
+ key=i_proj.unsqueeze(1), # (1, 1, proj_dim)
164
+ value=i_proj.unsqueeze(1) # (1, 1, proj_dim)
165
+ )
166
+ fused = attn_out.squeeze(1) # (1, proj_dim) -> (proj_dim,)
167
+ fused_np = fused.squeeze(0).cpu().numpy().tolist()
168
+ return fused_np
169
+
170
+ # -----------------------
171
+ # FastAPI + Gradio integration
172
+ # -----------------------
173
+ app = FastAPI()
174
+
175
+ @app.post("/api/get_embedding")
176
+ async def api_get_embedding(request: Request):
177
+ payload = await request.json()
178
+ title = payload.get("title", "")
179
+ description = payload.get("description", "")
180
+ tags = payload.get("tags", "")
181
+ thumbnail_url = payload.get("thumbnail_url", "")
182
+
183
+ try:
184
+ emb = compute_fused_embedding(title, description, tags, thumbnail_url)
185
+ except Exception as e:
186
+ return JSONResponse({"error": str(e)}, status_code=500)
187
+
188
+ return JSONResponse({"embedding": emb})
189
+
190
+ # Gradio UI for quick testing (truncated embedding shown)
191
+ def gradio_fn(title, description, tags, thumbnail_url):
192
+ try:
193
+ emb = compute_fused_embedding(title, description, tags, thumbnail_url)
194
+ return f"embedding (len={len(emb)}): {emb[:10]} ... (truncated)"
195
+ except Exception as e:
196
+ return f"Error: {e}"
197
+
198
+ gr_interface = gr.Interface(
199
+ fn=gradio_fn,
200
+ inputs=[
201
+ gr.Textbox(label="Title", lines=1),
202
+ gr.Textbox(label="Description", lines=3),
203
+ gr.Textbox(label="Tags", lines=1),
204
+ gr.Textbox(label="Thumbnail URL", lines=1),
205
+ ],
206
+ outputs=gr.Textbox(label="Embedding (truncated)"),
207
+ title="Multimodal Embedding (Notebook -> Space)",
208
+ description="Provide title, description, tags and thumbnail URL. Returns fused multimodal embedding (vector).",
209
+ examples=[
210
+ ["Cute cat", "A cat doing flips", "cat,funny", "https://example.com/sample.jpg"]
211
+ ]
212
+ )
213
+
214
+ # Mount Gradio app at root
215
+ app = gr.mount_gradio_app(app, gr_interface, path="/")