|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Batch norm variants |
|
|
AlphaBatchNorm builds upon: https://github.com/bethgelab/robustness/blob/main/robusta/batchnorm/bn.py |
|
|
""" |
|
|
|
|
|
from torch import nn |
|
|
from torch.nn import functional as F |
|
|
|
|
|
|
|
|
class AlphaBatchNorm(nn.Module): |
|
|
""" Use the source statistics as a prior on the target statistics """ |
|
|
|
|
|
@staticmethod |
|
|
def find_bns(parent, alpha): |
|
|
replace_mods = [] |
|
|
if parent is None: |
|
|
return [] |
|
|
for name, child in parent.named_children(): |
|
|
if isinstance(child, nn.BatchNorm2d): |
|
|
module = AlphaBatchNorm(child, alpha) |
|
|
replace_mods.append((parent, name, module)) |
|
|
else: |
|
|
replace_mods.extend(AlphaBatchNorm.find_bns(child, alpha)) |
|
|
|
|
|
return replace_mods |
|
|
|
|
|
@staticmethod |
|
|
def adapt_model(model, alpha): |
|
|
replace_mods = AlphaBatchNorm.find_bns(model, alpha) |
|
|
print(f"| Found {len(replace_mods)} modules to be replaced.") |
|
|
for (parent, name, child) in replace_mods: |
|
|
setattr(parent, name, child) |
|
|
return model |
|
|
|
|
|
def __init__(self, layer, alpha): |
|
|
assert alpha >= 0 and alpha <= 1 |
|
|
|
|
|
super().__init__() |
|
|
self.layer = layer |
|
|
self.layer.eval() |
|
|
self.alpha = alpha |
|
|
|
|
|
self.norm = nn.BatchNorm2d(self.layer.num_features, affine=False, momentum=1.0) |
|
|
|
|
|
def forward(self, input): |
|
|
self.norm(input) |
|
|
|
|
|
running_mean = ((1 - self.alpha) * self.layer.running_mean + self.alpha * self.norm.running_mean) |
|
|
running_var = ((1 - self.alpha) * self.layer.running_var + self.alpha * self.norm.running_var) |
|
|
|
|
|
return F.batch_norm( |
|
|
input, |
|
|
running_mean, |
|
|
running_var, |
|
|
self.layer.weight, |
|
|
self.layer.bias, |
|
|
False, |
|
|
0, |
|
|
self.layer.eps, |
|
|
) |
|
|
|
|
|
|
|
|
class EMABatchNorm(nn.Module): |
|
|
@staticmethod |
|
|
def adapt_model(model): |
|
|
model = EMABatchNorm(model) |
|
|
return model |
|
|
|
|
|
def __init__(self, model): |
|
|
super().__init__() |
|
|
self.model = model |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
self.model.train() |
|
|
self.model(x) |
|
|
|
|
|
self.model.eval() |
|
|
return self.model(x) |
|
|
|