|
|
|
|
|
""" |
|
|
3D Point Cloud Inference and Visualization Script |
|
|
|
|
|
This script performs inference using the ARCroco3DStereo model and visualizes the |
|
|
resulting 3D point clouds with the PointCloudViewer. Use the command-line arguments |
|
|
to adjust parameters such as the model checkpoint path, image sequence directory, |
|
|
image size, device, etc. |
|
|
|
|
|
Usage: |
|
|
python demo_ga.py [--model_path MODEL_PATH] [--seq_path SEQ_PATH] [--size IMG_SIZE] |
|
|
[--device DEVICE] [--vis_threshold VIS_THRESHOLD] [--output_dir OUT_DIR] |
|
|
|
|
|
Example: |
|
|
python demo_ga.py --model_path src/cut3r_512_dpt_4_64.pth \ |
|
|
--seq_path examples/001 --device cuda --size 512 |
|
|
""" |
|
|
|
|
|
import os |
|
|
import numpy as np |
|
|
import torch |
|
|
import time |
|
|
import glob |
|
|
import random |
|
|
import cv2 |
|
|
import argparse |
|
|
import tempfile |
|
|
import shutil |
|
|
from copy import deepcopy |
|
|
from add_ckpt_path import add_path_to_dust3r |
|
|
import imageio.v2 as iio |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
random.seed(42) |
|
|
|
|
|
def forward_backward_permutations(n, interval=1): |
|
|
original = list(range(n)) |
|
|
result = [original] |
|
|
for i in range(1, n): |
|
|
new_list = original[i::interval] |
|
|
result.append(new_list) |
|
|
new_list = original[: i + 1][::-interval] |
|
|
result.append(new_list) |
|
|
return result |
|
|
|
|
|
def listify(elems): |
|
|
return [x for e in elems for x in e] |
|
|
|
|
|
|
|
|
def collate_with_cat(whatever, lists=False): |
|
|
if isinstance(whatever, dict): |
|
|
return {k: collate_with_cat(vals, lists=lists) for k, vals in whatever.items()} |
|
|
|
|
|
elif isinstance(whatever, (tuple, list)): |
|
|
if len(whatever) == 0: |
|
|
return whatever |
|
|
elem = whatever[0] |
|
|
T = type(whatever) |
|
|
|
|
|
if elem is None: |
|
|
return None |
|
|
if isinstance(elem, (bool, float, int, str)): |
|
|
return whatever |
|
|
if isinstance(elem, tuple): |
|
|
return T(collate_with_cat(x, lists=lists) for x in zip(*whatever)) |
|
|
if isinstance(elem, dict): |
|
|
return { |
|
|
k: collate_with_cat([e[k] for e in whatever], lists=lists) for k in elem |
|
|
} |
|
|
|
|
|
if isinstance(elem, torch.Tensor): |
|
|
return listify(whatever) if lists else torch.cat(whatever) |
|
|
if isinstance(elem, np.ndarray): |
|
|
return ( |
|
|
listify(whatever) |
|
|
if lists |
|
|
else torch.cat([torch.from_numpy(x) for x in whatever]) |
|
|
) |
|
|
|
|
|
|
|
|
return sum(whatever, T()) |
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
"""Parse command-line arguments.""" |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Run 3D point cloud inference and visualization using ARCroco3DStereo." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--model_path", |
|
|
type=str, |
|
|
default="src/cut3r_512_dpt_4_64.pth", |
|
|
help="Path to the pretrained model checkpoint.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--seq_path", |
|
|
type=str, |
|
|
default="", |
|
|
help="Path to the directory containing the image sequence.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--device", |
|
|
type=str, |
|
|
default="cuda", |
|
|
help="Device to run inference on (e.g., 'cuda' or 'cpu').", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--size", |
|
|
type=int, |
|
|
default="512", |
|
|
help="Shape that input images will be rescaled to; if using 224+linear model, choose 224 otherwise 512", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--vis_threshold", |
|
|
type=float, |
|
|
default=1.5, |
|
|
help="Visualization threshold for the point cloud viewer. Ranging from 1 to INF", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output_dir", |
|
|
type=str, |
|
|
default="./demo_tmp", |
|
|
help="value for tempfile.tempdir", |
|
|
) |
|
|
|
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def prepare_input( |
|
|
img_paths, img_mask, size, raymaps=None, raymap_mask=None, revisit=1, update=True |
|
|
): |
|
|
""" |
|
|
Prepare input views for inference from a list of image paths. |
|
|
|
|
|
Args: |
|
|
img_paths (list): List of image file paths. |
|
|
img_mask (list of bool): Flags indicating valid images. |
|
|
size (int): Target image size. |
|
|
raymaps (list, optional): List of ray maps. |
|
|
raymap_mask (list, optional): Flags indicating valid ray maps. |
|
|
revisit (int): How many times to revisit each view. |
|
|
update (bool): Whether to update the state on revisits. |
|
|
|
|
|
Returns: |
|
|
list: A list of view dictionaries. |
|
|
""" |
|
|
|
|
|
from src.dust3r.utils.image import load_images |
|
|
|
|
|
images = load_images(img_paths, size=size) |
|
|
views = [] |
|
|
num_views = len(images) |
|
|
all_permutations = forward_backward_permutations(num_views, interval=2) |
|
|
for permute in all_permutations: |
|
|
_views = [] |
|
|
for idx, i in enumerate(permute): |
|
|
view = { |
|
|
"img": images[i]["img"], |
|
|
"ray_map": torch.full( |
|
|
( |
|
|
images[i]["img"].shape[0], |
|
|
6, |
|
|
images[i]["img"].shape[-2], |
|
|
images[i]["img"].shape[-1], |
|
|
), |
|
|
torch.nan, |
|
|
), |
|
|
"true_shape": torch.from_numpy(images[i]["true_shape"]), |
|
|
"idx": i, |
|
|
"instance": str(i), |
|
|
"camera_pose": torch.from_numpy(np.eye(4).astype(np.float32)).unsqueeze( |
|
|
0 |
|
|
), |
|
|
"img_mask": torch.tensor(True).unsqueeze(0), |
|
|
"ray_mask": torch.tensor(False).unsqueeze(0), |
|
|
"update": torch.tensor(True).unsqueeze(0), |
|
|
"reset": torch.tensor(False).unsqueeze(0), |
|
|
} |
|
|
_views.append(view) |
|
|
views.append(_views) |
|
|
return views |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def prepare_output(output, outdir, device): |
|
|
from cloud_opt.dust3r_opt import global_aligner, GlobalAlignerMode |
|
|
|
|
|
with torch.enable_grad(): |
|
|
mode = GlobalAlignerMode.PointCloudOptimizer |
|
|
scene = global_aligner( |
|
|
output, |
|
|
device=device, |
|
|
mode=mode, |
|
|
verbose=True, |
|
|
) |
|
|
lr = 0.01 |
|
|
loss = scene.compute_global_alignment( |
|
|
init="mst", |
|
|
niter=300, |
|
|
schedule="linear", |
|
|
lr=lr, |
|
|
) |
|
|
scene.clean_pointcloud() |
|
|
pts3d = scene.get_pts3d() |
|
|
depths = scene.get_depthmaps() |
|
|
poses = scene.get_im_poses() |
|
|
focals = scene.get_focals() |
|
|
pps = scene.get_principal_points() |
|
|
confs = scene.get_conf(mode="none") |
|
|
|
|
|
pts3ds_other = [pts.detach().cpu().unsqueeze(0) for pts in pts3d] |
|
|
depths = [d.detach().cpu().unsqueeze(0) for d in depths] |
|
|
colors = [torch.from_numpy(img).unsqueeze(0) for img in scene.imgs] |
|
|
confs = [conf.detach().cpu().unsqueeze(0) for conf in confs] |
|
|
cam_dict = { |
|
|
"focal": focals.detach().cpu().numpy(), |
|
|
"pp": pps.detach().cpu().numpy(), |
|
|
"R": poses.detach().cpu().numpy()[..., :3, :3], |
|
|
"t": poses.detach().cpu().numpy()[..., :3, 3], |
|
|
} |
|
|
|
|
|
depths_tosave = torch.cat(depths) |
|
|
pts3ds_other_tosave = torch.cat(pts3ds_other) |
|
|
conf_self_tosave = torch.cat(confs) |
|
|
colors_tosave = torch.cat(colors) |
|
|
cam2world_tosave = poses.detach().cpu() |
|
|
intrinsics_tosave = ( |
|
|
torch.eye(3).unsqueeze(0).repeat(cam2world_tosave.shape[0], 1, 1) |
|
|
) |
|
|
intrinsics_tosave[:, 0, 0] = focals[:, 0].detach().cpu() |
|
|
intrinsics_tosave[:, 1, 1] = focals[:, 0].detach().cpu() |
|
|
intrinsics_tosave[:, 0, 2] = pps[:, 0].detach().cpu() |
|
|
intrinsics_tosave[:, 1, 2] = pps[:, 1].detach().cpu() |
|
|
|
|
|
os.makedirs(os.path.join(outdir, "depth"), exist_ok=True) |
|
|
os.makedirs(os.path.join(outdir, "conf"), exist_ok=True) |
|
|
os.makedirs(os.path.join(outdir, "color"), exist_ok=True) |
|
|
os.makedirs(os.path.join(outdir, "camera"), exist_ok=True) |
|
|
for f_id in range(len(depths_tosave)): |
|
|
depth = depths_tosave[f_id].cpu().numpy() |
|
|
conf = conf_self_tosave[f_id].cpu().numpy() |
|
|
color = colors_tosave[f_id].cpu().numpy() |
|
|
c2w = cam2world_tosave[f_id].cpu().numpy() |
|
|
intrins = intrinsics_tosave[f_id].cpu().numpy() |
|
|
np.save(os.path.join(outdir, "depth", f"{f_id:06d}.npy"), depth) |
|
|
np.save(os.path.join(outdir, "conf", f"{f_id:06d}.npy"), conf) |
|
|
iio.imwrite( |
|
|
os.path.join(outdir, "color", f"{f_id:06d}.png"), |
|
|
(color * 255).astype(np.uint8), |
|
|
) |
|
|
np.savez( |
|
|
os.path.join(outdir, "camera", f"{f_id:06d}.npz"), |
|
|
pose=c2w, |
|
|
intrinsics=intrins, |
|
|
) |
|
|
|
|
|
return pts3ds_other, colors, confs, cam_dict |
|
|
|
|
|
|
|
|
def parse_seq_path(p): |
|
|
if os.path.isdir(p): |
|
|
img_paths = sorted(glob.glob(f"{p}/*")) |
|
|
tmpdirname = None |
|
|
else: |
|
|
cap = cv2.VideoCapture(p) |
|
|
if not cap.isOpened(): |
|
|
raise ValueError(f"Error opening video file {p}") |
|
|
video_fps = cap.get(cv2.CAP_PROP_FPS) |
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
if video_fps == 0: |
|
|
cap.release() |
|
|
raise ValueError(f"Error: Video FPS is 0 for {p}") |
|
|
frame_interval = 1 |
|
|
frame_indices = list(range(0, total_frames, frame_interval)) |
|
|
print( |
|
|
f" - Video FPS: {video_fps}, Frame Interval: {frame_interval}, Total Frames to Read: {len(frame_indices)}" |
|
|
) |
|
|
img_paths = [] |
|
|
tmpdirname = tempfile.mkdtemp() |
|
|
for i in frame_indices: |
|
|
cap.set(cv2.CAP_PROP_POS_FRAMES, i) |
|
|
ret, frame = cap.read() |
|
|
if not ret: |
|
|
break |
|
|
frame_path = os.path.join(tmpdirname, f"frame_{i}.jpg") |
|
|
cv2.imwrite(frame_path, frame) |
|
|
img_paths.append(frame_path) |
|
|
cap.release() |
|
|
return img_paths, tmpdirname |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_inference(args): |
|
|
""" |
|
|
Execute the full inference and visualization pipeline. |
|
|
|
|
|
Args: |
|
|
args: Parsed command-line arguments. |
|
|
""" |
|
|
|
|
|
device = args.device |
|
|
if device == "cuda" and not torch.cuda.is_available(): |
|
|
print("CUDA not available. Switching to CPU.") |
|
|
device = "cpu" |
|
|
|
|
|
|
|
|
add_path_to_dust3r(args.model_path) |
|
|
|
|
|
|
|
|
from src.dust3r.inference import inference, inference_recurrent |
|
|
from src.dust3r.model import ARCroco3DStereo |
|
|
from viser_utils import PointCloudViewer |
|
|
|
|
|
|
|
|
img_paths, tmpdirname = parse_seq_path(args.seq_path) |
|
|
if not img_paths: |
|
|
print(f"No images found in {args.seq_path}. Please verify the path.") |
|
|
return |
|
|
|
|
|
print(f"Found {len(img_paths)} images in {args.seq_path}.") |
|
|
img_mask = [True] * len(img_paths) |
|
|
|
|
|
|
|
|
print("Preparing input views...") |
|
|
views = prepare_input( |
|
|
img_paths=img_paths, |
|
|
img_mask=img_mask, |
|
|
size=args.size, |
|
|
revisit=1, |
|
|
update=True, |
|
|
) |
|
|
if tmpdirname is not None: |
|
|
shutil.rmtree(tmpdirname) |
|
|
|
|
|
|
|
|
print(f"Loading model from {args.model_path}...") |
|
|
model = ARCroco3DStereo.from_pretrained(args.model_path).to(device) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
print("Running inference...") |
|
|
start_time = time.time() |
|
|
output = { |
|
|
"view1": [], |
|
|
"view2": [], |
|
|
"pred1": [], |
|
|
"pred2": [], |
|
|
} |
|
|
edges = [] |
|
|
for _views in views: |
|
|
outputs, state_args = inference(_views, model, device) |
|
|
for view_id in range(1, len(outputs["views"])): |
|
|
output["view1"].append(outputs["views"][0]) |
|
|
output["view2"].append(outputs["views"][view_id]) |
|
|
output["pred1"].append(outputs["pred"][0]) |
|
|
output["pred2"].append(outputs["pred"][view_id]) |
|
|
|
|
|
edges.append((outputs["views"][0]["idx"], outputs["views"][view_id]["idx"])) |
|
|
list_of_tuples = edges |
|
|
sorted_indices = sorted( |
|
|
range(len(list_of_tuples)), |
|
|
key=lambda x: ( |
|
|
list_of_tuples[x][0] > list_of_tuples[x][1], |
|
|
( |
|
|
list_of_tuples[x][1] |
|
|
if list_of_tuples[x][0] > list_of_tuples[x][1] |
|
|
else list_of_tuples[x][0] |
|
|
), |
|
|
( |
|
|
list_of_tuples[x][0] |
|
|
if list_of_tuples[x][0] > list_of_tuples[x][1] |
|
|
else list_of_tuples[x][1] |
|
|
), |
|
|
), |
|
|
) |
|
|
new_output = { |
|
|
"view1": [], |
|
|
"view2": [], |
|
|
"pred1": [], |
|
|
"pred2": [], |
|
|
} |
|
|
for i in sorted_indices: |
|
|
new_output["view1"].append(output["view1"][i]) |
|
|
new_output["view2"].append(output["view2"][i]) |
|
|
new_output["pred1"].append(output["pred1"][i]) |
|
|
new_output["pred2"].append(output["pred2"][i]) |
|
|
output["view1"] = collate_with_cat(new_output["view1"]) |
|
|
output["view2"] = collate_with_cat(new_output["view2"]) |
|
|
output["pred1"] = collate_with_cat(new_output["pred1"]) |
|
|
output["pred2"] = collate_with_cat(new_output["pred2"]) |
|
|
|
|
|
total_time = time.time() - start_time |
|
|
per_frame_time = total_time / len(views) |
|
|
print( |
|
|
f"Inference completed in {total_time:.2f} seconds (average {per_frame_time:.2f} s per frame)." |
|
|
) |
|
|
|
|
|
|
|
|
print("Preparing output for visualization...") |
|
|
|
|
|
pts3ds_other, colors, conf, cam_dict = prepare_output( |
|
|
output, args.output_dir, device |
|
|
) |
|
|
|
|
|
|
|
|
pts3ds_to_vis = [p.cpu().numpy() for p in pts3ds_other] |
|
|
colors_to_vis = [c.cpu().numpy() for c in colors] |
|
|
edge_colors = [None] * len(pts3ds_to_vis) |
|
|
|
|
|
|
|
|
print("Launching point cloud viewer...") |
|
|
viewer = PointCloudViewer( |
|
|
model, |
|
|
state_args, |
|
|
pts3ds_to_vis, |
|
|
colors_to_vis, |
|
|
conf, |
|
|
cam_dict, |
|
|
device=device, |
|
|
edge_color_list=edge_colors, |
|
|
show_camera=True, |
|
|
vis_threshold=args.vis_threshold, |
|
|
size=args.size, |
|
|
) |
|
|
viewer.run() |
|
|
|
|
|
|
|
|
def main(): |
|
|
args = parse_args() |
|
|
if not args.seq_path: |
|
|
print( |
|
|
"No inputs found! Please use our gradio demo if you would like to iteractively upload inputs." |
|
|
) |
|
|
return |
|
|
else: |
|
|
run_inference(args) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|