SynCMRIApp / diffusion.py
Ishan Kumarasinghe
Update app file and requirements
29da2fa
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import einsum
import numpy as np
import pickle
import glob
import os
# ==========================================
# BLOCKS for VQVAE (Down, Mid, Up)
# ==========================================
def get_time_embedding(time_steps, temb_dim):
r"""
Convert time steps tensor into an embedding using the
sinusoidal time embedding formula
:param time_steps: 1D tensor of length batch size
:param temb_dim: Dimension of the embedding
:return: BxD embedding representation of B time steps
"""
assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2"
# factor = 10000^(2i/d_model)
factor = 10000 ** ((torch.arange(
start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2))
)
# pos / factor
# timesteps B -> B, 1 -> B, temb_dim
t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor
t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)
return t_emb
class DownBlock(nn.Module):
r"""
Down conv block with attention.
Sequence of following block
1. Resnet block with time embedding
2. Attention block
3. Downsample
"""
def __init__(self, in_channels, out_channels, t_emb_dim,
down_sample, num_heads, num_layers, attn, norm_channels, cross_attn=False, context_dim=None):
super().__init__()
self.num_layers = num_layers
self.down_sample = down_sample
self.attn = attn
self.context_dim = context_dim
self.cross_attn = cross_attn
self.t_emb_dim = t_emb_dim
self.resnet_conv_first = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
nn.SiLU(),
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels,
kernel_size=3, stride=1, padding=1),
)
for i in range(num_layers)
]
)
if self.t_emb_dim is not None:
self.t_emb_layers = nn.ModuleList([
nn.Sequential(
nn.SiLU(),
nn.Linear(self.t_emb_dim, out_channels)
)
for _ in range(num_layers)
])
self.resnet_conv_second = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(norm_channels, out_channels),
nn.SiLU(),
nn.Conv2d(out_channels, out_channels,
kernel_size=3, stride=1, padding=1),
)
for _ in range(num_layers)
]
)
if self.attn:
self.attention_norms = nn.ModuleList(
[nn.GroupNorm(norm_channels, out_channels)
for _ in range(num_layers)]
)
self.attentions = nn.ModuleList(
[nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
for _ in range(num_layers)]
)
if self.cross_attn:
assert context_dim is not None, "Context Dimension must be passed for cross attention"
self.cross_attention_norms = nn.ModuleList(
[nn.GroupNorm(norm_channels, out_channels)
for _ in range(num_layers)]
)
self.cross_attentions = nn.ModuleList(
[nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
for _ in range(num_layers)]
)
self.context_proj = nn.ModuleList(
[nn.Linear(context_dim, out_channels)
for _ in range(num_layers)]
)
self.residual_input_conv = nn.ModuleList(
[
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
for i in range(num_layers)
]
)
self.down_sample_conv = nn.Conv2d(out_channels, out_channels,
4, 2, 1) if self.down_sample else nn.Identity()
def forward(self, x, t_emb=None, context=None):
out = x
for i in range(self.num_layers):
# Resnet block of Unet
resnet_input = out
out = self.resnet_conv_first[i](out)
if self.t_emb_dim is not None:
out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
out = self.resnet_conv_second[i](out)
out = out + self.residual_input_conv[i](resnet_input)
if self.attn:
# Attention block of Unet
batch_size, channels, h, w = out.shape
in_attn = out.reshape(batch_size, channels, h * w)
in_attn = self.attention_norms[i](in_attn)
in_attn = in_attn.transpose(1, 2)
out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
out = out + out_attn
if self.cross_attn:
assert context is not None, "context cannot be None if cross attention layers are used"
batch_size, channels, h, w = out.shape
in_attn = out.reshape(batch_size, channels, h * w)
in_attn = self.cross_attention_norms[i](in_attn)
in_attn = in_attn.transpose(1, 2)
assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim
context_proj = self.context_proj[i](context)
out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
out = out + out_attn
# Downsample
out = self.down_sample_conv(out)
return out
class MidBlock(nn.Module):
r"""
Mid conv block with attention.
Sequence of following blocks
1. Resnet block with time embedding
2. Attention block
3. Resnet block with time embedding
"""
def __init__(self, in_channels, out_channels, t_emb_dim, num_heads, num_layers, norm_channels, cross_attn=None, context_dim=None):
super().__init__()
self.num_layers = num_layers
self.t_emb_dim = t_emb_dim
self.context_dim = context_dim
self.cross_attn = cross_attn
self.resnet_conv_first = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
nn.SiLU(),
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
padding=1),
)
for i in range(num_layers + 1)
]
)
if self.t_emb_dim is not None:
self.t_emb_layers = nn.ModuleList([
nn.Sequential(
nn.SiLU(),
nn.Linear(t_emb_dim, out_channels)
)
for _ in range(num_layers + 1)
])
self.resnet_conv_second = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(norm_channels, out_channels),
nn.SiLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
)
for _ in range(num_layers + 1)
]
)
self.attention_norms = nn.ModuleList(
[nn.GroupNorm(norm_channels, out_channels)
for _ in range(num_layers)]
)
self.attentions = nn.ModuleList(
[nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
for _ in range(num_layers)]
)
if self.cross_attn:
assert context_dim is not None, "Context Dimension must be passed for cross attention"
self.cross_attention_norms = nn.ModuleList(
[nn.GroupNorm(norm_channels, out_channels)
for _ in range(num_layers)]
)
self.cross_attentions = nn.ModuleList(
[nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
for _ in range(num_layers)]
)
self.context_proj = nn.ModuleList(
[nn.Linear(context_dim, out_channels)
for _ in range(num_layers)]
)
self.residual_input_conv = nn.ModuleList(
[
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
for i in range(num_layers + 1)
]
)
def forward(self, x, t_emb=None, context=None):
out = x
# First resnet block
resnet_input = out
out = self.resnet_conv_first[0](out)
if self.t_emb_dim is not None:
out = out + self.t_emb_layers[0](t_emb)[:, :, None, None]
out = self.resnet_conv_second[0](out)
out = out + self.residual_input_conv[0](resnet_input)
for i in range(self.num_layers):
# Attention Block
batch_size, channels, h, w = out.shape
in_attn = out.reshape(batch_size, channels, h * w)
in_attn = self.attention_norms[i](in_attn)
in_attn = in_attn.transpose(1, 2)
out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
out = out + out_attn
if self.cross_attn:
assert context is not None, "context cannot be None if cross attention layers are used"
batch_size, channels, h, w = out.shape
in_attn = out.reshape(batch_size, channels, h * w)
in_attn = self.cross_attention_norms[i](in_attn)
in_attn = in_attn.transpose(1, 2)
assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim
context_proj = self.context_proj[i](context)
out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
out = out + out_attn
# Resnet Block
resnet_input = out
out = self.resnet_conv_first[i + 1](out)
if self.t_emb_dim is not None:
out = out + self.t_emb_layers[i + 1](t_emb)[:, :, None, None]
out = self.resnet_conv_second[i + 1](out)
out = out + self.residual_input_conv[i + 1](resnet_input)
return out
class UpBlock(nn.Module):
r"""
Up conv block with attention.
Sequence of following blocks
1. Upsample
1. Concatenate Down block output
2. Resnet block with time embedding
3. Attention Block
"""
def __init__(self, in_channels, out_channels, t_emb_dim,
up_sample, num_heads, num_layers, attn, norm_channels):
super().__init__()
self.num_layers = num_layers
self.up_sample = up_sample
self.t_emb_dim = t_emb_dim
self.attn = attn
self.resnet_conv_first = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
nn.SiLU(),
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
padding=1),
)
for i in range(num_layers)
]
)
if self.t_emb_dim is not None:
self.t_emb_layers = nn.ModuleList([
nn.Sequential(
nn.SiLU(),
nn.Linear(t_emb_dim, out_channels)
)
for _ in range(num_layers)
])
self.resnet_conv_second = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(norm_channels, out_channels),
nn.SiLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
)
for _ in range(num_layers)
]
)
if self.attn:
self.attention_norms = nn.ModuleList(
[
nn.GroupNorm(norm_channels, out_channels)
for _ in range(num_layers)
]
)
self.attentions = nn.ModuleList(
[
nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
for _ in range(num_layers)
]
)
self.residual_input_conv = nn.ModuleList(
[
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
for i in range(num_layers)
]
)
self.up_sample_conv = nn.ConvTranspose2d(in_channels, in_channels,
4, 2, 1) \
if self.up_sample else nn.Identity()
def forward(self, x, out_down=None, t_emb=None):
# Upsample
x = self.up_sample_conv(x)
# Concat with Downblock output
if out_down is not None:
x = torch.cat([x, out_down], dim=1)
out = x
for i in range(self.num_layers):
# Resnet Block
resnet_input = out
out = self.resnet_conv_first[i](out)
if self.t_emb_dim is not None:
out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
out = self.resnet_conv_second[i](out)
out = out + self.residual_input_conv[i](resnet_input)
# Self Attention
if self.attn:
batch_size, channels, h, w = out.shape
in_attn = out.reshape(batch_size, channels, h * w)
in_attn = self.attention_norms[i](in_attn)
in_attn = in_attn.transpose(1, 2)
out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
out = out + out_attn
return out
class UpBlockUnet(nn.Module):
r"""
Up conv block with attention.
Sequence of following blocks
1. Upsample
1. Concatenate Down block output
2. Resnet block with time embedding
3. Attention Block
"""
def __init__(self, in_channels, out_channels, t_emb_dim, up_sample,
num_heads, num_layers, norm_channels, cross_attn=False, context_dim=None):
super().__init__()
self.num_layers = num_layers
self.up_sample = up_sample
self.t_emb_dim = t_emb_dim
self.cross_attn = cross_attn
self.context_dim = context_dim
self.resnet_conv_first = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
nn.SiLU(),
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
padding=1),
)
for i in range(num_layers)
]
)
if self.t_emb_dim is not None:
self.t_emb_layers = nn.ModuleList([
nn.Sequential(
nn.SiLU(),
nn.Linear(t_emb_dim, out_channels)
)
for _ in range(num_layers)
])
self.resnet_conv_second = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(norm_channels, out_channels),
nn.SiLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
)
for _ in range(num_layers)
]
)
self.attention_norms = nn.ModuleList(
[
nn.GroupNorm(norm_channels, out_channels)
for _ in range(num_layers)
]
)
self.attentions = nn.ModuleList(
[
nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
for _ in range(num_layers)
]
)
if self.cross_attn:
assert context_dim is not None, "Context Dimension must be passed for cross attention"
self.cross_attention_norms = nn.ModuleList(
[nn.GroupNorm(norm_channels, out_channels)
for _ in range(num_layers)]
)
self.cross_attentions = nn.ModuleList(
[nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
for _ in range(num_layers)]
)
self.context_proj = nn.ModuleList(
[nn.Linear(context_dim, out_channels)
for _ in range(num_layers)]
)
self.residual_input_conv = nn.ModuleList(
[
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
for i in range(num_layers)
]
)
self.up_sample_conv = nn.ConvTranspose2d(in_channels // 2, in_channels // 2,
4, 2, 1) \
if self.up_sample else nn.Identity()
def forward(self, x, out_down=None, t_emb=None, context=None):
x = self.up_sample_conv(x)
if out_down is not None:
x = torch.cat([x, out_down], dim=1)
out = x
for i in range(self.num_layers):
# Resnet
resnet_input = out
out = self.resnet_conv_first[i](out)
if self.t_emb_dim is not None:
out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
out = self.resnet_conv_second[i](out)
out = out + self.residual_input_conv[i](resnet_input)
# Self Attention
batch_size, channels, h, w = out.shape
in_attn = out.reshape(batch_size, channels, h * w)
in_attn = self.attention_norms[i](in_attn)
in_attn = in_attn.transpose(1, 2)
out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
out = out + out_attn
# Cross Attention
if self.cross_attn:
assert context is not None, "context cannot be None if cross attention layers are used"
batch_size, channels, h, w = out.shape
in_attn = out.reshape(batch_size, channels, h * w)
in_attn = self.cross_attention_norms[i](in_attn)
in_attn = in_attn.transpose(1, 2)
assert len(context.shape) == 3, \
"Context shape does not match B,_,CONTEXT_DIM"
assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim,\
"Context shape does not match B,_,CONTEXT_DIM"
context_proj = self.context_proj[i](context)
out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
out = out + out_attn
return out
# ==========================================
# VQVAE Definition
# ==========================================
class VQVAE(nn.Module):
def __init__(self, im_channels, model_config):
super().__init__()
self.down_channels = model_config['down_channels']
self.mid_channels = model_config['mid_channels']
self.down_sample = model_config['down_sample']
self.num_down_layers = model_config['num_down_layers']
self.num_mid_layers = model_config['num_mid_layers']
self.num_up_layers = model_config['num_up_layers']
# To disable attention in Downblock of Encoder and Upblock of Decoder
self.attns = model_config['attn_down']
#Latent Dimension
self.z_channels = model_config['z_channels']
self.codebook_size = model_config['codebook_size']
self.norm_channels = model_config['norm_channels']
self.num_heads = model_config['num_heads']
#Assertion to validate the channel information
assert self.mid_channels[0] == self.down_channels[-1]
assert self.mid_channels[-1] == self.down_channels[-1]
assert len(self.down_sample) == len(self.down_channels) - 1
assert len(self.attns) == len(self.down_channels) - 1
# Wherever we use downsampling in encoder correspondingly use
# upsampling in decoder
self.up_sample = list(reversed(self.down_sample))
## Encoder ##
self.encoder_conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=(1, 1))
# Downblock + Midblock
self.encoder_layers = nn.ModuleList([])
for i in range(len(self.down_channels) - 1):
self.encoder_layers.append(DownBlock(self.down_channels[i], self.down_channels[i + 1],
t_emb_dim=None, down_sample=self.down_sample[i],
num_heads=self.num_heads,
num_layers=self.num_down_layers,
attn=self.attns[i],
norm_channels=self.norm_channels))
self.encoder_mids = nn.ModuleList([])
for i in range(len(self.mid_channels) - 1):
self.encoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i + 1],
t_emb_dim=None,
num_heads=self.num_heads,
num_layers=self.num_mid_layers,
norm_channels=self.norm_channels))
self.encoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[-1])
self.encoder_conv_out = nn.Conv2d(self.down_channels[-1], self.z_channels, kernel_size=3, padding=1)
# Pre Quantization Convolution
self.pre_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1)
# Codebook
self.embedding = nn.Embedding(self.codebook_size, self.z_channels)
## Decoder ##
# Post Quantization Convolution
self.post_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1)
self.decoder_conv_in = nn.Conv2d(self.z_channels, self.mid_channels[-1], kernel_size=3, padding=(1, 1))
# Midblock + Upblock
self.decoder_mids = nn.ModuleList([])
for i in reversed(range(1, len(self.mid_channels))):
self.decoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i - 1],
t_emb_dim=None,
num_heads=self.num_heads,
num_layers=self.num_mid_layers,
norm_channels=self.norm_channels))
self.decoder_layers = nn.ModuleList([])
for i in reversed(range(1, len(self.down_channels))):
self.decoder_layers.append(UpBlock(self.down_channels[i], self.down_channels[i - 1],
t_emb_dim=None, up_sample=self.down_sample[i - 1],
num_heads=self.num_heads,
num_layers=self.num_up_layers,
attn=self.attns[i-1],
norm_channels=self.norm_channels))
self.decoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[0])
self.decoder_conv_out = nn.Conv2d(self.down_channels[0], im_channels, kernel_size=3, padding=1)
def quantize(self, x):
B, C, H, W = x.shape
# B, C, H, W -> B, H, W, C
x = x.permute(0, 2, 3, 1)
# B, H, W, C -> B, H*W, C
x = x.reshape(x.size(0), -1, x.size(-1))
# Find nearest embedding/codebook vector
# dist between (B, H*W, C) and (B, K, C) -> (B, H*W, K)
dist = torch.cdist(x, self.embedding.weight[None, :].repeat((x.size(0), 1, 1)))
# (B, H*W)
min_encoding_indices = torch.argmin(dist, dim=-1)
# Replace encoder output with nearest codebook
# quant_out -> B*H*W, C
quant_out = torch.index_select(self.embedding.weight, 0, min_encoding_indices.view(-1))
# x -> B*H*W, C
x = x.reshape((-1, x.size(-1)))
commmitment_loss = torch.mean((quant_out.detach() - x) ** 2)
codebook_loss = torch.mean((quant_out - x.detach()) ** 2)
quantize_losses = {
'codebook_loss': codebook_loss,
'commitment_loss': commmitment_loss
}
# Straight through estimation
quant_out = x + (quant_out - x).detach()
# quant_out -> B, C, H, W
quant_out = quant_out.reshape((B, H, W, C)).permute(0, 3, 1, 2)
min_encoding_indices = min_encoding_indices.reshape((-1, quant_out.size(-2), quant_out.size(-1)))
return quant_out, quantize_losses, min_encoding_indices
def encode(self, x):
out = self.encoder_conv_in(x)
for idx, down in enumerate(self.encoder_layers):
out = down(out)
for mid in self.encoder_mids:
out = mid(out)
out = self.encoder_norm_out(out)
out = nn.SiLU()(out)
out = self.encoder_conv_out(out)
out = self.pre_quant_conv(out)
out, quant_losses, _ = self.quantize(out)
return out, quant_losses
def decode(self, z):
out = z
out = self.post_quant_conv(out)
out = self.decoder_conv_in(out)
for mid in self.decoder_mids:
out = mid(out)
for idx, up in enumerate(self.decoder_layers):
out = up(out)
out = self.decoder_norm_out(out)
out = nn.SiLU()(out)
out = self.decoder_conv_out(out)
return out
def forward(self, x):
z, quant_losses = self.encode(x)
out = self.decode(z)
return out, z, quant_losses
# ==========================================
# SPADE Definitions
# ==========================================
class SPADE(nn.Module):
def __init__(self, norm_nc, label_nc):
super().__init__()
self.param_free_norm = nn.GroupNorm(32, norm_nc)
nhidden = 128
# Convolutions to generate modulation parameters from the mask
self.mlp_shared = nn.Sequential(
nn.Conv2d(label_nc, nhidden, kernel_size=3, padding=1),
nn.ReLU()
)
self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
def forward(self, x, segmap):
# 1. Normalize
normalized = self.param_free_norm(x)
# 2. Resize mask to match x's resolution
if segmap.size()[2:] != x.size()[2:]:
segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
# 3. Generate params
actv = self.mlp_shared(segmap)
gamma = self.mlp_gamma(actv)
beta = self.mlp_beta(actv)
# 4. Modulate
out = normalized * (1 + gamma) + beta
return out
class SPADEResnetBlock(nn.Module):
"""
Simplified SPADE Block: Norm -> Act -> Conv
(We removed the internal shortcut because DownBlock/MidBlock handles the residual connection)
"""
def __init__(self, in_channels, out_channels, label_nc):
super().__init__()
# 1. SPADE Normalization (Uses Mask)
self.norm1 = SPADE(in_channels, label_nc)
# 2. Activation
self.act1 = nn.SiLU()
# 3. Convolution
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x, segmap):
# Apply SPADE Norm -> Act -> Conv
h = self.norm1(x, segmap)
h = self.act1(h)
h = self.conv1(h)
return h
# ==========================================
# BLOCKS (Down, Mid, Up)
# ==========================================
def get_time_embedding(time_steps, temb_dim):
assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2"
factor = 10000 ** ((torch.arange(
start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2))
)
t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor
t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)
return t_emb
class SpadeDownBlock(nn.Module):
def __init__(self, in_channels, out_channels, t_emb_dim, down_sample, num_heads,
num_layers, attn, norm_channels, cross_attn=False, context_dim=None, label_nc=4):
super().__init__()
self.num_layers = num_layers
self.down_sample = down_sample
self.attn = attn
self.context_dim = context_dim
self.cross_attn = cross_attn
self.t_emb_dim = t_emb_dim
# REPLACED nn.Sequential with SPADEResnetBlock
self.resnet_conv_first = nn.ModuleList([
SPADEResnetBlock(in_channels if i == 0 else out_channels, out_channels, label_nc)
for i in range(num_layers)
])
if self.t_emb_dim is not None:
self.t_emb_layers = nn.ModuleList([
nn.Sequential(nn.SiLU(), nn.Linear(self.t_emb_dim, out_channels))
for _ in range(num_layers)
])
# REPLACED nn.Sequential with SPADEResnetBlock
self.resnet_conv_second = nn.ModuleList([
SPADEResnetBlock(out_channels, out_channels, label_nc)
for _ in range(num_layers)
])
if self.attn:
self.attention_norms = nn.ModuleList([nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)])
self.attentions = nn.ModuleList([nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers)])
if self.cross_attn:
self.cross_attention_norms = nn.ModuleList([nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)])
self.cross_attentions = nn.ModuleList([nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers)])
self.context_proj = nn.ModuleList([nn.Linear(context_dim, out_channels) for _ in range(num_layers)])
self.residual_input_conv = nn.ModuleList([
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
for i in range(num_layers)
])
self.down_sample_conv = nn.Conv2d(out_channels, out_channels, 4, 2, 1) if self.down_sample else nn.Identity()
def forward(self, x, t_emb=None, context=None, segmap=None):
out = x
for i in range(self.num_layers):
resnet_input = out
# SPADE Block 1 (Pass segmap)
out = self.resnet_conv_first[i](out, segmap)
if self.t_emb_dim is not None:
out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
# SPADE Block 2 (Pass segmap)
out = self.resnet_conv_second[i](out, segmap)
# No residual add here because SPADEResnetBlock handles its own residual/shortcut
# But your original code added another residual from the very start of the loop
out = out + self.residual_input_conv[i](resnet_input)
if self.attn:
batch_size, channels, h, w = out.shape
in_attn = out.reshape(batch_size, channels, h * w)
in_attn = self.attention_norms[i](in_attn)
in_attn = in_attn.transpose(1, 2)
out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
out = out + out_attn
if self.cross_attn:
batch_size, channels, h, w = out.shape
in_attn = out.reshape(batch_size, channels, h * w)
in_attn = self.cross_attention_norms[i](in_attn)
in_attn = in_attn.transpose(1, 2)
context_proj = self.context_proj[i](context)
out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
out = out + out_attn
out = self.down_sample_conv(out)
return out
class SpadeMidBlock(nn.Module):
def __init__(self, in_channels, out_channels, t_emb_dim, num_heads, num_layers, norm_channels, cross_attn=None, context_dim=None, label_nc=4):
super().__init__()
self.num_layers = num_layers
self.t_emb_dim = t_emb_dim
self.context_dim = context_dim
self.cross_attn = cross_attn
# REPLACED with SPADE
self.resnet_conv_first = nn.ModuleList([
SPADEResnetBlock(in_channels if i == 0 else out_channels, out_channels, label_nc)
for i in range(num_layers + 1)
])
if self.t_emb_dim is not None:
self.t_emb_layers = nn.ModuleList([
nn.Sequential(nn.SiLU(), nn.Linear(t_emb_dim, out_channels))
for _ in range(num_layers + 1)
])
# REPLACED with SPADE
self.resnet_conv_second = nn.ModuleList([
SPADEResnetBlock(out_channels, out_channels, label_nc)
for _ in range(num_layers + 1)
])
self.attention_norms = nn.ModuleList([nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)])
self.attentions = nn.ModuleList([nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers)])
if self.cross_attn:
self.cross_attention_norms = nn.ModuleList([nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)])
self.cross_attentions = nn.ModuleList([nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers)])
self.context_proj = nn.ModuleList([nn.Linear(context_dim, out_channels) for _ in range(num_layers)])
self.residual_input_conv = nn.ModuleList([
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
for i in range(num_layers + 1)
])
def forward(self, x, t_emb=None, context=None, segmap=None):
out = x
# First Block (No Attention)
resnet_input = out
out = self.resnet_conv_first[0](out, segmap) # Pass segmap
if self.t_emb_dim is not None:
out = out + self.t_emb_layers[0](t_emb)[:, :, None, None]
out = self.resnet_conv_second[0](out, segmap) # Pass segmap
out = out + self.residual_input_conv[0](resnet_input)
for i in range(self.num_layers):
# Attention
batch_size, channels, h, w = out.shape
in_attn = out.reshape(batch_size, channels, h * w)
in_attn = self.attention_norms[i](in_attn)
in_attn = in_attn.transpose(1, 2)
out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
out = out + out_attn
if self.cross_attn:
batch_size, channels, h, w = out.shape
in_attn = out.reshape(batch_size, channels, h * w)
in_attn = self.cross_attention_norms[i](in_attn)
in_attn = in_attn.transpose(1, 2)
context_proj = self.context_proj[i](context)
out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
out = out + out_attn
# Next Resnet Block
resnet_input = out
out = self.resnet_conv_first[i + 1](out, segmap) # Pass segmap
if self.t_emb_dim is not None:
out = out + self.t_emb_layers[i + 1](t_emb)[:, :, None, None]
out = self.resnet_conv_second[i + 1](out, segmap) # Pass segmap
out = out + self.residual_input_conv[i + 1](resnet_input)
return out
class SpadeUpBlock(nn.Module):
def __init__(self, in_channels, out_channels, t_emb_dim, up_sample, num_heads,
num_layers, norm_channels, cross_attn=False, context_dim=None, label_nc=4):
super().__init__()
self.num_layers = num_layers
self.up_sample = up_sample
self.t_emb_dim = t_emb_dim
self.cross_attn = cross_attn
self.context_dim = context_dim
# REPLACED with SPADE
self.resnet_conv_first = nn.ModuleList([
SPADEResnetBlock(in_channels if i == 0 else out_channels, out_channels, label_nc)
for i in range(num_layers)
])
if self.t_emb_dim is not None:
self.t_emb_layers = nn.ModuleList([
nn.Sequential(nn.SiLU(), nn.Linear(t_emb_dim, out_channels))
for _ in range(num_layers)
])
# REPLACED with SPADE
self.resnet_conv_second = nn.ModuleList([
SPADEResnetBlock(out_channels, out_channels, label_nc)
for _ in range(num_layers)
])
self.attention_norms = nn.ModuleList([nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)])
self.attentions = nn.ModuleList([nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers)])
if self.cross_attn:
self.cross_attention_norms = nn.ModuleList([nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)])
self.cross_attentions = nn.ModuleList([nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers)])
self.context_proj = nn.ModuleList([nn.Linear(context_dim, out_channels) for _ in range(num_layers)])
self.residual_input_conv = nn.ModuleList([
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
for i in range(num_layers)
])
self.up_sample_conv = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, 4, 2, 1) if self.up_sample else nn.Identity()
def forward(self, x, out_down=None, t_emb=None, context=None, segmap=None):
x = self.up_sample_conv(x)
if out_down is not None:
x = torch.cat([x, out_down], dim=1)
out = x
for i in range(self.num_layers):
resnet_input = out
out = self.resnet_conv_first[i](out, segmap) # Pass segmap
if self.t_emb_dim is not None:
out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
out = self.resnet_conv_second[i](out, segmap) # Pass segmap
out = out + self.residual_input_conv[i](resnet_input)
batch_size, channels, h, w = out.shape
in_attn = out.reshape(batch_size, channels, h * w)
in_attn = self.attention_norms[i](in_attn)
in_attn = in_attn.transpose(1, 2)
out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
out = out + out_attn
if self.cross_attn:
batch_size, channels, h, w = out.shape
in_attn = out.reshape(batch_size, channels, h * w)
in_attn = self.cross_attention_norms[i](in_attn)
in_attn = in_attn.transpose(1, 2)
context_proj = self.context_proj[i](context)
out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
out = out + out_attn
return out
# ==========================================
# Helper Fuctions
# ==========================================
def validate_image_config(condition_config):
assert 'image_condition_config' in condition_config, "Image conditioning desired but config missing"
assert 'image_condition_input_channels' in condition_config['image_condition_config'], "Input channels missing"
assert 'image_condition_output_channels' in condition_config['image_condition_config'], "Output channels missing"
def validate_image_conditional_input(cond_input, x):
assert 'image' in cond_input, "Model initialized with image conditioning but input missing"
assert cond_input['image'].shape[0] == x.shape[0], "Batch size mismatch"
def get_config_value(config, key, default_value):
return config[key] if key in config else default_value
def get_time_embedding(time_steps, temb_dim):
assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2"
factor = 10000 ** ((torch.arange(
start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2))
)
t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor
t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)
return t_emb
def drop_image_condition(image_condition, im, im_drop_prob):
if im_drop_prob > 0:
im_drop_mask = torch.zeros((im.shape[0], 1, 1, 1), device=im.device).float().uniform_(0, 1) > im_drop_prob
return image_condition * im_drop_mask
else:
return image_condition
# ==========================================
# UNET Definition
# ==========================================
class Unet(nn.Module):
#Unet model with SPADE integration for anatomical consistency.
def __init__(self, im_channels, model_config):
super().__init__()
self.down_channels = model_config['down_channels']
self.mid_channels = model_config['mid_channels']
self.t_emb_dim = model_config['time_emb_dim']
self.down_sample = model_config['down_sample']
self.num_down_layers = model_config['num_down_layers']
self.num_mid_layers = model_config['num_mid_layers']
self.num_up_layers = model_config['num_up_layers']
self.attns = model_config['attn_down']
self.norm_channels = model_config['norm_channels']
self.num_heads = model_config['num_heads']
self.conv_out_channels = model_config['conv_out_channels']
# Validate Config
assert self.mid_channels[0] == self.down_channels[-1]
assert self.mid_channels[-1] == self.down_channels[-2]
assert len(self.down_sample) == len(self.down_channels) - 1
assert len(self.attns) == len(self.down_channels) - 1
# Conditioning Setup
self.image_cond = False
self.condition_config = get_config_value(model_config, 'condition_config', None)
# Default mask channels (usually 4: BG, LV, Myo, RV)
self.im_cond_input_ch = 4
if self.condition_config is not None:
if 'image' in self.condition_config.get('condition_types', []):
self.image_cond = True
self.im_cond_input_ch = self.condition_config['image_condition_config']['image_condition_input_channels']
self.im_cond_output_ch = self.condition_config['image_condition_config']['image_condition_output_channels']
# Standard Input Conv
# SPADE injects the mask later, so we just take the latent input here.
self.conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=1)
# Time Embedding
self.t_proj = nn.Sequential(
nn.Linear(self.t_emb_dim, self.t_emb_dim), nn.SiLU(), nn.Linear(self.t_emb_dim, self.t_emb_dim)
)
self.up_sample = list(reversed(self.down_sample))
self.downs = nn.ModuleList([])
# Pass label_nc to Blocks
for i in range(len(self.down_channels) - 1):
self.downs.append(SpadeDownBlock(
self.down_channels[i], self.down_channels[i + 1], self.t_emb_dim,
down_sample=self.down_sample[i], num_heads=self.num_heads,
num_layers=self.num_down_layers, attn=self.attns[i],
norm_channels=self.norm_channels,
label_nc=self.im_cond_input_ch # SPADE needs this
))
self.mids = nn.ModuleList([])
for i in range(len(self.mid_channels) - 1):
self.mids.append(SpadeMidBlock(
self.mid_channels[i], self.mid_channels[i + 1], self.t_emb_dim,
num_heads=self.num_heads, num_layers=self.num_mid_layers,
norm_channels=self.norm_channels,
label_nc=self.im_cond_input_ch # SPADE needs this
))
self.ups = nn.ModuleList([])
for i in reversed(range(len(self.down_channels) - 1)):
self.ups.append(SpadeUpBlock(
self.down_channels[i] * 2, self.down_channels[i - 1] if i != 0 else self.conv_out_channels,
self.t_emb_dim, up_sample=self.down_sample[i], num_heads=self.num_heads,
num_layers=self.num_up_layers, norm_channels=self.norm_channels,
label_nc=self.im_cond_input_ch # SPADE needs this
))
self.norm_out = nn.GroupNorm(self.norm_channels, self.conv_out_channels)
self.conv_out = nn.Conv2d(self.conv_out_channels, im_channels, kernel_size=3, padding=1)
def forward(self, x, t, cond_input=None):
# 1. Validation
if self.image_cond:
validate_image_conditional_input(cond_input, x)
# Get the mask, but don't concatenate yet
im_cond = cond_input['image']
else:
im_cond = None
# 2. Initial Conv (Standard)
out = self.conv_in(x)
# 3. Time Embedding
t_emb = get_time_embedding(torch.as_tensor(t).long(), self.t_emb_dim)
t_emb = self.t_proj(t_emb)
# 4. Down Blocks (Pass segmap)
down_outs = []
for down in self.downs:
down_outs.append(out)
# Inject Mask into Block
out = down(out, t_emb, segmap=im_cond)
# 5. Mid Blocks (Pass segmap)
for mid in self.mids:
# Inject Mask into Block
out = mid(out, t_emb, segmap=im_cond)
# 6. Up Blocks (Pass segmap)
for up in self.ups:
down_out = down_outs.pop()
# Inject Mask into Block
out = up(out, down_out, t_emb, segmap=im_cond)
out = self.norm_out(out)
out = nn.SiLU()(out)
out = self.conv_out(out)
return out
# ==========================================
# Noise Schedular Definition
# ==========================================
class LinearNoiseScheduler:
def __init__(self, num_timesteps, beta_start, beta_end):
self.num_timesteps = num_timesteps
self.beta_start = beta_start
self.beta_end = beta_end
self.betas = (torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_timesteps) ** 2)
self.alphas = 1. - self.betas
self.alpha_cum_prod = torch.cumprod(self.alphas, dim=0)
self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod)
self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1 - self.alpha_cum_prod)
def add_noise(self, original, noise, t):
original_shape = original.shape
batch_size = original_shape[0]
sqrt_alpha_cum_prod = self.sqrt_alpha_cum_prod.to(original.device)[t].reshape(batch_size)
sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod.to(original.device)[t].reshape(batch_size)
for _ in range(len(original_shape) - 1):
sqrt_alpha_cum_prod = sqrt_alpha_cum_prod.unsqueeze(-1)
sqrt_one_minus_alpha_cum_prod = sqrt_one_minus_alpha_cum_prod.unsqueeze(-1)
return (sqrt_alpha_cum_prod * original + sqrt_one_minus_alpha_cum_prod * noise)
def sample_prev_timestep(self, xt, noise_pred, t):
"""
Reverse diffusion process: Remove noise to get x_{t-1}
"""
sqrt_one_minus_alpha_bar = self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t].view(-1, 1, 1, 1)
sqrt_alpha_bar = self.sqrt_alpha_cum_prod.to(xt.device)[t].view(-1, 1, 1, 1)
beta_t = self.betas.to(xt.device)[t].view(-1, 1, 1, 1)
alpha_t = self.alphas.to(xt.device)[t].view(-1, 1, 1, 1)
# 1. Estimate x0 (Original image)
x0 = (xt - (sqrt_one_minus_alpha_bar * noise_pred)) / sqrt_alpha_bar
x0 = torch.clamp(x0, -1., 1.)
# 2. Calculate Mean of x_{t-1}
mean = (xt - (beta_t * noise_pred) / sqrt_one_minus_alpha_bar) / torch.sqrt(alpha_t)
# 3. Add Noise (if not last step)
if t[0] == 0:
return mean, x0
else:
# Reshape variance to [Batch, 1, 1, 1] too
variance = ((1 - self.alpha_cum_prod.to(xt.device)[t - 1]) / (1.0 - self.alpha_cum_prod.to(xt.device)[t])) * self.betas.to(xt.device)[t]
sigma = (variance ** 0.5).view(-1, 1, 1, 1)
z = torch.randn(xt.shape).to(xt.device)
return mean + sigma * z, x0
# 1. Estimate x0 (Original image)
# x0 = ((xt - (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t] * noise_pred)) /
# torch.sqrt(self.alpha_cum_prod.to(xt.device)[t]))
# x0 = torch.clamp(x0, -1., 1.)
# # 2. Calculate Mean of x_{t-1}
# mean = xt - ((self.betas.to(xt.device)[t]) * noise_pred) / (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t])
# mean = mean / torch.sqrt(self.alphas.to(xt.device)[t])
# # 3. Add Noise (if not last step)
# if t == 0:
# return mean, x0
# else:
# variance = (1 - self.alpha_cum_prod.to(xt.device)[t - 1]) / (1.0 - self.alpha_cum_prod.to(xt.device)[t])
# variance = variance * self.betas.to(xt.device)[t]
# sigma = variance ** 0.5
# z = torch.randn(xt.shape).to(xt.device)
# return mean + sigma * z, x0