Spaces:
Configuration error
Configuration error
| import os | |
| import sys | |
| import torch | |
| import argparse | |
| from tqdm import tqdm | |
| from accelerate import PartialState | |
| from stream3r.models.stream3r import STream3R | |
| from stream3r.stream_session import StreamSession | |
| from stream3r.dust3r.utils.image import load_images_for_eval as load_images | |
| from stream3r.dust3r.utils.device import collate_with_cat | |
| from stream3r.models.components.utils.pose_enc import pose_encoding_to_extri_intri | |
| from stream3r.dust3r.utils.geometry import inv | |
| from stream3r.utils.utils import ImgDust3r2Stream3r | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) | |
| from eval.relpose.metadata import dataset_metadata | |
| from eval.relpose.utils import * | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| # avoid high cpu usage | |
| os.environ["OMP_NUM_THREADS"] = "1" | |
| os.environ["MKL_NUM_THREADS"] = "1" | |
| os.environ["NUMEXPR_NUM_THREADS"] = "1" | |
| os.environ["OPENBLAS_NUM_THREADS"] = "1" | |
| torch.set_num_threads(1) | |
| # =========================================== | |
| def get_args_parser(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--device", | |
| type=str, | |
| default="cuda", | |
| help="pytorch device") | |
| parser.add_argument( | |
| "--output_dir", | |
| type=str, | |
| default="", | |
| help="value for outdir", | |
| ) | |
| parser.add_argument("--no_crop", | |
| type=bool, | |
| default=True, | |
| help="whether to crop input data") | |
| parser.add_argument( | |
| "--eval_dataset", | |
| type=str, | |
| default="sintel", | |
| choices=list(dataset_metadata.keys()), | |
| ) | |
| parser.add_argument("--size", type=int, default="224") | |
| parser.add_argument("--pose_eval_stride", | |
| default=1, | |
| type=int, | |
| help="stride for pose evaluation") | |
| parser.add_argument("--shuffle", action="store_true", default=False) | |
| parser.add_argument( | |
| "--full_seq", | |
| action="store_true", | |
| default=False, | |
| help="use full sequence for pose evaluation", | |
| ) | |
| parser.add_argument( | |
| "--seq_list", | |
| nargs="+", | |
| default=None, | |
| help="list of sequences for pose evaluation", | |
| ) | |
| parser.add_argument("--freeze_state", action="store_true", default=False) | |
| return parser | |
| def eval_pose_estimation_dist(args, | |
| model, | |
| img_path, | |
| save_dir=None, | |
| mask_path=None): | |
| metadata = dataset_metadata.get(args.eval_dataset) | |
| anno_path = metadata.get("anno_path", None) | |
| seq_list = args.seq_list | |
| if seq_list is None: | |
| if metadata.get("full_seq", False): | |
| args.full_seq = True | |
| else: | |
| seq_list = metadata.get("seq_list", []) | |
| if args.full_seq: | |
| seq_list = os.listdir(img_path) | |
| seq_list = [ | |
| seq for seq in seq_list | |
| if os.path.isdir(os.path.join(img_path, seq)) | |
| ] | |
| seq_list = sorted(seq_list) | |
| if save_dir is None: | |
| save_dir = args.output_dir | |
| distributed_state = PartialState() | |
| model.to(distributed_state.device) | |
| device = distributed_state.device | |
| with distributed_state.split_between_processes(seq_list) as seqs: | |
| ate_list = [] | |
| rpe_trans_list = [] | |
| rpe_rot_list = [] | |
| error_log_path = f"{save_dir}/_error_log_{distributed_state.process_index}.txt" # Unique log file per process | |
| for seq in tqdm(seqs): | |
| try: | |
| dir_path = metadata["dir_path_func"](img_path, seq) | |
| # Handle skip_condition | |
| skip_condition = metadata.get("skip_condition", None) | |
| if skip_condition is not None and skip_condition( | |
| save_dir, seq): | |
| continue | |
| mask_path_seq_func = metadata.get("mask_path_seq_func", | |
| lambda mask_path, seq: None) | |
| mask_path_seq = mask_path_seq_func(mask_path, seq) | |
| filelist = [ | |
| os.path.join(dir_path, name) | |
| for name in os.listdir(dir_path) | |
| ] | |
| filelist.sort() | |
| filelist = filelist[::args.pose_eval_stride] | |
| images = load_images( | |
| filelist, | |
| size=518, | |
| verbose=True, | |
| crop=False, | |
| patch_size=14, | |
| ) | |
| images = collate_with_cat([tuple(images)]) | |
| images = torch.stack([view["img"] for view in images], dim=1) | |
| images = ImgDust3r2Stream3r(images).to(device) | |
| with torch.no_grad(): | |
| session = StreamSession(model, mode="causal") | |
| for i in range(images.shape[1]): | |
| image = images[:, i:i+1] | |
| predictions = session.forward_stream(image) | |
| extrinsic, _ = pose_encoding_to_extri_intri(predictions["pose_enc"], predictions["images"].shape[-2:]) | |
| pr_poses = [] | |
| for i in range(extrinsic.shape[1]): | |
| pr_poses.append(inv(torch.cat([extrinsic[0, i], torch.tensor([[0, 0, 0, 1]], device=device)], dim=0))) | |
| pred_traj = get_tum_poses(pr_poses) | |
| os.makedirs(f"{save_dir}/{seq}", exist_ok=True) | |
| save_tum_poses(pr_poses, f"{save_dir}/{seq}/pred_traj.txt") | |
| gt_traj_file = metadata["gt_traj_func"](img_path, anno_path, | |
| seq) | |
| traj_format = metadata.get("traj_format", None) | |
| if args.eval_dataset == "sintel": | |
| gt_traj = load_traj(gt_traj_file=gt_traj_file, | |
| stride=args.pose_eval_stride) | |
| elif traj_format is not None: | |
| gt_traj = load_traj( | |
| gt_traj_file=gt_traj_file, | |
| traj_format=traj_format, | |
| stride=args.pose_eval_stride, | |
| ) | |
| else: | |
| gt_traj = None | |
| if gt_traj is not None: | |
| ate, rpe_trans, rpe_rot = eval_metrics( | |
| pred_traj, | |
| gt_traj, | |
| seq=seq, | |
| filename=f"{save_dir}/{seq}_eval_metric.txt", | |
| ) | |
| plot_trajectory(pred_traj, | |
| gt_traj, | |
| title=seq, | |
| filename=f"{save_dir}/{seq}.png") | |
| else: | |
| ate, rpe_trans, rpe_rot = 0, 0, 0 | |
| bug = True | |
| ate_list.append(ate) | |
| rpe_trans_list.append(rpe_trans) | |
| rpe_rot_list.append(rpe_rot) | |
| # Write to error log after each sequence | |
| with open(error_log_path, "a") as f: | |
| f.write( | |
| f"{args.eval_dataset}-{seq: <16} | ATE: {ate:.5f}, RPE trans: {rpe_trans:.5f}, RPE rot: {rpe_rot:.5f}\n" | |
| ) | |
| f.write(f"{ate:.5f}\n") | |
| f.write(f"{rpe_trans:.5f}\n") | |
| f.write(f"{rpe_rot:.5f}\n") | |
| except Exception as e: | |
| if "out of memory" in str(e): | |
| # Handle OOM | |
| torch.cuda.empty_cache() # Clear the CUDA memory | |
| with open(error_log_path, "a") as f: | |
| f.write( | |
| f"OOM error in sequence {seq}, skipping this sequence.\n" | |
| ) | |
| print(f"OOM error in sequence {seq}, skipping...") | |
| elif "Degenerate covariance rank" in str( | |
| e) or "Eigenvalues did not converge" in str(e): | |
| # Handle Degenerate covariance rank exception and Eigenvalues did not converge exception | |
| with open(error_log_path, "a") as f: | |
| f.write(f"Exception in sequence {seq}: {str(e)}\n") | |
| print( | |
| f"Traj evaluation error in sequence {seq}, skipping.") | |
| else: | |
| raise e # Rethrow if it's not an expected exception | |
| distributed_state.wait_for_everyone() | |
| results = process_directory(save_dir) | |
| avg_ate, avg_rpe_trans, avg_rpe_rot = calculate_averages(results) | |
| # Write the averages to the error log (only on the main process) | |
| if distributed_state.is_main_process: | |
| with open(f"{save_dir}/_error_log.txt", "a") as f: | |
| # Copy the error log from each process to the main error log | |
| for i in range(distributed_state.num_processes): | |
| if not os.path.exists(f"{save_dir}/_error_log_{i}.txt"): | |
| break | |
| with open(f"{save_dir}/_error_log_{i}.txt", "r") as f_sub: | |
| f.write(f_sub.read()) | |
| f.write( | |
| f"Average ATE: {avg_ate:.5f}, Average RPE trans: {avg_rpe_trans:.5f}, Average RPE rot: {avg_rpe_rot:.5f}\n" | |
| ) | |
| return avg_ate, avg_rpe_trans, avg_rpe_rot | |
| def eval_pose_estimation(args, model, save_dir=None): | |
| metadata = dataset_metadata.get(args.eval_dataset) | |
| img_path = metadata["img_path"] | |
| mask_path = metadata["mask_path"] | |
| ate_mean, rpe_trans_mean, rpe_rot_mean = eval_pose_estimation_dist( | |
| args, model, save_dir=save_dir, img_path=img_path, mask_path=mask_path) | |
| return ate_mean, rpe_trans_mean, rpe_rot_mean | |
| def main(): | |
| args = get_args_parser() | |
| args = args.parse_args() | |
| args.full_seq = False | |
| args.no_crop = False | |
| model = STream3R.from_pretrained("yslan/STream3R").to(args.device) | |
| model.eval() | |
| eval_pose_estimation(args, model, save_dir=args.output_dir) | |
| if __name__ == "__main__": | |
| main() | |