test / skill_example /scripts /transformers_injection_example.py
Jack-Khuu
Demo
88a1dd2
#!/usr/bin/env python3
"""
Minimal example: Inject custom CUDA kernels into HuggingFace Transformers models.
This script demonstrates the essential pattern for integrating custom CUDA kernels
with transformers models like LLaMA, Mistral, and Qwen.
Key lessons:
1. Transformers RMSNorm modules always have weights (unlike some diffusers modules)
2. Use 'RMSNorm' substring match to catch LlamaRMSNorm, MistralRMSNorm, etc.
3. Check for 'variance_epsilon' (LLaMA) or 'eps' (others) for epsilon value
4. Use Flash Attention 2 for attention optimization instead of custom processors
Usage:
cd examples/ltx_video
uv pip install -e . # Build kernels first
python ../../.claude/skills/h100-diffusers-kernels/scripts/transformers_injection_example.py
"""
import sys
import time
import torch
import torch.nn as nn
sys.path.insert(0, "torch-ext")
from ltx_kernels import rmsnorm
def patch_rmsnorm_modules(model: nn.Module) -> int:
"""Patch all RMSNorm modules to use custom CUDA kernel."""
patched_count = 0
for name, module in model.named_modules():
class_name = type(module).__name__
if 'RMSNorm' in class_name:
eps = getattr(module, 'variance_epsilon', None)
if eps is None:
eps = getattr(module, 'eps', 1e-6)
has_weight = hasattr(module, 'weight') and module.weight is not None
if has_weight:
def make_patched_forward(mod, epsilon):
def patched_forward(hidden_states):
return rmsnorm(hidden_states, mod.weight, eps=epsilon)
return patched_forward
module.forward = make_patched_forward(module, eps)
patched_count += 1
else:
print(f"WARNING: {name} has no weight, skipping")
return patched_count
def inject_optimized_kernels(model) -> dict:
"""Inject custom CUDA kernels into a transformers model."""
stats = {'rmsnorm_modules': 0}
stats['rmsnorm_modules'] = patch_rmsnorm_modules(model)
return stats
def main():
from transformers import AutoModelForCausalLM, AutoTokenizer
print("=" * 60)
print("Transformers Kernel Injection Example")
print("=" * 60)
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
print(f"\n1. Loading model: {model_id}...")
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="cuda"
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
rmsnorm_count = sum(1 for _, m in model.named_modules() if 'RMSNorm' in type(m).__name__)
print(f" Found {rmsnorm_count} RMSNorm modules")
print("\n2. Injecting optimized CUDA kernels...")
stats = inject_optimized_kernels(model)
print(f" RMSNorm modules patched: {stats['rmsnorm_modules']}")
print("\n3. Verifying injection...")
x = torch.randn(1, 10, model.config.hidden_size, device='cuda', dtype=torch.bfloat16)
for name, module in model.named_modules():
if 'RMSNorm' in type(module).__name__:
out = module(x)
print(f" RMSNorm forward pass: {x.shape} -> {out.shape}")
break
print("\n4. Running generation test...")
prompt = "The capital of France is"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
with torch.inference_mode():
_ = model.generate(**inputs, max_new_tokens=5, do_sample=False)
num_tokens = 50
start_time = time.perf_counter()
with torch.inference_mode():
outputs = model.generate(
**inputs,
max_new_tokens=num_tokens,
do_sample=False,
pad_token_id=tokenizer.eos_token_id
)
end_time = time.perf_counter()
elapsed = end_time - start_time
tokens_per_second = num_tokens / elapsed
print(f" Prompt: {prompt}")
print(f" Output: {tokenizer.decode(outputs[0], skip_special_tokens=True)}")
print(f" Generated {num_tokens} tokens in {elapsed:.2f}s ({tokens_per_second:.1f} tokens/s)")
print("\n" + "=" * 60)
print("Success! Custom kernels are being used.")
print("=" * 60)
if __name__ == "__main__":
main()