import logging import torch import torch.nn as nn logger = logging.getLogger(__name__) def get_time_embedding(time_steps, temb_dim): r""" Convert time steps tensor into an embedding using the sinusoidal time embedding formula :param time_steps: 1D tensor of length batch size :param temb_dim: Dimension of the embedding :return: BxD embedding representation of B time steps """ assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2" # factor = 10000^(2i/d_model) factor = 10000 ** ( torch.arange( start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device ) / (temb_dim // 2) ) # pos / factor # timesteps B -> B, 1 -> B, temb_dim t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1) return t_emb class DownBlock(nn.Module): r""" DownBlock for Diffusion model: a) Block Time embedding -> [Silu -> FC] ↓ 1) Resnet Block :- [Norm-> Silu -> Conv] x num_layers 2) Self Attention :- [Norm -> SA] 3) Cross Attention :- [Norm -> CA] b) MidSample : DownSample the dimnension """ def __init__( self, num_heads, num_layers, cross_attn, input_dim, output_dim, t_emb_dim, cond_dim, norm_channels, self_attn, down_sample, ) -> None: super().__init__() self.num_heads = num_heads self.num_layers = num_layers self.cross_attn = cross_attn self.input_dim = input_dim self.output_dim = output_dim self.cond_dim = cond_dim self.norm_channels = norm_channels self.t_emb_dim = t_emb_dim self.attn = self_attn self.down_sample = down_sample self.resnet_in = nn.ModuleList( [ nn.Conv2d( self.input_dim if i == 0 else self.output_dim, self.output_dim, kernel_size=1, ) for i in range(self.num_layers) ] ) self.resnet_one = nn.ModuleList( [ nn.Sequential( nn.GroupNorm( self.norm_channels, self.input_dim if i == 0 else self.output_dim, ), nn.SiLU(), nn.Conv2d( self.input_dim if i == 0 else self.output_dim, self.output_dim, kernel_size=3, stride=1, padding=1, ), ) for i in range(self.num_layers) ] ) if self.t_emb_dim is not None: self.t_emb_layers = nn.ModuleList( [ nn.Sequential(nn.SiLU(), nn.Linear(self.t_emb_dim, self.output_dim)) for _ in range(self.num_layers) ] ) self.resnet_two = nn.ModuleList( [ nn.Sequential( nn.GroupNorm( self.norm_channels, self.output_dim, ), nn.SiLU(), nn.Conv2d( self.output_dim, self.output_dim, kernel_size=3, stride=1, padding=1, ), ) for _ in range(self.num_layers) ] ) if self.attn: self.attention_norms = nn.ModuleList( [ nn.GroupNorm(self.norm_channels, self.output_dim) for _ in range(num_layers) ] ) self.attentions = nn.ModuleList( [ nn.MultiheadAttention( self.output_dim, self.num_heads, batch_first=True ) for _ in range(self.num_layers) ] ) if self.cross_attn: self.cross_attn_norms = nn.ModuleList( [ nn.GroupNorm(self.norm_channels, self.output_dim) for _ in range(self.num_layers) ] ) self.cross_attentions = nn.ModuleList( [ nn.MultiheadAttention( self.output_dim, self.num_heads, batch_first=True ) for _ in range(self.num_layers) ] ) self.context_proj = nn.ModuleList( [ nn.Linear(self.cond_dim, self.output_dim) for _ in range(self.num_layers) ] ) self.down_sample_conv = ( nn.Conv2d(self.output_dim, self.output_dim, 4, 2, 1) if self.down_sample else nn.Identity() ) def forward(self, x, t_emb=None, context=None): out = x for i in range(self.num_layers): # Input x to Resnet Block of the Encoder of the Unet logger.debug(f"Input to Resnet Block in Down Block Layer {i} : {out.shape}") resnet_input = out out = self.resnet_one[i](out) logger.debug( f"Output of Resnet Sub Block 1 of Down Block Layer {i} : {out.shape}" ) if self.t_emb_dim is not None: logger.debug( f"Adding t_emb of shape {self.t_emb_dim} to output of shape: {out.shape} of Down Block Layer {i}" ) out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] out = self.resnet_two[i](out) logger.debug( f"Output of Resnet Sub Block 2 of Down Block Layer: {i} with output_shape:{out.shape}" ) out = out + self.resnet_in[i](resnet_input) logger.debug( f"Residual connection of the input to out : {out.shape} in Down Block Layer {i}" ) if self.attn: # Now Passing through the Self Attention blocks logger.debug(f"Going into the attention Block in Down Block Layer {i}") batch_size, channels, h, w = out.shape in_attn = out.reshape(batch_size, channels, h * w) in_attn = self.attention_norms[i](in_attn) in_attn = in_attn.transpose(1, 2) out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) out = out + out_attn logger.debug( f"Out of the Self Attention Block with out : {out.shape} in Down Block Layer {i}" ) if self.cross_attn: assert context is not None, ( "context cannot be None if cross attention layers are used" ) logger.debug( f"Going into the Cross Attention Block in Down Block Layer {i}" ) batch_size, channels, h, w = out.shape in_attn = out.reshape(batch_size, channels, h * w) in_attn = self.cross_attn_norms[i](in_attn) in_attn = in_attn.transpose(1, 2) assert ( context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim ) logger.debug( f"Calculating context projection for Cross Attn in Down Block Layer : {i}" ) context_proj = self.context_proj[i](context) out_attn, _ = self.cross_attentions[i]( in_attn, context_proj, context_proj ) out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) out = out + out_attn logger.debug( f"Out of the Cross Attention Block with out : {out.shape} in Down Block Layer {i}" ) # DownSample to x2 smaller dimension out = self.down_sample_conv(out) logger.debug(f"Down Sampling out to : {out.shape} in Down Block Layer {i} ") return out class MidBlock(nn.Module): r""" MidBlock for Diffusion model: Time embedding -> [Silu -> FC] ↓ 1) Resnet Block :- [Norm-> Silu -> Conv] x num_layers 2) Self Attention :- [Norm -> SA] 3) Cross Attention :- [Norm -> CA] Time embedding -> [Silu -> FC] ↓ 4) Resnet Block :- [Norm-> Silu -> Conv] x num_layers """ def __init__( self, num_heads, num_layers, cross_attn, input_dim, output_dim, t_emb_dim, cond_dim, norm_channels, self_attn, down_sample, ) -> None: super().__init__() self.num_heads = num_heads self.num_layers = num_layers self.cross_attn = cross_attn self.input_dim = input_dim self.output_dim = output_dim self.cond_dim = cond_dim self.norm_channels = norm_channels self.t_emb_dim = t_emb_dim self.attn = self_attn self.down_sample = down_sample self.resnet_one = nn.ModuleList( [ nn.Sequential( nn.GroupNorm( self.norm_channels, self.input_dim if i == 0 else self.output_dim, ), nn.SiLU(), nn.Conv2d( self.input_dim if i == 0 else self.output_dim, self.output_dim, kernel_size=3, stride=1, padding=1, ), ) for i in range(self.num_layers + 1) ] ) if self.t_emb_dim is not None: self.t_emb_layers = nn.ModuleList( [ nn.Sequential(nn.SiLU(), nn.Linear(self.t_emb_dim, self.output_dim)) for _ in range(self.num_layers + 1) ] ) self.resnet_two = nn.ModuleList( [ nn.Sequential( nn.GroupNorm( self.norm_channels, self.output_dim, ), nn.SiLU(), nn.Conv2d( self.output_dim, self.output_dim, kernel_size=3, stride=1, padding=1, ), ) for _ in range(self.num_layers + 1) ] ) if self.attn: self.attention_norms = nn.ModuleList( [ nn.GroupNorm(self.norm_channels, self.output_dim) for _ in range(num_layers) ] ) self.attentions = nn.ModuleList( [ nn.MultiheadAttention( self.output_dim, self.num_heads, batch_first=True ) for _ in range(self.num_layers) ] ) if self.cross_attn: self.cross_attn_norms = nn.ModuleList( [ nn.GroupNorm(self.norm_channels, self.output_dim) for _ in range(self.num_layers) ] ) self.cross_attentions = nn.ModuleList( [ nn.MultiheadAttention( self.output_dim, self.num_heads, batch_first=True ) for _ in range(self.num_layers) ] ) self.context_proj = nn.ModuleList( [ nn.Linear(self.cond_dim, self.output_dim) for _ in range(self.num_layers) ] ) self.resnet_in = nn.ModuleList( [ nn.Conv2d( self.input_dim if i == 0 else self.output_dim, self.output_dim, kernel_size=1, ) for i in range(self.num_layers + 1) ] ) def forward(self, x, t_emb=None, context=None): out = x # Input Resnet Block logger.debug("Input to First Resnet Block in Mid Block") resnet_input = out out = self.resnet_one[0](out) logger.debug(f"Output of Resnet Sub Block 1 of Mid Block Layer: {out.shape}") if self.t_emb_dim is not None: out = out + self.t_emb_layers[0](t_emb)[:, :, None, None] logger.debug( f"Adding t_emb of shape {self.t_emb_dim} to output of shape: {out.shape}" ) out = self.resnet_two[0](out) logger.debug(f"Output of Resnet Sub Block 2 with output_shape:{out.shape}") out = out + self.resnet_in[0](resnet_input) logger.debug( f"Residual connection of the input to out : {out.shape} in Mid Block" ) for i in range(self.num_layers): logger.debug(f"Going into the attention Block in Mid Block Layer {i}") batch_size, channels, h, w = out.shape in_attn = out.reshape(batch_size, channels, h * w) in_attn = self.attention_norms[i](in_attn) in_attn = in_attn.transpose(1, 2) out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) out = out + out_attn logger.debug( f"Out of the Self Attention Block with out : {out.shape} in Mid Block Layer {i}" ) if self.cross_attn: assert context is not None, ( "context cannot be None if cross attention layers are used" ) logger.debug( f"Going into the Cross Attention Block in Mid Block Layer {i}" ) batch_size, channels, h, w = out.shape in_attn = out.reshape(batch_size, channels, h * w) in_attn = self.cross_attn_norms[i](in_attn) in_attn = in_attn.transpose(1, 2) assert ( context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim ) logger.debug( f"Calculating context projection for Cross Attn in Mid Block Layer : {i}" ) context_proj = self.context_proj[i](context) out_attn, _ = self.cross_attentions[i]( in_attn, context_proj, context_proj ) out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) out = out + out_attn logger.debug( f"Out of the Cross Attention Block with out : {out.shape} in Mid Block Layer {i}" ) logger.debug( f"Last Resnet Block input : {out.shape} of Mid Block Layer {i}" ) resnet_input = out out = self.resnet_one[0](out) logger.debug( f"Output of Resnet Sub Block 1 of Mid Block Layer {i} of shape : {out.shape}" ) if self.t_emb_dim is not None: out = out + self.t_emb_layers[0](t_emb)[:, :, None, None] logger.debug( f"Adding t_emb of shape {self.t_emb_dim} to output of shape: {out.shape} of Mid Block Layer {i}" ) out = self.resnet_two[0](out) logger.debug( f"Output of Resnet Sub Block 2 with output_shape:{out.shape} of Mid Block Layer {i}" ) out = out + self.resnet_in[0](resnet_input) logger.debug( f"Residual connection of the input to out : {out.shape} in Mid Block Layer {i}" ) return out class UpBlockUnet(nn.Module): r""" Up conv block with attention. Sequence of following blocks 1. Upsample 1. Concatenate Down block output 2. Resnet block with time embedding 3. Attention Block """ def __init__( self, in_channels, out_channels, t_emb_dim, up_sample, num_heads, num_layers, norm_channels, cross_attn=False, context_dim=None, ): super().__init__() self.num_layers = num_layers self.up_sample = up_sample self.t_emb_dim = t_emb_dim self.cross_attn = cross_attn self.context_dim = context_dim self.resnet_conv_first = nn.ModuleList( [ nn.Sequential( nn.GroupNorm( norm_channels, in_channels if i == 0 else out_channels ), nn.SiLU(), nn.Conv2d( in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1, padding=1, ), ) for i in range(num_layers) ] ) if self.t_emb_dim is not None: self.t_emb_layers = nn.ModuleList( [ nn.Sequential(nn.SiLU(), nn.Linear(t_emb_dim, out_channels)) for _ in range(num_layers) ] ) self.resnet_conv_second = nn.ModuleList( [ nn.Sequential( nn.GroupNorm(norm_channels, out_channels), nn.SiLU(), nn.Conv2d( out_channels, out_channels, kernel_size=3, stride=1, padding=1 ), ) for _ in range(num_layers) ] ) self.attention_norms = nn.ModuleList( [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)] ) self.attentions = nn.ModuleList( [ nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers) ] ) if self.cross_attn: assert context_dim is not None, ( "Context Dimension must be passed for cross attention" ) self.cross_attention_norms = nn.ModuleList( [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)] ) self.cross_attentions = nn.ModuleList( [ nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers) ] ) self.context_proj = nn.ModuleList( [nn.Linear(context_dim, out_channels) for _ in range(num_layers)] ) self.residual_input_conv = nn.ModuleList( [ nn.Conv2d( in_channels if i == 0 else out_channels, out_channels, kernel_size=1 ) for i in range(num_layers) ] ) self.up_sample_conv = ( nn.ConvTranspose2d(in_channels // 2, in_channels // 2, 4, 2, 1) if self.up_sample else nn.Identity() ) def forward(self, x, out_down=None, t_emb=None, context=None): x = self.up_sample_conv(x) if out_down is not None: x = torch.cat([x, out_down], dim=1) out = x for i in range(self.num_layers): # Resnet resnet_input = out out = self.resnet_conv_first[i](out) if self.t_emb_dim is not None: out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] out = self.resnet_conv_second[i](out) out = out + self.residual_input_conv[i](resnet_input) # Self Attention batch_size, channels, h, w = out.shape in_attn = out.reshape(batch_size, channels, h * w) in_attn = self.attention_norms[i](in_attn) in_attn = in_attn.transpose(1, 2) out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) out = out + out_attn # Cross Attention if self.cross_attn: assert context is not None, ( "context cannot be None if cross attention layers are used" ) batch_size, channels, h, w = out.shape in_attn = out.reshape(batch_size, channels, h * w) in_attn = self.cross_attention_norms[i](in_attn) in_attn = in_attn.transpose(1, 2) assert len(context.shape) == 3, ( "Context shape does not match B,_,CONTEXT_DIM" ) assert ( context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim ), "Context shape does not match B,_,CONTEXT_DIM" context_proj = self.context_proj[i](context) out_attn, _ = self.cross_attentions[i]( in_attn, context_proj, context_proj ) out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) out = out + out_attn return out