Sprite-Generation / src /model_parts.py
Efesasa0's picture
b6b6742
import numpy as np
import torch
import torch.nn as nn
class ResidualDoubleConv(nn.Module):
def __init__(self, in_channels, out_channels, is_residual=False):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=1),
nn.BatchNorm2d(out_channels),
nn.GELU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=1),
nn.BatchNorm2d(out_channels),
nn.GELU(),
)
self.is_same_channels = in_channels == out_channels
self.is_residual = is_residual
if is_residual and not self.is_same_channels:
self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
else:
self.shortcut = None
def forward(self, x):
out = self.conv(x)
if not self.is_residual:
return out
if self.is_same_channels:
out += x
else:
out += self.shortcut(x)
return out / np.sqrt(2) # Normalizing residual flow
class UpSample(nn.Module):
def __init__(self, in_channels, out_channels):
super(UpSample, self).__init__()
self.conv = nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
ResidualDoubleConv(out_channels, out_channels),
ResidualDoubleConv(out_channels, out_channels),
)
def forward(self, x, skip):
x = torch.cat((x, skip), 1)
x = self.conv(x)
return x
class DownSample(nn.Module):
def __init__(self, in_channels, out_channels):
super(DownSample, self).__init__()
# Diffusion nets handle residual connections inside DoubleConv
self.conv = nn.Sequential(
ResidualDoubleConv(in_channels, out_channels),
ResidualDoubleConv(out_channels, out_channels),
nn.MaxPool2d(2),
)
def forward(self, x):
return self.conv(x)
class EmbedFC(nn.Module):
def __init__(self, input_dim, embed_dim):
super(EmbedFC, self).__init__()
self.input_dim = input_dim
self.fc = nn.Sequential(
nn.Linear(input_dim, embed_dim),
nn.GELU(),
nn.Linear(embed_dim, embed_dim),
)
def forward(self, x):
x = x.view(-1, self.input_dim)
x = self.fc(x)
return x