| |
| |
| |
| from dust3r.training import get_args_parser, train, load_model |
| from dust3r.pose_eval import eval_pose_estimation, pose_estimation_custom |
| from dust3r.depth_eval import eval_mono_depth_estimation |
| import croco.utils.misc as misc |
| import torch |
| import torch.backends.cudnn as cudnn |
| import numpy as np |
| import os |
|
|
| def main(pretrained, dir_path, output_dir, use_pred_mask, n_iter): |
| |
| args = get_args_parser() |
| args = args.parse_args() |
| args.pretrained = pretrained |
| args.dir_path = dir_path |
| args.output_dir = output_dir |
| args.use_pred_mask = use_pred_mask |
| args.n_iter = n_iter |
|
|
| misc.init_distributed_mode(args) |
| global_rank = misc.get_rank() |
| world_size = misc.get_world_size() |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| device = torch.device(device) |
|
|
| |
| seed = args.seed + misc.get_rank() |
| torch.manual_seed(seed) |
| np.random.seed(seed) |
| cudnn.benchmark = args.cudnn_benchmark |
| model, _ = load_model(args, device) |
| os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
|
| pose_estimation_custom(args, model, device, save_dir=args.output_dir) |
|
|
|
|
|
|
| if __name__ == '__main__': |
| args = get_args_parser() |
| args = args.parse_args() |
| if args.mode.startswith('eval'): |
| misc.init_distributed_mode(args) |
| global_rank = misc.get_rank() |
| world_size = misc.get_world_size() |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| device = torch.device(device) |
|
|
| |
| seed = args.seed + misc.get_rank() |
| torch.manual_seed(seed) |
| np.random.seed(seed) |
| cudnn.benchmark = args.cudnn_benchmark |
| model, _ = load_model(args, device) |
| os.makedirs(args.output_dir, exist_ok=True) |
|
|
| if args.mode == 'eval_pose': |
| ate_mean, rpe_trans_mean, rpe_rot_mean, outfile_list, bug = eval_pose_estimation(args, model, device, save_dir=args.output_dir) |
| print(f'ATE mean: {ate_mean}, RPE trans mean: {rpe_trans_mean}, RPE rot mean: {rpe_rot_mean}') |
| if args.mode == 'eval_pose_custom': |
| pose_estimation_custom(args, model, device, save_dir=args.output_dir) |
|
|
| if args.mode == 'eval_depth': |
| eval_mono_depth_estimation(args, model, device) |
|
|
| exit(0) |
| train(args) |
|
|
|
|