|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include <torch/types.h> |
|
|
#include "edit_dist.h" |
|
|
|
|
|
#ifndef TORCH_CHECK |
|
|
#define TORCH_CHECK AT_CHECK |
|
|
#endif |
|
|
|
|
|
#define CHECK_CUDA(x) \ |
|
|
TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") |
|
|
#define CHECK_CONTIGUOUS(x) \ |
|
|
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") |
|
|
#define CHECK_INPUT(x) \ |
|
|
CHECK_CUDA(x); \ |
|
|
CHECK_CONTIGUOUS(x) |
|
|
|
|
|
torch::Tensor LevenshteinDistance( |
|
|
torch::Tensor source, |
|
|
torch::Tensor target, |
|
|
torch::Tensor source_length, |
|
|
torch::Tensor target_length) { |
|
|
CHECK_INPUT(source); |
|
|
CHECK_INPUT(target); |
|
|
CHECK_INPUT(source_length); |
|
|
CHECK_INPUT(target_length); |
|
|
return LevenshteinDistanceCuda(source, target, source_length, target_length); |
|
|
} |
|
|
|
|
|
torch::Tensor GenerateDeletionLabel( |
|
|
torch::Tensor source, |
|
|
torch::Tensor operations) { |
|
|
CHECK_INPUT(source); |
|
|
CHECK_INPUT(operations); |
|
|
return GenerateDeletionLabelCuda(source, operations); |
|
|
} |
|
|
|
|
|
std::pair<torch::Tensor, torch::Tensor> GenerateInsertionLabel( |
|
|
torch::Tensor target, |
|
|
torch::Tensor operations) { |
|
|
CHECK_INPUT(target); |
|
|
CHECK_INPUT(operations); |
|
|
return GenerateInsertionLabelCuda(target, operations); |
|
|
} |
|
|
|
|
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { |
|
|
m.def("levenshtein_distance", &LevenshteinDistance, "Levenshtein distance"); |
|
|
m.def( |
|
|
"generate_deletion_labels", |
|
|
&GenerateDeletionLabel, |
|
|
"Generate Deletion Label"); |
|
|
m.def( |
|
|
"generate_insertion_labels", |
|
|
&GenerateInsertionLabel, |
|
|
"Generate Insertion Label"); |
|
|
} |
|
|
|