| from __future__ import division, print_function |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import voxelize_cuda |
| from torch.autograd import Function |
|
|
|
|
| class VoxelizationFunction(Function): |
| """ |
| Definition of differentiable voxelization function |
| Currently implemented only for cuda Tensors |
| """ |
| @staticmethod |
| def forward( |
| ctx, |
| smpl_vertices, |
| smpl_face_center, |
| smpl_face_normal, |
| smpl_vertex_code, |
| smpl_face_code, |
| smpl_tetrahedrons, |
| volume_res, |
| sigma, |
| smooth_kernel_size, |
| ): |
| """ |
| forward pass |
| Output format: (batch_size, z_dims, y_dims, x_dims, channel_num) |
| """ |
| assert smpl_vertices.size()[1] == smpl_vertex_code.size()[1] |
| assert smpl_face_center.size()[1] == smpl_face_normal.size()[1] |
| assert smpl_face_center.size()[1] == smpl_face_code.size()[1] |
| ctx.batch_size = smpl_vertices.size()[0] |
| ctx.volume_res = volume_res |
| ctx.sigma = sigma |
| ctx.smooth_kernel_size = smooth_kernel_size |
| ctx.smpl_vertex_num = smpl_vertices.size()[1] |
| ctx.device = smpl_vertices.device |
|
|
| smpl_vertices = smpl_vertices.contiguous() |
| smpl_face_center = smpl_face_center.contiguous() |
| smpl_face_normal = smpl_face_normal.contiguous() |
| smpl_vertex_code = smpl_vertex_code.contiguous() |
| smpl_face_code = smpl_face_code.contiguous() |
| smpl_tetrahedrons = smpl_tetrahedrons.contiguous() |
|
|
| occ_volume = torch.cuda.FloatTensor( |
| ctx.batch_size, ctx.volume_res, ctx.volume_res, ctx.volume_res |
| ).fill_(0.0) |
| semantic_volume = torch.cuda.FloatTensor( |
| ctx.batch_size, ctx.volume_res, ctx.volume_res, ctx.volume_res, 3 |
| ).fill_(0.0) |
| weight_sum_volume = torch.cuda.FloatTensor( |
| ctx.batch_size, ctx.volume_res, ctx.volume_res, ctx.volume_res |
| ).fill_(1e-3) |
|
|
| |
| |
| |
|
|
| ( |
| occ_volume, |
| semantic_volume, |
| weight_sum_volume, |
| ) = voxelize_cuda.forward_semantic_voxelization( |
| smpl_vertices, |
| smpl_vertex_code, |
| smpl_tetrahedrons, |
| occ_volume, |
| semantic_volume, |
| weight_sum_volume, |
| sigma, |
| ) |
|
|
| return semantic_volume |
|
|
|
|
| class Voxelization(nn.Module): |
| """ |
| Wrapper around the autograd function VoxelizationFunction |
| """ |
| def __init__( |
| self, |
| smpl_vertex_code, |
| smpl_face_code, |
| smpl_face_indices, |
| smpl_tetraderon_indices, |
| volume_res, |
| sigma, |
| smooth_kernel_size, |
| batch_size, |
| ): |
| super(Voxelization, self).__init__() |
| assert len(smpl_face_indices.shape) == 2 |
| assert len(smpl_tetraderon_indices.shape) == 2 |
| assert smpl_face_indices.shape[1] == 3 |
| assert smpl_tetraderon_indices.shape[1] == 4 |
|
|
| self.volume_res = volume_res |
| self.sigma = sigma |
| self.smooth_kernel_size = smooth_kernel_size |
| self.batch_size = batch_size |
| self.device = None |
|
|
| self.smpl_vertex_code = smpl_vertex_code |
| self.smpl_face_code = smpl_face_code |
| self.smpl_face_indices = smpl_face_indices |
| self.smpl_tetraderon_indices = smpl_tetraderon_indices |
|
|
| def update_param(self, voxel_faces): |
|
|
| self.device = voxel_faces.device |
|
|
| self.smpl_tetraderon_indices = voxel_faces |
|
|
| smpl_vertex_code_batch = torch.tile(self.smpl_vertex_code, (self.batch_size, 1, 1)) |
| smpl_face_code_batch = torch.tile(self.smpl_face_code, (self.batch_size, 1, 1)) |
| smpl_face_indices_batch = torch.tile(self.smpl_face_indices, (self.batch_size, 1, 1)) |
|
|
| smpl_vertex_code_batch = (smpl_vertex_code_batch.contiguous().to(self.device)) |
| smpl_face_code_batch = (smpl_face_code_batch.contiguous().to(self.device)) |
| smpl_face_indices_batch = (smpl_face_indices_batch.contiguous().to(self.device)) |
| smpl_tetraderon_indices_batch = (self.smpl_tetraderon_indices.contiguous().to(self.device)) |
|
|
| self.register_buffer("smpl_vertex_code_batch", smpl_vertex_code_batch) |
| self.register_buffer("smpl_face_code_batch", smpl_face_code_batch) |
| self.register_buffer("smpl_face_indices_batch", smpl_face_indices_batch) |
| self.register_buffer("smpl_tetraderon_indices_batch", smpl_tetraderon_indices_batch) |
|
|
| def forward(self, smpl_vertices): |
| """ |
| Generate semantic volumes from SMPL vertices |
| """ |
| self.check_input(smpl_vertices) |
| smpl_faces = self.vertices_to_faces(smpl_vertices) |
| smpl_tetrahedrons = self.vertices_to_tetrahedrons(smpl_vertices) |
| smpl_face_center = self.calc_face_centers(smpl_faces) |
| smpl_face_normal = self.calc_face_normals(smpl_faces) |
| smpl_surface_vertex_num = self.smpl_vertex_code_batch.size()[1] |
| smpl_vertices_surface = smpl_vertices[:, :smpl_surface_vertex_num, :] |
| vol = VoxelizationFunction.apply( |
| smpl_vertices_surface, |
| smpl_face_center, |
| smpl_face_normal, |
| self.smpl_vertex_code_batch, |
| self.smpl_face_code_batch, |
| smpl_tetrahedrons, |
| self.volume_res, |
| self.sigma, |
| self.smooth_kernel_size, |
| ) |
| return vol.permute((0, 4, 1, 2, 3)) |
|
|
| def vertices_to_faces(self, vertices): |
| assert vertices.ndimension() == 3 |
| bs, nv = vertices.shape[:2] |
| face = ( |
| self.smpl_face_indices_batch + |
| (torch.arange(bs, dtype=torch.int32).to(self.device) * nv)[:, None, None] |
| ) |
| vertices_ = vertices.reshape((bs * nv, 3)) |
| return vertices_[face.long()] |
|
|
| def vertices_to_tetrahedrons(self, vertices): |
| assert vertices.ndimension() == 3 |
| bs, nv = vertices.shape[:2] |
| tets = ( |
| self.smpl_tetraderon_indices_batch + |
| (torch.arange(bs, dtype=torch.int32).to(self.device) * nv)[:, None, None] |
| ) |
| vertices_ = vertices.reshape((bs * nv, 3)) |
| return vertices_[tets.long()] |
|
|
| def calc_face_centers(self, face_verts): |
| assert len(face_verts.shape) == 4 |
| assert face_verts.shape[2] == 3 |
| assert face_verts.shape[3] == 3 |
| bs, nf = face_verts.shape[:2] |
| face_centers = ( |
| face_verts[:, :, 0, :] + face_verts[:, :, 1, :] + face_verts[:, :, 2, :] |
| ) / 3.0 |
| face_centers = face_centers.reshape((bs, nf, 3)) |
| return face_centers |
|
|
| def calc_face_normals(self, face_verts): |
| assert len(face_verts.shape) == 4 |
| assert face_verts.shape[2] == 3 |
| assert face_verts.shape[3] == 3 |
| bs, nf = face_verts.shape[:2] |
| face_verts = face_verts.reshape((bs * nf, 3, 3)) |
| v10 = face_verts[:, 0] - face_verts[:, 1] |
| v12 = face_verts[:, 2] - face_verts[:, 1] |
| normals = F.normalize(torch.cross(v10, v12), eps=1e-5) |
| normals = normals.reshape((bs, nf, 3)) |
| return normals |
|
|
| def check_input(self, x): |
| if x.device == "cpu": |
| raise TypeError("Voxelization module supports only cuda tensors") |
| if x.type() != "torch.cuda.FloatTensor": |
| raise TypeError("Voxelization module supports only float32 tensors") |
|
|