Spaces:
Paused
Paused
| from typing import Any | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from einops import rearrange | |
| from collections import defaultdict | |
| import torch as th | |
| import numpy as np | |
| import math | |
| str_to_act = defaultdict(lambda: nn.SiLU()) | |
| str_to_act.update({ | |
| "relu": nn.ReLU(), | |
| "silu": nn.SiLU(), | |
| "gelu": nn.GELU(), | |
| }) | |
| class SinusoidalPositionalEncoding(nn.Module): | |
| def __init__(self, dim): | |
| super().__init__() | |
| self.dim = dim | |
| def forward(self, t): | |
| device = t.device | |
| t = t.unsqueeze(-1) | |
| inv_freq = 1.0 / (10000 ** (torch.arange(0, self.dim, 2, device=device).float() / self.dim)) | |
| sin_enc = torch.sin(t.repeat(1, self.dim // 2) * inv_freq) | |
| cos_enc = torch.cos(t.repeat(1, self.dim // 2) * inv_freq) | |
| pos_enc = torch.cat([sin_enc, cos_enc], dim=-1) | |
| return pos_enc | |
| class TimeEmbedding(nn.Module): | |
| def __init__(self, model_dim: int, emb_dim: int, act="silu"): | |
| super().__init__() | |
| self.lin = nn.Linear(model_dim, emb_dim) | |
| self.act = str_to_act[act] | |
| self.lin2 = nn.Linear(emb_dim, emb_dim) | |
| def forward(self, x): | |
| x = self.lin(x) | |
| x = self.act(x) | |
| x = self.lin2(x) | |
| return x | |
| class ConvBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels, act="silu", dropout=None, zero=False): | |
| super().__init__() | |
| self.norm = nn.GroupNorm( | |
| num_groups=32, | |
| num_channels=in_channels, | |
| ) | |
| self.act = str_to_act[act] | |
| if dropout is not None: | |
| self.dropout = nn.Dropout(dropout) | |
| self.conv = nn.Conv2d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=3, | |
| padding=1, | |
| ) | |
| if zero: | |
| self.conv.weight.data.zero_() | |
| def forward(self, x): | |
| x = self.norm(x) | |
| x = self.act(x) | |
| if hasattr(self, "dropout"): | |
| x = self.dropout(x) | |
| x = self.conv(x) | |
| return x | |
| class EmbeddingBlock(nn.Module): | |
| def __init__(self, channels: int, emb_dim: int, act="silu"): | |
| super().__init__() | |
| self.act = str_to_act[act] | |
| self.lin = nn.Linear(emb_dim, channels) | |
| def forward(self, x): | |
| x = self.act(x) | |
| x = self.lin(x) | |
| return x | |
| class ResBlock(nn.Module): | |
| def __init__(self, channels: int, emb_dim: int, dropout: float = 0, out_channels=None): | |
| """A resblock with a time embedding and an optional change in channel count | |
| """ | |
| if out_channels is None: | |
| out_channels = channels | |
| super().__init__() | |
| self.conv1 = ConvBlock(channels, out_channels) | |
| self.emb = EmbeddingBlock(out_channels, emb_dim) | |
| self.conv2 = ConvBlock(out_channels, out_channels, dropout=dropout, zero=True) | |
| if channels != out_channels: | |
| self.skip_connection = nn.Conv2d(channels, out_channels, kernel_size=1) | |
| else: | |
| self.skip_connection = nn.Identity() | |
| def forward(self, x, t): | |
| original = x | |
| x = self.conv1(x) | |
| t = self.emb(t) | |
| # t: (batch_size, time_embedding_dim) = (batch_size, out_channels) | |
| # x: (batch_size, out_channels, height, width) | |
| # we repeat the time embedding to match the shape of x | |
| t = t.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, x.shape[2], x.shape[3]) | |
| x = x + t | |
| x = self.conv2(x) | |
| x = x + self.skip_connection(original) | |
| return x | |
| class SelfAttentionBlock(nn.Module): | |
| def __init__(self, channels, num_heads=1): | |
| super().__init__() | |
| self.channels = channels | |
| self.num_heads = num_heads | |
| self.norm = nn.GroupNorm(32, channels) | |
| self.attention = nn.MultiheadAttention( | |
| embed_dim=channels, | |
| num_heads=num_heads, | |
| dropout=0, | |
| batch_first=True, | |
| bias=True, | |
| ) | |
| def forward(self, x): | |
| h, w = x.shape[-2:] | |
| original = x | |
| x = self.norm(x) | |
| x = rearrange(x, "b c h w -> b (h w) c") | |
| x = self.attention(x, x, x)[0] | |
| x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) | |
| return x + original | |
| class Downsample(nn.Module): | |
| def __init__(self, channels): | |
| super().__init__() | |
| # ddpm uses maxpool | |
| # self.down = nn.MaxPool2d | |
| # iddpm uses strided conv | |
| self.down = nn.Conv2d( | |
| in_channels=channels, | |
| out_channels=channels, | |
| kernel_size=3, | |
| stride=2, | |
| padding=1, | |
| ) | |
| def forward(self, x): | |
| return self.down(x) | |
| class DownBlock(nn.Module): | |
| """According to U-Net paper | |
| 'The contracting path follows the typical architecture of a convolutional network. | |
| It consists of the repeated application of two 3x3 convolutions (unpadded convolutions), | |
| each followed by a rectified linear unit (ReLU) and a 2x2 max pooling operation with stride 2 | |
| for downsampling. At each downsampling step we double the number of feature channels.' | |
| """ | |
| def __init__(self, in_channels, out_channels, time_embedding_dim, use_attn=False, dropout=0, downsample=True, width=1): | |
| """in_channels will typically be half of out_channels""" | |
| super().__init__() | |
| self.width = width | |
| self.use_attn = use_attn | |
| self.do_downsample = downsample | |
| self.blocks = nn.ModuleList() | |
| for _ in range(width): | |
| self.blocks.append(ResBlock( | |
| channels=in_channels, | |
| out_channels=out_channels, | |
| emb_dim=time_embedding_dim, | |
| dropout=dropout, | |
| )) | |
| if self.use_attn: | |
| self.blocks.append(SelfAttentionBlock( | |
| channels=out_channels, | |
| )) | |
| in_channels = out_channels | |
| if self.do_downsample: | |
| self.downsample = Downsample(out_channels) | |
| def forward(self, x, t): | |
| for block in self.blocks: | |
| if isinstance(block, ResBlock): | |
| x = block(x, t) | |
| elif isinstance(block, SelfAttentionBlock): | |
| x = block(x) | |
| residual = x | |
| if self.do_downsample: | |
| x = self.downsample(x) | |
| return x, residual | |
| class Upsample(nn.Module): | |
| def __init__(self, channels): | |
| super().__init__() | |
| self.upsample = nn.Upsample(scale_factor=2) | |
| self.conv = nn.Conv2d( | |
| in_channels=channels, | |
| out_channels=channels, | |
| kernel_size=3, | |
| padding=1, | |
| ) | |
| def forward(self, x): | |
| x = self.upsample(x) | |
| x = self.conv(x) | |
| return x | |
| class UpBlock(nn.Module): | |
| """According to U-Net paper | |
| Every step in the expansive path consists of an upsampling of the feature map followed by a 2x2 | |
| convolution (“up-convolution”) that halves the number of feature channels, a concatenation with | |
| the correspondingly cropped feature map from the contracting path, and two 3x3 convolutions, | |
| each followed by a ReLU. | |
| """ | |
| def __init__(self, in_channels, out_channels, time_embedding_dim, use_attn=False, dropout=0, upsample=True, width=1): | |
| """in_channels will typically be double of out_channels | |
| """ | |
| super().__init__() | |
| self.use_attn = use_attn | |
| self.do_upsample = upsample | |
| self.blocks = nn.ModuleList() | |
| for _ in range(width): | |
| self.blocks.append(ResBlock( | |
| channels=in_channels, | |
| out_channels=out_channels, | |
| emb_dim=time_embedding_dim, | |
| dropout=dropout, | |
| )) | |
| if self.use_attn: | |
| self.blocks.append(SelfAttentionBlock( | |
| channels=out_channels, | |
| )) | |
| in_channels = out_channels | |
| if self.do_upsample: | |
| self.upsample = Upsample(out_channels) | |
| def forward(self, x, t): | |
| for block in self.blocks: | |
| if isinstance(block, ResBlock): | |
| x = block(x, t) | |
| elif isinstance(block, SelfAttentionBlock): | |
| x = block(x) | |
| if self.do_upsample: | |
| x = self.upsample(x) | |
| return x | |
| class Bottleneck(nn.Module): | |
| def __init__(self, channels, dropout, time_embedding_dim): | |
| super().__init__() | |
| in_channels = channels | |
| out_channels = channels | |
| self.resblock_1 = ResBlock( | |
| channels=in_channels, | |
| out_channels=out_channels, | |
| dropout=dropout, | |
| emb_dim=time_embedding_dim | |
| ) | |
| self.attention_block = SelfAttentionBlock( | |
| channels=out_channels, | |
| ) | |
| self.resblock_2 = ResBlock( | |
| channels=out_channels, | |
| out_channels=out_channels, | |
| dropout=dropout, | |
| emb_dim=time_embedding_dim | |
| ) | |
| def forward(self, x, t): | |
| x = self.resblock_1(x, t) | |
| x = self.attention_block(x) | |
| x = self.resblock_2(x, t) | |
| return x | |
| class Unet(nn.Module): | |
| def __init__( | |
| self, | |
| image_channels=3, | |
| res_block_width=2, | |
| starting_channels=128, | |
| dropout=0, | |
| channel_mults=(1, 2, 2, 4, 4), | |
| attention_layers=(False, False, False, True, False) | |
| ): | |
| super().__init__() | |
| self.is_conditional = False | |
| #channel_mults = (1, 2, 2, 2) | |
| #attention_layers = (False, False, True, False) | |
| #res_block_width=3 | |
| self.image_channels = image_channels | |
| self.starting_channels = starting_channels | |
| time_embedding_dim = 4 * starting_channels | |
| self.time_encoding = SinusoidalPositionalEncoding(dim=starting_channels) | |
| self.time_embedding = TimeEmbedding(model_dim=starting_channels, emb_dim=time_embedding_dim) | |
| self.input = nn.Conv2d(3, starting_channels, kernel_size=3, padding=1) | |
| current_channel_count = starting_channels | |
| input_channel_counts = [] | |
| self.contracting_path = nn.ModuleList([]) | |
| for i, channel_multiplier in enumerate(channel_mults): | |
| is_last_layer = i == len(channel_mults) - 1 | |
| next_channel_count = channel_multiplier * starting_channels | |
| self.contracting_path.append(DownBlock( | |
| in_channels=current_channel_count, | |
| out_channels=next_channel_count, | |
| time_embedding_dim=time_embedding_dim, | |
| use_attn=attention_layers[i], | |
| dropout=dropout, | |
| downsample=not is_last_layer, | |
| width=res_block_width, | |
| )) | |
| current_channel_count = next_channel_count | |
| input_channel_counts.append(current_channel_count) | |
| self.bottleneck = Bottleneck(channels=current_channel_count, time_embedding_dim=time_embedding_dim, dropout=dropout) | |
| self.expansive_path = nn.ModuleList([]) | |
| for i, channel_multiplier in enumerate(reversed(channel_mults)): | |
| next_channel_count = channel_multiplier * starting_channels | |
| self.expansive_path.append(UpBlock( | |
| in_channels=current_channel_count + input_channel_counts.pop(), | |
| out_channels=next_channel_count, | |
| time_embedding_dim=time_embedding_dim, | |
| use_attn=list(reversed(attention_layers))[i], | |
| dropout=dropout, | |
| upsample=i != len(channel_mults) - 1, | |
| width=res_block_width, | |
| )) | |
| current_channel_count = next_channel_count | |
| last_conv = nn.Conv2d( | |
| in_channels=starting_channels, | |
| out_channels=image_channels, | |
| kernel_size=3, | |
| padding=1, | |
| ) | |
| last_conv.weight.data.zero_() | |
| self.head = nn.Sequential( | |
| nn.GroupNorm(32, starting_channels), | |
| nn.SiLU(), | |
| last_conv, | |
| ) | |
| def forward(self, x, t): | |
| t = self.time_encoding(t) | |
| return self._forward(x, t) | |
| def _forward(self, x, t): | |
| t = self.time_embedding(t) | |
| x = self.input(x) | |
| residuals = [] | |
| for contracting_block in self.contracting_path: | |
| x, residual = contracting_block(x, t) | |
| residuals.append(residual) | |
| x = self.bottleneck(x, t) | |
| for expansive_block in self.expansive_path: | |
| # Add the residual | |
| residual = residuals.pop() | |
| x = torch.cat([x, residual], dim=1) | |
| x = expansive_block(x, t) | |
| x = self.head(x) | |
| return x | |
| class ConditionalUnet(nn.Module): | |
| def __init__(self, unet, num_classes): | |
| super().__init__() | |
| self.is_conditional = True | |
| self.unet = unet | |
| self.num_classes = num_classes | |
| self.class_embedding = nn.Embedding(num_classes + 1, unet.starting_channels, padding_idx=0) | |
| def to(self, device): | |
| self.device = device | |
| return super().to(device) | |
| def forward(self, x, t, cond=None): | |
| cond = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) | |
| cond = cond.unsqueeze(0) | |
| cond = cond.to(self.device) | |
| # cond: (batch_size, n), where n is the number of classes that we are conditioning on | |
| t = self.unet.time_encoding(t) | |
| if cond is not None: | |
| cond = self.class_embedding(cond) | |
| # sum across the classes so we get a single vector representing the set of classes | |
| cond = cond.sum(dim=1) | |
| t += cond | |
| return self.unet._forward(x, t) | |