| | """ Distributed training/validation utils |
| | |
| | Hacked together by / Copyright 2020 Ross Wightman |
| | """ |
| | import torch |
| | from torch import distributed as dist |
| |
|
| | from .model import unwrap_model |
| |
|
| |
|
| | def reduce_tensor(tensor, n): |
| | rt = tensor.clone() |
| | dist.all_reduce(rt, op=dist.ReduceOp.SUM) |
| | rt /= n |
| | return rt |
| |
|
| |
|
| | def distribute_bn(model, world_size, reduce=False): |
| | |
| | for bn_name, bn_buf in unwrap_model(model).named_buffers(recurse=True): |
| | if ('running_mean' in bn_name) or ('running_var' in bn_name): |
| | if reduce: |
| | |
| | torch.distributed.all_reduce(bn_buf, op=dist.ReduceOp.SUM) |
| | bn_buf /= float(world_size) |
| | else: |
| | |
| | torch.distributed.broadcast(bn_buf, 0) |
| |
|