| |
|
|
| #include <torch/torch.h> |
|
|
| |
| |
| torch::Tensor batch_mm( |
| torch::Tensor x, |
| torch::Tensor weights, |
| torch::Tensor batch_sizes, |
| torch::Tensor output, |
| bool trans_b) { |
| |
| TORCH_CHECK(x.is_cuda(), "x must be on CUDA"); |
| TORCH_CHECK(weights.is_cuda(), "weights must be on CUDA"); |
| TORCH_CHECK(batch_sizes.is_cuda(), "batch_sizes must be on CUDA"); |
|
|
| TORCH_CHECK(x.ndimension() == 3, "x must be 3D tensor"); |
| TORCH_CHECK(weights.ndimension() == 3, |
| "weights must be 3D tensor"); |
| TORCH_CHECK(batch_sizes.ndimension() == 1, |
| "batch_sizes must be 1D tensor"); |
|
|
| TORCH_CHECK(x.size(0) == weights.size(0) && x.size(0) == batch_sizes.size(0)); |
| TORCH_CHECK(x.size(2) == weights.size(1)); |
|
|
| |
| |
| torch::bmm_out(output, x, weights); |
| return output; |
| } |
|
|