Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from einops import einsum | |
| import numpy as np | |
| import pickle | |
| import glob | |
| import os | |
| # ========================================== | |
| # BLOCKS for VQVAE (Down, Mid, Up) | |
| # ========================================== | |
| def get_time_embedding(time_steps, temb_dim): | |
| r""" | |
| Convert time steps tensor into an embedding using the | |
| sinusoidal time embedding formula | |
| :param time_steps: 1D tensor of length batch size | |
| :param temb_dim: Dimension of the embedding | |
| :return: BxD embedding representation of B time steps | |
| """ | |
| assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2" | |
| # factor = 10000^(2i/d_model) | |
| factor = 10000 ** ((torch.arange( | |
| start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2)) | |
| ) | |
| # pos / factor | |
| # timesteps B -> B, 1 -> B, temb_dim | |
| t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor | |
| t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1) | |
| return t_emb | |
| class DownBlock(nn.Module): | |
| r""" | |
| Down conv block with attention. | |
| Sequence of following block | |
| 1. Resnet block with time embedding | |
| 2. Attention block | |
| 3. Downsample | |
| """ | |
| def __init__(self, in_channels, out_channels, t_emb_dim, | |
| down_sample, num_heads, num_layers, attn, norm_channels, cross_attn=False, context_dim=None): | |
| super().__init__() | |
| self.num_layers = num_layers | |
| self.down_sample = down_sample | |
| self.attn = attn | |
| self.context_dim = context_dim | |
| self.cross_attn = cross_attn | |
| self.t_emb_dim = t_emb_dim | |
| self.resnet_conv_first = nn.ModuleList( | |
| [ | |
| nn.Sequential( | |
| nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), | |
| nn.SiLU(), | |
| nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, | |
| kernel_size=3, stride=1, padding=1), | |
| ) | |
| for i in range(num_layers) | |
| ] | |
| ) | |
| if self.t_emb_dim is not None: | |
| self.t_emb_layers = nn.ModuleList([ | |
| nn.Sequential( | |
| nn.SiLU(), | |
| nn.Linear(self.t_emb_dim, out_channels) | |
| ) | |
| for _ in range(num_layers) | |
| ]) | |
| self.resnet_conv_second = nn.ModuleList( | |
| [ | |
| nn.Sequential( | |
| nn.GroupNorm(norm_channels, out_channels), | |
| nn.SiLU(), | |
| nn.Conv2d(out_channels, out_channels, | |
| kernel_size=3, stride=1, padding=1), | |
| ) | |
| for _ in range(num_layers) | |
| ] | |
| ) | |
| if self.attn: | |
| self.attention_norms = nn.ModuleList( | |
| [nn.GroupNorm(norm_channels, out_channels) | |
| for _ in range(num_layers)] | |
| ) | |
| self.attentions = nn.ModuleList( | |
| [nn.MultiheadAttention(out_channels, num_heads, batch_first=True) | |
| for _ in range(num_layers)] | |
| ) | |
| if self.cross_attn: | |
| assert context_dim is not None, "Context Dimension must be passed for cross attention" | |
| self.cross_attention_norms = nn.ModuleList( | |
| [nn.GroupNorm(norm_channels, out_channels) | |
| for _ in range(num_layers)] | |
| ) | |
| self.cross_attentions = nn.ModuleList( | |
| [nn.MultiheadAttention(out_channels, num_heads, batch_first=True) | |
| for _ in range(num_layers)] | |
| ) | |
| self.context_proj = nn.ModuleList( | |
| [nn.Linear(context_dim, out_channels) | |
| for _ in range(num_layers)] | |
| ) | |
| self.residual_input_conv = nn.ModuleList( | |
| [ | |
| nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) | |
| for i in range(num_layers) | |
| ] | |
| ) | |
| self.down_sample_conv = nn.Conv2d(out_channels, out_channels, | |
| 4, 2, 1) if self.down_sample else nn.Identity() | |
| def forward(self, x, t_emb=None, context=None): | |
| out = x | |
| for i in range(self.num_layers): | |
| # Resnet block of Unet | |
| resnet_input = out | |
| out = self.resnet_conv_first[i](out) | |
| if self.t_emb_dim is not None: | |
| out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] | |
| out = self.resnet_conv_second[i](out) | |
| out = out + self.residual_input_conv[i](resnet_input) | |
| if self.attn: | |
| # Attention block of Unet | |
| batch_size, channels, h, w = out.shape | |
| in_attn = out.reshape(batch_size, channels, h * w) | |
| in_attn = self.attention_norms[i](in_attn) | |
| in_attn = in_attn.transpose(1, 2) | |
| out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) | |
| out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) | |
| out = out + out_attn | |
| if self.cross_attn: | |
| assert context is not None, "context cannot be None if cross attention layers are used" | |
| batch_size, channels, h, w = out.shape | |
| in_attn = out.reshape(batch_size, channels, h * w) | |
| in_attn = self.cross_attention_norms[i](in_attn) | |
| in_attn = in_attn.transpose(1, 2) | |
| assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim | |
| context_proj = self.context_proj[i](context) | |
| out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj) | |
| out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) | |
| out = out + out_attn | |
| # Downsample | |
| out = self.down_sample_conv(out) | |
| return out | |
| class MidBlock(nn.Module): | |
| r""" | |
| Mid conv block with attention. | |
| Sequence of following blocks | |
| 1. Resnet block with time embedding | |
| 2. Attention block | |
| 3. Resnet block with time embedding | |
| """ | |
| def __init__(self, in_channels, out_channels, t_emb_dim, num_heads, num_layers, norm_channels, cross_attn=None, context_dim=None): | |
| super().__init__() | |
| self.num_layers = num_layers | |
| self.t_emb_dim = t_emb_dim | |
| self.context_dim = context_dim | |
| self.cross_attn = cross_attn | |
| self.resnet_conv_first = nn.ModuleList( | |
| [ | |
| nn.Sequential( | |
| nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), | |
| nn.SiLU(), | |
| nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1, | |
| padding=1), | |
| ) | |
| for i in range(num_layers + 1) | |
| ] | |
| ) | |
| if self.t_emb_dim is not None: | |
| self.t_emb_layers = nn.ModuleList([ | |
| nn.Sequential( | |
| nn.SiLU(), | |
| nn.Linear(t_emb_dim, out_channels) | |
| ) | |
| for _ in range(num_layers + 1) | |
| ]) | |
| self.resnet_conv_second = nn.ModuleList( | |
| [ | |
| nn.Sequential( | |
| nn.GroupNorm(norm_channels, out_channels), | |
| nn.SiLU(), | |
| nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), | |
| ) | |
| for _ in range(num_layers + 1) | |
| ] | |
| ) | |
| self.attention_norms = nn.ModuleList( | |
| [nn.GroupNorm(norm_channels, out_channels) | |
| for _ in range(num_layers)] | |
| ) | |
| self.attentions = nn.ModuleList( | |
| [nn.MultiheadAttention(out_channels, num_heads, batch_first=True) | |
| for _ in range(num_layers)] | |
| ) | |
| if self.cross_attn: | |
| assert context_dim is not None, "Context Dimension must be passed for cross attention" | |
| self.cross_attention_norms = nn.ModuleList( | |
| [nn.GroupNorm(norm_channels, out_channels) | |
| for _ in range(num_layers)] | |
| ) | |
| self.cross_attentions = nn.ModuleList( | |
| [nn.MultiheadAttention(out_channels, num_heads, batch_first=True) | |
| for _ in range(num_layers)] | |
| ) | |
| self.context_proj = nn.ModuleList( | |
| [nn.Linear(context_dim, out_channels) | |
| for _ in range(num_layers)] | |
| ) | |
| self.residual_input_conv = nn.ModuleList( | |
| [ | |
| nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) | |
| for i in range(num_layers + 1) | |
| ] | |
| ) | |
| def forward(self, x, t_emb=None, context=None): | |
| out = x | |
| # First resnet block | |
| resnet_input = out | |
| out = self.resnet_conv_first[0](out) | |
| if self.t_emb_dim is not None: | |
| out = out + self.t_emb_layers[0](t_emb)[:, :, None, None] | |
| out = self.resnet_conv_second[0](out) | |
| out = out + self.residual_input_conv[0](resnet_input) | |
| for i in range(self.num_layers): | |
| # Attention Block | |
| batch_size, channels, h, w = out.shape | |
| in_attn = out.reshape(batch_size, channels, h * w) | |
| in_attn = self.attention_norms[i](in_attn) | |
| in_attn = in_attn.transpose(1, 2) | |
| out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) | |
| out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) | |
| out = out + out_attn | |
| if self.cross_attn: | |
| assert context is not None, "context cannot be None if cross attention layers are used" | |
| batch_size, channels, h, w = out.shape | |
| in_attn = out.reshape(batch_size, channels, h * w) | |
| in_attn = self.cross_attention_norms[i](in_attn) | |
| in_attn = in_attn.transpose(1, 2) | |
| assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim | |
| context_proj = self.context_proj[i](context) | |
| out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj) | |
| out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) | |
| out = out + out_attn | |
| # Resnet Block | |
| resnet_input = out | |
| out = self.resnet_conv_first[i + 1](out) | |
| if self.t_emb_dim is not None: | |
| out = out + self.t_emb_layers[i + 1](t_emb)[:, :, None, None] | |
| out = self.resnet_conv_second[i + 1](out) | |
| out = out + self.residual_input_conv[i + 1](resnet_input) | |
| return out | |
| class UpBlock(nn.Module): | |
| r""" | |
| Up conv block with attention. | |
| Sequence of following blocks | |
| 1. Upsample | |
| 1. Concatenate Down block output | |
| 2. Resnet block with time embedding | |
| 3. Attention Block | |
| """ | |
| def __init__(self, in_channels, out_channels, t_emb_dim, | |
| up_sample, num_heads, num_layers, attn, norm_channels): | |
| super().__init__() | |
| self.num_layers = num_layers | |
| self.up_sample = up_sample | |
| self.t_emb_dim = t_emb_dim | |
| self.attn = attn | |
| self.resnet_conv_first = nn.ModuleList( | |
| [ | |
| nn.Sequential( | |
| nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), | |
| nn.SiLU(), | |
| nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1, | |
| padding=1), | |
| ) | |
| for i in range(num_layers) | |
| ] | |
| ) | |
| if self.t_emb_dim is not None: | |
| self.t_emb_layers = nn.ModuleList([ | |
| nn.Sequential( | |
| nn.SiLU(), | |
| nn.Linear(t_emb_dim, out_channels) | |
| ) | |
| for _ in range(num_layers) | |
| ]) | |
| self.resnet_conv_second = nn.ModuleList( | |
| [ | |
| nn.Sequential( | |
| nn.GroupNorm(norm_channels, out_channels), | |
| nn.SiLU(), | |
| nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), | |
| ) | |
| for _ in range(num_layers) | |
| ] | |
| ) | |
| if self.attn: | |
| self.attention_norms = nn.ModuleList( | |
| [ | |
| nn.GroupNorm(norm_channels, out_channels) | |
| for _ in range(num_layers) | |
| ] | |
| ) | |
| self.attentions = nn.ModuleList( | |
| [ | |
| nn.MultiheadAttention(out_channels, num_heads, batch_first=True) | |
| for _ in range(num_layers) | |
| ] | |
| ) | |
| self.residual_input_conv = nn.ModuleList( | |
| [ | |
| nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) | |
| for i in range(num_layers) | |
| ] | |
| ) | |
| self.up_sample_conv = nn.ConvTranspose2d(in_channels, in_channels, | |
| 4, 2, 1) \ | |
| if self.up_sample else nn.Identity() | |
| def forward(self, x, out_down=None, t_emb=None): | |
| # Upsample | |
| x = self.up_sample_conv(x) | |
| # Concat with Downblock output | |
| if out_down is not None: | |
| x = torch.cat([x, out_down], dim=1) | |
| out = x | |
| for i in range(self.num_layers): | |
| # Resnet Block | |
| resnet_input = out | |
| out = self.resnet_conv_first[i](out) | |
| if self.t_emb_dim is not None: | |
| out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] | |
| out = self.resnet_conv_second[i](out) | |
| out = out + self.residual_input_conv[i](resnet_input) | |
| # Self Attention | |
| if self.attn: | |
| batch_size, channels, h, w = out.shape | |
| in_attn = out.reshape(batch_size, channels, h * w) | |
| in_attn = self.attention_norms[i](in_attn) | |
| in_attn = in_attn.transpose(1, 2) | |
| out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) | |
| out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) | |
| out = out + out_attn | |
| return out | |
| class UpBlockUnet(nn.Module): | |
| r""" | |
| Up conv block with attention. | |
| Sequence of following blocks | |
| 1. Upsample | |
| 1. Concatenate Down block output | |
| 2. Resnet block with time embedding | |
| 3. Attention Block | |
| """ | |
| def __init__(self, in_channels, out_channels, t_emb_dim, up_sample, | |
| num_heads, num_layers, norm_channels, cross_attn=False, context_dim=None): | |
| super().__init__() | |
| self.num_layers = num_layers | |
| self.up_sample = up_sample | |
| self.t_emb_dim = t_emb_dim | |
| self.cross_attn = cross_attn | |
| self.context_dim = context_dim | |
| self.resnet_conv_first = nn.ModuleList( | |
| [ | |
| nn.Sequential( | |
| nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), | |
| nn.SiLU(), | |
| nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1, | |
| padding=1), | |
| ) | |
| for i in range(num_layers) | |
| ] | |
| ) | |
| if self.t_emb_dim is not None: | |
| self.t_emb_layers = nn.ModuleList([ | |
| nn.Sequential( | |
| nn.SiLU(), | |
| nn.Linear(t_emb_dim, out_channels) | |
| ) | |
| for _ in range(num_layers) | |
| ]) | |
| self.resnet_conv_second = nn.ModuleList( | |
| [ | |
| nn.Sequential( | |
| nn.GroupNorm(norm_channels, out_channels), | |
| nn.SiLU(), | |
| nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), | |
| ) | |
| for _ in range(num_layers) | |
| ] | |
| ) | |
| self.attention_norms = nn.ModuleList( | |
| [ | |
| nn.GroupNorm(norm_channels, out_channels) | |
| for _ in range(num_layers) | |
| ] | |
| ) | |
| self.attentions = nn.ModuleList( | |
| [ | |
| nn.MultiheadAttention(out_channels, num_heads, batch_first=True) | |
| for _ in range(num_layers) | |
| ] | |
| ) | |
| if self.cross_attn: | |
| assert context_dim is not None, "Context Dimension must be passed for cross attention" | |
| self.cross_attention_norms = nn.ModuleList( | |
| [nn.GroupNorm(norm_channels, out_channels) | |
| for _ in range(num_layers)] | |
| ) | |
| self.cross_attentions = nn.ModuleList( | |
| [nn.MultiheadAttention(out_channels, num_heads, batch_first=True) | |
| for _ in range(num_layers)] | |
| ) | |
| self.context_proj = nn.ModuleList( | |
| [nn.Linear(context_dim, out_channels) | |
| for _ in range(num_layers)] | |
| ) | |
| self.residual_input_conv = nn.ModuleList( | |
| [ | |
| nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) | |
| for i in range(num_layers) | |
| ] | |
| ) | |
| self.up_sample_conv = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, | |
| 4, 2, 1) \ | |
| if self.up_sample else nn.Identity() | |
| def forward(self, x, out_down=None, t_emb=None, context=None): | |
| x = self.up_sample_conv(x) | |
| if out_down is not None: | |
| x = torch.cat([x, out_down], dim=1) | |
| out = x | |
| for i in range(self.num_layers): | |
| # Resnet | |
| resnet_input = out | |
| out = self.resnet_conv_first[i](out) | |
| if self.t_emb_dim is not None: | |
| out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] | |
| out = self.resnet_conv_second[i](out) | |
| out = out + self.residual_input_conv[i](resnet_input) | |
| # Self Attention | |
| batch_size, channels, h, w = out.shape | |
| in_attn = out.reshape(batch_size, channels, h * w) | |
| in_attn = self.attention_norms[i](in_attn) | |
| in_attn = in_attn.transpose(1, 2) | |
| out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) | |
| out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) | |
| out = out + out_attn | |
| # Cross Attention | |
| if self.cross_attn: | |
| assert context is not None, "context cannot be None if cross attention layers are used" | |
| batch_size, channels, h, w = out.shape | |
| in_attn = out.reshape(batch_size, channels, h * w) | |
| in_attn = self.cross_attention_norms[i](in_attn) | |
| in_attn = in_attn.transpose(1, 2) | |
| assert len(context.shape) == 3, \ | |
| "Context shape does not match B,_,CONTEXT_DIM" | |
| assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim,\ | |
| "Context shape does not match B,_,CONTEXT_DIM" | |
| context_proj = self.context_proj[i](context) | |
| out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj) | |
| out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) | |
| out = out + out_attn | |
| return out | |
| # ========================================== | |
| # VQVAE Definition | |
| # ========================================== | |
| class VQVAE(nn.Module): | |
| def __init__(self, im_channels, model_config): | |
| super().__init__() | |
| self.down_channels = model_config['down_channels'] | |
| self.mid_channels = model_config['mid_channels'] | |
| self.down_sample = model_config['down_sample'] | |
| self.num_down_layers = model_config['num_down_layers'] | |
| self.num_mid_layers = model_config['num_mid_layers'] | |
| self.num_up_layers = model_config['num_up_layers'] | |
| # To disable attention in Downblock of Encoder and Upblock of Decoder | |
| self.attns = model_config['attn_down'] | |
| #Latent Dimension | |
| self.z_channels = model_config['z_channels'] | |
| self.codebook_size = model_config['codebook_size'] | |
| self.norm_channels = model_config['norm_channels'] | |
| self.num_heads = model_config['num_heads'] | |
| #Assertion to validate the channel information | |
| assert self.mid_channels[0] == self.down_channels[-1] | |
| assert self.mid_channels[-1] == self.down_channels[-1] | |
| assert len(self.down_sample) == len(self.down_channels) - 1 | |
| assert len(self.attns) == len(self.down_channels) - 1 | |
| # Wherever we use downsampling in encoder correspondingly use | |
| # upsampling in decoder | |
| self.up_sample = list(reversed(self.down_sample)) | |
| ## Encoder ## | |
| self.encoder_conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=(1, 1)) | |
| # Downblock + Midblock | |
| self.encoder_layers = nn.ModuleList([]) | |
| for i in range(len(self.down_channels) - 1): | |
| self.encoder_layers.append(DownBlock(self.down_channels[i], self.down_channels[i + 1], | |
| t_emb_dim=None, down_sample=self.down_sample[i], | |
| num_heads=self.num_heads, | |
| num_layers=self.num_down_layers, | |
| attn=self.attns[i], | |
| norm_channels=self.norm_channels)) | |
| self.encoder_mids = nn.ModuleList([]) | |
| for i in range(len(self.mid_channels) - 1): | |
| self.encoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i + 1], | |
| t_emb_dim=None, | |
| num_heads=self.num_heads, | |
| num_layers=self.num_mid_layers, | |
| norm_channels=self.norm_channels)) | |
| self.encoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[-1]) | |
| self.encoder_conv_out = nn.Conv2d(self.down_channels[-1], self.z_channels, kernel_size=3, padding=1) | |
| # Pre Quantization Convolution | |
| self.pre_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1) | |
| # Codebook | |
| self.embedding = nn.Embedding(self.codebook_size, self.z_channels) | |
| ## Decoder ## | |
| # Post Quantization Convolution | |
| self.post_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1) | |
| self.decoder_conv_in = nn.Conv2d(self.z_channels, self.mid_channels[-1], kernel_size=3, padding=(1, 1)) | |
| # Midblock + Upblock | |
| self.decoder_mids = nn.ModuleList([]) | |
| for i in reversed(range(1, len(self.mid_channels))): | |
| self.decoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i - 1], | |
| t_emb_dim=None, | |
| num_heads=self.num_heads, | |
| num_layers=self.num_mid_layers, | |
| norm_channels=self.norm_channels)) | |
| self.decoder_layers = nn.ModuleList([]) | |
| for i in reversed(range(1, len(self.down_channels))): | |
| self.decoder_layers.append(UpBlock(self.down_channels[i], self.down_channels[i - 1], | |
| t_emb_dim=None, up_sample=self.down_sample[i - 1], | |
| num_heads=self.num_heads, | |
| num_layers=self.num_up_layers, | |
| attn=self.attns[i-1], | |
| norm_channels=self.norm_channels)) | |
| self.decoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[0]) | |
| self.decoder_conv_out = nn.Conv2d(self.down_channels[0], im_channels, kernel_size=3, padding=1) | |
| def quantize(self, x): | |
| B, C, H, W = x.shape | |
| # B, C, H, W -> B, H, W, C | |
| x = x.permute(0, 2, 3, 1) | |
| # B, H, W, C -> B, H*W, C | |
| x = x.reshape(x.size(0), -1, x.size(-1)) | |
| # Find nearest embedding/codebook vector | |
| # dist between (B, H*W, C) and (B, K, C) -> (B, H*W, K) | |
| dist = torch.cdist(x, self.embedding.weight[None, :].repeat((x.size(0), 1, 1))) | |
| # (B, H*W) | |
| min_encoding_indices = torch.argmin(dist, dim=-1) | |
| # Replace encoder output with nearest codebook | |
| # quant_out -> B*H*W, C | |
| quant_out = torch.index_select(self.embedding.weight, 0, min_encoding_indices.view(-1)) | |
| # x -> B*H*W, C | |
| x = x.reshape((-1, x.size(-1))) | |
| commmitment_loss = torch.mean((quant_out.detach() - x) ** 2) | |
| codebook_loss = torch.mean((quant_out - x.detach()) ** 2) | |
| quantize_losses = { | |
| 'codebook_loss': codebook_loss, | |
| 'commitment_loss': commmitment_loss | |
| } | |
| # Straight through estimation | |
| quant_out = x + (quant_out - x).detach() | |
| # quant_out -> B, C, H, W | |
| quant_out = quant_out.reshape((B, H, W, C)).permute(0, 3, 1, 2) | |
| min_encoding_indices = min_encoding_indices.reshape((-1, quant_out.size(-2), quant_out.size(-1))) | |
| return quant_out, quantize_losses, min_encoding_indices | |
| def encode(self, x): | |
| out = self.encoder_conv_in(x) | |
| for idx, down in enumerate(self.encoder_layers): | |
| out = down(out) | |
| for mid in self.encoder_mids: | |
| out = mid(out) | |
| out = self.encoder_norm_out(out) | |
| out = nn.SiLU()(out) | |
| out = self.encoder_conv_out(out) | |
| out = self.pre_quant_conv(out) | |
| out, quant_losses, _ = self.quantize(out) | |
| return out, quant_losses | |
| def decode(self, z): | |
| out = z | |
| out = self.post_quant_conv(out) | |
| out = self.decoder_conv_in(out) | |
| for mid in self.decoder_mids: | |
| out = mid(out) | |
| for idx, up in enumerate(self.decoder_layers): | |
| out = up(out) | |
| out = self.decoder_norm_out(out) | |
| out = nn.SiLU()(out) | |
| out = self.decoder_conv_out(out) | |
| return out | |
| def forward(self, x): | |
| z, quant_losses = self.encode(x) | |
| out = self.decode(z) | |
| return out, z, quant_losses | |
| # ========================================== | |
| # SPADE Definitions | |
| # ========================================== | |
| class SPADE(nn.Module): | |
| def __init__(self, norm_nc, label_nc): | |
| super().__init__() | |
| self.param_free_norm = nn.GroupNorm(32, norm_nc) | |
| nhidden = 128 | |
| # Convolutions to generate modulation parameters from the mask | |
| self.mlp_shared = nn.Sequential( | |
| nn.Conv2d(label_nc, nhidden, kernel_size=3, padding=1), | |
| nn.ReLU() | |
| ) | |
| self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1) | |
| self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1) | |
| def forward(self, x, segmap): | |
| # 1. Normalize | |
| normalized = self.param_free_norm(x) | |
| # 2. Resize mask to match x's resolution | |
| if segmap.size()[2:] != x.size()[2:]: | |
| segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest') | |
| # 3. Generate params | |
| actv = self.mlp_shared(segmap) | |
| gamma = self.mlp_gamma(actv) | |
| beta = self.mlp_beta(actv) | |
| # 4. Modulate | |
| out = normalized * (1 + gamma) + beta | |
| return out | |
| class SPADEResnetBlock(nn.Module): | |
| """ | |
| Simplified SPADE Block: Norm -> Act -> Conv | |
| (We removed the internal shortcut because DownBlock/MidBlock handles the residual connection) | |
| """ | |
| def __init__(self, in_channels, out_channels, label_nc): | |
| super().__init__() | |
| # 1. SPADE Normalization (Uses Mask) | |
| self.norm1 = SPADE(in_channels, label_nc) | |
| # 2. Activation | |
| self.act1 = nn.SiLU() | |
| # 3. Convolution | |
| self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) | |
| def forward(self, x, segmap): | |
| # Apply SPADE Norm -> Act -> Conv | |
| h = self.norm1(x, segmap) | |
| h = self.act1(h) | |
| h = self.conv1(h) | |
| return h | |
| # ========================================== | |
| # BLOCKS (Down, Mid, Up) | |
| # ========================================== | |
| def get_time_embedding(time_steps, temb_dim): | |
| assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2" | |
| factor = 10000 ** ((torch.arange( | |
| start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2)) | |
| ) | |
| t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor | |
| t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1) | |
| return t_emb | |
| class SpadeDownBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels, t_emb_dim, down_sample, num_heads, | |
| num_layers, attn, norm_channels, cross_attn=False, context_dim=None, label_nc=4): | |
| super().__init__() | |
| self.num_layers = num_layers | |
| self.down_sample = down_sample | |
| self.attn = attn | |
| self.context_dim = context_dim | |
| self.cross_attn = cross_attn | |
| self.t_emb_dim = t_emb_dim | |
| # REPLACED nn.Sequential with SPADEResnetBlock | |
| self.resnet_conv_first = nn.ModuleList([ | |
| SPADEResnetBlock(in_channels if i == 0 else out_channels, out_channels, label_nc) | |
| for i in range(num_layers) | |
| ]) | |
| if self.t_emb_dim is not None: | |
| self.t_emb_layers = nn.ModuleList([ | |
| nn.Sequential(nn.SiLU(), nn.Linear(self.t_emb_dim, out_channels)) | |
| for _ in range(num_layers) | |
| ]) | |
| # REPLACED nn.Sequential with SPADEResnetBlock | |
| self.resnet_conv_second = nn.ModuleList([ | |
| SPADEResnetBlock(out_channels, out_channels, label_nc) | |
| for _ in range(num_layers) | |
| ]) | |
| if self.attn: | |
| self.attention_norms = nn.ModuleList([nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)]) | |
| self.attentions = nn.ModuleList([nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers)]) | |
| if self.cross_attn: | |
| self.cross_attention_norms = nn.ModuleList([nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)]) | |
| self.cross_attentions = nn.ModuleList([nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers)]) | |
| self.context_proj = nn.ModuleList([nn.Linear(context_dim, out_channels) for _ in range(num_layers)]) | |
| self.residual_input_conv = nn.ModuleList([ | |
| nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) | |
| for i in range(num_layers) | |
| ]) | |
| self.down_sample_conv = nn.Conv2d(out_channels, out_channels, 4, 2, 1) if self.down_sample else nn.Identity() | |
| def forward(self, x, t_emb=None, context=None, segmap=None): | |
| out = x | |
| for i in range(self.num_layers): | |
| resnet_input = out | |
| # SPADE Block 1 (Pass segmap) | |
| out = self.resnet_conv_first[i](out, segmap) | |
| if self.t_emb_dim is not None: | |
| out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] | |
| # SPADE Block 2 (Pass segmap) | |
| out = self.resnet_conv_second[i](out, segmap) | |
| # No residual add here because SPADEResnetBlock handles its own residual/shortcut | |
| # But your original code added another residual from the very start of the loop | |
| out = out + self.residual_input_conv[i](resnet_input) | |
| if self.attn: | |
| batch_size, channels, h, w = out.shape | |
| in_attn = out.reshape(batch_size, channels, h * w) | |
| in_attn = self.attention_norms[i](in_attn) | |
| in_attn = in_attn.transpose(1, 2) | |
| out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) | |
| out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) | |
| out = out + out_attn | |
| if self.cross_attn: | |
| batch_size, channels, h, w = out.shape | |
| in_attn = out.reshape(batch_size, channels, h * w) | |
| in_attn = self.cross_attention_norms[i](in_attn) | |
| in_attn = in_attn.transpose(1, 2) | |
| context_proj = self.context_proj[i](context) | |
| out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj) | |
| out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) | |
| out = out + out_attn | |
| out = self.down_sample_conv(out) | |
| return out | |
| class SpadeMidBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels, t_emb_dim, num_heads, num_layers, norm_channels, cross_attn=None, context_dim=None, label_nc=4): | |
| super().__init__() | |
| self.num_layers = num_layers | |
| self.t_emb_dim = t_emb_dim | |
| self.context_dim = context_dim | |
| self.cross_attn = cross_attn | |
| # REPLACED with SPADE | |
| self.resnet_conv_first = nn.ModuleList([ | |
| SPADEResnetBlock(in_channels if i == 0 else out_channels, out_channels, label_nc) | |
| for i in range(num_layers + 1) | |
| ]) | |
| if self.t_emb_dim is not None: | |
| self.t_emb_layers = nn.ModuleList([ | |
| nn.Sequential(nn.SiLU(), nn.Linear(t_emb_dim, out_channels)) | |
| for _ in range(num_layers + 1) | |
| ]) | |
| # REPLACED with SPADE | |
| self.resnet_conv_second = nn.ModuleList([ | |
| SPADEResnetBlock(out_channels, out_channels, label_nc) | |
| for _ in range(num_layers + 1) | |
| ]) | |
| self.attention_norms = nn.ModuleList([nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)]) | |
| self.attentions = nn.ModuleList([nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers)]) | |
| if self.cross_attn: | |
| self.cross_attention_norms = nn.ModuleList([nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)]) | |
| self.cross_attentions = nn.ModuleList([nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers)]) | |
| self.context_proj = nn.ModuleList([nn.Linear(context_dim, out_channels) for _ in range(num_layers)]) | |
| self.residual_input_conv = nn.ModuleList([ | |
| nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) | |
| for i in range(num_layers + 1) | |
| ]) | |
| def forward(self, x, t_emb=None, context=None, segmap=None): | |
| out = x | |
| # First Block (No Attention) | |
| resnet_input = out | |
| out = self.resnet_conv_first[0](out, segmap) # Pass segmap | |
| if self.t_emb_dim is not None: | |
| out = out + self.t_emb_layers[0](t_emb)[:, :, None, None] | |
| out = self.resnet_conv_second[0](out, segmap) # Pass segmap | |
| out = out + self.residual_input_conv[0](resnet_input) | |
| for i in range(self.num_layers): | |
| # Attention | |
| batch_size, channels, h, w = out.shape | |
| in_attn = out.reshape(batch_size, channels, h * w) | |
| in_attn = self.attention_norms[i](in_attn) | |
| in_attn = in_attn.transpose(1, 2) | |
| out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) | |
| out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) | |
| out = out + out_attn | |
| if self.cross_attn: | |
| batch_size, channels, h, w = out.shape | |
| in_attn = out.reshape(batch_size, channels, h * w) | |
| in_attn = self.cross_attention_norms[i](in_attn) | |
| in_attn = in_attn.transpose(1, 2) | |
| context_proj = self.context_proj[i](context) | |
| out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj) | |
| out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) | |
| out = out + out_attn | |
| # Next Resnet Block | |
| resnet_input = out | |
| out = self.resnet_conv_first[i + 1](out, segmap) # Pass segmap | |
| if self.t_emb_dim is not None: | |
| out = out + self.t_emb_layers[i + 1](t_emb)[:, :, None, None] | |
| out = self.resnet_conv_second[i + 1](out, segmap) # Pass segmap | |
| out = out + self.residual_input_conv[i + 1](resnet_input) | |
| return out | |
| class SpadeUpBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels, t_emb_dim, up_sample, num_heads, | |
| num_layers, norm_channels, cross_attn=False, context_dim=None, label_nc=4): | |
| super().__init__() | |
| self.num_layers = num_layers | |
| self.up_sample = up_sample | |
| self.t_emb_dim = t_emb_dim | |
| self.cross_attn = cross_attn | |
| self.context_dim = context_dim | |
| # REPLACED with SPADE | |
| self.resnet_conv_first = nn.ModuleList([ | |
| SPADEResnetBlock(in_channels if i == 0 else out_channels, out_channels, label_nc) | |
| for i in range(num_layers) | |
| ]) | |
| if self.t_emb_dim is not None: | |
| self.t_emb_layers = nn.ModuleList([ | |
| nn.Sequential(nn.SiLU(), nn.Linear(t_emb_dim, out_channels)) | |
| for _ in range(num_layers) | |
| ]) | |
| # REPLACED with SPADE | |
| self.resnet_conv_second = nn.ModuleList([ | |
| SPADEResnetBlock(out_channels, out_channels, label_nc) | |
| for _ in range(num_layers) | |
| ]) | |
| self.attention_norms = nn.ModuleList([nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)]) | |
| self.attentions = nn.ModuleList([nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers)]) | |
| if self.cross_attn: | |
| self.cross_attention_norms = nn.ModuleList([nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)]) | |
| self.cross_attentions = nn.ModuleList([nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers)]) | |
| self.context_proj = nn.ModuleList([nn.Linear(context_dim, out_channels) for _ in range(num_layers)]) | |
| self.residual_input_conv = nn.ModuleList([ | |
| nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) | |
| for i in range(num_layers) | |
| ]) | |
| self.up_sample_conv = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, 4, 2, 1) if self.up_sample else nn.Identity() | |
| def forward(self, x, out_down=None, t_emb=None, context=None, segmap=None): | |
| x = self.up_sample_conv(x) | |
| if out_down is not None: | |
| x = torch.cat([x, out_down], dim=1) | |
| out = x | |
| for i in range(self.num_layers): | |
| resnet_input = out | |
| out = self.resnet_conv_first[i](out, segmap) # Pass segmap | |
| if self.t_emb_dim is not None: | |
| out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] | |
| out = self.resnet_conv_second[i](out, segmap) # Pass segmap | |
| out = out + self.residual_input_conv[i](resnet_input) | |
| batch_size, channels, h, w = out.shape | |
| in_attn = out.reshape(batch_size, channels, h * w) | |
| in_attn = self.attention_norms[i](in_attn) | |
| in_attn = in_attn.transpose(1, 2) | |
| out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) | |
| out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) | |
| out = out + out_attn | |
| if self.cross_attn: | |
| batch_size, channels, h, w = out.shape | |
| in_attn = out.reshape(batch_size, channels, h * w) | |
| in_attn = self.cross_attention_norms[i](in_attn) | |
| in_attn = in_attn.transpose(1, 2) | |
| context_proj = self.context_proj[i](context) | |
| out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj) | |
| out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) | |
| out = out + out_attn | |
| return out | |
| # ========================================== | |
| # Helper Fuctions | |
| # ========================================== | |
| def validate_image_config(condition_config): | |
| assert 'image_condition_config' in condition_config, "Image conditioning desired but config missing" | |
| assert 'image_condition_input_channels' in condition_config['image_condition_config'], "Input channels missing" | |
| assert 'image_condition_output_channels' in condition_config['image_condition_config'], "Output channels missing" | |
| def validate_image_conditional_input(cond_input, x): | |
| assert 'image' in cond_input, "Model initialized with image conditioning but input missing" | |
| assert cond_input['image'].shape[0] == x.shape[0], "Batch size mismatch" | |
| def get_config_value(config, key, default_value): | |
| return config[key] if key in config else default_value | |
| def get_time_embedding(time_steps, temb_dim): | |
| assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2" | |
| factor = 10000 ** ((torch.arange( | |
| start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2)) | |
| ) | |
| t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor | |
| t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1) | |
| return t_emb | |
| def drop_image_condition(image_condition, im, im_drop_prob): | |
| if im_drop_prob > 0: | |
| im_drop_mask = torch.zeros((im.shape[0], 1, 1, 1), device=im.device).float().uniform_(0, 1) > im_drop_prob | |
| return image_condition * im_drop_mask | |
| else: | |
| return image_condition | |
| # ========================================== | |
| # UNET Definition | |
| # ========================================== | |
| class Unet(nn.Module): | |
| #Unet model with SPADE integration for anatomical consistency. | |
| def __init__(self, im_channels, model_config): | |
| super().__init__() | |
| self.down_channels = model_config['down_channels'] | |
| self.mid_channels = model_config['mid_channels'] | |
| self.t_emb_dim = model_config['time_emb_dim'] | |
| self.down_sample = model_config['down_sample'] | |
| self.num_down_layers = model_config['num_down_layers'] | |
| self.num_mid_layers = model_config['num_mid_layers'] | |
| self.num_up_layers = model_config['num_up_layers'] | |
| self.attns = model_config['attn_down'] | |
| self.norm_channels = model_config['norm_channels'] | |
| self.num_heads = model_config['num_heads'] | |
| self.conv_out_channels = model_config['conv_out_channels'] | |
| # Validate Config | |
| assert self.mid_channels[0] == self.down_channels[-1] | |
| assert self.mid_channels[-1] == self.down_channels[-2] | |
| assert len(self.down_sample) == len(self.down_channels) - 1 | |
| assert len(self.attns) == len(self.down_channels) - 1 | |
| # Conditioning Setup | |
| self.image_cond = False | |
| self.condition_config = get_config_value(model_config, 'condition_config', None) | |
| # Default mask channels (usually 4: BG, LV, Myo, RV) | |
| self.im_cond_input_ch = 4 | |
| if self.condition_config is not None: | |
| if 'image' in self.condition_config.get('condition_types', []): | |
| self.image_cond = True | |
| self.im_cond_input_ch = self.condition_config['image_condition_config']['image_condition_input_channels'] | |
| self.im_cond_output_ch = self.condition_config['image_condition_config']['image_condition_output_channels'] | |
| # Standard Input Conv | |
| # SPADE injects the mask later, so we just take the latent input here. | |
| self.conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=1) | |
| # Time Embedding | |
| self.t_proj = nn.Sequential( | |
| nn.Linear(self.t_emb_dim, self.t_emb_dim), nn.SiLU(), nn.Linear(self.t_emb_dim, self.t_emb_dim) | |
| ) | |
| self.up_sample = list(reversed(self.down_sample)) | |
| self.downs = nn.ModuleList([]) | |
| # Pass label_nc to Blocks | |
| for i in range(len(self.down_channels) - 1): | |
| self.downs.append(SpadeDownBlock( | |
| self.down_channels[i], self.down_channels[i + 1], self.t_emb_dim, | |
| down_sample=self.down_sample[i], num_heads=self.num_heads, | |
| num_layers=self.num_down_layers, attn=self.attns[i], | |
| norm_channels=self.norm_channels, | |
| label_nc=self.im_cond_input_ch # SPADE needs this | |
| )) | |
| self.mids = nn.ModuleList([]) | |
| for i in range(len(self.mid_channels) - 1): | |
| self.mids.append(SpadeMidBlock( | |
| self.mid_channels[i], self.mid_channels[i + 1], self.t_emb_dim, | |
| num_heads=self.num_heads, num_layers=self.num_mid_layers, | |
| norm_channels=self.norm_channels, | |
| label_nc=self.im_cond_input_ch # SPADE needs this | |
| )) | |
| self.ups = nn.ModuleList([]) | |
| for i in reversed(range(len(self.down_channels) - 1)): | |
| self.ups.append(SpadeUpBlock( | |
| self.down_channels[i] * 2, self.down_channels[i - 1] if i != 0 else self.conv_out_channels, | |
| self.t_emb_dim, up_sample=self.down_sample[i], num_heads=self.num_heads, | |
| num_layers=self.num_up_layers, norm_channels=self.norm_channels, | |
| label_nc=self.im_cond_input_ch # SPADE needs this | |
| )) | |
| self.norm_out = nn.GroupNorm(self.norm_channels, self.conv_out_channels) | |
| self.conv_out = nn.Conv2d(self.conv_out_channels, im_channels, kernel_size=3, padding=1) | |
| def forward(self, x, t, cond_input=None): | |
| # 1. Validation | |
| if self.image_cond: | |
| validate_image_conditional_input(cond_input, x) | |
| # Get the mask, but don't concatenate yet | |
| im_cond = cond_input['image'] | |
| else: | |
| im_cond = None | |
| # 2. Initial Conv (Standard) | |
| out = self.conv_in(x) | |
| # 3. Time Embedding | |
| t_emb = get_time_embedding(torch.as_tensor(t).long(), self.t_emb_dim) | |
| t_emb = self.t_proj(t_emb) | |
| # 4. Down Blocks (Pass segmap) | |
| down_outs = [] | |
| for down in self.downs: | |
| down_outs.append(out) | |
| # Inject Mask into Block | |
| out = down(out, t_emb, segmap=im_cond) | |
| # 5. Mid Blocks (Pass segmap) | |
| for mid in self.mids: | |
| # Inject Mask into Block | |
| out = mid(out, t_emb, segmap=im_cond) | |
| # 6. Up Blocks (Pass segmap) | |
| for up in self.ups: | |
| down_out = down_outs.pop() | |
| # Inject Mask into Block | |
| out = up(out, down_out, t_emb, segmap=im_cond) | |
| out = self.norm_out(out) | |
| out = nn.SiLU()(out) | |
| out = self.conv_out(out) | |
| return out | |
| # ========================================== | |
| # Noise Schedular Definition | |
| # ========================================== | |
| class LinearNoiseScheduler: | |
| def __init__(self, num_timesteps, beta_start, beta_end): | |
| self.num_timesteps = num_timesteps | |
| self.beta_start = beta_start | |
| self.beta_end = beta_end | |
| self.betas = (torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_timesteps) ** 2) | |
| self.alphas = 1. - self.betas | |
| self.alpha_cum_prod = torch.cumprod(self.alphas, dim=0) | |
| self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod) | |
| self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1 - self.alpha_cum_prod) | |
| def add_noise(self, original, noise, t): | |
| original_shape = original.shape | |
| batch_size = original_shape[0] | |
| sqrt_alpha_cum_prod = self.sqrt_alpha_cum_prod.to(original.device)[t].reshape(batch_size) | |
| sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod.to(original.device)[t].reshape(batch_size) | |
| for _ in range(len(original_shape) - 1): | |
| sqrt_alpha_cum_prod = sqrt_alpha_cum_prod.unsqueeze(-1) | |
| sqrt_one_minus_alpha_cum_prod = sqrt_one_minus_alpha_cum_prod.unsqueeze(-1) | |
| return (sqrt_alpha_cum_prod * original + sqrt_one_minus_alpha_cum_prod * noise) | |
| def sample_prev_timestep(self, xt, noise_pred, t): | |
| """ | |
| Reverse diffusion process: Remove noise to get x_{t-1} | |
| """ | |
| sqrt_one_minus_alpha_bar = self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t].view(-1, 1, 1, 1) | |
| sqrt_alpha_bar = self.sqrt_alpha_cum_prod.to(xt.device)[t].view(-1, 1, 1, 1) | |
| beta_t = self.betas.to(xt.device)[t].view(-1, 1, 1, 1) | |
| alpha_t = self.alphas.to(xt.device)[t].view(-1, 1, 1, 1) | |
| # 1. Estimate x0 (Original image) | |
| x0 = (xt - (sqrt_one_minus_alpha_bar * noise_pred)) / sqrt_alpha_bar | |
| x0 = torch.clamp(x0, -1., 1.) | |
| # 2. Calculate Mean of x_{t-1} | |
| mean = (xt - (beta_t * noise_pred) / sqrt_one_minus_alpha_bar) / torch.sqrt(alpha_t) | |
| # 3. Add Noise (if not last step) | |
| if t[0] == 0: | |
| return mean, x0 | |
| else: | |
| # Reshape variance to [Batch, 1, 1, 1] too | |
| variance = ((1 - self.alpha_cum_prod.to(xt.device)[t - 1]) / (1.0 - self.alpha_cum_prod.to(xt.device)[t])) * self.betas.to(xt.device)[t] | |
| sigma = (variance ** 0.5).view(-1, 1, 1, 1) | |
| z = torch.randn(xt.shape).to(xt.device) | |
| return mean + sigma * z, x0 | |
| # 1. Estimate x0 (Original image) | |
| # x0 = ((xt - (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t] * noise_pred)) / | |
| # torch.sqrt(self.alpha_cum_prod.to(xt.device)[t])) | |
| # x0 = torch.clamp(x0, -1., 1.) | |
| # # 2. Calculate Mean of x_{t-1} | |
| # mean = xt - ((self.betas.to(xt.device)[t]) * noise_pred) / (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t]) | |
| # mean = mean / torch.sqrt(self.alphas.to(xt.device)[t]) | |
| # # 3. Add Noise (if not last step) | |
| # if t == 0: | |
| # return mean, x0 | |
| # else: | |
| # variance = (1 - self.alpha_cum_prod.to(xt.device)[t - 1]) / (1.0 - self.alpha_cum_prod.to(xt.device)[t]) | |
| # variance = variance * self.betas.to(xt.device)[t] | |
| # sigma = variance ** 0.5 | |
| # z = torch.randn(xt.shape).to(xt.device) | |
| # return mean + sigma * z, x0 |