File size: 6,895 Bytes
88a1dd2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
#!/usr/bin/env python3
"""
Minimal example: Inject custom CUDA kernels into LTX-Video pipeline.

This script demonstrates the essential pattern for integrating custom CUDA kernels
with diffusers pipelines. For full usage, see examples/ltx_video/generate_video.py.

Key lessons:
1. Check if RMSNorm has weight (elementwise_affine may be False)
2. Use type(module).__name__ not isinstance() to detect diffusers modules
3. LTX-Video uses GELU, not GEGLU - check your target model
4. Inject kernels AFTER loading to CUDA, BEFORE CPU offloading

Usage:
    cd examples/ltx_video
    uv pip install -e .  # Build kernels first
    python ../../.claude/skills/h100-diffusers-kernels/references/ltx_kernel_injection_example.py
"""

import sys
from typing import Optional, Tuple

import torch
import torch.nn as nn

sys.path.insert(0, "torch-ext")

from ltx_kernels import rmsnorm


class OptimizedLTXVideoAttnProcessor:
    """
    Attention processor using custom rmsnorm kernel for Q/K normalization.
    NOTE: attn.norm_q and attn.norm_k HAVE weights (elementwise_affine=True).
    """

    def __call__(
        self,
        attn,
        hidden_states: torch.Tensor,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
    ) -> torch.Tensor:
        from diffusers.models.transformers.transformer_ltx import apply_rotary_emb
        from diffusers.models.attention_dispatch import dispatch_attention_fn

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None
            else encoder_hidden_states.shape
        )

        if attention_mask is not None:
            attention_mask = attn.prepare_attention_mask(
                attention_mask, sequence_length, batch_size
            )
            attention_mask = attention_mask.view(
                batch_size, attn.heads, -1, attention_mask.shape[-1]
            )

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states

        query = attn.to_q(hidden_states)
        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        query = rmsnorm(query, attn.norm_q.weight, eps=attn.norm_q.eps)
        key = rmsnorm(key, attn.norm_k.weight, eps=attn.norm_k.eps)

        if image_rotary_emb is not None:
            query = apply_rotary_emb(query, image_rotary_emb)
            key = apply_rotary_emb(key, image_rotary_emb)

        query = query.unflatten(2, (attn.heads, -1))
        key = key.unflatten(2, (attn.heads, -1))
        value = value.unflatten(2, (attn.heads, -1))

        hidden_states = dispatch_attention_fn(
            query, key, value,
            attn_mask=attention_mask,
            dropout_p=0.0,
            is_causal=False,
        )
        hidden_states = hidden_states.flatten(2, 3).to(query.dtype)

        hidden_states = attn.to_out[0](hidden_states)
        hidden_states = attn.to_out[1](hidden_states)
        return hidden_states


def patch_rmsnorm_modules(model: nn.Module) -> int:
    """Patch all RMSNorm modules to use custom CUDA kernel."""
    patched_count = 0

    for name, module in model.named_modules():
        if type(module).__name__ == 'RMSNorm':
            eps = getattr(module, 'eps', 1e-6)
            has_weight = hasattr(module, 'weight') and module.weight is not None

            if has_weight:
                def make_patched_forward_with_weight(mod, epsilon):
                    def patched_forward(x):
                        return rmsnorm(x, mod.weight, eps=epsilon)
                    return patched_forward
                module.forward = make_patched_forward_with_weight(module, eps)
            else:
                def make_patched_forward_no_weight(epsilon):
                    def patched_forward(x):
                        weight = torch.ones(x.shape[-1], device=x.device, dtype=x.dtype)
                        return rmsnorm(x, weight, eps=epsilon)
                    return patched_forward
                module.forward = make_patched_forward_no_weight(eps)

            patched_count += 1

    return patched_count


def inject_optimized_kernels(pipe) -> dict:
    """Inject custom CUDA kernels into the LTX-Video pipeline."""
    stats = {'attention_processors': 0, 'rmsnorm_modules': 0}

    if not hasattr(pipe, 'transformer'):
        print("WARNING: Pipeline has no 'transformer' attribute!")
        return stats

    transformer = pipe.transformer

    for name, module in transformer.named_modules():
        if hasattr(module, 'set_processor') and hasattr(module, 'processor'):
            module.set_processor(OptimizedLTXVideoAttnProcessor())
            stats['attention_processors'] += 1

    stats['rmsnorm_modules'] = patch_rmsnorm_modules(transformer)

    return stats


def main():
    from diffusers import LTXPipeline
    from diffusers.utils import export_to_video

    print("=" * 60)
    print("LTX-Video Kernel Injection Example")
    print("=" * 60)

    print("\n1. Loading pipeline...")
    pipe = LTXPipeline.from_pretrained(
        "Lightricks/LTX-Video",
        torch_dtype=torch.bfloat16,
    )
    pipe.to("cuda")

    print("\n2. Injecting optimized CUDA kernels...")
    stats = inject_optimized_kernels(pipe)
    print(f"   Attention processors replaced: {stats['attention_processors']}")
    print(f"   RMSNorm modules patched: {stats['rmsnorm_modules']}")

    print("\n3. Verifying injection...")
    for name, module in pipe.transformer.named_modules():
        if hasattr(module, 'processor'):
            processor_name = type(module.processor).__name__
            assert processor_name == 'OptimizedLTXVideoAttnProcessor', \
                f"Expected OptimizedLTXVideoAttnProcessor, got {processor_name}"
            print(f"   Attention processor: {processor_name}")
            break

    x = torch.randn(1, 10, 2048, device='cuda', dtype=torch.bfloat16)
    for name, module in pipe.transformer.named_modules():
        if type(module).__name__ == 'RMSNorm':
            out = module(x)
            print(f"   RMSNorm forward: {x.shape} -> {out.shape}")
            break

    print("\n4. Enabling CPU offloading...")
    pipe.enable_model_cpu_offload()

    print("\n5. Generating test video (9 frames, 5 steps)...")
    output = pipe(
        prompt="A cat sleeping in warm sunlight",
        num_frames=9,
        height=480,
        width=704,
        num_inference_steps=5,
        generator=torch.Generator(device="cuda").manual_seed(42),
    )

    output_path = "test_output.mp4"
    export_to_video(output.frames[0], output_path, fps=8)
    print(f"\nVideo saved to: {output_path}")

    print("\n" + "=" * 60)
    print("Success! Custom kernels are being used.")
    print("=" * 60)


if __name__ == "__main__":
    main()