File size: 7,273 Bytes
2f33c28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
"""
model.py β€” 3D U-Net for BraTS2020 Brain Tumor Segmentation
============================================================
Architecture: Encoder β†’ Bottleneck β†’ Decoder with skip connections.
Each level doubles/halves the feature maps and halves/doubles spatial dims.

Input:  (B, 4,  128, 128, 128)  β€” batch of 4-modality MRI volumes
Output: (B, 4,  128, 128, 128)  β€” per-voxel class logits
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


# ─── Residual Block ───────────────────────────────────────────────────────────
# Two conv3d layers with a skip connection.
# InstanceNorm3d is used instead of BatchNorm because BraTS batch size is 1
# (one 128Β³ volume barely fits in VRAM), and BatchNorm is unstable at batch=1.
# LeakyReLU(0.01) avoids dead neurons better than standard ReLU.

class ResidualBlock(nn.Module):
    def __init__(self, in_ch: int, out_ch: int):
        super().__init__()
        self.conv1 = nn.Conv3d(in_ch,  out_ch, kernel_size=3, padding=1, bias=False)
        self.norm1 = nn.InstanceNorm3d(out_ch, affine=True)
        self.conv2 = nn.Conv3d(out_ch, out_ch, kernel_size=3, padding=1, bias=False)
        self.norm2 = nn.InstanceNorm3d(out_ch, affine=True)
        self.act   = nn.LeakyReLU(0.01, inplace=True)

        # 1Γ—1Γ—1 projection so skip connection can match channel count
        # If in_ch == out_ch this is just an identity (no parameters added)
        self.skip = nn.Conv3d(in_ch, out_ch, kernel_size=1, bias=False) \
                    if in_ch != out_ch else nn.Identity()

    def forward(self, x):
        residual = self.skip(x)
        out = self.act(self.norm1(self.conv1(x)))
        out = self.norm2(self.conv2(out))
        return self.act(out + residual)   # add skip then activate


# ─── Encoder Block ────────────────────────────────────────────────────────────
# Each encoder level:
#   1. ResidualBlock to extract features at current resolution
#   2. Strided Conv3d (stride=2) to halve spatial dimensions
# Returns both the downsampled output AND the pre-downsample features (skip).
# The skip connection is later concatenated in the corresponding decoder level.

class DownBlock(nn.Module):
    def __init__(self, in_ch: int, out_ch: int):
        super().__init__()
        self.res  = ResidualBlock(in_ch, out_ch)
        self.down = nn.Conv3d(out_ch, out_ch, kernel_size=3,
                              stride=2, padding=1, bias=False)

    def forward(self, x):
        skip = self.res(x)      # full-resolution features β†’ stored for skip
        out  = self.down(skip)  # halved spatial dims β†’ passed to next level
        return out, skip


# ─── Decoder Block ────────────────────────────────────────────────────────────
# Each decoder level:
#   1. Trilinear upsample to double spatial dimensions
#   2. Concatenate with the skip connection from the matching encoder level
#   3. ResidualBlock to fuse upsampled + skip features
# The concat doubles the channel count, so ResidualBlock takes in_ch + skip_ch.

class UpBlock(nn.Module):
    def __init__(self, in_ch: int, skip_ch: int, out_ch: int):
        super().__init__()
        self.up  = nn.Upsample(scale_factor=2, mode="trilinear",
                               align_corners=True)
        self.res = ResidualBlock(in_ch + skip_ch, out_ch)

    def forward(self, x, skip):
        x = self.up(x)

        # Pad if spatial dims don't match exactly (can happen with odd input sizes)
        if x.shape != skip.shape:
            x = F.pad(x, _pad_to_match(x, skip))

        return self.res(torch.cat([x, skip], dim=1))


# ─── Full 3D U-Net ────────────────────────────────────────────────────────────
# depth=4 means 4 encoder levels, 1 bottleneck, 4 decoder levels.
# base_filters=32 means the first level has 32 feature maps.
# Each subsequent level doubles: [32, 64, 128, 256] with bottleneck at 512.
# Total parameters with default settings: ~19M

class UNet3D(nn.Module):
    def __init__(self, in_channels=4, out_channels=4,
                 base_filters=32, depth=4):
        super().__init__()
        self.depth = depth

        # Build filter counts per level: [32, 64, 128, 256, 512]
        filters = [base_filters * (2 ** i) for i in range(depth + 1)]

        # Encoder: depth DownBlocks
        self.encoders = nn.ModuleList()
        self.encoders.append(DownBlock(in_channels, filters[0]))
        for i in range(1, depth):
            self.encoders.append(DownBlock(filters[i - 1], filters[i]))

        # Bottleneck: single ResidualBlock at lowest resolution
        self.bottleneck = ResidualBlock(filters[depth - 1], filters[depth])

        # Decoder: depth UpBlocks (mirror of encoder)
        self.decoders = nn.ModuleList()
        for i in range(depth - 1, -1, -1):
            self.decoders.append(UpBlock(filters[i + 1], filters[i], filters[i]))

        # Final 1Γ—1Γ—1 conv: map feature maps β†’ class logits
        self.head = nn.Conv3d(filters[0], out_channels, kernel_size=1)

        self._init_weights()

    def forward(self, x):
        # Encode β€” collect skip connections
        skips = []
        for enc in self.encoders:
            x, skip = enc(x)
            skips.append(skip)

        # Bottleneck
        x = self.bottleneck(x)

        # Decode β€” consume skip connections in reverse order
        for i, dec in enumerate(self.decoders):
            x = dec(x, skips[-(i + 1)])

        return self.head(x)   # (B, 4, 128, 128, 128) logits

    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

    def _init_weights(self):
        # Kaiming init for conv layers β€” designed for LeakyReLU
        # Ones/zeros for InstanceNorm affine parameters (standard)
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out",
                                        nonlinearity="leaky_relu")
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.InstanceNorm3d) and m.affine:
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)


# ─── Utility ──────────────────────────────────────────────────────────────────
# Computes the padding needed to make tensor x match the spatial dims of target.
# Only needed when input dimensions are odd, causing off-by-one after downsample.

def _pad_to_match(x: torch.Tensor, target: torch.Tensor):
    diffs = [t - s for s, t in zip(x.shape[2:], target.shape[2:])]
    pad = []
    for d in reversed(diffs):
        pad += [d // 2, d - d // 2]
    return pad