|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
from .unet import SelfAttention
|
|
|
|
|
|
class VAEEncoder(nn.Module):
|
|
|
"""Encoder for VAE with attention mechanisms."""
|
|
|
def __init__(
|
|
|
self,
|
|
|
in_channels=1,
|
|
|
latent_channels=4,
|
|
|
hidden_dims=[64, 128, 256, 512],
|
|
|
attention_resolutions=[32, 16]
|
|
|
):
|
|
|
"""Initialize VAE encoder."""
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
self.conv_in = nn.Conv2d(in_channels, hidden_dims[0], 3, padding=1)
|
|
|
|
|
|
|
|
|
self.down_blocks = nn.ModuleList()
|
|
|
|
|
|
|
|
|
for i in range(len(hidden_dims) - 1):
|
|
|
in_dim = hidden_dims[i]
|
|
|
out_dim = hidden_dims[i + 1]
|
|
|
|
|
|
|
|
|
resolution = 256 // (2 ** i)
|
|
|
use_attention = resolution in attention_resolutions
|
|
|
|
|
|
block = []
|
|
|
|
|
|
|
|
|
if use_attention:
|
|
|
block.append(SelfAttention(in_dim))
|
|
|
|
|
|
|
|
|
block.append(nn.Sequential(
|
|
|
nn.GroupNorm(8, in_dim),
|
|
|
nn.SiLU(),
|
|
|
nn.Conv2d(in_dim, out_dim, 3, stride=2, padding=1)
|
|
|
))
|
|
|
|
|
|
self.down_blocks.append(nn.Sequential(*block))
|
|
|
|
|
|
|
|
|
self.final = nn.Sequential(
|
|
|
nn.GroupNorm(8, hidden_dims[-1]),
|
|
|
nn.SiLU(),
|
|
|
nn.Conv2d(hidden_dims[-1], latent_channels * 2, 3, padding=1)
|
|
|
)
|
|
|
|
|
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
|
|
def _init_weights(self, m):
|
|
|
"""Initialize weights with Kaiming normal."""
|
|
|
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
|
|
nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
|
|
|
if m.bias is not None:
|
|
|
nn.init.zeros_(m.bias)
|
|
|
|
|
|
def forward(self, x):
|
|
|
"""Forward pass through encoder."""
|
|
|
|
|
|
x = self.conv_in(x)
|
|
|
|
|
|
|
|
|
for block in self.down_blocks:
|
|
|
x = block(x)
|
|
|
|
|
|
|
|
|
x = self.final(x)
|
|
|
|
|
|
|
|
|
mu, logvar = torch.chunk(x, 2, dim=1)
|
|
|
|
|
|
return mu, logvar
|
|
|
|
|
|
class VAEDecoder(nn.Module):
|
|
|
"""Decoder for VAE with attention mechanisms."""
|
|
|
def __init__(
|
|
|
self,
|
|
|
latent_channels=4,
|
|
|
out_channels=1,
|
|
|
hidden_dims=[512, 256, 128, 64],
|
|
|
attention_resolutions=[16, 32]
|
|
|
):
|
|
|
"""Initialize VAE decoder."""
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
self.conv_in = nn.Conv2d(latent_channels, hidden_dims[0], 3, padding=1)
|
|
|
|
|
|
|
|
|
self.up_blocks = nn.ModuleList()
|
|
|
|
|
|
|
|
|
for i in range(len(hidden_dims) - 1):
|
|
|
in_dim = hidden_dims[i]
|
|
|
out_dim = hidden_dims[i + 1]
|
|
|
|
|
|
|
|
|
resolution = 16 * (2 ** i)
|
|
|
use_attention = resolution in attention_resolutions
|
|
|
|
|
|
block = []
|
|
|
|
|
|
|
|
|
if use_attention:
|
|
|
block.append(SelfAttention(in_dim))
|
|
|
|
|
|
|
|
|
block.append(nn.Sequential(
|
|
|
nn.GroupNorm(8, in_dim),
|
|
|
nn.SiLU(),
|
|
|
nn.ConvTranspose2d(in_dim, out_dim, 4, stride=2, padding=1)
|
|
|
))
|
|
|
|
|
|
self.up_blocks.append(nn.Sequential(*block))
|
|
|
|
|
|
|
|
|
self.final = nn.Sequential(
|
|
|
nn.GroupNorm(8, hidden_dims[-1]),
|
|
|
nn.SiLU(),
|
|
|
nn.Conv2d(hidden_dims[-1], out_channels, 3, padding=1)
|
|
|
)
|
|
|
|
|
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
|
|
def _init_weights(self, m):
|
|
|
"""Initialize weights with Kaiming normal."""
|
|
|
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
|
|
|
nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
|
|
|
if m.bias is not None:
|
|
|
nn.init.zeros_(m.bias)
|
|
|
|
|
|
def forward(self, x):
|
|
|
"""Forward pass through decoder."""
|
|
|
|
|
|
x = self.conv_in(x)
|
|
|
|
|
|
|
|
|
for block in self.up_blocks:
|
|
|
x = block(x)
|
|
|
|
|
|
|
|
|
x = self.final(x)
|
|
|
|
|
|
return x
|
|
|
|
|
|
class MedicalVAE(nn.Module):
|
|
|
"""Complete VAE model for medical images."""
|
|
|
def __init__(
|
|
|
self,
|
|
|
in_channels=1,
|
|
|
out_channels=1,
|
|
|
latent_channels=4,
|
|
|
hidden_dims=[64, 128, 256, 512],
|
|
|
attention_resolutions=[16, 32]
|
|
|
):
|
|
|
"""Initialize VAE."""
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
self.encoder = VAEEncoder(
|
|
|
in_channels=in_channels,
|
|
|
latent_channels=latent_channels,
|
|
|
hidden_dims=hidden_dims,
|
|
|
attention_resolutions=attention_resolutions
|
|
|
)
|
|
|
|
|
|
self.decoder = VAEDecoder(
|
|
|
latent_channels=latent_channels,
|
|
|
out_channels=out_channels,
|
|
|
hidden_dims=list(reversed(hidden_dims)),
|
|
|
attention_resolutions=attention_resolutions
|
|
|
)
|
|
|
|
|
|
|
|
|
self.latent_channels = latent_channels
|
|
|
|
|
|
def encode(self, x):
|
|
|
"""Encode input to latent space."""
|
|
|
return self.encoder(x)
|
|
|
|
|
|
def decode(self, z):
|
|
|
"""Decode from latent space."""
|
|
|
return self.decoder(z)
|
|
|
|
|
|
def reparameterize(self, mu, logvar):
|
|
|
"""Reparameterization trick."""
|
|
|
std = torch.exp(0.5 * logvar)
|
|
|
eps = torch.randn_like(std)
|
|
|
return mu + eps * std
|
|
|
|
|
|
def forward(self, x):
|
|
|
"""Forward pass through the VAE."""
|
|
|
|
|
|
mu, logvar = self.encode(x)
|
|
|
|
|
|
|
|
|
z = self.reparameterize(mu, logvar)
|
|
|
|
|
|
|
|
|
recon = self.decode(z)
|
|
|
|
|
|
return recon, mu, logvar |