Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
| # | |
| # This source code is licensed under the BSD license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import pytest | |
| import torch | |
| from xformers.components import NormalizationType, PreNorm | |
| class Passthrough(torch.nn.Module): | |
| def __init__(self) -> None: | |
| super().__init__() | |
| def forward(self, *args): | |
| return args | |
| def test_pre_norm(normalization): | |
| # Check that passing the same tensor a bunch of times skips the extra normalizations | |
| x = torch.rand((3, 3), requires_grad=True) | |
| wrap = PreNorm( | |
| d_norm=3, sublayer=Passthrough(), normalization=normalization, use_triton=False | |
| ) | |
| outputs = wrap(inputs=[x, x, x]) | |
| assert id(outputs[0]) == id(outputs[1]) | |
| # Check the BW pass | |
| torch.sum(outputs[0]).backward() | |