| | |
| | |
| | |
| | |
| |
|
| | #include <torch/extension.h> |
| | #include <vector> |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | torch::Tensor ngram_repeat_block_cuda_forward(torch::Tensor tokens, |
| | torch::Tensor lprobs, int bsz, |
| | int step, int beam_size, |
| | int no_repeat_ngram_size); |
| |
|
| | #define CHECK_CUDA(x) \ |
| | TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") |
| | #define CHECK_CONTIGUOUS(x) \ |
| | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") |
| | #define CHECK_INPUT(x) \ |
| | CHECK_CUDA(x); \ |
| | CHECK_CONTIGUOUS(x) |
| |
|
| | |
| | |
| | torch::Tensor ngram_repeat_block_forward(torch::Tensor tokens, |
| | torch::Tensor lprobs, int bsz, |
| | int step, int beam_size, |
| | int no_repeat_ngram_size) { |
| | CHECK_INPUT(tokens); |
| | CHECK_INPUT(lprobs); |
| | assert(bsz > 0); |
| | assert(step >= 0); |
| | assert(beam_size > 0); |
| | assert(no_repeat_ngram_size > 0); |
| |
|
| | return ngram_repeat_block_cuda_forward(tokens, lprobs, bsz, step, beam_size, |
| | no_repeat_ngram_size); |
| | } |
| |
|
| | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { |
| | m.def("forward", &ngram_repeat_block_forward, |
| | "No Repeat Ngram Block forward (CUDA)"); |
| | } |
| |
|