MeshMax commited on
Commit
408b528
·
verified ·
1 Parent(s): a379734

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -18
app.py CHANGED
@@ -13,7 +13,7 @@ from transformers import AutoTokenizer, AutoModel
13
  from torchvision import transforms
14
 
15
  # --- Config ---
16
- MODEL_URL = "https://drive.google.com/uc?export=download&id=10Y_HLjflL54H7iwP1oz1ZG1SV4SsK6Qw"
17
  MODEL_FILENAME = "finetuned_multimodal.pt"
18
  TEXT_MODEL = "sentence-transformers/LaBSE"
19
  IMG_MODEL = "vit_base_patch16_224"
@@ -21,20 +21,20 @@ IMG_SIZE = 224
21
  MAX_LENGTH = 512
22
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
23
 
24
- # --- Download model from Google Drive ---
25
  if not os.path.exists(MODEL_FILENAME):
26
- print(f"Downloading checkpoint from {MODEL_URL} ...")
27
- r = requests.get(MODEL_URL, stream=True)
28
- r.raise_for_status()
29
  with open(MODEL_FILENAME, "wb") as f:
30
- for chunk in r.iter_content(chunk_size=8192):
31
  if chunk:
32
  f.write(chunk)
33
- print("Download complete.")
34
  else:
35
- print("Checkpoint already exists locally.")
36
 
37
- # --- Define model ---
38
  class MultimodalRegressor(nn.Module):
39
  def __init__(self, text_dim=768, img_dim=768, proj_dim=768):
40
  super().__init__()
@@ -57,7 +57,7 @@ class MultimodalRegressor(nn.Module):
57
  fused = self.dropout(fused)
58
  return self.regressor(fused).squeeze(1)
59
 
60
- # --- Load models ---
61
  tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL)
62
  text_model = AutoModel.from_pretrained(TEXT_MODEL).to(DEVICE)
63
  img_model = timm.create_model(IMG_MODEL, pretrained=False, num_classes=0).to(DEVICE)
@@ -85,33 +85,45 @@ def compute_embedding(title, description, tags, thumbnail_url):
85
  text = f"{title} {description} {tags}"
86
  toks = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=MAX_LENGTH).to(DEVICE)
87
  with torch.no_grad():
88
- text_emb = text_model(**toks).pooler_output
 
 
 
 
 
 
89
  try:
90
- img = Image.open(BytesIO(requests.get(thumbnail_url).content)).convert("RGB")
 
91
  except Exception:
92
  img = Image.new("RGB", (IMG_SIZE, IMG_SIZE), color=(128, 128, 128))
 
93
  img_tensor = transform(img).unsqueeze(0).to(DEVICE)
94
  with torch.no_grad():
95
  img_emb = img_model(img_tensor)
96
  t_proj = head.text_proj(text_emb)
97
  i_proj = head.img_proj(img_emb)
98
- attn_out, _ = head.fusion_layer(query=t_proj.unsqueeze(1), key=i_proj.unsqueeze(1), value=i_proj.unsqueeze(1))
 
 
99
  fused = attn_out.squeeze(1)
100
  return fused.squeeze(0).cpu().numpy().tolist()
101
 
102
- # --- FastAPI + Gradio ---
103
  app = FastAPI()
104
 
105
  @app.post("/api/get_embedding")
106
  async def api_get_embedding(request: Request):
107
  data = await request.json()
108
- emb = compute_embedding(data.get("title", ""), data.get("description", ""),
109
- data.get("tags", ""), data.get("thumbnail_url", ""))
 
 
110
  return JSONResponse({"embedding": emb})
111
 
112
  def gradio_fn(title, description, tags, thumbnail_url):
113
  emb = compute_embedding(title, description, tags, thumbnail_url)
114
- return f"Embedding length {len(emb)}; first 10: {emb[:10]}"
115
 
116
  gr_interface = gr.Interface(
117
  fn=gradio_fn,
@@ -119,7 +131,7 @@ gr_interface = gr.Interface(
119
  gr.Textbox(label="Tags"), gr.Textbox(label="Thumbnail URL")],
120
  outputs="text",
121
  title="Video Embedding Generator",
122
- description="Generates fused multimodal embeddings from video metadata and thumbnail."
123
  )
124
 
125
  app = gr.mount_gradio_app(app, gr_interface, path="/")
 
13
  from torchvision import transforms
14
 
15
  # --- Config ---
16
+ MODEL_URL = "https://huggingface.co/MeshMax/video_tower/resolve/main/finetuned_multimodal.pt"
17
  MODEL_FILENAME = "finetuned_multimodal.pt"
18
  TEXT_MODEL = "sentence-transformers/LaBSE"
19
  IMG_MODEL = "vit_base_patch16_224"
 
21
  MAX_LENGTH = 512
22
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
23
 
24
+ # --- Download checkpoint if not already present ---
25
  if not os.path.exists(MODEL_FILENAME):
26
+ print(f"Downloading model from {MODEL_URL} ...")
27
+ response = requests.get(MODEL_URL, stream=True)
28
+ response.raise_for_status()
29
  with open(MODEL_FILENAME, "wb") as f:
30
+ for chunk in response.iter_content(chunk_size=8192):
31
  if chunk:
32
  f.write(chunk)
33
+ print("Download done.")
34
  else:
35
+ print("Model file already exists.")
36
 
37
+ # --- Model definition (same as before) ---
38
  class MultimodalRegressor(nn.Module):
39
  def __init__(self, text_dim=768, img_dim=768, proj_dim=768):
40
  super().__init__()
 
57
  fused = self.dropout(fused)
58
  return self.regressor(fused).squeeze(1)
59
 
60
+ # --- Load backbone models + head ---
61
  tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL)
62
  text_model = AutoModel.from_pretrained(TEXT_MODEL).to(DEVICE)
63
  img_model = timm.create_model(IMG_MODEL, pretrained=False, num_classes=0).to(DEVICE)
 
85
  text = f"{title} {description} {tags}"
86
  toks = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=MAX_LENGTH).to(DEVICE)
87
  with torch.no_grad():
88
+ # Using pooler_output or fallback
89
+ out = text_model(**toks)
90
+ if hasattr(out, "pooler_output") and out.pooler_output is not None:
91
+ text_emb = out.pooler_output
92
+ else:
93
+ text_emb = out.last_hidden_state.mean(dim=1)
94
+
95
  try:
96
+ img_resp = requests.get(thumbnail_url, timeout=5)
97
+ img = Image.open(BytesIO(img_resp.content)).convert("RGB")
98
  except Exception:
99
  img = Image.new("RGB", (IMG_SIZE, IMG_SIZE), color=(128, 128, 128))
100
+
101
  img_tensor = transform(img).unsqueeze(0).to(DEVICE)
102
  with torch.no_grad():
103
  img_emb = img_model(img_tensor)
104
  t_proj = head.text_proj(text_emb)
105
  i_proj = head.img_proj(img_emb)
106
+ attn_out, _ = head.fusion_layer(
107
+ query=t_proj.unsqueeze(1), key=i_proj.unsqueeze(1), value=i_proj.unsqueeze(1)
108
+ )
109
  fused = attn_out.squeeze(1)
110
  return fused.squeeze(0).cpu().numpy().tolist()
111
 
112
+ # --- FastAPI + Gradio integration ---
113
  app = FastAPI()
114
 
115
  @app.post("/api/get_embedding")
116
  async def api_get_embedding(request: Request):
117
  data = await request.json()
118
+ emb = compute_embedding(
119
+ data.get("title", ""), data.get("description", ""),
120
+ data.get("tags", ""), data.get("thumbnail_url", "")
121
+ )
122
  return JSONResponse({"embedding": emb})
123
 
124
  def gradio_fn(title, description, tags, thumbnail_url):
125
  emb = compute_embedding(title, description, tags, thumbnail_url)
126
+ return f"Embedding length={len(emb)}; first 10: {emb[:10]}"
127
 
128
  gr_interface = gr.Interface(
129
  fn=gradio_fn,
 
131
  gr.Textbox(label="Tags"), gr.Textbox(label="Thumbnail URL")],
132
  outputs="text",
133
  title="Video Embedding Generator",
134
+ description="Generates fused multimodal embeddings from video metadata",
135
  )
136
 
137
  app = gr.mount_gradio_app(app, gr_interface, path="/")