| import torch | |
| import triton | |
| import triton.language as tl | |
| import time | |
| def bw_kernel(A, B, N, BLOCK_SIZE: tl.constexpr): | |
| pid = tl.program_id(0) | |
| offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) | |
| mask = offsets < N | |
| b = tl.load(B + offsets, mask=mask) | |
| tl.store(A + offsets, b, mask=mask) | |
| def run_bw(): | |
| N = 1024 * 1024 * 512 | |
| A = torch.empty(N, device="cuda", dtype=torch.float32) | |
| B = torch.randn(N, device="cuda", dtype=torch.float32) | |
| # Use huge block size for Sm_90 | |
| BLOCK_SIZE = 16384 | |
| grid = (triton.cdiv(N, BLOCK_SIZE),) | |
| torch.cuda.synchronize() | |
| start = time.time() | |
| for _ in range(100): bw_kernel[grid](A, B, N, BLOCK_SIZE=BLOCK_SIZE) | |
| torch.cuda.synchronize() | |
| bw = (2 * N * 4) / ((time.time() - start) / 100) / 1e12 | |
| print(f"H200 HBM3e (Artisan): {bw:.2f} TB/s") | |
| if __name__ == "__main__": | |
| run_bw() | |