// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 #include "quad_rectify_gpu.h" #include #include #include #include "quad_rectify_shared.h" #include "../half_ops.cuh" #include "../geometry.h" #define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") #define CHECK_INPUT(x) CHECK_CUDA(x) template __global__ void quad_rectify_device_calc_quad_width(quads_accessor_t quads, output_accessor_t output, const scalar_t outputHeight, const scalar_t roundFactor, const scalar_t maxWidth) { const unsigned int quadIdx = blockIdx.x * blockDim.x + threadIdx.x; const unsigned int numQuads = quads.size(0); if (quadIdx >= numQuads) { return; } auto currQuad = quads[quadIdx]; auto quadWidth = calc_quad_width(currQuad, outputHeight, roundFactor, maxWidth); output[quadIdx] = Convert::LeftToRight(quadWidth); } template __global__ void quad_rectify_device_forward(quads_accessor_t quads, output_accessor_t outputs, const scalar_t imageHeight, const scalar_t imageWidth, bool isotropic) { typedef Point_ Point_t; const unsigned int quadIdx = blockIdx.y * blockDim.y + threadIdx.y; const unsigned int numQuads = quads.size(0); if (quadIdx >= numQuads) { return; } const unsigned int outputHeight = outputs.size(1); const unsigned int outputWidth = outputs.size(2); const unsigned int offset = blockIdx.x * blockDim.x + threadIdx.x; const unsigned int x = offset % outputWidth; const unsigned int y = offset / outputWidth; if (y >= outputHeight) { return; } auto quad = quads[quadIdx]; auto output = outputs[quadIdx][y][x]; auto scOutputHeight = Convert::RightToLeft(outputHeight); auto scOutputWidth = Convert::RightToLeft(outputWidth); auto scOne = Convert::RightToLeft(1); scalar_t quadWidth = isotropic ? calc_quad_width(quad, scOutputHeight, scOne, scOutputWidth) : scOutputWidth; Point_t outputPoint = calc_rect_value(quad, quadWidth, scOutputHeight, x, y, imageWidth, imageHeight); output[0] = outputPoint.X; output[1] = outputPoint.Y; } template __global__ void quad_rectify_device_backward(quads_accessor_t quads, output_accessor_t gradOutput, quads_accessor_t gradInput, const scalar_t imageHeight, const scalar_t imageWidth, bool isotropic) { const unsigned int numQuads = quads.size(0); int64_t quadIdx = blockIdx.y * blockDim.y + threadIdx.y; int64_t offset = blockIdx.x * blockDim.x + threadIdx.x; const int64_t outputHeight = gradOutput.size(1); const int64_t outputWidth = gradOutput.size(2); int64_t x = offset % outputWidth; int64_t y = offset / outputWidth; auto scOutputHeight = Convert::RightToLeft(outputHeight); auto scOutputWidth = Convert::RightToLeft(outputWidth); auto scOne = Convert::RightToLeft(1); const scalar_t scHalf = Convert::RightToLeft(0.5); auto currQuad = quads[quadIdx]; scalar_t quadWidth = isotropic ? calc_quad_width(currQuad, scOutputHeight, scOne, scOutputWidth) : scOutputWidth; __shared__ scalar_t sharedFloats[32][8]; scalar_t scale[2] = { Convert::RightToLeft(2.0f) / imageWidth, Convert::RightToLeft(2.0f) / imageHeight }; bool valid = false; if (quadIdx < numQuads && y < outputHeight) { auto fRow = (scalar_t(y) + scHalf) / outputHeight; auto fCol = (scalar_t(x) + scHalf) / quadWidth; // auto fRow = scalar_t(y) / (outputHeight - scOne); // auto fCol = scalar_t(x) / (quadWidth - scOne); auto fRowCol = fRow * fCol; if (fCol <= 1) { #pragma unroll 2 for (int64_t i = 0; i < 2; ++i) { auto currGradOutput = gradOutput[quadIdx][y][x][i] * scale[i]; sharedFloats[threadIdx.x][0 + i] = currGradOutput * (fRowCol - fCol - fRow + 1); sharedFloats[threadIdx.x][2 + i] = currGradOutput * (fCol - fRowCol); sharedFloats[threadIdx.x][4 + i] = currGradOutput * fRowCol; sharedFloats[threadIdx.x][6 + i] = currGradOutput * (fRow - fRowCol); } valid = true; } } if (! valid) { #pragma unroll 8 for (int64_t i = 0; i < 8; ++i) { sharedFloats[threadIdx.x][i] = 0; } } __syncthreads(); // Now accumulate over the shared memory for (unsigned int i = 16; i > 0; i /= 2) { if (threadIdx.x < i) { #pragma unroll 8 for (unsigned int k = 0; k < 8; ++k) { sharedFloats[threadIdx.x][k] += sharedFloats[threadIdx.x + i][k]; } } __syncthreads(); } auto pGradInput = gradInput[quadIdx].data(); // Finally, write the values if (threadIdx.x == 0) { #pragma unroll 8 for (int64_t i = 0; i < 8; ++i) { atomicAdd(pGradInput + i, sharedFloats[0][i]); } } } torch::Tensor quad_rectify_gpu_calc_quad_width(torch::Tensor quads, int64_t outputHeight, int64_t roundFactor, float maxWidth) { CHECK_INPUT(quads); const int64_t numQuads = quads.size(0); dim3 dimBlock(32); dim3 dimGrid(div_up(numQuads, dimBlock.x)); auto output = torch::empty({ numQuads }, quads.options().dtype(torch::kInt64)); if (numQuads > 0) { AT_DISPATCH_FLOATING_TYPES_AND_HALF( quads.scalar_type(), "quad_rectify_calc_quad_width", ([&] { typedef typename remap_half::type T; quad_rectify_device_calc_quad_width KERNEL_ARG2(dimGrid, dimBlock) ( quads.packed_accessor64(), output.packed_accessor64(), Convert::RightToLeft(outputHeight), Convert::RightToLeft(roundFactor), Convert::RightToLeft(maxWidth) ); }) ); } return output; } torch::Tensor quad_rectify_gpu_forward(torch::Tensor quads, int64_t imageHeight, int64_t imageWidth, int64_t outputHeight, int64_t outputWidth, bool isotropic) { CHECK_INPUT(quads); const int64_t numQuads = quads.size(0); const int64_t numCells = outputHeight * outputWidth; dim3 dimBlock(32); dim3 dimGrid(div_up(numCells, dimBlock.x), numQuads); auto output = torch::empty({ numQuads, outputHeight, outputWidth, 2 }, quads.options()); if (numQuads > 0) { AT_DISPATCH_FLOATING_TYPES_AND_HALF( quads.scalar_type(), "quad_rectify_device_forward", ([&] { typedef typename remap_half::type T; quad_rectify_device_forward KERNEL_ARG2(dimGrid, dimBlock) ( quads.packed_accessor64(), output.packed_accessor64(), Convert::RightToLeft(imageHeight), Convert::RightToLeft(imageWidth), isotropic ); }) ); } return output; } torch::Tensor quad_rectify_gpu_backward(torch::Tensor quads, torch::Tensor gradOutput, int64_t imageHeight, int64_t imageWidth, bool isotropic) { CHECK_INPUT(quads); CHECK_INPUT(gradOutput); const int64_t numQuads = quads.size(0); const int64_t outputHeight = gradOutput.size(1); const int64_t outputWidth = gradOutput.size(2); const int64_t numCells = outputHeight * outputWidth; dim3 dimBlock(32); dim3 dimGrid(div_up(numCells, dimBlock.x), numQuads); auto gradInput = torch::zeros_like(quads); if (numQuads > 0) { AT_DISPATCH_FLOATING_TYPES( quads.scalar_type(), "quad_rectify_device_backward", ([&] { typedef typename remap_half::type T; quad_rectify_device_backward KERNEL_ARG2(dimGrid, dimBlock) ( quads.packed_accessor64(), gradOutput.packed_accessor64(), gradInput.packed_accessor64(), Convert::RightToLeft(imageHeight), Convert::RightToLeft(imageWidth), isotropic ); }) ); } return gradInput; }