import torch import torch.nn.functional as F import torchvision from torch import nn from torchvision import transforms from scripts.dynamic.blocks import * class Normalize: def __init__(self, opt, expected_values, variance): self.n_channels = opt.input_channel self.expected_values = expected_values self.variance = variance assert self.n_channels == len(self.expected_values) def __call__(self, x): x_clone = x.clone() for channel in range(self.n_channels): x_clone[:, channel] = (x[:, channel] - self.expected_values[channel]) / self.variance[channel] return x_clone class Denormalize: def __init__(self, opt, expected_values, variance): self.n_channels = opt.input_channel self.expected_values = expected_values self.variance = variance assert self.n_channels == len(self.expected_values) def __call__(self, x): x_clone = x.clone() for channel in range(self.n_channels): x_clone[:, channel] = x[:, channel] * self.variance[channel] + self.expected_values[channel] return x_clone # ---------------------------- Generators ----------------------------# class Generator(nn.Sequential): def __init__(self, opt, out_channels=None): super(Generator, self).__init__() if opt.dataset == "mnist": channel_init = 16 steps = 2 else: channel_init = 32 steps = 3 channel_current = opt.input_channel channel_next = channel_init for step in range(steps): self.add_module("convblock_down_{}".format(2 * step), Conv2dBlock(channel_current, channel_next)) self.add_module("convblock_down_{}".format(2 * step + 1), Conv2dBlock(channel_next, channel_next)) self.add_module("downsample_{}".format(step), DownSampleBlock()) if step < steps - 1: channel_current = channel_next channel_next *= 2 self.add_module("convblock_middle", Conv2dBlock(channel_next, channel_next)) channel_current = channel_next channel_next = channel_current // 2 for step in range(steps): self.add_module("upsample_{}".format(step), UpSampleBlock()) self.add_module("convblock_up_{}".format(2 * step), Conv2dBlock(channel_current, channel_current)) if step == steps - 1: self.add_module( "convblock_up_{}".format(2 * step + 1), Conv2dBlock(channel_current, channel_next, relu=False) ) else: self.add_module("convblock_up_{}".format(2 * step + 1), Conv2dBlock(channel_current, channel_next)) channel_current = channel_next channel_next = channel_next // 2 if step == steps - 2: if out_channels is None: channel_next = opt.input_channel else: channel_next = out_channels self._EPSILON = 1e-7 self._normalizer = self._get_normalize(opt) self._denormalizer = self._get_denormalize(opt) def _get_denormalize(self, opt): if opt.dataset == "cifar10": denormalizer = Denormalize(opt, [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]) elif opt.dataset == "mnist": denormalizer = Denormalize(opt, [0.5], [0.5]) elif opt.dataset == "gtsrb": denormalizer = None else: raise Exception("Invalid dataset") return denormalizer def _get_normalize(self, opt): if opt.dataset == "cifar10": normalizer = Normalize(opt, [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]) elif opt.dataset == "mnist": normalizer = Normalize(opt, [0.5], [0.5]) elif opt.dataset == "gtsrb": normalizer = None else: raise Exception("Invalid dataset") return normalizer def forward(self, x): for module in self.children(): x = module(x) x = nn.Tanh()(x) / (2 + self._EPSILON) + 0.5 return x def normalize_pattern(self, x): if self._normalizer: x = self._normalizer(x) return x def denormalize_pattern(self, x): if self._denormalizer: x = self._denormalizer(x) return x def threshold(self, x): return nn.Tanh()(x * 20 - 10) / (2 + self._EPSILON) + 0.5 # ---------------------------- Classifiers ----------------------------# class NetC_MNIST(nn.Module): def __init__(self): super(NetC_MNIST, self).__init__() self.conv1 = nn.Conv2d(1, 32, (5, 5), 1, 0) self.relu2 = nn.ReLU(inplace=True) self.dropout3 = nn.Dropout(0.1) self.maxpool4 = nn.MaxPool2d((2, 2)) self.conv5 = nn.Conv2d(32, 64, (5, 5), 1, 0) self.relu6 = nn.ReLU(inplace=True) self.dropout7 = nn.Dropout(0.1) self.maxpool5 = nn.MaxPool2d((2, 2)) self.flatten = nn.Flatten() self.linear6 = nn.Linear(64 * 4 * 4, 512) self.relu7 = nn.ReLU(inplace=True) self.dropout8 = nn.Dropout(0.1) self.linear9 = nn.Linear(512, 10) def forward(self, x): for module in self.children(): x = module(x) return x