| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| #include "cutlass_preprocessors.h" |
| #include "cuda_utils.h" |
| #include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" |
|
|
| #include <vector> |
|
|
| namespace fastertransformer { |
|
|
| int get_bits_in_quant_type(QuantType quant_type) { |
| switch (quant_type) { |
| case QuantType::INT8_WEIGHT_ONLY: |
| return 8; |
| case QuantType::PACKED_INT4_WEIGHT_ONLY: |
| return 4; |
| default: |
| return -1; |
| } |
| } |
|
|
| struct LayoutDetails { |
| enum class Layout { |
| UNKNOWN, |
| ROW_MAJOR, |
| COLUMN_MAJOR |
| }; |
|
|
| Layout layoutB = Layout::UNKNOWN; |
| int rows_per_column_tile = 1; |
| int columns_interleaved = 1; |
|
|
| bool uses_imma_ldsm = false; |
| }; |
|
|
| template<typename Layout> |
| struct getLayoutDetails { |
| }; |
|
|
| template<> |
| struct getLayoutDetails<cutlass::layout::RowMajor> { |
| LayoutDetails operator()() |
| { |
| LayoutDetails layout_details; |
| layout_details.layoutB = LayoutDetails::Layout::ROW_MAJOR; |
| return layout_details; |
| } |
| }; |
|
|
| template<> |
| struct getLayoutDetails<cutlass::layout::ColumnMajor> { |
| LayoutDetails operator()() |
| { |
| LayoutDetails layout_details; |
| layout_details.layoutB = LayoutDetails::Layout::COLUMN_MAJOR; |
| return layout_details; |
| } |
| }; |
|
|
| template<int RowsPerTile, int ColumnsInterleaved> |
| struct getLayoutDetails<cutlass::layout::ColumnMajorTileInterleave<RowsPerTile, ColumnsInterleaved>> { |
| LayoutDetails operator()() |
| { |
| LayoutDetails layout_details; |
| layout_details.layoutB = LayoutDetails::Layout::COLUMN_MAJOR; |
| layout_details.rows_per_column_tile = RowsPerTile; |
| layout_details.columns_interleaved = ColumnsInterleaved; |
| return layout_details; |
| } |
| }; |
|
|
| template<typename cutlassArch, typename TypeB> |
| LayoutDetails getLayoutDetailsForArchAndQuantType() |
| { |
|
|
| using CompileTraits = cutlass::gemm::kernel::LayoutDetailsB<TypeB, cutlassArch>; |
| using LayoutB = typename CompileTraits::Layout; |
| using MmaOperator = typename CompileTraits::Operator; |
| LayoutDetails details = getLayoutDetails<LayoutB>()(); |
| details.uses_imma_ldsm = std::is_same<MmaOperator, cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA>::value; |
| return details; |
| } |
|
|
| template<typename cutlassArch> |
| LayoutDetails getLayoutDetailsForArch(QuantType quant_type) |
| { |
| LayoutDetails details; |
| if (quant_type == QuantType::INT8_WEIGHT_ONLY) { |
| details = getLayoutDetailsForArchAndQuantType<cutlassArch, uint8_t>(); |
| } |
| else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) { |
| details = getLayoutDetailsForArchAndQuantType<cutlassArch, cutlass::uint4b_t>(); |
| } |
| else { |
| FT_CHECK_WITH_INFO(false, "Unsupported quantization type"); |
| } |
| return details; |
| } |
|
|
| LayoutDetails getLayoutDetailsForTransform(QuantType quant_type, int arch) |
| { |
| if (arch >= 70 && arch < 75) { |
| return getLayoutDetailsForArch<cutlass::arch::Sm70>(quant_type); |
| } |
| else if (arch >= 75 && arch < 80) { |
| return getLayoutDetailsForArch<cutlass::arch::Sm75>(quant_type); |
| } |
| else if (arch >= 80 && arch < 90) { |
| return getLayoutDetailsForArch<cutlass::arch::Sm80>(quant_type); |
| } |
| else { |
| FT_CHECK_WITH_INFO(false, "Unsupported Arch"); |
| return LayoutDetails(); |
| } |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| void permute_B_rows_for_mixed_gemm(int8_t *permuted_quantized_tensor, |
| const int8_t *quantized_tensor, |
| const std::vector<size_t> &shape, |
| QuantType quant_type, |
| const int64_t arch_version) { |
| const size_t num_rows = shape[0]; |
| const size_t num_cols = shape[1]; |
|
|
| const int BITS_PER_ELT = get_bits_in_quant_type(quant_type); |
| const int K = 16 / BITS_PER_ELT; |
| const int ELTS_PER_REG = 32 / BITS_PER_ELT; |
|
|
| const uint32_t *input_byte_ptr = |
| reinterpret_cast<const uint32_t *>(quantized_tensor); |
| uint32_t *output_byte_ptr = |
| reinterpret_cast<uint32_t *>(permuted_quantized_tensor); |
|
|
| int MMA_SHAPE_N = 8; |
| int B_ROWS_PER_MMA = 8 * K; |
| const int elts_in_int32 = 32 / BITS_PER_ELT; |
|
|
| const int num_vec_cols = num_cols / elts_in_int32; |
|
|
| FT_CHECK_WITH_INFO(arch_version >= 75, |
| "Unsupported Arch. Pre-volta not supported. Column " |
| "interleave not needed on Volta."); |
|
|
| FT_CHECK_WITH_INFO(num_rows % B_ROWS_PER_MMA == 0, |
| fmtstr("Invalid shape for quantized tensor. Number of " |
| "rows of quantized matrix must be a multiple of %d", |
| B_ROWS_PER_MMA)); |
|
|
| FT_CHECK_WITH_INFO( |
| num_cols % MMA_SHAPE_N == 0, |
| fmtstr("Invalid shape for quantized tensor. On turing/Ampere, the number " |
| "of cols must be a multiple of %d.", |
| MMA_SHAPE_N)); |
|
|
| |
| |
| for (size_t base_row = 0; base_row < num_rows; base_row += B_ROWS_PER_MMA) { |
| for (int tile_row = 0; tile_row < B_ROWS_PER_MMA; ++tile_row) { |
|
|
| for (int write_col = 0; write_col < num_vec_cols; ++write_col) { |
| const int write_row = base_row + tile_row; |
| const int tile_read_row = 8 * (((tile_row % ELTS_PER_REG) / 2)) + |
| tile_row % 2 + 2 * (tile_row / ELTS_PER_REG); |
| const int read_row = base_row + tile_read_row; |
| const int read_col = write_col; |
|
|
| const int64_t read_offset = int64_t(read_row) * num_vec_cols + read_col; |
| const int64_t write_offset = |
| int64_t(write_row) * num_vec_cols + write_col; |
|
|
| output_byte_ptr[write_offset] = input_byte_ptr[read_offset]; |
| } |
| } |
| } |
| } |
|
|
| |
| |
| |
| |
| template <QuantType quant_type> |
| void subbyte_transpose_impl(int8_t *transposed_quantized_tensor, |
| const int8_t *quantized_tensor, |
| const std::vector<size_t> &shape) { |
| const int bits_per_elt = get_bits_in_quant_type(quant_type); |
| const size_t num_rows = shape[0]; |
| const size_t num_cols = shape[1]; |
|
|
| const size_t col_bytes = num_cols * bits_per_elt / 8; |
| const size_t col_bytes_trans = num_rows * bits_per_elt / 8; |
|
|
| const uint8_t *input_byte_ptr = |
| reinterpret_cast<const uint8_t *>(quantized_tensor); |
| uint8_t *output_byte_ptr = |
| reinterpret_cast<uint8_t *>(transposed_quantized_tensor); |
|
|
| static constexpr int ELTS_PER_BYTE = |
| quant_type == QuantType::INT8_WEIGHT_ONLY ? 1 : 2; |
|
|
| static constexpr int M_TILE_L1 = 64; |
| static constexpr int N_TILE_L1 = M_TILE_L1 / ELTS_PER_BYTE; |
| uint8_t cache_buf[M_TILE_L1][N_TILE_L1]; |
|
|
| static constexpr int VECTOR_WIDTH = std::min(32, N_TILE_L1); |
|
|
| |
| |
| |
| |
| FT_CHECK_WITH_INFO( |
| !(col_bytes_trans % VECTOR_WIDTH) && !(col_bytes % VECTOR_WIDTH), |
| fmtstr("Number of bytes for rows and cols must be a multiple of %d. " |
| "However, num_rows_bytes = %ld and num_col_bytes = %d.", |
| VECTOR_WIDTH, col_bytes_trans, col_bytes)); |
|
|
| for (size_t row_tile_start = 0; row_tile_start < num_rows; |
| row_tile_start += M_TILE_L1) { |
| for (size_t col_tile_start_byte = 0; col_tile_start_byte < col_bytes; |
| col_tile_start_byte += N_TILE_L1) { |
|
|
| const int row_limit = std::min(row_tile_start + M_TILE_L1, num_rows); |
| const int col_limit = |
| std::min(col_tile_start_byte + N_TILE_L1, col_bytes); |
|
|
| for (int ii = 0; ii < M_TILE_L1; ++ii) { |
| const int row = row_tile_start + ii; |
|
|
| for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) { |
| const int col = col_tile_start_byte + jj; |
|
|
| const size_t logical_src_offset = row * col_bytes + col; |
|
|
| if (row < row_limit && col < col_limit) { |
| for (int v = 0; v < VECTOR_WIDTH; ++v) { |
| cache_buf[ii][jj + v] = input_byte_ptr[logical_src_offset + v]; |
| } |
| } |
| } |
| } |
|
|
| if (quant_type == QuantType::INT8_WEIGHT_ONLY) { |
| for (int ii = 0; ii < M_TILE_L1; ++ii) { |
| for (int jj = ii + 1; jj < N_TILE_L1; ++jj) { |
| std::swap(cache_buf[ii][jj], cache_buf[jj][ii]); |
| } |
| } |
| } else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) { |
|
|
| for (int ii = 0; ii < M_TILE_L1; ++ii) { |
| |
| |
| |
| for (int jj = ii + 1; jj < M_TILE_L1; ++jj) { |
| const int ii_byte = ii / ELTS_PER_BYTE; |
| const int ii_bit_offset = ii % ELTS_PER_BYTE; |
|
|
| const int jj_byte = jj / ELTS_PER_BYTE; |
| const int jj_bit_offset = jj % ELTS_PER_BYTE; |
|
|
| uint8_t src_elt = |
| 0xF & (cache_buf[ii][jj_byte] >> (4 * jj_bit_offset)); |
| uint8_t tgt_elt = |
| 0xF & (cache_buf[jj][ii_byte] >> (4 * ii_bit_offset)); |
|
|
| cache_buf[ii][jj_byte] &= (0xF0 >> (4 * jj_bit_offset)); |
| cache_buf[jj][ii_byte] &= (0xF0 >> (4 * ii_bit_offset)); |
|
|
| cache_buf[ii][jj_byte] |= (tgt_elt << (4 * jj_bit_offset)); |
| cache_buf[jj][ii_byte] |= (src_elt << (4 * ii_bit_offset)); |
| } |
| } |
| } else { |
| FT_CHECK_WITH_INFO(false, "Unsupported quantization type."); |
| } |
|
|
| const size_t row_tile_start_trans = col_tile_start_byte * ELTS_PER_BYTE; |
| const size_t col_tile_start_byte_trans = row_tile_start / ELTS_PER_BYTE; |
|
|
| const int row_limit_trans = |
| std::min(row_tile_start_trans + M_TILE_L1, num_cols); |
| const int col_limit_trans = |
| std::min(col_tile_start_byte_trans + N_TILE_L1, col_bytes_trans); |
|
|
| for (int ii = 0; ii < M_TILE_L1; ++ii) { |
| const int row = row_tile_start_trans + ii; |
| for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) { |
| const int col = col_tile_start_byte_trans + jj; |
|
|
| const size_t logical_tgt_offset = row * col_bytes_trans + col; |
|
|
| if (row < row_limit_trans && col < col_limit_trans) { |
| for (int v = 0; v < VECTOR_WIDTH; ++v) { |
| output_byte_ptr[logical_tgt_offset + v] = cache_buf[ii][jj + v]; |
| } |
| } |
| } |
| } |
| } |
| } |
| } |
|
|
| void subbyte_transpose(int8_t *transposed_quantized_tensor, |
| const int8_t *quantized_tensor, |
| const std::vector<size_t> &shape, QuantType quant_type) { |
|
|
| if (quant_type == QuantType::INT8_WEIGHT_ONLY) { |
| subbyte_transpose_impl<QuantType::INT8_WEIGHT_ONLY>( |
| transposed_quantized_tensor, quantized_tensor, shape); |
| } else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) { |
| subbyte_transpose_impl<QuantType::PACKED_INT4_WEIGHT_ONLY>( |
| transposed_quantized_tensor, quantized_tensor, shape); |
| } else { |
| FT_CHECK_WITH_INFO(false, "Invalid quant_tye"); |
| } |
| } |
|
|
| void add_bias_and_interleave_int8s_inplace(int8_t *int8_tensor, |
| const size_t num_elts) { |
| for (size_t ii = 0; ii < num_elts; ++ii) { |
| int8_tensor[ii] = int8_t(int(int8_tensor[ii]) + 128); |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| FT_CHECK_WITH_INFO(num_elts % 4 == 0, "Dimensions of int8 tensor must be a " |
| "multiple of 4 for register relayout"); |
| for (size_t base = 0; base < num_elts; base += 4) { |
| std::swap(int8_tensor[base + 1], int8_tensor[base + 2]); |
| } |
| } |
|
|
| void add_bias_and_interleave_int4s_inplace(int8_t *packed_int4_tensor, |
| const size_t num_elts) { |
| const size_t num_bytes = num_elts / 2; |
|
|
| |
| |
| for (size_t ii = 0; ii < num_bytes; ++ii) { |
| int8_t transformed_packed_int4s = 0; |
| int8_t transformed_first_elt = |
| (int8_t(packed_int4_tensor[ii] << 4) >> 4) + |
| 8; |
| int8_t transformed_second_elt = (packed_int4_tensor[ii] >> 4) + 8; |
|
|
| FT_CHECK_WITH_INFO(transformed_first_elt >= 0 && |
| transformed_first_elt <= 15, |
| "Illegal result for int4 transform (first elt)"); |
| FT_CHECK_WITH_INFO(transformed_second_elt >= 0 && |
| transformed_second_elt <= 15, |
| "Illegal result for int4 transform (second elt)"); |
|
|
| |
| |
| transformed_packed_int4s |= transformed_first_elt; |
| transformed_packed_int4s |= (transformed_second_elt << 4); |
| packed_int4_tensor[ii] = transformed_packed_int4s; |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| FT_CHECK_WITH_INFO(num_bytes % 4 == 0, "Dimensions of int4 tensor must be a " |
| "multiple of 8 for register relayout"); |
| const size_t num_registers = num_bytes / 4; |
|
|
| uint32_t *register_ptr = reinterpret_cast<uint32_t *>(packed_int4_tensor); |
| for (size_t ii = 0; ii < num_registers; ++ii) { |
| const uint32_t current_register = register_ptr[ii]; |
| uint32_t transformed_register = 0; |
|
|
| for (int dest_idx = 0; dest_idx < 8; ++dest_idx) { |
| const int src_idx = dest_idx < 4 ? 2 * dest_idx : 2 * (dest_idx - 4) + 1; |
| const int src_shift = 4 * src_idx; |
| const int dest_shift = 4 * dest_idx; |
|
|
| const uint32_t src_bits = (current_register >> src_shift) & 0xF; |
| transformed_register |= (src_bits << dest_shift); |
| } |
| register_ptr[ii] = transformed_register; |
| } |
| } |
|
|
| void add_bias_and_interleave_quantized_tensor_inplace(int8_t *tensor, |
| const size_t num_elts, |
| QuantType quant_type) { |
| if (quant_type == QuantType::INT8_WEIGHT_ONLY) { |
| add_bias_and_interleave_int8s_inplace(tensor, num_elts); |
| } else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) { |
| add_bias_and_interleave_int4s_inplace(tensor, num_elts); |
| } else { |
| FT_CHECK_WITH_INFO(false, "Invalid quantization type for interleaving."); |
| } |
| } |
|
|
| void interleave_column_major_tensor(int8_t *interleaved_quantized_tensor, |
| const int8_t *quantized_tensor, |
| const std::vector<size_t> &shape, |
| QuantType quant_type, |
| LayoutDetails details) { |
| |
| FT_CHECK(quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY || |
| quant_type == QuantType::INT8_WEIGHT_ONLY); |
| FT_CHECK_WITH_INFO(shape.size() == 2, "Shape must be 2-D"); |
|
|
| const size_t num_rows = shape[0]; |
| const size_t num_cols = shape[1]; |
|
|
| const int BITS_PER_ELT = get_bits_in_quant_type(quant_type); |
| const int elts_in_int32 = 32 / BITS_PER_ELT; |
|
|
| const int rows_per_tile = details.rows_per_column_tile; |
|
|
| FT_CHECK_WITH_INFO(!(num_rows % elts_in_int32), |
| fmtstr("The number of rows must be a multiple of %d but " |
| "the number of rows is %d.", |
| elts_in_int32, num_rows)); |
|
|
| FT_CHECK_WITH_INFO(!(num_cols % rows_per_tile), |
| fmtstr("The number of columns must be a multiple of %d " |
| "but the number of columns is %ld", |
| rows_per_tile, num_cols)); |
|
|
| const uint32_t *input_byte_ptr = |
| reinterpret_cast<const uint32_t *>(quantized_tensor); |
| uint32_t *output_byte_ptr = |
| reinterpret_cast<uint32_t *>(interleaved_quantized_tensor); |
|
|
| FT_CHECK_WITH_INFO(!(num_cols % rows_per_tile), |
| fmtstr("The number of columns must be a multiple of %d " |
| "but the number of columns is %d.", |
| rows_per_tile, num_cols)); |
|
|
| const int num_vec_rows = num_rows / elts_in_int32; |
| const int vec_rows_per_tile = rows_per_tile / elts_in_int32; |
| const int interleave = details.columns_interleaved; |
|
|
| for (size_t read_col = 0; read_col < num_cols; ++read_col) { |
| const auto write_col = read_col / interleave; |
| for (int base_vec_row = 0; base_vec_row < num_vec_rows; |
| base_vec_row += vec_rows_per_tile) { |
| for (int vec_read_row = base_vec_row; |
| vec_read_row < |
| std::min(num_vec_rows, base_vec_row + vec_rows_per_tile); |
| ++vec_read_row) { |
| const int64_t vec_write_row = |
| interleave * base_vec_row + |
| vec_rows_per_tile * (read_col % interleave) + |
| vec_read_row % vec_rows_per_tile; |
|
|
| const int64_t read_offset = |
| int64_t(read_col) * num_vec_rows + vec_read_row; |
| const int64_t write_offset = |
| int64_t(write_col) * num_vec_rows * interleave + vec_write_row; |
| output_byte_ptr[write_offset] = input_byte_ptr[read_offset]; |
| } |
| } |
| } |
| } |
|
|
| void preprocess_weights_for_mixed_gemm(int8_t *preprocessed_quantized_weight, |
| const int8_t *row_major_quantized_weight, |
| const std::vector<size_t> &shape, |
| QuantType quant_type, int arch) { |
| LayoutDetails details = getLayoutDetailsForTransform(quant_type, arch); |
|
|
| FT_CHECK_WITH_INFO(shape.size() == 2, "Shape must be 2-D"); |
|
|
| size_t num_elts = 1; |
| for (const auto &dim : shape) { |
| num_elts *= dim; |
| } |
|
|
| const size_t num_bytes = num_elts * get_bits_in_quant_type(quant_type) / 8; |
|
|
| std::vector<int8_t> src_buf(num_bytes); |
| std::vector<int8_t> dst_buf(num_bytes); |
| std::copy(row_major_quantized_weight, row_major_quantized_weight + num_bytes, src_buf.begin()); |
|
|
| |
| if (details.uses_imma_ldsm) { |
| permute_B_rows_for_mixed_gemm(dst_buf.data(), src_buf.data(), shape, quant_type, arch); |
| src_buf.swap(dst_buf); |
| } |
|
|
| if (details.layoutB == LayoutDetails::Layout::COLUMN_MAJOR) { |
| subbyte_transpose(dst_buf.data(), src_buf.data(), shape, quant_type); |
| src_buf.swap(dst_buf); |
| } |
|
|
| if (details.columns_interleaved > 1) { |
| interleave_column_major_tensor(dst_buf.data(), src_buf.data(), shape, quant_type, details); |
| src_buf.swap(dst_buf); |
| } |
|
|
| add_bias_and_interleave_quantized_tensor_inplace(src_buf.data(), num_elts, quant_type); |
| std::copy(src_buf.begin(), src_buf.end(), preprocessed_quantized_weight); |
| } |
|
|
| void preprocess_weights(int8_t *preprocessed_quantized_weight, |
| const int8_t *row_major_quantized_weight, size_t rows, |
| size_t cols, bool is_int4, int arch) { |
| QuantType qtype = is_int4 ? QuantType::PACKED_INT4_WEIGHT_ONLY |
| : QuantType::INT8_WEIGHT_ONLY; |
| preprocess_weights_for_mixed_gemm(preprocessed_quantized_weight, |
| row_major_quantized_weight, {rows, cols}, |
| qtype, arch); |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| template<typename ComputeType, typename WeightType> |
| void symmetric_quantize(int8_t* processed_quantized_weight, |
| int8_t* unprocessed_quantized_weight, |
| ComputeType* scale_ptr, |
| const WeightType* input_weight_ptr, |
| const std::vector<size_t>& shape, |
| QuantType quant_type) |
| { |
|
|
| FT_CHECK_WITH_INFO(processed_quantized_weight, "Processed quantized tensor is NULL"); |
| FT_CHECK_WITH_INFO(scale_ptr, "Scale output pointer is NULL"); |
| FT_CHECK_WITH_INFO(input_weight_ptr, "Input weight pointer is NULL"); |
|
|
| FT_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D"); |
| const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; |
| const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; |
| const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; |
|
|
| const int bits_in_type = get_bits_in_quant_type(quant_type); |
| const int bytes_per_out_col = num_cols * bits_in_type / 8; |
|
|
| std::vector<int8_t> weight_buf; |
| if (unprocessed_quantized_weight == nullptr) { |
| weight_buf.resize(num_experts * num_rows * num_cols); |
| unprocessed_quantized_weight = weight_buf.data(); |
| } |
|
|
| const int input_mat_size = num_rows * num_cols; |
| const int quantized_mat_size = num_rows * bytes_per_out_col; |
| const float quant_range_scale = 1.f / float(1 << (bits_in_type - 1)); |
|
|
| std::vector<float> per_col_max(num_cols); |
|
|
| for (int expert = 0; expert < num_experts; ++expert) { |
| const WeightType* current_weight = input_weight_ptr + expert * input_mat_size; |
| int8_t* current_quantized_weight = unprocessed_quantized_weight + expert * quantized_mat_size; |
|
|
| |
| for (int jj = 0; jj < num_cols; ++jj) { |
| per_col_max[jj] = 0.f; |
| } |
|
|
| for (int ii = 0; ii < num_rows; ++ii) { |
| const WeightType* current_weight_row = current_weight + ii * num_cols; |
| for (int jj = 0; jj < num_cols; ++jj) { |
| per_col_max[jj] = std::max(per_col_max[jj], std::abs(float(current_weight_row[jj]))); |
| } |
| } |
|
|
| |
| ComputeType* current_scales = scale_ptr + expert * num_cols; |
| for (int jj = 0; jj < num_cols; ++jj) { |
| per_col_max[jj] *= quant_range_scale; |
| current_scales[jj] = ComputeType(per_col_max[jj]); |
| } |
|
|
| |
| for (int ii = 0; ii < num_rows; ++ii) { |
| int8_t* current_quantized_weight_row = current_quantized_weight + ii * bytes_per_out_col; |
| const WeightType* current_weight_row = current_weight + ii * num_cols; |
| for (int jj = 0; jj < bytes_per_out_col; ++jj) { |
|
|
| if (quant_type == QuantType::INT8_WEIGHT_ONLY) { |
| const float col_scale = per_col_max[jj]; |
| const float weight_elt = float(current_weight_row[jj]); |
| const float scaled_weight = round(weight_elt / col_scale); |
| const int8_t clipped_weight = int8_t(std::max(-128.f, std::min(127.f, scaled_weight))); |
| current_quantized_weight_row[jj] = clipped_weight; |
| } |
| else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) { |
|
|
| |
| int8_t packed_int4s = 0; |
| for (int packed_idx = 0; packed_idx < 2; ++packed_idx) { |
| const int input_idx = 2 * jj + packed_idx; |
| if (input_idx < num_cols) { |
| const float col_scale = per_col_max[input_idx]; |
| const float weight_elt = float(current_weight_row[input_idx]); |
| const float scaled_weight = round(weight_elt / col_scale); |
| int int_weight = int(scaled_weight); |
| const int8_t clipped_weight = std::max(-8, std::min(7, int_weight)); |
|
|
| |
| |
| packed_int4s |= ((clipped_weight & 0x0F) << (4 * packed_idx)); |
| } |
| } |
| current_quantized_weight_row[jj] = packed_int4s; |
| } |
| else { |
| FT_CHECK_WITH_INFO(false, "Unsupported quantization type"); |
| } |
| } |
| } |
| } |
| const int arch = fastertransformer::getSMVersion(); |
| preprocess_weights_for_mixed_gemm(processed_quantized_weight, unprocessed_quantized_weight, shape, quant_type, arch); |
| } |
|
|
| template void |
| symmetric_quantize<half, float>(int8_t*, int8_t*, half*, const float*, const std::vector<size_t>&, QuantType); |
|
|
| template void |
| symmetric_quantize<half, half>(int8_t*, int8_t*, half*, const half*, const std::vector<size_t>&, QuantType); |
|
|
|
|
| template<typename ComputeType, typename WeightType> |
| void symmetric_quantize(int8_t* processed_quantized_weight, |
| ComputeType* scale_ptr, |
| const WeightType* input_weight_ptr, |
| const std::vector<size_t>& shape, |
| QuantType quant_type) |
| { |
| symmetric_quantize(processed_quantized_weight, nullptr, scale_ptr, input_weight_ptr, shape, quant_type); |
| } |
|
|
| template void symmetric_quantize<float, float>(int8_t*, float*, const float*, const std::vector<size_t>&, QuantType); |
|
|
| template void symmetric_quantize<half, float>(int8_t*, half*, const float*, const std::vector<size_t>&, QuantType); |
|
|
| template void symmetric_quantize<half, half>(int8_t*, half*, const half*, const std::vector<size_t>&, QuantType); |
|
|
| } |
|
|