''' Code from Robust Overfitting may be mitigated by properly learned smoothening URL: https://github.com/VITA-Group/Alleviate-Robust-Overfitting ''' import torch def moving_average(net1, net2, alpha=1): for param1, param2 in zip(net1.parameters(), net2.parameters()): param1.data *= (1.0 - alpha) param1.data += param2.data * alpha def _check_bn(module, flag): if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): flag[0] = True def check_bn(model): flag = [False] model.apply(lambda module: _check_bn(module, flag)) return flag[0] def reset_bn(module): if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): module.running_mean = torch.zeros_like(module.running_mean) module.running_var = torch.ones_like(module.running_var) def _get_momenta(module, momenta): if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): momenta[module] = module.momentum def _set_momenta(module, momenta): if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): module.momentum = momenta[module] def bn_update(loader, model): """ BatchNorm buffers update (if any). Performs 1 epochs to estimate buffers average using train dataset. :param loader: train dataset loader for buffers average estimation. :param model: model being update :return: None """ if not check_bn(model): # print(check_bn(model)) return model.train() momenta = {} model.apply(reset_bn) model.apply(lambda module: _get_momenta(module, momenta)) n = 0 for input, _ in loader: input = input.cuda() b = input.data.size(0) momentum = b / (n + b) for module in momenta.keys(): module.momentum = momentum model(input) n += b model.apply(lambda module: _set_momenta(module, momenta))