| | import time |
| | from typing import List |
| |
|
| | import cv2 |
| | import numpy as np |
| | import viser |
| | import viser.transforms as vt |
| | import hydra |
| | import torch |
| | from torch import Tensor |
| | from jaxtyping import Float |
| | from PIL import Image |
| | from torchvision import transforms as TF |
| | from einops import repeat |
| | import matplotlib |
| |
|
| | from vggt.utils.pose_enc import pose_encoding_to_extri_intri |
| |
|
| | from dpm.model import VDPM |
| | from util.transforms import transform_points |
| |
|
| |
|
| | VIDEO_SAMPLE_HZ = 1.0 |
| |
|
| |
|
| | def assign_colours(pts3d, colour=[0, 0, 1]): |
| | num_points = pts3d.shape[0] |
| | colors = ( |
| | np.tile(np.array([colour]), (num_points, 1)) * 255 |
| | ).astype(np.uint8) |
| | return colors |
| |
|
| |
|
| | def compute_box_edges(corners): |
| | """ |
| | Compute all edges of a 3D bounding box |
| | |
| | Args: |
| | corners: torch tensor of shape (8, 3) containing the coordinates of the 8 corners |
| | of a 3D bounding box |
| | |
| | Returns: |
| | edges: torch tensor of shape (12, 2, 3) containing the 12 edges of the box, |
| | each represented as a pair of 3D coordinates [start_point, end_point] |
| | """ |
| | |
| | edge_indices = torch.tensor([ |
| | |
| | [0, 1], [2, 3], [4, 5], [6, 7], |
| | |
| | [0, 2], [1, 3], [4, 6], [5, 7], |
| | |
| | [0, 4], [1, 5], [2, 6], [3, 7] |
| | ], dtype=torch.long) |
| |
|
| | |
| | edges = torch.zeros((12, 2, 3), dtype=corners.dtype, device=corners.device) |
| |
|
| | |
| | for i, (start_idx, end_idx) in enumerate(edge_indices): |
| | edges[i, 0] = corners[start_idx] |
| | edges[i, 1] = corners[end_idx] |
| |
|
| | colors = torch.tensor([ |
| | [255, 0, 0], |
| | [0, 255, 0], |
| | [0, 0, 255], |
| | [255, 255, 0], |
| | [0, 255, 255], |
| | [255, 0, 255], |
| | [255, 128, 0], |
| | [128, 0, 255], |
| | [128, 255, 0], |
| | [255, 0, 128], |
| | [0, 128, 255], |
| | [128, 0, 0] |
| | ], dtype=torch.uint8, device=corners.device) |
| |
|
| | return edges, colors |
| |
|
| |
|
| | class TrackVisualiser: |
| | def __init__(self, |
| | server: viser.ViserServer |
| | ): |
| | self._trail_length = 12 |
| | self._server = server |
| |
|
| | def remove_static_tracks(self, |
| | tracks: Float[Tensor, "t n 3"], |
| | threshold=0.025 |
| | ) -> Float[Tensor, "t n 3"]: |
| | |
| | delta = tracks[None, ...] - tracks[:, None, ...] |
| | max_displ = torch.linalg.norm(delta.abs(), dim=-1).amax((0, 1)) |
| | dynamic = max_displ > threshold |
| | tracks_filtered = tracks[:, dynamic, :] |
| | return tracks_filtered |
| |
|
| | def set_data(self, |
| | tracks_xyz: Float[Tensor, "t n 3"], |
| | ): |
| | |
| | tracks_xyz = tracks_xyz.numpy() |
| | print("num actual tracks", tracks_xyz.shape[0]) |
| | num_tracks = min(1000, tracks_xyz.shape[1]) |
| | indices = np.random.choice(tracks_xyz.shape[1], num_tracks, replace=False) |
| | tracks_xyz = tracks_xyz[:, indices] |
| | sorted_indices = np.argsort(tracks_xyz[0, ..., 1]) |
| | tracks_xyz = tracks_xyz[:, sorted_indices] |
| | color_map = matplotlib.colormaps.get_cmap('hsv') |
| | cmap_norm = matplotlib.colors.Normalize(vmin=0, vmax=tracks_xyz.shape[1] - 1) |
| | colours = np.zeros((num_tracks, 3), dtype=np.float32) |
| | for t_idx in range(num_tracks): |
| | color = color_map(cmap_norm(t_idx))[:3] |
| | colours[t_idx] = color |
| | colours = colours[:, None, :].repeat(2, axis=1) |
| |
|
| | n_frames = tracks_xyz.shape[0] |
| | segment_nodes: list[viser.LineSegmentsHandle] = [] |
| | for k in range(1, n_frames): |
| | segment = tracks_xyz[k-1:k+1].swapaxes(0, 1) |
| | |
| | segment_node = self._server.scene.add_line_segments( |
| | f"/track_vis/{k}", |
| | segment, |
| | colours |
| | ) |
| | segment_node.visible = False |
| | segment_nodes.append(segment_node) |
| | self._segment_nodes = segment_nodes |
| |
|
| | def set_current_frame(self, f_idx: int): |
| | start_idx = max(1, f_idx - self._trail_length + 1) |
| | for node in self._segment_nodes: |
| | node.visible = False |
| | for idx in range(start_idx, f_idx + 1): |
| | self._segment_nodes[idx-1].visible = True |
| |
|
| |
|
| | class ViserViewer: |
| | def __init__(self, model, device, port=8080): |
| | self.device = device |
| | self.model = model |
| | self.port = port |
| |
|
| | self.S = 5 |
| | self.need_update = True |
| | self.need_sequence_change = False |
| | self.is_playing = False |
| | self.last_update_time = time.time() |
| |
|
| | self.server = viser.ViserServer(port=self.port) |
| | self._setup_gui() |
| | self._setup_event_handlers() |
| |
|
| | self._track_visualiser = TrackVisualiser(self.server) |
| |
|
| | def _setup_gui(self): |
| | server = self.server |
| | server.gui.configure_theme(control_layout="floating", control_width="large", show_logo=False) |
| | self.seq_selector = server.gui.add_button("Next example") |
| | self.play_button = server.gui.add_button("Play") |
| | self.scene_label = server.gui.add_text( |
| | "Sequence ID", |
| | initial_value="", |
| | disabled=True |
| | ) |
| |
|
| | self.gui_point_size = server.gui.add_slider( |
| | "Point size", |
| | min=0.0005, |
| | max=0.002, |
| | step=0.0005, |
| | initial_value=0.001, |
| | ) |
| |
|
| | self.gui_timestep = server.gui.add_slider( |
| | "Time", |
| | min=0, |
| | max=self.S-1, |
| | step=1, |
| | initial_value=0, |
| | ) |
| | self.conf_slider = server.gui.add_slider( |
| | "Confidence", |
| | min=0.0, |
| | max=1.0, |
| | step=0.01, |
| | initial_value=0.3, |
| | ) |
| | self.prev_timestep = self.gui_timestep.value |
| |
|
| | self.rgb0_vis = self.server.gui.add_image( |
| | np.ones((100,100,3), dtype=np.uint8) * 255, |
| | label="rgb_0" |
| | ) |
| | self.rgbt_vis = self.server.gui.add_image( |
| | np.ones((100,100,3), dtype=np.uint8) * 255, |
| | label="rgb_t" |
| | ) |
| |
|
| | def set_scene_label(self, example_idx): |
| | seq_idx, frame_idx = self.dataset.idx_to_seq_frame_id(example_idx) |
| | scene_name = self.dataset.seq_keys()[seq_idx] |
| | self.scene_label.value = f"{scene_name}_{frame_idx}" |
| | print("setting scene label", f"{scene_name}_{frame_idx}") |
| |
|
| | def _setup_event_handlers(self): |
| | @self.seq_selector.on_click |
| | def _(_) -> None: |
| | """Choose random sequence to display""" |
| | with self.server.atomic(): |
| | num_scenes = len(self.dataset) |
| | example_idx = np.random.randint(num_scenes) |
| | self.set_scene_label(example_idx) |
| | views = self.dataset[example_idx] |
| | views = process_example(views, self.device) |
| | pointmaps, extrinsic, _, gt_extrinsic = compute_predictions(self, model, views) |
| | self.visualise_reconstruction(views, pointmaps, extrinsic, gt_extrinsic) |
| | self.server.flush() |
| | self.need_update = True |
| | self.need_sequence_change = True |
| |
|
| | @self.play_button.on_click |
| | def _(_) -> None: |
| | self.is_playing = not self.is_playing |
| | self.play_button.text = "Pause" if self.is_playing else "Play" |
| |
|
| | @self.gui_point_size.on_update |
| | def _(_): |
| | for node in self.point_nodes: |
| | node.point_size = self.gui_point_size.value |
| |
|
| | @self.conf_slider.on_update |
| | def _(_): |
| | self.need_update = True |
| |
|
| | @self.gui_timestep.on_update |
| | def _(_) -> None: |
| | """Toggle frame visibility when the timestep slider changes""" |
| | current_timestep = self.gui_timestep.value |
| | with self.server.atomic(): |
| | |
| | self.frame_nodes[current_timestep].visible = True |
| | self.frame_nodes[self.prev_timestep].visible = False |
| | self.prev_timestep = current_timestep |
| | if self._track_visualiser is not None: |
| | self._track_visualiser.set_current_frame(current_timestep) |
| | self._update_image_t() |
| | self.server.flush() |
| |
|
| | def continue_loop(self): |
| | return not self.need_sequence_change |
| |
|
| | def set_data( |
| | self, |
| | pts3d_v0_t1: Float[Tensor, "s h w 3"], |
| | confs: Float[Tensor, "h w"], |
| | img_v0: Float[Tensor, "3 h w"], |
| | imgs: List[Float[Tensor, "3 h w"]], |
| | instance_ids, |
| | panoptic_v0, |
| | extrinsic, |
| | ): |
| | self.S = pts3d_v0_t1.shape[0] |
| | self.gui_timestep.max = self.S - 1 |
| |
|
| | self.pts3d_v0_t1 = pts3d_v0_t1 |
| | self.img_v0 = img_v0 |
| | self.imgs = imgs |
| | self.panoptic_v0 = panoptic_v0 |
| | self.instance_ids = instance_ids |
| | self.confs = confs |
| | self.extrinsic = extrinsic |
| |
|
| | self.need_update = True |
| | self.need_sequence_change = False |
| |
|
| | def update(self): |
| | if not self.need_update: |
| | return |
| | self._do_update() |
| | self.need_update = False |
| |
|
| | def _do_update(self): |
| | self.server.scene.reset() |
| |
|
| | img_v0 = self.img_v0 |
| | rgb_v0 = (img_v0 * 255.0).to(torch.uint8).permute(1, 2, 0).numpy() |
| |
|
| | def get_coloured_pointclouds(pts_img, color=None): |
| | return { |
| | "pts3d": pts_img.view(-1, 3), |
| | "rgb": rgb_v0.reshape(-1, 3) if color is None else color, |
| | "conf": self.confs.view(-1) |
| | } |
| |
|
| | points3d = dict() |
| | for s in range(self.S): |
| | points3d[f"v0_t{s}"] = get_coloured_pointclouds(self.pts3d_v0_t1[s]) |
| | point_size = float(self.gui_point_size.value) |
| |
|
| | T = torch.tensor([ |
| | [1, 0, 0, 0], |
| | [0, 0, 1, 0], |
| | [0, -1, 0, 0], |
| | [0, 0, 0, 1] |
| | ], dtype=torch.float32) |
| |
|
| | view_colours = np.array([ |
| | [0, 0, 1], |
| | [1, 0, 0], |
| | [0, 1, 0], |
| | [1, 1, 0], |
| | [1, 0, 1], |
| | [0, 1, 1], |
| | [0.5, 0, 0], |
| | [0, 0.5, 0], |
| | [0, 0, 0.5], |
| | [0.5, 0.5, 0] |
| | ], np.float32) |
| |
|
| | if self.extrinsic is not None: |
| | extrinsic = self.extrinsic |
| | S = extrinsic.shape[0] |
| |
|
| | T_c2ws = [extrinsic[s] for s in range(S)] |
| |
|
| | for v, T_c2w in enumerate(T_c2ws): |
| | T_c2w = (T @ T_c2w).numpy() |
| | H, W = img_v0.shape[1:3] |
| | f_x = 600 |
| | fov = 2 * np.arctan2(W / 2, f_x) |
| | aspect = W / H |
| | self.server.scene.add_camera_frustum( |
| | f"/frames/t{v}/camera/pred", |
| | fov=fov, |
| | aspect=aspect, |
| | scale=0.1, |
| | color=view_colours[0], |
| | |
| | wxyz=vt.SO3.from_matrix(T_c2w[:3, :3]).wxyz, |
| | position=T_c2w[:3, -1], |
| | ) |
| |
|
| | for pts in points3d.values(): |
| | pts["pts3d"] = transform_points(T, pts["pts3d"]) |
| |
|
| | |
| | reference_frame_id = 0 |
| |
|
| | confs = points3d[f"v0_t{reference_frame_id}"]["conf"] |
| | thresh = confs[confs.argsort()][int(confs.size()[0] * self.conf_slider.value)].item() |
| | good_points = (confs > thresh).numpy() |
| |
|
| | tracks = torch.stack([points3d[f"v0_t{s}"]["pts3d"] for s in range(self.S)]) |
| | tracks = tracks[:, good_points, :] |
| | if self._track_visualiser is not None: |
| | tracks_filtered = self._track_visualiser.remove_static_tracks(tracks) |
| | self._track_visualiser.set_data(tracks_filtered) |
| |
|
| | frame_nodes: list[viser.FrameHandle] = [] |
| | point_nodes: list[viser.PointCloudHandle] = [] |
| | for s in range(self.S): |
| | v = points3d[f"v0_t{s}"] |
| | pts3d = v["pts3d"] |
| | colours = v["rgb"] |
| | pts3d_ = pts3d.numpy()[good_points, :] |
| | colours_ = colours if isinstance(colours, tuple) else colours[good_points] |
| | point_node = self.server.scene.add_point_cloud( |
| | name=f"/frames/t{s}/xyz", |
| | points=pts3d_, |
| | colors=colours_, |
| | point_size=point_size, |
| | ) |
| | point_nodes.append(point_node) |
| | frame_node = self.server.scene.add_frame(f"/frames/t{s}", show_axes=False) |
| | frame_node.visible = s == self.gui_timestep.value |
| | frame_nodes.append(frame_node) |
| | self.point_nodes = point_nodes |
| | self.frame_nodes = frame_nodes |
| | |
| |
|
| | scene_centre = points3d["v0_t0"]["pts3d"].mean(dim=0) |
| |
|
| | for client in self.server.get_clients().values(): |
| | camera = client.camera |
| | camera.look_at = scene_centre |
| |
|
| | self.rgb0_vis.image = rgb_v0 |
| | self._update_image_t() |
| |
|
| | def _update_image_t(self): |
| | rgb_vt = (self.imgs[self.gui_timestep.value] * 255.0).to(torch.uint8).permute(1, 2, 0).numpy() |
| | self.rgbt_vis.image = rgb_vt |
| |
|
| | def visualise_reconstruction(self, images, pred, extrinsic): |
| | S = len(pred) |
| | pts3d_all = [pr["pts3d"] for pr in pred] |
| | conf_all = [pr["conf"] for pr in pred] |
| |
|
| | |
| | |
| | pts3d_v0 = torch.stack([pts3d_all[s][:, 0] for s in range(S)], dim=1) |
| | pred_dynamic = dict(pts3d=pts3d_v0) |
| |
|
| | pred_pts_t1 = pred_dynamic["pts3d"] |
| |
|
| | pts3d_t1 = pred_pts_t1[0].detach() |
| |
|
| | indices = torch.arange(S).to(torch.int64) |
| | pts3d_t1 = pts3d_t1[indices] |
| | confs_t1 = conf_all[0][0, 0] |
| | if extrinsic is not None: |
| | extrinsic = extrinsic[indices, ...].cpu() |
| |
|
| | H, W = images.shape[-2:] |
| | imgs = images.cpu() |
| | img_v0 = images[0] |
| |
|
| | panoptic_1 = torch.zeros((H, W), dtype=torch.uint8, device=self.device) |
| | valid_instances = [] |
| |
|
| | self.set_data( |
| | pts3d_t1.cpu(), |
| | confs_t1.cpu(), |
| | img_v0.cpu(), |
| | imgs, |
| | valid_instances, |
| | panoptic_1.cpu(), |
| | extrinsic, |
| | ) |
| |
|
| | def run(self): |
| | """Run the visualization event loop""" |
| | while True: |
| | current_time = time.time() |
| | if self.is_playing and current_time - self.last_update_time > 0.1: |
| | self.gui_timestep.value = (self.gui_timestep.value + 1) % self.S |
| | self.last_update_time = current_time |
| | self.update() |
| | time.sleep(1e-3) |
| |
|
| | def process_example(views, device): |
| | tensors = ['img', 'camera_pose', 'T_WV_norm', 'camera_intrinsics', 'pts3d_t0', 'pts3d_t1', 'valid_mask_t0', 'valid_mask_t1'] |
| | for view in views: |
| | |
| | for name in tensors: |
| | if name not in view: |
| | continue |
| | view[name] = view[name][None, ...] |
| | if isinstance(view[name], np.ndarray): |
| | view[name] = torch.from_numpy(view[name]) |
| |
|
| | for view in views: |
| | for name in tensors: |
| | if name not in view: |
| | continue |
| | view[name] = view[name].to(device, non_blocking=True) |
| | return views |
| |
|
| |
|
| | def compute_predictions(model, images): |
| | print("model inference started") |
| |
|
| | start = time.perf_counter() |
| |
|
| | with torch.no_grad(): |
| | result = model.inference(None, images=images.unsqueeze(0)) |
| | print("model inference finished") |
| | end = time.perf_counter() |
| | print(f"Execution time: {end - start:.6f} seconds") |
| |
|
| | pointmaps = result["pointmaps"] |
| |
|
| | |
| | pose_enc = result["pose_enc"] |
| | HW = pointmaps[0]["pts3d"].shape[2:4] |
| | extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc, HW) |
| | extrinsic = extrinsic[0] |
| | S = extrinsic.shape[0] |
| | extrinsic_CW = torch.cat([extrinsic.cpu(), repeat(torch.tensor([0, 0, 0, 1]), "c -> s 1 c", s=S)], dim=1) |
| | extrinsic_WC = torch.linalg.inv(extrinsic_CW) |
| |
|
| | return pointmaps, extrinsic_WC, intrinsic |
| |
|
| |
|
| | def extract_frames(input_video): |
| | torch.cuda.empty_cache() |
| |
|
| | video_path = input_video |
| | vs = cv2.VideoCapture(video_path) |
| |
|
| | fps = float(vs.get(cv2.CAP_PROP_FPS) or 0.0) |
| | frame_interval = max(int(fps / max(VIDEO_SAMPLE_HZ, 1e-6)), 1) |
| |
|
| | count = 0 |
| | frame_num = 0 |
| | images = [] |
| | try: |
| | while True: |
| | gotit, frame = vs.read() |
| | if not gotit: |
| | break |
| | if count % frame_interval == 0: |
| | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| | images.append(frame) |
| | frame_num += 1 |
| | count += 1 |
| | finally: |
| | vs.release() |
| |
|
| | return images |
| |
|
| |
|
| | def preprocess_images(images_np, mode="crop"): |
| | |
| | if len(images_np) == 0: |
| | raise ValueError("At least 1 image is required") |
| |
|
| | |
| | if mode not in ["crop", "pad"]: |
| | raise ValueError("Mode must be either 'crop' or 'pad'") |
| |
|
| | images = [] |
| | shapes = set() |
| | to_tensor = TF.ToTensor() |
| | target_size = 518 |
| |
|
| | |
| | for img_np in images_np: |
| |
|
| | |
| | img = Image.fromarray(img_np) |
| |
|
| | |
| | if img.mode == "RGBA": |
| | |
| | background = Image.new("RGBA", img.size, (255, 255, 255, 255)) |
| | |
| | img = Image.alpha_composite(background, img) |
| |
|
| | |
| | img = img.convert("RGB") |
| |
|
| | width, height = img.size |
| |
|
| | if mode == "pad": |
| | |
| | if width >= height: |
| | new_width = target_size |
| | new_height = round(height * (new_width / width) / 14) * 14 |
| | else: |
| | new_height = target_size |
| | new_width = round(width * (new_height / height) / 14) * 14 |
| | else: |
| | |
| | new_width = target_size |
| | |
| | new_height = round(height * (new_width / width) / 14) * 14 |
| |
|
| | |
| | img = img.resize((new_width, new_height), Image.Resampling.BICUBIC) |
| | img = to_tensor(img) |
| |
|
| | |
| | if mode == "crop" and new_height > target_size: |
| | start_y = (new_height - target_size) // 2 |
| | img = img[:, start_y : start_y + target_size, :] |
| |
|
| | |
| | if mode == "pad": |
| | h_padding = target_size - img.shape[1] |
| | w_padding = target_size - img.shape[2] |
| |
|
| | if h_padding > 0 or w_padding > 0: |
| | pad_top = h_padding // 2 |
| | pad_bottom = h_padding - pad_top |
| | pad_left = w_padding // 2 |
| | pad_right = w_padding - pad_left |
| |
|
| | |
| | img = torch.nn.functional.pad( |
| | img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0 |
| | ) |
| |
|
| | shapes.add((img.shape[1], img.shape[2])) |
| | images.append(img) |
| |
|
| | |
| | |
| | if len(shapes) > 1: |
| | print(f"Warning: Found images with different shapes: {shapes}") |
| | |
| | max_height = max(shape[0] for shape in shapes) |
| | max_width = max(shape[1] for shape in shapes) |
| |
|
| | |
| | padded_images = [] |
| | for img in images: |
| | h_padding = max_height - img.shape[1] |
| | w_padding = max_width - img.shape[2] |
| |
|
| | if h_padding > 0 or w_padding > 0: |
| | pad_top = h_padding // 2 |
| | pad_bottom = h_padding - pad_top |
| | pad_left = w_padding // 2 |
| | pad_right = w_padding - pad_left |
| |
|
| | img = torch.nn.functional.pad( |
| | img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0 |
| | ) |
| | padded_images.append(img) |
| | images = padded_images |
| |
|
| | images = torch.stack(images) |
| |
|
| | |
| | if len(images_np) == 1: |
| | |
| | if images.dim() == 3: |
| | images = images.unsqueeze(0) |
| |
|
| | return images |
| |
|
| |
|
| | def load_model(cfg, device) -> VDPM: |
| | model = VDPM(cfg).to(device) |
| |
|
| | _URL = "https://huggingface.co/edgarsucar/vdpm/resolve/main/model.pt" |
| | sd = torch.hub.load_state_dict_from_url( |
| | _URL, |
| | file_name="vdpm_model.pt", |
| | progress=True |
| | ) |
| | print(model.load_state_dict(sd, strict=True)) |
| |
|
| | model.eval() |
| | return model |
| |
|
| |
|
| | @hydra.main(config_path="configs", config_name="visualise") |
| | def main(cfg) -> None: |
| | device = 'cuda:0' |
| | torch.backends.cuda.matmul.allow_tf32 = True |
| |
|
| | model = load_model(cfg, device) |
| |
|
| | viewer = ViserViewer(model, device, cfg.vis.port) |
| |
|
| | input_video = cfg.vis.input_video |
| | frames = extract_frames(input_video) |
| | images = preprocess_images(frames).to(device) |
| |
|
| | pointmaps, extrinsic, _ = compute_predictions(model, images) |
| | viewer.visualise_reconstruction(images, pointmaps, extrinsic) |
| | viewer.run() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|