File size: 6,643 Bytes
c46900a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 | """
Conditional U-Net Architecture for Diffusion Model
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class TimeEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, time):
device = time.device
half_dim = self.dim // 2
embeddings = math.log(10000) / (half_dim - 1)
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
embeddings = time[:, None] * embeddings[None, :]
return torch.cat([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)
class LabelEmbedding(nn.Module):
def __init__(self, label_dim, emb_dim):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(label_dim, emb_dim),
nn.SiLU(),
nn.Linear(emb_dim, emb_dim)
)
def forward(self, labels):
return self.mlp(labels)
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, time_emb_dim, dropout=0.1):
super().__init__()
self.conv1 = nn.Sequential(
nn.GroupNorm(8, in_channels),
nn.SiLU(),
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
)
self.time_emb = nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, out_channels))
self.conv2 = nn.Sequential(
nn.GroupNorm(8, out_channels),
nn.SiLU(),
nn.Dropout(dropout),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
)
self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity()
def forward(self, x, emb):
h = self.conv1(x)
h = h + self.time_emb(emb)[:, :, None, None]
h = self.conv2(h)
return h + self.shortcut(x)
class AttentionBlock(nn.Module):
def __init__(self, channels, num_heads=4):
super().__init__()
self.channels = channels
self.num_heads = num_heads
self.norm = nn.GroupNorm(8, channels)
self.qkv = nn.Conv2d(channels, channels * 3, kernel_size=1)
self.proj = nn.Conv2d(channels, channels, kernel_size=1)
def forward(self, x):
B, C, H, W = x.shape
h = self.norm(x)
q, k, v = self.qkv(h).chunk(3, dim=1)
head_dim = C // self.num_heads
q = q.view(B, self.num_heads, head_dim, H*W).transpose(2, 3)
k = k.view(B, self.num_heads, head_dim, H*W).transpose(2, 3)
v = v.view(B, self.num_heads, head_dim, H*W).transpose(2, 3)
h = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0)
h = h.transpose(2, 3).reshape(B, C, H, W)
return x + self.proj(h)
class ConditionalUNet(nn.Module):
def __init__(self, in_channels=1, out_channels=1, label_dim=2,
base_channels=64, channel_multipliers=(1,2,4,8),
attention_levels=(2,3), dropout=0.1, time_emb_dim=256, label_emb_dim=256):
super().__init__()
self.label_dim = label_dim
self.time_embedding = TimeEmbedding(time_emb_dim)
self.time_mlp = nn.Sequential(nn.Linear(time_emb_dim, time_emb_dim*4), nn.SiLU(), nn.Linear(time_emb_dim*4, time_emb_dim))
self.label_embedding = LabelEmbedding(label_dim, label_emb_dim)
self.combined_emb_dim = time_emb_dim + label_emb_dim
self.combined_mlp = nn.Sequential(nn.Linear(self.combined_emb_dim, time_emb_dim*4), nn.SiLU(), nn.Linear(time_emb_dim*4, time_emb_dim))
self.conv_in = nn.Conv2d(in_channels, base_channels, kernel_size=3, padding=1)
self.down_blocks = nn.ModuleList()
channels = [base_channels]
now_channels = base_channels
for i, mult in enumerate(channel_multipliers):
out_ch = base_channels * mult
for _ in range(2):
self.down_blocks.append(ResidualBlock(now_channels, out_ch, time_emb_dim, dropout))
if i in attention_levels:
self.down_blocks.append(AttentionBlock(out_ch))
now_channels = out_ch
channels.append(now_channels)
if i != len(channel_multipliers) - 1:
self.down_blocks.append(nn.Conv2d(now_channels, now_channels, kernel_size=3, stride=2, padding=1))
channels.append(now_channels)
self.middle = nn.ModuleList([
ResidualBlock(now_channels, now_channels, time_emb_dim, dropout),
AttentionBlock(now_channels),
ResidualBlock(now_channels, now_channels, time_emb_dim, dropout)
])
self.up_blocks = nn.ModuleList()
for i, mult in reversed(list(enumerate(channel_multipliers))):
out_ch = base_channels * mult
for _ in range(3):
self.up_blocks.append(ResidualBlock(now_channels + channels.pop(), out_ch, time_emb_dim, dropout))
if i in attention_levels:
self.up_blocks.append(AttentionBlock(out_ch))
now_channels = out_ch
if i != 0:
self.up_blocks.append(nn.ConvTranspose2d(now_channels, now_channels, kernel_size=4, stride=2, padding=1))
self.conv_out = nn.Sequential(
nn.GroupNorm(8, now_channels),
nn.SiLU(),
nn.Conv2d(now_channels, out_channels, kernel_size=3, padding=1)
)
def forward(self, x, t, labels=None):
t_emb = self.time_embedding(t)
t_emb = self.time_mlp(t_emb)
if labels is not None:
label_emb = self.label_embedding(labels)
combined = torch.cat([t_emb, label_emb], dim=-1)
emb = self.combined_mlp(combined)
else:
emb = t_emb
h = self.conv_in(x)
skips = [h]
for module in self.down_blocks:
if isinstance(module, ResidualBlock):
h = module(h, emb)
skips.append(h)
elif isinstance(module, AttentionBlock):
h = module(h)
else:
h = module(h)
skips.append(h)
for module in self.middle:
if isinstance(module, ResidualBlock):
h = module(h, emb)
else:
h = module(h)
for module in self.up_blocks:
if isinstance(module, ResidualBlock):
h = torch.cat([h, skips.pop()], dim=1)
h = module(h, emb)
elif isinstance(module, AttentionBlock):
h = module(h)
else:
h = module(h)
return self.conv_out(h)
|