| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | #include "edit_dist.h" |
| | #include <torch/types.h> |
| |
|
| | #ifndef TORCH_CHECK |
| | #define TORCH_CHECK AT_CHECK |
| | #endif |
| |
|
| | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") |
| | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") |
| | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) |
| |
|
| |
|
| | torch::Tensor LevenshteinDistance( |
| | torch::Tensor source, |
| | torch::Tensor target, |
| | torch::Tensor source_length, |
| | torch::Tensor target_length) { |
| |
|
| | CHECK_INPUT(source); |
| | CHECK_INPUT(target); |
| | CHECK_INPUT(source_length); |
| | CHECK_INPUT(target_length); |
| | return LevenshteinDistanceCuda(source, target, source_length, target_length); |
| | } |
| |
|
| | torch::Tensor GenerateDeletionLabel( |
| | torch::Tensor source, |
| | torch::Tensor operations) { |
| |
|
| | CHECK_INPUT(source); |
| | CHECK_INPUT(operations); |
| | return GenerateDeletionLabelCuda(source, operations); |
| | } |
| |
|
| | std::pair<torch::Tensor, torch::Tensor> GenerateInsertionLabel( |
| | torch::Tensor target, |
| | torch::Tensor operations) { |
| |
|
| | CHECK_INPUT(target); |
| | CHECK_INPUT(operations); |
| | return GenerateInsertionLabelCuda(target, operations); |
| | } |
| |
|
| | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { |
| | m.def("levenshtein_distance", &LevenshteinDistance, "Levenshtein distance"); |
| | m.def("generate_deletion_labels", &GenerateDeletionLabel, "Generate Deletion Label"); |
| | m.def("generate_insertion_labels", &GenerateInsertionLabel, "Generate Insertion Label"); |
| | } |
| |
|