Spaces:
Runtime error
Runtime error
| import torch | |
| import logging | |
| import os | |
| from pathlib import Path | |
| from dust3r.model import AsymmetricCroCo3DStereo, inf | |
| logger = logging.getLogger(__name__) | |
| def download_weights(model_path: str): | |
| """Download model weights if they don't exist""" | |
| if os.path.exists(model_path): | |
| logger.info(f"Weights already exist at {model_path}") | |
| return | |
| logger.info("Weights not found. Downloading...") | |
| Path(model_path).parent.mkdir(parents=True, exist_ok=True) | |
| import urllib.request | |
| url = "https://huggingface.co/camenduru/dust3r/resolve/main/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth" | |
| logger.info(f"Downloading from {url}") | |
| urllib.request.urlretrieve(url, model_path) | |
| logger.info("Download complete!") | |
| def initialize(model_path: str, device: str) -> torch.nn.Module: | |
| download_weights(model_path) | |
| logger.info(f"Loading model from: {model_path}") | |
| logger.info("Loading checkpoint...") | |
| ckpt = torch.load(model_path, map_location='cpu', weights_only=False) | |
| logger.info("Parsing model arguments...") | |
| args = ckpt['args'].model.replace("ManyAR_PatchEmbed", "PatchEmbedDust3R") | |
| if isinstance(args, str) and 'landscape_only' not in args: | |
| args = args[:-1] + ', landscape_only=False)' | |
| elif isinstance(args, str): | |
| args = args.replace(" ", "").replace('landscape_only=True', 'landscape_only=False') | |
| logger.info("Instantiating model...") | |
| net = eval(args) | |
| logger.info("Loading model weights...") | |
| net.load_state_dict(ckpt['model'], strict=False) | |
| logger.info(f"Moving model to {device}...") | |
| model = net.to(device) | |
| logger.info("Model initialization complete!") | |
| return model | |