| """UNet model for DDPM noise prediction. |
| |
| Architecture follows the DDPM paper (Ho et al. 2020) with: |
| - Sinusoidal time-step embeddings |
| - ResNet blocks with group-normalisation |
| - Self-attention at selected resolutions |
| - Skip connections from encoder to decoder |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from config import ModelConfig |
|
|
|
|
| |
| |
| |
|
|
| class SinusoidalPositionEmbedding(nn.Module): |
| """Transformer-style sinusoidal position encoding for diffusion timesteps.""" |
|
|
| def __init__(self, dim: int): |
| super().__init__() |
| self.dim = dim |
|
|
| def forward(self, t: torch.Tensor) -> torch.Tensor: |
| """t: (B,) integer timesteps → (B, dim)""" |
| device = t.device |
| half_dim = self.dim // 2 |
| emb = math.log(10_000) / (half_dim - 1) |
| emb = torch.exp(torch.arange(half_dim, device=device) * -emb) |
| emb = t.float()[:, None] * emb[None, :] |
| emb = torch.cat([emb.sin(), emb.cos()], dim=-1) |
| return emb |
|
|
|
|
| |
| |
| |
|
|
| class ResBlock(nn.Module): |
| """Residual block with time-embedding conditioning.""" |
|
|
| def __init__(self, in_ch: int, out_ch: int, time_emb_dim: int, dropout: float = 0.0): |
| super().__init__() |
| self.norm1 = nn.GroupNorm(min(32, in_ch), in_ch) |
| self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1) |
| self.time_proj = nn.Linear(time_emb_dim, out_ch) |
|
|
| self.norm2 = nn.GroupNorm(min(32, out_ch), out_ch) |
| self.dropout = nn.Dropout(dropout) |
| self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1) |
|
|
| self.shortcut = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity() |
|
|
| def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor: |
| """x: (B, C, H, W) t_emb: (B, time_emb_dim)""" |
| h = self.norm1(x) |
| h = F.silu(h) |
| h = self.conv1(h) |
|
|
| |
| h = h + self.time_proj(F.silu(t_emb))[:, :, None, None] |
|
|
| h = self.norm2(h) |
| h = F.silu(h) |
| h = self.dropout(h) |
| h = self.conv2(h) |
|
|
| return h + self.shortcut(x) |
|
|
|
|
| class AttentionBlock(nn.Module): |
| """Multi-head self-attention with residual connection.""" |
|
|
| def __init__(self, channels: int, num_heads: int = 4): |
| super().__init__() |
| self.norm = nn.GroupNorm(min(32, channels), channels) |
| self.attn = nn.MultiheadAttention( |
| channels, num_heads, batch_first=True |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """x: (B, C, H, W) → (B, C, H, W)""" |
| B, C, H, W = x.shape |
| h = self.norm(x).reshape(B, C, H * W).permute(0, 2, 1) |
| h, _ = self.attn(h, h, h) |
| h = h.permute(0, 2, 1).reshape(B, C, H, W) |
| return h + x |
|
|
|
|
| |
| |
| |
|
|
| class DownBlock(nn.Module): |
| """One encoder stage: ResBlocks + optional attention + down-sample.""" |
|
|
| def __init__( |
| self, |
| in_ch: int, |
| out_ch: int, |
| time_emb_dim: int, |
| num_res_blocks: int, |
| dropout: float, |
| use_attention: bool, |
| num_heads: int, |
| downsample: bool, |
| ): |
| super().__init__() |
| self.res_blocks = nn.ModuleList() |
| self.attentions = nn.ModuleList() |
|
|
| ch = in_ch |
| for _ in range(num_res_blocks): |
| self.res_blocks.append(ResBlock(ch, out_ch, time_emb_dim, dropout)) |
| ch = out_ch |
| self.attentions.append( |
| AttentionBlock(ch, num_heads) if use_attention else nn.Identity() |
| ) |
|
|
| self.downsample = ( |
| nn.Conv2d(ch, ch, 3, stride=2, padding=1) if downsample else nn.Identity() |
| ) |
|
|
| def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| for res_block, attn in zip(self.res_blocks, self.attentions): |
| x = res_block(x, t_emb) |
| x = attn(x) |
| skip = x |
| x = self.downsample(x) |
| return x, skip |
|
|
|
|
| class UpBlock(nn.Module): |
| """One decoder stage: ResBlocks + optional attention + up-sample.""" |
|
|
| def __init__( |
| self, |
| in_ch: int, |
| out_ch: int, |
| skip_ch: int, |
| time_emb_dim: int, |
| num_res_blocks: int, |
| dropout: float, |
| use_attention: bool, |
| num_heads: int, |
| upsample: bool, |
| ): |
| super().__init__() |
| self.res_blocks = nn.ModuleList() |
| self.attentions = nn.ModuleList() |
|
|
| ch = in_ch |
| for i in range(num_res_blocks): |
| if i == 0: |
| self.res_blocks.append(ResBlock(ch + skip_ch, out_ch, time_emb_dim, dropout)) |
| else: |
| self.res_blocks.append(ResBlock(ch, out_ch, time_emb_dim, dropout)) |
| ch = out_ch |
| self.attentions.append( |
| AttentionBlock(ch, num_heads) if use_attention else nn.Identity() |
| ) |
|
|
| self.upsample = ( |
| nn.Sequential( |
| nn.Upsample(scale_factor=2, mode="nearest"), |
| nn.Conv2d(ch, ch, 3, padding=1), |
| ) |
| if upsample |
| else nn.Identity() |
| ) |
|
|
| def forward(self, x: torch.Tensor, skip: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor: |
| x = torch.cat([x, skip], dim=1) |
| for res_block, attn in zip(self.res_blocks, self.attentions): |
| x = res_block(x, t_emb) |
| x = attn(x) |
| x = self.upsample(x) |
| return x |
|
|
|
|
| |
| |
| |
|
|
| class UNet(nn.Module): |
| """DDPM U-Net for ε-prediction on 32×32 CIFAR-10.""" |
|
|
| def __init__(self, config: ModelConfig): |
| super().__init__() |
|
|
| |
| time_emb_dim = config.base_channels * 4 |
| self.time_embed = nn.Sequential( |
| SinusoidalPositionEmbedding(config.base_channels), |
| nn.Linear(config.base_channels, time_emb_dim), |
| nn.SiLU(), |
| nn.Linear(time_emb_dim, time_emb_dim), |
| ) |
|
|
| |
| self.init_conv = nn.Conv2d(config.in_channels, config.base_channels, 3, padding=1) |
|
|
| |
| chs = [config.base_channels * m for m in config.channel_multipliers] |
| resolutions = [config.image_size // (2 ** i) for i in range(len(chs))] |
|
|
| |
| self.down_blocks = nn.ModuleList() |
| in_ch = config.base_channels |
| for i, (out_ch, res) in enumerate(zip(chs, resolutions)): |
| use_attn = res in config.attention_resolutions |
| is_last = i == len(chs) - 1 |
| self.down_blocks.append( |
| DownBlock( |
| in_ch=in_ch, |
| out_ch=out_ch, |
| time_emb_dim=time_emb_dim, |
| num_res_blocks=config.num_res_blocks, |
| dropout=config.dropout, |
| use_attention=use_attn, |
| num_heads=config.num_heads, |
| downsample=not is_last, |
| ) |
| ) |
| in_ch = out_ch |
|
|
| |
| mid_ch = chs[-1] |
| self.mid_block1 = ResBlock(mid_ch, mid_ch, time_emb_dim, config.dropout) |
| self.mid_attn = AttentionBlock(mid_ch, config.num_heads) |
| self.mid_block2 = ResBlock(mid_ch, mid_ch, time_emb_dim, config.dropout) |
|
|
| |
| rev_chs = list(reversed(chs)) |
| skip_chs = list(reversed(chs)) |
| rev_resolutions = list(reversed(resolutions)) |
|
|
| self.up_blocks = nn.ModuleList() |
| in_ch = mid_ch |
| for i, (out_ch, skip_ch, res) in enumerate(zip(rev_chs, skip_chs, rev_resolutions)): |
| use_attn = res in config.attention_resolutions |
| is_last = i == len(rev_chs) - 1 |
| self.up_blocks.append( |
| UpBlock( |
| in_ch=in_ch, |
| out_ch=out_ch, |
| skip_ch=skip_ch, |
| time_emb_dim=time_emb_dim, |
| num_res_blocks=config.num_res_blocks, |
| dropout=config.dropout, |
| use_attention=use_attn, |
| num_heads=config.num_heads, |
| upsample=not is_last, |
| ) |
| ) |
| in_ch = out_ch |
|
|
| |
| self.out_norm = nn.GroupNorm(min(32, rev_chs[-1]), rev_chs[-1]) |
| self.out_conv = nn.Conv2d(rev_chs[-1], config.out_channels, 3, padding=1) |
|
|
| def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: |
| """x: (B,3,H,W) in [-1,1] t: (B,) int timesteps → predicted noise (B,3,H,W)""" |
| t_emb = self.time_embed(t) |
|
|
| h = self.init_conv(x) |
|
|
| |
| skips: list[torch.Tensor] = [] |
| for down in self.down_blocks: |
| h, skip = down(h, t_emb) |
| skips.append(skip) |
|
|
| |
| h = self.mid_block1(h, t_emb) |
| h = self.mid_attn(h) |
| h = self.mid_block2(h, t_emb) |
|
|
| |
| for up in self.up_blocks: |
| skip = skips.pop() |
| h = up(h, skip, t_emb) |
|
|
| |
| h = self.out_norm(h) |
| h = F.silu(h) |
| h = self.out_conv(h) |
| return h |
|
|