| |
|
| | import math
|
| | import torch
|
| | import torch.nn as nn
|
| | import torch.nn.functional as F
|
| | from einops import rearrange
|
| |
|
| | def timestep_embedding(timesteps, dim, max_period=10000):
|
| | """Create sinusoidal timestep embeddings."""
|
| | half = dim // 2
|
| | freqs = torch.exp(
|
| | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half
|
| | )
|
| | args = timesteps[:, None].float() * freqs[None]
|
| | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| | if dim % 2:
|
| | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| | return embedding
|
| |
|
| | class TimeEmbedding(nn.Module):
|
| | """Time embedding module for diffusion models."""
|
| | def __init__(self, dim, dim_out=None):
|
| | """Initialize time embedding."""
|
| | super().__init__()
|
| | if dim_out is None:
|
| | dim_out = dim
|
| |
|
| | self.dim = dim
|
| |
|
| |
|
| | self.main = nn.Sequential(
|
| | nn.Linear(dim, dim * 4),
|
| | nn.SiLU(),
|
| | nn.Linear(dim * 4, dim_out)
|
| | )
|
| |
|
| | def forward(self, time):
|
| | """Forward pass through time embedding."""
|
| | time_emb = timestep_embedding(time, self.dim)
|
| | return self.main(time_emb)
|
| |
|
| | class SelfAttention(nn.Module):
|
| | """Self-attention module for VAE and UNet."""
|
| | def __init__(self, channels, num_heads=8):
|
| | """Initialize self-attention module."""
|
| | super().__init__()
|
| | assert channels % num_heads == 0, f"Channels must be divisible by num_heads"
|
| |
|
| | self.num_heads = num_heads
|
| | self.head_dim = channels // num_heads
|
| | self.scale = self.head_dim ** -0.5
|
| |
|
| |
|
| | self.to_qkv = nn.Conv2d(channels, channels * 3, 1, bias=False)
|
| | self.to_out = nn.Conv2d(channels, channels, 1)
|
| |
|
| |
|
| | self.norm = nn.GroupNorm(8, channels)
|
| |
|
| | def forward(self, x):
|
| | """Forward pass through self-attention."""
|
| | b, c, h, w = x.shape
|
| |
|
| |
|
| | x_norm = self.norm(x)
|
| |
|
| |
|
| | qkv = self.to_qkv(x_norm).chunk(3, dim=1)
|
| | q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> b h (x y) d', h=self.num_heads), qkv)
|
| |
|
| |
|
| | attn = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
| | attn = attn.softmax(dim=-1)
|
| |
|
| |
|
| | out = torch.matmul(attn, v)
|
| | out = rearrange(out, 'b h (x y) d -> b (h d) x y', x=h, y=w)
|
| |
|
| |
|
| | out = self.to_out(out)
|
| |
|
| |
|
| | return out + x
|
| |
|
| | class CrossAttention(nn.Module):
|
| | """Cross-attention module for conditioning on text."""
|
| | def __init__(self, channels, text_dim, num_heads=8):
|
| | """Initialize cross-attention module."""
|
| | super().__init__()
|
| | assert channels % num_heads == 0, f"Channels must be divisible by num_heads"
|
| |
|
| | self.num_heads = num_heads
|
| | self.head_dim = channels // num_heads
|
| | self.scale = self.head_dim ** -0.5
|
| |
|
| |
|
| | self.to_q = nn.Conv2d(channels, channels, 1, bias=False)
|
| |
|
| | self.to_k = nn.Linear(text_dim, channels, bias=False)
|
| | self.to_v = nn.Linear(text_dim, channels, bias=False)
|
| |
|
| | self.to_out = nn.Conv2d(channels, channels, 1)
|
| |
|
| |
|
| | self.norm = nn.GroupNorm(8, channels)
|
| |
|
| | def forward(self, x, context):
|
| | """Forward pass through cross-attention."""
|
| | b, c, h, w = x.shape
|
| |
|
| |
|
| | x_norm = self.norm(x)
|
| |
|
| |
|
| | q = self.to_q(x_norm)
|
| | q = rearrange(q, 'b c h w -> b (h w) c')
|
| | q = rearrange(q, 'b n (h d) -> b h n d', h=self.num_heads)
|
| |
|
| |
|
| | k = self.to_k(context)
|
| | v = self.to_v(context)
|
| | k = rearrange(k, 'b n (h d) -> b h n d', h=self.num_heads)
|
| | v = rearrange(v, 'b n (h d) -> b h n d', h=self.num_heads)
|
| |
|
| |
|
| | attn = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
| | attn = attn.softmax(dim=-1)
|
| |
|
| |
|
| | out = torch.matmul(attn, v)
|
| | out = rearrange(out, 'b h (x y) d -> b (h d) x y', x=h, y=w)
|
| |
|
| |
|
| | out = self.to_out(out)
|
| |
|
| |
|
| | return out + x
|
| |
|
| | class ResnetBlock(nn.Module):
|
| | """Residual block with time embedding and optional attention."""
|
| | def __init__(
|
| | self,
|
| | in_channels,
|
| | out_channels,
|
| | time_channels,
|
| | dropout=0.0,
|
| | use_attention=False,
|
| | attention_type="self",
|
| | text_dim=None
|
| | ):
|
| | """Initialize residual block."""
|
| | super().__init__()
|
| |
|
| |
|
| | self.block1 = nn.Sequential(
|
| | nn.GroupNorm(8, in_channels),
|
| | nn.SiLU(),
|
| | nn.Conv2d(in_channels, out_channels, 3, padding=1)
|
| | )
|
| |
|
| |
|
| | self.time_emb = nn.Sequential(
|
| | nn.SiLU(),
|
| | nn.Linear(time_channels, out_channels)
|
| | )
|
| |
|
| |
|
| | self.block2 = nn.Sequential(
|
| | nn.GroupNorm(8, out_channels),
|
| | nn.SiLU(),
|
| | nn.Dropout(dropout),
|
| | nn.Conv2d(out_channels, out_channels, 3, padding=1)
|
| | )
|
| |
|
| |
|
| | self.use_attention = use_attention
|
| | if use_attention:
|
| | if attention_type == "self":
|
| | self.attention = SelfAttention(out_channels)
|
| | elif attention_type == "cross":
|
| | assert text_dim is not None, "Text dimension required for cross-attention"
|
| | self.attention = CrossAttention(out_channels, text_dim)
|
| | else:
|
| | raise ValueError(f"Unknown attention type: {attention_type}")
|
| |
|
| |
|
| | self.shortcut = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
|
| |
|
| | def forward(self, x, time_emb, context=None):
|
| | """Forward pass through residual block."""
|
| |
|
| | shortcut = self.shortcut(x)
|
| |
|
| |
|
| | h = self.block1(x)
|
| |
|
| |
|
| | h += self.time_emb(time_emb)[:, :, None, None]
|
| |
|
| |
|
| | h = self.block2(h)
|
| |
|
| |
|
| | if self.use_attention:
|
| | if isinstance(self.attention, CrossAttention) and context is not None:
|
| | h = self.attention(h, context)
|
| | else:
|
| | h = self.attention(h)
|
| |
|
| |
|
| | return h + shortcut
|
| |
|
| | class Downsample(nn.Module):
|
| | """Downsampling layer for UNet."""
|
| | def __init__(self, channels, use_conv=True):
|
| | """Initialize downsampling layer."""
|
| | super().__init__()
|
| | if use_conv:
|
| | self.downsample = nn.Conv2d(channels, channels, 3, stride=2, padding=1)
|
| | else:
|
| | self.downsample = nn.AvgPool2d(2, stride=2)
|
| |
|
| | def forward(self, x):
|
| | """Forward pass through downsampling layer."""
|
| | return self.downsample(x)
|
| |
|
| | class Upsample(nn.Module):
|
| | """Upsampling layer for UNet."""
|
| | def __init__(self, channels, use_conv=True):
|
| | """Initialize upsampling layer."""
|
| | super().__init__()
|
| | self.upsample = nn.ConvTranspose2d(channels, channels, 4, stride=2, padding=1)
|
| | self.use_conv = use_conv
|
| | if use_conv:
|
| | self.conv = nn.Conv2d(channels, channels, 3, padding=1)
|
| |
|
| | def forward(self, x):
|
| | """Forward pass through upsampling layer."""
|
| | x = self.upsample(x)
|
| | if self.use_conv:
|
| | x = self.conv(x)
|
| | return x
|
| |
|
| | class DiffusionUNet(nn.Module):
|
| | """UNet model for diffusion process with cross-attention for text conditioning."""
|
| | def __init__(
|
| | self,
|
| | in_channels=4,
|
| | model_channels=64,
|
| | out_channels=4,
|
| | num_res_blocks=2,
|
| | attention_resolutions=(8, 16, 32),
|
| | dropout=0.0,
|
| | channel_mult=(1, 2, 4, 8),
|
| | context_dim=768
|
| | ):
|
| | """Initialize UNet model."""
|
| | super().__init__()
|
| |
|
| |
|
| | self.in_channels = in_channels
|
| | self.model_channels = model_channels
|
| | self.out_channels = out_channels
|
| | self.num_res_blocks = num_res_blocks
|
| | self.attention_resolutions = attention_resolutions
|
| | self.dropout = dropout
|
| | self.channel_mult = channel_mult
|
| | self.context_dim = context_dim
|
| |
|
| |
|
| | time_embed_dim = model_channels * 4
|
| | self.time_embed = TimeEmbedding(model_channels, time_embed_dim)
|
| |
|
| |
|
| | self.input_blocks = nn.ModuleList([
|
| | nn.Conv2d(in_channels, model_channels, 3, padding=1)
|
| | ])
|
| |
|
| |
|
| | input_block_channels = [model_channels]
|
| | ch = model_channels
|
| | ds = 1
|
| |
|
| |
|
| | for level, mult in enumerate(channel_mult):
|
| | for _ in range(num_res_blocks):
|
| |
|
| | use_attention = ds in attention_resolutions
|
| |
|
| |
|
| | block = ResnetBlock(
|
| | ch,
|
| | model_channels * mult,
|
| | time_embed_dim,
|
| | dropout,
|
| | use_attention,
|
| | "cross" if use_attention else None,
|
| | context_dim if use_attention else None
|
| | )
|
| |
|
| |
|
| | self.input_blocks.append(block)
|
| |
|
| |
|
| | ch = model_channels * mult
|
| | input_block_channels.append(ch)
|
| |
|
| |
|
| | if level != len(channel_mult) - 1:
|
| | self.input_blocks.append(Downsample(ch))
|
| | input_block_channels.append(ch)
|
| | ds *= 2
|
| |
|
| |
|
| | self.middle_block = nn.ModuleList([
|
| | ResnetBlock(
|
| | ch, ch, time_embed_dim, dropout, True, "cross", context_dim
|
| | ),
|
| | ResnetBlock(
|
| | ch, ch, time_embed_dim, dropout, False
|
| | )
|
| | ])
|
| |
|
| |
|
| | self.output_blocks = nn.ModuleList([])
|
| | for level, mult in list(enumerate(channel_mult))[::-1]:
|
| | for i in range(num_res_blocks + 1):
|
| |
|
| | skip_ch = input_block_channels.pop()
|
| |
|
| |
|
| | use_attention = ds in attention_resolutions
|
| |
|
| |
|
| | block = ResnetBlock(
|
| | ch + skip_ch,
|
| | model_channels * mult,
|
| | time_embed_dim,
|
| | dropout,
|
| | use_attention,
|
| | "cross" if use_attention else None,
|
| | context_dim if use_attention else None
|
| | )
|
| |
|
| |
|
| | self.output_blocks.append(block)
|
| |
|
| |
|
| | ch = model_channels * mult
|
| |
|
| |
|
| | if level != 0 and i == num_res_blocks:
|
| | self.output_blocks.append(Upsample(ch))
|
| | ds //= 2
|
| |
|
| |
|
| | self.out = nn.Sequential(
|
| | nn.GroupNorm(8, ch),
|
| | nn.SiLU(),
|
| | nn.Conv2d(ch, out_channels, 3, padding=1)
|
| | )
|
| |
|
| |
|
| | self.apply(self._init_weights)
|
| |
|
| | def _init_weights(self, m):
|
| | """Initialize weights."""
|
| | if isinstance(m, (nn.Conv2d, nn.Linear)):
|
| | nn.init.xavier_uniform_(m.weight)
|
| | if m.bias is not None:
|
| | nn.init.zeros_(m.bias)
|
| |
|
| | def forward(self, x, timesteps, context=None):
|
| | """Forward pass through UNet."""
|
| |
|
| | t_emb = self.time_embed(timesteps)
|
| |
|
| |
|
| | h = x
|
| | hs = [h]
|
| |
|
| | for module in self.input_blocks:
|
| | if isinstance(module, ResnetBlock):
|
| | h = module(h, t_emb, context)
|
| | else:
|
| | h = module(h)
|
| | hs.append(h)
|
| |
|
| |
|
| | for module in self.middle_block:
|
| | h = module(h, t_emb, context) if isinstance(module, ResnetBlock) else module(h)
|
| |
|
| |
|
| | for module in self.output_blocks:
|
| | if isinstance(module, ResnetBlock):
|
| |
|
| | h = torch.cat([h, hs.pop()], dim=1)
|
| | h = module(h, t_emb, context)
|
| | else:
|
| | h = module(h)
|
| |
|
| |
|
| | return self.out(h) |