Spaces:
Runtime error
Runtime error
| import torch | |
| import trimesh | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class PrimSDF(nn.Module): | |
| def __init__(self, mesh_obj=None, f_sdf=None, geo_fn=None, asset_list=None, num_prims=1024, dim_feat=6, prim_shape=8, init_scale=0.05, sdf2alpha_var=0.005, auto_scale_init=True, init_sampling="uniform"): | |
| super().__init__() | |
| self.num_prims = num_prims | |
| # 6 channels features - [SDF, R, G, B, roughness, metallic] | |
| self.dim_feat = dim_feat | |
| self.prim_shape = prim_shape | |
| self.sdf_sampled_point = None | |
| self.auto_scale_init = auto_scale_init | |
| self.init_sampling = init_sampling | |
| self.sdf2alpha_var = sdf2alpha_var | |
| # assume the mesh is normalized to [-1, 1] cube | |
| self.mesh_obj = mesh_obj | |
| self.f_sdf = f_sdf | |
| # N x (D x S^3 + 3(Global Translation) + 1(Global Scale)) | |
| self.srt_param = nn.parameter.Parameter(torch.zeros(self.num_prims, 1 + 3)) | |
| self.feat_param = nn.parameter.Parameter(torch.zeros(self.num_prims, self.dim_feat * (self.prim_shape ** 3))) | |
| self.geo_start_index = 0 | |
| self.geo_end_index = self.geo_start_index + self.prim_shape ** 3 # non-inclusive | |
| self.tex_start_index = self.geo_end_index | |
| self.tex_end_index = self.tex_start_index + self.prim_shape ** 3 * 3 # non-inclusive | |
| self.mat_start_index = self.tex_end_index | |
| self.mat_end_index = self.mat_start_index + self.prim_shape ** 3 * 2 | |
| # sampled_point -> local grid | |
| # local_grid - [prim_shape^3, 3] | |
| xx = torch.linspace(-1, 1, self.prim_shape) | |
| # two ways to sample xyz-axis aligned local grids: 1st is ij indexing | |
| meshx, meshy, meshz = torch.meshgrid(xx, xx, xx, indexing='ij') | |
| local_grid = torch.stack((meshz, meshy, meshx), dim=-1).reshape(-1, 3) | |
| self.local_grid = local_grid | |
| # second is xy indexing, equivalent to the first one | |
| # meshx, meshy, meshz = torch.meshgrid(xx, xx, xx, indexing='xy') | |
| # local_grid = torch.stack((meshz, meshx, meshy), dim=-1).reshape(-1, 3) | |
| if self.f_sdf is not None and geo_fn is not None and asset_list is not None: | |
| self._init_param(init_scale=init_scale, geo_fn=geo_fn, asset_list=asset_list, sampling=self.init_sampling) | |
| def _init_param(self, init_scale, geo_fn, asset_list, sampling="uniform"): | |
| pass | |
| def forward(self, x): | |
| # x - [bs, 3] | |
| bs = x.shape[0] | |
| weights = self.prim_weight(x) | |
| output = self.grid_sample_feat(x, weights) | |
| preds = {} | |
| preds['sdf'] = output[:, 0:1] | |
| # RGB | |
| preds['tex'] = torch.clip(output[:, 1:4], min=0.0, max=1.0) | |
| # roughness, metallic | |
| preds['mat'] = torch.clip(output[:, 4:6], min=0.0, max=1.0) | |
| return preds | |
| def grid_sample_feat(self, x, weights): | |
| # implementation of I_V -> trilinear grid sample of V_i | |
| # x - [bs, 3] | |
| # weights - [bs, n_prims] | |
| bs = x.shape[0] | |
| sampled_point = (x[:, None, :] - self.pos[None, ...]) / self.scale[None, ...] | |
| mask = weights > 0 | |
| ind_bs, ind_nprim = torch.where(weights > 0) | |
| masked_sampled_point = sampled_point[ind_bs, ind_nprim, :].reshape(ind_nprim.shape[0], 1, 1, 1, 3) | |
| feat4sample = self.feat[ind_nprim, :].reshape(ind_nprim.shape[0], self.dim_feat, self.prim_shape, self.prim_shape, self.prim_shape) | |
| sampled_feat = F.grid_sample(feat4sample, masked_sampled_point, mode='bilinear', padding_mode='zeros', align_corners=True).reshape(ind_nprim.shape[0], self.dim_feat) | |
| weighted_sampled_feat = sampled_feat * weights[mask][:, None] | |
| weighted_feat = torch.zeros(bs, self.dim_feat).to(x) | |
| weighted_feat.index_add_(0, ind_bs, weighted_sampled_feat) | |
| # at inference time, fill in approximated SDF value for region not covered by prims | |
| if not self.training: | |
| # get mask for points not covered by prims | |
| bs_mask = weights.sum(1) <= 0 | |
| # get nearest prim index | |
| dist = torch.norm(x[bs_mask, None, :] - self.pos[None, ...], p=2, dim=-1) | |
| _, min_dist_ind = dist.min(1) | |
| nearest_prim_pos = self.pos[min_dist_ind, :] | |
| nearest_prim_scale = self.scale[min_dist_ind, :] | |
| # in each nearest prim, get nearest voxel points | |
| candidate_nearest_pts = nearest_prim_pos[:, None, :] + nearest_prim_scale[..., None] * self.local_grid.to(x)[None, :] | |
| pts_dist = torch.norm(x[bs_mask, None, :] - candidate_nearest_pts, p=2, dim=-1) | |
| min_dist, min_dist_pts_ind = pts_dist.min(1) | |
| # get the SDF value as a nearest valid SDF value | |
| min_pts_sdf = self.feat_geo[min_dist_ind, min_dist_pts_ind] | |
| # approximate SDF value with the same sign distance + L2 distance | |
| approx_sdf = min_pts_sdf + min_dist * torch.sign(min_pts_sdf) | |
| weighted_feat[bs_mask, 0:1] = approx_sdf[:, None] | |
| return weighted_feat | |
| def prim_weight(self, x): | |
| # x - [bs, 3] | |
| weights = F.relu(1 - torch.norm((x[:, None, :] - self.pos[None, ...]) / self.scale[None, ...], p = float('inf'), dim=-1)) | |
| # weight - [bs, N] | |
| normalized_weights = weights / (torch.sum(weights, dim=-1, keepdim=True) + 1e-6) | |
| return normalized_weights | |
| def sdf2alpha(self, sdf): | |
| return torch.exp(-(sdf / self.sdf2alpha_var) ** 2) | |
| def pos(self): | |
| return self.srt_param[:, 1:4] | |
| def scale(self): | |
| return self.srt_param[:, 0:1] | |
| def feat(self): | |
| return self.feat_param | |
| def feat_geo(self): | |
| return self.feat_param[:, self.geo_start_index:self.geo_end_index] | |
| def feat_tex(self): | |
| return self.feat_param[:, self.tex_start_index:self.tex_end_index] | |
| def feat_mat(self): | |
| return self.feat_param[:, self.mat_start_index:self.mat_end_index] |