Spaces:
Paused
Paused
| import os | |
| import numpy | |
| import torch | |
| import igl | |
| from . import MeshProcessor | |
| WKS_DIM = MeshProcessor.WKS_DIM | |
| WKS_FACTOR = 1000 | |
| import numpy as np | |
| import sys | |
| import random | |
| import time | |
| class SourceMesh: | |
| ''' | |
| datastructure for the source mesh to be mapped | |
| ''' | |
| def __init__(self, source_ind, source_dir, extra_source_fields, | |
| random_scale, ttype, use_wks=False, random_centering=False, | |
| cpuonly=False): | |
| self.__use_wks = use_wks | |
| self.source_ind = source_ind | |
| self.source_dir = source_dir | |
| self.centroids_and_normals = None | |
| self.center_source = True | |
| self.poisson = None | |
| self.splu = None | |
| self.__source_global_translation_to_original = 0 | |
| self.__extra_keys = extra_source_fields | |
| self.__loaded_data = {} | |
| self.__ttype = ttype | |
| self.__random_scale = random_scale | |
| self.random_centering = random_centering | |
| self.source_mesh_centroid = None | |
| self.mesh_processor = None | |
| self.cpuonly = cpuonly | |
| def get_vertices(self): | |
| return self.source_vertices | |
| def get_global_translation_to_original(self): | |
| return self.__source_global_translation_to_original | |
| def vertices_from_jacobians(self, d): | |
| return self.poisson.solve_poisson(d) | |
| # return self.splu.solve(d) | |
| def jacobians_from_vertices(self, v): | |
| return self.poisson.jacobians_from_vertices(v) | |
| def restrict_jacobians(self, J): | |
| return self.poisson.restrict_jacobians(J) | |
| def get_loaded_data(self, key: str): | |
| return self.__loaded_data.get(key) | |
| def get_source_triangles(self): | |
| # if self.__source_triangles is None: | |
| # self.__source_triangles = np.load(os.path.join(self.source_dir, 'faces.npy')) | |
| return self.mesh_processor.get_faces() | |
| def to(self, device): | |
| self.poisson = self.poisson.to(device) | |
| self.splu = self.splu.to(device) | |
| self.centroids_and_normals = self.centroids_and_normals.to(device) | |
| for key in self.__loaded_data.keys(): | |
| self.__loaded_data[key] = self.__loaded_data[key].to(device) | |
| return self | |
| def __init_from_mesh_data(self): | |
| assert self.mesh_processor is not None | |
| self.mesh_processor.prepare_differential_operators_for_use(self.__ttype) #call 1 | |
| self.source_vertices = torch.from_numpy(self.mesh_processor.get_vertices()).type( | |
| self.__ttype) | |
| if self.__random_scale != 1: | |
| print("Diff ops and WKS need to be multiplied accordingly. Not implemented for now") | |
| sys.exit() | |
| self.source_vertices *= self.__random_scale | |
| bb = igl.bounding_box(self.source_vertices.numpy())[0] | |
| diag = igl.bounding_box_diagonal(self.source_vertices.numpy()) | |
| # self.source_mesh_centroid = torch.mean(self.source_vertices, axis=0) | |
| self.source_mesh_centroid = (bb[0] + bb[-1])/2 | |
| if self.random_centering: | |
| # centering augmentation | |
| self.source_mesh_centroid = self.source_mesh_centroid + [(2*random.random() - 1)*diag*0.2, (2*random.random() - 1)*diag*0.2, (2*random.random() - 1)*diag*0.2] | |
| # self.source_mesh_centroid = (bb[0] + bb[-1])/2 - np.array([-0.00033245, -0.2910367 , 0.02100835]) | |
| # Load input to NJF MLP | |
| # start = time.time() | |
| centroids = self.mesh_processor.get_centroids() | |
| centroid_points_and_normals = centroids.points_and_normals | |
| if self.__use_wks: | |
| wks = WKS_FACTOR * centroids.wks | |
| centroid_points_and_normals = numpy.hstack((centroid_points_and_normals, wks)) | |
| self.centroids_and_normals = torch.from_numpy( | |
| centroid_points_and_normals).type(self.__ttype) | |
| if self.center_source: | |
| c = self.source_mesh_centroid | |
| self.centroids_and_normals[:, 0:3] -= c | |
| self.source_vertices -= c | |
| self.__source_global_translation_to_original = c | |
| self.poisson = self.mesh_processor.diff_ops.poisson_solver | |
| self.splu = self.mesh_processor.diff_ops.MyCuSPLU_solver | |
| # Essentially here we load pointnet data and apply the same preprocessing | |
| for key in self.__extra_keys: | |
| data = self.mesh_processor.get_data(key) | |
| # if data is None: # not found in mesh data so try loading from disk | |
| # data = np.load(os.path.join(self.source_dir, key + ".npy")) | |
| data = torch.from_numpy(data) | |
| if key == 'samples': | |
| if self.center_source: | |
| data -= self.get_mesh_centroid() | |
| scale = self.__random_scale | |
| data *= scale | |
| data = data.unsqueeze(0).type(self.__ttype) | |
| self.__loaded_data[key] = data | |
| # print("Ellapsed load source mesh ", time.time() - start) | |
| def load(self, source_v=None, source_f=None): | |
| # mesh_data = SourceMeshData.SourceMeshData.meshprocessor_from_file(self.source_dir) | |
| if source_v is not None and source_f is not None: | |
| self.mesh_processor = MeshProcessor.MeshProcessor.meshprocessor_from_array(source_v,source_f, self.source_dir, self.__ttype, cpuonly=self.cpuonly, load_wks_samples=self.__use_wks, load_wks_centroids=self.__use_wks) | |
| else: | |
| if os.path.isdir(self.source_dir): | |
| self.mesh_processor = MeshProcessor.MeshProcessor.meshprocessor_from_directory(self.source_dir, self.__ttype, cpuonly=self.cpuonly, load_wks_samples=self.__use_wks, load_wks_centroids=self.__use_wks) | |
| else: | |
| self.mesh_processor = MeshProcessor.MeshProcessor.meshprocessor_from_file(self.source_dir, self.__ttype, cpuonly=self.cpuonly, load_wks_samples=self.__use_wks, load_wks_centroids=self.__use_wks) | |
| self.__init_from_mesh_data() | |
| def get_point_dim(self): | |
| return self.centroids_and_normals.shape[1] | |
| def get_centroids_and_normals(self): | |
| return self.centroids_and_normals | |
| def get_mesh_centroid(self): | |
| return self.source_mesh_centroid | |
| def pin_memory(self): | |
| # self.poisson.pin_memory() | |
| # self.centroids_and_normals.pin_memory() | |
| # self.source_vertices.pin_memory() | |
| # for key in self.__loaded_data.keys(): | |
| # self.__loaded_data[key].pin_memory() | |
| return self | |