Spaces:
Paused
Paused
| template <typename scalar_t, typename bound_t> | |
| __device__ __forceinline__ scalar_t clamp(const scalar_t v, const bound_t lo, const bound_t hi) { | |
| return min(max(v, lo), hi); | |
| } | |
| template <typename scalar_t, bool dense_mode> | |
| __global__ void total_variation_add_grad_cuda_kernel( | |
| const scalar_t* __restrict__ param, | |
| scalar_t* __restrict__ grad, | |
| float wx, float wy, float wz, | |
| const size_t sz_i, const size_t sz_j, const size_t sz_k, const size_t N) { | |
| const size_t index = blockIdx.x * blockDim.x + threadIdx.x; | |
| if(index<N && (dense_mode || grad[index]!=0)) { | |
| const size_t k = index % sz_k; | |
| const size_t j = index / sz_k % sz_j; | |
| const size_t i = index / sz_k / sz_j % sz_i; | |
| float grad_to_add = 0; | |
| grad_to_add += (k==0 ? 0 : wz * clamp(param[index]-param[index-1], -1.f, 1.f)); | |
| grad_to_add += (k==sz_k-1 ? 0 : wz * clamp(param[index]-param[index+1], -1.f, 1.f)); | |
| grad_to_add += (j==0 ? 0 : wy * clamp(param[index]-param[index-sz_k], -1.f, 1.f)); | |
| grad_to_add += (j==sz_j-1 ? 0 : wy * clamp(param[index]-param[index+sz_k], -1.f, 1.f)); | |
| grad_to_add += (i==0 ? 0 : wz * clamp(param[index]-param[index-sz_k*sz_j], -1.f, 1.f)); | |
| grad_to_add += (i==sz_i-1 ? 0 : wz * clamp(param[index]-param[index+sz_k*sz_j], -1.f, 1.f)); | |
| grad[index] += grad_to_add; | |
| } | |
| } | |
| void total_variation_add_grad_cuda(torch::Tensor param, torch::Tensor grad, float wx, float wy, float wz, bool dense_mode) { | |
| const size_t N = param.numel(); | |
| const size_t sz_i = param.size(2); | |
| const size_t sz_j = param.size(3); | |
| const size_t sz_k = param.size(4); | |
| const int threads = 256; | |
| const int blocks = (N + threads - 1) / threads; | |
| wx /= 6; | |
| wy /= 6; | |
| wz /= 6; | |
| if(dense_mode) { | |
| AT_DISPATCH_FLOATING_TYPES(param.type(), "total_variation_add_grad_cuda", ([&] { | |
| total_variation_add_grad_cuda_kernel<scalar_t,true><<<blocks, threads>>>( | |
| param.data<scalar_t>(), | |
| grad.data<scalar_t>(), | |
| wx, wy, wz, | |
| sz_i, sz_j, sz_k, N); | |
| })); | |
| } | |
| else { | |
| AT_DISPATCH_FLOATING_TYPES(param.type(), "total_variation_add_grad_cuda", ([&] { | |
| total_variation_add_grad_cuda_kernel<scalar_t,false><<<blocks, threads>>>( | |
| param.data<scalar_t>(), | |
| grad.data<scalar_t>(), | |
| wx, wy, wz, | |
| sz_i, sz_j, sz_k, N); | |
| })); | |
| } | |
| } | |