File size: 1,759 Bytes
e05eed1 98a67a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
#include "sparse_select.h"
#include <algorithm>
#include "../common.h"
using namespace std;
std::tuple<torch::Tensor, std::vector<torch::Tensor>> sparse_select(torch::Tensor sparseCounts,
const std::vector<torch::Tensor> sparseTensors,
torch::Tensor selectIndices)
{
bool is_gpu = sparseCounts.is_cuda();
auto sparseCountsCPU = sparseCounts.cpu();
auto sortedSelect = get<0>(torch::sort(selectIndices));
vector<torch::Tensor> retTensors;
for (const torch::Tensor &t : sparseTensors) {
retTensors.push_back(t.index({sortedSelect}));
}
vector<int64_t> offsets(1 + sparseCountsCPU.size(0));
auto sparseCtAccess = sparseCountsCPU.accessor<int64_t, 1>();
for (int64_t i = 0; i < sparseCountsCPU.size(0); ++i) {
offsets[i + 1] = sparseCtAccess[i] + offsets[i];
}
// cout << "Offsets: " << offsets << endl;
auto retCounts = torch::zeros_like(sparseCountsCPU);
auto retCtAccess = retCounts.accessor<int64_t, 1>();
auto idxAccess = sortedSelect.accessor<int64_t, 1>();
for (int64_t i = 0; i < idxAccess.size(0); ++i) {
int64_t idx = idxAccess[i];
int64_t batchIdx = std::upper_bound(begin(offsets), end(offsets), idx) - begin(offsets) - 1;
// cout << "Index: " << idx << ", Batch Index: " << batchIdx << endl;
retCtAccess[batchIdx] += 1;
}
if (is_gpu) {
retCounts = retCounts.to(sparseCounts);
}
return make_tuple(retCounts, retTensors);
}
|