Spaces:
Sleeping
Sleeping
Commit
·
9b8ec7b
1
Parent(s):
a025d2f
Fix downloaded file path
Browse files
streamlit_apps/app_utils/depth_model.py
CHANGED
|
@@ -39,13 +39,16 @@ class DPTDepth(BaseDepthModel):
|
|
| 39 |
weights_path = os.path.join("weights", weights_fname)
|
| 40 |
if not os.path.isfile(weights_path):
|
| 41 |
from huggingface_hub import hf_hub_download
|
| 42 |
-
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
| 44 |
omnidata_ckpt = torch.load(
|
| 45 |
weights_path,
|
| 46 |
map_location="cpu",
|
| 47 |
)
|
| 48 |
-
|
| 49 |
self.model = DPTDepthModel()
|
| 50 |
self.model.load_state_dict(omnidata_ckpt)
|
| 51 |
self.model: DPTDepthModel = self.model.to(device).eval()
|
|
|
|
| 39 |
weights_path = os.path.join("weights", weights_fname)
|
| 40 |
if not os.path.isfile(weights_path):
|
| 41 |
from huggingface_hub import hf_hub_download
|
| 42 |
+
|
| 43 |
+
downloaded_filepath = hf_hub_download(
|
| 44 |
+
repo_id="RGBD-SOD/S-MultiMAE", filename=weights_fname
|
| 45 |
+
)
|
| 46 |
+
os.system(f'mv "{downloaded_filepath}" weights')
|
| 47 |
omnidata_ckpt = torch.load(
|
| 48 |
weights_path,
|
| 49 |
map_location="cpu",
|
| 50 |
)
|
| 51 |
+
|
| 52 |
self.model = DPTDepthModel()
|
| 53 |
self.model.load_state_dict(omnidata_ckpt)
|
| 54 |
self.model: DPTDepthModel = self.model.to(device).eval()
|
streamlit_apps/app_utils/sod_selection_ui.py
CHANGED
|
@@ -29,14 +29,15 @@ def load_smultimae_model(
|
|
| 29 |
cfg = arg_cfg[sod_model_config_key]()
|
| 30 |
|
| 31 |
weights_fname = f"s-multimae-{cfg.experiment_name}-top{top}.pth"
|
| 32 |
-
ckpt_path = os.path.join(
|
| 33 |
-
"weights", weights_fname
|
| 34 |
-
)
|
| 35 |
print(ckpt_path)
|
| 36 |
if not os.path.isfile(ckpt_path):
|
| 37 |
from huggingface_hub import hf_hub_download
|
| 38 |
-
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
| 40 |
assert os.path.isfile(ckpt_path)
|
| 41 |
|
| 42 |
# sod_model = ModelPL.load_from_checkpoint(
|
|
|
|
| 29 |
cfg = arg_cfg[sod_model_config_key]()
|
| 30 |
|
| 31 |
weights_fname = f"s-multimae-{cfg.experiment_name}-top{top}.pth"
|
| 32 |
+
ckpt_path = os.path.join("weights", weights_fname)
|
|
|
|
|
|
|
| 33 |
print(ckpt_path)
|
| 34 |
if not os.path.isfile(ckpt_path):
|
| 35 |
from huggingface_hub import hf_hub_download
|
| 36 |
+
|
| 37 |
+
downloaded_filepath = hf_hub_download(
|
| 38 |
+
repo_id="RGBD-SOD/S-MultiMAE", filename=weights_fname
|
| 39 |
+
)
|
| 40 |
+
os.system(f'mv "{downloaded_filepath}" weights')
|
| 41 |
assert os.path.isfile(ckpt_path)
|
| 42 |
|
| 43 |
# sod_model = ModelPL.load_from_checkpoint(
|