File size: 1,128 Bytes
f6e23b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
import triton
import triton.language as tl
import time

@triton.jit
def blitz_scan_kernel(X, Y, N, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < N
    x = tl.load(X + offsets, mask=mask)
    # Simplified artisan scan simulation
    y = tl.cumsum(x, axis=0)
    tl.store(Y + offsets, y, mask=mask)

def benchmark_blitz(size):
    X = torch.randn(size, device="cuda", dtype=torch.float32)
    Y = torch.empty_like(X)
    
    # Warmup
    blitz_scan_kernel[(1, )](X, Y, size, BLOCK_SIZE=size)
    
    torch.cuda.synchronize()
    start = time.time()
    for _ in range(100):
        blitz_scan_kernel[(1, )](X, Y, size, BLOCK_SIZE=size)
    torch.cuda.synchronize()
    avg_ms = (time.time() - start) / 100 * 1000
    throughput = (X.numel() * X.element_size()) / (avg_ms / 1000) / 1e9
    print(f"Size: {size}, Time: {avg_ms:.4f}ms, Throughput: {throughput:.2f} GB/s")

if __name__ == "__main__":
    print("--- Blitz Artisan Kernel Benchmark (H200) ---")
    for size in [1024, 2048, 4096, 8192]:
        benchmark_blitz(size)