Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| class Upsample(nn.Module): | |
| def __init__(self, in_channels : int, with_conv : bool): | |
| super().__init__() | |
| self.with_conv = with_conv | |
| if self.with_conv: | |
| self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size = 3, stride = 1, padding = 1) | |
| def forward(self, x): | |
| x = torch.nn.functional.interpolate(x, scale_factor = 2.0, mode = "nearest") | |
| if self.with_conv: | |
| x = self.conv(x) | |
| return x | |
| class Downsample(nn.Module): | |
| def __init__(self, in_channels : int, with_conv : bool): | |
| super().__init__() | |
| self.with_conv = with_conv | |
| if self.with_conv: | |
| self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size = 3, stride = 2, padding = 0) | |
| def forward(self, x): | |
| if self.with_conv: | |
| pad = (0, 1, 0, 1) | |
| x = torch.nn.functional.pad(x, pad, mode = "constant", value = 0) | |
| x = self.conv(x) | |
| else: | |
| x = torch.nn.functional.avg_pool2d(x, kernel_size = 2, stride = 2) | |
| return x |