File size: 298 Bytes
f6e23b0 |
1 2 3 4 5 6 7 8 9 |
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self): super().__init__()
def forward(self, x):
return x.to(torch.float8_e4m3fn).view(torch.uint8).to(torch.float32)
def get_inputs(): return [torch.randn(8192, device="cuda")]
def get_init_inputs(): return []
|