File size: 3,543 Bytes
372980e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import torch
import torch.nn as nn
from torch import Tensor
class DoubleConv(nn.Module):
def __init__(self, in_channels: int, out_channels: int, mid_channels: int = None):
super().__init__()
if mid_channels is None:
mid_channels = out_channels
self.conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x: Tensor) -> Tensor:
return self.conv(x)
class Down(nn.Module):
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
return self.maxpool_conv(x)
class Up(nn.Module):
def __init__(self, in_channels: int, out_channels: int, bilinear: bool = False):
super().__init__()
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else:
self.up = nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x):
x = self.up(x)
return self.conv(x)
class Encoder(nn.Module):
def __init__(self, z_channels: int, in_channels: int, channels: int, channels_mult: list[int], **ignore_kwargs):
super().__init__()
self.encoder = nn.ModuleList()
num_resolutions = len(channels_mult)
in_ch_mult = (1,) + tuple(channels_mult)
self.encoder.append(DoubleConv(in_channels, channels))
for i_level in range(num_resolutions):
block_in = channels * in_ch_mult[i_level]
block_out = channels * channels_mult[i_level]
if i_level != num_resolutions - 1:
self.encoder.append(Down(block_in, block_out))
else:
self.encoder.append(DoubleConv(block_in, block_out))
block_in = block_out
self.encoder.append(nn.Conv2d(block_in, z_channels, kernel_size=(1, 1)))
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
for layer in self.encoder:
x = layer(x)
return x
class Decoder(nn.Module):
def __init__(self, z_channels: int, out_channels: int, channels: int, channels_mult: list[int], **ignore_kwargs):
super().__init__()
self.decoder = nn.ModuleList()
num_resolutions = len(channels_mult)
block_in = channels*channels_mult[num_resolutions-1]
self.decoder.append(nn.Conv2d(z_channels, block_in, kernel_size=(1, 1)))
for i_level in reversed(range(num_resolutions)):
block_out = channels * channels_mult[i_level]
if i_level != 0:
self.decoder.append(Up(block_in, block_out))
else:
self.decoder.append(DoubleConv(block_in, block_out))
block_in = block_out
self.final_conv = nn.Conv2d(block_in, out_channels, kernel_size=1)
def forward(self, x):
for layer in self.decoder:
x = layer(x)
return self.final_conv(x)
|