File size: 1,199 Bytes
eeef81e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import os
import unittest

import torch
import torch.distributed as dist
from torch.multiprocessing import Process
import torch.nn as nn

from machina.optims import DistributedAdamW


def init_processes(rank, world_size,
                   function, backend='tcp'):
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(backend, rank=rank,
                            world_size=world_size)
    function(rank, world_size)


class TestDistributedAdamW(unittest.TestCase):

    def test_step(self):

        def _run(rank, world_size):
            model = nn.Linear(10, 1)
            optimizer = DistributedAdamW(
                model.parameters())

            optimizer.zero_grad()
            loss = model(torch.ones(10).float())
            loss.backward()
            optimizer.step()

        processes = []
        world_size = 4
        for rank in range(world_size):
            p = Process(target=init_processes,
                        args=(rank,
                              world_size,
                              _run))
            p.start()
            processes.append(p)
        for p in processes:
            p.join()