File size: 912 Bytes
377dccd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
class bn_track_stats:
def __init__(self, module: nn.Module, condition=True):
self.module = module
self.enable = condition
def __enter__(self):
if not self.enable:
for m in self.module.modules():
if isinstance(m, (torch.nn.BatchNorm2d, torch.nn.BatchNorm1d)):
m.track_running_stats = False
def __exit__(self ,type, value, traceback):
if not self.enable:
for m in self.module.modules():
if isinstance(m, (torch.nn.BatchNorm2d, torch.nn.BatchNorm1d)):
m.track_running_stats = True |