class-numbers-cuda / torch-ext /torch_binding.h
cahlen's picture
Add torch library bindings (guard main, add torch wrapper functions)
a51a509 verified
raw
history blame contribute delete
124 Bytes
#pragma once
#include <torch/torch.h>
std::vector<torch::Tensor> compute_class_numbers_torch(torch::Tensor discriminants);