""" U-Net architecture for conditional diffusion on spatiotemporal PDE data. Supports non-square inputs, time conditioning, and skip connections. """ import math import torch import torch.nn as nn import torch.nn.functional as F class SinusoidalPosEmb(nn.Module): """Sinusoidal positional embedding for diffusion timestep.""" def __init__(self, dim): super().__init__() self.dim = dim def forward(self, t): half = self.dim // 2 emb = math.log(10000) / (half - 1) emb = torch.exp(torch.arange(half, device=t.device) * -emb) emb = t[:, None].float() * emb[None, :] return torch.cat([emb.sin(), emb.cos()], dim=-1) class ResBlock(nn.Module): """Residual block with group norm, SiLU, and time embedding injection.""" def __init__(self, in_ch, out_ch, time_dim, dropout=0.1): super().__init__() self.norm1 = nn.GroupNorm(min(32, in_ch), in_ch) self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1) self.time_mlp = nn.Sequential(nn.SiLU(), nn.Linear(time_dim, out_ch)) self.norm2 = nn.GroupNorm(min(32, out_ch), out_ch) self.dropout = nn.Dropout(dropout) self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1) self.skip = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity() def forward(self, x, t_emb): h = F.silu(self.norm1(x)) h = self.conv1(h) h = h + self.time_mlp(t_emb)[:, :, None, None] h = F.silu(self.norm2(h)) h = self.dropout(h) h = self.conv2(h) return h + self.skip(x) class SelfAttention(nn.Module): """Multi-head self-attention on spatial features.""" def __init__(self, channels, num_heads=4): super().__init__() self.norm = nn.GroupNorm(min(32, channels), channels) self.attn = nn.MultiheadAttention(channels, num_heads, batch_first=True) def forward(self, x): B, C, H, W = x.shape h = self.norm(x).reshape(B, C, H * W).permute(0, 2, 1) h, _ = self.attn(h, h, h) h = h.permute(0, 2, 1).reshape(B, C, H, W) return x + h class Downsample(nn.Module): def __init__(self, ch): super().__init__() self.conv = nn.Conv2d(ch, ch, 3, stride=2, padding=1) def forward(self, x): return self.conv(x) class Upsample(nn.Module): def __init__(self, ch): super().__init__() self.conv = nn.Conv2d(ch, ch, 3, padding=1) def forward(self, x): x = F.interpolate(x, scale_factor=2, mode="nearest") return self.conv(x) class UNet(nn.Module): """U-Net for conditional diffusion. Condition (e.g. previous frame) is concatenated to the noisy input along the channel dimension *before* being passed to forward(). So set ``in_channels = output_channels + condition_channels``. Args: in_channels: noisy-target channels + condition channels. out_channels: channels to predict (same as target). base_ch: base channel width. ch_mults: per-level channel multipliers. n_res: residual blocks per level. attn_levels: which levels get self-attention (0-indexed). dropout: dropout rate. time_dim: timestep embedding dimension. """ def __init__( self, in_channels, out_channels, base_ch=64, ch_mults=(1, 2, 4, 8), n_res=2, attn_levels=(3,), dropout=0.1, time_dim=256, ): super().__init__() self.n_res = n_res self.ch_mults = ch_mults # --- time embedding --- self.time_embed = nn.Sequential( SinusoidalPosEmb(time_dim), nn.Linear(time_dim, time_dim * 4), nn.SiLU(), nn.Linear(time_dim * 4, time_dim), ) # --- input projection --- self.input_conv = nn.Conv2d(in_channels, base_ch, 3, padding=1) # --- downsampling path --- self.downs = nn.ModuleList() ch = base_ch skip_chs = [ch] # track channel dims for skip connections for lvl, mult in enumerate(ch_mults): out_ch = base_ch * mult for _ in range(n_res): self.downs.append( nn.ModuleDict( { "res": ResBlock(ch, out_ch, time_dim, dropout), **( {"attn": SelfAttention(out_ch)} if lvl in attn_levels else {} ), } ) ) ch = out_ch skip_chs.append(ch) if lvl < len(ch_mults) - 1: self.downs.append(nn.ModuleDict({"down": Downsample(ch)})) skip_chs.append(ch) # --- middle --- self.mid_res1 = ResBlock(ch, ch, time_dim, dropout) self.mid_attn = SelfAttention(ch) self.mid_res2 = ResBlock(ch, ch, time_dim, dropout) # --- upsampling path --- self.ups = nn.ModuleList() for lvl in reversed(range(len(ch_mults))): out_ch = base_ch * ch_mults[lvl] for _ in range(n_res + 1): # +1 to consume downsample skip skip_ch = skip_chs.pop() self.ups.append( nn.ModuleDict( { "res": ResBlock(ch + skip_ch, out_ch, time_dim, dropout), **( {"attn": SelfAttention(out_ch)} if lvl in attn_levels else {} ), } ) ) ch = out_ch if lvl > 0: self.ups.append(nn.ModuleDict({"up": Upsample(ch)})) # --- output projection --- self.out_norm = nn.GroupNorm(min(32, ch), ch) self.out_conv = nn.Conv2d(ch, out_channels, 3, padding=1) def forward(self, x, t, cond=None): """ Args: x: noisy target [B, C_out, H, W] t: diffusion timestep [B] (int or float) cond: condition [B, C_cond, H, W] (optional, concatenated) Returns: predicted noise [B, C_out, H, W] """ if cond is not None: x = torch.cat([x, cond], dim=1) t_emb = self.time_embed(t) h = self.input_conv(x) # --- down --- skips = [h] for block in self.downs: if "down" in block: h = block["down"](h) skips.append(h) else: h = block["res"](h, t_emb) if "attn" in block: h = block["attn"](h) skips.append(h) # --- middle --- h = self.mid_res1(h, t_emb) h = self.mid_attn(h) h = self.mid_res2(h, t_emb) # --- up --- for block in self.ups: if "up" in block: h = block["up"](h) else: s = skips.pop() h = torch.cat([h, s], dim=1) h = block["res"](h, t_emb) if "attn" in block: h = block["attn"](h) h = F.silu(self.out_norm(h)) return self.out_conv(h)