Upload 19 files
Browse files- __init__.py +0 -0
- app.py +49 -0
- checkpoints/epoch=99-step=6000.ckpt +3 -0
- files/bunny_n1_hi_50.obj +0 -0
- files/child_n2_80.obj +0 -0
- files/eight_n3_70.obj +0 -0
- models/SAP/__init__.py +0 -0
- models/SAP/__pycache__/__init__.cpython-39.pyc +0 -0
- models/SAP/__pycache__/dpsr.cpython-39.pyc +0 -0
- models/SAP/__pycache__/model.cpython-39.pyc +0 -0
- models/SAP/__pycache__/utils.cpython-39.pyc +0 -0
- models/SAP/dpsr.py +65 -0
- models/SAP/model.py +129 -0
- models/SAP/utils.py +526 -0
- models/__init__.py +0 -0
- models/__pycache__/__init__.cpython-39.pyc +0 -0
- models/__pycache__/model.cpython-39.pyc +0 -0
- models/model.py +181 -0
- requirements.txt +298 -0
__init__.py
ADDED
|
File without changes
|
app.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import pytorch_lightning as pl
|
| 5 |
+
import torch as th
|
| 6 |
+
import open3d as o3d
|
| 7 |
+
import numpy as np
|
| 8 |
+
import trimesh as tm
|
| 9 |
+
|
| 10 |
+
from models.model import Model
|
| 11 |
+
|
| 12 |
+
model = Model()
|
| 13 |
+
ckpg = th.load("./checkpoints/epoch=99-step=6000.ckpt")
|
| 14 |
+
model.load_state_dict(ckpg["state_dict"])
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def process_mesh(mesh_file_name):
|
| 18 |
+
|
| 19 |
+
mesh = tm.load_mesh(mesh_file_name)
|
| 20 |
+
|
| 21 |
+
v = th.tensor(mesh.vertices, dtype=th.float)
|
| 22 |
+
n = th.tensor(mesh.vertex_normals, dtype=th.float)
|
| 23 |
+
|
| 24 |
+
with th.no_grad():
|
| 25 |
+
v, f, n, _ = model(v.unsqueeze(0), n.unsqueeze(0))
|
| 26 |
+
|
| 27 |
+
mesh = tm.Trimesh(vertices=v.squeeze(0),
|
| 28 |
+
faces=f.squeeze(0),
|
| 29 |
+
vertex_normals=n.squeeze(0))
|
| 30 |
+
obj_path = "./sample.obj"
|
| 31 |
+
mesh.export(obj_path)
|
| 32 |
+
|
| 33 |
+
return obj_path
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
demo = gr.Interface(
|
| 37 |
+
fn=process_mesh,
|
| 38 |
+
inputs=gr.Model3D(),
|
| 39 |
+
outputs=gr.Model3D(
|
| 40 |
+
clear_color=[0.0, 0.0, 0.0, 0.0], label="3D Model"),
|
| 41 |
+
examples=[
|
| 42 |
+
[os.path.join(os.path.dirname(__file__), "files\\bunny_n1_hi_50.obj")],
|
| 43 |
+
[os.path.join(os.path.dirname(__file__), "files\\child_n2_80.obj")],
|
| 44 |
+
[os.path.join(os.path.dirname(__file__), "files\\eight_n3_70.obj")],
|
| 45 |
+
],
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
if __name__ == "__main__":
|
| 49 |
+
demo.launch()
|
checkpoints/epoch=99-step=6000.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d3a025dfa88edbf34bf7cf2b69c554cb01ab7d61b4d8cc699a2a6753e14dbdea
|
| 3 |
+
size 4308343
|
files/bunny_n1_hi_50.obj
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
files/child_n2_80.obj
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
files/eight_n3_70.obj
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
models/SAP/__init__.py
ADDED
|
File without changes
|
models/SAP/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (133 Bytes). View file
|
|
|
models/SAP/__pycache__/dpsr.cpython-39.pyc
ADDED
|
Binary file (2.47 kB). View file
|
|
|
models/SAP/__pycache__/model.cpython-39.pyc
ADDED
|
Binary file (4.08 kB). View file
|
|
|
models/SAP/__pycache__/utils.cpython-39.pyc
ADDED
|
Binary file (15 kB). View file
|
|
|
models/SAP/dpsr.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from .utils import spec_gaussian_filter, fftfreqs, img, grid_interp, point_rasterize
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch.fft
|
| 6 |
+
|
| 7 |
+
class DPSR(nn.Module):
|
| 8 |
+
def __init__(self, res, sig=10, scale=True, shift=True):
|
| 9 |
+
"""
|
| 10 |
+
:param res: tuple of output field resolution. eg., (128,128)
|
| 11 |
+
:param sig: degree of gaussian smoothing
|
| 12 |
+
"""
|
| 13 |
+
super(DPSR, self).__init__()
|
| 14 |
+
self.res = res
|
| 15 |
+
self.sig = sig
|
| 16 |
+
self.dim = len(res)
|
| 17 |
+
self.denom = np.prod(res)
|
| 18 |
+
G = spec_gaussian_filter(res=res, sig=sig).float()
|
| 19 |
+
# self.G.requires_grad = False # True, if we also make sig a learnable parameter
|
| 20 |
+
self.omega = fftfreqs(res, dtype=torch.float32)
|
| 21 |
+
self.scale = scale
|
| 22 |
+
self.shift = shift
|
| 23 |
+
self.register_buffer("G", G)
|
| 24 |
+
|
| 25 |
+
def forward(self, V, N):
|
| 26 |
+
"""
|
| 27 |
+
:param V: (batch, nv, 2 or 3) tensor for point cloud coordinates
|
| 28 |
+
:param N: (batch, nv, 2 or 3) tensor for point normals
|
| 29 |
+
:return phi: (batch, res, res, ...) tensor of output indicator function field
|
| 30 |
+
"""
|
| 31 |
+
assert(V.shape == N.shape) # [b, nv, ndims]
|
| 32 |
+
ras_p = point_rasterize(V, N, self.res) # [b, n_dim, dim0, dim1, dim2]
|
| 33 |
+
|
| 34 |
+
ras_s = torch.fft.rfftn(ras_p, dim=(2,3,4))
|
| 35 |
+
ras_s = ras_s.permute(*tuple([0]+list(range(2, self.dim+1))+[self.dim+1, 1]))
|
| 36 |
+
N_ = ras_s[..., None] * self.G # [b, dim0, dim1, dim2/2+1, n_dim, 1]
|
| 37 |
+
|
| 38 |
+
omega = fftfreqs(self.res, dtype=torch.float32).unsqueeze(-1) # [dim0, dim1, dim2/2+1, n_dim, 1]
|
| 39 |
+
omega *= 2 * np.pi # normalize frequencies
|
| 40 |
+
omega = omega.to(V.device)
|
| 41 |
+
|
| 42 |
+
DivN = torch.sum(-img(torch.view_as_real(N_[..., 0])) * omega, dim=-2)
|
| 43 |
+
|
| 44 |
+
Lap = -torch.sum(omega**2, -2) # [dim0, dim1, dim2/2+1, 1]
|
| 45 |
+
Phi = DivN / (Lap+1e-6) # [b, dim0, dim1, dim2/2+1, 2]
|
| 46 |
+
Phi = Phi.permute(*tuple([list(range(1,self.dim+2)) + [0]])) # [dim0, dim1, dim2/2+1, 2, b]
|
| 47 |
+
Phi[tuple([0] * self.dim)] = 0
|
| 48 |
+
Phi = Phi.permute(*tuple([[self.dim+1] + list(range(self.dim+1))])) # [b, dim0, dim1, dim2/2+1, 2]
|
| 49 |
+
|
| 50 |
+
phi = torch.fft.irfftn(torch.view_as_complex(Phi), s=self.res, dim=(1,2,3))
|
| 51 |
+
|
| 52 |
+
if self.shift or self.scale:
|
| 53 |
+
# ensure values at points are zero
|
| 54 |
+
fv = grid_interp(phi.unsqueeze(-1), V, batched=True).squeeze(-1) # [b, nv]
|
| 55 |
+
if self.shift: # offset points to have mean of 0
|
| 56 |
+
offset = torch.mean(fv, dim=-1) # [b,]
|
| 57 |
+
phi -= offset.view(*tuple([-1] + [1] * self.dim))
|
| 58 |
+
|
| 59 |
+
phi = phi.permute(*tuple([list(range(1,self.dim+1)) + [0]]))
|
| 60 |
+
fv0 = phi[tuple([0] * self.dim)] # [b,]
|
| 61 |
+
phi = phi.permute(*tuple([[self.dim] + list(range(self.dim))]))
|
| 62 |
+
|
| 63 |
+
if self.scale:
|
| 64 |
+
phi = -phi / torch.abs(fv0.view(*tuple([-1]+[1] * self.dim))) *0.5
|
| 65 |
+
return phi
|
models/SAP/model.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import time
|
| 4 |
+
from .utils import point_rasterize, grid_interp, mc_from_psr, \
|
| 5 |
+
calc_inters_points
|
| 6 |
+
from .dpsr import DPSR
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
|
| 9 |
+
class PSR2Mesh(torch.autograd.Function):
|
| 10 |
+
@staticmethod
|
| 11 |
+
def forward(ctx, psr_grid):
|
| 12 |
+
"""
|
| 13 |
+
In the forward pass we receive a Tensor containing the input and return
|
| 14 |
+
a Tensor containing the output. ctx is a context object that can be used
|
| 15 |
+
to stash information for backward computation. You can cache arbitrary
|
| 16 |
+
objects for use in the backward pass using the ctx.save_for_backward method.
|
| 17 |
+
"""
|
| 18 |
+
verts, faces, normals = mc_from_psr(psr_grid, pytorchify=True)
|
| 19 |
+
verts = verts.unsqueeze(0)
|
| 20 |
+
faces = faces.unsqueeze(0)
|
| 21 |
+
normals = normals.unsqueeze(0)
|
| 22 |
+
|
| 23 |
+
res = torch.tensor(psr_grid.detach().shape[2])
|
| 24 |
+
ctx.save_for_backward(verts, normals, res)
|
| 25 |
+
|
| 26 |
+
return verts, faces, normals
|
| 27 |
+
|
| 28 |
+
@staticmethod
|
| 29 |
+
def backward(ctx, dL_dVertex, dL_dFace, dL_dNormals):
|
| 30 |
+
"""
|
| 31 |
+
In the backward pass we receive a Tensor containing the gradient of the loss
|
| 32 |
+
with respect to the output, and we need to compute the gradient of the loss
|
| 33 |
+
with respect to the input.
|
| 34 |
+
"""
|
| 35 |
+
vert_pts, normals, res = ctx.saved_tensors
|
| 36 |
+
res = (res.item(), res.item(), res.item())
|
| 37 |
+
# matrix multiplication between dL/dV and dV/dPSR
|
| 38 |
+
# dV/dPSR = - normals
|
| 39 |
+
grad_vert = torch.matmul(dL_dVertex.permute(1, 0, 2), -normals.permute(1, 2, 0))
|
| 40 |
+
grad_grid = point_rasterize(vert_pts, grad_vert.permute(1, 0, 2), res) # b x 1 x res x res x res
|
| 41 |
+
|
| 42 |
+
return grad_grid
|
| 43 |
+
|
| 44 |
+
class PSR2SurfacePoints(torch.autograd.Function):
|
| 45 |
+
@staticmethod
|
| 46 |
+
def forward(ctx, psr_grid, poses, img_size, uv, psr_grad, mask_sample):
|
| 47 |
+
verts, faces, normals = mc_from_psr(psr_grid, pytorchify=True)
|
| 48 |
+
verts = verts * 2. - 1. # within the range of [-1, 1]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
p_all, n_all, mask_all = [], [], []
|
| 52 |
+
|
| 53 |
+
for i in range(len(poses)):
|
| 54 |
+
pose = poses[i]
|
| 55 |
+
if mask_sample is not None:
|
| 56 |
+
p_inters, mask, _, _ = calc_inters_points(verts, faces, pose, img_size, mask_gt=mask_sample[i])
|
| 57 |
+
else:
|
| 58 |
+
p_inters, mask, _, _ = calc_inters_points(verts, faces, pose, img_size)
|
| 59 |
+
|
| 60 |
+
n_inters = grid_interp(psr_grad[None], (p_inters[None].detach() + 1) / 2).squeeze()
|
| 61 |
+
p_all.append(p_inters)
|
| 62 |
+
n_all.append(n_inters)
|
| 63 |
+
mask_all.append(mask)
|
| 64 |
+
p_inters_all = torch.cat(p_all, dim=0)
|
| 65 |
+
n_inters_all = torch.cat(n_all, dim=0)
|
| 66 |
+
mask_visible = torch.stack(mask_all, dim=0)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
res = torch.tensor(psr_grid.detach().shape[2])
|
| 70 |
+
ctx.save_for_backward(p_inters_all, n_inters_all, res)
|
| 71 |
+
|
| 72 |
+
return p_inters_all, mask_visible
|
| 73 |
+
|
| 74 |
+
@staticmethod
|
| 75 |
+
def backward(ctx, dL_dp, dL_dmask):
|
| 76 |
+
pts, pts_n, res = ctx.saved_tensors
|
| 77 |
+
res = (res.item(), res.item(), res.item())
|
| 78 |
+
|
| 79 |
+
# grad from the p_inters via MLP renderer
|
| 80 |
+
grad_pts = torch.matmul(dL_dp[:, None], -pts_n[..., None])
|
| 81 |
+
grad_grid_pts = point_rasterize((pts[None]+1)/2, grad_pts.permute(1, 0, 2), res) # b x 1 x res x res x res
|
| 82 |
+
|
| 83 |
+
return grad_grid_pts, None, None, None, None, None
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# Resnet Blocks from https://github.com/autonomousvision/shape_as_points/blob/12757682f1075d83738b52f96747463b77343caf/src/network/utils.py
|
| 87 |
+
class ResnetBlockFC(nn.Module):
|
| 88 |
+
''' Fully connected ResNet Block class.
|
| 89 |
+
Args:
|
| 90 |
+
size_in (int): input dimension
|
| 91 |
+
size_out (int): output dimension
|
| 92 |
+
size_h (int): hidden dimension
|
| 93 |
+
'''
|
| 94 |
+
|
| 95 |
+
def __init__(self, size_in, size_out=None, size_h=None, siren=False):
|
| 96 |
+
super().__init__()
|
| 97 |
+
# Attributes
|
| 98 |
+
if size_out is None:
|
| 99 |
+
size_out = size_in
|
| 100 |
+
|
| 101 |
+
if size_h is None:
|
| 102 |
+
size_h = min(size_in, size_out)
|
| 103 |
+
|
| 104 |
+
self.size_in = size_in
|
| 105 |
+
self.size_h = size_h
|
| 106 |
+
self.size_out = size_out
|
| 107 |
+
# Submodules
|
| 108 |
+
self.fc_0 = nn.Linear(size_in, size_h)
|
| 109 |
+
self.fc_1 = nn.Linear(size_h, size_out)
|
| 110 |
+
self.actvn = nn.ReLU()
|
| 111 |
+
|
| 112 |
+
if size_in == size_out:
|
| 113 |
+
self.shortcut = None
|
| 114 |
+
else:
|
| 115 |
+
self.shortcut = nn.Linear(size_in, size_out, bias=False)
|
| 116 |
+
# Initialization
|
| 117 |
+
nn.init.zeros_(self.fc_1.weight)
|
| 118 |
+
|
| 119 |
+
def forward(self, x):
|
| 120 |
+
net = self.fc_0(self.actvn(x))
|
| 121 |
+
dx = self.fc_1(self.actvn(net))
|
| 122 |
+
|
| 123 |
+
if self.shortcut is not None:
|
| 124 |
+
x_s = self.shortcut(x)
|
| 125 |
+
else:
|
| 126 |
+
x_s = x
|
| 127 |
+
|
| 128 |
+
return x_s + dx
|
| 129 |
+
|
models/SAP/utils.py
ADDED
|
@@ -0,0 +1,526 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import io, os, logging, urllib
|
| 3 |
+
import yaml
|
| 4 |
+
import trimesh
|
| 5 |
+
import imageio
|
| 6 |
+
import numbers
|
| 7 |
+
import math
|
| 8 |
+
import numpy as np
|
| 9 |
+
from collections import OrderedDict
|
| 10 |
+
from plyfile import PlyData
|
| 11 |
+
from torch import nn
|
| 12 |
+
from torch.nn import functional as F
|
| 13 |
+
from torch.utils import model_zoo
|
| 14 |
+
from skimage import measure, img_as_float32
|
| 15 |
+
from igl import adjacency_matrix, connected_components
|
| 16 |
+
|
| 17 |
+
##################################################
|
| 18 |
+
# Below are functions for DPSR
|
| 19 |
+
|
| 20 |
+
def fftfreqs(res, dtype=torch.float32, exact=True):
|
| 21 |
+
"""
|
| 22 |
+
Helper function to return frequency tensors
|
| 23 |
+
:param res: n_dims int tuple of number of frequency modes
|
| 24 |
+
:return:
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
n_dims = len(res)
|
| 28 |
+
freqs = []
|
| 29 |
+
for dim in range(n_dims - 1):
|
| 30 |
+
r_ = res[dim]
|
| 31 |
+
freq = np.fft.fftfreq(r_, d=1/r_)
|
| 32 |
+
freqs.append(torch.tensor(freq, dtype=dtype))
|
| 33 |
+
r_ = res[-1]
|
| 34 |
+
if exact:
|
| 35 |
+
freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1/r_), dtype=dtype))
|
| 36 |
+
else:
|
| 37 |
+
freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1/r_)[:-1], dtype=dtype))
|
| 38 |
+
omega = torch.meshgrid(freqs)
|
| 39 |
+
omega = list(omega)
|
| 40 |
+
omega = torch.stack(omega, dim=-1)
|
| 41 |
+
|
| 42 |
+
return omega
|
| 43 |
+
|
| 44 |
+
def img(x, deg=1): # imaginary of tensor (assume last dim: real/imag)
|
| 45 |
+
"""
|
| 46 |
+
multiply tensor x by i ** deg
|
| 47 |
+
"""
|
| 48 |
+
deg %= 4
|
| 49 |
+
if deg == 0:
|
| 50 |
+
res = x
|
| 51 |
+
elif deg == 1:
|
| 52 |
+
res = x[..., [1, 0]]
|
| 53 |
+
res[..., 0] = -res[..., 0]
|
| 54 |
+
elif deg == 2:
|
| 55 |
+
res = -x
|
| 56 |
+
elif deg == 3:
|
| 57 |
+
res = x[..., [1, 0]]
|
| 58 |
+
res[..., 1] = -res[..., 1]
|
| 59 |
+
return res
|
| 60 |
+
|
| 61 |
+
def spec_gaussian_filter(res, sig):
|
| 62 |
+
omega = fftfreqs(res, dtype=torch.float64) # [dim0, dim1, dim2, d]
|
| 63 |
+
dis = torch.sqrt(torch.sum(omega ** 2, dim=-1))
|
| 64 |
+
filter_ = torch.exp(-0.5*((sig*2*dis/res[0])**2)).unsqueeze(-1).unsqueeze(-1)
|
| 65 |
+
filter_.requires_grad = False
|
| 66 |
+
|
| 67 |
+
return filter_
|
| 68 |
+
|
| 69 |
+
def grid_interp(grid, pts, batched=True):
|
| 70 |
+
"""
|
| 71 |
+
:param grid: tensor of shape (batch, *size, in_features)
|
| 72 |
+
:param pts: tensor of shape (batch, num_points, dim) within range (0, 1)
|
| 73 |
+
:return values at query points
|
| 74 |
+
"""
|
| 75 |
+
if not batched:
|
| 76 |
+
grid = grid.unsqueeze(0)
|
| 77 |
+
pts = pts.unsqueeze(0)
|
| 78 |
+
dim = pts.shape[-1]
|
| 79 |
+
bs = grid.shape[0]
|
| 80 |
+
size = torch.tensor(grid.shape[1:-1]).to(grid.device).type(pts.dtype)
|
| 81 |
+
cubesize = 1.0 / size
|
| 82 |
+
|
| 83 |
+
ind0 = torch.floor(pts / cubesize).long() # (batch, num_points, dim)
|
| 84 |
+
ind1 = torch.fmod(torch.ceil(pts / cubesize), size).long() # periodic wrap-around
|
| 85 |
+
ind01 = torch.stack((ind0, ind1), dim=0) # (2, batch, num_points, dim)
|
| 86 |
+
tmp = torch.tensor([0,1],dtype=torch.long)
|
| 87 |
+
com_ = torch.stack(torch.meshgrid(tuple([tmp] * dim)), dim=-1).view(-1, dim)
|
| 88 |
+
dim_ = torch.arange(dim).repeat(com_.shape[0], 1) # (2**dim, dim)
|
| 89 |
+
ind_ = ind01[com_, ..., dim_] # (2**dim, dim, batch, num_points)
|
| 90 |
+
ind_n = ind_.permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim)
|
| 91 |
+
ind_b = torch.arange(bs).expand(ind_n.shape[1], ind_n.shape[2], bs).permute(2, 0, 1) # (batch, num_points, 2**dim)
|
| 92 |
+
# latent code on neighbor nodes
|
| 93 |
+
if dim == 2:
|
| 94 |
+
lat = grid.clone()[ind_b, ind_n[..., 0], ind_n[..., 1]] # (batch, num_points, 2**dim, in_features)
|
| 95 |
+
else:
|
| 96 |
+
lat = grid.clone()[ind_b, ind_n[..., 0], ind_n[..., 1], ind_n[..., 2]] # (batch, num_points, 2**dim, in_features)
|
| 97 |
+
|
| 98 |
+
# weights of neighboring nodes
|
| 99 |
+
xyz0 = ind0.type(cubesize.dtype) * cubesize # (batch, num_points, dim)
|
| 100 |
+
xyz1 = (ind0.type(cubesize.dtype) + 1) * cubesize # (batch, num_points, dim)
|
| 101 |
+
xyz01 = torch.stack((xyz0, xyz1), dim=0) # (2, batch, num_points, dim)
|
| 102 |
+
pos = xyz01[com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim)
|
| 103 |
+
pos_ = xyz01[1-com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim)
|
| 104 |
+
pos_ = pos_.type(pts.dtype)
|
| 105 |
+
dxyz_ = torch.abs(pts.unsqueeze(-2) - pos_) / cubesize # (batch, num_points, 2**dim, dim)
|
| 106 |
+
weights = torch.prod(dxyz_, dim=-1, keepdim=False) # (batch, num_points, 2**dim)
|
| 107 |
+
query_values = torch.sum(lat * weights.unsqueeze(-1), dim=-2) # (batch, num_points, in_features)
|
| 108 |
+
if not batched:
|
| 109 |
+
query_values = query_values.squeeze(0)
|
| 110 |
+
|
| 111 |
+
return query_values
|
| 112 |
+
|
| 113 |
+
def scatter_to_grid(inds, vals, size):
|
| 114 |
+
"""
|
| 115 |
+
Scatter update values into empty tensor of size size.
|
| 116 |
+
:param inds: (#values, dims)
|
| 117 |
+
:param vals: (#values)
|
| 118 |
+
:param size: tuple for size. len(size)=dims
|
| 119 |
+
"""
|
| 120 |
+
dims = inds.shape[1]
|
| 121 |
+
assert(inds.shape[0] == vals.shape[0])
|
| 122 |
+
assert(len(size) == dims)
|
| 123 |
+
dev = vals.device
|
| 124 |
+
# result = torch.zeros(*size).view(-1).to(dev).type(vals.dtype) # flatten
|
| 125 |
+
# # flatten inds
|
| 126 |
+
result = torch.zeros(*size, device=dev).view(-1).type(vals.dtype) # flatten
|
| 127 |
+
# flatten inds
|
| 128 |
+
fac = [np.prod(size[i+1:]) for i in range(len(size)-1)] + [1]
|
| 129 |
+
fac = torch.tensor(fac, device=dev).type(inds.dtype)
|
| 130 |
+
inds_fold = torch.sum(inds*fac, dim=-1) # [#values,]
|
| 131 |
+
result.scatter_add_(0, inds_fold, vals)
|
| 132 |
+
result = result.view(*size)
|
| 133 |
+
return result
|
| 134 |
+
|
| 135 |
+
def point_rasterize(pts, vals, size):
|
| 136 |
+
"""
|
| 137 |
+
:param pts: point coords, tensor of shape (batch, num_points, dim) within range (0, 1)
|
| 138 |
+
:param vals: point values, tensor of shape (batch, num_points, features)
|
| 139 |
+
:param size: len(size)=dim tuple for grid size
|
| 140 |
+
:return rasterized values (batch, features, res0, res1, res2)
|
| 141 |
+
"""
|
| 142 |
+
dim = pts.shape[-1]
|
| 143 |
+
assert(pts.shape[:2] == vals.shape[:2])
|
| 144 |
+
assert(pts.shape[2] == dim)
|
| 145 |
+
size_list = list(size)
|
| 146 |
+
size = torch.tensor(size).to(pts.device).float()
|
| 147 |
+
cubesize = 1.0 / size
|
| 148 |
+
bs = pts.shape[0]
|
| 149 |
+
nf = vals.shape[-1]
|
| 150 |
+
npts = pts.shape[1]
|
| 151 |
+
dev = pts.device
|
| 152 |
+
|
| 153 |
+
ind0 = torch.floor(pts / cubesize).long() # (batch, num_points, dim)
|
| 154 |
+
ind1 = torch.fmod(torch.ceil(pts / cubesize), size).long() # periodic wrap-around
|
| 155 |
+
ind01 = torch.stack((ind0, ind1), dim=0) # (2, batch, num_points, dim)
|
| 156 |
+
tmp = torch.tensor([0,1],dtype=torch.long)
|
| 157 |
+
com_ = torch.stack(torch.meshgrid(tuple([tmp] * dim)), dim=-1).view(-1, dim)
|
| 158 |
+
dim_ = torch.arange(dim).repeat(com_.shape[0], 1) # (2**dim, dim)
|
| 159 |
+
ind_ = ind01[com_, ..., dim_] # (2**dim, dim, batch, num_points)
|
| 160 |
+
ind_n = ind_.permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim)
|
| 161 |
+
# ind_b = torch.arange(bs).expand(ind_n.shape[1], ind_n.shape[2], bs).permute(2, 0, 1) # (batch, num_points, 2**dim)
|
| 162 |
+
ind_b = torch.arange(bs, device=dev).expand(ind_n.shape[1], ind_n.shape[2], bs).permute(2, 0, 1) # (batch, num_points, 2**dim)
|
| 163 |
+
|
| 164 |
+
# weights of neighboring nodes
|
| 165 |
+
xyz0 = ind0.type(cubesize.dtype) * cubesize # (batch, num_points, dim)
|
| 166 |
+
xyz1 = (ind0.type(cubesize.dtype) + 1) * cubesize # (batch, num_points, dim)
|
| 167 |
+
xyz01 = torch.stack((xyz0, xyz1), dim=0) # (2, batch, num_points, dim)
|
| 168 |
+
pos = xyz01[com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim)
|
| 169 |
+
pos_ = xyz01[1-com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim)
|
| 170 |
+
pos_ = pos_.type(pts.dtype)
|
| 171 |
+
dxyz_ = torch.abs(pts.unsqueeze(-2) - pos_) / cubesize # (batch, num_points, 2**dim, dim)
|
| 172 |
+
weights = torch.prod(dxyz_, dim=-1, keepdim=False) # (batch, num_points, 2**dim)
|
| 173 |
+
|
| 174 |
+
ind_b = ind_b.unsqueeze(-1).unsqueeze(-1) # (batch, num_points, 2**dim, 1, 1)
|
| 175 |
+
ind_n = ind_n.unsqueeze(-2) # (batch, num_points, 2**dim, 1, dim)
|
| 176 |
+
ind_f = torch.arange(nf, device=dev).view(1, 1, 1, nf, 1) # (1, 1, 1, nf, 1)
|
| 177 |
+
# ind_f = torch.arange(nf).view(1, 1, 1, nf, 1) # (1, 1, 1, nf, 1)
|
| 178 |
+
|
| 179 |
+
ind_b = ind_b.expand(bs, npts, 2**dim, nf, 1)
|
| 180 |
+
ind_n = ind_n.expand(bs, npts, 2**dim, nf, dim).to(dev)
|
| 181 |
+
ind_f = ind_f.expand(bs, npts, 2**dim, nf, 1)
|
| 182 |
+
inds = torch.cat([ind_b, ind_f, ind_n], dim=-1) # (batch, num_points, 2**dim, nf, 1+1+dim)
|
| 183 |
+
|
| 184 |
+
# weighted values
|
| 185 |
+
vals = weights.unsqueeze(-1) * vals.unsqueeze(-2) # (batch, num_points, 2**dim, nf)
|
| 186 |
+
|
| 187 |
+
inds = inds.view(-1, dim+2).permute(1, 0).long() # (1+dim+1, bs*npts*2**dim*nf)
|
| 188 |
+
vals = vals.reshape(-1) # (bs*npts*2**dim*nf)
|
| 189 |
+
tensor_size = [bs, nf] + size_list
|
| 190 |
+
raster = scatter_to_grid(inds.permute(1, 0), vals, [bs, nf] + size_list)
|
| 191 |
+
|
| 192 |
+
return raster # [batch, nf, res, res, res]
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
##################################################
|
| 197 |
+
# Below are the utilization functions in general
|
| 198 |
+
|
| 199 |
+
class AverageMeter(object):
|
| 200 |
+
"""Computes and stores the average and current value"""
|
| 201 |
+
def __init__(self):
|
| 202 |
+
self.reset()
|
| 203 |
+
|
| 204 |
+
def reset(self):
|
| 205 |
+
self.val = 0
|
| 206 |
+
self.n = 0
|
| 207 |
+
self.avg = 0
|
| 208 |
+
self.sum = 0
|
| 209 |
+
self.count = 0
|
| 210 |
+
|
| 211 |
+
def update(self, val, n=1):
|
| 212 |
+
self.val = val
|
| 213 |
+
self.n = n
|
| 214 |
+
self.sum += val * n
|
| 215 |
+
self.count += n
|
| 216 |
+
self.avg = self.sum / self.count
|
| 217 |
+
|
| 218 |
+
@property
|
| 219 |
+
def valcavg(self):
|
| 220 |
+
return self.val.sum().item() / (self.n != 0).sum().item()
|
| 221 |
+
|
| 222 |
+
@property
|
| 223 |
+
def avgcavg(self):
|
| 224 |
+
return self.avg.sum().item() / (self.count != 0).sum().item()
|
| 225 |
+
|
| 226 |
+
def load_model_manual(state_dict, model):
|
| 227 |
+
new_state_dict = OrderedDict()
|
| 228 |
+
is_model_parallel = isinstance(model, torch.nn.DataParallel)
|
| 229 |
+
for k, v in state_dict.items():
|
| 230 |
+
if k.startswith('module.') != is_model_parallel:
|
| 231 |
+
if k.startswith('module.'):
|
| 232 |
+
# remove module
|
| 233 |
+
k = k[7:]
|
| 234 |
+
else:
|
| 235 |
+
# add module
|
| 236 |
+
k = 'module.' + k
|
| 237 |
+
|
| 238 |
+
new_state_dict[k]=v
|
| 239 |
+
|
| 240 |
+
model.load_state_dict(new_state_dict)
|
| 241 |
+
|
| 242 |
+
def mc_from_psr(psr_grid, pytorchify=False, real_scale=False, zero_level=0):
|
| 243 |
+
'''
|
| 244 |
+
Run marching cubes from PSR grid
|
| 245 |
+
'''
|
| 246 |
+
batch_size = psr_grid.shape[0]
|
| 247 |
+
s = psr_grid.shape[-1] # size of psr_grid
|
| 248 |
+
psr_grid_numpy = psr_grid.squeeze().detach().cpu().numpy()
|
| 249 |
+
|
| 250 |
+
if batch_size>1:
|
| 251 |
+
verts, faces, normals = [], [], []
|
| 252 |
+
for i in range(batch_size):
|
| 253 |
+
verts_cur, faces_cur, normals_cur, values = measure.marching_cubes(psr_grid_numpy[i], level=0)
|
| 254 |
+
verts.append(verts_cur)
|
| 255 |
+
faces.append(faces_cur)
|
| 256 |
+
normals.append(normals_cur)
|
| 257 |
+
verts = np.stack(verts, axis = 0)
|
| 258 |
+
faces = np.stack(faces, axis = 0)
|
| 259 |
+
normals = np.stack(normals, axis = 0)
|
| 260 |
+
else:
|
| 261 |
+
try:
|
| 262 |
+
verts, faces, normals, values = measure.marching_cubes(psr_grid_numpy, level=zero_level)
|
| 263 |
+
except:
|
| 264 |
+
verts, faces, normals, values = measure.marching_cubes(psr_grid_numpy)
|
| 265 |
+
if real_scale:
|
| 266 |
+
verts = verts / (s-1) # scale to range [0, 1]
|
| 267 |
+
else:
|
| 268 |
+
verts = verts / s # scale to range [0, 1)
|
| 269 |
+
|
| 270 |
+
if pytorchify:
|
| 271 |
+
device = psr_grid.device
|
| 272 |
+
verts = torch.Tensor(np.ascontiguousarray(verts)).to(device)
|
| 273 |
+
faces = torch.Tensor(np.ascontiguousarray(faces)).to(device)
|
| 274 |
+
normals = torch.Tensor(np.ascontiguousarray(-normals)).to(device)
|
| 275 |
+
|
| 276 |
+
return verts, faces, normals
|
| 277 |
+
|
| 278 |
+
def calc_inters_points(verts, faces, pose, img_size, mask_gt=None):
|
| 279 |
+
verts = verts.squeeze()
|
| 280 |
+
faces = faces.squeeze()
|
| 281 |
+
pix_to_face, w, mask = mesh_rasterization(verts, faces, pose, img_size)
|
| 282 |
+
if mask_gt is not None:
|
| 283 |
+
#! only evaluate within the intersection
|
| 284 |
+
mask = mask & mask_gt
|
| 285 |
+
# find 3D points intesected on the mesh
|
| 286 |
+
if True:
|
| 287 |
+
w_masked = w[mask]
|
| 288 |
+
f_p = faces[pix_to_face[mask]].long() # cooresponding faces for each pixel
|
| 289 |
+
# corresponding vertices for p_closest
|
| 290 |
+
v_a, v_b, v_c = verts[f_p[..., 0]], verts[f_p[..., 1]], verts[f_p[..., 2]]
|
| 291 |
+
|
| 292 |
+
# calculate the intersection point of each pixel and the mesh
|
| 293 |
+
p_inters = w_masked[..., 0, None] * v_a + \
|
| 294 |
+
w_masked[..., 1, None] * v_b + \
|
| 295 |
+
w_masked[..., 2, None] * v_c
|
| 296 |
+
else:
|
| 297 |
+
# backproject ndc to world coordinates using z-buffer
|
| 298 |
+
W, H = img_size[1], img_size[0]
|
| 299 |
+
xy = uv.to(mask.device)[mask]
|
| 300 |
+
x_ndc = 1 - (2*xy[:, 0]) / (W - 1)
|
| 301 |
+
y_ndc = 1 - (2*xy[:, 1]) / (H - 1)
|
| 302 |
+
z = zbuf.squeeze().reshape(H * W)[mask]
|
| 303 |
+
xy_depth = torch.stack((x_ndc, y_ndc, z), dim=1)
|
| 304 |
+
|
| 305 |
+
p_inters = pose.unproject_points(xy_depth, world_coordinates=True)
|
| 306 |
+
|
| 307 |
+
# if there are outlier points, we should remove it
|
| 308 |
+
if (p_inters.max()>1) | (p_inters.min()<-1):
|
| 309 |
+
mask_bound = (p_inters>=-1) & (p_inters<=1)
|
| 310 |
+
mask_bound = (mask_bound.sum(dim=-1)==3)
|
| 311 |
+
mask[mask==True] = mask_bound
|
| 312 |
+
p_inters = p_inters[mask_bound]
|
| 313 |
+
print('!!!!!find outlier!')
|
| 314 |
+
|
| 315 |
+
return p_inters, mask, f_p, w_masked
|
| 316 |
+
|
| 317 |
+
def mesh_rasterization(verts, faces, pose, img_size):
|
| 318 |
+
'''
|
| 319 |
+
Use PyTorch3D to rasterize the mesh given a camera
|
| 320 |
+
'''
|
| 321 |
+
transformed_v = pose.transform_points(verts.detach()) # world -> ndc coordinate system
|
| 322 |
+
if isinstance(pose, PerspectiveCameras):
|
| 323 |
+
transformed_v[..., 2] = 1/transformed_v[..., 2]
|
| 324 |
+
# find p_closest on mesh of each pixel via rasterization
|
| 325 |
+
transformed_mesh = Meshes(verts=[transformed_v], faces=[faces])
|
| 326 |
+
pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
|
| 327 |
+
transformed_mesh,
|
| 328 |
+
image_size=img_size,
|
| 329 |
+
blur_radius=0,
|
| 330 |
+
faces_per_pixel=1,
|
| 331 |
+
perspective_correct=False
|
| 332 |
+
)
|
| 333 |
+
pix_to_face = pix_to_face.reshape(1, -1) # B x reso x reso -> B x (reso x reso)
|
| 334 |
+
mask = pix_to_face.clone() != -1
|
| 335 |
+
mask = mask.squeeze()
|
| 336 |
+
pix_to_face = pix_to_face.squeeze()
|
| 337 |
+
w = bary_coords.reshape(-1, 3)
|
| 338 |
+
|
| 339 |
+
return pix_to_face, w, mask
|
| 340 |
+
|
| 341 |
+
def verts_on_largest_mesh(verts, faces):
|
| 342 |
+
'''
|
| 343 |
+
verts: Numpy array or Torch.Tensor (N, 3)
|
| 344 |
+
faces: Numpy array (N, 3)
|
| 345 |
+
'''
|
| 346 |
+
if torch.is_tensor(faces):
|
| 347 |
+
verts = verts.squeeze().detach().cpu().numpy()
|
| 348 |
+
faces = faces.squeeze().int().detach().cpu().numpy()
|
| 349 |
+
|
| 350 |
+
A = adjacency_matrix(faces)
|
| 351 |
+
num, conn_idx, conn_size = connected_components(A)
|
| 352 |
+
if num == 0:
|
| 353 |
+
v_large, f_large = verts, faces
|
| 354 |
+
else:
|
| 355 |
+
max_idx = conn_size.argmax() # find the index of the largest component
|
| 356 |
+
v_large = verts[conn_idx==max_idx] # keep points on the largest component
|
| 357 |
+
|
| 358 |
+
if True:
|
| 359 |
+
mesh_largest = trimesh.Trimesh(verts, faces)
|
| 360 |
+
connected_comp = mesh_largest.split(only_watertight=False)
|
| 361 |
+
mesh_largest = connected_comp[max_idx]
|
| 362 |
+
v_large, f_large = mesh_largest.vertices, mesh_largest.faces
|
| 363 |
+
v_large = v_large.astype(np.float32)
|
| 364 |
+
return v_large, f_large
|
| 365 |
+
|
| 366 |
+
def update_recursive(dict1, dict2):
|
| 367 |
+
''' Update two config dictionaries recursively.
|
| 368 |
+
|
| 369 |
+
Args:
|
| 370 |
+
dict1 (dict): first dictionary to be updated
|
| 371 |
+
dict2 (dict): second dictionary which entries should be used
|
| 372 |
+
|
| 373 |
+
'''
|
| 374 |
+
for k, v in dict2.items():
|
| 375 |
+
if k not in dict1:
|
| 376 |
+
dict1[k] = dict()
|
| 377 |
+
if isinstance(v, dict):
|
| 378 |
+
update_recursive(dict1[k], v)
|
| 379 |
+
else:
|
| 380 |
+
dict1[k] = v
|
| 381 |
+
|
| 382 |
+
def scale2onet(p, scale=1.2):
|
| 383 |
+
'''
|
| 384 |
+
Scale the point cloud from SAP to ONet range
|
| 385 |
+
'''
|
| 386 |
+
return (p - 0.5) * scale
|
| 387 |
+
|
| 388 |
+
def update_optimizer(inputs, cfg, epoch, model=None, schedule=None):
|
| 389 |
+
if model is not None:
|
| 390 |
+
if schedule is not None:
|
| 391 |
+
optimizer = torch.optim.Adam([
|
| 392 |
+
{"params": model.parameters(),
|
| 393 |
+
"lr": schedule[0].get_learning_rate(epoch)},
|
| 394 |
+
{"params": inputs,
|
| 395 |
+
"lr": schedule[1].get_learning_rate(epoch)}])
|
| 396 |
+
elif 'lr' in cfg['train']:
|
| 397 |
+
optimizer = torch.optim.Adam([
|
| 398 |
+
{"params": model.parameters(),
|
| 399 |
+
"lr": float(cfg['train']['lr'])},
|
| 400 |
+
{"params": inputs,
|
| 401 |
+
"lr": float(cfg['train']['lr_pcl'])}])
|
| 402 |
+
else:
|
| 403 |
+
raise Exception('no known learning rate')
|
| 404 |
+
else:
|
| 405 |
+
if schedule is not None:
|
| 406 |
+
optimizer = torch.optim.Adam([inputs], lr=schedule[0].get_learning_rate(epoch))
|
| 407 |
+
else:
|
| 408 |
+
optimizer = torch.optim.Adam([inputs], lr=float(cfg['train']['lr_pcl']))
|
| 409 |
+
|
| 410 |
+
return optimizer
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
def is_url(url):
|
| 414 |
+
scheme = urllib.parse.urlparse(url).scheme
|
| 415 |
+
return scheme in ('http', 'https')
|
| 416 |
+
|
| 417 |
+
def load_url(url):
|
| 418 |
+
'''Load a module dictionary from url.
|
| 419 |
+
|
| 420 |
+
Args:
|
| 421 |
+
url (str): url to saved model
|
| 422 |
+
'''
|
| 423 |
+
print(url)
|
| 424 |
+
print('=> Loading checkpoint from url...')
|
| 425 |
+
state_dict = model_zoo.load_url(url, progress=True)
|
| 426 |
+
|
| 427 |
+
return state_dict
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
class GaussianSmoothing(nn.Module):
|
| 431 |
+
"""
|
| 432 |
+
Apply gaussian smoothing on a
|
| 433 |
+
1d, 2d or 3d tensor. Filtering is performed seperately for each channel
|
| 434 |
+
in the input using a depthwise convolution.
|
| 435 |
+
Arguments:
|
| 436 |
+
channels (int, sequence): Number of channels of the input tensors. Output will have this number of channels as well.
|
| 437 |
+
kernel_size (int, sequence): Size of the gaussian kernel.
|
| 438 |
+
sigma (float, sequence): Standard deviation of the gaussian kernel.
|
| 439 |
+
dim (int, optional): The number of dimensions of the data.
|
| 440 |
+
Default value is 2 (spatial).
|
| 441 |
+
"""
|
| 442 |
+
def __init__(self, channels, kernel_size, sigma, dim=3):
|
| 443 |
+
super(GaussianSmoothing, self).__init__()
|
| 444 |
+
if isinstance(kernel_size, numbers.Number):
|
| 445 |
+
kernel_size = [kernel_size] * dim
|
| 446 |
+
if isinstance(sigma, numbers.Number):
|
| 447 |
+
sigma = [sigma] * dim
|
| 448 |
+
|
| 449 |
+
# The gaussian kernel is the product of the
|
| 450 |
+
# gaussian function of each dimension.
|
| 451 |
+
kernel = 1
|
| 452 |
+
meshgrids = torch.meshgrid(
|
| 453 |
+
[
|
| 454 |
+
torch.arange(size, dtype=torch.float32)
|
| 455 |
+
for size in kernel_size
|
| 456 |
+
]
|
| 457 |
+
)
|
| 458 |
+
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
|
| 459 |
+
mean = (size - 1) / 2
|
| 460 |
+
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \
|
| 461 |
+
torch.exp(-((mgrid - mean) / std) ** 2 / 2)
|
| 462 |
+
|
| 463 |
+
# Make sure sum of values in gaussian kernel equals 1.
|
| 464 |
+
kernel = kernel / torch.sum(kernel)
|
| 465 |
+
|
| 466 |
+
# Reshape to depthwise convolutional weight
|
| 467 |
+
kernel = kernel.view(1, 1, *kernel.size())
|
| 468 |
+
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
|
| 469 |
+
|
| 470 |
+
self.register_buffer('weight', kernel)
|
| 471 |
+
self.groups = channels
|
| 472 |
+
|
| 473 |
+
if dim == 1:
|
| 474 |
+
self.conv = F.conv1d
|
| 475 |
+
elif dim == 2:
|
| 476 |
+
self.conv = F.conv2d
|
| 477 |
+
elif dim == 3:
|
| 478 |
+
self.conv = F.conv3d
|
| 479 |
+
else:
|
| 480 |
+
raise RuntimeError(
|
| 481 |
+
'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim)
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
def forward(self, input):
|
| 485 |
+
"""
|
| 486 |
+
Apply gaussian filter to input.
|
| 487 |
+
Arguments:
|
| 488 |
+
input (torch.Tensor): Input to apply gaussian filter on.
|
| 489 |
+
Returns:
|
| 490 |
+
filtered (torch.Tensor): Filtered output.
|
| 491 |
+
"""
|
| 492 |
+
return self.conv(input, weight=self.weight, groups=self.groups)
|
| 493 |
+
|
| 494 |
+
# Originally from https://github.com/amosgropp/IGR/blob/0db06b1273/code/utils/general.py
|
| 495 |
+
def get_learning_rate_schedules(schedule_specs):
|
| 496 |
+
|
| 497 |
+
schedules = []
|
| 498 |
+
|
| 499 |
+
for key in schedule_specs.keys():
|
| 500 |
+
schedules.append(StepLearningRateSchedule(
|
| 501 |
+
schedule_specs[key]['initial'],
|
| 502 |
+
schedule_specs[key]["interval"],
|
| 503 |
+
schedule_specs[key]["factor"],
|
| 504 |
+
schedule_specs[key]["final"]))
|
| 505 |
+
return schedules
|
| 506 |
+
|
| 507 |
+
class LearningRateSchedule:
|
| 508 |
+
def get_learning_rate(self, epoch):
|
| 509 |
+
pass
|
| 510 |
+
class StepLearningRateSchedule(LearningRateSchedule):
|
| 511 |
+
def __init__(self, initial, interval, factor, final=1e-6):
|
| 512 |
+
self.initial = float(initial)
|
| 513 |
+
self.interval = interval
|
| 514 |
+
self.factor = factor
|
| 515 |
+
self.final = float(final)
|
| 516 |
+
|
| 517 |
+
def get_learning_rate(self, epoch):
|
| 518 |
+
lr = np.maximum(self.initial * (self.factor ** (epoch // self.interval)), 5.0e-6)
|
| 519 |
+
if lr > self.final:
|
| 520 |
+
return lr
|
| 521 |
+
else:
|
| 522 |
+
return self.final
|
| 523 |
+
|
| 524 |
+
def adjust_learning_rate(lr_schedules, optimizer, epoch):
|
| 525 |
+
for i, param_group in enumerate(optimizer.param_groups):
|
| 526 |
+
param_group["lr"] = lr_schedules[i].get_learning_rate(epoch)
|
models/__init__.py
ADDED
|
File without changes
|
models/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (129 Bytes). View file
|
|
|
models/__pycache__/model.cpython-39.pyc
ADDED
|
Binary file (5.56 kB). View file
|
|
|
models/model.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from statistics import mean
|
| 3 |
+
from typing import List, Tuple
|
| 4 |
+
|
| 5 |
+
import torch as th
|
| 6 |
+
import pytorch_lightning as pl
|
| 7 |
+
from jaxtyping import Float, Int
|
| 8 |
+
import numpy as np
|
| 9 |
+
from torch_geometric.nn.conv import GATv2Conv
|
| 10 |
+
|
| 11 |
+
from models.SAP.dpsr import DPSR
|
| 12 |
+
from models.SAP.model import PSR2Mesh
|
| 13 |
+
|
| 14 |
+
# Constants
|
| 15 |
+
|
| 16 |
+
th.manual_seed(0)
|
| 17 |
+
np.random.seed(0)
|
| 18 |
+
|
| 19 |
+
BATCH_SIZE = 1 # BS
|
| 20 |
+
|
| 21 |
+
IN_DIM = 1
|
| 22 |
+
OUT_DIM = 1
|
| 23 |
+
LATENT_DIM = 32
|
| 24 |
+
|
| 25 |
+
DROPOUT_PROB = 0.1
|
| 26 |
+
GRID_SIZE = 128
|
| 27 |
+
|
| 28 |
+
def generate_grid_edge_list(gs: int = 128):
|
| 29 |
+
grid_edge_list = []
|
| 30 |
+
|
| 31 |
+
for k in range(gs):
|
| 32 |
+
for j in range(gs):
|
| 33 |
+
for i in range(gs):
|
| 34 |
+
current_idx = i + gs*j + k*gs*gs
|
| 35 |
+
if (i - 1) >= 0:
|
| 36 |
+
grid_edge_list.append([current_idx, i-1 + gs*j + k*gs*gs])
|
| 37 |
+
if (i + 1) < gs:
|
| 38 |
+
grid_edge_list.append([current_idx, i+1 + gs*j + k*gs*gs])
|
| 39 |
+
if (j - 1) >= 0:
|
| 40 |
+
grid_edge_list.append([current_idx, i + gs*(j-1) + k*gs*gs])
|
| 41 |
+
if (j + 1) < gs:
|
| 42 |
+
grid_edge_list.append([current_idx, i + gs*(j+1) + k*gs*gs])
|
| 43 |
+
if (k - 1) >= 0:
|
| 44 |
+
grid_edge_list.append([current_idx, i + gs*j + (k-1)*gs*gs])
|
| 45 |
+
if (k + 1) < gs:
|
| 46 |
+
grid_edge_list.append([current_idx, i + gs*j + (k+1)*gs*gs])
|
| 47 |
+
return grid_edge_list
|
| 48 |
+
|
| 49 |
+
GRID_EDGE_LIST = generate_grid_edge_list(GRID_SIZE)
|
| 50 |
+
GRID_EDGE_LIST = th.tensor(GRID_EDGE_LIST, dtype=th.int)
|
| 51 |
+
GRID_EDGE_LIST = GRID_EDGE_LIST.T
|
| 52 |
+
# GRID_EDGE_LIST = GRID_EDGE_LIST.to(th.device("cuda"))
|
| 53 |
+
GRID_EDGE_LIST.requires_grad = False # Do not forget to delete it if train
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class FormOptimizer(th.nn.Module):
|
| 57 |
+
def __init__(self) -> None:
|
| 58 |
+
super().__init__()
|
| 59 |
+
|
| 60 |
+
layers = []
|
| 61 |
+
|
| 62 |
+
self.gconv1 = GATv2Conv(in_channels=IN_DIM, out_channels=LATENT_DIM, heads=1, dropout=DROPOUT_PROB)
|
| 63 |
+
self.gconv2 = GATv2Conv(in_channels=LATENT_DIM, out_channels=LATENT_DIM, heads=1, dropout=DROPOUT_PROB)
|
| 64 |
+
|
| 65 |
+
self.actv = th.nn.Sigmoid()
|
| 66 |
+
self.head = th.nn.Linear(in_features=LATENT_DIM, out_features=OUT_DIM)
|
| 67 |
+
|
| 68 |
+
def forward(self,
|
| 69 |
+
field: Float[th.Tensor, "GS GS GS"]) -> Float[th.Tensor, "GS GS GS"]:
|
| 70 |
+
"""
|
| 71 |
+
Args:
|
| 72 |
+
field (Tensor [GS, GS, GS]): vertices and normals tensor.
|
| 73 |
+
"""
|
| 74 |
+
vertex_features = field.clone()
|
| 75 |
+
vertex_features = vertex_features.reshape(GRID_SIZE*GRID_SIZE*GRID_SIZE, IN_DIM)
|
| 76 |
+
|
| 77 |
+
vertex_features = self.gconv1(x=vertex_features, edge_index=GRID_EDGE_LIST)
|
| 78 |
+
vertex_features = self.gconv2(x=vertex_features, edge_index=GRID_EDGE_LIST)
|
| 79 |
+
field_delta = self.head(self.actv(vertex_features))
|
| 80 |
+
|
| 81 |
+
field_delta = field_delta.reshape(BATCH_SIZE, GRID_SIZE, GRID_SIZE, GRID_SIZE)
|
| 82 |
+
field_delta += field # field_delta carries the gradient
|
| 83 |
+
field_delta = th.clamp(field_delta, min=-0.5, max=0.5)
|
| 84 |
+
|
| 85 |
+
return field_delta
|
| 86 |
+
|
| 87 |
+
class Model(pl.LightningModule):
|
| 88 |
+
def __init__(self):
|
| 89 |
+
super().__init__()
|
| 90 |
+
self.form_optimizer = FormOptimizer()
|
| 91 |
+
|
| 92 |
+
self.dpsr = DPSR([GRID_SIZE, GRID_SIZE, GRID_SIZE], sig=0.0)
|
| 93 |
+
self.field2mesh = PSR2Mesh().apply
|
| 94 |
+
|
| 95 |
+
self.metric = th.nn.MSELoss()
|
| 96 |
+
|
| 97 |
+
self.val_losses = []
|
| 98 |
+
self.train_losses = []
|
| 99 |
+
|
| 100 |
+
def log_h5(self, points, normals):
|
| 101 |
+
dset = self.log_points_file.create_dataset(
|
| 102 |
+
name=str(self.h5_frame),
|
| 103 |
+
shape=points.shape,
|
| 104 |
+
dtype=np.float16,
|
| 105 |
+
compression="gzip")
|
| 106 |
+
dset[:] = points
|
| 107 |
+
dset = self.log_normals_file.create_dataset(
|
| 108 |
+
name=str(self.h5_frame),
|
| 109 |
+
shape=normals.shape,
|
| 110 |
+
dtype=np.float16,
|
| 111 |
+
compression="gzip")
|
| 112 |
+
dset[:] = normals
|
| 113 |
+
self.h5_frame += 1
|
| 114 |
+
|
| 115 |
+
def forward(self,
|
| 116 |
+
v: Float[th.Tensor, "BS N 3"],
|
| 117 |
+
n: Float[th.Tensor, "BS N 3"]) -> Tuple[Float[th.Tensor, "BS N 3"], # v - vertices
|
| 118 |
+
Int[th.Tensor, "2 E"], # f - faces
|
| 119 |
+
Float[th.Tensor, "BS N 3"], # n - vertices normals
|
| 120 |
+
Float[th.Tensor, "BS GR GR GR"]]: # field:
|
| 121 |
+
field = self.dpsr(v, n)
|
| 122 |
+
field = self.form_optimizer(field)
|
| 123 |
+
v, f, n = self.field2mesh(field)
|
| 124 |
+
return v, f, n, field
|
| 125 |
+
|
| 126 |
+
def training_step(self, batch, batch_idx) -> Float[th.Tensor, "1"]:
|
| 127 |
+
vertices, vertices_normals, vertices_gt, vertices_normals_gt, field_gt, adj = batch
|
| 128 |
+
|
| 129 |
+
mask = th.rand((vertices.shape[1], ), device=th.device("cuda")) < (random.random() / 2.0 + 0.5)
|
| 130 |
+
vertices = vertices[:, mask]
|
| 131 |
+
vertices_normals = vertices_normals[:, mask]
|
| 132 |
+
|
| 133 |
+
vr, fr, nr, field_r = model(vertices, vertices_normals)
|
| 134 |
+
|
| 135 |
+
loss = self.metric(field_r, field_gt)
|
| 136 |
+
train_per_step_loss = loss.item()
|
| 137 |
+
self.train_losses.append(train_per_step_loss)
|
| 138 |
+
|
| 139 |
+
return loss
|
| 140 |
+
|
| 141 |
+
def on_train_epoch_end(self):
|
| 142 |
+
mean_train_per_epoch_loss = mean(self.train_losses)
|
| 143 |
+
self.log("mean_train_per_epoch_loss", mean_train_per_epoch_loss, on_step=False, on_epoch=True)
|
| 144 |
+
self.train_losses = []
|
| 145 |
+
|
| 146 |
+
def validation_step(self, batch, batch_idx):
|
| 147 |
+
vertices, vertices_normals, vertices_gt, vertices_normals_gt, field_gt, adj = batch
|
| 148 |
+
|
| 149 |
+
vr, fr, nr, field_r = model(vertices, vertices_normals)
|
| 150 |
+
|
| 151 |
+
loss = self.metric(field_r, field_gt)
|
| 152 |
+
val_per_step_loss = loss.item()
|
| 153 |
+
self.val_losses.append(val_per_step_loss)
|
| 154 |
+
return loss
|
| 155 |
+
|
| 156 |
+
def on_validation_epoch_end(self):
|
| 157 |
+
mean_val_per_epoch_loss = mean(self.val_losses)
|
| 158 |
+
self.log("mean_val_per_epoch_loss", mean_val_per_epoch_loss, on_step=False, on_epoch=True)
|
| 159 |
+
self.val_losses = []
|
| 160 |
+
|
| 161 |
+
def configure_optimizers(self):
|
| 162 |
+
optimizer = th.optim.Adam(self.parameters(), lr=LR)
|
| 163 |
+
scheduler = th.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)
|
| 164 |
+
|
| 165 |
+
return {
|
| 166 |
+
"optimizer": optimizer,
|
| 167 |
+
"lr_scheduler": {
|
| 168 |
+
"scheduler": scheduler,
|
| 169 |
+
"monitor": "mean_val_per_epoch_loss",
|
| 170 |
+
"interval": "epoch",
|
| 171 |
+
"frequency": 1,
|
| 172 |
+
# If set to `True`, will enforce that the value specified 'monitor'
|
| 173 |
+
# is available when the scheduler is updated, thus stopping
|
| 174 |
+
# training if not found. If set to `False`, it will only produce a warning
|
| 175 |
+
"strict": True,
|
| 176 |
+
# If using the `LearningRateMonitor` callback to monitor the
|
| 177 |
+
# learning rate progress, this keyword can be used to specify
|
| 178 |
+
# a custom logged name
|
| 179 |
+
"name": None,
|
| 180 |
+
}
|
| 181 |
+
}
|
requirements.txt
ADDED
|
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
aiohttp==3.8.4
|
| 2 |
+
aiosignal==1.3.1
|
| 3 |
+
ansicon==1.89.0
|
| 4 |
+
anyio==3.6.2
|
| 5 |
+
arrow==1.2.3
|
| 6 |
+
asttokens==2.2.1
|
| 7 |
+
async-timeout==4.0.2
|
| 8 |
+
attrs==23.1.0
|
| 9 |
+
backcall==0.2.0
|
| 10 |
+
beautifulsoup4==4.12.2
|
| 11 |
+
blessed==1.20.0
|
| 12 |
+
blinker==1.6.2
|
| 13 |
+
certifi==2023.5.7
|
| 14 |
+
charset-normalizer==3.1.0
|
| 15 |
+
click==8.1.3
|
| 16 |
+
colorama==0.4.6
|
| 17 |
+
comm==0.1.3
|
| 18 |
+
ConfigArgParse==1.5.3
|
| 19 |
+
croniter==1.3.14
|
| 20 |
+
dash==2.9.3
|
| 21 |
+
dash-core-components==2.0.0
|
| 22 |
+
dash-html-components==2.0.0
|
| 23 |
+
dash-table==5.0.0
|
| 24 |
+
dateutils==0.6.12
|
| 25 |
+
debugpy==1.6.7
|
| 26 |
+
decorator==5.1.1
|
| 27 |
+
deepdiff==6.3.0
|
| 28 |
+
executing==1.2.0
|
| 29 |
+
fastapi==0.88.0
|
| 30 |
+
fastjsonschema==2.16.3
|
| 31 |
+
Flask==2.3.2
|
| 32 |
+
frozenlist==1.3.3
|
| 33 |
+
fsspec==2023.5.0
|
| 34 |
+
fvcore==0.1.5.post20221221
|
| 35 |
+
h11==0.14.0
|
| 36 |
+
idna==3.4
|
| 37 |
+
imageio==2.28.1
|
| 38 |
+
importlib-metadata==6.6.0
|
| 39 |
+
inquirer==3.1.3
|
| 40 |
+
iopath==0.1.10
|
| 41 |
+
ipykernel==6.23.1
|
| 42 |
+
ipython==8.13.2
|
| 43 |
+
ipywidgets==8.0.6
|
| 44 |
+
itsdangerous==2.1.2
|
| 45 |
+
jaxtyping==0.2.19
|
| 46 |
+
jedi==0.18.2
|
| 47 |
+
Jinja2==3.1.2
|
| 48 |
+
jinxed==1.2.0
|
| 49 |
+
joblib==1.2.0
|
| 50 |
+
jsonschema==4.17.3
|
| 51 |
+
jupyter_client==8.2.0
|
| 52 |
+
jupyter_core==5.3.0
|
| 53 |
+
jupyterlab-widgets==3.0.7
|
| 54 |
+
lazy_loader==0.2
|
| 55 |
+
libigl==2.4.1
|
| 56 |
+
lightning==2.0.2
|
| 57 |
+
lightning-cloud==0.5.36
|
| 58 |
+
lightning-utilities==0.8.0
|
| 59 |
+
markdown-it-py==2.2.0
|
| 60 |
+
MarkupSafe==2.1.2
|
| 61 |
+
matplotlib-inline==0.1.6
|
| 62 |
+
mdurl==0.1.2
|
| 63 |
+
multidict==6.0.4
|
| 64 |
+
nbformat==5.7.0
|
| 65 |
+
nest-asyncio==1.5.6
|
| 66 |
+
networkx==3.1
|
| 67 |
+
numpy==1.24.3
|
| 68 |
+
open3d==0.17.0
|
| 69 |
+
ordered-set==4.1.0
|
| 70 |
+
packaging==23.1
|
| 71 |
+
parso==0.8.3
|
| 72 |
+
pickleshare==0.7.5
|
| 73 |
+
Pillow==9.5.0
|
| 74 |
+
platformdirs==3.5.1
|
| 75 |
+
plotly==5.14.1
|
| 76 |
+
plyfile==0.9
|
| 77 |
+
portalocker==2.7.0
|
| 78 |
+
prompt-toolkit==3.0.38
|
| 79 |
+
psutil==5.9.5
|
| 80 |
+
pure-eval==0.2.2
|
| 81 |
+
pydantic==1.10.7
|
| 82 |
+
Pygments==2.15.1
|
| 83 |
+
PyJWT==2.7.0
|
| 84 |
+
pyparsing==3.0.9
|
| 85 |
+
pyrsistent==0.19.3
|
| 86 |
+
PySimpleGUI==4.60.4
|
| 87 |
+
python-dateutil==2.8.2
|
| 88 |
+
python-editor==1.0.4
|
| 89 |
+
python-multipart==0.0.6
|
| 90 |
+
pytorch-lightning==2.0.2
|
| 91 |
+
pytz==2023.3
|
| 92 |
+
PyWavelets==1.4.1
|
| 93 |
+
pywin32==306
|
| 94 |
+
PyYAML==6.0
|
| 95 |
+
pyzmq==25.0.2
|
| 96 |
+
readchar==4.0.5
|
| 97 |
+
requests==2.30.0
|
| 98 |
+
rich==13.3.5
|
| 99 |
+
scikit-image==0.20.0
|
| 100 |
+
scikit-learn==1.2.2
|
| 101 |
+
scipy==1.9.1
|
| 102 |
+
six==1.16.0
|
| 103 |
+
sniffio==1.3.0
|
| 104 |
+
soupsieve==2.4.1
|
| 105 |
+
stack-data==0.6.2
|
| 106 |
+
starlette==0.22.0
|
| 107 |
+
starsessions==1.3.0
|
| 108 |
+
tabulate==0.9.0
|
| 109 |
+
tenacity==8.2.2
|
| 110 |
+
termcolor==2.3.0
|
| 111 |
+
threadpoolctl==3.1.0
|
| 112 |
+
tifffile==2023.4.12
|
| 113 |
+
torch==1.13.1+cu116
|
| 114 |
+
torch-cluster==1.6.1+pt113cu116
|
| 115 |
+
torch-geometric==2.3.1
|
| 116 |
+
torch-scatter==2.1.1+pt113cu116
|
| 117 |
+
torch-sparse==0.6.17+pt113cu116
|
| 118 |
+
torch-spline-conv==1.2.2+pt113cu116
|
| 119 |
+
torchaudio==0.13.1
|
| 120 |
+
torchmetrics==0.11.4
|
| 121 |
+
torchvision==0.14.1+cu116
|
| 122 |
+
tornado==6.3.2
|
| 123 |
+
tqdm==4.65.0
|
| 124 |
+
traitlets==5.9.0
|
| 125 |
+
trimesh==3.21.6
|
| 126 |
+
typeguard==4.0.0
|
| 127 |
+
typing_extensions==4.5.0
|
| 128 |
+
urllib3==2.0.2
|
| 129 |
+
uvicorn==0.22.0
|
| 130 |
+
wcwidth==0.2.6
|
| 131 |
+
websocket-client==1.5.1
|
| 132 |
+
websockets==11.0.3
|
| 133 |
+
Werkzeug==2.3.4
|
| 134 |
+
widgetsnbextension==4.0.7
|
| 135 |
+
yacs==0.1.8
|
| 136 |
+
yarl==1.9.2
|
| 137 |
+
zipp==3.15.0
|
| 138 |
+
aiofiles==23.1.0
|
| 139 |
+
aiohttp==3.8.4
|
| 140 |
+
aiosignal==1.3.1
|
| 141 |
+
altair==5.0.0
|
| 142 |
+
ansicon==1.89.0
|
| 143 |
+
anyio==3.6.2
|
| 144 |
+
arrow==1.2.3
|
| 145 |
+
asttokens==2.2.1
|
| 146 |
+
async-timeout==4.0.2
|
| 147 |
+
attrs==23.1.0
|
| 148 |
+
backcall==0.2.0
|
| 149 |
+
beautifulsoup4==4.12.2
|
| 150 |
+
blessed==1.20.0
|
| 151 |
+
blinker==1.6.2
|
| 152 |
+
certifi==2023.5.7
|
| 153 |
+
charset-normalizer==3.1.0
|
| 154 |
+
click==8.1.3
|
| 155 |
+
colorama==0.4.6
|
| 156 |
+
comm==0.1.3
|
| 157 |
+
ConfigArgParse==1.5.3
|
| 158 |
+
contourpy==1.0.7
|
| 159 |
+
croniter==1.3.14
|
| 160 |
+
cycler==0.11.0
|
| 161 |
+
dash==2.9.3
|
| 162 |
+
dash-core-components==2.0.0
|
| 163 |
+
dash-html-components==2.0.0
|
| 164 |
+
dash-table==5.0.0
|
| 165 |
+
dateutils==0.6.12
|
| 166 |
+
debugpy==1.6.7
|
| 167 |
+
decorator==5.1.1
|
| 168 |
+
deepdiff==6.3.0
|
| 169 |
+
executing==1.2.0
|
| 170 |
+
fastapi==0.88.0
|
| 171 |
+
fastjsonschema==2.16.3
|
| 172 |
+
ffmpy==0.3.0
|
| 173 |
+
filelock==3.12.0
|
| 174 |
+
Flask==2.3.2
|
| 175 |
+
fonttools==4.39.4
|
| 176 |
+
frozenlist==1.3.3
|
| 177 |
+
fsspec==2023.5.0
|
| 178 |
+
fvcore==0.1.5.post20221221
|
| 179 |
+
gradio==3.30.0
|
| 180 |
+
gradio_client==0.2.5
|
| 181 |
+
h11==0.14.0
|
| 182 |
+
httpcore==0.17.0
|
| 183 |
+
httpx==0.24.0
|
| 184 |
+
huggingface-hub==0.14.1
|
| 185 |
+
idna==3.4
|
| 186 |
+
imageio==2.28.1
|
| 187 |
+
importlib-metadata==6.6.0
|
| 188 |
+
importlib-resources==5.12.0
|
| 189 |
+
inquirer==3.1.3
|
| 190 |
+
iopath==0.1.10
|
| 191 |
+
ipykernel==6.23.1
|
| 192 |
+
ipython==8.13.2
|
| 193 |
+
ipywidgets==8.0.6
|
| 194 |
+
itsdangerous==2.1.2
|
| 195 |
+
jaxtyping==0.2.19
|
| 196 |
+
jedi==0.18.2
|
| 197 |
+
Jinja2==3.1.2
|
| 198 |
+
jinxed==1.2.0
|
| 199 |
+
joblib==1.2.0
|
| 200 |
+
jsonschema==4.17.3
|
| 201 |
+
jupyter_client==8.2.0
|
| 202 |
+
jupyter_core==5.3.0
|
| 203 |
+
jupyterlab-widgets==3.0.7
|
| 204 |
+
kiwisolver==1.4.4
|
| 205 |
+
lazy_loader==0.2
|
| 206 |
+
libigl==2.4.1
|
| 207 |
+
lightning==2.0.2
|
| 208 |
+
lightning-cloud==0.5.36
|
| 209 |
+
lightning-utilities==0.8.0
|
| 210 |
+
linkify-it-py==2.0.2
|
| 211 |
+
markdown-it-py==2.2.0
|
| 212 |
+
MarkupSafe==2.1.2
|
| 213 |
+
matplotlib==3.7.1
|
| 214 |
+
matplotlib-inline==0.1.6
|
| 215 |
+
mdit-py-plugins==0.3.3
|
| 216 |
+
mdurl==0.1.2
|
| 217 |
+
multidict==6.0.4
|
| 218 |
+
nbformat==5.7.0
|
| 219 |
+
nest-asyncio==1.5.6
|
| 220 |
+
networkx==3.1
|
| 221 |
+
numpy==1.24.3
|
| 222 |
+
open3d==0.17.0
|
| 223 |
+
ordered-set==4.1.0
|
| 224 |
+
orjson==3.8.12
|
| 225 |
+
packaging==23.1
|
| 226 |
+
pandas==2.0.1
|
| 227 |
+
parso==0.8.3
|
| 228 |
+
pickleshare==0.7.5
|
| 229 |
+
Pillow==9.5.0
|
| 230 |
+
platformdirs==3.5.1
|
| 231 |
+
plotly==5.14.1
|
| 232 |
+
plyfile==0.9
|
| 233 |
+
portalocker==2.7.0
|
| 234 |
+
prompt-toolkit==3.0.38
|
| 235 |
+
psutil==5.9.5
|
| 236 |
+
pure-eval==0.2.2
|
| 237 |
+
pydantic==1.10.7
|
| 238 |
+
pydub==0.25.1
|
| 239 |
+
Pygments==2.15.1
|
| 240 |
+
PyJWT==2.7.0
|
| 241 |
+
pyparsing==3.0.9
|
| 242 |
+
pyrsistent==0.19.3
|
| 243 |
+
PySimpleGUI==4.60.4
|
| 244 |
+
python-dateutil==2.8.2
|
| 245 |
+
python-editor==1.0.4
|
| 246 |
+
python-multipart==0.0.6
|
| 247 |
+
pytorch-lightning==2.0.2
|
| 248 |
+
pytz==2023.3
|
| 249 |
+
PyWavelets==1.4.1
|
| 250 |
+
pywin32==306
|
| 251 |
+
PyYAML==6.0
|
| 252 |
+
pyzmq==25.0.2
|
| 253 |
+
readchar==4.0.5
|
| 254 |
+
requests==2.30.0
|
| 255 |
+
rich==13.3.5
|
| 256 |
+
scikit-image==0.20.0
|
| 257 |
+
scikit-learn==1.2.2
|
| 258 |
+
scipy==1.9.1
|
| 259 |
+
semantic-version==2.10.0
|
| 260 |
+
six==1.16.0
|
| 261 |
+
sniffio==1.3.0
|
| 262 |
+
soupsieve==2.4.1
|
| 263 |
+
stack-data==0.6.2
|
| 264 |
+
starlette==0.22.0
|
| 265 |
+
starsessions==1.3.0
|
| 266 |
+
tabulate==0.9.0
|
| 267 |
+
tenacity==8.2.2
|
| 268 |
+
termcolor==2.3.0
|
| 269 |
+
threadpoolctl==3.1.0
|
| 270 |
+
tifffile==2023.4.12
|
| 271 |
+
toolz==0.12.0
|
| 272 |
+
torch==1.13.1+cu116
|
| 273 |
+
torch-cluster==1.6.1+pt113cu116
|
| 274 |
+
torch-geometric==2.3.1
|
| 275 |
+
torch-scatter==2.1.1+pt113cu116
|
| 276 |
+
torch-sparse==0.6.17+pt113cu116
|
| 277 |
+
torch-spline-conv==1.2.2+pt113cu116
|
| 278 |
+
torchaudio==0.13.1
|
| 279 |
+
torchmetrics==0.11.4
|
| 280 |
+
torchvision==0.14.1+cu116
|
| 281 |
+
tornado==6.3.2
|
| 282 |
+
tqdm==4.65.0
|
| 283 |
+
traitlets==5.9.0
|
| 284 |
+
trimesh==3.21.6
|
| 285 |
+
typeguard==4.0.0
|
| 286 |
+
typing_extensions==4.5.0
|
| 287 |
+
tzdata==2023.3
|
| 288 |
+
uc-micro-py==1.0.2
|
| 289 |
+
urllib3==2.0.2
|
| 290 |
+
uvicorn==0.22.0
|
| 291 |
+
wcwidth==0.2.6
|
| 292 |
+
websocket-client==1.5.1
|
| 293 |
+
websockets==11.0.3
|
| 294 |
+
Werkzeug==2.3.4
|
| 295 |
+
widgetsnbextension==4.0.7
|
| 296 |
+
yacs==0.1.8
|
| 297 |
+
yarl==1.9.2
|
| 298 |
+
zipp==3.15.0
|