Spaces:
Sleeping
Sleeping
| import os | |
| import pytest | |
| import torch | |
| from hivemind import RemoteExpert | |
| from hivemind.moe.server import background_server | |
| CUSTOM_EXPERTS_PATH = os.path.join(os.path.dirname(__file__), "test_utils", "custom_networks.py") | |
| def test_custom_expert(hid_dim=16): | |
| with background_server( | |
| expert_cls="perceptron", | |
| num_experts=2, | |
| device="cpu", | |
| hidden_dim=hid_dim, | |
| num_handlers=2, | |
| no_dht=True, | |
| custom_module_path=CUSTOM_EXPERTS_PATH, | |
| ) as (server_endpoint, _): | |
| expert0 = RemoteExpert("expert.0", server_endpoint) | |
| expert1 = RemoteExpert("expert.1", server_endpoint) | |
| for batch_size in (1, 4): | |
| batch = torch.randn(batch_size, hid_dim) | |
| output0 = expert0(batch) | |
| output1 = expert1(batch) | |
| loss = output0.sum() | |
| loss.backward() | |
| loss = output1.sum() | |
| loss.backward() | |
| def test_multihead_expert(hid_dim=16): | |
| with background_server( | |
| expert_cls="multihead", | |
| num_experts=2, | |
| device="cpu", | |
| hidden_dim=hid_dim, | |
| num_handlers=2, | |
| no_dht=True, | |
| custom_module_path=CUSTOM_EXPERTS_PATH, | |
| ) as (server_endpoint, _): | |
| expert0 = RemoteExpert("expert.0", server_endpoint) | |
| expert1 = RemoteExpert("expert.1", server_endpoint) | |
| for batch_size in (1, 4): | |
| batch = ( | |
| torch.randn(batch_size, hid_dim), | |
| torch.randn(batch_size, 2 * hid_dim), | |
| torch.randn(batch_size, 3 * hid_dim), | |
| ) | |
| output0 = expert0(*batch) | |
| output1 = expert1(*batch) | |
| loss = output0.sum() | |
| loss.backward() | |
| loss = output1.sum() | |
| loss.backward() | |