Spaces:
Configuration error
Configuration error
| import os | |
| import sys | |
| import numpy as np | |
| import torch | |
| import argparse | |
| from accelerate import PartialState | |
| from tqdm import tqdm | |
| from PIL import Image | |
| import imageio.v2 as iio | |
| from stream3r.models.stream3r import STream3R | |
| from stream3r.stream_session import StreamSession | |
| from stream3r.dust3r.utils.image import load_images_for_eval as load_images | |
| from stream3r.dust3r.utils.device import collate_with_cat | |
| from stream3r.utils.utils import ImgDust3r2Stream3r | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) | |
| from eval.video_depth.metadata import dataset_metadata | |
| from eval.video_depth.utils import colorize | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| # avoid high cpu usage | |
| os.environ["OMP_NUM_THREADS"] = "1" | |
| os.environ["MKL_NUM_THREADS"] = "1" | |
| os.environ["NUMEXPR_NUM_THREADS"] = "1" | |
| os.environ["OPENBLAS_NUM_THREADS"] = "1" | |
| torch.set_num_threads(1) | |
| # =========================================== | |
| def save_depth_maps(pts3ds_self, path, conf_self=None, depth_maps=None): | |
| if depth_maps is None: | |
| depth_maps = torch.stack([pts3d_self[..., -1] for pts3d_self in pts3ds_self], 0) | |
| min_depth = depth_maps.min() # float(torch.quantile(out, 0.01)) | |
| max_depth = depth_maps.max() # float(torch.quantile(out, 0.99)) | |
| colored_depth = colorize( | |
| depth_maps, | |
| cmap_name="Spectral_r", | |
| range=(min_depth, max_depth), | |
| append_cbar=True, | |
| ) | |
| images = [] | |
| if conf_self is not None: | |
| conf_selfs = torch.concat(conf_self, 0) | |
| min_conf = torch.log(conf_selfs.min()) # float(torch.quantile(out, 0.01)) | |
| max_conf = torch.log(conf_selfs.max()) # float(torch.quantile(out, 0.99)) | |
| colored_conf = colorize( | |
| torch.log(conf_selfs), | |
| cmap_name="jet", | |
| range=(min_conf, max_conf), | |
| append_cbar=True, | |
| ) | |
| for i, depth_map in enumerate(colored_depth): | |
| # Apply color map to depth map | |
| img_path = f"{path}/frame_{(i):04d}.png" | |
| if conf_self is None: | |
| to_save = (depth_map * 255).detach().cpu().numpy().astype(np.uint8) | |
| else: | |
| to_save = torch.cat([depth_map, colored_conf[i]], dim=1) | |
| to_save = (to_save * 255).detach().cpu().numpy().astype(np.uint8) | |
| iio.imwrite(img_path, to_save) | |
| images.append(Image.open(img_path)) | |
| np.save(f"{path}/frame_{(i):04d}.npy", depth_maps[i].detach().cpu().numpy()) | |
| return depth_maps | |
| def get_args_parser(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--device", type=str, default="cuda", help="pytorch device") | |
| parser.add_argument( | |
| "--output_dir", | |
| type=str, | |
| default="", | |
| help="value for outdir", | |
| ) | |
| parser.add_argument( | |
| "--no_crop", type=bool, default=True, help="whether to crop input data" | |
| ) | |
| parser.add_argument( | |
| "--eval_dataset", | |
| type=str, | |
| default="sintel", | |
| choices=list(dataset_metadata.keys()), | |
| ) | |
| parser.add_argument("--size", type=int, default="512") | |
| parser.add_argument( | |
| "--pose_eval_stride", default=1, type=int, help="stride for pose evaluation" | |
| ) | |
| parser.add_argument( | |
| "--full_seq", | |
| action="store_true", | |
| default=False, | |
| help="use full sequence for pose evaluation", | |
| ) | |
| parser.add_argument( | |
| "--seq_list", | |
| nargs="+", | |
| default=None, | |
| help="list of sequences for pose evaluation", | |
| ) | |
| parser.add_argument( | |
| "--checkpoint_dir", | |
| type=str, | |
| default="", | |
| help="path to the checkpoint directory", | |
| ) | |
| return parser | |
| def eval_pose_estimation(args, model, save_dir=None): | |
| metadata = dataset_metadata.get(args.eval_dataset) | |
| img_path = metadata["img_path"] | |
| mask_path = metadata["mask_path"] | |
| ate_mean, rpe_trans_mean, rpe_rot_mean = eval_pose_estimation_dist( | |
| args, model, save_dir=save_dir, img_path=img_path, mask_path=mask_path) | |
| return ate_mean, rpe_trans_mean, rpe_rot_mean | |
| def eval_pose_estimation_dist(args, | |
| model, | |
| img_path, | |
| save_dir=None, | |
| mask_path=None): | |
| metadata = dataset_metadata.get(args.eval_dataset) | |
| model.eval() | |
| seq_list = args.seq_list | |
| if seq_list is None: | |
| if metadata.get("full_seq", False): | |
| args.full_seq = True | |
| else: | |
| seq_list = metadata.get("seq_list", []) | |
| if args.full_seq: | |
| seq_list = os.listdir(img_path) | |
| seq_list = [ | |
| seq for seq in seq_list | |
| if os.path.isdir(os.path.join(img_path, seq)) | |
| ] | |
| seq_list = sorted(seq_list) | |
| if save_dir is None: | |
| save_dir = args.output_dir | |
| distributed_state = PartialState() | |
| model.to(distributed_state.device) | |
| device = distributed_state.device | |
| with distributed_state.split_between_processes(seq_list) as seqs: | |
| error_log_path = f"{save_dir}/_error_log_{distributed_state.process_index}.txt" # Unique log file per process | |
| for seq in tqdm(seqs): | |
| try: | |
| dir_path = metadata["dir_path_func"](img_path, seq) | |
| # Handle skip_condition | |
| skip_condition = metadata.get("skip_condition", None) | |
| if skip_condition is not None and skip_condition( | |
| save_dir, seq): | |
| continue | |
| mask_path_seq_func = metadata.get("mask_path_seq_func", | |
| lambda mask_path, seq: None) | |
| mask_path_seq = mask_path_seq_func(mask_path, seq) | |
| filelist = [ | |
| os.path.join(dir_path, name) | |
| for name in os.listdir(dir_path) | |
| ] | |
| filelist.sort() | |
| filelist = filelist[::args.pose_eval_stride] | |
| images = load_images( | |
| filelist, | |
| size=518, | |
| verbose=True, | |
| crop=False, | |
| patch_size=14, | |
| ) | |
| images = collate_with_cat([tuple(images)]) | |
| images = torch.stack([view["img"] for view in images], dim=1) | |
| images = ImgDust3r2Stream3r(images).to(device) | |
| with torch.no_grad(): | |
| session = StreamSession(model, mode="causal") | |
| for i in range(images.shape[1]): | |
| image = images[:, i:i+1] | |
| predictions = session.forward_stream(image) | |
| print( | |
| f"Finished depth estmation of {len(filelist)} images" | |
| ) | |
| os.makedirs(f"{save_dir}/{seq}", exist_ok=True) | |
| save_depth_maps(None, | |
| f"{save_dir}/{seq}", | |
| conf_self=None, | |
| depth_maps=predictions['depth'].squeeze().cpu()) | |
| except Exception as e: | |
| if "out of memory" in str(e): | |
| # Handle OOM | |
| torch.cuda.empty_cache() # Clear the CUDA memory | |
| with open(error_log_path, "a") as f: | |
| f.write( | |
| f"OOM error in sequence {seq}, skipping this sequence.\n" | |
| ) | |
| print(f"OOM error in sequence {seq}, skipping...") | |
| elif "Degenerate covariance rank" in str( | |
| e) or "Eigenvalues did not converge" in str(e): | |
| # Handle Degenerate covariance rank exception and Eigenvalues did not converge exception | |
| with open(error_log_path, "a") as f: | |
| f.write(f"Exception in sequence {seq}: {str(e)}\n") | |
| print( | |
| f"Traj evaluation error in sequence {seq}, skipping.") | |
| else: | |
| raise e # Rethrow if it's not an expected exception | |
| return None, None, None | |
| def main(): | |
| args = get_args_parser() | |
| args = args.parse_args() | |
| if args.eval_dataset == "sintel": | |
| args.full_seq = True | |
| else: | |
| args.full_seq = False | |
| args.no_crop = True | |
| model = STream3R.from_pretrained("yslan/STream3R").to(args.device) | |
| model.eval() | |
| eval_pose_estimation(args, model, save_dir=args.output_dir) | |
| if __name__ == "__main__": | |
| main() | |