Stylique's picture
Upload 65 files
f498ac0 verified
import os
import numpy
import torch
import igl
from . import MeshProcessor
WKS_DIM = MeshProcessor.WKS_DIM
WKS_FACTOR = 1000
import numpy as np
import sys
import random
import time
class SourceMesh:
'''
datastructure for the source mesh to be mapped
'''
def __init__(self, source_ind, source_dir, extra_source_fields,
random_scale, ttype, use_wks=False, random_centering=False,
cpuonly=False):
self.__use_wks = use_wks
self.source_ind = source_ind
self.source_dir = source_dir
self.centroids_and_normals = None
self.center_source = True
self.poisson = None
self.splu = None
self.__source_global_translation_to_original = 0
self.__extra_keys = extra_source_fields
self.__loaded_data = {}
self.__ttype = ttype
self.__random_scale = random_scale
self.random_centering = random_centering
self.source_mesh_centroid = None
self.mesh_processor = None
self.cpuonly = cpuonly
def get_vertices(self):
return self.source_vertices
def get_global_translation_to_original(self):
return self.__source_global_translation_to_original
def vertices_from_jacobians(self, d):
return self.poisson.solve_poisson(d)
# return self.splu.solve(d)
def jacobians_from_vertices(self, v):
return self.poisson.jacobians_from_vertices(v)
def restrict_jacobians(self, J):
return self.poisson.restrict_jacobians(J)
def get_loaded_data(self, key: str):
return self.__loaded_data.get(key)
def get_source_triangles(self):
# if self.__source_triangles is None:
# self.__source_triangles = np.load(os.path.join(self.source_dir, 'faces.npy'))
return self.mesh_processor.get_faces()
def to(self, device):
self.poisson = self.poisson.to(device)
self.splu = self.splu.to(device)
self.centroids_and_normals = self.centroids_and_normals.to(device)
for key in self.__loaded_data.keys():
self.__loaded_data[key] = self.__loaded_data[key].to(device)
return self
def __init_from_mesh_data(self):
assert self.mesh_processor is not None
self.mesh_processor.prepare_differential_operators_for_use(self.__ttype) #call 1
self.source_vertices = torch.from_numpy(self.mesh_processor.get_vertices()).type(
self.__ttype)
if self.__random_scale != 1:
print("Diff ops and WKS need to be multiplied accordingly. Not implemented for now")
sys.exit()
self.source_vertices *= self.__random_scale
bb = igl.bounding_box(self.source_vertices.numpy())[0]
diag = igl.bounding_box_diagonal(self.source_vertices.numpy())
# self.source_mesh_centroid = torch.mean(self.source_vertices, axis=0)
self.source_mesh_centroid = (bb[0] + bb[-1])/2
if self.random_centering:
# centering augmentation
self.source_mesh_centroid = self.source_mesh_centroid + [(2*random.random() - 1)*diag*0.2, (2*random.random() - 1)*diag*0.2, (2*random.random() - 1)*diag*0.2]
# self.source_mesh_centroid = (bb[0] + bb[-1])/2 - np.array([-0.00033245, -0.2910367 , 0.02100835])
# Load input to NJF MLP
# start = time.time()
centroids = self.mesh_processor.get_centroids()
centroid_points_and_normals = centroids.points_and_normals
if self.__use_wks:
wks = WKS_FACTOR * centroids.wks
centroid_points_and_normals = numpy.hstack((centroid_points_and_normals, wks))
self.centroids_and_normals = torch.from_numpy(
centroid_points_and_normals).type(self.__ttype)
if self.center_source:
c = self.source_mesh_centroid
self.centroids_and_normals[:, 0:3] -= c
self.source_vertices -= c
self.__source_global_translation_to_original = c
self.poisson = self.mesh_processor.diff_ops.poisson_solver
self.splu = self.mesh_processor.diff_ops.MyCuSPLU_solver
# Essentially here we load pointnet data and apply the same preprocessing
for key in self.__extra_keys:
data = self.mesh_processor.get_data(key)
# if data is None: # not found in mesh data so try loading from disk
# data = np.load(os.path.join(self.source_dir, key + ".npy"))
data = torch.from_numpy(data)
if key == 'samples':
if self.center_source:
data -= self.get_mesh_centroid()
scale = self.__random_scale
data *= scale
data = data.unsqueeze(0).type(self.__ttype)
self.__loaded_data[key] = data
# print("Ellapsed load source mesh ", time.time() - start)
def load(self, source_v=None, source_f=None):
# mesh_data = SourceMeshData.SourceMeshData.meshprocessor_from_file(self.source_dir)
if source_v is not None and source_f is not None:
self.mesh_processor = MeshProcessor.MeshProcessor.meshprocessor_from_array(source_v,source_f, self.source_dir, self.__ttype, cpuonly=self.cpuonly, load_wks_samples=self.__use_wks, load_wks_centroids=self.__use_wks)
else:
if os.path.isdir(self.source_dir):
self.mesh_processor = MeshProcessor.MeshProcessor.meshprocessor_from_directory(self.source_dir, self.__ttype, cpuonly=self.cpuonly, load_wks_samples=self.__use_wks, load_wks_centroids=self.__use_wks)
else:
self.mesh_processor = MeshProcessor.MeshProcessor.meshprocessor_from_file(self.source_dir, self.__ttype, cpuonly=self.cpuonly, load_wks_samples=self.__use_wks, load_wks_centroids=self.__use_wks)
self.__init_from_mesh_data()
def get_point_dim(self):
return self.centroids_and_normals.shape[1]
def get_centroids_and_normals(self):
return self.centroids_and_normals
def get_mesh_centroid(self):
return self.source_mesh_centroid
def pin_memory(self):
# self.poisson.pin_memory()
# self.centroids_and_normals.pin_memory()
# self.source_vertices.pin_memory()
# for key in self.__loaded_data.keys():
# self.__loaded_data[key].pin_memory()
return self