File size: 556 Bytes
c1af2fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
#pragma once
#include <ATen/native/Activation.h>
#include <cstdint>

namespace at {
struct TensorIteratorBase;
class TensorBase;
}

namespace at::native {

void launch_glu_backward_kernel(const TensorIteratorBase& iter,

                                int64_t gI_stride, int64_t I_stride);

void launch_log_sigmoid_forward_kernel(TensorIteratorBase& iter);

void GeluCUDAKernelImpl(TensorIteratorBase& it, GeluType approximate);
void GeluBackwardCUDAKernelImpl(TensorIteratorBase& it, GeluType approximate);

}  // namespace at::native