import torch import logging import os from pathlib import Path from dust3r.model import AsymmetricCroCo3DStereo, inf logger = logging.getLogger(__name__) def download_weights(model_path: str): """Download model weights if they don't exist""" if os.path.exists(model_path): logger.info(f"Weights already exist at {model_path}") return logger.info("Weights not found. Downloading...") Path(model_path).parent.mkdir(parents=True, exist_ok=True) import urllib.request url = "https://huggingface.co/camenduru/dust3r/resolve/main/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth" logger.info(f"Downloading from {url}") urllib.request.urlretrieve(url, model_path) logger.info("Download complete!") def initialize(model_path: str, device: str) -> torch.nn.Module: download_weights(model_path) logger.info(f"Loading model from: {model_path}") logger.info("Loading checkpoint...") ckpt = torch.load(model_path, map_location='cpu', weights_only=False) logger.info("Parsing model arguments...") args = ckpt['args'].model.replace("ManyAR_PatchEmbed", "PatchEmbedDust3R") if isinstance(args, str) and 'landscape_only' not in args: args = args[:-1] + ', landscape_only=False)' elif isinstance(args, str): args = args.replace(" ", "").replace('landscape_only=True', 'landscape_only=False') logger.info("Instantiating model...") net = eval(args) logger.info("Loading model weights...") net.load_state_dict(ckpt['model'], strict=False) logger.info(f"Moving model to {device}...") model = net.to(device) logger.info("Model initialization complete!") return model