| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import os |
| import sys |
| import subprocess |
| import argparse |
| import logging |
| import glob |
| import numpy as np |
| from time import time |
| import urllib |
|
|
| |
| try: |
| import open3d as o3d |
| except ImportError: |
| raise ImportError( |
| "Please install open3d and scipy with `pip install open3d scipy`." |
| ) |
|
|
| import torch |
| import torch.nn as nn |
| import torch.utils.data |
| import torch.optim as optim |
| from torch.utils.data.sampler import Sampler |
|
|
| import MinkowskiEngine as ME |
|
|
|
|
| class InfSampler(Sampler): |
| """Samples elements randomly, without replacement. |
| |
| Arguments: |
| data_source (Dataset): dataset to sample from |
| """ |
|
|
| def __init__(self, data_source, shuffle=False): |
| self.data_source = data_source |
| self.shuffle = shuffle |
| self.reset_permutation() |
|
|
| def reset_permutation(self): |
| perm = len(self.data_source) |
| if self.shuffle: |
| perm = torch.randperm(perm) |
| self._perm = perm.tolist() |
|
|
| def __iter__(self): |
| return self |
|
|
| def __next__(self): |
| if len(self._perm) == 0: |
| self.reset_permutation() |
| return self._perm.pop() |
|
|
| def __len__(self): |
| return len(self.data_source) |
|
|
|
|
| def resample_mesh(mesh_cad, density=1): |
| """ |
| https://chrischoy.github.io/research/barycentric-coordinate-for-mesh-sampling/ |
| Samples point cloud on the surface of the model defined as vectices and |
| faces. This function uses vectorized operations so fast at the cost of some |
| memory. |
| |
| param mesh_cad: low-polygon triangle mesh in o3d.geometry.TriangleMesh |
| param density: density of the point cloud per unit area |
| param return_numpy: return numpy format or open3d pointcloud format |
| return resampled point cloud |
| |
| Reference : |
| [1] Barycentric coordinate system |
| \begin{align} |
| P = (1 - \sqrt{r_1})A + \sqrt{r_1} (1 - r_2) B + \sqrt{r_1} r_2 C |
| \end{align} |
| """ |
| faces = np.array(mesh_cad.triangles).astype(int) |
| vertices = np.array(mesh_cad.vertices) |
|
|
| vec_cross = np.cross( |
| vertices[faces[:, 0], :] - vertices[faces[:, 2], :], |
| vertices[faces[:, 1], :] - vertices[faces[:, 2], :], |
| ) |
| face_areas = np.sqrt(np.sum(vec_cross ** 2, 1)) |
|
|
| n_samples = (np.sum(face_areas) * density).astype(int) |
| |
|
|
| |
| |
| n_samples_per_face = np.ceil(density * face_areas).astype(int) |
| floor_num = np.sum(n_samples_per_face) - n_samples |
| if floor_num > 0: |
| indices = np.where(n_samples_per_face > 0)[0] |
| floor_indices = np.random.choice(indices, floor_num, replace=True) |
| n_samples_per_face[floor_indices] -= 1 |
|
|
| n_samples = np.sum(n_samples_per_face) |
|
|
| |
| sample_face_idx = np.zeros((n_samples,), dtype=int) |
| acc = 0 |
| for face_idx, _n_sample in enumerate(n_samples_per_face): |
| sample_face_idx[acc : acc + _n_sample] = face_idx |
| acc += _n_sample |
|
|
| r = np.random.rand(n_samples, 2) |
| A = vertices[faces[sample_face_idx, 0], :] |
| B = vertices[faces[sample_face_idx, 1], :] |
| C = vertices[faces[sample_face_idx, 2], :] |
|
|
| P = ( |
| (1 - np.sqrt(r[:, 0:1])) * A |
| + np.sqrt(r[:, 0:1]) * (1 - r[:, 1:]) * B |
| + np.sqrt(r[:, 0:1]) * r[:, 1:] * C |
| ) |
|
|
| return P |
|
|
|
|
| M = np.array( |
| [ |
| [0.80656762, -0.5868724, -0.07091862], |
| [0.3770505, 0.418344, 0.82632997], |
| [-0.45528188, -0.6932309, 0.55870326], |
| ] |
| ) |
|
|
| assert ( |
| int(o3d.__version__.split(".")[1]) >= 8 |
| ), f"Requires open3d version >= 0.8, the current version is {o3d.__version__}" |
|
|
| if not os.path.exists("ModelNet40"): |
| logging.info("Downloading the pruned ModelNet40 dataset...") |
| subprocess.run(["sh", "./examples/download_modelnet40.sh"]) |
|
|
|
|
| |
| |
| |
| def PointCloud(points, colors=None): |
|
|
| pcd = o3d.geometry.PointCloud() |
| pcd.points = o3d.utility.Vector3dVector(points) |
| if colors is not None: |
| pcd.colors = o3d.utility.Vector3dVector(colors) |
| return pcd |
|
|
|
|
| def collate_pointcloud_fn(list_data): |
| coords, feats, labels = list(zip(*list_data)) |
|
|
| |
| return { |
| "coords": coords, |
| "xyzs": [torch.from_numpy(feat).float() for feat in feats], |
| "labels": torch.LongTensor(labels), |
| } |
|
|
|
|
| class ModelNet40Dataset(torch.utils.data.Dataset): |
| def __init__(self, phase, transform=None, config=None): |
| self.phase = phase |
| self.files = [] |
| self.cache = {} |
| self.data_objects = [] |
| self.transform = transform |
| self.resolution = config.resolution |
| self.last_cache_percent = 0 |
|
|
| self.root = "./ModelNet40" |
| fnames = glob.glob(os.path.join(self.root, "chair/train/*.off")) |
| fnames = sorted([os.path.relpath(fname, self.root) for fname in fnames]) |
| self.files = fnames |
| assert len(self.files) > 0, "No file loaded" |
| logging.info( |
| f"Loading the subset {phase} from {self.root} with {len(self.files)} files" |
| ) |
| self.density = 30000 |
|
|
| |
| o3d.utility.set_verbosity_level(o3d.utility.VerbosityLevel.Error) |
|
|
| def __len__(self): |
| return len(self.files) |
|
|
| def __getitem__(self, idx): |
| mesh_file = os.path.join(self.root, self.files[idx]) |
| if idx in self.cache: |
| xyz = self.cache[idx] |
| else: |
| |
| assert os.path.exists(mesh_file) |
| pcd = o3d.io.read_triangle_mesh(mesh_file) |
| |
| vertices = np.asarray(pcd.vertices) |
| vmax = vertices.max(0, keepdims=True) |
| vmin = vertices.min(0, keepdims=True) |
| pcd.vertices = o3d.utility.Vector3dVector( |
| (vertices - vmin) / (vmax - vmin).max() |
| ) |
|
|
| |
| xyz = resample_mesh(pcd, density=self.density) |
| self.cache[idx] = xyz |
| cache_percent = int((len(self.cache) / len(self)) * 100) |
| if ( |
| cache_percent > 0 |
| and cache_percent % 10 == 0 |
| and cache_percent != self.last_cache_percent |
| ): |
| logging.info( |
| f"Cached {self.phase}: {len(self.cache)} / {len(self)}: {cache_percent}%" |
| ) |
| self.last_cache_percent = cache_percent |
|
|
| |
| feats = np.ones((len(xyz), 1)) |
|
|
| if len(xyz) < 1000: |
| logging.info( |
| f"Skipping {mesh_file}: does not have sufficient CAD sampling density after resampling: {len(xyz)}." |
| ) |
| return None |
|
|
| if self.transform: |
| xyz, feats = self.transform(xyz, feats) |
|
|
| |
| xyz = xyz * self.resolution |
| coords, inds = ME.utils.sparse_quantize(xyz, return_index=True) |
|
|
| return (coords, xyz[inds], idx) |
|
|
|
|
| def make_data_loader( |
| phase, augment_data, batch_size, shuffle, num_workers, repeat, config |
| ): |
| dset = ModelNet40Dataset(phase, config=config) |
|
|
| args = { |
| "batch_size": batch_size, |
| "num_workers": num_workers, |
| "collate_fn": collate_pointcloud_fn, |
| "pin_memory": False, |
| "drop_last": False, |
| } |
|
|
| if repeat: |
| args["sampler"] = InfSampler(dset, shuffle) |
| else: |
| args["shuffle"] = shuffle |
|
|
| loader = torch.utils.data.DataLoader(dset, **args) |
|
|
| return loader |
|
|
|
|
| ch = logging.StreamHandler(sys.stdout) |
| logging.getLogger().setLevel(logging.INFO) |
| logging.basicConfig( |
| format=os.uname()[1].split(".")[0] + " %(asctime)s %(message)s", |
| datefmt="%m/%d %H:%M:%S", |
| handlers=[ch], |
| ) |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument("--resolution", type=int, default=128) |
| parser.add_argument("--max_iter", type=int, default=30000) |
| parser.add_argument("--val_freq", type=int, default=1000) |
| parser.add_argument("--batch_size", default=16, type=int) |
| parser.add_argument("--lr", default=1e-2, type=float) |
| parser.add_argument("--momentum", type=float, default=0.9) |
| parser.add_argument("--weight_decay", type=float, default=1e-4) |
| parser.add_argument("--num_workers", type=int, default=1) |
| parser.add_argument("--stat_freq", type=int, default=50) |
| parser.add_argument("--weights", type=str, default="modelnet_reconstruction.pth") |
| parser.add_argument("--load_optimizer", type=str, default="true") |
| parser.add_argument("--eval", action="store_true") |
| parser.add_argument("--max_visualization", type=int, default=4) |
|
|
| |
| |
| |
|
|
|
|
| class GenerativeNet(nn.Module): |
|
|
| CHANNELS = [1024, 512, 256, 128, 64, 32, 16] |
|
|
| def __init__(self, resolution, in_nchannel=512): |
| nn.Module.__init__(self) |
|
|
| self.resolution = resolution |
|
|
| |
| ch = self.CHANNELS |
|
|
| |
| self.block1 = nn.Sequential( |
| ME.MinkowskiGenerativeConvolutionTranspose( |
| in_nchannel, ch[0], kernel_size=2, stride=2, dimension=3 |
| ), |
| ME.MinkowskiBatchNorm(ch[0]), |
| ME.MinkowskiELU(), |
| ME.MinkowskiConvolution(ch[0], ch[0], kernel_size=3, dimension=3), |
| ME.MinkowskiBatchNorm(ch[0]), |
| ME.MinkowskiELU(), |
| ME.MinkowskiGenerativeConvolutionTranspose( |
| ch[0], ch[1], kernel_size=2, stride=2, dimension=3 |
| ), |
| ME.MinkowskiBatchNorm(ch[1]), |
| ME.MinkowskiELU(), |
| ME.MinkowskiConvolution(ch[1], ch[1], kernel_size=3, dimension=3), |
| ME.MinkowskiBatchNorm(ch[1]), |
| ME.MinkowskiELU(), |
| ) |
|
|
| self.block1_cls = ME.MinkowskiConvolution( |
| ch[1], 1, kernel_size=1, bias=True, dimension=3 |
| ) |
|
|
| |
| self.block2 = nn.Sequential( |
| ME.MinkowskiGenerativeConvolutionTranspose( |
| ch[1], ch[2], kernel_size=2, stride=2, dimension=3 |
| ), |
| ME.MinkowskiBatchNorm(ch[2]), |
| ME.MinkowskiELU(), |
| ME.MinkowskiConvolution(ch[2], ch[2], kernel_size=3, dimension=3), |
| ME.MinkowskiBatchNorm(ch[2]), |
| ME.MinkowskiELU(), |
| ) |
|
|
| self.block2_cls = ME.MinkowskiConvolution( |
| ch[2], 1, kernel_size=1, bias=True, dimension=3 |
| ) |
|
|
| |
| self.block3 = nn.Sequential( |
| ME.MinkowskiGenerativeConvolutionTranspose( |
| ch[2], ch[3], kernel_size=2, stride=2, dimension=3 |
| ), |
| ME.MinkowskiBatchNorm(ch[3]), |
| ME.MinkowskiELU(), |
| ME.MinkowskiConvolution(ch[3], ch[3], kernel_size=3, dimension=3), |
| ME.MinkowskiBatchNorm(ch[3]), |
| ME.MinkowskiELU(), |
| ) |
|
|
| self.block3_cls = ME.MinkowskiConvolution( |
| ch[3], 1, kernel_size=1, bias=True, dimension=3 |
| ) |
|
|
| |
| self.block4 = nn.Sequential( |
| ME.MinkowskiGenerativeConvolutionTranspose( |
| ch[3], ch[4], kernel_size=2, stride=2, dimension=3 |
| ), |
| ME.MinkowskiBatchNorm(ch[4]), |
| ME.MinkowskiELU(), |
| ME.MinkowskiConvolution(ch[4], ch[4], kernel_size=3, dimension=3), |
| ME.MinkowskiBatchNorm(ch[4]), |
| ME.MinkowskiELU(), |
| ) |
|
|
| self.block4_cls = ME.MinkowskiConvolution( |
| ch[4], 1, kernel_size=1, bias=True, dimension=3 |
| ) |
|
|
| |
| self.block5 = nn.Sequential( |
| ME.MinkowskiGenerativeConvolutionTranspose( |
| ch[4], ch[5], kernel_size=2, stride=2, dimension=3 |
| ), |
| ME.MinkowskiBatchNorm(ch[5]), |
| ME.MinkowskiELU(), |
| ME.MinkowskiConvolution(ch[5], ch[5], kernel_size=3, dimension=3), |
| ME.MinkowskiBatchNorm(ch[5]), |
| ME.MinkowskiELU(), |
| ) |
|
|
| self.block5_cls = ME.MinkowskiConvolution( |
| ch[5], 1, kernel_size=1, bias=True, dimension=3 |
| ) |
|
|
| |
| self.block6 = nn.Sequential( |
| ME.MinkowskiGenerativeConvolutionTranspose( |
| ch[5], ch[6], kernel_size=2, stride=2, dimension=3 |
| ), |
| ME.MinkowskiBatchNorm(ch[6]), |
| ME.MinkowskiELU(), |
| ME.MinkowskiConvolution(ch[6], ch[6], kernel_size=3, dimension=3), |
| ME.MinkowskiBatchNorm(ch[6]), |
| ME.MinkowskiELU(), |
| ) |
|
|
| self.block6_cls = ME.MinkowskiConvolution( |
| ch[6], 1, kernel_size=1, bias=True, dimension=3 |
| ) |
|
|
| |
| self.pruning = ME.MinkowskiPruning() |
|
|
| @torch.no_grad() |
| def get_target(self, out, target_key, kernel_size=1): |
| target = torch.zeros(len(out), dtype=torch.bool, device=out.device) |
| cm = out.coordinate_manager |
| strided_target_key = cm.stride( |
| target_key, |
| out.tensor_stride[0], |
| ) |
| kernel_map = cm.kernel_map( |
| out.coordinate_map_key, |
| strided_target_key, |
| kernel_size=kernel_size, |
| region_type=1, |
| ) |
| for k, curr_in in kernel_map.items(): |
| target[curr_in[0].long()] = 1 |
| return target |
|
|
| def valid_batch_map(self, batch_map): |
| for b in batch_map: |
| if len(b) == 0: |
| return False |
| return True |
|
|
| def forward(self, z, target_key): |
| out_cls, targets = [], [] |
|
|
| |
| out1 = self.block1(z) |
| out1_cls = self.block1_cls(out1) |
| target = self.get_target(out1, target_key) |
| targets.append(target) |
| out_cls.append(out1_cls) |
| keep1 = (out1_cls.F > 0).squeeze() |
|
|
| |
| if self.training: |
| keep1 += target |
|
|
| |
| out1 = self.pruning(out1, keep1) |
|
|
| |
| out2 = self.block2(out1) |
| out2_cls = self.block2_cls(out2) |
| target = self.get_target(out2, target_key) |
| targets.append(target) |
| out_cls.append(out2_cls) |
| keep2 = (out2_cls.F > 0).squeeze() |
|
|
| if self.training: |
| keep2 += target |
|
|
| |
| out2 = self.pruning(out2, keep2) |
|
|
| |
| out3 = self.block3(out2) |
| out3_cls = self.block3_cls(out3) |
| target = self.get_target(out3, target_key) |
| targets.append(target) |
| out_cls.append(out3_cls) |
| keep3 = (out3_cls.F > 0).squeeze() |
|
|
| if self.training: |
| keep3 += target |
|
|
| |
| out3 = self.pruning(out3, keep3) |
|
|
| |
| out4 = self.block4(out3) |
| out4_cls = self.block4_cls(out4) |
| target = self.get_target(out4, target_key) |
| targets.append(target) |
| out_cls.append(out4_cls) |
| keep4 = (out4_cls.F > 0).squeeze() |
|
|
| if self.training: |
| keep4 += target |
|
|
| |
| out4 = self.pruning(out4, keep4) |
|
|
| |
| out5 = self.block5(out4) |
| out5_cls = self.block5_cls(out5) |
| target = self.get_target(out5, target_key) |
| targets.append(target) |
| out_cls.append(out5_cls) |
| keep5 = (out5_cls.F > 0).squeeze() |
|
|
| if self.training: |
| keep5 += target |
|
|
| |
| out5 = self.pruning(out5, keep5) |
|
|
| |
| out6 = self.block6(out5) |
| out6_cls = self.block6_cls(out6) |
| target = self.get_target(out6, target_key) |
| targets.append(target) |
| out_cls.append(out6_cls) |
| keep6 = (out6_cls.F > 0).squeeze() |
|
|
| |
| |
| |
|
|
| |
| out6 = self.pruning(out6, keep6) |
|
|
| return out_cls, targets, out6 |
|
|
|
|
| def train(net, dataloader, device, config): |
| in_nchannel = len(dataloader.dataset) |
|
|
| optimizer = optim.SGD( |
| net.parameters(), |
| lr=config.lr, |
| momentum=config.momentum, |
| weight_decay=config.weight_decay, |
| ) |
| scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.95) |
|
|
| crit = nn.BCEWithLogitsLoss() |
|
|
| net.train() |
| train_iter = iter(dataloader) |
| |
| logging.info(f"LR: {scheduler.get_lr()}") |
| for i in range(config.max_iter): |
|
|
| s = time() |
| data_dict = train_iter.next() |
| d = time() - s |
|
|
| optimizer.zero_grad() |
| init_coords = torch.zeros((config.batch_size, 4), dtype=torch.int) |
| init_coords[:, 0] = torch.arange(config.batch_size) |
|
|
| in_feat = torch.zeros((config.batch_size, in_nchannel)) |
| in_feat[torch.arange(config.batch_size), data_dict["labels"]] = 1 |
|
|
| sin = ME.SparseTensor( |
| features=in_feat, |
| coordinates=init_coords, |
| tensor_stride=config.resolution, |
| device=device, |
| ) |
|
|
| |
| cm = sin.coordinate_manager |
| target_key, _ = cm.insert_and_map( |
| ME.utils.batched_coordinates(data_dict["xyzs"]).to(device), |
| string_id="target", |
| ) |
| |
| out_cls, targets, sout = net(sin, target_key) |
| num_layers, loss = len(out_cls), 0 |
| losses = [] |
| for out_cl, target in zip(out_cls, targets): |
| curr_loss = crit(out_cl.F.squeeze(), target.type(out_cl.F.dtype).to(device)) |
| losses.append(curr_loss.item()) |
| loss += curr_loss / num_layers |
|
|
| loss.backward() |
| optimizer.step() |
| t = time() - s |
|
|
| if i % config.stat_freq == 0: |
| logging.info( |
| f"Iter: {i}, Loss: {loss.item():.3e}, Depths: {len(out_cls)} Data Loading Time: {d:.3e}, Tot Time: {t:.3e}" |
| ) |
|
|
| if i % config.val_freq == 0 and i > 0: |
| torch.save( |
| { |
| "state_dict": net.state_dict(), |
| "optimizer": optimizer.state_dict(), |
| "scheduler": scheduler.state_dict(), |
| "curr_iter": i, |
| }, |
| config.weights, |
| ) |
|
|
| scheduler.step() |
| logging.info(f"LR: {scheduler.get_lr()}") |
|
|
| net.train() |
|
|
|
|
| def visualize(net, dataloader, device, config): |
| in_nchannel = len(dataloader.dataset) |
| net.eval() |
| crit = nn.BCEWithLogitsLoss() |
| n_vis = 0 |
|
|
| for data_dict in dataloader: |
|
|
| init_coords = torch.zeros((config.batch_size, 4), dtype=torch.int) |
| init_coords[:, 0] = torch.arange(config.batch_size) |
|
|
| in_feat = torch.zeros((config.batch_size, in_nchannel)) |
| in_feat[torch.arange(config.batch_size), data_dict["labels"]] = 1 |
|
|
| sin = ME.SparseTensor( |
| features=in_feat, |
| coordinates=init_coords, |
| tensor_stride=config.resolution, |
| device=device, |
| ) |
|
|
| |
| cm = sin.coordinate_manager |
| target_key, _ = cm.insert_and_map( |
| ME.utils.batched_coordinates(data_dict["xyzs"]).to(device), |
| string_id="target", |
| ) |
|
|
| |
| out_cls, targets, sout = net(sin, target_key) |
| num_layers, loss = len(out_cls), 0 |
| for out_cl, target in zip(out_cls, targets): |
| loss += ( |
| crit(out_cl.F.squeeze(), target.type(out_cl.F.dtype).to(device)) |
| / num_layers |
| ) |
|
|
| batch_coords, batch_feats = sout.decomposed_coordinates_and_features |
| for b, (coords, feats) in enumerate(zip(batch_coords, batch_feats)): |
| pcd = PointCloud(coords.cpu()) |
| pcd.estimate_normals() |
| pcd.translate([0.6 * config.resolution, 0, 0]) |
| pcd.rotate(M) |
| opcd = PointCloud(data_dict["xyzs"][b]) |
| opcd.translate([-0.6 * config.resolution, 0, 0]) |
| opcd.estimate_normals() |
| opcd.rotate(M) |
| o3d.visualization.draw_geometries([pcd, opcd]) |
|
|
| n_vis += 1 |
| if n_vis > config.max_visualization: |
| return |
|
|
|
|
| if __name__ == "__main__": |
| config = parser.parse_args() |
| logging.info(config) |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| dataloader = make_data_loader( |
| "val", |
| augment_data=True, |
| batch_size=config.batch_size, |
| shuffle=True, |
| num_workers=config.num_workers, |
| repeat=True, |
| config=config, |
| ) |
| in_nchannel = len(dataloader.dataset) |
|
|
| net = GenerativeNet(config.resolution, in_nchannel=in_nchannel) |
| net.to(device) |
|
|
| logging.info(net) |
|
|
| if not config.eval: |
| train(net, dataloader, device, config) |
| else: |
| if not os.path.exists(config.weights): |
| logging.info(f"Downloaing pretrained weights. This might take a while...") |
| urllib.request.urlretrieve( |
| "https://bit.ly/36d9m1n", filename=config.weights |
| ) |
|
|
| logging.info(f"Loading weights from {config.weights}") |
| checkpoint = torch.load(config.weights) |
| net.load_state_dict(checkpoint["state_dict"]) |
|
|
| visualize(net, dataloader, device, config) |
|
|