Spaces:
Sleeping
Sleeping
| """ | |
| VAE model definition. | |
| Input: (B, 4, 256, 256) — 4 slices (3D CT label/mask). | |
| Output: decode(z) -> (B, 4, 256, 256) — 4 reconstructed slices. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| class Conv(nn.Module): | |
| def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, padding: int = 0): | |
| super().__init__() | |
| self.conv = nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False), | |
| nn.BatchNorm2d(out_channels), | |
| nn.LeakyReLU(inplace=True), | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.conv(x) | |
| class ConvTranspose(nn.Module): | |
| def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, padding: int = 0): | |
| super().__init__() | |
| self.conv = nn.Sequential( | |
| nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False), | |
| nn.BatchNorm2d(out_channels), | |
| nn.LeakyReLU(inplace=True), | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.conv(x) | |
| class VAE(nn.Module): | |
| """VAE: 4-channel input (4 slices) -> latent -> 4-channel output (4 slices).""" | |
| def __init__(self, base: int = 64): | |
| super().__init__() | |
| self.base = base | |
| self.encoder = nn.Sequential( | |
| Conv(4, base, 3, stride=2, padding=1), | |
| Conv(base, 2 * base, 3, padding=1), | |
| Conv(2 * base, 2 * base, 3, stride=2, padding=1), | |
| Conv(2 * base, 2 * base, 3, padding=1), | |
| Conv(2 * base, 2 * base, 3, stride=2, padding=1), | |
| Conv(2 * base, 4 * base, 3, padding=1), | |
| Conv(4 * base, 4 * base, 3, stride=2, padding=1), | |
| Conv(4 * base, 4 * base, 3, padding=1), | |
| Conv(4 * base, 4 * base, 3, stride=2, padding=1), | |
| nn.Conv2d(4 * base, 64 * base, 8), | |
| nn.LeakyReLU(inplace=True), | |
| ) | |
| self.encoder_mu = nn.Conv2d(64 * base, 32 * base, 1) | |
| self.encoder_logvar = nn.Conv2d(64 * base, 32 * base, 1) | |
| self.decoder = nn.Sequential( | |
| nn.Conv2d(32 * base, 64 * base, 1), | |
| ConvTranspose(64 * base, 4 * base, 8), | |
| Conv(4 * base, 4 * base, 3, padding=1), | |
| ConvTranspose(4 * base, 4 * base, 4, stride=2, padding=1), | |
| Conv(4 * base, 4 * base, 3, padding=1), | |
| ConvTranspose(4 * base, 4 * base, 4, stride=2, padding=1), | |
| Conv(4 * base, 2 * base, 3, padding=1), | |
| ConvTranspose(2 * base, 2 * base, 4, stride=2, padding=1), | |
| Conv(2 * base, 2 * base, 3, padding=1), | |
| ConvTranspose(2 * base, 2 * base, 4, stride=2, padding=1), | |
| Conv(2 * base, base, 3, padding=1), | |
| ConvTranspose(base, base, 4, stride=2, padding=1), | |
| nn.Conv2d(base, 4, 3, padding=1), | |
| nn.Sigmoid(), | |
| ) | |
| def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
| x = self.encoder(x) | |
| return self.encoder_mu(x), self.encoder_logvar(x) | |
| def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: | |
| """Standard VAE reparameterization: z = mu + std * eps. For inference (eval mode), return mu for deterministic output.""" | |
| if not self.training: | |
| return mu | |
| std = torch.exp(0.5 * logvar) | |
| eps = torch.randn_like(std) | |
| return mu + std * eps | |
| def decode(self, z: torch.Tensor) -> torch.Tensor: | |
| return self.decoder(z) | |
| def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| mu, logvar = self.encode(x) | |
| z = self.reparameterize(mu, logvar) | |
| return self.decode(z), mu, logvar | |