| |
|
|
| import pickle |
| from functools import lru_cache |
| from typing import Dict, Optional, Tuple |
| import torch |
|
|
| from detectron2.utils.file_io import PathManager |
|
|
| from densepose.data.meshes.catalog import MeshCatalog, MeshInfo |
|
|
|
|
| def _maybe_copy_to_device( |
| attribute: Optional[torch.Tensor], device: torch.device |
| ) -> Optional[torch.Tensor]: |
| if attribute is None: |
| return None |
| return attribute.to(device) |
|
|
|
|
| class Mesh: |
| def __init__( |
| self, |
| vertices: Optional[torch.Tensor] = None, |
| faces: Optional[torch.Tensor] = None, |
| geodists: Optional[torch.Tensor] = None, |
| symmetry: Optional[Dict[str, torch.Tensor]] = None, |
| texcoords: Optional[torch.Tensor] = None, |
| mesh_info: Optional[MeshInfo] = None, |
| device: Optional[torch.device] = None, |
| ): |
| """ |
| Args: |
| vertices (tensor [N, 3] of float32): vertex coordinates in 3D |
| faces (tensor [M, 3] of long): triangular face represented as 3 |
| vertex indices |
| geodists (tensor [N, N] of float32): geodesic distances from |
| vertex `i` to vertex `j` (optional, default: None) |
| symmetry (dict: str -> tensor): various mesh symmetry data: |
| - "vertex_transforms": vertex mapping under horizontal flip, |
| tensor of size [N] of type long; vertex `i` is mapped to |
| vertex `tensor[i]` (optional, default: None) |
| texcoords (tensor [N, 2] of float32): texture coordinates, i.e. global |
| and normalized mesh UVs (optional, default: None) |
| mesh_info (MeshInfo type): necessary to load the attributes on-the-go, |
| can be used instead of passing all the variables one by one |
| device (torch.device): device of the Mesh. If not provided, will use |
| the device of the vertices |
| """ |
| self._vertices = vertices |
| self._faces = faces |
| self._geodists = geodists |
| self._symmetry = symmetry |
| self._texcoords = texcoords |
| self.mesh_info = mesh_info |
| self.device = device |
|
|
| assert self._vertices is not None or self.mesh_info is not None |
|
|
| all_fields = [self._vertices, self._faces, self._geodists, self._texcoords] |
|
|
| if self.device is None: |
| for field in all_fields: |
| if field is not None: |
| self.device = field.device |
| break |
| if self.device is None and symmetry is not None: |
| for key in symmetry: |
| self.device = symmetry[key].device |
| break |
| self.device = torch.device("cpu") if self.device is None else self.device |
|
|
| assert all([var.device == self.device for var in all_fields if var is not None]) |
| if symmetry: |
| assert all(symmetry[key].device == self.device for key in symmetry) |
| if texcoords and vertices: |
| assert len(vertices) == len(texcoords) |
|
|
| def to(self, device: torch.device): |
| device_symmetry = self._symmetry |
| if device_symmetry: |
| device_symmetry = {key: value.to(device) for key, value in device_symmetry.items()} |
| return Mesh( |
| _maybe_copy_to_device(self._vertices, device), |
| _maybe_copy_to_device(self._faces, device), |
| _maybe_copy_to_device(self._geodists, device), |
| device_symmetry, |
| _maybe_copy_to_device(self._texcoords, device), |
| self.mesh_info, |
| device, |
| ) |
|
|
| @property |
| def vertices(self): |
| if self._vertices is None and self.mesh_info is not None: |
| self._vertices = load_mesh_data(self.mesh_info.data, "vertices", self.device) |
| return self._vertices |
|
|
| @property |
| def faces(self): |
| if self._faces is None and self.mesh_info is not None: |
| self._faces = load_mesh_data(self.mesh_info.data, "faces", self.device) |
| return self._faces |
|
|
| @property |
| def geodists(self): |
| if self._geodists is None and self.mesh_info is not None: |
| self._geodists = load_mesh_auxiliary_data(self.mesh_info.geodists, self.device) |
| return self._geodists |
|
|
| @property |
| def symmetry(self): |
| if self._symmetry is None and self.mesh_info is not None: |
| self._symmetry = load_mesh_symmetry(self.mesh_info.symmetry, self.device) |
| return self._symmetry |
|
|
| @property |
| def texcoords(self): |
| if self._texcoords is None and self.mesh_info is not None: |
| self._texcoords = load_mesh_auxiliary_data(self.mesh_info.texcoords, self.device) |
| return self._texcoords |
|
|
| def get_geodists(self): |
| if self.geodists is None: |
| self.geodists = self._compute_geodists() |
| return self.geodists |
|
|
| def _compute_geodists(self): |
| |
| geodists = None |
| return geodists |
|
|
|
|
| def load_mesh_data( |
| mesh_fpath: str, field: str, device: Optional[torch.device] = None |
| ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: |
| with PathManager.open(mesh_fpath, "rb") as hFile: |
| |
| |
| return torch.as_tensor(pickle.load(hFile)[field], dtype=torch.float).to(device) |
| return None |
|
|
|
|
| def load_mesh_auxiliary_data( |
| fpath: str, device: Optional[torch.device] = None |
| ) -> Optional[torch.Tensor]: |
| fpath_local = PathManager.get_local_path(fpath) |
| with PathManager.open(fpath_local, "rb") as hFile: |
| return torch.as_tensor(pickle.load(hFile), dtype=torch.float).to(device) |
| return None |
|
|
|
|
| @lru_cache() |
| def load_mesh_symmetry( |
| symmetry_fpath: str, device: Optional[torch.device] = None |
| ) -> Optional[Dict[str, torch.Tensor]]: |
| with PathManager.open(symmetry_fpath, "rb") as hFile: |
| symmetry_loaded = pickle.load(hFile) |
| symmetry = { |
| "vertex_transforms": torch.as_tensor( |
| symmetry_loaded["vertex_transforms"], dtype=torch.long |
| ).to(device), |
| } |
| return symmetry |
| return None |
|
|
|
|
| @lru_cache() |
| def create_mesh(mesh_name: str, device: Optional[torch.device] = None) -> Mesh: |
| return Mesh(mesh_info=MeshCatalog[mesh_name], device=device) |
|
|