DDPM-6param / src /unet_conditional.py
collins909's picture
Upload 6-parameter conditional DDPM (HI emulation, CAMELS LH params_6, best checkpoint)
c46900a verified
"""
Conditional U-Net Architecture for Diffusion Model
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class TimeEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, time):
device = time.device
half_dim = self.dim // 2
embeddings = math.log(10000) / (half_dim - 1)
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
embeddings = time[:, None] * embeddings[None, :]
return torch.cat([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)
class LabelEmbedding(nn.Module):
def __init__(self, label_dim, emb_dim):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(label_dim, emb_dim),
nn.SiLU(),
nn.Linear(emb_dim, emb_dim)
)
def forward(self, labels):
return self.mlp(labels)
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, time_emb_dim, dropout=0.1):
super().__init__()
self.conv1 = nn.Sequential(
nn.GroupNorm(8, in_channels),
nn.SiLU(),
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
)
self.time_emb = nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, out_channels))
self.conv2 = nn.Sequential(
nn.GroupNorm(8, out_channels),
nn.SiLU(),
nn.Dropout(dropout),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
)
self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity()
def forward(self, x, emb):
h = self.conv1(x)
h = h + self.time_emb(emb)[:, :, None, None]
h = self.conv2(h)
return h + self.shortcut(x)
class AttentionBlock(nn.Module):
def __init__(self, channels, num_heads=4):
super().__init__()
self.channels = channels
self.num_heads = num_heads
self.norm = nn.GroupNorm(8, channels)
self.qkv = nn.Conv2d(channels, channels * 3, kernel_size=1)
self.proj = nn.Conv2d(channels, channels, kernel_size=1)
def forward(self, x):
B, C, H, W = x.shape
h = self.norm(x)
q, k, v = self.qkv(h).chunk(3, dim=1)
head_dim = C // self.num_heads
q = q.view(B, self.num_heads, head_dim, H*W).transpose(2, 3)
k = k.view(B, self.num_heads, head_dim, H*W).transpose(2, 3)
v = v.view(B, self.num_heads, head_dim, H*W).transpose(2, 3)
h = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0)
h = h.transpose(2, 3).reshape(B, C, H, W)
return x + self.proj(h)
class ConditionalUNet(nn.Module):
def __init__(self, in_channels=1, out_channels=1, label_dim=2,
base_channels=64, channel_multipliers=(1,2,4,8),
attention_levels=(2,3), dropout=0.1, time_emb_dim=256, label_emb_dim=256):
super().__init__()
self.label_dim = label_dim
self.time_embedding = TimeEmbedding(time_emb_dim)
self.time_mlp = nn.Sequential(nn.Linear(time_emb_dim, time_emb_dim*4), nn.SiLU(), nn.Linear(time_emb_dim*4, time_emb_dim))
self.label_embedding = LabelEmbedding(label_dim, label_emb_dim)
self.combined_emb_dim = time_emb_dim + label_emb_dim
self.combined_mlp = nn.Sequential(nn.Linear(self.combined_emb_dim, time_emb_dim*4), nn.SiLU(), nn.Linear(time_emb_dim*4, time_emb_dim))
self.conv_in = nn.Conv2d(in_channels, base_channels, kernel_size=3, padding=1)
self.down_blocks = nn.ModuleList()
channels = [base_channels]
now_channels = base_channels
for i, mult in enumerate(channel_multipliers):
out_ch = base_channels * mult
for _ in range(2):
self.down_blocks.append(ResidualBlock(now_channels, out_ch, time_emb_dim, dropout))
if i in attention_levels:
self.down_blocks.append(AttentionBlock(out_ch))
now_channels = out_ch
channels.append(now_channels)
if i != len(channel_multipliers) - 1:
self.down_blocks.append(nn.Conv2d(now_channels, now_channels, kernel_size=3, stride=2, padding=1))
channels.append(now_channels)
self.middle = nn.ModuleList([
ResidualBlock(now_channels, now_channels, time_emb_dim, dropout),
AttentionBlock(now_channels),
ResidualBlock(now_channels, now_channels, time_emb_dim, dropout)
])
self.up_blocks = nn.ModuleList()
for i, mult in reversed(list(enumerate(channel_multipliers))):
out_ch = base_channels * mult
for _ in range(3):
self.up_blocks.append(ResidualBlock(now_channels + channels.pop(), out_ch, time_emb_dim, dropout))
if i in attention_levels:
self.up_blocks.append(AttentionBlock(out_ch))
now_channels = out_ch
if i != 0:
self.up_blocks.append(nn.ConvTranspose2d(now_channels, now_channels, kernel_size=4, stride=2, padding=1))
self.conv_out = nn.Sequential(
nn.GroupNorm(8, now_channels),
nn.SiLU(),
nn.Conv2d(now_channels, out_channels, kernel_size=3, padding=1)
)
def forward(self, x, t, labels=None):
t_emb = self.time_embedding(t)
t_emb = self.time_mlp(t_emb)
if labels is not None:
label_emb = self.label_embedding(labels)
combined = torch.cat([t_emb, label_emb], dim=-1)
emb = self.combined_mlp(combined)
else:
emb = t_emb
h = self.conv_in(x)
skips = [h]
for module in self.down_blocks:
if isinstance(module, ResidualBlock):
h = module(h, emb)
skips.append(h)
elif isinstance(module, AttentionBlock):
h = module(h)
else:
h = module(h)
skips.append(h)
for module in self.middle:
if isinstance(module, ResidualBlock):
h = module(h, emb)
else:
h = module(h)
for module in self.up_blocks:
if isinstance(module, ResidualBlock):
h = torch.cat([h, skips.pop()], dim=1)
h = module(h, emb)
elif isinstance(module, AttentionBlock):
h = module(h)
else:
h = module(h)
return self.conv_out(h)