Spaces:
Sleeping
Sleeping
File size: 1,811 Bytes
eeef81e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 | import os
import pytest
import torch
from hivemind import RemoteExpert
from hivemind.moe.server import background_server
CUSTOM_EXPERTS_PATH = os.path.join(os.path.dirname(__file__), "test_utils", "custom_networks.py")
@pytest.mark.forked
def test_custom_expert(hid_dim=16):
with background_server(
expert_cls="perceptron",
num_experts=2,
device="cpu",
hidden_dim=hid_dim,
num_handlers=2,
no_dht=True,
custom_module_path=CUSTOM_EXPERTS_PATH,
) as (server_endpoint, _):
expert0 = RemoteExpert("expert.0", server_endpoint)
expert1 = RemoteExpert("expert.1", server_endpoint)
for batch_size in (1, 4):
batch = torch.randn(batch_size, hid_dim)
output0 = expert0(batch)
output1 = expert1(batch)
loss = output0.sum()
loss.backward()
loss = output1.sum()
loss.backward()
@pytest.mark.forked
def test_multihead_expert(hid_dim=16):
with background_server(
expert_cls="multihead",
num_experts=2,
device="cpu",
hidden_dim=hid_dim,
num_handlers=2,
no_dht=True,
custom_module_path=CUSTOM_EXPERTS_PATH,
) as (server_endpoint, _):
expert0 = RemoteExpert("expert.0", server_endpoint)
expert1 = RemoteExpert("expert.1", server_endpoint)
for batch_size in (1, 4):
batch = (
torch.randn(batch_size, hid_dim),
torch.randn(batch_size, 2 * hid_dim),
torch.randn(batch_size, 3 * hid_dim),
)
output0 = expert0(*batch)
output1 = expert1(*batch)
loss = output0.sum()
loss.backward()
loss = output1.sum()
loss.backward()
|