Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
| # | |
| # This source code is licensed under the BSD license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import pytest | |
| import torch | |
| from xformers.components import PatchEmbeddingConfig, build_patch_embedding | |
| from xformers.components.positional_embedding import ( | |
| POSITION_EMBEDDING_REGISTRY, | |
| build_positional_embedding, | |
| ) | |
| BATCH = 20 | |
| SEQ = 512 | |
| MODEL = 384 | |
| assert ( | |
| POSITION_EMBEDDING_REGISTRY.keys() | |
| ), "Positional encoding layers should have been registered" | |
| def test_dimensions(encoding_name: str, dropout: float): | |
| test_config = { | |
| "name": encoding_name, | |
| "dim_model": MODEL, | |
| "vocab_size": 32, | |
| "dropout": dropout, | |
| "seq_len": SEQ, | |
| } | |
| # dummy, just check construction and dimensions in the FW pass | |
| encoding = build_positional_embedding(test_config) | |
| inputs = (torch.rand(BATCH, SEQ) * 10).abs().to(torch.int) | |
| _ = encoding(inputs) | |
| # Test that inputs having an embedding dimension would also work out | |
| if "name" == "sine": | |
| inputs = (torch.rand(BATCH, SEQ, MODEL) * 10).abs().to(torch.int) | |
| _ = encoding(inputs) | |
| def test_patch_embedding(): | |
| patch_embedding_config = { | |
| "in_channels": 3, | |
| "out_channels": 64, | |
| "kernel_size": 7, | |
| "stride": 4, | |
| "padding": 2, | |
| } | |
| # dummy, just check construction and dimensions in the FW pass | |
| patch_emb = build_patch_embedding(PatchEmbeddingConfig(**patch_embedding_config)) | |
| # Check BHWC | |
| inputs = torch.rand(BATCH, 32 * 32, 3) | |
| out = patch_emb(inputs) | |
| assert out.shape[-1] == 64 | |
| # Check BCHW | |
| inputs = torch.rand(BATCH, 3, 32, 32) | |
| out = patch_emb(inputs) | |
| assert out.shape[-1] == 64 | |