| import torch |
| import torch.nn as nn |
|
|
|
|
| def get_time_embedding(time_steps, temb_dim): |
| r""" |
| Convert time steps tensor into an embedding using the |
| sinusoidal time embedding formula |
| :param time_steps: 1D tensor of length batch size |
| :param temb_dim: Dimension of the embedding |
| :return: BxD embedding representation of B time steps |
| """ |
| assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2" |
| |
| |
| factor = 10000 ** ((torch.arange( |
| start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2)) |
| ) |
| |
| |
| |
| t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor |
| t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1) |
| return t_emb |
|
|
|
|
| class DownBlock(nn.Module): |
| r""" |
| Down conv block with attention. |
| Sequence of following block |
| 1. Resnet block with time embedding |
| 2. Attention block |
| 3. Downsample |
| """ |
| |
| def __init__(self, in_channels, out_channels, t_emb_dim, |
| down_sample, num_heads, num_layers, attn, norm_channels, cross_attn=False, context_dim=None): |
| super().__init__() |
| self.num_layers = num_layers |
| self.down_sample = down_sample |
| self.attn = attn |
| self.context_dim = context_dim |
| self.cross_attn = cross_attn |
| self.t_emb_dim = t_emb_dim |
| self.resnet_conv_first = nn.ModuleList( |
| [ |
| nn.Sequential( |
| nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), |
| nn.SiLU(), |
| nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, |
| kernel_size=3, stride=1, padding=1), |
| ) |
| for i in range(num_layers) |
| ] |
| ) |
| if self.t_emb_dim is not None: |
| self.t_emb_layers = nn.ModuleList([ |
| nn.Sequential( |
| nn.SiLU(), |
| nn.Linear(self.t_emb_dim, out_channels) |
| ) |
| for _ in range(num_layers) |
| ]) |
| self.resnet_conv_second = nn.ModuleList( |
| [ |
| nn.Sequential( |
| nn.GroupNorm(num_groups=norm_channels, num_channels=out_channels), |
| nn.SiLU(), |
| nn.Conv2d(out_channels, out_channels, |
| kernel_size=3, stride=1, padding=1), |
| ) |
| for _ in range(num_layers) |
| ] |
| ) |
| |
| if self.attn: |
| self.attention_norms = nn.ModuleList( |
| [nn.GroupNorm(num_groups=norm_channels, num_channels=out_channels) |
| for _ in range(num_layers)] |
| ) |
| |
| self.attentions = nn.ModuleList( |
| [nn.MultiheadAttention(out_channels, num_heads, batch_first=True) |
| for _ in range(num_layers)] |
| ) |
| |
| if self.cross_attn: |
| assert context_dim is not None, "Context Dimension must be passed for cross attention" |
| self.cross_attention_norms = nn.ModuleList( |
| [nn.GroupNorm(norm_channels, out_channels) |
| for _ in range(num_layers)] |
| ) |
| self.cross_attentions = nn.ModuleList( |
| [nn.MultiheadAttention(out_channels, num_heads, batch_first=True) |
| for _ in range(num_layers)] |
| ) |
| self.context_proj = nn.ModuleList( |
| [nn.Linear(context_dim, out_channels) |
| for _ in range(num_layers)] |
| ) |
|
|
| self.residual_input_conv = nn.ModuleList( |
| [ |
| nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) |
| for i in range(num_layers) |
| ] |
| ) |
| self.down_sample_conv = nn.Conv2d(out_channels, out_channels, |
| 4, 2, 1) if self.down_sample else nn.Identity() |
| |
| def forward(self, x, t_emb=None, context=None): |
| out = x |
| for i in range(self.num_layers): |
| |
| resnet_input = out |
| out = self.resnet_conv_first[i](out) |
| if self.t_emb_dim is not None: |
| out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] |
| out = self.resnet_conv_second[i](out) |
| out = out + self.residual_input_conv[i](resnet_input) |
| |
| if self.attn: |
| |
| batch_size, channels, h, w = out.shape |
| in_attn = out.reshape(batch_size, channels, h * w) |
| in_attn = self.attention_norms[i](in_attn) |
| in_attn = in_attn.transpose(1, 2) |
| out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) |
| out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) |
| out = out + out_attn |
| |
| if self.cross_attn: |
| assert context is not None, "context cannot be None if cross attention layers are used" |
| batch_size, channels, h, w = out.shape |
| in_attn = out.reshape(batch_size, channels, h * w) |
| in_attn = self.cross_attention_norms[i](in_attn) |
| in_attn = in_attn.transpose(1, 2) |
| assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim |
| context_proj = self.context_proj[i](context) |
| out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj) |
| out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) |
| out = out + out_attn |
| |
| |
| out = self.down_sample_conv(out) |
| return out |
|
|
|
|
| class MidBlock(nn.Module): |
| r""" |
| Mid conv block with attention. |
| Sequence of following blocks |
| 1. Resnet block with time embedding |
| 2. Attention block |
| 3. Resnet block with time embedding |
| """ |
| |
| def __init__(self, in_channels, out_channels, t_emb_dim, num_heads, num_layers, norm_channels, cross_attn=None, context_dim=None): |
| super().__init__() |
| self.num_layers = num_layers |
| self.t_emb_dim = t_emb_dim |
| self.context_dim = context_dim |
| self.cross_attn = cross_attn |
| self.resnet_conv_first = nn.ModuleList( |
| [ |
| nn.Sequential( |
| nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), |
| nn.SiLU(), |
| nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1, |
| padding=1), |
| ) |
| for i in range(num_layers + 1) |
| ] |
| ) |
| |
| if self.t_emb_dim is not None: |
| self.t_emb_layers = nn.ModuleList([ |
| nn.Sequential( |
| nn.SiLU(), |
| nn.Linear(t_emb_dim, out_channels) |
| ) |
| for _ in range(num_layers + 1) |
| ]) |
| self.resnet_conv_second = nn.ModuleList( |
| [ |
| nn.Sequential( |
| nn.GroupNorm(norm_channels, out_channels), |
| nn.SiLU(), |
| nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), |
| ) |
| for _ in range(num_layers + 1) |
| ] |
| ) |
| |
| self.attention_norms = nn.ModuleList( |
| [nn.GroupNorm(norm_channels, out_channels) |
| for _ in range(num_layers)] |
| ) |
| |
| self.attentions = nn.ModuleList( |
| [nn.MultiheadAttention(out_channels, num_heads, batch_first=True) |
| for _ in range(num_layers)] |
| ) |
| if self.cross_attn: |
| assert context_dim is not None, "Context Dimension must be passed for cross attention" |
| self.cross_attention_norms = nn.ModuleList( |
| [nn.GroupNorm(norm_channels, out_channels) |
| for _ in range(num_layers)] |
| ) |
| self.cross_attentions = nn.ModuleList( |
| [nn.MultiheadAttention(out_channels, num_heads, batch_first=True) |
| for _ in range(num_layers)] |
| ) |
| self.context_proj = nn.ModuleList( |
| [nn.Linear(context_dim, out_channels) |
| for _ in range(num_layers)] |
| ) |
| self.residual_input_conv = nn.ModuleList( |
| [ |
| nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) |
| for i in range(num_layers + 1) |
| ] |
| ) |
| |
| def forward(self, x, t_emb=None, context=None): |
| out = x |
| |
| |
| resnet_input = out |
| out = self.resnet_conv_first[0](out) |
| if self.t_emb_dim is not None: |
| out = out + self.t_emb_layers[0](t_emb)[:, :, None, None] |
| out = self.resnet_conv_second[0](out) |
| out = out + self.residual_input_conv[0](resnet_input) |
| |
| for i in range(self.num_layers): |
| |
| batch_size, channels, h, w = out.shape |
| in_attn = out.reshape(batch_size, channels, h * w) |
| in_attn = self.attention_norms[i](in_attn) |
| in_attn = in_attn.transpose(1, 2) |
| out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) |
| out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) |
| out = out + out_attn |
| |
| if self.cross_attn: |
| assert context is not None, "context cannot be None if cross attention layers are used" |
| batch_size, channels, h, w = out.shape |
| in_attn = out.reshape(batch_size, channels, h * w) |
| in_attn = self.cross_attention_norms[i](in_attn) |
| in_attn = in_attn.transpose(1, 2) |
| assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim |
| context_proj = self.context_proj[i](context) |
| out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj) |
| out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) |
| out = out + out_attn |
| |
| |
| |
| resnet_input = out |
| out = self.resnet_conv_first[i + 1](out) |
| if self.t_emb_dim is not None: |
| out = out + self.t_emb_layers[i + 1](t_emb)[:, :, None, None] |
| out = self.resnet_conv_second[i + 1](out) |
| out = out + self.residual_input_conv[i + 1](resnet_input) |
| |
| return out |
|
|
|
|
| class UpBlock(nn.Module): |
| r""" |
| Up conv block with attention. |
| Sequence of following blocks |
| 1. Upsample |
| 1. Concatenate Down block output |
| 2. Resnet block with time embedding |
| 3. Attention Block |
| """ |
| |
| def __init__(self, in_channels, out_channels, t_emb_dim, |
| up_sample, num_heads, num_layers, attn, norm_channels): |
| super().__init__() |
| self.num_layers = num_layers |
| self.up_sample = up_sample |
| self.t_emb_dim = t_emb_dim |
| self.attn = attn |
| self.resnet_conv_first = nn.ModuleList( |
| [ |
| nn.Sequential( |
| nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), |
| nn.SiLU(), |
| nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1, |
| padding=1), |
| ) |
| for i in range(num_layers) |
| ] |
| ) |
| |
| if self.t_emb_dim is not None: |
| self.t_emb_layers = nn.ModuleList([ |
| nn.Sequential( |
| nn.SiLU(), |
| nn.Linear(t_emb_dim, out_channels) |
| ) |
| for _ in range(num_layers) |
| ]) |
| |
| self.resnet_conv_second = nn.ModuleList( |
| [ |
| nn.Sequential( |
| nn.GroupNorm(norm_channels, out_channels), |
| nn.SiLU(), |
| nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), |
| ) |
| for _ in range(num_layers) |
| ] |
| ) |
| if self.attn: |
| self.attention_norms = nn.ModuleList( |
| [ |
| nn.GroupNorm(norm_channels, out_channels) |
| for _ in range(num_layers) |
| ] |
| ) |
| |
| self.attentions = nn.ModuleList( |
| [ |
| nn.MultiheadAttention(out_channels, num_heads, batch_first=True) |
| for _ in range(num_layers) |
| ] |
| ) |
| |
| self.residual_input_conv = nn.ModuleList( |
| [ |
| nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) |
| for i in range(num_layers) |
| ] |
| ) |
| self.up_sample_conv = nn.ConvTranspose2d(in_channels, in_channels, |
| 4, 2, 1) \ |
| if self.up_sample else nn.Identity() |
| |
| def forward(self, x, out_down=None, t_emb=None): |
| |
| x = self.up_sample_conv(x) |
| |
| |
| if out_down is not None: |
| x = torch.cat([x, out_down], dim=1) |
| |
| out = x |
| for i in range(self.num_layers): |
| |
| resnet_input = out |
| out = self.resnet_conv_first[i](out) |
| if self.t_emb_dim is not None: |
| out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] |
| out = self.resnet_conv_second[i](out) |
| out = out + self.residual_input_conv[i](resnet_input) |
| |
| |
| if self.attn: |
| batch_size, channels, h, w = out.shape |
| in_attn = out.reshape(batch_size, channels, h * w) |
| in_attn = self.attention_norms[i](in_attn) |
| in_attn = in_attn.transpose(1, 2) |
| out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) |
| out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) |
| out = out + out_attn |
| return out |
|
|
|
|
| class UpBlockUnet(nn.Module): |
| r""" |
| Up conv block with attention. |
| Sequence of following blocks |
| 1. Upsample |
| 1. Concatenate Down block output |
| 2. Resnet block with time embedding |
| 3. Attention Block |
| """ |
| |
| def __init__(self, in_channels, out_channels, t_emb_dim, up_sample, |
| num_heads, num_layers, norm_channels, cross_attn=False, context_dim=None): |
| super().__init__() |
| self.num_layers = num_layers |
| self.up_sample = up_sample |
| self.t_emb_dim = t_emb_dim |
| self.cross_attn = cross_attn |
| self.context_dim = context_dim |
| self.resnet_conv_first = nn.ModuleList( |
| [ |
| nn.Sequential( |
| nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), |
| nn.SiLU(), |
| nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1, |
| padding=1), |
| ) |
| for i in range(num_layers) |
| ] |
| ) |
| |
| if self.t_emb_dim is not None: |
| self.t_emb_layers = nn.ModuleList([ |
| nn.Sequential( |
| nn.SiLU(), |
| nn.Linear(t_emb_dim, out_channels) |
| ) |
| for _ in range(num_layers) |
| ]) |
| |
| self.resnet_conv_second = nn.ModuleList( |
| [ |
| nn.Sequential( |
| nn.GroupNorm(norm_channels, out_channels), |
| nn.SiLU(), |
| nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), |
| ) |
| for _ in range(num_layers) |
| ] |
| ) |
| |
| self.attention_norms = nn.ModuleList( |
| [ |
| nn.GroupNorm(norm_channels, out_channels) |
| for _ in range(num_layers) |
| ] |
| ) |
| |
| self.attentions = nn.ModuleList( |
| [ |
| nn.MultiheadAttention(out_channels, num_heads, batch_first=True) |
| for _ in range(num_layers) |
| ] |
| ) |
| |
| if self.cross_attn: |
| assert context_dim is not None, "Context Dimension must be passed for cross attention" |
| self.cross_attention_norms = nn.ModuleList( |
| [nn.GroupNorm(norm_channels, out_channels) |
| for _ in range(num_layers)] |
| ) |
| self.cross_attentions = nn.ModuleList( |
| [nn.MultiheadAttention(out_channels, num_heads, batch_first=True) |
| for _ in range(num_layers)] |
| ) |
| self.context_proj = nn.ModuleList( |
| [nn.Linear(context_dim, out_channels) |
| for _ in range(num_layers)] |
| ) |
| self.residual_input_conv = nn.ModuleList( |
| [ |
| nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) |
| for i in range(num_layers) |
| ] |
| ) |
| self.up_sample_conv = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, |
| 4, 2, 1) \ |
| if self.up_sample else nn.Identity() |
| |
| def forward(self, x, out_down=None, t_emb=None, context=None): |
| x = self.up_sample_conv(x) |
| if out_down is not None: |
| x = torch.cat([x, out_down], dim=1) |
| |
| out = x |
| for i in range(self.num_layers): |
| |
| resnet_input = out |
| out = self.resnet_conv_first[i](out) |
| if self.t_emb_dim is not None: |
| out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] |
| out = self.resnet_conv_second[i](out) |
| out = out + self.residual_input_conv[i](resnet_input) |
| |
| batch_size, channels, h, w = out.shape |
| in_attn = out.reshape(batch_size, channels, h * w) |
| in_attn = self.attention_norms[i](in_attn) |
| in_attn = in_attn.transpose(1, 2) |
| out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) |
| out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) |
| out = out + out_attn |
| |
| if self.cross_attn: |
| assert context is not None, "context cannot be None if cross attention layers are used" |
| batch_size, channels, h, w = out.shape |
| in_attn = out.reshape(batch_size, channels, h * w) |
| in_attn = self.cross_attention_norms[i](in_attn) |
| in_attn = in_attn.transpose(1, 2) |
| assert len(context.shape) == 3, \ |
| "Context shape does not match B,_,CONTEXT_DIM" |
| assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim,\ |
| "Context shape does not match B,_,CONTEXT_DIM" |
| context_proj = self.context_proj[i](context) |
| out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj) |
| out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) |
| out = out + out_attn |
| |
| return out |
|
|
|
|