Yaning1001's picture
Add files using upload-large-folder tool
64470b3 verified
'''
Code from Robust Overfitting may be mitigated by properly learned smoothening
URL: https://github.com/VITA-Group/Alleviate-Robust-Overfitting
'''
import torch
def moving_average(net1, net2, alpha=1):
for param1, param2 in zip(net1.parameters(), net2.parameters()):
param1.data *= (1.0 - alpha)
param1.data += param2.data * alpha
def _check_bn(module, flag):
if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
flag[0] = True
def check_bn(model):
flag = [False]
model.apply(lambda module: _check_bn(module, flag))
return flag[0]
def reset_bn(module):
if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
module.running_mean = torch.zeros_like(module.running_mean)
module.running_var = torch.ones_like(module.running_var)
def _get_momenta(module, momenta):
if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
momenta[module] = module.momentum
def _set_momenta(module, momenta):
if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
module.momentum = momenta[module]
def bn_update(loader, model):
"""
BatchNorm buffers update (if any).
Performs 1 epochs to estimate buffers average using train dataset.
:param loader: train dataset loader for buffers average estimation.
:param model: model being update
:return: None
"""
if not check_bn(model):
# print(check_bn(model))
return
model.train()
momenta = {}
model.apply(reset_bn)
model.apply(lambda module: _get_momenta(module, momenta))
n = 0
for input, _ in loader:
input = input.cuda()
b = input.data.size(0)
momentum = b / (n + b)
for module in momenta.keys():
module.momentum = momentum
model(input)
n += b
model.apply(lambda module: _set_momenta(module, momenta))