#include "warp.h" namespace { // Specialized accumulation functions for common type sizes template void fixed_len_sum(const T *val, T *sum, int value_size) { for (int i = 0; i < N; ++i, ++val, ++sum) { *sum += *val; } } template void dyn_len_sum(const T *val, T *sum, int value_size) { for (int i = 0; i < value_size; ++i, ++val, ++sum) { *sum += *val; } } template void fixed_len_inner(const T *a, const T *b, T *dot, int value_size) { for (int i = 0; i < N; ++i, ++a, ++b) { *dot += *a * *b; } } template void dyn_len_inner(const T *a, const T *b, T *dot, int value_size) { for (int i = 0; i < value_size; ++i, ++a, ++b) { *dot += *a * *b; } } } // namespace template void array_inner_host(const T *ptr_a, const T *ptr_b, T *ptr_out, int count, int byte_stride_a, int byte_stride_b, int type_length) { assert((byte_stride_a % sizeof(T)) == 0); assert((byte_stride_b % sizeof(T)) == 0); const int stride_a = byte_stride_a / sizeof(T); const int stride_b = byte_stride_b / sizeof(T); void (*inner_func)(const T *, const T *, T *, int); switch (type_length) { case 1: inner_func = fixed_len_inner<1, T>; break; case 2: inner_func = fixed_len_inner<2, T>; break; case 3: inner_func = fixed_len_inner<3, T>; break; case 4: inner_func = fixed_len_inner<4, T>; break; default: inner_func = dyn_len_inner; } *ptr_out = 0.0f; for (int i = 0; i < count; ++i) { inner_func(ptr_a + i * stride_a, ptr_b + i * stride_b, ptr_out, type_length); } } template void array_sum_host(const T *ptr_a, T *ptr_out, int count, int byte_stride, int type_length) { assert((byte_stride % sizeof(T)) == 0); const int stride = byte_stride / sizeof(T); void (*accumulate_func)(const T *, T *, int); switch (type_length) { case 1: accumulate_func = fixed_len_sum<1, T>; break; case 2: accumulate_func = fixed_len_sum<2, T>; break; case 3: accumulate_func = fixed_len_sum<3, T>; break; case 4: accumulate_func = fixed_len_sum<4, T>; break; default: accumulate_func = dyn_len_sum; } memset(ptr_out, 0, sizeof(T)*type_length); for (int i = 0; i < count; ++i) accumulate_func(ptr_a + i * stride, ptr_out, type_length); } void array_inner_float_host(uint64_t a, uint64_t b, uint64_t out, int count, int byte_stride_a, int byte_stride_b, int type_length) { const float *ptr_a = (const float *)(a); const float *ptr_b = (const float *)(b); float *ptr_out = (float *)(out); array_inner_host(ptr_a, ptr_b, ptr_out, count, byte_stride_a, byte_stride_b, type_length); } void array_inner_double_host(uint64_t a, uint64_t b, uint64_t out, int count, int byte_stride_a, int byte_stride_b, int type_length) { const double *ptr_a = (const double *)(a); const double *ptr_b = (const double *)(b); double *ptr_out = (double *)(out); array_inner_host(ptr_a, ptr_b, ptr_out, count, byte_stride_a, byte_stride_b, type_length); } void array_sum_float_host(uint64_t a, uint64_t out, int count, int byte_stride_a, int type_length) { const float *ptr_a = (const float *)(a); float *ptr_out = (float *)(out); array_sum_host(ptr_a, ptr_out, count, byte_stride_a, type_length); } void array_sum_double_host(uint64_t a, uint64_t out, int count, int byte_stride_a, int type_length) { const double *ptr_a = (const double *)(a); double *ptr_out = (double *)(out); array_sum_host(ptr_a, ptr_out, count, byte_stride_a, type_length); } #if !WP_ENABLE_CUDA void array_inner_float_device(uint64_t a, uint64_t b, uint64_t out, int count, int byte_stride_a, int byte_stride_b, int type_length) { } void array_inner_double_device(uint64_t a, uint64_t b, uint64_t out, int count, int byte_stride_a, int byte_stride_b, int type_length) { } void array_sum_float_device(uint64_t a, uint64_t out, int count, int byte_stride_a, int type_length) { } void array_sum_double_device(uint64_t a, uint64_t out, int count, int byte_stride_a, int type_length) { } #endif