MeshMax commited on
Commit
d36dff5
·
verified ·
1 Parent(s): b8fc529

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -22
app.py CHANGED
@@ -1,4 +1,3 @@
1
- # app.py
2
  import os
3
  import torch
4
  import torch.nn as nn
@@ -11,30 +10,16 @@ from fastapi import FastAPI, Request
11
  from fastapi.responses import JSONResponse
12
  from transformers import AutoTokenizer, AutoModel
13
  from torchvision import transforms
 
14
 
15
  # --- Config ---
16
- MODEL_URL = "https://huggingface.co/MeshMax/video_tower/resolve/main/finetuned_multimodal.pt?download=true"
17
- MODEL_FILENAME = "finetuned_multimodal.pt"
18
  TEXT_MODEL = "sentence-transformers/LaBSE"
19
  IMG_MODEL = "vit_base_patch16_224"
20
  IMG_SIZE = 224
21
  MAX_LENGTH = 512
22
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
23
 
24
- # --- Download checkpoint if not already present ---
25
- if not os.path.exists(MODEL_FILENAME):
26
- print(f"Downloading model from {MODEL_URL} ...")
27
- response = requests.get(MODEL_URL, stream=True)
28
- response.raise_for_status()
29
- with open(MODEL_FILENAME, "wb") as f:
30
- for chunk in response.iter_content(chunk_size=8192):
31
- if chunk:
32
- f.write(chunk)
33
- print("Download done.")
34
- else:
35
- print("Model file already exists.")
36
-
37
- # --- Model definition (same as before) ---
38
  class MultimodalRegressor(nn.Module):
39
  def __init__(self, text_dim=768, img_dim=768, proj_dim=768):
40
  super().__init__()
@@ -63,7 +48,20 @@ text_model = AutoModel.from_pretrained(TEXT_MODEL).to(DEVICE)
63
  img_model = timm.create_model(IMG_MODEL, pretrained=False, num_classes=0).to(DEVICE)
64
  head = MultimodalRegressor().to(DEVICE)
65
 
66
- ckpt = torch.load(MODEL_FILENAME, map_location=DEVICE, weights_only=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  if "text_model_state" in ckpt:
68
  text_model.load_state_dict(ckpt["text_model_state"])
69
  if "img_model_state" in ckpt:
@@ -85,7 +83,6 @@ def compute_embedding(title, description, tags, thumbnail_url):
85
  text = f"{title} {description} {tags}"
86
  toks = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=MAX_LENGTH).to(DEVICE)
87
  with torch.no_grad():
88
- # Using pooler_output or fallback
89
  out = text_model(**toks)
90
  if hasattr(out, "pooler_output") and out.pooler_output is not None:
91
  text_emb = out.pooler_output
@@ -94,6 +91,7 @@ def compute_embedding(title, description, tags, thumbnail_url):
94
 
95
  try:
96
  img_resp = requests.get(thumbnail_url, timeout=5)
 
97
  img = Image.open(BytesIO(img_resp.content)).convert("RGB")
98
  except Exception:
99
  img = Image.new("RGB", (IMG_SIZE, IMG_SIZE), color=(128, 128, 128))
@@ -107,9 +105,9 @@ def compute_embedding(title, description, tags, thumbnail_url):
107
  query=t_proj.unsqueeze(1), key=i_proj.unsqueeze(1), value=i_proj.unsqueeze(1)
108
  )
109
  fused = attn_out.squeeze(1)
110
- return fused.squeeze(0).cpu().numpy().tolist()
111
 
112
- # --- FastAPI + Gradio integration ---
113
  app = FastAPI()
114
 
115
  @app.post("/api/get_embedding")
@@ -134,4 +132,4 @@ gr_interface = gr.Interface(
134
  description="Generates fused multimodal embeddings from video metadata",
135
  )
136
 
137
- app = gr.mount_gradio_app(app, gr_interface, path="/")
 
 
1
  import os
2
  import torch
3
  import torch.nn as nn
 
10
  from fastapi.responses import JSONResponse
11
  from transformers import AutoTokenizer, AutoModel
12
  from torchvision import transforms
13
+ from huggingface_hub import hf_hub_download # NEW
14
 
15
  # --- Config ---
 
 
16
  TEXT_MODEL = "sentence-transformers/LaBSE"
17
  IMG_MODEL = "vit_base_patch16_224"
18
  IMG_SIZE = 224
19
  MAX_LENGTH = 512
20
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
21
 
22
+ # --- Model definition (unchanged) ---
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  class MultimodalRegressor(nn.Module):
24
  def __init__(self, text_dim=768, img_dim=768, proj_dim=768):
25
  super().__init__()
 
48
  img_model = timm.create_model(IMG_MODEL, pretrained=False, num_classes=0).to(DEVICE)
49
  head = MultimodalRegressor().to(DEVICE)
50
 
51
+ # NEW: Dynamic load with cache
52
+ def load_model_if_needed():
53
+ model_path = hf_hub_download(
54
+ repo_id="MeshMax/video_tower",
55
+ filename="finetuned_multimodal.pt",
56
+ local_dir="/tmp",
57
+ local_dir_use_symlinks=False,
58
+ cache_dir=None
59
+ )
60
+ print(f"Model loaded from: {model_path}")
61
+ return model_path
62
+
63
+ model_path = load_model_if_needed()
64
+ ckpt = torch.load(model_path, map_location=DEVICE, weights_only=False)
65
  if "text_model_state" in ckpt:
66
  text_model.load_state_dict(ckpt["text_model_state"])
67
  if "img_model_state" in ckpt:
 
83
  text = f"{title} {description} {tags}"
84
  toks = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=MAX_LENGTH).to(DEVICE)
85
  with torch.no_grad():
 
86
  out = text_model(**toks)
87
  if hasattr(out, "pooler_output") and out.pooler_output is not None:
88
  text_emb = out.pooler_output
 
91
 
92
  try:
93
  img_resp = requests.get(thumbnail_url, timeout=5)
94
+ img_resp.raise_for_status() # IMPROVED: Raise on HTTP errors
95
  img = Image.open(BytesIO(img_resp.content)).convert("RGB")
96
  except Exception:
97
  img = Image.new("RGB", (IMG_SIZE, IMG_SIZE), color=(128, 128, 128))
 
105
  query=t_proj.unsqueeze(1), key=i_proj.unsqueeze(1), value=i_proj.unsqueeze(1)
106
  )
107
  fused = attn_out.squeeze(1)
108
+ return fused.squeeze(0).cpu().numpy().tolist() # Note: This is proj_dim=768, not 1—adjust if regression output
109
 
110
+ # --- FastAPI + Gradio (unchanged) ---
111
  app = FastAPI()
112
 
113
  @app.post("/api/get_embedding")
 
132
  description="Generates fused multimodal embeddings from video metadata",
133
  )
134
 
135
+ app = gr.mount_gradio_app(app, gr_interface, path="/")