Spaces:
Sleeping
Sleeping
| from http.client import MOVED_PERMANENTLY | |
| import io | |
| import ipdb # noqa: F401 | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import trimesh | |
| import torch | |
| import torchvision | |
| from pytorch3d.loss import chamfer_distance | |
| from scipy.spatial.transform import Rotation | |
| from diffusionsfm.inference.ddim import inference_ddim | |
| from diffusionsfm.utils.rays import ( | |
| Rays, | |
| cameras_to_rays, | |
| rays_to_cameras, | |
| rays_to_cameras_homography, | |
| ) | |
| from diffusionsfm.utils.geometry import ( | |
| compute_optimal_alignment, | |
| ) | |
| cmap = plt.get_cmap("hsv") | |
| def create_training_visualizations( | |
| model, | |
| images, | |
| device, | |
| cameras_gt, | |
| num_images, | |
| crop_parameters, | |
| pred_x0=False, | |
| no_crop_param_device="cpu", | |
| visualize_pred=False, | |
| return_first=False, | |
| calculate_intrinsics=False, | |
| mode=None, | |
| depths=None, | |
| scale_min=-1, | |
| scale_max=1, | |
| diffuse_depths=False, | |
| vis_mode=None, | |
| average_centers=True, | |
| full_num_patches_x=16, | |
| full_num_patches_y=16, | |
| use_homogeneous=False, | |
| distortion_coefficients=None, | |
| ): | |
| if model.depth_resolution == 1: | |
| W_in = W_out = full_num_patches_x | |
| H_in = H_out = full_num_patches_y | |
| else: | |
| W_in = H_in = model.width | |
| W_out = model.width * model.depth_resolution | |
| H_out = model.width * model.depth_resolution | |
| rays_final, rays_intermediate, pred_intermediate, _ = inference_ddim( | |
| model, | |
| images, | |
| device, | |
| crop_parameters=crop_parameters, | |
| eta=[1, 0], | |
| num_patches_x=W_in, | |
| num_patches_y=H_in, | |
| visualize=True, | |
| ) | |
| if vis_mode is None: | |
| vis_mode = mode | |
| T = model.noise_scheduler.max_timesteps | |
| if T == 1000: | |
| ts = [0, 100, 200, 300, 400, 500, 600, 700, 800, 900, 999] | |
| else: | |
| ts = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 99] | |
| # Get predicted cameras from rays | |
| pred_cameras_batched = [] | |
| vis_images = [] | |
| pred_rays = [] | |
| for index in range(len(images)): | |
| pred_cameras = [] | |
| per_sample_images = [] | |
| for ii in range(num_images): | |
| rays_gt = cameras_to_rays( | |
| cameras_gt[index], | |
| crop_parameters[index], | |
| no_crop_param_device=no_crop_param_device, | |
| num_patches_x=W_in, | |
| num_patches_y=H_in, | |
| depths=None if depths is None else depths[index], | |
| mode=mode, | |
| depth_resolution=model.depth_resolution, | |
| distortion_coefficients=( | |
| None | |
| if distortion_coefficients is None | |
| else distortion_coefficients[index] | |
| ), | |
| ) | |
| image_vis = (images[index, ii].cpu().permute(1, 2, 0).numpy() + 1) / 2 | |
| if diffuse_depths: | |
| fig, axs = plt.subplots(3, 13, figsize=(15, 4.5), dpi=100) | |
| else: | |
| fig, axs = plt.subplots(3, 9, figsize=(12, 4.5), dpi=100) | |
| for i, t in enumerate(ts): | |
| r, c = i // 4, i % 4 | |
| if visualize_pred: | |
| curr = pred_intermediate[t][index] | |
| else: | |
| curr = rays_intermediate[t][index] | |
| rays = Rays.from_spatial( | |
| curr, | |
| mode=mode, | |
| num_patches_x=H_in, | |
| num_patches_y=W_in, | |
| use_homogeneous=use_homogeneous, | |
| ) | |
| if vis_mode == "segment": | |
| vis = ( | |
| torch.clip( | |
| rays.get_segments()[ii], min=scale_min, max=scale_max | |
| ) | |
| - scale_min | |
| ) / (scale_max - scale_min) | |
| else: | |
| vis = ( | |
| torch.nn.functional.normalize(rays.get_moments()[ii], dim=-1) | |
| + 1 | |
| ) / 2 | |
| axs[r, c].imshow(vis.reshape(W_out, H_out, 3).cpu()) | |
| axs[r, c].set_title(f"T={T - t}") | |
| i += 1 | |
| r, c = i // 4, i % 4 | |
| if vis_mode == "segment": | |
| vis = ( | |
| torch.clip(rays_gt.get_segments()[ii], min=scale_min, max=scale_max) | |
| - scale_min | |
| ) / (scale_max - scale_min) | |
| else: | |
| vis = ( | |
| torch.nn.functional.normalize(rays_gt.get_moments()[ii], dim=-1) + 1 | |
| ) / 2 | |
| axs[r, c].imshow(vis.reshape(W_out, H_out, 3).cpu()) | |
| type_str = "Endpoints" if vis_mode == "segment" else "Moments" | |
| axs[r, c].set_title(f"GT {type_str}") | |
| for i, t in enumerate(ts): | |
| r, c = i // 4, i % 4 + 4 | |
| if visualize_pred: | |
| curr = pred_intermediate[t][index] | |
| else: | |
| curr = rays_intermediate[t][index] | |
| rays = Rays.from_spatial( | |
| curr, | |
| mode, | |
| num_patches_x=H_in, | |
| num_patches_y=W_in, | |
| use_homogeneous=use_homogeneous, | |
| ) | |
| if vis_mode == "segment": | |
| vis = ( | |
| torch.clip( | |
| rays.get_origins(high_res=True)[ii], | |
| min=scale_min, | |
| max=scale_max, | |
| ) | |
| - scale_min | |
| ) / (scale_max - scale_min) | |
| else: | |
| vis = ( | |
| torch.nn.functional.normalize(rays.get_directions()[ii], dim=-1) | |
| + 1 | |
| ) / 2 | |
| axs[r, c].imshow(vis.reshape(W_out, H_out, 3).cpu()) | |
| axs[r, c].set_title(f"T={T - t}") | |
| i += 1 | |
| r, c = i // 4, i % 4 + 4 | |
| if vis_mode == "segment": | |
| vis = ( | |
| torch.clip( | |
| rays_gt.get_origins(high_res=True)[ii], | |
| min=scale_min, | |
| max=scale_max, | |
| ) | |
| - scale_min | |
| ) / (scale_max - scale_min) | |
| else: | |
| vis = ( | |
| torch.nn.functional.normalize(rays_gt.get_directions()[ii], dim=-1) | |
| + 1 | |
| ) / 2 | |
| axs[r, c].imshow(vis.reshape(W_out, H_out, 3).cpu()) | |
| type_str = "Origins" if vis_mode == "segment" else "Directions" | |
| axs[r, c].set_title(f"GT {type_str}") | |
| if diffuse_depths: | |
| for i, t in enumerate(ts): | |
| r, c = i // 4, i % 4 + 8 | |
| if visualize_pred: | |
| curr = pred_intermediate[t][index] | |
| else: | |
| curr = rays_intermediate[t][index] | |
| rays = Rays.from_spatial( | |
| curr, | |
| mode, | |
| num_patches_x=H_in, | |
| num_patches_y=W_in, | |
| use_homogeneous=use_homogeneous, | |
| ) | |
| vis = rays.depths[ii] | |
| if len(rays.depths[ii].shape) < 2: | |
| vis = rays.depths[ii].reshape(H_out, W_out) | |
| axs[r, c].imshow(vis.cpu()) | |
| axs[r, c].set_title(f"T={T - t}") | |
| i += 1 | |
| r, c = i // 4, i % 4 + 8 | |
| vis = depths[index][ii] | |
| if len(rays.depths[ii].shape) < 2: | |
| vis = depths[index][ii].reshape(256, 256) | |
| axs[r, c].imshow(vis.cpu()) | |
| axs[r, c].set_title(f"GT Depths") | |
| axs[2, -1].imshow(image_vis) | |
| axs[2, -1].set_title("Input Image") | |
| for s in ["bottom", "top", "left", "right"]: | |
| axs[2, -1].spines[s].set_color(cmap(ii / (num_images))) | |
| axs[2, -1].spines[s].set_linewidth(5) | |
| for ax in axs.flatten(): | |
| ax.set_xticks([]) | |
| ax.set_yticks([]) | |
| plt.tight_layout() | |
| img = plot_to_image(fig) | |
| plt.close() | |
| per_sample_images.append(img) | |
| if return_first: | |
| rays_camera = pred_intermediate[0][index] | |
| elif pred_x0: | |
| rays_camera = pred_intermediate[-1][index] | |
| else: | |
| rays_camera = rays_final[index] | |
| rays = Rays.from_spatial( | |
| rays_camera, | |
| mode=mode, | |
| num_patches_x=H_in, | |
| num_patches_y=W_in, | |
| use_homogeneous=use_homogeneous, | |
| ) | |
| if calculate_intrinsics: | |
| pred_camera = rays_to_cameras_homography( | |
| rays=rays[ii, None], | |
| crop_parameters=crop_parameters[index], | |
| num_patches_x=W_in, | |
| num_patches_y=H_in, | |
| average_centers=average_centers, | |
| depth_resolution=model.depth_resolution, | |
| ) | |
| else: | |
| pred_camera = rays_to_cameras( | |
| rays=rays[ii, None], | |
| crop_parameters=crop_parameters[index], | |
| no_crop_param_device=no_crop_param_device, | |
| num_patches_x=W_in, | |
| num_patches_y=H_in, | |
| depth_resolution=model.depth_resolution, | |
| average_centers=average_centers, | |
| ) | |
| pred_cameras.append(pred_camera[0]) | |
| pred_rays.append(rays) | |
| pred_cameras_batched.append(pred_cameras) | |
| vis_images.append(np.vstack(per_sample_images)) | |
| return vis_images, pred_cameras_batched, pred_rays | |
| def plot_to_image(figure, dpi=100): | |
| """Converts matplotlib fig to a png for logging with tf.summary.image.""" | |
| buffer = io.BytesIO() | |
| figure.savefig(buffer, format="raw", dpi=dpi) | |
| plt.close(figure) | |
| buffer.seek(0) | |
| image = np.reshape( | |
| np.frombuffer(buffer.getvalue(), dtype=np.uint8), | |
| newshape=(int(figure.bbox.bounds[3]), int(figure.bbox.bounds[2]), -1), | |
| ) | |
| return image[..., :3] | |
| def view_color_coded_images_from_tensor(images, depth=False): | |
| num_frames = images.shape[0] | |
| cmap = plt.get_cmap("hsv") | |
| num_rows = 3 | |
| num_cols = 3 | |
| figsize = (num_cols * 2, num_rows * 2) | |
| fig, axs = plt.subplots(num_rows, num_cols, figsize=figsize) | |
| axs = axs.flatten() | |
| for i in range(num_rows * num_cols): | |
| if i < num_frames: | |
| if images[i].shape[0] == 3: | |
| image = images[i].permute(1, 2, 0) | |
| else: | |
| image = images[i].unsqueeze(-1) | |
| if not depth: | |
| image = image * 0.5 + 0.5 | |
| else: | |
| image = image.repeat((1, 1, 3)) / torch.max(image) | |
| axs[i].imshow(image) | |
| for s in ["bottom", "top", "left", "right"]: | |
| axs[i].spines[s].set_color(cmap(i / (num_frames))) | |
| axs[i].spines[s].set_linewidth(5) | |
| axs[i].set_xticks([]) | |
| axs[i].set_yticks([]) | |
| else: | |
| axs[i].axis("off") | |
| plt.tight_layout() | |
| return fig | |
| def color_and_filter_points(points, images, mask, num_show, resolution): | |
| # Resize images | |
| resize = torchvision.transforms.Resize(resolution) | |
| images = resize(images) * 0.5 + 0.5 | |
| # Reshape points and calculate mask | |
| points = points.reshape(num_show * resolution * resolution, 3) | |
| mask = mask.reshape(num_show * resolution * resolution) | |
| depth_mask = torch.argwhere(mask > 0.5)[:, 0] | |
| points = points[depth_mask] | |
| # Mask and reshape colors | |
| colors = images.permute(0, 2, 3, 1).reshape(num_show * resolution * resolution, 3) | |
| colors = colors[depth_mask] | |
| return points, colors | |
| def filter_and_align_point_clouds( | |
| num_frames, | |
| gt_points, | |
| pred_points, | |
| gt_masks, | |
| pred_masks, | |
| images, | |
| metrics=False, | |
| num_patches_x=16, | |
| ): | |
| # Filter and color points | |
| gt_points, gt_colors = color_and_filter_points( | |
| gt_points, images, gt_masks, num_show=num_frames, resolution=num_patches_x | |
| ) | |
| pred_points, pred_colors = color_and_filter_points( | |
| pred_points, images, pred_masks, num_show=num_frames, resolution=num_patches_x | |
| ) | |
| pred_points, _, _, _ = compute_optimal_alignment( | |
| gt_points.float(), pred_points.float() | |
| ) | |
| # Scale PCL so that furthest point from centroid is distance 1 | |
| centroid = torch.mean(gt_points, dim=0) | |
| dists = torch.norm(gt_points - centroid.unsqueeze(0), dim=-1) | |
| scale = torch.mean(dists) | |
| gt_points_scaled = (gt_points - centroid) / scale | |
| pred_points_scaled = (pred_points - centroid) / scale | |
| if metrics: | |
| cd, _ = chamfer_distance( | |
| pred_points_scaled.unsqueeze(0), gt_points_scaled.unsqueeze(0) | |
| ) | |
| cd = cd.item() | |
| mse = torch.mean( | |
| torch.norm(pred_points_scaled - gt_points_scaled, dim=-1), dim=-1 | |
| ).item() | |
| else: | |
| mse, cd = None, None | |
| return ( | |
| gt_points, | |
| pred_points, | |
| gt_colors, | |
| pred_colors, | |
| [mse, cd, None], | |
| ) | |
| def add_scene_cam(scene, c2w, edge_color, image=None, focal=None, imsize=None, screen_width=0.03): | |
| OPENGL = np.array([ | |
| [1, 0, 0, 0], | |
| [0, -1, 0, 0], | |
| [0, 0, -1, 0], | |
| [0, 0, 0, 1] | |
| ]) | |
| if image is not None: | |
| H, W, THREE = image.shape | |
| assert THREE == 3 | |
| if image.dtype != np.uint8: | |
| image = np.uint8(255*image) | |
| elif imsize is not None: | |
| W, H = imsize | |
| elif focal is not None: | |
| H = W = focal / 1.1 | |
| else: | |
| H = W = 1 | |
| if focal is None: | |
| focal = min(H, W) * 1.1 # default value | |
| elif isinstance(focal, np.ndarray): | |
| focal = focal[0] | |
| # create fake camera | |
| height = focal * screen_width / H | |
| width = screen_width * 0.5**0.5 | |
| rot45 = np.eye(4) | |
| rot45[:3, :3] = Rotation.from_euler('z', np.deg2rad(45)).as_matrix() | |
| rot45[2, 3] = -height # set the tip of the cone = optical center | |
| aspect_ratio = np.eye(4) | |
| aspect_ratio[0, 0] = W/H | |
| transform = c2w @ OPENGL @ aspect_ratio @ rot45 | |
| cam = trimesh.creation.cone(width, height, sections=4) | |
| # this is the camera mesh | |
| rot2 = np.eye(4) | |
| rot2[:3, :3] = Rotation.from_euler('z', np.deg2rad(4)).as_matrix() | |
| vertices = cam.vertices | |
| vertices_offset = 0.9 * cam.vertices | |
| vertices = np.r_[vertices, vertices_offset, geotrf(rot2, cam.vertices)] | |
| vertices = geotrf(transform, vertices) | |
| faces = [] | |
| for face in cam.faces: | |
| if 0 in face: | |
| continue | |
| a, b, c = face | |
| a2, b2, c2 = face + len(cam.vertices) | |
| # add 3 pseudo-edges | |
| faces.append((a, b, b2)) | |
| faces.append((a, a2, c)) | |
| faces.append((c2, b, c)) | |
| faces.append((a, b2, a2)) | |
| faces.append((a2, c, c2)) | |
| faces.append((c2, b2, b)) | |
| # no culling | |
| faces += [(c, b, a) for a, b, c in faces] | |
| for i,face in enumerate(cam.faces): | |
| if 0 in face: | |
| continue | |
| if i == 1 or i == 5: | |
| a, b, c = face | |
| faces.append((a, b, c)) | |
| cam = trimesh.Trimesh(vertices=vertices, faces=faces) | |
| cam.visual.face_colors[:, :3] = edge_color | |
| scene.add_geometry(cam) | |
| def geotrf(Trf, pts, ncol=None, norm=False): | |
| """ Apply a geometric transformation to a list of 3-D points. | |
| H: 3x3 or 4x4 projection matrix (typically a Homography) | |
| p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3) | |
| ncol: int. number of columns of the result (2 or 3) | |
| norm: float. if != 0, the resut is projected on the z=norm plane. | |
| Returns an array of projected 2d points. | |
| """ | |
| assert Trf.ndim >= 2 | |
| if isinstance(Trf, np.ndarray): | |
| pts = np.asarray(pts) | |
| elif isinstance(Trf, torch.Tensor): | |
| pts = torch.as_tensor(pts, dtype=Trf.dtype) | |
| # adapt shape if necessary | |
| output_reshape = pts.shape[:-1] | |
| ncol = ncol or pts.shape[-1] | |
| # optimized code | |
| if (isinstance(Trf, torch.Tensor) and isinstance(pts, torch.Tensor) and | |
| Trf.ndim == 3 and pts.ndim == 4): | |
| d = pts.shape[3] | |
| if Trf.shape[-1] == d: | |
| pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts) | |
| elif Trf.shape[-1] == d+1: | |
| pts = torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) + Trf[:, None, None, :d, d] | |
| else: | |
| raise ValueError(f'bad shape, not ending with 3 or 4, for {pts.shape=}') | |
| else: | |
| if Trf.ndim >= 3: | |
| n = Trf.ndim-2 | |
| assert Trf.shape[:n] == pts.shape[:n], 'batch size does not match' | |
| Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1]) | |
| if pts.ndim > Trf.ndim: | |
| # Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d) | |
| pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1]) | |
| elif pts.ndim == 2: | |
| # Trf == (B,d,d) & pts == (B,d) --> (B, 1, d) | |
| pts = pts[:, None, :] | |
| if pts.shape[-1]+1 == Trf.shape[-1]: | |
| Trf = Trf.swapaxes(-1, -2) # transpose Trf | |
| pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :] | |
| elif pts.shape[-1] == Trf.shape[-1]: | |
| Trf = Trf.swapaxes(-1, -2) # transpose Trf | |
| pts = pts @ Trf | |
| else: | |
| pts = Trf @ pts.T | |
| if pts.ndim >= 2: | |
| pts = pts.swapaxes(-1, -2) | |
| if norm: | |
| pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG | |
| if norm != 1: | |
| pts *= norm | |
| res = pts[..., :ncol].reshape(*output_reshape, ncol) | |
| return res |