Diffusion-Sprite / models.py
YashNagraj75's picture
Add checkpoints its still not clear
8a6ed33
import torch.nn as nn
import torch
class ResidualBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int,is_res: bool = False) -> None:
super(ResidualBlock,self).__init__()
self.same_channesls = in_channels == out_channels
self.is_res = is_res
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels,out_channels,3,1,1),
nn.BatchNorm2d(out_channels),
nn.GELU(),
)
self.conv2 = nn.Sequential(
nn.Conv2d(out_channels,out_channels,3,1,1),
nn.BatchNorm2d(out_channels),
nn.GELU(),
)
def forward(self,x):
if self.is_res:
x1 = self.conv1(x)
x2 = self.conv2(x1)
if self.same_channesls:
out = x1 + x2
else:
shortcut = nn.Conv2d(x.shape[1],x2.shape[1],1,1,0).to(x.device)
out = shortcut(x) + x2
return out / 1.414
else:
x1 = self.conv1(x)
x2 = self.conv2(x1)
return x2
class UnetUp(nn.Module):
def __init__(self, in_channels, out_channels):
super(UnetUp, self).__init__()
# Create a list of layers for the upsampling block
# The block consists of a ConvTranspose2d layer for upsampling, followed by two ResidualConvBlock layers
layers = [
nn.ConvTranspose2d(in_channels, out_channels, 2, 2),
ResidualBlock(out_channels, out_channels),
ResidualBlock(out_channels, out_channels),
]
# Use the layers to create a sequential model
self.model = nn.Sequential(*layers)
def forward(self, x, skip):
# Concatenate the input tensor x with the skip connection tensor along the channel dimension
x = torch.cat((x, skip), 1)
# Pass the concatenated tensor through the sequential model and return the output
x = self.model(x)
return x
class UnetDown(nn.Module):
def __init__(self, input_channels, out_channels) -> None:
super(UnetDown,self).__init__()
self.model = nn.Sequential(
ResidualBlock(input_channels,out_channels),
ResidualBlock(out_channels,out_channels),
nn.MaxPool2d(2)
)
def forward(self,x):
return self.model(x)
class EmbedFC(nn.Module):
def __init__(self, input_dim,embed_dm) -> None:
super(EmbedFC,self).__init__()
self.input_dim = input_dim
self.model = nn.Sequential(
nn.Linear(input_dim,embed_dm),
nn.GELU(),
nn.Linear(embed_dm,embed_dm),
)
def forward(self,x):
x = x.view(-1,self.input_dim)
return self.model(x)
class ContextUnet(nn.Module):
def __init__(self,in_channels, n_feat = 256,n_cfeat = 10, height = 28) -> None:
super(ContextUnet,self).__init__()
self.in_channels = in_channels
self.n_feat = n_feat
self.n_cfeat = n_cfeat
self.h = height
self.init_conv = ResidualBlock(in_channels,n_feat,is_res=True)
self.down1 = UnetDown(n_feat,n_feat)
self.down2 = UnetDown(n_feat,n_feat * 2)
self.to_vec = nn.Sequential(nn.AvgPool2d((4)),nn.GELU())
self.timeembed1 = EmbedFC(1, 2 *n_feat)
self.timeembed2 = EmbedFC(1,embed_dm=1*n_feat)
self.contextembed1 = EmbedFC(n_cfeat,2 * n_feat)
self.contextembed2 = EmbedFC(n_cfeat,1*n_feat)
self.up0 = nn.Sequential(
nn.ConvTranspose2d(2 * n_feat,2*n_feat,self.h // 4,self.h // 4),
nn.GroupNorm(8, 2*n_feat),
nn.ReLU(),
)
self.up1 = UnetUp(4 * n_feat,n_feat)
self.up2 = UnetUp(2 * n_feat,n_feat)
self.out = nn.Sequential(
nn.Conv2d(2 * n_feat, n_feat,3,1,1),
nn.GroupNorm(8,n_feat),
nn.ReLU(),
nn.Conv2d(n_feat,self.in_channels,3,1,1)
)
def forward(self,x,t,c=None):
x = self.init_conv(x)
down1 = self.down1(x)
down2 = self.down2(down1)
hidden_vec = self.to_vec(down2)
if c is None:
c = torch.zeros(x.shape[0],self.n_cfeat).to(x)
cemb1 = self.contextembed1(c).view(-1,self.n_feat*2,1,1)
temb1 = self.timeembed1(t).view(-1,self.n_feat * 2,1,1)
cemb2 = self.contextembed2(c).view(-1,self.n_feat,1,1)
temb2 = self.timeembed2(t).view(-1,self.n_feat,1,1)
up1 = self.up0(hidden_vec)
up2 = self.up1(cemb1*up1 + temb1, down2) # add and multiply embeddings
up3 = self.up2(cemb2*up2 + temb2, down1)
out = self.out(torch.cat((up3, x), 1))
return out