show / SHOW /modules /PyMAF /utils /mesh_generation.py
camenduru's picture
thanks to show ❤
3bbb319
import time
import torch
import trimesh
import numpy as np
import torch.optim as optim
from torch import autograd
from torch.utils.data import TensorDataset, DataLoader
from .common import make_3d_grid
from .utils import libmcubes
from .utils.libmise import MISE
from .utils.libsimplify import simplify_mesh
from .common import transform_pointcloud
class Generator3D(object):
''' Generator class for DVRs.
It provides functions to generate the final mesh as well refining options.
Args:
model (nn.Module): trained DVR model
points_batch_size (int): batch size for points evaluation
threshold (float): threshold value
refinement_step (int): number of refinement steps
device (device): pytorch device
resolution0 (int): start resolution for MISE
upsampling steps (int): number of upsampling steps
with_normals (bool): whether normals should be estimated
padding (float): how much padding should be used for MISE
simplify_nfaces (int): number of faces the mesh should be simplified to
refine_max_faces (int): max number of faces which are used as batch
size for refinement process (we added this functionality in this
work)
'''
def __init__(self, model, points_batch_size=100000,
threshold=0.5, refinement_step=0, device=None,
resolution0=16, upsampling_steps=3,
with_normals=False, padding=0.1,
simplify_nfaces=None, with_color=False,
refine_max_faces=10000):
self.model = model.to(device)
self.points_batch_size = points_batch_size
self.refinement_step = refinement_step
self.threshold = threshold
self.device = device
self.resolution0 = resolution0
self.upsampling_steps = upsampling_steps
self.with_normals = with_normals
self.padding = padding
self.simplify_nfaces = simplify_nfaces
self.with_color = with_color
self.refine_max_faces = refine_max_faces
def generate_mesh(self, data, return_stats=True):
''' Generates the output mesh.
Args:
data (tensor): data tensor
return_stats (bool): whether stats should be returned
'''
self.model.eval()
device = self.device
stats_dict = {}
inputs = data.get('inputs', torch.empty(1, 0)).to(device)
kwargs = {}
c = self.model.encode_inputs(inputs)
mesh = self.generate_from_latent(c, stats_dict=stats_dict,
data=data, **kwargs)
return mesh, stats_dict
def generate_meshes(self, data, return_stats=True):
''' Generates the output meshes with data of batch size >=1
Args:
data (tensor): data tensor
return_stats (bool): whether stats should be returned
'''
self.model.eval()
device = self.device
stats_dict = {}
inputs = data.get('inputs', torch.empty(1, 1, 0)).to(device)
meshes = []
for i in range(inputs.shape[0]):
input_i = inputs[i].unsqueeze(0)
c = self.model.encode_inputs(input_i)
mesh = self.generate_from_latent(c, stats_dict=stats_dict)
meshes.append(mesh)
return meshes
def generate_pointcloud(self, mesh, data=None, n_points=2000000,
scale_back=True):
''' Generates a point cloud from the mesh.
Args:
mesh (trimesh): mesh
data (dict): data dictionary
n_points (int): number of point cloud points
scale_back (bool): whether to undo scaling (requires a scale
matrix in data dictionary)
'''
pcl = mesh.sample(n_points).astype(np.float32)
if scale_back:
scale_mat = data.get('camera.scale_mat_0', None)
if scale_mat is not None:
pcl = transform_pointcloud(pcl, scale_mat[0])
else:
print('Warning: No scale_mat found!')
pcl_out = trimesh.Trimesh(vertices=pcl, process=False)
return pcl_out
def generate_from_latent(self, c=None, pl=None, stats_dict={}, data=None,
**kwargs):
''' Generates mesh from latent.
Args:
c (tensor): latent conditioned code c
pl (tensor): predicted plane parameters
stats_dict (dict): stats dictionary
'''
threshold = np.log(self.threshold) - np.log(1. - self.threshold)
t0 = time.time()
# Compute bounding box size
box_size = 1 + self.padding
# Shortcut
if self.upsampling_steps == 0:
nx = self.resolution0
pointsf = box_size * make_3d_grid(
(-0.5,)*3, (0.5,)*3, (nx,)*3
)
values = self.eval_points(pointsf, c, pl, **kwargs).cpu().numpy()
value_grid = values.reshape(nx, nx, nx)
else:
mesh_extractor = MISE(
self.resolution0, self.upsampling_steps, threshold)
points = mesh_extractor.query()
while points.shape[0] != 0:
# Query points
pointsf = torch.FloatTensor(points).to(self.device)
# Normalize to bounding box
pointsf = 2 * pointsf / mesh_extractor.resolution
pointsf = box_size * (pointsf - 1.0)
# Evaluate model and update
values = self.eval_points(
pointsf, c, pl, **kwargs).cpu().numpy()
values = values.astype(np.float64)
mesh_extractor.update(points, values)
points = mesh_extractor.query()
value_grid = mesh_extractor.to_dense()
# Extract mesh
stats_dict['time (eval points)'] = time.time() - t0
mesh = self.extract_mesh(value_grid, c, stats_dict=stats_dict)
return mesh
def eval_points(self, p, c=None, pl=None, **kwargs):
''' Evaluates the occupancy values for the points.
Args:
p (tensor): points
c (tensor): latent conditioned code c
'''
p_split = torch.split(p, self.points_batch_size)
occ_hats = []
for pi in p_split:
pi = pi.unsqueeze(0).to(self.device)
with torch.no_grad():
occ_hat = self.model.decode(pi, c, pl, **kwargs).logits
occ_hats.append(occ_hat.squeeze(0).detach().cpu())
occ_hat = torch.cat(occ_hats, dim=0)
return occ_hat
def extract_mesh(self, occ_hat, c=None, stats_dict=dict()):
''' Extracts the mesh from the predicted occupancy grid.
Args:
occ_hat (tensor): value grid of occupancies
c (tensor): latent conditioned code c
stats_dict (dict): stats dictionary
'''
# Some short hands
n_x, n_y, n_z = occ_hat.shape
box_size = 1 + self.padding
threshold = np.log(self.threshold) - np.log(1. - self.threshold)
# Make sure that mesh is watertight
t0 = time.time()
occ_hat_padded = np.pad(
occ_hat, 1, 'constant', constant_values=-1e6)
vertices, triangles = libmcubes.marching_cubes(
occ_hat_padded, threshold)
stats_dict['time (marching cubes)'] = time.time() - t0
# Strange behaviour in libmcubes: vertices are shifted by 0.5
vertices -= 0.5
# Undo padding
vertices -= 1
# Normalize to bounding box
vertices /= np.array([n_x-1, n_y-1, n_z-1])
vertices *= 2
vertices = box_size * (vertices - 1)
# mesh_pymesh = pymesh.form_mesh(vertices, triangles)
# mesh_pymesh = fix_pymesh(mesh_pymesh)
# Estimate normals if needed
if self.with_normals and not vertices.shape[0] == 0:
t0 = time.time()
normals = self.estimate_normals(vertices, c)
stats_dict['time (normals)'] = time.time() - t0
else:
normals = None
# Create mesh
mesh = trimesh.Trimesh(vertices, triangles,
vertex_normals=normals,
# vertex_colors=vertex_colors,
process=False)
# Directly return if mesh is empty
if vertices.shape[0] == 0:
return mesh
# TODO: normals are lost here
if self.simplify_nfaces is not None:
t0 = time.time()
mesh = simplify_mesh(mesh, self.simplify_nfaces, 5.)
stats_dict['time (simplify)'] = time.time() - t0
# Refine mesh
if self.refinement_step > 0:
t0 = time.time()
self.refine_mesh(mesh, occ_hat, c)
stats_dict['time (refine)'] = time.time() - t0
# Estimate Vertex Colors
if self.with_color and not vertices.shape[0] == 0:
t0 = time.time()
vertex_colors = self.estimate_colors(np.array(mesh.vertices), c)
stats_dict['time (color)'] = time.time() - t0
mesh = trimesh.Trimesh(
vertices=mesh.vertices, faces=mesh.faces,
vertex_normals=mesh.vertex_normals,
vertex_colors=vertex_colors, process=False)
return mesh
def estimate_colors(self, vertices, c=None):
''' Estimates vertex colors by evaluating the texture field.
Args:
vertices (numpy array): vertices of the mesh
c (tensor): latent conditioned code c
'''
device = self.device
vertices = torch.FloatTensor(vertices)
vertices_split = torch.split(vertices, self.points_batch_size)
colors = []
for vi in vertices_split:
vi = vi.to(device)
with torch.no_grad():
ci = self.model.decode_color(
vi.unsqueeze(0), c).squeeze(0).cpu()
colors.append(ci)
colors = np.concatenate(colors, axis=0)
colors = np.clip(colors, 0, 1)
colors = (colors * 255).astype(np.uint8)
colors = np.concatenate([
colors, np.full((colors.shape[0], 1), 255, dtype=np.uint8)],
axis=1)
return colors
def estimate_normals(self, vertices, c=None):
''' Estimates the normals by computing the gradient of the objective.
Args:
vertices (numpy array): vertices of the mesh
z (tensor): latent code z
c (tensor): latent conditioned code c
'''
device = self.device
vertices = torch.FloatTensor(vertices)
vertices_split = torch.split(vertices, self.points_batch_size)
normals = []
c = c.unsqueeze(0)
for vi in vertices_split:
vi = vi.unsqueeze(0).to(device)
vi.requires_grad_()
occ_hat = self.model.decode(vi, c).logits
out = occ_hat.sum()
out.backward()
ni = -vi.grad
ni = ni / torch.norm(ni, dim=-1, keepdim=True)
ni = ni.squeeze(0).cpu().numpy()
normals.append(ni)
normals = np.concatenate(normals, axis=0)
return normals
def refine_mesh(self, mesh, occ_hat, c=None):
''' Refines the predicted mesh.
Args:
mesh (trimesh object): predicted mesh
occ_hat (tensor): predicted occupancy grid
c (tensor): latent conditioned code c
'''
self.model.eval()
# Some shorthands
n_x, n_y, n_z = occ_hat.shape
assert(n_x == n_y == n_z)
# threshold = np.log(self.threshold) - np.log(1. - self.threshold)
threshold = self.threshold
# Vertex parameter
v0 = torch.FloatTensor(mesh.vertices).to(self.device)
v = torch.nn.Parameter(v0.clone())
# Faces of mesh
faces = torch.LongTensor(mesh.faces)
# detach c; otherwise graph needs to be retained
# caused by new Pytorch version?
c = c.detach()
# Start optimization
optimizer = optim.RMSprop([v], lr=1e-5)
# Dataset
ds_faces = TensorDataset(faces)
dataloader = DataLoader(ds_faces, batch_size=self.refine_max_faces,
shuffle=True)
# We updated the refinement algorithm to subsample faces; this is
# usefull when using a high extraction resolution / when working on
# small GPUs
it_r = 0
while it_r < self.refinement_step:
for f_it in dataloader:
f_it = f_it[0].to(self.device)
optimizer.zero_grad()
# Loss
face_vertex = v[f_it]
eps = np.random.dirichlet((0.5, 0.5, 0.5), size=f_it.shape[0])
eps = torch.FloatTensor(eps).to(self.device)
face_point = (face_vertex * eps[:, :, None]).sum(dim=1)
face_v1 = face_vertex[:, 1, :] - face_vertex[:, 0, :]
face_v2 = face_vertex[:, 2, :] - face_vertex[:, 1, :]
face_normal = torch.cross(face_v1, face_v2)
face_normal = face_normal / \
(face_normal.norm(dim=1, keepdim=True) + 1e-10)
face_value = torch.cat([
torch.sigmoid(self.model.decode(p_split, c).logits)
for p_split in torch.split(
face_point.unsqueeze(0), 20000, dim=1)], dim=1)
normal_target = -autograd.grad(
[face_value.sum()], [face_point], create_graph=True)[0]
normal_target = \
normal_target / \
(normal_target.norm(dim=1, keepdim=True) + 1e-10)
loss_target = (face_value - threshold).pow(2).mean()
loss_normal = \
(face_normal - normal_target).pow(2).sum(dim=1).mean()
loss = loss_target + 0.01 * loss_normal
# Update
loss.backward()
optimizer.step()
# Update it_r
it_r += 1
if it_r >= self.refinement_step:
break
mesh.vertices = v.data.cpu().numpy()
return mesh