GuideFlow3D / third_party /PartField /partfield /model_trainer_pvcnn_only_demo.py
suvadityamuk's picture
feat: add initial files for space
382733a
raw
history blame
5.24 kB
import torch
import lightning.pytorch as pl
import torch.nn as nn
import os
import trimesh
import numpy as np
from torch.utils.data import DataLoader
from third_party.PartField.partfield.model.PVCNN.encoder_pc import TriPlanePC2Encoder, sample_triplane_feat
from third_party.PartField.partfield.model.triplane import TriplaneTransformer, get_grid_coord #, sample_from_planes, Voxel2Triplane
from third_party.PartField.partfield.model.model_utils import VanillaMLP
from third_party.PartField.partfield.dataloader import Demo_Dataset
class Model(pl.LightningModule):
def __init__(self, cfg, obj_path):
super().__init__()
self.obj_path = obj_path
self.save_hyperparameters()
self.cfg = cfg
self.automatic_optimization = False
self.triplane_resolution = cfg.triplane_resolution
self.triplane_channels_low = cfg.triplane_channels_low
self.triplane_transformer = TriplaneTransformer(
input_dim=cfg.triplane_channels_low * 2,
transformer_dim=1024,
transformer_layers=6,
transformer_heads=8,
triplane_low_res=32,
triplane_high_res=128,
triplane_dim=cfg.triplane_channels_high,
)
self.sdf_decoder = VanillaMLP(input_dim=64,
output_dim=1,
out_activation="tanh",
n_neurons=64, #64
n_hidden_layers=6) #6
self.use_pvcnn = cfg.use_pvcnnonly
self.use_2d_feat = cfg.use_2d_feat
if self.use_pvcnn:
self.pvcnn = TriPlanePC2Encoder(
cfg.pvcnn,
device="cuda",
shape_min=-1,
shape_length=2,
use_2d_feat=self.use_2d_feat) #.cuda()
self.logit_scale = nn.Parameter(torch.tensor([1.0], requires_grad=True))
self.grid_coord = get_grid_coord(256)
self.mse_loss = torch.nn.MSELoss()
self.l1_loss = torch.nn.L1Loss(reduction='none')
if cfg.regress_2d_feat:
self.feat_decoder = VanillaMLP(input_dim=64,
output_dim=192,
out_activation="GELU",
n_neurons=64, #64
n_hidden_layers=6) #6
def predict_dataloader(self):
dataset = Demo_Dataset(self.obj_path)
dataloader = DataLoader(dataset,
num_workers=self.cfg.dataset.val_num_workers,
batch_size=self.cfg.dataset.val_batch_size,
shuffle=False,
pin_memory=True,
drop_last=False)
return dataloader
@torch.no_grad()
def predict_step(self, batch, batch_idx):
N = batch['pc'].shape[0]
assert N == 1
pc_feat = self.pvcnn(batch['pc'], batch['pc'])
planes = pc_feat
planes = self.triplane_transformer(planes)
sdf_planes, part_planes = torch.split(planes, [64, planes.shape[2] - 64], dim=2)
def sample_points(vertices, faces, n_point_per_face):
# Generate random barycentric coordinates
# borrowed from Kaolin https://github.com/NVIDIAGameWorks/kaolin/blob/master/kaolin/ops/mesh/trianglemesh.py#L43
n_f = faces.shape[0]
u = torch.sqrt(torch.rand((n_f, n_point_per_face, 1),
device=vertices.device,
dtype=vertices.dtype))
v = torch.rand((n_f, n_point_per_face, 1),
device=vertices.device,
dtype=vertices.dtype)
w0 = 1 - u
w1 = u * (1 - v)
w2 = u * v
face_v_0 = torch.index_select(vertices, 0, faces[:, 0].reshape(-1))
face_v_1 = torch.index_select(vertices, 0, faces[:, 1].reshape(-1))
face_v_2 = torch.index_select(vertices, 0, faces[:, 2].reshape(-1))
points = w0 * face_v_0.unsqueeze(dim=1) + w1 * face_v_1.unsqueeze(dim=1) + w2 * face_v_2.unsqueeze(dim=1)
return points
def sample_and_mean_memory_save_version(part_planes, tensor_vertices, n_point_per_face):
n_sample_each = self.cfg.n_sample_each # we iterate over this to avoid OOM
n_v = tensor_vertices.shape[1]
n_sample = n_v // n_sample_each + 1
all_sample = []
for i_sample in range(n_sample):
sampled_feature = sample_triplane_feat(part_planes, tensor_vertices[:, i_sample * n_sample_each: i_sample * n_sample_each + n_sample_each,])
assert sampled_feature.shape[1] % n_point_per_face == 0
sampled_feature = sampled_feature.reshape(1, -1, n_point_per_face, sampled_feature.shape[-1])
sampled_feature = torch.mean(sampled_feature, axis=-2)
all_sample.append(sampled_feature)
return torch.cat(all_sample, dim=1)
part_planes = part_planes.cpu().numpy()
return part_planes, batch['uid'][0]