File size: 2,908 Bytes
e05eed1
98a67a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include <torch/torch.h>

#include "nms_common.h"
#include "../geometry.h"

/*
* \brief Result type for non-maximal suppression.
*
* The results are flattened across the batch, use the third value (region counts) to determine which
* example a quad is associated with.
*
* N - Total number of quads for the entire batch
* B - Batch size
*
* 0 - quads - Nx4x2
* 1 - confidence - N
* 2 - regionCounts - B (s.t. sum(regionCounts) == N)
*/
typedef std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> nms_result_t;

nms_result_t quad_non_maximal_suppression(
    torch::Tensor quads, torch::Tensor probs,
    float probThreshold, float iouThreshold,
    int64_t kernelHeight, int64_t kernelWidth,
    int64_t maxRegions,
    bool verbose = false);


template<typename T>
struct TrackedInPlaceQuad_ : InPlaceQuad_<T> {
    Point_<int64_t> ImgCoords;

    TrackedInPlaceQuad_(Point_<int64_t> imgCoords) : ImgCoords(std::move(imgCoords)) {}
    TrackedInPlaceQuad_(int64_t row, int64_t col) : ImgCoords(col, row) {}
};

template<typename T>
struct TrackedEmbedQuad_ : EmbedQuad_<T> {
    std::vector<Point_<int64_t>> ImgCoords;

    TrackedEmbedQuad_(T confidence = 0): EmbedQuad_<T>(confidence) {}
    TrackedEmbedQuad_(const TrackedEmbedQuad_ &other) = default;

    void swap(TrackedEmbedQuad_ &other) noexcept {
        using std::swap;

        swap(ImgCoords, other.ImgCoords);

        EmbedQuad_<T>::swap(other);
    }

    TrackedEmbedQuad_(TrackedEmbedQuad_ &&other) : TrackedEmbedQuad_() {
        other.swap(*this);
    }

    TrackedEmbedQuad_ &operator=(TrackedEmbedQuad_ other) {
        other.swap(*this);
        return *this;
    }

    void Append(const TrackedInPlaceQuad_<T> &q, T conf, T numQuads = 1) {
        ImgCoords.push_back(q.ImgCoords);

        EmbedQuad_<T>::Append(q, conf, numQuads);
    }

    void Append(const TrackedEmbedQuad_<T> &other) {
        ImgCoords.insert(end(ImgCoords), begin(other.ImgCoords), end(other.ImgCoords));

        EmbedQuad_<T>::Append(other);
    }

    void Reset() {
        ImgCoords.clear();

        EmbedQuad_<T>::Reset();
    }
};

typedef TrackedInPlaceQuad_<float> TIPQuad;
typedef TrackedEmbedQuad_<float> TEFQuad;


std::vector<TEFQuad> reduced_quad_non_maximal_suppression(
    const std::vector<TIPQuad> &rowQuads, float iouThreshold, int64_t imageHeight, int64_t imageWidth);

std::vector<torch::Tensor> quad_non_maximal_suppression_backward(
    torch::Tensor quads, torch::Tensor probs,
    torch::Tensor gradOutQuads, torch::Tensor gradOutProbs);

nms_result_t cuda_quad_non_maximal_suppression(
    torch::Tensor quads, torch::Tensor probs,
    float probThreshold, float iouThreshold,
    int64_t kernelHeight, int64_t kernelWidth,
    int64_t maxRegions,
    bool verbose);