Spaces:
Sleeping
Sleeping
| import math | |
| from functools import partial | |
| import pytest | |
| import torch | |
| import torch.nn.functional as F | |
| from einops import rearrange | |
| from flash_attn.ops.fused_dense import FusedDense, FusedMLP | |
| def test_fused_linear_bias(in_features, out_features, has_bias, return_residual, dtype): | |
| device = "cuda" | |
| rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3) | |
| # set seed | |
| torch.random.manual_seed(0) | |
| batch_size = 8 | |
| seqlen = 512 | |
| x_pt = torch.randn( | |
| batch_size, seqlen, in_features, device=device, dtype=dtype, requires_grad=True | |
| ) | |
| x = x_pt.detach().clone().requires_grad_() | |
| model_pt = torch.nn.Linear(in_features, out_features, bias=has_bias, device=device, dtype=dtype) | |
| model = FusedDense( | |
| in_features, | |
| out_features, | |
| bias=has_bias, | |
| return_residual=return_residual, | |
| device=device, | |
| dtype=dtype, | |
| ) | |
| with torch.no_grad(): | |
| model.weight.copy_(model_pt.weight) | |
| if has_bias: | |
| model.bias.copy_(model_pt.bias) | |
| out_pt = model_pt(x_pt) | |
| if not return_residual: | |
| out = model(x) | |
| else: | |
| out, x_copy = model(x) | |
| x_copy = ( | |
| x_copy[..., :out_features] | |
| if out_features < in_features | |
| else F.pad(x_copy, (0, out_features - in_features)) | |
| ) | |
| x_pt_copy = ( | |
| x_pt[..., :out_features] | |
| if out_features < in_features | |
| else F.pad(x_pt, (0, out_features - in_features)) | |
| ) | |
| # Just add some random function of the residual | |
| out_pt = out_pt + F.gelu(x_pt_copy) | |
| out = out + F.gelu(x_copy) | |
| # with torch.no_grad(): | |
| # out_fl = F.linear(x_pt.float(), model.weight.float(), model.bias.float()).half() | |
| assert torch.allclose(out, out_pt, rtol=rtol, atol=atol) | |
| # If we don't divide by batch_size, the gradient gets a bit too large. | |
| g = torch.randn_like(out) / 32 | |
| out_pt.backward(g) | |
| out.backward(g) | |
| assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol) | |
| # The error for d_weight and d_bias is quite a bit higher | |
| assert torch.allclose(model.weight.grad, model_pt.weight.grad, rtol=rtol, atol=atol * 10) | |
| if has_bias: | |
| assert torch.allclose(model.bias.grad, model_pt.bias.grad, rtol=rtol, atol=atol * 5) | |
| # @pytest.mark.parametrize('dtype', [torch.float16]) | |
| # @pytest.mark.parametrize('heuristic', ['auto']) | |
| # @pytest.mark.parametrize('checkpoint_lvl', [1]) | |
| # @pytest.mark.parametrize('return_residual', [False]) | |
| # @pytest.mark.parametrize('has_bias2', [True]) | |
| # @pytest.mark.parametrize('has_bias1', [True]) | |
| # @pytest.mark.parametrize('activation', ['relu']) | |
| # @pytest.mark.parametrize('out_features', [4096]) | |
| # @pytest.mark.parametrize('in_features', [1024]) | |
| def test_fused_mlp( | |
| in_features, | |
| out_features, | |
| activation, | |
| has_bias1, | |
| has_bias2, | |
| return_residual, | |
| checkpoint_lvl, | |
| heuristic, | |
| dtype, | |
| ): | |
| device = "cuda" | |
| rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3) | |
| # set seed | |
| torch.random.manual_seed(0) | |
| batch_size = 8 | |
| seqlen = 512 | |
| x_pt = torch.randn( | |
| batch_size, seqlen, in_features, device=device, dtype=dtype, requires_grad=True | |
| ) | |
| x = x_pt.detach().clone().requires_grad_() | |
| model_pt_fc1 = torch.nn.Linear( | |
| in_features, out_features, bias=has_bias1, device=device, dtype=dtype | |
| ) | |
| model_pt_fc2 = torch.nn.Linear( | |
| out_features, in_features, bias=has_bias2, device=device, dtype=dtype | |
| ) | |
| model = FusedMLP( | |
| in_features, | |
| out_features, | |
| in_features, | |
| activation=activation, | |
| bias1=has_bias1, | |
| bias2=has_bias2, | |
| return_residual=return_residual, | |
| checkpoint_lvl=checkpoint_lvl, | |
| heuristic=heuristic, | |
| device=device, | |
| dtype=dtype, | |
| ) | |
| with torch.no_grad(): | |
| model.fc1.weight.copy_(model_pt_fc1.weight) | |
| if has_bias1: | |
| model.fc1.bias.copy_(model_pt_fc1.bias) | |
| model.fc2.weight.copy_(model_pt_fc2.weight) | |
| if has_bias2: | |
| model.fc2.bias.copy_(model_pt_fc2.bias) | |
| activation_fn = ( | |
| partial(F.gelu, approximate="tanh") | |
| if activation == "gelu_approx" | |
| else partial(F.relu, inplace=True) | |
| ) | |
| out_pt = model_pt_fc2(activation_fn(model_pt_fc1(x_pt))) | |
| if not return_residual: | |
| out = model(x) | |
| else: | |
| out, x_copy = model(x) | |
| # Just add some random function of the residual | |
| out_pt = out_pt + F.gelu(x_pt) | |
| out = out + F.gelu(x_copy) | |
| assert torch.allclose(out, out_pt, rtol=rtol, atol=atol) | |
| # If we don't divide by batch_size, the gradient gets a bit too large. | |
| g = torch.randn_like(out) / 32 | |
| out_pt.backward(g) | |
| out.backward(g) | |
| # The error for relu is higher still | |
| if activation == "relu": | |
| atol = 1e-1 if dtype == torch.bfloat16 else 5e-2 | |
| assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol) | |
| # The error for d_weight and d_bias is quite a bit higher | |
| assert torch.allclose( | |
| model.fc1.weight.grad, model_pt_fc1.weight.grad, rtol=rtol, atol=atol * 10 | |
| ) | |
| if has_bias1: | |
| assert torch.allclose(model.fc1.bias.grad, model_pt_fc1.bias.grad, rtol=rtol, atol=atol * 5) | |
| assert torch.allclose( | |
| model.fc2.weight.grad, model_pt_fc2.weight.grad, rtol=rtol, atol=atol * 10 | |
| ) | |
| if has_bias2: | |
| assert torch.allclose(model.fc2.bias.grad, model_pt_fc2.bias.grad, rtol=rtol, atol=atol * 5) | |