Spaces:
Sleeping
Sleeping
refactor: add implementations
Browse files- res/impl/DeepLabV3Plus.py +404 -0
- res/impl/FCN.py +130 -0
- res/impl/HRNetV2.py +378 -0
- res/impl/PSPNet.py +240 -0
- res/impl/SETR.py +291 -0
- res/impl/SegFormer.py +294 -0
- res/impl/UNet3PlusDeepSup.py +241 -0
- res/models/hrnetv2/best_config.json +0 -151
res/impl/DeepLabV3Plus.py
ADDED
|
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
paper: https://arxiv.org/abs/1802.02611
|
| 3 |
+
ref:
|
| 4 |
+
- https://github.com/tensorflow/models/tree/master/research/deeplab
|
| 5 |
+
- https://github.com/VainF/DeepLabV3Plus-Pytorch
|
| 6 |
+
- https://github.com/Hyunjulie/KR-Reading-Computer-Vision-Papers/blob/master/DeepLabv3%2B/deeplabv3p.py
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import math
|
| 10 |
+
import torch
|
| 11 |
+
from torch import nn
|
| 12 |
+
from torch.functional import F
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class AtrousSeparableConv1d(nn.Module):
|
| 16 |
+
def __init__(
|
| 17 |
+
self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False
|
| 18 |
+
):
|
| 19 |
+
super(AtrousSeparableConv1d, self).__init__()
|
| 20 |
+
|
| 21 |
+
self.depthwise = nn.Conv1d(
|
| 22 |
+
inplanes,
|
| 23 |
+
inplanes,
|
| 24 |
+
kernel_size,
|
| 25 |
+
stride,
|
| 26 |
+
0,
|
| 27 |
+
dilation,
|
| 28 |
+
groups=inplanes,
|
| 29 |
+
bias=bias,
|
| 30 |
+
)
|
| 31 |
+
self.pointwise = nn.Conv1d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias)
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
x = self.apply_fixed_padding(
|
| 35 |
+
x, self.depthwise.kernel_size[0], rate=self.depthwise.dilation[0]
|
| 36 |
+
)
|
| 37 |
+
x = self.depthwise(x)
|
| 38 |
+
x = self.pointwise(x)
|
| 39 |
+
return x
|
| 40 |
+
|
| 41 |
+
def apply_fixed_padding(self, inputs, kernel_size, rate):
|
| 42 |
+
"""
|
| 43 |
+
해당 함수는 (dilation)rate 와 kernel_size 에 따라 output 의 크기가 input 의 크기와 동일해질 수 있도록 input 에 padding 을 적용합니다.
|
| 44 |
+
다만, stride 가 2 이상인 경우에는 해당 함수를 거치더라도 input 과 output 크기가 동일해지지 않을 수 있습니다.
|
| 45 |
+
이 경우는 최대한 input 과 output 크기를 맞춰주는 것에 의미가 있고, 전체 네트워크의 마지막 upsample 단계에서 최종적으로 크기를 맞춰줍니다.
|
| 46 |
+
"""
|
| 47 |
+
kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1)
|
| 48 |
+
pad_total = kernel_size_effective - 1
|
| 49 |
+
pad_beg = pad_total // 2
|
| 50 |
+
pad_end = pad_total - pad_beg
|
| 51 |
+
padded_inputs = F.pad(inputs, (pad_beg, pad_end))
|
| 52 |
+
return padded_inputs
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class Block(nn.Module):
|
| 56 |
+
def __init__(
|
| 57 |
+
self,
|
| 58 |
+
inplanes,
|
| 59 |
+
planes,
|
| 60 |
+
reps,
|
| 61 |
+
kernel_size=3,
|
| 62 |
+
stride=1,
|
| 63 |
+
dilation=1,
|
| 64 |
+
start_with_relu=True,
|
| 65 |
+
grow_first=True,
|
| 66 |
+
is_last=False,
|
| 67 |
+
):
|
| 68 |
+
super(Block, self).__init__()
|
| 69 |
+
|
| 70 |
+
if planes != inplanes or stride != 1:
|
| 71 |
+
self.skip = nn.Conv1d(inplanes, planes, 1, stride=stride, bias=False)
|
| 72 |
+
self.skipbn = nn.BatchNorm1d(planes)
|
| 73 |
+
else:
|
| 74 |
+
self.skip = None
|
| 75 |
+
|
| 76 |
+
self.relu = nn.ReLU(inplace=True)
|
| 77 |
+
rep = []
|
| 78 |
+
|
| 79 |
+
filters = inplanes
|
| 80 |
+
if grow_first:
|
| 81 |
+
rep.append(self.relu)
|
| 82 |
+
rep.append(
|
| 83 |
+
AtrousSeparableConv1d(
|
| 84 |
+
inplanes, planes, kernel_size, stride=1, dilation=dilation
|
| 85 |
+
)
|
| 86 |
+
)
|
| 87 |
+
rep.append(nn.BatchNorm1d(planes))
|
| 88 |
+
filters = planes
|
| 89 |
+
|
| 90 |
+
for _ in range(reps - 1):
|
| 91 |
+
rep.append(self.relu)
|
| 92 |
+
rep.append(
|
| 93 |
+
AtrousSeparableConv1d(
|
| 94 |
+
filters, filters, kernel_size, stride=1, dilation=dilation
|
| 95 |
+
)
|
| 96 |
+
)
|
| 97 |
+
rep.append(nn.BatchNorm1d(filters))
|
| 98 |
+
|
| 99 |
+
if not grow_first:
|
| 100 |
+
rep.append(self.relu)
|
| 101 |
+
rep.append(
|
| 102 |
+
AtrousSeparableConv1d(
|
| 103 |
+
inplanes, planes, kernel_size, stride=1, dilation=dilation
|
| 104 |
+
)
|
| 105 |
+
)
|
| 106 |
+
rep.append(nn.BatchNorm1d(planes))
|
| 107 |
+
|
| 108 |
+
if not start_with_relu:
|
| 109 |
+
rep = rep[1:]
|
| 110 |
+
|
| 111 |
+
if stride == 2:
|
| 112 |
+
rep.append(AtrousSeparableConv1d(planes, planes, kernel_size, stride=2))
|
| 113 |
+
elif stride == 1:
|
| 114 |
+
if is_last:
|
| 115 |
+
rep.append(AtrousSeparableConv1d(planes, planes, kernel_size, stride=1))
|
| 116 |
+
else:
|
| 117 |
+
raise NotImplementedError("stride must be 1 or 2 in Block.")
|
| 118 |
+
|
| 119 |
+
self.rep = nn.Sequential(*rep)
|
| 120 |
+
|
| 121 |
+
def forward(self, inp):
|
| 122 |
+
x = self.rep(inp)
|
| 123 |
+
|
| 124 |
+
if self.skip is not None:
|
| 125 |
+
skip = self.skip(inp)
|
| 126 |
+
skip = self.skipbn(skip)
|
| 127 |
+
else:
|
| 128 |
+
skip = inp
|
| 129 |
+
|
| 130 |
+
x += skip
|
| 131 |
+
|
| 132 |
+
return x
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class Xception(nn.Module):
|
| 136 |
+
"""Modified Aligned Xception"""
|
| 137 |
+
|
| 138 |
+
def __init__(
|
| 139 |
+
self,
|
| 140 |
+
inplanes=1,
|
| 141 |
+
output_stride=16,
|
| 142 |
+
kernel_size=3,
|
| 143 |
+
middle_repeat=16,
|
| 144 |
+
middle_block_rate=1,
|
| 145 |
+
exit_block_rates=(1, 2),
|
| 146 |
+
):
|
| 147 |
+
super(Xception, self).__init__()
|
| 148 |
+
|
| 149 |
+
if output_stride == 16:
|
| 150 |
+
entry3_stride = 2
|
| 151 |
+
elif output_stride == 8:
|
| 152 |
+
entry3_stride = 1
|
| 153 |
+
else:
|
| 154 |
+
raise NotImplementedError
|
| 155 |
+
|
| 156 |
+
self.conv1 = nn.Conv1d(
|
| 157 |
+
inplanes,
|
| 158 |
+
32,
|
| 159 |
+
kernel_size,
|
| 160 |
+
stride=2,
|
| 161 |
+
padding=(kernel_size - 1) // 2,
|
| 162 |
+
bias=False,
|
| 163 |
+
)
|
| 164 |
+
self.bn1 = nn.BatchNorm1d(32)
|
| 165 |
+
self.relu = nn.ReLU(inplace=True)
|
| 166 |
+
|
| 167 |
+
self.conv2 = nn.Conv1d(
|
| 168 |
+
32, 64, kernel_size, stride=1, padding=(kernel_size - 1) // 2, bias=False
|
| 169 |
+
)
|
| 170 |
+
self.bn2 = nn.BatchNorm1d(64)
|
| 171 |
+
|
| 172 |
+
self.entry1 = Block(
|
| 173 |
+
64, 128, reps=2, kernel_size=kernel_size, stride=2, start_with_relu=False
|
| 174 |
+
)
|
| 175 |
+
self.entry2 = Block(
|
| 176 |
+
128,
|
| 177 |
+
256,
|
| 178 |
+
reps=2,
|
| 179 |
+
kernel_size=kernel_size,
|
| 180 |
+
stride=2,
|
| 181 |
+
start_with_relu=True,
|
| 182 |
+
grow_first=True,
|
| 183 |
+
)
|
| 184 |
+
self.entry3 = Block(
|
| 185 |
+
256,
|
| 186 |
+
728,
|
| 187 |
+
reps=2,
|
| 188 |
+
kernel_size=kernel_size,
|
| 189 |
+
stride=entry3_stride,
|
| 190 |
+
start_with_relu=True,
|
| 191 |
+
grow_first=True,
|
| 192 |
+
is_last=True,
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
self.middle = nn.Sequential(
|
| 196 |
+
*[
|
| 197 |
+
Block(
|
| 198 |
+
728,
|
| 199 |
+
728,
|
| 200 |
+
reps=3,
|
| 201 |
+
kernel_size=kernel_size,
|
| 202 |
+
stride=1,
|
| 203 |
+
dilation=middle_block_rate,
|
| 204 |
+
start_with_relu=True,
|
| 205 |
+
grow_first=True,
|
| 206 |
+
)
|
| 207 |
+
for _ in range(middle_repeat)
|
| 208 |
+
]
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
self.exit = Block(
|
| 212 |
+
728,
|
| 213 |
+
1024,
|
| 214 |
+
reps=2,
|
| 215 |
+
kernel_size=kernel_size,
|
| 216 |
+
stride=1,
|
| 217 |
+
dilation=exit_block_rates[0],
|
| 218 |
+
start_with_relu=True,
|
| 219 |
+
grow_first=False,
|
| 220 |
+
is_last=True,
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
self.conv3 = AtrousSeparableConv1d(
|
| 224 |
+
1024, 1536, kernel_size, stride=1, dilation=exit_block_rates[1]
|
| 225 |
+
)
|
| 226 |
+
self.bn3 = nn.BatchNorm1d(1536)
|
| 227 |
+
|
| 228 |
+
self.conv4 = AtrousSeparableConv1d(
|
| 229 |
+
1536, 1536, kernel_size, stride=1, dilation=exit_block_rates[1]
|
| 230 |
+
)
|
| 231 |
+
self.bn4 = nn.BatchNorm1d(1536)
|
| 232 |
+
|
| 233 |
+
self.conv5 = AtrousSeparableConv1d(
|
| 234 |
+
1536, 2048, kernel_size, stride=1, dilation=exit_block_rates[1]
|
| 235 |
+
)
|
| 236 |
+
self.bn5 = nn.BatchNorm1d(2048)
|
| 237 |
+
|
| 238 |
+
def forward(self, x: torch.Tensor):
|
| 239 |
+
x = self.conv1(x)
|
| 240 |
+
x = self.bn1(x)
|
| 241 |
+
x = self.relu(x)
|
| 242 |
+
|
| 243 |
+
x = self.conv2(x)
|
| 244 |
+
x = self.bn2(x)
|
| 245 |
+
x = self.relu(x)
|
| 246 |
+
|
| 247 |
+
low_level = x = self.entry1(x)
|
| 248 |
+
|
| 249 |
+
x = self.entry2(x)
|
| 250 |
+
x = self.entry3(x)
|
| 251 |
+
|
| 252 |
+
x = self.middle(x)
|
| 253 |
+
|
| 254 |
+
x = self.exit(x)
|
| 255 |
+
x = self.conv3(x)
|
| 256 |
+
x = self.bn3(x)
|
| 257 |
+
x = self.relu(x)
|
| 258 |
+
|
| 259 |
+
x = self.conv4(x)
|
| 260 |
+
x = self.bn4(x)
|
| 261 |
+
x = self.relu(x)
|
| 262 |
+
|
| 263 |
+
x = self.conv5(x)
|
| 264 |
+
x = self.bn5(x)
|
| 265 |
+
x = self.relu(x)
|
| 266 |
+
|
| 267 |
+
return x, low_level
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
class ASPP(nn.Module):
|
| 271 |
+
"""Atrous Spatial Pyramid Pooling"""
|
| 272 |
+
|
| 273 |
+
def __init__(self, inplanes, planes, rate, kernel_size=3):
|
| 274 |
+
super(ASPP, self).__init__()
|
| 275 |
+
if rate == 1:
|
| 276 |
+
kernel_size = 1
|
| 277 |
+
padding = 0
|
| 278 |
+
else:
|
| 279 |
+
padding = rate * (kernel_size - 1) // 2
|
| 280 |
+
self.atrous_convolution = nn.Conv1d(
|
| 281 |
+
inplanes,
|
| 282 |
+
planes,
|
| 283 |
+
kernel_size=kernel_size,
|
| 284 |
+
stride=1,
|
| 285 |
+
padding=padding,
|
| 286 |
+
dilation=rate,
|
| 287 |
+
bias=False,
|
| 288 |
+
)
|
| 289 |
+
self.bn = nn.BatchNorm1d(planes)
|
| 290 |
+
self.relu = nn.ReLU()
|
| 291 |
+
|
| 292 |
+
def forward(self, x):
|
| 293 |
+
x = self.atrous_convolution(x)
|
| 294 |
+
x = self.bn(x)
|
| 295 |
+
|
| 296 |
+
return self.relu(x)
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
class DeepLabV3Plus(nn.Module):
|
| 300 |
+
def __init__(self, config):
|
| 301 |
+
super(DeepLabV3Plus, self).__init__()
|
| 302 |
+
|
| 303 |
+
self.config = config
|
| 304 |
+
# output_stride: (input's spatial resolution / output's resolution)
|
| 305 |
+
output_stride = int(config.output_stride)
|
| 306 |
+
kernel_size = int(config.kernel_size)
|
| 307 |
+
middle_block_rate = int(config.middle_block_rate)
|
| 308 |
+
exit_block_rates: list = config.exit_block_rates
|
| 309 |
+
middle_repeat = int(config.middle_repeat)
|
| 310 |
+
self.interpolate_mode = str(config.interpolate_mode)
|
| 311 |
+
aspp_channel = int(config.aspp_channel)
|
| 312 |
+
aspp_rate: list = config.aspp_rate
|
| 313 |
+
output_size = config.output_size # 3(p, qrs, t)
|
| 314 |
+
|
| 315 |
+
self.xception_features = Xception(
|
| 316 |
+
output_stride=output_stride,
|
| 317 |
+
kernel_size=kernel_size,
|
| 318 |
+
middle_repeat=middle_repeat,
|
| 319 |
+
middle_block_rate=middle_block_rate,
|
| 320 |
+
exit_block_rates=exit_block_rates,
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
# ASPP
|
| 324 |
+
self.aspp1 = ASPP(
|
| 325 |
+
2048, aspp_channel, rate=aspp_rate[0], kernel_size=kernel_size
|
| 326 |
+
)
|
| 327 |
+
self.aspp2 = ASPP(
|
| 328 |
+
2048, aspp_channel, rate=aspp_rate[1], kernel_size=kernel_size
|
| 329 |
+
)
|
| 330 |
+
self.aspp3 = ASPP(
|
| 331 |
+
2048, aspp_channel, rate=aspp_rate[2], kernel_size=kernel_size
|
| 332 |
+
)
|
| 333 |
+
self.aspp4 = ASPP(
|
| 334 |
+
2048, aspp_channel, rate=aspp_rate[3], kernel_size=kernel_size
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
self.relu = nn.ReLU()
|
| 338 |
+
|
| 339 |
+
self.global_avg_pool = nn.Sequential(
|
| 340 |
+
nn.AdaptiveAvgPool1d(1),
|
| 341 |
+
nn.Conv1d(2048, aspp_channel, 1, stride=1, bias=False),
|
| 342 |
+
nn.BatchNorm1d(aspp_channel),
|
| 343 |
+
nn.ReLU(),
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
self.conv1 = nn.Conv1d(aspp_channel * 5, aspp_channel, 1, bias=False)
|
| 347 |
+
self.bn1 = nn.BatchNorm1d(aspp_channel)
|
| 348 |
+
|
| 349 |
+
# adopt [1x1, 48] for channel reduction.
|
| 350 |
+
self.conv2 = nn.Conv1d(128, 48, 1, bias=False)
|
| 351 |
+
self.bn2 = nn.BatchNorm1d(48)
|
| 352 |
+
|
| 353 |
+
self.last_conv = nn.Sequential(
|
| 354 |
+
nn.Conv1d(
|
| 355 |
+
aspp_channel + 48,
|
| 356 |
+
256,
|
| 357 |
+
kernel_size=kernel_size,
|
| 358 |
+
stride=1,
|
| 359 |
+
padding=(kernel_size - 1) // 2,
|
| 360 |
+
bias=False,
|
| 361 |
+
),
|
| 362 |
+
nn.BatchNorm1d(256),
|
| 363 |
+
nn.ReLU(),
|
| 364 |
+
nn.Conv1d(
|
| 365 |
+
256,
|
| 366 |
+
256,
|
| 367 |
+
kernel_size=kernel_size,
|
| 368 |
+
stride=1,
|
| 369 |
+
padding=(kernel_size - 1) // 2,
|
| 370 |
+
bias=False,
|
| 371 |
+
),
|
| 372 |
+
nn.BatchNorm1d(256),
|
| 373 |
+
nn.ReLU(),
|
| 374 |
+
nn.Conv1d(256, output_size, kernel_size=1, stride=1),
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
def forward(self, input):
|
| 378 |
+
x, low_level_features = self.xception_features(input)
|
| 379 |
+
|
| 380 |
+
x1 = self.aspp1(x)
|
| 381 |
+
x2 = self.aspp2(x)
|
| 382 |
+
x3 = self.aspp3(x)
|
| 383 |
+
x4 = self.aspp4(x)
|
| 384 |
+
x5 = self.global_avg_pool(x)
|
| 385 |
+
x5 = F.interpolate(x5, size=x4.shape[2:], mode=self.interpolate_mode)
|
| 386 |
+
|
| 387 |
+
x = torch.cat((x1, x2, x3, x4, x5), dim=1)
|
| 388 |
+
|
| 389 |
+
x = self.conv1(x)
|
| 390 |
+
x = self.bn1(x)
|
| 391 |
+
x = self.relu(x)
|
| 392 |
+
x = F.interpolate(
|
| 393 |
+
x, size=int(math.ceil(input.shape[-1] / 4)), mode=self.interpolate_mode
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
low_level_features = self.conv2(low_level_features)
|
| 397 |
+
low_level_features = self.bn2(low_level_features)
|
| 398 |
+
low_level_features = self.relu(low_level_features)
|
| 399 |
+
|
| 400 |
+
x = torch.cat((x, low_level_features), dim=1)
|
| 401 |
+
x = self.last_conv(x)
|
| 402 |
+
x = F.interpolate(x, size=input.shape[2:], mode=self.interpolate_mode)
|
| 403 |
+
|
| 404 |
+
return x
|
res/impl/FCN.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
paper: https://arxiv.org/abs/1605.06211
|
| 3 |
+
ref: https://github.com/shelhamer/fcn.berkeleyvision.org/blob/master/voc-fcn8s/net.py
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class FCN(nn.Module):
|
| 11 |
+
def __init__(self, config):
|
| 12 |
+
super().__init__()
|
| 13 |
+
|
| 14 |
+
self.config = config
|
| 15 |
+
self.kernel_size = int(config.kernel_size)
|
| 16 |
+
last_layer_kernel_size = int(config.last_layer_kernel_size)
|
| 17 |
+
inplanes = int(config.inplanes)
|
| 18 |
+
combine_conf: dict = config.combine_conf
|
| 19 |
+
self.num_layers = int(combine_conf["num_layers"])
|
| 20 |
+
self.first_padding = {6: 240, 5: 130, 4: 80}[self.num_layers]
|
| 21 |
+
self.num_convs = int(config.num_convs)
|
| 22 |
+
self.dilation = int(config.dilation)
|
| 23 |
+
self.combine_until = int(combine_conf["combine_until"])
|
| 24 |
+
assert self.combine_until < self.num_layers
|
| 25 |
+
dropout = float(config.dropout)
|
| 26 |
+
output_size = config.output_size # 3(p, qrs, t)
|
| 27 |
+
|
| 28 |
+
self.layers = nn.ModuleList()
|
| 29 |
+
for i in range(self.num_layers):
|
| 30 |
+
self.layers.append(
|
| 31 |
+
self._make_layer(
|
| 32 |
+
1 if i == 0 else inplanes * (2 ** (i - 1)),
|
| 33 |
+
inplanes * (2 ** (i)),
|
| 34 |
+
is_first=True if i == 0 else False,
|
| 35 |
+
)
|
| 36 |
+
)
|
| 37 |
+
# pool 단계가 없는 마지막 conv layer로 다른 layer 와 다르게 conv 개수(2)와 channel이 고정이고, dropout을 수행
|
| 38 |
+
self.layers.append(
|
| 39 |
+
nn.Sequential(
|
| 40 |
+
nn.Conv1d(inplanes * (2 ** (i)), 4096, last_layer_kernel_size),
|
| 41 |
+
nn.BatchNorm1d(4096),
|
| 42 |
+
nn.ReLU(),
|
| 43 |
+
nn.Dropout(dropout),
|
| 44 |
+
nn.Conv1d(4096, 4096, 1),
|
| 45 |
+
nn.BatchNorm1d(4096),
|
| 46 |
+
nn.ReLU(),
|
| 47 |
+
nn.Dropout(dropout),
|
| 48 |
+
)
|
| 49 |
+
)
|
| 50 |
+
self.score_convs = []
|
| 51 |
+
self.up_convs = []
|
| 52 |
+
for i in range(self.combine_until, self.num_layers - 1):
|
| 53 |
+
# pool 결과를 combine 하는 만큼만 score_convs 와 up_convs 가 생성됨
|
| 54 |
+
self.score_convs.append(
|
| 55 |
+
nn.Conv1d(inplanes * (2 ** (i)), output_size, kernel_size=1, bias=False)
|
| 56 |
+
)
|
| 57 |
+
self.up_convs.append(
|
| 58 |
+
nn.ConvTranspose1d(output_size, output_size, kernel_size=4, stride=2)
|
| 59 |
+
)
|
| 60 |
+
# pool 이 없는 마지막 convs 결과에 수행하는 score_convs
|
| 61 |
+
# self.score_convs 는 항상 self.up_convs 의 개수보다 1개 더 많음
|
| 62 |
+
self.score_convs.append(nn.Conv1d(4096, output_size, kernel_size=1, bias=False))
|
| 63 |
+
|
| 64 |
+
self.score_convs.reverse()
|
| 65 |
+
self.score_convs = nn.ModuleList(self.score_convs)
|
| 66 |
+
self.up_convs = nn.ModuleList(self.up_convs)
|
| 67 |
+
self.last_up_convs = nn.ConvTranspose1d(
|
| 68 |
+
output_size,
|
| 69 |
+
output_size,
|
| 70 |
+
kernel_size=2 ** (self.combine_until + 1) * 2, # stride * 2
|
| 71 |
+
stride=2 ** (self.combine_until + 1),
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
def _make_layer(
|
| 75 |
+
self,
|
| 76 |
+
in_channel: int,
|
| 77 |
+
out_channel: int,
|
| 78 |
+
is_first: bool = False,
|
| 79 |
+
):
|
| 80 |
+
layer = []
|
| 81 |
+
plane = in_channel
|
| 82 |
+
for idx in range(self.num_convs):
|
| 83 |
+
layer.append(
|
| 84 |
+
nn.Conv1d(
|
| 85 |
+
plane,
|
| 86 |
+
out_channel,
|
| 87 |
+
kernel_size=self.kernel_size,
|
| 88 |
+
padding=self.first_padding
|
| 89 |
+
if idx == 0 and is_first
|
| 90 |
+
else (self.dilation * (self.kernel_size - 1)) // 2,
|
| 91 |
+
dilation=self.dilation,
|
| 92 |
+
bias=False,
|
| 93 |
+
)
|
| 94 |
+
)
|
| 95 |
+
layer.append(nn.BatchNorm1d(out_channel))
|
| 96 |
+
layer.append(nn.ReLU())
|
| 97 |
+
plane = out_channel
|
| 98 |
+
|
| 99 |
+
layer.append(nn.MaxPool1d(2, 2, ceil_mode=True))
|
| 100 |
+
return nn.Sequential(*layer)
|
| 101 |
+
|
| 102 |
+
def forward(self, input: torch.Tensor, y=None):
|
| 103 |
+
output: torch.Tensor = input
|
| 104 |
+
|
| 105 |
+
pools = []
|
| 106 |
+
for idx, layer in enumerate(self.layers):
|
| 107 |
+
output = layer(output)
|
| 108 |
+
if self.combine_until <= idx < (self.num_layers - 1):
|
| 109 |
+
pools.append(output)
|
| 110 |
+
pools.reverse()
|
| 111 |
+
|
| 112 |
+
output = self.score_convs[0](output)
|
| 113 |
+
if len(pools) > 0:
|
| 114 |
+
output = self.up_convs[0](output)
|
| 115 |
+
for i in range(len(pools)):
|
| 116 |
+
score_pool = self.score_convs[i + 1](pools[i])
|
| 117 |
+
offset = (score_pool.shape[2] - output.shape[2]) // 2
|
| 118 |
+
cropped_score_pool = torch.tensor_split(
|
| 119 |
+
score_pool, (offset, offset + output.shape[2]), dim=2
|
| 120 |
+
)[1]
|
| 121 |
+
output = torch.add(cropped_score_pool, output)
|
| 122 |
+
if i < len(pools) - 1: # 마지막 up_conv 는 last_up_convs 이용
|
| 123 |
+
output = self.up_convs[i + 1](output)
|
| 124 |
+
output = self.last_up_convs(output)
|
| 125 |
+
|
| 126 |
+
offset = (output.shape[2] - input.shape[2]) // 2
|
| 127 |
+
cropped_score_pool = torch.tensor_split(
|
| 128 |
+
output, (offset, offset + input.shape[2]), dim=2
|
| 129 |
+
)[1]
|
| 130 |
+
return cropped_score_pool
|
res/impl/HRNetV2.py
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
paper: https://arxiv.org/abs/1904.04514
|
| 3 |
+
ref: https://github.com/HRNet/HRNet-Semantic-Segmentation/blob/HRNet-OCR/lib/models/seg_hrnet.py
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from torch.functional import F
|
| 9 |
+
import math
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _gen_same_length_conv(in_channel, out_channel, kernel_size=1, dilation=1):
|
| 13 |
+
"""길이가 변하지 않는 conv 생성, block 내에서 feature 를 추출하는 convolution 에서 사용"""
|
| 14 |
+
return nn.Conv1d(
|
| 15 |
+
in_channel,
|
| 16 |
+
out_channel,
|
| 17 |
+
kernel_size=kernel_size,
|
| 18 |
+
stride=1,
|
| 19 |
+
padding=(dilation * (kernel_size - 1)) // 2,
|
| 20 |
+
dilation=dilation,
|
| 21 |
+
bias=False,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _gen_downsample(in_channel, out_channel):
|
| 26 |
+
"""kernel_size:3, stride:2, padding:1 인 2배 downsample 하는 conv 생성"""
|
| 27 |
+
return nn.Conv1d(
|
| 28 |
+
in_channel, out_channel, kernel_size=3, stride=2, padding=1, bias=False
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _gen_channel_change_conv(in_channel, out_channel):
|
| 33 |
+
"""kernel_size:1, stride:1 인 channel 변경하는 conv 생성"""
|
| 34 |
+
return nn.Conv1d(in_channel, out_channel, kernel_size=1, stride=1, bias=False)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class BasicBlock(nn.Module):
|
| 38 |
+
"""resnet 의 basic block 으로 channel 변화는 inplanes -> planes"""
|
| 39 |
+
|
| 40 |
+
expansion = 1
|
| 41 |
+
|
| 42 |
+
def __init__(self, inplanes, planes, kernel_size=3, dilation=1):
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.conv1 = _gen_same_length_conv(inplanes, planes, kernel_size, dilation)
|
| 45 |
+
self.bn1 = nn.BatchNorm1d(planes)
|
| 46 |
+
self.relu = nn.ReLU()
|
| 47 |
+
self.conv2 = _gen_same_length_conv(planes, planes, kernel_size, dilation)
|
| 48 |
+
self.bn2 = nn.BatchNorm1d(planes)
|
| 49 |
+
self.make_residual = (
|
| 50 |
+
_gen_channel_change_conv(inplanes, planes)
|
| 51 |
+
if inplanes != planes
|
| 52 |
+
else nn.Identity()
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
def forward(self, x):
|
| 56 |
+
out = self.conv1(x)
|
| 57 |
+
out = self.bn1(out)
|
| 58 |
+
out = self.relu(out)
|
| 59 |
+
|
| 60 |
+
out = self.conv2(out)
|
| 61 |
+
out = self.bn2(out)
|
| 62 |
+
|
| 63 |
+
residual = self.make_residual(x)
|
| 64 |
+
|
| 65 |
+
out = out + residual
|
| 66 |
+
out = self.relu(out)
|
| 67 |
+
|
| 68 |
+
return out
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class Bottleneck(nn.Module):
|
| 72 |
+
"""resnet 의 Bottleneck block 으로 channel 변화는 inplanes -> planes * 4"""
|
| 73 |
+
|
| 74 |
+
expansion = 4
|
| 75 |
+
|
| 76 |
+
def __init__(self, inplanes, planes, kernel_size=3, dilation=1):
|
| 77 |
+
super().__init__()
|
| 78 |
+
self.conv1 = _gen_same_length_conv(inplanes, planes)
|
| 79 |
+
self.bn1 = nn.BatchNorm1d(planes)
|
| 80 |
+
self.conv2 = _gen_same_length_conv(planes, planes, kernel_size, dilation)
|
| 81 |
+
self.bn2 = nn.BatchNorm1d(planes)
|
| 82 |
+
self.conv3 = _gen_same_length_conv(planes, planes * self.expansion)
|
| 83 |
+
self.bn3 = nn.BatchNorm1d(planes * self.expansion)
|
| 84 |
+
self.relu = nn.ReLU()
|
| 85 |
+
self.make_residual = (
|
| 86 |
+
_gen_channel_change_conv(inplanes, planes * self.expansion)
|
| 87 |
+
if inplanes != planes * self.expansion
|
| 88 |
+
else nn.Identity()
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
def forward(self, x):
|
| 92 |
+
out = self.conv1(x)
|
| 93 |
+
out = self.bn1(out)
|
| 94 |
+
out = self.relu(out)
|
| 95 |
+
|
| 96 |
+
out = self.conv2(out)
|
| 97 |
+
out = self.bn2(out)
|
| 98 |
+
out = self.relu(out)
|
| 99 |
+
|
| 100 |
+
out = self.conv3(out)
|
| 101 |
+
out = self.bn3(out)
|
| 102 |
+
|
| 103 |
+
residual = self.make_residual(x)
|
| 104 |
+
|
| 105 |
+
out = out + residual
|
| 106 |
+
out = self.relu(out)
|
| 107 |
+
|
| 108 |
+
return out
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class HRModule(nn.Module):
|
| 112 |
+
def __init__(
|
| 113 |
+
self,
|
| 114 |
+
stage_idx,
|
| 115 |
+
num_blocks,
|
| 116 |
+
block_type_by_stage,
|
| 117 |
+
in_channels_by_stage,
|
| 118 |
+
out_channels_by_stage,
|
| 119 |
+
data_len_by_branch,
|
| 120 |
+
kernel_size,
|
| 121 |
+
dilation,
|
| 122 |
+
interpolate_mode,
|
| 123 |
+
):
|
| 124 |
+
super().__init__()
|
| 125 |
+
|
| 126 |
+
self.branches = nn.ModuleList()
|
| 127 |
+
self.fusions = nn.ModuleList()
|
| 128 |
+
|
| 129 |
+
block_type: BasicBlock | Bottleneck = block_type_by_stage[stage_idx]
|
| 130 |
+
in_channels = in_channels_by_stage[stage_idx]
|
| 131 |
+
for i in range(stage_idx + 1): # branch 생성
|
| 132 |
+
blocks_by_branch = []
|
| 133 |
+
_channels = in_channels[i]
|
| 134 |
+
blocks_by_branch.append(
|
| 135 |
+
block_type(_channels, _channels, kernel_size, dilation)
|
| 136 |
+
)
|
| 137 |
+
for _ in range(1, num_blocks):
|
| 138 |
+
blocks_by_branch.append(
|
| 139 |
+
block_type(
|
| 140 |
+
_channels * block_type.expansion,
|
| 141 |
+
_channels,
|
| 142 |
+
kernel_size,
|
| 143 |
+
dilation,
|
| 144 |
+
)
|
| 145 |
+
)
|
| 146 |
+
self.branches.append(nn.Sequential(*blocks_by_branch))
|
| 147 |
+
|
| 148 |
+
out_channels = out_channels_by_stage[stage_idx]
|
| 149 |
+
for i in range(stage_idx + 1):
|
| 150 |
+
fusion_by_branch = nn.ModuleList()
|
| 151 |
+
for j in range(stage_idx + 1):
|
| 152 |
+
if i < j:
|
| 153 |
+
fusion_by_branch.append(
|
| 154 |
+
nn.Sequential(
|
| 155 |
+
_gen_channel_change_conv(out_channels[j], in_channels[i]),
|
| 156 |
+
nn.BatchNorm1d(in_channels[i]),
|
| 157 |
+
nn.Upsample(
|
| 158 |
+
size=data_len_by_branch[i], mode=interpolate_mode
|
| 159 |
+
),
|
| 160 |
+
)
|
| 161 |
+
)
|
| 162 |
+
elif i == j:
|
| 163 |
+
if out_channels[i] != in_channels[j]:
|
| 164 |
+
fusion_by_branch.append(
|
| 165 |
+
nn.Sequential(
|
| 166 |
+
_gen_channel_change_conv(
|
| 167 |
+
out_channels[i], in_channels[j]
|
| 168 |
+
),
|
| 169 |
+
nn.BatchNorm1d(in_channels[j]),
|
| 170 |
+
nn.ReLU(),
|
| 171 |
+
)
|
| 172 |
+
)
|
| 173 |
+
else:
|
| 174 |
+
fusion_by_branch.append(nn.Identity())
|
| 175 |
+
else:
|
| 176 |
+
# 차이나는 branch 만큼 2배씩 downsample, channel 은 현재 layer 의 in_channel 로 맞춰줌
|
| 177 |
+
downsamples = [
|
| 178 |
+
_gen_downsample(out_channels[j], in_channels[i]),
|
| 179 |
+
nn.BatchNorm1d(in_channels[i]),
|
| 180 |
+
]
|
| 181 |
+
for _ in range(1, i - j):
|
| 182 |
+
downsamples.extend(
|
| 183 |
+
[
|
| 184 |
+
nn.ReLU(),
|
| 185 |
+
_gen_downsample(in_channels[i], in_channels[i]),
|
| 186 |
+
nn.BatchNorm1d(in_channels[i]),
|
| 187 |
+
]
|
| 188 |
+
)
|
| 189 |
+
fusion_by_branch.append(nn.Sequential(*downsamples))
|
| 190 |
+
self.fusions.append(fusion_by_branch)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
class HRNetV2(nn.Module):
|
| 194 |
+
def __init__(self, config):
|
| 195 |
+
super().__init__()
|
| 196 |
+
|
| 197 |
+
self.config = config
|
| 198 |
+
data_len = int(config.data_len) # ECGPQRSTDataset.second, hz 에 맞춰서
|
| 199 |
+
kernel_size = int(config.kernel_size)
|
| 200 |
+
dilation = int(config.dilation)
|
| 201 |
+
num_stages = int(config.num_stages)
|
| 202 |
+
num_blocks = int(config.num_blocks)
|
| 203 |
+
self.num_modules = config.num_modules # [1, 1, 4, 3, ..]
|
| 204 |
+
assert num_stages <= len(self.num_modules)
|
| 205 |
+
use_bottleneck = config.use_bottleneck # [1, 0, 0, 0, ..]
|
| 206 |
+
assert num_stages <= len(use_bottleneck)
|
| 207 |
+
stage1_channels = int(config.stage1_channels) # 64, 128
|
| 208 |
+
num_channels_init = int(config.num_channels_init) # 18, 32, 48
|
| 209 |
+
self.interpolate_mode = config.interpolate_mode
|
| 210 |
+
output_size = config.output_size # 3(p, qrs, t)
|
| 211 |
+
|
| 212 |
+
# stem
|
| 213 |
+
self.stem = nn.Sequential(
|
| 214 |
+
nn.Conv1d(
|
| 215 |
+
1, stage1_channels, kernel_size=3, stride=2, padding=1, bias=False
|
| 216 |
+
),
|
| 217 |
+
nn.BatchNorm1d(stage1_channels),
|
| 218 |
+
nn.Conv1d(
|
| 219 |
+
stage1_channels,
|
| 220 |
+
stage1_channels,
|
| 221 |
+
kernel_size=3,
|
| 222 |
+
stride=2,
|
| 223 |
+
padding=1,
|
| 224 |
+
bias=False,
|
| 225 |
+
),
|
| 226 |
+
nn.BatchNorm1d(stage1_channels),
|
| 227 |
+
nn.ReLU(),
|
| 228 |
+
)
|
| 229 |
+
for _ in range(2): # stem 을 거친 이후 데이터 길이 계산
|
| 230 |
+
data_len = math.floor((data_len - 1) / 2 + 1)
|
| 231 |
+
|
| 232 |
+
# create meta: 네트워크 생성 전 각 stage 의 in_channel, out_channel 등의 정보를 먼저 만들고 시작
|
| 233 |
+
in_channels_by_stage = []
|
| 234 |
+
out_channels_by_stage = []
|
| 235 |
+
block_type_by_stage = []
|
| 236 |
+
for stage_idx in range(num_stages):
|
| 237 |
+
block_type_each_stage = (
|
| 238 |
+
Bottleneck if use_bottleneck[stage_idx] == 1 else BasicBlock
|
| 239 |
+
)
|
| 240 |
+
if stage_idx == 0:
|
| 241 |
+
in_channels_each_stage = [stage1_channels]
|
| 242 |
+
out_channels_each_stage = [
|
| 243 |
+
stage1_channels * block_type_each_stage.expansion
|
| 244 |
+
]
|
| 245 |
+
data_len_by_branch = [data_len]
|
| 246 |
+
else:
|
| 247 |
+
in_channels_each_stage = [
|
| 248 |
+
num_channels_init * 2**idx for idx in range(stage_idx + 1)
|
| 249 |
+
]
|
| 250 |
+
out_channels_each_stage = [
|
| 251 |
+
(num_channels_init * 2**idx) * block_type_each_stage.expansion
|
| 252 |
+
for idx in range(stage_idx + 1)
|
| 253 |
+
]
|
| 254 |
+
data_len_by_branch.append(
|
| 255 |
+
math.floor((data_len_by_branch[-1] - 1) / 2 + 1)
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
block_type_by_stage.append(block_type_each_stage)
|
| 259 |
+
in_channels_by_stage.append(in_channels_each_stage)
|
| 260 |
+
out_channels_by_stage.append(out_channels_each_stage)
|
| 261 |
+
|
| 262 |
+
# create stages
|
| 263 |
+
self.stages = nn.ModuleList()
|
| 264 |
+
for stage_idx in range(num_stages):
|
| 265 |
+
modules_by_stage = nn.ModuleList()
|
| 266 |
+
for _ in range(self.num_modules[stage_idx]):
|
| 267 |
+
modules_by_stage.append(
|
| 268 |
+
HRModule(
|
| 269 |
+
stage_idx,
|
| 270 |
+
num_blocks,
|
| 271 |
+
block_type_by_stage,
|
| 272 |
+
in_channels_by_stage,
|
| 273 |
+
out_channels_by_stage,
|
| 274 |
+
data_len_by_branch,
|
| 275 |
+
kernel_size,
|
| 276 |
+
dilation,
|
| 277 |
+
self.interpolate_mode,
|
| 278 |
+
)
|
| 279 |
+
)
|
| 280 |
+
self.stages.append(modules_by_stage)
|
| 281 |
+
|
| 282 |
+
# create transition
|
| 283 |
+
self.transitions = nn.ModuleList()
|
| 284 |
+
for stage_idx in range(num_stages - 1):
|
| 285 |
+
# 여기에서 stage_idx 는 이전 stage 를 뜻함. transition 은 각 stage 사이에서 channel 을 바꿔주거나 새로운 branch 를 생성하는 역할
|
| 286 |
+
transition_by_stage = nn.ModuleList()
|
| 287 |
+
psc = in_channels_by_stage[stage_idx] # psc: prev_stage_channels
|
| 288 |
+
nsc = in_channels_by_stage[stage_idx + 1] # nsc: next_stage_channels
|
| 289 |
+
for nsbi in range(stage_idx + 2): # nsbi: next_stage_branch_idx
|
| 290 |
+
if nsbi < stage_idx + 1: # 동일한 branch level
|
| 291 |
+
if psc[nsbi] != nsc[nsbi]:
|
| 292 |
+
transition_by_stage.append(
|
| 293 |
+
nn.Sequential(
|
| 294 |
+
_gen_channel_change_conv(psc[nsbi], nsc[nsbi]),
|
| 295 |
+
nn.BatchNorm1d(nsc[nsbi]),
|
| 296 |
+
nn.ReLU(),
|
| 297 |
+
)
|
| 298 |
+
)
|
| 299 |
+
else:
|
| 300 |
+
transition_by_stage.append(nn.Identity())
|
| 301 |
+
else: # create new branch from exists branches
|
| 302 |
+
transition_from_branches = nn.ModuleList()
|
| 303 |
+
for psbi in range(nsbi):
|
| 304 |
+
# psbi: prev_stage_branch_idx
|
| 305 |
+
transition_from_one_branch = [
|
| 306 |
+
_gen_downsample(psc[psbi], nsc[nsbi]),
|
| 307 |
+
nn.BatchNorm1d(nsc[nsbi]),
|
| 308 |
+
]
|
| 309 |
+
for _ in range(1, nsbi - psbi):
|
| 310 |
+
transition_from_one_branch.extend(
|
| 311 |
+
[
|
| 312 |
+
nn.ReLU(),
|
| 313 |
+
_gen_downsample(nsc[nsbi], nsc[nsbi]),
|
| 314 |
+
nn.BatchNorm1d(nsc[nsbi]),
|
| 315 |
+
]
|
| 316 |
+
)
|
| 317 |
+
transition_from_branches.append(
|
| 318 |
+
nn.Sequential(*transition_from_one_branch)
|
| 319 |
+
)
|
| 320 |
+
transition_by_stage.append(transition_from_branches)
|
| 321 |
+
self.transitions.append(transition_by_stage)
|
| 322 |
+
|
| 323 |
+
self.cls = nn.Conv1d(sum(in_channels_each_stage), output_size, 1, bias=False)
|
| 324 |
+
|
| 325 |
+
def forward(self, input: torch.Tensor, y=None):
|
| 326 |
+
output: torch.Tensor = input
|
| 327 |
+
|
| 328 |
+
output = self.stem(output)
|
| 329 |
+
|
| 330 |
+
outputs = [output]
|
| 331 |
+
for stage_idx, stage in enumerate(self.stages):
|
| 332 |
+
for module_idx in range(self.num_modules[stage_idx]):
|
| 333 |
+
for branch_idx in range(stage_idx + 1):
|
| 334 |
+
outputs[branch_idx] = stage[module_idx].branches[branch_idx](
|
| 335 |
+
outputs[branch_idx]
|
| 336 |
+
)
|
| 337 |
+
fusion_outputs = []
|
| 338 |
+
for next in range(stage_idx + 1):
|
| 339 |
+
fusion_output_from_branches = []
|
| 340 |
+
for prev in range(stage_idx + 1):
|
| 341 |
+
fusion_output_from_branch: torch.Tensor = stage[
|
| 342 |
+
module_idx
|
| 343 |
+
].fusions[next][prev](outputs[prev])
|
| 344 |
+
fusion_output_from_branches.append(fusion_output_from_branch)
|
| 345 |
+
fusion_outputs.append(sum(fusion_output_from_branches))
|
| 346 |
+
outputs = fusion_outputs
|
| 347 |
+
|
| 348 |
+
if stage_idx < len(self.stages) - 1:
|
| 349 |
+
transition_outputs = []
|
| 350 |
+
for trans_idx, transition in enumerate(self.transitions[stage_idx]):
|
| 351 |
+
# transition 에는 다음 stage 의 branch 개수만큼 Sequential 이나 ModuleList 가 존재
|
| 352 |
+
# 앞의 Sequential 들은 channel 만 다음 stage 에 맞게 변경하거나 기존 그대로 사용 (Identity)
|
| 353 |
+
# 마지막 ModuleList 각 branch 의 fusion 결과들을 downsample 한 결과들로부터 새로운 branch 를 생성
|
| 354 |
+
if trans_idx < stage_idx + 1:
|
| 355 |
+
transition_outputs.append(transition(outputs[trans_idx]))
|
| 356 |
+
else:
|
| 357 |
+
transition_outputs.append(
|
| 358 |
+
sum(
|
| 359 |
+
[
|
| 360 |
+
transition_from_each_branch(output)
|
| 361 |
+
for transition_from_each_branch, output in zip(
|
| 362 |
+
transition, outputs
|
| 363 |
+
)
|
| 364 |
+
]
|
| 365 |
+
)
|
| 366 |
+
)
|
| 367 |
+
outputs = transition_outputs
|
| 368 |
+
|
| 369 |
+
# HRNetV2
|
| 370 |
+
outputs = [
|
| 371 |
+
F.interpolate(output, size=outputs[0].shape[-1], mode=self.interpolate_mode)
|
| 372 |
+
for output in outputs
|
| 373 |
+
]
|
| 374 |
+
output = torch.cat(outputs, dim=1)
|
| 375 |
+
|
| 376 |
+
return F.interpolate(
|
| 377 |
+
self.cls(output), size=input.shape[-1], mode=self.interpolate_mode
|
| 378 |
+
)
|
res/impl/PSPNet.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
paper: https://arxiv.org/abs/1612.01105
|
| 3 |
+
ref:
|
| 4 |
+
- https://github.com/hszhao/PSPNet
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn
|
| 9 |
+
from torch.functional import F
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class PPM(nn.Module):
|
| 13 |
+
"""Pyramid Pooling Module"""
|
| 14 |
+
|
| 15 |
+
def __init__(self, in_dim, reduction_dim, bins, interplate_mode):
|
| 16 |
+
super(PPM, self).__init__()
|
| 17 |
+
self.features = []
|
| 18 |
+
for bin in bins:
|
| 19 |
+
self.features.append(
|
| 20 |
+
nn.Sequential(
|
| 21 |
+
nn.AdaptiveAvgPool1d(bin),
|
| 22 |
+
nn.Conv1d(in_dim, reduction_dim, kernel_size=1, bias=False),
|
| 23 |
+
nn.BatchNorm1d(reduction_dim),
|
| 24 |
+
nn.ReLU(),
|
| 25 |
+
)
|
| 26 |
+
)
|
| 27 |
+
self.features = nn.ModuleList(self.features)
|
| 28 |
+
self.interplate_mode = interplate_mode
|
| 29 |
+
|
| 30 |
+
def forward(self, x: torch.Tensor):
|
| 31 |
+
x_size = x.size()
|
| 32 |
+
out = [x]
|
| 33 |
+
for f in self.features:
|
| 34 |
+
out.append(F.interpolate(f(x), x_size[2], mode=self.interplate_mode))
|
| 35 |
+
return torch.cat(out, dim=1)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class Bottleneck(nn.Module):
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
inplanes,
|
| 42 |
+
planes,
|
| 43 |
+
expansion=4,
|
| 44 |
+
kernel_size=3,
|
| 45 |
+
stride=1,
|
| 46 |
+
dilation=1,
|
| 47 |
+
padding=1,
|
| 48 |
+
downsample=None,
|
| 49 |
+
):
|
| 50 |
+
super(Bottleneck, self).__init__()
|
| 51 |
+
self.conv1 = nn.Conv1d(inplanes, planes, kernel_size=1, bias=False)
|
| 52 |
+
self.bn1 = nn.BatchNorm1d(planes)
|
| 53 |
+
self.conv2 = nn.Conv1d(
|
| 54 |
+
planes,
|
| 55 |
+
planes,
|
| 56 |
+
kernel_size=kernel_size,
|
| 57 |
+
stride=stride,
|
| 58 |
+
dilation=dilation,
|
| 59 |
+
padding=padding,
|
| 60 |
+
bias=False,
|
| 61 |
+
)
|
| 62 |
+
self.bn2 = nn.BatchNorm1d(planes)
|
| 63 |
+
self.conv3 = nn.Conv1d(planes, planes * expansion, kernel_size=1, bias=False)
|
| 64 |
+
self.bn3 = nn.BatchNorm1d(planes * expansion)
|
| 65 |
+
self.relu = nn.ReLU()
|
| 66 |
+
self.downsample = downsample
|
| 67 |
+
|
| 68 |
+
def forward(self, x):
|
| 69 |
+
residual = x
|
| 70 |
+
|
| 71 |
+
out = self.conv1(x)
|
| 72 |
+
out = self.bn1(out)
|
| 73 |
+
out = self.relu(out)
|
| 74 |
+
|
| 75 |
+
out = self.conv2(out)
|
| 76 |
+
out = self.bn2(out)
|
| 77 |
+
out = self.relu(out)
|
| 78 |
+
|
| 79 |
+
out = self.conv3(out)
|
| 80 |
+
out = self.bn3(out)
|
| 81 |
+
|
| 82 |
+
if self.downsample is not None:
|
| 83 |
+
residual = self.downsample(x)
|
| 84 |
+
|
| 85 |
+
out += residual
|
| 86 |
+
out = self.relu(out)
|
| 87 |
+
|
| 88 |
+
return out
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class PSPNet(nn.Module):
|
| 92 |
+
def __init__(self, config):
|
| 93 |
+
super(PSPNet, self).__init__()
|
| 94 |
+
|
| 95 |
+
self.config = config
|
| 96 |
+
self.kernel_size = int(config.kernel_size)
|
| 97 |
+
self.padding = (self.kernel_size - 1) // 2
|
| 98 |
+
self.expansion = int(config.expansion)
|
| 99 |
+
self.inplanes = int(config.inplanes)
|
| 100 |
+
num_layers = int(config.num_layers)
|
| 101 |
+
self.num_bottlenecks = int(config.num_bottlenecks)
|
| 102 |
+
self.interpolate_mode = str(config.interpolate_mode)
|
| 103 |
+
self.dilation = int(config.dilation)
|
| 104 |
+
ppm_bins: list = config.ppm_bins
|
| 105 |
+
self.aux_idx = int(config.aux_idx)
|
| 106 |
+
assert self.aux_idx < num_layers
|
| 107 |
+
self.aux_ratio = float(config.aux_ratio)
|
| 108 |
+
dropout = float(config.dropout)
|
| 109 |
+
output_size = config.output_size # 3(p, qrs, t)
|
| 110 |
+
|
| 111 |
+
# stem 단계에서 1/4 만큼 downsample 된 상태로 시작
|
| 112 |
+
self.stem = nn.Sequential(
|
| 113 |
+
*[
|
| 114 |
+
nn.Conv1d(
|
| 115 |
+
1,
|
| 116 |
+
self.inplanes,
|
| 117 |
+
self.kernel_size,
|
| 118 |
+
stride=2,
|
| 119 |
+
padding=self.padding,
|
| 120 |
+
bias=False,
|
| 121 |
+
),
|
| 122 |
+
nn.BatchNorm1d(self.inplanes),
|
| 123 |
+
nn.ReLU(),
|
| 124 |
+
nn.MaxPool1d(self.kernel_size, stride=2, padding=self.padding),
|
| 125 |
+
]
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
self.layers = []
|
| 129 |
+
plane = self.inplanes
|
| 130 |
+
for i in range(num_layers):
|
| 131 |
+
self.layers.append(self._make_layer(plane * (2 ** (i))))
|
| 132 |
+
self.layers = nn.ModuleList(self.layers)
|
| 133 |
+
|
| 134 |
+
encode_dim = self.inplanes
|
| 135 |
+
self.ppm = PPM(
|
| 136 |
+
encode_dim,
|
| 137 |
+
int(encode_dim / len(ppm_bins)),
|
| 138 |
+
ppm_bins,
|
| 139 |
+
self.interpolate_mode,
|
| 140 |
+
)
|
| 141 |
+
encode_dim *= 2
|
| 142 |
+
self.cls = nn.Sequential(
|
| 143 |
+
nn.Conv1d(
|
| 144 |
+
encode_dim,
|
| 145 |
+
512,
|
| 146 |
+
kernel_size=self.kernel_size,
|
| 147 |
+
padding=self.padding,
|
| 148 |
+
bias=False,
|
| 149 |
+
),
|
| 150 |
+
nn.BatchNorm1d(512),
|
| 151 |
+
nn.ReLU(),
|
| 152 |
+
nn.Dropout1d(dropout),
|
| 153 |
+
nn.Conv1d(512, output_size, kernel_size=1),
|
| 154 |
+
)
|
| 155 |
+
self.aux_branch = nn.Sequential(
|
| 156 |
+
# 추출하고자 하는 layer index 에 해당하는 channel 과 맞춰주어야 함
|
| 157 |
+
nn.Conv1d(
|
| 158 |
+
plane * self.expansion * (2**self.aux_idx),
|
| 159 |
+
256,
|
| 160 |
+
kernel_size=self.kernel_size,
|
| 161 |
+
padding=self.padding,
|
| 162 |
+
bias=False,
|
| 163 |
+
),
|
| 164 |
+
nn.BatchNorm1d(256),
|
| 165 |
+
nn.ReLU(),
|
| 166 |
+
nn.Dropout1d(0.1),
|
| 167 |
+
nn.Conv1d(256, output_size, kernel_size=1),
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
def _make_layer(self, planes: int):
|
| 171 |
+
"""
|
| 172 |
+
self.num_bottlenecks 개의 bottleneck 으로 구성된 layer 를 반환
|
| 173 |
+
첫번째 bottleneck 에서 2 만큼 downsample 됨
|
| 174 |
+
두번째 이후부터의 bottleneck 에서 self.dilation 으로 dilated conv 수행
|
| 175 |
+
"""
|
| 176 |
+
downsample = nn.Sequential(
|
| 177 |
+
nn.Conv1d(
|
| 178 |
+
self.inplanes,
|
| 179 |
+
planes * self.expansion,
|
| 180 |
+
kernel_size=1,
|
| 181 |
+
stride=2,
|
| 182 |
+
bias=False,
|
| 183 |
+
),
|
| 184 |
+
nn.BatchNorm1d(planes * self.expansion),
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
bottlenecks = []
|
| 188 |
+
bottlenecks.append(
|
| 189 |
+
Bottleneck(
|
| 190 |
+
self.inplanes,
|
| 191 |
+
planes,
|
| 192 |
+
expansion=self.expansion,
|
| 193 |
+
kernel_size=self.kernel_size,
|
| 194 |
+
stride=2,
|
| 195 |
+
dilation=1,
|
| 196 |
+
padding=self.padding,
|
| 197 |
+
downsample=downsample,
|
| 198 |
+
)
|
| 199 |
+
)
|
| 200 |
+
self.inplanes = planes * self.expansion
|
| 201 |
+
for _ in range(1, self.num_bottlenecks):
|
| 202 |
+
bottlenecks.append(
|
| 203 |
+
Bottleneck(
|
| 204 |
+
self.inplanes,
|
| 205 |
+
planes,
|
| 206 |
+
expansion=self.expansion,
|
| 207 |
+
kernel_size=self.kernel_size,
|
| 208 |
+
stride=1,
|
| 209 |
+
dilation=self.dilation,
|
| 210 |
+
padding=(self.dilation * (self.kernel_size - 1)) // 2,
|
| 211 |
+
)
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
return nn.Sequential(*bottlenecks)
|
| 215 |
+
|
| 216 |
+
def forward(self, input: torch.Tensor, y=None):
|
| 217 |
+
output: torch.Tensor = input
|
| 218 |
+
output = self.stem(output)
|
| 219 |
+
for i, _layer in enumerate(self.layers):
|
| 220 |
+
output = _layer(output)
|
| 221 |
+
if i == self.aux_idx:
|
| 222 |
+
aux = output
|
| 223 |
+
|
| 224 |
+
output = self.ppm(output)
|
| 225 |
+
output = self.cls(output)
|
| 226 |
+
output = F.interpolate(
|
| 227 |
+
output,
|
| 228 |
+
input.shape[2],
|
| 229 |
+
mode=self.interpolate_mode,
|
| 230 |
+
)
|
| 231 |
+
if self.training:
|
| 232 |
+
aux = self.aux_branch(aux)
|
| 233 |
+
aux = F.interpolate(
|
| 234 |
+
aux,
|
| 235 |
+
input.shape[2],
|
| 236 |
+
mode=self.interpolate_mode,
|
| 237 |
+
)
|
| 238 |
+
return torch.add(output * (1 - self.aux_ratio), aux * self.aux_ratio)
|
| 239 |
+
else:
|
| 240 |
+
return output
|
res/impl/SETR.py
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
paper: https://arxiv.org/abs/2012.15840
|
| 3 |
+
- ref
|
| 4 |
+
- encoder:
|
| 5 |
+
- https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/backbones/vit.py
|
| 6 |
+
- https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit_1d.py
|
| 7 |
+
- decoder:
|
| 8 |
+
- https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/models/decode_heads/setr_up_head.py
|
| 9 |
+
- https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/models/decode_heads/setr_mla_head.py
|
| 10 |
+
|
| 11 |
+
- encoder: ViT 와 구조가 동일하며, PatchEmbed 의 경우 patch_size를 kernel_size와 stride 로 하는 Conv1d를 사용
|
| 12 |
+
- decoder: upsample 하는 방식으로 다음 두가지를 사용 (scale_factor: 특정 배수만큼 upsample / size: 특정 크기와 동일한 크기로 upsample)
|
| 13 |
+
- naive: 원본 길이로 size 방식 upsample
|
| 14 |
+
- pup: scale_factor 방식으로 수행하다가 마지막에 원본 길이로 size 방식으로 upsample
|
| 15 |
+
- mla: 총 두 단계로 수행하며, 첫번째 단계에서 transformer block 의 결과들을 scale_factor 방식으로 수행하고 두번째 단계에서 첫번째 결과들을 concat 한 후 size 방식으로 upsample
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import math
|
| 19 |
+
import torch
|
| 20 |
+
from torch import nn
|
| 21 |
+
from einops import rearrange
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class FeedForward(nn.Module):
|
| 25 |
+
def __init__(self, dim, hidden_dim, dropout=0.0):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.net = nn.Sequential(
|
| 28 |
+
nn.LayerNorm(dim),
|
| 29 |
+
nn.Linear(dim, hidden_dim),
|
| 30 |
+
nn.GELU(),
|
| 31 |
+
nn.Dropout(dropout),
|
| 32 |
+
nn.Linear(hidden_dim, dim),
|
| 33 |
+
nn.Dropout(dropout),
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
return self.net(x)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class Attention(nn.Module):
|
| 41 |
+
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
|
| 42 |
+
super().__init__()
|
| 43 |
+
inner_dim = dim_head * heads
|
| 44 |
+
project_out = not (heads == 1 and dim_head == dim)
|
| 45 |
+
|
| 46 |
+
self.heads = heads
|
| 47 |
+
self.scale = dim_head**-0.5
|
| 48 |
+
|
| 49 |
+
self.norm = nn.LayerNorm(dim)
|
| 50 |
+
self.attend = nn.Softmax(dim=-1)
|
| 51 |
+
self.dropout = nn.Dropout(dropout)
|
| 52 |
+
|
| 53 |
+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
|
| 54 |
+
|
| 55 |
+
self.to_out = (
|
| 56 |
+
nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
|
| 57 |
+
if project_out
|
| 58 |
+
else nn.Identity()
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
def forward(self, x):
|
| 62 |
+
x = self.norm(x)
|
| 63 |
+
qkv = self.to_qkv(x).chunk(3, dim=-1)
|
| 64 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv)
|
| 65 |
+
|
| 66 |
+
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
| 67 |
+
|
| 68 |
+
attn = self.attend(dots)
|
| 69 |
+
attn = self.dropout(attn)
|
| 70 |
+
|
| 71 |
+
out = torch.matmul(attn, v)
|
| 72 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
| 73 |
+
return self.to_out(out)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# ========== 여기까지 https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit_1d.py 차용 ==========
|
| 77 |
+
# ========== 아래부터 setr 원본 참고 https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/backbones/vit.py ==========
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class TransformerBlock(nn.Module):
|
| 81 |
+
def __init__(
|
| 82 |
+
self,
|
| 83 |
+
dim,
|
| 84 |
+
num_attn_heads,
|
| 85 |
+
attn_head_dim,
|
| 86 |
+
mlp_dim,
|
| 87 |
+
attn_dropout=0.0,
|
| 88 |
+
ffn_dropout=0.0,
|
| 89 |
+
):
|
| 90 |
+
super().__init__()
|
| 91 |
+
self.attn = Attention(
|
| 92 |
+
dim, heads=num_attn_heads, dim_head=attn_head_dim, dropout=attn_dropout
|
| 93 |
+
)
|
| 94 |
+
self.ffn = FeedForward(dim, mlp_dim, dropout=ffn_dropout)
|
| 95 |
+
|
| 96 |
+
def forward(self, x):
|
| 97 |
+
x = self.attn(x) + x
|
| 98 |
+
x = self.ffn(x) + x
|
| 99 |
+
return x
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class PatchEmbed(nn.Module):
|
| 103 |
+
def __init__(
|
| 104 |
+
self,
|
| 105 |
+
embed_dim=1024,
|
| 106 |
+
kernel_size=16,
|
| 107 |
+
bias=False,
|
| 108 |
+
):
|
| 109 |
+
super().__init__()
|
| 110 |
+
|
| 111 |
+
self.projection = nn.Conv1d(
|
| 112 |
+
in_channels=1,
|
| 113 |
+
out_channels=embed_dim,
|
| 114 |
+
kernel_size=kernel_size,
|
| 115 |
+
stride=kernel_size,
|
| 116 |
+
bias=bias,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
def forward(self, x: torch.Tensor):
|
| 120 |
+
return self.projection(x).transpose(1, 2)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class SETR(nn.Module):
|
| 124 |
+
def __init__(self, config):
|
| 125 |
+
super().__init__()
|
| 126 |
+
|
| 127 |
+
embed_dim = int(config.embed_dim)
|
| 128 |
+
data_len = int(config.data_len) # ECGPQRSTDataset.second, hz 에 맞춰서
|
| 129 |
+
patch_size = int(config.patch_size)
|
| 130 |
+
assert data_len % patch_size == 0
|
| 131 |
+
num_patches = data_len // patch_size
|
| 132 |
+
patch_bias = bool(config.patch_bias)
|
| 133 |
+
dropout = float(config.dropout)
|
| 134 |
+
# pos_dropout_p: float = config.pos_dropout_p # 파라미터라 너무 많으므로 우선 dropout 개수는 하나로 사용
|
| 135 |
+
num_layers = int(config.num_layers) # transformer block 개수
|
| 136 |
+
num_attn_heads = int(config.num_attn_heads)
|
| 137 |
+
attn_head_dim = int(config.attn_head_dim)
|
| 138 |
+
mlp_dim = int(config.mlp_dim)
|
| 139 |
+
# attn_dropout: float = config.attn_dropout
|
| 140 |
+
# ffn_dropout: float = config.ffn_dropout
|
| 141 |
+
interpolate_mode = str(config.interpolate_mode)
|
| 142 |
+
dec_conf: dict = config.dec_conf
|
| 143 |
+
assert len(dec_conf) == 1
|
| 144 |
+
self.dec_mode: str = list(dec_conf.keys())[0]
|
| 145 |
+
assert self.dec_mode in ["naive", "pup", "mla"]
|
| 146 |
+
self.dec_param: dict = dec_conf[self.dec_mode]
|
| 147 |
+
output_size = int(config.output_size)
|
| 148 |
+
|
| 149 |
+
# patch embedding
|
| 150 |
+
self.patch_embed = PatchEmbed(
|
| 151 |
+
embed_dim=embed_dim,
|
| 152 |
+
kernel_size=patch_size,
|
| 153 |
+
bias=patch_bias,
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
# positional embedding
|
| 157 |
+
self.pos_embed = nn.Parameter(torch.randn(1, num_patches, embed_dim))
|
| 158 |
+
self.pos_dropout = nn.Dropout(p=dropout)
|
| 159 |
+
|
| 160 |
+
# transformer encoder
|
| 161 |
+
self.layers = nn.ModuleList()
|
| 162 |
+
for _ in range(num_layers):
|
| 163 |
+
self.layers.append(
|
| 164 |
+
TransformerBlock(
|
| 165 |
+
dim=embed_dim,
|
| 166 |
+
num_attn_heads=num_attn_heads,
|
| 167 |
+
attn_head_dim=attn_head_dim,
|
| 168 |
+
mlp_dim=mlp_dim,
|
| 169 |
+
attn_dropout=dropout,
|
| 170 |
+
ffn_dropout=dropout,
|
| 171 |
+
)
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
# decoder
|
| 175 |
+
self.dec_layers = nn.ModuleList()
|
| 176 |
+
if self.dec_mode == "naive":
|
| 177 |
+
self.dec_layers.append(nn.Upsample(size=data_len, mode=interpolate_mode))
|
| 178 |
+
dec_out_channel = embed_dim
|
| 179 |
+
elif self.dec_mode == "pup":
|
| 180 |
+
self.dec_layers.append(nn.LayerNorm(embed_dim))
|
| 181 |
+
dec_up_scale = int(self.dec_param["up_scale"])
|
| 182 |
+
available_up_count = int(
|
| 183 |
+
math.log(data_len // num_patches, dec_up_scale)
|
| 184 |
+
) # scale_factor 방법으로 upsample 할 수 있는 단계 계산, 나머지는 size 방법으로 upsample
|
| 185 |
+
pup_channels = int(self.dec_param["channels"])
|
| 186 |
+
dec_in_channel = embed_dim
|
| 187 |
+
dec_out_channel = pup_channels
|
| 188 |
+
dec_kernel_size = int(self.dec_param["kernel_size"])
|
| 189 |
+
dec_num_convs_by_layer = int(self.dec_param["num_convs_by_layer"])
|
| 190 |
+
assert dec_kernel_size in [1, 3] # 원본 코드 그대로
|
| 191 |
+
for i in range(available_up_count + 1):
|
| 192 |
+
for _ in range(dec_num_convs_by_layer):
|
| 193 |
+
self.dec_layers.append(
|
| 194 |
+
nn.Conv1d(
|
| 195 |
+
dec_in_channel,
|
| 196 |
+
dec_out_channel,
|
| 197 |
+
kernel_size=dec_kernel_size,
|
| 198 |
+
stride=1,
|
| 199 |
+
padding=(dec_kernel_size - 1) // 2,
|
| 200 |
+
)
|
| 201 |
+
)
|
| 202 |
+
dec_in_channel = dec_out_channel
|
| 203 |
+
if i < available_up_count:
|
| 204 |
+
self.dec_layers.append(
|
| 205 |
+
nn.Upsample(scale_factor=dec_up_scale, mode=interpolate_mode)
|
| 206 |
+
)
|
| 207 |
+
else: # last upsample
|
| 208 |
+
self.dec_layers.append(
|
| 209 |
+
nn.Upsample(size=data_len, mode=interpolate_mode)
|
| 210 |
+
)
|
| 211 |
+
else: # mla
|
| 212 |
+
dec_up_scale = int(self.dec_param["up_scale"])
|
| 213 |
+
assert (
|
| 214 |
+
data_len >= dec_up_scale * num_patches
|
| 215 |
+
) # transformer 중간 결과를 up_scale 만큼 upsample 했을 때 원본 보다는 작아야 최종 upsample 이 의미가 있음
|
| 216 |
+
dec_output_step = int(self.dec_param["output_step"])
|
| 217 |
+
assert num_layers % dec_output_step == 0
|
| 218 |
+
dec_num_convs_by_layer = int(self.dec_param["num_convs_by_layer"])
|
| 219 |
+
dec_kernel_size = int(self.dec_param["kernel_size"])
|
| 220 |
+
mid_feature_cnt = num_layers // dec_output_step
|
| 221 |
+
mla_channel = int(self.dec_param["channels"])
|
| 222 |
+
for _ in range(mid_feature_cnt):
|
| 223 |
+
# transformer block 중간 결과에서 각 step 별로 추출한 feature map 에 적용할 conv-upsample
|
| 224 |
+
dec_in_channel = embed_dim
|
| 225 |
+
dec_layers_each_upsample = []
|
| 226 |
+
for _ in range(dec_num_convs_by_layer):
|
| 227 |
+
dec_layers_each_upsample.append(
|
| 228 |
+
nn.Conv1d(
|
| 229 |
+
dec_in_channel,
|
| 230 |
+
mla_channel,
|
| 231 |
+
kernel_size=dec_kernel_size,
|
| 232 |
+
stride=1,
|
| 233 |
+
padding=(dec_kernel_size - 1) // 2,
|
| 234 |
+
)
|
| 235 |
+
)
|
| 236 |
+
dec_in_channel = mla_channel
|
| 237 |
+
dec_layers_each_upsample.append(
|
| 238 |
+
nn.Upsample(scale_factor=dec_up_scale, mode=interpolate_mode)
|
| 239 |
+
)
|
| 240 |
+
self.dec_layers.append(nn.Sequential(*dec_layers_each_upsample))
|
| 241 |
+
# last decoder layer: 중간 feature map 을 concat 한 이후, upsample
|
| 242 |
+
self.dec_layers.append(nn.Upsample(size=data_len, mode=interpolate_mode))
|
| 243 |
+
|
| 244 |
+
dec_out_channel = (
|
| 245 |
+
mla_channel * mid_feature_cnt
|
| 246 |
+
) # self.dec_layers 를 transformer 중간 결과들에 적용한 feature map 개수(mid_feature_cnt)만큼 channel-wise concat 하기 때문에 그만큼 증가된 channel ��� 아래 self.cls 의 in_channel 로 사용되어어야 함
|
| 247 |
+
|
| 248 |
+
self.cls = nn.Conv1d(dec_out_channel, output_size, 1, bias=False)
|
| 249 |
+
|
| 250 |
+
def forward(self, input: torch.Tensor, y=None):
|
| 251 |
+
output = input
|
| 252 |
+
|
| 253 |
+
# patch embedding
|
| 254 |
+
output = self.patch_embed(output)
|
| 255 |
+
|
| 256 |
+
# positional embedding
|
| 257 |
+
output += self.pos_embed
|
| 258 |
+
output = self.pos_dropout(output)
|
| 259 |
+
|
| 260 |
+
outputs = []
|
| 261 |
+
# transformer encoder
|
| 262 |
+
for i, layer in enumerate(self.layers):
|
| 263 |
+
output = layer(output)
|
| 264 |
+
if self.dec_mode == "mla":
|
| 265 |
+
if (i + 1) % int(self.dec_param["output_step"]) == 0:
|
| 266 |
+
outputs.append(output.transpose(1, 2))
|
| 267 |
+
if self.dec_mode != "mla": # mla 의 경우 위에서 이미 추가
|
| 268 |
+
outputs.append(output.transpose(1, 2))
|
| 269 |
+
|
| 270 |
+
# decoder
|
| 271 |
+
if self.dec_mode == "naive":
|
| 272 |
+
assert len(outputs) == 1
|
| 273 |
+
output = outputs[0]
|
| 274 |
+
output = self.dec_layers[0](output)
|
| 275 |
+
elif self.dec_mode == "pup":
|
| 276 |
+
assert len(outputs) == 1
|
| 277 |
+
output = outputs[0]
|
| 278 |
+
pup_norm = self.dec_layers[0]
|
| 279 |
+
output = pup_norm(output.transpose(1, 2)).transpose(1, 2)
|
| 280 |
+
for i, dec_layer in enumerate(self.dec_layers[1:]):
|
| 281 |
+
output = dec_layer(output)
|
| 282 |
+
else: # mla
|
| 283 |
+
dec_output_step = int(self.dec_param["output_step"])
|
| 284 |
+
mid_feature_cnt = len(self.layers) // dec_output_step
|
| 285 |
+
assert len(outputs) == mid_feature_cnt
|
| 286 |
+
for i in range(len(outputs)):
|
| 287 |
+
outputs[i] = self.dec_layers[i](outputs[i])
|
| 288 |
+
output = torch.cat(outputs, dim=1)
|
| 289 |
+
output = self.dec_layers[-1](output)
|
| 290 |
+
|
| 291 |
+
return self.cls(output)
|
res/impl/SegFormer.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
paper: https://arxiv.org/abs/2105.15203
|
| 3 |
+
- ref:
|
| 4 |
+
- encoder:
|
| 5 |
+
- https://github.com/NVlabs/SegFormer/blob/master/mmseg/models/backbones/mix_transformer.py
|
| 6 |
+
- https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/backbones/mit.py
|
| 7 |
+
- decoder:
|
| 8 |
+
- https://github.com/NVlabs/SegFormer/blob/master/mmseg/models/decode_heads/segformer_head.py
|
| 9 |
+
- https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/decode_heads/segformer_head.py
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
from torch import nn
|
| 15 |
+
from torch.functional import F
|
| 16 |
+
import math
|
| 17 |
+
from einops import rearrange
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class MixFFN(nn.Module):
|
| 21 |
+
def __init__(self, embed_dim, channels, dropout=0.0):
|
| 22 |
+
super().__init__()
|
| 23 |
+
|
| 24 |
+
self.layers = nn.Sequential(
|
| 25 |
+
nn.Conv1d( # fc1
|
| 26 |
+
in_channels=embed_dim, out_channels=channels, kernel_size=1, stride=1
|
| 27 |
+
),
|
| 28 |
+
nn.Conv1d( # position embed (depthwise-separable)
|
| 29 |
+
in_channels=channels,
|
| 30 |
+
out_channels=channels,
|
| 31 |
+
kernel_size=3,
|
| 32 |
+
stride=1,
|
| 33 |
+
padding=1,
|
| 34 |
+
groups=channels,
|
| 35 |
+
),
|
| 36 |
+
nn.GELU(),
|
| 37 |
+
nn.Dropout(dropout),
|
| 38 |
+
nn.Conv1d( # fc2
|
| 39 |
+
in_channels=channels, out_channels=embed_dim, kernel_size=1
|
| 40 |
+
),
|
| 41 |
+
nn.Dropout(dropout),
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
def forward(self, x):
|
| 45 |
+
out = x.transpose(1, 2)
|
| 46 |
+
out = self.layers(out)
|
| 47 |
+
out = out.transpose(1, 2)
|
| 48 |
+
return out
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class EfficientMultiheadAttention(nn.Module):
|
| 52 |
+
"""
|
| 53 |
+
PVT(Pyramid Vision Transformer)에서 사용한 Spatial-Reduction Attention 을 차용
|
| 54 |
+
변수명 중 sr 은 Spatial-Reduction 의 약어
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
def __init__(
|
| 58 |
+
self, embed_dim, num_heads=8, attn_drop=0.0, proj_drop=0.0, sr_ratio=1
|
| 59 |
+
):
|
| 60 |
+
super().__init__()
|
| 61 |
+
|
| 62 |
+
assert (
|
| 63 |
+
embed_dim % num_heads == 0
|
| 64 |
+
), f"dim {embed_dim} should be divided by num_heads {num_heads}."
|
| 65 |
+
|
| 66 |
+
self.num_heads = num_heads
|
| 67 |
+
head_dim = embed_dim // num_heads
|
| 68 |
+
self.scale = head_dim**-0.5
|
| 69 |
+
|
| 70 |
+
self.q = nn.Linear(embed_dim, embed_dim, bias=False)
|
| 71 |
+
self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=False)
|
| 72 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 73 |
+
self.proj = nn.Linear(embed_dim, embed_dim)
|
| 74 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 75 |
+
|
| 76 |
+
self.sr_ratio = sr_ratio
|
| 77 |
+
if sr_ratio > 1:
|
| 78 |
+
self.sr = nn.Conv1d(
|
| 79 |
+
embed_dim, embed_dim, kernel_size=sr_ratio, stride=sr_ratio
|
| 80 |
+
)
|
| 81 |
+
self.norm = nn.LayerNorm(embed_dim)
|
| 82 |
+
|
| 83 |
+
def forward(self, x):
|
| 84 |
+
B, N, C = x.shape
|
| 85 |
+
q = self.q(x)
|
| 86 |
+
q = rearrange(q, "b n (h c) -> b h n c", h=self.num_heads)
|
| 87 |
+
|
| 88 |
+
if self.sr_ratio > 1:
|
| 89 |
+
x_ = x.transpose(1, 2)
|
| 90 |
+
x_ = self.sr(x_).transpose(1, 2)
|
| 91 |
+
x_ = self.norm(x_)
|
| 92 |
+
kv = self.kv(x_)
|
| 93 |
+
kv = rearrange(
|
| 94 |
+
kv,
|
| 95 |
+
"b n (two_heads h c) -> two_heads b h n c",
|
| 96 |
+
two_heads=2,
|
| 97 |
+
h=self.num_heads,
|
| 98 |
+
)
|
| 99 |
+
else:
|
| 100 |
+
kv = self.kv(x)
|
| 101 |
+
kv = rearrange(
|
| 102 |
+
kv,
|
| 103 |
+
"b n (two_heads h c) -> two_heads b h n c",
|
| 104 |
+
two_heads=2,
|
| 105 |
+
h=self.num_heads,
|
| 106 |
+
)
|
| 107 |
+
k, v = kv[0], kv[1]
|
| 108 |
+
|
| 109 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
| 110 |
+
attn = attn.softmax(dim=-1)
|
| 111 |
+
attn = self.attn_drop(attn)
|
| 112 |
+
|
| 113 |
+
x = (attn @ v).transpose(1, 2)
|
| 114 |
+
x = x.reshape(B, N, C)
|
| 115 |
+
x = self.proj(x)
|
| 116 |
+
x = self.proj_drop(x)
|
| 117 |
+
|
| 118 |
+
return x
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class TransformerBlock(nn.Module):
|
| 122 |
+
def __init__(self, embed_dim, num_heads, ffn_channels, dropout=0.2, sr_ratio=1):
|
| 123 |
+
super().__init__()
|
| 124 |
+
|
| 125 |
+
self.attn = nn.Sequential(
|
| 126 |
+
nn.LayerNorm(embed_dim),
|
| 127 |
+
EfficientMultiheadAttention(
|
| 128 |
+
embed_dim=embed_dim,
|
| 129 |
+
num_heads=num_heads,
|
| 130 |
+
attn_drop=dropout,
|
| 131 |
+
proj_drop=dropout,
|
| 132 |
+
sr_ratio=sr_ratio,
|
| 133 |
+
),
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
self.ffn = nn.Sequential(
|
| 137 |
+
nn.LayerNorm(embed_dim),
|
| 138 |
+
MixFFN(embed_dim=embed_dim, channels=ffn_channels, dropout=dropout),
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
def forward(self, x):
|
| 142 |
+
x = x + self.attn(x)
|
| 143 |
+
x = x + self.ffn(x)
|
| 144 |
+
return x
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class PatchEmbed(nn.Module):
|
| 148 |
+
def __init__(
|
| 149 |
+
self,
|
| 150 |
+
in_channels=1,
|
| 151 |
+
embed_dim=1024,
|
| 152 |
+
kernel_size=7,
|
| 153 |
+
stride=4,
|
| 154 |
+
padding=3,
|
| 155 |
+
bias=False,
|
| 156 |
+
):
|
| 157 |
+
super().__init__()
|
| 158 |
+
|
| 159 |
+
self.projection = nn.Conv1d(
|
| 160 |
+
in_channels=in_channels,
|
| 161 |
+
out_channels=embed_dim,
|
| 162 |
+
kernel_size=kernel_size,
|
| 163 |
+
stride=stride,
|
| 164 |
+
padding=padding,
|
| 165 |
+
bias=bias,
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
def forward(self, x: torch.Tensor):
|
| 169 |
+
return self.projection(x).transpose(1, 2)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class MiT(nn.Module):
|
| 173 |
+
"""MixVisionTransformer"""
|
| 174 |
+
|
| 175 |
+
def __init__(
|
| 176 |
+
self,
|
| 177 |
+
embed_dim=512,
|
| 178 |
+
num_blocks=[2, 2, 6, 2],
|
| 179 |
+
num_heads=[1, 2, "ceil"],
|
| 180 |
+
sr_ratios=[1, 2, "ceil"],
|
| 181 |
+
mlp_ratio=4,
|
| 182 |
+
dropout=0.2,
|
| 183 |
+
):
|
| 184 |
+
super().__init__()
|
| 185 |
+
|
| 186 |
+
num_stages = len(num_blocks)
|
| 187 |
+
round_func = getattr(math, num_heads[2]) # math.ceil or match.floor
|
| 188 |
+
num_heads = [
|
| 189 |
+
round_func((num_heads[0] * math.pow(num_heads[1], itr)))
|
| 190 |
+
for itr in range(num_stages)
|
| 191 |
+
]
|
| 192 |
+
round_func = getattr(math, sr_ratios[2]) # math.ceil or match.floor
|
| 193 |
+
sr_ratios = [
|
| 194 |
+
round_func(sr_ratios[0] * math.pow(sr_ratios[1], itr))
|
| 195 |
+
for itr in range(num_stages)
|
| 196 |
+
]
|
| 197 |
+
sr_ratios.reverse()
|
| 198 |
+
|
| 199 |
+
self.embed_dims = [embed_dim * num_head for num_head in num_heads]
|
| 200 |
+
patch_kernel_sizes = [7] # [7, 3, 3, ..]
|
| 201 |
+
patch_kernel_sizes.extend([3] * (num_stages - 1))
|
| 202 |
+
patch_strides = [4] # [4, 2, 2, ..]
|
| 203 |
+
patch_strides.extend([2] * (num_stages - 1))
|
| 204 |
+
patch_paddings = [3] # [3, 1, 1, ..]
|
| 205 |
+
patch_paddings.extend([1] * (num_stages - 1))
|
| 206 |
+
|
| 207 |
+
in_channels = 1
|
| 208 |
+
self.stages = nn.ModuleList()
|
| 209 |
+
for i, num_block in enumerate(num_blocks):
|
| 210 |
+
patch_embed = PatchEmbed(
|
| 211 |
+
in_channels=in_channels,
|
| 212 |
+
embed_dim=self.embed_dims[i],
|
| 213 |
+
kernel_size=patch_kernel_sizes[i],
|
| 214 |
+
stride=patch_strides[i],
|
| 215 |
+
padding=patch_paddings[i],
|
| 216 |
+
)
|
| 217 |
+
blocks = nn.ModuleList(
|
| 218 |
+
[
|
| 219 |
+
TransformerBlock(
|
| 220 |
+
embed_dim=self.embed_dims[i],
|
| 221 |
+
num_heads=num_heads[i],
|
| 222 |
+
ffn_channels=mlp_ratio * self.embed_dims[i],
|
| 223 |
+
dropout=dropout,
|
| 224 |
+
sr_ratio=sr_ratios[i],
|
| 225 |
+
)
|
| 226 |
+
for _ in range(num_block)
|
| 227 |
+
]
|
| 228 |
+
)
|
| 229 |
+
in_channels = self.embed_dims[i]
|
| 230 |
+
norm = nn.LayerNorm(self.embed_dims[i])
|
| 231 |
+
self.stages.append(nn.ModuleList([patch_embed, blocks, norm]))
|
| 232 |
+
|
| 233 |
+
def forward(self, x):
|
| 234 |
+
outs = []
|
| 235 |
+
|
| 236 |
+
for stage in self.stages:
|
| 237 |
+
x = stage[0](x) # patch embed
|
| 238 |
+
for block in stage[1]: # transformer blocks
|
| 239 |
+
x = block(x)
|
| 240 |
+
x = stage[2](x) # norm
|
| 241 |
+
x = x.transpose(1, 2)
|
| 242 |
+
outs.append(x)
|
| 243 |
+
|
| 244 |
+
return outs
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
class SegFormer(nn.Module):
|
| 248 |
+
def __init__(self, config):
|
| 249 |
+
super().__init__()
|
| 250 |
+
|
| 251 |
+
embed_dim = int(config.embed_dim)
|
| 252 |
+
num_blocks = config.num_blocks
|
| 253 |
+
num_heads = config.num_heads
|
| 254 |
+
assert len(num_heads) == 3 and num_heads[2] in ["floor", "ceil"]
|
| 255 |
+
sr_ratios = config.sr_ratios
|
| 256 |
+
assert len(sr_ratios) == 3 and sr_ratios[2] in ["floor", "ceil"]
|
| 257 |
+
mlp_ratio = int(config.mlp_ratio)
|
| 258 |
+
dropout = float(config.dropout)
|
| 259 |
+
decoder_channels = int(config.decoder_channels)
|
| 260 |
+
self.interpolate_mode = str(config.interpolate_mode)
|
| 261 |
+
output_size = int(config.output_size)
|
| 262 |
+
|
| 263 |
+
self.MiT = MiT(embed_dim, num_blocks, num_heads, sr_ratios, mlp_ratio, dropout)
|
| 264 |
+
|
| 265 |
+
num_stages = len(num_blocks)
|
| 266 |
+
self.decode_mlps = nn.ModuleList(
|
| 267 |
+
[
|
| 268 |
+
nn.Conv1d(self.MiT.embed_dims[i], decoder_channels, 1, bias=False)
|
| 269 |
+
for i in range(num_stages)
|
| 270 |
+
]
|
| 271 |
+
)
|
| 272 |
+
self.decode_fusion = nn.Conv1d(
|
| 273 |
+
decoder_channels * num_stages, decoder_channels, 1, bias=False
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
self.cls = nn.Conv1d(decoder_channels, output_size, 1, bias=False)
|
| 277 |
+
|
| 278 |
+
def forward(self, input: torch.Tensor, y=None):
|
| 279 |
+
output = input
|
| 280 |
+
|
| 281 |
+
output = self.MiT(output)
|
| 282 |
+
for i, (_output, decode_mlp) in enumerate(zip(output, self.decode_mlps)):
|
| 283 |
+
_output = decode_mlp(_output)
|
| 284 |
+
if i != 0:
|
| 285 |
+
_output = F.interpolate(
|
| 286 |
+
_output, size=output[0].shape[2], mode=self.interpolate_mode
|
| 287 |
+
)
|
| 288 |
+
output[i] = _output
|
| 289 |
+
|
| 290 |
+
output = torch.concat(output, dim=1)
|
| 291 |
+
output = self.decode_fusion(output)
|
| 292 |
+
output = self.cls(output)
|
| 293 |
+
|
| 294 |
+
return F.interpolate(output, size=input.shape[2], mode=self.interpolate_mode)
|
res/impl/UNet3PlusDeepSup.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
paper: https://arxiv.org/abs/2004.08790
|
| 3 |
+
ref: https://github.com/ZJUGiveLab/UNet-Version/blob/master/models/UNet_3Plus.py
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch import nn
|
| 8 |
+
from torch.functional import F
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class UNetConv(nn.Module):
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
in_size,
|
| 15 |
+
out_size,
|
| 16 |
+
is_batchnorm=True,
|
| 17 |
+
num_layers=2,
|
| 18 |
+
kernel_size=3,
|
| 19 |
+
stride=1,
|
| 20 |
+
padding=1,
|
| 21 |
+
):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.num_layers = num_layers
|
| 24 |
+
|
| 25 |
+
for i in range(num_layers):
|
| 26 |
+
seq = [nn.Conv1d(in_size, out_size, kernel_size, stride, padding)]
|
| 27 |
+
if is_batchnorm:
|
| 28 |
+
seq.append(nn.BatchNorm1d(out_size))
|
| 29 |
+
seq.append(nn.ReLU())
|
| 30 |
+
conv = nn.Sequential(*seq)
|
| 31 |
+
setattr(self, "conv%d" % i, conv)
|
| 32 |
+
in_size = out_size
|
| 33 |
+
|
| 34 |
+
def forward(self, inputs):
|
| 35 |
+
x = inputs
|
| 36 |
+
for i in range(self.num_layers):
|
| 37 |
+
conv = getattr(self, "conv%d" % i)
|
| 38 |
+
x = conv(x)
|
| 39 |
+
|
| 40 |
+
return x
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class UNet3PlusDeepSup(nn.Module):
|
| 44 |
+
def __init__(self, config):
|
| 45 |
+
super().__init__()
|
| 46 |
+
|
| 47 |
+
self.config = config
|
| 48 |
+
inplanes = int(config.inplanes)
|
| 49 |
+
kernel_size = int(config.kernel_size)
|
| 50 |
+
padding = (kernel_size - 1) // 2
|
| 51 |
+
num_encoder_layers = int(config.num_encoder_layers)
|
| 52 |
+
encoder_batchnorm = bool(config.encoder_batchnorm)
|
| 53 |
+
self.num_depths = int(config.num_depths)
|
| 54 |
+
self.interpolate_mode = str(config.interpolate_mode)
|
| 55 |
+
dropout = float(config.dropout)
|
| 56 |
+
self.use_cgm = bool(config.use_cgm)
|
| 57 |
+
# sum_of_sup == True: 모든 sup 을 elementwise sum 하여 하나의 dense map 을 만들어 label 과 loss 를 구함
|
| 58 |
+
# sum_of_sup == False: 각 sup 과 label의 loss 를 각각 구하여 하나의 loss 에 저장
|
| 59 |
+
self.sum_of_sup = bool(config.sum_of_sup)
|
| 60 |
+
# TrialSetup._init_network_params 에서 설정됨
|
| 61 |
+
self.output_size: int = config.output_size
|
| 62 |
+
|
| 63 |
+
# Encoder
|
| 64 |
+
self.encoders = torch.nn.ModuleList()
|
| 65 |
+
for i in range(self.num_depths):
|
| 66 |
+
"""(MaxPool - UNetConv) 를 수행하는 것이 하나의 depth 이고, 예외적으로 첫번째 depth 의 encode 결과는 (UNetConv)만 수행한 것"""
|
| 67 |
+
_encoders = []
|
| 68 |
+
if i != 0:
|
| 69 |
+
_encoders.append(nn.MaxPool1d(2))
|
| 70 |
+
_encoders.append(
|
| 71 |
+
UNetConv(
|
| 72 |
+
1 if i == 0 else (inplanes * (2 ** (i - 1))),
|
| 73 |
+
inplanes * (2**i),
|
| 74 |
+
is_batchnorm=encoder_batchnorm,
|
| 75 |
+
num_layers=num_encoder_layers,
|
| 76 |
+
kernel_size=kernel_size,
|
| 77 |
+
stride=1,
|
| 78 |
+
padding=padding,
|
| 79 |
+
)
|
| 80 |
+
)
|
| 81 |
+
self.encoders.append(nn.Sequential(*_encoders))
|
| 82 |
+
|
| 83 |
+
# CGM: Classification-Guided Module
|
| 84 |
+
if self.use_cgm:
|
| 85 |
+
self.cls = nn.Sequential(
|
| 86 |
+
nn.Dropout(dropout),
|
| 87 |
+
nn.Conv1d(
|
| 88 |
+
inplanes * (2 ** (self.num_depths - 1)), 2 * self.output_size, 1
|
| 89 |
+
),
|
| 90 |
+
nn.AdaptiveMaxPool1d(1),
|
| 91 |
+
nn.Sigmoid(),
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
# Decoder
|
| 95 |
+
self.up_channels = inplanes * self.num_depths
|
| 96 |
+
|
| 97 |
+
self.decoders = torch.nn.ModuleList()
|
| 98 |
+
for i in reversed(range(self.num_depths - 1)):
|
| 99 |
+
"""
|
| 100 |
+
각 decoder 는 각 encode 결과를 MaxPool 하거나 그대로(Conv,BatchNorm,Relu 만) 사용하거나 Upsample 된 결과를 수행하고 concat 하여 (Conv,BatchNorm,Relu)를 수행할 수 있도록 구성
|
| 101 |
+
다만, Upsample 은 encode 결과와 size 를 맞추기 간편하도록 forward 단계에서 torch.functional.interpolate() 로 수행
|
| 102 |
+
"""
|
| 103 |
+
# 각 단계별 decoder 는 항상 num_depths 만큼 구성되고 내부적으로 MaxPool/그대로/Upsample 수행할지가 달라짐
|
| 104 |
+
_decoders = torch.nn.ModuleList()
|
| 105 |
+
for j in range(self.num_depths):
|
| 106 |
+
_each_decoders = []
|
| 107 |
+
if j < i:
|
| 108 |
+
_each_decoders.append(nn.MaxPool1d(2 ** (i - j), ceil_mode=True))
|
| 109 |
+
if i < j < self.num_depths - 1:
|
| 110 |
+
_each_decoders.append(
|
| 111 |
+
nn.Conv1d(
|
| 112 |
+
inplanes * self.num_depths,
|
| 113 |
+
inplanes,
|
| 114 |
+
kernel_size,
|
| 115 |
+
padding=padding,
|
| 116 |
+
)
|
| 117 |
+
)
|
| 118 |
+
else:
|
| 119 |
+
_each_decoders.append(
|
| 120 |
+
nn.Conv1d(
|
| 121 |
+
inplanes * (2**j), inplanes, kernel_size, padding=padding
|
| 122 |
+
)
|
| 123 |
+
)
|
| 124 |
+
_each_decoders.append(nn.BatchNorm1d(inplanes))
|
| 125 |
+
_each_decoders.append(nn.ReLU())
|
| 126 |
+
_decoders.append(nn.Sequential(*_each_decoders))
|
| 127 |
+
_decoders.append(
|
| 128 |
+
nn.Sequential(
|
| 129 |
+
nn.Conv1d(
|
| 130 |
+
self.up_channels, self.up_channels, kernel_size, padding=padding
|
| 131 |
+
),
|
| 132 |
+
nn.BatchNorm1d(self.up_channels),
|
| 133 |
+
nn.ReLU(),
|
| 134 |
+
)
|
| 135 |
+
)
|
| 136 |
+
self.decoders.append(_decoders)
|
| 137 |
+
|
| 138 |
+
# 앞 conv 들은 in channel 이 up_channels(inplanes*num_depths(원본에서는 320)), 마지막 conv 는 마지막 encoder 결과의 output_channel 과 맞춤
|
| 139 |
+
self.sup_conv = torch.nn.ModuleList()
|
| 140 |
+
for i in range(self.num_depths - 1):
|
| 141 |
+
self.sup_conv.append(
|
| 142 |
+
nn.Sequential(
|
| 143 |
+
nn.Conv1d(
|
| 144 |
+
self.up_channels, self.output_size, kernel_size, padding=padding
|
| 145 |
+
),
|
| 146 |
+
nn.BatchNorm1d(self.output_size),
|
| 147 |
+
nn.ReLU(),
|
| 148 |
+
)
|
| 149 |
+
)
|
| 150 |
+
self.sup_conv.append(
|
| 151 |
+
nn.Sequential(
|
| 152 |
+
nn.Conv1d(
|
| 153 |
+
inplanes * (2 ** (self.num_depths - 1)),
|
| 154 |
+
self.output_size,
|
| 155 |
+
kernel_size,
|
| 156 |
+
padding=padding,
|
| 157 |
+
),
|
| 158 |
+
nn.BatchNorm1d(self.output_size),
|
| 159 |
+
nn.ReLU(),
|
| 160 |
+
)
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
def forward(self, input: torch.Tensor, y=None):
|
| 164 |
+
# Encoder
|
| 165 |
+
output = input
|
| 166 |
+
enc_features = [] # X1Ee, X2Ee, .. , X5Ee
|
| 167 |
+
dec_features = [] # X5Ee, X4De, .. , X1De
|
| 168 |
+
for encoder in self.encoders:
|
| 169 |
+
output = encoder(output)
|
| 170 |
+
enc_features.append(output)
|
| 171 |
+
dec_features.append(output)
|
| 172 |
+
|
| 173 |
+
# CGM
|
| 174 |
+
cls_branch_max = None
|
| 175 |
+
if self.use_cgm:
|
| 176 |
+
# (B, 2*3(output_size), 1)
|
| 177 |
+
cls_branch: torch.Tensor = self.cls(enc_features[-1])
|
| 178 |
+
# (B, 3(output_size))
|
| 179 |
+
cls_branch_max = cls_branch.view(
|
| 180 |
+
input.shape[0], self.output_size, 2
|
| 181 |
+
).argmax(2)
|
| 182 |
+
|
| 183 |
+
# Decoder
|
| 184 |
+
for i in reversed(range(self.num_depths - 1)):
|
| 185 |
+
_each_dec_feature = []
|
| 186 |
+
for j in range(self.num_depths):
|
| 187 |
+
if j <= i:
|
| 188 |
+
_each_enc = enc_features[j]
|
| 189 |
+
else:
|
| 190 |
+
_each_enc = F.interpolate(
|
| 191 |
+
dec_features[self.num_depths - j - 1],
|
| 192 |
+
enc_features[i].shape[2],
|
| 193 |
+
mode=self.interpolate_mode,
|
| 194 |
+
)
|
| 195 |
+
_each_dec_feature.append(
|
| 196 |
+
self.decoders[self.num_depths - i - 2][j](_each_enc)
|
| 197 |
+
)
|
| 198 |
+
dec_features.append(
|
| 199 |
+
self.decoders[self.num_depths - i - 2][-1](
|
| 200 |
+
torch.cat(_each_dec_feature, dim=1)
|
| 201 |
+
)
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
sup = []
|
| 205 |
+
for i, (dec_feature, sup_conv) in enumerate(
|
| 206 |
+
zip(dec_features, reversed(self.sup_conv))
|
| 207 |
+
):
|
| 208 |
+
if i < self.num_depths - 1:
|
| 209 |
+
sup.append(
|
| 210 |
+
F.interpolate(
|
| 211 |
+
sup_conv(dec_feature),
|
| 212 |
+
input.shape[2],
|
| 213 |
+
mode=self.interpolate_mode,
|
| 214 |
+
)
|
| 215 |
+
)
|
| 216 |
+
else:
|
| 217 |
+
sup.append(sup_conv(dec_feature))
|
| 218 |
+
|
| 219 |
+
if self.use_cgm:
|
| 220 |
+
if self.sum_of_sup:
|
| 221 |
+
return torch.sigmoid(
|
| 222 |
+
sum(
|
| 223 |
+
[
|
| 224 |
+
torch.einsum("ijk,ij->ijk", [_sup, cls_branch_max])
|
| 225 |
+
for _sup in reversed(sup)
|
| 226 |
+
]
|
| 227 |
+
)
|
| 228 |
+
)
|
| 229 |
+
else:
|
| 230 |
+
return [
|
| 231 |
+
torch.sigmoid(
|
| 232 |
+
torch.einsum("ijk,ij->ijk", [_sup, cls_branch_max])
|
| 233 |
+
for _sup in reversed(sup)
|
| 234 |
+
)
|
| 235 |
+
]
|
| 236 |
+
|
| 237 |
+
else:
|
| 238 |
+
if self.sum_of_sup:
|
| 239 |
+
return torch.sigmoid(sum(sup))
|
| 240 |
+
else:
|
| 241 |
+
return [torch.sigmoid(_sup) for _sup in reversed(sup)]
|
res/models/hrnetv2/best_config.json
DELETED
|
@@ -1,151 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"train": {
|
| 3 |
-
"progress": true,
|
| 4 |
-
"random_seed": 2407041220,
|
| 5 |
-
"resume_dir": [],
|
| 6 |
-
"checkpoint_dir": "/bfai/nfs_export/workspace/share/result/wogh/hrnet/train-240704_123013",
|
| 7 |
-
"checkpoint_save_freq": 1,
|
| 8 |
-
"working_dir": "",
|
| 9 |
-
"user": "wogh",
|
| 10 |
-
"name": "hrnet",
|
| 11 |
-
"exp_name": "wogh:hrnet",
|
| 12 |
-
"type": "supervised",
|
| 13 |
-
"task": "segmentation",
|
| 14 |
-
"epochs": 501,
|
| 15 |
-
"batch_size": 64,
|
| 16 |
-
"hpo": {
|
| 17 |
-
"num_samples": 256,
|
| 18 |
-
"criteria": {
|
| 19 |
-
"jaccard_avg": 1
|
| 20 |
-
},
|
| 21 |
-
"scheduler": {
|
| 22 |
-
"ASHAScheduler": {
|
| 23 |
-
"grace_period": 200,
|
| 24 |
-
"max_t": 501
|
| 25 |
-
}
|
| 26 |
-
}
|
| 27 |
-
},
|
| 28 |
-
"label": {
|
| 29 |
-
"num_labels": 3,
|
| 30 |
-
"path": [
|
| 31 |
-
"/bfai/nfs_export/workspace/share/labels/pqrst/ludb/train.csv",
|
| 32 |
-
"/bfai/nfs_export/workspace/share/labels/pqrst/ludb/valid.csv",
|
| 33 |
-
"/bfai/nfs_export/workspace/share/labels/pqrst/ludb/test.csv"
|
| 34 |
-
],
|
| 35 |
-
"target": [
|
| 36 |
-
"p_onoffs",
|
| 37 |
-
"qrs_onoffs",
|
| 38 |
-
"t_onoffs"
|
| 39 |
-
],
|
| 40 |
-
"split_ratio": [
|
| 41 |
-
1,
|
| 42 |
-
1,
|
| 43 |
-
1
|
| 44 |
-
]
|
| 45 |
-
},
|
| 46 |
-
"resource_per_trial": {
|
| 47 |
-
"num_workers": 1,
|
| 48 |
-
"num_gpus_per_worker": 1,
|
| 49 |
-
"num_cpus_per_worker": 16
|
| 50 |
-
},
|
| 51 |
-
"comment": "",
|
| 52 |
-
"tracking": true,
|
| 53 |
-
"available_resources": {
|
| 54 |
-
"available_gpus": 16.0
|
| 55 |
-
}
|
| 56 |
-
},
|
| 57 |
-
"solver": {
|
| 58 |
-
"SolverPQRST": {
|
| 59 |
-
"mixed_precision": true,
|
| 60 |
-
"gradient_clip": 0.1
|
| 61 |
-
}
|
| 62 |
-
},
|
| 63 |
-
"datasets": [
|
| 64 |
-
{
|
| 65 |
-
"ECGPQRSTDataset": {
|
| 66 |
-
"lead_type": [
|
| 67 |
-
"I",
|
| 68 |
-
"II",
|
| 69 |
-
"III",
|
| 70 |
-
"aVR",
|
| 71 |
-
"aVL",
|
| 72 |
-
"aVF",
|
| 73 |
-
"V1",
|
| 74 |
-
"V2",
|
| 75 |
-
"V3",
|
| 76 |
-
"V4",
|
| 77 |
-
"V5",
|
| 78 |
-
"V6"
|
| 79 |
-
],
|
| 80 |
-
"aux_data": [],
|
| 81 |
-
"normalization": "z_norm",
|
| 82 |
-
"second": 10,
|
| 83 |
-
"hz": 500
|
| 84 |
-
}
|
| 85 |
-
}
|
| 86 |
-
],
|
| 87 |
-
"models": [
|
| 88 |
-
{
|
| 89 |
-
"network": {
|
| 90 |
-
"HRNetV2": {
|
| 91 |
-
"data_len": 5000,
|
| 92 |
-
"kernel_size": 5,
|
| 93 |
-
"dilation": 1,
|
| 94 |
-
"num_stages": 3,
|
| 95 |
-
"num_blocks": 6,
|
| 96 |
-
"num_modules": [
|
| 97 |
-
1,
|
| 98 |
-
1,
|
| 99 |
-
1,
|
| 100 |
-
4,
|
| 101 |
-
3
|
| 102 |
-
],
|
| 103 |
-
"use_bottleneck": [
|
| 104 |
-
1,
|
| 105 |
-
0,
|
| 106 |
-
0,
|
| 107 |
-
0,
|
| 108 |
-
0
|
| 109 |
-
],
|
| 110 |
-
"stage1_channels": 128,
|
| 111 |
-
"num_channels_init": 48,
|
| 112 |
-
"interpolate_mode": "linear",
|
| 113 |
-
"task": "segmentation",
|
| 114 |
-
"num_leads": 12,
|
| 115 |
-
"num_aux": 0,
|
| 116 |
-
"output_size": 3,
|
| 117 |
-
"aux_output_size": 0
|
| 118 |
-
}
|
| 119 |
-
},
|
| 120 |
-
"optimizer": [
|
| 121 |
-
{
|
| 122 |
-
"SGD": {
|
| 123 |
-
"lr": 0.0983058839402403,
|
| 124 |
-
"momentum": 0.9,
|
| 125 |
-
"weight_decay": 0.0003850652731758502,
|
| 126 |
-
"sharpness_min": false
|
| 127 |
-
}
|
| 128 |
-
}
|
| 129 |
-
],
|
| 130 |
-
"scheduler": [
|
| 131 |
-
{
|
| 132 |
-
"PolynomialLR": {
|
| 133 |
-
"total_iters": 501,
|
| 134 |
-
"power": 0.0
|
| 135 |
-
}
|
| 136 |
-
}
|
| 137 |
-
]
|
| 138 |
-
}
|
| 139 |
-
],
|
| 140 |
-
"loss_fns": [
|
| 141 |
-
{
|
| 142 |
-
"BCEWithLogitsLoss": {}
|
| 143 |
-
}
|
| 144 |
-
],
|
| 145 |
-
"cur_epoch": 358,
|
| 146 |
-
"cutoff": [
|
| 147 |
-
0.001163482666015625,
|
| 148 |
-
0.15087890625,
|
| 149 |
-
-0.587890625
|
| 150 |
-
]
|
| 151 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|