Spaces:
Runtime error
Runtime error
| // Copyright (c) Meta Platforms, Inc. and affiliates. | |
| // All rights reserved. | |
| // | |
| // This source code is licensed under the license found in the | |
| // LICENSE file in the root directory of this source tree. | |
| __global__ void compute_raydirs_forward_kernel( | |
| int N, int H, int W, | |
| float3 * viewposim, | |
| float3 * viewrotim, | |
| float2 * focalim, | |
| float2 * princptim, | |
| float2 * pixelcoordsim, | |
| float volradius, | |
| float3 * rayposim, | |
| float3 * raydirim, | |
| float2 * tminmaxim | |
| ) { | |
| bool validthread = false; | |
| int w, h, n; | |
| w = blockIdx.x * blockDim.x + threadIdx.x; | |
| h = (blockIdx.y * blockDim.y + threadIdx.y)%H; | |
| n = (blockIdx.y * blockDim.y + threadIdx.y)/H; | |
| validthread = (w < W) && (h < H) && (n<N); | |
| if (validthread) { | |
| float3 raypos = viewposim[n] / volradius; | |
| float3 viewrot0 = viewrotim[n * 3 + 0]; | |
| float3 viewrot1 = viewrotim[n * 3 + 1]; | |
| float3 viewrot2 = viewrotim[n * 3 + 2]; | |
| float2 pixelcoord = pixelcoordsim ? pixelcoordsim[n * H * W + h * W + w] : make_float2(w, h); | |
| pixelcoord = (pixelcoord - princptim[n]) / focalim[n]; | |
| float3 raydir = make_float3(pixelcoord, 1.f); | |
| raydir = viewrot0 * raydir.x + viewrot1 * raydir.y + viewrot2 * raydir.z; | |
| raydir = normalize(raydir); | |
| float3 t1 = (-1.f - raypos) / raydir; | |
| float3 t2 = ( 1.f - raypos) / raydir; | |
| float tmin = fmaxf(fminf(t1.x, t2.x), fmaxf(fminf(t1.y, t2.y), fminf(t1.z, t2.z))); | |
| float tmax = fminf(fmaxf(t1.x, t2.x), fminf(fmaxf(t1.y, t2.y), fmaxf(t1.z, t2.z))); | |
| float2 tminmax = make_float2(fmaxf(tmin, 0.f), tmax); | |
| rayposim[n * H * W + h * W + w] = raypos; | |
| raydirim[n * H * W + h * W + w] = raydir; | |
| tminmaxim[n * H * W + h * W + w] = tminmax; | |
| } | |
| } | |
| __global__ void compute_raydirs_backward_kernel( | |
| int N, int H, int W, | |
| float3 * viewposim, | |
| float3 * viewrotim, | |
| float2 * focalim, | |
| float2 * princptim, | |
| float2 * pixelcoordsim, | |
| float volradius, | |
| float3 * rayposim, | |
| float3 * raydirim, | |
| float2 * tminmaxim, | |
| float3 * grad_viewposim, | |
| float3 * grad_viewrotim, | |
| float2 * grad_focalim, | |
| float2 * grad_princptim | |
| ) { | |
| bool validthread = false; | |
| int w, h, n; | |
| w = blockIdx.x * blockDim.x + threadIdx.x; | |
| h = (blockIdx.y * blockDim.y + threadIdx.y)%H; | |
| n = (blockIdx.y * blockDim.y + threadIdx.y)/H; | |
| validthread = (w < W) && (h < H) && (n<N); | |
| if (validthread) { | |
| float3 raypos = viewposim[n] / volradius; | |
| float3 viewrot0 = viewrotim[n * 3 + 0]; | |
| float3 viewrot1 = viewrotim[n * 3 + 1]; | |
| float3 viewrot2 = viewrotim[n * 3 + 2]; | |
| float2 pixelcoord = pixelcoordsim ? pixelcoordsim[n * H * W + h * W + w] : make_float2(w, h); | |
| pixelcoord = (pixelcoord - princptim[n]) / focalim[n]; | |
| float3 raydir = make_float3(pixelcoord, 1.f); | |
| raydir = viewrot0 * raydir.x + viewrot1 * raydir.y + viewrot2 * raydir.z; | |
| raydir = normalize(raydir); | |
| float3 t1 = (-1.f - raypos) / raydir; | |
| float3 t2 = ( 1.f - raypos) / raydir; | |
| float tmin = fmaxf(fminf(t1.x, t2.x), fmaxf(fminf(t1.y, t2.y), fminf(t1.z, t2.z))); | |
| float tmax = fminf(fmaxf(t1.x, t2.x), fminf(fmaxf(t1.y, t2.y), fmaxf(t1.z, t2.z))); | |
| float2 tminmax = make_float2(fmaxf(tmin, 0.f), tmax); | |
| } | |
| } | |
| void compute_raydirs_forward_cuda( | |
| int N, int H, int W, | |
| float * viewposim, | |
| float * viewrotim, | |
| float * focalim, | |
| float * princptim, | |
| float * pixelcoordsim, | |
| float volradius, | |
| float * rayposim, | |
| float * raydirim, | |
| float * tminmaxim, | |
| cudaStream_t stream) { | |
| int blocksizex = 16; | |
| int blocksizey = 16; | |
| dim3 blocksize(blocksizex, blocksizey); | |
| dim3 gridsize; | |
| gridsize = dim3( | |
| (W + blocksize.x - 1) / blocksize.x, | |
| (N*H + blocksize.y - 1) / blocksize.y); | |
| auto fn = compute_raydirs_forward_kernel; | |
| fn<<<gridsize, blocksize, 0, stream>>>( | |
| N, H, W, | |
| reinterpret_cast<float3 *>(viewposim), | |
| reinterpret_cast<float3 *>(viewrotim), | |
| reinterpret_cast<float2 *>(focalim), | |
| reinterpret_cast<float2 *>(princptim), | |
| reinterpret_cast<float2 *>(pixelcoordsim), | |
| volradius, | |
| reinterpret_cast<float3 *>(rayposim), | |
| reinterpret_cast<float3 *>(raydirim), | |
| reinterpret_cast<float2 *>(tminmaxim)); | |
| } | |
| void compute_raydirs_backward_cuda( | |
| int N, int H, int W, | |
| float * viewposim, | |
| float * viewrotim, | |
| float * focalim, | |
| float * princptim, | |
| float * pixelcoordsim, | |
| float volradius, | |
| float * rayposim, | |
| float * raydirim, | |
| float * tminmaxim, | |
| float * grad_viewposim, | |
| float * grad_viewrotim, | |
| float * grad_focalim, | |
| float * grad_princptim, | |
| cudaStream_t stream) { | |
| int blocksizex = 16; | |
| int blocksizey = 16; | |
| dim3 blocksize(blocksizex, blocksizey); | |
| dim3 gridsize; | |
| gridsize = dim3( | |
| (W + blocksize.x - 1) / blocksize.x, | |
| (N*H + blocksize.y - 1) / blocksize.y); | |
| auto fn = compute_raydirs_backward_kernel; | |
| fn<<<gridsize, blocksize, 0, stream>>>( | |
| N, H, W, | |
| reinterpret_cast<float3 *>(viewposim), | |
| reinterpret_cast<float3 *>(viewrotim), | |
| reinterpret_cast<float2 *>(focalim), | |
| reinterpret_cast<float2 *>(princptim), | |
| reinterpret_cast<float2 *>(pixelcoordsim), | |
| volradius, | |
| reinterpret_cast<float3 *>(rayposim), | |
| reinterpret_cast<float3 *>(raydirim), | |
| reinterpret_cast<float2 *>(tminmaxim), | |
| reinterpret_cast<float3 *>(grad_viewposim), | |
| reinterpret_cast<float3 *>(grad_viewrotim), | |
| reinterpret_cast<float2 *>(grad_focalim), | |
| reinterpret_cast<float2 *>(grad_princptim)); | |
| } | |