// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 #include "sparse_select.h" #include #include "../common.h" using namespace std; std::tuple> sparse_select(torch::Tensor sparseCounts, const std::vector sparseTensors, torch::Tensor selectIndices) { bool is_gpu = sparseCounts.is_cuda(); auto sparseCountsCPU = sparseCounts.cpu(); auto sortedSelect = get<0>(torch::sort(selectIndices)); vector retTensors; for (const torch::Tensor &t : sparseTensors) { retTensors.push_back(t.index({sortedSelect})); } vector offsets(1 + sparseCountsCPU.size(0)); auto sparseCtAccess = sparseCountsCPU.accessor(); for (int64_t i = 0; i < sparseCountsCPU.size(0); ++i) { offsets[i + 1] = sparseCtAccess[i] + offsets[i]; } // cout << "Offsets: " << offsets << endl; auto retCounts = torch::zeros_like(sparseCountsCPU); auto retCtAccess = retCounts.accessor(); auto idxAccess = sortedSelect.accessor(); for (int64_t i = 0; i < idxAccess.size(0); ++i) { int64_t idx = idxAccess[i]; int64_t batchIdx = std::upper_bound(begin(offsets), end(offsets), idx) - begin(offsets) - 1; // cout << "Index: " << idx << ", Batch Index: " << batchIdx << endl; retCtAccess[batchIdx] += 1; } if (is_gpu) { retCounts = retCounts.to(sparseCounts); } return make_tuple(retCounts, retTensors); }