|
|
import os |
|
|
import sys |
|
|
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) |
|
|
import math |
|
|
import cv2 |
|
|
import numpy as np |
|
|
import torch |
|
|
import argparse |
|
|
|
|
|
from copy import deepcopy |
|
|
from eval.relpose.metadata import dataset_metadata |
|
|
from eval.relpose.utils import * |
|
|
|
|
|
from accelerate import PartialState |
|
|
from add_ckpt_path import add_path_to_dust3r |
|
|
|
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
def get_args_parser(): |
|
|
parser = argparse.ArgumentParser() |
|
|
|
|
|
parser.add_argument( |
|
|
"--weights", |
|
|
type=str, |
|
|
help="path to the model weights", |
|
|
default="", |
|
|
) |
|
|
|
|
|
parser.add_argument("--device", type=str, default="cuda", help="pytorch device") |
|
|
parser.add_argument( |
|
|
"--output_dir", |
|
|
type=str, |
|
|
default="", |
|
|
help="value for outdir", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--no_crop", type=bool, default=True, help="whether to crop input data" |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--eval_dataset", |
|
|
type=str, |
|
|
default="sintel", |
|
|
choices=list(dataset_metadata.keys()), |
|
|
) |
|
|
parser.add_argument("--size", type=int, default="224") |
|
|
|
|
|
parser.add_argument( |
|
|
"--pose_eval_stride", default=1, type=int, help="stride for pose evaluation" |
|
|
) |
|
|
parser.add_argument("--shuffle", action="store_true", default=False) |
|
|
parser.add_argument( |
|
|
"--full_seq", |
|
|
action="store_true", |
|
|
default=False, |
|
|
help="use full sequence for pose evaluation", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--seq_list", |
|
|
nargs="+", |
|
|
default=None, |
|
|
help="list of sequences for pose evaluation", |
|
|
) |
|
|
|
|
|
parser.add_argument("--revisit", type=int, default=1) |
|
|
parser.add_argument("--freeze_state", action="store_true", default=False) |
|
|
parser.add_argument("--solve_pose", action="store_true", default=False) |
|
|
return parser |
|
|
|
|
|
|
|
|
def eval_pose_estimation(args, model, save_dir=None): |
|
|
metadata = dataset_metadata.get(args.eval_dataset) |
|
|
img_path = metadata["img_path"] |
|
|
mask_path = metadata["mask_path"] |
|
|
|
|
|
ate_mean, rpe_trans_mean, rpe_rot_mean = eval_pose_estimation_dist( |
|
|
args, model, save_dir=save_dir, img_path=img_path, mask_path=mask_path |
|
|
) |
|
|
return ate_mean, rpe_trans_mean, rpe_rot_mean |
|
|
|
|
|
|
|
|
def eval_pose_estimation_dist(args, model, img_path, save_dir=None, mask_path=None): |
|
|
from dust3r.inference import inference |
|
|
|
|
|
metadata = dataset_metadata.get(args.eval_dataset) |
|
|
anno_path = metadata.get("anno_path", None) |
|
|
|
|
|
seq_list = args.seq_list |
|
|
if seq_list is None: |
|
|
if metadata.get("full_seq", False): |
|
|
args.full_seq = True |
|
|
else: |
|
|
seq_list = metadata.get("seq_list", []) |
|
|
if args.full_seq: |
|
|
seq_list = os.listdir(img_path) |
|
|
seq_list = [ |
|
|
seq for seq in seq_list if os.path.isdir(os.path.join(img_path, seq)) |
|
|
] |
|
|
seq_list = sorted(seq_list) |
|
|
|
|
|
if save_dir is None: |
|
|
save_dir = args.output_dir |
|
|
|
|
|
distributed_state = PartialState() |
|
|
model.to(distributed_state.device) |
|
|
device = distributed_state.device |
|
|
|
|
|
with distributed_state.split_between_processes(seq_list) as seqs: |
|
|
ate_list = [] |
|
|
rpe_trans_list = [] |
|
|
rpe_rot_list = [] |
|
|
load_img_size = args.size |
|
|
error_log_path = f"{save_dir}/_error_log_{distributed_state.process_index}.txt" |
|
|
bug = False |
|
|
for seq in tqdm(seqs): |
|
|
try: |
|
|
dir_path = metadata["dir_path_func"](img_path, seq) |
|
|
|
|
|
|
|
|
skip_condition = metadata.get("skip_condition", None) |
|
|
if skip_condition is not None and skip_condition(save_dir, seq): |
|
|
continue |
|
|
|
|
|
mask_path_seq_func = metadata.get( |
|
|
"mask_path_seq_func", lambda mask_path, seq: None |
|
|
) |
|
|
mask_path_seq = mask_path_seq_func(mask_path, seq) |
|
|
|
|
|
filelist = [ |
|
|
os.path.join(dir_path, name) for name in os.listdir(dir_path) |
|
|
] |
|
|
filelist.sort() |
|
|
filelist = filelist[:: args.pose_eval_stride] |
|
|
|
|
|
views = prepare_input( |
|
|
filelist, |
|
|
[True for _ in filelist], |
|
|
size=load_img_size, |
|
|
crop=not args.no_crop, |
|
|
revisit=args.revisit, |
|
|
update=not args.freeze_state, |
|
|
) |
|
|
outputs, _ = inference(views, model, device) |
|
|
|
|
|
( |
|
|
colors, |
|
|
pts3ds_self, |
|
|
pts3ds_other, |
|
|
conf_self, |
|
|
conf_other, |
|
|
cam_dict, |
|
|
pr_poses, |
|
|
) = prepare_output( |
|
|
outputs, revisit=args.revisit, solve_pose=args.solve_pose |
|
|
) |
|
|
|
|
|
pred_traj = get_tum_poses(pr_poses) |
|
|
os.makedirs(f"{save_dir}/{seq}", exist_ok=True) |
|
|
save_tum_poses(pr_poses, f"{save_dir}/{seq}/pred_traj.txt") |
|
|
save_focals(cam_dict, f"{save_dir}/{seq}/pred_focal.txt") |
|
|
save_intrinsics(cam_dict, f"{save_dir}/{seq}/pred_intrinsics.txt") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gt_traj_file = metadata["gt_traj_func"](img_path, anno_path, seq) |
|
|
traj_format = metadata.get("traj_format", None) |
|
|
|
|
|
if args.eval_dataset == "sintel": |
|
|
gt_traj = load_traj( |
|
|
gt_traj_file=gt_traj_file, stride=args.pose_eval_stride |
|
|
) |
|
|
elif traj_format is not None: |
|
|
gt_traj = load_traj( |
|
|
gt_traj_file=gt_traj_file, |
|
|
traj_format=traj_format, |
|
|
stride=args.pose_eval_stride, |
|
|
) |
|
|
else: |
|
|
gt_traj = None |
|
|
|
|
|
if gt_traj is not None: |
|
|
ate, rpe_trans, rpe_rot = eval_metrics( |
|
|
pred_traj, |
|
|
gt_traj, |
|
|
seq=seq, |
|
|
filename=f"{save_dir}/{seq}_eval_metric.txt", |
|
|
) |
|
|
plot_trajectory( |
|
|
pred_traj, gt_traj, title=seq, filename=f"{save_dir}/{seq}.png" |
|
|
) |
|
|
else: |
|
|
ate, rpe_trans, rpe_rot = 0, 0, 0 |
|
|
bug = True |
|
|
|
|
|
ate_list.append(ate) |
|
|
rpe_trans_list.append(rpe_trans) |
|
|
rpe_rot_list.append(rpe_rot) |
|
|
|
|
|
|
|
|
with open(error_log_path, "a") as f: |
|
|
f.write( |
|
|
f"{args.eval_dataset}-{seq: <16} | ATE: {ate:.5f}, RPE trans: {rpe_trans:.5f}, RPE rot: {rpe_rot:.5f}\n" |
|
|
) |
|
|
f.write(f"{ate:.5f}\n") |
|
|
f.write(f"{rpe_trans:.5f}\n") |
|
|
f.write(f"{rpe_rot:.5f}\n") |
|
|
|
|
|
except Exception as e: |
|
|
if "out of memory" in str(e): |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
with open(error_log_path, "a") as f: |
|
|
f.write( |
|
|
f"OOM error in sequence {seq}, skipping this sequence.\n" |
|
|
) |
|
|
print(f"OOM error in sequence {seq}, skipping...") |
|
|
elif "Degenerate covariance rank" in str( |
|
|
e |
|
|
) or "Eigenvalues did not converge" in str(e): |
|
|
|
|
|
with open(error_log_path, "a") as f: |
|
|
f.write(f"Exception in sequence {seq}: {str(e)}\n") |
|
|
print(f"Traj evaluation error in sequence {seq}, skipping.") |
|
|
else: |
|
|
raise e |
|
|
|
|
|
distributed_state.wait_for_everyone() |
|
|
|
|
|
results = process_directory(save_dir) |
|
|
avg_ate, avg_rpe_trans, avg_rpe_rot = calculate_averages(results) |
|
|
|
|
|
|
|
|
if distributed_state.is_main_process: |
|
|
with open(f"{save_dir}/_error_log.txt", "a") as f: |
|
|
|
|
|
for i in range(distributed_state.num_processes): |
|
|
if not os.path.exists(f"{save_dir}/_error_log_{i}.txt"): |
|
|
break |
|
|
with open(f"{save_dir}/_error_log_{i}.txt", "r") as f_sub: |
|
|
f.write(f_sub.read()) |
|
|
f.write( |
|
|
f"Average ATE: {avg_ate:.5f}, Average RPE trans: {avg_rpe_trans:.5f}, Average RPE rot: {avg_rpe_rot:.5f}\n" |
|
|
) |
|
|
|
|
|
return avg_ate, avg_rpe_trans, avg_rpe_rot |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
args = get_args_parser() |
|
|
args = args.parse_args() |
|
|
add_path_to_dust3r(args.weights) |
|
|
from dust3r.utils.image import load_images_for_eval as load_images |
|
|
from dust3r.post_process import estimate_focal_knowing_depth |
|
|
from dust3r.model import ARCroco3DStereo |
|
|
from dust3r.utils.camera import pose_encoding_to_camera |
|
|
from dust3r.utils.geometry import weighted_procrustes, geotrf |
|
|
|
|
|
args.full_seq = False |
|
|
args.no_crop = False |
|
|
|
|
|
def recover_cam_params(pts3ds_self, pts3ds_other, conf_self, conf_other): |
|
|
B, H, W, _ = pts3ds_self.shape |
|
|
pp = ( |
|
|
torch.tensor([W // 2, H // 2], device=pts3ds_self.device) |
|
|
.float() |
|
|
.repeat(B, 1) |
|
|
.reshape(B, 1, 2) |
|
|
) |
|
|
focal = estimate_focal_knowing_depth(pts3ds_self, pp, focal_mode="weiszfeld") |
|
|
|
|
|
pts3ds_self = pts3ds_self.reshape(B, -1, 3) |
|
|
pts3ds_other = pts3ds_other.reshape(B, -1, 3) |
|
|
conf_self = conf_self.reshape(B, -1) |
|
|
conf_other = conf_other.reshape(B, -1) |
|
|
|
|
|
c2w = weighted_procrustes( |
|
|
pts3ds_self, |
|
|
pts3ds_other, |
|
|
torch.log(conf_self) * torch.log(conf_other), |
|
|
use_weights=True, |
|
|
return_T=True, |
|
|
) |
|
|
return c2w, focal, pp.reshape(B, 2) |
|
|
|
|
|
def prepare_input( |
|
|
img_paths, |
|
|
img_mask, |
|
|
size, |
|
|
raymaps=None, |
|
|
raymap_mask=None, |
|
|
revisit=1, |
|
|
update=True, |
|
|
crop=True, |
|
|
): |
|
|
images = load_images(img_paths, size=size, crop=crop) |
|
|
views = [] |
|
|
if raymaps is None and raymap_mask is None: |
|
|
num_views = len(images) |
|
|
|
|
|
for i in range(num_views): |
|
|
view = { |
|
|
"img": images[i]["img"], |
|
|
"ray_map": torch.full( |
|
|
( |
|
|
images[i]["img"].shape[0], |
|
|
6, |
|
|
images[i]["img"].shape[-2], |
|
|
images[i]["img"].shape[-1], |
|
|
), |
|
|
torch.nan, |
|
|
), |
|
|
"true_shape": torch.from_numpy(images[i]["true_shape"]), |
|
|
"idx": i, |
|
|
"instance": str(i), |
|
|
"camera_pose": torch.from_numpy( |
|
|
np.eye(4).astype(np.float32) |
|
|
).unsqueeze(0), |
|
|
"img_mask": torch.tensor(True).unsqueeze(0), |
|
|
"ray_mask": torch.tensor(False).unsqueeze(0), |
|
|
"update": torch.tensor(True).unsqueeze(0), |
|
|
"reset": torch.tensor(False).unsqueeze(0), |
|
|
} |
|
|
views.append(view) |
|
|
else: |
|
|
|
|
|
num_views = len(images) + len(raymaps) |
|
|
assert len(img_mask) == len(raymap_mask) == num_views |
|
|
assert sum(img_mask) == len(images) and sum(raymap_mask) == len(raymaps) |
|
|
|
|
|
j = 0 |
|
|
k = 0 |
|
|
for i in range(num_views): |
|
|
view = { |
|
|
"img": ( |
|
|
images[j]["img"] |
|
|
if img_mask[i] |
|
|
else torch.full_like(images[0]["img"], torch.nan) |
|
|
), |
|
|
"ray_map": ( |
|
|
raymaps[k] |
|
|
if raymap_mask[i] |
|
|
else torch.full_like(raymaps[0], torch.nan) |
|
|
), |
|
|
"true_shape": ( |
|
|
torch.from_numpy(images[j]["true_shape"]) |
|
|
if img_mask[i] |
|
|
else torch.from_numpy(np.int32([raymaps[k].shape[1:-1][::-1]])) |
|
|
), |
|
|
"idx": i, |
|
|
"instance": str(i), |
|
|
"camera_pose": torch.from_numpy( |
|
|
np.eye(4).astype(np.float32) |
|
|
).unsqueeze(0), |
|
|
"img_mask": torch.tensor(img_mask[i]).unsqueeze(0), |
|
|
"ray_mask": torch.tensor(raymap_mask[i]).unsqueeze(0), |
|
|
"update": torch.tensor(img_mask[i]).unsqueeze(0), |
|
|
"reset": torch.tensor(False).unsqueeze(0), |
|
|
} |
|
|
if img_mask[i]: |
|
|
j += 1 |
|
|
if raymap_mask[i]: |
|
|
k += 1 |
|
|
views.append(view) |
|
|
assert j == len(images) and k == len(raymaps) |
|
|
|
|
|
if revisit > 1: |
|
|
|
|
|
new_views = [] |
|
|
for r in range(revisit): |
|
|
for i in range(len(views)): |
|
|
new_view = deepcopy(views[i]) |
|
|
new_view["idx"] = r * len(views) + i |
|
|
new_view["instance"] = str(r * len(views) + i) |
|
|
if r > 0: |
|
|
if not update: |
|
|
new_view["update"] = torch.tensor(False).unsqueeze(0) |
|
|
new_views.append(new_view) |
|
|
return new_views |
|
|
return views |
|
|
|
|
|
def prepare_output(outputs, revisit=1, solve_pose=False): |
|
|
valid_length = len(outputs["pred"]) // revisit |
|
|
outputs["pred"] = outputs["pred"][-valid_length:] |
|
|
outputs["views"] = outputs["views"][-valid_length:] |
|
|
|
|
|
if solve_pose: |
|
|
pts3ds_self = [ |
|
|
output["pts3d_in_self_view"].cpu() for output in outputs["pred"] |
|
|
] |
|
|
pts3ds_other = [ |
|
|
output["pts3d_in_other_view"].cpu() for output in outputs["pred"] |
|
|
] |
|
|
conf_self = [output["conf_self"].cpu() for output in outputs["pred"]] |
|
|
conf_other = [output["conf"].cpu() for output in outputs["pred"]] |
|
|
pr_poses, focal, pp = recover_cam_params( |
|
|
torch.cat(pts3ds_self, 0), |
|
|
torch.cat(pts3ds_other, 0), |
|
|
torch.cat(conf_self, 0), |
|
|
torch.cat(conf_other, 0), |
|
|
) |
|
|
pts3ds_self = torch.cat(pts3ds_self, 0) |
|
|
else: |
|
|
|
|
|
pts3ds_self = [ |
|
|
output["pts3d_in_self_view"].cpu() for output in outputs["pred"] |
|
|
] |
|
|
pts3ds_other = [ |
|
|
output["pts3d_in_other_view"].cpu() for output in outputs["pred"] |
|
|
] |
|
|
conf_self = [output["conf_self"].cpu() for output in outputs["pred"]] |
|
|
conf_other = [output["conf"].cpu() for output in outputs["pred"]] |
|
|
pts3ds_self = torch.cat(pts3ds_self, 0) |
|
|
pr_poses = [ |
|
|
pose_encoding_to_camera(pred["camera_pose"].clone()).cpu() |
|
|
for pred in outputs["pred"] |
|
|
] |
|
|
pr_poses = torch.cat(pr_poses, 0) |
|
|
|
|
|
B, H, W, _ = pts3ds_self.shape |
|
|
pp = ( |
|
|
torch.tensor([W // 2, H // 2], device=pts3ds_self.device) |
|
|
.float() |
|
|
.repeat(B, 1) |
|
|
.reshape(B, 2) |
|
|
) |
|
|
focal = estimate_focal_knowing_depth( |
|
|
pts3ds_self, pp, focal_mode="weiszfeld" |
|
|
) |
|
|
|
|
|
colors = [0.5 * (output["rgb"][0] + 1.0) for output in outputs["pred"]] |
|
|
cam_dict = { |
|
|
"focal": focal.cpu().numpy(), |
|
|
"pp": pp.cpu().numpy(), |
|
|
} |
|
|
return ( |
|
|
colors, |
|
|
pts3ds_self, |
|
|
pts3ds_other, |
|
|
conf_self, |
|
|
conf_other, |
|
|
cam_dict, |
|
|
pr_poses, |
|
|
) |
|
|
|
|
|
model = ARCroco3DStereo.from_pretrained(args.weights) |
|
|
eval_pose_estimation(args, model, save_dir=args.output_dir) |
|
|
|