Spaces:
Sleeping
Sleeping
File size: 7,273 Bytes
2f33c28 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 | """
model.py β 3D U-Net for BraTS2020 Brain Tumor Segmentation
============================================================
Architecture: Encoder β Bottleneck β Decoder with skip connections.
Each level doubles/halves the feature maps and halves/doubles spatial dims.
Input: (B, 4, 128, 128, 128) β batch of 4-modality MRI volumes
Output: (B, 4, 128, 128, 128) β per-voxel class logits
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
# βββ Residual Block βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Two conv3d layers with a skip connection.
# InstanceNorm3d is used instead of BatchNorm because BraTS batch size is 1
# (one 128Β³ volume barely fits in VRAM), and BatchNorm is unstable at batch=1.
# LeakyReLU(0.01) avoids dead neurons better than standard ReLU.
class ResidualBlock(nn.Module):
def __init__(self, in_ch: int, out_ch: int):
super().__init__()
self.conv1 = nn.Conv3d(in_ch, out_ch, kernel_size=3, padding=1, bias=False)
self.norm1 = nn.InstanceNorm3d(out_ch, affine=True)
self.conv2 = nn.Conv3d(out_ch, out_ch, kernel_size=3, padding=1, bias=False)
self.norm2 = nn.InstanceNorm3d(out_ch, affine=True)
self.act = nn.LeakyReLU(0.01, inplace=True)
# 1Γ1Γ1 projection so skip connection can match channel count
# If in_ch == out_ch this is just an identity (no parameters added)
self.skip = nn.Conv3d(in_ch, out_ch, kernel_size=1, bias=False) \
if in_ch != out_ch else nn.Identity()
def forward(self, x):
residual = self.skip(x)
out = self.act(self.norm1(self.conv1(x)))
out = self.norm2(self.conv2(out))
return self.act(out + residual) # add skip then activate
# βββ Encoder Block ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Each encoder level:
# 1. ResidualBlock to extract features at current resolution
# 2. Strided Conv3d (stride=2) to halve spatial dimensions
# Returns both the downsampled output AND the pre-downsample features (skip).
# The skip connection is later concatenated in the corresponding decoder level.
class DownBlock(nn.Module):
def __init__(self, in_ch: int, out_ch: int):
super().__init__()
self.res = ResidualBlock(in_ch, out_ch)
self.down = nn.Conv3d(out_ch, out_ch, kernel_size=3,
stride=2, padding=1, bias=False)
def forward(self, x):
skip = self.res(x) # full-resolution features β stored for skip
out = self.down(skip) # halved spatial dims β passed to next level
return out, skip
# βββ Decoder Block ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Each decoder level:
# 1. Trilinear upsample to double spatial dimensions
# 2. Concatenate with the skip connection from the matching encoder level
# 3. ResidualBlock to fuse upsampled + skip features
# The concat doubles the channel count, so ResidualBlock takes in_ch + skip_ch.
class UpBlock(nn.Module):
def __init__(self, in_ch: int, skip_ch: int, out_ch: int):
super().__init__()
self.up = nn.Upsample(scale_factor=2, mode="trilinear",
align_corners=True)
self.res = ResidualBlock(in_ch + skip_ch, out_ch)
def forward(self, x, skip):
x = self.up(x)
# Pad if spatial dims don't match exactly (can happen with odd input sizes)
if x.shape != skip.shape:
x = F.pad(x, _pad_to_match(x, skip))
return self.res(torch.cat([x, skip], dim=1))
# βββ Full 3D U-Net ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# depth=4 means 4 encoder levels, 1 bottleneck, 4 decoder levels.
# base_filters=32 means the first level has 32 feature maps.
# Each subsequent level doubles: [32, 64, 128, 256] with bottleneck at 512.
# Total parameters with default settings: ~19M
class UNet3D(nn.Module):
def __init__(self, in_channels=4, out_channels=4,
base_filters=32, depth=4):
super().__init__()
self.depth = depth
# Build filter counts per level: [32, 64, 128, 256, 512]
filters = [base_filters * (2 ** i) for i in range(depth + 1)]
# Encoder: depth DownBlocks
self.encoders = nn.ModuleList()
self.encoders.append(DownBlock(in_channels, filters[0]))
for i in range(1, depth):
self.encoders.append(DownBlock(filters[i - 1], filters[i]))
# Bottleneck: single ResidualBlock at lowest resolution
self.bottleneck = ResidualBlock(filters[depth - 1], filters[depth])
# Decoder: depth UpBlocks (mirror of encoder)
self.decoders = nn.ModuleList()
for i in range(depth - 1, -1, -1):
self.decoders.append(UpBlock(filters[i + 1], filters[i], filters[i]))
# Final 1Γ1Γ1 conv: map feature maps β class logits
self.head = nn.Conv3d(filters[0], out_channels, kernel_size=1)
self._init_weights()
def forward(self, x):
# Encode β collect skip connections
skips = []
for enc in self.encoders:
x, skip = enc(x)
skips.append(skip)
# Bottleneck
x = self.bottleneck(x)
# Decode β consume skip connections in reverse order
for i, dec in enumerate(self.decoders):
x = dec(x, skips[-(i + 1)])
return self.head(x) # (B, 4, 128, 128, 128) logits
def count_parameters(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)
def _init_weights(self):
# Kaiming init for conv layers β designed for LeakyReLU
# Ones/zeros for InstanceNorm affine parameters (standard)
for m in self.modules():
if isinstance(m, nn.Conv3d):
nn.init.kaiming_normal_(m.weight, mode="fan_out",
nonlinearity="leaky_relu")
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.InstanceNorm3d) and m.affine:
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
# βββ Utility ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Computes the padding needed to make tensor x match the spatial dims of target.
# Only needed when input dimensions are odd, causing off-by-one after downsample.
def _pad_to_match(x: torch.Tensor, target: torch.Tensor):
diffs = [t - s for s, t in zip(x.shape[2:], target.shape[2:])]
pad = []
for d in reversed(diffs):
pad += [d // 2, d - d // 2]
return pad |