# Copyright (c) Chris Choy (chrischoy@ai.stanford.edu). # # Permission is hereby granted, free of charge, to any person obtaining a copy of # this software and associated documentation files (the "Software"), to deal in # the Software without restriction, including without limitation the rights to # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies # of the Software, and to permit persons to whom the Software is furnished to do # so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # # Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural # Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part # of the code. import os import sys import subprocess import argparse import logging import numpy as np from time import time import urllib # Must be imported before large libs try: import open3d as o3d except ImportError: raise ImportError("Please install open3d with `pip install open3d`.") import torch import torch.nn as nn import torch.utils.data import torch.optim as optim import MinkowskiEngine as ME from examples.reconstruction import ModelNet40Dataset, InfSampler M = np.array( [ [0.80656762, -0.5868724, -0.07091862], [0.3770505, 0.418344, 0.82632997], [-0.45528188, -0.6932309, 0.55870326], ] ) assert ( int(o3d.__version__.split(".")[1]) >= 8 ), f"Requires open3d version >= 0.8, the current version is {o3d.__version__}" if not os.path.exists("ModelNet40"): logging.info("Downloading the pruned ModelNet40 dataset...") subprocess.run(["sh", "./examples/download_modelnet40.sh"]) ############################################################################### # Utility functions ############################################################################### def PointCloud(points, colors=None): pcd = o3d.geometry.PointCloud() pcd.points = o3d.utility.Vector3dVector(points) if colors is not None: pcd.colors = o3d.utility.Vector3dVector(colors) return pcd class CollationAndTransformation: def __init__(self, resolution): self.resolution = resolution def random_crop(self, coords_list): crop_coords_list = [] for coords in coords_list: sel = coords[:, 0] < self.resolution / 3 crop_coords_list.append(coords[sel]) return crop_coords_list def __call__(self, list_data): coords, feats, labels = list(zip(*list_data)) coords = self.random_crop(coords) # Concatenate all lists return { "coords": ME.utils.batched_coordinates(coords), "xyzs": [torch.from_numpy(feat).float() for feat in feats], "cropped_coords": coords, "labels": torch.LongTensor(labels), } def make_data_loader( phase, augment_data, batch_size, shuffle, num_workers, repeat, config ): dset = ModelNet40Dataset(phase, config=config) args = { "batch_size": batch_size, "num_workers": num_workers, "collate_fn": CollationAndTransformation(config.resolution), "pin_memory": False, "drop_last": False, } if repeat: args["sampler"] = InfSampler(dset, shuffle) else: args["shuffle"] = shuffle loader = torch.utils.data.DataLoader(dset, **args) return loader ch = logging.StreamHandler(sys.stdout) logging.getLogger().setLevel(logging.INFO) logging.basicConfig( format="%(asctime)s %(message)s", datefmt="%m/%d %H:%M:%S", handlers=[ch], ) parser = argparse.ArgumentParser() parser.add_argument("--resolution", type=int, default=128) parser.add_argument("--max_iter", type=int, default=30000) parser.add_argument("--val_freq", type=int, default=1000) parser.add_argument("--batch_size", default=16, type=int) parser.add_argument("--lr", default=1e-2, type=float) parser.add_argument("--momentum", type=float, default=0.9) parser.add_argument("--weight_decay", type=float, default=1e-4) parser.add_argument("--num_workers", type=int, default=1) parser.add_argument("--stat_freq", type=int, default=50) parser.add_argument("--weights", type=str, default="modelnet_completion.pth") parser.add_argument("--load_optimizer", type=str, default="true") parser.add_argument("--eval", action="store_true") parser.add_argument("--max_visualization", type=int, default=4) ############################################################################### # End of utility functions ############################################################################### class CompletionNet(nn.Module): ENC_CHANNELS = [16, 32, 64, 128, 256, 512, 1024] DEC_CHANNELS = [16, 32, 64, 128, 256, 512, 1024] def __init__(self, resolution, in_nchannel=512): nn.Module.__init__(self) self.resolution = resolution # Input sparse tensor must have tensor stride 128. enc_ch = self.ENC_CHANNELS dec_ch = self.DEC_CHANNELS # Encoder self.enc_block_s1 = nn.Sequential( ME.MinkowskiConvolution(1, enc_ch[0], kernel_size=3, stride=1, dimension=3), ME.MinkowskiBatchNorm(enc_ch[0]), ME.MinkowskiELU(), ) self.enc_block_s1s2 = nn.Sequential( ME.MinkowskiConvolution( enc_ch[0], enc_ch[1], kernel_size=2, stride=2, dimension=3 ), ME.MinkowskiBatchNorm(enc_ch[1]), ME.MinkowskiELU(), ME.MinkowskiConvolution(enc_ch[1], enc_ch[1], kernel_size=3, dimension=3), ME.MinkowskiBatchNorm(enc_ch[1]), ME.MinkowskiELU(), ) self.enc_block_s2s4 = nn.Sequential( ME.MinkowskiConvolution( enc_ch[1], enc_ch[2], kernel_size=2, stride=2, dimension=3 ), ME.MinkowskiBatchNorm(enc_ch[2]), ME.MinkowskiELU(), ME.MinkowskiConvolution(enc_ch[2], enc_ch[2], kernel_size=3, dimension=3), ME.MinkowskiBatchNorm(enc_ch[2]), ME.MinkowskiELU(), ) self.enc_block_s4s8 = nn.Sequential( ME.MinkowskiConvolution( enc_ch[2], enc_ch[3], kernel_size=2, stride=2, dimension=3 ), ME.MinkowskiBatchNorm(enc_ch[3]), ME.MinkowskiELU(), ME.MinkowskiConvolution(enc_ch[3], enc_ch[3], kernel_size=3, dimension=3), ME.MinkowskiBatchNorm(enc_ch[3]), ME.MinkowskiELU(), ) self.enc_block_s8s16 = nn.Sequential( ME.MinkowskiConvolution( enc_ch[3], enc_ch[4], kernel_size=2, stride=2, dimension=3 ), ME.MinkowskiBatchNorm(enc_ch[4]), ME.MinkowskiELU(), ME.MinkowskiConvolution(enc_ch[4], enc_ch[4], kernel_size=3, dimension=3), ME.MinkowskiBatchNorm(enc_ch[4]), ME.MinkowskiELU(), ) self.enc_block_s16s32 = nn.Sequential( ME.MinkowskiConvolution( enc_ch[4], enc_ch[5], kernel_size=2, stride=2, dimension=3 ), ME.MinkowskiBatchNorm(enc_ch[5]), ME.MinkowskiELU(), ME.MinkowskiConvolution(enc_ch[5], enc_ch[5], kernel_size=3, dimension=3), ME.MinkowskiBatchNorm(enc_ch[5]), ME.MinkowskiELU(), ) self.enc_block_s32s64 = nn.Sequential( ME.MinkowskiConvolution( enc_ch[5], enc_ch[6], kernel_size=2, stride=2, dimension=3 ), ME.MinkowskiBatchNorm(enc_ch[6]), ME.MinkowskiELU(), ME.MinkowskiConvolution(enc_ch[6], enc_ch[6], kernel_size=3, dimension=3), ME.MinkowskiBatchNorm(enc_ch[6]), ME.MinkowskiELU(), ) # Decoder self.dec_block_s64s32 = nn.Sequential( ME.MinkowskiGenerativeConvolutionTranspose( enc_ch[6], dec_ch[5], kernel_size=4, stride=2, dimension=3, ), ME.MinkowskiBatchNorm(dec_ch[5]), ME.MinkowskiELU(), ME.MinkowskiConvolution(dec_ch[5], dec_ch[5], kernel_size=3, dimension=3), ME.MinkowskiBatchNorm(dec_ch[5]), ME.MinkowskiELU(), ) self.dec_s32_cls = ME.MinkowskiConvolution( dec_ch[5], 1, kernel_size=1, bias=True, dimension=3 ) self.dec_block_s32s16 = nn.Sequential( ME.MinkowskiGenerativeConvolutionTranspose( enc_ch[5], dec_ch[4], kernel_size=2, stride=2, dimension=3, ), ME.MinkowskiBatchNorm(dec_ch[4]), ME.MinkowskiELU(), ME.MinkowskiConvolution(dec_ch[4], dec_ch[4], kernel_size=3, dimension=3), ME.MinkowskiBatchNorm(dec_ch[4]), ME.MinkowskiELU(), ) self.dec_s16_cls = ME.MinkowskiConvolution( dec_ch[4], 1, kernel_size=1, bias=True, dimension=3 ) self.dec_block_s16s8 = nn.Sequential( ME.MinkowskiGenerativeConvolutionTranspose( dec_ch[4], dec_ch[3], kernel_size=2, stride=2, dimension=3, ), ME.MinkowskiBatchNorm(dec_ch[3]), ME.MinkowskiELU(), ME.MinkowskiConvolution(dec_ch[3], dec_ch[3], kernel_size=3, dimension=3), ME.MinkowskiBatchNorm(dec_ch[3]), ME.MinkowskiELU(), ) self.dec_s8_cls = ME.MinkowskiConvolution( dec_ch[3], 1, kernel_size=1, bias=True, dimension=3 ) self.dec_block_s8s4 = nn.Sequential( ME.MinkowskiGenerativeConvolutionTranspose( dec_ch[3], dec_ch[2], kernel_size=2, stride=2, dimension=3, ), ME.MinkowskiBatchNorm(dec_ch[2]), ME.MinkowskiELU(), ME.MinkowskiConvolution(dec_ch[2], dec_ch[2], kernel_size=3, dimension=3), ME.MinkowskiBatchNorm(dec_ch[2]), ME.MinkowskiELU(), ) self.dec_s4_cls = ME.MinkowskiConvolution( dec_ch[2], 1, kernel_size=1, bias=True, dimension=3 ) self.dec_block_s4s2 = nn.Sequential( ME.MinkowskiGenerativeConvolutionTranspose( dec_ch[2], dec_ch[1], kernel_size=2, stride=2, dimension=3, ), ME.MinkowskiBatchNorm(dec_ch[1]), ME.MinkowskiELU(), ME.MinkowskiConvolution(dec_ch[1], dec_ch[1], kernel_size=3, dimension=3), ME.MinkowskiBatchNorm(dec_ch[1]), ME.MinkowskiELU(), ) self.dec_s2_cls = ME.MinkowskiConvolution( dec_ch[1], 1, kernel_size=1, bias=True, dimension=3 ) self.dec_block_s2s1 = nn.Sequential( ME.MinkowskiGenerativeConvolutionTranspose( dec_ch[1], dec_ch[0], kernel_size=2, stride=2, dimension=3, ), ME.MinkowskiBatchNorm(dec_ch[0]), ME.MinkowskiELU(), ME.MinkowskiConvolution(dec_ch[0], dec_ch[0], kernel_size=3, dimension=3), ME.MinkowskiBatchNorm(dec_ch[0]), ME.MinkowskiELU(), ) self.dec_s1_cls = ME.MinkowskiConvolution( dec_ch[0], 1, kernel_size=1, bias=True, dimension=3 ) # pruning self.pruning = ME.MinkowskiPruning() def get_target(self, out, target_key, kernel_size=1): with torch.no_grad(): target = torch.zeros(len(out), dtype=torch.bool, device=out.device) cm = out.coordinate_manager strided_target_key = cm.stride( target_key, out.tensor_stride[0], ) kernel_map = cm.kernel_map( out.coordinate_map_key, strided_target_key, kernel_size=kernel_size, region_type=1, ) for k, curr_in in kernel_map.items(): target[curr_in[0].long()] = 1 return target def valid_batch_map(self, batch_map): for b in batch_map: if len(b) == 0: return False return True def forward(self, partial_in, target_key): out_cls, targets = [], [] enc_s1 = self.enc_block_s1(partial_in) enc_s2 = self.enc_block_s1s2(enc_s1) enc_s4 = self.enc_block_s2s4(enc_s2) enc_s8 = self.enc_block_s4s8(enc_s4) enc_s16 = self.enc_block_s8s16(enc_s8) enc_s32 = self.enc_block_s16s32(enc_s16) enc_s64 = self.enc_block_s32s64(enc_s32) ################################################## # Decoder 64 -> 32 ################################################## dec_s32 = self.dec_block_s64s32(enc_s64) # Add encoder features dec_s32 = dec_s32 + enc_s32 dec_s32_cls = self.dec_s32_cls(dec_s32) keep_s32 = (dec_s32_cls.F > 0).squeeze() target = self.get_target(dec_s32, target_key) targets.append(target) out_cls.append(dec_s32_cls) if self.training: keep_s32 += target # Remove voxels s32 dec_s32 = self.pruning(dec_s32, keep_s32) ################################################## # Decoder 32 -> 16 ################################################## dec_s16 = self.dec_block_s32s16(dec_s32) # Add encoder features dec_s16 = dec_s16 + enc_s16 dec_s16_cls = self.dec_s16_cls(dec_s16) keep_s16 = (dec_s16_cls.F > 0).squeeze() target = self.get_target(dec_s16, target_key) targets.append(target) out_cls.append(dec_s16_cls) if self.training: keep_s16 += target # Remove voxels s16 dec_s16 = self.pruning(dec_s16, keep_s16) ################################################## # Decoder 16 -> 8 ################################################## dec_s8 = self.dec_block_s16s8(dec_s16) # Add encoder features dec_s8 = dec_s8 + enc_s8 dec_s8_cls = self.dec_s8_cls(dec_s8) target = self.get_target(dec_s8, target_key) targets.append(target) out_cls.append(dec_s8_cls) keep_s8 = (dec_s8_cls.F > 0).squeeze() if self.training: keep_s8 += target # Remove voxels s16 dec_s8 = self.pruning(dec_s8, keep_s8) ################################################## # Decoder 8 -> 4 ################################################## dec_s4 = self.dec_block_s8s4(dec_s8) # Add encoder features dec_s4 = dec_s4 + enc_s4 dec_s4_cls = self.dec_s4_cls(dec_s4) target = self.get_target(dec_s4, target_key) targets.append(target) out_cls.append(dec_s4_cls) keep_s4 = (dec_s4_cls.F > 0).squeeze() if self.training: keep_s4 += target # Remove voxels s4 dec_s4 = self.pruning(dec_s4, keep_s4) ################################################## # Decoder 4 -> 2 ################################################## dec_s2 = self.dec_block_s4s2(dec_s4) # Add encoder features dec_s2 = dec_s2 + enc_s2 dec_s2_cls = self.dec_s2_cls(dec_s2) target = self.get_target(dec_s2, target_key) targets.append(target) out_cls.append(dec_s2_cls) keep_s2 = (dec_s2_cls.F > 0).squeeze() if self.training: keep_s2 += target # Remove voxels s2 dec_s2 = self.pruning(dec_s2, keep_s2) ################################################## # Decoder 2 -> 1 ################################################## dec_s1 = self.dec_block_s2s1(dec_s2) dec_s1_cls = self.dec_s1_cls(dec_s1) # Add encoder features dec_s1 = dec_s1 + enc_s1 dec_s1_cls = self.dec_s1_cls(dec_s1) target = self.get_target(dec_s1, target_key) targets.append(target) out_cls.append(dec_s1_cls) keep_s1 = (dec_s1_cls.F > 0).squeeze() # Last layer does not require adding the target # if self.training: # keep_s1 += target # Remove voxels s1 dec_s1 = self.pruning(dec_s1, keep_s1) return out_cls, targets, dec_s1 def train(net, dataloader, device, config): optimizer = optim.SGD( net.parameters(), lr=config.lr, momentum=config.momentum, weight_decay=config.weight_decay, ) scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.95) crit = nn.BCEWithLogitsLoss() net.train() train_iter = iter(dataloader) # val_iter = iter(val_dataloader) logging.info(f"LR: {scheduler.get_lr()}") for i in range(config.max_iter): s = time() data_dict = train_iter.next() d = time() - s optimizer.zero_grad() in_feat = torch.ones((len(data_dict["coords"]), 1)) sin = ME.SparseTensor( features=in_feat, coordinates=data_dict["coords"], device=device, ) # Generate target sparse tensor cm = sin.coordinate_manager target_key, _ = cm.insert_and_map( ME.utils.batched_coordinates(data_dict["xyzs"]).to(device), string_id="target", ) # Generate from a dense tensor out_cls, targets, sout = net(sin, target_key) num_layers, loss = len(out_cls), 0 losses = [] for out_cl, target in zip(out_cls, targets): curr_loss = crit(out_cl.F.squeeze(), target.type(out_cl.F.dtype).to(device)) losses.append(curr_loss.item()) loss += curr_loss / num_layers loss.backward() optimizer.step() t = time() - s if i % config.stat_freq == 0: logging.info( f"Iter: {i}, Loss: {loss.item():.3e}, Data Loading Time: {d:.3e}, Tot Time: {t:.3e}" ) if i % config.val_freq == 0 and i > 0: torch.save( { "state_dict": net.state_dict(), "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict(), "curr_iter": i, }, config.weights, ) scheduler.step() logging.info(f"LR: {scheduler.get_lr()}") net.train() def visualize(net, dataloader, device, config): net.eval() crit = nn.BCEWithLogitsLoss() n_vis = 0 for data_dict in dataloader: in_feat = torch.ones((len(data_dict["coords"]), 1)) sin = ME.SparseTensor( feats=in_feat, coords=data_dict["coords"], ).to(device) # Generate target sparse tensor cm = sin.coords_man target_key = cm.create_coords_key( ME.utils.batched_coordinates(data_dict["xyzs"]), force_creation=True, allow_duplicate_coords=True, ) # Generate from a dense tensor out_cls, targets, sout = net(sin, target_key) num_layers, loss = len(out_cls), 0 for out_cl, target in zip(out_cls, targets): loss += ( crit(out_cl.F.squeeze(), target.type(out_cl.F.dtype).to(device)) / num_layers ) batch_coords, batch_feats = sout.decomposed_coordinates_and_features for b, (coords, feats) in enumerate(zip(batch_coords, batch_feats)): pcd = PointCloud(coords) pcd.estimate_normals() pcd.translate([0.6 * config.resolution, 0, 0]) pcd.rotate(M, np.array([[0.0], [0.0], [0.0]])) opcd = PointCloud(data_dict["cropped_coords"][b]) opcd.translate([-0.6 * config.resolution, 0, 0]) opcd.estimate_normals() opcd.rotate(M, np.array([[0.0], [0.0], [0.0]])) o3d.visualization.draw_geometries([pcd, opcd]) n_vis += 1 if n_vis > config.max_visualization: return if __name__ == "__main__": config = parser.parse_args() logging.info(config) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dataloader = make_data_loader( "val", augment_data=True, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers, repeat=True, config=config, ) in_nchannel = len(dataloader.dataset) net = CompletionNet(config.resolution, in_nchannel=in_nchannel) net.to(device) logging.info(net) if not config.eval: train(net, dataloader, device, config) else: if not os.path.exists(config.weights): logging.info(f"Downloaing pretrained weights. This might take a while...") urllib.request.urlretrieve( "https://bit.ly/36d9m1n", filename=config.weights ) logging.info(f"Loading weights from {config.weights}") checkpoint = torch.load(config.weights) net.load_state_dict(checkpoint["state_dict"]) visualize(net, dataloader, device, config)