File size: 495 Bytes
e05eed1
98a67a0
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include <torch/extension.h>

std::tuple<torch::Tensor, std::vector<torch::Tensor>> sparse_select(torch::Tensor sparseCounts,
                                                                    const std::vector<torch::Tensor> sparseTensors,
                                                                    torch::Tensor selectIndices);