| import torch |
| import lightning.pytorch as pl |
| |
| from torch.utils.data import DataLoader |
| from partfield.model.UNet.model import ResidualUNet3D |
| from partfield.model.triplane import TriplaneTransformer, get_grid_coord |
| from partfield.model.model_utils import VanillaMLP |
| import torch.nn.functional as F |
| import torch.nn as nn |
| import os |
| import trimesh |
| import skimage |
| import numpy as np |
| import h5py |
| import torch.distributed as dist |
| from partfield.model.PVCNN.encoder_pc import TriPlanePC2Encoder, sample_triplane_feat |
| import json |
| import gc |
| import time |
| from plyfile import PlyData, PlyElement |
|
|
|
|
| class Model(pl.LightningModule): |
| def __init__(self, cfg): |
| super().__init__() |
|
|
| self.save_hyperparameters() |
| self.cfg = cfg |
| self.automatic_optimization = False |
| self.triplane_resolution = cfg.triplane_resolution |
| self.triplane_channels_low = cfg.triplane_channels_low |
| self.triplane_transformer = TriplaneTransformer( |
| input_dim=cfg.triplane_channels_low * 2, |
| transformer_dim=1024, |
| transformer_layers=6, |
| transformer_heads=8, |
| triplane_low_res=32, |
| triplane_high_res=128, |
| triplane_dim=cfg.triplane_channels_high, |
| ) |
| self.sdf_decoder = VanillaMLP(input_dim=64, |
| output_dim=1, |
| out_activation="tanh", |
| n_neurons=64, |
| n_hidden_layers=6) |
| self.use_pvcnn = cfg.use_pvcnnonly |
| self.use_2d_feat = cfg.use_2d_feat |
| if self.use_pvcnn: |
| self.pvcnn = TriPlanePC2Encoder( |
| cfg.pvcnn, |
| device="cuda", |
| shape_min=-1, |
| shape_length=2, |
| use_2d_feat=self.use_2d_feat) |
| self.logit_scale = nn.Parameter(torch.tensor([1.0], requires_grad=True)) |
| self.grid_coord = get_grid_coord(256) |
| self.mse_loss = torch.nn.MSELoss() |
| self.l1_loss = torch.nn.L1Loss(reduction='none') |
|
|
| if cfg.regress_2d_feat: |
| self.feat_decoder = VanillaMLP(input_dim=64, |
| output_dim=192, |
| out_activation="GELU", |
| n_neurons=64, |
| n_hidden_layers=6) |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| @torch.no_grad() |
| def encode(self, points): |
|
|
| N = points.shape[0] |
| |
| pcd = points[..., :3] |
|
|
| pc_feat = self.pvcnn(pcd, pcd) |
|
|
| planes = pc_feat |
| planes = self.triplane_transformer(planes) |
| sdf_planes, part_planes = torch.split(planes, [64, planes.shape[2] - 64], dim=2) |
|
|
| tensor_vertices = pcd.reshape(N, -1, 3).cuda().to(pcd.dtype) |
| point_feat = sample_triplane_feat(part_planes, tensor_vertices) |
| |
| point_feat = point_feat.reshape(N, -1, 448) |
|
|
| return point_feat |