Spaces:
Runtime error
Runtime error
Commit ·
9d22a21
1
Parent(s): d423f20
Add automatic weight download
Browse files
model.py
CHANGED
|
@@ -1,11 +1,31 @@
|
|
| 1 |
import torch
|
| 2 |
import logging
|
|
|
|
|
|
|
| 3 |
from dust3r.model import AsymmetricCroCo3DStereo
|
| 4 |
|
| 5 |
logger = logging.getLogger(__name__)
|
| 6 |
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
def initialize(model_path: str, device: str) -> torch.nn.Module:
|
|
|
|
| 9 |
logger.info(f"Loading model from: {model_path}")
|
| 10 |
|
| 11 |
logger.info("Loading checkpoint...")
|
|
|
|
| 1 |
import torch
|
| 2 |
import logging
|
| 3 |
+
import os
|
| 4 |
+
from pathlib import Path
|
| 5 |
from dust3r.model import AsymmetricCroCo3DStereo
|
| 6 |
|
| 7 |
logger = logging.getLogger(__name__)
|
| 8 |
|
| 9 |
|
| 10 |
+
def download_weights(model_path: str):
|
| 11 |
+
"""Download model weights if they don't exist"""
|
| 12 |
+
if os.path.exists(model_path):
|
| 13 |
+
logger.info(f"Weights already exist at {model_path}")
|
| 14 |
+
return
|
| 15 |
+
|
| 16 |
+
logger.info("Weights not found. Downloading...")
|
| 17 |
+
Path(model_path).parent.mkdir(parents=True, exist_ok=True)
|
| 18 |
+
|
| 19 |
+
import urllib.request
|
| 20 |
+
url = "https://huggingface.co/camenduru/dust3r/resolve/main/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth"
|
| 21 |
+
|
| 22 |
+
logger.info(f"Downloading from {url}")
|
| 23 |
+
urllib.request.urlretrieve(url, model_path)
|
| 24 |
+
logger.info("Download complete!")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
def initialize(model_path: str, device: str) -> torch.nn.Module:
|
| 28 |
+
download_weights(model_path)
|
| 29 |
logger.info(f"Loading model from: {model_path}")
|
| 30 |
|
| 31 |
logger.info("Loading checkpoint...")
|