import torch import torch.nn as nn import torch.nn.functional as F from einops import einsum import numpy as np import pickle import glob import os # ========================================== # BLOCKS for VQVAE (Down, Mid, Up) # ========================================== def get_time_embedding(time_steps, temb_dim): r""" Convert time steps tensor into an embedding using the sinusoidal time embedding formula :param time_steps: 1D tensor of length batch size :param temb_dim: Dimension of the embedding :return: BxD embedding representation of B time steps """ assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2" # factor = 10000^(2i/d_model) factor = 10000 ** ((torch.arange( start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2)) ) # pos / factor # timesteps B -> B, 1 -> B, temb_dim t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1) return t_emb class DownBlock(nn.Module): r""" Down conv block with attention. Sequence of following block 1. Resnet block with time embedding 2. Attention block 3. Downsample """ def __init__(self, in_channels, out_channels, t_emb_dim, down_sample, num_heads, num_layers, attn, norm_channels, cross_attn=False, context_dim=None): super().__init__() self.num_layers = num_layers self.down_sample = down_sample self.attn = attn self.context_dim = context_dim self.cross_attn = cross_attn self.t_emb_dim = t_emb_dim self.resnet_conv_first = nn.ModuleList( [ nn.Sequential( nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), nn.SiLU(), nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1, padding=1), ) for i in range(num_layers) ] ) if self.t_emb_dim is not None: self.t_emb_layers = nn.ModuleList([ nn.Sequential( nn.SiLU(), nn.Linear(self.t_emb_dim, out_channels) ) for _ in range(num_layers) ]) self.resnet_conv_second = nn.ModuleList( [ nn.Sequential( nn.GroupNorm(norm_channels, out_channels), nn.SiLU(), nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), ) for _ in range(num_layers) ] ) if self.attn: self.attention_norms = nn.ModuleList( [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)] ) self.attentions = nn.ModuleList( [nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers)] ) if self.cross_attn: assert context_dim is not None, "Context Dimension must be passed for cross attention" self.cross_attention_norms = nn.ModuleList( [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)] ) self.cross_attentions = nn.ModuleList( [nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers)] ) self.context_proj = nn.ModuleList( [nn.Linear(context_dim, out_channels) for _ in range(num_layers)] ) self.residual_input_conv = nn.ModuleList( [ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) for i in range(num_layers) ] ) self.down_sample_conv = nn.Conv2d(out_channels, out_channels, 4, 2, 1) if self.down_sample else nn.Identity() def forward(self, x, t_emb=None, context=None): out = x for i in range(self.num_layers): # Resnet block of Unet resnet_input = out out = self.resnet_conv_first[i](out) if self.t_emb_dim is not None: out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] out = self.resnet_conv_second[i](out) out = out + self.residual_input_conv[i](resnet_input) if self.attn: # Attention block of Unet batch_size, channels, h, w = out.shape in_attn = out.reshape(batch_size, channels, h * w) in_attn = self.attention_norms[i](in_attn) in_attn = in_attn.transpose(1, 2) out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) out = out + out_attn if self.cross_attn: assert context is not None, "context cannot be None if cross attention layers are used" batch_size, channels, h, w = out.shape in_attn = out.reshape(batch_size, channels, h * w) in_attn = self.cross_attention_norms[i](in_attn) in_attn = in_attn.transpose(1, 2) assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim context_proj = self.context_proj[i](context) out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj) out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) out = out + out_attn # Downsample out = self.down_sample_conv(out) return out class MidBlock(nn.Module): r""" Mid conv block with attention. Sequence of following blocks 1. Resnet block with time embedding 2. Attention block 3. Resnet block with time embedding """ def __init__(self, in_channels, out_channels, t_emb_dim, num_heads, num_layers, norm_channels, cross_attn=None, context_dim=None): super().__init__() self.num_layers = num_layers self.t_emb_dim = t_emb_dim self.context_dim = context_dim self.cross_attn = cross_attn self.resnet_conv_first = nn.ModuleList( [ nn.Sequential( nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), nn.SiLU(), nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1, padding=1), ) for i in range(num_layers + 1) ] ) if self.t_emb_dim is not None: self.t_emb_layers = nn.ModuleList([ nn.Sequential( nn.SiLU(), nn.Linear(t_emb_dim, out_channels) ) for _ in range(num_layers + 1) ]) self.resnet_conv_second = nn.ModuleList( [ nn.Sequential( nn.GroupNorm(norm_channels, out_channels), nn.SiLU(), nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), ) for _ in range(num_layers + 1) ] ) self.attention_norms = nn.ModuleList( [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)] ) self.attentions = nn.ModuleList( [nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers)] ) if self.cross_attn: assert context_dim is not None, "Context Dimension must be passed for cross attention" self.cross_attention_norms = nn.ModuleList( [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)] ) self.cross_attentions = nn.ModuleList( [nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers)] ) self.context_proj = nn.ModuleList( [nn.Linear(context_dim, out_channels) for _ in range(num_layers)] ) self.residual_input_conv = nn.ModuleList( [ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) for i in range(num_layers + 1) ] ) def forward(self, x, t_emb=None, context=None): out = x # First resnet block resnet_input = out out = self.resnet_conv_first[0](out) if self.t_emb_dim is not None: out = out + self.t_emb_layers[0](t_emb)[:, :, None, None] out = self.resnet_conv_second[0](out) out = out + self.residual_input_conv[0](resnet_input) for i in range(self.num_layers): # Attention Block batch_size, channels, h, w = out.shape in_attn = out.reshape(batch_size, channels, h * w) in_attn = self.attention_norms[i](in_attn) in_attn = in_attn.transpose(1, 2) out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) out = out + out_attn if self.cross_attn: assert context is not None, "context cannot be None if cross attention layers are used" batch_size, channels, h, w = out.shape in_attn = out.reshape(batch_size, channels, h * w) in_attn = self.cross_attention_norms[i](in_attn) in_attn = in_attn.transpose(1, 2) assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim context_proj = self.context_proj[i](context) out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj) out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) out = out + out_attn # Resnet Block resnet_input = out out = self.resnet_conv_first[i + 1](out) if self.t_emb_dim is not None: out = out + self.t_emb_layers[i + 1](t_emb)[:, :, None, None] out = self.resnet_conv_second[i + 1](out) out = out + self.residual_input_conv[i + 1](resnet_input) return out class UpBlock(nn.Module): r""" Up conv block with attention. Sequence of following blocks 1. Upsample 1. Concatenate Down block output 2. Resnet block with time embedding 3. Attention Block """ def __init__(self, in_channels, out_channels, t_emb_dim, up_sample, num_heads, num_layers, attn, norm_channels): super().__init__() self.num_layers = num_layers self.up_sample = up_sample self.t_emb_dim = t_emb_dim self.attn = attn self.resnet_conv_first = nn.ModuleList( [ nn.Sequential( nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), nn.SiLU(), nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1, padding=1), ) for i in range(num_layers) ] ) if self.t_emb_dim is not None: self.t_emb_layers = nn.ModuleList([ nn.Sequential( nn.SiLU(), nn.Linear(t_emb_dim, out_channels) ) for _ in range(num_layers) ]) self.resnet_conv_second = nn.ModuleList( [ nn.Sequential( nn.GroupNorm(norm_channels, out_channels), nn.SiLU(), nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), ) for _ in range(num_layers) ] ) if self.attn: self.attention_norms = nn.ModuleList( [ nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers) ] ) self.attentions = nn.ModuleList( [ nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers) ] ) self.residual_input_conv = nn.ModuleList( [ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) for i in range(num_layers) ] ) self.up_sample_conv = nn.ConvTranspose2d(in_channels, in_channels, 4, 2, 1) \ if self.up_sample else nn.Identity() def forward(self, x, out_down=None, t_emb=None): # Upsample x = self.up_sample_conv(x) # Concat with Downblock output if out_down is not None: x = torch.cat([x, out_down], dim=1) out = x for i in range(self.num_layers): # Resnet Block resnet_input = out out = self.resnet_conv_first[i](out) if self.t_emb_dim is not None: out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] out = self.resnet_conv_second[i](out) out = out + self.residual_input_conv[i](resnet_input) # Self Attention if self.attn: batch_size, channels, h, w = out.shape in_attn = out.reshape(batch_size, channels, h * w) in_attn = self.attention_norms[i](in_attn) in_attn = in_attn.transpose(1, 2) out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) out = out + out_attn return out class UpBlockUnet(nn.Module): r""" Up conv block with attention. Sequence of following blocks 1. Upsample 1. Concatenate Down block output 2. Resnet block with time embedding 3. Attention Block """ def __init__(self, in_channels, out_channels, t_emb_dim, up_sample, num_heads, num_layers, norm_channels, cross_attn=False, context_dim=None): super().__init__() self.num_layers = num_layers self.up_sample = up_sample self.t_emb_dim = t_emb_dim self.cross_attn = cross_attn self.context_dim = context_dim self.resnet_conv_first = nn.ModuleList( [ nn.Sequential( nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), nn.SiLU(), nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1, padding=1), ) for i in range(num_layers) ] ) if self.t_emb_dim is not None: self.t_emb_layers = nn.ModuleList([ nn.Sequential( nn.SiLU(), nn.Linear(t_emb_dim, out_channels) ) for _ in range(num_layers) ]) self.resnet_conv_second = nn.ModuleList( [ nn.Sequential( nn.GroupNorm(norm_channels, out_channels), nn.SiLU(), nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), ) for _ in range(num_layers) ] ) self.attention_norms = nn.ModuleList( [ nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers) ] ) self.attentions = nn.ModuleList( [ nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers) ] ) if self.cross_attn: assert context_dim is not None, "Context Dimension must be passed for cross attention" self.cross_attention_norms = nn.ModuleList( [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)] ) self.cross_attentions = nn.ModuleList( [nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers)] ) self.context_proj = nn.ModuleList( [nn.Linear(context_dim, out_channels) for _ in range(num_layers)] ) self.residual_input_conv = nn.ModuleList( [ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) for i in range(num_layers) ] ) self.up_sample_conv = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, 4, 2, 1) \ if self.up_sample else nn.Identity() def forward(self, x, out_down=None, t_emb=None, context=None): x = self.up_sample_conv(x) if out_down is not None: x = torch.cat([x, out_down], dim=1) out = x for i in range(self.num_layers): # Resnet resnet_input = out out = self.resnet_conv_first[i](out) if self.t_emb_dim is not None: out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] out = self.resnet_conv_second[i](out) out = out + self.residual_input_conv[i](resnet_input) # Self Attention batch_size, channels, h, w = out.shape in_attn = out.reshape(batch_size, channels, h * w) in_attn = self.attention_norms[i](in_attn) in_attn = in_attn.transpose(1, 2) out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) out = out + out_attn # Cross Attention if self.cross_attn: assert context is not None, "context cannot be None if cross attention layers are used" batch_size, channels, h, w = out.shape in_attn = out.reshape(batch_size, channels, h * w) in_attn = self.cross_attention_norms[i](in_attn) in_attn = in_attn.transpose(1, 2) assert len(context.shape) == 3, \ "Context shape does not match B,_,CONTEXT_DIM" assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim,\ "Context shape does not match B,_,CONTEXT_DIM" context_proj = self.context_proj[i](context) out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj) out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) out = out + out_attn return out # ========================================== # VQVAE Definition # ========================================== class VQVAE(nn.Module): def __init__(self, im_channels, model_config): super().__init__() self.down_channels = model_config['down_channels'] self.mid_channels = model_config['mid_channels'] self.down_sample = model_config['down_sample'] self.num_down_layers = model_config['num_down_layers'] self.num_mid_layers = model_config['num_mid_layers'] self.num_up_layers = model_config['num_up_layers'] # To disable attention in Downblock of Encoder and Upblock of Decoder self.attns = model_config['attn_down'] #Latent Dimension self.z_channels = model_config['z_channels'] self.codebook_size = model_config['codebook_size'] self.norm_channels = model_config['norm_channels'] self.num_heads = model_config['num_heads'] #Assertion to validate the channel information assert self.mid_channels[0] == self.down_channels[-1] assert self.mid_channels[-1] == self.down_channels[-1] assert len(self.down_sample) == len(self.down_channels) - 1 assert len(self.attns) == len(self.down_channels) - 1 # Wherever we use downsampling in encoder correspondingly use # upsampling in decoder self.up_sample = list(reversed(self.down_sample)) ## Encoder ## self.encoder_conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=(1, 1)) # Downblock + Midblock self.encoder_layers = nn.ModuleList([]) for i in range(len(self.down_channels) - 1): self.encoder_layers.append(DownBlock(self.down_channels[i], self.down_channels[i + 1], t_emb_dim=None, down_sample=self.down_sample[i], num_heads=self.num_heads, num_layers=self.num_down_layers, attn=self.attns[i], norm_channels=self.norm_channels)) self.encoder_mids = nn.ModuleList([]) for i in range(len(self.mid_channels) - 1): self.encoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i + 1], t_emb_dim=None, num_heads=self.num_heads, num_layers=self.num_mid_layers, norm_channels=self.norm_channels)) self.encoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[-1]) self.encoder_conv_out = nn.Conv2d(self.down_channels[-1], self.z_channels, kernel_size=3, padding=1) # Pre Quantization Convolution self.pre_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1) # Codebook self.embedding = nn.Embedding(self.codebook_size, self.z_channels) ## Decoder ## # Post Quantization Convolution self.post_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1) self.decoder_conv_in = nn.Conv2d(self.z_channels, self.mid_channels[-1], kernel_size=3, padding=(1, 1)) # Midblock + Upblock self.decoder_mids = nn.ModuleList([]) for i in reversed(range(1, len(self.mid_channels))): self.decoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i - 1], t_emb_dim=None, num_heads=self.num_heads, num_layers=self.num_mid_layers, norm_channels=self.norm_channels)) self.decoder_layers = nn.ModuleList([]) for i in reversed(range(1, len(self.down_channels))): self.decoder_layers.append(UpBlock(self.down_channels[i], self.down_channels[i - 1], t_emb_dim=None, up_sample=self.down_sample[i - 1], num_heads=self.num_heads, num_layers=self.num_up_layers, attn=self.attns[i-1], norm_channels=self.norm_channels)) self.decoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[0]) self.decoder_conv_out = nn.Conv2d(self.down_channels[0], im_channels, kernel_size=3, padding=1) def quantize(self, x): B, C, H, W = x.shape # B, C, H, W -> B, H, W, C x = x.permute(0, 2, 3, 1) # B, H, W, C -> B, H*W, C x = x.reshape(x.size(0), -1, x.size(-1)) # Find nearest embedding/codebook vector # dist between (B, H*W, C) and (B, K, C) -> (B, H*W, K) dist = torch.cdist(x, self.embedding.weight[None, :].repeat((x.size(0), 1, 1))) # (B, H*W) min_encoding_indices = torch.argmin(dist, dim=-1) # Replace encoder output with nearest codebook # quant_out -> B*H*W, C quant_out = torch.index_select(self.embedding.weight, 0, min_encoding_indices.view(-1)) # x -> B*H*W, C x = x.reshape((-1, x.size(-1))) commmitment_loss = torch.mean((quant_out.detach() - x) ** 2) codebook_loss = torch.mean((quant_out - x.detach()) ** 2) quantize_losses = { 'codebook_loss': codebook_loss, 'commitment_loss': commmitment_loss } # Straight through estimation quant_out = x + (quant_out - x).detach() # quant_out -> B, C, H, W quant_out = quant_out.reshape((B, H, W, C)).permute(0, 3, 1, 2) min_encoding_indices = min_encoding_indices.reshape((-1, quant_out.size(-2), quant_out.size(-1))) return quant_out, quantize_losses, min_encoding_indices def encode(self, x): out = self.encoder_conv_in(x) for idx, down in enumerate(self.encoder_layers): out = down(out) for mid in self.encoder_mids: out = mid(out) out = self.encoder_norm_out(out) out = nn.SiLU()(out) out = self.encoder_conv_out(out) out = self.pre_quant_conv(out) out, quant_losses, _ = self.quantize(out) return out, quant_losses def decode(self, z): out = z out = self.post_quant_conv(out) out = self.decoder_conv_in(out) for mid in self.decoder_mids: out = mid(out) for idx, up in enumerate(self.decoder_layers): out = up(out) out = self.decoder_norm_out(out) out = nn.SiLU()(out) out = self.decoder_conv_out(out) return out def forward(self, x): z, quant_losses = self.encode(x) out = self.decode(z) return out, z, quant_losses # ========================================== # SPADE Definitions # ========================================== class SPADE(nn.Module): def __init__(self, norm_nc, label_nc): super().__init__() self.param_free_norm = nn.GroupNorm(32, norm_nc) nhidden = 128 # Convolutions to generate modulation parameters from the mask self.mlp_shared = nn.Sequential( nn.Conv2d(label_nc, nhidden, kernel_size=3, padding=1), nn.ReLU() ) self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1) self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1) def forward(self, x, segmap): # 1. Normalize normalized = self.param_free_norm(x) # 2. Resize mask to match x's resolution if segmap.size()[2:] != x.size()[2:]: segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest') # 3. Generate params actv = self.mlp_shared(segmap) gamma = self.mlp_gamma(actv) beta = self.mlp_beta(actv) # 4. Modulate out = normalized * (1 + gamma) + beta return out class SPADEResnetBlock(nn.Module): """ Simplified SPADE Block: Norm -> Act -> Conv (We removed the internal shortcut because DownBlock/MidBlock handles the residual connection) """ def __init__(self, in_channels, out_channels, label_nc): super().__init__() # 1. SPADE Normalization (Uses Mask) self.norm1 = SPADE(in_channels, label_nc) # 2. Activation self.act1 = nn.SiLU() # 3. Convolution self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) def forward(self, x, segmap): # Apply SPADE Norm -> Act -> Conv h = self.norm1(x, segmap) h = self.act1(h) h = self.conv1(h) return h # ========================================== # BLOCKS (Down, Mid, Up) # ========================================== def get_time_embedding(time_steps, temb_dim): assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2" factor = 10000 ** ((torch.arange( start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2)) ) t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1) return t_emb class SpadeDownBlock(nn.Module): def __init__(self, in_channels, out_channels, t_emb_dim, down_sample, num_heads, num_layers, attn, norm_channels, cross_attn=False, context_dim=None, label_nc=4): super().__init__() self.num_layers = num_layers self.down_sample = down_sample self.attn = attn self.context_dim = context_dim self.cross_attn = cross_attn self.t_emb_dim = t_emb_dim # REPLACED nn.Sequential with SPADEResnetBlock self.resnet_conv_first = nn.ModuleList([ SPADEResnetBlock(in_channels if i == 0 else out_channels, out_channels, label_nc) for i in range(num_layers) ]) if self.t_emb_dim is not None: self.t_emb_layers = nn.ModuleList([ nn.Sequential(nn.SiLU(), nn.Linear(self.t_emb_dim, out_channels)) for _ in range(num_layers) ]) # REPLACED nn.Sequential with SPADEResnetBlock self.resnet_conv_second = nn.ModuleList([ SPADEResnetBlock(out_channels, out_channels, label_nc) for _ in range(num_layers) ]) if self.attn: self.attention_norms = nn.ModuleList([nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)]) self.attentions = nn.ModuleList([nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers)]) if self.cross_attn: self.cross_attention_norms = nn.ModuleList([nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)]) self.cross_attentions = nn.ModuleList([nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers)]) self.context_proj = nn.ModuleList([nn.Linear(context_dim, out_channels) for _ in range(num_layers)]) self.residual_input_conv = nn.ModuleList([ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) for i in range(num_layers) ]) self.down_sample_conv = nn.Conv2d(out_channels, out_channels, 4, 2, 1) if self.down_sample else nn.Identity() def forward(self, x, t_emb=None, context=None, segmap=None): out = x for i in range(self.num_layers): resnet_input = out # SPADE Block 1 (Pass segmap) out = self.resnet_conv_first[i](out, segmap) if self.t_emb_dim is not None: out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] # SPADE Block 2 (Pass segmap) out = self.resnet_conv_second[i](out, segmap) # No residual add here because SPADEResnetBlock handles its own residual/shortcut # But your original code added another residual from the very start of the loop out = out + self.residual_input_conv[i](resnet_input) if self.attn: batch_size, channels, h, w = out.shape in_attn = out.reshape(batch_size, channels, h * w) in_attn = self.attention_norms[i](in_attn) in_attn = in_attn.transpose(1, 2) out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) out = out + out_attn if self.cross_attn: batch_size, channels, h, w = out.shape in_attn = out.reshape(batch_size, channels, h * w) in_attn = self.cross_attention_norms[i](in_attn) in_attn = in_attn.transpose(1, 2) context_proj = self.context_proj[i](context) out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj) out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) out = out + out_attn out = self.down_sample_conv(out) return out class SpadeMidBlock(nn.Module): def __init__(self, in_channels, out_channels, t_emb_dim, num_heads, num_layers, norm_channels, cross_attn=None, context_dim=None, label_nc=4): super().__init__() self.num_layers = num_layers self.t_emb_dim = t_emb_dim self.context_dim = context_dim self.cross_attn = cross_attn # REPLACED with SPADE self.resnet_conv_first = nn.ModuleList([ SPADEResnetBlock(in_channels if i == 0 else out_channels, out_channels, label_nc) for i in range(num_layers + 1) ]) if self.t_emb_dim is not None: self.t_emb_layers = nn.ModuleList([ nn.Sequential(nn.SiLU(), nn.Linear(t_emb_dim, out_channels)) for _ in range(num_layers + 1) ]) # REPLACED with SPADE self.resnet_conv_second = nn.ModuleList([ SPADEResnetBlock(out_channels, out_channels, label_nc) for _ in range(num_layers + 1) ]) self.attention_norms = nn.ModuleList([nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)]) self.attentions = nn.ModuleList([nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers)]) if self.cross_attn: self.cross_attention_norms = nn.ModuleList([nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)]) self.cross_attentions = nn.ModuleList([nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers)]) self.context_proj = nn.ModuleList([nn.Linear(context_dim, out_channels) for _ in range(num_layers)]) self.residual_input_conv = nn.ModuleList([ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) for i in range(num_layers + 1) ]) def forward(self, x, t_emb=None, context=None, segmap=None): out = x # First Block (No Attention) resnet_input = out out = self.resnet_conv_first[0](out, segmap) # Pass segmap if self.t_emb_dim is not None: out = out + self.t_emb_layers[0](t_emb)[:, :, None, None] out = self.resnet_conv_second[0](out, segmap) # Pass segmap out = out + self.residual_input_conv[0](resnet_input) for i in range(self.num_layers): # Attention batch_size, channels, h, w = out.shape in_attn = out.reshape(batch_size, channels, h * w) in_attn = self.attention_norms[i](in_attn) in_attn = in_attn.transpose(1, 2) out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) out = out + out_attn if self.cross_attn: batch_size, channels, h, w = out.shape in_attn = out.reshape(batch_size, channels, h * w) in_attn = self.cross_attention_norms[i](in_attn) in_attn = in_attn.transpose(1, 2) context_proj = self.context_proj[i](context) out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj) out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) out = out + out_attn # Next Resnet Block resnet_input = out out = self.resnet_conv_first[i + 1](out, segmap) # Pass segmap if self.t_emb_dim is not None: out = out + self.t_emb_layers[i + 1](t_emb)[:, :, None, None] out = self.resnet_conv_second[i + 1](out, segmap) # Pass segmap out = out + self.residual_input_conv[i + 1](resnet_input) return out class SpadeUpBlock(nn.Module): def __init__(self, in_channels, out_channels, t_emb_dim, up_sample, num_heads, num_layers, norm_channels, cross_attn=False, context_dim=None, label_nc=4): super().__init__() self.num_layers = num_layers self.up_sample = up_sample self.t_emb_dim = t_emb_dim self.cross_attn = cross_attn self.context_dim = context_dim # REPLACED with SPADE self.resnet_conv_first = nn.ModuleList([ SPADEResnetBlock(in_channels if i == 0 else out_channels, out_channels, label_nc) for i in range(num_layers) ]) if self.t_emb_dim is not None: self.t_emb_layers = nn.ModuleList([ nn.Sequential(nn.SiLU(), nn.Linear(t_emb_dim, out_channels)) for _ in range(num_layers) ]) # REPLACED with SPADE self.resnet_conv_second = nn.ModuleList([ SPADEResnetBlock(out_channels, out_channels, label_nc) for _ in range(num_layers) ]) self.attention_norms = nn.ModuleList([nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)]) self.attentions = nn.ModuleList([nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers)]) if self.cross_attn: self.cross_attention_norms = nn.ModuleList([nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)]) self.cross_attentions = nn.ModuleList([nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers)]) self.context_proj = nn.ModuleList([nn.Linear(context_dim, out_channels) for _ in range(num_layers)]) self.residual_input_conv = nn.ModuleList([ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) for i in range(num_layers) ]) self.up_sample_conv = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, 4, 2, 1) if self.up_sample else nn.Identity() def forward(self, x, out_down=None, t_emb=None, context=None, segmap=None): x = self.up_sample_conv(x) if out_down is not None: x = torch.cat([x, out_down], dim=1) out = x for i in range(self.num_layers): resnet_input = out out = self.resnet_conv_first[i](out, segmap) # Pass segmap if self.t_emb_dim is not None: out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] out = self.resnet_conv_second[i](out, segmap) # Pass segmap out = out + self.residual_input_conv[i](resnet_input) batch_size, channels, h, w = out.shape in_attn = out.reshape(batch_size, channels, h * w) in_attn = self.attention_norms[i](in_attn) in_attn = in_attn.transpose(1, 2) out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) out = out + out_attn if self.cross_attn: batch_size, channels, h, w = out.shape in_attn = out.reshape(batch_size, channels, h * w) in_attn = self.cross_attention_norms[i](in_attn) in_attn = in_attn.transpose(1, 2) context_proj = self.context_proj[i](context) out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj) out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) out = out + out_attn return out # ========================================== # Helper Fuctions # ========================================== def validate_image_config(condition_config): assert 'image_condition_config' in condition_config, "Image conditioning desired but config missing" assert 'image_condition_input_channels' in condition_config['image_condition_config'], "Input channels missing" assert 'image_condition_output_channels' in condition_config['image_condition_config'], "Output channels missing" def validate_image_conditional_input(cond_input, x): assert 'image' in cond_input, "Model initialized with image conditioning but input missing" assert cond_input['image'].shape[0] == x.shape[0], "Batch size mismatch" def get_config_value(config, key, default_value): return config[key] if key in config else default_value def get_time_embedding(time_steps, temb_dim): assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2" factor = 10000 ** ((torch.arange( start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2)) ) t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1) return t_emb def drop_image_condition(image_condition, im, im_drop_prob): if im_drop_prob > 0: im_drop_mask = torch.zeros((im.shape[0], 1, 1, 1), device=im.device).float().uniform_(0, 1) > im_drop_prob return image_condition * im_drop_mask else: return image_condition # ========================================== # UNET Definition # ========================================== class Unet(nn.Module): #Unet model with SPADE integration for anatomical consistency. def __init__(self, im_channels, model_config): super().__init__() self.down_channels = model_config['down_channels'] self.mid_channels = model_config['mid_channels'] self.t_emb_dim = model_config['time_emb_dim'] self.down_sample = model_config['down_sample'] self.num_down_layers = model_config['num_down_layers'] self.num_mid_layers = model_config['num_mid_layers'] self.num_up_layers = model_config['num_up_layers'] self.attns = model_config['attn_down'] self.norm_channels = model_config['norm_channels'] self.num_heads = model_config['num_heads'] self.conv_out_channels = model_config['conv_out_channels'] # Validate Config assert self.mid_channels[0] == self.down_channels[-1] assert self.mid_channels[-1] == self.down_channels[-2] assert len(self.down_sample) == len(self.down_channels) - 1 assert len(self.attns) == len(self.down_channels) - 1 # Conditioning Setup self.image_cond = False self.condition_config = get_config_value(model_config, 'condition_config', None) # Default mask channels (usually 4: BG, LV, Myo, RV) self.im_cond_input_ch = 4 if self.condition_config is not None: if 'image' in self.condition_config.get('condition_types', []): self.image_cond = True self.im_cond_input_ch = self.condition_config['image_condition_config']['image_condition_input_channels'] self.im_cond_output_ch = self.condition_config['image_condition_config']['image_condition_output_channels'] # Standard Input Conv # SPADE injects the mask later, so we just take the latent input here. self.conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=1) # Time Embedding self.t_proj = nn.Sequential( nn.Linear(self.t_emb_dim, self.t_emb_dim), nn.SiLU(), nn.Linear(self.t_emb_dim, self.t_emb_dim) ) self.up_sample = list(reversed(self.down_sample)) self.downs = nn.ModuleList([]) # Pass label_nc to Blocks for i in range(len(self.down_channels) - 1): self.downs.append(SpadeDownBlock( self.down_channels[i], self.down_channels[i + 1], self.t_emb_dim, down_sample=self.down_sample[i], num_heads=self.num_heads, num_layers=self.num_down_layers, attn=self.attns[i], norm_channels=self.norm_channels, label_nc=self.im_cond_input_ch # SPADE needs this )) self.mids = nn.ModuleList([]) for i in range(len(self.mid_channels) - 1): self.mids.append(SpadeMidBlock( self.mid_channels[i], self.mid_channels[i + 1], self.t_emb_dim, num_heads=self.num_heads, num_layers=self.num_mid_layers, norm_channels=self.norm_channels, label_nc=self.im_cond_input_ch # SPADE needs this )) self.ups = nn.ModuleList([]) for i in reversed(range(len(self.down_channels) - 1)): self.ups.append(SpadeUpBlock( self.down_channels[i] * 2, self.down_channels[i - 1] if i != 0 else self.conv_out_channels, self.t_emb_dim, up_sample=self.down_sample[i], num_heads=self.num_heads, num_layers=self.num_up_layers, norm_channels=self.norm_channels, label_nc=self.im_cond_input_ch # SPADE needs this )) self.norm_out = nn.GroupNorm(self.norm_channels, self.conv_out_channels) self.conv_out = nn.Conv2d(self.conv_out_channels, im_channels, kernel_size=3, padding=1) def forward(self, x, t, cond_input=None): # 1. Validation if self.image_cond: validate_image_conditional_input(cond_input, x) # Get the mask, but don't concatenate yet im_cond = cond_input['image'] else: im_cond = None # 2. Initial Conv (Standard) out = self.conv_in(x) # 3. Time Embedding t_emb = get_time_embedding(torch.as_tensor(t).long(), self.t_emb_dim) t_emb = self.t_proj(t_emb) # 4. Down Blocks (Pass segmap) down_outs = [] for down in self.downs: down_outs.append(out) # Inject Mask into Block out = down(out, t_emb, segmap=im_cond) # 5. Mid Blocks (Pass segmap) for mid in self.mids: # Inject Mask into Block out = mid(out, t_emb, segmap=im_cond) # 6. Up Blocks (Pass segmap) for up in self.ups: down_out = down_outs.pop() # Inject Mask into Block out = up(out, down_out, t_emb, segmap=im_cond) out = self.norm_out(out) out = nn.SiLU()(out) out = self.conv_out(out) return out # ========================================== # Noise Schedular Definition # ========================================== class LinearNoiseScheduler: def __init__(self, num_timesteps, beta_start, beta_end): self.num_timesteps = num_timesteps self.beta_start = beta_start self.beta_end = beta_end self.betas = (torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_timesteps) ** 2) self.alphas = 1. - self.betas self.alpha_cum_prod = torch.cumprod(self.alphas, dim=0) self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod) self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1 - self.alpha_cum_prod) def add_noise(self, original, noise, t): original_shape = original.shape batch_size = original_shape[0] sqrt_alpha_cum_prod = self.sqrt_alpha_cum_prod.to(original.device)[t].reshape(batch_size) sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod.to(original.device)[t].reshape(batch_size) for _ in range(len(original_shape) - 1): sqrt_alpha_cum_prod = sqrt_alpha_cum_prod.unsqueeze(-1) sqrt_one_minus_alpha_cum_prod = sqrt_one_minus_alpha_cum_prod.unsqueeze(-1) return (sqrt_alpha_cum_prod * original + sqrt_one_minus_alpha_cum_prod * noise) def sample_prev_timestep(self, xt, noise_pred, t): """ Reverse diffusion process: Remove noise to get x_{t-1} """ sqrt_one_minus_alpha_bar = self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t].view(-1, 1, 1, 1) sqrt_alpha_bar = self.sqrt_alpha_cum_prod.to(xt.device)[t].view(-1, 1, 1, 1) beta_t = self.betas.to(xt.device)[t].view(-1, 1, 1, 1) alpha_t = self.alphas.to(xt.device)[t].view(-1, 1, 1, 1) # 1. Estimate x0 (Original image) x0 = (xt - (sqrt_one_minus_alpha_bar * noise_pred)) / sqrt_alpha_bar x0 = torch.clamp(x0, -1., 1.) # 2. Calculate Mean of x_{t-1} mean = (xt - (beta_t * noise_pred) / sqrt_one_minus_alpha_bar) / torch.sqrt(alpha_t) # 3. Add Noise (if not last step) if t[0] == 0: return mean, x0 else: # Reshape variance to [Batch, 1, 1, 1] too variance = ((1 - self.alpha_cum_prod.to(xt.device)[t - 1]) / (1.0 - self.alpha_cum_prod.to(xt.device)[t])) * self.betas.to(xt.device)[t] sigma = (variance ** 0.5).view(-1, 1, 1, 1) z = torch.randn(xt.shape).to(xt.device) return mean + sigma * z, x0 # 1. Estimate x0 (Original image) # x0 = ((xt - (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t] * noise_pred)) / # torch.sqrt(self.alpha_cum_prod.to(xt.device)[t])) # x0 = torch.clamp(x0, -1., 1.) # # 2. Calculate Mean of x_{t-1} # mean = xt - ((self.betas.to(xt.device)[t]) * noise_pred) / (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t]) # mean = mean / torch.sqrt(self.alphas.to(xt.device)[t]) # # 3. Add Noise (if not last step) # if t == 0: # return mean, x0 # else: # variance = (1 - self.alpha_cum_prod.to(xt.device)[t - 1]) / (1.0 - self.alpha_cum_prod.to(xt.device)[t]) # variance = variance * self.betas.to(xt.device)[t] # sigma = variance ** 0.5 # z = torch.randn(xt.shape).to(xt.device) # return mean + sigma * z, x0