File size: 9,962 Bytes
3386f25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
/*
StyleForge - Fused Instance Normalization Kernel

Fuses: Mean → Variance → Normalize → Affine Transform

Key Optimizations:
- Single kernel launch for entire InstanceNorm operation
- Warp-level reductions for mean/variance computation
- Fused affine transform (gamma * normalized + beta)
- Efficient shared memory usage

Performance Target: 3-5x speedup over PyTorch nn.InstanceNorm2d
*/

#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <math.h>

// ============================================
// CUDA Error Checking
// ============================================
#define CUDA_CHECK(call) \
    do { \
        cudaError_t err = call; \
        if (err != cudaSuccess) { \
            printf("CUDA error at %s:%d: %s\n", __FILE__, __LINE__, \
                   cudaGetErrorString(err)); \
            std::abort(); \
        } \
    } while (0)

// ============================================
// Configuration
// ============================================
#define WARP_SIZE 32
#define MAX_BLOCK_SIZE 1024

// ============================================
// Warp-Level Primitives
// ============================================

__device__ __forceinline__ float warp_reduce_sum(float val) {
    #pragma unroll
    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
        val += __shfl_down_sync(0xffffffff, val, offset);
    }
    return val;
}

__device__ __forceinline__ float warp_reduce_max(float val) {
    #pragma unroll
    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
        val = fmaxf(val, __shfl_down_sync(0xffffffff, val, offset));
    }
    return val;
}

// ============================================
// Fused Instance Norm Kernel
// ============================================

template<int BLOCK_SIZE>
__global__ void fused_instance_norm_kernel(
    const float* __restrict__ input,   // [B, C, H, W]
    const float* __restrict__ gamma,   // [C]
    const float* __restrict__ beta,    // [C]
    float* __restrict__ output,        // [B, C, H, W]
    int batch_size,
    int channels,
    int height,
    int width,
    float eps
) {
    // One block per (batch, channel) instance
    int batch_idx = blockIdx.y;
    int channel_idx = blockIdx.x;
    int tid = threadIdx.x;
    int spatial_size = height * width;

    // Shared memory for reductions
    __shared__ float s_warp_sums[32];  // Up to 32 warps
    __shared__ float s_mean;
    __shared__ float s_inv_std;

    // Input offset for this (batch, channel)
    int64_t channel_offset = ((int64_t)batch_idx * channels + channel_idx) * spatial_size;

    // ============================================
    // Stage 1: Compute Mean
    // ============================================

    float sum = 0.0f;
    for (int i = tid; i < spatial_size; i += BLOCK_SIZE) {
        sum += input[channel_offset + i];
    }

    // Warp-level reduction
    sum = warp_reduce_sum(sum);

    // Store warp sum in shared memory
    int warp_id = tid / WARP_SIZE;
    int lane_id = tid % WARP_SIZE;

    if (lane_id == 0) {
        s_warp_sums[warp_id] = sum;
    }
    __syncthreads();

    // Final reduction across warps
    if (tid == 0) {
        float total = 0.0f;
        int num_warps = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE;
        for (int i = 0; i < num_warps; i++) {
            total += s_warp_sums[i];
        }
        s_mean = total / spatial_size;
    }
    __syncthreads();

    float mean = s_mean;

    // ============================================
    // Stage 2: Compute Variance
    // ============================================

    float var_sum = 0.0f;
    for (int i = tid; i < spatial_size; i += BLOCK_SIZE) {
        float diff = input[channel_offset + i] - mean;
        var_sum += diff * diff;
    }

    // Warp-level reduction
    var_sum = warp_reduce_sum(var_sum);

    if (lane_id == 0) {
        s_warp_sums[warp_id] = var_sum;
    }
    __syncthreads();

    // Final reduction across warps
    if (tid == 0) {
        float total = 0.0f;
        int num_warps = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE;
        for (int i = 0; i < num_warps; i++) {
            total += s_warp_sums[i];
        }
        float variance = total / spatial_size;
        s_inv_std = rsqrtf(variance + eps);
    }
    __syncthreads();

    float inv_std = s_inv_std;

    // ============================================
    // Stage 3: Normalize & Affine Transform (Fused)
    // ============================================

    float gamma_val = gamma[channel_idx];
    float beta_val = beta[channel_idx];

    for (int i = tid; i < spatial_size; i += BLOCK_SIZE) {
        int idx = channel_offset + i;

        // Normalize: (x - mean) / std
        float normalized = (input[idx] - mean) * inv_std;

        // Affine transform: gamma * x + beta
        output[idx] = gamma_val * normalized + beta_val;
    }
}

// ============================================
// Vectorized Instance Norm (float4)
// ============================================

template<int BLOCK_SIZE>
__global__ void fused_instance_norm_kernel_vec4(
    const float* __restrict__ input,
    const float* __restrict__ gamma,
    const float* __restrict__ beta,
    float* __restrict__ output,
    int batch_size,
    int channels,
    int height,
    int width,
    float eps
) {
    // Vectorized loads using float4 (4 pixels at once)
    const float4* input_vec = reinterpret_cast<const float4*>(input);
    float4* output_vec = reinterpret_cast<float4*>(output);

    int batch_idx = blockIdx.y;
    int channel_idx = blockIdx.x;
    int tid = threadIdx.x;
    int spatial_size = height * width;
    int vec_size = spatial_size / 4;

    __shared__ float s_warp_sums[32];
    __shared__ float s_mean;
    __shared__ float s_inv_std;

    int64_t channel_offset = ((int64_t)batch_idx * channels + channel_idx) * vec_size;

    // Compute mean using vectorized loads
    float sum = 0.0f;
    for (int i = tid; i < vec_size; i += BLOCK_SIZE) {
        float4 vec = input_vec[channel_offset + i];
        sum += vec.x + vec.y + vec.z + vec.w;
    }

    sum = warp_reduce_sum(sum);

    int warp_id = tid / WARP_SIZE;
    int lane_id = tid % WARP_SIZE;

    if (lane_id == 0) {
        s_warp_sums[warp_id] = sum;
    }
    __syncthreads();

    if (tid == 0) {
        float total = 0.0f;
        int num_warps = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE;
        for (int i = 0; i < num_warps; i++) {
            total += s_warp_sums[i];
        }
        s_mean = total / spatial_size;
    }
    __syncthreads();

    float mean = s_mean;

    // Compute variance
    float var_sum = 0.0f;
    for (int i = tid; i < vec_size; i += BLOCK_SIZE) {
        float4 vec = input_vec[channel_offset + i];
        float4 diff;
        diff.x = vec.x - mean;
        diff.y = vec.y - mean;
        diff.z = vec.z - mean;
        diff.w = vec.w - mean;
        var_sum += diff.x * diff.x + diff.y * diff.y + diff.z * diff.z + diff.w * diff.w;
    }

    var_sum = warp_reduce_sum(var_sum);

    if (lane_id == 0) {
        s_warp_sums[warp_id] = var_sum;
    }
    __syncthreads();

    if (tid == 0) {
        float total = 0.0f;
        int num_warps = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE;
        for (int i = 0; i < num_warps; i++) {
            total += s_warp_sums[i];
        }
        float variance = total / spatial_size;
        s_inv_std = rsqrtf(variance + eps);
    }
    __syncthreads();

    float inv_std = s_inv_std;
    float gamma_val = gamma[channel_idx];
    float beta_val = beta[channel_idx];

    // Normalize and apply affine transform
    for (int i = tid; i < vec_size; i += BLOCK_SIZE) {
        float4 vec = input_vec[channel_offset + i];
        float4 result;
        result.x = gamma_val * (vec.x - mean) * inv_std + beta_val;
        result.y = gamma_val * (vec.y - mean) * inv_std + beta_val;
        result.z = gamma_val * (vec.z - mean) * inv_std + beta_val;
        result.w = gamma_val * (vec.w - mean) * inv_std + beta_val;
        output_vec[channel_offset + i] = result;
    }
}

// ============================================
// Launcher Function
// ============================================

torch::Tensor fused_instance_norm_forward(
    torch::Tensor input,
    torch::Tensor gamma,
    torch::Tensor beta,
    float eps,
    bool use_vectorized
) {
    TORCH_CHECK(input.device().is_cuda(), "Input must be on CUDA");
    TORCH_CHECK(input.dtype() == torch::kFloat32, "Input must be float32");
    TORCH_CHECK(input.dim() == 4, "Input must be 4D (B, C, H, W)");

    const int batch_size = input.size(0);
    const int channels = input.size(1);
    const int height = input.size(2);
    const int width = input.size(3);
    const int spatial_size = height * width;

    auto output = torch::zeros_like(input);

    dim3 block(256);
    dim3 grid(channels, batch_size);

    // Use vectorized kernel if spatial size is multiple of 4
    bool use_vec4 = use_vectorized && (spatial_size % 4 == 0);

    if (use_vec4) {
        fused_instance_norm_kernel_vec4<256><<<grid, block>>>(
            input.data_ptr<float>(),
            gamma.data_ptr<float>(),
            beta.data_ptr<float>(),
            output.data_ptr<float>(),
            batch_size,
            channels,
            height,
            width,
            eps
        );
    } else {
        fused_instance_norm_kernel<256><<<grid, block>>>(
            input.data_ptr<float>(),
            gamma.data_ptr<float>(),
            beta.data_ptr<float>(),
            output.data_ptr<float>(),
            batch_size,
            channels,
            height,
            width,
            eps
        );
    }

    CUDA_CHECK(cudaGetLastError());

    return output;
}

// ============================================
// Pybind11 Module
// ============================================

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &fused_instance_norm_forward, "Fused InstanceNorm (CUDA)");
}