Spaces:
Running
on
Zero
Running
on
Zero
File size: 9,962 Bytes
3386f25 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 |
/*
StyleForge - Fused Instance Normalization Kernel
Fuses: Mean → Variance → Normalize → Affine Transform
Key Optimizations:
- Single kernel launch for entire InstanceNorm operation
- Warp-level reductions for mean/variance computation
- Fused affine transform (gamma * normalized + beta)
- Efficient shared memory usage
Performance Target: 3-5x speedup over PyTorch nn.InstanceNorm2d
*/
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <math.h>
// ============================================
// CUDA Error Checking
// ============================================
#define CUDA_CHECK(call) \
do { \
cudaError_t err = call; \
if (err != cudaSuccess) { \
printf("CUDA error at %s:%d: %s\n", __FILE__, __LINE__, \
cudaGetErrorString(err)); \
std::abort(); \
} \
} while (0)
// ============================================
// Configuration
// ============================================
#define WARP_SIZE 32
#define MAX_BLOCK_SIZE 1024
// ============================================
// Warp-Level Primitives
// ============================================
__device__ __forceinline__ float warp_reduce_sum(float val) {
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
val += __shfl_down_sync(0xffffffff, val, offset);
}
return val;
}
__device__ __forceinline__ float warp_reduce_max(float val) {
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
val = fmaxf(val, __shfl_down_sync(0xffffffff, val, offset));
}
return val;
}
// ============================================
// Fused Instance Norm Kernel
// ============================================
template<int BLOCK_SIZE>
__global__ void fused_instance_norm_kernel(
const float* __restrict__ input, // [B, C, H, W]
const float* __restrict__ gamma, // [C]
const float* __restrict__ beta, // [C]
float* __restrict__ output, // [B, C, H, W]
int batch_size,
int channels,
int height,
int width,
float eps
) {
// One block per (batch, channel) instance
int batch_idx = blockIdx.y;
int channel_idx = blockIdx.x;
int tid = threadIdx.x;
int spatial_size = height * width;
// Shared memory for reductions
__shared__ float s_warp_sums[32]; // Up to 32 warps
__shared__ float s_mean;
__shared__ float s_inv_std;
// Input offset for this (batch, channel)
int64_t channel_offset = ((int64_t)batch_idx * channels + channel_idx) * spatial_size;
// ============================================
// Stage 1: Compute Mean
// ============================================
float sum = 0.0f;
for (int i = tid; i < spatial_size; i += BLOCK_SIZE) {
sum += input[channel_offset + i];
}
// Warp-level reduction
sum = warp_reduce_sum(sum);
// Store warp sum in shared memory
int warp_id = tid / WARP_SIZE;
int lane_id = tid % WARP_SIZE;
if (lane_id == 0) {
s_warp_sums[warp_id] = sum;
}
__syncthreads();
// Final reduction across warps
if (tid == 0) {
float total = 0.0f;
int num_warps = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE;
for (int i = 0; i < num_warps; i++) {
total += s_warp_sums[i];
}
s_mean = total / spatial_size;
}
__syncthreads();
float mean = s_mean;
// ============================================
// Stage 2: Compute Variance
// ============================================
float var_sum = 0.0f;
for (int i = tid; i < spatial_size; i += BLOCK_SIZE) {
float diff = input[channel_offset + i] - mean;
var_sum += diff * diff;
}
// Warp-level reduction
var_sum = warp_reduce_sum(var_sum);
if (lane_id == 0) {
s_warp_sums[warp_id] = var_sum;
}
__syncthreads();
// Final reduction across warps
if (tid == 0) {
float total = 0.0f;
int num_warps = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE;
for (int i = 0; i < num_warps; i++) {
total += s_warp_sums[i];
}
float variance = total / spatial_size;
s_inv_std = rsqrtf(variance + eps);
}
__syncthreads();
float inv_std = s_inv_std;
// ============================================
// Stage 3: Normalize & Affine Transform (Fused)
// ============================================
float gamma_val = gamma[channel_idx];
float beta_val = beta[channel_idx];
for (int i = tid; i < spatial_size; i += BLOCK_SIZE) {
int idx = channel_offset + i;
// Normalize: (x - mean) / std
float normalized = (input[idx] - mean) * inv_std;
// Affine transform: gamma * x + beta
output[idx] = gamma_val * normalized + beta_val;
}
}
// ============================================
// Vectorized Instance Norm (float4)
// ============================================
template<int BLOCK_SIZE>
__global__ void fused_instance_norm_kernel_vec4(
const float* __restrict__ input,
const float* __restrict__ gamma,
const float* __restrict__ beta,
float* __restrict__ output,
int batch_size,
int channels,
int height,
int width,
float eps
) {
// Vectorized loads using float4 (4 pixels at once)
const float4* input_vec = reinterpret_cast<const float4*>(input);
float4* output_vec = reinterpret_cast<float4*>(output);
int batch_idx = blockIdx.y;
int channel_idx = blockIdx.x;
int tid = threadIdx.x;
int spatial_size = height * width;
int vec_size = spatial_size / 4;
__shared__ float s_warp_sums[32];
__shared__ float s_mean;
__shared__ float s_inv_std;
int64_t channel_offset = ((int64_t)batch_idx * channels + channel_idx) * vec_size;
// Compute mean using vectorized loads
float sum = 0.0f;
for (int i = tid; i < vec_size; i += BLOCK_SIZE) {
float4 vec = input_vec[channel_offset + i];
sum += vec.x + vec.y + vec.z + vec.w;
}
sum = warp_reduce_sum(sum);
int warp_id = tid / WARP_SIZE;
int lane_id = tid % WARP_SIZE;
if (lane_id == 0) {
s_warp_sums[warp_id] = sum;
}
__syncthreads();
if (tid == 0) {
float total = 0.0f;
int num_warps = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE;
for (int i = 0; i < num_warps; i++) {
total += s_warp_sums[i];
}
s_mean = total / spatial_size;
}
__syncthreads();
float mean = s_mean;
// Compute variance
float var_sum = 0.0f;
for (int i = tid; i < vec_size; i += BLOCK_SIZE) {
float4 vec = input_vec[channel_offset + i];
float4 diff;
diff.x = vec.x - mean;
diff.y = vec.y - mean;
diff.z = vec.z - mean;
diff.w = vec.w - mean;
var_sum += diff.x * diff.x + diff.y * diff.y + diff.z * diff.z + diff.w * diff.w;
}
var_sum = warp_reduce_sum(var_sum);
if (lane_id == 0) {
s_warp_sums[warp_id] = var_sum;
}
__syncthreads();
if (tid == 0) {
float total = 0.0f;
int num_warps = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE;
for (int i = 0; i < num_warps; i++) {
total += s_warp_sums[i];
}
float variance = total / spatial_size;
s_inv_std = rsqrtf(variance + eps);
}
__syncthreads();
float inv_std = s_inv_std;
float gamma_val = gamma[channel_idx];
float beta_val = beta[channel_idx];
// Normalize and apply affine transform
for (int i = tid; i < vec_size; i += BLOCK_SIZE) {
float4 vec = input_vec[channel_offset + i];
float4 result;
result.x = gamma_val * (vec.x - mean) * inv_std + beta_val;
result.y = gamma_val * (vec.y - mean) * inv_std + beta_val;
result.z = gamma_val * (vec.z - mean) * inv_std + beta_val;
result.w = gamma_val * (vec.w - mean) * inv_std + beta_val;
output_vec[channel_offset + i] = result;
}
}
// ============================================
// Launcher Function
// ============================================
torch::Tensor fused_instance_norm_forward(
torch::Tensor input,
torch::Tensor gamma,
torch::Tensor beta,
float eps,
bool use_vectorized
) {
TORCH_CHECK(input.device().is_cuda(), "Input must be on CUDA");
TORCH_CHECK(input.dtype() == torch::kFloat32, "Input must be float32");
TORCH_CHECK(input.dim() == 4, "Input must be 4D (B, C, H, W)");
const int batch_size = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int spatial_size = height * width;
auto output = torch::zeros_like(input);
dim3 block(256);
dim3 grid(channels, batch_size);
// Use vectorized kernel if spatial size is multiple of 4
bool use_vec4 = use_vectorized && (spatial_size % 4 == 0);
if (use_vec4) {
fused_instance_norm_kernel_vec4<256><<<grid, block>>>(
input.data_ptr<float>(),
gamma.data_ptr<float>(),
beta.data_ptr<float>(),
output.data_ptr<float>(),
batch_size,
channels,
height,
width,
eps
);
} else {
fused_instance_norm_kernel<256><<<grid, block>>>(
input.data_ptr<float>(),
gamma.data_ptr<float>(),
beta.data_ptr<float>(),
output.data_ptr<float>(),
batch_size,
channels,
height,
width,
eps
);
}
CUDA_CHECK(cudaGetLastError());
return output;
}
// ============================================
// Pybind11 Module
// ============================================
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &fused_instance_norm_forward, "Fused InstanceNorm (CUDA)");
}
|