| import torch | |
| from common.diff_engine import DiffCase | |
| import activation | |
| class PolyNorm(torch.nn.Module): | |
| def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32): | |
| super().__init__() | |
| self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3) | |
| self.bias = torch.nn.Parameter(torch.zeros(1, dtype=dtype)) | |
| self.eps = eps | |
| def _norm(self, x): | |
| return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| def forward(self, x): | |
| orig_dtype = x.dtype | |
| x_float = x.to(torch.float32) | |
| output = (self.weight[0] * self._norm(x_float**3) + | |
| self.weight[1] * self._norm(x_float**2) + | |
| self.weight[2] * self._norm(x_float) + self.bias) | |
| return output.to(orig_dtype) | |
| class Poly(DiffCase): | |
| def build_inputs(self, bs, sl, hidden, dtype, eps): | |
| return { | |
| "x": torch.randn(bs, sl, hidden, dtype=dtype, requires_grad=True), | |
| "weight": torch.ones(3, dtype=dtype), | |
| "bias": torch.ones(1, dtype=dtype), | |
| "dim": hidden, | |
| "eps": eps, | |
| "dtype": dtype, | |
| } | |
| def make_naive(self, I): | |
| m = PolyNorm(I["eps"], dtype=I["dtype"]) | |
| m.weight = torch.nn.Parameter(I["weight"].detach().clone()) | |
| m.bias = torch.nn.Parameter(I["bias"].detach().clone()) | |
| return m | |
| def make_cuda(self, I): | |
| m = activation.layers.PolyNorm(I["eps"], dtype=I["dtype"]) | |
| m.weight = torch.nn.Parameter(I["weight"].detach().clone()) | |
| m.bias = torch.nn.Parameter(I["bias"].detach().clone()) | |
| return m | |
| def forward(self, obj, I): | |
| return obj(I["x"]) | |
| def grad_inputs(self, I): | |
| return [I["x"]] | |
| CASE = Poly() | |