Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class Generator(nn.Module): | |
| def __init__(self, config, gk, gs, gf, gp): | |
| super(Generator, self).__init__() | |
| self.config = config | |
| self.convs = nn.ModuleList() | |
| self.bns = nn.ModuleList() | |
| self.no_layers = len(gk) | |
| for lay, (k, s, p) in enumerate(zip(gk, gs, gp)): | |
| if lay < self.no_layers - 2: | |
| self.convs.append( | |
| nn.ConvTranspose2d(gf[lay], gf[lay + 1], k, s, p, bias=False) | |
| ) | |
| else: | |
| self.convs.append( | |
| nn.Conv2d( | |
| gf[lay], | |
| gf[lay + 1], | |
| k, | |
| s, | |
| p, | |
| bias=False, | |
| padding_mode="reflect", | |
| ) | |
| ) | |
| self.bns.append(nn.BatchNorm2d(gf[lay + 1])) | |
| def forward(self, x: torch.Tensor): | |
| count = 0 | |
| # layers = [] | |
| for conv, bn in zip(self.convs[:-1], self.bns[:-1]): | |
| if count < self.no_layers - 2: | |
| x = conv(x) | |
| x = bn(x) | |
| x = F.relu_(x) | |
| else: | |
| x = conv(x) | |
| x = F.interpolate( | |
| x, | |
| [x.shape[-2] * 2 + 2, x.shape[-1] * 2 + 2], | |
| mode="bilinear", | |
| align_corners=False, | |
| ) | |
| x = bn(x) | |
| x = F.relu_(x) | |
| count += 1 | |
| if self.config.image_type == "n-phase": | |
| out = torch.softmax(self.convs[-1](x), dim=1) | |
| else: | |
| out = torch.sigmoid(self.convs[-1](x)) | |
| return out # bs x n x imsize x imsize x imsize | |
| class Discriminator(nn.Module): | |
| def __init__(self, dk, ds, dp, df): | |
| super(Discriminator, self).__init__() | |
| self.convs = nn.ModuleList() | |
| for lay, (k, s, p) in enumerate(zip(dk, ds, dp)): | |
| self.convs.append(nn.Conv2d(df[lay], df[lay + 1], k, s, p, bias=False)) | |
| def forward(self, x): | |
| for conv in self.convs[:-1]: | |
| x = F.relu_(conv(x)) | |
| x = self.convs[-1](x) # bs x 1 x 1 | |
| return x | |
| def make_nets(config, training=True): | |
| """Creates Generator and Discriminator class objects from params either loaded from config object or params file. | |
| :param config: a Config class object | |
| :type config: Config | |
| :param training: if training is True, params are loaded from Config object. If False, params are loaded from file, defaults to True | |
| :type training: bool, optional | |
| :return: Discriminator and Generator class objects | |
| :rtype: Discriminator, Generator | |
| """ | |
| # save/load params | |
| if training: | |
| config.save() | |
| else: | |
| config.load() | |
| dk, ds, df, dp, gk, gs, gf, gp = config.get_net_params() | |
| # Make nets | |
| return Discriminator(dk, ds, dp, df), Generator(config, gk, gs, gf, gp) | |