#!/usr/bin/env python3 """ Minimal example: Inject custom CUDA kernels into HuggingFace Transformers models. This script demonstrates the essential pattern for integrating custom CUDA kernels with transformers models like LLaMA, Mistral, and Qwen. Key lessons: 1. Transformers RMSNorm modules always have weights (unlike some diffusers modules) 2. Use 'RMSNorm' substring match to catch LlamaRMSNorm, MistralRMSNorm, etc. 3. Check for 'variance_epsilon' (LLaMA) or 'eps' (others) for epsilon value 4. Use Flash Attention 2 for attention optimization instead of custom processors Usage: cd examples/ltx_video uv pip install -e . # Build kernels first python ../../.claude/skills/h100-diffusers-kernels/scripts/transformers_injection_example.py """ import sys import time import torch import torch.nn as nn sys.path.insert(0, "torch-ext") from ltx_kernels import rmsnorm def patch_rmsnorm_modules(model: nn.Module) -> int: """Patch all RMSNorm modules to use custom CUDA kernel.""" patched_count = 0 for name, module in model.named_modules(): class_name = type(module).__name__ if 'RMSNorm' in class_name: eps = getattr(module, 'variance_epsilon', None) if eps is None: eps = getattr(module, 'eps', 1e-6) has_weight = hasattr(module, 'weight') and module.weight is not None if has_weight: def make_patched_forward(mod, epsilon): def patched_forward(hidden_states): return rmsnorm(hidden_states, mod.weight, eps=epsilon) return patched_forward module.forward = make_patched_forward(module, eps) patched_count += 1 else: print(f"WARNING: {name} has no weight, skipping") return patched_count def inject_optimized_kernels(model) -> dict: """Inject custom CUDA kernels into a transformers model.""" stats = {'rmsnorm_modules': 0} stats['rmsnorm_modules'] = patch_rmsnorm_modules(model) return stats def main(): from transformers import AutoModelForCausalLM, AutoTokenizer print("=" * 60) print("Transformers Kernel Injection Example") print("=" * 60) model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" print(f"\n1. Loading model: {model_id}...") model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.bfloat16, device_map="cuda" ) tokenizer = AutoTokenizer.from_pretrained(model_id) rmsnorm_count = sum(1 for _, m in model.named_modules() if 'RMSNorm' in type(m).__name__) print(f" Found {rmsnorm_count} RMSNorm modules") print("\n2. Injecting optimized CUDA kernels...") stats = inject_optimized_kernels(model) print(f" RMSNorm modules patched: {stats['rmsnorm_modules']}") print("\n3. Verifying injection...") x = torch.randn(1, 10, model.config.hidden_size, device='cuda', dtype=torch.bfloat16) for name, module in model.named_modules(): if 'RMSNorm' in type(module).__name__: out = module(x) print(f" RMSNorm forward pass: {x.shape} -> {out.shape}") break print("\n4. Running generation test...") prompt = "The capital of France is" inputs = tokenizer(prompt, return_tensors="pt").to("cuda") with torch.inference_mode(): _ = model.generate(**inputs, max_new_tokens=5, do_sample=False) num_tokens = 50 start_time = time.perf_counter() with torch.inference_mode(): outputs = model.generate( **inputs, max_new_tokens=num_tokens, do_sample=False, pad_token_id=tokenizer.eos_token_id ) end_time = time.perf_counter() elapsed = end_time - start_time tokens_per_second = num_tokens / elapsed print(f" Prompt: {prompt}") print(f" Output: {tokenizer.decode(outputs[0], skip_special_tokens=True)}") print(f" Generated {num_tokens} tokens in {elapsed:.2f}s ({tokens_per_second:.1f} tokens/s)") print("\n" + "=" * 60) print("Success! Custom kernels are being used.") print("=" * 60) if __name__ == "__main__": main()