Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
| # | |
| # This source code is licensed under the BSD license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import pytest | |
| import torch | |
| from xformers.components import Activation | |
| from xformers.components.feedforward import FEEDFORWARD_REGISTRY, build_feedforward | |
| from xformers.components.feedforward.mixture_of_experts import GateConfig | |
| from xformers.helpers.test_utils import init_torch_distributed_local | |
| BATCH = 4 | |
| SEQ = 256 | |
| EMBD = 16 | |
| LATENT = 128 | |
| DROPOUT = 0.5 | |
| DEVICES = ( | |
| [torch.device("cpu")] if not torch.cuda.is_available() else [torch.device("cuda")] | |
| ) | |
| assert FEEDFORWARD_REGISTRY.keys(), "Feedforward layers should have been registered" | |
| def test_feedforward( | |
| feedforward_name: str, activation: Activation, device: torch.device | |
| ): | |
| test_config = { | |
| "name": feedforward_name, | |
| "dim_model": LATENT, | |
| "dropout": DROPOUT, | |
| "activation": activation, | |
| "hidden_layer_multiplier": 4, | |
| "number_of_experts": 4, # MoE | |
| "gate": "top_2", # MoE | |
| } | |
| if feedforward_name == "MixtureOfExperts": | |
| init_torch_distributed_local() | |
| # dummy, just check construction and dimensions in the FW pass | |
| ffw = build_feedforward(test_config) | |
| if ffw.requires_cuda and not device.type == "cuda": | |
| # pyre-fixme[29]: The library function `pytest.skip` is not supported by Pyre. | |
| pytest.skip("This MLP requires CUDA and current device does not match") | |
| inputs = torch.rand(BATCH, SEQ, LATENT, device=device) | |
| ffw = ffw.to(device) | |
| _ = ffw(inputs) | |
| def get_expert(): | |
| return torch.nn.Linear(LATENT, LATENT, bias=False) | |
| def test_moe(gate, number_of_local_experts, expert_constructor): | |
| test_config = { | |
| "name": "MixtureOfExperts", | |
| "dim_model": LATENT, | |
| "dropout": DROPOUT, | |
| "activation": Activation.ReLU, | |
| "hidden_layer_multiplier": 4, | |
| "number_of_experts": 4, | |
| "number_of_local_experts": number_of_local_experts, | |
| "gate": gate, | |
| "expert_constructor": expert_constructor, | |
| } | |
| init_torch_distributed_local() | |
| # dummy, just check construction and dimensions in the FW pass | |
| ffw = build_feedforward(test_config) | |
| inputs = torch.rand(BATCH, SEQ, LATENT, device=torch.device("cuda")) | |
| ffw = ffw.to(torch.device("cuda")) | |
| outputs = ffw(inputs) | |
| loss = torch.sum(outputs) | |
| loss.backward() | |