Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| class DownBlock(nn.Module): | |
| def __init__(self, in_filters, out_filters, normal=True): | |
| super().__init__() | |
| layers = [nn.Conv2d(in_filters, out_filters, kernel_size=4, stride=2, padding=1, padding_mode='reflect', bias=not normal)] | |
| if normal: | |
| layers.append(nn.InstanceNorm2d(out_filters, affine=True)) | |
| layers.append(nn.LeakyReLU(0.2, inplace=True)) | |
| self.block = nn.Sequential(*layers) | |
| def forward(self, x): | |
| return self.block(x) | |
| class UpBlock(nn.Module): | |
| def __init__(self, in_filters, out_filters, dropout=0.0): | |
| super().__init__() | |
| layers = [ | |
| nn.ConvTranspose2d(in_filters, out_filters, 4, 2, 1, bias=False), | |
| nn.InstanceNorm2d(out_filters, affine=True), | |
| nn.ReLU(inplace=True), | |
| ] | |
| if dropout: | |
| layers.append(nn.Dropout(dropout)) | |
| self.model = nn.Sequential(*layers) | |
| def forward(self, x, skip_input): | |
| x = self.model(x) | |
| x = torch.cat((x, skip_input), 1) | |
| return x | |
| class Generator(nn.Module): | |
| def __init__(self, input_channels, features=[64, 128, 256, 512, 512, 512, 512]): | |
| super().__init__() | |
| self.encoder = nn.ModuleList() | |
| self.decoder = nn.ModuleList() | |
| for idx, feature in enumerate(features): | |
| if idx == 0: | |
| self.encoder.append(DownBlock(input_channels, feature, normal=False)) | |
| else: | |
| self.encoder.append(DownBlock(input_channels, feature)) | |
| input_channels = feature | |
| self.bottleneck = DownBlock(512, 512, normal=False) | |
| self.final = nn.Sequential( | |
| nn.ConvTranspose2d(128, 3, 4, 2, 1), | |
| nn.Tanh() | |
| ) | |
| input_channels = features[-1] | |
| for idx, feature in enumerate(reversed(features)): | |
| if idx == 0: | |
| self.decoder.append(UpBlock(input_channels, feature, dropout=0.5)) | |
| elif idx < 3: | |
| self.decoder.append(UpBlock(input_channels*2, feature, dropout=0.5)) | |
| else: | |
| self.decoder.append(UpBlock(input_channels*2, feature)) | |
| input_channels = feature | |
| def forward(self, x): | |
| skips = [] | |
| for layer in self.encoder: | |
| x = layer(x) | |
| skips.append(x) | |
| x = self.bottleneck(x) | |
| skips = skips[::-1] | |
| for idx, layer in enumerate(self.decoder): | |
| x = layer(x, skips[idx]) | |
| x = self.final(x) | |
| return x | |