""" Conditional U-Net Architecture for Diffusion Model """ import torch import torch.nn as nn import torch.nn.functional as F import math class TimeEmbedding(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, time): device = time.device half_dim = self.dim // 2 embeddings = math.log(10000) / (half_dim - 1) embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) embeddings = time[:, None] * embeddings[None, :] return torch.cat([torch.sin(embeddings), torch.cos(embeddings)], dim=-1) class LabelEmbedding(nn.Module): def __init__(self, label_dim, emb_dim): super().__init__() self.mlp = nn.Sequential( nn.Linear(label_dim, emb_dim), nn.SiLU(), nn.Linear(emb_dim, emb_dim) ) def forward(self, labels): return self.mlp(labels) class ResidualBlock(nn.Module): def __init__(self, in_channels, out_channels, time_emb_dim, dropout=0.1): super().__init__() self.conv1 = nn.Sequential( nn.GroupNorm(8, in_channels), nn.SiLU(), nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) ) self.time_emb = nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, out_channels)) self.conv2 = nn.Sequential( nn.GroupNorm(8, out_channels), nn.SiLU(), nn.Dropout(dropout), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) ) self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity() def forward(self, x, emb): h = self.conv1(x) h = h + self.time_emb(emb)[:, :, None, None] h = self.conv2(h) return h + self.shortcut(x) class AttentionBlock(nn.Module): def __init__(self, channels, num_heads=4): super().__init__() self.channels = channels self.num_heads = num_heads self.norm = nn.GroupNorm(8, channels) self.qkv = nn.Conv2d(channels, channels * 3, kernel_size=1) self.proj = nn.Conv2d(channels, channels, kernel_size=1) def forward(self, x): B, C, H, W = x.shape h = self.norm(x) q, k, v = self.qkv(h).chunk(3, dim=1) head_dim = C // self.num_heads q = q.view(B, self.num_heads, head_dim, H*W).transpose(2, 3) k = k.view(B, self.num_heads, head_dim, H*W).transpose(2, 3) v = v.view(B, self.num_heads, head_dim, H*W).transpose(2, 3) h = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0) h = h.transpose(2, 3).reshape(B, C, H, W) return x + self.proj(h) class ConditionalUNet(nn.Module): def __init__(self, in_channels=1, out_channels=1, label_dim=2, base_channels=64, channel_multipliers=(1,2,4,8), attention_levels=(2,3), dropout=0.1, time_emb_dim=256, label_emb_dim=256): super().__init__() self.label_dim = label_dim self.time_embedding = TimeEmbedding(time_emb_dim) self.time_mlp = nn.Sequential(nn.Linear(time_emb_dim, time_emb_dim*4), nn.SiLU(), nn.Linear(time_emb_dim*4, time_emb_dim)) self.label_embedding = LabelEmbedding(label_dim, label_emb_dim) self.combined_emb_dim = time_emb_dim + label_emb_dim self.combined_mlp = nn.Sequential(nn.Linear(self.combined_emb_dim, time_emb_dim*4), nn.SiLU(), nn.Linear(time_emb_dim*4, time_emb_dim)) self.conv_in = nn.Conv2d(in_channels, base_channels, kernel_size=3, padding=1) self.down_blocks = nn.ModuleList() channels = [base_channels] now_channels = base_channels for i, mult in enumerate(channel_multipliers): out_ch = base_channels * mult for _ in range(2): self.down_blocks.append(ResidualBlock(now_channels, out_ch, time_emb_dim, dropout)) if i in attention_levels: self.down_blocks.append(AttentionBlock(out_ch)) now_channels = out_ch channels.append(now_channels) if i != len(channel_multipliers) - 1: self.down_blocks.append(nn.Conv2d(now_channels, now_channels, kernel_size=3, stride=2, padding=1)) channels.append(now_channels) self.middle = nn.ModuleList([ ResidualBlock(now_channels, now_channels, time_emb_dim, dropout), AttentionBlock(now_channels), ResidualBlock(now_channels, now_channels, time_emb_dim, dropout) ]) self.up_blocks = nn.ModuleList() for i, mult in reversed(list(enumerate(channel_multipliers))): out_ch = base_channels * mult for _ in range(3): self.up_blocks.append(ResidualBlock(now_channels + channels.pop(), out_ch, time_emb_dim, dropout)) if i in attention_levels: self.up_blocks.append(AttentionBlock(out_ch)) now_channels = out_ch if i != 0: self.up_blocks.append(nn.ConvTranspose2d(now_channels, now_channels, kernel_size=4, stride=2, padding=1)) self.conv_out = nn.Sequential( nn.GroupNorm(8, now_channels), nn.SiLU(), nn.Conv2d(now_channels, out_channels, kernel_size=3, padding=1) ) def forward(self, x, t, labels=None): t_emb = self.time_embedding(t) t_emb = self.time_mlp(t_emb) if labels is not None: label_emb = self.label_embedding(labels) combined = torch.cat([t_emb, label_emb], dim=-1) emb = self.combined_mlp(combined) else: emb = t_emb h = self.conv_in(x) skips = [h] for module in self.down_blocks: if isinstance(module, ResidualBlock): h = module(h, emb) skips.append(h) elif isinstance(module, AttentionBlock): h = module(h) else: h = module(h) skips.append(h) for module in self.middle: if isinstance(module, ResidualBlock): h = module(h, emb) else: h = module(h) for module in self.up_blocks: if isinstance(module, ResidualBlock): h = torch.cat([h, skips.pop()], dim=1) h = module(h, emb) elif isinstance(module, AttentionBlock): h = module(h) else: h = module(h) return self.conv_out(h)