Spaces:
Sleeping
Sleeping
| import re | |
| import pytest | |
| import torch | |
| from flash_attn.models.vit import vit_base_patch16_224 as flash_vit_base_patch16_224 | |
| from timm.models.vision_transformer import vit_base_patch16_224 | |
| # @pytest.mark.parametrize('fused_mlp', [False]) | |
| # @pytest.mark.parametrize('optimized', [True]) | |
| def test_vit(optimized, fused_mlp): | |
| """Check that our implementation of ViT matches the timm's implementation: | |
| the output of our forward pass in fp16 should be around the same as | |
| timm' forward pass in fp16, when compared to timm's forward pass in fp32. | |
| """ | |
| dtype = torch.float16 | |
| device = "cuda" | |
| kwargs = {} | |
| if optimized: | |
| kwargs = dict(use_flash_attn=True, fused_bias_fc=True, fused_dropout_add_ln=True) | |
| kwargs["fused_mlp"] = fused_mlp | |
| model = flash_vit_base_patch16_224(**kwargs).to(device=device, dtype=dtype) | |
| model_ref = vit_base_patch16_224(pretrained=True).to(device=device) | |
| model_timm = vit_base_patch16_224(pretrained=True).to(device=device, dtype=dtype) | |
| model.load_state_dict(model_ref.state_dict()) | |
| model.eval() | |
| model_ref.eval() | |
| model_timm.eval() | |
| torch.manual_seed(0) | |
| batch_size = 2 | |
| x = torch.randn(batch_size, 3, 224, 224, device=device, dtype=dtype) | |
| out = model(x) | |
| out_timm = model_timm(x) | |
| out_ref = model_ref(x.float()) | |
| print(f"Output max diff: {(out - out_ref).abs().max().item()}") | |
| print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") | |
| print(f"timm fp16 max diff: {(out_timm - out_ref).abs().max().item()}") | |
| print(f"timm fp16 mean diff: {(out_timm - out_ref).abs().mean().item()}") | |
| rtol = 2 if not fused_mlp else 8 | |
| assert (out - out_ref).abs().max().item() < rtol * (out_timm - out_ref).abs().max().item() | |