Spaces:
Sleeping
Sleeping
| import math | |
| import pytest | |
| import torch | |
| import torch.nn.functional as F | |
| from einops import rearrange, repeat | |
| from flash_attn.ops.layer_norm import ( | |
| DropoutAddLayerNorm, | |
| dropout_add_layer_norm, | |
| dropout_add_layer_norm_parallel_residual, | |
| dropout_add_layer_norm_subset, | |
| ) | |
| from flash_attn.ops.rms_norm import ( | |
| DropoutAddRMSNorm, | |
| dropout_add_rms_norm, | |
| dropout_add_rms_norm_parallel_residual, | |
| dropout_add_rms_norm_subset, | |
| ) | |
| try: | |
| from apex.normalization import FusedRMSNorm | |
| from apex.normalization.fused_layer_norm import fused_rms_norm_affine | |
| except: | |
| FusedRMSNorm, fused_rms_norm_affine = None, None | |
| is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8 | |
| # @pytest.mark.parametrize('has_colscale', [False]) | |
| # @pytest.mark.parametrize('has_rowscale', [True]) | |
| # @pytest.mark.parametrize('has_residual', [False]) | |
| # @pytest.mark.parametrize('dropout_p', [0.0]) | |
| # @pytest.mark.parametrize('weight_dtype', [torch.float32]) | |
| # @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)]) | |
| # @pytest.mark.parametrize('hidden_size', [256]) | |
| def test_dropout_layer_norm_training( | |
| hidden_size, | |
| input_dtype, | |
| residual_dtype, | |
| weight_dtype, | |
| dropout_p, | |
| has_residual, | |
| has_rowscale, | |
| has_colscale, | |
| is_rms_norm, | |
| ): | |
| if weight_dtype == torch.float16 and input_dtype == torch.bfloat16: | |
| pytest.skip() # Not supported | |
| if is_rms_norm and FusedRMSNorm is None: | |
| pytest.skip() # We need Apex's FusedRMSNorm to test | |
| layer_norm_cls = torch.nn.LayerNorm if not is_rms_norm else FusedRMSNorm | |
| our_layer_norm_cls = DropoutAddLayerNorm if not is_rms_norm else DropoutAddRMSNorm | |
| our_layer_norm_func = dropout_add_layer_norm if not is_rms_norm else dropout_add_rms_norm | |
| device = "cuda" | |
| # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4) | |
| rtol, atol = (1e-3, 1e-4) | |
| # set seed | |
| torch.random.manual_seed(0) | |
| batch_size = 8 | |
| seqlen = 512 | |
| x0_pt = torch.randn( | |
| batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True | |
| ) | |
| x0 = x0_pt.detach().clone().requires_grad_() | |
| x0_ref = x0_pt.detach().clone().float().requires_grad_() | |
| if has_colscale: | |
| colscale = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True) | |
| colscale_pt = colscale.detach().clone().requires_grad_() | |
| colscale_ref = colscale.detach().clone().float().requires_grad_() | |
| else: | |
| colscale = None | |
| if has_residual: | |
| res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True) | |
| res = res_pt.detach().clone().requires_grad_() | |
| res_ref = res_pt.detach().clone().float().requires_grad_() | |
| else: | |
| res = None | |
| if has_rowscale: | |
| rowscale = torch.empty(batch_size, seqlen, device=device, dtype=input_dtype) | |
| survival_rate = 0.87 | |
| rowscale = rowscale.bernoulli_(survival_rate) / survival_rate | |
| x0_scaled_pt = x0_pt * rearrange(rowscale, "... -> ... 1") | |
| x0_scaled_ref = x0_ref * rearrange(rowscale, "... -> ... 1") | |
| else: | |
| rowscale = None | |
| x0_scaled_pt = x0_pt | |
| x0_scaled_ref = x0_ref | |
| if has_colscale: | |
| x0_scaled_pt = x0_scaled_pt * colscale_pt | |
| x0_scaled_ref = x0_scaled_ref * colscale_ref | |
| model_pt = layer_norm_cls(hidden_size).to(device=device, dtype=weight_dtype) | |
| torch.nn.init.normal_(model_pt.weight) | |
| if not is_rms_norm: | |
| torch.nn.init.normal_(model_pt.bias) | |
| model_ref = layer_norm_cls(hidden_size).to(device=device, dtype=torch.float32) | |
| model = our_layer_norm_cls(hidden_size, p=dropout_p, device=device, dtype=weight_dtype) | |
| with torch.no_grad(): | |
| model.weight.copy_(model_pt.weight) | |
| model_ref.weight.copy_(model_pt.weight) | |
| if not is_rms_norm: | |
| model.bias.copy_(model_pt.bias) | |
| model_ref.bias.copy_(model_pt.bias) | |
| residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32 | |
| out, dmask = our_layer_norm_func( | |
| x0, | |
| res, | |
| model.weight, | |
| model.bias, | |
| model.p, | |
| model.eps, | |
| rowscale=rowscale, | |
| layerscale=colscale, | |
| residual_in_fp32=residual_in_fp32, | |
| return_dropout_mask=True, | |
| ) | |
| assert out.dtype == input_dtype | |
| print(f"Actual dropout fraction: {1 - dmask.float().mean().item()}") | |
| if has_residual: | |
| residual_pt = ( | |
| (x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p) + res_pt.float() | |
| ).to(dtype=residual_dtype) | |
| residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) + res_ref | |
| else: | |
| residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p)).to( | |
| dtype=residual_dtype | |
| ) | |
| residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) | |
| out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype) | |
| out_ref = model_ref(residual_ref) | |
| assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4 | |
| g = torch.randn_like(out) / batch_size | |
| out_pt.backward(g) | |
| out.backward(g) | |
| out_ref.backward(g) | |
| assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4 | |
| if has_residual: | |
| assert (res.grad - res_ref.grad).abs().max() <= 4 * ( | |
| res_pt.grad - res_ref.grad | |
| ).abs().max() + 1e-4 | |
| assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 3 * ( | |
| model_pt.weight.grad - model_ref.weight.grad | |
| ).abs().max() + 3e-5 | |
| if not is_rms_norm: | |
| assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * ( | |
| model_pt.bias.grad - model_ref.bias.grad | |
| ).abs().max() + 3e-5 | |
| if has_colscale: | |
| assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * ( | |
| colscale_pt.grad - colscale_ref.grad | |
| ).abs().max() + 2e-4 | |
| def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weight_dtype): | |
| if weight_dtype == torch.float16 and input_dtype == torch.bfloat16: | |
| pytest.skip() # Not supported | |
| device = "cuda" | |
| # rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4) | |
| rtol, atol = (1e-3, 1e-4) | |
| dropout_p = 0.37 | |
| # set seed | |
| torch.random.manual_seed(0) | |
| batch_size = 32 | |
| seqlen = 512 | |
| x0_pt = torch.randn( | |
| batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True | |
| ) | |
| x0 = x0_pt.detach().clone().requires_grad_() | |
| x0_ref = x0_pt.detach().clone().float().requires_grad_() | |
| res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True) | |
| res = res_pt.detach().clone().requires_grad_() | |
| res_ref = res_pt.detach().clone().float().requires_grad_() | |
| model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype) | |
| torch.nn.init.normal_(model_pt.weight) | |
| torch.nn.init.normal_(model_pt.bias) | |
| model = DropoutAddLayerNorm(hidden_size, p=dropout_p, device=device, dtype=weight_dtype) | |
| model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32) | |
| with torch.no_grad(): | |
| model.weight.copy_(model_pt.weight) | |
| model.bias.copy_(model_pt.bias) | |
| model_ref.weight.copy_(model_pt.weight) | |
| model_ref.bias.copy_(model_pt.bias) | |
| model_pt.eval() | |
| model.eval() | |
| model_ref.eval() | |
| out = model(x0, res) | |
| residual_pt = (x0_pt.float() + res_pt.float()).to(dtype=residual_dtype) | |
| residual_ref = x0_ref + res_ref | |
| out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(input_dtype) | |
| out_ref = model_ref(residual_ref) | |
| assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4 | |
| # @pytest.mark.parametrize('has_colscale', [True]) | |
| # @pytest.mark.parametrize('has_rowscale', [False]) | |
| # @pytest.mark.parametrize('has_residual', [True]) | |
| # @pytest.mark.parametrize('dropout_p', [0.0]) | |
| # @pytest.mark.parametrize('weight_dtype', [torch.float32]) | |
| # @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)]) | |
| # @pytest.mark.parametrize('hidden_size', [256]) | |
| def test_dropout_layer_norm_prenorm_training( | |
| hidden_size, | |
| input_dtype, | |
| residual_dtype, | |
| weight_dtype, | |
| dropout_p, | |
| has_residual, | |
| has_rowscale, | |
| has_colscale, | |
| is_rms_norm, | |
| ): | |
| if weight_dtype == torch.float16 and input_dtype == torch.bfloat16: | |
| pytest.skip() # Not supported | |
| if is_rms_norm and FusedRMSNorm is None: | |
| pytest.skip() # We need Apex's FusedRMSNorm to test | |
| layer_norm_cls = torch.nn.LayerNorm if not is_rms_norm else FusedRMSNorm | |
| our_layer_norm_cls = DropoutAddLayerNorm if not is_rms_norm else DropoutAddRMSNorm | |
| our_layer_norm_func = dropout_add_layer_norm if not is_rms_norm else dropout_add_rms_norm | |
| device = "cuda" | |
| # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4) | |
| rtol, atol = (1e-3, 2e-4) | |
| # set seed | |
| torch.random.manual_seed(0) | |
| batch_size = 8 | |
| seqlen = 512 | |
| x0_pt = torch.randn( | |
| batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True | |
| ) | |
| x0 = x0_pt.detach().clone().requires_grad_() | |
| x0_ref = x0_pt.detach().clone().float().requires_grad_() | |
| if has_colscale: | |
| colscale = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True) | |
| colscale_pt = colscale.detach().clone().requires_grad_() | |
| colscale_ref = colscale.detach().clone().float().requires_grad_() | |
| else: | |
| colscale = None | |
| if has_residual: | |
| res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True) | |
| res = res_pt.detach().clone().requires_grad_() | |
| res_ref = res_pt.detach().clone().float().requires_grad_() | |
| else: | |
| res = None | |
| if has_rowscale: | |
| rowscale = torch.empty(batch_size, seqlen, device=device, dtype=input_dtype) | |
| survival_rate = 0.87 | |
| rowscale = rowscale.bernoulli_(survival_rate) / survival_rate | |
| x0_scaled_pt = x0_pt * rearrange(rowscale, "... -> ... 1") | |
| x0_scaled_ref = x0_ref * rearrange(rowscale, "... -> ... 1") | |
| else: | |
| rowscale = None | |
| x0_scaled_pt = x0_pt | |
| x0_scaled_ref = x0_ref | |
| if has_colscale: | |
| x0_scaled_pt = x0_scaled_pt * colscale_pt | |
| x0_scaled_ref = x0_scaled_ref * colscale_ref | |
| model_pt = layer_norm_cls(hidden_size).to(device=device, dtype=weight_dtype) | |
| torch.nn.init.normal_(model_pt.weight) | |
| if not is_rms_norm: | |
| torch.nn.init.normal_(model_pt.bias) | |
| model_ref = layer_norm_cls(hidden_size).to(device=device, dtype=torch.float32) | |
| model = our_layer_norm_cls( | |
| hidden_size, prenorm=True, p=dropout_p, device=device, dtype=weight_dtype | |
| ) | |
| with torch.no_grad(): | |
| model.weight.copy_(model_pt.weight) | |
| model_ref.weight.copy_(model_pt.weight) | |
| if not is_rms_norm: | |
| model.bias.copy_(model_pt.bias) | |
| model_ref.bias.copy_(model_pt.bias) | |
| residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32 | |
| out, residual, dmask = our_layer_norm_func( | |
| x0, | |
| res, | |
| model.weight, | |
| model.bias, | |
| model.p, | |
| model.eps, | |
| rowscale=rowscale, | |
| layerscale=colscale, | |
| prenorm=True, | |
| residual_in_fp32=residual_in_fp32, | |
| return_dropout_mask=True, | |
| ) | |
| print(f"Actual dropout fraction: {1 - dmask.float().mean().item()}") | |
| if has_residual: | |
| residual_pt = ( | |
| (x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p) + res_pt.float() | |
| ).to(dtype=residual_dtype) | |
| residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) + res_ref | |
| else: | |
| residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p)).to( | |
| dtype=residual_dtype | |
| ) | |
| residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) | |
| out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype) | |
| out_ref = model_ref(residual_ref) | |
| assert out.dtype == input_dtype | |
| assert residual.dtype == residual_dtype | |
| assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4 | |
| assert (residual - residual_ref).abs().max() <= 4 * ( | |
| residual_pt - residual_ref | |
| ).abs().max() + 1e-4 | |
| g = torch.randn_like(out) / batch_size | |
| (out_pt * F.sigmoid(residual_pt)).backward(g) | |
| (out * F.sigmoid(residual)).backward(g) | |
| (out_ref * F.sigmoid(residual_ref.to(dtype=residual_dtype))).backward(g) | |
| assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4 | |
| if has_residual: | |
| assert (res.grad - res_ref.grad).abs().max() <= 4 * ( | |
| res_pt.grad - res_ref.grad | |
| ).abs().max() + 1e-4 | |
| assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * ( | |
| model_pt.weight.grad - model_ref.weight.grad | |
| ).abs().max() + 2e-4 | |
| if not is_rms_norm: | |
| assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * ( | |
| model_pt.bias.grad - model_ref.bias.grad | |
| ).abs().max() + 2e-4 | |
| if has_colscale: | |
| assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * ( | |
| colscale_pt.grad - colscale_ref.grad | |
| ).abs().max() + 2e-4 | |
| def test_dropout_layer_norm_prenorm_eval(hidden_size, input_dtype, residual_dtype, weight_dtype): | |
| if weight_dtype == torch.float16 and input_dtype == torch.bfloat16: | |
| pytest.skip() # Not supported | |
| device = "cuda" | |
| # rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4) | |
| rtol, atol = (1e-3, 1e-4) | |
| dropout_p = 0.37 | |
| # set seed | |
| torch.random.manual_seed(0) | |
| batch_size = 32 | |
| seqlen = 512 | |
| x0_pt = torch.randn( | |
| batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True | |
| ) | |
| x0 = x0_pt.detach().clone().requires_grad_() | |
| x0_ref = x0_pt.detach().clone().float().requires_grad_() | |
| res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True) | |
| res = res_pt.detach().clone().requires_grad_() | |
| res_ref = res_pt.detach().clone().float().requires_grad_() | |
| model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype) | |
| torch.nn.init.normal_(model_pt.weight) | |
| torch.nn.init.normal_(model_pt.bias) | |
| model = DropoutAddLayerNorm( | |
| hidden_size, prenorm=True, p=dropout_p, device=device, dtype=weight_dtype | |
| ) | |
| model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32) | |
| with torch.no_grad(): | |
| model.weight.copy_(model_pt.weight) | |
| model.bias.copy_(model_pt.bias) | |
| model_ref.weight.copy_(model_pt.weight) | |
| model_ref.bias.copy_(model_pt.bias) | |
| model_pt.eval() | |
| model.eval() | |
| model_ref.eval() | |
| out, residual = model(x0, res) | |
| residual_pt = (x0_pt.float() + res_pt.float()).to(dtype=residual_dtype) | |
| residual_ref = x0_ref + res_ref | |
| out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(input_dtype) | |
| out_ref = model_ref(residual_ref) | |
| assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4 | |
| assert (residual - residual_ref).abs().max() <= 4 * ( | |
| residual_pt - residual_ref | |
| ).abs().max() + 1e-4 | |
| # @pytest.mark.parametrize('has_colscale', [True]) | |
| # @pytest.mark.parametrize('has_residual', [True]) | |
| # @pytest.mark.parametrize('dropout_p', [0.0]) | |
| # @pytest.mark.parametrize('weight_dtype', [torch.float32]) | |
| # @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)]) | |
| # @pytest.mark.parametrize('hidden_size', [256]) | |
| def test_dropout_layer_norm_subset_training( | |
| hidden_size, input_dtype, residual_dtype, weight_dtype, dropout_p, has_residual, has_colscale | |
| ): | |
| if weight_dtype == torch.float16 and input_dtype == torch.bfloat16: | |
| pytest.skip() # Not supported | |
| device = "cuda" | |
| # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4) | |
| rtol, atol = (1e-3, 2e-4) | |
| # set seed | |
| torch.random.manual_seed(0) | |
| batch_size = 8 | |
| seqlen = 512 | |
| drop_path_rate = 0.4 | |
| drop_path_scale = 1 / (1 - drop_path_rate) | |
| def generate_droppath_masks(batch_size, seqlen, drop_path_rate, device): | |
| # Do it on CPU so we can get the numrows (with .item()) without GPU-CPU sync | |
| mask_batch = torch.rand(batch_size) < 1 - drop_path_rate | |
| numrows = (mask_batch).sum().item() * seqlen | |
| mask_batch = mask_batch.to(device=device, non_blocking=True) | |
| mask_batch_seqlen = repeat(mask_batch, "b -> (b s)", s=seqlen) | |
| subset = torch.cumsum(mask_batch_seqlen, dim=0, dtype=torch.int32).masked_fill_( | |
| ~mask_batch_seqlen, 0 | |
| ) | |
| return mask_batch, numrows, rearrange(subset, "(b s) -> b s", b=batch_size) | |
| x0_mask_batch, x0_numrows, x0_subset = generate_droppath_masks( | |
| batch_size, seqlen, drop_path_rate, device | |
| ) | |
| out_mask_batch, out_numrows, out_subset = generate_droppath_masks( | |
| batch_size, seqlen, drop_path_rate, device | |
| ) | |
| x0_pt = torch.randn( | |
| batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True | |
| ) | |
| x0 = x0_pt.detach().clone()[x0_mask_batch].requires_grad_() | |
| x0_ref = x0_pt.detach().clone().float().requires_grad_() | |
| if has_colscale: | |
| colscale = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True) | |
| colscale_pt = colscale.detach().clone().requires_grad_() | |
| colscale_ref = colscale.detach().clone().float().requires_grad_() | |
| else: | |
| colscale = None | |
| if has_residual: | |
| res_pt = torch.randn_like(x0_pt, dtype=residual_dtype, requires_grad=True) | |
| res = res_pt.detach().clone().requires_grad_() | |
| res_ref = res_pt.detach().clone().float().requires_grad_() | |
| else: | |
| res = None | |
| if has_colscale: | |
| x0_scaled_pt = x0_pt * colscale_pt | |
| x0_scaled_ref = x0_ref * colscale_ref | |
| else: | |
| x0_scaled_pt = x0_pt | |
| x0_scaled_ref = x0_ref | |
| model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype) | |
| torch.nn.init.normal_(model_pt.weight) | |
| torch.nn.init.normal_(model_pt.bias) | |
| model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32) | |
| model = DropoutAddLayerNorm( | |
| hidden_size, prenorm=False, p=dropout_p, device=device, dtype=weight_dtype | |
| ) | |
| with torch.no_grad(): | |
| model.weight.copy_(model_pt.weight) | |
| model.bias.copy_(model_pt.bias) | |
| model_ref.weight.copy_(model_pt.weight) | |
| model_ref.bias.copy_(model_pt.bias) | |
| residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32 | |
| out, dmask = dropout_add_layer_norm_subset( | |
| x0, | |
| res, | |
| model.weight, | |
| model.bias, | |
| model.p, | |
| model.eps, | |
| layerscale=colscale, | |
| x0_subset=x0_subset, | |
| out_subset=out_subset, | |
| rowscale_const=drop_path_scale, | |
| out_numrows=out_numrows, | |
| prenorm=False, | |
| residual_in_fp32=residual_in_fp32, | |
| return_dropout_mask=True, | |
| ) | |
| print(f"Actual dropout fraction: {1 - dmask.float().mean().item()}") | |
| x0_scaled_pt = ( | |
| x0_scaled_pt.masked_fill(repeat(~x0_mask_batch, "b -> b s d", s=seqlen, d=hidden_size), 0) | |
| * drop_path_scale | |
| ) | |
| x0_scaled_ref = ( | |
| x0_scaled_ref.masked_fill(repeat(~x0_mask_batch, "b -> b s d", s=seqlen, d=hidden_size), 0) | |
| * drop_path_scale | |
| ) | |
| dmask_expanded = torch.zeros_like(x0_pt, dtype=torch.uint8) | |
| dmask_expanded[x0_mask_batch] = dmask | |
| if has_residual: | |
| residual_pt = ( | |
| (x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p) + res_pt.float() | |
| ).to(dtype=residual_dtype) | |
| residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p) + res_ref | |
| else: | |
| residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p)).to( | |
| dtype=residual_dtype | |
| ) | |
| residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p) | |
| out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)[out_mask_batch] | |
| out_ref = model_ref(residual_ref)[out_mask_batch] | |
| assert out.dtype == input_dtype | |
| assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4 | |
| g = torch.randn_like(out) / batch_size | |
| out_pt.backward(g) | |
| out.backward(g) | |
| out_ref.backward(g) | |
| assert (x0.grad - x0_ref.grad[x0_mask_batch]).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad)[ | |
| x0_mask_batch | |
| ].abs().max() + 1e-4 | |
| if has_residual: | |
| assert (res.grad - res_ref.grad).abs().max() <= 4 * ( | |
| res_pt.grad - res_ref.grad | |
| ).abs().max() + 1e-4 | |
| assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * ( | |
| model_pt.weight.grad - model_ref.weight.grad | |
| ).abs().max() + 2e-4 | |
| assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * ( | |
| model_pt.bias.grad - model_ref.bias.grad | |
| ).abs().max() + 2e-4 | |
| if has_colscale: | |
| assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * ( | |
| colscale_pt.grad - colscale_ref.grad | |
| ).abs().max() + 2e-4 | |
| # @pytest.mark.parametrize('has_colscale', [True]) | |
| # @pytest.mark.parametrize('has_residual', [True]) | |
| # @pytest.mark.parametrize('dropout_p', [0.0]) | |
| # @pytest.mark.parametrize('weight_dtype', [torch.float32]) | |
| # @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)]) | |
| # @pytest.mark.parametrize('hidden_size', [256]) | |
| def test_dropout_layer_norm_subset_prenorm_training( | |
| hidden_size, input_dtype, residual_dtype, weight_dtype, dropout_p, has_residual, has_colscale | |
| ): | |
| if weight_dtype == torch.float16 and input_dtype == torch.bfloat16: | |
| pytest.skip() # Not supported | |
| device = "cuda" | |
| # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4) | |
| rtol, atol = (1e-3, 2e-4) | |
| # set seed | |
| torch.random.manual_seed(0) | |
| batch_size = 8 | |
| seqlen = 512 | |
| drop_path_rate = 0.4 | |
| drop_path_scale = 1 / (1 - drop_path_rate) | |
| def generate_droppath_masks(batch_size, seqlen, drop_path_rate, device): | |
| # Do it on CPU so we can get the numrows (with .item()) without GPU-CPU sync | |
| mask_batch = torch.rand(batch_size) < 1 - drop_path_rate | |
| numrows = (mask_batch).sum().item() * seqlen | |
| mask_batch = mask_batch.to(device=device, non_blocking=True) | |
| mask_batch_seqlen = repeat(mask_batch, "b -> (b s)", s=seqlen) | |
| subset = torch.cumsum(mask_batch_seqlen, dim=0, dtype=torch.int32).masked_fill_( | |
| ~mask_batch_seqlen, 0 | |
| ) | |
| return mask_batch, numrows, rearrange(subset, "(b s) -> b s", b=batch_size) | |
| x0_mask_batch, x0_numrows, x0_subset = generate_droppath_masks( | |
| batch_size, seqlen, drop_path_rate, device | |
| ) | |
| out_mask_batch, out_numrows, out_subset = generate_droppath_masks( | |
| batch_size, seqlen, drop_path_rate, device | |
| ) | |
| x0_pt = torch.randn( | |
| batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True | |
| ) | |
| x0 = x0_pt.detach().clone()[x0_mask_batch].requires_grad_() | |
| x0_ref = x0_pt.detach().clone().float().requires_grad_() | |
| if has_colscale: | |
| colscale = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True) | |
| colscale_pt = colscale.detach().clone().requires_grad_() | |
| colscale_ref = colscale.detach().clone().float().requires_grad_() | |
| else: | |
| colscale = None | |
| if has_residual: | |
| res_pt = torch.randn_like(x0_pt, dtype=residual_dtype, requires_grad=True) | |
| res = res_pt.detach().clone().requires_grad_() | |
| res_ref = res_pt.detach().clone().float().requires_grad_() | |
| else: | |
| res = None | |
| if has_colscale: | |
| x0_scaled_pt = x0_pt * colscale_pt | |
| x0_scaled_ref = x0_ref * colscale_ref | |
| else: | |
| x0_scaled_pt = x0_pt | |
| x0_scaled_ref = x0_ref | |
| model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype) | |
| torch.nn.init.normal_(model_pt.weight) | |
| torch.nn.init.normal_(model_pt.bias) | |
| model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32) | |
| model = DropoutAddLayerNorm( | |
| hidden_size, prenorm=True, p=dropout_p, device=device, dtype=weight_dtype | |
| ) | |
| with torch.no_grad(): | |
| model.weight.copy_(model_pt.weight) | |
| model.bias.copy_(model_pt.bias) | |
| model_ref.weight.copy_(model_pt.weight) | |
| model_ref.bias.copy_(model_pt.bias) | |
| residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32 | |
| out, residual, dmask = dropout_add_layer_norm_subset( | |
| x0, | |
| res, | |
| model.weight, | |
| model.bias, | |
| model.p, | |
| model.eps, | |
| layerscale=colscale, | |
| x0_subset=x0_subset, | |
| out_subset=out_subset, | |
| rowscale_const=drop_path_scale, | |
| out_numrows=out_numrows, | |
| prenorm=True, | |
| residual_in_fp32=residual_in_fp32, | |
| return_dropout_mask=True, | |
| ) | |
| print(f"Actual dropout fraction: {1 - dmask.float().mean().item()}") | |
| x0_scaled_pt = ( | |
| x0_scaled_pt.masked_fill(repeat(~x0_mask_batch, "b -> b s d", s=seqlen, d=hidden_size), 0) | |
| * drop_path_scale | |
| ) | |
| x0_scaled_ref = ( | |
| x0_scaled_ref.masked_fill(repeat(~x0_mask_batch, "b -> b s d", s=seqlen, d=hidden_size), 0) | |
| * drop_path_scale | |
| ) | |
| dmask_expanded = torch.zeros_like(x0_pt, dtype=torch.uint8) | |
| dmask_expanded[x0_mask_batch] = dmask | |
| if has_residual: | |
| residual_pt = ( | |
| (x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p) + res_pt.float() | |
| ).to(dtype=residual_dtype) | |
| residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p) + res_ref | |
| else: | |
| residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p)).to( | |
| dtype=residual_dtype | |
| ) | |
| residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p) | |
| out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)[out_mask_batch] | |
| out_ref = model_ref(residual_ref)[out_mask_batch] | |
| assert out.dtype == input_dtype | |
| assert residual.dtype == residual_dtype | |
| assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4 | |
| assert (residual - residual_ref).abs().max() <= 4 * ( | |
| residual_pt - residual_ref | |
| ).abs().max() + 1e-4 | |
| g = torch.randn_like(out) / batch_size | |
| (out_pt * F.sigmoid(residual_pt[out_mask_batch]) + residual_pt.mean(0, keepdim=True)).backward( | |
| g | |
| ) | |
| (out * F.sigmoid(residual[out_mask_batch]) + residual.mean(0, keepdim=True)).backward(g) | |
| ( | |
| out_ref * F.sigmoid(residual_ref[out_mask_batch].to(dtype=residual_dtype)) | |
| + residual_ref.mean(0, keepdim=True) | |
| ).backward(g) | |
| assert (x0.grad - x0_ref.grad[x0_mask_batch]).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad)[ | |
| x0_mask_batch | |
| ].abs().max() + 1e-4 | |
| if has_residual: | |
| assert (res.grad - res_ref.grad).abs().max() <= 4 * ( | |
| res_pt.grad - res_ref.grad | |
| ).abs().max() + 1e-4 | |
| assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * ( | |
| model_pt.weight.grad - model_ref.weight.grad | |
| ).abs().max() + 2e-4 | |
| assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * ( | |
| model_pt.bias.grad - model_ref.bias.grad | |
| ).abs().max() + 2e-4 | |
| if has_colscale: | |
| assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * ( | |
| colscale_pt.grad - colscale_ref.grad | |
| ).abs().max() + 2e-4 | |
| # @pytest.mark.parametrize('is_rms_norm', [False]) | |
| # @pytest.mark.parametrize('tied_norm', [False]) | |
| # @pytest.mark.parametrize('has_residual', [False]) | |
| # @pytest.mark.parametrize('has_x1', [True]) | |
| # @pytest.mark.parametrize('dropout_p', [0.0]) | |
| # @pytest.mark.parametrize('weight_dtype', [torch.float16]) | |
| # @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)]) | |
| # @pytest.mark.parametrize('hidden_size', [256]) | |
| def test_dropout_layer_norm_parallel_residual_training( | |
| hidden_size, | |
| input_dtype, | |
| residual_dtype, | |
| weight_dtype, | |
| dropout_p, | |
| has_x1, | |
| has_residual, | |
| tied_norm, | |
| is_rms_norm, | |
| ): | |
| if weight_dtype == torch.float16 and input_dtype == torch.bfloat16: | |
| pytest.skip() # Not supported | |
| if is_rms_norm and fused_rms_norm_affine is None: | |
| pytest.skip() # We need Apex's FusedRMSNorm to test | |
| our_layer_norm_func = ( | |
| dropout_add_layer_norm_parallel_residual | |
| if not is_rms_norm | |
| else dropout_add_rms_norm_parallel_residual | |
| ) | |
| device = "cuda" | |
| # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4) | |
| rtol, atol = (1e-3, 1e-4) | |
| # set seed | |
| torch.random.manual_seed(0) | |
| batch_size = 8 | |
| seqlen = 512 | |
| x0_pt = torch.randn( | |
| batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True | |
| ) | |
| x0 = x0_pt.detach().clone().requires_grad_() | |
| x0_ref = x0_pt.detach().clone().float().requires_grad_() | |
| if has_x1: | |
| x1_pt = torch.randn( | |
| batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True | |
| ) | |
| x1 = x1_pt.detach().clone().requires_grad_() | |
| x1_ref = x1_pt.detach().clone().float().requires_grad_() | |
| else: | |
| x1 = None | |
| if has_residual: | |
| res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True) | |
| res = res_pt.detach().clone().requires_grad_() | |
| res_ref = res_pt.detach().clone().float().requires_grad_() | |
| else: | |
| res = None | |
| weight0 = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True) | |
| bias0 = ( | |
| torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True) | |
| if not is_rms_norm | |
| else None | |
| ) | |
| weight0_pt = weight0.detach().clone().requires_grad_() | |
| weight0_ref = weight0.detach().clone().float().requires_grad_() | |
| bias0_pt = bias0.detach().clone().requires_grad_() if bias0 is not None else None | |
| bias0_ref = bias0.detach().clone().float().requires_grad_() if bias0 is not None else None | |
| if not tied_norm: | |
| weight1 = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True) | |
| bias1 = ( | |
| torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True) | |
| if not is_rms_norm | |
| else None | |
| ) | |
| weight1_pt = weight1.detach().clone().requires_grad_() | |
| weight1_ref = weight1.detach().clone().float().requires_grad_() | |
| bias1_pt = bias1.detach().clone().requires_grad_() if bias1 is not None else None | |
| bias1_ref = bias1.detach().clone().float().requires_grad_() if bias1 is not None else None | |
| else: | |
| weight1, bias1 = None, None | |
| epsilon = 1e-5 | |
| residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32 | |
| out0, out1, dmask0, dmask1 = our_layer_norm_func( | |
| x0, | |
| x1, | |
| res, | |
| weight0, | |
| bias0, | |
| weight1, | |
| bias1, | |
| dropout_p, | |
| epsilon, | |
| residual_in_fp32=residual_in_fp32, | |
| return_dropout_mask=True, | |
| ) | |
| assert out0.dtype == input_dtype | |
| if not tied_norm: | |
| assert out1.dtype == input_dtype | |
| print(f"Actual dropout fraction: {1 - dmask0.float().mean().item()}") | |
| if has_residual: | |
| if has_x1: | |
| residual_pt = ( | |
| (x0_pt.float() * dmask0.float()) / (1 - dropout_p) | |
| + (x1_pt.float() * dmask1.float()) / (1 - dropout_p) | |
| + res_pt.float() | |
| ).to(dtype=residual_dtype) | |
| residual_ref = ( | |
| (x0_ref * dmask0.float()) / (1 - dropout_p) | |
| + (x1_ref * dmask1.float()) / (1 - dropout_p) | |
| ) + res_ref | |
| else: | |
| residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p) + res_pt.float()).to( | |
| dtype=residual_dtype | |
| ) | |
| residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p) + res_ref | |
| else: | |
| if has_x1: | |
| residual_pt = ( | |
| (x0_pt.float() * dmask0.float()) / (1 - dropout_p) | |
| + (x1_pt.float() * dmask1.float()) / (1 - dropout_p) | |
| ).to(dtype=residual_dtype) | |
| residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p) + ( | |
| x1_ref * dmask1.float() | |
| ) / (1 - dropout_p) | |
| else: | |
| residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)).to( | |
| dtype=residual_dtype | |
| ) | |
| residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p) | |
| if not is_rms_norm: | |
| out0_pt = F.layer_norm( | |
| residual_pt.to(dtype=weight_dtype), (hidden_size,), weight0_pt, bias0_pt, eps=epsilon | |
| ).to(dtype=input_dtype) | |
| out0_ref = F.layer_norm(residual_ref, (hidden_size,), weight0_ref, bias0_ref, eps=epsilon) | |
| if not tied_norm: | |
| out1_pt = F.layer_norm( | |
| residual_pt.to(dtype=weight_dtype), | |
| (hidden_size,), | |
| weight1_pt, | |
| bias1_pt, | |
| eps=epsilon, | |
| ).to(dtype=input_dtype) | |
| out1_ref = F.layer_norm( | |
| residual_ref, (hidden_size,), weight1_ref, bias1_ref, eps=epsilon | |
| ) | |
| else: | |
| out0_pt = fused_rms_norm_affine( | |
| residual_pt.to(dtype=weight_dtype), weight0_pt, (hidden_size,), eps=epsilon | |
| ).to(dtype=input_dtype) | |
| out0_ref = fused_rms_norm_affine(residual_ref, weight0_ref, (hidden_size,), eps=epsilon) | |
| if not tied_norm: | |
| out1_pt = fused_rms_norm_affine( | |
| residual_pt.to(dtype=weight_dtype), weight1_pt, (hidden_size,), eps=epsilon | |
| ).to(dtype=input_dtype) | |
| out1_ref = fused_rms_norm_affine(residual_ref, weight1_ref, (hidden_size,), eps=epsilon) | |
| assert (out0 - out0_ref).abs().max() <= 4 * (out0_pt - out0_ref).abs().max() + 1e-4 | |
| if not tied_norm: | |
| assert (out1 - out1_ref).abs().max() <= 4 * (out1_pt - out1_ref).abs().max() + 1e-4 | |
| g0 = torch.randn_like(out0) / batch_size | |
| if tied_norm: | |
| out0.backward(g0) | |
| out0_pt.backward(g0) | |
| out0_ref.backward(g0) | |
| else: | |
| g1 = torch.randn_like(out1) / batch_size | |
| (out0 * g0 + out1 * g1).sum().backward() | |
| (out0_pt * g0 + out1_pt * g1).sum().backward() | |
| (out0_ref * g0 + out1_ref * g1).sum().backward() | |
| assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4 | |
| if has_x1: | |
| assert (x1.grad - x1_ref.grad).abs().max() <= 4 * ( | |
| x1_pt.grad - x1_ref.grad | |
| ).abs().max() + 1e-4 | |
| if has_residual: | |
| assert (res.grad - res_ref.grad).abs().max() <= 4 * ( | |
| res_pt.grad - res_ref.grad | |
| ).abs().max() + 1e-4 | |
| assert (weight0.grad - weight0_ref.grad).abs().max() <= 3 * ( | |
| weight0_pt.grad - weight0_ref.grad | |
| ).abs().max() + 3e-5 | |
| if not is_rms_norm: | |
| assert (bias0.grad - bias0_ref.grad).abs().max() <= 2 * ( | |
| bias0_pt.grad - bias0_ref.grad | |
| ).abs().max() + 3e-5 | |
| if not tied_norm: | |
| assert (weight1.grad - weight1_ref.grad).abs().max() <= 3 * ( | |
| weight1_pt.grad - weight1_ref.grad | |
| ).abs().max() + 3e-5 | |
| if not is_rms_norm: | |
| assert (bias1.grad - bias1_ref.grad).abs().max() <= 2 * ( | |
| bias1_pt.grad - bias1_ref.grad | |
| ).abs().max() + 3e-5 | |
| # @pytest.mark.parametrize('is_rms_norm', [False]) | |
| # @pytest.mark.parametrize('tied_norm', [False]) | |
| # @pytest.mark.parametrize('has_residual', [False]) | |
| # @pytest.mark.parametrize('has_x1', [True]) | |
| # @pytest.mark.parametrize('dropout_p', [0.0]) | |
| # @pytest.mark.parametrize('weight_dtype', [torch.float16]) | |
| # @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)]) | |
| # @pytest.mark.parametrize('hidden_size', [256]) | |
| def test_dropout_layer_norm_parallel_residual_prenorm_training( | |
| hidden_size, | |
| input_dtype, | |
| residual_dtype, | |
| weight_dtype, | |
| dropout_p, | |
| has_x1, | |
| has_residual, | |
| tied_norm, | |
| is_rms_norm, | |
| ): | |
| if weight_dtype == torch.float16 and input_dtype == torch.bfloat16: | |
| pytest.skip() # Not supported | |
| if is_rms_norm and fused_rms_norm_affine is None: | |
| pytest.skip() # We need Apex's FusedRMSNorm to test | |
| our_layer_norm_func = ( | |
| dropout_add_layer_norm_parallel_residual | |
| if not is_rms_norm | |
| else dropout_add_rms_norm_parallel_residual | |
| ) | |
| device = "cuda" | |
| # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4) | |
| rtol, atol = (1e-3, 1e-4) | |
| # set seed | |
| torch.random.manual_seed(0) | |
| batch_size = 8 | |
| seqlen = 512 | |
| x0_pt = torch.randn( | |
| batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True | |
| ) | |
| x0 = x0_pt.detach().clone().requires_grad_() | |
| x0_ref = x0_pt.detach().clone().float().requires_grad_() | |
| if has_x1: | |
| x1_pt = torch.randn( | |
| batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True | |
| ) | |
| x1 = x1_pt.detach().clone().requires_grad_() | |
| x1_ref = x1_pt.detach().clone().float().requires_grad_() | |
| else: | |
| x1 = None | |
| if has_residual: | |
| res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True) | |
| res = res_pt.detach().clone().requires_grad_() | |
| res_ref = res_pt.detach().clone().float().requires_grad_() | |
| else: | |
| res = None | |
| weight0 = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True) | |
| bias0 = ( | |
| torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True) | |
| if not is_rms_norm | |
| else None | |
| ) | |
| weight0_pt = weight0.detach().clone().requires_grad_() | |
| weight0_ref = weight0.detach().clone().float().requires_grad_() | |
| bias0_pt = bias0.detach().clone().requires_grad_() if bias0 is not None else None | |
| bias0_ref = bias0.detach().clone().float().requires_grad_() if bias0 is not None else None | |
| if not tied_norm: | |
| weight1 = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True) | |
| bias1 = ( | |
| torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True) | |
| if not is_rms_norm | |
| else None | |
| ) | |
| weight1_pt = weight1.detach().clone().requires_grad_() | |
| weight1_ref = weight1.detach().clone().float().requires_grad_() | |
| bias1_pt = bias1.detach().clone().requires_grad_() if bias1 is not None else None | |
| bias1_ref = bias1.detach().clone().float().requires_grad_() if bias1 is not None else None | |
| else: | |
| weight1, bias1 = None, None | |
| epsilon = 1e-5 | |
| residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32 | |
| out0, out1, residual, dmask0, dmask1 = our_layer_norm_func( | |
| x0, | |
| x1, | |
| res, | |
| weight0, | |
| bias0, | |
| weight1, | |
| bias1, | |
| dropout_p, | |
| epsilon, | |
| prenorm=True, | |
| residual_in_fp32=residual_in_fp32, | |
| return_dropout_mask=True, | |
| ) | |
| assert out0.dtype == input_dtype | |
| if not tied_norm: | |
| assert out1.dtype == input_dtype | |
| print(f"Actual dropout fraction: {1 - dmask0.float().mean().item()}") | |
| if has_residual: | |
| if has_x1: | |
| residual_pt = ( | |
| (x0_pt.float() * dmask0.float()) / (1 - dropout_p) | |
| + (x1_pt.float() * dmask1.float()) / (1 - dropout_p) | |
| + res_pt.float() | |
| ).to(dtype=residual_dtype) | |
| residual_ref = ( | |
| (x0_ref * dmask0.float()) / (1 - dropout_p) | |
| + (x1_ref * dmask1.float()) / (1 - dropout_p) | |
| ) + res_ref | |
| else: | |
| residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p) + res_pt.float()).to( | |
| dtype=residual_dtype | |
| ) | |
| residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p) + res_ref | |
| else: | |
| if has_x1: | |
| residual_pt = ( | |
| (x0_pt.float() * dmask0.float()) / (1 - dropout_p) | |
| + (x1_pt.float() * dmask1.float()) / (1 - dropout_p) | |
| ).to(dtype=residual_dtype) | |
| residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p) + ( | |
| x1_ref * dmask1.float() | |
| ) / (1 - dropout_p) | |
| else: | |
| residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)).to( | |
| dtype=residual_dtype | |
| ) | |
| residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p) | |
| if not is_rms_norm: | |
| out0_pt = F.layer_norm( | |
| residual_pt.to(dtype=weight_dtype), (hidden_size,), weight0_pt, bias0_pt, eps=epsilon | |
| ).to(dtype=input_dtype) | |
| out0_ref = F.layer_norm(residual_ref, (hidden_size,), weight0_ref, bias0_ref, eps=epsilon) | |
| if not tied_norm: | |
| out1_pt = F.layer_norm( | |
| residual_pt.to(dtype=weight_dtype), | |
| (hidden_size,), | |
| weight1_pt, | |
| bias1_pt, | |
| eps=epsilon, | |
| ).to(dtype=input_dtype) | |
| out1_ref = F.layer_norm( | |
| residual_ref, (hidden_size,), weight1_ref, bias1_ref, eps=epsilon | |
| ) | |
| else: | |
| out0_pt = fused_rms_norm_affine( | |
| residual_pt.to(dtype=weight_dtype), weight0_pt, (hidden_size,), eps=epsilon | |
| ).to(dtype=input_dtype) | |
| out0_ref = fused_rms_norm_affine(residual_ref, weight0_ref, (hidden_size,), eps=epsilon) | |
| if not tied_norm: | |
| out1_pt = fused_rms_norm_affine( | |
| residual_pt.to(dtype=weight_dtype), weight1_pt, (hidden_size,), eps=epsilon | |
| ).to(dtype=input_dtype) | |
| out1_ref = fused_rms_norm_affine(residual_ref, weight1_ref, (hidden_size,), eps=epsilon) | |
| assert (out0 - out0_ref).abs().max() <= 4 * (out0_pt - out0_ref).abs().max() + 1e-4 | |
| if not tied_norm: | |
| assert (out1 - out1_ref).abs().max() <= 4 * (out1_pt - out1_ref).abs().max() + 1e-4 | |
| assert (residual - residual_ref).abs().max() <= 4 * ( | |
| residual_pt - residual_ref | |
| ).abs().max() + 1e-4 | |
| g0 = torch.randn_like(out0) / batch_size | |
| if tied_norm: | |
| (out0 * F.sigmoid(residual)).backward(g0) | |
| (out0_pt * F.sigmoid(residual_pt)).backward(g0) | |
| (out0_ref * F.sigmoid(residual_ref)).backward(g0) | |
| else: | |
| g1 = torch.randn_like(out1) / batch_size | |
| (out0 * F.sigmoid(residual) * g0 + out1 * g1).sum().backward() | |
| (out0_pt * F.sigmoid(residual_pt) * g0 + out1_pt * g1).sum().backward() | |
| (out0_ref * F.sigmoid(residual_ref) * g0 + out1_ref * g1).sum().backward() | |
| assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4 | |
| if has_x1: | |
| assert (x1.grad - x1_ref.grad).abs().max() <= 4 * ( | |
| x1_pt.grad - x1_ref.grad | |
| ).abs().max() + 1e-4 | |
| if has_residual: | |
| assert (res.grad - res_ref.grad).abs().max() <= 4 * ( | |
| res_pt.grad - res_ref.grad | |
| ).abs().max() + 1e-4 | |
| assert (weight0.grad - weight0_ref.grad).abs().max() <= 3 * ( | |
| weight0_pt.grad - weight0_ref.grad | |
| ).abs().max() + 3e-5 | |
| if not is_rms_norm: | |
| assert (bias0.grad - bias0_ref.grad).abs().max() <= 2 * ( | |
| bias0_pt.grad - bias0_ref.grad | |
| ).abs().max() + 3e-5 | |
| if not tied_norm: | |
| assert (weight1.grad - weight1_ref.grad).abs().max() <= 3 * ( | |
| weight1_pt.grad - weight1_ref.grad | |
| ).abs().max() + 3e-5 | |
| if not is_rms_norm: | |
| assert (bias1.grad - bias1_ref.grad).abs().max() <= 2 * ( | |
| bias1_pt.grad - bias1_ref.grad | |
| ).abs().max() + 3e-5 | |
| def test_dropout_layer_norm_randomness(): | |
| hidden_size = 256 | |
| dtype = torch.float32 | |
| dropout_p = 0.1 | |
| device = "cuda" | |
| # set seed | |
| torch.random.manual_seed(0) | |
| batch_size = 8 | |
| seqlen = 512 | |
| x0 = torch.randn( | |
| batch_size, seqlen, hidden_size, device=device, dtype=dtype, requires_grad=True | |
| ) | |
| res = torch.randn_like(x0, dtype=dtype, requires_grad=True) | |
| model = DropoutAddLayerNorm(hidden_size, p=dropout_p, device=device, dtype=dtype) | |
| torch.random.manual_seed(42) | |
| _, dmask0 = dropout_add_layer_norm( | |
| x0, res, model.weight, model.bias, model.p, model.eps, return_dropout_mask=True | |
| ) | |
| # Subsequent call should have a different dropout mask | |
| _, dmask1 = dropout_add_layer_norm( | |
| x0, res, model.weight, model.bias, model.p, model.eps, return_dropout_mask=True | |
| ) | |
| torch.random.manual_seed(42) | |
| # Resetting the seed, should get the same dropout mask | |
| _, dmask2 = dropout_add_layer_norm( | |
| x0, res, model.weight, model.bias, model.p, model.eps, return_dropout_mask=True | |
| ) | |
| assert not torch.equal(dmask0, dmask1) | |
| assert torch.equal(dmask0, dmask2) | |