File size: 3,762 Bytes
29658b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import unittest

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from accelerate.utils import set_seed

from specforge.distributed import init_distributed
from specforge.layers import ParallelLMHead, VocabParallelEmbedding
from tests.utils import get_available_port


def run_lm_head(rank, world_size, port):
    os.environ["RANK"] = str(rank)
    os.environ["WORLD_SIZE"] = str(world_size)
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = str(port)
    init_distributed(tp_size=world_size)
    set_seed(42)

    # ===============================
    # Case 1: the output vocab size is divisible by the TP size
    # ===============================
    # create data
    data = torch.rand(1, 128, 256).cuda()

    for bias in [True, False]:
        # create layers
        native_lm_head = torch.nn.Linear(256, 512, bias=bias).cuda()
        sf_lm_head = ParallelLMHead(256, 512, bias=bias).cuda()
        sf_lm_head.load_state_dict(native_lm_head.state_dict())

        # forward
        native_output = native_lm_head(data)
        sf_output = sf_lm_head(data, gather_output=True)

        # check
        assert torch.allclose(
            native_output, sf_output, rtol=1e-5, atol=1e-5
        ), f"bias: {bias}, native_output: \n{native_output}, \nsf_output: \n{sf_output}"

        # ===============================
        # Case 2: the output vocab size is not divisible by the TP size
        # ===============================
        # create data
        data = torch.rand(1, 128, 256).cuda()

        # create layers
        native_lm_head = torch.nn.Linear(256, 377, bias=bias).cuda()
        sf_lm_head = ParallelLMHead(256, 377, bias=bias).cuda()
        sf_lm_head.load_state_dict(native_lm_head.state_dict())

        # forward
        native_output = native_lm_head(data)
        sf_output = sf_lm_head(data, gather_output=True)

        # check
        assert torch.allclose(
            native_output, sf_output, rtol=1e-5, atol=1e-5
        ), f"bias: {bias}, native_output: \n{native_output}, \nsf_output: \n{sf_output}"

        # ===============================
        # Case 3: tie word embedding
        # ===============================
        if not bias:
            # there is no bias in the embedding layer so we skip when bias is True
            # create data
            data = torch.rand(128, 256).cuda()

            # create native layers
            native_embedding = torch.nn.Embedding(512, 256).cuda()
            native_lm_head = torch.nn.Linear(256, 512, bias=bias).cuda()
            native_lm_head.weight = native_embedding.weight

            # create specforge layers
            sf_embedding = VocabParallelEmbedding(512, 256).cuda()
            sf_embedding.load_state_dict(native_embedding.state_dict())
            sf_lm_head = ParallelLMHead(256, 512, bias=bias).cuda()
            sf_lm_head.weight = sf_embedding.weight

            # forward
            native_output = native_lm_head(data)
            sf_output = sf_lm_head(data, gather_output=True)

            # check
            assert torch.allclose(
                native_output, sf_output, rtol=1e-5, atol=1e-5
            ), f"bias: {bias}, native_output: \n{native_output}, \nsf_output: \n{sf_output}"

    dist.destroy_process_group()


class TestLMHead(unittest.TestCase):

    def test_lm_head(self):
        port = get_available_port()
        mp.spawn(run_lm_head, nprocs=2, args=(2, port))

        port = get_available_port()
        mp.spawn(run_lm_head, nprocs=1, args=(1, port))


if __name__ == "__main__":
    suite = unittest.TestSuite()
    suite.addTest(unittest.makeSuite(TestLMHead))
    runner = unittest.TextTestRunner(verbosity=2)
    runner.run(suite)