File size: 845 Bytes
f6e23b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
import torch.nn as nn
import triton
import triton.language as tl
import sys
sys.path.append("/models/blitz/crates/blitz-kernels/src/cuda")
@triton.jit
def blitz_speed_kernel(X, Y, N, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < N
    x = tl.load(X + offsets, mask=mask)
    y = x.to(tl.float8e4nv)
    tl.store(Y + offsets, y.to(tl.int8, bitcast=True), mask=mask)
class ModelNew(nn.Module):
    def __init__(self): super().__init__()
    def forward(self, x):
        y = torch.empty(x.shape, device="cuda", dtype=torch.int8)
        blitz_speed_kernel[(1,)](x, y, x.numel(), BLOCK_SIZE=x.numel())
        return y.view(torch.uint8).to(torch.float32)
def get_inputs(): return [torch.randn(8192, device="cuda")]
def get_init_inputs(): return []