martinbadrous commited on
Commit
1a360c4
·
verified ·
1 Parent(s): 5b36243

Create models/unet.py

Browse files
Files changed (1) hide show
  1. models/unet.py +85 -0
models/unet.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Implementation of a lightweight U-Net architecture used for segmentation."""
2
+ from __future__ import annotations
3
+ from typing import Iterable, Sequence
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ class DoubleConv(nn.Module):
8
+ """(convolution => [BN] => ReLU) * 2 block used throughout U-Net."""
9
+ def __init__(self, in_channels: int, out_channels: int, mid_channels: int | None = None) -> None:
10
+ super().__init__()
11
+ if mid_channels is None:
12
+ mid_channels = out_channels
13
+ self.block = nn.Sequential(
14
+ nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
15
+ nn.BatchNorm2d(mid_channels),
16
+ nn.ReLU(inplace=True),
17
+ nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
18
+ nn.BatchNorm2d(out_channels),
19
+ nn.ReLU(inplace=True),
20
+ )
21
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
22
+ return self.block(x)
23
+
24
+ class UNet(nn.Module):
25
+ """Standard U-Net implementation with configurable encoder width."""
26
+ def __init__(self, in_channels: int = 1, out_channels: int = 1, features: Sequence[int] | Iterable[int] = (64, 128, 256, 512), bilinear: bool = True) -> None:
27
+ super().__init__()
28
+ self.features = tuple(int(f) for f in features)
29
+ if len(self.features) < 2:
30
+ raise ValueError("`features` must contain at least two stages")
31
+ self.bilinear = bilinear
32
+ self.downs = nn.ModuleList()
33
+ self.ups = nn.ModuleList()
34
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
35
+ current_channels = in_channels
36
+ for feature in self.features:
37
+ self.downs.append(DoubleConv(current_channels, feature))
38
+ current_channels = feature
39
+ factor = 2 if bilinear else 1
40
+ self.bottleneck = DoubleConv(self.features[-1], self.features[-1] * factor)
41
+ reversed_features = list(reversed(self.features))
42
+ prev_channels = self.features[-1] * factor
43
+ for feature in reversed_features:
44
+ if bilinear:
45
+ self.ups.append(nn.Sequential(
46
+ nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
47
+ nn.Conv2d(prev_channels, feature, kernel_size=1),
48
+ ))
49
+ else:
50
+ self.ups.append(nn.ConvTranspose2d(prev_channels, feature, kernel_size=2, stride=2))
51
+ self.ups.append(DoubleConv(feature * 2, feature))
52
+ prev_channels = feature
53
+ self.out_conv = nn.Conv2d(self.features[0], out_channels, kernel_size=1)
54
+ self.apply(self._init_weights)
55
+ @staticmethod
56
+ def _init_weights(module: nn.Module) -> None:
57
+ if isinstance(module, nn.Conv2d):
58
+ nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
59
+ if module.bias is not None:
60
+ nn.init.zeros_(module.bias)
61
+ elif isinstance(module, nn.BatchNorm2d):
62
+ nn.init.ones_(module.weight)
63
+ nn.init.zeros_(module.bias)
64
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
65
+ skip_connections = []
66
+ for down in self.downs:
67
+ x = down(x)
68
+ skip_connections.append(x)
69
+ x = self.pool(x)
70
+ x = self.bottleneck(x)
71
+ skip_connections = skip_connections[::-1]
72
+ for idx in range(0, len(self.ups), 2):
73
+ upsample = self.ups[idx]
74
+ conv = self.ups[idx + 1]
75
+ x = upsample(x)
76
+ skip = skip_connections[idx // 2]
77
+ if x.shape[-2:] != skip.shape[-2:]:
78
+ diff_y = skip.size(2) - x.size(2)
79
+ diff_x = skip.size(3) - x.size(3)
80
+ x = nn.functional.pad(x, [diff_x // 2, diff_x - diff_x // 2, diff_y // 2, diff_y - diff_y // 2])
81
+ x = torch.cat([skip, x], dim=1)
82
+ x = conv(x)
83
+ return self.out_conv(x)
84
+
85
+ __all__ = ["UNet"]