| import torch |
| import torch.nn as nn |
|
|
|
|
| def get_time_embedding(time_steps, temb_dim): |
| r""" |
| Convert time steps tensor into an embedding using the |
| sinusoidal time embedding formula |
| :param time_steps: 1D tensor of length batch size |
| :param temb_dim: Dimension of the embedding |
| :return: BxD embedding representation of B time steps |
| """ |
| assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2" |
|
|
| |
| factor = 10000 ** ( |
| torch.arange( |
| start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device |
| ) |
| / (temb_dim // 2) |
| ) |
|
|
| |
| |
| t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor |
| t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1) |
| return t_emb |
|
|
|
|
| class DownBlock(nn.Module): |
| r""" |
| DownBlock for Diffusion model: |
| a) Block Time embedding -> [Silu -> FC] |
| ↓ |
| 1) Resnet Block :- [Norm-> Silu -> Conv] x num_layers |
| 2) Self Attention :- [Norm -> SA] |
| 3) Cross Attention :- [Norm -> CA] |
| b) DownSample : DownSample the dimnension |
| """ |
|
|
| def __init__( |
| self, |
| num_heads, |
| num_layers, |
| cross_attn, |
| input_dim, |
| output_dim, |
| t_emb_dim, |
| cond_dim, |
| norm_channels, |
| self_attn, |
| down_sample, |
| ) -> None: |
| super().__init__() |
| self.num_heads = num_heads |
| self.num_layers = num_layers |
| self.cross_attn = cross_attn |
| self.input_dim = input_dim |
| self.output_dim = output_dim |
| self.cond_dim = cond_dim |
| self.norm_channels = norm_channels |
| self.t_emb_dim = t_emb_dim |
| self.attn = self_attn |
| self.down_sample = down_sample |
|
|
| self.resnet_in = nn.ModuleList( |
| [ |
| nn.Conv2d( |
| self.input_dim if i == 0 else self.output_dim, |
| self.output_dim, |
| kernel_size=1, |
| ) |
| for i in range(self.num_layers) |
| ] |
| ) |
| self.resnet_one = nn.ModuleList( |
| [ |
| nn.Sequential( |
| nn.GroupNorm( |
| self.norm_channels, |
| self.input_dim if i == 0 else self.output_dim, |
| ), |
| nn.SiLU(), |
| nn.Conv2d( |
| self.input_dim if i == 0 else self.output_dim, |
| self.output_dim, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| ), |
| ) |
| for i in range(self.num_layers) |
| ] |
| ) |
|
|
| if self.t_emb_dim is not None: |
| self.t_emb_layers = nn.ModuleList( |
| [ |
| nn.Sequential(nn.SiLU(), nn.Linear(self.t_emb_dim, self.output_dim)) |
| for _ in range(self.num_layers) |
| ] |
| ) |
|
|
| self.resnet_two = nn.ModuleList( |
| [ |
| nn.Sequential( |
| nn.GroupNorm( |
| self.norm_channels, |
| self.output_dim, |
| ), |
| nn.SiLU(), |
| nn.Conv2d( |
| self.output_dim, |
| self.output_dim, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| ), |
| ) |
| for _ in range(self.num_layers) |
| ] |
| ) |
|
|
| if self.attn: |
| self.attention_norms = nn.ModuleList( |
| [ |
| nn.GroupNorm(self.norm_channels, self.output_dim) |
| for _ in range(num_layers) |
| ] |
| ) |
| self.attentions = nn.ModuleList( |
| [ |
| nn.MultiheadAttention( |
| self.output_dim, self.num_heads, batch_first=True |
| ) |
| for _ in range(self.num_layers) |
| ] |
| ) |
|
|
| if self.cross_attn: |
| self.cross_attn_norms = nn.ModuleList( |
| [ |
| nn.GroupNorm(self.norm_channels, self.output_dim) |
| for _ in range(self.num_layers) |
| ] |
| ) |
| self.cross_attentions = nn.ModuleList( |
| [ |
| nn.MultiheadAttention( |
| self.output_dim, self.num_heads, batch_first=True |
| ) |
| for _ in range(self.num_layers) |
| ] |
| ) |
|
|
| self.context_proj = nn.ModuleList( |
| [ |
| nn.Linear(self.cond_dim, self.output_dim) |
| for _ in range(self.num_layers) |
| ] |
| ) |
|
|
| self.down_sample_conv = ( |
| nn.Conv2d(self.output_dim, self.output_dim, 4, 2, 1) |
| if self.down_sample |
| else nn.Identity() |
| ) |
|
|
| def forward(self, x, t_emb=None, context=None): |
| out = x |
| for i in range(self.num_layers): |
| |
| resnet_input = out |
| out = self.resnet_one[i](out) |
| if t_emb is not None: |
| out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] |
| out = self.resnet_two[i](out) |
| out = out + self.resnet_in[i](resnet_input) |
|
|
| if self.attn: |
| |
| batch_size, channels, h, w = out.shape |
| in_attn = out.reshape(batch_size, channels, h * w) |
| in_attn = self.attention_norms[i](in_attn) |
| in_attn = in_attn.transpose(1, 2) |
|
|