| |
| |
|
|
| #include <vector> |
| #include <iostream> |
| #include <numeric> |
| #include <cassert> |
| #include <cstdlib> |
| #include <iostream> |
| #include <time.h> |
| #include <unordered_set> |
|
|
| #include "ck_tile/core.hpp" |
|
|
| #ifndef TEST_SCATTER_GATHER_VERBOSE |
| #define TEST_SCATTER_GATHER_VERBOSE 1 |
| #endif |
|
|
| #define HIP_CALL(call) \ |
| do \ |
| { \ |
| hipError_t err = call; \ |
| if(err != hipSuccess) \ |
| { \ |
| printf("[hiperror](%d) fail to call %s", static_cast<int>(err), #call); \ |
| exit(0); \ |
| } \ |
| } while(0) |
|
|
| |
| |
| |
| |
| |
| template <ck_tile::index_t ROW_TILE_SIZE = 8, |
| ck_tile::index_t COL_TILE_SIZE = 32 * 8, |
| ck_tile::index_t BLOCK_SIZE = 256, |
| ck_tile::index_t ALIGNMENT = 8, |
| typename INDEX_BUF_TYPE = ck_tile::index_t, |
| typename DATA_TYPE = ck_tile::fp16_t> |
| __global__ void row_scatter_gather(const INDEX_BUF_TYPE* src_row_idx_ptr, |
| const INDEX_BUF_TYPE* dst_row_idx_ptr, |
| const DATA_TYPE* src_ptr, |
| DATA_TYPE* dst_ptr, |
| ck_tile::index_t n_row_total, |
| ck_tile::index_t , |
| ck_tile::index_t n_cols) |
| { |
| using namespace ck_tile; |
|
|
| |
| constexpr index_t vec = ALIGNMENT; |
| static_assert(COL_TILE_SIZE % vec == 0); |
| constexpr index_t col_lanes = COL_TILE_SIZE / vec; |
| constexpr index_t warp_size = ck_tile::get_warp_size(); |
| static_assert(warp_size % col_lanes == 0); |
| constexpr index_t row_lanes = warp_size / col_lanes; |
| constexpr index_t num_warps = BLOCK_SIZE / warp_size; |
| static_assert(ROW_TILE_SIZE % (num_warps * row_lanes) == 0); |
| constexpr index_t row_repeat = ROW_TILE_SIZE / (num_warps * row_lanes); |
| static_assert( |
| row_repeat == 1, |
| "currently indexing not support(and would be not performant) if row_repeat has more"); |
|
|
| |
| index_t tile_col_idx = 0; |
| index_t tile_row_idx = blockIdx.x * ROW_TILE_SIZE; |
|
|
| |
| constexpr auto src_dist = make_static_tile_distribution( |
| tile_distribution_encoding< |
| sequence<1>, |
| tuple<sequence<row_repeat, num_warps, row_lanes>, sequence<col_lanes, vec>>, |
| tuple<sequence<1>, sequence<1, 2>>, |
| tuple<sequence<1>, sequence<2, 0>>, |
| sequence<1, 2>, |
| sequence<0, 1>>{}); |
| const auto coord = src_dist.calculate_index(); |
| const auto row_coord = coord[number<0>{}] + tile_row_idx; |
|
|
| |
| INDEX_BUF_TYPE src_row_id = src_row_idx_ptr[row_coord]; |
| INDEX_BUF_TYPE dst_row_id = dst_row_idx_ptr[row_coord]; |
|
|
| |
| |
|
|
| const auto src_view = |
| make_naive_tensor_view<address_space_enum::global>(src_ptr, |
| make_tuple(n_row_total, n_cols), |
| make_tuple(n_cols, 1), |
| number<vec>{}, |
| number<1>{}); |
|
|
| const auto src_gather_view = transform_tensor_view( |
| src_view, |
| make_tuple(make_indexing_transform( |
| n_row_total, |
| src_row_id), |
| make_pass_through_transform(n_cols)), |
| make_tuple(sequence<0>{}, sequence<1>{}), |
| make_tuple(sequence<0>{}, sequence<1>{})); |
|
|
| auto src_tile = make_tile_window(src_gather_view, |
| make_tuple(number<ROW_TILE_SIZE>{}, number<COL_TILE_SIZE>{}), |
| {tile_row_idx, tile_col_idx}, |
| src_dist); |
|
|
| const auto dst_view = |
| make_naive_tensor_view<address_space_enum::global>(dst_ptr, |
| make_tuple(n_row_total, n_cols), |
| make_tuple(n_cols, 1), |
| number<vec>{}, |
| number<1>{}); |
|
|
| const auto dst_scatter_view = transform_tensor_view( |
| dst_view, |
| make_tuple(make_indexing_transform( |
| n_row_total, |
| dst_row_id), |
| make_pass_through_transform(n_cols)), |
| make_tuple(sequence<0>{}, sequence<1>{}), |
| make_tuple(sequence<0>{}, sequence<1>{})); |
|
|
| auto dst_tile = make_tile_window(dst_scatter_view, |
| make_tuple(number<ROW_TILE_SIZE>{}, number<COL_TILE_SIZE>{}), |
| {tile_row_idx, tile_col_idx}, |
| src_dist ); |
|
|
| |
| for(auto i = 0; i < n_cols; i += COL_TILE_SIZE) |
| { |
| |
| |
| auto data = load_tile(src_tile); |
| store_tile(dst_tile, data); |
|
|
| move_tile_window(src_tile, {number<0>{}, number<COL_TILE_SIZE>{}}); |
| move_tile_window(dst_tile, {number<0>{}, number<COL_TILE_SIZE>{}}); |
| } |
| } |
|
|
| union pixel |
| { |
| struct __attribute__((packed)) |
| { |
| unsigned int r : 6; |
| unsigned int c : 10; |
| }; |
| ushort data; |
| }; |
|
|
| struct unique_linear_rand |
| { |
| unique_linear_rand(int capacity_) : capacity(capacity_) {} |
| std::unordered_set<int> set; |
| int gen() |
| { |
| if(static_cast<int>(set.size()) >= capacity) |
| { |
| printf("overflow, but will give you an number as well\n"); |
| return std::rand() % capacity; |
| } |
| while(1) |
| { |
| int r = std::rand() % capacity; |
| if(set.count(r) == 1) |
| { |
| continue; |
| } |
| set.insert(r); |
| return r; |
| } |
| } |
|
|
| int capacity; |
| }; |
|
|
| int main() |
| { |
| int row_total = 64; |
| int row_select = 8 * 2; |
| int col = 256 * 2; |
| using fp16_t = ck_tile::fp16_t; |
|
|
| constexpr int row_tile = 8; |
| constexpr int col_tile = 256; |
|
|
| fp16_t* src = reinterpret_cast<fp16_t*>(malloc(row_total * col * sizeof(fp16_t))); |
| for(int i_r = 0; i_r < row_total; i_r++) |
| { |
| for(int i_c = 0; i_c < col; i_c++) |
| { |
| int i = i_r * col + i_c; |
| pixel p; |
| p.r = i_r; |
| p.c = i_c; |
| ushort d = p.data; |
| src[i] = ck_tile::bit_cast<fp16_t>(d); |
| } |
| } |
|
|
| fp16_t* dst = reinterpret_cast<fp16_t*>(malloc(row_total * col * sizeof(fp16_t))); |
| int* src_idx = reinterpret_cast<int*>(malloc(row_select * sizeof(int))); |
| int* dst_idx = reinterpret_cast<int*>(malloc(row_select * sizeof(int))); |
| |
| |
| std::srand(std::time(nullptr)); |
| auto src_gen = unique_linear_rand(row_total); |
| auto dst_gen = unique_linear_rand(row_total); |
| for(int i_r = 0; i_r < row_select; i_r++) |
| { |
| src_idx[i_r] = src_gen.gen(); |
| dst_idx[i_r] = dst_gen.gen(); |
| } |
|
|
| void* dev_src; |
| void* dev_dst; |
| void* dev_src_idx; |
| void* dev_dst_idx; |
| HIP_CALL(hipMalloc(&dev_src, row_total * col * sizeof(fp16_t))); |
| HIP_CALL(hipMalloc(&dev_dst, row_total * col * sizeof(fp16_t))); |
| HIP_CALL(hipMalloc(&dev_src_idx, row_select * sizeof(int))); |
| HIP_CALL(hipMalloc(&dev_dst_idx, row_select * sizeof(int))); |
|
|
| HIP_CALL(hipMemcpy(dev_src, src, row_total * col * sizeof(fp16_t), hipMemcpyHostToDevice)); |
| HIP_CALL(hipMemcpy(dev_src_idx, src_idx, row_select * sizeof(int), hipMemcpyHostToDevice)); |
| HIP_CALL(hipMemcpy(dev_dst_idx, dst_idx, row_select * sizeof(int), hipMemcpyHostToDevice)); |
|
|
| constexpr int bdim = 256; |
| int gdim = (row_select + row_tile - 1) / row_tile; |
| row_scatter_gather<row_tile, col_tile><<<gdim, bdim>>>(reinterpret_cast<int*>(dev_src_idx), |
| reinterpret_cast<int*>(dev_dst_idx), |
| reinterpret_cast<fp16_t*>(dev_src), |
| reinterpret_cast<fp16_t*>(dev_dst), |
| row_total, |
| row_select, |
| col); |
|
|
| HIP_CALL(hipMemcpy(dst, dev_dst, row_total * col * sizeof(fp16_t), hipMemcpyDeviceToHost)); |
|
|
| #if TEST_SCATTER_GATHER_VERBOSE |
| printf("select row:"); |
| for(int i_r = 0; i_r < row_select; i_r++) |
| { |
| printf("%d->%d->%d ", i_r, src_idx[i_r], dst_idx[i_r]); |
| } |
| printf("\n"); |
| #endif |
|
|
| int err_cnt = 0; |
| for(int i_r = 0; i_r < row_select; i_r++) |
| { |
| for(int i_c = 0; i_c < col; i_c++) |
| { |
| int i = dst_idx[i_r] * col + i_c; |
| pixel p = ck_tile::bit_cast<pixel>(dst[i]); |
| bool is_ok = p.r == src_idx[i_r] && p.c == i_c; |
| if(!is_ok) |
| { |
| if(i_c == 0) |
| printf("(%d)pixel: %dx%d -> %d\n", i_r, p.r, p.c, dst_idx[i_r]); |
| err_cnt++; |
| } |
| } |
| } |
| #if TEST_SCATTER_GATHER_VERBOSE |
| printf("err:%d\n", err_cnt); |
| #endif |
|
|
| free(src); |
| free(dst); |
| free(src_idx); |
| free(dst_idx); |
| return err_cnt == 0 ? 0 : -1; |
| } |
|
|