AlexWortega's picture
Upload unet.py with huggingface_hub
54c8086 verified
"""
U-Net architecture for conditional diffusion on spatiotemporal PDE data.
Supports non-square inputs, time conditioning, and skip connections.
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class SinusoidalPosEmb(nn.Module):
"""Sinusoidal positional embedding for diffusion timestep."""
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, t):
half = self.dim // 2
emb = math.log(10000) / (half - 1)
emb = torch.exp(torch.arange(half, device=t.device) * -emb)
emb = t[:, None].float() * emb[None, :]
return torch.cat([emb.sin(), emb.cos()], dim=-1)
class ResBlock(nn.Module):
"""Residual block with group norm, SiLU, and time embedding injection."""
def __init__(self, in_ch, out_ch, time_dim, dropout=0.1):
super().__init__()
self.norm1 = nn.GroupNorm(min(32, in_ch), in_ch)
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
self.time_mlp = nn.Sequential(nn.SiLU(), nn.Linear(time_dim, out_ch))
self.norm2 = nn.GroupNorm(min(32, out_ch), out_ch)
self.dropout = nn.Dropout(dropout)
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
self.skip = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
def forward(self, x, t_emb):
h = F.silu(self.norm1(x))
h = self.conv1(h)
h = h + self.time_mlp(t_emb)[:, :, None, None]
h = F.silu(self.norm2(h))
h = self.dropout(h)
h = self.conv2(h)
return h + self.skip(x)
class SelfAttention(nn.Module):
"""Multi-head self-attention on spatial features."""
def __init__(self, channels, num_heads=4):
super().__init__()
self.norm = nn.GroupNorm(min(32, channels), channels)
self.attn = nn.MultiheadAttention(channels, num_heads, batch_first=True)
def forward(self, x):
B, C, H, W = x.shape
h = self.norm(x).reshape(B, C, H * W).permute(0, 2, 1)
h, _ = self.attn(h, h, h)
h = h.permute(0, 2, 1).reshape(B, C, H, W)
return x + h
class Downsample(nn.Module):
def __init__(self, ch):
super().__init__()
self.conv = nn.Conv2d(ch, ch, 3, stride=2, padding=1)
def forward(self, x):
return self.conv(x)
class Upsample(nn.Module):
def __init__(self, ch):
super().__init__()
self.conv = nn.Conv2d(ch, ch, 3, padding=1)
def forward(self, x):
x = F.interpolate(x, scale_factor=2, mode="nearest")
return self.conv(x)
class UNet(nn.Module):
"""U-Net for conditional diffusion.
Condition (e.g. previous frame) is concatenated to the noisy input along
the channel dimension *before* being passed to forward(). So set
``in_channels = output_channels + condition_channels``.
Args:
in_channels: noisy-target channels + condition channels.
out_channels: channels to predict (same as target).
base_ch: base channel width.
ch_mults: per-level channel multipliers.
n_res: residual blocks per level.
attn_levels: which levels get self-attention (0-indexed).
dropout: dropout rate.
time_dim: timestep embedding dimension.
"""
def __init__(
self,
in_channels,
out_channels,
base_ch=64,
ch_mults=(1, 2, 4, 8),
n_res=2,
attn_levels=(3,),
dropout=0.1,
time_dim=256,
):
super().__init__()
self.n_res = n_res
self.ch_mults = ch_mults
# --- time embedding ---
self.time_embed = nn.Sequential(
SinusoidalPosEmb(time_dim),
nn.Linear(time_dim, time_dim * 4),
nn.SiLU(),
nn.Linear(time_dim * 4, time_dim),
)
# --- input projection ---
self.input_conv = nn.Conv2d(in_channels, base_ch, 3, padding=1)
# --- downsampling path ---
self.downs = nn.ModuleList()
ch = base_ch
skip_chs = [ch] # track channel dims for skip connections
for lvl, mult in enumerate(ch_mults):
out_ch = base_ch * mult
for _ in range(n_res):
self.downs.append(
nn.ModuleDict(
{
"res": ResBlock(ch, out_ch, time_dim, dropout),
**(
{"attn": SelfAttention(out_ch)}
if lvl in attn_levels
else {}
),
}
)
)
ch = out_ch
skip_chs.append(ch)
if lvl < len(ch_mults) - 1:
self.downs.append(nn.ModuleDict({"down": Downsample(ch)}))
skip_chs.append(ch)
# --- middle ---
self.mid_res1 = ResBlock(ch, ch, time_dim, dropout)
self.mid_attn = SelfAttention(ch)
self.mid_res2 = ResBlock(ch, ch, time_dim, dropout)
# --- upsampling path ---
self.ups = nn.ModuleList()
for lvl in reversed(range(len(ch_mults))):
out_ch = base_ch * ch_mults[lvl]
for _ in range(n_res + 1): # +1 to consume downsample skip
skip_ch = skip_chs.pop()
self.ups.append(
nn.ModuleDict(
{
"res": ResBlock(ch + skip_ch, out_ch, time_dim, dropout),
**(
{"attn": SelfAttention(out_ch)}
if lvl in attn_levels
else {}
),
}
)
)
ch = out_ch
if lvl > 0:
self.ups.append(nn.ModuleDict({"up": Upsample(ch)}))
# --- output projection ---
self.out_norm = nn.GroupNorm(min(32, ch), ch)
self.out_conv = nn.Conv2d(ch, out_channels, 3, padding=1)
def forward(self, x, t, cond=None):
"""
Args:
x: noisy target [B, C_out, H, W]
t: diffusion timestep [B] (int or float)
cond: condition [B, C_cond, H, W] (optional, concatenated)
Returns:
predicted noise [B, C_out, H, W]
"""
if cond is not None:
x = torch.cat([x, cond], dim=1)
t_emb = self.time_embed(t)
h = self.input_conv(x)
# --- down ---
skips = [h]
for block in self.downs:
if "down" in block:
h = block["down"](h)
skips.append(h)
else:
h = block["res"](h, t_emb)
if "attn" in block:
h = block["attn"](h)
skips.append(h)
# --- middle ---
h = self.mid_res1(h, t_emb)
h = self.mid_attn(h)
h = self.mid_res2(h, t_emb)
# --- up ---
for block in self.ups:
if "up" in block:
h = block["up"](h)
else:
s = skips.pop()
h = torch.cat([h, s], dim=1)
h = block["res"](h, t_emb)
if "attn" in block:
h = block["attn"](h)
h = F.silu(self.out_norm(h))
return self.out_conv(h)