import logging from os import wait import torch import torch.nn as nn logger = logging.getLogger(__name__) def get_time_embedding(time_steps, temb_dim): r""" Convert time steps tensor into an embedding using the sinusoidal time embedding formula :param time_steps: 1D tensor of length batch size :param temb_dim: Dimension of the embedding :return: BxD embedding representation of B time steps """ assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2" # factor = 10000^(2i/d_model) factor = 10000 ** ( torch.arange( start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device ) / (temb_dim // 2) ) # pos / factor # timesteps B -> B, 1 -> B, temb_dim t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1) return t_emb class DownBlock(nn.Module): r""" DownBlock for Diffusion model: a) Block Time embedding -> [Silu -> FC] ↓ 1) Resnet Block :- [Norm-> Silu -> Conv] x num_layers 2) Self Attention :- [Norm -> SA] b) DownSample : DownSample the dimnension """ def __init__( self, input_dim, output_dim, t_emb_dim, down_sample=True, num_heads=4, num_layers=1, ) -> None: super().__init__() self.input_dim = input_dim self.output_dim = output_dim self.down_sample = down_sample self.num_heads = num_heads self.num_layers = num_layers self.t_emb_dim = t_emb_dim self.resnet_one = nn.ModuleList( [ nn.Sequential( nn.GroupNorm(8, self.input_dim if i == 0 else self.output_dim), nn.SiLU(), nn.Conv2d( self.input_dim if i == 0 else self.output_dim, self.output_dim, kernel_size=3, stride=1, padding=1, ), ) for i in range(self.num_layers) ] ) self.t_emb_layers = nn.ModuleList( [ nn.Sequential(nn.SiLU(), nn.Linear(self.t_emb_dim, self.output_dim)) for _ in range(self.num_layers) ] ) self.resnet_two = nn.ModuleList( [ nn.Sequential( nn.GroupNorm(8, self.output_dim), nn.SiLU(), nn.Conv2d( self.output_dim, self.output_dim, kernel_size=3, stride=1, padding=1, ), ) for _ in range(self.num_layers) ] ) self.attention_norms = nn.ModuleList( [nn.GroupNorm(8, self.output_dim) for _ in range(self.num_layers)] ) self.attentions = nn.ModuleList( [nn.MultiheadAttention(self.output_dim, self.num_heads, batch_first=True)] ) self.resnet_in = nn.ModuleList( [ nn.Conv2d( self.input_dim if i == 0 else self.output_dim, self.output_dim, kernel_size=1, ) for i in range(self.num_layers) ] ) self.down_sample_conv = ( nn.Conv2d(self.output_dim, self.output_dim, 4, 2, 1) if self.down_sample else nn.Identity() ) def forward( self, x, t_emb, ): out = x logger.debug(f"Input of shape: {out.shape} to Down Block ") for i in range(self.num_layers): resnet_input = out logger.debug(f"Input to Resnet Block : {resnet_input.shape} ") out = self.resnet_one[i](out) out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] logger.debug( f"Concatenated time embeddings to Resnet Sub Block 1: {out.shape} of Down Block Layer {i}" ) out = self.resnet_two[i](out) out = out + self.resnet_in[i](resnet_input) logger.debug( f"Adding Residual connection : {out.shape} to Down Block Layer {i}" ) batch_size, channels, h, w = out.shape in_attn = out.reshape(batch_size, channels, h * w) in_attn = self.attention_norms[i](in_attn) logger.debug(f"Attention Norm: {in_attn.shape} in Down Block Layer : {i}") in_attn = in_attn.transpose(1, 2) logger.debug( f"Passing Norm : {in_attn.shape} to Attention Layer in Down Block Layer : {i}" ) out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) out = out + out_attn logger.debug( f"Added Attention score to output: {out.shape} in Down Block Layer {i}" ) out = self.down_sample_conv(out) logger.debug(f"Down sampled to : {out.shape}") return out class MidBlock(nn.Module): r""" MidBlock for Diffusion model: Time embedding -> [Silu -> FC] ↓ 1) Resnet Block :- [Norm-> Silu -> Conv] x num_layers 2) Self Attention :- [Norm -> SA] Time embedding -> [Silu -> FC] ↓ 3) Resnet Block :- [Norm-> Silu -> Conv] x num_layers """ def __init__( self, input_dim, output_dim, t_emb_dim, num_heads=4, num_layers=1, ) -> None: super().__init__() self.input_dim = input_dim self.output_dim = output_dim self.num_heads = num_heads self.num_layers = num_layers self.t_emb_dim = t_emb_dim self.resnet_one = nn.ModuleList( [ nn.Sequential( nn.GroupNorm(8, self.input_dim if i == 0 else self.output_dim), nn.SiLU(), nn.Conv2d( self.input_dim if i == 0 else self.output_dim, self.output_dim, kernel_size=3, stride=1, padding=1, ), ) for i in range(self.num_layers + 1) ] ) self.t_emb_layers = nn.ModuleList( [ nn.Sequential(nn.SiLU(), nn.Linear(self.t_emb_dim, self.output_dim)) for _ in range(self.num_layers + 1) ] ) self.resnet_two = nn.ModuleList( [ nn.Sequential( nn.GroupNorm(8, self.output_dim), nn.SiLU(), nn.Conv2d( self.output_dim, self.output_dim, kernel_size=3, stride=1, padding=1, ), ) for _ in range(self.num_layers + 1) ] ) self.attention_norms = nn.ModuleList( [nn.GroupNorm(8, self.output_dim) for _ in range(self.num_layers)] ) self.attentions = nn.ModuleList( [ nn.MultiheadAttention(self.output_dim, self.num_heads, batch_first=True) for _ in range(self.num_layers) ] ) self.resnet_in = nn.ModuleList( [ nn.Conv2d( self.input_dim if i == 0 else self.output_dim, self.output_dim, kernel_size=1, ) for i in range(self.num_layers + 1) ] ) def forward(self, x, t_emb): out = x logger.debug(f"Input of shape: {out.shape} to Mid Block ") # First Resnet Block resnet_input = out logger.debug( f"Input to Resnet Block : {resnet_input.shape} in Mid Block Layer 0" ) out = self.resnet_one[0](out) out = out + self.t_emb_layers[0](t_emb)[:, :, None, None] logger.debug( f"Concatenated time embeddings to Resnet Sub Block 1: {out.shape} of Mid Block Layer 0" ) out = self.resnet_two[0](out) out = out + self.resnet_in[0](resnet_input) logger.debug(f"Adding Residual connection : {out.shape} to Mid Block Layer 0") for i in range(self.num_layers): # Attention Block batch_size, channels, h, w = out.shape in_attn = out.reshape(batch_size, channels, h * w) in_attn = self.attention_norms[i](in_attn) logger.debug(f"Attention Norm: {in_attn.shape} in Mid Block Layer : {i} ") in_attn = in_attn.transpose(1, 2) out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) out = out + out_attn logger.debug( f"Added Attention score to output: {out.shape} in Mid Block Layer {i}" ) # Resnet Block resnet_input = out logger.debug( f"Input to Resnet Block : {resnet_input.shape} in Mid Block Layer {i}" ) out = self.resnet_one[i + 1](out) out = out + self.t_emb_layers[i + 1](t_emb)[:, :, None, None] logger.debug( f"Concatenated time embeddings to Resnet Sub Block 1: {out.shape} of Mid Block Layer {i}" ) out = self.resnet_two[i + 1](out) out = out + self.resnet_in[i + 1](resnet_input) logger.debug( f"Adding Residual connection : {out.shape} to Mid Block Layer {i}" ) return out class UpBlock(nn.Module): r""" UpBlock for Diffusion model: 1. Upsample 1. Concatenate Down block output 2. Resnet block with time embedding 3. Attention Block """ def __init__( self, input_dim, output_dim, t_emb_dim, up_sample=True, num_heads=4, num_layers=1, ) -> None: super().__init__() self.input_dim = input_dim self.output_dim = output_dim self.up_sample = up_sample self.num_heads = num_heads self.num_layers = num_layers self.t_emb_dim = t_emb_dim self.resnet_one = nn.ModuleList( [ nn.Sequential( nn.GroupNorm(8, self.input_dim if i == 0 else self.output_dim), nn.SiLU(), nn.Conv2d( self.input_dim if i == 0 else self.output_dim, self.output_dim, kernel_size=3, stride=1, padding=1, ), ) for i in range(self.num_layers) ] ) self.t_emb_layers = nn.ModuleList( [ nn.Sequential(nn.SiLU(), nn.Linear(self.t_emb_dim, self.output_dim)) for _ in range(self.num_layers) ] ) self.resnet_two = nn.ModuleList( [ nn.Sequential( nn.GroupNorm(8, self.output_dim), nn.SiLU(), nn.Conv2d( self.output_dim, self.output_dim, kernel_size=3, stride=1, padding=1, ), ) for _ in range(self.num_layers) ] ) self.attention_norms = nn.ModuleList( [nn.GroupNorm(8, self.output_dim) for _ in range(self.num_layers)] ) self.attentions = nn.ModuleList( [ nn.MultiheadAttention(self.output_dim, self.num_heads, batch_first=True) for _ in range(self.num_layers) ] ) self.resnet_in = nn.ModuleList( [ nn.Conv2d( self.input_dim if i == 0 else self.output_dim, self.output_dim, kernel_size=1, ) for i in range(self.num_layers) ] ) self.up_sample_conv = ( nn.ConvTranspose2d(self.input_dim // 2, self.output_dim // 2, 4, 2, 1) if self.up_sample else nn.Identity() ) def forward(self, x, out_down, t_emb): logger.debug(f"Input of shape: {x.shape} to Up Block ") out = x out = self.up_sample_conv(out) logger.debug(f"Up sampled to : {out.shape}") # Concatenate Down Block output out = torch.cat([out, out_down], dim=1) logger.debug(f"Concatenated Down Block output: {out.shape}") for i in range(self.num_layers): resnet_input = out logger.debug( f"Input to Resnet Block : {resnet_input.shape} in Up Block Layer {i}" ) out = self.resnet_one[i](out) out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] logger.debug( f"Concatenated time embeddings to Resnet Sub Block 1: {out.shape} of Up Block Layer {i}" ) out = self.resnet_two[i](out) out = out + self.resnet_in[i](resnet_input) logger.debug( f"Adding Residual connection : {out.shape} to Up Block Layer {i}" ) # Attention Block batch_size, channels, h, w = out.shape in_attn = out.reshape(batch_size, channels, h * w) in_attn = self.attention_norms[i](in_attn) logger.debug(f"Attention Norm: {in_attn.shape} in Up Block Layer : {i}") in_attn = in_attn.transpose(1, 2) logger.debug( f"Passing Norm : {in_attn.shape} to Attention Layer in Up Block Layer : {i}" ) out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) out = out + out_attn logger.debug( f"Added Attention score to output: {out.shape} in Up Block Layer {i}" ) return out class UNet(nn.Module): r""" Unet Backbone consisting: Down Blocks, Mid Blocks, UpBlocks """ def __init__(self, model_config, use_up=True): super().__init__() im_channels = model_config["im_channels"] self.down_channels = model_config["down_channels"] self.mid_channels = model_config["mid_channels"] self.t_emb_dim = model_config["t_emb_dim"] self.down_sample = model_config["down_sample"] self.num_down_layers = model_config["num_down_layers"] self.num_mid_layers = model_config["num_mid_layers"] self.num_up_layers = model_config["num_up_layers"] assert self.mid_channels[0] == self.down_channels[-1] assert self.mid_channels[-1] == self.down_channels[-2] assert len(self.down_sample) == len(self.down_channels) - 1 self.t_proj = nn.Sequential( nn.Linear(self.t_emb_dim, self.t_emb_dim), nn.SiLU(), nn.Linear(self.t_emb_dim, self.t_emb_dim), ) self.up_sample = list(reversed(self.down_sample)) self.conv_in = nn.Conv2d( im_channels, self.down_channels[0], kernel_size=3, padding=1 ) self.downs = nn.ModuleList([]) for i in range(len(self.down_channels) - 1): self.downs.append( DownBlock( self.down_channels[i], self.down_channels[i + 1], self.t_emb_dim, down_sample=self.down_sample[i], num_layers=self.num_down_layers, ) ) self.mids = nn.ModuleList([]) for i in range(len(self.mid_channels) - 1): self.mids.append( MidBlock( self.mid_channels[i], self.mid_channels[i + 1], self.t_emb_dim, num_layers=self.num_mid_layers, ) ) if use_up: self.ups = nn.ModuleList([]) for i in reversed(range(len(self.down_channels) - 1)): self.ups.append( UpBlock( self.down_channels[i] * 2, self.down_channels[i - 1] if i != 0 else 16, self.t_emb_dim, up_sample=self.down_sample[i], num_layers=self.num_up_layers, ) ) self.norm_out = nn.GroupNorm(8, 16) self.conv_out = nn.Conv2d(16, im_channels, kernel_size=3, padding=1) def forward(self, x, t): t_emb = get_time_embedding(torch.as_tensor(t).long(), self.t_emb_dim) t_emb = self.t_proj(t_emb) logger.debug(f"Time embedding shape: {t_emb.shape} to UNet") out = self.conv_in(x) logger.debug(f"Ouput for conv : {out.shape} to UNet") down_outs = [] for idx, down in enumerate(self.downs): down_outs.append(out) out = down(out, t_emb) logger.debug(f"Output of Down Block {idx} : {out.shape} in UNet") for idx, mid in enumerate(self.mids): out = mid(out, t_emb) logger.debug(f"Output of Mid Block {idx} : {out.shape} in UNet") for idx, up in enumerate(self.ups): out = up(out, down_outs.pop(), t_emb) logger.debug(f"Output of Up Block {idx} : {out.shape} in UNet") out = self.norm_out(out) out = self.conv_out(out) logger.debug(f"Output of UNet : {out.shape}") return out