tan200224's picture
Update model.py
933aa93 verified
"""
VAE model definition.
Input: (B, 4, 256, 256) — 4 slices (3D CT label/mask).
Output: decode(z) -> (B, 4, 256, 256) — 4 reconstructed slices.
"""
import torch
import torch.nn as nn
class Conv(nn.Module):
def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, padding: int = 0):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(inplace=True),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(x)
class ConvTranspose(nn.Module):
def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, padding: int = 0):
super().__init__()
self.conv = nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(inplace=True),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(x)
class VAE(nn.Module):
"""VAE: 4-channel input (4 slices) -> latent -> 4-channel output (4 slices)."""
def __init__(self, base: int = 64):
super().__init__()
self.base = base
self.encoder = nn.Sequential(
Conv(4, base, 3, stride=2, padding=1),
Conv(base, 2 * base, 3, padding=1),
Conv(2 * base, 2 * base, 3, stride=2, padding=1),
Conv(2 * base, 2 * base, 3, padding=1),
Conv(2 * base, 2 * base, 3, stride=2, padding=1),
Conv(2 * base, 4 * base, 3, padding=1),
Conv(4 * base, 4 * base, 3, stride=2, padding=1),
Conv(4 * base, 4 * base, 3, padding=1),
Conv(4 * base, 4 * base, 3, stride=2, padding=1),
nn.Conv2d(4 * base, 64 * base, 8),
nn.LeakyReLU(inplace=True),
)
self.encoder_mu = nn.Conv2d(64 * base, 32 * base, 1)
self.encoder_logvar = nn.Conv2d(64 * base, 32 * base, 1)
self.decoder = nn.Sequential(
nn.Conv2d(32 * base, 64 * base, 1),
ConvTranspose(64 * base, 4 * base, 8),
Conv(4 * base, 4 * base, 3, padding=1),
ConvTranspose(4 * base, 4 * base, 4, stride=2, padding=1),
Conv(4 * base, 4 * base, 3, padding=1),
ConvTranspose(4 * base, 4 * base, 4, stride=2, padding=1),
Conv(4 * base, 2 * base, 3, padding=1),
ConvTranspose(2 * base, 2 * base, 4, stride=2, padding=1),
Conv(2 * base, 2 * base, 3, padding=1),
ConvTranspose(2 * base, 2 * base, 4, stride=2, padding=1),
Conv(2 * base, base, 3, padding=1),
ConvTranspose(base, base, 4, stride=2, padding=1),
nn.Conv2d(base, 4, 3, padding=1),
nn.Sigmoid(),
)
def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
x = self.encoder(x)
return self.encoder_mu(x), self.encoder_logvar(x)
def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
"""Standard VAE reparameterization: z = mu + std * eps. For inference (eval mode), return mu for deterministic output."""
if not self.training:
return mu
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + std * eps
def decode(self, z: torch.Tensor) -> torch.Tensor:
return self.decoder(z)
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar