| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| #include <tiny-cuda-nn/networks/fully_fused_mlp.h> |
|
|
| #include <tiny-cuda-nn/common_device.h> |
| #include <tiny-cuda-nn/cutlass_matmul.h> |
| #include <tiny-cuda-nn/multi_stream.h> |
|
|
| #include <mma.h> |
|
|
| TCNN_NAMESPACE_BEGIN |
|
|
| void check_shmem_error(cudaError_t error) { |
| if (error != cudaSuccess) { |
| throw std::runtime_error{"FullyFusedMLP: insufficient shared memory available on the GPU. Reduce `n_neurons` or use `CutlassMLP` (better compatibility but slower) instead."}; |
| } |
| } |
|
|
| template <int WIDTH, int N_ITERS, typename OUT_T, bool BACKWARD=false> |
| __device__ void threadblock_layer(Activation activation, __half* __restrict__ act_shmem, const __half* __restrict__ weights_this_layer, OUT_T* __restrict__ out_intermediate_threadblock_this_layer, const OUT_T* __restrict__ activation_aux = nullptr) { |
| |
| |
| |
| |
| |
| |
|
|
| constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0; |
| constexpr uint32_t N_BLOCKS = WIDTH / 16; |
|
|
| using namespace nvcuda; |
|
|
| |
| |
| using weights_layout_t = std::conditional_t<BACKWARD, wmma::row_major, wmma::col_major>; |
|
|
| |
| wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, wmma::row_major> act_frag; |
| wmma::fragment<wmma::matrix_b, 16, 16, 16, __half, weights_layout_t> weights_frag[N_BLOCKS]; |
| wmma::fragment<wmma::accumulator, 16, 16, 16, OUT_T> result_frag[N_ITERS]; |
|
|
| |
| const uint32_t li = threadIdx.x; |
| const uint32_t wi = threadIdx.y; |
|
|
| const uint32_t lane_offset = (8 * li) % WIDTH; |
| const uint32_t row = (8 * li + wi * 8 * 32) / WIDTH; |
|
|
| const uint32_t weights_col = 16 * wi; |
|
|
| __syncthreads(); |
|
|
| |
| TCNN_PRAGMA_UNROLL |
| for (uint32_t i = 0; i < N_BLOCKS; ++i) { |
| if (BACKWARD) { |
| |
| |
| wmma::load_matrix_sync(weights_frag[i], weights_this_layer + 16 * i * WIDTH + weights_col, WIDTH); |
| } else { |
| wmma::load_matrix_sync(weights_frag[i], weights_this_layer + 16 * i + weights_col * WIDTH, WIDTH); |
| } |
| } |
|
|
| TCNN_PRAGMA_UNROLL |
| for (int l = 0; l < N_ITERS; ++l) { |
| wmma::fill_fragment(result_frag[l], 0.0f); |
|
|
| TCNN_PRAGMA_UNROLL |
| for (uint32_t i = 0; i < N_BLOCKS; ++i) { |
| |
| wmma::load_matrix_sync(act_frag, act_shmem + 16 * i + (16 * l) * (WIDTH + SKEW), WIDTH + SKEW); |
| wmma::mma_sync(result_frag[l], act_frag, weights_frag[i], result_frag[l]); |
| } |
|
|
| |
| if (BACKWARD) { |
| |
| wmma::load_matrix_sync(act_frag, activation_aux + weights_col + l * 16 * WIDTH, WIDTH); |
| warp_activation_backward<__half>(activation, result_frag[l], act_frag, result_frag[l]); |
| } else { |
| warp_activation<__half>(activation, result_frag[l], result_frag[l]); |
| } |
| } |
|
|
| __syncthreads(); |
|
|
| TCNN_PRAGMA_UNROLL |
| for (int l = 0; l < N_ITERS; ++l) { |
| wmma::store_matrix_sync(act_shmem + weights_col + l * 16 * (WIDTH + SKEW), result_frag[l], WIDTH + SKEW, wmma::mem_row_major); |
| } |
|
|
| if (out_intermediate_threadblock_this_layer != nullptr) { |
| __syncthreads(); |
|
|
| TCNN_PRAGMA_UNROLL |
| for (int l = 0; l < N_ITERS; ++l) { |
| *(int4*)&out_intermediate_threadblock_this_layer[lane_offset + (row + 16 * l) * WIDTH] = *(int4*)&act_shmem[lane_offset + (row + 16 * l) * (WIDTH + SKEW)]; |
| } |
| } |
| } |
|
|
| template <int WIDTH, int N_ITERS> |
| __device__ void threadblock_load_input_static(__half* __restrict__ act_shmem, const __half* __restrict__ input_threadblock) { |
| |
|
|
| constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0; |
|
|
| |
| const uint32_t li = threadIdx.x; |
| const uint32_t wi = threadIdx.y; |
|
|
| const uint32_t lane_offset = (8 * li) % WIDTH; |
| const uint32_t row = (8 * li + wi * 8 * 32) / WIDTH; |
|
|
| TCNN_PRAGMA_UNROLL |
| for (int i = 0; i < N_ITERS; ++i) { |
| *(int4*)&act_shmem[lane_offset + (row + 16 * i) * (WIDTH + SKEW)] = *(int4*)&input_threadblock[lane_offset + (row + 16 * i) * WIDTH]; |
| } |
| } |
|
|
| template <int WIDTH, int N_ITERS, Activation ACTIVATION, typename OUTPUT_LAYOUT> |
| __global__ void kernel_mlp_fused_backward( |
| const __half* __restrict__ dL_doutput, |
| const __half* __restrict__ weights, |
| __half* __restrict__ out_intermediate, |
| const __half* __restrict__ forward, |
| __half* __restrict__ dL_dinput, |
| const __half* __restrict__ weights_first_layer, |
| const uint32_t output_stride, |
| const uint32_t batch_size, |
| const uint32_t out_width, |
| const uint32_t n_hidden_matmuls |
| ) { |
| |
| |
| |
| |
|
|
| constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0; |
|
|
| |
| const uint32_t li = threadIdx.x; |
| const uint32_t wi = threadIdx.y; |
| const uint32_t bi = blockIdx.x; |
|
|
| |
| |
| extern __shared__ __half shmem[]; |
| __half* act_shmem = shmem; |
|
|
| const uint32_t lane_offset = (8 * li) % WIDTH; |
| const uint32_t row = (8 * li + wi * 8 * 32) / WIDTH; |
|
|
| |
| |
| const uint32_t elem_idx_base = 16 * bi * N_ITERS; |
| const uint32_t elem_idx = elem_idx_base; |
|
|
| const uint32_t weights_stride = WIDTH * WIDTH; |
| const uint32_t layer_stride = WIDTH * batch_size; |
|
|
| |
| if (out_width <= 16) { |
| using namespace nvcuda; |
|
|
| |
| wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, OUTPUT_LAYOUT> act_frag; |
| wmma::fragment<wmma::matrix_b, 16, 16, 16, __half, wmma::row_major> weights_frag; |
| wmma::fragment<wmma::accumulator, 16, 16, 16, __half> result_frag[N_ITERS]; |
|
|
| |
| const uint32_t weights_col = 16 * wi; |
|
|
| wmma::load_matrix_sync(weights_frag, weights + weights_stride * n_hidden_matmuls + weights_col, WIDTH); |
|
|
| TCNN_PRAGMA_UNROLL |
| for (int l = 0; l < N_ITERS; ++l) { |
| wmma::fill_fragment(result_frag[l], 0.0f); |
|
|
| |
| if (std::is_same<OUTPUT_LAYOUT, wmma::row_major>::value) { |
| wmma::load_matrix_sync(act_frag, dL_doutput + (elem_idx + 16 * l) * output_stride, output_stride); |
| } else { |
| wmma::load_matrix_sync(act_frag, dL_doutput + (elem_idx + 16 * l), output_stride); |
| } |
|
|
| |
| |
| |
| wmma::mma_sync(result_frag[l], act_frag, weights_frag, result_frag[l]); |
|
|
| |
| wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, wmma::row_major> forward_frag; |
| wmma::load_matrix_sync(forward_frag, forward + layer_stride * n_hidden_matmuls + weights_col + (elem_idx + l * 16) * WIDTH, WIDTH); |
|
|
| warp_activation_backward<__half>(ACTIVATION, result_frag[l], forward_frag, result_frag[l]); |
| } |
|
|
| __syncthreads(); |
|
|
| TCNN_PRAGMA_UNROLL |
| for (int l = 0; l < N_ITERS; ++l) { |
| wmma::store_matrix_sync(act_shmem + weights_col + (16 * l) * (WIDTH + SKEW), result_frag[l], WIDTH + SKEW, wmma::mem_row_major); |
| } |
|
|
| __syncthreads(); |
|
|
| TCNN_PRAGMA_UNROLL |
| for (int i = 0; i < N_ITERS; ++i) { |
| *(int4*)&out_intermediate[lane_offset + (row + elem_idx + i * 16) * WIDTH] = *(int4*)&act_shmem[lane_offset + (row + 16 * i) * (WIDTH + SKEW)]; |
| } |
| } else { |
| |
| |
| threadblock_load_input_static<WIDTH, N_ITERS>(act_shmem, out_intermediate + elem_idx * WIDTH); |
| } |
|
|
| |
| for (uint32_t k = 0; k < n_hidden_matmuls; ++k) { |
| threadblock_layer<WIDTH, N_ITERS, __half, true>(ACTIVATION, act_shmem, weights + weights_stride * (n_hidden_matmuls - k - 1), out_intermediate + layer_stride * (k + 1) + elem_idx_base * WIDTH, forward + layer_stride * (n_hidden_matmuls - k - 1) + elem_idx_base * WIDTH); |
| } |
|
|
| |
| |
| |
| |
| if (dL_dinput != nullptr) { |
| threadblock_layer<WIDTH, N_ITERS, __half, true>(Activation::None, act_shmem, weights_first_layer, dL_dinput + elem_idx_base * WIDTH); |
| } |
| } |
|
|
| template <int WIDTH, typename T, Activation ACTIVATION> |
| std::enable_if_t<!std::is_same<__half, T>::value> mlp_fused_backward( |
| cudaStream_t stream, |
| const GPUMatrix<T, RM>& weights_first_layer, |
| const GPUMatrix<T, RM>& weights, |
| const GPUMatrixDynamic<T>& dL_doutput, |
| GPUMatrix<T>& temporaries, |
| const GPUMatrix<T>& forward, |
| GPUMatrixDynamic<T>* dL_dinput, |
| const uint32_t n_hidden_matmuls |
| ) { |
| throw std::runtime_error{"The fully fused backward pass only supports __half precision."}; |
| } |
|
|
| template <int WIDTH, typename T, Activation ACTIVATION> |
| std::enable_if_t<std::is_same<__half, T>::value> mlp_fused_backward( |
| cudaStream_t stream, |
| const GPUMatrix<T, RM>& weights_first_layer, |
| const GPUMatrix<T, RM>& weights, |
| const GPUMatrixDynamic<T>& dL_doutput, |
| GPUMatrix<T>& temporaries, |
| const GPUMatrix<T>& forward, |
| GPUMatrixDynamic<T>* dL_dinput, |
| const uint32_t n_hidden_matmuls |
| ) { |
| const uint32_t batch_size = dL_doutput.cols(); |
| const uint32_t out_width = dL_doutput.rows(); |
| constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0; |
| constexpr uint32_t N_BLOCKS = WIDTH / 16; |
|
|
| const int N_ITERS = WIDTH >= 256 ? 2 : 8; |
|
|
| CHECK_THROW(forward.cols() == batch_size); |
| CHECK_THROW(batch_size % (16 * N_ITERS) == 0); |
| CHECK_THROW(!dL_dinput || dL_dinput->layout() == RM || dL_dinput->stride() == dL_dinput->m()); |
|
|
| const dim3 threads = { 32u, N_BLOCKS, 1 }; |
|
|
| uint32_t n_elems_per_block = 16 * N_ITERS; |
| uint32_t n_blocks = div_round_up(batch_size, n_elems_per_block); |
|
|
| int shmem_size = sizeof(__half) * ((16 * N_ITERS) * (WIDTH + SKEW)); |
| const dim3 blocks = { n_blocks, 1u, 1u }; |
|
|
| |
| if (dL_doutput.layout() == RM) { |
| check_shmem_error(cudaFuncSetAttribute(kernel_mlp_fused_backward<WIDTH, N_ITERS, ACTIVATION, nvcuda::wmma::col_major>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); |
| kernel_mlp_fused_backward<WIDTH, N_ITERS, ACTIVATION, nvcuda::wmma::col_major><<<blocks, threads, shmem_size, stream>>>(dL_doutput.data(), weights.data(), temporaries.data(), forward.data(), dL_dinput ? dL_dinput->data() : nullptr, weights_first_layer.data(), dL_doutput.stride(), batch_size, out_width, n_hidden_matmuls); |
| } else { |
| check_shmem_error(cudaFuncSetAttribute(kernel_mlp_fused_backward<WIDTH, N_ITERS, ACTIVATION, nvcuda::wmma::row_major>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); |
| kernel_mlp_fused_backward<WIDTH, N_ITERS, ACTIVATION, nvcuda::wmma::row_major><<<blocks, threads, shmem_size, stream>>>(dL_doutput.data(), weights.data(), temporaries.data(), forward.data(), dL_dinput ? dL_dinput->data() : nullptr, weights_first_layer.data(), dL_doutput.stride(), batch_size, out_width, n_hidden_matmuls); |
| } |
| } |
|
|
| template <int WIDTH, int N_ITERS, typename OUT_T, typename INPUT_LAYOUT> |
| __device__ void threadblock_input_layer_forward_dynamic(Activation activation, __half* __restrict__ act_shmem, const __half* __restrict__ input_threadblock, const __half* __restrict__ weights_this_layer, OUT_T* __restrict__ out_intermediate_threadblock_this_layer, const uint32_t in_width, const uint32_t batch_size) { |
| |
| |
| |
| |
| |
| |
|
|
| constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0; |
| constexpr uint32_t INPUT_SKEW = 8; |
| constexpr uint32_t N_BLOCKS = WIDTH / 16; |
|
|
| using namespace nvcuda; |
|
|
| |
| wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, INPUT_LAYOUT> act_frag; |
| wmma::fragment<wmma::matrix_b, 16, 16, 16, __half, wmma::col_major> weights_frag; |
| wmma::fragment<wmma::accumulator, 16, 16, 16, OUT_T> result_frag[N_ITERS]; |
|
|
| |
| const uint32_t li = threadIdx.x; |
| const uint32_t wi = threadIdx.y; |
|
|
| const uint32_t lane_offset = (8 * li) % WIDTH; |
| const uint32_t row = (8 * li + wi * 8 * 32) / WIDTH; |
|
|
| const uint32_t weights_col = 16 * wi; |
|
|
| __half* __restrict__ weights_shmem = act_shmem + 16 * (in_width + INPUT_SKEW); |
|
|
| |
| |
| const uint32_t n_elems_per_load = N_BLOCKS * 32 * 8; |
| const uint32_t thread_elem_idx = (li + wi * 32) * 8; |
|
|
| const uint32_t n_elems_b = WIDTH * in_width; |
|
|
| TCNN_PRAGMA_UNROLL |
| for (uint32_t idx = thread_elem_idx; idx < n_elems_b; idx += n_elems_per_load) { |
| const uint32_t idx_skewed = idx + idx / in_width * INPUT_SKEW; |
| *(int4*)&weights_shmem[idx_skewed] = *(int4*)&weights_this_layer[idx]; |
| } |
|
|
| const uint32_t n_tensor_ops = in_width / 16; |
|
|
| if (std::is_same<INPUT_LAYOUT, wmma::col_major>::value) { |
| __syncthreads(); |
| } |
|
|
| TCNN_PRAGMA_UNROLL |
| for (int l = 0; l < N_ITERS; ++l) { |
| if (std::is_same<INPUT_LAYOUT, wmma::row_major>::value) { |
| |
| |
| |
| const uint32_t n_elems_a = 16 * in_width; |
|
|
| TCNN_PRAGMA_UNROLL |
| for (uint32_t idx = thread_elem_idx; idx < n_elems_a; idx += n_elems_per_load) { |
| const uint32_t idx_skewed = idx + idx / in_width * INPUT_SKEW; |
| *(int4*)&act_shmem[idx_skewed] = *(int4*)&input_threadblock[l * n_elems_a + idx]; |
| } |
|
|
| __syncthreads(); |
| } |
|
|
| wmma::fill_fragment(result_frag[l], 0.0f); |
| TCNN_PRAGMA_UNROLL |
| for (uint32_t i = 0; i < n_tensor_ops; ++i) { |
| |
| if (std::is_same<INPUT_LAYOUT, wmma::row_major>::value) { |
| wmma::load_matrix_sync(act_frag, act_shmem + 16 * i, in_width + INPUT_SKEW); |
| } else { |
| wmma::load_matrix_sync(act_frag, input_threadblock + 16 * i * batch_size + 16 * l, batch_size); |
| } |
| wmma::load_matrix_sync(weights_frag, weights_shmem + 16 * i + weights_col * (in_width + INPUT_SKEW), in_width + INPUT_SKEW); |
| wmma::mma_sync(result_frag[l], act_frag, weights_frag, result_frag[l]); |
| } |
|
|
| if (std::is_same<INPUT_LAYOUT, wmma::row_major>::value) { |
| __syncthreads(); |
| } |
|
|
| warp_activation<__half>(activation, result_frag[l], result_frag[l]); |
| } |
|
|
| if (std::is_same<INPUT_LAYOUT, wmma::col_major>::value) { |
| __syncthreads(); |
| } |
|
|
| TCNN_PRAGMA_UNROLL |
| for (int l = 0; l < N_ITERS; ++l) { |
| wmma::store_matrix_sync(act_shmem + weights_col + (16 * l) * (WIDTH + SKEW), result_frag[l], WIDTH + SKEW, wmma::mem_row_major); |
| } |
|
|
| if (out_intermediate_threadblock_this_layer != nullptr) { |
| __syncthreads(); |
|
|
| TCNN_PRAGMA_UNROLL |
| for (int i = 0; i < N_ITERS; ++i) { |
| *(int4*)&out_intermediate_threadblock_this_layer[lane_offset + (row + 16 * i) * WIDTH] = *(int4*)&act_shmem[lane_offset + (row + 16 * i) * (WIDTH + SKEW)]; |
| } |
| } |
| } |
|
|
| template <int WIDTH, int N_ITERS, typename OUT_T> |
| __device__ void threadblock_last_layer_forward(Activation activation, __half* __restrict__ act_shmem, const __half* __restrict__ weights_this_layer, OUT_T* __restrict__ out, const uint32_t output_stride, const nvcuda::wmma::layout_t output_layout) { |
| |
| |
| |
| |
|
|
| constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0; |
| constexpr uint32_t N_BLOCKS = WIDTH / 16; |
|
|
| using namespace nvcuda; |
|
|
| |
| wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, wmma::row_major> act_frag; |
| wmma::fragment<wmma::matrix_b, 16, 16, 16, __half, wmma::col_major> weights_frag[N_BLOCKS]; |
| wmma::fragment<wmma::accumulator, 16, 16, 16, OUT_T> result_frag; |
|
|
| |
| const uint32_t li = threadIdx.x; |
| const uint32_t wi = threadIdx.y; |
|
|
| __half* __restrict__ weights_shmem = act_shmem + N_ITERS * 16 * (WIDTH + SKEW); |
|
|
| const uint32_t weights_row = (8 * li) % WIDTH; |
| const uint32_t weights_col = (8 * li + 8 * 32 * wi) / WIDTH; |
|
|
| |
| |
| |
| *(int4*)&weights_shmem[weights_row + weights_col * (WIDTH + SKEW)] = *(int4*)&weights_this_layer[weights_row + weights_col * WIDTH]; |
|
|
| __syncthreads(); |
|
|
| TCNN_PRAGMA_UNROLL |
| for (uint32_t i = 0; i < N_BLOCKS; ++i) |
| wmma::load_matrix_sync(weights_frag[i], weights_shmem + 16 * i, WIDTH + SKEW); |
|
|
| |
| for (uint32_t idx = wi; idx < N_ITERS; idx += N_BLOCKS) { |
| wmma::fill_fragment(result_frag, 0.0f); |
| TCNN_PRAGMA_UNROLL |
| for (uint32_t i = 0; i < N_BLOCKS; ++i) { |
| |
| wmma::load_matrix_sync(act_frag, act_shmem + 16 * i + (16 * idx) * (WIDTH + SKEW), WIDTH + SKEW); |
| wmma::mma_sync(result_frag, act_frag, weights_frag[i], result_frag); |
| } |
|
|
| warp_activation<__half>(activation, result_frag, result_frag); |
|
|
| if (output_layout == wmma::mem_row_major) { |
| wmma::store_matrix_sync(out + idx * 16 * output_stride, result_frag, output_stride, output_layout); |
| } else { |
| wmma::store_matrix_sync(out + idx * 16, result_frag, output_stride, output_layout); |
| } |
| } |
| } |
|
|
| template <int WIDTH, int N_ITERS> |
| __device__ void threadblock_write_output_static(const __half* __restrict__ act_shmem, __half* __restrict__ output_threadblock) { |
| |
|
|
| constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0; |
|
|
| |
| const uint32_t li = threadIdx.x; |
| const uint32_t wi = threadIdx.y; |
|
|
| const uint32_t lane_offset = (8 * li) % WIDTH; |
| const uint32_t row = (8 * li + wi * 8 * 32) / WIDTH; |
|
|
| __syncthreads(); |
|
|
| TCNN_PRAGMA_UNROLL |
| for (int i = 0; i < N_ITERS; ++i) { |
| *(int4*)&output_threadblock[lane_offset + (row + 16 * i) * WIDTH] = *(int4*)&act_shmem[lane_offset + (row + 16 * i) * (WIDTH + SKEW)]; |
| } |
| } |
|
|
| template <int WIDTH, int N_ITERS, typename OUT_T, Activation ACTIVATION, bool INFERENCE> |
| __global__ void kernel_mlp_fused(const Activation output_activation, const __half* __restrict__ input, const __half* __restrict__ weights, OUT_T* __restrict__ out_intermediate, OUT_T* __restrict__ out, const uint32_t output_stride, const uint32_t batch_size, const uint32_t in_width, const uint32_t out_width, const uint32_t n_hidden_matmuls, const nvcuda::wmma::layout_t input_layout, const nvcuda::wmma::layout_t output_layout) { |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| extern __shared__ __half shmem[]; |
| __half* act_shmem = shmem; |
|
|
| |
| const uint32_t elem_idx = 16 * blockIdx.x * N_ITERS; |
|
|
| |
| if (input_layout == nvcuda::wmma::mem_col_major || in_width != WIDTH) { |
| if (input_layout == nvcuda::wmma::mem_row_major) { |
| threadblock_input_layer_forward_dynamic<WIDTH, N_ITERS, OUT_T, nvcuda::wmma::row_major>(ACTIVATION, act_shmem, input + elem_idx * in_width, weights, !INFERENCE ? (out_intermediate + elem_idx * WIDTH) : nullptr, in_width, batch_size); |
| } else { |
| threadblock_input_layer_forward_dynamic<WIDTH, N_ITERS, OUT_T, nvcuda::wmma::col_major>(ACTIVATION, act_shmem, input + elem_idx, weights, !INFERENCE ? (out_intermediate + elem_idx * WIDTH) : nullptr, in_width, batch_size); |
| } |
| } else { |
| |
| |
| threadblock_load_input_static<WIDTH, N_ITERS>(act_shmem, input + elem_idx * WIDTH); |
| threadblock_layer<WIDTH, N_ITERS, OUT_T>(ACTIVATION, act_shmem, weights, !INFERENCE ? (out_intermediate + elem_idx * WIDTH) : nullptr); |
| } |
|
|
| const uint32_t first_weights_stride = WIDTH * in_width; |
| const uint32_t weights_stride = WIDTH * WIDTH; |
| const uint32_t layer_stride = WIDTH * batch_size; |
|
|
| |
| for (uint32_t k = 0; k < n_hidden_matmuls; ++k) { |
| threadblock_layer<WIDTH, N_ITERS, OUT_T>(ACTIVATION, act_shmem, weights + first_weights_stride + weights_stride * k, !INFERENCE ? (out_intermediate + layer_stride * (k + 1) + elem_idx * WIDTH) : nullptr); |
| } |
|
|
| if (out_width > 16) { |
| |
| if (INFERENCE) { |
| threadblock_write_output_static<WIDTH, N_ITERS>(act_shmem, out_intermediate + elem_idx * WIDTH); |
| } |
| } else if (out) { |
| |
| if (output_layout == nvcuda::wmma::mem_row_major) { |
| threadblock_last_layer_forward<WIDTH, N_ITERS, OUT_T>(output_activation, act_shmem, weights + first_weights_stride + weights_stride * n_hidden_matmuls, out + elem_idx * output_stride, output_stride, output_layout); |
| } else { |
| threadblock_last_layer_forward<WIDTH, N_ITERS, OUT_T>(output_activation, act_shmem, weights + first_weights_stride + weights_stride * n_hidden_matmuls, out + elem_idx, output_stride, output_layout); |
| } |
| } |
| } |
|
|
| template <int WIDTH, typename T, Activation ACTIVATION, bool INFERENCE> |
| std::enable_if_t<!std::is_same<__half, T>::value> mlp_fused_forward( |
| cudaStream_t stream, |
| Activation output_activation, |
| const GPUMatrix<T, RM>& weights, |
| const GPUMatrixDynamic<T>& input, |
| GPUMatrix<T>& output_intermediate, |
| GPUMatrixDynamic<T>* output, |
| const uint32_t n_hidden_layers |
| ) { |
| throw std::runtime_error{"The fully fused forward pass only supports __half precision."}; |
| } |
|
|
| template <int WIDTH, typename T, Activation ACTIVATION, bool INFERENCE> |
| std::enable_if_t<std::is_same<__half, T>::value> mlp_fused_forward( |
| cudaStream_t stream, |
| Activation output_activation, |
| const GPUMatrix<T, RM>& weights, |
| const GPUMatrixDynamic<T>& input, |
| GPUMatrix<T>& output_intermediate, |
| GPUMatrixDynamic<T>* output, |
| const uint32_t n_hidden_layers |
| ) { |
| const uint32_t batch_size = input.cols(); |
| const uint32_t in_width = input.rows(); |
|
|
| constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0; |
| constexpr uint32_t INPUT_SKEW = 8; |
| constexpr uint32_t N_BLOCK_ROWS = WIDTH / 16; |
|
|
| static_assert(WIDTH % 16 == 0, "Width must be a multiply of 16."); |
|
|
| CHECK_THROW(in_width % 16 == 0); |
| CHECK_THROW(weights.rows() == WIDTH); |
| CHECK_THROW(weights.cols() % 16 == 0); |
| CHECK_THROW(output_intermediate.cols() == batch_size); |
| CHECK_THROW(!output || output->cols() == batch_size); |
| CHECK_THROW(input.layout() == RM || input.stride() == input.m()); |
|
|
| const int N_ITERS = WIDTH >= 256 ? 2 : 8; |
|
|
| if (batch_size % (16 * N_ITERS) != 0) { |
| throw std::runtime_error{fmt::format("Batch size must be a multiple of {}.", 16 * N_ITERS)}; |
| } |
|
|
| const dim3 threads = { 32u, N_BLOCK_ROWS, 1 }; |
|
|
| uint32_t n_elems_per_block = 16 * N_ITERS; |
| uint32_t n_blocks = div_round_up(batch_size, n_elems_per_block); |
|
|
| size_t shmem_size = sizeof(__half) * (16 + 16 * N_ITERS) * (WIDTH + SKEW); |
| if (in_width != WIDTH || input.layout() == RM) { |
| |
| shmem_size = std::max(shmem_size, sizeof(__half) * (WIDTH + 16) * (in_width + INPUT_SKEW)); |
| } |
|
|
| const dim3 blocks = { n_blocks, 1u, 1u }; |
|
|
| check_shmem_error(cudaFuncSetAttribute(kernel_mlp_fused<WIDTH, N_ITERS, __half, ACTIVATION, INFERENCE>, cudaFuncAttributeMaxDynamicSharedMemorySize, (int)shmem_size)); |
| kernel_mlp_fused<WIDTH, N_ITERS, __half, ACTIVATION, INFERENCE><<<blocks, threads, shmem_size, stream>>>( |
| output_activation, |
| input.data(), |
| weights.data(), |
| output_intermediate.data(), |
| output ? output->data() : nullptr, |
| output ? output->stride() : 0, |
| batch_size, |
| in_width, |
| output ? output->rows() : 0, |
| n_hidden_layers, |
| |
| input.layout() == RM ? nvcuda::wmma::mem_col_major : nvcuda::wmma::mem_row_major, |
| output && output->layout() == RM ? nvcuda::wmma::mem_col_major : nvcuda::wmma::mem_row_major |
| ); |
| } |
|
|
| template <typename T, int WIDTH> |
| FullyFusedMLP<T, WIDTH>::FullyFusedMLP( |
| uint32_t input_width, |
| uint32_t output_width, |
| uint32_t n_hidden_layers, |
| Activation activation, |
| Activation output_activation |
| ) : |
| m_input_width{input_width}, |
| m_network_width{WIDTH}, |
| m_output_width{output_width}, |
| m_n_hidden_layers{n_hidden_layers}, |
| m_activation{activation}, |
| m_output_activation{output_activation} |
| { |
| if (m_n_hidden_layers <= 0) { |
| throw std::runtime_error("FullyFusedMLP requires at least 1 hidden layer (3 layers in total)."); |
| } |
|
|
| m_n_hidden_matmuls = n_hidden_layers-1; |
|
|
| m_padded_output_width = next_multiple(m_output_width, REQUIRED_ALIGNMENT()); |
|
|
| |
| m_weight_matrices.emplace_back(nullptr, m_network_width, m_input_width); |
| m_weight_matrices_inference.emplace_back(nullptr, m_network_width, m_input_width); |
| m_gradient_matrices.emplace_back(nullptr, m_network_width, m_input_width); |
|
|
| for (uint32_t i = 0; i < m_n_hidden_matmuls; ++i) { |
| m_weight_matrices.emplace_back(nullptr, m_network_width, m_network_width); |
| m_weight_matrices_inference.emplace_back(nullptr, m_network_width, m_network_width); |
| m_gradient_matrices.emplace_back(nullptr, m_network_width, m_network_width); |
| } |
|
|
| m_weight_matrices.emplace_back(nullptr, m_padded_output_width, m_network_width); |
| m_weight_matrices_inference.emplace_back(nullptr, m_padded_output_width, m_network_width); |
| m_gradient_matrices.emplace_back(nullptr, m_padded_output_width, m_network_width); |
|
|
| |
| m_total_n_params = 0; |
| for (const auto& m : m_weight_matrices) { |
| m_total_n_params += m.n_elements(); |
| } |
| } |
|
|
| template <typename T, int WIDTH> |
| void FullyFusedMLP<T, WIDTH>::inference_mixed_precision_impl(cudaStream_t stream, const GPUMatrixDynamic<T>& input, GPUMatrixDynamic<T>& output, bool use_inference_params) { |
| |
| uint32_t batch_size = input.n(); |
|
|
| GPUMatrix<T> inference_tmp = m_output_width > 16 ? GPUMatrix<T>{m_network_width, batch_size, stream} : GPUMatrix<T>{nullptr, m_network_width, batch_size}; |
|
|
| |
| switch (m_activation) { |
| case Activation::None: mlp_fused_forward<WIDTH, T, Activation::None, true>( stream, m_output_activation, input_weight_matrix(use_inference_params), input, inference_tmp, &output, m_n_hidden_matmuls); break; |
| case Activation::Exponential: mlp_fused_forward<WIDTH, T, Activation::Exponential, true>(stream, m_output_activation, input_weight_matrix(use_inference_params), input, inference_tmp, &output, m_n_hidden_matmuls); break; |
| case Activation::Sigmoid: mlp_fused_forward<WIDTH, T, Activation::Sigmoid, true>( stream, m_output_activation, input_weight_matrix(use_inference_params), input, inference_tmp, &output, m_n_hidden_matmuls); break; |
| case Activation::ReLU: mlp_fused_forward<WIDTH, T, Activation::ReLU, true>( stream, m_output_activation, input_weight_matrix(use_inference_params), input, inference_tmp, &output, m_n_hidden_matmuls); break; |
| case Activation::Squareplus: mlp_fused_forward<WIDTH, T, Activation::Squareplus, true>( stream, m_output_activation, input_weight_matrix(use_inference_params), input, inference_tmp, &output, m_n_hidden_matmuls); break; |
| case Activation::Softplus: mlp_fused_forward<WIDTH, T, Activation::Softplus, true>( stream, m_output_activation, input_weight_matrix(use_inference_params), input, inference_tmp, &output, m_n_hidden_matmuls); break; |
| case Activation::Tanh: mlp_fused_forward<WIDTH, T, Activation::Tanh, true>( stream, m_output_activation, input_weight_matrix(use_inference_params), input, inference_tmp, &output, m_n_hidden_matmuls); break; |
| default: throw std::runtime_error{"Unsupported activation."}; |
| } |
|
|
| |
| |
| if (m_output_width > 16) { |
| fc_multiply<LastLayer>(stream, output_weight_matrix(use_inference_params), inference_tmp, output, m_output_activation); |
| } |
| } |
|
|
| template <typename T, int WIDTH> |
| std::unique_ptr<Context> FullyFusedMLP<T, WIDTH>::forward_impl(cudaStream_t stream, const GPUMatrixDynamic<T>& input, GPUMatrixDynamic<T>* output, bool use_inference_params, bool prepare_input_gradients) { |
| |
| uint32_t batch_size = input.n(); |
| auto forward = allocate_forward_buffers(stream, batch_size); |
|
|
| |
| switch (m_activation) { |
| case Activation::None: mlp_fused_forward<WIDTH, T, Activation::None, false>( stream, m_output_activation, input_weight_matrix(use_inference_params), input, forward->hidden.at(0), output, m_n_hidden_matmuls); break; |
| case Activation::Exponential: mlp_fused_forward<WIDTH, T, Activation::Exponential, false>(stream, m_output_activation, input_weight_matrix(use_inference_params), input, forward->hidden.at(0), output, m_n_hidden_matmuls); break; |
| case Activation::Sigmoid: mlp_fused_forward<WIDTH, T, Activation::Sigmoid, false>( stream, m_output_activation, input_weight_matrix(use_inference_params), input, forward->hidden.at(0), output, m_n_hidden_matmuls); break; |
| case Activation::ReLU: mlp_fused_forward<WIDTH, T, Activation::ReLU, false>( stream, m_output_activation, input_weight_matrix(use_inference_params), input, forward->hidden.at(0), output, m_n_hidden_matmuls); break; |
| case Activation::Squareplus: mlp_fused_forward<WIDTH, T, Activation::Squareplus, false>( stream, m_output_activation, input_weight_matrix(use_inference_params), input, forward->hidden.at(0), output, m_n_hidden_matmuls); break; |
| case Activation::Softplus: mlp_fused_forward<WIDTH, T, Activation::Softplus, false>( stream, m_output_activation, input_weight_matrix(use_inference_params), input, forward->hidden.at(0), output, m_n_hidden_matmuls); break; |
| case Activation::Tanh: mlp_fused_forward<WIDTH, T, Activation::Tanh, false>( stream, m_output_activation, input_weight_matrix(use_inference_params), input, forward->hidden.at(0), output, m_n_hidden_matmuls); break; |
| default: throw std::runtime_error{"Unsupported activation."}; |
| } |
|
|
| |
| |
| if (output && m_output_width > 16) { |
| fc_multiply<LastLayer>(stream, output_weight_matrix(use_inference_params), forward->hidden.back(), *output, m_output_activation); |
| } |
|
|
| return forward; |
| } |
|
|
| template <typename T, int WIDTH> |
| void FullyFusedMLP<T, WIDTH>::backward_impl( |
| cudaStream_t stream, |
| const Context& ctx, |
| const GPUMatrixDynamic<T>& input, |
| const GPUMatrixDynamic<T>& output, |
| const GPUMatrixDynamic<T>& dL_doutput, |
| GPUMatrixDynamic<T>* dL_dinput, |
| bool use_inference_params, |
| EGradientMode param_gradients_mode |
| ) { |
| |
| uint32_t batch_size = dL_doutput.n(); |
|
|
| |
| |
| std::vector<GPUMatrix<T>> backward_tmp(num_forward_activations()); |
| for (uint32_t i = 0; i < num_forward_activations(); ++i) { |
| backward_tmp[i].set_size_unsafe(m_network_width, batch_size); |
| } |
| auto backward_tmp_alloc = GPUMatrixBase::allocate_shared_memory(stream, backward_tmp); |
|
|
| |
| GPUMatrixDynamic<T> backward_output_tmp; |
| if (m_output_activation != Activation::None) { |
| backward_output_tmp = {m_padded_output_width, batch_size, stream, dL_doutput.layout()}; |
| activation_backward_output_gpu(stream, dL_doutput.n_elements(), m_output_activation, output.data(), dL_doutput.data(), backward_output_tmp.data()); |
| } |
|
|
| |
| |
| |
| |
|
|
| const float param_gradient_beta = param_gradients_mode == EGradientMode::Accumulate ? 1.0f : 0.0f; |
|
|
| std::vector<SyncedMultiStream> multi_streams; |
|
|
| const auto& forward = dynamic_cast<const ForwardContext&>(ctx); |
|
|
| int split_k_factor = batch_size / std::min((uint32_t)(1 << 12), batch_size); |
|
|
| const GPUMatrixDynamic<T>& tmp_dL_doutput = m_output_activation == Activation::None ? dL_doutput : backward_output_tmp; |
|
|
| uint32_t tmp_idx = m_n_hidden_matmuls; |
| uint32_t backward_tmp_idx = 0; |
|
|
| |
| if (param_gradients_mode != EGradientMode::Ignore) { |
| multi_streams.emplace_back(stream, 2); |
| fc_multiply_split_k<LastLayerK>(multi_streams.back().get(1), tmp_dL_doutput, forward.hidden.at(tmp_idx).transposed(), output_gradient_matrix(), split_k_factor, param_gradient_beta); |
| } |
|
|
| |
| |
| if (m_output_width > 16) { |
| fc_multiply<FullLayer>(stream, output_weight_matrix(use_inference_params).transposed(), tmp_dL_doutput, forward.hidden.at(tmp_idx), backward_tmp.at(backward_tmp_idx), m_activation, true); |
| } |
|
|
| |
| auto dL_dinput_fused = input.m() == forward.hidden.at(0).m() && input.layout() == CM ? dL_dinput : nullptr; |
|
|
| |
| switch (m_activation) { |
| case Activation::None: mlp_fused_backward<WIDTH, T, Activation::None>( stream, input_weight_matrix(use_inference_params), weight_matrix_at(use_inference_params, 0), tmp_dL_doutput, backward_tmp.at(backward_tmp_idx), forward.hidden.at(0), dL_dinput_fused, m_n_hidden_matmuls); break; |
| case Activation::Exponential: mlp_fused_backward<WIDTH, T, Activation::Exponential>(stream, input_weight_matrix(use_inference_params), weight_matrix_at(use_inference_params, 0), tmp_dL_doutput, backward_tmp.at(backward_tmp_idx), forward.hidden.at(0), dL_dinput_fused, m_n_hidden_matmuls); break; |
| case Activation::Sigmoid: mlp_fused_backward<WIDTH, T, Activation::Sigmoid>( stream, input_weight_matrix(use_inference_params), weight_matrix_at(use_inference_params, 0), tmp_dL_doutput, backward_tmp.at(backward_tmp_idx), forward.hidden.at(0), dL_dinput_fused, m_n_hidden_matmuls); break; |
| case Activation::ReLU: mlp_fused_backward<WIDTH, T, Activation::ReLU>( stream, input_weight_matrix(use_inference_params), weight_matrix_at(use_inference_params, 0), tmp_dL_doutput, backward_tmp.at(backward_tmp_idx), forward.hidden.at(0), dL_dinput_fused, m_n_hidden_matmuls); break; |
| case Activation::Squareplus: mlp_fused_backward<WIDTH, T, Activation::Squareplus>( stream, input_weight_matrix(use_inference_params), weight_matrix_at(use_inference_params, 0), tmp_dL_doutput, backward_tmp.at(backward_tmp_idx), forward.hidden.at(0), dL_dinput_fused, m_n_hidden_matmuls); break; |
| case Activation::Softplus: mlp_fused_backward<WIDTH, T, Activation::Softplus>( stream, input_weight_matrix(use_inference_params), weight_matrix_at(use_inference_params, 0), tmp_dL_doutput, backward_tmp.at(backward_tmp_idx), forward.hidden.at(0), dL_dinput_fused, m_n_hidden_matmuls); break; |
| case Activation::Tanh: mlp_fused_backward<WIDTH, T, Activation::Tanh>( stream, input_weight_matrix(use_inference_params), weight_matrix_at(use_inference_params, 0), tmp_dL_doutput, backward_tmp.at(backward_tmp_idx), forward.hidden.at(0), dL_dinput_fused, m_n_hidden_matmuls); break; |
| default: throw std::runtime_error{"Unsupported activation."}; |
| } |
|
|
| tmp_idx -= 1; |
| ++backward_tmp_idx; |
|
|
| |
| for (uint32_t i = 0; i < m_n_hidden_matmuls; ++i) { |
| uint32_t matrix_idx = m_n_hidden_matmuls - i - 1; |
|
|
| if (param_gradients_mode != EGradientMode::Ignore) { |
| multi_streams.emplace_back(stream, 2); |
| fc_multiply_split_k<FullLayerK>(multi_streams.back().get(1), backward_tmp.at(backward_tmp_idx-1), forward.hidden.at(tmp_idx).transposed(), gradient_matrix_at(matrix_idx), split_k_factor, param_gradient_beta); |
| } |
|
|
| tmp_idx -= 1; |
| ++backward_tmp_idx; |
| } |
|
|
| if (param_gradients_mode != EGradientMode::Ignore) { |
| multi_streams.emplace_back(stream, 2); |
| fc_multiply_split_k<FullLayerK>(multi_streams.back().get(1), backward_tmp.at(backward_tmp_idx-1), input.transposed(), input_gradient_matrix(), split_k_factor, param_gradient_beta); |
| } |
|
|
| |
| if (dL_dinput && !dL_dinput_fused) { |
| |
| fc_multiply<FullLayer>(stream, input_weight_matrix(use_inference_params).transposed(), backward_tmp.at(backward_tmp_idx-1), *dL_dinput); |
| } |
| } |
|
|
| template <typename T, int WIDTH> |
| std::unique_ptr<typename FullyFusedMLP<T, WIDTH>::ForwardContext> FullyFusedMLP<T, WIDTH>::allocate_forward_buffers(cudaStream_t stream, uint32_t batch_size) { |
| auto forward = std::make_unique<ForwardContext>(); |
|
|
| |
| |
| forward->hidden.resize(num_forward_activations()); |
| for (uint32_t i = 0; i < num_forward_activations(); ++i) { |
| forward->hidden[i].set_size_unsafe(m_network_width, batch_size); |
| } |
|
|
| forward->alloc = GPUMatrixBase::allocate_shared_memory(stream, forward->hidden); |
|
|
| return forward; |
| } |
|
|
| template <typename T, int WIDTH> |
| void FullyFusedMLP<T, WIDTH>::set_params_impl(T* params, T* inference_params, T* gradients) { |
| size_t current_pos = 0; |
| for (size_t i = 0; i < m_weight_matrices.size(); ++i) { |
| m_weight_matrices[i].set_data_unsafe(params + current_pos); |
| m_weight_matrices_inference[i].set_data_unsafe(inference_params + current_pos); |
| m_gradient_matrices[i].set_data_unsafe(gradients + current_pos); |
| current_pos += m_weight_matrices[i].n_elements(); |
| } |
| } |
|
|
| template <typename T, int WIDTH> |
| void FullyFusedMLP<T, WIDTH>::initialize_params(pcg32& rnd, float* params_full_precision, float scale) { |
| |
| std::vector<GPUMatrix<float, RM>> weight_matrices_full_precision; |
| weight_matrices_full_precision.emplace_back(params_full_precision, m_network_width, m_input_width); |
| params_full_precision += weight_matrices_full_precision.back().n_elements(); |
|
|
| for (uint32_t i = 0; i < m_n_hidden_matmuls; ++i) { |
| weight_matrices_full_precision.emplace_back(params_full_precision, m_network_width, m_network_width); |
| params_full_precision += weight_matrices_full_precision.back().n_elements(); |
| } |
|
|
| weight_matrices_full_precision.emplace_back(params_full_precision, m_padded_output_width, m_network_width); |
|
|
| |
| for (size_t i = 0; i < weight_matrices_full_precision.size(); ++i) { |
| if (m_activation == Activation::Sine) { |
| if (i == 0) { |
| weight_matrices_full_precision[i].initialize_siren_uniform_first(rnd, scale); |
| } else { |
| weight_matrices_full_precision[i].initialize_siren_uniform(rnd, scale); |
| } |
| } else { |
| weight_matrices_full_precision[i].initialize_xavier_uniform(rnd, scale); |
| } |
| } |
| } |
|
|
| template class FullyFusedMLP<network_precision_t, 128>; |
| template class FullyFusedMLP<network_precision_t, 64>; |
| template class FullyFusedMLP<network_precision_t, 32>; |
| template class FullyFusedMLP<network_precision_t, 16>; |
|
|
| TCNN_NAMESPACE_END |
|
|