File size: 1,103 Bytes
9dd3461
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
#pragma once
#include <ATen/native/Activation.h>
#include <cstdint>

namespace at {
struct TensorIteratorBase;
class TensorBase;
}

namespace at { namespace native {

void launch_glu_backward_kernel(const TensorIteratorBase& iter,
                                int64_t gI_stride, int64_t I_stride);

void launch_log_sigmoid_forward_kernel(TensorIteratorBase& iter);

void launch_prelu_cuda_kernel_share_weights(
    TensorIteratorBase &iter, const TensorBase &weight);
void launch_prelu_cuda_kernel_multi_weights(
    const TensorBase &result, const TensorBase &input, const TensorBase &weight);

void launch_prelu_cuda_backward_kernel_share_weights(
    TensorIteratorBase &iter, const TensorBase &weight);
void launch_prelu_cuda_backward_kernel_multi_weights(
    const TensorBase &input, const TensorBase &weight, const TensorBase &grad_out,
    const TensorBase &input_grad, const TensorBase &weight_grad_collector);

void GeluCUDAKernelImpl(TensorIteratorBase& it, GeluType approximate);
void GeluBackwardCUDAKernelImpl(TensorIteratorBase& it, GeluType approximate);

}}  // namespace at::native