|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import pytorch_lightning as pl |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ResidualBlock(nn.Module): |
|
|
def __init__(self, in_channels, out_channels, stride=1): |
|
|
super().__init__() |
|
|
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) |
|
|
self.bn1 = nn.BatchNorm2d(out_channels) |
|
|
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) |
|
|
self.bn2 = nn.BatchNorm2d(out_channels) |
|
|
|
|
|
|
|
|
self.shortcut = nn.Sequential() |
|
|
if stride != 1 or in_channels != out_channels: |
|
|
self.shortcut = nn.Sequential( |
|
|
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False), |
|
|
nn.BatchNorm2d(out_channels) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
out = F.relu(self.bn1(self.conv1(x))) |
|
|
out = self.bn2(self.conv2(out)) |
|
|
out += self.shortcut(x) |
|
|
out = F.relu(out) |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ResidualConvAutoencoder(pl.LightningModule): |
|
|
def __init__(self, latent_dim=512, dropout_rate=0.2): |
|
|
super().__init__() |
|
|
self.latent_dim = latent_dim |
|
|
|
|
|
|
|
|
self.encoder = nn.Sequential( |
|
|
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1), |
|
|
ResidualBlock(64, 128, stride=2), |
|
|
ResidualBlock(128, 256, stride=2), |
|
|
ResidualBlock(256, 512, stride=2), |
|
|
nn.Flatten(), |
|
|
nn.Linear(512 * 4 * 4, self.latent_dim), |
|
|
nn.Dropout(dropout_rate) |
|
|
) |
|
|
|
|
|
|
|
|
self.decoder = nn.Sequential( |
|
|
nn.Linear(self.latent_dim, 512 * 4 * 4), |
|
|
nn.Unflatten(1, (512, 4, 4)), |
|
|
ResidualBlock(512, 256), |
|
|
nn.Upsample(scale_factor=2, mode='nearest'), |
|
|
ResidualBlock(256, 128), |
|
|
nn.Upsample(scale_factor=2, mode='nearest'), |
|
|
ResidualBlock(128, 64), |
|
|
nn.Upsample(scale_factor=2, mode='nearest'), |
|
|
nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1), |
|
|
nn.Sigmoid() |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
z = self.encoder(x) |
|
|
recon = self.decoder(z) |
|
|
return recon |
|
|
|
|
|
|
|
|
def training_step(self, batch, batch_idx): |
|
|
return torch.tensor(0.0) |
|
|
|
|
|
|
|
|
def configure_optimizers(self): |
|
|
return torch.optim.Adam(self.parameters()) |