Kernels
File size: 2,509 Bytes
e5e2eeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1e5ca8
 
 
 
e5e2eeb
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from abc import ABC, abstractmethod
from typing import Any, Dict, Sequence

import torch


class DiffCase(ABC):

    @abstractmethod
    def build_inputs(self, hidden: int, bs: int, sl: int, dtype: torch.dtype,
                     eps: float) -> Dict[str, Any]:
        ...

    @abstractmethod
    def make_naive(self, I: Dict[str, Any]) -> Any:
        ...

    @abstractmethod
    def make_cuda(self, I: Dict[str, Any]) -> Any:
        ...

    @abstractmethod
    def forward(self, obj: Any, I: Dict[str, Any]) -> torch.Tensor:
        ...

    @abstractmethod
    def grad_inputs(self, I: Dict[str, Any]) -> Sequence[torch.Tensor]:
        ...


def _clone_payload(d, device):
    out = {}
    for k, v in d.items():
        if isinstance(v, torch.Tensor):
            t = v.detach().clone().to(device)
            t.requires_grad_(v.requires_grad)
            out[k] = t
        else:
            out[k] = v
    return out


def _unit_grad_like(y):
    g = torch.randn_like(y)
    n = g.norm()
    return g if n == 0 else g / n


def calculate_diff(
    case: DiffCase,
    *,
    batch_size: int,
    seq_len: int,
    hidden_size: int,
    dtype=torch.bfloat16,
    eps: float = 1e-6,
    atol: float = 1e-2,
    rtol: float = 1e-2,
    device="cuda",
) -> None:
    base = case.build_inputs(hidden_size, batch_size, seq_len, dtype, eps)
    I_n = _clone_payload(base, device)
    I_c = _clone_payload(base, device)
    obj_n = case.make_naive(I_n)
    obj_c = case.make_cuda(I_c)
    y_n = case.forward(obj_n, I_n)
    y_c = case.forward(obj_c, I_c)
    torch.testing.assert_close(y_n, y_c, atol=atol, rtol=rtol)
    gin_n = list(case.grad_inputs(I_n)) + list(obj_n.parameters())
    gin_c = list(case.grad_inputs(I_c)) + list(obj_c.parameters())
    if isinstance(y_n, torch.Tensor):
        g = [_unit_grad_like(y_n).to(device)]
    else:
        g = [_unit_grad_like(r).to(device) for r in y_n]
    ng = torch.autograd.grad(y_n,
                             gin_n,
                             g,
                             retain_graph=False,
                             create_graph=False,
                             allow_unused=False)
    cg = torch.autograd.grad(y_c,
                             gin_c,
                             g,
                             retain_graph=False,
                             create_graph=False,
                             allow_unused=False)
    torch.testing.assert_close(ng, cg, atol=atol, rtol=rtol)
    print("✅ forward + backward match")