File size: 298 Bytes
f6e23b0
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
import torch
import torch.nn as nn
class Model(nn.Module):
    def __init__(self): super().__init__()
    def forward(self, x):
        return x.to(torch.float8_e4m3fn).view(torch.uint8).to(torch.float32)
def get_inputs(): return [torch.randn(8192, device="cuda")]
def get_init_inputs(): return []