|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import sys |
|
|
import subprocess |
|
|
import argparse |
|
|
import logging |
|
|
import numpy as np |
|
|
from time import time |
|
|
import urllib |
|
|
|
|
|
|
|
|
try: |
|
|
import open3d as o3d |
|
|
except ImportError: |
|
|
raise ImportError("Please install open3d with `pip install open3d`.") |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.utils.data |
|
|
import torch.optim as optim |
|
|
|
|
|
import MinkowskiEngine as ME |
|
|
|
|
|
from examples.reconstruction import ModelNet40Dataset, InfSampler |
|
|
|
|
|
M = np.array( |
|
|
[ |
|
|
[0.80656762, -0.5868724, -0.07091862], |
|
|
[0.3770505, 0.418344, 0.82632997], |
|
|
[-0.45528188, -0.6932309, 0.55870326], |
|
|
] |
|
|
) |
|
|
|
|
|
assert ( |
|
|
int(o3d.__version__.split(".")[1]) >= 8 |
|
|
), f"Requires open3d version >= 0.8, the current version is {o3d.__version__}" |
|
|
|
|
|
if not os.path.exists("ModelNet40"): |
|
|
logging.info("Downloading the pruned ModelNet40 dataset...") |
|
|
subprocess.run(["sh", "./examples/download_modelnet40.sh"]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def PointCloud(points, colors=None): |
|
|
|
|
|
pcd = o3d.geometry.PointCloud() |
|
|
pcd.points = o3d.utility.Vector3dVector(points) |
|
|
if colors is not None: |
|
|
pcd.colors = o3d.utility.Vector3dVector(colors) |
|
|
return pcd |
|
|
|
|
|
|
|
|
class CollationAndTransformation: |
|
|
def __init__(self, resolution): |
|
|
self.resolution = resolution |
|
|
|
|
|
def random_crop(self, coords_list): |
|
|
crop_coords_list = [] |
|
|
for coords in coords_list: |
|
|
sel = coords[:, 0] < self.resolution / 3 |
|
|
crop_coords_list.append(coords[sel]) |
|
|
return crop_coords_list |
|
|
|
|
|
def __call__(self, list_data): |
|
|
coords, feats, labels = list(zip(*list_data)) |
|
|
coords = self.random_crop(coords) |
|
|
|
|
|
|
|
|
return { |
|
|
"coords": ME.utils.batched_coordinates(coords), |
|
|
"xyzs": [torch.from_numpy(feat).float() for feat in feats], |
|
|
"cropped_coords": coords, |
|
|
"labels": torch.LongTensor(labels), |
|
|
} |
|
|
|
|
|
|
|
|
def make_data_loader( |
|
|
phase, augment_data, batch_size, shuffle, num_workers, repeat, config |
|
|
): |
|
|
dset = ModelNet40Dataset(phase, config=config) |
|
|
|
|
|
args = { |
|
|
"batch_size": batch_size, |
|
|
"num_workers": num_workers, |
|
|
"collate_fn": CollationAndTransformation(config.resolution), |
|
|
"pin_memory": False, |
|
|
"drop_last": False, |
|
|
} |
|
|
|
|
|
if repeat: |
|
|
args["sampler"] = InfSampler(dset, shuffle) |
|
|
else: |
|
|
args["shuffle"] = shuffle |
|
|
|
|
|
loader = torch.utils.data.DataLoader(dset, **args) |
|
|
|
|
|
return loader |
|
|
|
|
|
|
|
|
ch = logging.StreamHandler(sys.stdout) |
|
|
logging.getLogger().setLevel(logging.INFO) |
|
|
logging.basicConfig( |
|
|
format="%(asctime)s %(message)s", |
|
|
datefmt="%m/%d %H:%M:%S", |
|
|
handlers=[ch], |
|
|
) |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--resolution", type=int, default=128) |
|
|
parser.add_argument("--max_iter", type=int, default=30000) |
|
|
parser.add_argument("--val_freq", type=int, default=1000) |
|
|
parser.add_argument("--batch_size", default=16, type=int) |
|
|
parser.add_argument("--lr", default=1e-2, type=float) |
|
|
parser.add_argument("--momentum", type=float, default=0.9) |
|
|
parser.add_argument("--weight_decay", type=float, default=1e-4) |
|
|
parser.add_argument("--num_workers", type=int, default=1) |
|
|
parser.add_argument("--stat_freq", type=int, default=50) |
|
|
parser.add_argument("--weights", type=str, default="modelnet_completion.pth") |
|
|
parser.add_argument("--load_optimizer", type=str, default="true") |
|
|
parser.add_argument("--eval", action="store_true") |
|
|
parser.add_argument("--max_visualization", type=int, default=4) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CompletionNet(nn.Module): |
|
|
|
|
|
ENC_CHANNELS = [16, 32, 64, 128, 256, 512, 1024] |
|
|
DEC_CHANNELS = [16, 32, 64, 128, 256, 512, 1024] |
|
|
|
|
|
def __init__(self, resolution, in_nchannel=512): |
|
|
nn.Module.__init__(self) |
|
|
|
|
|
self.resolution = resolution |
|
|
|
|
|
|
|
|
enc_ch = self.ENC_CHANNELS |
|
|
dec_ch = self.DEC_CHANNELS |
|
|
|
|
|
|
|
|
self.enc_block_s1 = nn.Sequential( |
|
|
ME.MinkowskiConvolution(1, enc_ch[0], kernel_size=3, stride=1, dimension=3), |
|
|
ME.MinkowskiBatchNorm(enc_ch[0]), |
|
|
ME.MinkowskiELU(), |
|
|
) |
|
|
|
|
|
self.enc_block_s1s2 = nn.Sequential( |
|
|
ME.MinkowskiConvolution( |
|
|
enc_ch[0], enc_ch[1], kernel_size=2, stride=2, dimension=3 |
|
|
), |
|
|
ME.MinkowskiBatchNorm(enc_ch[1]), |
|
|
ME.MinkowskiELU(), |
|
|
ME.MinkowskiConvolution(enc_ch[1], enc_ch[1], kernel_size=3, dimension=3), |
|
|
ME.MinkowskiBatchNorm(enc_ch[1]), |
|
|
ME.MinkowskiELU(), |
|
|
) |
|
|
|
|
|
self.enc_block_s2s4 = nn.Sequential( |
|
|
ME.MinkowskiConvolution( |
|
|
enc_ch[1], enc_ch[2], kernel_size=2, stride=2, dimension=3 |
|
|
), |
|
|
ME.MinkowskiBatchNorm(enc_ch[2]), |
|
|
ME.MinkowskiELU(), |
|
|
ME.MinkowskiConvolution(enc_ch[2], enc_ch[2], kernel_size=3, dimension=3), |
|
|
ME.MinkowskiBatchNorm(enc_ch[2]), |
|
|
ME.MinkowskiELU(), |
|
|
) |
|
|
|
|
|
self.enc_block_s4s8 = nn.Sequential( |
|
|
ME.MinkowskiConvolution( |
|
|
enc_ch[2], enc_ch[3], kernel_size=2, stride=2, dimension=3 |
|
|
), |
|
|
ME.MinkowskiBatchNorm(enc_ch[3]), |
|
|
ME.MinkowskiELU(), |
|
|
ME.MinkowskiConvolution(enc_ch[3], enc_ch[3], kernel_size=3, dimension=3), |
|
|
ME.MinkowskiBatchNorm(enc_ch[3]), |
|
|
ME.MinkowskiELU(), |
|
|
) |
|
|
|
|
|
self.enc_block_s8s16 = nn.Sequential( |
|
|
ME.MinkowskiConvolution( |
|
|
enc_ch[3], enc_ch[4], kernel_size=2, stride=2, dimension=3 |
|
|
), |
|
|
ME.MinkowskiBatchNorm(enc_ch[4]), |
|
|
ME.MinkowskiELU(), |
|
|
ME.MinkowskiConvolution(enc_ch[4], enc_ch[4], kernel_size=3, dimension=3), |
|
|
ME.MinkowskiBatchNorm(enc_ch[4]), |
|
|
ME.MinkowskiELU(), |
|
|
) |
|
|
|
|
|
self.enc_block_s16s32 = nn.Sequential( |
|
|
ME.MinkowskiConvolution( |
|
|
enc_ch[4], enc_ch[5], kernel_size=2, stride=2, dimension=3 |
|
|
), |
|
|
ME.MinkowskiBatchNorm(enc_ch[5]), |
|
|
ME.MinkowskiELU(), |
|
|
ME.MinkowskiConvolution(enc_ch[5], enc_ch[5], kernel_size=3, dimension=3), |
|
|
ME.MinkowskiBatchNorm(enc_ch[5]), |
|
|
ME.MinkowskiELU(), |
|
|
) |
|
|
|
|
|
self.enc_block_s32s64 = nn.Sequential( |
|
|
ME.MinkowskiConvolution( |
|
|
enc_ch[5], enc_ch[6], kernel_size=2, stride=2, dimension=3 |
|
|
), |
|
|
ME.MinkowskiBatchNorm(enc_ch[6]), |
|
|
ME.MinkowskiELU(), |
|
|
ME.MinkowskiConvolution(enc_ch[6], enc_ch[6], kernel_size=3, dimension=3), |
|
|
ME.MinkowskiBatchNorm(enc_ch[6]), |
|
|
ME.MinkowskiELU(), |
|
|
) |
|
|
|
|
|
|
|
|
self.dec_block_s64s32 = nn.Sequential( |
|
|
ME.MinkowskiGenerativeConvolutionTranspose( |
|
|
enc_ch[6], |
|
|
dec_ch[5], |
|
|
kernel_size=4, |
|
|
stride=2, |
|
|
dimension=3, |
|
|
), |
|
|
ME.MinkowskiBatchNorm(dec_ch[5]), |
|
|
ME.MinkowskiELU(), |
|
|
ME.MinkowskiConvolution(dec_ch[5], dec_ch[5], kernel_size=3, dimension=3), |
|
|
ME.MinkowskiBatchNorm(dec_ch[5]), |
|
|
ME.MinkowskiELU(), |
|
|
) |
|
|
|
|
|
self.dec_s32_cls = ME.MinkowskiConvolution( |
|
|
dec_ch[5], 1, kernel_size=1, bias=True, dimension=3 |
|
|
) |
|
|
|
|
|
self.dec_block_s32s16 = nn.Sequential( |
|
|
ME.MinkowskiGenerativeConvolutionTranspose( |
|
|
enc_ch[5], |
|
|
dec_ch[4], |
|
|
kernel_size=2, |
|
|
stride=2, |
|
|
dimension=3, |
|
|
), |
|
|
ME.MinkowskiBatchNorm(dec_ch[4]), |
|
|
ME.MinkowskiELU(), |
|
|
ME.MinkowskiConvolution(dec_ch[4], dec_ch[4], kernel_size=3, dimension=3), |
|
|
ME.MinkowskiBatchNorm(dec_ch[4]), |
|
|
ME.MinkowskiELU(), |
|
|
) |
|
|
|
|
|
self.dec_s16_cls = ME.MinkowskiConvolution( |
|
|
dec_ch[4], 1, kernel_size=1, bias=True, dimension=3 |
|
|
) |
|
|
|
|
|
self.dec_block_s16s8 = nn.Sequential( |
|
|
ME.MinkowskiGenerativeConvolutionTranspose( |
|
|
dec_ch[4], |
|
|
dec_ch[3], |
|
|
kernel_size=2, |
|
|
stride=2, |
|
|
dimension=3, |
|
|
), |
|
|
ME.MinkowskiBatchNorm(dec_ch[3]), |
|
|
ME.MinkowskiELU(), |
|
|
ME.MinkowskiConvolution(dec_ch[3], dec_ch[3], kernel_size=3, dimension=3), |
|
|
ME.MinkowskiBatchNorm(dec_ch[3]), |
|
|
ME.MinkowskiELU(), |
|
|
) |
|
|
|
|
|
self.dec_s8_cls = ME.MinkowskiConvolution( |
|
|
dec_ch[3], 1, kernel_size=1, bias=True, dimension=3 |
|
|
) |
|
|
|
|
|
self.dec_block_s8s4 = nn.Sequential( |
|
|
ME.MinkowskiGenerativeConvolutionTranspose( |
|
|
dec_ch[3], |
|
|
dec_ch[2], |
|
|
kernel_size=2, |
|
|
stride=2, |
|
|
dimension=3, |
|
|
), |
|
|
ME.MinkowskiBatchNorm(dec_ch[2]), |
|
|
ME.MinkowskiELU(), |
|
|
ME.MinkowskiConvolution(dec_ch[2], dec_ch[2], kernel_size=3, dimension=3), |
|
|
ME.MinkowskiBatchNorm(dec_ch[2]), |
|
|
ME.MinkowskiELU(), |
|
|
) |
|
|
|
|
|
self.dec_s4_cls = ME.MinkowskiConvolution( |
|
|
dec_ch[2], 1, kernel_size=1, bias=True, dimension=3 |
|
|
) |
|
|
|
|
|
self.dec_block_s4s2 = nn.Sequential( |
|
|
ME.MinkowskiGenerativeConvolutionTranspose( |
|
|
dec_ch[2], |
|
|
dec_ch[1], |
|
|
kernel_size=2, |
|
|
stride=2, |
|
|
dimension=3, |
|
|
), |
|
|
ME.MinkowskiBatchNorm(dec_ch[1]), |
|
|
ME.MinkowskiELU(), |
|
|
ME.MinkowskiConvolution(dec_ch[1], dec_ch[1], kernel_size=3, dimension=3), |
|
|
ME.MinkowskiBatchNorm(dec_ch[1]), |
|
|
ME.MinkowskiELU(), |
|
|
) |
|
|
|
|
|
self.dec_s2_cls = ME.MinkowskiConvolution( |
|
|
dec_ch[1], 1, kernel_size=1, bias=True, dimension=3 |
|
|
) |
|
|
|
|
|
self.dec_block_s2s1 = nn.Sequential( |
|
|
ME.MinkowskiGenerativeConvolutionTranspose( |
|
|
dec_ch[1], |
|
|
dec_ch[0], |
|
|
kernel_size=2, |
|
|
stride=2, |
|
|
dimension=3, |
|
|
), |
|
|
ME.MinkowskiBatchNorm(dec_ch[0]), |
|
|
ME.MinkowskiELU(), |
|
|
ME.MinkowskiConvolution(dec_ch[0], dec_ch[0], kernel_size=3, dimension=3), |
|
|
ME.MinkowskiBatchNorm(dec_ch[0]), |
|
|
ME.MinkowskiELU(), |
|
|
) |
|
|
|
|
|
self.dec_s1_cls = ME.MinkowskiConvolution( |
|
|
dec_ch[0], 1, kernel_size=1, bias=True, dimension=3 |
|
|
) |
|
|
|
|
|
|
|
|
self.pruning = ME.MinkowskiPruning() |
|
|
|
|
|
def get_target(self, out, target_key, kernel_size=1): |
|
|
with torch.no_grad(): |
|
|
target = torch.zeros(len(out), dtype=torch.bool, device=out.device) |
|
|
cm = out.coordinate_manager |
|
|
strided_target_key = cm.stride( |
|
|
target_key, out.tensor_stride[0], |
|
|
) |
|
|
kernel_map = cm.kernel_map( |
|
|
out.coordinate_map_key, |
|
|
strided_target_key, |
|
|
kernel_size=kernel_size, |
|
|
region_type=1, |
|
|
) |
|
|
for k, curr_in in kernel_map.items(): |
|
|
target[curr_in[0].long()] = 1 |
|
|
return target |
|
|
|
|
|
def valid_batch_map(self, batch_map): |
|
|
for b in batch_map: |
|
|
if len(b) == 0: |
|
|
return False |
|
|
return True |
|
|
|
|
|
def forward(self, partial_in, target_key): |
|
|
out_cls, targets = [], [] |
|
|
|
|
|
enc_s1 = self.enc_block_s1(partial_in) |
|
|
enc_s2 = self.enc_block_s1s2(enc_s1) |
|
|
enc_s4 = self.enc_block_s2s4(enc_s2) |
|
|
enc_s8 = self.enc_block_s4s8(enc_s4) |
|
|
enc_s16 = self.enc_block_s8s16(enc_s8) |
|
|
enc_s32 = self.enc_block_s16s32(enc_s16) |
|
|
enc_s64 = self.enc_block_s32s64(enc_s32) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dec_s32 = self.dec_block_s64s32(enc_s64) |
|
|
|
|
|
|
|
|
dec_s32 = dec_s32 + enc_s32 |
|
|
dec_s32_cls = self.dec_s32_cls(dec_s32) |
|
|
keep_s32 = (dec_s32_cls.F > 0).squeeze() |
|
|
|
|
|
target = self.get_target(dec_s32, target_key) |
|
|
targets.append(target) |
|
|
out_cls.append(dec_s32_cls) |
|
|
|
|
|
if self.training: |
|
|
keep_s32 += target |
|
|
|
|
|
|
|
|
dec_s32 = self.pruning(dec_s32, keep_s32) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dec_s16 = self.dec_block_s32s16(dec_s32) |
|
|
|
|
|
|
|
|
dec_s16 = dec_s16 + enc_s16 |
|
|
dec_s16_cls = self.dec_s16_cls(dec_s16) |
|
|
keep_s16 = (dec_s16_cls.F > 0).squeeze() |
|
|
|
|
|
target = self.get_target(dec_s16, target_key) |
|
|
targets.append(target) |
|
|
out_cls.append(dec_s16_cls) |
|
|
|
|
|
if self.training: |
|
|
keep_s16 += target |
|
|
|
|
|
|
|
|
dec_s16 = self.pruning(dec_s16, keep_s16) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dec_s8 = self.dec_block_s16s8(dec_s16) |
|
|
|
|
|
|
|
|
dec_s8 = dec_s8 + enc_s8 |
|
|
dec_s8_cls = self.dec_s8_cls(dec_s8) |
|
|
|
|
|
target = self.get_target(dec_s8, target_key) |
|
|
targets.append(target) |
|
|
out_cls.append(dec_s8_cls) |
|
|
keep_s8 = (dec_s8_cls.F > 0).squeeze() |
|
|
|
|
|
if self.training: |
|
|
keep_s8 += target |
|
|
|
|
|
|
|
|
dec_s8 = self.pruning(dec_s8, keep_s8) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dec_s4 = self.dec_block_s8s4(dec_s8) |
|
|
|
|
|
|
|
|
dec_s4 = dec_s4 + enc_s4 |
|
|
dec_s4_cls = self.dec_s4_cls(dec_s4) |
|
|
|
|
|
target = self.get_target(dec_s4, target_key) |
|
|
targets.append(target) |
|
|
out_cls.append(dec_s4_cls) |
|
|
keep_s4 = (dec_s4_cls.F > 0).squeeze() |
|
|
|
|
|
if self.training: |
|
|
keep_s4 += target |
|
|
|
|
|
|
|
|
dec_s4 = self.pruning(dec_s4, keep_s4) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dec_s2 = self.dec_block_s4s2(dec_s4) |
|
|
|
|
|
|
|
|
dec_s2 = dec_s2 + enc_s2 |
|
|
dec_s2_cls = self.dec_s2_cls(dec_s2) |
|
|
|
|
|
target = self.get_target(dec_s2, target_key) |
|
|
targets.append(target) |
|
|
out_cls.append(dec_s2_cls) |
|
|
keep_s2 = (dec_s2_cls.F > 0).squeeze() |
|
|
|
|
|
if self.training: |
|
|
keep_s2 += target |
|
|
|
|
|
|
|
|
dec_s2 = self.pruning(dec_s2, keep_s2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dec_s1 = self.dec_block_s2s1(dec_s2) |
|
|
dec_s1_cls = self.dec_s1_cls(dec_s1) |
|
|
|
|
|
|
|
|
dec_s1 = dec_s1 + enc_s1 |
|
|
dec_s1_cls = self.dec_s1_cls(dec_s1) |
|
|
|
|
|
target = self.get_target(dec_s1, target_key) |
|
|
targets.append(target) |
|
|
out_cls.append(dec_s1_cls) |
|
|
keep_s1 = (dec_s1_cls.F > 0).squeeze() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dec_s1 = self.pruning(dec_s1, keep_s1) |
|
|
|
|
|
return out_cls, targets, dec_s1 |
|
|
|
|
|
|
|
|
def train(net, dataloader, device, config): |
|
|
optimizer = optim.SGD( |
|
|
net.parameters(), |
|
|
lr=config.lr, |
|
|
momentum=config.momentum, |
|
|
weight_decay=config.weight_decay, |
|
|
) |
|
|
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.95) |
|
|
|
|
|
crit = nn.BCEWithLogitsLoss() |
|
|
|
|
|
net.train() |
|
|
train_iter = iter(dataloader) |
|
|
|
|
|
logging.info(f"LR: {scheduler.get_lr()}") |
|
|
for i in range(config.max_iter): |
|
|
|
|
|
s = time() |
|
|
data_dict = train_iter.next() |
|
|
d = time() - s |
|
|
|
|
|
optimizer.zero_grad() |
|
|
|
|
|
in_feat = torch.ones((len(data_dict["coords"]), 1)) |
|
|
|
|
|
sin = ME.SparseTensor( |
|
|
features=in_feat, |
|
|
coordinates=data_dict["coords"], |
|
|
device=device, |
|
|
) |
|
|
|
|
|
|
|
|
cm = sin.coordinate_manager |
|
|
target_key, _ = cm.insert_and_map( |
|
|
ME.utils.batched_coordinates(data_dict["xyzs"]).to(device), |
|
|
string_id="target", |
|
|
) |
|
|
|
|
|
|
|
|
out_cls, targets, sout = net(sin, target_key) |
|
|
num_layers, loss = len(out_cls), 0 |
|
|
losses = [] |
|
|
for out_cl, target in zip(out_cls, targets): |
|
|
curr_loss = crit(out_cl.F.squeeze(), target.type(out_cl.F.dtype).to(device)) |
|
|
losses.append(curr_loss.item()) |
|
|
loss += curr_loss / num_layers |
|
|
|
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
t = time() - s |
|
|
|
|
|
if i % config.stat_freq == 0: |
|
|
logging.info( |
|
|
f"Iter: {i}, Loss: {loss.item():.3e}, Data Loading Time: {d:.3e}, Tot Time: {t:.3e}" |
|
|
) |
|
|
|
|
|
if i % config.val_freq == 0 and i > 0: |
|
|
torch.save( |
|
|
{ |
|
|
"state_dict": net.state_dict(), |
|
|
"optimizer": optimizer.state_dict(), |
|
|
"scheduler": scheduler.state_dict(), |
|
|
"curr_iter": i, |
|
|
}, |
|
|
config.weights, |
|
|
) |
|
|
|
|
|
scheduler.step() |
|
|
logging.info(f"LR: {scheduler.get_lr()}") |
|
|
|
|
|
net.train() |
|
|
|
|
|
|
|
|
def visualize(net, dataloader, device, config): |
|
|
net.eval() |
|
|
crit = nn.BCEWithLogitsLoss() |
|
|
n_vis = 0 |
|
|
|
|
|
for data_dict in dataloader: |
|
|
in_feat = torch.ones((len(data_dict["coords"]), 1)) |
|
|
|
|
|
sin = ME.SparseTensor( |
|
|
feats=in_feat, |
|
|
coords=data_dict["coords"], |
|
|
).to(device) |
|
|
|
|
|
|
|
|
cm = sin.coords_man |
|
|
target_key = cm.create_coords_key( |
|
|
ME.utils.batched_coordinates(data_dict["xyzs"]), |
|
|
force_creation=True, |
|
|
allow_duplicate_coords=True, |
|
|
) |
|
|
|
|
|
|
|
|
out_cls, targets, sout = net(sin, target_key) |
|
|
num_layers, loss = len(out_cls), 0 |
|
|
for out_cl, target in zip(out_cls, targets): |
|
|
loss += ( |
|
|
crit(out_cl.F.squeeze(), target.type(out_cl.F.dtype).to(device)) |
|
|
/ num_layers |
|
|
) |
|
|
|
|
|
batch_coords, batch_feats = sout.decomposed_coordinates_and_features |
|
|
for b, (coords, feats) in enumerate(zip(batch_coords, batch_feats)): |
|
|
pcd = PointCloud(coords) |
|
|
pcd.estimate_normals() |
|
|
pcd.translate([0.6 * config.resolution, 0, 0]) |
|
|
pcd.rotate(M, np.array([[0.0], [0.0], [0.0]])) |
|
|
opcd = PointCloud(data_dict["cropped_coords"][b]) |
|
|
opcd.translate([-0.6 * config.resolution, 0, 0]) |
|
|
opcd.estimate_normals() |
|
|
opcd.rotate(M, np.array([[0.0], [0.0], [0.0]])) |
|
|
o3d.visualization.draw_geometries([pcd, opcd]) |
|
|
|
|
|
n_vis += 1 |
|
|
if n_vis > config.max_visualization: |
|
|
return |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
config = parser.parse_args() |
|
|
logging.info(config) |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
dataloader = make_data_loader( |
|
|
"val", |
|
|
augment_data=True, |
|
|
batch_size=config.batch_size, |
|
|
shuffle=True, |
|
|
num_workers=config.num_workers, |
|
|
repeat=True, |
|
|
config=config, |
|
|
) |
|
|
in_nchannel = len(dataloader.dataset) |
|
|
|
|
|
net = CompletionNet(config.resolution, in_nchannel=in_nchannel) |
|
|
net.to(device) |
|
|
|
|
|
logging.info(net) |
|
|
|
|
|
if not config.eval: |
|
|
train(net, dataloader, device, config) |
|
|
else: |
|
|
if not os.path.exists(config.weights): |
|
|
logging.info(f"Downloaing pretrained weights. This might take a while...") |
|
|
urllib.request.urlretrieve( |
|
|
"https://bit.ly/36d9m1n", filename=config.weights |
|
|
) |
|
|
|
|
|
logging.info(f"Loading weights from {config.weights}") |
|
|
checkpoint = torch.load(config.weights) |
|
|
net.load_state_dict(checkpoint["state_dict"]) |
|
|
|
|
|
visualize(net, dataloader, device, config) |
|
|
|