Spaces:
Sleeping
Sleeping
File size: 6,895 Bytes
88a1dd2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 | #!/usr/bin/env python3
"""
Minimal example: Inject custom CUDA kernels into LTX-Video pipeline.
This script demonstrates the essential pattern for integrating custom CUDA kernels
with diffusers pipelines. For full usage, see examples/ltx_video/generate_video.py.
Key lessons:
1. Check if RMSNorm has weight (elementwise_affine may be False)
2. Use type(module).__name__ not isinstance() to detect diffusers modules
3. LTX-Video uses GELU, not GEGLU - check your target model
4. Inject kernels AFTER loading to CUDA, BEFORE CPU offloading
Usage:
cd examples/ltx_video
uv pip install -e . # Build kernels first
python ../../.claude/skills/h100-diffusers-kernels/references/ltx_kernel_injection_example.py
"""
import sys
from typing import Optional, Tuple
import torch
import torch.nn as nn
sys.path.insert(0, "torch-ext")
from ltx_kernels import rmsnorm
class OptimizedLTXVideoAttnProcessor:
"""
Attention processor using custom rmsnorm kernel for Q/K normalization.
NOTE: attn.norm_q and attn.norm_k HAVE weights (elementwise_affine=True).
"""
def __call__(
self,
attn,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
from diffusers.models.transformers.transformer_ltx import apply_rotary_emb
from diffusers.models.attention_dispatch import dispatch_attention_fn
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None
else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(
attention_mask, sequence_length, batch_size
)
attention_mask = attention_mask.view(
batch_size, attn.heads, -1, attention_mask.shape[-1]
)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = rmsnorm(query, attn.norm_q.weight, eps=attn.norm_q.eps)
key = rmsnorm(key, attn.norm_k.weight, eps=attn.norm_k.eps)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
query = query.unflatten(2, (attn.heads, -1))
key = key.unflatten(2, (attn.heads, -1))
value = value.unflatten(2, (attn.heads, -1))
hidden_states = dispatch_attention_fn(
query, key, value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
)
hidden_states = hidden_states.flatten(2, 3).to(query.dtype)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
def patch_rmsnorm_modules(model: nn.Module) -> int:
"""Patch all RMSNorm modules to use custom CUDA kernel."""
patched_count = 0
for name, module in model.named_modules():
if type(module).__name__ == 'RMSNorm':
eps = getattr(module, 'eps', 1e-6)
has_weight = hasattr(module, 'weight') and module.weight is not None
if has_weight:
def make_patched_forward_with_weight(mod, epsilon):
def patched_forward(x):
return rmsnorm(x, mod.weight, eps=epsilon)
return patched_forward
module.forward = make_patched_forward_with_weight(module, eps)
else:
def make_patched_forward_no_weight(epsilon):
def patched_forward(x):
weight = torch.ones(x.shape[-1], device=x.device, dtype=x.dtype)
return rmsnorm(x, weight, eps=epsilon)
return patched_forward
module.forward = make_patched_forward_no_weight(eps)
patched_count += 1
return patched_count
def inject_optimized_kernels(pipe) -> dict:
"""Inject custom CUDA kernels into the LTX-Video pipeline."""
stats = {'attention_processors': 0, 'rmsnorm_modules': 0}
if not hasattr(pipe, 'transformer'):
print("WARNING: Pipeline has no 'transformer' attribute!")
return stats
transformer = pipe.transformer
for name, module in transformer.named_modules():
if hasattr(module, 'set_processor') and hasattr(module, 'processor'):
module.set_processor(OptimizedLTXVideoAttnProcessor())
stats['attention_processors'] += 1
stats['rmsnorm_modules'] = patch_rmsnorm_modules(transformer)
return stats
def main():
from diffusers import LTXPipeline
from diffusers.utils import export_to_video
print("=" * 60)
print("LTX-Video Kernel Injection Example")
print("=" * 60)
print("\n1. Loading pipeline...")
pipe = LTXPipeline.from_pretrained(
"Lightricks/LTX-Video",
torch_dtype=torch.bfloat16,
)
pipe.to("cuda")
print("\n2. Injecting optimized CUDA kernels...")
stats = inject_optimized_kernels(pipe)
print(f" Attention processors replaced: {stats['attention_processors']}")
print(f" RMSNorm modules patched: {stats['rmsnorm_modules']}")
print("\n3. Verifying injection...")
for name, module in pipe.transformer.named_modules():
if hasattr(module, 'processor'):
processor_name = type(module.processor).__name__
assert processor_name == 'OptimizedLTXVideoAttnProcessor', \
f"Expected OptimizedLTXVideoAttnProcessor, got {processor_name}"
print(f" Attention processor: {processor_name}")
break
x = torch.randn(1, 10, 2048, device='cuda', dtype=torch.bfloat16)
for name, module in pipe.transformer.named_modules():
if type(module).__name__ == 'RMSNorm':
out = module(x)
print(f" RMSNorm forward: {x.shape} -> {out.shape}")
break
print("\n4. Enabling CPU offloading...")
pipe.enable_model_cpu_offload()
print("\n5. Generating test video (9 frames, 5 steps)...")
output = pipe(
prompt="A cat sleeping in warm sunlight",
num_frames=9,
height=480,
width=704,
num_inference_steps=5,
generator=torch.Generator(device="cuda").manual_seed(42),
)
output_path = "test_output.mp4"
export_to_video(output.frames[0], output_path, fps=8)
print(f"\nVideo saved to: {output_path}")
print("\n" + "=" * 60)
print("Success! Custom kernels are being used.")
print("=" * 60)
if __name__ == "__main__":
main()
|