Kernels
activation / benchmarks /common /diff_engine.py
TaehyunKim
Fix fused add rms norm (#4)
a1e5ca8 unverified
from abc import ABC, abstractmethod
from typing import Any, Dict, Sequence
import torch
class DiffCase(ABC):
@abstractmethod
def build_inputs(self, hidden: int, bs: int, sl: int, dtype: torch.dtype,
eps: float) -> Dict[str, Any]:
...
@abstractmethod
def make_naive(self, I: Dict[str, Any]) -> Any:
...
@abstractmethod
def make_cuda(self, I: Dict[str, Any]) -> Any:
...
@abstractmethod
def forward(self, obj: Any, I: Dict[str, Any]) -> torch.Tensor:
...
@abstractmethod
def grad_inputs(self, I: Dict[str, Any]) -> Sequence[torch.Tensor]:
...
def _clone_payload(d, device):
out = {}
for k, v in d.items():
if isinstance(v, torch.Tensor):
t = v.detach().clone().to(device)
t.requires_grad_(v.requires_grad)
out[k] = t
else:
out[k] = v
return out
def _unit_grad_like(y):
g = torch.randn_like(y)
n = g.norm()
return g if n == 0 else g / n
def calculate_diff(
case: DiffCase,
*,
batch_size: int,
seq_len: int,
hidden_size: int,
dtype=torch.bfloat16,
eps: float = 1e-6,
atol: float = 1e-2,
rtol: float = 1e-2,
device="cuda",
) -> None:
base = case.build_inputs(hidden_size, batch_size, seq_len, dtype, eps)
I_n = _clone_payload(base, device)
I_c = _clone_payload(base, device)
obj_n = case.make_naive(I_n)
obj_c = case.make_cuda(I_c)
y_n = case.forward(obj_n, I_n)
y_c = case.forward(obj_c, I_c)
torch.testing.assert_close(y_n, y_c, atol=atol, rtol=rtol)
gin_n = list(case.grad_inputs(I_n)) + list(obj_n.parameters())
gin_c = list(case.grad_inputs(I_c)) + list(obj_c.parameters())
if isinstance(y_n, torch.Tensor):
g = [_unit_grad_like(y_n).to(device)]
else:
g = [_unit_grad_like(r).to(device) for r in y_n]
ng = torch.autograd.grad(y_n,
gin_n,
g,
retain_graph=False,
create_graph=False,
allow_unused=False)
cg = torch.autograd.grad(y_c,
gin_c,
g,
retain_graph=False,
create_graph=False,
allow_unused=False)
torch.testing.assert_close(ng, cg, atol=atol, rtol=rtol)
print("✅ forward + backward match")