Kernels
File size: 1,754 Bytes
e5e2eeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
from common.diff_engine import DiffCase

import activation


class PolyNorm(torch.nn.Module):

    def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3)
        self.bias = torch.nn.Parameter(torch.zeros(1, dtype=dtype))
        self.eps = eps

    def _norm(self, x):
        return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        orig_dtype = x.dtype
        x_float = x.to(torch.float32)
        output = (self.weight[0] * self._norm(x_float**3) +
                  self.weight[1] * self._norm(x_float**2) +
                  self.weight[2] * self._norm(x_float) + self.bias)
        return output.to(orig_dtype)


class Poly(DiffCase):

    def build_inputs(self, bs, sl, hidden, dtype, eps):
        return {
            "x": torch.randn(bs, sl, hidden, dtype=dtype, requires_grad=True),
            "weight": torch.ones(3, dtype=dtype),
            "bias": torch.ones(1, dtype=dtype),
            "dim": hidden,
            "eps": eps,
            "dtype": dtype,
        }

    def make_naive(self, I):
        m = PolyNorm(I["eps"], dtype=I["dtype"])
        m.weight = torch.nn.Parameter(I["weight"].detach().clone())
        m.bias = torch.nn.Parameter(I["bias"].detach().clone())
        return m

    def make_cuda(self, I):
        m = activation.layers.PolyNorm(I["eps"], dtype=I["dtype"])
        m.weight = torch.nn.Parameter(I["weight"].detach().clone())
        m.bias = torch.nn.Parameter(I["bias"].detach().clone())
        return m

    def forward(self, obj, I):
        return obj(I["x"])

    def grad_inputs(self, I):
        return [I["x"]]


CASE = Poly()