| from typing import *
|
| import torch
|
| from ..voxel import Voxel
|
| import cumesh
|
| from flex_gemm.ops.grid_sample import grid_sample_3d
|
|
|
|
|
| class Mesh:
|
| def __init__(self,
|
| vertices,
|
| faces,
|
| vertex_attrs=None
|
| ):
|
| self.vertices = vertices.float()
|
| self.faces = faces.int()
|
| self.vertex_attrs = vertex_attrs
|
|
|
| @property
|
| def device(self):
|
| return self.vertices.device
|
|
|
| def to(self, device, non_blocking=False):
|
| return Mesh(
|
| self.vertices.to(device, non_blocking=non_blocking),
|
| self.faces.to(device, non_blocking=non_blocking),
|
| self.vertex_attrs.to(device, non_blocking=non_blocking) if self.vertex_attrs is not None else None,
|
| )
|
|
|
| def cuda(self, non_blocking=False):
|
| return self.to('cuda', non_blocking=non_blocking)
|
|
|
| def cpu(self):
|
| return self.to('cpu')
|
|
|
| def fill_holes(self, max_hole_perimeter=3e-2):
|
| vertices = self.vertices.cuda()
|
| faces = self.faces.cuda()
|
|
|
| mesh = cumesh.CuMesh()
|
| mesh.init(vertices, faces)
|
| mesh.get_edges()
|
| mesh.get_boundary_info()
|
| if mesh.num_boundaries == 0:
|
| return
|
| mesh.get_vertex_edge_adjacency()
|
| mesh.get_vertex_boundary_adjacency()
|
| mesh.get_manifold_boundary_adjacency()
|
| mesh.read_manifold_boundary_adjacency()
|
| mesh.get_boundary_connected_components()
|
| mesh.get_boundary_loops()
|
| if mesh.num_boundary_loops == 0:
|
| return
|
| mesh.fill_holes(max_hole_perimeter=max_hole_perimeter)
|
| new_vertices, new_faces = mesh.read()
|
|
|
| self.vertices = new_vertices.to(self.device)
|
| self.faces = new_faces.to(self.device)
|
|
|
| def remove_faces(self, face_mask: torch.Tensor):
|
| vertices = self.vertices.cuda()
|
| faces = self.faces.cuda()
|
|
|
| mesh = cumesh.CuMesh()
|
| mesh.init(vertices, faces)
|
| mesh.remove_faces(face_mask)
|
| new_vertices, new_faces = mesh.read()
|
|
|
| self.vertices = new_vertices.to(self.device)
|
| self.faces = new_faces.to(self.device)
|
|
|
| def simplify(self, target=1000000, verbose: bool=False, options: dict={}):
|
| vertices = self.vertices.cuda()
|
| faces = self.faces.cuda()
|
|
|
| mesh = cumesh.CuMesh()
|
| mesh.init(vertices, faces)
|
| mesh.simplify(target, verbose=verbose, options=options)
|
| new_vertices, new_faces = mesh.read()
|
|
|
| self.vertices = new_vertices.to(self.device)
|
| self.faces = new_faces.to(self.device)
|
|
|
|
|
| class TextureFilterMode:
|
| CLOSEST = 0
|
| LINEAR = 1
|
|
|
|
|
| class TextureWrapMode:
|
| CLAMP_TO_EDGE = 0
|
| REPEAT = 1
|
| MIRRORED_REPEAT = 2
|
|
|
|
|
| class AlphaMode:
|
| OPAQUE = 0
|
| MASK = 1
|
| BLEND = 2
|
|
|
|
|
| class Texture:
|
| def __init__(
|
| self,
|
| image: torch.Tensor,
|
| filter_mode: TextureFilterMode = TextureFilterMode.LINEAR,
|
| wrap_mode: TextureWrapMode = TextureWrapMode.REPEAT
|
| ):
|
| self.image = image
|
| self.filter_mode = filter_mode
|
| self.wrap_mode = wrap_mode
|
|
|
| def to(self, device, non_blocking=False):
|
| return Texture(
|
| self.image.to(device, non_blocking=non_blocking),
|
| self.filter_mode,
|
| self.wrap_mode,
|
| )
|
|
|
|
|
| class PbrMaterial:
|
| def __init__(
|
| self,
|
| base_color_texture: Optional[Texture] = None,
|
| base_color_factor: Union[torch.Tensor, List[float]] = [1.0, 1.0, 1.0],
|
| metallic_texture: Optional[Texture] = None,
|
| metallic_factor: float = 1.0,
|
| roughness_texture: Optional[Texture] = None,
|
| roughness_factor: float = 1.0,
|
| alpha_texture: Optional[Texture] = None,
|
| alpha_factor: float = 1.0,
|
| alpha_mode: AlphaMode = AlphaMode.OPAQUE,
|
| alpha_cutoff: float = 0.5,
|
| ):
|
| self.base_color_texture = base_color_texture
|
| self.base_color_factor = torch.tensor(base_color_factor, dtype=torch.float32)[:3]
|
| self.metallic_texture = metallic_texture
|
| self.metallic_factor = metallic_factor
|
| self.roughness_texture = roughness_texture
|
| self.roughness_factor = roughness_factor
|
| self.alpha_texture = alpha_texture
|
| self.alpha_factor = alpha_factor
|
| self.alpha_mode = alpha_mode
|
| self.alpha_cutoff = alpha_cutoff
|
|
|
| def to(self, device, non_blocking=False):
|
| return PbrMaterial(
|
| base_color_texture=self.base_color_texture.to(device, non_blocking=non_blocking) if self.base_color_texture is not None else None,
|
| base_color_factor=self.base_color_factor.to(device, non_blocking=non_blocking),
|
| metallic_texture=self.metallic_texture.to(device, non_blocking=non_blocking) if self.metallic_texture is not None else None,
|
| metallic_factor=self.metallic_factor,
|
| roughness_texture=self.roughness_texture.to(device, non_blocking=non_blocking) if self.roughness_texture is not None else None,
|
| roughness_factor=self.roughness_factor,
|
| alpha_texture=self.alpha_texture.to(device, non_blocking=non_blocking) if self.alpha_texture is not None else None,
|
| alpha_factor=self.alpha_factor,
|
| alpha_mode=self.alpha_mode,
|
| alpha_cutoff=self.alpha_cutoff,
|
| )
|
|
|
|
|
| class MeshWithPbrMaterial(Mesh):
|
| def __init__(self,
|
| vertices,
|
| faces,
|
| material_ids,
|
| uv_coords,
|
| materials: List[PbrMaterial],
|
| ):
|
| self.vertices = vertices.float()
|
| self.faces = faces.int()
|
| self.material_ids = material_ids
|
| self.uv_coords = uv_coords
|
| self.materials = materials
|
| self.layout = {
|
| 'base_color': slice(0, 3),
|
| 'metallic': slice(3, 4),
|
| 'roughness': slice(4, 5),
|
| 'alpha': slice(5, 6),
|
| }
|
|
|
| def to(self, device, non_blocking=False):
|
| return MeshWithPbrMaterial(
|
| self.vertices.to(device, non_blocking=non_blocking),
|
| self.faces.to(device, non_blocking=non_blocking),
|
| self.material_ids.to(device, non_blocking=non_blocking),
|
| self.uv_coords.to(device, non_blocking=non_blocking),
|
| [material.to(device, non_blocking=non_blocking) for material in self.materials],
|
| )
|
|
|
|
|
| class MeshWithVoxel(Mesh, Voxel):
|
| def __init__(self,
|
| vertices: torch.Tensor,
|
| faces: torch.Tensor,
|
| origin: list,
|
| voxel_size: float,
|
| coords: torch.Tensor,
|
| attrs: torch.Tensor,
|
| voxel_shape: torch.Size,
|
| layout: Dict = {},
|
| ):
|
| self.vertices = vertices.float()
|
| self.faces = faces.int()
|
| self.origin = torch.tensor(origin, dtype=torch.float32, device=self.device)
|
| self.voxel_size = voxel_size
|
| self.coords = coords
|
| self.attrs = attrs
|
| self.voxel_shape = voxel_shape
|
| self.layout = layout
|
|
|
| def to(self, device, non_blocking=False):
|
| return MeshWithVoxel(
|
| self.vertices.to(device, non_blocking=non_blocking),
|
| self.faces.to(device, non_blocking=non_blocking),
|
| self.origin.tolist(),
|
| self.voxel_size,
|
| self.coords.to(device, non_blocking=non_blocking),
|
| self.attrs.to(device, non_blocking=non_blocking),
|
| self.voxel_shape,
|
| self.layout,
|
| )
|
|
|
| def query_attrs(self, xyz):
|
| grid = ((xyz - self.origin) / self.voxel_size).reshape(1, -1, 3)
|
| vertex_attrs = grid_sample_3d(
|
| self.attrs,
|
| torch.cat([torch.zeros_like(self.coords[..., :1]), self.coords], dim=-1),
|
| self.voxel_shape,
|
| grid,
|
| mode='trilinear'
|
| )[0]
|
| return vertex_attrs
|
|
|
| def query_vertex_attrs(self):
|
| return self.query_attrs(self.vertices)
|
|
|