Spaces:
Sleeping
Sleeping
| import sys | |
| from pathlib import Path | |
| from typing import Union | |
| import h5py | |
| import numpy as np | |
| import open3d as o3d | |
| import torch | |
| from rich.progress import track | |
| from salad.utils.paths import SPAGHETTI_DIR | |
| from salad.utils import nputil, thutil, sysutil, meshutil | |
| # TODO rewrite SPAGHETTI's relative path dependecies. | |
| # Too lazy to refactorize SPAGHETTI's relative paths.. | |
| def add_spaghetti_path(spaghetti_path=SPAGHETTI_DIR): | |
| spaghetti_path = str(spaghetti_path) | |
| if spaghetti_path not in sys.path: | |
| sys.path.append(spaghetti_path) | |
| def delete_spaghetti_path( | |
| spaghetti_path=SPAGHETTI_DIR, | |
| ): | |
| spaghetti_path = str(spaghetti_path) | |
| if spaghetti_path in sys.path: | |
| sys.path.remove(spaghetti_path) | |
| def load_spaghetti(device, tag="chairs_large"): | |
| assert tag in [ | |
| "chairs_large", | |
| "airplanes", | |
| "tables", | |
| ], f"tag should be 'chairs_large', 'airplanes' or 'tables'." | |
| add_spaghetti_path() | |
| from salad.spaghetti.options import Options | |
| from salad.spaghetti.ui import occ_inference | |
| opt = Options() | |
| opt.dataset_size = 1 | |
| opt.device = device | |
| opt.tag = tag | |
| infer_module = occ_inference.Inference(opt) | |
| spaghetti = infer_module.model.to(device) | |
| spaghetti.eval() | |
| for p in spaghetti.parameters(): | |
| p.requires_grad_(False) | |
| delete_spaghetti_path() | |
| return spaghetti | |
| def load_mesher( | |
| device, | |
| min_res=64, | |
| ): | |
| from salad.spaghetti.utils.mcubes_meshing import MarchingCubesMeshing | |
| mesher = MarchingCubesMeshing(device=device, min_res=min_res) | |
| delete_spaghetti_path() | |
| return mesher | |
| def get_mesh_and_pc(spaghetti, mesher, zc): | |
| vert, face = get_mesh_from_spaghetti(spaghetti, mesher, zc) | |
| pc = poisson_sampling(vert, face) | |
| return vert, face, pc | |
| def get_mesh_from_spaghetti(spaghetti, mesher, zc, res=256): | |
| mesh = mesher.occ_meshing( | |
| decoder=get_occ_func(spaghetti, zc), res=res, get_time=False, verbose=False | |
| ) | |
| vert, face = list(map(lambda x: thutil.th2np(x), mesh)) | |
| return vert, face | |
| def poisson_sampling(vert: np.array, face: np.array): | |
| vert_o3d = o3d.utility.Vector3dVector(vert) | |
| face_o3d = o3d.utility.Vector3iVector(face) | |
| mesh_o3d = o3d.geometry.TriangleMesh(vert_o3d, face_o3d) | |
| pc_o3d = mesh_o3d.sample_points_poisson_disk(2048) | |
| pc = np.asarray(pc_o3d.points).astype(np.float32) | |
| return pc | |
| def get_occ_func(spaghetti, zc): | |
| device = spaghetti.device | |
| zc = nputil.np2th(zc).to(device) | |
| def forward(x): | |
| nonlocal zc | |
| x = x.unsqueeze(0) | |
| out = spaghetti.occupancy_network(x, zc)[0, :] | |
| out = 2 * out.sigmoid_() - 1 | |
| return out | |
| if zc.dim() == 2: | |
| zc = zc.unsqueeze(0) | |
| return forward | |
| def generate_zc_from_sj_gaus( | |
| spaghetti, | |
| sj: Union[torch.Tensor, np.ndarray], | |
| gaus: Union[torch.Tensor, np.ndarray], | |
| ): | |
| """ | |
| Input: | |
| sj: [B,16,512] or [16,512] | |
| gaus: [B,16,16] or [16,16] | |
| Output: | |
| zc: [B,16,512] | |
| """ | |
| device = spaghetti.device | |
| sj = nputil.np2th(sj) | |
| gaus = nputil.np2th(gaus) | |
| assert sj.dim() == gaus.dim() | |
| if sj.dim() == 2: | |
| sj = sj.unsqueeze(0) | |
| batch_sj = sj.to(device) | |
| batch_gmms = batch_gaus_to_gmms(gaus, device) | |
| zcs, _ = spaghetti.merge_zh(batch_sj, batch_gmms) | |
| return zcs | |
| def generate_zc_from_za(spaghetti, za: Union[torch.Tensor, np.ndarray]): | |
| device = spaghetti.device | |
| za = nputil.np2th(za).to(device) | |
| sjs, gmms = spaghetti.decomposition_control(za) | |
| zcs, _ = spaghetti.merge_zh(sjs, gmms) | |
| return zcs | |
| def generate_gaus_from_za(spaghetti, za): | |
| # device = spaghetti.device | |
| # za = nputil.np2th(za).to(device) | |
| sjs, gmms = spaghetti.decomposition_control(za) | |
| if isinstance(gmms[0], list): | |
| gaus = gmms[0] | |
| else: | |
| gaus = list(gmms) | |
| gaus = [flatten_gmms_item(x) for x in gaus] | |
| gaus = torch.cat(gaus, -1) | |
| # gaus = batch_gmms_to_gaus(gmms) | |
| return gaus | |
| def generate_zc_from_single_phase_latent( | |
| spaghetti, sj_gaus: Union[torch.Tensor, np.ndarray] | |
| ): | |
| device = spaghetti.device | |
| sj_gaus = nputil.np2th(sj_gaus).to(device) | |
| sj, gaus = sj_gaus.split(split_size=[512, 16], dim=-1) | |
| zcs = generate_zc_from_sj_gaus(spaghetti, sj, gaus) | |
| return zcs | |
| def flatten_gmms_item(x): | |
| """ | |
| Input: [B,1,G,*shapes] | |
| Output: [B,G,-1] | |
| """ | |
| return x.reshape(x.shape[0], x.shape[2], -1) | |
| def batch_gmms_to_gaus(gmms): | |
| """ | |
| Input: | |
| [T(B,1,G,3), T(B,1,G,3,3), T(B,1,G), T(B,1,G,3)] | |
| Output: | |
| T(B,G,16) | |
| """ | |
| if isinstance(gmms[0], list): | |
| gaus = gmms[0].copy() | |
| else: | |
| gaus = list(gmms).copy() | |
| gaus = [flatten_gmms_item(x) for x in gaus] | |
| return torch.cat(gaus, -1) | |
| def batch_gaus_to_gmms(gaus, device="cpu"): | |
| """ | |
| Input: T(B,G,16) | |
| Output: [mu: T(B,1,G,3), eivec: T(B,1,G,3,3), pi: T(B,1,G), eival: T(B,1,G,3)] | |
| """ | |
| gaus = nputil.np2th(gaus).to(device) | |
| if len(gaus.shape) < 3: | |
| gaus = gaus.unsqueeze(0) # expand dim for batch | |
| B, G, _ = gaus.shape | |
| mu = gaus[:, :, :3].reshape(B, 1, G, 3) | |
| eivec = gaus[:, :, 3:12].reshape(B, 1, G, 3, 3) | |
| pi = gaus[:, :, 12].reshape(B, 1, G) | |
| eival = gaus[:, :, 13:16].reshape(B, 1, G, 3) | |
| return [mu, eivec, pi, eival] | |
| def reflect_and_concat_gmms(gmms: torch.Tensor): | |
| """ | |
| Input: | |
| gmms: (B, 8, 16). A batch of GMMs | |
| Output: | |
| new_gmms: (B, 16, 16) | |
| """ | |
| gmms = nputil.np2th(gmms) | |
| gmms = gmms.clone() | |
| if gmms.dim() == 2: | |
| gmms = gmms.unsqueeze(0) | |
| affine = torch.eye(3).to(gmms) | |
| affine[0, 0] = -1.0 | |
| mu, p, phi, eigen = torch.split(gmms, [3, 9, 1, 3], dim=2) | |
| if affine.ndim == 2: | |
| affine = affine.unsqueeze(0).expand(mu.size(0), *affine.shape) | |
| bs, n_part, _ = mu.shape | |
| p = p.reshape(bs, n_part, 3, 3) | |
| mu_r = torch.einsum("bad, bnd -> bna", affine, mu) | |
| p_r = torch.einsum("bad, bncd -> bnca", affine, p) | |
| p_r = p_r.reshape(bs, n_part, -1) | |
| gmms_t = torch.cat([mu_r, p_r, phi, eigen], dim=2) | |
| assert ( | |
| gmms.shape == gmms_t.shape | |
| ), "Input and reflected gmms shapes must be the same" | |
| return torch.cat([gmms, gmms_t], dim=1) | |
| def clip_eigenvalues(gaus: Union[torch.Tensor, np.ndarray], eps=1e-4): | |
| """ | |
| Input: | |
| gaus: [B,G,16] or [G,16] | |
| Output: | |
| gaus_clipped: [B,G,16] or [G,16] torch.Tensor | |
| """ | |
| gaus = nputil.np2th(gaus) | |
| clipped_gaus = gaus.clone() | |
| clipped_gaus[..., 13:16] = torch.clamp_min(clipped_gaus[..., 13:16], eps) | |
| return clipped_gaus | |
| def project_eigenvectors(gaus: Union[torch.Tensor, np.ndarray]): | |
| """ | |
| Input: | |
| gaus: [B,G,16] or [G,16] | |
| Output: | |
| gaus_projected: [B,G,16] or [1,G,16] | |
| """ | |
| gaus = nputil.np2th(gaus).clone() | |
| if gaus.ndim == 2: | |
| gaus = gaus.unsqueeze(0) | |
| B, G = gaus.shape[:2] | |
| eigvec = gaus[:, :, 3:12] | |
| eigvec_projected = get_orthonormal_bases_svd(eigvec) | |
| gaus[:, :, 3:12] = eigvec_projected | |
| return gaus | |
| def get_orthonormal_bases_svd(vs: torch.Tensor): | |
| """ | |
| Implements the solution for the Orthogonal Procrustes problem, | |
| which projects a matrix to the closest rotation matrix / reflection matrix using SVD. | |
| Args: | |
| vs: Tensor of shape (B, M, 9) | |
| Returns: | |
| p: Tensor of shape (B, M, 9). | |
| """ | |
| # Compute SVDs of matrices in batch | |
| b, m, _ = vs.shape | |
| vs_ = vs.reshape(b * m, 3, 3) | |
| U, _, Vh = torch.linalg.svd(vs_) | |
| # Determine the diagonal matrix to make determinants 1 | |
| sigma = torch.eye(3)[None, ...].repeat(b * m, 1, 1).to(vs_.device) | |
| det = torch.linalg.det(torch.bmm(U, Vh)) # Compute determinants of UVT | |
| #### | |
| # Do not set the sign of determinants to 1. | |
| # Inputs contain reflection matrices. | |
| # sigma[:, 2, 2] = det | |
| #### | |
| # Construct orthogonal matrices | |
| p = torch.bmm(torch.bmm(U, sigma), Vh) | |
| return p.reshape(b, m, 9) | |
| def save_meshes_and_pointclouds( | |
| spaghetti, | |
| mesher, | |
| zcs, | |
| save_top_dir, | |
| mesh_save_dir=None, | |
| pc_save_dir=None, | |
| num_shapes=2000, | |
| ): | |
| save_top_dir = Path(save_top_dir) | |
| print(f"Save dir is: {save_top_dir}") | |
| if mesh_save_dir is None: | |
| mesh_save_dir = save_top_dir / "meshes" | |
| mesh_save_dir.mkdir(exist_ok=True) | |
| if pc_save_dir is None: | |
| pc_save_dir = save_top_dir / "pointclouds" | |
| pc_save_dir.mkdir(exist_ok=True) | |
| mesh_save_dir = Path(mesh_save_dir) | |
| pc_save_dir = Path(pc_save_dir) | |
| all_pointclouds = np.zeros((num_shapes, 2048, 3)) | |
| for i in track(range(num_shapes), description="extracting pc and mesh"): | |
| zc = zcs[i] | |
| vert_np, face_np, pc_np = get_mesh_and_pc(spaghetti, mesher, zc) | |
| sysutil.clean_gpu() | |
| all_pointclouds[i] = pc_np | |
| meshutil.write_obj_triangle(mesh_save_dir / f"{i}.obj", vert_np, face_np) | |
| np.save(pc_save_dir / f"{i}.npy", pc_np) | |
| if i == 1000: | |
| with h5py.File(save_top_dir / "o3d_all_pointclouds.hdf5", "w") as f: | |
| f["data"] = all_pointclouds[:1000] | |
| with h5py.File(save_top_dir / "o3d_all_pointclouds.hdf5", "w") as f: | |
| f["data"] = all_pointclouds | |