File size: 1,811 Bytes
eeef81e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os

import pytest
import torch

from hivemind import RemoteExpert
from hivemind.moe.server import background_server

CUSTOM_EXPERTS_PATH = os.path.join(os.path.dirname(__file__), "test_utils", "custom_networks.py")


@pytest.mark.forked
def test_custom_expert(hid_dim=16):
    with background_server(
        expert_cls="perceptron",
        num_experts=2,
        device="cpu",
        hidden_dim=hid_dim,
        num_handlers=2,
        no_dht=True,
        custom_module_path=CUSTOM_EXPERTS_PATH,
    ) as (server_endpoint, _):
        expert0 = RemoteExpert("expert.0", server_endpoint)
        expert1 = RemoteExpert("expert.1", server_endpoint)

        for batch_size in (1, 4):
            batch = torch.randn(batch_size, hid_dim)

            output0 = expert0(batch)
            output1 = expert1(batch)

            loss = output0.sum()
            loss.backward()
            loss = output1.sum()
            loss.backward()


@pytest.mark.forked
def test_multihead_expert(hid_dim=16):
    with background_server(
        expert_cls="multihead",
        num_experts=2,
        device="cpu",
        hidden_dim=hid_dim,
        num_handlers=2,
        no_dht=True,
        custom_module_path=CUSTOM_EXPERTS_PATH,
    ) as (server_endpoint, _):
        expert0 = RemoteExpert("expert.0", server_endpoint)
        expert1 = RemoteExpert("expert.1", server_endpoint)

        for batch_size in (1, 4):
            batch = (
                torch.randn(batch_size, hid_dim),
                torch.randn(batch_size, 2 * hid_dim),
                torch.randn(batch_size, 3 * hid_dim),
            )

            output0 = expert0(*batch)
            output1 = expert1(*batch)

            loss = output0.sum()
            loss.backward()
            loss = output1.sum()
            loss.backward()