File size: 5,632 Bytes
61ba51e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
#!/usr/bin/env python3
"""
Warmup script to pre-build AITER JIT kernels.

This script triggers compilation of commonly used AITER kernels by importing
the relevant modules and calling functions with sample data. This avoids
timeouts during actual tests when kernels need to be compiled on first use.

Run this after clearing pre-built AITER kernels from the Docker image.
"""

import os
import sys
import time

# Ensure AITER is enabled
os.environ["SGLANG_USE_AITER"] = "1"


def warmup_aiter_kernels():
    """Trigger AITER JIT kernel compilation."""
    import torch

    if not torch.cuda.is_available():
        print("CUDA/ROCm not available, skipping AITER warmup")
        return

    print("=" * 60)
    print("AITER JIT Kernel Warmup")
    print("=" * 60)

    device = torch.device("cuda:0")
    start_time = time.time()

    # Warmup module_rmsnorm_quant (small module, ~2MB)
    # Triggered by rmsnorm2d_fwd when hidden_size <= 8192
    try:
        print(
            "\n[1/5] Warming up module_rmsnorm_quant (rmsnorm2d_fwd, hidden<=8192)..."
        )
        from aiter import rmsnorm2d_fwd

        hidden_size = 4096
        batch_size = 512  # Use larger batch to match CUDA graph capture
        x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
        weight = torch.ones(hidden_size, dtype=torch.bfloat16, device=device)
        eps = 1e-6

        # hidden_size=4096 <= 8192 -> takes rmsnorm() path -> compiles module_rmsnorm_quant
        _ = rmsnorm2d_fwd(x, weight, eps)
        torch.cuda.synchronize()
        print("   module_rmsnorm_quant compiled successfully")
    except Exception as e:
        print(f"   module_rmsnorm_quant warmup failed: {e}")

    # Warmup module_rmsnorm (large CK module, ~159MB)
    # Triggered by rmsnorm2d_fwd_with_add (always uses CK path)
    # NOTE: rmsnorm2d_fwd_with_add signature is:
    #   rmsnorm2d_fwd_with_add(out, input, residual_in, residual_out, weight, epsilon)
    try:
        print("\n[2/5] Warming up module_rmsnorm (rmsnorm2d_fwd_with_add, CK path)...")
        from aiter import rmsnorm2d_fwd_with_add

        hidden_size = 4096
        batch_size = 512
        x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
        residual_in = torch.randn(
            batch_size, hidden_size, dtype=torch.bfloat16, device=device
        )
        output = torch.empty_like(x)
        residual_out = torch.empty_like(x)
        weight = torch.ones(hidden_size, dtype=torch.bfloat16, device=device)
        eps = 1e-6

        # This triggers JIT compilation of module_rmsnorm (CK kernels)
        rmsnorm2d_fwd_with_add(output, x, residual_in, residual_out, weight, eps)
        torch.cuda.synchronize()
        print("   module_rmsnorm compiled successfully")
    except Exception as e:
        print(f"   module_rmsnorm warmup failed: {e}")

    # Warmup module_rmsnorm via rmsnorm2d_fwd with large hidden_size (CK path)
    # When hidden_size > 8192, rmsnorm2d_fwd takes the rmsnorm2d_fwd_ck path
    # which also uses module_rmsnorm (already compiled in step 2, but this
    # ensures the CK rmsnorm2d_fwd path is exercised as well)
    try:
        print("\n[3/5] Warming up rmsnorm2d_fwd CK path (hidden>8192)...")
        from aiter import rmsnorm2d_fwd

        hidden_size = 16384  # > 8192 to trigger rmsnorm2d_fwd_ck (module_rmsnorm)
        batch_size = 32
        x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
        weight = torch.ones(hidden_size, dtype=torch.bfloat16, device=device)
        eps = 1e-6

        _ = rmsnorm2d_fwd(x, weight, eps)
        torch.cuda.synchronize()
        print("   rmsnorm2d_fwd CK path compiled successfully")
    except Exception as e:
        print(f"   rmsnorm2d_fwd CK path warmup skipped: {e}")

    # Warmup rotary embedding kernel if available
    try:
        print("\n[4/5] Warming up rotary embedding kernel...")
        from aiter import rotary_embedding

        head_size = 128
        seq_len = 32
        num_heads = 32
        positions = torch.arange(seq_len, device=device)
        query = torch.randn(
            seq_len, num_heads, head_size, dtype=torch.bfloat16, device=device
        )
        key = torch.randn(
            seq_len, num_heads, head_size, dtype=torch.bfloat16, device=device
        )
        cos = torch.ones(seq_len, head_size // 2, dtype=torch.bfloat16, device=device)
        sin = torch.zeros(seq_len, head_size // 2, dtype=torch.bfloat16, device=device)

        _ = rotary_embedding(positions, query, key, head_size, cos, sin, True)
        torch.cuda.synchronize()
        print("   Rotary embedding kernel compiled successfully")
    except Exception as e:
        print(f"   Rotary embedding warmup skipped (may not be available): {e}")

    # Warmup activation kernels if available
    try:
        print("\n[5/5] Warming up activation kernels...")
        from aiter import silu_and_mul

        hidden_size = 4096
        batch_size = 512
        x = torch.randn(
            batch_size, hidden_size * 2, dtype=torch.bfloat16, device=device
        )
        out = torch.empty(batch_size, hidden_size, dtype=torch.bfloat16, device=device)

        silu_and_mul(out, x)
        torch.cuda.synchronize()
        print("   Activation kernel compiled successfully")
    except Exception as e:
        print(f"   Activation warmup skipped (may not be available): {e}")

    elapsed = time.time() - start_time
    print("\n" + "=" * 60)
    print(f"AITER warmup completed in {elapsed:.1f}s")
    print("=" * 60 + "\n")


if __name__ == "__main__":
    warmup_aiter_kernels()