Spaces:
Sleeping
Sleeping
File size: 4,460 Bytes
66c9c8a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 | #include "warp.h"
namespace
{
// Specialized accumulation functions for common type sizes
template <int N, typename T> void fixed_len_sum(const T *val, T *sum, int value_size)
{
for (int i = 0; i < N; ++i, ++val, ++sum)
{
*sum += *val;
}
}
template <typename T> void dyn_len_sum(const T *val, T *sum, int value_size)
{
for (int i = 0; i < value_size; ++i, ++val, ++sum)
{
*sum += *val;
}
}
template <int N, typename T> void fixed_len_inner(const T *a, const T *b, T *dot, int value_size)
{
for (int i = 0; i < N; ++i, ++a, ++b)
{
*dot += *a * *b;
}
}
template <typename T> void dyn_len_inner(const T *a, const T *b, T *dot, int value_size)
{
for (int i = 0; i < value_size; ++i, ++a, ++b)
{
*dot += *a * *b;
}
}
} // namespace
template <typename T>
void array_inner_host(const T *ptr_a, const T *ptr_b, T *ptr_out, int count, int byte_stride_a, int byte_stride_b,
int type_length)
{
assert((byte_stride_a % sizeof(T)) == 0);
assert((byte_stride_b % sizeof(T)) == 0);
const int stride_a = byte_stride_a / sizeof(T);
const int stride_b = byte_stride_b / sizeof(T);
void (*inner_func)(const T *, const T *, T *, int);
switch (type_length)
{
case 1:
inner_func = fixed_len_inner<1, T>;
break;
case 2:
inner_func = fixed_len_inner<2, T>;
break;
case 3:
inner_func = fixed_len_inner<3, T>;
break;
case 4:
inner_func = fixed_len_inner<4, T>;
break;
default:
inner_func = dyn_len_inner<T>;
}
*ptr_out = 0.0f;
for (int i = 0; i < count; ++i)
{
inner_func(ptr_a + i * stride_a, ptr_b + i * stride_b, ptr_out, type_length);
}
}
template <typename T> void array_sum_host(const T *ptr_a, T *ptr_out, int count, int byte_stride, int type_length)
{
assert((byte_stride % sizeof(T)) == 0);
const int stride = byte_stride / sizeof(T);
void (*accumulate_func)(const T *, T *, int);
switch (type_length)
{
case 1:
accumulate_func = fixed_len_sum<1, T>;
break;
case 2:
accumulate_func = fixed_len_sum<2, T>;
break;
case 3:
accumulate_func = fixed_len_sum<3, T>;
break;
case 4:
accumulate_func = fixed_len_sum<4, T>;
break;
default:
accumulate_func = dyn_len_sum<T>;
}
memset(ptr_out, 0, sizeof(T)*type_length);
for (int i = 0; i < count; ++i)
accumulate_func(ptr_a + i * stride, ptr_out, type_length);
}
void array_inner_float_host(uint64_t a, uint64_t b, uint64_t out, int count, int byte_stride_a, int byte_stride_b,
int type_length)
{
const float *ptr_a = (const float *)(a);
const float *ptr_b = (const float *)(b);
float *ptr_out = (float *)(out);
array_inner_host(ptr_a, ptr_b, ptr_out, count, byte_stride_a, byte_stride_b, type_length);
}
void array_inner_double_host(uint64_t a, uint64_t b, uint64_t out, int count, int byte_stride_a, int byte_stride_b,
int type_length)
{
const double *ptr_a = (const double *)(a);
const double *ptr_b = (const double *)(b);
double *ptr_out = (double *)(out);
array_inner_host(ptr_a, ptr_b, ptr_out, count, byte_stride_a, byte_stride_b, type_length);
}
void array_sum_float_host(uint64_t a, uint64_t out, int count, int byte_stride_a, int type_length)
{
const float *ptr_a = (const float *)(a);
float *ptr_out = (float *)(out);
array_sum_host(ptr_a, ptr_out, count, byte_stride_a, type_length);
}
void array_sum_double_host(uint64_t a, uint64_t out, int count, int byte_stride_a, int type_length)
{
const double *ptr_a = (const double *)(a);
double *ptr_out = (double *)(out);
array_sum_host(ptr_a, ptr_out, count, byte_stride_a, type_length);
}
#if !WP_ENABLE_CUDA
void array_inner_float_device(uint64_t a, uint64_t b, uint64_t out, int count, int byte_stride_a, int byte_stride_b,
int type_length)
{
}
void array_inner_double_device(uint64_t a, uint64_t b, uint64_t out, int count, int byte_stride_a, int byte_stride_b,
int type_length)
{
}
void array_sum_float_device(uint64_t a, uint64_t out, int count, int byte_stride_a, int type_length)
{
}
void array_sum_double_device(uint64_t a, uint64_t out, int count, int byte_stride_a, int type_length)
{
}
#endif |