Spaces:
Sleeping
Sleeping
File size: 6,981 Bytes
3386f25 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 |
/*
StyleForge - Test CUDA Kernels
Simple kernels for verifying CUDA compilation and testing
optimization techniques.
*/
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
// -------------------------------------------------------------------------
// Error checking macro
// -------------------------------------------------------------------------
#define CUDA_CHECK(call) \
do { \
cudaError_t err = call; \
if (err != cudaSuccess) { \
std::cerr << "CUDA error at " << __FILE__ << ":" << __LINE__ \
<< ": " << cudaGetErrorString(err) << std::endl; \
throw std::runtime_error(cudaGetErrorString(err)); \
} \
} while(0)
// -------------------------------------------------------------------------
// Kernel 1: Simple element-wise multiplication
// -------------------------------------------------------------------------
__global__ void multiply_kernel(
const float* __restrict__ a,
const float* __restrict__ b,
float* __restrict__ c,
int size
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < size) {
c[idx] = a[idx] * b[idx];
}
}
torch::Tensor multiply_cuda(torch::Tensor a, torch::Tensor b) {
TORCH_CHECK(a.device().is_cuda(), "Input a must be on CUDA");
TORCH_CHECK(b.device().is_cuda(), "Input b must be on CUDA");
TORCH_CHECK(a.dtype() == torch::kFloat32, "Input a must be float32");
TORCH_CHECK(b.dtype() == torch::kFloat32, "Input b must be float32");
auto c = torch::zeros_like(a);
int size = a.numel();
const int threads = 256;
const int blocks = (size + threads - 1) / threads;
multiply_kernel<<<blocks, threads>>>(
a.data_ptr<float>(),
b.data_ptr<float>(),
c.data_ptr<float>(),
size
);
CUDA_CHECK(cudaGetLastError());
return c;
}
// -------------------------------------------------------------------------
// Kernel 2: Vectorized element-wise multiplication (float4)
// -------------------------------------------------------------------------
__global__ void multiply_vectorized_kernel(
const float* __restrict__ a,
const float* __restrict__ b,
float* __restrict__ c,
int size
) {
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4;
if (idx + 3 < size) {
// Vectorized load using float4 (4 floats = 128 bits)
float4 a4 = reinterpret_cast<const float4*>(a)[idx / 4];
float4 b4 = reinterpret_cast<const float4*>(b)[idx / 4];
// Element-wise multiply
float4 c4;
c4.x = a4.x * b4.x;
c4.y = a4.y * b4.y;
c4.z = a4.z * b4.z;
c4.w = a4.w * b4.w;
// Vectorized store
reinterpret_cast<float4*>(c)[idx / 4] = c4;
}
}
torch::Tensor multiply_vectorized_cuda(torch::Tensor a, torch::Tensor b) {
TORCH_CHECK(a.device().is_cuda(), "Input a must be on CUDA");
TORCH_CHECK(b.device().is_cuda(), "Input b must be on CUDA");
TORCH_CHECK(a.dtype() == torch::kFloat32, "Input a must be float32");
TORCH_CHECK(b.dtype() == torch::kFloat32, "Input b must be float32");
auto c = torch::zeros_like(a);
int size = a.numel();
const int threads = 256;
const int blocks = ((size / 4) + threads - 1) / threads;
multiply_vectorized_kernel<<<blocks, threads>>>(
a.data_ptr<float>(),
b.data_ptr<float>(),
c.data_ptr<float>(),
size
);
CUDA_CHECK(cudaGetLastError());
return c;
}
// -------------------------------------------------------------------------
// Kernel 3: Shared memory reduction (sum)
// -------------------------------------------------------------------------
template<int BLOCK_SIZE>
__global__ void sum_kernel(
const float* __restrict__ input,
float* __restrict__ output,
int size
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int tid = threadIdx.x;
// Shared memory for block-level reduction
__shared__ float sdata[BLOCK_SIZE];
// Load element (0 if out of bounds)
sdata[tid] = (idx < size) ? input[idx] : 0.0f;
__syncthreads();
// Reduce in shared memory
#pragma unroll
for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) {
sdata[tid] += sdata[tid + s];
}
__syncthreads();
}
// Write result for this block
if (tid == 0) {
output[blockIdx.x] = sdata[0];
}
}
torch::Tensor sum_cuda(torch::Tensor input) {
TORCH_CHECK(input.device().is_cuda(), "Input must be on CUDA");
TORCH_CHECK(input.dtype() == torch::kFloat32, "Input must be float32");
int size = input.numel();
const int BLOCK_SIZE = 256;
const int blocks = (size + BLOCK_SIZE - 1) / BLOCK_SIZE;
// Allocate intermediate output
auto partial_sums = torch::zeros({blocks}, torch::dtype(torch::kFloat32).device(input.device()));
// First level reduction
sum_kernel<BLOCK_SIZE><<<blocks, BLOCK_SIZE>>>(
input.data_ptr<float>(),
partial_sums.data_ptr<float>(),
size
);
CUDA_CHECK(cudaGetLastError());
// Final reduction on CPU (or could do another kernel pass)
auto result = partial_sums.sum();
return result;
}
// -------------------------------------------------------------------------
// Kernel 4: Fused multiply-add (a * b + c)
// -------------------------------------------------------------------------
__global__ void multiply_add_kernel(
const float* __restrict__ a,
const float* __restrict__ b,
const float* __restrict__ c,
float* __restrict__ d,
int size
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < size) {
d[idx] = a[idx] * b[idx] + c[idx]; // FMA: one instruction
}
}
torch::Tensor multiply_add_cuda(torch::Tensor a, torch::Tensor b, torch::Tensor c) {
TORCH_CHECK(a.device().is_cuda(), "Input a must be on CUDA");
TORCH_CHECK(b.device().is_cuda(), "Input b must be on CUDA");
TORCH_CHECK(c.device().is_cuda(), "Input c must be on CUDA");
auto d = torch::zeros_like(a);
int size = a.numel();
const int threads = 256;
const int blocks = (size + threads - 1) / threads;
multiply_add_kernel<<<blocks, threads>>>(
a.data_ptr<float>(),
b.data_ptr<float>(),
c.data_ptr<float>(),
d.data_ptr<float>(),
size
);
CUDA_CHECK(cudaGetLastError());
return d;
}
// -------------------------------------------------------------------------
// Pybind11 module definition
// -------------------------------------------------------------------------
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multiply", &multiply_cuda, "Element-wise multiply (CUDA)");
m.def("multiply_vectorized", &multiply_vectorized_cuda, "Element-wise multiply with float4 vectorization");
m.def("sum", &sum_cuda, "Sum reduction using shared memory");
m.def("multiply_add", &multiply_add_cuda, "Fused multiply-add (a * b + c)");
}
|