Update app.py
Browse files
app.py
CHANGED
|
@@ -50,15 +50,28 @@ head = MultimodalRegressor().to(DEVICE)
|
|
| 50 |
|
| 51 |
# NEW: Dynamic load with cache
|
| 52 |
def load_model_if_needed():
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
model_path = load_model_if_needed()
|
| 64 |
ckpt = torch.load(model_path, map_location=DEVICE, weights_only=False)
|
|
|
|
| 50 |
|
| 51 |
# NEW: Dynamic load with cache
|
| 52 |
def load_model_if_needed():
|
| 53 |
+
try:
|
| 54 |
+
model_path = hf_hub_download(
|
| 55 |
+
repo_id="MeshMax/video_tower",
|
| 56 |
+
filename="finetuned_multimodal.pt",
|
| 57 |
+
local_dir=None, # CHANGED: Use default ~/.cache (persistent, no /tmp)
|
| 58 |
+
local_dir_use_symlinks=False,
|
| 59 |
+
cache_dir=None
|
| 60 |
+
)
|
| 61 |
+
print(f"Model loaded from: {model_path}")
|
| 62 |
+
return model_path
|
| 63 |
+
except Exception as e:
|
| 64 |
+
print(f"Download failed: {e}. Retrying with force_download...")
|
| 65 |
+
# Fallback: Force re-download if cache is corrupted
|
| 66 |
+
model_path = hf_hub_download(
|
| 67 |
+
repo_id="MeshMax/video_tower",
|
| 68 |
+
filename="finetuned_multimodal.pt",
|
| 69 |
+
local_dir=None,
|
| 70 |
+
local_dir_use_symlinks=False,
|
| 71 |
+
cache_dir=None,
|
| 72 |
+
force_download=True # Overwrite if needed
|
| 73 |
+
)
|
| 74 |
+
return model_path
|
| 75 |
|
| 76 |
model_path = load_model_if_needed()
|
| 77 |
ckpt = torch.load(model_path, map_location=DEVICE, weights_only=False)
|