pyamy's picture
Upload 31 files
0a0f923 verified
# xray_generator/models/vae.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from .unet import SelfAttention
class VAEEncoder(nn.Module):
"""Encoder for VAE with attention mechanisms."""
def __init__(
self,
in_channels=1,
latent_channels=4,
hidden_dims=[64, 128, 256, 512],
attention_resolutions=[32, 16]
):
"""Initialize VAE encoder."""
super().__init__()
# Input convolution
self.conv_in = nn.Conv2d(in_channels, hidden_dims[0], 3, padding=1)
# Downsampling blocks
self.down_blocks = nn.ModuleList()
# Create downsampling blocks
for i in range(len(hidden_dims) - 1):
in_dim = hidden_dims[i]
out_dim = hidden_dims[i + 1]
# Determine resolution
resolution = 256 // (2 ** i)
use_attention = resolution in attention_resolutions
block = []
# Add attention if needed
if use_attention:
block.append(SelfAttention(in_dim))
# Convolution with GroupNorm and activation
block.append(nn.Sequential(
nn.GroupNorm(8, in_dim),
nn.SiLU(),
nn.Conv2d(in_dim, out_dim, 3, stride=2, padding=1)
))
self.down_blocks.append(nn.Sequential(*block))
# Final layers
self.final = nn.Sequential(
nn.GroupNorm(8, hidden_dims[-1]),
nn.SiLU(),
nn.Conv2d(hidden_dims[-1], latent_channels * 2, 3, padding=1)
)
# Initialize weights
self.apply(self._init_weights)
def _init_weights(self, m):
"""Initialize weights with Kaiming normal."""
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, x):
"""Forward pass through encoder."""
# Initial convolution
x = self.conv_in(x)
# Downsampling
for block in self.down_blocks:
x = block(x)
# Final layers
x = self.final(x)
# Split into mu and logvar
mu, logvar = torch.chunk(x, 2, dim=1)
return mu, logvar
class VAEDecoder(nn.Module):
"""Decoder for VAE with attention mechanisms."""
def __init__(
self,
latent_channels=4,
out_channels=1,
hidden_dims=[512, 256, 128, 64],
attention_resolutions=[16, 32]
):
"""Initialize VAE decoder."""
super().__init__()
# Input convolution
self.conv_in = nn.Conv2d(latent_channels, hidden_dims[0], 3, padding=1)
# Upsampling blocks
self.up_blocks = nn.ModuleList()
# Create upsampling blocks
for i in range(len(hidden_dims) - 1):
in_dim = hidden_dims[i]
out_dim = hidden_dims[i + 1]
# Determine resolution
resolution = 16 * (2 ** i) # Starting at 16x16 for latent space
use_attention = resolution in attention_resolutions
block = []
# Add attention if needed
if use_attention:
block.append(SelfAttention(in_dim))
# Add upsampling
block.append(nn.Sequential(
nn.GroupNorm(8, in_dim),
nn.SiLU(),
nn.ConvTranspose2d(in_dim, out_dim, 4, stride=2, padding=1)
))
self.up_blocks.append(nn.Sequential(*block))
# Final layers
self.final = nn.Sequential(
nn.GroupNorm(8, hidden_dims[-1]),
nn.SiLU(),
nn.Conv2d(hidden_dims[-1], out_channels, 3, padding=1)
)
# Initialize weights
self.apply(self._init_weights)
def _init_weights(self, m):
"""Initialize weights with Kaiming normal."""
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, x):
"""Forward pass through decoder."""
# Initial convolution
x = self.conv_in(x)
# Upsampling
for block in self.up_blocks:
x = block(x)
# Final layers
x = self.final(x)
return x
class MedicalVAE(nn.Module):
"""Complete VAE model for medical images."""
def __init__(
self,
in_channels=1,
out_channels=1,
latent_channels=4,
hidden_dims=[64, 128, 256, 512],
attention_resolutions=[16, 32]
):
"""Initialize VAE."""
super().__init__()
# Create encoder and decoder
self.encoder = VAEEncoder(
in_channels=in_channels,
latent_channels=latent_channels,
hidden_dims=hidden_dims,
attention_resolutions=attention_resolutions
)
self.decoder = VAEDecoder(
latent_channels=latent_channels,
out_channels=out_channels,
hidden_dims=list(reversed(hidden_dims)),
attention_resolutions=attention_resolutions
)
# Save parameters
self.latent_channels = latent_channels
def encode(self, x):
"""Encode input to latent space."""
return self.encoder(x)
def decode(self, z):
"""Decode from latent space."""
return self.decoder(z)
def reparameterize(self, mu, logvar):
"""Reparameterization trick."""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, x):
"""Forward pass through the VAE."""
# Encode
mu, logvar = self.encode(x)
# Reparameterize
z = self.reparameterize(mu, logvar)
# Decode
recon = self.decode(z)
return recon, mu, logvar