File size: 815 Bytes
b024d42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.

#include <algorithm>
#include <stdexcept>
#include <vector>

#include "gtest/gtest.h"
#include "test_softmax_util.hpp"

template <ck::index_t N>
using I = ck::Number<N>;
#ifdef CK_ENABLE_FP16
using F16 = ck::half_t;
#endif
using F32 = float;

template <typename Tuple>
class TestSoftmax : public ck::TestSoftmax<Tuple>
{
};

// clang-format off
using KernelTypes = ::testing::Types<
    //         InDataType, AccDataType, OutDataType, Rank
#ifdef CK_ENABLE_FP16
    std::tuple<       F16,         F32,         F16,    I<4>>,
#endif
    std::tuple<       F32,         F32,         F32,    I<4>>
    >;
// clang-format on

TYPED_TEST_SUITE(TestSoftmax, KernelTypes);

#include "test_softmax_ut_cases.inc"