File size: 10,541 Bytes
e05eed1
98a67a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0


#include "quad_rectify_gpu.h"

#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.hpp>

#include "quad_rectify_shared.h"
#include "../half_ops.cuh"
#include "../geometry.h"

#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x)

template<typename scalar_t, typename quads_accessor_t, typename output_accessor_t>
__global__ void quad_rectify_device_calc_quad_width(quads_accessor_t quads,
                                                    output_accessor_t output,
                                                    const scalar_t outputHeight,
                                                    const scalar_t roundFactor,
                                                    const scalar_t maxWidth)
{
    const unsigned int quadIdx = blockIdx.x * blockDim.x + threadIdx.x;
    const unsigned int numQuads = quads.size(0);

    if (quadIdx >= numQuads) {
        return;
    }

    auto currQuad = quads[quadIdx];

    auto quadWidth = calc_quad_width<scalar_t>(currQuad, outputHeight, roundFactor, maxWidth);

    output[quadIdx] = Convert<scalar_t, int64_t>::LeftToRight(quadWidth);
}

template<typename scalar_t, typename quads_accessor_t, typename output_accessor_t>
__global__ void quad_rectify_device_forward(quads_accessor_t quads,
                                            output_accessor_t outputs,
                                            const scalar_t imageHeight,
                                            const scalar_t imageWidth,
                                            bool isotropic)
{
    typedef Point_<scalar_t> Point_t;

    const unsigned int quadIdx = blockIdx.y * blockDim.y + threadIdx.y;
    const unsigned int numQuads = quads.size(0);

    if (quadIdx >= numQuads) {
        return;
    }

    const unsigned int outputHeight = outputs.size(1);
    const unsigned int outputWidth = outputs.size(2);

    const unsigned int offset = blockIdx.x * blockDim.x + threadIdx.x;

    const unsigned int x = offset % outputWidth;
    const unsigned int y = offset / outputWidth;

    if (y >= outputHeight) {
        return;
    }

    auto quad = quads[quadIdx];
    auto output = outputs[quadIdx][y][x];

    auto scOutputHeight = Convert<scalar_t, unsigned int>::RightToLeft(outputHeight);
    auto scOutputWidth = Convert<scalar_t, unsigned int>::RightToLeft(outputWidth);
    auto scOne = Convert<scalar_t, float>::RightToLeft(1);

    scalar_t quadWidth = isotropic ? calc_quad_width<scalar_t>(quad, scOutputHeight, scOne, scOutputWidth) : scOutputWidth;

    Point_t outputPoint = calc_rect_value<scalar_t>(quad,
                                                    quadWidth,
                                                    scOutputHeight,
                                                    x,
                                                    y,
                                                    imageWidth,
                                                    imageHeight);

    output[0] = outputPoint.X;
    output[1] = outputPoint.Y;
}

template<typename scalar_t, typename quads_accessor_t, typename output_accessor_t>
__global__ void quad_rectify_device_backward(quads_accessor_t quads,
                                             output_accessor_t gradOutput,
                                             quads_accessor_t gradInput,
                                             const scalar_t imageHeight,
                                             const scalar_t imageWidth,
                                             bool isotropic)
{
    const unsigned int numQuads = quads.size(0);
    int64_t quadIdx = blockIdx.y * blockDim.y + threadIdx.y;

    int64_t offset = blockIdx.x * blockDim.x + threadIdx.x;

    const int64_t outputHeight = gradOutput.size(1);
    const int64_t outputWidth = gradOutput.size(2);

    int64_t x = offset % outputWidth;
    int64_t y = offset / outputWidth;

    auto scOutputHeight = Convert<scalar_t, unsigned int>::RightToLeft(outputHeight);
    auto scOutputWidth = Convert<scalar_t, unsigned int>::RightToLeft(outputWidth);
    auto scOne = Convert<scalar_t, float>::RightToLeft(1);
    const scalar_t scHalf = Convert<scalar_t, float>::RightToLeft(0.5);

    auto currQuad = quads[quadIdx];
    scalar_t quadWidth = isotropic ? calc_quad_width<scalar_t>(currQuad, scOutputHeight, scOne, scOutputWidth) : scOutputWidth;

    __shared__ scalar_t sharedFloats[32][8];

    scalar_t scale[2] = { Convert<scalar_t, float>::RightToLeft(2.0f) / imageWidth,
                          Convert<scalar_t, float>::RightToLeft(2.0f) / imageHeight };

    bool valid = false;
    if (quadIdx < numQuads && y < outputHeight) {
        auto fRow = (scalar_t(y) + scHalf) / outputHeight;
        auto fCol = (scalar_t(x) + scHalf) / quadWidth;
        // auto fRow = scalar_t(y) / (outputHeight - scOne);
        // auto fCol = scalar_t(x) / (quadWidth - scOne);
        auto fRowCol = fRow * fCol;

        if (fCol <= 1) {
            #pragma unroll 2
            for (int64_t i = 0; i < 2; ++i) {
                auto currGradOutput = gradOutput[quadIdx][y][x][i] * scale[i];

                sharedFloats[threadIdx.x][0 + i] = currGradOutput * (fRowCol - fCol - fRow + 1);
                sharedFloats[threadIdx.x][2 + i] = currGradOutput * (fCol - fRowCol);
                sharedFloats[threadIdx.x][4 + i] = currGradOutput * fRowCol;
                sharedFloats[threadIdx.x][6 + i] = currGradOutput * (fRow - fRowCol);
            }
            valid = true;
        }
    }

    if (! valid) {
        #pragma unroll 8
        for (int64_t i = 0; i < 8; ++i) {
            sharedFloats[threadIdx.x][i] = 0;
        }
    }

    __syncthreads();

    // Now accumulate over the shared memory
    for (unsigned int i = 16; i > 0; i /= 2) {
        if (threadIdx.x < i) {
            #pragma unroll 8
            for (unsigned int k = 0; k < 8; ++k) {
                sharedFloats[threadIdx.x][k] += sharedFloats[threadIdx.x + i][k];
            }
        }
        __syncthreads();
    }

    auto pGradInput = gradInput[quadIdx].data();

    // Finally, write the values
    if (threadIdx.x == 0) {
        #pragma unroll 8
        for (int64_t i = 0; i < 8; ++i) {
            atomicAdd(pGradInput + i, sharedFloats[0][i]);
        }
    }
}

torch::Tensor quad_rectify_gpu_calc_quad_width(torch::Tensor quads,
                                               int64_t outputHeight,
                                               int64_t roundFactor,
                                               float maxWidth)
{
    CHECK_INPUT(quads);

    const int64_t numQuads = quads.size(0);

    dim3 dimBlock(32);
    dim3 dimGrid(div_up(numQuads, dimBlock.x));

    auto output = torch::empty({ numQuads },
                               quads.options().dtype(torch::kInt64));

    if (numQuads > 0) {
        AT_DISPATCH_FLOATING_TYPES_AND_HALF(
            quads.scalar_type(),
            "quad_rectify_calc_quad_width",
            ([&] {
                typedef typename remap_half<scalar_t>::type T;
                quad_rectify_device_calc_quad_width<T> KERNEL_ARG2(dimGrid, dimBlock) (
                    quads.packed_accessor64<T, 3>(),
                    output.packed_accessor64<int64_t, 1>(),
                    Convert<T, int64_t>::RightToLeft(outputHeight),
                    Convert<T, int64_t>::RightToLeft(roundFactor),
                    Convert<T, float>::RightToLeft(maxWidth)
                );
            })
        );
    }

    return output;
}

torch::Tensor quad_rectify_gpu_forward(torch::Tensor quads,
                                       int64_t imageHeight,
                                       int64_t imageWidth,
                                       int64_t outputHeight,
                                       int64_t outputWidth,
                                       bool isotropic)
{
    CHECK_INPUT(quads);

    const int64_t numQuads = quads.size(0);
    const int64_t numCells = outputHeight * outputWidth;

    dim3 dimBlock(32);
    dim3 dimGrid(div_up(numCells, dimBlock.x),
                 numQuads);

    auto output = torch::empty({ numQuads, outputHeight, outputWidth, 2 },
                               quads.options());

    if (numQuads > 0) {
        AT_DISPATCH_FLOATING_TYPES_AND_HALF(
            quads.scalar_type(),
            "quad_rectify_device_forward",
            ([&] {
                typedef typename remap_half<scalar_t>::type T;
                quad_rectify_device_forward<T> KERNEL_ARG2(dimGrid, dimBlock) (
                    quads.packed_accessor64<T, 3>(),
                    output.packed_accessor64<T, 4>(),
                    Convert<T, int64_t>::RightToLeft(imageHeight),
                    Convert<T, int64_t>::RightToLeft(imageWidth),
                    isotropic
                );
            })
        );
    }

    return output;
}

torch::Tensor quad_rectify_gpu_backward(torch::Tensor quads,
                                        torch::Tensor gradOutput,
                                        int64_t imageHeight,
                                        int64_t imageWidth,
                                        bool isotropic)
{
    CHECK_INPUT(quads);
    CHECK_INPUT(gradOutput);

    const int64_t numQuads = quads.size(0);
    const int64_t outputHeight = gradOutput.size(1);
    const int64_t outputWidth = gradOutput.size(2);

    const int64_t numCells = outputHeight * outputWidth;

    dim3 dimBlock(32);
    dim3 dimGrid(div_up(numCells, dimBlock.x),
                 numQuads);

    auto gradInput = torch::zeros_like(quads);

    if (numQuads > 0) {
        AT_DISPATCH_FLOATING_TYPES(
            quads.scalar_type(),
            "quad_rectify_device_backward",
            ([&] {
                typedef typename remap_half<scalar_t>::type T;
                quad_rectify_device_backward<T> KERNEL_ARG2(dimGrid, dimBlock) (
                    quads.packed_accessor64<T, 3>(),
                    gradOutput.packed_accessor64<T, 4>(),
                    gradInput.packed_accessor64<T, 3>(),
                    Convert<T, int64_t>::RightToLeft(imageHeight),
                    Convert<T, int64_t>::RightToLeft(imageWidth),
                    isotropic
                );
            })
        );
    }

    return gradInput;
}