Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn.functional as F | |
| from torch import Tensor | |
| import numpy as np | |
| from tqdm import tqdm | |
| from PIL import Image | |
| from optgs.experimental.edgs.utils import ( | |
| select_cameras_kmeans, | |
| k_closest_vectors, | |
| aggregate_confidences_and_warps, | |
| extract_keypoints_and_colors, | |
| triangulate_points, | |
| select_best_keypoints | |
| ) | |
| def init_gaussians_with_corr( | |
| viewpoints_img: Tensor, # (N, 3, H, W) original images, column-major | |
| viewpoints_w2c: Tensor, # (N, 4, 4) camera-to-world matrices, column-major | |
| viewpoints_proj: Tensor, # (N, 4, 4) projection matrices, column-major | |
| camera_centers: Tensor, # (N, 3) camera centers in world coordinates | |
| init_opacity: float, | |
| roma_model_type: str, | |
| verbose: bool = False | |
| ): | |
| """ | |
| For a given input gaussians and a scene we instantiate a RoMa model(change to indoors if necessary) and process scene | |
| training frames to extract correspondences. Those are used to initialize gaussians | |
| Args: | |
| scene: object of the Scene class. | |
| cfg: configuration. Use init_wC | |
| Returns: | |
| gaussians: inplace transforms object gaussians of the class GaussianModel. | |
| """ | |
| # default values used in original EDGS code | |
| num_refs: int = 180 | |
| nns_per_ref: int = 3 | |
| matches_per_ref: int = 20000 | |
| proj_err_tolerance: float = 0.01 | |
| device = viewpoints_w2c.device | |
| try: | |
| from romatch import roma_outdoor, roma_indoor | |
| except ImportError as e: | |
| raise ImportError( | |
| "The edgs initializer requires RoMa (romatch), which is not " | |
| "installed. Install it with: " | |
| "pip install git+https://github.com/Parskatt/RoMa.git" | |
| ) from e | |
| if roma_model_type == "indoors": | |
| roma_model = roma_indoor(device=device) | |
| else: | |
| roma_model = roma_outdoor(device=device) | |
| roma_model.upsample_preds = False | |
| roma_model.symmetric = False | |
| M = matches_per_ref | |
| upper_thresh = roma_model.sample_thresh | |
| expansion_factor = 1 | |
| keypoint_fit_error_tolerance = proj_err_tolerance | |
| visualizations = {} | |
| N_VIEWS = viewpoints_img.shape[0] | |
| NUM_REFERENCE_FRAMES = min(num_refs, N_VIEWS) | |
| NUM_NNS_PER_REFERENCE = min(nns_per_ref , N_VIEWS) | |
| # Select cameras using K-means | |
| # viewpoint_cam_all = torch.stack([x.world_view_transform.flatten() for x in viewpoint_stack], axis=0) | |
| viewpoint_cam_all = viewpoints_w2c.reshape(N_VIEWS, -1) # (N_VIEWS, 16) | |
| selected_indices = select_cameras_kmeans(cameras=viewpoint_cam_all.detach().cpu().numpy(), K=NUM_REFERENCE_FRAMES) | |
| selected_indices = sorted(selected_indices) | |
| # Find the k-closest vectors for each vector | |
| closest_indices = k_closest_vectors(viewpoint_cam_all, NUM_NNS_PER_REFERENCE) | |
| if verbose: | |
| print("Indices of k-closest vectors for each vector:\n", closest_indices) | |
| closest_indices_selected = closest_indices[:, :].detach().cpu().numpy() | |
| all_new_xyz = [] | |
| all_new_rgb = [] | |
| all_new_scaling = [] | |
| all_new_opacities_raw = [] | |
| # Run roma_model.match once to kinda initialize the model | |
| viewpoint_img1 = viewpoints_img[0].cpu().numpy().transpose(1, 2, 0) # [H, W, 3] | |
| viewpoint_img2 = viewpoints_img[1].cpu().numpy().transpose(1, 2, 0) # [H, W, 3] | |
| imA = Image.fromarray(np.clip(viewpoint_img1 * 255, 0, 255).astype(np.uint8)) | |
| imB = Image.fromarray(np.clip(viewpoint_img2 * 255, 0, 255).astype(np.uint8)) | |
| warp, certainty_warp = roma_model.match(imA, imB, device=device) | |
| if verbose: | |
| print("Once run full roma_model.match warp.shape:", warp.shape) | |
| print("Once run full roma_model.match certainty_warp.shape:", certainty_warp.shape) | |
| del warp, certainty_warp | |
| torch.cuda.empty_cache() | |
| for source_idx in tqdm(sorted(selected_indices)): | |
| # 1. Compute keypoints and warping for all the neigboring views | |
| # Call the aggregation function to get imA and imB_compound | |
| certainties_max, warps_max, certainties_max_idcs, imA, imB_compound, certainties_all, warps_all = aggregate_confidences_and_warps( | |
| # viewpoint_stack=viewpoint_stack, | |
| viewpoints_img=viewpoints_img, | |
| closest_indices=closest_indices_selected, | |
| roma_model=roma_model, | |
| source_idx=source_idx, | |
| verbose=verbose, output_dict=visualizations | |
| ) | |
| # Triangulate keypoints | |
| matches = warps_max | |
| certainty = certainties_max | |
| certainty = certainty.clone() | |
| certainty[certainty > upper_thresh] = 1 | |
| matches, certainty = ( | |
| matches.reshape(-1, 4), | |
| certainty.reshape(-1), | |
| ) | |
| # Select based on certainty elements with high confidence. These are basically all of | |
| # kptsA_np. | |
| good_samples = torch.multinomial(certainty, | |
| num_samples=min(expansion_factor * M, len(certainty)), | |
| replacement=False) | |
| certainties_max, warps_max, certainties_max_idcs, imA, imB_compound, certainties_all, warps_all | |
| reference_image_dict = { | |
| "ref_image": imA, | |
| "NNs_images": imB_compound, | |
| "certainties_all": certainties_all, | |
| "warps_all": warps_all, | |
| "triangulated_points": [], | |
| "triangulated_points_errors_proj1": [], | |
| "triangulated_points_errors_proj2": [] | |
| } | |
| for NN_idx in tqdm(range(len(warps_all))): | |
| matches_NN = warps_all[NN_idx].reshape(-1, 4)[good_samples] | |
| # Extract keypoints and colors | |
| kptsA_np, kptsB_np, kptsB_proj_matrices_idcs, kptsA_color, kptsB_color = extract_keypoints_and_colors( | |
| imA, imB_compound, certainties_max, certainties_max_idcs, matches_NN, roma_model | |
| ) | |
| # proj_matrices_A = viewpoint_stack[source_idx].full_proj_transform | |
| # proj_matrices_B = viewpoint_stack[closest_indices_selected[source_idx, NN_idx]].full_proj_transform | |
| proj_matrices_A = viewpoints_proj[source_idx] | |
| proj_matrices_B = viewpoints_proj[closest_indices_selected[source_idx, NN_idx]] | |
| # exit(0) | |
| triangulated_points, triangulated_points_errors_proj1, triangulated_points_errors_proj2 = triangulate_points( | |
| P1=torch.stack([proj_matrices_A] * M, axis=0), | |
| P2=torch.stack([proj_matrices_B] * M, axis=0), | |
| k1_x=kptsA_np[:M, 0], k1_y=kptsA_np[:M, 1], | |
| k2_x=kptsB_np[:M, 0], k2_y=kptsB_np[:M, 1]) | |
| reference_image_dict["triangulated_points"].append(triangulated_points) | |
| reference_image_dict["triangulated_points_errors_proj1"].append(triangulated_points_errors_proj1) | |
| reference_image_dict["triangulated_points_errors_proj2"].append(triangulated_points_errors_proj2) | |
| NNs_triangulated_points_selected, NNs_triangulated_points_selected_proj_errors = select_best_keypoints( | |
| NNs_triangulated_points=torch.stack(reference_image_dict["triangulated_points"], dim=0), | |
| NNs_errors_proj1=np.stack(reference_image_dict["triangulated_points_errors_proj1"], axis=0), | |
| NNs_errors_proj2=np.stack(reference_image_dict["triangulated_points_errors_proj2"], axis=0)) | |
| # 4. Save as gaussians | |
| # N = len(NNs_triangulated_points_selected) | |
| new_xyz = NNs_triangulated_points_selected[:, :-1] | |
| all_new_xyz.append(new_xyz) # seeked_splats | |
| all_new_rgb.append(torch.tensor(kptsA_color.astype(np.float32) / 255.).to(device)) | |
| mask_bad_points = torch.tensor( | |
| NNs_triangulated_points_selected_proj_errors > keypoint_fit_error_tolerance, | |
| dtype=torch.float32) | |
| mask_bad_points = mask_bad_points.to(device) | |
| print("Number of bad points for source_idx", source_idx, ":", mask_bad_points.sum().item()) | |
| # exit(0) | |
| new_opacities = torch.ones((new_xyz.shape[0]), device=device) * init_opacity | |
| new_opacities_raw = torch.logit(new_opacities) | |
| new_opacities_raw = new_opacities_raw - mask_bad_points * (1e1) | |
| all_new_opacities_raw.append(new_opacities_raw) | |
| camera_center = camera_centers[source_idx].unsqueeze(0) # (1, 3) | |
| dist_points_to_cam1 = torch.linalg.norm(camera_center - new_xyz, dim=1, ord=2) | |
| all_new_scaling.append((dist_points_to_cam1 * 0.001).unsqueeze(1).repeat(1, 3)) | |
| all_new_xyz = torch.cat(all_new_xyz, dim=0) | |
| all_new_rgb = torch.cat(all_new_rgb, dim=0) | |
| all_new_scaling = torch.cat(all_new_scaling, dim=0) | |
| all_new_opacities_raw = torch.cat(all_new_opacities_raw, dim=0) | |
| all_new_opacities = torch.sigmoid(all_new_opacities_raw) | |
| points_dict = { | |
| "xyz": all_new_xyz, | |
| "rgb": all_new_rgb, | |
| "scales": all_new_scaling, | |
| "opacities": all_new_opacities, | |
| } | |
| return closest_indices_selected, visualizations, points_dict | |