|
|
""" |
|
|
U-Net architecture for conditional diffusion on spatiotemporal PDE data. |
|
|
Supports non-square inputs, time conditioning, and skip connections. |
|
|
""" |
|
|
import math |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
class SinusoidalPosEmb(nn.Module): |
|
|
"""Sinusoidal positional embedding for diffusion timestep.""" |
|
|
|
|
|
def __init__(self, dim): |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
|
|
|
def forward(self, t): |
|
|
half = self.dim // 2 |
|
|
emb = math.log(10000) / (half - 1) |
|
|
emb = torch.exp(torch.arange(half, device=t.device) * -emb) |
|
|
emb = t[:, None].float() * emb[None, :] |
|
|
return torch.cat([emb.sin(), emb.cos()], dim=-1) |
|
|
|
|
|
|
|
|
class ResBlock(nn.Module): |
|
|
"""Residual block with group norm, SiLU, and time embedding injection.""" |
|
|
|
|
|
def __init__(self, in_ch, out_ch, time_dim, dropout=0.1): |
|
|
super().__init__() |
|
|
self.norm1 = nn.GroupNorm(min(32, in_ch), in_ch) |
|
|
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1) |
|
|
self.time_mlp = nn.Sequential(nn.SiLU(), nn.Linear(time_dim, out_ch)) |
|
|
self.norm2 = nn.GroupNorm(min(32, out_ch), out_ch) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1) |
|
|
self.skip = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity() |
|
|
|
|
|
def forward(self, x, t_emb): |
|
|
h = F.silu(self.norm1(x)) |
|
|
h = self.conv1(h) |
|
|
h = h + self.time_mlp(t_emb)[:, :, None, None] |
|
|
h = F.silu(self.norm2(h)) |
|
|
h = self.dropout(h) |
|
|
h = self.conv2(h) |
|
|
return h + self.skip(x) |
|
|
|
|
|
|
|
|
class SelfAttention(nn.Module): |
|
|
"""Multi-head self-attention on spatial features.""" |
|
|
|
|
|
def __init__(self, channels, num_heads=4): |
|
|
super().__init__() |
|
|
self.norm = nn.GroupNorm(min(32, channels), channels) |
|
|
self.attn = nn.MultiheadAttention(channels, num_heads, batch_first=True) |
|
|
|
|
|
def forward(self, x): |
|
|
B, C, H, W = x.shape |
|
|
h = self.norm(x).reshape(B, C, H * W).permute(0, 2, 1) |
|
|
h, _ = self.attn(h, h, h) |
|
|
h = h.permute(0, 2, 1).reshape(B, C, H, W) |
|
|
return x + h |
|
|
|
|
|
|
|
|
class Downsample(nn.Module): |
|
|
def __init__(self, ch): |
|
|
super().__init__() |
|
|
self.conv = nn.Conv2d(ch, ch, 3, stride=2, padding=1) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.conv(x) |
|
|
|
|
|
|
|
|
class Upsample(nn.Module): |
|
|
def __init__(self, ch): |
|
|
super().__init__() |
|
|
self.conv = nn.Conv2d(ch, ch, 3, padding=1) |
|
|
|
|
|
def forward(self, x): |
|
|
x = F.interpolate(x, scale_factor=2, mode="nearest") |
|
|
return self.conv(x) |
|
|
|
|
|
|
|
|
class UNet(nn.Module): |
|
|
"""U-Net for conditional diffusion. |
|
|
|
|
|
Condition (e.g. previous frame) is concatenated to the noisy input along |
|
|
the channel dimension *before* being passed to forward(). So set |
|
|
``in_channels = output_channels + condition_channels``. |
|
|
|
|
|
Args: |
|
|
in_channels: noisy-target channels + condition channels. |
|
|
out_channels: channels to predict (same as target). |
|
|
base_ch: base channel width. |
|
|
ch_mults: per-level channel multipliers. |
|
|
n_res: residual blocks per level. |
|
|
attn_levels: which levels get self-attention (0-indexed). |
|
|
dropout: dropout rate. |
|
|
time_dim: timestep embedding dimension. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
in_channels, |
|
|
out_channels, |
|
|
base_ch=64, |
|
|
ch_mults=(1, 2, 4, 8), |
|
|
n_res=2, |
|
|
attn_levels=(3,), |
|
|
dropout=0.1, |
|
|
time_dim=256, |
|
|
): |
|
|
super().__init__() |
|
|
self.n_res = n_res |
|
|
self.ch_mults = ch_mults |
|
|
|
|
|
|
|
|
self.time_embed = nn.Sequential( |
|
|
SinusoidalPosEmb(time_dim), |
|
|
nn.Linear(time_dim, time_dim * 4), |
|
|
nn.SiLU(), |
|
|
nn.Linear(time_dim * 4, time_dim), |
|
|
) |
|
|
|
|
|
|
|
|
self.input_conv = nn.Conv2d(in_channels, base_ch, 3, padding=1) |
|
|
|
|
|
|
|
|
self.downs = nn.ModuleList() |
|
|
ch = base_ch |
|
|
skip_chs = [ch] |
|
|
|
|
|
for lvl, mult in enumerate(ch_mults): |
|
|
out_ch = base_ch * mult |
|
|
for _ in range(n_res): |
|
|
self.downs.append( |
|
|
nn.ModuleDict( |
|
|
{ |
|
|
"res": ResBlock(ch, out_ch, time_dim, dropout), |
|
|
**( |
|
|
{"attn": SelfAttention(out_ch)} |
|
|
if lvl in attn_levels |
|
|
else {} |
|
|
), |
|
|
} |
|
|
) |
|
|
) |
|
|
ch = out_ch |
|
|
skip_chs.append(ch) |
|
|
if lvl < len(ch_mults) - 1: |
|
|
self.downs.append(nn.ModuleDict({"down": Downsample(ch)})) |
|
|
skip_chs.append(ch) |
|
|
|
|
|
|
|
|
self.mid_res1 = ResBlock(ch, ch, time_dim, dropout) |
|
|
self.mid_attn = SelfAttention(ch) |
|
|
self.mid_res2 = ResBlock(ch, ch, time_dim, dropout) |
|
|
|
|
|
|
|
|
self.ups = nn.ModuleList() |
|
|
for lvl in reversed(range(len(ch_mults))): |
|
|
out_ch = base_ch * ch_mults[lvl] |
|
|
for _ in range(n_res + 1): |
|
|
skip_ch = skip_chs.pop() |
|
|
self.ups.append( |
|
|
nn.ModuleDict( |
|
|
{ |
|
|
"res": ResBlock(ch + skip_ch, out_ch, time_dim, dropout), |
|
|
**( |
|
|
{"attn": SelfAttention(out_ch)} |
|
|
if lvl in attn_levels |
|
|
else {} |
|
|
), |
|
|
} |
|
|
) |
|
|
) |
|
|
ch = out_ch |
|
|
if lvl > 0: |
|
|
self.ups.append(nn.ModuleDict({"up": Upsample(ch)})) |
|
|
|
|
|
|
|
|
self.out_norm = nn.GroupNorm(min(32, ch), ch) |
|
|
self.out_conv = nn.Conv2d(ch, out_channels, 3, padding=1) |
|
|
|
|
|
def forward(self, x, t, cond=None): |
|
|
""" |
|
|
Args: |
|
|
x: noisy target [B, C_out, H, W] |
|
|
t: diffusion timestep [B] (int or float) |
|
|
cond: condition [B, C_cond, H, W] (optional, concatenated) |
|
|
Returns: |
|
|
predicted noise [B, C_out, H, W] |
|
|
""" |
|
|
if cond is not None: |
|
|
x = torch.cat([x, cond], dim=1) |
|
|
|
|
|
t_emb = self.time_embed(t) |
|
|
h = self.input_conv(x) |
|
|
|
|
|
|
|
|
skips = [h] |
|
|
for block in self.downs: |
|
|
if "down" in block: |
|
|
h = block["down"](h) |
|
|
skips.append(h) |
|
|
else: |
|
|
h = block["res"](h, t_emb) |
|
|
if "attn" in block: |
|
|
h = block["attn"](h) |
|
|
skips.append(h) |
|
|
|
|
|
|
|
|
h = self.mid_res1(h, t_emb) |
|
|
h = self.mid_attn(h) |
|
|
h = self.mid_res2(h, t_emb) |
|
|
|
|
|
|
|
|
for block in self.ups: |
|
|
if "up" in block: |
|
|
h = block["up"](h) |
|
|
else: |
|
|
s = skips.pop() |
|
|
h = torch.cat([h, s], dim=1) |
|
|
h = block["res"](h, t_emb) |
|
|
if "attn" in block: |
|
|
h = block["attn"](h) |
|
|
|
|
|
h = F.silu(self.out_norm(h)) |
|
|
return self.out_conv(h) |
|
|
|