File size: 4,422 Bytes
f06f310
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
#include <cuda_fp16.h>
#include <cuda_runtime.h>


#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Parallel.h>

#define BLOCK 16

__forceinline__ __device__ bool within_bounds(int h, int w, int H, int W) {
  return h >= 0 && h < H && w >= 0 && w < W;
}

template <typename scalar_t>
__global__ void sampler_forward_kernel(
    const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> volume,
    const torch::PackedTensorAccessor32<float,4,torch::RestrictPtrTraits> coords,
    torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> corr,
    int r)
{
  // batch index
  const int x = blockIdx.x * blockDim.x + threadIdx.x;
  const int y = blockIdx.y * blockDim.y + threadIdx.y;
  const int n = blockIdx.z;

  const int h1 = volume.size(1);
  const int w1 = volume.size(2);
  const int w2 = volume.size(3);

  if (!within_bounds(y, x, h1, w1)) {
    return;
  }

  float x0 = coords[n][0][y][x];
  float y0 = coords[n][1][y][x];

  float dx = x0 - floor(x0);
  float dy = y0 - floor(y0);

  int rd = 2*r + 1;
  for (int i=0; i<rd+1; i++) { // i is X
    int x1 = static_cast<int>(floor(x0)) - r + i;

    if (within_bounds(0, x1, 1, w2)) {
      scalar_t s = volume[n][y][x][x1];

      if (i > 0)
        corr[n][i-1][y][x] += s * scalar_t(dx);

      if (i < rd)
        corr[n][i][y][x] += s * scalar_t((1.0f-dx));

    }
  }
}


template <typename scalar_t>
__global__ void sampler_backward_kernel(
    const torch::PackedTensorAccessor32<float,4,torch::RestrictPtrTraits> coords,
    const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> corr_grad,
    torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> volume_grad,
    int r)
{
  // batch index
  const int x = blockIdx.x * blockDim.x + threadIdx.x;
  const int y = blockIdx.y * blockDim.y + threadIdx.y;
  const int n = blockIdx.z;

  const int h1 = volume_grad.size(1);
  const int w1 = volume_grad.size(2);
  const int w2 = volume_grad.size(3);

  if (!within_bounds(y, x, h1, w1)) {
    return;
  }

  float x0 = coords[n][0][y][x];
  float y0 = coords[n][1][y][x];

  float dx = x0 - floor(x0);
  float dy = y0 - floor(y0);

  int rd = 2*r + 1;
  for (int i=0; i<rd+1; i++) {
    int x1 = static_cast<int>(floor(x0)) - r + i;

    if (within_bounds(0, x1, 1, w2)) {
      scalar_t g = 0.0;

      if (i > 0)
        g += corr_grad[n][i-1][y][x] * scalar_t(dx);

      if (i < rd)
        g += corr_grad[n][i][y][x] * scalar_t((1.0f-dx));

      volume_grad[n][y][x][x1] += g;
    }
  }
}

std::vector<torch::Tensor> sampler_cuda_forward(
    torch::Tensor volume,
    torch::Tensor coords,
    int radius)
{
  const auto batch_size = volume.size(0);
  const auto ht = volume.size(1);
  const auto wd = volume.size(2);

  const dim3 blocks((wd + BLOCK - 1) / BLOCK, 
                    (ht + BLOCK - 1) / BLOCK, 
                    batch_size);
  
  const dim3 threads(BLOCK, BLOCK);

  auto opts = volume.options();
  torch::Tensor corr = torch::zeros(
    {batch_size, 2*radius+1, ht, wd}, opts);

  AT_DISPATCH_FLOATING_TYPES_AND_HALF(volume.type(), "sampler_forward_kernel", ([&] {
    sampler_forward_kernel<scalar_t><<<blocks, threads>>>(
      volume.packed_accessor32<scalar_t,4,torch::RestrictPtrTraits>(),
      coords.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
      corr.packed_accessor32<scalar_t,4,torch::RestrictPtrTraits>(),
      radius);
   }));

  return {corr};

}

std::vector<torch::Tensor> sampler_cuda_backward(
  torch::Tensor volume,
  torch::Tensor coords,
  torch::Tensor corr_grad,
  int radius)
{
  const auto batch_size = volume.size(0);
  const auto ht = volume.size(1);
  const auto wd = volume.size(2);

  auto volume_grad = torch::zeros_like(volume);

  const dim3 blocks((wd + BLOCK - 1) / BLOCK, 
                    (ht + BLOCK - 1) / BLOCK, 
                    batch_size);

  const dim3 threads(BLOCK, BLOCK);


  AT_DISPATCH_FLOATING_TYPES_AND_HALF(volume.type(), "sampler_backward_kernel", ([&] {
    sampler_backward_kernel<scalar_t><<<blocks, threads>>>(
      coords.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
      corr_grad.packed_accessor32<scalar_t,4,torch::RestrictPtrTraits>(),
      volume_grad.packed_accessor32<scalar_t,4,torch::RestrictPtrTraits>(),
      radius);
   }));

  return {volume_grad};
}