File size: 1,083 Bytes
fc16538
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
# TRI-VIDAR - Copyright 2022 Toyota Research Institute.  All rights reserved.

import os

import fire
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

from vidar.core.trainer import Trainer
from vidar.core.wrapper import Wrapper
from vidar.utils.config import read_config


def train(cfg, **kwargs):

    os.environ['DIST_MODE'] = 'ddp'

    cfg = read_config(cfg, **kwargs)

    mp.spawn(main_worker,
             nprocs=torch.cuda.device_count(),
             args=(cfg,), join=True)


def main_worker(gpu, cfg):

    torch.cuda.set_device(gpu)
    world_size = torch.cuda.device_count()

    os.environ['RANK'] = str(gpu)
    os.environ['WORLD_SIZE'] = str(world_size)
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    os.environ['DIST_MODE'] = 'ddp'

    dist.init_process_group(backend='nccl', world_size=world_size, rank=gpu)

    wrapper = Wrapper(cfg, verbose=True)
    trainer = Trainer(cfg)
    trainer.learn(wrapper)

    dist.destroy_process_group()


if __name__ == '__main__':
    fire.Fire(train)