Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import numpy as np | |
| import time | |
| import torch | |
| import torch.nn as nn | |
| from torch.autograd import Function | |
| import torch.nn.functional as F | |
| try: | |
| from . import utilslib | |
| except: | |
| import utilslib | |
| class ComputeRaydirs(Function): | |
| def forward(self, viewpos, viewrot, focal, princpt, pixelcoords, volradius): | |
| for tensor in [viewpos, viewrot, focal, princpt, pixelcoords]: | |
| assert tensor.is_contiguous() | |
| N = viewpos.size(0) | |
| if isinstance(pixelcoords, tuple): | |
| W, H = pixelcoords | |
| pixelcoords = None | |
| else: | |
| H = pixelcoords.size(1) | |
| W = pixelcoords.size(2) | |
| raypos = torch.empty((N, H, W, 3), device=viewpos.device) | |
| raydirs = torch.empty((N, H, W, 3), device=viewpos.device) | |
| tminmax = torch.empty((N, H, W, 2), device=viewpos.device) | |
| utilslib.compute_raydirs_forward(viewpos, viewrot, focal, princpt, | |
| pixelcoords, W, H, volradius, raypos, raydirs, tminmax) | |
| return raypos, raydirs, tminmax | |
| def backward(self, grad_raydirs, grad_tminmax): | |
| return None, None, None, None, None, None | |
| def compute_raydirs(viewpos, viewrot, focal, princpt, pixelcoords, volradius): | |
| raypos, raydirs, tminmax = ComputeRaydirs.apply(viewpos, viewrot, focal, princpt, pixelcoords, volradius) | |
| return raypos, raydirs, tminmax | |
| class Rodrigues(nn.Module): | |
| def __init__(self): | |
| super(Rodrigues, self).__init__() | |
| def forward(self, rvec): | |
| theta = torch.sqrt(1e-5 + torch.sum(rvec ** 2, dim=1)) | |
| rvec = rvec / theta[:, None] | |
| costh = torch.cos(theta) | |
| sinth = torch.sin(theta) | |
| return torch.stack(( | |
| rvec[:, 0] ** 2 + (1. - rvec[:, 0] ** 2) * costh, | |
| rvec[:, 0] * rvec[:, 1] * (1. - costh) - rvec[:, 2] * sinth, | |
| rvec[:, 0] * rvec[:, 2] * (1. - costh) + rvec[:, 1] * sinth, | |
| rvec[:, 0] * rvec[:, 1] * (1. - costh) + rvec[:, 2] * sinth, | |
| rvec[:, 1] ** 2 + (1. - rvec[:, 1] ** 2) * costh, | |
| rvec[:, 1] * rvec[:, 2] * (1. - costh) - rvec[:, 0] * sinth, | |
| rvec[:, 0] * rvec[:, 2] * (1. - costh) - rvec[:, 1] * sinth, | |
| rvec[:, 1] * rvec[:, 2] * (1. - costh) + rvec[:, 0] * sinth, | |
| rvec[:, 2] ** 2 + (1. - rvec[:, 2] ** 2) * costh), dim=1).view(-1, 3, 3) | |
| def gradcheck(): | |
| N = 2 | |
| H = 64 | |
| W = 64 | |
| k3 = 4 | |
| K = k3*k3*k3 | |
| M = 32 | |
| volradius = 1. | |
| # generate random inputs | |
| torch.manual_seed(1113) | |
| rodrigues = Rodrigues() | |
| _viewpos = torch.tensor([[-0.0, 0.0, -4.] for n in range(N)], device="cuda") + torch.randn(N, 3, device="cuda") * 0.1 | |
| viewrvec = torch.randn(N, 3, device="cuda") * 0.01 | |
| _viewrot = rodrigues(viewrvec) | |
| _focal = torch.tensor([[W*4.0, W*4.0] for n in range(N)], device="cuda") | |
| _princpt = torch.tensor([[W*0.5, H*0.5] for n in range(N)], device="cuda") | |
| pixely, pixelx = torch.meshgrid(torch.arange(H, device="cuda").float(), torch.arange(W, device="cuda").float()) | |
| _pixelcoords = torch.stack([pixelx, pixely], dim=-1)[None, :, :, :].repeat(N, 1, 1, 1) | |
| _viewpos = _viewpos.contiguous().detach().clone() | |
| _viewpos.requires_grad = True | |
| _viewrot = _viewrot.contiguous().detach().clone() | |
| _viewrot.requires_grad = True | |
| _focal = _focal.contiguous().detach().clone() | |
| _focal.requires_grad = True | |
| _princpt = _princpt.contiguous().detach().clone() | |
| _princpt.requires_grad = True | |
| _pixelcoords = _pixelcoords.contiguous().detach().clone() | |
| _pixelcoords.requires_grad = True | |
| max_len = 6.0 | |
| _stepsize = max_len / 15.5 | |
| params = [_viewpos, _viewrot, _focal, _princpt] | |
| paramnames = ["viewpos", "viewrot", "focal", "princpt"] | |
| ########################### run pytorch version ########################### | |
| viewpos = _viewpos | |
| viewrot = _viewrot | |
| focal = _focal | |
| princpt = _princpt | |
| pixelcoords = _pixelcoords | |
| raypos = viewpos[:, None, None, :].repeat(1, H, W, 1) | |
| raydir = (pixelcoords - princpt[:, None, None, :]) / focal[:, None, None, :] | |
| raydir = torch.cat([raydir, torch.ones_like(raydir[:, :, :, 0:1])], dim=-1) | |
| raydir = torch.sum(viewrot[:, None, None, :, :] * raydir[:, :, :, :, None], dim=-2) | |
| raydir = raydir / torch.sqrt(torch.sum(raydir ** 2, dim=-1, keepdim=True)) | |
| t1 = (-1. - viewpos[:, None, None, :]) / raydir | |
| t2 = ( 1. - viewpos[:, None, None, :]) / raydir | |
| tmin = torch.max(torch.min(t1[..., 0], t2[..., 0]), | |
| torch.max(torch.min(t1[..., 1], t2[..., 1]), | |
| torch.min(t1[..., 2], t2[..., 2]))).clamp(min=0.) | |
| tmax = torch.min(torch.max(t1[..., 0], t2[..., 0]), | |
| torch.min(torch.max(t1[..., 1], t2[..., 1]), | |
| torch.max(t1[..., 2], t2[..., 2]))) | |
| tminmax = torch.stack([tmin, tmax], dim=-1) | |
| sample0 = raydir | |
| torch.cuda.synchronize() | |
| time1 = time.time() | |
| sample0.backward(torch.ones_like(sample0)) | |
| torch.cuda.synchronize() | |
| time2 = time.time() | |
| grads0 = [p.grad.detach().clone() if p.grad is not None else None for p in params] | |
| for p in params: | |
| if p.grad is not None: | |
| p.grad.detach_() | |
| p.grad.zero_() | |
| ############################## run cuda version ########################### | |
| viewpos = _viewpos | |
| viewrot = _viewrot | |
| focal = _focal | |
| princpt = _princpt | |
| pixelcoords = _pixelcoords | |
| niter = 1 | |
| for p in params: | |
| if p.grad is not None: | |
| p.grad.detach_() | |
| p.grad.zero_() | |
| t0 = time.time() | |
| torch.cuda.synchronize() | |
| sample1 = compute_raydirs(viewpos, viewrot, focal, princpt, pixelcoords, volradius)[1] | |
| t1 = time.time() | |
| torch.cuda.synchronize() | |
| print("-----------------------------------------------------------------") | |
| print("{:>10} {:>10} {:>10} {:>10} {:>10} {:>10}".format("", "maxabsdiff", "dp", "index", "py", "cuda")) | |
| ind = torch.argmax(torch.abs(sample0 - sample1)) | |
| print("{:<10} {:>10.5} {:>10.5} {:>10} {:>10.5} {:>10.5}".format( | |
| "fwd", | |
| torch.max(torch.abs(sample0 - sample1)).item(), | |
| (torch.sum(sample0 * sample1) / torch.sqrt(torch.sum(sample0 * sample0) * torch.sum(sample1 * sample1))).item(), | |
| ind.item(), | |
| sample0.view(-1)[ind].item(), | |
| sample1.view(-1)[ind].item())) | |
| sample1.backward(torch.ones_like(sample1), retain_graph=True) | |
| torch.cuda.synchronize() | |
| t2 = time.time() | |
| print("{:<10} {:10.5} {:10.5} {:10.5}".format("time", tf / niter, tb / niter, (tf + tb) / niter)) | |
| grads1 = [p.grad.detach().clone() if p.grad is not None else None for p in params] | |
| ############# compare results ############# | |
| for p, g0, g1 in zip(paramnames, grads0, grads1): | |
| ind = torch.argmax(torch.abs(g0 - g1)) | |
| print("{:<10} {:>10.5} {:>10.5} {:>10} {:>10.5} {:>10.5}".format( | |
| p, | |
| torch.max(torch.abs(g0 - g1)).item(), | |
| (torch.sum(g0 * g1) / torch.sqrt(torch.sum(g0 * g0) * torch.sum(g1 * g1))).item(), | |
| ind.item(), | |
| g0.view(-1)[ind].item(), | |
| g1.view(-1)[ind].item())) | |
| if __name__ == "__main__": | |
| gradcheck() | |