brian4dwell's picture
add stream3r
9d31508
import os
import sys
import torch
import argparse
from tqdm import tqdm
from accelerate import PartialState
from stream3r.models.stream3r import STream3R
from stream3r.stream_session import StreamSession
from stream3r.dust3r.utils.image import load_images_for_eval as load_images
from stream3r.dust3r.utils.device import collate_with_cat
from stream3r.models.components.utils.pose_enc import pose_encoding_to_extri_intri
from stream3r.dust3r.utils.geometry import inv
from stream3r.utils.utils import ImgDust3r2Stream3r
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
from eval.relpose.metadata import dataset_metadata
from eval.relpose.utils import *
torch.backends.cuda.matmul.allow_tf32 = True
# avoid high cpu usage
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
torch.set_num_threads(1)
# ===========================================
def get_args_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--device",
type=str,
default="cuda",
help="pytorch device")
parser.add_argument(
"--output_dir",
type=str,
default="",
help="value for outdir",
)
parser.add_argument("--no_crop",
type=bool,
default=True,
help="whether to crop input data")
parser.add_argument(
"--eval_dataset",
type=str,
default="sintel",
choices=list(dataset_metadata.keys()),
)
parser.add_argument("--size", type=int, default="224")
parser.add_argument("--pose_eval_stride",
default=1,
type=int,
help="stride for pose evaluation")
parser.add_argument("--shuffle", action="store_true", default=False)
parser.add_argument(
"--full_seq",
action="store_true",
default=False,
help="use full sequence for pose evaluation",
)
parser.add_argument(
"--seq_list",
nargs="+",
default=None,
help="list of sequences for pose evaluation",
)
parser.add_argument("--freeze_state", action="store_true", default=False)
return parser
def eval_pose_estimation_dist(args,
model,
img_path,
save_dir=None,
mask_path=None):
metadata = dataset_metadata.get(args.eval_dataset)
anno_path = metadata.get("anno_path", None)
seq_list = args.seq_list
if seq_list is None:
if metadata.get("full_seq", False):
args.full_seq = True
else:
seq_list = metadata.get("seq_list", [])
if args.full_seq:
seq_list = os.listdir(img_path)
seq_list = [
seq for seq in seq_list
if os.path.isdir(os.path.join(img_path, seq))
]
seq_list = sorted(seq_list)
if save_dir is None:
save_dir = args.output_dir
distributed_state = PartialState()
model.to(distributed_state.device)
device = distributed_state.device
with distributed_state.split_between_processes(seq_list) as seqs:
ate_list = []
rpe_trans_list = []
rpe_rot_list = []
error_log_path = f"{save_dir}/_error_log_{distributed_state.process_index}.txt" # Unique log file per process
for seq in tqdm(seqs):
try:
dir_path = metadata["dir_path_func"](img_path, seq)
# Handle skip_condition
skip_condition = metadata.get("skip_condition", None)
if skip_condition is not None and skip_condition(
save_dir, seq):
continue
mask_path_seq_func = metadata.get("mask_path_seq_func",
lambda mask_path, seq: None)
mask_path_seq = mask_path_seq_func(mask_path, seq)
filelist = [
os.path.join(dir_path, name)
for name in os.listdir(dir_path)
]
filelist.sort()
filelist = filelist[::args.pose_eval_stride]
images = load_images(
filelist,
size=518,
verbose=True,
crop=False,
patch_size=14,
)
images = collate_with_cat([tuple(images)])
images = torch.stack([view["img"] for view in images], dim=1)
images = ImgDust3r2Stream3r(images).to(device)
with torch.no_grad():
session = StreamSession(model, mode="causal")
for i in range(images.shape[1]):
image = images[:, i:i+1]
predictions = session.forward_stream(image)
extrinsic, _ = pose_encoding_to_extri_intri(predictions["pose_enc"], predictions["images"].shape[-2:])
pr_poses = []
for i in range(extrinsic.shape[1]):
pr_poses.append(inv(torch.cat([extrinsic[0, i], torch.tensor([[0, 0, 0, 1]], device=device)], dim=0)))
pred_traj = get_tum_poses(pr_poses)
os.makedirs(f"{save_dir}/{seq}", exist_ok=True)
save_tum_poses(pr_poses, f"{save_dir}/{seq}/pred_traj.txt")
gt_traj_file = metadata["gt_traj_func"](img_path, anno_path,
seq)
traj_format = metadata.get("traj_format", None)
if args.eval_dataset == "sintel":
gt_traj = load_traj(gt_traj_file=gt_traj_file,
stride=args.pose_eval_stride)
elif traj_format is not None:
gt_traj = load_traj(
gt_traj_file=gt_traj_file,
traj_format=traj_format,
stride=args.pose_eval_stride,
)
else:
gt_traj = None
if gt_traj is not None:
ate, rpe_trans, rpe_rot = eval_metrics(
pred_traj,
gt_traj,
seq=seq,
filename=f"{save_dir}/{seq}_eval_metric.txt",
)
plot_trajectory(pred_traj,
gt_traj,
title=seq,
filename=f"{save_dir}/{seq}.png")
else:
ate, rpe_trans, rpe_rot = 0, 0, 0
bug = True
ate_list.append(ate)
rpe_trans_list.append(rpe_trans)
rpe_rot_list.append(rpe_rot)
# Write to error log after each sequence
with open(error_log_path, "a") as f:
f.write(
f"{args.eval_dataset}-{seq: <16} | ATE: {ate:.5f}, RPE trans: {rpe_trans:.5f}, RPE rot: {rpe_rot:.5f}\n"
)
f.write(f"{ate:.5f}\n")
f.write(f"{rpe_trans:.5f}\n")
f.write(f"{rpe_rot:.5f}\n")
except Exception as e:
if "out of memory" in str(e):
# Handle OOM
torch.cuda.empty_cache() # Clear the CUDA memory
with open(error_log_path, "a") as f:
f.write(
f"OOM error in sequence {seq}, skipping this sequence.\n"
)
print(f"OOM error in sequence {seq}, skipping...")
elif "Degenerate covariance rank" in str(
e) or "Eigenvalues did not converge" in str(e):
# Handle Degenerate covariance rank exception and Eigenvalues did not converge exception
with open(error_log_path, "a") as f:
f.write(f"Exception in sequence {seq}: {str(e)}\n")
print(
f"Traj evaluation error in sequence {seq}, skipping.")
else:
raise e # Rethrow if it's not an expected exception
distributed_state.wait_for_everyone()
results = process_directory(save_dir)
avg_ate, avg_rpe_trans, avg_rpe_rot = calculate_averages(results)
# Write the averages to the error log (only on the main process)
if distributed_state.is_main_process:
with open(f"{save_dir}/_error_log.txt", "a") as f:
# Copy the error log from each process to the main error log
for i in range(distributed_state.num_processes):
if not os.path.exists(f"{save_dir}/_error_log_{i}.txt"):
break
with open(f"{save_dir}/_error_log_{i}.txt", "r") as f_sub:
f.write(f_sub.read())
f.write(
f"Average ATE: {avg_ate:.5f}, Average RPE trans: {avg_rpe_trans:.5f}, Average RPE rot: {avg_rpe_rot:.5f}\n"
)
return avg_ate, avg_rpe_trans, avg_rpe_rot
def eval_pose_estimation(args, model, save_dir=None):
metadata = dataset_metadata.get(args.eval_dataset)
img_path = metadata["img_path"]
mask_path = metadata["mask_path"]
ate_mean, rpe_trans_mean, rpe_rot_mean = eval_pose_estimation_dist(
args, model, save_dir=save_dir, img_path=img_path, mask_path=mask_path)
return ate_mean, rpe_trans_mean, rpe_rot_mean
def main():
args = get_args_parser()
args = args.parse_args()
args.full_seq = False
args.no_crop = False
model = STream3R.from_pretrained("yslan/STream3R").to(args.device)
model.eval()
eval_pose_estimation(args, model, save_dir=args.output_dir)
if __name__ == "__main__":
main()