Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import os | |
| import torch | |
| from pathlib import Path | |
| from multiprocessing import Process, Queue | |
| from yacs.config import CfgNode | |
| from mini_dpvo.utils import Timer | |
| from mini_dpvo.dpvo import DPVO | |
| from mini_dpvo.stream import image_stream, video_stream | |
| import rerun as rr | |
| from jaxtyping import UInt8, Float64, Float32 | |
| from scipy.spatial.transform import Rotation | |
| from dataclasses import dataclass | |
| from timeit import default_timer as timer | |
| import cv2 | |
| import mmcv | |
| from tqdm import tqdm | |
| from mini_dust3r.api import OptimizedResult, inferece_dust3r | |
| from mini_dust3r.model import AsymmetricCroCo3DStereo | |
| class DPVOPrediction: | |
| final_poses: Float32[torch.Tensor, "num_keyframes 7"] # noqa: F722 | |
| tstamps: Float64[torch.Tensor, "num_keyframes"] # noqa: F821 | |
| final_points: Float32[torch.Tensor, "buffer_size*num_patches 3"] # noqa: F722 | |
| final_colors: UInt8[torch.Tensor, "buffer_size num_patches 3"] # noqa: F722 | |
| def log_trajectory( | |
| parent_log_path: Path, | |
| poses: Float32[torch.Tensor, "buffer_size 7"], # noqa: F722 | |
| points: Float32[torch.Tensor, "buffer_size*num_patches 3"], # noqa: F722 | |
| colors: UInt8[torch.Tensor, "buffer_size num_patches 3"], # noqa: F722 | |
| intri_np: Float64[np.ndarray, "4"], | |
| bgr_hw3: UInt8[np.ndarray, "h w 3"], # noqa: F722 | |
| path_list: list, | |
| jpg_quality: int = 90, | |
| ): | |
| cam_log_path = f"{parent_log_path}/camera" | |
| rgb_hw3 = mmcv.bgr2rgb(bgr_hw3) | |
| rr.log( | |
| f"{cam_log_path}/pinhole/image", | |
| rr.Image(rgb_hw3).compress(jpeg_quality=jpg_quality), | |
| ) | |
| rr.log( | |
| f"{cam_log_path}/pinhole", | |
| rr.Pinhole( | |
| height=bgr_hw3.shape[0], | |
| width=bgr_hw3.shape[1], | |
| focal_length=[intri_np[0], intri_np[1]], | |
| principal_point=[intri_np[2], intri_np[3]], | |
| ), | |
| ) | |
| poses_mask = ~(poses[:, :6] == 0).all(dim=1) | |
| points_mask = ~(points == 0).all(dim=1) | |
| nonzero_poses = poses[poses_mask] | |
| nonzero_points = points[points_mask] | |
| last_index = nonzero_poses.shape[0] - 1 | |
| # get last non-zero pose, and the index of the last non-zero pose | |
| quat_pose = nonzero_poses[last_index].numpy(force=True) | |
| trans_quat: Float32[np.ndarray, "3"] = quat_pose[:3] | |
| rotation_quat = Rotation.from_quat(quat_pose[3:]) | |
| cam_R_world: Float64[np.ndarray, "3 3"] = rotation_quat.as_matrix() | |
| cam_T_world = np.eye(4) | |
| cam_T_world[:3, :3] = cam_R_world | |
| cam_T_world[0:3, 3] = trans_quat | |
| world_T_cam = np.linalg.inv(cam_T_world) | |
| path_list.append(world_T_cam[:3, 3].copy().tolist()) | |
| rr.log( | |
| f"{cam_log_path}", | |
| rr.Transform3D( | |
| translation=world_T_cam[:3, 3], | |
| mat3x3=world_T_cam[:3, :3], | |
| from_parent=False, | |
| ), | |
| ) | |
| # log path using linestrip | |
| rr.log( | |
| f"{parent_log_path}/path", | |
| rr.LineStrips3D( | |
| strips=[ | |
| path_list, | |
| ], | |
| colors=[255, 0, 0], | |
| ), | |
| ) | |
| # outlier removal | |
| trajectory_center = np.median(nonzero_poses[:, :3].numpy(force=True), axis=0) | |
| def radii(a): | |
| return np.linalg.norm(a - trajectory_center, axis=1) | |
| points_np = nonzero_points.view(-1, 3).numpy(force=True) | |
| colors_np = colors.view(-1, 3)[points_mask].numpy(force=True) | |
| inlier_mask = ( | |
| radii(points_np) < radii(nonzero_poses[:, :3].numpy(force=True)).max() * 5 | |
| ) | |
| points_filtered = points_np[inlier_mask] | |
| colors_filtered = colors_np[inlier_mask] | |
| # log all points and colors at the same time | |
| rr.log( | |
| f"{parent_log_path}/pointcloud", | |
| rr.Points3D( | |
| positions=points_filtered, | |
| colors=colors_filtered, | |
| ), | |
| ) | |
| return path_list | |
| def log_final( | |
| parent_log_path: Path, | |
| final_poses: Float32[torch.Tensor, "num_keyframes 7"], | |
| tstamps: Float64[torch.Tensor, "num_keyframes"], # noqa: F821 | |
| final_points: Float32[torch.Tensor, "buffer_size*num_patches 3"], | |
| final_colors: UInt8[torch.Tensor, "buffer_size num_patches 3"], | |
| ): | |
| for idx, (pose_quat, tstamp) in enumerate(zip(final_poses, tstamps)): | |
| cam_log_path = f"{parent_log_path}/camera_{idx}" | |
| trans_quat = pose_quat[:3] | |
| R_33 = Rotation.from_quat(pose_quat[3:]).as_matrix() | |
| rr.log( | |
| f"{cam_log_path}", | |
| rr.Transform3D(translation=trans_quat, mat3x3=R_33, from_parent=False), | |
| ) | |
| def create_reader( | |
| imagedir: str, calib: str | None, stride: int, skip: int, queue: Queue | |
| ) -> Process: | |
| if os.path.isdir(imagedir): | |
| reader = Process( | |
| target=image_stream, args=(queue, imagedir, calib, stride, skip) | |
| ) | |
| else: | |
| reader = Process( | |
| target=video_stream, args=(queue, imagedir, calib, stride, skip) | |
| ) | |
| return reader | |
| def calculate_num_frames(video_or_image_dir: str, stride: int, skip: int) -> int: | |
| # Determine the total number of frames | |
| total_frames = 0 | |
| if os.path.isdir(video_or_image_dir): | |
| total_frames = len( | |
| [ | |
| name | |
| for name in os.listdir(video_or_image_dir) | |
| if os.path.isfile(os.path.join(video_or_image_dir, name)) | |
| ] | |
| ) | |
| else: | |
| cap = cv2.VideoCapture(video_or_image_dir) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| cap.release() | |
| total_frames = (total_frames - skip) // stride | |
| return total_frames | |
| def calib_from_dust3r( | |
| bgr_hw3: UInt8[np.ndarray, "height width 3"], | |
| model: AsymmetricCroCo3DStereo, | |
| device: str, | |
| ) -> Float64[np.ndarray, "3 3"]: | |
| """ | |
| Calculates the calibration matrix from mini-dust3r. | |
| Args: | |
| bgr_hw3: The input image in BGR format with shape (height, width, 3). | |
| model: The Dust3D-R model used for inference. | |
| device: The device to run the inference on. | |
| Returns: | |
| The calibration matrix with shape (3, 3). | |
| Raises: | |
| None. | |
| """ | |
| tmp_path = Path("/tmp/dpvo/tmp.png") | |
| # save image | |
| mmcv.imwrite(bgr_hw3, str(tmp_path)) | |
| optimized_results: OptimizedResult = inferece_dust3r( | |
| image_dir_or_list=tmp_path.parent, | |
| model=model, | |
| device=device, | |
| batch_size=1, | |
| ) | |
| # DELETE tmp file | |
| tmp_path.unlink() | |
| # get predicted intrinsics in original image size | |
| downscaled_h, downscaled_w, _ = optimized_results.rgb_hw3_list[0].shape | |
| orig_h, orig_w, _ = bgr_hw3.shape | |
| # Scaling factors | |
| scaling_factor_x = orig_w / downscaled_w | |
| scaling_factor_y = orig_h / downscaled_h | |
| # Scale the intrinsic matrix to the original image size | |
| K_33_original = optimized_results.K_b33[0].copy() | |
| K_33_original[0, 0] *= scaling_factor_x # fx | |
| K_33_original[1, 1] *= scaling_factor_y # fy | |
| K_33_original[0, 2] *= scaling_factor_x # cx | |
| K_33_original[1, 2] *= scaling_factor_y # cy | |
| return K_33_original | |
| def inference_dpvo( | |
| cfg: CfgNode, | |
| network_path: str, | |
| imagedir: str, | |
| calib: str, | |
| stride: int = 1, | |
| skip: int = 0, | |
| timeit: bool = False, | |
| ) -> tuple[DPVOPrediction, float]: | |
| slam = None | |
| queue = Queue(maxsize=8) | |
| reader: Process = create_reader(imagedir, calib, stride, skip, queue) | |
| reader.start() | |
| parent_log_path = Path("world") | |
| rr.log(f"{parent_log_path}", rr.ViewCoordinates.RDF, timeless=True) | |
| start = timer() | |
| total_frames = calculate_num_frames(imagedir, stride, skip) | |
| # estimate camera intrinsics if not provided | |
| if calib is None: | |
| dust3r_device = ( | |
| "mps" | |
| if torch.backends.mps.is_available() | |
| else "cuda" | |
| if torch.cuda.is_available() | |
| else "cpu" | |
| ) | |
| dust3r_model = AsymmetricCroCo3DStereo.from_pretrained( | |
| "naver/DUSt3R_ViTLarge_BaseDecoder_512_dpt" | |
| ).to(dust3r_device) | |
| _, bgr_hw3, _ = queue.get() | |
| K_33_pred = calib_from_dust3r(bgr_hw3, dust3r_model, dust3r_device) | |
| intri_np_dust3r = np.array( | |
| [K_33_pred[0, 0], K_33_pred[1, 1], K_33_pred[0, 2], K_33_pred[1, 2]] | |
| ) | |
| # path list for visualizing the trajectory | |
| path_list = [] | |
| with tqdm(total=total_frames, desc="Processing Frames") as pbar: | |
| while True: | |
| t: int | |
| bgr_hw3: UInt8[np.ndarray, "h w 3"] | |
| intri_np: Float64[np.ndarray, "4"] | |
| (t, bgr_hw3, intri_np_calib) = queue.get() | |
| intri_np = intri_np_calib if calib is not None else intri_np_dust3r | |
| # queue will have a (-1, image, intrinsics) tuple when the reader is done | |
| if t < 0: | |
| break | |
| rr.set_time_sequence(timeline="timestep", sequence=t) | |
| bgr_3hw: UInt8[torch.Tensor, "h w 3"] = ( | |
| torch.from_numpy(bgr_hw3).permute(2, 0, 1).cuda() | |
| ) | |
| intri_torch: Float64[torch.Tensor, "4"] = torch.from_numpy(intri_np).cuda() | |
| if slam is None: | |
| slam = DPVO(cfg, network_path, ht=bgr_3hw.shape[1], wd=bgr_3hw.shape[2]) | |
| with Timer("SLAM", enabled=timeit): | |
| slam(t, bgr_3hw, intri_torch) | |
| if slam.is_initialized: | |
| poses: Float32[torch.Tensor, "buffer_size 7"] = slam.poses_ | |
| points: Float32[torch.Tensor, "buffer_size*num_patches 3"] = ( | |
| slam.points_ | |
| ) | |
| colors: UInt8[torch.Tensor, "buffer_size num_patches 3"] = slam.colors_ | |
| path_list = log_trajectory( | |
| parent_log_path=parent_log_path, | |
| poses=poses, | |
| points=points, | |
| colors=colors, | |
| intri_np=intri_np, | |
| bgr_hw3=bgr_hw3, | |
| path_list=path_list, | |
| ) | |
| pbar.update(1) | |
| for _ in range(12): | |
| slam.update() | |
| total_time: float = timer() - start | |
| print(f"Total time: {total_time:.2f}s") | |
| reader.join() | |
| final_poses: Float32[torch.Tensor, "num_keyframes 7"] | |
| tstamps: Float64[torch.Tensor, "num_keyframes"] # noqa: F821 | |
| final_poses, tstamps = slam.terminate() | |
| final_points: Float32[torch.Tensor, "buffer_size*num_patches 3"] = slam.points_ | |
| final_colors: UInt8[torch.Tensor, "buffer_size num_patches 3"] = slam.colors_ | |
| dpvo_pred = DPVOPrediction( | |
| final_poses=final_poses, | |
| tstamps=tstamps, | |
| final_points=final_points, | |
| final_colors=final_colors, | |
| ) | |
| return dpvo_pred, total_time | |