|
|
|
|
|
""" |
|
|
3D Point Cloud Inference and Visualization Script |
|
|
|
|
|
This script performs inference using the ARCroco3DStereo model and visualizes the |
|
|
resulting 3D point clouds with the PointCloudViewer. Use the command-line arguments |
|
|
to adjust parameters such as the model checkpoint path, image sequence directory, |
|
|
image size, device, etc. |
|
|
|
|
|
Usage: |
|
|
python demo.py [--model_path MODEL_PATH] [--seq_path SEQ_PATH] [--size IMG_SIZE] |
|
|
[--device DEVICE] [--vis_threshold VIS_THRESHOLD] [--output_dir OUT_DIR] |
|
|
|
|
|
Example: |
|
|
python demo.py --model_path src/cut3r_512_dpt_4_64.pth \ |
|
|
--seq_path examples/001 --device cuda --size 512 |
|
|
""" |
|
|
|
|
|
import os |
|
|
import numpy as np |
|
|
import torch |
|
|
import time |
|
|
import glob |
|
|
import random |
|
|
import cv2 |
|
|
import argparse |
|
|
import tempfile |
|
|
import shutil |
|
|
from copy import deepcopy |
|
|
from add_ckpt_path import add_path_to_dust3r |
|
|
import imageio.v2 as iio |
|
|
|
|
|
|
|
|
random.seed(42) |
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
"""Parse command-line arguments.""" |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Run 3D point cloud inference and visualization using ARCroco3DStereo." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--model_path", |
|
|
type=str, |
|
|
default="src/cut3r_512_dpt_4_64.pth", |
|
|
help="Path to the pretrained model checkpoint.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--seq_path", |
|
|
type=str, |
|
|
default="", |
|
|
help="Path to the directory containing the image sequence.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--device", |
|
|
type=str, |
|
|
default="cuda", |
|
|
help="Device to run inference on (e.g., 'cuda' or 'cpu').", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--size", |
|
|
type=int, |
|
|
default="512", |
|
|
help="Shape that input images will be rescaled to; if using 224+linear model, choose 224 otherwise 512", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--vis_threshold", |
|
|
type=float, |
|
|
default=1.5, |
|
|
help="Visualization threshold for the point cloud viewer. Ranging from 1 to INF", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output_dir", |
|
|
type=str, |
|
|
default="./demo_tmp", |
|
|
help="value for tempfile.tempdir", |
|
|
) |
|
|
|
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def prepare_input( |
|
|
img_paths, img_mask, size, raymaps=None, raymap_mask=None, revisit=1, update=True |
|
|
): |
|
|
""" |
|
|
Prepare input views for inference from a list of image paths. |
|
|
|
|
|
Args: |
|
|
img_paths (list): List of image file paths. |
|
|
img_mask (list of bool): Flags indicating valid images. |
|
|
size (int): Target image size. |
|
|
raymaps (list, optional): List of ray maps. |
|
|
raymap_mask (list, optional): Flags indicating valid ray maps. |
|
|
revisit (int): How many times to revisit each view. |
|
|
update (bool): Whether to update the state on revisits. |
|
|
|
|
|
Returns: |
|
|
list: A list of view dictionaries. |
|
|
""" |
|
|
|
|
|
from src.dust3r.utils.image import load_images |
|
|
|
|
|
images = load_images(img_paths, size=size) |
|
|
views = [] |
|
|
|
|
|
if raymaps is None and raymap_mask is None: |
|
|
|
|
|
for i in range(len(images)): |
|
|
view = { |
|
|
"img": images[i]["img"], |
|
|
"ray_map": torch.full( |
|
|
( |
|
|
images[i]["img"].shape[0], |
|
|
6, |
|
|
images[i]["img"].shape[-2], |
|
|
images[i]["img"].shape[-1], |
|
|
), |
|
|
torch.nan, |
|
|
), |
|
|
"true_shape": torch.from_numpy(images[i]["true_shape"]), |
|
|
"idx": i, |
|
|
"instance": str(i), |
|
|
"camera_pose": torch.from_numpy(np.eye(4, dtype=np.float32)).unsqueeze( |
|
|
0 |
|
|
), |
|
|
"img_mask": torch.tensor(True).unsqueeze(0), |
|
|
"ray_mask": torch.tensor(False).unsqueeze(0), |
|
|
"update": torch.tensor(True).unsqueeze(0), |
|
|
"reset": torch.tensor(False).unsqueeze(0), |
|
|
} |
|
|
views.append(view) |
|
|
else: |
|
|
|
|
|
num_views = len(images) + len(raymaps) |
|
|
assert len(img_mask) == len(raymap_mask) == num_views |
|
|
assert sum(img_mask) == len(images) and sum(raymap_mask) == len(raymaps) |
|
|
|
|
|
j = 0 |
|
|
k = 0 |
|
|
for i in range(num_views): |
|
|
view = { |
|
|
"img": ( |
|
|
images[j]["img"] |
|
|
if img_mask[i] |
|
|
else torch.full_like(images[0]["img"], torch.nan) |
|
|
), |
|
|
"ray_map": ( |
|
|
raymaps[k] |
|
|
if raymap_mask[i] |
|
|
else torch.full_like(raymaps[0], torch.nan) |
|
|
), |
|
|
"true_shape": ( |
|
|
torch.from_numpy(images[j]["true_shape"]) |
|
|
if img_mask[i] |
|
|
else torch.from_numpy(np.int32([raymaps[k].shape[1:-1][::-1]])) |
|
|
), |
|
|
"idx": i, |
|
|
"instance": str(i), |
|
|
"camera_pose": torch.from_numpy(np.eye(4, dtype=np.float32)).unsqueeze( |
|
|
0 |
|
|
), |
|
|
"img_mask": torch.tensor(img_mask[i]).unsqueeze(0), |
|
|
"ray_mask": torch.tensor(raymap_mask[i]).unsqueeze(0), |
|
|
"update": torch.tensor(img_mask[i]).unsqueeze(0), |
|
|
"reset": torch.tensor(False).unsqueeze(0), |
|
|
} |
|
|
if img_mask[i]: |
|
|
j += 1 |
|
|
if raymap_mask[i]: |
|
|
k += 1 |
|
|
views.append(view) |
|
|
assert j == len(images) and k == len(raymaps) |
|
|
|
|
|
if revisit > 1: |
|
|
new_views = [] |
|
|
for r in range(revisit): |
|
|
for i, view in enumerate(views): |
|
|
new_view = deepcopy(view) |
|
|
new_view["idx"] = r * len(views) + i |
|
|
new_view["instance"] = str(r * len(views) + i) |
|
|
if r > 0 and not update: |
|
|
new_view["update"] = torch.tensor(False).unsqueeze(0) |
|
|
new_views.append(new_view) |
|
|
return new_views |
|
|
|
|
|
return views |
|
|
|
|
|
|
|
|
def prepare_output(outputs, outdir, revisit=1, use_pose=True): |
|
|
""" |
|
|
Process inference outputs to generate point clouds and camera parameters for visualization. |
|
|
|
|
|
Args: |
|
|
outputs (dict): Inference outputs. |
|
|
revisit (int): Number of revisits per view. |
|
|
use_pose (bool): Whether to transform points using camera pose. |
|
|
|
|
|
Returns: |
|
|
tuple: (points, colors, confidence, camera parameters dictionary) |
|
|
""" |
|
|
from src.dust3r.utils.camera import pose_encoding_to_camera |
|
|
from src.dust3r.post_process import estimate_focal_knowing_depth |
|
|
from src.dust3r.utils.geometry import geotrf |
|
|
|
|
|
|
|
|
valid_length = len(outputs["pred"]) // revisit |
|
|
outputs["pred"] = outputs["pred"][-valid_length:] |
|
|
outputs["views"] = outputs["views"][-valid_length:] |
|
|
|
|
|
pts3ds_self_ls = [output["pts3d_in_self_view"].cpu() for output in outputs["pred"]] |
|
|
pts3ds_other = [output["pts3d_in_other_view"].cpu() for output in outputs["pred"]] |
|
|
conf_self = [output["conf_self"].cpu() for output in outputs["pred"]] |
|
|
conf_other = [output["conf"].cpu() for output in outputs["pred"]] |
|
|
pts3ds_self = torch.cat(pts3ds_self_ls, 0) |
|
|
|
|
|
|
|
|
pr_poses = [ |
|
|
pose_encoding_to_camera(pred["camera_pose"].clone()).cpu() |
|
|
for pred in outputs["pred"] |
|
|
] |
|
|
R_c2w = torch.cat([pr_pose[:, :3, :3] for pr_pose in pr_poses], 0) |
|
|
t_c2w = torch.cat([pr_pose[:, :3, 3] for pr_pose in pr_poses], 0) |
|
|
|
|
|
if use_pose: |
|
|
transformed_pts3ds_other = [] |
|
|
for pose, pself in zip(pr_poses, pts3ds_self): |
|
|
transformed_pts3ds_other.append(geotrf(pose, pself.unsqueeze(0))) |
|
|
pts3ds_other = transformed_pts3ds_other |
|
|
conf_other = conf_self |
|
|
|
|
|
|
|
|
B, H, W, _ = pts3ds_self.shape |
|
|
pp = torch.tensor([W // 2, H // 2], device=pts3ds_self.device).float().repeat(B, 1) |
|
|
focal = estimate_focal_knowing_depth(pts3ds_self, pp, focal_mode="weiszfeld") |
|
|
|
|
|
colors = [ |
|
|
0.5 * (output["img"].permute(0, 2, 3, 1) + 1.0) for output in outputs["views"] |
|
|
] |
|
|
|
|
|
cam_dict = { |
|
|
"focal": focal.cpu().numpy(), |
|
|
"pp": pp.cpu().numpy(), |
|
|
"R": R_c2w.cpu().numpy(), |
|
|
"t": t_c2w.cpu().numpy(), |
|
|
} |
|
|
|
|
|
pts3ds_self_tosave = pts3ds_self |
|
|
depths_tosave = pts3ds_self_tosave[..., 2] |
|
|
pts3ds_other_tosave = torch.cat(pts3ds_other) |
|
|
conf_self_tosave = torch.cat(conf_self) |
|
|
conf_other_tosave = torch.cat(conf_other) |
|
|
colors_tosave = torch.cat( |
|
|
[ |
|
|
0.5 * (output["img"].permute(0, 2, 3, 1).cpu() + 1.0) |
|
|
for output in outputs["views"] |
|
|
] |
|
|
) |
|
|
cam2world_tosave = torch.cat(pr_poses) |
|
|
intrinsics_tosave = ( |
|
|
torch.eye(3).unsqueeze(0).repeat(cam2world_tosave.shape[0], 1, 1) |
|
|
) |
|
|
intrinsics_tosave[:, 0, 0] = focal.detach().cpu() |
|
|
intrinsics_tosave[:, 1, 1] = focal.detach().cpu() |
|
|
intrinsics_tosave[:, 0, 2] = pp[:, 0] |
|
|
intrinsics_tosave[:, 1, 2] = pp[:, 1] |
|
|
|
|
|
os.makedirs(os.path.join(outdir, "depth"), exist_ok=True) |
|
|
os.makedirs(os.path.join(outdir, "conf"), exist_ok=True) |
|
|
os.makedirs(os.path.join(outdir, "color"), exist_ok=True) |
|
|
os.makedirs(os.path.join(outdir, "camera"), exist_ok=True) |
|
|
for f_id in range(len(pts3ds_self)): |
|
|
depth = depths_tosave[f_id].cpu().numpy() |
|
|
conf = conf_self_tosave[f_id].cpu().numpy() |
|
|
color = colors_tosave[f_id].cpu().numpy() |
|
|
c2w = cam2world_tosave[f_id].cpu().numpy() |
|
|
intrins = intrinsics_tosave[f_id].cpu().numpy() |
|
|
np.save(os.path.join(outdir, "depth", f"{f_id:06d}.npy"), depth) |
|
|
np.save(os.path.join(outdir, "conf", f"{f_id:06d}.npy"), conf) |
|
|
iio.imwrite( |
|
|
os.path.join(outdir, "color", f"{f_id:06d}.png"), |
|
|
(color * 255).astype(np.uint8), |
|
|
) |
|
|
np.savez( |
|
|
os.path.join(outdir, "camera", f"{f_id:06d}.npz"), |
|
|
pose=c2w, |
|
|
intrinsics=intrins, |
|
|
) |
|
|
|
|
|
return pts3ds_other, colors, conf_other, cam_dict |
|
|
|
|
|
|
|
|
def parse_seq_path(p): |
|
|
if os.path.isdir(p): |
|
|
img_paths = sorted(glob.glob(f"{p}/*")) |
|
|
tmpdirname = None |
|
|
else: |
|
|
cap = cv2.VideoCapture(p) |
|
|
if not cap.isOpened(): |
|
|
raise ValueError(f"Error opening video file {p}") |
|
|
video_fps = cap.get(cv2.CAP_PROP_FPS) |
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
if video_fps == 0: |
|
|
cap.release() |
|
|
raise ValueError(f"Error: Video FPS is 0 for {p}") |
|
|
frame_interval = 1 |
|
|
frame_indices = list(range(0, total_frames, frame_interval)) |
|
|
print( |
|
|
f" - Video FPS: {video_fps}, Frame Interval: {frame_interval}, Total Frames to Read: {len(frame_indices)}" |
|
|
) |
|
|
img_paths = [] |
|
|
tmpdirname = tempfile.mkdtemp() |
|
|
for i in frame_indices: |
|
|
cap.set(cv2.CAP_PROP_POS_FRAMES, i) |
|
|
ret, frame = cap.read() |
|
|
if not ret: |
|
|
break |
|
|
frame_path = os.path.join(tmpdirname, f"frame_{i}.jpg") |
|
|
cv2.imwrite(frame_path, frame) |
|
|
img_paths.append(frame_path) |
|
|
cap.release() |
|
|
return img_paths, tmpdirname |
|
|
|
|
|
|
|
|
def run_inference(args): |
|
|
""" |
|
|
Execute the full inference and visualization pipeline. |
|
|
|
|
|
Args: |
|
|
args: Parsed command-line arguments. |
|
|
""" |
|
|
|
|
|
device = args.device |
|
|
if device == "cuda" and not torch.cuda.is_available(): |
|
|
print("CUDA not available. Switching to CPU.") |
|
|
device = "cpu" |
|
|
|
|
|
|
|
|
add_path_to_dust3r(args.model_path) |
|
|
|
|
|
|
|
|
from src.dust3r.inference import inference, inference_recurrent |
|
|
from src.dust3r.model import ARCroco3DStereo |
|
|
from viser_utils import PointCloudViewer |
|
|
|
|
|
|
|
|
img_paths, tmpdirname = parse_seq_path(args.seq_path) |
|
|
if not img_paths: |
|
|
print(f"No images found in {args.seq_path}. Please verify the path.") |
|
|
return |
|
|
|
|
|
print(f"Found {len(img_paths)} images in {args.seq_path}.") |
|
|
img_mask = [True] * len(img_paths) |
|
|
|
|
|
|
|
|
print("Preparing input views...") |
|
|
views = prepare_input( |
|
|
img_paths=img_paths, |
|
|
img_mask=img_mask, |
|
|
size=args.size, |
|
|
revisit=1, |
|
|
update=True, |
|
|
) |
|
|
if tmpdirname is not None: |
|
|
shutil.rmtree(tmpdirname) |
|
|
|
|
|
|
|
|
print(f"Loading model from {args.model_path}...") |
|
|
model = ARCroco3DStereo.from_pretrained(args.model_path).to(device) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
print("Running inference...") |
|
|
start_time = time.time() |
|
|
outputs, state_args = inference(views, model, device) |
|
|
total_time = time.time() - start_time |
|
|
per_frame_time = total_time / len(views) |
|
|
print( |
|
|
f"Inference completed in {total_time:.2f} seconds (average {per_frame_time:.2f} s per frame)." |
|
|
) |
|
|
|
|
|
|
|
|
print("Preparing output for visualization...") |
|
|
pts3ds_other, colors, conf, cam_dict = prepare_output( |
|
|
outputs, args.output_dir, 1, True |
|
|
) |
|
|
|
|
|
|
|
|
pts3ds_to_vis = [p.cpu().numpy() for p in pts3ds_other] |
|
|
colors_to_vis = [c.cpu().numpy() for c in colors] |
|
|
edge_colors = [None] * len(pts3ds_to_vis) |
|
|
|
|
|
|
|
|
print("Launching point cloud viewer...") |
|
|
viewer = PointCloudViewer( |
|
|
model, |
|
|
state_args, |
|
|
pts3ds_to_vis, |
|
|
colors_to_vis, |
|
|
conf, |
|
|
cam_dict, |
|
|
device=device, |
|
|
edge_color_list=edge_colors, |
|
|
show_camera=True, |
|
|
vis_threshold=args.vis_threshold, |
|
|
size = args.size |
|
|
) |
|
|
viewer.run() |
|
|
|
|
|
|
|
|
def main(): |
|
|
args = parse_args() |
|
|
if not args.seq_path: |
|
|
print( |
|
|
"No inputs found! Please use our gradio demo if you would like to iteractively upload inputs." |
|
|
) |
|
|
return |
|
|
else: |
|
|
run_inference(args) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|