File size: 4,204 Bytes
88a1dd2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#!/usr/bin/env python3
"""
Minimal example: Inject custom CUDA kernels into HuggingFace Transformers models.

This script demonstrates the essential pattern for integrating custom CUDA kernels
with transformers models like LLaMA, Mistral, and Qwen.

Key lessons:
1. Transformers RMSNorm modules always have weights (unlike some diffusers modules)
2. Use 'RMSNorm' substring match to catch LlamaRMSNorm, MistralRMSNorm, etc.
3. Check for 'variance_epsilon' (LLaMA) or 'eps' (others) for epsilon value
4. Use Flash Attention 2 for attention optimization instead of custom processors

Usage:
    cd examples/ltx_video
    uv pip install -e .  # Build kernels first
    python ../../.claude/skills/h100-diffusers-kernels/scripts/transformers_injection_example.py
"""

import sys
import time

import torch
import torch.nn as nn

sys.path.insert(0, "torch-ext")

from ltx_kernels import rmsnorm


def patch_rmsnorm_modules(model: nn.Module) -> int:
    """Patch all RMSNorm modules to use custom CUDA kernel."""
    patched_count = 0

    for name, module in model.named_modules():
        class_name = type(module).__name__

        if 'RMSNorm' in class_name:
            eps = getattr(module, 'variance_epsilon', None)
            if eps is None:
                eps = getattr(module, 'eps', 1e-6)

            has_weight = hasattr(module, 'weight') and module.weight is not None

            if has_weight:
                def make_patched_forward(mod, epsilon):
                    def patched_forward(hidden_states):
                        return rmsnorm(hidden_states, mod.weight, eps=epsilon)
                    return patched_forward
                module.forward = make_patched_forward(module, eps)
                patched_count += 1
            else:
                print(f"WARNING: {name} has no weight, skipping")

    return patched_count


def inject_optimized_kernels(model) -> dict:
    """Inject custom CUDA kernels into a transformers model."""
    stats = {'rmsnorm_modules': 0}
    stats['rmsnorm_modules'] = patch_rmsnorm_modules(model)
    return stats


def main():
    from transformers import AutoModelForCausalLM, AutoTokenizer

    print("=" * 60)
    print("Transformers Kernel Injection Example")
    print("=" * 60)

    model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

    print(f"\n1. Loading model: {model_id}...")
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.bfloat16,
        device_map="cuda"
    )
    tokenizer = AutoTokenizer.from_pretrained(model_id)

    rmsnorm_count = sum(1 for _, m in model.named_modules() if 'RMSNorm' in type(m).__name__)
    print(f"   Found {rmsnorm_count} RMSNorm modules")

    print("\n2. Injecting optimized CUDA kernels...")
    stats = inject_optimized_kernels(model)
    print(f"   RMSNorm modules patched: {stats['rmsnorm_modules']}")

    print("\n3. Verifying injection...")
    x = torch.randn(1, 10, model.config.hidden_size, device='cuda', dtype=torch.bfloat16)
    for name, module in model.named_modules():
        if 'RMSNorm' in type(module).__name__:
            out = module(x)
            print(f"   RMSNorm forward pass: {x.shape} -> {out.shape}")
            break

    print("\n4. Running generation test...")
    prompt = "The capital of France is"
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

    with torch.inference_mode():
        _ = model.generate(**inputs, max_new_tokens=5, do_sample=False)

    num_tokens = 50
    start_time = time.perf_counter()
    with torch.inference_mode():
        outputs = model.generate(
            **inputs,
            max_new_tokens=num_tokens,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id
        )
    end_time = time.perf_counter()

    elapsed = end_time - start_time
    tokens_per_second = num_tokens / elapsed

    print(f"   Prompt: {prompt}")
    print(f"   Output: {tokenizer.decode(outputs[0], skip_special_tokens=True)}")
    print(f"   Generated {num_tokens} tokens in {elapsed:.2f}s ({tokens_per_second:.1f} tokens/s)")

    print("\n" + "=" * 60)
    print("Success! Custom kernels are being used.")
    print("=" * 60)


if __name__ == "__main__":
    main()