File size: 4,188 Bytes
e05eed1
98a67a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

#include "geometry_api.h"

#include "../geometry.h"
#include "../cuda_intellisense.cuh"
#include "geometry_api_common.h"

#include <trove/ptr.h>

using namespace std;


template<typename T>
struct RRect_ {
    T Data[5];

    template<typename index_t>
    __device__
    const T &operator[](index_t i) const { return Data[i]; }
    template<typename index_t>
    __device__
    T &operator[](index_t i) { return Data[i]; }
};

template<typename T>
__global__
void device_rrect_to_quads_gpu(torch::PackedTensorAccessor64<T, 2> rrectAccess,
                               torch::PackedTensorAccessor64<T, 3> quadsAccess,
                               int64_t numRows, int64_t numCols,
                               T cellSize)
{
    typedef Point_<T> Pointf;
    typedef RRect_<T> RRectf;
    typedef InPlaceQuad_<T> Quadf;
    constexpr T TWO = 2;

    const int64_t jobIdx = blockIdx.x * blockDim.x + threadIdx.x;

    if (jobIdx >= rrectAccess.size(0)) {
        return;
    }

    int64_t row = jobIdx / numCols;
    const int64_t col = jobIdx - (row * numCols);
    row = row % numRows;

    auto rawRRect = reinterpret_cast<RRectf*>(rrectAccess.data());
    auto rawQuad = reinterpret_cast<Quadf*>(quadsAccess.data());
#if defined(NDEBUG)
    trove::coalesced_ptr<RRectf> pRRect(rawRRect);
    trove::coalesced_ptr<Quadf> pQuad(rawQuad);
#else
    auto pRRect = rawRRect;
    auto pQuad = rawQuad;
#endif

    RRectf rrect = pRRect[jobIdx];

    T cellOff = cellSize / TWO;
    Quadf cvQuad = cvt_rrect_to_quad<T>(rrect, cellSize, cellOff, col, row);

    pQuad[jobIdx] = cvQuad;
}

torch::Tensor rrect_to_quads_gpu(torch::Tensor rrects, float cellSize)
{
    if (!rrects.is_contiguous()) {
        throw std::runtime_error("Expected the rrects to be contiguous!");
    }

    torch::Tensor quads = torch::empty({ rrects.size(0), rrects.size(1), rrects.size(2), 4, 2 }, rrects.options());

    auto rrFlat = rrects.flatten(0, 2);
    auto qFlat = quads.flatten(0, 2);

    dim3 blockSize(96);
    dim3 gridSize(div_up(qFlat.size(0), blockSize.x));

    if (quads.numel() > 0) {
        AT_DISPATCH_FLOATING_TYPES(
            quads.scalar_type(),
            "cuda_rrect_to_quads",
            ([&] {

                device_rrect_to_quads_gpu<scalar_t> KERNEL_ARG2(gridSize, blockSize) (
                    rrFlat.packed_accessor64<scalar_t, 2>(),
                    qFlat.packed_accessor64<scalar_t, 3>(),
                    rrects.size(1), rrects.size(2),
                    cellSize
                );

            })
        );
    }

    return quads;
}

template<typename scalar_t>
__global__
void device_rrect_to_quads_backward_gpu(torch::PackedTensorAccessor64<scalar_t, 2> rrect,
                                        torch::PackedTensorAccessor64<scalar_t, 3> gradOutput,
                                        torch::PackedTensorAccessor64<scalar_t, 2> gradInput)
{
    const int64_t jobIdx = blockIdx.x * blockDim.x + threadIdx.x;

    if (jobIdx >= rrect.size(0)) return;

    assign_grad_rrect_to_quad<scalar_t>(rrect[jobIdx], gradOutput[jobIdx], gradInput[jobIdx]);
}


torch::Tensor rrect_to_quads_backward_gpu(torch::Tensor rrects, torch::Tensor gradOutput)
{
    auto gradInput = torch::empty_like(rrects);

    auto flatRRects = rrects.reshape({ -1, 5 });
    auto flatGradOutput = gradOutput.reshape({ -1, 4, 2 });
    auto flatGradInput = gradInput.reshape({ -1, 5 });

    dim3 blockSize(32);
    dim3 gridSize(div_up(rrects.size(0) * rrects.size(1) * rrects.size(2), blockSize.x));

    if (rrects.numel() > 0) {
        AT_DISPATCH_FLOATING_TYPES(
            rrects.scalar_type(),
            "cuda_rrect_to_quads_backward",
            ([&] {
                device_rrect_to_quads_backward_gpu KERNEL_ARG2(gridSize, blockSize) (
                    flatRRects.packed_accessor64<scalar_t, 2>(),
                    flatGradOutput.packed_accessor64<scalar_t, 3>(),
                    flatGradInput.packed_accessor64<scalar_t, 2>()
                );
            })
        );
    }

    return gradInput;
}