test / skill_example /references /diffusers-integration.md
Jack-Khuu
Demo
88a1dd2
# Diffusers Pipeline Integration Guide
## Overview
This guide covers the complete process of integrating custom CUDA kernels into HuggingFace Diffusers pipelines. It includes model architecture analysis, custom processor creation, kernel injection, and verification.
## Model Architecture Analysis
Before writing any kernels, you must understand the target model's architecture. The key operations to identify are:
1. **Normalization layers** (RMSNorm, LayerNorm, GroupNorm)
2. **Activation functions** (GELU, GEGLU, SiLU)
3. **Attention mechanisms** (self-attention, cross-attention)
4. **Linear projections** (QKV projections, feed-forward)
5. **Positional encodings** (RoPE, learned, sinusoidal)
### Inspecting a Diffusers Model
```python
from diffusers import LTXPipeline
import torch
pipe = LTXPipeline.from_pretrained(
"Lightricks/LTX-Video",
torch_dtype=torch.bfloat16
)
# List all module types
module_types = set()
for name, module in pipe.transformer.named_modules():
module_types.add(type(module).__name__)
print("Module types found:")
for t in sorted(module_types):
print(f" - {t}")
# Count occurrences of each type
from collections import Counter
type_counts = Counter(
type(m).__name__ for _, m in pipe.transformer.named_modules()
)
print("\nModule counts:")
for name, count in type_counts.most_common():
print(f" {name}: {count}")
```
### Identifying Optimization Targets
Look for operations that are:
1. **Frequently called** -- operations inside the main transformer blocks
2. **Memory-bound** -- normalization, activations, and element-wise operations
3. **Fusible** -- adjacent operations that can be combined into a single kernel
```python
# Trace the model to understand the call graph
with torch.no_grad():
# Create dummy inputs matching the model's expected shapes
sample = torch.randn(1, 16, 32, 32, dtype=torch.bfloat16, device="cuda")
timestep = torch.tensor([500.0], device="cuda")
encoder_hidden_states = torch.randn(1, 77, 2048, dtype=torch.bfloat16, device="cuda")
# Profile with PyTorch profiler
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA],
record_shapes=True,
) as prof:
output = pipe.transformer(sample, timestep, encoder_hidden_states)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
```
## LTX-Video Architecture
LTX-Video is a video generation model based on a Transformer architecture. Here is its structure:
```
LTXVideoTransformer3DModel
β”œβ”€β”€ patch_embed # Patch embedding (Conv3D)
β”œβ”€β”€ time_embed # Timestep embedding
β”œβ”€β”€ transformer_blocks # Main transformer blocks (N layers)
β”‚ β”œβ”€β”€ norm1 # RMSNorm (pre-attention)
β”‚ β”œβ”€β”€ attn1 # Self-attention
β”‚ β”‚ β”œβ”€β”€ to_q # Linear (query projection)
β”‚ β”‚ β”œβ”€β”€ to_k # Linear (key projection)
β”‚ β”‚ β”œβ”€β”€ to_v # Linear (value projection)
β”‚ β”‚ β”œβ”€β”€ to_out[0] # Linear (output projection)
β”‚ β”‚ └── processor # Attention processor
β”‚ β”œβ”€β”€ norm2 # RMSNorm (pre-FFN)
β”‚ β”œβ”€β”€ ff # Feed-forward network
β”‚ β”‚ β”œβ”€β”€ net[0] # GELU activation (NOT GEGLU!)
β”‚ β”‚ └── net[2] # Linear (down projection)
β”‚ β”œβ”€β”€ norm3 # RMSNorm (pre-cross-attention, if present)
β”‚ └── attn2 # Cross-attention (if present)
β”œβ”€β”€ norm_out # Final RMSNorm
└── proj_out # Output projection
```
### Key Observations for LTX-Video
| Feature | Detail |
|---|---|
| Normalization | RMSNorm (diffusers version) |
| Activation | GELU (not GEGLU!) |
| Attention | Standard scaled dot-product |
| Positional encoding | RoPE (Rotary Position Embedding) |
| Precision | BF16 preferred |
| Hidden sizes | Varies by model variant (1024, 2048, etc.) |
## Custom Attention Processor
Diffusers uses a processor pattern for attention. You can replace the default processor with a custom one that uses your optimized kernels.
### OptimizedLTXVideoAttnProcessor
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
class OptimizedLTXVideoAttnProcessor:
"""
Custom attention processor for LTX-Video that uses optimized CUDA kernels.
This replaces the default attention computation with:
1. Fused QKV projection (optional)
2. RoPE application via custom kernel
3. Flash Attention or custom attention kernel
4. Fused output projection (optional)
"""
def __init__(
self,
use_custom_rope: bool = True,
use_custom_softmax: bool = False,
):
self.use_custom_rope = use_custom_rope
self.use_custom_softmax = use_custom_softmax
def __call__(
self,
attn, # The Attention module
hidden_states: torch.Tensor, # [batch, seq_len, hidden_dim]
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Determine if this is self-attention or cross-attention
is_cross_attention = encoder_hidden_states is not None
input_states = encoder_hidden_states if is_cross_attention else hidden_states
batch_size, seq_len, _ = hidden_states.shape
# QKV projections
query = attn.to_q(hidden_states)
key = attn.to_k(input_states)
value = attn.to_v(input_states)
# Reshape for multi-head attention
head_dim = attn.head_dim
num_heads = attn.heads
query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
# Apply RoPE if provided
if image_rotary_emb is not None:
if self.use_custom_rope:
# Use custom CUDA RoPE kernel
query = cuda_apply_rope(query, image_rotary_emb)
key = cuda_apply_rope(key, image_rotary_emb)
else:
# Fallback to PyTorch implementation
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
# Scaled dot-product attention
# PyTorch's SDPA will use Flash Attention on H100 when available
attn_output = F.scaled_dot_product_attention(
query, key, value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
)
# Reshape back
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, -1, num_heads * head_dim)
# Output projection
attn_output = attn.to_out[0](attn_output)
# Dropout (if any)
if len(attn.to_out) > 1:
attn_output = attn.to_out[1](attn_output)
return attn_output
```
### Setting the Processor
```python
from diffusers.models.attention_processor import AttnProcessor
def set_custom_attention_processors(model):
"""Replace attention processors in all transformer blocks."""
processors = {}
for name, module in model.named_modules():
if hasattr(module, 'set_processor'):
processors[name] = OptimizedLTXVideoAttnProcessor(
use_custom_rope=True,
use_custom_softmax=False,
)
model.set_attn_processor(processors)
print(f"Set {len(processors)} custom attention processors")
return model
```
## RMSNorm Module Patcher
The RMSNorm patcher replaces the forward method of all RMSNorm modules with a custom CUDA implementation.
```python
import torch
import torch.nn as nn
from diffusers.models.normalization import RMSNorm
from functools import wraps
class RMSNormPatcher:
"""
Patches RMSNorm modules in a model to use custom CUDA kernels.
Usage:
patcher = RMSNormPatcher()
patcher.patch(model)
# ... run inference ...
patcher.unpatch(model) # Restore original behavior
"""
def __init__(self, cuda_rmsnorm_fn=None, cuda_rmsnorm_no_weight_fn=None):
"""
Args:
cuda_rmsnorm_fn: Custom CUDA function for weighted RMSNorm.
Signature: (input: Tensor, weight: Tensor, eps: float) -> Tensor
cuda_rmsnorm_no_weight_fn: Custom CUDA function for unweighted RMSNorm.
Signature: (input: Tensor, eps: float) -> Tensor
"""
self.cuda_rmsnorm = cuda_rmsnorm_fn
self.cuda_rmsnorm_no_weight = cuda_rmsnorm_no_weight_fn
self._original_forwards = {}
def patch(self, model: nn.Module) -> int:
"""
Patch all RMSNorm modules in the model.
Returns:
Number of modules patched.
"""
count = 0
for name, module in model.named_modules():
if isinstance(module, RMSNorm):
# Save original forward
self._original_forwards[name] = module.forward
# Create patched forward
module.forward = self._make_patched_forward(module)
count += 1
return count
def unpatch(self, model: nn.Module) -> int:
"""Restore original forward methods."""
count = 0
for name, module in model.named_modules():
if name in self._original_forwards:
module.forward = self._original_forwards[name]
count += 1
self._original_forwards.clear()
return count
def _make_patched_forward(self, module):
"""Create a patched forward function for an RMSNorm module."""
cuda_fn = self.cuda_rmsnorm
cuda_fn_no_weight = self.cuda_rmsnorm_no_weight
def patched_forward(hidden_states: torch.Tensor) -> torch.Tensor:
# Handle the case where weight is None
if module.weight is None:
if cuda_fn_no_weight is not None:
return cuda_fn_no_weight(hidden_states, module.eps)
else:
# Fallback: manual RMSNorm without weight
variance = hidden_states.pow(2).mean(-1, keepdim=True)
return hidden_states * torch.rsqrt(variance + module.eps)
else:
if cuda_fn is not None:
return cuda_fn(hidden_states, module.weight, module.eps)
else:
# Fallback: manual RMSNorm with weight
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + module.eps)
return hidden_states * module.weight
return patched_forward
```
## Kernel Injection Function
The main injection function ties everything together:
```python
import torch
import torch.nn as nn
from typing import Optional, Dict, Any
def inject_custom_kernels(
model: nn.Module,
kernel_config: Optional[Dict[str, Any]] = None,
) -> nn.Module:
"""
Inject optimized CUDA kernels into a diffusers model.
This function patches:
1. RMSNorm layers -> custom CUDA RMSNorm
2. Attention processors -> custom attention processor
3. Activation functions -> custom CUDA activations
Args:
model: A diffusers model (e.g., pipe.transformer)
kernel_config: Optional configuration dict with keys:
- 'rmsnorm': bool (default True)
- 'attention': bool (default True)
- 'activation': bool (default True)
- 'rope': bool (default True)
Returns:
The patched model (modified in place)
"""
config = {
'rmsnorm': True,
'attention': True,
'activation': True,
'rope': True,
}
if kernel_config:
config.update(kernel_config)
stats = {'rmsnorm': 0, 'attention': 0, 'activation': 0}
# Step 1: Patch RMSNorm
if config['rmsnorm']:
patcher = RMSNormPatcher(
cuda_rmsnorm_fn=cuda_rmsnorm,
cuda_rmsnorm_no_weight_fn=cuda_rmsnorm_no_weight,
)
stats['rmsnorm'] = patcher.patch(model)
# Step 2: Patch attention processors
if config['attention']:
for name, module in model.named_modules():
if hasattr(module, 'set_processor'):
processor = OptimizedLTXVideoAttnProcessor(
use_custom_rope=config['rope'],
)
module.set_processor(processor)
stats['attention'] += 1
# Step 3: Patch activation functions
if config['activation']:
stats['activation'] = _patch_activations(model)
print(f"Kernel injection complete:")
print(f" RMSNorm layers patched: {stats['rmsnorm']}")
print(f" Attention processors patched: {stats['attention']}")
print(f" Activation functions patched: {stats['activation']}")
return model
def _patch_activations(model: nn.Module) -> int:
"""Patch activation functions with custom CUDA kernels."""
count = 0
for name, module in model.named_modules():
# Check for GELU activation
if isinstance(module, nn.GELU):
original_forward = module.forward
def make_gelu_forward():
def forward(input):
return cuda_gelu(input)
return forward
module.forward = make_gelu_forward()
count += 1
return count
```
## Model-Specific Differences
Different diffusion models have different architectures. Here are the key differences:
### LTX-Video
```python
# Architecture: Transformer-based video model
# Normalization: RMSNorm (diffusers)
# Activation: GELU (plain, NOT GEGLU)
# Attention: Standard multi-head with RoPE
# Positional encoding: 3D RoPE (spatial + temporal)
# Weight: RMSNorm weight MAY be None
def inject_ltx_video(pipe):
model = pipe.transformer
inject_custom_kernels(model, {
'rmsnorm': True,
'attention': True,
'activation': True, # Patches GELU
'rope': True,
})
```
### Stable Diffusion 3 (SD3)
```python
# Architecture: MMDiT (Multi-Modal Diffusion Transformer)
# Normalization: RMSNorm and AdaLayerNorm
# Activation: GEGLU (gated, NOT plain GELU)
# Attention: Joint attention (text + image in same sequence)
# Positional encoding: Learned + RoPE
def inject_sd3(pipe):
model = pipe.transformer
inject_custom_kernels(model, {
'rmsnorm': True,
'attention': True,
'activation': True, # Patches GEGLU
'rope': True,
})
# SD3-specific: Also patch AdaLayerNorm if supported
_patch_adalayernorm(model)
```
### FLUX
```python
# Architecture: Similar to SD3 (MMDiT variant)
# Normalization: RMSNorm
# Activation: GEGLU
# Attention: Joint attention with different block structure
# Positional encoding: RoPE
# Note: FLUX has single-stream and double-stream blocks
def inject_flux(pipe):
model = pipe.transformer
inject_custom_kernels(model, {
'rmsnorm': True,
'attention': True,
'activation': True, # Patches GEGLU
'rope': True,
})
```
### Comparison Table
| Feature | LTX-Video | SD3 | FLUX |
|---|---|---|---|
| Normalization | RMSNorm | RMSNorm + AdaLN | RMSNorm |
| Activation | GELU | GEGLU | GEGLU |
| Attention type | Standard | Joint (MMDiT) | Joint (MMDiT) |
| RoPE | 3D (spatial+temporal) | 2D (spatial) | 2D (spatial) |
| Weight may be None | Yes | Rare | No |
| `set_processor` | Yes | Yes | Yes |
| Block structure | Uniform | Uniform | Single + Double stream |
| Hidden sizes | 1024-2048 | 1536-4096 | 3072 |
## Verification Steps
After injecting kernels, verify correctness and performance.
### Step 1: Numerical Correctness
```python
import torch
def verify_correctness(pipe, rtol=1e-2, atol=1e-3):
"""
Verify that custom kernels produce numerically correct output.
"""
device = "cuda"
dtype = torch.bfloat16
# Generate reference output without custom kernels
torch.manual_seed(42)
pipe_ref = load_pipeline() # Fresh pipeline
pipe_ref.to(device)
with torch.no_grad():
ref_output = pipe_ref(
"a cat sitting on a mat",
num_inference_steps=2,
output_type="latent",
generator=torch.Generator(device).manual_seed(42),
).images
# Generate output with custom kernels
torch.manual_seed(42)
inject_custom_kernels(pipe.transformer)
with torch.no_grad():
test_output = pipe(
"a cat sitting on a mat",
num_inference_steps=2,
output_type="latent",
generator=torch.Generator(device).manual_seed(42),
).images
# Compare
max_diff = (ref_output - test_output).abs().max().item()
mean_diff = (ref_output - test_output).abs().mean().item()
print(f"Max absolute difference: {max_diff:.6f}")
print(f"Mean absolute difference: {mean_diff:.6f}")
is_close = torch.allclose(ref_output, test_output, rtol=rtol, atol=atol)
print(f"Numerically close (rtol={rtol}, atol={atol}): {is_close}")
return is_close
```
### Step 2: Module-Level Verification
```python
def verify_rmsnorm(hidden_size=2048, batch_size=4, eps=1e-6):
"""Verify custom RMSNorm against PyTorch reference."""
from diffusers.models.normalization import RMSNorm
# Create reference module
ref_norm = RMSNorm(hidden_size, eps=eps).cuda().to(torch.bfloat16)
# Test input
x = torch.randn(batch_size, 128, hidden_size, dtype=torch.bfloat16, device="cuda")
# Reference output
with torch.no_grad():
ref_out = ref_norm(x)
# Custom kernel output
with torch.no_grad():
custom_out = cuda_rmsnorm(x, ref_norm.weight, eps)
max_diff = (ref_out - custom_out).abs().max().item()
print(f"RMSNorm max diff: {max_diff:.8f}")
assert max_diff < 1e-2, f"RMSNorm verification failed: max_diff={max_diff}"
print("RMSNorm verification PASSED")
```
### Step 3: Performance Verification
```python
import time
import torch
def verify_performance(pipe, num_runs=5, num_steps=20):
"""
Verify that custom kernels provide a speedup.
"""
prompt = "A beautiful sunset over the ocean, 4K, detailed"
generator = torch.Generator("cuda").manual_seed(42)
# Warmup
pipe(prompt, num_inference_steps=2, generator=generator)
# Benchmark
times = []
for i in range(num_runs):
torch.cuda.synchronize()
start = time.perf_counter()
pipe(
prompt,
num_inference_steps=num_steps,
generator=torch.Generator("cuda").manual_seed(42),
)
torch.cuda.synchronize()
end = time.perf_counter()
times.append(end - start)
avg_time = sum(times) / len(times)
std_time = (sum((t - avg_time) ** 2 for t in times) / len(times)) ** 0.5
print(f"Average time: {avg_time:.3f}s +/- {std_time:.3f}s")
print(f"Per-step time: {avg_time / num_steps * 1000:.1f}ms")
return avg_time
```
### Step 4: Edge Case Testing
```python
def test_edge_cases():
"""Test edge cases that commonly cause issues."""
# Test 1: RMSNorm with None weight
print("Test 1: RMSNorm with None weight...")
x = torch.randn(2, 64, 2048, dtype=torch.bfloat16, device="cuda")
try:
result = cuda_rmsnorm_no_weight(x, 1e-6)
assert result.shape == x.shape
print(" PASSED")
except Exception as e:
print(f" FAILED: {e}")
# Test 2: Non-contiguous input
print("Test 2: Non-contiguous input...")
x = torch.randn(2, 2048, 64, dtype=torch.bfloat16, device="cuda").transpose(1, 2)
assert not x.is_contiguous()
try:
result = cuda_rmsnorm(x.contiguous(), weight, 1e-6)
print(" PASSED")
except Exception as e:
print(f" FAILED: {e}")
# Test 3: Very small hidden size
print("Test 3: Small hidden size (64)...")
x = torch.randn(2, 128, 64, dtype=torch.bfloat16, device="cuda")
w = torch.randn(64, dtype=torch.bfloat16, device="cuda")
try:
result = cuda_rmsnorm(x, w, 1e-6)
assert result.shape == x.shape
print(" PASSED")
except Exception as e:
print(f" FAILED: {e}")
# Test 4: Very large hidden size
print("Test 4: Large hidden size (8192)...")
x = torch.randn(1, 32, 8192, dtype=torch.bfloat16, device="cuda")
w = torch.randn(8192, dtype=torch.bfloat16, device="cuda")
try:
result = cuda_rmsnorm(x, w, 1e-6)
assert result.shape == x.shape
print(" PASSED")
except Exception as e:
print(f" FAILED: {e}")
# Test 5: Single element batch
print("Test 5: Single element batch...")
x = torch.randn(1, 1, 2048, dtype=torch.bfloat16, device="cuda")
w = torch.randn(2048, dtype=torch.bfloat16, device="cuda")
try:
result = cuda_rmsnorm(x, w, 1e-6)
assert result.shape == x.shape
print(" PASSED")
except Exception as e:
print(f" FAILED: {e}")
print("\nAll edge case tests complete.")
```
## Complete Integration Example
```python
import torch
from diffusers import LTXPipeline
def main():
# Load pipeline
pipe = LTXPipeline.from_pretrained(
"Lightricks/LTX-Video",
torch_dtype=torch.bfloat16,
)
# Step 1: Inject custom kernels BEFORE moving to device or enabling offloading
inject_custom_kernels(pipe.transformer)
# Step 2: Move to device or enable offloading
pipe.enable_model_cpu_offload()
# Step 3: Optionally compile for additional speedup
pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead")
# Step 4: Run inference
output = pipe(
"A time-lapse of a flower blooming",
num_inference_steps=30,
num_frames=16,
)
# Step 5: Save output
output.frames[0].save("output.mp4")
if __name__ == "__main__":
main()
```