File size: 1,329 Bytes
a567fa4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 | // Copyright (c) Facebook, Inc. and its affiliates.
#pragma once
#include <torch/types.h>
namespace tensormask {
#if defined(WITH_CUDA) || defined(WITH_HIP)
at::Tensor SwapAlign2Nat_forward_cuda(
const at::Tensor& X,
const int lambda_val,
const float pad_val);
at::Tensor SwapAlign2Nat_backward_cuda(
const at::Tensor& gY,
const int lambda_val,
const int batch_size,
const int channel,
const int height,
const int width);
#endif
inline at::Tensor SwapAlign2Nat_forward(
const at::Tensor& X,
const int lambda_val,
const float pad_val) {
if (X.type().is_cuda()) {
#if defined(WITH_CUDA) || defined(WITH_HIP)
return SwapAlign2Nat_forward_cuda(X, lambda_val, pad_val);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}
inline at::Tensor SwapAlign2Nat_backward(
const at::Tensor& gY,
const int lambda_val,
const int batch_size,
const int channel,
const int height,
const int width) {
if (gY.type().is_cuda()) {
#if defined(WITH_CUDA) || defined(WITH_HIP)
return SwapAlign2Nat_backward_cuda(
gY, lambda_val, batch_size, channel, height, width);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}
} // namespace tensormask
|