Spaces:
Running
Running
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| class ResidualDoubleConv(nn.Module): | |
| def __init__(self, in_channels, out_channels, is_residual=False): | |
| super().__init__() | |
| self.conv = nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=1), | |
| nn.BatchNorm2d(out_channels), | |
| nn.GELU(), | |
| nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=1), | |
| nn.BatchNorm2d(out_channels), | |
| nn.GELU(), | |
| ) | |
| self.is_same_channels = in_channels == out_channels | |
| self.is_residual = is_residual | |
| if is_residual and not self.is_same_channels: | |
| self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False) | |
| else: | |
| self.shortcut = None | |
| def forward(self, x): | |
| out = self.conv(x) | |
| if not self.is_residual: | |
| return out | |
| if self.is_same_channels: | |
| out += x | |
| else: | |
| out += self.shortcut(x) | |
| return out / np.sqrt(2) # Normalizing residual flow | |
| class UpSample(nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super(UpSample, self).__init__() | |
| self.conv = nn.Sequential( | |
| nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2), | |
| ResidualDoubleConv(out_channels, out_channels), | |
| ResidualDoubleConv(out_channels, out_channels), | |
| ) | |
| def forward(self, x, skip): | |
| x = torch.cat((x, skip), 1) | |
| x = self.conv(x) | |
| return x | |
| class DownSample(nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super(DownSample, self).__init__() | |
| # Diffusion nets handle residual connections inside DoubleConv | |
| self.conv = nn.Sequential( | |
| ResidualDoubleConv(in_channels, out_channels), | |
| ResidualDoubleConv(out_channels, out_channels), | |
| nn.MaxPool2d(2), | |
| ) | |
| def forward(self, x): | |
| return self.conv(x) | |
| class EmbedFC(nn.Module): | |
| def __init__(self, input_dim, embed_dim): | |
| super(EmbedFC, self).__init__() | |
| self.input_dim = input_dim | |
| self.fc = nn.Sequential( | |
| nn.Linear(input_dim, embed_dim), | |
| nn.GELU(), | |
| nn.Linear(embed_dim, embed_dim), | |
| ) | |
| def forward(self, x): | |
| x = x.view(-1, self.input_dim) | |
| x = self.fc(x) | |
| return x |