Spaces:
Running
A newer version of the Gradio SDK is available: 6.12.0
H100 Hopper Architecture Deep Dive
Overview
The NVIDIA H100 GPU, based on the Hopper architecture (sm_90), is the current flagship for AI training and inference. This guide provides a comprehensive deep dive into optimizing CUDA kernels for H100 hardware, with specific focus on patterns used in HuggingFace Diffusers and Transformers models.
H100 Hardware Specifications
| Specification | Value |
|---|---|
| Architecture | Hopper (sm_90 / sm_90a) |
| Streaming Multiprocessors (SMs) | 132 |
| HBM3 Bandwidth | 3.35 TB/s |
| Shared Memory per SM | 228 KB (max configurable) |
| Typical Shared Memory Config | 192 KB shared + 36 KB L1 |
| L2 Cache | 50 MB |
| FP32 CUDA Cores per SM | 128 |
| Total FP32 CUDA Cores | 16,896 |
| Tensor Cores (4th gen) per SM | 4 |
| Total Tensor Cores | 528 |
| Memory | 80 GB HBM3 |
| Memory Bus Width | 5120-bit |
| TDP | 700W (SXM5) |
| Max Threads per SM | 2048 |
| Max Threads per Block | 1024 |
| Max Warps per SM | 64 |
| Warp Size | 32 |
| Register File per SM | 256 KB |
| Max Registers per Thread | 255 |
Memory Hierarchy
Understanding the H100 memory hierarchy is essential for writing high-performance kernels.
βββββββββββββββββββββββββββββββββββββββββββββββ
β HBM3 β
β 80 GB @ 3.35 TB/s β
β (~400 cycles latency) β
βββββββββββββββββββββββββββββββββββββββββββββββ€
β L2 Cache β
β 50 MB @ ~12 TB/s β
β (~200 cycles latency) β
βββββββββββββββββββββββββββββββββββββββββββββββ€
β L1 Cache / Shared Memory (per SM) β
β 228 KB total (configurable split) β
β Shared Memory: ~19 TB/s effective β
β (~20-30 cycles latency) β
βββββββββββββββββββββββββββββββββββββββββββββββ€
β Register File (per SM) β
β 256 KB @ ~80 TB/s β
β (~1 cycle latency) β
βββββββββββββββββββββββββββββββββββββββββββββββ
Shared Memory / L1 Cache Configuration
The H100 allows flexible partitioning of the 228 KB per-SM memory between shared memory and L1 cache:
| Shared Memory | L1 Cache | Best For |
|---|---|---|
| 0 KB | 228 KB | Pure streaming workloads |
| 28 KB | 200 KB | Simple kernels with small working sets |
| 100 KB | 128 KB | Moderate data reuse |
| 164 KB | 64 KB | Large tile-based algorithms |
| 192 KB | 36 KB | Maximum shared memory (typical for deep learning) |
| 228 KB | 0 KB | Maximum shared memory (extreme cases) |
// Request specific shared memory size
cudaFuncSetAttribute(
kernel_fn,
cudaFuncAttributeMaxDynamicSharedMemorySize,
192 * 1024 // 192 KB -- recommended for most DL kernels
);
// Set preferred shared memory carveout
cudaFuncSetAttribute(
kernel_fn,
cudaFuncAttributePreferredSharedMemoryCarveout,
cudaSharedmemCarveoutMaxShared
);
Vectorized Memory Access
Vectorized loads and stores are critical for saturating the H100's 3.35 TB/s bandwidth. The memory bus is 5120 bits wide, so maximizing bytes per transaction is essential.
BF16 Vectorized Access (bf16x2)
#include <cuda_bf16.h>
// Load 2 BF16 values as a single 32-bit operation
__device__ __forceinline__ __nv_bfloat162 load_bf16x2(
const __nv_bfloat16* ptr,
int idx
) {
return reinterpret_cast<const __nv_bfloat162*>(ptr)[idx / 2];
}
// Store 2 BF16 values as a single 32-bit operation
__device__ __forceinline__ void store_bf16x2(
__nv_bfloat16* ptr,
int idx,
__nv_bfloat162 val
) {
reinterpret_cast<__nv_bfloat162*>(ptr)[idx / 2] = val;
}
// Arithmetic on bf16x2 (operates on both elements simultaneously)
__device__ __forceinline__ __nv_bfloat162 bf16x2_add(
__nv_bfloat162 a,
__nv_bfloat162 b
) {
return __hadd2(a, b);
}
__device__ __forceinline__ __nv_bfloat162 bf16x2_mul(
__nv_bfloat162 a,
__nv_bfloat162 b
) {
return __hmul2(a, b);
}
FP16 Vectorized Access (half2)
#include <cuda_fp16.h>
// Load 2 FP16 values as a single 32-bit operation
__device__ __forceinline__ half2 load_half2(
const half* ptr,
int idx
) {
return reinterpret_cast<const half2*>(ptr)[idx / 2];
}
// half2 fused multiply-add
__device__ __forceinline__ half2 half2_fma(
half2 a, half2 b, half2 c
) {
return __hfma2(a, b, c);
}
128-bit Vectorized Access (float4)
// Load 16 bytes at once (4 floats, 8 BF16, or 8 FP16 values)
__device__ __forceinline__ float4 load_128bit(
const void* ptr,
int idx // index in units of float4
) {
return reinterpret_cast<const float4*>(ptr)[idx];
}
// Example: Process 8 BF16 values from a single 128-bit load
__device__ void process_bf16_vectorized(
const __nv_bfloat16* input,
__nv_bfloat16* output,
int n
) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
// Each thread processes 8 BF16 values (128 bits)
int vec_idx = tid;
int elem_idx = vec_idx * 8;
if (elem_idx + 7 < n) {
float4 packed = reinterpret_cast<const float4*>(input)[vec_idx];
__nv_bfloat162* pairs = reinterpret_cast<__nv_bfloat162*>(&packed);
float4 result;
__nv_bfloat162* out_pairs = reinterpret_cast<__nv_bfloat162*>(&result);
#pragma unroll
for (int i = 0; i < 4; i++) {
float lo = __bfloat162float(__low2bfloat16(pairs[i]));
float hi = __bfloat162float(__high2bfloat16(pairs[i]));
// Process (example: ReLU)
lo = fmaxf(lo, 0.0f);
hi = fmaxf(hi, 0.0f);
out_pairs[i] = __halves2bfloat162(
__float2bfloat16(lo),
__float2bfloat16(hi)
);
}
reinterpret_cast<float4*>(output)[vec_idx] = result;
}
}
Memory Coalescing Rules
For optimal memory throughput on H100:
| Access Pattern | Throughput | Recommendation |
|---|---|---|
| Coalesced 128-byte | 100% | Ideal -- 32 threads each load 4 bytes |
| Coalesced 64-byte | ~90% | Acceptable for half/bf16 |
| Strided (stride=2) | ~50% | Avoid or restructure data layout |
| Random | ~10-25% | Use shared memory staging |
// GOOD: Coalesced access -- consecutive threads access consecutive memory
int idx = blockIdx.x * blockDim.x + threadIdx.x;
float val = input[idx]; // Thread 0 reads [0], thread 1 reads [1], etc.
// BAD: Strided access -- threads skip elements
int idx = threadIdx.x * stride; // Non-coalesced if stride > 1
float val = input[idx];
// FIX: Use shared memory to restructure access
__shared__ float smem[BLOCK_SIZE * stride];
// Load coalesced into shared memory first, then access with stride from smem
Warp Shuffles
Warp shuffle operations allow threads within a warp to exchange data without shared memory, providing much lower latency.
Basic Warp Reduction
// Sum reduction within a warp (32 threads)
__device__ __forceinline__ float warp_reduce_sum(float val) {
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1) {
val += __shfl_xor_sync(0xffffffff, val, offset);
}
return val;
}
// Max reduction within a warp
__device__ __forceinline__ float warp_reduce_max(float val) {
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1) {
val = fmaxf(val, __shfl_xor_sync(0xffffffff, val, offset));
}
return val;
}
Block-Level Reduction Using Warp Shuffles
template<int BLOCK_SIZE>
__device__ float block_reduce_sum(float val) {
constexpr int NUM_WARPS = BLOCK_SIZE / 32;
__shared__ float warp_sums[NUM_WARPS];
int warp_id = threadIdx.x / 32;
int lane_id = threadIdx.x % 32;
// Step 1: Reduce within each warp
val = warp_reduce_sum(val);
// Step 2: First thread of each warp writes to shared memory
if (lane_id == 0) {
warp_sums[warp_id] = val;
}
__syncthreads();
// Step 3: First warp reduces across warps
if (warp_id == 0) {
val = (lane_id < NUM_WARPS) ? warp_sums[lane_id] : 0.0f;
val = warp_reduce_sum(val);
}
return val; // Only thread 0 has the final result
}
Warp Shuffle for Data Exchange
// Broadcast a value from one thread to all threads in the warp
__device__ __forceinline__ float warp_broadcast(float val, int src_lane) {
return __shfl_sync(0xffffffff, val, src_lane);
}
// Shift values down within a warp (useful for prefix sums)
__device__ __forceinline__ float warp_shift_down(float val, int delta) {
return __shfl_down_sync(0xffffffff, val, delta);
}
// Shift values up within a warp
__device__ __forceinline__ float warp_shift_up(float val, int delta) {
return __shfl_up_sync(0xffffffff, val, delta);
}
Register Optimization
The H100 has 256 KB of registers per SM (65,536 32-bit registers). Register pressure directly impacts occupancy.
Register Usage Guidelines
| Registers/Thread | Max Threads/SM | Occupancy | Recommendation |
|---|---|---|---|
| 32 | 2048 | 100% | Ideal for simple kernels |
| 64 | 1024 | 50% | Good for moderate complexity |
| 128 | 512 | 25% | Acceptable for compute-bound |
| 255 | 256 | 12.5% | Only for very compute-heavy |
Controlling Register Usage
// Limit register usage per thread
// This increases occupancy at the cost of register spilling
__global__ __launch_bounds__(256, 4) // max 256 threads/block, min 4 blocks/SM
void my_kernel(...) {
// Kernel code
}
// Or use compiler flag:
// --maxrregcount=128
Tips for Reducing Register Pressure
// 1. Use local scope to limit variable lifetime
{
float temp = compute_something();
use(temp);
} // temp can be reused by compiler after this scope
// 2. Prefer recomputation over storing intermediate results
// BAD: wastes a register
float x_squared = x * x;
float x_cubed = x_squared * x;
float result = x_cubed + x_squared + x;
// GOOD: recompute instead
float result = x * x * x + x * x + x;
// 3. Use shared memory for values needed across synchronization points
__shared__ float shared_val;
// Instead of: float local_val = compute(); __syncthreads(); use(local_val);
Occupancy Tuning for 132 SMs
The H100 has 132 SMs. Grid and block sizes must be tuned to maximize utilization.
Grid Sizing
// Query SM count at runtime (portable across GPU models)
int num_sms;
cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, 0);
// H100: num_sms = 132
// Minimum grid size for full SM utilization
int min_grid = num_sms; // 132
// Recommended: 2-4 blocks per SM for latency hiding
int grid_2x = num_sms * 2; // 264 blocks
int grid_4x = num_sms * 4; // 528 blocks
// For occupancy-limited kernels, use cudaOccupancyMaxActiveBlocksPerMultiprocessor
int max_blocks_per_sm;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_blocks_per_sm, kernel_fn, block_size, shared_mem_bytes
);
int optimal_grid = num_sms * max_blocks_per_sm;
Block Size Selection
// Auto-tune block size using CUDA occupancy API
int min_grid_size, optimal_block_size;
cudaOccupancyMaxPotentialBlockSize(
&min_grid_size,
&optimal_block_size,
kernel_fn,
shared_mem_bytes,
0 // max block size (0 = no limit)
);
printf("Optimal block size: %d\n", optimal_block_size);
printf("Minimum grid size: %d\n", min_grid_size);
Occupancy Calculator
void analyze_occupancy(void* kernel, int block_size, size_t shared_mem) {
cudaFuncAttributes attr;
cudaFuncGetAttributes(&attr, kernel);
printf("Registers per thread: %d\n", attr.numRegs);
printf("Shared memory (static): %zu bytes\n", attr.sharedSizeBytes);
printf("Max threads per block: %d\n", attr.maxThreadsPerBlock);
int max_active_blocks;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks, kernel, block_size, shared_mem
);
float occupancy = (float)(max_active_blocks * block_size) / 2048.0f;
printf("Active blocks/SM: %d\n", max_active_blocks);
printf("Occupancy: %.1f%%\n", occupancy * 100.0f);
printf("Active warps/SM: %d / 64\n", max_active_blocks * block_size / 32);
}
Precision: BF16 vs FP16
Format Comparison
| Property | BF16 | FP16 |
|---|---|---|
| Total bits | 16 | 16 |
| Sign bits | 1 | 1 |
| Exponent bits | 8 | 5 |
| Mantissa bits | 7 | 10 |
| Dynamic range | ~1e-38 to ~3e38 | ~6e-5 to 65504 |
| Precision | ~3 decimal digits | ~3.3 decimal digits |
| Overflow risk | Very low | Moderate |
| H100 Tensor Core | Yes | Yes |
When to Use Each
// BF16: Preferred for most deep learning workloads
// - Wider dynamic range avoids overflow in attention scores
// - Safe for gradient accumulation
// - Native support on H100
__nv_bfloat16 bf16_val = __float2bfloat16(1.0f);
// FP16: Use when precision matters more than range
// - Image pixel values (0-1 range)
// - Already-normalized values
// - When exact reproducibility with FP16 models is needed
half fp16_val = __float2half(1.0f);
Type Conversion Helpers
#include <cuda_bf16.h>
#include <cuda_fp16.h>
// BF16 to float
__device__ __forceinline__ float bf16_to_float(__nv_bfloat16 val) {
return __bfloat162float(val);
}
// Float to BF16
__device__ __forceinline__ __nv_bfloat16 float_to_bf16(float val) {
return __float2bfloat16(val);
}
// FP16 to float
__device__ __forceinline__ float fp16_to_float(half val) {
return __half2float(val);
}
// Float to FP16
__device__ __forceinline__ half float_to_fp16(float val) {
return __float2half(val);
}
// BF16 to FP16 (via float intermediate)
__device__ __forceinline__ half bf16_to_fp16(__nv_bfloat16 val) {
return __float2half(__bfloat162float(val));
}
// FP16 to BF16 (via float intermediate)
__device__ __forceinline__ __nv_bfloat16 fp16_to_bf16(half val) {
return __float2bfloat16(__half2float(val));
}
Online Softmax
The online softmax algorithm computes softmax in a single pass, which is critical for memory-efficient attention implementations.
Algorithm
The standard softmax requires two passes:
- Compute max over all elements
- Compute sum of exp(x - max) and normalize
The online algorithm combines these into a single pass:
// Online softmax: single-pass algorithm
// Maintains running max and running sum, correcting as new max is found
struct OnlineSoftmaxState {
float max_val;
float sum_exp;
};
__device__ __forceinline__ OnlineSoftmaxState online_softmax_init() {
return {-INFINITY, 0.0f};
}
__device__ __forceinline__ OnlineSoftmaxState online_softmax_update(
OnlineSoftmaxState state,
float new_val
) {
float new_max = fmaxf(state.max_val, new_val);
float correction = expf(state.max_val - new_max);
float new_sum = state.sum_exp * correction + expf(new_val - new_max);
return {new_max, new_sum};
}
__device__ __forceinline__ OnlineSoftmaxState online_softmax_merge(
OnlineSoftmaxState a,
OnlineSoftmaxState b
) {
float new_max = fmaxf(a.max_val, b.max_val);
float new_sum = a.sum_exp * expf(a.max_val - new_max)
+ b.sum_exp * expf(b.max_val - new_max);
return {new_max, new_sum};
}
Warp-Level Online Softmax
__device__ OnlineSoftmaxState warp_online_softmax_reduce(
OnlineSoftmaxState state
) {
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1) {
OnlineSoftmaxState other;
other.max_val = __shfl_xor_sync(0xffffffff, state.max_val, offset);
other.sum_exp = __shfl_xor_sync(0xffffffff, state.sum_exp, offset);
state = online_softmax_merge(state, other);
}
return state;
}
Full Softmax Kernel Using Online Algorithm
template<int BLOCK_SIZE>
__global__ void online_softmax_kernel(
const __nv_bfloat16* __restrict__ input,
__nv_bfloat16* __restrict__ output,
const int seq_len
) {
const int row = blockIdx.x;
const int tid = threadIdx.x;
const __nv_bfloat16* row_in = input + row * seq_len;
__nv_bfloat16* row_out = output + row * seq_len;
// Phase 1: Online reduction to find max and sum
OnlineSoftmaxState state = online_softmax_init();
for (int i = tid; i < seq_len; i += BLOCK_SIZE) {
float val = __bfloat162float(row_in[i]);
state = online_softmax_update(state, val);
}
// Warp reduction
state = warp_online_softmax_reduce(state);
// Block reduction
__shared__ OnlineSoftmaxState warp_states[BLOCK_SIZE / 32];
int warp_id = tid / 32;
int lane_id = tid % 32;
if (lane_id == 0) warp_states[warp_id] = state;
__syncthreads();
if (warp_id == 0) {
state = (lane_id < BLOCK_SIZE / 32)
? warp_states[lane_id]
: online_softmax_init();
state = warp_online_softmax_reduce(state);
}
__shared__ float s_max, s_sum_inv;
if (tid == 0) {
s_max = state.max_val;
s_sum_inv = 1.0f / state.sum_exp;
}
__syncthreads();
// Phase 2: Apply softmax
for (int i = tid; i < seq_len; i += BLOCK_SIZE) {
float val = __bfloat162float(row_in[i]);
float result = expf(val - s_max) * s_sum_inv;
row_out[i] = __float2bfloat16(result);
}
}
Diffusers-Specific Optimizations
Fused RMSNorm + Linear
A common pattern in transformer blocks is RMSNorm followed by a linear layer. Fusing these avoids an extra memory round-trip:
// Instead of:
// normalized = rmsnorm(x) // Read x, write normalized
// output = linear(normalized) // Read normalized, write output
//
// Fused version: reads x once, applies both operations
template<int BLOCK_SIZE, int HIDDEN_SIZE>
__global__ void fused_rmsnorm_linear_kernel(
const __nv_bfloat16* __restrict__ input, // [batch, seq, hidden]
const __nv_bfloat16* __restrict__ norm_weight, // [hidden]
const __nv_bfloat16* __restrict__ linear_weight, // [out_features, hidden]
const __nv_bfloat16* __restrict__ linear_bias, // [out_features] or nullptr
__nv_bfloat16* __restrict__ output, // [batch, seq, out_features]
const int out_features,
const float epsilon
) {
// Implementation would combine RMSNorm and matrix-vector multiply
// This is most beneficial when out_features is small (e.g., QKV projection)
// ...
}
Fused GEGLU
GEGLU splits the input tensor, applies GELU to one half, and multiplies:
__global__ void fused_geglu_kernel(
const __nv_bfloat16* __restrict__ input,
__nv_bfloat16* __restrict__ output,
const int batch_size,
const int seq_len,
const int hidden_size // Full hidden size (2x output)
) {
const int half_hidden = hidden_size / 2;
const int total_elements = batch_size * seq_len * half_hidden;
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
// Process 8 elements per thread for vectorized access
const int vec_idx = idx * 8;
if (vec_idx + 7 < total_elements) {
int linear_idx = vec_idx;
int row = linear_idx / half_hidden;
int col = linear_idx % half_hidden;
// Load x and gate from interleaved or split layout
// (depends on model -- LTX-Video splits along last dim)
float4 x_packed = reinterpret_cast<const float4*>(
input + row * hidden_size + col
)[0];
float4 g_packed = reinterpret_cast<const float4*>(
input + row * hidden_size + half_hidden + col
)[0];
__nv_bfloat162* x_pairs = reinterpret_cast<__nv_bfloat162*>(&x_packed);
__nv_bfloat162* g_pairs = reinterpret_cast<__nv_bfloat162*>(&g_packed);
float4 result;
__nv_bfloat162* out_pairs = reinterpret_cast<__nv_bfloat162*>(&result);
#pragma unroll
for (int i = 0; i < 4; i++) {
float x0 = __bfloat162float(__low2bfloat16(x_pairs[i]));
float x1 = __bfloat162float(__high2bfloat16(x_pairs[i]));
float g0 = __bfloat162float(__low2bfloat16(g_pairs[i]));
float g1 = __bfloat162float(__high2bfloat16(g_pairs[i]));
// Apply GELU to gate
g0 = g0 * 0.5f * (1.0f + tanhf(0.7978845608f * (g0 + 0.044715f * g0 * g0 * g0)));
g1 = g1 * 0.5f * (1.0f + tanhf(0.7978845608f * (g1 + 0.044715f * g1 * g1 * g1)));
out_pairs[i] = __halves2bfloat162(
__float2bfloat16(x0 * g0),
__float2bfloat16(x1 * g1)
);
}
reinterpret_cast<float4*>(output + row * half_hidden + col)[0] = result;
}
}
Profiling with nsys and ncu
nsys (Nsight Systems) -- Timeline Profiling
# Basic timeline profile
nsys profile -o h100_trace python inference.py
# With detailed CUDA and NVTX tracing
nsys profile \
--trace=cuda,cudnn,cublas,nvtx,osrt \
--cuda-memory-usage=true \
--stats=true \
-o h100_detailed \
python inference.py
# Focus on specific time range
nsys profile \
--delay=5 --duration=10 \
-o h100_window \
python inference.py
ncu (Nsight Compute) -- Kernel Profiling
# Full kernel analysis
ncu --set full \
--target-processes all \
--launch-skip 10 --launch-count 5 \
-o h100_kernel_profile \
python inference.py
# Memory throughput analysis
ncu --metrics \
dram__bytes_read.sum,\
dram__bytes_write.sum,\
dram__throughput.avg.pct_of_peak_sustained,\
l1tex__t_bytes_pipe_lsu_mem_global_op_ld.sum,\
l2__read_throughput.avg.pct_of_peak_sustained \
python inference.py
# Occupancy analysis
ncu --metrics \
sm__warps_active.avg.pct_of_peak_sustained,\
sm__maximum_warps_per_active_cycle,\
launch__occupancy \
python inference.py
# Instruction mix analysis
ncu --metrics \
smsp__inst_executed.sum,\
smsp__inst_executed_pipe_fp32.sum,\
smsp__inst_executed_pipe_fp16.sum,\
smsp__inst_executed_pipe_tensor.sum \
python inference.py
Key Metrics to Monitor
| Metric | Target | Description |
|---|---|---|
| DRAM Throughput | > 70% of 3.35 TB/s | Memory bandwidth utilization |
| SM Occupancy | > 50% | Active warps vs maximum |
| L2 Hit Rate | > 50% | Data reuse in L2 cache |
| Shared Memory Throughput | > 60% | Shared memory bandwidth utilization |
| Achieved FLOPs | > 40% of peak | Compute utilization |
| Warp Stall Reasons | < 20% any single | Identify bottlenecks |
Profiling Workflow
- Start with nsys to get the big picture (timeline, kernel durations)
- Identify hotspot kernels from the nsys output
- Profile hotspots with ncu for detailed analysis
- Check memory throughput -- most DL kernels are memory-bound
- Check occupancy -- ensure you are not register/shared-memory limited
- Iterate: optimize, reprofile, compare
Summary
When optimizing for H100:
- Leverage 132 SMs -- size grids as multiples of 132
- Saturate 3.35 TB/s bandwidth -- vectorized 128-bit loads/stores are essential
- Use 192 KB shared memory for maximum data reuse in tiled algorithms
- Prefer BF16 for its wider dynamic range, matching modern model precision
- Use warp shuffles for intra-warp communication to avoid shared memory overhead
- Implement online softmax for memory-efficient attention
- Fuse operations (RMSNorm+linear, GEGLU) to reduce memory round-trips
- Profile with nsys then ncu to identify and fix bottlenecks
- Balance occupancy vs. register usage -- 50% occupancy is often sufficient for compute-bound kernels