#!/usr/bin/env python3 """ Minimal example: Inject custom CUDA kernels into LTX-Video pipeline. This script demonstrates the essential pattern for integrating custom CUDA kernels with diffusers pipelines. For full usage, see examples/ltx_video/generate_video.py. Key lessons: 1. Check if RMSNorm has weight (elementwise_affine may be False) 2. Use type(module).__name__ not isinstance() to detect diffusers modules 3. LTX-Video uses GELU, not GEGLU - check your target model 4. Inject kernels AFTER loading to CUDA, BEFORE CPU offloading Usage: cd examples/ltx_video uv pip install -e . # Build kernels first python ../../.claude/skills/h100-diffusers-kernels/references/ltx_kernel_injection_example.py """ import sys from typing import Optional, Tuple import torch import torch.nn as nn sys.path.insert(0, "torch-ext") from ltx_kernels import rmsnorm class OptimizedLTXVideoAttnProcessor: """ Attention processor using custom rmsnorm kernel for Q/K normalization. NOTE: attn.norm_q and attn.norm_k HAVE weights (elementwise_affine=True). """ def __call__( self, attn, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: from diffusers.models.transformers.transformer_ltx import apply_rotary_emb from diffusers.models.attention_dispatch import dispatch_attention_fn batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = attn.prepare_attention_mask( attention_mask, sequence_length, batch_size ) attention_mask = attention_mask.view( batch_size, attn.heads, -1, attention_mask.shape[-1] ) if encoder_hidden_states is None: encoder_hidden_states = hidden_states query = attn.to_q(hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = rmsnorm(query, attn.norm_q.weight, eps=attn.norm_q.eps) key = rmsnorm(key, attn.norm_k.weight, eps=attn.norm_k.eps) if image_rotary_emb is not None: query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) query = query.unflatten(2, (attn.heads, -1)) key = key.unflatten(2, (attn.heads, -1)) value = value.unflatten(2, (attn.heads, -1)) hidden_states = dispatch_attention_fn( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, ) hidden_states = hidden_states.flatten(2, 3).to(query.dtype) hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) return hidden_states def patch_rmsnorm_modules(model: nn.Module) -> int: """Patch all RMSNorm modules to use custom CUDA kernel.""" patched_count = 0 for name, module in model.named_modules(): if type(module).__name__ == 'RMSNorm': eps = getattr(module, 'eps', 1e-6) has_weight = hasattr(module, 'weight') and module.weight is not None if has_weight: def make_patched_forward_with_weight(mod, epsilon): def patched_forward(x): return rmsnorm(x, mod.weight, eps=epsilon) return patched_forward module.forward = make_patched_forward_with_weight(module, eps) else: def make_patched_forward_no_weight(epsilon): def patched_forward(x): weight = torch.ones(x.shape[-1], device=x.device, dtype=x.dtype) return rmsnorm(x, weight, eps=epsilon) return patched_forward module.forward = make_patched_forward_no_weight(eps) patched_count += 1 return patched_count def inject_optimized_kernels(pipe) -> dict: """Inject custom CUDA kernels into the LTX-Video pipeline.""" stats = {'attention_processors': 0, 'rmsnorm_modules': 0} if not hasattr(pipe, 'transformer'): print("WARNING: Pipeline has no 'transformer' attribute!") return stats transformer = pipe.transformer for name, module in transformer.named_modules(): if hasattr(module, 'set_processor') and hasattr(module, 'processor'): module.set_processor(OptimizedLTXVideoAttnProcessor()) stats['attention_processors'] += 1 stats['rmsnorm_modules'] = patch_rmsnorm_modules(transformer) return stats def main(): from diffusers import LTXPipeline from diffusers.utils import export_to_video print("=" * 60) print("LTX-Video Kernel Injection Example") print("=" * 60) print("\n1. Loading pipeline...") pipe = LTXPipeline.from_pretrained( "Lightricks/LTX-Video", torch_dtype=torch.bfloat16, ) pipe.to("cuda") print("\n2. Injecting optimized CUDA kernels...") stats = inject_optimized_kernels(pipe) print(f" Attention processors replaced: {stats['attention_processors']}") print(f" RMSNorm modules patched: {stats['rmsnorm_modules']}") print("\n3. Verifying injection...") for name, module in pipe.transformer.named_modules(): if hasattr(module, 'processor'): processor_name = type(module.processor).__name__ assert processor_name == 'OptimizedLTXVideoAttnProcessor', \ f"Expected OptimizedLTXVideoAttnProcessor, got {processor_name}" print(f" Attention processor: {processor_name}") break x = torch.randn(1, 10, 2048, device='cuda', dtype=torch.bfloat16) for name, module in pipe.transformer.named_modules(): if type(module).__name__ == 'RMSNorm': out = module(x) print(f" RMSNorm forward: {x.shape} -> {out.shape}") break print("\n4. Enabling CPU offloading...") pipe.enable_model_cpu_offload() print("\n5. Generating test video (9 frames, 5 steps)...") output = pipe( prompt="A cat sleeping in warm sunlight", num_frames=9, height=480, width=704, num_inference_steps=5, generator=torch.Generator(device="cuda").manual_seed(42), ) output_path = "test_output.mp4" export_to_video(output.frames[0], output_path, fps=8) print(f"\nVideo saved to: {output_path}") print("\n" + "=" * 60) print("Success! Custom kernels are being used.") print("=" * 60) if __name__ == "__main__": main()