brian4dwell's picture
add stream3r
9d31508
import os
import sys
import numpy as np
import torch
import argparse
from accelerate import PartialState
from tqdm import tqdm
from PIL import Image
import imageio.v2 as iio
from stream3r.models.stream3r import STream3R
from stream3r.stream_session import StreamSession
from stream3r.dust3r.utils.image import load_images_for_eval as load_images
from stream3r.dust3r.utils.device import collate_with_cat
from stream3r.utils.utils import ImgDust3r2Stream3r
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
from eval.video_depth.metadata import dataset_metadata
from eval.video_depth.utils import colorize
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.backends.cuda.matmul.allow_tf32 = True
# avoid high cpu usage
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
torch.set_num_threads(1)
# ===========================================
def save_depth_maps(pts3ds_self, path, conf_self=None, depth_maps=None):
if depth_maps is None:
depth_maps = torch.stack([pts3d_self[..., -1] for pts3d_self in pts3ds_self], 0)
min_depth = depth_maps.min() # float(torch.quantile(out, 0.01))
max_depth = depth_maps.max() # float(torch.quantile(out, 0.99))
colored_depth = colorize(
depth_maps,
cmap_name="Spectral_r",
range=(min_depth, max_depth),
append_cbar=True,
)
images = []
if conf_self is not None:
conf_selfs = torch.concat(conf_self, 0)
min_conf = torch.log(conf_selfs.min()) # float(torch.quantile(out, 0.01))
max_conf = torch.log(conf_selfs.max()) # float(torch.quantile(out, 0.99))
colored_conf = colorize(
torch.log(conf_selfs),
cmap_name="jet",
range=(min_conf, max_conf),
append_cbar=True,
)
for i, depth_map in enumerate(colored_depth):
# Apply color map to depth map
img_path = f"{path}/frame_{(i):04d}.png"
if conf_self is None:
to_save = (depth_map * 255).detach().cpu().numpy().astype(np.uint8)
else:
to_save = torch.cat([depth_map, colored_conf[i]], dim=1)
to_save = (to_save * 255).detach().cpu().numpy().astype(np.uint8)
iio.imwrite(img_path, to_save)
images.append(Image.open(img_path))
np.save(f"{path}/frame_{(i):04d}.npy", depth_maps[i].detach().cpu().numpy())
return depth_maps
def get_args_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--device", type=str, default="cuda", help="pytorch device")
parser.add_argument(
"--output_dir",
type=str,
default="",
help="value for outdir",
)
parser.add_argument(
"--no_crop", type=bool, default=True, help="whether to crop input data"
)
parser.add_argument(
"--eval_dataset",
type=str,
default="sintel",
choices=list(dataset_metadata.keys()),
)
parser.add_argument("--size", type=int, default="512")
parser.add_argument(
"--pose_eval_stride", default=1, type=int, help="stride for pose evaluation"
)
parser.add_argument(
"--full_seq",
action="store_true",
default=False,
help="use full sequence for pose evaluation",
)
parser.add_argument(
"--seq_list",
nargs="+",
default=None,
help="list of sequences for pose evaluation",
)
parser.add_argument(
"--checkpoint_dir",
type=str,
default="",
help="path to the checkpoint directory",
)
return parser
def eval_pose_estimation(args, model, save_dir=None):
metadata = dataset_metadata.get(args.eval_dataset)
img_path = metadata["img_path"]
mask_path = metadata["mask_path"]
ate_mean, rpe_trans_mean, rpe_rot_mean = eval_pose_estimation_dist(
args, model, save_dir=save_dir, img_path=img_path, mask_path=mask_path)
return ate_mean, rpe_trans_mean, rpe_rot_mean
def eval_pose_estimation_dist(args,
model,
img_path,
save_dir=None,
mask_path=None):
metadata = dataset_metadata.get(args.eval_dataset)
model.eval()
seq_list = args.seq_list
if seq_list is None:
if metadata.get("full_seq", False):
args.full_seq = True
else:
seq_list = metadata.get("seq_list", [])
if args.full_seq:
seq_list = os.listdir(img_path)
seq_list = [
seq for seq in seq_list
if os.path.isdir(os.path.join(img_path, seq))
]
seq_list = sorted(seq_list)
if save_dir is None:
save_dir = args.output_dir
distributed_state = PartialState()
model.to(distributed_state.device)
device = distributed_state.device
with distributed_state.split_between_processes(seq_list) as seqs:
error_log_path = f"{save_dir}/_error_log_{distributed_state.process_index}.txt" # Unique log file per process
for seq in tqdm(seqs):
try:
dir_path = metadata["dir_path_func"](img_path, seq)
# Handle skip_condition
skip_condition = metadata.get("skip_condition", None)
if skip_condition is not None and skip_condition(
save_dir, seq):
continue
mask_path_seq_func = metadata.get("mask_path_seq_func",
lambda mask_path, seq: None)
mask_path_seq = mask_path_seq_func(mask_path, seq)
filelist = [
os.path.join(dir_path, name)
for name in os.listdir(dir_path)
]
filelist.sort()
filelist = filelist[::args.pose_eval_stride]
images = load_images(
filelist,
size=518,
verbose=True,
crop=False,
patch_size=14,
)
images = collate_with_cat([tuple(images)])
images = torch.stack([view["img"] for view in images], dim=1)
images = ImgDust3r2Stream3r(images).to(device)
with torch.no_grad():
session = StreamSession(model, mode="causal")
for i in range(images.shape[1]):
image = images[:, i:i+1]
predictions = session.forward_stream(image)
print(
f"Finished depth estmation of {len(filelist)} images"
)
os.makedirs(f"{save_dir}/{seq}", exist_ok=True)
save_depth_maps(None,
f"{save_dir}/{seq}",
conf_self=None,
depth_maps=predictions['depth'].squeeze().cpu())
except Exception as e:
if "out of memory" in str(e):
# Handle OOM
torch.cuda.empty_cache() # Clear the CUDA memory
with open(error_log_path, "a") as f:
f.write(
f"OOM error in sequence {seq}, skipping this sequence.\n"
)
print(f"OOM error in sequence {seq}, skipping...")
elif "Degenerate covariance rank" in str(
e) or "Eigenvalues did not converge" in str(e):
# Handle Degenerate covariance rank exception and Eigenvalues did not converge exception
with open(error_log_path, "a") as f:
f.write(f"Exception in sequence {seq}: {str(e)}\n")
print(
f"Traj evaluation error in sequence {seq}, skipping.")
else:
raise e # Rethrow if it's not an expected exception
return None, None, None
def main():
args = get_args_parser()
args = args.parse_args()
if args.eval_dataset == "sintel":
args.full_seq = True
else:
args.full_seq = False
args.no_crop = True
model = STream3R.from_pretrained("yslan/STream3R").to(args.device)
model.eval()
eval_pose_estimation(args, model, save_dir=args.output_dir)
if __name__ == "__main__":
main()