# xray_generator/models/unet.py import math import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange def timestep_embedding(timesteps, dim, max_period=10000): """Create sinusoidal timestep embeddings.""" half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half ) args = timesteps[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) return embedding class TimeEmbedding(nn.Module): """Time embedding module for diffusion models.""" def __init__(self, dim, dim_out=None): """Initialize time embedding.""" super().__init__() if dim_out is None: dim_out = dim self.dim = dim # Linear layers for time embedding self.main = nn.Sequential( nn.Linear(dim, dim * 4), nn.SiLU(), nn.Linear(dim * 4, dim_out) ) def forward(self, time): """Forward pass through time embedding.""" time_emb = timestep_embedding(time, self.dim) return self.main(time_emb) class SelfAttention(nn.Module): """Self-attention module for VAE and UNet.""" def __init__(self, channels, num_heads=8): """Initialize self-attention module.""" super().__init__() assert channels % num_heads == 0, f"Channels must be divisible by num_heads" self.num_heads = num_heads self.head_dim = channels // num_heads self.scale = self.head_dim ** -0.5 # QKV projection self.to_qkv = nn.Conv2d(channels, channels * 3, 1, bias=False) self.to_out = nn.Conv2d(channels, channels, 1) # Normalization self.norm = nn.GroupNorm(8, channels) def forward(self, x): """Forward pass through self-attention.""" b, c, h, w = x.shape # Apply normalization x_norm = self.norm(x) # Get QKV qkv = self.to_qkv(x_norm).chunk(3, dim=1) q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> b h (x y) d', h=self.num_heads), qkv) # Attention attn = torch.matmul(q, k.transpose(-1, -2)) * self.scale attn = attn.softmax(dim=-1) # Combine out = torch.matmul(attn, v) out = rearrange(out, 'b h (x y) d -> b (h d) x y', x=h, y=w) # Project to output out = self.to_out(out) # Add residual return out + x class CrossAttention(nn.Module): """Cross-attention module for conditioning on text.""" def __init__(self, channels, text_dim, num_heads=8): """Initialize cross-attention module.""" super().__init__() assert channels % num_heads == 0, f"Channels must be divisible by num_heads" self.num_heads = num_heads self.head_dim = channels // num_heads self.scale = self.head_dim ** -0.5 # Query from image features self.to_q = nn.Conv2d(channels, channels, 1, bias=False) # Key and value from text self.to_k = nn.Linear(text_dim, channels, bias=False) self.to_v = nn.Linear(text_dim, channels, bias=False) self.to_out = nn.Conv2d(channels, channels, 1) # Normalization self.norm = nn.GroupNorm(8, channels) def forward(self, x, context): """Forward pass through cross-attention.""" b, c, h, w = x.shape # Apply normalization x_norm = self.norm(x) # Get query from image features q = self.to_q(x_norm) q = rearrange(q, 'b c h w -> b (h w) c') q = rearrange(q, 'b n (h d) -> b h n d', h=self.num_heads) # Get key and value from text context k = self.to_k(context) v = self.to_v(context) k = rearrange(k, 'b n (h d) -> b h n d', h=self.num_heads) v = rearrange(v, 'b n (h d) -> b h n d', h=self.num_heads) # Attention attn = torch.matmul(q, k.transpose(-1, -2)) * self.scale attn = attn.softmax(dim=-1) # Combine out = torch.matmul(attn, v) out = rearrange(out, 'b h (x y) d -> b (h d) x y', x=h, y=w) # Project to output out = self.to_out(out) # Add residual return out + x class ResnetBlock(nn.Module): """Residual block with time embedding and optional attention.""" def __init__( self, in_channels, out_channels, time_channels, dropout=0.0, use_attention=False, attention_type="self", text_dim=None ): """Initialize residual block.""" super().__init__() # First convolution block self.block1 = nn.Sequential( nn.GroupNorm(8, in_channels), nn.SiLU(), nn.Conv2d(in_channels, out_channels, 3, padding=1) ) # Time embedding self.time_emb = nn.Sequential( nn.SiLU(), nn.Linear(time_channels, out_channels) ) # Second convolution block self.block2 = nn.Sequential( nn.GroupNorm(8, out_channels), nn.SiLU(), nn.Dropout(dropout), nn.Conv2d(out_channels, out_channels, 3, padding=1) ) # Attention self.use_attention = use_attention if use_attention: if attention_type == "self": self.attention = SelfAttention(out_channels) elif attention_type == "cross": assert text_dim is not None, "Text dimension required for cross-attention" self.attention = CrossAttention(out_channels, text_dim) else: raise ValueError(f"Unknown attention type: {attention_type}") # Shortcut connection self.shortcut = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity() def forward(self, x, time_emb, context=None): """Forward pass through residual block.""" # Shortcut shortcut = self.shortcut(x) # Block 1 h = self.block1(x) # Add time embedding h += self.time_emb(time_emb)[:, :, None, None] # Block 2 h = self.block2(h) # Apply attention if self.use_attention: if isinstance(self.attention, CrossAttention) and context is not None: h = self.attention(h, context) else: h = self.attention(h) # Add shortcut return h + shortcut class Downsample(nn.Module): """Downsampling layer for UNet.""" def __init__(self, channels, use_conv=True): """Initialize downsampling layer.""" super().__init__() if use_conv: self.downsample = nn.Conv2d(channels, channels, 3, stride=2, padding=1) else: self.downsample = nn.AvgPool2d(2, stride=2) def forward(self, x): """Forward pass through downsampling layer.""" return self.downsample(x) class Upsample(nn.Module): """Upsampling layer for UNet.""" def __init__(self, channels, use_conv=True): """Initialize upsampling layer.""" super().__init__() self.upsample = nn.ConvTranspose2d(channels, channels, 4, stride=2, padding=1) self.use_conv = use_conv if use_conv: self.conv = nn.Conv2d(channels, channels, 3, padding=1) def forward(self, x): """Forward pass through upsampling layer.""" x = self.upsample(x) if self.use_conv: x = self.conv(x) return x class DiffusionUNet(nn.Module): """UNet model for diffusion process with cross-attention for text conditioning.""" def __init__( self, in_channels=4, model_channels=64, out_channels=4, num_res_blocks=2, attention_resolutions=(8, 16, 32), dropout=0.0, channel_mult=(1, 2, 4, 8), context_dim=768 ): """Initialize UNet model.""" super().__init__() # Parameters self.in_channels = in_channels self.model_channels = model_channels self.out_channels = out_channels self.num_res_blocks = num_res_blocks self.attention_resolutions = attention_resolutions self.dropout = dropout self.channel_mult = channel_mult self.context_dim = context_dim # Time embedding time_embed_dim = model_channels * 4 self.time_embed = TimeEmbedding(model_channels, time_embed_dim) # Input block self.input_blocks = nn.ModuleList([ nn.Conv2d(in_channels, model_channels, 3, padding=1) ]) # Keep track of channels for skip connections input_block_channels = [model_channels] ch = model_channels ds = 1 # Downsampling factor # Downsampling blocks for level, mult in enumerate(channel_mult): for _ in range(num_res_blocks): # Use cross-attention if at an attention resolution use_attention = ds in attention_resolutions # Create block block = ResnetBlock( ch, model_channels * mult, time_embed_dim, dropout, use_attention, "cross" if use_attention else None, context_dim if use_attention else None ) # Add to input blocks self.input_blocks.append(block) # Update channels ch = model_channels * mult input_block_channels.append(ch) # Add downsampling except for last level if level != len(channel_mult) - 1: self.input_blocks.append(Downsample(ch)) input_block_channels.append(ch) ds *= 2 # Middle blocks (bottleneck) with cross-attention self.middle_block = nn.ModuleList([ ResnetBlock( ch, ch, time_embed_dim, dropout, True, "cross", context_dim ), ResnetBlock( ch, ch, time_embed_dim, dropout, False ) ]) # Upsampling blocks self.output_blocks = nn.ModuleList([]) for level, mult in list(enumerate(channel_mult))[::-1]: for i in range(num_res_blocks + 1): # Combine with skip connection skip_ch = input_block_channels.pop() # Use cross-attention if at an attention resolution use_attention = ds in attention_resolutions # Create block block = ResnetBlock( ch + skip_ch, model_channels * mult, time_embed_dim, dropout, use_attention, "cross" if use_attention else None, context_dim if use_attention else None ) # Add to output blocks self.output_blocks.append(block) # Update channels ch = model_channels * mult # Add upsampling except for last block of last level if level != 0 and i == num_res_blocks: self.output_blocks.append(Upsample(ch)) ds //= 2 # Final layers self.out = nn.Sequential( nn.GroupNorm(8, ch), nn.SiLU(), nn.Conv2d(ch, out_channels, 3, padding=1) ) # Initialize weights self.apply(self._init_weights) def _init_weights(self, m): """Initialize weights.""" if isinstance(m, (nn.Conv2d, nn.Linear)): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) def forward(self, x, timesteps, context=None): """Forward pass through UNet.""" # Time embedding t_emb = self.time_embed(timesteps) # Input blocks (downsampling) h = x hs = [h] # Store intermediate activations for skip connections for module in self.input_blocks: if isinstance(module, ResnetBlock): h = module(h, t_emb, context) else: h = module(h) hs.append(h) # Middle block for module in self.middle_block: h = module(h, t_emb, context) if isinstance(module, ResnetBlock) else module(h) # Output blocks (upsampling) for module in self.output_blocks: if isinstance(module, ResnetBlock): # Add skip connection h = torch.cat([h, hs.pop()], dim=1) h = module(h, t_emb, context) else: h = module(h) # Final output return self.out(h)