codewraith / data /source_files /clean /2d1c1112bd6d.py
slenk's picture
Upload folder using huggingface_hub
eeef81e verified
import os
import pytest
import torch
from hivemind import RemoteExpert
from hivemind.moe.server import background_server
CUSTOM_EXPERTS_PATH = os.path.join(os.path.dirname(__file__), "test_utils", "custom_networks.py")
@pytest.mark.forked
def test_custom_expert(hid_dim=16):
with background_server(
expert_cls="perceptron",
num_experts=2,
device="cpu",
hidden_dim=hid_dim,
num_handlers=2,
no_dht=True,
custom_module_path=CUSTOM_EXPERTS_PATH,
) as (server_endpoint, _):
expert0 = RemoteExpert("expert.0", server_endpoint)
expert1 = RemoteExpert("expert.1", server_endpoint)
for batch_size in (1, 4):
batch = torch.randn(batch_size, hid_dim)
output0 = expert0(batch)
output1 = expert1(batch)
loss = output0.sum()
loss.backward()
loss = output1.sum()
loss.backward()
@pytest.mark.forked
def test_multihead_expert(hid_dim=16):
with background_server(
expert_cls="multihead",
num_experts=2,
device="cpu",
hidden_dim=hid_dim,
num_handlers=2,
no_dht=True,
custom_module_path=CUSTOM_EXPERTS_PATH,
) as (server_endpoint, _):
expert0 = RemoteExpert("expert.0", server_endpoint)
expert1 = RemoteExpert("expert.1", server_endpoint)
for batch_size in (1, 4):
batch = (
torch.randn(batch_size, hid_dim),
torch.randn(batch_size, 2 * hid_dim),
torch.randn(batch_size, 3 * hid_dim),
)
output0 = expert0(*batch)
output1 = expert1(*batch)
loss = output0.sum()
loss.backward()
loss = output1.sum()
loss.backward()