| import torch | |
| import torch.nn as nn | |
| import triton | |
| import triton.language as tl | |
| import sys | |
| sys.path.append("/models/blitz/crates/blitz-kernels/src/cuda") | |
| def blitz_speed_kernel(X, Y, N, BLOCK_SIZE: tl.constexpr): | |
| pid = tl.program_id(0) | |
| offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) | |
| mask = offsets < N | |
| x = tl.load(X + offsets, mask=mask) | |
| y = x.to(tl.float8e4nv) | |
| tl.store(Y + offsets, y.to(tl.int8, bitcast=True), mask=mask) | |
| class ModelNew(nn.Module): | |
| def __init__(self): super().__init__() | |
| def forward(self, x): | |
| y = torch.empty(x.shape, device="cuda", dtype=torch.int8) | |
| blitz_speed_kernel[(1,)](x, y, x.numel(), BLOCK_SIZE=x.numel()) | |
| return y.view(torch.uint8).to(torch.float32) | |
| def get_inputs(): return [torch.randn(8192, device="cuda")] | |
| def get_init_inputs(): return [] | |