| | """ |
| | Ported from Paella |
| | """ |
| |
|
| | import torch |
| | from torch import nn |
| |
|
| | from diffusers.configuration_utils import ConfigMixin, register_to_config |
| | from diffusers.models.modeling_utils import ModelMixin |
| |
|
| |
|
| | |
| | class Discriminator(ModelMixin, ConfigMixin): |
| | @register_to_config |
| | def __init__(self, in_channels=3, cond_channels=0, hidden_channels=512, depth=6): |
| | super().__init__() |
| | d = max(depth - 3, 3) |
| | layers = [ |
| | nn.utils.spectral_norm( |
| | nn.Conv2d(in_channels, hidden_channels // (2**d), kernel_size=3, stride=2, padding=1) |
| | ), |
| | nn.LeakyReLU(0.2), |
| | ] |
| | for i in range(depth - 1): |
| | c_in = hidden_channels // (2 ** max((d - i), 0)) |
| | c_out = hidden_channels // (2 ** max((d - 1 - i), 0)) |
| | layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1))) |
| | layers.append(nn.InstanceNorm2d(c_out)) |
| | layers.append(nn.LeakyReLU(0.2)) |
| | self.encoder = nn.Sequential(*layers) |
| | self.shuffle = nn.Conv2d( |
| | (hidden_channels + cond_channels) if cond_channels > 0 else hidden_channels, 1, kernel_size=1 |
| | ) |
| | self.logits = nn.Sigmoid() |
| |
|
| | def forward(self, x, cond=None): |
| | x = self.encoder(x) |
| | if cond is not None: |
| | cond = cond.view( |
| | cond.size(0), |
| | cond.size(1), |
| | 1, |
| | 1, |
| | ).expand(-1, -1, x.size(-2), x.size(-1)) |
| | x = torch.cat([x, cond], dim=1) |
| | x = self.shuffle(x) |
| | x = self.logits(x) |
| | return x |
| |
|