nano6626 commited on
Commit
7da32d6
·
verified ·
1 Parent(s): e34361f

Update unet.py

Browse files
Files changed (1) hide show
  1. unet.py +189 -129
unet.py CHANGED
@@ -1,129 +1,189 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
- from huggingface_hub import PyTorchModelHubMixin
6
-
7
- #based on: https://github.com/milesial/Pytorch-UNet/tree/master/unet
8
-
9
- class CBR(nn.Module):
10
-
11
- def __init__(self, in_channels, out_channels):
12
-
13
- super().__init__()
14
-
15
- self.conv1=nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False)
16
- self.bnorm1=nn.BatchNorm2d(out_channels)
17
- self.relu1=nn.ReLU(inplace=True)
18
-
19
- def forward(self, x):
20
-
21
- return self.relu1(self.bnorm1(self.conv1(x)))
22
-
23
- class NCBR(nn.Module):
24
-
25
- def __init__(self, in_channels, out_channels,N,skip=False,skipcat=False):
26
-
27
- super().__init__()
28
-
29
- assert N>1
30
- self.skip=skip
31
- self.skipcat=skipcat
32
- channels=[]
33
- channels.append(in_channels)
34
- for i in range(N):
35
- channels.append(out_channels) #len(channels) == N+1
36
-
37
- self.layers=nn.ModuleList()
38
- for i in range(N):
39
- self.layers.append(CBR(channels[i],channels[i+1]))
40
-
41
- def forward(self, x):
42
- for i,layer in enumerate(self.layers):
43
- if i==0:
44
- x=layer(x)
45
- x1=x
46
- else:
47
- x=layer(x)
48
- if self.skip:
49
- if self.skipcat:
50
- x=torch.cat([x,x1],dim=1)
51
- else:
52
- x=x+x1
53
- return x
54
-
55
- class DownNCBR(nn.Module):
56
- """Downscaling with maxpool then NCBR"""
57
- def __init__(self, in_channels, out_channels,N,skip=False,skipcat=False):
58
- super().__init__()
59
- self.maxpool=nn.MaxPool2d(2)
60
- self.ncbr=NCBR(in_channels, out_channels,N=N,skip=skip,skipcat=skipcat)
61
-
62
- def forward(self, x):
63
- return self.ncbr(self.maxpool(x))
64
-
65
- class UpNCBR(nn.Module):
66
- """Upscaling then double conv"""
67
- def __init__(self, in_channels, out_channels,N,skip=False,skipcat=False):
68
- super().__init__()
69
- self.upconv = nn.ConvTranspose2d(in_channels, in_channels//2, kernel_size=2, stride=2)
70
- self.ncbr = NCBR(in_channels, out_channels,N=N,skip=skip,skipcat=skipcat)
71
-
72
- def forward(self, x1, x2):
73
- x1 = self.upconv(x1)
74
- # input is CHW
75
- diffY = x2.size()[2] - x1.size()[2]
76
- diffX = x2.size()[3] - x1.size()[3]
77
-
78
- x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
79
- diffY // 2, diffY - diffY // 2])
80
- x = torch.cat([x2, x1], dim=1)
81
- return self.ncbr(x)
82
-
83
- class OutConv(nn.Module):
84
- def __init__(self, in_channels, out_channels,kernel_size=1):
85
- super(OutConv, self).__init__()
86
- padding=kernel_size//2
87
- self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,padding=padding)
88
-
89
- def forward(self, x):
90
- return self.conv(x)
91
-
92
- class UNet(nn.Module, PyTorchModelHubMixin):
93
- def __init__(self,n_channels,n_classes,N,width,skip,skipcat,catorig,outker):
94
- super(UNet, self).__init__()
95
- self.n_channels = n_channels
96
- self.n_classes = n_classes
97
-
98
- self.catorig=catorig
99
-
100
- self.inc = NCBR(self.n_channels, 2*width,N=N,skip=skip)
101
- self.down1 = DownNCBR(2*width, 2*width if skipcat else 4*width,N=N,skip=skip,skipcat=skipcat)
102
- self.down2 = DownNCBR(4*width, 4*width if skipcat else 8*width,N=N,skip=skip,skipcat=skipcat)
103
- self.down3 = DownNCBR(8*width, 8*width if skipcat else 16*width,N=N,skip=skip,skipcat=skipcat)
104
- self.down4 = DownNCBR(16*width, 16*width if skipcat else 32*width,N=N,skip=skip,skipcat=skipcat)
105
- self.up1 = UpNCBR(32*width, 8*width if skipcat else 16*width,N=N,skip=skip,skipcat=skipcat)
106
- self.up2 = UpNCBR(16*width, 4*width if skipcat else 8*width,N=N,skip=skip,skipcat=skipcat)
107
- self.up3 = UpNCBR(8*width, 2*width if skipcat else 4*width,N=N,skip=skip,skipcat=skipcat)
108
- self.up4 = UpNCBR(4*width, width if skipcat else 2*width,N=N,skip=skip,skipcat=skipcat)
109
- if self.catorig:
110
- self.outc = OutConv(2*width+self.n_channels, self.n_classes,kernel_size=outker)
111
- else:
112
- self.outc = OutConv(2*width, self.n_classes,kernel_size=outker)
113
-
114
- def forward(self, x):
115
- orig=x
116
- x1 = self.inc(x)
117
- x2 = self.down1(x1)
118
- x3 = self.down2(x2)
119
- x4 = self.down3(x3)
120
- x5 = self.down4(x4)
121
- x = self.up1(x5, x4)
122
- x = self.up2(x, x3)
123
- x = self.up3(x, x2)
124
- x = self.up4(x, x1)
125
- if self.catorig:
126
- logits=self.outc(torch.cat([x,orig],axis=1))
127
- else:
128
- logits=self.outc(x)
129
- return logits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from transformers import PreTrainedModel, PretrainedConfig
6
+
7
+
8
+ # -------- building blocks (same as your original) --------
9
+
10
+ class CBR(nn.Module):
11
+ def __init__(self, in_channels, out_channels):
12
+ super().__init__()
13
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False)
14
+ self.bnorm1 = nn.BatchNorm2d(out_channels)
15
+ self.relu1 = nn.ReLU(inplace=True)
16
+
17
+ def forward(self, x):
18
+ return self.relu1(self.bnorm1(self.conv1(x)))
19
+
20
+
21
+ class NCBR(nn.Module):
22
+ def __init__(self, in_channels, out_channels, N, skip=False, skipcat=False):
23
+ super().__init__()
24
+
25
+ assert N > 1
26
+ self.skip = skip
27
+ self.skipcat = skipcat
28
+
29
+ channels = [in_channels] + [out_channels] * N # len(channels) == N+1
30
+
31
+ self.layers = nn.ModuleList()
32
+ for i in range(N):
33
+ self.layers.append(CBR(channels[i], channels[i + 1]))
34
+
35
+ def forward(self, x):
36
+ for i, layer in enumerate(self.layers):
37
+ if i == 0:
38
+ x = layer(x)
39
+ x1 = x
40
+ else:
41
+ x = layer(x)
42
+ if self.skip:
43
+ if self.skipcat:
44
+ x = torch.cat([x, x1], dim=1)
45
+ else:
46
+ x = x + x1
47
+ return x
48
+
49
+
50
+ class DownNCBR(nn.Module):
51
+ """Downscaling with maxpool then NCBR"""
52
+ def __init__(self, in_channels, out_channels, N, skip=False, skipcat=False):
53
+ super().__init__()
54
+ self.maxpool = nn.MaxPool2d(2)
55
+ self.ncbr = NCBR(in_channels, out_channels, N=N, skip=skip, skipcat=skipcat)
56
+
57
+ def forward(self, x):
58
+ return self.ncbr(self.maxpool(x))
59
+
60
+
61
+ class UpNCBR(nn.Module):
62
+ """Upscaling then NCBR"""
63
+ def __init__(self, in_channels, out_channels, N, skip=False, skipcat=False):
64
+ super().__init__()
65
+ self.upconv = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
66
+ self.ncbr = NCBR(in_channels, out_channels, N=N, skip=skip, skipcat=skipcat)
67
+
68
+ def forward(self, x1, x2):
69
+ x1 = self.upconv(x1)
70
+ # input is CHW
71
+ diffY = x2.size()[2] - x1.size()[2]
72
+ diffX = x2.size()[3] - x1.size()[3]
73
+
74
+ x1 = F.pad(
75
+ x1,
76
+ [diffX // 2, diffX - diffX // 2,
77
+ diffY // 2, diffY - diffY // 2]
78
+ )
79
+ x = torch.cat([x2, x1], dim=1)
80
+ return self.ncbr(x)
81
+
82
+
83
+ class OutConv(nn.Module):
84
+ def __init__(self, in_channels, out_channels, kernel_size=1):
85
+ super().__init__()
86
+ padding = kernel_size // 2
87
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)
88
+
89
+ def forward(self, x):
90
+ return self.conv(x)
91
+
92
+
93
+ # -------- HF config class --------
94
+
95
+ class UNetConfig(PretrainedConfig):
96
+ """
97
+ Config for the custom UNet. This is what AutoConfig will load
98
+ when trust_remote_code=True.
99
+ """
100
+ model_type = "custom_unet"
101
+
102
+ def __init__(
103
+ self,
104
+ n_channels=1,
105
+ n_classes=1,
106
+ N=2,
107
+ width=32,
108
+ skip=True,
109
+ skipcat=False,
110
+ catorig=False,
111
+ outker=1,
112
+ **kwargs,
113
+ ):
114
+ super().__init__(**kwargs)
115
+
116
+ self.n_channels = n_channels
117
+ self.n_classes = n_classes
118
+ self.N = N
119
+ self.width = width
120
+ self.skip = skip
121
+ self.skipcat = skipcat
122
+ self.catorig = catorig
123
+ self.outker = outker
124
+
125
+ # This is also written into config.json when saving from code,
126
+ # but having it here makes it explicit.
127
+ self.auto_map = {
128
+ "AutoConfig": "unet.UNetConfig",
129
+ "AutoModel": "unet.UNet",
130
+ }
131
+
132
+
133
+ # -------- HF model wrapper around your architecture --------
134
+
135
+ class UNet(PreTrainedModel):
136
+ """
137
+ PreTrainedModel wrapper so AutoModel can construct and load this.
138
+ """
139
+ config_class = UNetConfig
140
+
141
+ def __init__(self, config: UNetConfig):
142
+ super().__init__(config)
143
+
144
+ n_channels = config.n_channels
145
+ n_classes = config.n_classes
146
+ N = config.N
147
+ width = config.width
148
+ skip = config.skip
149
+ skipcat = config.skipcat
150
+ self.catorig = config.catorig
151
+ outker = config.outker
152
+
153
+ self.n_channels = n_channels
154
+ self.n_classes = n_classes
155
+
156
+ self.inc = NCBR(self.n_channels, 2 * width, N=N, skip=skip)
157
+ self.down1 = DownNCBR(2 * width, 2 * width if skipcat else 4 * width, N=N, skip=skip, skipcat=skipcat)
158
+ self.down2 = DownNCBR(4 * width, 4 * width if skipcat else 8 * width, N=N, skip=skip, skipcat=skipcat)
159
+ self.down3 = DownNCBR(8 * width, 8 * width if skipcat else 16 * width, N=N, skip=skip, skipcat=skipcat)
160
+ self.down4 = DownNCBR(16 * width, 16 * width if skipcat else 32 * width, N=N, skip=skip, skipcat=skipcat)
161
+ self.up1 = UpNCBR(32 * width, 8 * width if skipcat else 16 * width, N=N, skip=skip, skipcat=skipcat)
162
+ self.up2 = UpNCBR(16 * width, 4 * width if skipcat else 8 * width, N=N, skip=skip, skipcat=skipcat)
163
+ self.up3 = UpNCBR(8 * width, 2 * width if skipcat else 4 * width, N=N, skip=skip, skipcat=skipcat)
164
+ self.up4 = UpNCBR(4 * width, width if skipcat else 2 * width, N=N, skip=skip, skipcat=skipcat)
165
+
166
+ if self.catorig:
167
+ self.outc = OutConv(2 * width + self.n_channels, self.n_classes, kernel_size=outker)
168
+ else:
169
+ self.outc = OutConv(2 * width, self.n_classes, kernel_size=outker)
170
+
171
+ # HF weight init hook
172
+ self.post_init()
173
+
174
+ def forward(self, x):
175
+ orig = x
176
+ x1 = self.inc(x)
177
+ x2 = self.down1(x1)
178
+ x3 = self.down2(x2)
179
+ x4 = self.down3(x3)
180
+ x5 = self.down4(x4)
181
+ x = self.up1(x5, x4)
182
+ x = self.up2(x, x3)
183
+ x = self.up3(x, x2)
184
+ x = self.up4(x, x1)
185
+ if self.catorig:
186
+ logits = self.outc(torch.cat([x, orig], dim=1))
187
+ else:
188
+ logits = self.outc(x)
189
+ return logits