Spaces:
Sleeping
Sleeping
| namespace | |
| { | |
| // Specialized accumulation functions for common type sizes | |
| template <int N, typename T> void fixed_len_sum(const T *val, T *sum, int value_size) | |
| { | |
| for (int i = 0; i < N; ++i, ++val, ++sum) | |
| { | |
| *sum += *val; | |
| } | |
| } | |
| template <typename T> void dyn_len_sum(const T *val, T *sum, int value_size) | |
| { | |
| for (int i = 0; i < value_size; ++i, ++val, ++sum) | |
| { | |
| *sum += *val; | |
| } | |
| } | |
| template <int N, typename T> void fixed_len_inner(const T *a, const T *b, T *dot, int value_size) | |
| { | |
| for (int i = 0; i < N; ++i, ++a, ++b) | |
| { | |
| *dot += *a * *b; | |
| } | |
| } | |
| template <typename T> void dyn_len_inner(const T *a, const T *b, T *dot, int value_size) | |
| { | |
| for (int i = 0; i < value_size; ++i, ++a, ++b) | |
| { | |
| *dot += *a * *b; | |
| } | |
| } | |
| } // namespace | |
| template <typename T> | |
| void array_inner_host(const T *ptr_a, const T *ptr_b, T *ptr_out, int count, int byte_stride_a, int byte_stride_b, | |
| int type_length) | |
| { | |
| assert((byte_stride_a % sizeof(T)) == 0); | |
| assert((byte_stride_b % sizeof(T)) == 0); | |
| const int stride_a = byte_stride_a / sizeof(T); | |
| const int stride_b = byte_stride_b / sizeof(T); | |
| void (*inner_func)(const T *, const T *, T *, int); | |
| switch (type_length) | |
| { | |
| case 1: | |
| inner_func = fixed_len_inner<1, T>; | |
| break; | |
| case 2: | |
| inner_func = fixed_len_inner<2, T>; | |
| break; | |
| case 3: | |
| inner_func = fixed_len_inner<3, T>; | |
| break; | |
| case 4: | |
| inner_func = fixed_len_inner<4, T>; | |
| break; | |
| default: | |
| inner_func = dyn_len_inner<T>; | |
| } | |
| *ptr_out = 0.0f; | |
| for (int i = 0; i < count; ++i) | |
| { | |
| inner_func(ptr_a + i * stride_a, ptr_b + i * stride_b, ptr_out, type_length); | |
| } | |
| } | |
| template <typename T> void array_sum_host(const T *ptr_a, T *ptr_out, int count, int byte_stride, int type_length) | |
| { | |
| assert((byte_stride % sizeof(T)) == 0); | |
| const int stride = byte_stride / sizeof(T); | |
| void (*accumulate_func)(const T *, T *, int); | |
| switch (type_length) | |
| { | |
| case 1: | |
| accumulate_func = fixed_len_sum<1, T>; | |
| break; | |
| case 2: | |
| accumulate_func = fixed_len_sum<2, T>; | |
| break; | |
| case 3: | |
| accumulate_func = fixed_len_sum<3, T>; | |
| break; | |
| case 4: | |
| accumulate_func = fixed_len_sum<4, T>; | |
| break; | |
| default: | |
| accumulate_func = dyn_len_sum<T>; | |
| } | |
| memset(ptr_out, 0, sizeof(T)*type_length); | |
| for (int i = 0; i < count; ++i) | |
| accumulate_func(ptr_a + i * stride, ptr_out, type_length); | |
| } | |
| void array_inner_float_host(uint64_t a, uint64_t b, uint64_t out, int count, int byte_stride_a, int byte_stride_b, | |
| int type_length) | |
| { | |
| const float *ptr_a = (const float *)(a); | |
| const float *ptr_b = (const float *)(b); | |
| float *ptr_out = (float *)(out); | |
| array_inner_host(ptr_a, ptr_b, ptr_out, count, byte_stride_a, byte_stride_b, type_length); | |
| } | |
| void array_inner_double_host(uint64_t a, uint64_t b, uint64_t out, int count, int byte_stride_a, int byte_stride_b, | |
| int type_length) | |
| { | |
| const double *ptr_a = (const double *)(a); | |
| const double *ptr_b = (const double *)(b); | |
| double *ptr_out = (double *)(out); | |
| array_inner_host(ptr_a, ptr_b, ptr_out, count, byte_stride_a, byte_stride_b, type_length); | |
| } | |
| void array_sum_float_host(uint64_t a, uint64_t out, int count, int byte_stride_a, int type_length) | |
| { | |
| const float *ptr_a = (const float *)(a); | |
| float *ptr_out = (float *)(out); | |
| array_sum_host(ptr_a, ptr_out, count, byte_stride_a, type_length); | |
| } | |
| void array_sum_double_host(uint64_t a, uint64_t out, int count, int byte_stride_a, int type_length) | |
| { | |
| const double *ptr_a = (const double *)(a); | |
| double *ptr_out = (double *)(out); | |
| array_sum_host(ptr_a, ptr_out, count, byte_stride_a, type_length); | |
| } | |
| void array_inner_float_device(uint64_t a, uint64_t b, uint64_t out, int count, int byte_stride_a, int byte_stride_b, | |
| int type_length) | |
| { | |
| } | |
| void array_inner_double_device(uint64_t a, uint64_t b, uint64_t out, int count, int byte_stride_a, int byte_stride_b, | |
| int type_length) | |
| { | |
| } | |
| void array_sum_float_device(uint64_t a, uint64_t out, int count, int byte_stride_a, int type_length) | |
| { | |
| } | |
| void array_sum_double_device(uint64_t a, uint64_t out, int count, int byte_stride_a, int type_length) | |
| { | |
| } | |