File size: 4,460 Bytes
66c9c8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
#include "warp.h"

namespace
{

// Specialized accumulation functions for common type sizes
template <int N, typename T> void fixed_len_sum(const T *val, T *sum, int value_size)
{
    for (int i = 0; i < N; ++i, ++val, ++sum)
    {
        *sum += *val;
    }
}

template <typename T> void dyn_len_sum(const T *val, T *sum, int value_size)
{
    for (int i = 0; i < value_size; ++i, ++val, ++sum)
    {
        *sum += *val;
    }
}

template <int N, typename T> void fixed_len_inner(const T *a, const T *b, T *dot, int value_size)
{
    for (int i = 0; i < N; ++i, ++a, ++b)
    {
        *dot += *a * *b;
    }
}

template <typename T> void dyn_len_inner(const T *a, const T *b, T *dot, int value_size)
{
    for (int i = 0; i < value_size; ++i, ++a, ++b)
    {
        *dot += *a * *b;
    }
}

} // namespace

template <typename T>
void array_inner_host(const T *ptr_a, const T *ptr_b, T *ptr_out, int count, int byte_stride_a, int byte_stride_b,
                      int type_length)
{
    assert((byte_stride_a % sizeof(T)) == 0);
    assert((byte_stride_b % sizeof(T)) == 0);
    const int stride_a = byte_stride_a / sizeof(T);
    const int stride_b = byte_stride_b / sizeof(T);

    void (*inner_func)(const T *, const T *, T *, int);
    switch (type_length)
    {
    case 1:
        inner_func = fixed_len_inner<1, T>;
        break;
    case 2:
        inner_func = fixed_len_inner<2, T>;
        break;
    case 3:
        inner_func = fixed_len_inner<3, T>;
        break;
    case 4:
        inner_func = fixed_len_inner<4, T>;
        break;
    default:
        inner_func = dyn_len_inner<T>;
    }

    *ptr_out = 0.0f;
    for (int i = 0; i < count; ++i)
    {
        inner_func(ptr_a + i * stride_a, ptr_b + i * stride_b, ptr_out, type_length);
    }
}

template <typename T> void array_sum_host(const T *ptr_a, T *ptr_out, int count, int byte_stride, int type_length)
{
    assert((byte_stride % sizeof(T)) == 0);
    const int stride = byte_stride / sizeof(T);

    void (*accumulate_func)(const T *, T *, int);
    switch (type_length)
    {
    case 1:
        accumulate_func = fixed_len_sum<1, T>;
        break;
    case 2:
        accumulate_func = fixed_len_sum<2, T>;
        break;
    case 3:
        accumulate_func = fixed_len_sum<3, T>;
        break;
    case 4:
        accumulate_func = fixed_len_sum<4, T>;
        break;
    default:
        accumulate_func = dyn_len_sum<T>;
    }

    memset(ptr_out, 0, sizeof(T)*type_length);
    for (int i = 0; i < count; ++i)
        accumulate_func(ptr_a + i * stride, ptr_out, type_length);
}

void array_inner_float_host(uint64_t a, uint64_t b, uint64_t out, int count, int byte_stride_a, int byte_stride_b,
                            int type_length)
{
    const float *ptr_a = (const float *)(a);
    const float *ptr_b = (const float *)(b);
    float *ptr_out = (float *)(out);

    array_inner_host(ptr_a, ptr_b, ptr_out, count, byte_stride_a, byte_stride_b, type_length);
}

void array_inner_double_host(uint64_t a, uint64_t b, uint64_t out, int count, int byte_stride_a, int byte_stride_b,
                             int type_length)
{
    const double *ptr_a = (const double *)(a);
    const double *ptr_b = (const double *)(b);
    double *ptr_out = (double *)(out);

    array_inner_host(ptr_a, ptr_b, ptr_out, count, byte_stride_a, byte_stride_b, type_length);
}

void array_sum_float_host(uint64_t a, uint64_t out, int count, int byte_stride_a, int type_length)
{
    const float *ptr_a = (const float *)(a);
    float *ptr_out = (float *)(out);
    array_sum_host(ptr_a, ptr_out, count, byte_stride_a, type_length);
}

void array_sum_double_host(uint64_t a, uint64_t out, int count, int byte_stride_a, int type_length)
{
    const double *ptr_a = (const double *)(a);
    double *ptr_out = (double *)(out);
    array_sum_host(ptr_a, ptr_out, count, byte_stride_a, type_length);
}

#if !WP_ENABLE_CUDA
void array_inner_float_device(uint64_t a, uint64_t b, uint64_t out, int count, int byte_stride_a, int byte_stride_b,
                              int type_length)
{
}

void array_inner_double_device(uint64_t a, uint64_t b, uint64_t out, int count, int byte_stride_a, int byte_stride_b,
                               int type_length)
{
}

void array_sum_float_device(uint64_t a, uint64_t out, int count, int byte_stride_a, int type_length)
{
}

void array_sum_double_device(uint64_t a, uint64_t out, int count, int byte_stride_a, int type_length)
{
}
#endif