Spaces:
Paused
Paused
File size: 6,356 Bytes
f498ac0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import os
import numpy
import torch
import igl
from . import MeshProcessor
WKS_DIM = MeshProcessor.WKS_DIM
WKS_FACTOR = 1000
import numpy as np
import sys
import random
import time
class SourceMesh:
'''
datastructure for the source mesh to be mapped
'''
def __init__(self, source_ind, source_dir, extra_source_fields,
random_scale, ttype, use_wks=False, random_centering=False,
cpuonly=False):
self.__use_wks = use_wks
self.source_ind = source_ind
self.source_dir = source_dir
self.centroids_and_normals = None
self.center_source = True
self.poisson = None
self.splu = None
self.__source_global_translation_to_original = 0
self.__extra_keys = extra_source_fields
self.__loaded_data = {}
self.__ttype = ttype
self.__random_scale = random_scale
self.random_centering = random_centering
self.source_mesh_centroid = None
self.mesh_processor = None
self.cpuonly = cpuonly
def get_vertices(self):
return self.source_vertices
def get_global_translation_to_original(self):
return self.__source_global_translation_to_original
def vertices_from_jacobians(self, d):
return self.poisson.solve_poisson(d)
# return self.splu.solve(d)
def jacobians_from_vertices(self, v):
return self.poisson.jacobians_from_vertices(v)
def restrict_jacobians(self, J):
return self.poisson.restrict_jacobians(J)
def get_loaded_data(self, key: str):
return self.__loaded_data.get(key)
def get_source_triangles(self):
# if self.__source_triangles is None:
# self.__source_triangles = np.load(os.path.join(self.source_dir, 'faces.npy'))
return self.mesh_processor.get_faces()
def to(self, device):
self.poisson = self.poisson.to(device)
self.splu = self.splu.to(device)
self.centroids_and_normals = self.centroids_and_normals.to(device)
for key in self.__loaded_data.keys():
self.__loaded_data[key] = self.__loaded_data[key].to(device)
return self
def __init_from_mesh_data(self):
assert self.mesh_processor is not None
self.mesh_processor.prepare_differential_operators_for_use(self.__ttype) #call 1
self.source_vertices = torch.from_numpy(self.mesh_processor.get_vertices()).type(
self.__ttype)
if self.__random_scale != 1:
print("Diff ops and WKS need to be multiplied accordingly. Not implemented for now")
sys.exit()
self.source_vertices *= self.__random_scale
bb = igl.bounding_box(self.source_vertices.numpy())[0]
diag = igl.bounding_box_diagonal(self.source_vertices.numpy())
# self.source_mesh_centroid = torch.mean(self.source_vertices, axis=0)
self.source_mesh_centroid = (bb[0] + bb[-1])/2
if self.random_centering:
# centering augmentation
self.source_mesh_centroid = self.source_mesh_centroid + [(2*random.random() - 1)*diag*0.2, (2*random.random() - 1)*diag*0.2, (2*random.random() - 1)*diag*0.2]
# self.source_mesh_centroid = (bb[0] + bb[-1])/2 - np.array([-0.00033245, -0.2910367 , 0.02100835])
# Load input to NJF MLP
# start = time.time()
centroids = self.mesh_processor.get_centroids()
centroid_points_and_normals = centroids.points_and_normals
if self.__use_wks:
wks = WKS_FACTOR * centroids.wks
centroid_points_and_normals = numpy.hstack((centroid_points_and_normals, wks))
self.centroids_and_normals = torch.from_numpy(
centroid_points_and_normals).type(self.__ttype)
if self.center_source:
c = self.source_mesh_centroid
self.centroids_and_normals[:, 0:3] -= c
self.source_vertices -= c
self.__source_global_translation_to_original = c
self.poisson = self.mesh_processor.diff_ops.poisson_solver
self.splu = self.mesh_processor.diff_ops.MyCuSPLU_solver
# Essentially here we load pointnet data and apply the same preprocessing
for key in self.__extra_keys:
data = self.mesh_processor.get_data(key)
# if data is None: # not found in mesh data so try loading from disk
# data = np.load(os.path.join(self.source_dir, key + ".npy"))
data = torch.from_numpy(data)
if key == 'samples':
if self.center_source:
data -= self.get_mesh_centroid()
scale = self.__random_scale
data *= scale
data = data.unsqueeze(0).type(self.__ttype)
self.__loaded_data[key] = data
# print("Ellapsed load source mesh ", time.time() - start)
def load(self, source_v=None, source_f=None):
# mesh_data = SourceMeshData.SourceMeshData.meshprocessor_from_file(self.source_dir)
if source_v is not None and source_f is not None:
self.mesh_processor = MeshProcessor.MeshProcessor.meshprocessor_from_array(source_v,source_f, self.source_dir, self.__ttype, cpuonly=self.cpuonly, load_wks_samples=self.__use_wks, load_wks_centroids=self.__use_wks)
else:
if os.path.isdir(self.source_dir):
self.mesh_processor = MeshProcessor.MeshProcessor.meshprocessor_from_directory(self.source_dir, self.__ttype, cpuonly=self.cpuonly, load_wks_samples=self.__use_wks, load_wks_centroids=self.__use_wks)
else:
self.mesh_processor = MeshProcessor.MeshProcessor.meshprocessor_from_file(self.source_dir, self.__ttype, cpuonly=self.cpuonly, load_wks_samples=self.__use_wks, load_wks_centroids=self.__use_wks)
self.__init_from_mesh_data()
def get_point_dim(self):
return self.centroids_and_normals.shape[1]
def get_centroids_and_normals(self):
return self.centroids_and_normals
def get_mesh_centroid(self):
return self.source_mesh_centroid
def pin_memory(self):
# self.poisson.pin_memory()
# self.centroids_and_normals.pin_memory()
# self.source_vertices.pin_memory()
# for key in self.__loaded_data.keys():
# self.__loaded_data[key].pin_memory()
return self
|