# Diffusers Pipeline Integration Guide ## Overview This guide covers the complete process of integrating custom CUDA kernels into HuggingFace Diffusers pipelines. It includes model architecture analysis, custom processor creation, kernel injection, and verification. ## Model Architecture Analysis Before writing any kernels, you must understand the target model's architecture. The key operations to identify are: 1. **Normalization layers** (RMSNorm, LayerNorm, GroupNorm) 2. **Activation functions** (GELU, GEGLU, SiLU) 3. **Attention mechanisms** (self-attention, cross-attention) 4. **Linear projections** (QKV projections, feed-forward) 5. **Positional encodings** (RoPE, learned, sinusoidal) ### Inspecting a Diffusers Model ```python from diffusers import LTXPipeline import torch pipe = LTXPipeline.from_pretrained( "Lightricks/LTX-Video", torch_dtype=torch.bfloat16 ) # List all module types module_types = set() for name, module in pipe.transformer.named_modules(): module_types.add(type(module).__name__) print("Module types found:") for t in sorted(module_types): print(f" - {t}") # Count occurrences of each type from collections import Counter type_counts = Counter( type(m).__name__ for _, m in pipe.transformer.named_modules() ) print("\nModule counts:") for name, count in type_counts.most_common(): print(f" {name}: {count}") ``` ### Identifying Optimization Targets Look for operations that are: 1. **Frequently called** -- operations inside the main transformer blocks 2. **Memory-bound** -- normalization, activations, and element-wise operations 3. **Fusible** -- adjacent operations that can be combined into a single kernel ```python # Trace the model to understand the call graph with torch.no_grad(): # Create dummy inputs matching the model's expected shapes sample = torch.randn(1, 16, 32, 32, dtype=torch.bfloat16, device="cuda") timestep = torch.tensor([500.0], device="cuda") encoder_hidden_states = torch.randn(1, 77, 2048, dtype=torch.bfloat16, device="cuda") # Profile with PyTorch profiler with torch.profiler.profile( activities=[torch.profiler.ProfilerActivity.CUDA], record_shapes=True, ) as prof: output = pipe.transformer(sample, timestep, encoder_hidden_states) print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20)) ``` ## LTX-Video Architecture LTX-Video is a video generation model based on a Transformer architecture. Here is its structure: ``` LTXVideoTransformer3DModel ├── patch_embed # Patch embedding (Conv3D) ├── time_embed # Timestep embedding ├── transformer_blocks # Main transformer blocks (N layers) │ ├── norm1 # RMSNorm (pre-attention) │ ├── attn1 # Self-attention │ │ ├── to_q # Linear (query projection) │ │ ├── to_k # Linear (key projection) │ │ ├── to_v # Linear (value projection) │ │ ├── to_out[0] # Linear (output projection) │ │ └── processor # Attention processor │ ├── norm2 # RMSNorm (pre-FFN) │ ├── ff # Feed-forward network │ │ ├── net[0] # GELU activation (NOT GEGLU!) │ │ └── net[2] # Linear (down projection) │ ├── norm3 # RMSNorm (pre-cross-attention, if present) │ └── attn2 # Cross-attention (if present) ├── norm_out # Final RMSNorm └── proj_out # Output projection ``` ### Key Observations for LTX-Video | Feature | Detail | |---|---| | Normalization | RMSNorm (diffusers version) | | Activation | GELU (not GEGLU!) | | Attention | Standard scaled dot-product | | Positional encoding | RoPE (Rotary Position Embedding) | | Precision | BF16 preferred | | Hidden sizes | Varies by model variant (1024, 2048, etc.) | ## Custom Attention Processor Diffusers uses a processor pattern for attention. You can replace the default processor with a custom one that uses your optimized kernels. ### OptimizedLTXVideoAttnProcessor ```python import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional class OptimizedLTXVideoAttnProcessor: """ Custom attention processor for LTX-Video that uses optimized CUDA kernels. This replaces the default attention computation with: 1. Fused QKV projection (optional) 2. RoPE application via custom kernel 3. Flash Attention or custom attention kernel 4. Fused output projection (optional) """ def __init__( self, use_custom_rope: bool = True, use_custom_softmax: bool = False, ): self.use_custom_rope = use_custom_rope self.use_custom_softmax = use_custom_softmax def __call__( self, attn, # The Attention module hidden_states: torch.Tensor, # [batch, seq_len, hidden_dim] encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: # Determine if this is self-attention or cross-attention is_cross_attention = encoder_hidden_states is not None input_states = encoder_hidden_states if is_cross_attention else hidden_states batch_size, seq_len, _ = hidden_states.shape # QKV projections query = attn.to_q(hidden_states) key = attn.to_k(input_states) value = attn.to_v(input_states) # Reshape for multi-head attention head_dim = attn.head_dim num_heads = attn.heads query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) # Apply RoPE if provided if image_rotary_emb is not None: if self.use_custom_rope: # Use custom CUDA RoPE kernel query = cuda_apply_rope(query, image_rotary_emb) key = cuda_apply_rope(key, image_rotary_emb) else: # Fallback to PyTorch implementation query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) # Scaled dot-product attention # PyTorch's SDPA will use Flash Attention on H100 when available attn_output = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, ) # Reshape back attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(batch_size, -1, num_heads * head_dim) # Output projection attn_output = attn.to_out[0](attn_output) # Dropout (if any) if len(attn.to_out) > 1: attn_output = attn.to_out[1](attn_output) return attn_output ``` ### Setting the Processor ```python from diffusers.models.attention_processor import AttnProcessor def set_custom_attention_processors(model): """Replace attention processors in all transformer blocks.""" processors = {} for name, module in model.named_modules(): if hasattr(module, 'set_processor'): processors[name] = OptimizedLTXVideoAttnProcessor( use_custom_rope=True, use_custom_softmax=False, ) model.set_attn_processor(processors) print(f"Set {len(processors)} custom attention processors") return model ``` ## RMSNorm Module Patcher The RMSNorm patcher replaces the forward method of all RMSNorm modules with a custom CUDA implementation. ```python import torch import torch.nn as nn from diffusers.models.normalization import RMSNorm from functools import wraps class RMSNormPatcher: """ Patches RMSNorm modules in a model to use custom CUDA kernels. Usage: patcher = RMSNormPatcher() patcher.patch(model) # ... run inference ... patcher.unpatch(model) # Restore original behavior """ def __init__(self, cuda_rmsnorm_fn=None, cuda_rmsnorm_no_weight_fn=None): """ Args: cuda_rmsnorm_fn: Custom CUDA function for weighted RMSNorm. Signature: (input: Tensor, weight: Tensor, eps: float) -> Tensor cuda_rmsnorm_no_weight_fn: Custom CUDA function for unweighted RMSNorm. Signature: (input: Tensor, eps: float) -> Tensor """ self.cuda_rmsnorm = cuda_rmsnorm_fn self.cuda_rmsnorm_no_weight = cuda_rmsnorm_no_weight_fn self._original_forwards = {} def patch(self, model: nn.Module) -> int: """ Patch all RMSNorm modules in the model. Returns: Number of modules patched. """ count = 0 for name, module in model.named_modules(): if isinstance(module, RMSNorm): # Save original forward self._original_forwards[name] = module.forward # Create patched forward module.forward = self._make_patched_forward(module) count += 1 return count def unpatch(self, model: nn.Module) -> int: """Restore original forward methods.""" count = 0 for name, module in model.named_modules(): if name in self._original_forwards: module.forward = self._original_forwards[name] count += 1 self._original_forwards.clear() return count def _make_patched_forward(self, module): """Create a patched forward function for an RMSNorm module.""" cuda_fn = self.cuda_rmsnorm cuda_fn_no_weight = self.cuda_rmsnorm_no_weight def patched_forward(hidden_states: torch.Tensor) -> torch.Tensor: # Handle the case where weight is None if module.weight is None: if cuda_fn_no_weight is not None: return cuda_fn_no_weight(hidden_states, module.eps) else: # Fallback: manual RMSNorm without weight variance = hidden_states.pow(2).mean(-1, keepdim=True) return hidden_states * torch.rsqrt(variance + module.eps) else: if cuda_fn is not None: return cuda_fn(hidden_states, module.weight, module.eps) else: # Fallback: manual RMSNorm with weight variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + module.eps) return hidden_states * module.weight return patched_forward ``` ## Kernel Injection Function The main injection function ties everything together: ```python import torch import torch.nn as nn from typing import Optional, Dict, Any def inject_custom_kernels( model: nn.Module, kernel_config: Optional[Dict[str, Any]] = None, ) -> nn.Module: """ Inject optimized CUDA kernels into a diffusers model. This function patches: 1. RMSNorm layers -> custom CUDA RMSNorm 2. Attention processors -> custom attention processor 3. Activation functions -> custom CUDA activations Args: model: A diffusers model (e.g., pipe.transformer) kernel_config: Optional configuration dict with keys: - 'rmsnorm': bool (default True) - 'attention': bool (default True) - 'activation': bool (default True) - 'rope': bool (default True) Returns: The patched model (modified in place) """ config = { 'rmsnorm': True, 'attention': True, 'activation': True, 'rope': True, } if kernel_config: config.update(kernel_config) stats = {'rmsnorm': 0, 'attention': 0, 'activation': 0} # Step 1: Patch RMSNorm if config['rmsnorm']: patcher = RMSNormPatcher( cuda_rmsnorm_fn=cuda_rmsnorm, cuda_rmsnorm_no_weight_fn=cuda_rmsnorm_no_weight, ) stats['rmsnorm'] = patcher.patch(model) # Step 2: Patch attention processors if config['attention']: for name, module in model.named_modules(): if hasattr(module, 'set_processor'): processor = OptimizedLTXVideoAttnProcessor( use_custom_rope=config['rope'], ) module.set_processor(processor) stats['attention'] += 1 # Step 3: Patch activation functions if config['activation']: stats['activation'] = _patch_activations(model) print(f"Kernel injection complete:") print(f" RMSNorm layers patched: {stats['rmsnorm']}") print(f" Attention processors patched: {stats['attention']}") print(f" Activation functions patched: {stats['activation']}") return model def _patch_activations(model: nn.Module) -> int: """Patch activation functions with custom CUDA kernels.""" count = 0 for name, module in model.named_modules(): # Check for GELU activation if isinstance(module, nn.GELU): original_forward = module.forward def make_gelu_forward(): def forward(input): return cuda_gelu(input) return forward module.forward = make_gelu_forward() count += 1 return count ``` ## Model-Specific Differences Different diffusion models have different architectures. Here are the key differences: ### LTX-Video ```python # Architecture: Transformer-based video model # Normalization: RMSNorm (diffusers) # Activation: GELU (plain, NOT GEGLU) # Attention: Standard multi-head with RoPE # Positional encoding: 3D RoPE (spatial + temporal) # Weight: RMSNorm weight MAY be None def inject_ltx_video(pipe): model = pipe.transformer inject_custom_kernels(model, { 'rmsnorm': True, 'attention': True, 'activation': True, # Patches GELU 'rope': True, }) ``` ### Stable Diffusion 3 (SD3) ```python # Architecture: MMDiT (Multi-Modal Diffusion Transformer) # Normalization: RMSNorm and AdaLayerNorm # Activation: GEGLU (gated, NOT plain GELU) # Attention: Joint attention (text + image in same sequence) # Positional encoding: Learned + RoPE def inject_sd3(pipe): model = pipe.transformer inject_custom_kernels(model, { 'rmsnorm': True, 'attention': True, 'activation': True, # Patches GEGLU 'rope': True, }) # SD3-specific: Also patch AdaLayerNorm if supported _patch_adalayernorm(model) ``` ### FLUX ```python # Architecture: Similar to SD3 (MMDiT variant) # Normalization: RMSNorm # Activation: GEGLU # Attention: Joint attention with different block structure # Positional encoding: RoPE # Note: FLUX has single-stream and double-stream blocks def inject_flux(pipe): model = pipe.transformer inject_custom_kernels(model, { 'rmsnorm': True, 'attention': True, 'activation': True, # Patches GEGLU 'rope': True, }) ``` ### Comparison Table | Feature | LTX-Video | SD3 | FLUX | |---|---|---|---| | Normalization | RMSNorm | RMSNorm + AdaLN | RMSNorm | | Activation | GELU | GEGLU | GEGLU | | Attention type | Standard | Joint (MMDiT) | Joint (MMDiT) | | RoPE | 3D (spatial+temporal) | 2D (spatial) | 2D (spatial) | | Weight may be None | Yes | Rare | No | | `set_processor` | Yes | Yes | Yes | | Block structure | Uniform | Uniform | Single + Double stream | | Hidden sizes | 1024-2048 | 1536-4096 | 3072 | ## Verification Steps After injecting kernels, verify correctness and performance. ### Step 1: Numerical Correctness ```python import torch def verify_correctness(pipe, rtol=1e-2, atol=1e-3): """ Verify that custom kernels produce numerically correct output. """ device = "cuda" dtype = torch.bfloat16 # Generate reference output without custom kernels torch.manual_seed(42) pipe_ref = load_pipeline() # Fresh pipeline pipe_ref.to(device) with torch.no_grad(): ref_output = pipe_ref( "a cat sitting on a mat", num_inference_steps=2, output_type="latent", generator=torch.Generator(device).manual_seed(42), ).images # Generate output with custom kernels torch.manual_seed(42) inject_custom_kernels(pipe.transformer) with torch.no_grad(): test_output = pipe( "a cat sitting on a mat", num_inference_steps=2, output_type="latent", generator=torch.Generator(device).manual_seed(42), ).images # Compare max_diff = (ref_output - test_output).abs().max().item() mean_diff = (ref_output - test_output).abs().mean().item() print(f"Max absolute difference: {max_diff:.6f}") print(f"Mean absolute difference: {mean_diff:.6f}") is_close = torch.allclose(ref_output, test_output, rtol=rtol, atol=atol) print(f"Numerically close (rtol={rtol}, atol={atol}): {is_close}") return is_close ``` ### Step 2: Module-Level Verification ```python def verify_rmsnorm(hidden_size=2048, batch_size=4, eps=1e-6): """Verify custom RMSNorm against PyTorch reference.""" from diffusers.models.normalization import RMSNorm # Create reference module ref_norm = RMSNorm(hidden_size, eps=eps).cuda().to(torch.bfloat16) # Test input x = torch.randn(batch_size, 128, hidden_size, dtype=torch.bfloat16, device="cuda") # Reference output with torch.no_grad(): ref_out = ref_norm(x) # Custom kernel output with torch.no_grad(): custom_out = cuda_rmsnorm(x, ref_norm.weight, eps) max_diff = (ref_out - custom_out).abs().max().item() print(f"RMSNorm max diff: {max_diff:.8f}") assert max_diff < 1e-2, f"RMSNorm verification failed: max_diff={max_diff}" print("RMSNorm verification PASSED") ``` ### Step 3: Performance Verification ```python import time import torch def verify_performance(pipe, num_runs=5, num_steps=20): """ Verify that custom kernels provide a speedup. """ prompt = "A beautiful sunset over the ocean, 4K, detailed" generator = torch.Generator("cuda").manual_seed(42) # Warmup pipe(prompt, num_inference_steps=2, generator=generator) # Benchmark times = [] for i in range(num_runs): torch.cuda.synchronize() start = time.perf_counter() pipe( prompt, num_inference_steps=num_steps, generator=torch.Generator("cuda").manual_seed(42), ) torch.cuda.synchronize() end = time.perf_counter() times.append(end - start) avg_time = sum(times) / len(times) std_time = (sum((t - avg_time) ** 2 for t in times) / len(times)) ** 0.5 print(f"Average time: {avg_time:.3f}s +/- {std_time:.3f}s") print(f"Per-step time: {avg_time / num_steps * 1000:.1f}ms") return avg_time ``` ### Step 4: Edge Case Testing ```python def test_edge_cases(): """Test edge cases that commonly cause issues.""" # Test 1: RMSNorm with None weight print("Test 1: RMSNorm with None weight...") x = torch.randn(2, 64, 2048, dtype=torch.bfloat16, device="cuda") try: result = cuda_rmsnorm_no_weight(x, 1e-6) assert result.shape == x.shape print(" PASSED") except Exception as e: print(f" FAILED: {e}") # Test 2: Non-contiguous input print("Test 2: Non-contiguous input...") x = torch.randn(2, 2048, 64, dtype=torch.bfloat16, device="cuda").transpose(1, 2) assert not x.is_contiguous() try: result = cuda_rmsnorm(x.contiguous(), weight, 1e-6) print(" PASSED") except Exception as e: print(f" FAILED: {e}") # Test 3: Very small hidden size print("Test 3: Small hidden size (64)...") x = torch.randn(2, 128, 64, dtype=torch.bfloat16, device="cuda") w = torch.randn(64, dtype=torch.bfloat16, device="cuda") try: result = cuda_rmsnorm(x, w, 1e-6) assert result.shape == x.shape print(" PASSED") except Exception as e: print(f" FAILED: {e}") # Test 4: Very large hidden size print("Test 4: Large hidden size (8192)...") x = torch.randn(1, 32, 8192, dtype=torch.bfloat16, device="cuda") w = torch.randn(8192, dtype=torch.bfloat16, device="cuda") try: result = cuda_rmsnorm(x, w, 1e-6) assert result.shape == x.shape print(" PASSED") except Exception as e: print(f" FAILED: {e}") # Test 5: Single element batch print("Test 5: Single element batch...") x = torch.randn(1, 1, 2048, dtype=torch.bfloat16, device="cuda") w = torch.randn(2048, dtype=torch.bfloat16, device="cuda") try: result = cuda_rmsnorm(x, w, 1e-6) assert result.shape == x.shape print(" PASSED") except Exception as e: print(f" FAILED: {e}") print("\nAll edge case tests complete.") ``` ## Complete Integration Example ```python import torch from diffusers import LTXPipeline def main(): # Load pipeline pipe = LTXPipeline.from_pretrained( "Lightricks/LTX-Video", torch_dtype=torch.bfloat16, ) # Step 1: Inject custom kernels BEFORE moving to device or enabling offloading inject_custom_kernels(pipe.transformer) # Step 2: Move to device or enable offloading pipe.enable_model_cpu_offload() # Step 3: Optionally compile for additional speedup pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead") # Step 4: Run inference output = pipe( "A time-lapse of a flower blooming", num_inference_steps=30, num_frames=16, ) # Step 5: Save output output.frames[0].save("output.mp4") if __name__ == "__main__": main() ```