Spaces:
Configuration error
Configuration error
| import os | |
| import sys | |
| from copy import deepcopy | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) | |
| import time | |
| import torch | |
| import argparse | |
| import numpy as np | |
| import open3d as o3d | |
| import os.path as osp | |
| from torch.utils.data._utils.collate import default_collate | |
| from tqdm import tqdm | |
| from stream3r.models.stream3r import STream3R | |
| from stream3r.dust3r.utils.geometry import geotrf | |
| from stream3r.models.components.utils.geometry import unproject_depth_map_to_point_map | |
| from stream3r.models.components.utils.pose_enc import pose_encoding_to_extri_intri | |
| from stream3r.utils.utils import ImgDust3r2Stream3r | |
| from eval.mv_recon.criterion import Regr3D_t_ScaleShiftInv, L21 | |
| from eval.mv_recon.utils import accuracy, completion | |
| from eval.mv_recon.data import SevenScenes, NRGBD | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| # avoid high cpu usage | |
| os.environ["OMP_NUM_THREADS"] = "1" | |
| os.environ["MKL_NUM_THREADS"] = "1" | |
| os.environ["NUMEXPR_NUM_THREADS"] = "1" | |
| os.environ["OPENBLAS_NUM_THREADS"] = "1" | |
| torch.set_num_threads(1) | |
| # =========================================== | |
| def get_args_parser(): | |
| parser = argparse.ArgumentParser("3D Reconstruction evaluation", | |
| add_help=False) | |
| parser.add_argument( | |
| "--weights", | |
| type=str, | |
| default="", | |
| help="ckpt name", | |
| ) | |
| parser.add_argument("--device", type=str, default="cuda:0", help="device") | |
| parser.add_argument("--model_name", type=str, default="") | |
| parser.add_argument("--conf_thresh", | |
| type=float, | |
| default=0.0, | |
| help="confidence threshold") | |
| parser.add_argument( | |
| "--output_dir", | |
| type=str, | |
| default="", | |
| help="value for outdir", | |
| ) | |
| parser.add_argument("--size", type=int, default=512) | |
| parser.add_argument("--revisit", type=int, default=1, help="revisit times") | |
| parser.add_argument("--freeze", action="store_true") | |
| return parser | |
| def main(args): | |
| if args.size == 518: | |
| resolution = (518, 392) | |
| elif args.size == 512: | |
| resolution = (512, 384) | |
| elif args.size == 224: | |
| resolution = 224 | |
| else: | |
| raise NotImplementedError | |
| datasets_all = { | |
| "7scenes": | |
| SevenScenes( | |
| split="test", | |
| ROOT="./data/7scenes", | |
| resolution=resolution, | |
| num_seq=1, | |
| full_video=True, | |
| kf_every=200, | |
| ), # 20), | |
| "NRGBD": | |
| NRGBD( | |
| split="test", | |
| ROOT="./data/neural_rgbd", | |
| resolution=resolution, | |
| num_seq=1, | |
| full_video=True, | |
| kf_every=500, | |
| ), | |
| } | |
| device = 'cuda' | |
| model_name = args.model_name | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = STream3R.from_pretrained("yslan/STream3R").to(device) | |
| model.eval() | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| criterion = Regr3D_t_ScaleShiftInv(L21, norm_mode=False, gt_scale=True) | |
| with torch.no_grad(): | |
| for name_data, dataset in datasets_all.items(): | |
| save_path = osp.join(args.output_dir, name_data) | |
| os.makedirs(save_path, exist_ok=True) | |
| log_file = osp.join(save_path, | |
| f"logs_0.txt") | |
| acc_all = 0 | |
| acc_all_med = 0 | |
| comp_all = 0 | |
| comp_all_med = 0 | |
| nc1_all = 0 | |
| nc1_all_med = 0 | |
| nc2_all = 0 | |
| nc2_all_med = 0 | |
| idxs = list(range(len(dataset))) | |
| for data_idx in tqdm(idxs): | |
| batch = default_collate([dataset[data_idx]]) | |
| ignore_keys = set([ | |
| "depthmap", | |
| "dataset", | |
| "label", | |
| "instance", | |
| "idx", | |
| "true_shape", | |
| "rng", | |
| ]) | |
| for view in batch: | |
| for name in view.keys(): # pseudo_focal | |
| if name in ignore_keys: | |
| continue | |
| if isinstance(view[name], tuple) or isinstance( | |
| view[name], list): | |
| view[name] = [ | |
| x.to(device, non_blocking=True) | |
| for x in view[name] | |
| ] | |
| else: | |
| view[name] = view[name].to(device, | |
| non_blocking=True) | |
| if model_name == "ours" or model_name == "stream3r": | |
| revisit = args.revisit | |
| update = not args.freeze | |
| if revisit > 1: | |
| # repeat input for 'revisit' times | |
| new_views = [] | |
| for r in range(revisit): | |
| for i in range(len(batch)): | |
| new_view = deepcopy(batch[i]) | |
| new_view["idx"] = [ | |
| (r * len(batch) + i) | |
| for _ in range(len(batch[i]["idx"])) | |
| ] | |
| new_view["instance"] = [ | |
| str(r * len(batch) + i) for _ in range( | |
| len(batch[i]["instance"])) | |
| ] | |
| if r > 0: | |
| if not update: | |
| new_view[ | |
| "update"] = torch.zeros_like( | |
| batch[i]["update"]).bool() | |
| new_views.append(new_view) | |
| batch = new_views | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| batch_cpu = [ | |
| { | |
| k: v.to('cpu') if isinstance(v, torch.Tensor) else v for k, v in sample.items() | |
| } for sample in batch | |
| ] | |
| # move all stuffs in batch to cuda | |
| with torch.autocast('cuda', enabled=False): | |
| images = torch.cat([item['img'] for item in batch]) | |
| images = ImgDust3r2Stream3r(images).to(device) | |
| with torch.no_grad(): | |
| predictions = model(images) | |
| extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], predictions["images"].shape[-2:]) | |
| world_points_from_depth = unproject_depth_map_to_point_map( | |
| predictions["depth"].cpu().numpy().squeeze(0), | |
| extrinsic.cpu().numpy().squeeze(0), | |
| intrinsic.cpu().numpy().squeeze(0) | |
| ) | |
| world_points_from_depth = torch.from_numpy(world_points_from_depth).unsqueeze(0).to(device=device) | |
| preds = world_points_from_depth | |
| confs = predictions["depth_conf"] | |
| all_preds = [] | |
| for idx in range(preds.shape[1]): | |
| all_preds.append( | |
| {'pts3d': preds[0][idx:idx+1].cpu(), 'conf': confs[0][idx:idx+1]} | |
| ) | |
| # convert preds into list | |
| views = batch_cpu | |
| preds = all_preds | |
| valid_length = len(preds) // revisit | |
| preds = preds[-valid_length:] | |
| batch = batch[-valid_length:] | |
| # Evaluation | |
| print( | |
| f"Evaluation for {name_data} {data_idx+1}/{len(dataset)}" | |
| ) | |
| gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = ( | |
| criterion.get_all_pts3d_t(views, preds)) | |
| pred_scale, gt_scale, pred_shift_z, gt_shift_z = ( | |
| monitoring["pred_scale"], | |
| monitoring["gt_scale"], | |
| monitoring["pred_shift_z"], | |
| monitoring["gt_shift_z"], | |
| ) | |
| in_camera1 = None | |
| pts_all = [] | |
| pts_gt_all = [] | |
| images_all = [] | |
| masks_all = [] | |
| conf_all = [] | |
| for j, view in enumerate(batch): | |
| if in_camera1 is None: | |
| in_camera1 = view["camera_pose"][0].cpu() | |
| image = view["img"].permute(0, 2, 3, | |
| 1).cpu().numpy()[0] | |
| mask = view["valid_mask"].cpu().numpy()[0] | |
| pts = pred_pts[j].cpu().numpy()[0] | |
| conf = preds[j]["conf"].cpu().data.numpy()[0] | |
| # mask = mask & (conf > 1.8) | |
| pts_gt = gt_pts[j].detach().cpu().numpy()[0] | |
| H, W = image.shape[:2] | |
| cx = W // 2 | |
| cy = H // 2 | |
| l, t = cx - 112, cy - 112 | |
| r, b = cx + 112, cy + 112 | |
| image = image[t:b, l:r] | |
| mask = mask[t:b, l:r] | |
| pts = pts[t:b, l:r] | |
| pts_gt = pts_gt[t:b, l:r] | |
| #### Align predicted 3D points to the ground truth | |
| pts[..., -1] += gt_shift_z.cpu().numpy().item() | |
| pts = geotrf(in_camera1, pts) | |
| pts_gt[..., -1] += gt_shift_z.cpu().numpy().item() | |
| pts_gt = geotrf(in_camera1, pts_gt) | |
| images_all.append((image[None, ...] + 1.0) / 2.0) | |
| pts_all.append(pts[None, ...]) | |
| pts_gt_all.append(pts_gt[None, ...]) | |
| masks_all.append(mask[None, ...]) | |
| conf_all.append(conf[None, ...]) | |
| images_all = np.concatenate(images_all, axis=0) | |
| pts_all = np.concatenate(pts_all, axis=0) | |
| pts_gt_all = np.concatenate(pts_gt_all, axis=0) | |
| masks_all = np.concatenate(masks_all, axis=0) | |
| scene_id = view["label"][0].rsplit("/", 1)[0] | |
| save_params = {} | |
| save_params["images_all"] = images_all | |
| save_params["pts_all"] = pts_all | |
| save_params["pts_gt_all"] = pts_gt_all | |
| save_params["masks_all"] = masks_all | |
| np.save( | |
| os.path.join(save_path, | |
| f"{scene_id.replace('/', '_')}.npy"), | |
| save_params, | |
| ) | |
| if "DTU" in name_data: | |
| threshold = 100 | |
| else: | |
| threshold = 0.1 | |
| pts_all_masked = pts_all[masks_all > 0] | |
| pts_gt_all_masked = pts_gt_all[masks_all > 0] | |
| images_all_masked = images_all[masks_all > 0] | |
| pcd = o3d.geometry.PointCloud() | |
| pcd.points = o3d.utility.Vector3dVector( | |
| pts_all_masked.reshape(-1, 3)) | |
| pcd.colors = o3d.utility.Vector3dVector( | |
| images_all_masked.reshape(-1, 3)) | |
| o3d.io.write_point_cloud( | |
| os.path.join(save_path, | |
| f"{scene_id.replace('/', '_')}-mask.ply"), | |
| pcd, | |
| ) | |
| pcd_gt = o3d.geometry.PointCloud() | |
| pcd_gt.points = o3d.utility.Vector3dVector( | |
| pts_gt_all_masked.reshape(-1, 3)) | |
| pcd_gt.colors = o3d.utility.Vector3dVector( | |
| images_all_masked.reshape(-1, 3)) | |
| o3d.io.write_point_cloud( | |
| os.path.join(save_path, | |
| f"{scene_id.replace('/', '_')}-gt.ply"), | |
| pcd_gt, | |
| ) | |
| trans_init = np.eye(4) | |
| reg_p2p = o3d.pipelines.registration.registration_icp( | |
| pcd, | |
| pcd_gt, | |
| threshold, | |
| trans_init, | |
| o3d.pipelines.registration. | |
| TransformationEstimationPointToPoint(), | |
| ) | |
| transformation = reg_p2p.transformation | |
| pcd = pcd.transform(transformation) | |
| pcd.estimate_normals() | |
| pcd_gt.estimate_normals() | |
| gt_normal = np.asarray(pcd_gt.normals) | |
| pred_normal = np.asarray(pcd.normals) | |
| acc, acc_med, nc1, nc1_med = accuracy( | |
| pcd_gt.points, pcd.points, gt_normal, pred_normal) | |
| comp, comp_med, nc2, nc2_med = completion( | |
| pcd_gt.points, pcd.points, gt_normal, pred_normal) | |
| print( | |
| f"Idx: {scene_id}, Acc: {acc}, Comp: {comp}, NC1: {nc1}, NC2: {nc2} - Acc_med: {acc_med}, Compc_med: {comp_med}, NC1c_med: {nc1_med}, NC2c_med: {nc2_med}" | |
| ) | |
| print( | |
| f"Idx: {scene_id}, Acc: {acc}, Comp: {comp}, NC1: {nc1}, NC2: {nc2} - Acc_med: {acc_med}, Compc_med: {comp_med}, NC1c_med: {nc1_med}, NC2c_med: {nc2_med}", | |
| file=open(log_file, "a"), | |
| ) | |
| acc_all += acc | |
| comp_all += comp | |
| nc1_all += nc1 | |
| nc2_all += nc2 | |
| acc_all_med += acc_med | |
| comp_all_med += comp_med | |
| nc1_all_med += nc1_med | |
| nc2_all_med += nc2_med | |
| # release cuda memory | |
| torch.cuda.empty_cache() | |
| to_write = "" | |
| # Copy the error log from each process to the main error log | |
| for i in range(8): | |
| if not os.path.exists(osp.join(save_path, | |
| f"logs_{i}.txt")): | |
| break | |
| with open(osp.join(save_path, f"logs_{i}.txt"), | |
| "r") as f_sub: | |
| to_write += f_sub.read() | |
| with open(osp.join(save_path, f"logs_all.txt"), "w") as f: | |
| log_data = to_write | |
| metrics = defaultdict(list) | |
| for line in log_data.strip().split("\n"): | |
| match = regex.match(line) | |
| if match: | |
| data = match.groupdict() | |
| # Exclude 'scene_id' from metrics as it's an identifier | |
| for key, value in data.items(): | |
| if key != "scene_id": | |
| metrics[key].append(float(value)) | |
| metrics["nc"].append( | |
| (float(data["nc1"]) + float(data["nc2"])) / 2) | |
| metrics["nc_med"].append( | |
| (float(data["nc1_med"]) + | |
| float(data["nc2_med"])) / 2) | |
| mean_metrics = { | |
| metric: sum(values) / len(values) | |
| for metric, values in metrics.items() | |
| } | |
| c_name = "mean" | |
| print_str = f"{c_name.ljust(20)}: " | |
| for m_name in mean_metrics: | |
| print_num = np.mean(mean_metrics[m_name]) | |
| print_str = print_str + f"{m_name}: {print_num:.3f} | " | |
| print_str = print_str + "\n" | |
| f.write(to_write + print_str) | |
| from collections import defaultdict | |
| import re | |
| pattern = r""" | |
| Idx:\s*(?P<scene_id>[^,]+),\s* | |
| Acc:\s*(?P<acc>[^,]+),\s* | |
| Comp:\s*(?P<comp>[^,]+),\s* | |
| NC1:\s*(?P<nc1>[^,]+),\s* | |
| NC2:\s*(?P<nc2>[^,]+)\s*-\s* | |
| Acc_med:\s*(?P<acc_med>[^,]+),\s* | |
| Compc_med:\s*(?P<comp_med>[^,]+),\s* | |
| NC1c_med:\s*(?P<nc1_med>[^,]+),\s* | |
| NC2c_med:\s*(?P<nc2_med>[^,]+) | |
| """ | |
| regex = re.compile(pattern, re.VERBOSE) | |
| if __name__ == "__main__": | |
| parser = get_args_parser() | |
| args = parser.parse_args() | |
| main(args) | |