Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Minimal example: Inject custom CUDA kernels into LTX-Video pipeline. | |
| This script demonstrates the essential pattern for integrating custom CUDA kernels | |
| with diffusers pipelines. For full usage, see examples/ltx_video/generate_video.py. | |
| Key lessons: | |
| 1. Check if RMSNorm has weight (elementwise_affine may be False) | |
| 2. Use type(module).__name__ not isinstance() to detect diffusers modules | |
| 3. LTX-Video uses GELU, not GEGLU - check your target model | |
| 4. Inject kernels AFTER loading to CUDA, BEFORE CPU offloading | |
| Usage: | |
| cd examples/ltx_video | |
| uv pip install -e . # Build kernels first | |
| python ../../.claude/skills/h100-diffusers-kernels/references/ltx_kernel_injection_example.py | |
| """ | |
| import sys | |
| from typing import Optional, Tuple | |
| import torch | |
| import torch.nn as nn | |
| sys.path.insert(0, "torch-ext") | |
| from ltx_kernels import rmsnorm | |
| class OptimizedLTXVideoAttnProcessor: | |
| """ | |
| Attention processor using custom rmsnorm kernel for Q/K normalization. | |
| NOTE: attn.norm_q and attn.norm_k HAVE weights (elementwise_affine=True). | |
| """ | |
| def __call__( | |
| self, | |
| attn, | |
| hidden_states: torch.Tensor, | |
| encoder_hidden_states: Optional[torch.Tensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | |
| ) -> torch.Tensor: | |
| from diffusers.models.transformers.transformer_ltx import apply_rotary_emb | |
| from diffusers.models.attention_dispatch import dispatch_attention_fn | |
| batch_size, sequence_length, _ = ( | |
| hidden_states.shape if encoder_hidden_states is None | |
| else encoder_hidden_states.shape | |
| ) | |
| if attention_mask is not None: | |
| attention_mask = attn.prepare_attention_mask( | |
| attention_mask, sequence_length, batch_size | |
| ) | |
| attention_mask = attention_mask.view( | |
| batch_size, attn.heads, -1, attention_mask.shape[-1] | |
| ) | |
| if encoder_hidden_states is None: | |
| encoder_hidden_states = hidden_states | |
| query = attn.to_q(hidden_states) | |
| key = attn.to_k(encoder_hidden_states) | |
| value = attn.to_v(encoder_hidden_states) | |
| query = rmsnorm(query, attn.norm_q.weight, eps=attn.norm_q.eps) | |
| key = rmsnorm(key, attn.norm_k.weight, eps=attn.norm_k.eps) | |
| if image_rotary_emb is not None: | |
| query = apply_rotary_emb(query, image_rotary_emb) | |
| key = apply_rotary_emb(key, image_rotary_emb) | |
| query = query.unflatten(2, (attn.heads, -1)) | |
| key = key.unflatten(2, (attn.heads, -1)) | |
| value = value.unflatten(2, (attn.heads, -1)) | |
| hidden_states = dispatch_attention_fn( | |
| query, key, value, | |
| attn_mask=attention_mask, | |
| dropout_p=0.0, | |
| is_causal=False, | |
| ) | |
| hidden_states = hidden_states.flatten(2, 3).to(query.dtype) | |
| hidden_states = attn.to_out[0](hidden_states) | |
| hidden_states = attn.to_out[1](hidden_states) | |
| return hidden_states | |
| def patch_rmsnorm_modules(model: nn.Module) -> int: | |
| """Patch all RMSNorm modules to use custom CUDA kernel.""" | |
| patched_count = 0 | |
| for name, module in model.named_modules(): | |
| if type(module).__name__ == 'RMSNorm': | |
| eps = getattr(module, 'eps', 1e-6) | |
| has_weight = hasattr(module, 'weight') and module.weight is not None | |
| if has_weight: | |
| def make_patched_forward_with_weight(mod, epsilon): | |
| def patched_forward(x): | |
| return rmsnorm(x, mod.weight, eps=epsilon) | |
| return patched_forward | |
| module.forward = make_patched_forward_with_weight(module, eps) | |
| else: | |
| def make_patched_forward_no_weight(epsilon): | |
| def patched_forward(x): | |
| weight = torch.ones(x.shape[-1], device=x.device, dtype=x.dtype) | |
| return rmsnorm(x, weight, eps=epsilon) | |
| return patched_forward | |
| module.forward = make_patched_forward_no_weight(eps) | |
| patched_count += 1 | |
| return patched_count | |
| def inject_optimized_kernels(pipe) -> dict: | |
| """Inject custom CUDA kernels into the LTX-Video pipeline.""" | |
| stats = {'attention_processors': 0, 'rmsnorm_modules': 0} | |
| if not hasattr(pipe, 'transformer'): | |
| print("WARNING: Pipeline has no 'transformer' attribute!") | |
| return stats | |
| transformer = pipe.transformer | |
| for name, module in transformer.named_modules(): | |
| if hasattr(module, 'set_processor') and hasattr(module, 'processor'): | |
| module.set_processor(OptimizedLTXVideoAttnProcessor()) | |
| stats['attention_processors'] += 1 | |
| stats['rmsnorm_modules'] = patch_rmsnorm_modules(transformer) | |
| return stats | |
| def main(): | |
| from diffusers import LTXPipeline | |
| from diffusers.utils import export_to_video | |
| print("=" * 60) | |
| print("LTX-Video Kernel Injection Example") | |
| print("=" * 60) | |
| print("\n1. Loading pipeline...") | |
| pipe = LTXPipeline.from_pretrained( | |
| "Lightricks/LTX-Video", | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| pipe.to("cuda") | |
| print("\n2. Injecting optimized CUDA kernels...") | |
| stats = inject_optimized_kernels(pipe) | |
| print(f" Attention processors replaced: {stats['attention_processors']}") | |
| print(f" RMSNorm modules patched: {stats['rmsnorm_modules']}") | |
| print("\n3. Verifying injection...") | |
| for name, module in pipe.transformer.named_modules(): | |
| if hasattr(module, 'processor'): | |
| processor_name = type(module.processor).__name__ | |
| assert processor_name == 'OptimizedLTXVideoAttnProcessor', \ | |
| f"Expected OptimizedLTXVideoAttnProcessor, got {processor_name}" | |
| print(f" Attention processor: {processor_name}") | |
| break | |
| x = torch.randn(1, 10, 2048, device='cuda', dtype=torch.bfloat16) | |
| for name, module in pipe.transformer.named_modules(): | |
| if type(module).__name__ == 'RMSNorm': | |
| out = module(x) | |
| print(f" RMSNorm forward: {x.shape} -> {out.shape}") | |
| break | |
| print("\n4. Enabling CPU offloading...") | |
| pipe.enable_model_cpu_offload() | |
| print("\n5. Generating test video (9 frames, 5 steps)...") | |
| output = pipe( | |
| prompt="A cat sleeping in warm sunlight", | |
| num_frames=9, | |
| height=480, | |
| width=704, | |
| num_inference_steps=5, | |
| generator=torch.Generator(device="cuda").manual_seed(42), | |
| ) | |
| output_path = "test_output.mp4" | |
| export_to_video(output.frames[0], output_path, fps=8) | |
| print(f"\nVideo saved to: {output_path}") | |
| print("\n" + "=" * 60) | |
| print("Success! Custom kernels are being used.") | |
| print("=" * 60) | |
| if __name__ == "__main__": | |
| main() | |