# https://gist.github.com/mrsteyk/74ad3ec2f6f823111ae4c90e168505ac import torch import torch.nn.functional as F import torch.nn as nn from models import register class TimestepEmbedding(nn.Module): def __init__(self, n_time=1024, n_emb=320, n_out=1280) -> None: super().__init__() self.emb = nn.Embedding(n_time, n_emb) self.f_1 = nn.Linear(n_emb, n_out) self.f_2 = nn.Linear(n_out, n_out) def forward(self, x) -> torch.Tensor: x = self.emb(x) x = self.f_1(x) x = F.silu(x) return self.f_2(x) class PositionalEmbedding(nn.Module): def __init__(self, pe_dim=320, out_dim=1280, max_positions=10000, endpoint=True): super().__init__() self.num_channels = pe_dim self.max_positions = max_positions self.endpoint = endpoint self.f_1 = nn.Linear(pe_dim, out_dim) self.f_2 = nn.Linear(out_dim, out_dim) def forward(self, x): freqs = torch.arange(start=0, end=self.num_channels//2, dtype=torch.float32, device=x.device) freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0)) freqs = (1 / self.max_positions) ** freqs x = x.ger(freqs.to(x.dtype)) x = torch.cat([x.cos(), x.sin()], dim=1) x = self.f_1(x) x = F.silu(x) return self.f_2(x) class ImageEmbedding(nn.Module): def __init__(self, in_channels, out_channels=320) -> None: super().__init__() self.f = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) def forward(self, x) -> torch.Tensor: return self.f(x) class ImageUnembedding(nn.Module): def __init__(self, in_channels=320, out_channels=3) -> None: super().__init__() self.gn = nn.GroupNorm(32, in_channels) self.f = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) def forward(self, x) -> torch.Tensor: return self.f(F.silu(self.gn(x))) class ConvResblock(nn.Module): def __init__(self, in_features, out_features, t_dim) -> None: super().__init__() self.f_t = nn.Linear(t_dim, out_features * 2) self.gn_1 = nn.GroupNorm(32, in_features) self.f_1 = nn.Conv2d(in_features, out_features, kernel_size=3, padding=1) self.gn_2 = nn.GroupNorm(32, out_features) self.f_2 = nn.Conv2d(out_features, out_features, kernel_size=3, padding=1) skip_conv = in_features != out_features self.f_s = ( nn.Conv2d(in_features, out_features, kernel_size=1, padding=0) if skip_conv else nn.Identity() ) def forward(self, x, t): x_skip = x t = self.f_t(F.silu(t)) t = t.chunk(2, dim=1) t_1 = t[0].unsqueeze(dim=2).unsqueeze(dim=3) + 1 t_2 = t[1].unsqueeze(dim=2).unsqueeze(dim=3) gn_1 = F.silu(self.gn_1(x)) f_1 = self.f_1(gn_1) gn_2 = self.gn_2(f_1) return self.f_s(x_skip) + self.f_2(F.silu(gn_2 * t_1 + t_2)) # Also ConvResblock class Downsample(nn.Module): def __init__(self, in_channels, t_dim) -> None: super().__init__() self.f_t = nn.Linear(t_dim, in_channels * 2) self.gn_1 = nn.GroupNorm(32, in_channels) self.f_1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) self.gn_2 = nn.GroupNorm(32, in_channels) self.f_2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) def forward(self, x, t) -> torch.Tensor: x_skip = x t = self.f_t(F.silu(t)) t_1, t_2 = t.chunk(2, dim=1) t_1 = t_1.unsqueeze(2).unsqueeze(3) + 1 t_2 = t_2.unsqueeze(2).unsqueeze(3) gn_1 = F.silu(self.gn_1(x)) avg_pool2d = F.avg_pool2d(gn_1, kernel_size=(2, 2), stride=None) f_1 = self.f_1(avg_pool2d) gn_2 = self.gn_2(f_1) f_2 = self.f_2(F.silu(t_2 + (t_1 * gn_2))) return f_2 + F.avg_pool2d(x_skip, kernel_size=(2, 2), stride=None) # Also ConvResblock class Upsample(nn.Module): def __init__(self, in_channels, t_dim) -> None: super().__init__() self.f_t = nn.Linear(t_dim, in_channels * 2) self.gn_1 = nn.GroupNorm(32, in_channels) self.f_1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) self.gn_2 = nn.GroupNorm(32, in_channels) self.f_2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) def forward(self, x, t) -> torch.Tensor: x_skip = x t = self.f_t(F.silu(t)) t_1, t_2 = t.chunk(2, dim=1) t_1 = t_1.unsqueeze(2).unsqueeze(3) + 1 t_2 = t_2.unsqueeze(2).unsqueeze(3) gn_1 = F.silu(self.gn_1(x)) upsample = F.upsample_nearest(gn_1, scale_factor=2) f_1 = self.f_1(upsample) gn_2 = self.gn_2(f_1) f_2 = self.f_2(F.silu(t_2 + (t_1 * gn_2))) return f_2 + F.upsample_nearest(x_skip, scale_factor=2) @register('consistency_decoder_unet') class ConsistencyDecoderUNet(nn.Module): def __init__(self, in_channels=3, z_dec_channels=None, c0=320, c1=640, c2=1024, pe_dim=320, t_dim=1280) -> None: super().__init__() if z_dec_channels is not None: in_channels += z_dec_channels self.embed_image = ImageEmbedding(in_channels=in_channels, out_channels=c0) self.embed_time = PositionalEmbedding(pe_dim=pe_dim, out_dim=t_dim) down_0 = nn.ModuleList([ ConvResblock(c0, c0, t_dim), ConvResblock(c0, c0, t_dim), ConvResblock(c0, c0, t_dim), Downsample(c0, t_dim), ]) down_1 = nn.ModuleList([ ConvResblock(c0, c1, t_dim), ConvResblock(c1, c1, t_dim), ConvResblock(c1, c1, t_dim), Downsample(c1, t_dim), ]) down_2 = nn.ModuleList([ ConvResblock(c1, c2, t_dim), ConvResblock(c2, c2, t_dim), ConvResblock(c2, c2, t_dim), Downsample(c2, t_dim), ]) down_3 = nn.ModuleList([ ConvResblock(c2, c2, t_dim), ConvResblock(c2, c2, t_dim), ConvResblock(c2, c2, t_dim), ]) self.down = nn.ModuleList([ down_0, down_1, down_2, down_3, ]) self.mid = nn.ModuleList([ ConvResblock(c2, c2, t_dim), ConvResblock(c2, c2, t_dim), ]) up_3 = nn.ModuleList([ ConvResblock(c2 * 2, c2, t_dim), ConvResblock(c2 * 2, c2, t_dim), ConvResblock(c2 * 2, c2, t_dim), ConvResblock(c2 * 2, c2, t_dim), Upsample(c2, t_dim), ]) up_2 = nn.ModuleList([ ConvResblock(c2 * 2, c2, t_dim), ConvResblock(c2 * 2, c2, t_dim), ConvResblock(c2 * 2, c2, t_dim), ConvResblock(c2 + c1, c2, t_dim), Upsample(c2, t_dim), ]) up_1 = nn.ModuleList([ ConvResblock(c2 + c1, c1, t_dim), ConvResblock(c1 * 2, c1, t_dim), ConvResblock(c1 * 2, c1, t_dim), ConvResblock(c0 + c1, c1, t_dim), Upsample(c1, t_dim), ]) up_0 = nn.ModuleList([ ConvResblock(c0 + c1, c0, t_dim), ConvResblock(c0 * 2, c0, t_dim), ConvResblock(c0 * 2, c0, t_dim), ConvResblock(c0 * 2, c0, t_dim), ]) self.up = nn.ModuleList([ up_0, up_1, up_2, up_3, ]) self.output = ImageUnembedding(in_channels=c0) def get_last_layer_weight(self): return self.output.f.weight def forward(self, x, t=None, z_dec=None) -> torch.Tensor: if z_dec is not None: print('shape of x and z_dec: ', x.shape, z_dec.shape) if z_dec.shape[-2] != x.shape[-2] or z_dec.shape[-1] != x.shape[-1]: assert x.shape[-2] // z_dec.shape[-2] == x.shape[-1] // z_dec.shape[-1] z_dec = F.upsample_nearest(z_dec, scale_factor=x.shape[-2] // z_dec.shape[-2]) x = torch.cat([x, z_dec], dim=1) x = self.embed_image(x) if t is None: t = torch.zeros(x.shape[0], device=x.device) t = self.embed_time(t) skips = [x] for down in self.down: for block in down: x = block(x, t) skips.append(x) for mid in self.mid: x = mid(x, t) for up in self.up[::-1]: for block in up: if isinstance(block, ConvResblock): x = torch.concat([x, skips.pop()], dim=1) x = block(x, t) return self.output(x)