Spaces:
Sleeping
Sleeping
| import os | |
| import unittest | |
| import torch | |
| import torch.distributed as dist | |
| from torch.multiprocessing import Process | |
| import torch.nn as nn | |
| from machina.optims import DistributedAdamW | |
| def init_processes(rank, world_size, | |
| function, backend='tcp'): | |
| os.environ['MASTER_ADDR'] = '127.0.0.1' | |
| os.environ['MASTER_PORT'] = '29500' | |
| dist.init_process_group(backend, rank=rank, | |
| world_size=world_size) | |
| function(rank, world_size) | |
| class TestDistributedAdamW(unittest.TestCase): | |
| def test_step(self): | |
| def _run(rank, world_size): | |
| model = nn.Linear(10, 1) | |
| optimizer = DistributedAdamW( | |
| model.parameters()) | |
| optimizer.zero_grad() | |
| loss = model(torch.ones(10).float()) | |
| loss.backward() | |
| optimizer.step() | |
| processes = [] | |
| world_size = 4 | |
| for rank in range(world_size): | |
| p = Process(target=init_processes, | |
| args=(rank, | |
| world_size, | |
| _run)) | |
| p.start() | |
| processes.append(p) | |
| for p in processes: | |
| p.join() | |