| |
| |
| |
| |
| |
| |
| |
| import os |
| import torch |
| import tempfile |
|
|
| import mast3r.utils.path_to_dust3r |
| from dust3r.model import AsymmetricCroCo3DStereo |
| from mast3r.model import AsymmetricMASt3R |
| from dust3r.demo import get_args_parser as dust3r_get_args_parser |
| from dust3r.demo import main_demo |
|
|
| import matplotlib.pyplot as pl |
| pl.ion() |
|
|
| torch.backends.cuda.matmul.allow_tf32 = True |
|
|
|
|
| def get_args_parser(): |
| parser = dust3r_get_args_parser() |
|
|
| actions = parser._actions |
| for action in actions: |
| if action.dest == 'model_name': |
| action.choices.append('MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric') |
| |
| parser.prog = 'mast3r demo' |
| return parser |
|
|
|
|
| if __name__ == '__main__': |
| parser = get_args_parser() |
| args = parser.parse_args() |
|
|
| if args.tmp_dir is not None: |
| tmp_path = args.tmp_dir |
| os.makedirs(tmp_path, exist_ok=True) |
| tempfile.tempdir = tmp_path |
|
|
| if args.server_name is not None: |
| server_name = args.server_name |
| else: |
| server_name = '0.0.0.0' if args.local_network else '127.0.0.1' |
|
|
| if args.weights is not None: |
| weights_path = args.weights |
| else: |
| weights_path = "naver/" + args.model_name |
|
|
| try: |
| model = AsymmetricMASt3R.from_pretrained(weights_path).to(args.device) |
| except Exception as e: |
| model = AsymmetricCroCo3DStereo.from_pretrained(weights_path).to(args.device) |
|
|
| |
| with tempfile.TemporaryDirectory(suffix='dust3r_gradio_demo') as tmpdirname: |
| if not args.silent: |
| print('Outputing stuff in', tmpdirname) |
| main_demo(tmpdirname, model, args.device, args.image_size, server_name, args.server_port, silent=args.silent) |
|
|