| """
|
| /*****************************************************************************/
|
|
|
| BatchNorm2dSync with multi-gpu
|
|
|
| /*****************************************************************************/
|
| """
|
| from __future__ import absolute_import
|
| from __future__ import division
|
| from __future__ import print_function
|
|
|
| try:
|
|
|
| from queue import Queue
|
| except ImportError:
|
|
|
| from Queue import Queue
|
|
|
| import torch
|
| import torch.nn as nn
|
| from torch.nn import functional as F
|
| from torch.nn.parameter import Parameter
|
| from isegm.model.syncbn.modules.functional import batchnorm2d_sync
|
|
|
|
|
| class _BatchNorm(nn.Module):
|
| """
|
| Customized BatchNorm from nn.BatchNorm
|
| >> added freeze attribute to enable bn freeze.
|
| """
|
|
|
| def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
|
| track_running_stats=True):
|
| super(_BatchNorm, self).__init__()
|
| self.num_features = num_features
|
| self.eps = eps
|
| self.momentum = momentum
|
| self.affine = affine
|
| self.track_running_stats = track_running_stats
|
| self.freezed = False
|
| if self.affine:
|
| self.weight = Parameter(torch.Tensor(num_features))
|
| self.bias = Parameter(torch.Tensor(num_features))
|
| else:
|
| self.register_parameter('weight', None)
|
| self.register_parameter('bias', None)
|
| if self.track_running_stats:
|
| self.register_buffer('running_mean', torch.zeros(num_features))
|
| self.register_buffer('running_var', torch.ones(num_features))
|
| else:
|
| self.register_parameter('running_mean', None)
|
| self.register_parameter('running_var', None)
|
| self.reset_parameters()
|
|
|
| def reset_parameters(self):
|
| if self.track_running_stats:
|
| self.running_mean.zero_()
|
| self.running_var.fill_(1)
|
| if self.affine:
|
| self.weight.data.uniform_()
|
| self.bias.data.zero_()
|
|
|
| def _check_input_dim(self, input):
|
| return NotImplemented
|
|
|
| def forward(self, input):
|
| self._check_input_dim(input)
|
|
|
| compute_stats = not self.freezed and \
|
| self.training and self.track_running_stats
|
|
|
| ret = F.batch_norm(input, self.running_mean, self.running_var,
|
| self.weight, self.bias, compute_stats,
|
| self.momentum, self.eps)
|
| return ret
|
|
|
| def extra_repr(self):
|
| return '{num_features}, eps={eps}, momentum={momentum}, '\
|
| 'affine={affine}, ' \
|
| 'track_running_stats={track_running_stats}'.format(
|
| **self.__dict__)
|
|
|
|
|
| class BatchNorm2dNoSync(_BatchNorm):
|
| """
|
| Equivalent to nn.BatchNorm2d
|
| """
|
|
|
| def _check_input_dim(self, input):
|
| if input.dim() != 4:
|
| raise ValueError('expected 4D input (got {}D input)'
|
| .format(input.dim()))
|
|
|
|
|
| class BatchNorm2dSync(BatchNorm2dNoSync):
|
| """
|
| BatchNorm2d with automatic multi-GPU Sync
|
| """
|
|
|
| def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
|
| track_running_stats=True):
|
| super(BatchNorm2dSync, self).__init__(
|
| num_features, eps=eps, momentum=momentum, affine=affine,
|
| track_running_stats=track_running_stats)
|
| self.sync_enabled = True
|
| self.devices = list(range(torch.cuda.device_count()))
|
| if len(self.devices) > 1:
|
|
|
| self.worker_ids = self.devices[1:]
|
| self.master_queue = Queue(len(self.worker_ids))
|
| self.worker_queues = [Queue(1) for _ in self.worker_ids]
|
|
|
| def forward(self, x):
|
| compute_stats = not self.freezed and \
|
| self.training and self.track_running_stats
|
| if self.sync_enabled and compute_stats and len(self.devices) > 1:
|
| if x.get_device() == self.devices[0]:
|
|
|
| extra = {
|
| "is_master": True,
|
| "master_queue": self.master_queue,
|
| "worker_queues": self.worker_queues,
|
| "worker_ids": self.worker_ids
|
| }
|
| else:
|
|
|
| extra = {
|
| "is_master": False,
|
| "master_queue": self.master_queue,
|
| "worker_queue": self.worker_queues[
|
| self.worker_ids.index(x.get_device())]
|
| }
|
| return batchnorm2d_sync(x, self.weight, self.bias,
|
| self.running_mean, self.running_var,
|
| extra, compute_stats, self.momentum,
|
| self.eps)
|
| return super(BatchNorm2dSync, self).forward(x)
|
|
|
| def __repr__(self):
|
| """repr"""
|
| rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \
|
| 'affine={affine}, ' \
|
| 'track_running_stats={track_running_stats},' \
|
| 'devices={devices})'
|
| return rep.format(name=self.__class__.__name__, **self.__dict__)
|
|
|
|
|
| BatchNorm2d = BatchNorm2dSync
|
|
|