| import torch | |
| import torch.nn as nn | |
| class Model(nn.Module): | |
| def __init__(self): super().__init__() | |
| def forward(self, x): | |
| return x.to(torch.float8_e4m3fn).view(torch.uint8).to(torch.float32) | |
| def get_inputs(): return [torch.randn(8192, device="cuda")] | |
| def get_init_inputs(): return [] | |