Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| class CNNBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels, stride): | |
| super(CNNBlock, self).__init__() | |
| self.conv = nn.Sequential( | |
| nn.Conv2d( | |
| in_channels, out_channels, 4, stride, 1, bias=False, padding_mode="reflect" | |
| ), | |
| nn.BatchNorm2d(out_channels), | |
| nn.LeakyReLU(0.2), | |
| ) | |
| def forward(self, x): | |
| return self.conv(x) | |
| class Discriminator(nn.Module): | |
| def __init__(self, in_channels=3, features=[64, 128, 256, 512]): | |
| super().__init__() | |
| self.initial = nn.Sequential( | |
| nn.Conv2d( | |
| in_channels * 2, | |
| features[0], | |
| kernel_size=4, | |
| stride=2, | |
| padding=1, | |
| padding_mode="reflect", | |
| ), | |
| nn.LeakyReLU(0.2), | |
| ) | |
| layers = [] | |
| in_channels = features[0] | |
| for feature in features[1:]: | |
| layers.append( | |
| CNNBlock(in_channels, feature, stride=1 if feature == features[-1] else 2), | |
| ) | |
| in_channels = feature | |
| layers.append( | |
| nn.Conv2d( | |
| in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect" | |
| ), | |
| ) | |
| self.model = nn.Sequential(*layers) | |
| def forward(self, x, y): | |
| x = torch.cat([x, y], dim=1) | |
| x = self.initial(x) | |
| x = self.model(x) | |
| return x | |
| def test(): | |
| x = torch.randn((1, 3, 256, 256)) | |
| y = torch.randn((1, 3, 256, 256)) | |
| model = Discriminator(in_channels=3) | |
| preds = model(x, y) | |
| print(model) | |
| print(preds.shape) | |
| if __name__ == "__main__": | |
| test() |