| from typing import Union, Tuple, List |
|
|
| import numpy as np |
| import torch |
| from skimage import measure |
|
|
|
|
| class MeshExtractResult: |
| def __init__(self, verts, faces, vertex_attrs=None, res=64): |
| self.verts = verts |
| self.faces = faces.long() |
| self.vertex_attrs = vertex_attrs |
| self.face_normal = self.comput_face_normals() |
| self.vert_normal = self.comput_v_normals() |
| self.res = res |
| self.success = verts.shape[0] != 0 and faces.shape[0] != 0 |
|
|
| |
| self.tsdf_v = None |
| self.tsdf_s = None |
| self.reg_loss = None |
|
|
| def comput_face_normals(self): |
| i0 = self.faces[..., 0].long() |
| i1 = self.faces[..., 1].long() |
| i2 = self.faces[..., 2].long() |
|
|
| v0 = self.verts[i0, :] |
| v1 = self.verts[i1, :] |
| v2 = self.verts[i2, :] |
| face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1) |
| face_normals = torch.nn.functional.normalize(face_normals, dim=1) |
| return face_normals[:, None, :].repeat(1, 3, 1) |
|
|
| def comput_v_normals(self): |
| i0 = self.faces[..., 0].long() |
| i1 = self.faces[..., 1].long() |
| i2 = self.faces[..., 2].long() |
|
|
| v0 = self.verts[i0, :] |
| v1 = self.verts[i1, :] |
| v2 = self.verts[i2, :] |
| face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1) |
| v_normals = torch.zeros_like(self.verts) |
| v_normals.scatter_add_(0, i0[..., None].repeat(1, 3), face_normals) |
| v_normals.scatter_add_(0, i1[..., None].repeat(1, 3), face_normals) |
| v_normals.scatter_add_(0, i2[..., None].repeat(1, 3), face_normals) |
|
|
| v_normals = torch.nn.functional.normalize(v_normals, dim=1) |
| return v_normals |
|
|
|
|
| def center_vertices(vertices): |
| """Translate the vertices so that bounding box is centered at zero.""" |
| vert_min = vertices.min(dim=0)[0] |
| vert_max = vertices.max(dim=0)[0] |
| vert_center = 0.5 * (vert_min + vert_max) |
| return vertices - vert_center |
|
|
|
|
| class SurfaceExtractor: |
| def _compute_box_stat( |
| self, bounds: Union[Tuple[float], List[float], float], octree_resolution: int |
| ): |
| if isinstance(bounds, float): |
| bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds] |
|
|
| bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6]) |
| bbox_size = bbox_max - bbox_min |
| grid_size = [ |
| int(octree_resolution) + 1, |
| int(octree_resolution) + 1, |
| int(octree_resolution) + 1, |
| ] |
| return grid_size, bbox_min, bbox_size |
|
|
| def run(self, *args, **kwargs): |
| return NotImplementedError |
|
|
| def __call__(self, grid_logits, **kwargs): |
| outputs = [] |
| for i in range(grid_logits.shape[0]): |
| try: |
| verts, faces = self.run(grid_logits[i], **kwargs) |
| outputs.append( |
| MeshExtractResult( |
| verts=verts.float(), |
| faces=faces, |
| res=kwargs["octree_resolution"], |
| ) |
| ) |
|
|
| except Exception: |
| import traceback |
|
|
| traceback.print_exc() |
| outputs.append(None) |
|
|
| return outputs |
|
|
|
|
| class MCSurfaceExtractor(SurfaceExtractor): |
| def run(self, grid_logit, *, mc_level, bounds, octree_resolution, **kwargs): |
| verts, faces, normals, _ = measure.marching_cubes( |
| grid_logit.float().cpu().numpy(), mc_level, method="lewiner" |
| ) |
| grid_size, bbox_min, bbox_size = self._compute_box_stat( |
| bounds, octree_resolution |
| ) |
| verts = verts / grid_size * bbox_size + bbox_min |
| verts = torch.tensor(verts, device=grid_logit.device, dtype=torch.float32) |
| faces = torch.tensor( |
| np.ascontiguousarray(faces), device=grid_logit.device, dtype=torch.long |
| ) |
| faces = faces[:, [2, 1, 0]] |
| return verts, faces |
|
|
|
|
| class DMCSurfaceExtractor(SurfaceExtractor): |
| def run(self, grid_logit, *, octree_resolution, **kwargs): |
| device = grid_logit.device |
| if not hasattr(self, "dmc"): |
| try: |
| from diso import DiffDMC |
| except: |
| raise ImportError( |
| "Please install diso via `pip install diso`, or set mc_algo to 'mc'" |
| ) |
| self.dmc = DiffDMC(dtype=torch.float32).to(device) |
| sdf = -grid_logit / octree_resolution |
| sdf = sdf.to(torch.float32).contiguous() |
| verts, faces = self.dmc(sdf, deform=None, return_quads=False, normalize=True) |
| grid_size, bbox_min, bbox_size = self._compute_box_stat( |
| kwargs["bounds"], octree_resolution |
| ) |
| verts = verts * kwargs["bounds"] * 2 - kwargs["bounds"] |
| return verts, faces |
|
|