|
|
import torch.nn as nn |
|
|
from methods.source import Source |
|
|
from methods.bn import AlphaBatchNorm, EMABatchNorm |
|
|
from utils.registry import ADAPTATION_REGISTRY |
|
|
|
|
|
|
|
|
@ADAPTATION_REGISTRY.register() |
|
|
class BNTest(Source): |
|
|
def __init__(self, cfg, model, num_classes): |
|
|
super().__init__(cfg, model, num_classes) |
|
|
|
|
|
def configure_model(self): |
|
|
self.model.eval() |
|
|
self.model.requires_grad_(False) |
|
|
|
|
|
for m in self.model.modules(): |
|
|
|
|
|
if (isinstance(m, nn.BatchNorm1d) and self.batch_size > 1) or isinstance(m, (nn.BatchNorm2d, nn.BatchNorm3d)): |
|
|
m.train() |
|
|
|
|
|
|
|
|
@ADAPTATION_REGISTRY.register() |
|
|
class BNAlpha(Source): |
|
|
def __init__(self, cfg, model, num_classes): |
|
|
super().__init__(cfg, model, num_classes) |
|
|
|
|
|
def configure_model(self): |
|
|
self.model.eval() |
|
|
self.model.requires_grad_(False) |
|
|
|
|
|
|
|
|
self.model = AlphaBatchNorm.adapt_model(self.model, alpha=self.cfg.BN.ALPHA).to(self.device) |
|
|
|
|
|
|
|
|
@ADAPTATION_REGISTRY.register() |
|
|
class BNEMA(Source): |
|
|
def __init__(self, cfg, model, num_classes): |
|
|
super().__init__(cfg, model, num_classes) |
|
|
|
|
|
def configure_model(self): |
|
|
self.model.eval() |
|
|
self.model.requires_grad_(False) |
|
|
self.model = EMABatchNorm.adapt_model(self.model).to(self.device) |
|
|
|