Hanrui / sglang /scripts /ci /amd /amd_ci_warmup_aiter.py
Lekr0's picture
Add files using upload-large-folder tool
61ba51e verified
#!/usr/bin/env python3
"""
Warmup script to pre-build AITER JIT kernels.
This script triggers compilation of commonly used AITER kernels by importing
the relevant modules and calling functions with sample data. This avoids
timeouts during actual tests when kernels need to be compiled on first use.
Run this after clearing pre-built AITER kernels from the Docker image.
"""
import os
import sys
import time
# Ensure AITER is enabled
os.environ["SGLANG_USE_AITER"] = "1"
def warmup_aiter_kernels():
"""Trigger AITER JIT kernel compilation."""
import torch
if not torch.cuda.is_available():
print("CUDA/ROCm not available, skipping AITER warmup")
return
print("=" * 60)
print("AITER JIT Kernel Warmup")
print("=" * 60)
device = torch.device("cuda:0")
start_time = time.time()
# Warmup module_rmsnorm_quant (small module, ~2MB)
# Triggered by rmsnorm2d_fwd when hidden_size <= 8192
try:
print(
"\n[1/5] Warming up module_rmsnorm_quant (rmsnorm2d_fwd, hidden<=8192)..."
)
from aiter import rmsnorm2d_fwd
hidden_size = 4096
batch_size = 512 # Use larger batch to match CUDA graph capture
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
weight = torch.ones(hidden_size, dtype=torch.bfloat16, device=device)
eps = 1e-6
# hidden_size=4096 <= 8192 -> takes rmsnorm() path -> compiles module_rmsnorm_quant
_ = rmsnorm2d_fwd(x, weight, eps)
torch.cuda.synchronize()
print(" module_rmsnorm_quant compiled successfully")
except Exception as e:
print(f" module_rmsnorm_quant warmup failed: {e}")
# Warmup module_rmsnorm (large CK module, ~159MB)
# Triggered by rmsnorm2d_fwd_with_add (always uses CK path)
# NOTE: rmsnorm2d_fwd_with_add signature is:
# rmsnorm2d_fwd_with_add(out, input, residual_in, residual_out, weight, epsilon)
try:
print("\n[2/5] Warming up module_rmsnorm (rmsnorm2d_fwd_with_add, CK path)...")
from aiter import rmsnorm2d_fwd_with_add
hidden_size = 4096
batch_size = 512
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
residual_in = torch.randn(
batch_size, hidden_size, dtype=torch.bfloat16, device=device
)
output = torch.empty_like(x)
residual_out = torch.empty_like(x)
weight = torch.ones(hidden_size, dtype=torch.bfloat16, device=device)
eps = 1e-6
# This triggers JIT compilation of module_rmsnorm (CK kernels)
rmsnorm2d_fwd_with_add(output, x, residual_in, residual_out, weight, eps)
torch.cuda.synchronize()
print(" module_rmsnorm compiled successfully")
except Exception as e:
print(f" module_rmsnorm warmup failed: {e}")
# Warmup module_rmsnorm via rmsnorm2d_fwd with large hidden_size (CK path)
# When hidden_size > 8192, rmsnorm2d_fwd takes the rmsnorm2d_fwd_ck path
# which also uses module_rmsnorm (already compiled in step 2, but this
# ensures the CK rmsnorm2d_fwd path is exercised as well)
try:
print("\n[3/5] Warming up rmsnorm2d_fwd CK path (hidden>8192)...")
from aiter import rmsnorm2d_fwd
hidden_size = 16384 # > 8192 to trigger rmsnorm2d_fwd_ck (module_rmsnorm)
batch_size = 32
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
weight = torch.ones(hidden_size, dtype=torch.bfloat16, device=device)
eps = 1e-6
_ = rmsnorm2d_fwd(x, weight, eps)
torch.cuda.synchronize()
print(" rmsnorm2d_fwd CK path compiled successfully")
except Exception as e:
print(f" rmsnorm2d_fwd CK path warmup skipped: {e}")
# Warmup rotary embedding kernel if available
try:
print("\n[4/5] Warming up rotary embedding kernel...")
from aiter import rotary_embedding
head_size = 128
seq_len = 32
num_heads = 32
positions = torch.arange(seq_len, device=device)
query = torch.randn(
seq_len, num_heads, head_size, dtype=torch.bfloat16, device=device
)
key = torch.randn(
seq_len, num_heads, head_size, dtype=torch.bfloat16, device=device
)
cos = torch.ones(seq_len, head_size // 2, dtype=torch.bfloat16, device=device)
sin = torch.zeros(seq_len, head_size // 2, dtype=torch.bfloat16, device=device)
_ = rotary_embedding(positions, query, key, head_size, cos, sin, True)
torch.cuda.synchronize()
print(" Rotary embedding kernel compiled successfully")
except Exception as e:
print(f" Rotary embedding warmup skipped (may not be available): {e}")
# Warmup activation kernels if available
try:
print("\n[5/5] Warming up activation kernels...")
from aiter import silu_and_mul
hidden_size = 4096
batch_size = 512
x = torch.randn(
batch_size, hidden_size * 2, dtype=torch.bfloat16, device=device
)
out = torch.empty(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
silu_and_mul(out, x)
torch.cuda.synchronize()
print(" Activation kernel compiled successfully")
except Exception as e:
print(f" Activation warmup skipped (may not be available): {e}")
elapsed = time.time() - start_time
print("\n" + "=" * 60)
print(f"AITER warmup completed in {elapsed:.1f}s")
print("=" * 60 + "\n")
if __name__ == "__main__":
warmup_aiter_kernels()