# Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn as nn class bn_track_stats: def __init__(self, module: nn.Module, condition=True): self.module = module self.enable = condition def __enter__(self): if not self.enable: for m in self.module.modules(): if isinstance(m, (torch.nn.BatchNorm2d, torch.nn.BatchNorm1d)): m.track_running_stats = False def __exit__(self ,type, value, traceback): if not self.enable: for m in self.module.modules(): if isinstance(m, (torch.nn.BatchNorm2d, torch.nn.BatchNorm1d)): m.track_running_stats = True