File size: 1,582 Bytes
66b6912
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
import triton
import triton.testing
import sys
from torch.utils.benchmark import Timer

sys.path.append("/models/blitz/crates/blitz-kernels/src/cuda")
from ghost_quant import ghost_quant_fp8_kernel

def run_rigorous_quant():
    N = 1024 * 1024 * 16 # 16M elements
    X = torch.randn(N, device="cuda", dtype=torch.float32)
    Y_blitz = torch.empty(N, device="cuda", dtype=torch.int8)
    seed = 42

    # 1. Correctness Check
    def ref_fn(x):
        return x.to(torch.float8_e4m3fn).view(torch.int8)

    # Warmup and launch
    ghost_quant_fp8_kernel[(triton.cdiv(N, 1024),)](X, Y_blitz, seed, N, BLOCK_SIZE=1024)
    y_ref = ref_fn(X)
    
    # Note: We expect small differences due to stochastic rounding simulation vs deterministic cast
    diff = (Y_blitz.float() - y_ref.float()).abs().mean()
    print(f"Correctness (Mean Diff): {diff:.6f}")

    # 2. Rigorous Timing
    # Triton do_bench handles warmup and median/quantiles
    ms_blitz = triton.testing.do_bench(lambda: ghost_quant_fp8_kernel[(triton.cdiv(N, 1024),)](X, Y_blitz, seed, N, BLOCK_SIZE=1024))
    
    # PyTorch Inductor (The Real Competitor)
    compiled_ref = torch.compile(ref_fn, mode="max-autotune")
    compiled_ref(X) # warmup
    ms_inductor = triton.testing.do_bench(lambda: compiled_ref(X))

    print(f"--- RIGOROUS RECEIPT: GHOST QUANT (16M Tokens) ---")
    print(f"H200 Inductor Latency: {ms_inductor:.4f} ms")
    print(f"Blitz Artisan Latency: {ms_blitz:.4f} ms")
    print(f"REAL SPEEDUP: {ms_inductor/ms_blitz:.2f}x")

if __name__ == "__main__":
    run_rigorous_quant()