brian4dwell's picture
add stream3r
9d31508
import os
import sys
from copy import deepcopy
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
import time
import torch
import argparse
import numpy as np
import open3d as o3d
import os.path as osp
from torch.utils.data._utils.collate import default_collate
from tqdm import tqdm
from stream3r.models.stream3r import STream3R
from stream3r.dust3r.utils.geometry import geotrf
from stream3r.models.components.utils.geometry import unproject_depth_map_to_point_map
from stream3r.models.components.utils.pose_enc import pose_encoding_to_extri_intri
from stream3r.utils.utils import ImgDust3r2Stream3r
from eval.mv_recon.criterion import Regr3D_t_ScaleShiftInv, L21
from eval.mv_recon.utils import accuracy, completion
from eval.mv_recon.data import SevenScenes, NRGBD
torch.backends.cuda.matmul.allow_tf32 = True
# avoid high cpu usage
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
torch.set_num_threads(1)
# ===========================================
def get_args_parser():
parser = argparse.ArgumentParser("3D Reconstruction evaluation",
add_help=False)
parser.add_argument(
"--weights",
type=str,
default="",
help="ckpt name",
)
parser.add_argument("--device", type=str, default="cuda:0", help="device")
parser.add_argument("--model_name", type=str, default="")
parser.add_argument("--conf_thresh",
type=float,
default=0.0,
help="confidence threshold")
parser.add_argument(
"--output_dir",
type=str,
default="",
help="value for outdir",
)
parser.add_argument("--size", type=int, default=512)
parser.add_argument("--revisit", type=int, default=1, help="revisit times")
parser.add_argument("--freeze", action="store_true")
return parser
def main(args):
if args.size == 518:
resolution = (518, 392)
elif args.size == 512:
resolution = (512, 384)
elif args.size == 224:
resolution = 224
else:
raise NotImplementedError
datasets_all = {
"7scenes":
SevenScenes(
split="test",
ROOT="./data/7scenes",
resolution=resolution,
num_seq=1,
full_video=True,
kf_every=200,
), # 20),
"NRGBD":
NRGBD(
split="test",
ROOT="./data/neural_rgbd",
resolution=resolution,
num_seq=1,
full_video=True,
kf_every=500,
),
}
device = 'cuda'
model_name = args.model_name
device = "cuda" if torch.cuda.is_available() else "cpu"
model = STream3R.from_pretrained("yslan/STream3R").to(device)
model.eval()
os.makedirs(args.output_dir, exist_ok=True)
criterion = Regr3D_t_ScaleShiftInv(L21, norm_mode=False, gt_scale=True)
with torch.no_grad():
for name_data, dataset in datasets_all.items():
save_path = osp.join(args.output_dir, name_data)
os.makedirs(save_path, exist_ok=True)
log_file = osp.join(save_path,
f"logs_0.txt")
acc_all = 0
acc_all_med = 0
comp_all = 0
comp_all_med = 0
nc1_all = 0
nc1_all_med = 0
nc2_all = 0
nc2_all_med = 0
idxs = list(range(len(dataset)))
for data_idx in tqdm(idxs):
batch = default_collate([dataset[data_idx]])
ignore_keys = set([
"depthmap",
"dataset",
"label",
"instance",
"idx",
"true_shape",
"rng",
])
for view in batch:
for name in view.keys(): # pseudo_focal
if name in ignore_keys:
continue
if isinstance(view[name], tuple) or isinstance(
view[name], list):
view[name] = [
x.to(device, non_blocking=True)
for x in view[name]
]
else:
view[name] = view[name].to(device,
non_blocking=True)
if model_name == "ours" or model_name == "stream3r":
revisit = args.revisit
update = not args.freeze
if revisit > 1:
# repeat input for 'revisit' times
new_views = []
for r in range(revisit):
for i in range(len(batch)):
new_view = deepcopy(batch[i])
new_view["idx"] = [
(r * len(batch) + i)
for _ in range(len(batch[i]["idx"]))
]
new_view["instance"] = [
str(r * len(batch) + i) for _ in range(
len(batch[i]["instance"]))
]
if r > 0:
if not update:
new_view[
"update"] = torch.zeros_like(
batch[i]["update"]).bool()
new_views.append(new_view)
batch = new_views
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_cpu = [
{
k: v.to('cpu') if isinstance(v, torch.Tensor) else v for k, v in sample.items()
} for sample in batch
]
# move all stuffs in batch to cuda
with torch.autocast('cuda', enabled=False):
images = torch.cat([item['img'] for item in batch])
images = ImgDust3r2Stream3r(images).to(device)
with torch.no_grad():
predictions = model(images)
extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], predictions["images"].shape[-2:])
world_points_from_depth = unproject_depth_map_to_point_map(
predictions["depth"].cpu().numpy().squeeze(0),
extrinsic.cpu().numpy().squeeze(0),
intrinsic.cpu().numpy().squeeze(0)
)
world_points_from_depth = torch.from_numpy(world_points_from_depth).unsqueeze(0).to(device=device)
preds = world_points_from_depth
confs = predictions["depth_conf"]
all_preds = []
for idx in range(preds.shape[1]):
all_preds.append(
{'pts3d': preds[0][idx:idx+1].cpu(), 'conf': confs[0][idx:idx+1]}
)
# convert preds into list
views = batch_cpu
preds = all_preds
valid_length = len(preds) // revisit
preds = preds[-valid_length:]
batch = batch[-valid_length:]
# Evaluation
print(
f"Evaluation for {name_data} {data_idx+1}/{len(dataset)}"
)
gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = (
criterion.get_all_pts3d_t(views, preds))
pred_scale, gt_scale, pred_shift_z, gt_shift_z = (
monitoring["pred_scale"],
monitoring["gt_scale"],
monitoring["pred_shift_z"],
monitoring["gt_shift_z"],
)
in_camera1 = None
pts_all = []
pts_gt_all = []
images_all = []
masks_all = []
conf_all = []
for j, view in enumerate(batch):
if in_camera1 is None:
in_camera1 = view["camera_pose"][0].cpu()
image = view["img"].permute(0, 2, 3,
1).cpu().numpy()[0]
mask = view["valid_mask"].cpu().numpy()[0]
pts = pred_pts[j].cpu().numpy()[0]
conf = preds[j]["conf"].cpu().data.numpy()[0]
# mask = mask & (conf > 1.8)
pts_gt = gt_pts[j].detach().cpu().numpy()[0]
H, W = image.shape[:2]
cx = W // 2
cy = H // 2
l, t = cx - 112, cy - 112
r, b = cx + 112, cy + 112
image = image[t:b, l:r]
mask = mask[t:b, l:r]
pts = pts[t:b, l:r]
pts_gt = pts_gt[t:b, l:r]
#### Align predicted 3D points to the ground truth
pts[..., -1] += gt_shift_z.cpu().numpy().item()
pts = geotrf(in_camera1, pts)
pts_gt[..., -1] += gt_shift_z.cpu().numpy().item()
pts_gt = geotrf(in_camera1, pts_gt)
images_all.append((image[None, ...] + 1.0) / 2.0)
pts_all.append(pts[None, ...])
pts_gt_all.append(pts_gt[None, ...])
masks_all.append(mask[None, ...])
conf_all.append(conf[None, ...])
images_all = np.concatenate(images_all, axis=0)
pts_all = np.concatenate(pts_all, axis=0)
pts_gt_all = np.concatenate(pts_gt_all, axis=0)
masks_all = np.concatenate(masks_all, axis=0)
scene_id = view["label"][0].rsplit("/", 1)[0]
save_params = {}
save_params["images_all"] = images_all
save_params["pts_all"] = pts_all
save_params["pts_gt_all"] = pts_gt_all
save_params["masks_all"] = masks_all
np.save(
os.path.join(save_path,
f"{scene_id.replace('/', '_')}.npy"),
save_params,
)
if "DTU" in name_data:
threshold = 100
else:
threshold = 0.1
pts_all_masked = pts_all[masks_all > 0]
pts_gt_all_masked = pts_gt_all[masks_all > 0]
images_all_masked = images_all[masks_all > 0]
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(
pts_all_masked.reshape(-1, 3))
pcd.colors = o3d.utility.Vector3dVector(
images_all_masked.reshape(-1, 3))
o3d.io.write_point_cloud(
os.path.join(save_path,
f"{scene_id.replace('/', '_')}-mask.ply"),
pcd,
)
pcd_gt = o3d.geometry.PointCloud()
pcd_gt.points = o3d.utility.Vector3dVector(
pts_gt_all_masked.reshape(-1, 3))
pcd_gt.colors = o3d.utility.Vector3dVector(
images_all_masked.reshape(-1, 3))
o3d.io.write_point_cloud(
os.path.join(save_path,
f"{scene_id.replace('/', '_')}-gt.ply"),
pcd_gt,
)
trans_init = np.eye(4)
reg_p2p = o3d.pipelines.registration.registration_icp(
pcd,
pcd_gt,
threshold,
trans_init,
o3d.pipelines.registration.
TransformationEstimationPointToPoint(),
)
transformation = reg_p2p.transformation
pcd = pcd.transform(transformation)
pcd.estimate_normals()
pcd_gt.estimate_normals()
gt_normal = np.asarray(pcd_gt.normals)
pred_normal = np.asarray(pcd.normals)
acc, acc_med, nc1, nc1_med = accuracy(
pcd_gt.points, pcd.points, gt_normal, pred_normal)
comp, comp_med, nc2, nc2_med = completion(
pcd_gt.points, pcd.points, gt_normal, pred_normal)
print(
f"Idx: {scene_id}, Acc: {acc}, Comp: {comp}, NC1: {nc1}, NC2: {nc2} - Acc_med: {acc_med}, Compc_med: {comp_med}, NC1c_med: {nc1_med}, NC2c_med: {nc2_med}"
)
print(
f"Idx: {scene_id}, Acc: {acc}, Comp: {comp}, NC1: {nc1}, NC2: {nc2} - Acc_med: {acc_med}, Compc_med: {comp_med}, NC1c_med: {nc1_med}, NC2c_med: {nc2_med}",
file=open(log_file, "a"),
)
acc_all += acc
comp_all += comp
nc1_all += nc1
nc2_all += nc2
acc_all_med += acc_med
comp_all_med += comp_med
nc1_all_med += nc1_med
nc2_all_med += nc2_med
# release cuda memory
torch.cuda.empty_cache()
to_write = ""
# Copy the error log from each process to the main error log
for i in range(8):
if not os.path.exists(osp.join(save_path,
f"logs_{i}.txt")):
break
with open(osp.join(save_path, f"logs_{i}.txt"),
"r") as f_sub:
to_write += f_sub.read()
with open(osp.join(save_path, f"logs_all.txt"), "w") as f:
log_data = to_write
metrics = defaultdict(list)
for line in log_data.strip().split("\n"):
match = regex.match(line)
if match:
data = match.groupdict()
# Exclude 'scene_id' from metrics as it's an identifier
for key, value in data.items():
if key != "scene_id":
metrics[key].append(float(value))
metrics["nc"].append(
(float(data["nc1"]) + float(data["nc2"])) / 2)
metrics["nc_med"].append(
(float(data["nc1_med"]) +
float(data["nc2_med"])) / 2)
mean_metrics = {
metric: sum(values) / len(values)
for metric, values in metrics.items()
}
c_name = "mean"
print_str = f"{c_name.ljust(20)}: "
for m_name in mean_metrics:
print_num = np.mean(mean_metrics[m_name])
print_str = print_str + f"{m_name}: {print_num:.3f} | "
print_str = print_str + "\n"
f.write(to_write + print_str)
from collections import defaultdict
import re
pattern = r"""
Idx:\s*(?P<scene_id>[^,]+),\s*
Acc:\s*(?P<acc>[^,]+),\s*
Comp:\s*(?P<comp>[^,]+),\s*
NC1:\s*(?P<nc1>[^,]+),\s*
NC2:\s*(?P<nc2>[^,]+)\s*-\s*
Acc_med:\s*(?P<acc_med>[^,]+),\s*
Compc_med:\s*(?P<comp_med>[^,]+),\s*
NC1c_med:\s*(?P<nc1_med>[^,]+),\s*
NC2c_med:\s*(?P<nc2_med>[^,]+)
"""
regex = re.compile(pattern, re.VERBOSE)
if __name__ == "__main__":
parser = get_args_parser()
args = parser.parse_args()
main(args)