| | import random |
| | from dataclasses import dataclass |
| |
|
| | import pytest |
| | import torch |
| |
|
| | import activation |
| |
|
| | from .test_poly_norm import poly_norm |
| | from .utils import assert_close |
| |
|
| | CASES = [ |
| | ((1, 2048, 8192), torch.bfloat16), |
| | ((1, 2048, 16384), torch.bfloat16), |
| | ((1, 16384, 8192), torch.bfloat16), |
| | ((1, 16384, 16384), torch.bfloat16), |
| | ] |
| | NUM_REP = 100 |
| |
|
| |
|
| | @dataclass |
| | class PerfResult: |
| | type: str |
| | shape: tuple |
| | dtype: torch.dtype |
| | kernel_time_ms: float |
| | torch_time_ms: float |
| |
|
| | @property |
| | def speedup(self) -> float: |
| | return self.torch_time_ms / self.kernel_time_ms |
| |
|
| |
|
| | PERF_RESULTS: list[PerfResult] = [] |
| |
|
| |
|
| | @pytest.mark.parametrize("cases", CASES) |
| | @pytest.mark.perf |
| | def test_poly_norm( |
| | cases: tuple, |
| | do_plot: bool, |
| | ) -> None: |
| | random.seed(12345) |
| | torch.manual_seed(12345) |
| |
|
| | torch.set_default_device("cuda") |
| |
|
| | shape, dtype = cases |
| | x = torch.randn(shape, dtype=dtype, requires_grad=True) |
| | weight = torch.randn(3, dtype=dtype, requires_grad=True) |
| | bias = torch.randn(1, dtype=dtype, requires_grad=True) |
| | eps = 1e-05 |
| |
|
| | x.retain_grad() |
| | weight.retain_grad() |
| | bias.retain_grad() |
| | |
| |
|
| | x_ref = x.detach().clone().requires_grad_(True) |
| | weight_ref = weight.detach().clone().requires_grad_(True) |
| | bias_ref = bias.detach().clone().requires_grad_(True) |
| |
|
| | torch_fn = poly_norm |
| | layer = activation.layers.PolyNorm(eps) |
| | layer.weight = torch.nn.Parameter(weight) |
| | layer.bias = torch.nn.Parameter(bias) |
| |
|
| | |
| | mod_out = layer(x) |
| | ref_out = torch_fn(x_ref, weight_ref, bias_ref, eps) |
| | assert_close(mod_out, ref_out) |
| |
|
| | out_grad = torch.rand_like(ref_out) |
| | out_grad = out_grad / out_grad.norm() |
| |
|
| | ref_out.backward(out_grad, retain_graph=True) |
| | mod_out.backward(out_grad, retain_graph=True) |
| |
|
| | assert_close(x.grad, x_ref.grad) |
| | assert_close(layer.bias.grad, bias_ref.grad, rtol=0.05) |
| | assert_close(layer.weight.grad, weight_ref.grad, rtol=0.05) |
| |
|
| | def time_cuda(fn): |
| | start = torch.cuda.Event(enable_timing=True) |
| | end = torch.cuda.Event(enable_timing=True) |
| |
|
| | for _ in range(5): |
| | fn() |
| | start.record() |
| | for _ in range(NUM_REP): |
| | fn() |
| | end.record() |
| | torch.cuda.synchronize() |
| | return start.elapsed_time(end) / NUM_REP |
| |
|
| | kernel_time_ms = time_cuda(lambda: layer(x)) |
| | torch_fn_time = time_cuda(lambda: torch_fn(x_ref, weight_ref, bias_ref, eps)) |
| |
|
| | PERF_RESULTS.append( |
| | PerfResult( |
| | type="forward", |
| | shape=shape, |
| | dtype=dtype, |
| | kernel_time_ms=kernel_time_ms, |
| | torch_time_ms=torch_fn_time, |
| | ) |
| | ) |
| |
|
| | kernel_time_ms = time_cuda(lambda: mod_out.backward(out_grad, retain_graph=True)) |
| | torch_fn_time = time_cuda(lambda: ref_out.backward(out_grad, retain_graph=True)) |
| |
|
| | PERF_RESULTS.append( |
| | PerfResult( |
| | type="backward", |
| | shape=shape, |
| | dtype=dtype, |
| | kernel_time_ms=kernel_time_ms, |
| | torch_time_ms=torch_fn_time, |
| | ) |
| | ) |
| |
|