File size: 2,509 Bytes
e5e2eeb a1e5ca8 e5e2eeb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
from abc import ABC, abstractmethod
from typing import Any, Dict, Sequence
import torch
class DiffCase(ABC):
@abstractmethod
def build_inputs(self, hidden: int, bs: int, sl: int, dtype: torch.dtype,
eps: float) -> Dict[str, Any]:
...
@abstractmethod
def make_naive(self, I: Dict[str, Any]) -> Any:
...
@abstractmethod
def make_cuda(self, I: Dict[str, Any]) -> Any:
...
@abstractmethod
def forward(self, obj: Any, I: Dict[str, Any]) -> torch.Tensor:
...
@abstractmethod
def grad_inputs(self, I: Dict[str, Any]) -> Sequence[torch.Tensor]:
...
def _clone_payload(d, device):
out = {}
for k, v in d.items():
if isinstance(v, torch.Tensor):
t = v.detach().clone().to(device)
t.requires_grad_(v.requires_grad)
out[k] = t
else:
out[k] = v
return out
def _unit_grad_like(y):
g = torch.randn_like(y)
n = g.norm()
return g if n == 0 else g / n
def calculate_diff(
case: DiffCase,
*,
batch_size: int,
seq_len: int,
hidden_size: int,
dtype=torch.bfloat16,
eps: float = 1e-6,
atol: float = 1e-2,
rtol: float = 1e-2,
device="cuda",
) -> None:
base = case.build_inputs(hidden_size, batch_size, seq_len, dtype, eps)
I_n = _clone_payload(base, device)
I_c = _clone_payload(base, device)
obj_n = case.make_naive(I_n)
obj_c = case.make_cuda(I_c)
y_n = case.forward(obj_n, I_n)
y_c = case.forward(obj_c, I_c)
torch.testing.assert_close(y_n, y_c, atol=atol, rtol=rtol)
gin_n = list(case.grad_inputs(I_n)) + list(obj_n.parameters())
gin_c = list(case.grad_inputs(I_c)) + list(obj_c.parameters())
if isinstance(y_n, torch.Tensor):
g = [_unit_grad_like(y_n).to(device)]
else:
g = [_unit_grad_like(r).to(device) for r in y_n]
ng = torch.autograd.grad(y_n,
gin_n,
g,
retain_graph=False,
create_graph=False,
allow_unused=False)
cg = torch.autograd.grad(y_c,
gin_c,
g,
retain_graph=False,
create_graph=False,
allow_unused=False)
torch.testing.assert_close(ng, cg, atol=atol, rtol=rtol)
print("✅ forward + backward match")
|