|
|
import torch |
|
|
import triton |
|
|
import triton.testing |
|
|
import sys |
|
|
from torch.utils.benchmark import Timer |
|
|
|
|
|
sys.path.append("/models/blitz/crates/blitz-kernels/src/cuda") |
|
|
from ghost_quant import ghost_quant_fp8_kernel |
|
|
|
|
|
def run_rigorous_quant(): |
|
|
N = 1024 * 1024 * 16 |
|
|
X = torch.randn(N, device="cuda", dtype=torch.float32) |
|
|
Y_blitz = torch.empty(N, device="cuda", dtype=torch.int8) |
|
|
seed = 42 |
|
|
|
|
|
|
|
|
def ref_fn(x): |
|
|
return x.to(torch.float8_e4m3fn).view(torch.int8) |
|
|
|
|
|
|
|
|
ghost_quant_fp8_kernel[(triton.cdiv(N, 1024),)](X, Y_blitz, seed, N, BLOCK_SIZE=1024) |
|
|
y_ref = ref_fn(X) |
|
|
|
|
|
|
|
|
diff = (Y_blitz.float() - y_ref.float()).abs().mean() |
|
|
print(f"Correctness (Mean Diff): {diff:.6f}") |
|
|
|
|
|
|
|
|
|
|
|
ms_blitz = triton.testing.do_bench(lambda: ghost_quant_fp8_kernel[(triton.cdiv(N, 1024),)](X, Y_blitz, seed, N, BLOCK_SIZE=1024)) |
|
|
|
|
|
|
|
|
compiled_ref = torch.compile(ref_fn, mode="max-autotune") |
|
|
compiled_ref(X) |
|
|
ms_inductor = triton.testing.do_bench(lambda: compiled_ref(X)) |
|
|
|
|
|
print(f"--- RIGOROUS RECEIPT: GHOST QUANT (16M Tokens) ---") |
|
|
print(f"H200 Inductor Latency: {ms_inductor:.4f} ms") |
|
|
print(f"Blitz Artisan Latency: {ms_blitz:.4f} ms") |
|
|
print(f"REAL SPEEDUP: {ms_inductor/ms_blitz:.2f}x") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
run_rigorous_quant() |
|
|
|