| #include <torch/extension.h>
|
| #include <cuda_runtime.h>
|
| #include <cublas_v2.h>
|
| #include <cudnn.h>
|
| #include <cmath>
|
|
|
|
|
| __global__ void matmul_kernel_persistent(
|
| const float *a_ptr,
|
| const float *b_ptr,
|
| float *c_ptr,
|
| const float *bias_ptr,
|
| int M, int N, int K,
|
| int stride_am, int stride_ak,
|
| int stride_bk, int stride_bn,
|
| int stride_cm, int stride_cn,
|
| int BLOCK_SIZE_M, int BLOCK_SIZE_N, int BLOCK_SIZE_K,
|
| int GROUP_SIZE_M, int NUM_SMS,
|
| bool HAS_BIAS)
|
| {
|
| int start_pid = blockIdx.x;
|
| int num_pid_m = (M + BLOCK_SIZE_M - 1) / BLOCK_SIZE_M;
|
| int num_pid_n = (N + BLOCK_SIZE_N - 1) / BLOCK_SIZE_N;
|
| int k_tiles = (K + BLOCK_SIZE_K - 1) / BLOCK_SIZE_K;
|
| int num_tiles = num_pid_m * num_pid_n;
|
|
|
| int num_pid_in_group = GROUP_SIZE_M * num_pid_n;
|
|
|
| for (int tile_id = start_pid; tile_id < num_tiles; tile_id += NUM_SMS)
|
| {
|
| int group_id = tile_id / num_pid_in_group;
|
| int first_pid_m = group_id * GROUP_SIZE_M;
|
| int group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M);
|
| int pid_m = first_pid_m + (tile_id % group_size_m);
|
| int pid_n = (tile_id % num_pid_in_group) / group_size_m;
|
|
|
| int start_m = pid_m * BLOCK_SIZE_M;
|
| int start_n = pid_n * BLOCK_SIZE_N;
|
|
|
|
|
| __shared__ float As[16][16];
|
| __shared__ float Bs[16][16];
|
|
|
| float accumulator = 0.0f;
|
| int tx = threadIdx.x;
|
| int ty = threadIdx.y;
|
|
|
|
|
| if (start_m + tx < M && start_n + ty < N)
|
| {
|
|
|
| for (int ki = 0; ki < k_tiles; ki++)
|
| {
|
| int k_start = ki * BLOCK_SIZE_K;
|
|
|
|
|
| if (k_start + tx < K && start_m + ty < M)
|
| {
|
| As[ty][tx] = a_ptr[(start_m + ty) * stride_am + (k_start + tx) * stride_ak];
|
| }
|
| else
|
| {
|
| As[ty][tx] = 0.0f;
|
| }
|
|
|
| if (k_start + ty < K && start_n + tx < N)
|
| {
|
| Bs[ty][tx] = b_ptr[(k_start + ty) * stride_bk + (start_n + tx) * stride_bn];
|
| }
|
| else
|
| {
|
| Bs[ty][tx] = 0.0f;
|
| }
|
|
|
| __syncthreads();
|
|
|
|
|
| for (int k = 0; k < min(BLOCK_SIZE_K, K - k_start); k++)
|
| {
|
| accumulator += As[ty][k] * Bs[k][tx];
|
| }
|
|
|
| __syncthreads();
|
| }
|
|
|
|
|
| if (HAS_BIAS && bias_ptr != nullptr)
|
| {
|
| accumulator += bias_ptr[start_n + tx];
|
| }
|
|
|
|
|
| c_ptr[(start_m + ty) * stride_cm + (start_n + tx) * stride_cn] = accumulator;
|
| }
|
| }
|
| }
|
|
|
|
|
| __global__ void log_softmax_kernel(
|
| const float *input_ptr,
|
| float *output_ptr,
|
| int input_row_stride,
|
| int output_row_stride,
|
| int n_cols,
|
| int BLOCK_SIZE)
|
| {
|
| int row_idx = blockIdx.x;
|
| int tid = threadIdx.x;
|
|
|
|
|
| __shared__ float max_val;
|
| __shared__ float sum_exp;
|
|
|
| if (tid == 0)
|
| {
|
| max_val = -INFINITY;
|
| sum_exp = 0.0f;
|
| }
|
| __syncthreads();
|
|
|
|
|
| float thread_max = -INFINITY;
|
| for (int col = tid; col < n_cols; col += blockDim.x)
|
| {
|
| float val = input_ptr[row_idx * input_row_stride + col];
|
| thread_max = fmaxf(thread_max, val);
|
| }
|
|
|
|
|
| __shared__ float sdata[256];
|
| sdata[tid] = thread_max;
|
| __syncthreads();
|
|
|
| for (int s = blockDim.x / 2; s > 0; s >>= 1)
|
| {
|
| if (tid < s)
|
| {
|
| sdata[tid] = fmaxf(sdata[tid], sdata[tid + s]);
|
| }
|
| __syncthreads();
|
| }
|
|
|
| if (tid == 0)
|
| {
|
| max_val = sdata[0];
|
| }
|
| __syncthreads();
|
|
|
|
|
| float thread_sum = 0.0f;
|
| for (int col = tid; col < n_cols; col += blockDim.x)
|
| {
|
| float val = input_ptr[row_idx * input_row_stride + col];
|
| thread_sum += expf(val - max_val);
|
| }
|
|
|
|
|
| sdata[tid] = thread_sum;
|
| __syncthreads();
|
|
|
| for (int s = blockDim.x / 2; s > 0; s >>= 1)
|
| {
|
| if (tid < s)
|
| {
|
| sdata[tid] += sdata[tid + s];
|
| }
|
| __syncthreads();
|
| }
|
|
|
| if (tid == 0)
|
| {
|
| sum_exp = sdata[0];
|
| }
|
| __syncthreads();
|
|
|
| float log_sum_exp = logf(sum_exp);
|
|
|
|
|
| for (int col = tid; col < n_cols; col += blockDim.x)
|
| {
|
| float val = input_ptr[row_idx * input_row_stride + col];
|
| output_ptr[row_idx * output_row_stride + col] = val - max_val - log_sum_exp;
|
| }
|
| }
|
|
|
|
|
| __global__ void mean_kernel(
|
| const float *input_ptr,
|
| float *output_ptr,
|
| int input_stride0, int input_stride1, int input_stride2,
|
| int output_stride0, int output_stride1,
|
| int M, int N, int K,
|
| int BLOCK_SIZE)
|
| {
|
| int pid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
| if (pid >= M * K)
|
| return;
|
|
|
| int m_idx = pid / K;
|
| int k_idx = pid % K;
|
|
|
| float acc = 0.0f;
|
| for (int n = 0; n < N; n++)
|
| {
|
| int input_idx = m_idx * input_stride0 + n * input_stride1 + k_idx * input_stride2;
|
| acc += input_ptr[input_idx];
|
| }
|
|
|
| float mean_val = acc / N;
|
| int output_idx = m_idx * output_stride0 + k_idx * output_stride1;
|
| output_ptr[output_idx] = mean_val;
|
| }
|
|
|
|
|
| void matmul_persistent_cuda(
|
| torch::Tensor const &a,
|
| torch::Tensor const &b,
|
| torch::Tensor &c,
|
| torch::Tensor const &bias)
|
| {
|
| const int M = a.size(0);
|
| const int K = a.size(1);
|
| const int N = b.size(1);
|
|
|
|
|
| cudaDeviceProp prop;
|
| cudaGetDeviceProperties(&prop, 0);
|
| const int NUM_SMS = prop.multiProcessorCount;
|
|
|
|
|
| const int BLOCK_SIZE_M = 128;
|
| const int BLOCK_SIZE_N = 128;
|
| const int BLOCK_SIZE_K = 64;
|
| const int GROUP_SIZE_M = 8;
|
|
|
|
|
| const int num_pid_m = (M + BLOCK_SIZE_M - 1) / BLOCK_SIZE_M;
|
| const int num_pid_n = (N + BLOCK_SIZE_N - 1) / BLOCK_SIZE_N;
|
| const int grid_size = min(NUM_SMS, num_pid_m * num_pid_n);
|
|
|
| dim3 block(16, 16);
|
| dim3 grid_dim(grid_size);
|
|
|
| matmul_kernel_persistent<<<grid_dim, block>>>(
|
| a.data_ptr<float>(),
|
| b.data_ptr<float>(),
|
| c.data_ptr<float>(),
|
| bias.defined() ? bias.data_ptr<float>() : nullptr,
|
| M, N, K,
|
| a.stride(0), a.stride(1),
|
| b.stride(0), b.stride(1),
|
| c.stride(0), c.stride(1),
|
| BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K,
|
| GROUP_SIZE_M, NUM_SMS,
|
| bias.defined());
|
| }
|
|
|
| void log_softmax_cuda(
|
| torch::Tensor const &input,
|
| torch::Tensor &output)
|
| {
|
| const auto original_shape = input.sizes();
|
| auto input_2d = input.reshape({-1, input.size(-1)}).contiguous();
|
| auto output_2d = output.reshape({-1, output.size(-1)});
|
|
|
| const int n_rows = input_2d.size(0);
|
| const int n_cols = input_2d.size(1);
|
|
|
| const int BLOCK_SIZE = 256;
|
|
|
| log_softmax_kernel<<<n_rows, BLOCK_SIZE>>>(
|
| input_2d.data_ptr<float>(),
|
| output_2d.data_ptr<float>(),
|
| input_2d.stride(0),
|
| output_2d.stride(0),
|
| n_cols,
|
| BLOCK_SIZE);
|
| }
|
|
|
| void mean_dim_cuda(
|
| torch::Tensor const &input,
|
| torch::Tensor &output,
|
| int dim)
|
| {
|
| auto shape = input.sizes().vec();
|
|
|
| int M = 1;
|
| for (int i = 0; i < dim; i++)
|
| {
|
| M *= shape[i];
|
| }
|
|
|
| int N = shape[dim];
|
|
|
| int K = 1;
|
| for (int i = dim + 1; i < shape.size(); i++)
|
| {
|
| K *= shape[i];
|
| }
|
|
|
| auto input_3d = input.reshape({M, N, K});
|
| auto output_2d = output.reshape({M, K});
|
|
|
| const int BLOCK_SIZE = 256;
|
| const int grid_size = (M * K + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
|
|
| mean_kernel<<<grid_size, BLOCK_SIZE>>>(
|
| input_3d.data_ptr<float>(),
|
| output_2d.data_ptr<float>(),
|
| input_3d.stride(0), input_3d.stride(1), input_3d.stride(2),
|
| output_2d.stride(0), output_2d.stride(1),
|
| M, N, K,
|
| BLOCK_SIZE);
|
| } |