|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from omegaconf import DictConfig |
|
|
from typing import Any, Dict, Tuple |
|
|
from utils import instantiate |
|
|
import cv2 |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
|
|
|
class ResidualBlock(nn.Module): |
|
|
def __init__(self, dim): |
|
|
super().__init__() |
|
|
self.relu = nn.ReLU() |
|
|
self.conv1 = nn.Conv2d(dim, dim, 3, 1, 1) |
|
|
self.conv2 = nn.Conv2d(dim, dim, 1) |
|
|
|
|
|
def forward(self, x): |
|
|
tmp = self.relu(x) |
|
|
tmp = self.conv1(tmp) |
|
|
tmp = self.relu(tmp) |
|
|
tmp = self.conv2(tmp) |
|
|
return x + tmp |
|
|
|
|
|
|
|
|
class Encoder2D(nn.Module): |
|
|
def __init__(self, output_channels=512): |
|
|
super(Encoder2D, self).__init__() |
|
|
|
|
|
self.block = nn.Sequential( |
|
|
nn.Conv2d(3, output_channels, 4, 2, 1), |
|
|
nn.ReLU(), |
|
|
nn.Conv2d(output_channels, output_channels, 4, 2, 1), |
|
|
nn.ReLU(), |
|
|
nn.Conv2d(output_channels, output_channels, 4, 2, 1), |
|
|
nn.ReLU(), |
|
|
nn.Conv2d(output_channels, output_channels, 4, 2, 1), |
|
|
nn.ReLU(), |
|
|
nn.Conv2d(output_channels, output_channels, 4, 2, 1), |
|
|
nn.ReLU(), |
|
|
nn.Conv2d(output_channels, output_channels, 3, 1, 1), |
|
|
ResidualBlock(output_channels), |
|
|
ResidualBlock(output_channels), |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.block(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class Decoder2D(nn.Module): |
|
|
def __init__(self, input_dim=512): |
|
|
super(Decoder2D, self).__init__() |
|
|
|
|
|
self.fea_map_size=16 |
|
|
|
|
|
self.block = nn.Sequential( |
|
|
nn.Conv2d(input_dim, input_dim, 3, 1, 1), |
|
|
ResidualBlock(input_dim), |
|
|
ResidualBlock(input_dim), |
|
|
nn.ConvTranspose2d(input_dim, input_dim, 4, 2, 1), |
|
|
nn.ReLU(), |
|
|
nn.ConvTranspose2d(input_dim, input_dim, 4, 2, 1), |
|
|
nn.ReLU(), |
|
|
nn.ConvTranspose2d(input_dim, input_dim, 4, 2, 1), |
|
|
nn.ReLU(), |
|
|
nn.ConvTranspose2d(input_dim, input_dim, 4, 2, 1), |
|
|
nn.ReLU(), |
|
|
nn.ConvTranspose2d(input_dim, 3, 4, 2, 1) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
x_hat = self.block(x) |
|
|
|
|
|
return x_hat |
|
|
|
|
|
|
|
|
class Encoder(Encoder2D): |
|
|
def __init__(self, output_channels=512): |
|
|
super().__init__(output_channels) |
|
|
self.pool = nn.AdaptiveAvgPool2d(output_size=(1, 1)) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.block(x) |
|
|
x = self.pool(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class Decoder(Decoder2D): |
|
|
def __init__(self, input_dim=512): |
|
|
super().__init__(input_dim) |
|
|
|
|
|
self.fc = nn.Linear(input_dim, input_dim*self.fea_map_size*self.fea_map_size) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.fc(x.view(x.size(0), -1)) |
|
|
x = x.view(x.size(0), 512, self.fea_map_size, self.fea_map_size) |
|
|
x_hat = self.block(x) |
|
|
|
|
|
return x_hat |
|
|
|
|
|
|