| import torch.nn as nn |
| import torch |
|
|
| class ResidualBlock(nn.Module): |
| def __init__(self, in_channels: int, out_channels: int,is_res: bool = False) -> None: |
| super(ResidualBlock,self).__init__() |
|
|
| self.same_channesls = in_channels == out_channels |
|
|
| self.is_res = is_res |
|
|
| self.conv1 = nn.Sequential( |
| nn.Conv2d(in_channels,out_channels,3,1,1), |
| nn.BatchNorm2d(out_channels), |
| nn.GELU(), |
| ) |
|
|
| self.conv2 = nn.Sequential( |
| nn.Conv2d(out_channels,out_channels,3,1,1), |
| nn.BatchNorm2d(out_channels), |
| nn.GELU(), |
| ) |
|
|
| def forward(self,x): |
| if self.is_res: |
| x1 = self.conv1(x) |
|
|
| x2 = self.conv2(x1) |
|
|
| if self.same_channesls: |
| out = x1 + x2 |
| else: |
| shortcut = nn.Conv2d(x.shape[1],x2.shape[1],1,1,0).to(x.device) |
| out = shortcut(x) + x2 |
|
|
| return out / 1.414 |
| |
| else: |
| x1 = self.conv1(x) |
| x2 = self.conv2(x1) |
| return x2 |
|
|
|
|
|
|
| class UnetUp(nn.Module): |
| def __init__(self, in_channels, out_channels): |
| super(UnetUp, self).__init__() |
| |
| |
| |
| layers = [ |
| nn.ConvTranspose2d(in_channels, out_channels, 2, 2), |
| ResidualBlock(out_channels, out_channels), |
| ResidualBlock(out_channels, out_channels), |
| ] |
| |
| |
| self.model = nn.Sequential(*layers) |
|
|
| def forward(self, x, skip): |
| |
| x = torch.cat((x, skip), 1) |
| |
| |
| x = self.model(x) |
| return x |
| class UnetDown(nn.Module): |
| def __init__(self, input_channels, out_channels) -> None: |
| super(UnetDown,self).__init__() |
|
|
| self.model = nn.Sequential( |
| ResidualBlock(input_channels,out_channels), |
| ResidualBlock(out_channels,out_channels), |
| nn.MaxPool2d(2) |
| ) |
|
|
| def forward(self,x): |
| return self.model(x) |
| |
|
|
| class EmbedFC(nn.Module): |
| def __init__(self, input_dim,embed_dm) -> None: |
| super(EmbedFC,self).__init__() |
|
|
| self.input_dim = input_dim |
| |
| self.model = nn.Sequential( |
| nn.Linear(input_dim,embed_dm), |
| nn.GELU(), |
| nn.Linear(embed_dm,embed_dm), |
| ) |
|
|
| def forward(self,x): |
| x = x.view(-1,self.input_dim) |
| return self.model(x) |
|
|
|
|
| class ContextUnet(nn.Module): |
| def __init__(self,in_channels, n_feat = 256,n_cfeat = 10, height = 28) -> None: |
| super(ContextUnet,self).__init__() |
|
|
| self.in_channels = in_channels |
| self.n_feat = n_feat |
| self.n_cfeat = n_cfeat |
| self.h = height |
|
|
| self.init_conv = ResidualBlock(in_channels,n_feat,is_res=True) |
|
|
| self.down1 = UnetDown(n_feat,n_feat) |
| self.down2 = UnetDown(n_feat,n_feat * 2) |
|
|
| self.to_vec = nn.Sequential(nn.AvgPool2d((4)),nn.GELU()) |
|
|
| self.timeembed1 = EmbedFC(1, 2 *n_feat) |
| self.timeembed2 = EmbedFC(1,embed_dm=1*n_feat) |
| self.contextembed1 = EmbedFC(n_cfeat,2 * n_feat) |
| self.contextembed2 = EmbedFC(n_cfeat,1*n_feat) |
|
|
| self.up0 = nn.Sequential( |
| nn.ConvTranspose2d(2 * n_feat,2*n_feat,self.h // 4,self.h // 4), |
| nn.GroupNorm(8, 2*n_feat), |
| nn.ReLU(), |
| ) |
|
|
| self.up1 = UnetUp(4 * n_feat,n_feat) |
| self.up2 = UnetUp(2 * n_feat,n_feat) |
|
|
| self.out = nn.Sequential( |
| nn.Conv2d(2 * n_feat, n_feat,3,1,1), |
| nn.GroupNorm(8,n_feat), |
| nn.ReLU(), |
| nn.Conv2d(n_feat,self.in_channels,3,1,1) |
| ) |
|
|
| def forward(self,x,t,c=None): |
| x = self.init_conv(x) |
|
|
| down1 = self.down1(x) |
| down2 = self.down2(down1) |
|
|
| hidden_vec = self.to_vec(down2) |
|
|
| if c is None: |
| c = torch.zeros(x.shape[0],self.n_cfeat).to(x) |
| |
| cemb1 = self.contextembed1(c).view(-1,self.n_feat*2,1,1) |
| temb1 = self.timeembed1(t).view(-1,self.n_feat * 2,1,1) |
| cemb2 = self.contextembed2(c).view(-1,self.n_feat,1,1) |
| temb2 = self.timeembed2(t).view(-1,self.n_feat,1,1) |
|
|
| up1 = self.up0(hidden_vec) |
| up2 = self.up1(cemb1*up1 + temb1, down2) |
| up3 = self.up2(cemb2*up2 + temb2, down1) |
| out = self.out(torch.cat((up3, x), 1)) |
| return out |