Spaces:
Sleeping
Sleeping
File size: 4,204 Bytes
88a1dd2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 | #!/usr/bin/env python3
"""
Minimal example: Inject custom CUDA kernels into HuggingFace Transformers models.
This script demonstrates the essential pattern for integrating custom CUDA kernels
with transformers models like LLaMA, Mistral, and Qwen.
Key lessons:
1. Transformers RMSNorm modules always have weights (unlike some diffusers modules)
2. Use 'RMSNorm' substring match to catch LlamaRMSNorm, MistralRMSNorm, etc.
3. Check for 'variance_epsilon' (LLaMA) or 'eps' (others) for epsilon value
4. Use Flash Attention 2 for attention optimization instead of custom processors
Usage:
cd examples/ltx_video
uv pip install -e . # Build kernels first
python ../../.claude/skills/h100-diffusers-kernels/scripts/transformers_injection_example.py
"""
import sys
import time
import torch
import torch.nn as nn
sys.path.insert(0, "torch-ext")
from ltx_kernels import rmsnorm
def patch_rmsnorm_modules(model: nn.Module) -> int:
"""Patch all RMSNorm modules to use custom CUDA kernel."""
patched_count = 0
for name, module in model.named_modules():
class_name = type(module).__name__
if 'RMSNorm' in class_name:
eps = getattr(module, 'variance_epsilon', None)
if eps is None:
eps = getattr(module, 'eps', 1e-6)
has_weight = hasattr(module, 'weight') and module.weight is not None
if has_weight:
def make_patched_forward(mod, epsilon):
def patched_forward(hidden_states):
return rmsnorm(hidden_states, mod.weight, eps=epsilon)
return patched_forward
module.forward = make_patched_forward(module, eps)
patched_count += 1
else:
print(f"WARNING: {name} has no weight, skipping")
return patched_count
def inject_optimized_kernels(model) -> dict:
"""Inject custom CUDA kernels into a transformers model."""
stats = {'rmsnorm_modules': 0}
stats['rmsnorm_modules'] = patch_rmsnorm_modules(model)
return stats
def main():
from transformers import AutoModelForCausalLM, AutoTokenizer
print("=" * 60)
print("Transformers Kernel Injection Example")
print("=" * 60)
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
print(f"\n1. Loading model: {model_id}...")
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="cuda"
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
rmsnorm_count = sum(1 for _, m in model.named_modules() if 'RMSNorm' in type(m).__name__)
print(f" Found {rmsnorm_count} RMSNorm modules")
print("\n2. Injecting optimized CUDA kernels...")
stats = inject_optimized_kernels(model)
print(f" RMSNorm modules patched: {stats['rmsnorm_modules']}")
print("\n3. Verifying injection...")
x = torch.randn(1, 10, model.config.hidden_size, device='cuda', dtype=torch.bfloat16)
for name, module in model.named_modules():
if 'RMSNorm' in type(module).__name__:
out = module(x)
print(f" RMSNorm forward pass: {x.shape} -> {out.shape}")
break
print("\n4. Running generation test...")
prompt = "The capital of France is"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
with torch.inference_mode():
_ = model.generate(**inputs, max_new_tokens=5, do_sample=False)
num_tokens = 50
start_time = time.perf_counter()
with torch.inference_mode():
outputs = model.generate(
**inputs,
max_new_tokens=num_tokens,
do_sample=False,
pad_token_id=tokenizer.eos_token_id
)
end_time = time.perf_counter()
elapsed = end_time - start_time
tokens_per_second = num_tokens / elapsed
print(f" Prompt: {prompt}")
print(f" Output: {tokenizer.decode(outputs[0], skip_special_tokens=True)}")
print(f" Generated {num_tokens} tokens in {elapsed:.2f}s ({tokens_per_second:.1f} tokens/s)")
print("\n" + "=" * 60)
print("Success! Custom kernels are being used.")
print("=" * 60)
if __name__ == "__main__":
main()
|