Spaces:
Running
Running
| import torch | |
| import torch.nn.functional as F | |
| import torchvision | |
| from torch import nn | |
| from torchvision import transforms | |
| from scripts.dynamic.blocks import * | |
| class Normalize: | |
| def __init__(self, opt, expected_values, variance): | |
| self.n_channels = opt.input_channel | |
| self.expected_values = expected_values | |
| self.variance = variance | |
| assert self.n_channels == len(self.expected_values) | |
| def __call__(self, x): | |
| x_clone = x.clone() | |
| for channel in range(self.n_channels): | |
| x_clone[:, channel] = (x[:, channel] - self.expected_values[channel]) / self.variance[channel] | |
| return x_clone | |
| class Denormalize: | |
| def __init__(self, opt, expected_values, variance): | |
| self.n_channels = opt.input_channel | |
| self.expected_values = expected_values | |
| self.variance = variance | |
| assert self.n_channels == len(self.expected_values) | |
| def __call__(self, x): | |
| x_clone = x.clone() | |
| for channel in range(self.n_channels): | |
| x_clone[:, channel] = x[:, channel] * self.variance[channel] + self.expected_values[channel] | |
| return x_clone | |
| # ---------------------------- Generators ----------------------------# | |
| class Generator(nn.Sequential): | |
| def __init__(self, opt, out_channels=None): | |
| super(Generator, self).__init__() | |
| if opt.dataset == "mnist": | |
| channel_init = 16 | |
| steps = 2 | |
| else: | |
| channel_init = 32 | |
| steps = 3 | |
| channel_current = opt.input_channel | |
| channel_next = channel_init | |
| for step in range(steps): | |
| self.add_module("convblock_down_{}".format(2 * step), Conv2dBlock(channel_current, channel_next)) | |
| self.add_module("convblock_down_{}".format(2 * step + 1), Conv2dBlock(channel_next, channel_next)) | |
| self.add_module("downsample_{}".format(step), DownSampleBlock()) | |
| if step < steps - 1: | |
| channel_current = channel_next | |
| channel_next *= 2 | |
| self.add_module("convblock_middle", Conv2dBlock(channel_next, channel_next)) | |
| channel_current = channel_next | |
| channel_next = channel_current // 2 | |
| for step in range(steps): | |
| self.add_module("upsample_{}".format(step), UpSampleBlock()) | |
| self.add_module("convblock_up_{}".format(2 * step), Conv2dBlock(channel_current, channel_current)) | |
| if step == steps - 1: | |
| self.add_module( | |
| "convblock_up_{}".format(2 * step + 1), Conv2dBlock(channel_current, channel_next, relu=False) | |
| ) | |
| else: | |
| self.add_module("convblock_up_{}".format(2 * step + 1), Conv2dBlock(channel_current, channel_next)) | |
| channel_current = channel_next | |
| channel_next = channel_next // 2 | |
| if step == steps - 2: | |
| if out_channels is None: | |
| channel_next = opt.input_channel | |
| else: | |
| channel_next = out_channels | |
| self._EPSILON = 1e-7 | |
| self._normalizer = self._get_normalize(opt) | |
| self._denormalizer = self._get_denormalize(opt) | |
| def _get_denormalize(self, opt): | |
| if opt.dataset == "cifar10": | |
| denormalizer = Denormalize(opt, [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]) | |
| elif opt.dataset == "mnist": | |
| denormalizer = Denormalize(opt, [0.5], [0.5]) | |
| elif opt.dataset == "gtsrb": | |
| denormalizer = None | |
| else: | |
| raise Exception("Invalid dataset") | |
| return denormalizer | |
| def _get_normalize(self, opt): | |
| if opt.dataset == "cifar10": | |
| normalizer = Normalize(opt, [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]) | |
| elif opt.dataset == "mnist": | |
| normalizer = Normalize(opt, [0.5], [0.5]) | |
| elif opt.dataset == "gtsrb": | |
| normalizer = None | |
| else: | |
| raise Exception("Invalid dataset") | |
| return normalizer | |
| def forward(self, x): | |
| for module in self.children(): | |
| x = module(x) | |
| x = nn.Tanh()(x) / (2 + self._EPSILON) + 0.5 | |
| return x | |
| def normalize_pattern(self, x): | |
| if self._normalizer: | |
| x = self._normalizer(x) | |
| return x | |
| def denormalize_pattern(self, x): | |
| if self._denormalizer: | |
| x = self._denormalizer(x) | |
| return x | |
| def threshold(self, x): | |
| return nn.Tanh()(x * 20 - 10) / (2 + self._EPSILON) + 0.5 | |
| # ---------------------------- Classifiers ----------------------------# | |
| class NetC_MNIST(nn.Module): | |
| def __init__(self): | |
| super(NetC_MNIST, self).__init__() | |
| self.conv1 = nn.Conv2d(1, 32, (5, 5), 1, 0) | |
| self.relu2 = nn.ReLU(inplace=True) | |
| self.dropout3 = nn.Dropout(0.1) | |
| self.maxpool4 = nn.MaxPool2d((2, 2)) | |
| self.conv5 = nn.Conv2d(32, 64, (5, 5), 1, 0) | |
| self.relu6 = nn.ReLU(inplace=True) | |
| self.dropout7 = nn.Dropout(0.1) | |
| self.maxpool5 = nn.MaxPool2d((2, 2)) | |
| self.flatten = nn.Flatten() | |
| self.linear6 = nn.Linear(64 * 4 * 4, 512) | |
| self.relu7 = nn.ReLU(inplace=True) | |
| self.dropout8 = nn.Dropout(0.1) | |
| self.linear9 = nn.Linear(512, 10) | |
| def forward(self, x): | |
| for module in self.children(): | |
| x = module(x) | |
| return x |