|
|
from abc import ABC, abstractmethod |
|
|
from typing import Any, Dict, Sequence |
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
class DiffCase(ABC): |
|
|
|
|
|
@abstractmethod |
|
|
def build_inputs(self, hidden: int, bs: int, sl: int, dtype: torch.dtype, |
|
|
eps: float) -> Dict[str, Any]: |
|
|
... |
|
|
|
|
|
@abstractmethod |
|
|
def make_naive(self, I: Dict[str, Any]) -> Any: |
|
|
... |
|
|
|
|
|
@abstractmethod |
|
|
def make_cuda(self, I: Dict[str, Any]) -> Any: |
|
|
... |
|
|
|
|
|
@abstractmethod |
|
|
def forward(self, obj: Any, I: Dict[str, Any]) -> torch.Tensor: |
|
|
... |
|
|
|
|
|
@abstractmethod |
|
|
def grad_inputs(self, I: Dict[str, Any]) -> Sequence[torch.Tensor]: |
|
|
... |
|
|
|
|
|
|
|
|
def _clone_payload(d, device): |
|
|
out = {} |
|
|
for k, v in d.items(): |
|
|
if isinstance(v, torch.Tensor): |
|
|
t = v.detach().clone().to(device) |
|
|
t.requires_grad_(v.requires_grad) |
|
|
out[k] = t |
|
|
else: |
|
|
out[k] = v |
|
|
return out |
|
|
|
|
|
|
|
|
def _unit_grad_like(y): |
|
|
g = torch.randn_like(y) |
|
|
n = g.norm() |
|
|
return g if n == 0 else g / n |
|
|
|
|
|
|
|
|
def calculate_diff( |
|
|
case: DiffCase, |
|
|
*, |
|
|
batch_size: int, |
|
|
seq_len: int, |
|
|
hidden_size: int, |
|
|
dtype=torch.bfloat16, |
|
|
eps: float = 1e-6, |
|
|
atol: float = 1e-2, |
|
|
rtol: float = 1e-2, |
|
|
device="cuda", |
|
|
) -> None: |
|
|
base = case.build_inputs(hidden_size, batch_size, seq_len, dtype, eps) |
|
|
I_n = _clone_payload(base, device) |
|
|
I_c = _clone_payload(base, device) |
|
|
obj_n = case.make_naive(I_n) |
|
|
obj_c = case.make_cuda(I_c) |
|
|
y_n = case.forward(obj_n, I_n) |
|
|
y_c = case.forward(obj_c, I_c) |
|
|
torch.testing.assert_close(y_n, y_c, atol=atol, rtol=rtol) |
|
|
gin_n = list(case.grad_inputs(I_n)) + list(obj_n.parameters()) |
|
|
gin_c = list(case.grad_inputs(I_c)) + list(obj_c.parameters()) |
|
|
if isinstance(y_n, torch.Tensor): |
|
|
g = [_unit_grad_like(y_n).to(device)] |
|
|
else: |
|
|
g = [_unit_grad_like(r).to(device) for r in y_n] |
|
|
ng = torch.autograd.grad(y_n, |
|
|
gin_n, |
|
|
g, |
|
|
retain_graph=False, |
|
|
create_graph=False, |
|
|
allow_unused=False) |
|
|
cg = torch.autograd.grad(y_c, |
|
|
gin_c, |
|
|
g, |
|
|
retain_graph=False, |
|
|
create_graph=False, |
|
|
allow_unused=False) |
|
|
torch.testing.assert_close(ng, cg, atol=atol, rtol=rtol) |
|
|
print("✅ forward + backward match") |
|
|
|