Spaces:
Configuration error
Configuration error
File size: 4,637 Bytes
9d31508 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import torch
import numpy as np
import matplotlib
import numpy as np
import cv2
import argparse
from pathlib import Path
from tqdm import tqdm
import os
import sys
from stream3r.models.stream3r import STream3R
from stream3r.dust3r.utils.device import collate_with_cat
from stream3r.dust3r.utils.image import load_images_for_eval as load_images
from stream3r.utils.utils import ImgDust3r2Stream3r
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
from eval.monodepth.metadata import dataset_metadata
torch.backends.cuda.matmul.allow_tf32 = True
# avoid high cpu usage
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
torch.set_num_threads(1)
# ===========================================
def colorize_depth(depth: np.ndarray,
mask: np.ndarray = None,
normalize: bool = True,
cmap: str = 'Spectral') -> np.ndarray:
if mask is None:
depth = np.where(depth > 0, depth, np.nan)
else:
depth = np.where((depth > 0) & mask, depth, np.nan)
disp = 1 / depth
if normalize:
min_disp, max_disp = np.nanquantile(disp,
0.001), np.nanquantile(disp, 0.99)
disp = (disp - min_disp) / (max_disp - min_disp)
colored = np.nan_to_num(matplotlib.colormaps[cmap](1.0 - disp)[..., :3], 0)
colored = np.ascontiguousarray((colored.clip(0, 1) * 255).astype(np.uint8))
return colored
def get_args_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--device",
type=str,
default="cuda",
help="pytorch device")
parser.add_argument("--output_dir",
type=str,
default="",
help="value for outdir")
parser.add_argument("--no_crop",
type=bool,
default=True,
help="whether to crop input data")
parser.add_argument("--full_seq",
type=bool,
default=False,
help="whether to use all seqs")
parser.add_argument("--seq_list", default=None)
parser.add_argument("--eval_dataset",
type=str,
default="nyu",
choices=list(dataset_metadata.keys()))
return parser
def eval_mono_depth_estimation(args, model, device):
metadata = dataset_metadata.get(args.eval_dataset)
if metadata is None:
raise ValueError(f"Unknown dataset: {args.eval_dataset}")
img_path = metadata.get("img_path")
if "img_path_func" in metadata:
img_path = metadata["img_path_func"](args)
process_func = metadata.get("process_func")
if process_func is None:
raise ValueError(
f"No processing function defined for dataset: {args.eval_dataset}")
for filelist, save_dir in process_func(args, img_path):
Path(save_dir).mkdir(parents=True, exist_ok=True)
eval_mono_depth(args, model, device, filelist, save_dir=save_dir)
def eval_mono_depth(args, model, device, filelist, save_dir=None):
for file in tqdm(filelist):
file = [file]
images = load_images(
file,
size=518,
verbose=True,
crop=False,
patch_size=14,
)
images = collate_with_cat([tuple(images)])
images = torch.stack([view["img"] for view in images], dim=1)
images = ImgDust3r2Stream3r(images).to(device)
with torch.no_grad():
predictions = model(images)
depth_map = predictions['depth'][0,0].squeeze(-1).cpu()
if save_dir is not None:
# save the depth map to the save_dir as npy
np.save(
f"{save_dir}/{file[0].split('/')[-1].replace('.png','depth.npy')}",
depth_map.cpu().numpy(),
)
depth_map = colorize_depth(depth_map)
cv2.imwrite(
f"{save_dir}/{file[0].split('/')[-1].replace('.png','depth.jpg')}",
depth_map,
)
def main():
args = get_args_parser()
args = args.parse_args()
if args.eval_dataset == "sintel":
args.full_seq = True
else:
args.full_seq = False
model = STream3R.from_pretrained("yslan/STream3R").to(args.device)
model.eval()
eval_mono_depth_estimation(args, model, args.device)
if __name__ == "__main__":
main()
|