sharifIslam commited on
Commit
9d22a21
·
1 Parent(s): d423f20

Add automatic weight download

Browse files
Files changed (1) hide show
  1. model.py +20 -0
model.py CHANGED
@@ -1,11 +1,31 @@
1
  import torch
2
  import logging
 
 
3
  from dust3r.model import AsymmetricCroCo3DStereo
4
 
5
  logger = logging.getLogger(__name__)
6
 
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  def initialize(model_path: str, device: str) -> torch.nn.Module:
 
9
  logger.info(f"Loading model from: {model_path}")
10
 
11
  logger.info("Loading checkpoint...")
 
1
  import torch
2
  import logging
3
+ import os
4
+ from pathlib import Path
5
  from dust3r.model import AsymmetricCroCo3DStereo
6
 
7
  logger = logging.getLogger(__name__)
8
 
9
 
10
+ def download_weights(model_path: str):
11
+ """Download model weights if they don't exist"""
12
+ if os.path.exists(model_path):
13
+ logger.info(f"Weights already exist at {model_path}")
14
+ return
15
+
16
+ logger.info("Weights not found. Downloading...")
17
+ Path(model_path).parent.mkdir(parents=True, exist_ok=True)
18
+
19
+ import urllib.request
20
+ url = "https://huggingface.co/camenduru/dust3r/resolve/main/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth"
21
+
22
+ logger.info(f"Downloading from {url}")
23
+ urllib.request.urlretrieve(url, model_path)
24
+ logger.info("Download complete!")
25
+
26
+
27
  def initialize(model_path: str, device: str) -> torch.nn.Module:
28
+ download_weights(model_path)
29
  logger.info(f"Loading model from: {model_path}")
30
 
31
  logger.info("Loading checkpoint...")