import torch import triton import triton.testing import sys from torch.utils.benchmark import Timer sys.path.append("/models/blitz/crates/blitz-kernels/src/cuda") from ghost_quant import ghost_quant_fp8_kernel def run_rigorous_quant(): N = 1024 * 1024 * 16 # 16M elements X = torch.randn(N, device="cuda", dtype=torch.float32) Y_blitz = torch.empty(N, device="cuda", dtype=torch.int8) seed = 42 # 1. Correctness Check def ref_fn(x): return x.to(torch.float8_e4m3fn).view(torch.int8) # Warmup and launch ghost_quant_fp8_kernel[(triton.cdiv(N, 1024),)](X, Y_blitz, seed, N, BLOCK_SIZE=1024) y_ref = ref_fn(X) # Note: We expect small differences due to stochastic rounding simulation vs deterministic cast diff = (Y_blitz.float() - y_ref.float()).abs().mean() print(f"Correctness (Mean Diff): {diff:.6f}") # 2. Rigorous Timing # Triton do_bench handles warmup and median/quantiles ms_blitz = triton.testing.do_bench(lambda: ghost_quant_fp8_kernel[(triton.cdiv(N, 1024),)](X, Y_blitz, seed, N, BLOCK_SIZE=1024)) # PyTorch Inductor (The Real Competitor) compiled_ref = torch.compile(ref_fn, mode="max-autotune") compiled_ref(X) # warmup ms_inductor = triton.testing.do_bench(lambda: compiled_ref(X)) print(f"--- RIGOROUS RECEIPT: GHOST QUANT (16M Tokens) ---") print(f"H200 Inductor Latency: {ms_inductor:.4f} ms") print(f"Blitz Artisan Latency: {ms_blitz:.4f} ms") print(f"REAL SPEEDUP: {ms_inductor/ms_blitz:.2f}x") if __name__ == "__main__": run_rigorous_quant()