import torch.nn as nn import torch class ResidualBlock(nn.Module): def __init__(self, in_channels: int, out_channels: int,is_res: bool = False) -> None: super(ResidualBlock,self).__init__() self.same_channesls = in_channels == out_channels self.is_res = is_res self.conv1 = nn.Sequential( nn.Conv2d(in_channels,out_channels,3,1,1), nn.BatchNorm2d(out_channels), nn.GELU(), ) self.conv2 = nn.Sequential( nn.Conv2d(out_channels,out_channels,3,1,1), nn.BatchNorm2d(out_channels), nn.GELU(), ) def forward(self,x): if self.is_res: x1 = self.conv1(x) x2 = self.conv2(x1) if self.same_channesls: out = x1 + x2 else: shortcut = nn.Conv2d(x.shape[1],x2.shape[1],1,1,0).to(x.device) out = shortcut(x) + x2 return out / 1.414 else: x1 = self.conv1(x) x2 = self.conv2(x1) return x2 class UnetUp(nn.Module): def __init__(self, in_channels, out_channels): super(UnetUp, self).__init__() # Create a list of layers for the upsampling block # The block consists of a ConvTranspose2d layer for upsampling, followed by two ResidualConvBlock layers layers = [ nn.ConvTranspose2d(in_channels, out_channels, 2, 2), ResidualBlock(out_channels, out_channels), ResidualBlock(out_channels, out_channels), ] # Use the layers to create a sequential model self.model = nn.Sequential(*layers) def forward(self, x, skip): # Concatenate the input tensor x with the skip connection tensor along the channel dimension x = torch.cat((x, skip), 1) # Pass the concatenated tensor through the sequential model and return the output x = self.model(x) return x class UnetDown(nn.Module): def __init__(self, input_channels, out_channels) -> None: super(UnetDown,self).__init__() self.model = nn.Sequential( ResidualBlock(input_channels,out_channels), ResidualBlock(out_channels,out_channels), nn.MaxPool2d(2) ) def forward(self,x): return self.model(x) class EmbedFC(nn.Module): def __init__(self, input_dim,embed_dm) -> None: super(EmbedFC,self).__init__() self.input_dim = input_dim self.model = nn.Sequential( nn.Linear(input_dim,embed_dm), nn.GELU(), nn.Linear(embed_dm,embed_dm), ) def forward(self,x): x = x.view(-1,self.input_dim) return self.model(x) class ContextUnet(nn.Module): def __init__(self,in_channels, n_feat = 256,n_cfeat = 10, height = 28) -> None: super(ContextUnet,self).__init__() self.in_channels = in_channels self.n_feat = n_feat self.n_cfeat = n_cfeat self.h = height self.init_conv = ResidualBlock(in_channels,n_feat,is_res=True) self.down1 = UnetDown(n_feat,n_feat) self.down2 = UnetDown(n_feat,n_feat * 2) self.to_vec = nn.Sequential(nn.AvgPool2d((4)),nn.GELU()) self.timeembed1 = EmbedFC(1, 2 *n_feat) self.timeembed2 = EmbedFC(1,embed_dm=1*n_feat) self.contextembed1 = EmbedFC(n_cfeat,2 * n_feat) self.contextembed2 = EmbedFC(n_cfeat,1*n_feat) self.up0 = nn.Sequential( nn.ConvTranspose2d(2 * n_feat,2*n_feat,self.h // 4,self.h // 4), nn.GroupNorm(8, 2*n_feat), nn.ReLU(), ) self.up1 = UnetUp(4 * n_feat,n_feat) self.up2 = UnetUp(2 * n_feat,n_feat) self.out = nn.Sequential( nn.Conv2d(2 * n_feat, n_feat,3,1,1), nn.GroupNorm(8,n_feat), nn.ReLU(), nn.Conv2d(n_feat,self.in_channels,3,1,1) ) def forward(self,x,t,c=None): x = self.init_conv(x) down1 = self.down1(x) down2 = self.down2(down1) hidden_vec = self.to_vec(down2) if c is None: c = torch.zeros(x.shape[0],self.n_cfeat).to(x) cemb1 = self.contextembed1(c).view(-1,self.n_feat*2,1,1) temb1 = self.timeembed1(t).view(-1,self.n_feat * 2,1,1) cemb2 = self.contextembed2(c).view(-1,self.n_feat,1,1) temb2 = self.timeembed2(t).view(-1,self.n_feat,1,1) up1 = self.up0(hidden_vec) up2 = self.up1(cemb1*up1 + temb1, down2) # add and multiply embeddings up3 = self.up2(cemb2*up2 + temb2, down1) out = self.out(torch.cat((up3, x), 1)) return out