Image Classification
English
TTA
GuillaumeVray
Uploading files
02ba886
# Copyright 2020-2021 Evgenia Rusak, Steffen Schneider, George Pachitariu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# ---
# This licence notice applies to all originally written code by the
# authors. Code taken from other open-source projects is indicated.
# See NOTICE for a list of all third-party licences used in the project.
"""Batch norm variants
AlphaBatchNorm builds upon: https://github.com/bethgelab/robustness/blob/main/robusta/batchnorm/bn.py
"""
from torch import nn
from torch.nn import functional as F
class AlphaBatchNorm(nn.Module):
""" Use the source statistics as a prior on the target statistics """
@staticmethod
def find_bns(parent, alpha):
replace_mods = []
if parent is None:
return []
for name, child in parent.named_children():
if isinstance(child, nn.BatchNorm2d):
module = AlphaBatchNorm(child, alpha)
replace_mods.append((parent, name, module))
else:
replace_mods.extend(AlphaBatchNorm.find_bns(child, alpha))
return replace_mods
@staticmethod
def adapt_model(model, alpha):
replace_mods = AlphaBatchNorm.find_bns(model, alpha)
print(f"| Found {len(replace_mods)} modules to be replaced.")
for (parent, name, child) in replace_mods:
setattr(parent, name, child)
return model
def __init__(self, layer, alpha):
assert alpha >= 0 and alpha <= 1
super().__init__()
self.layer = layer
self.layer.eval()
self.alpha = alpha
self.norm = nn.BatchNorm2d(self.layer.num_features, affine=False, momentum=1.0)
def forward(self, input):
self.norm(input)
running_mean = ((1 - self.alpha) * self.layer.running_mean + self.alpha * self.norm.running_mean)
running_var = ((1 - self.alpha) * self.layer.running_var + self.alpha * self.norm.running_var)
return F.batch_norm(
input,
running_mean,
running_var,
self.layer.weight,
self.layer.bias,
False,
0,
self.layer.eps,
)
class EMABatchNorm(nn.Module):
@staticmethod
def adapt_model(model):
model = EMABatchNorm(model)
return model
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, x):
# store statistics, but discard result
self.model.train()
self.model(x)
# store statistics, use the stored stats
self.model.eval()
return self.model(x)