File size: 2,777 Bytes
e05eed1 98a67a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
#include "geometry_api.h"
#include "geometry_api_common.h"
using namespace std;
torch::Tensor rrect_to_quads_gpu(torch::Tensor rrects, float cellSize);
template<typename T>
torch::Tensor rrect_to_quads_impl(torch::Tensor rrects, T cellSize)
{
// BHW(5)
auto rrectAccess = rrects.accessor<T, 4>();
T cellOff = cellSize / 2;
auto quads = torch::empty({ rrects.size(0), rrects.size(1), rrects.size(2), 4, 2 }, rrects.options());
auto quadsAccess = quads.accessor<T, 5>();
for (long b = 0; b < rrects.size(0); ++b) {
for (long y = 0; y < rrects.size(1); ++y) {
for (long x = 0; x < rrects.size(2); ++x) {
auto rrect = rrectAccess[b][y][x];
auto quad = quadsAccess[b][y][x];
assign_rrect_to_quad(rrect, quad, cellSize, cellOff,
static_cast<T>(x),
static_cast<T>(y));
}
}
}
return quads;
}
torch::Tensor rrect_to_quads(torch::Tensor rrects, float cellSize)
{
if (rrects.is_cuda()) {
return rrect_to_quads_gpu(rrects, cellSize);
}
torch::Tensor quads;
AT_DISPATCH_FLOATING_TYPES(
rrects.scalar_type(),
"rrect_to_quads_impl",
([&] {
quads = rrect_to_quads_impl<scalar_t>(rrects, scalar_t(cellSize));
})
);
return quads;
}
template<typename T>
torch::Tensor rrect_to_quads_backward_impl(torch::Tensor rrects, torch::Tensor gradOutput)
{
// BHW(5)
auto gradInput = torch::empty_like(rrects);
auto rrectAccess = rrects.accessor<T, 4>();
// BHW42
auto gradOutputAccess = gradOutput.accessor<T, 5>();
auto gradInputAccess = gradInput.accessor<T, 4>();
for (long b = 0; b < rrects.size(0); ++b) {
for (long y = 0; y < rrects.size(1); ++y) {
for (long x = 0; x < rrects.size(2); ++x) {
assign_grad_rrect_to_quad<T>(rrectAccess[b][y][x], gradOutputAccess[b][y][x], gradInputAccess[b][y][x]);
}
}
}
return gradInput;
}
torch::Tensor rrect_to_quads_backward_gpu(torch::Tensor rrects, torch::Tensor gradOutput);
torch::Tensor rrect_to_quads_backward(torch::Tensor rrects, torch::Tensor gradOutput)
{
if (rrects.is_cuda()) {
return rrect_to_quads_backward_gpu(rrects, gradOutput);
}
torch::Tensor gradInput;
AT_DISPATCH_FLOATING_TYPES(
rrects.scalar_type(),
"rrect_to_quads_backward_impl",
([&] {
gradInput = rrect_to_quads_backward_impl<scalar_t>(rrects, gradOutput);
})
);
return gradInput;
}
|