File size: 3,105 Bytes
e05eed1
98a67a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include <torch/torch.h>

#include "quad_rectify_cpu.h"
#include "quad_rectify_gpu.h"

inline
torch::Tensor quad_rectify_calc_quad_width(torch::Tensor quads,
                                           int64_t outputHeight,
                                           int64_t roundFactor,
                                           float maxWidth)
{
    if (quads.dim() < 2 || quads.dim() > 3) {
        throw std::runtime_error("Invalid quads dimensions.");
    }

    if (quads.size(-1) != 2 || quads.size(-2) != 4) {
        throw std::runtime_error("The final 2 quad dimensions must be 4x2.");
    }

    if (quads.dim() == 2) {
        quads = quads.unsqueeze(0);
    }

    if (quads.is_cuda()) {
        return quad_rectify_gpu_calc_quad_width(quads, outputHeight, roundFactor, maxWidth);
    } else {
        return quad_rectify_cpu_calc_quad_width(quads, outputHeight, roundFactor, maxWidth);
    }
}

inline
torch::Tensor quad_rectify_forward(torch::Tensor quads,
                                   int64_t imageHeight,
                                   int64_t imageWidth,
                                   int64_t outputHeight,
                                   int64_t outputWidth,
                                   bool isotropic)
{
    if (quads.dim() < 2 || quads.dim() > 3) {
        throw std::runtime_error("Invalid quads dimensions.");
    }

    if (quads.size(-1) != 2 || quads.size(-2) != 4) {
        throw std::runtime_error("The final 2 quad dimensions must be 4x2.");
    }

    bool flatten = false;
    if (quads.dim() == 2) {
        quads = quads.unsqueeze(0);
        flatten = true;
    }

    torch::Tensor ret;
    if (quads.is_cuda()) {
        ret = quad_rectify_gpu_forward(quads, imageHeight, imageWidth, outputHeight, outputWidth, isotropic);
    }
    else {
        ret = quad_rectify_cpu_forward(quads, imageHeight, imageWidth, outputHeight, outputWidth, isotropic);
    }

    if (flatten) {
        ret = ret[0];
    }

    return ret;
}

inline
torch::Tensor quad_rectify_backward(torch::Tensor quads, torch::Tensor gradOutput,
                                    int64_t imageHeight, int64_t imageWidth,
                                    bool isotropic)
{
    if (quads.is_cuda() != gradOutput.is_cuda()) {
        throw std::runtime_error("Either both 'quads' and 'gradOutput' must be cuda, or neither.");
    }

    if (quads.dim() != 3 || quads.size(-2) != 4 || quads.size(-1) != 2) {
        throw std::runtime_error("Expected quads to be 3 dimensional. Nx4x2.");
    }

    if (gradOutput.dim() != 4 ||
        gradOutput.size(3) != 2) {
        throw std::runtime_error("Expected 'gradOutput' to be 4d: Nx<outputHeight>x<outputWidth>x2.");
    }

    if (quads.is_cuda()) {
        return quad_rectify_gpu_backward(quads, gradOutput, imageHeight, imageWidth, isotropic);
    }
    else {
        return quad_rectify_cpu_backward(quads, gradOutput, imageHeight, imageWidth, isotropic);
    }
}