Spaces:
Running
Running
| """Author: Yunpeng Chen.""" | |
| import logging | |
| from collections import OrderedDict | |
| import torch | |
| from torch import nn | |
| class BN_AC_CONV3D(nn.Module): | |
| def __init__( | |
| self, | |
| num_in, | |
| num_filter, | |
| kernel=(1, 1, 1), | |
| pad=(0, 0, 0), | |
| stride=(1, 1, 1), | |
| g=1, | |
| bias=False, | |
| ): | |
| super().__init__() | |
| self.bn = nn.BatchNorm3d(num_in) | |
| self.relu = nn.ReLU(inplace=True) | |
| self.conv = nn.Conv3d( | |
| num_in, | |
| num_filter, | |
| kernel_size=kernel, | |
| padding=pad, | |
| stride=stride, | |
| groups=g, | |
| bias=bias, | |
| ) | |
| def forward(self, x): | |
| h = self.relu(self.bn(x)) | |
| h = self.conv(h) | |
| return h | |
| class MF_UNIT(nn.Module): | |
| def __init__( | |
| self, | |
| num_in, | |
| num_mid, | |
| num_out, | |
| g=1, | |
| stride=(1, 1, 1), | |
| first_block=False, | |
| use_3d=True, | |
| ): | |
| super().__init__() | |
| num_ix = int(num_mid / 4) | |
| kt, pt = (3, 1) if use_3d else (1, 0) | |
| # prepare input | |
| self.conv_i1 = BN_AC_CONV3D( | |
| num_in=num_in, num_filter=num_ix, kernel=(1, 1, 1), pad=(0, 0, 0) | |
| ) | |
| self.conv_i2 = BN_AC_CONV3D( | |
| num_in=num_ix, num_filter=num_in, kernel=(1, 1, 1), pad=(0, 0, 0) | |
| ) | |
| # main part | |
| self.conv_m1 = BN_AC_CONV3D( | |
| num_in=num_in, | |
| num_filter=num_mid, | |
| kernel=(kt, 3, 3), | |
| pad=(pt, 1, 1), | |
| stride=stride, | |
| g=g, | |
| ) | |
| if first_block: | |
| self.conv_m2 = BN_AC_CONV3D( | |
| num_in=num_mid, num_filter=num_out, kernel=(1, 1, 1), pad=(0, 0, 0) | |
| ) | |
| else: | |
| self.conv_m2 = BN_AC_CONV3D( | |
| num_in=num_mid, num_filter=num_out, kernel=(1, 3, 3), pad=(0, 1, 1), g=g | |
| ) | |
| # adapter | |
| if first_block: | |
| self.conv_w1 = BN_AC_CONV3D( | |
| num_in=num_in, | |
| num_filter=num_out, | |
| kernel=(1, 1, 1), | |
| pad=(0, 0, 0), | |
| stride=stride, | |
| ) | |
| def forward(self, x): | |
| h = self.conv_i1(x) | |
| x_in = x + self.conv_i2(h) | |
| h = self.conv_m1(x_in) | |
| h = self.conv_m2(h) | |
| if hasattr(self, "conv_w1"): | |
| x = self.conv_w1(x) | |
| return h + x | |
| class MFNET_3D(nn.Module): | |
| """Original code: https://github.com/cypw/PyTorch-MFNet.""" | |
| def __init__( | |
| self, | |
| **_kwargs, | |
| ): | |
| super().__init__() | |
| groups = 16 | |
| k_sec = {2: 3, 3: 4, 4: 6, 5: 3} | |
| # conv1 - x224 (x16) | |
| conv1_num_out = 16 | |
| self.conv1 = nn.Sequential( | |
| OrderedDict( | |
| [ | |
| ( | |
| "conv", | |
| nn.Conv3d( | |
| 3, | |
| conv1_num_out, | |
| kernel_size=(3, 5, 5), | |
| padding=(1, 2, 2), | |
| stride=(1, 2, 2), | |
| bias=False, | |
| ), | |
| ), | |
| ("bn", nn.BatchNorm3d(conv1_num_out)), | |
| ("relu", nn.ReLU(inplace=True)), | |
| ] | |
| ) | |
| ) | |
| self.maxpool = nn.MaxPool3d( | |
| kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1) | |
| ) | |
| # conv2 - x56 (x8) | |
| num_mid = 96 | |
| conv2_num_out = 96 | |
| self.conv2 = nn.Sequential( | |
| OrderedDict( | |
| [ | |
| ( | |
| "B%02d" % i, | |
| MF_UNIT( | |
| num_in=conv1_num_out if i == 1 else conv2_num_out, | |
| num_mid=num_mid, | |
| num_out=conv2_num_out, | |
| stride=(2, 1, 1) if i == 1 else (1, 1, 1), | |
| g=groups, | |
| first_block=(i == 1), | |
| ), | |
| ) | |
| for i in range(1, k_sec[2] + 1) | |
| ] | |
| ) | |
| ) | |
| # conv3 - x28 (x8) | |
| num_mid *= 2 | |
| conv3_num_out = 2 * conv2_num_out | |
| self.conv3 = nn.Sequential( | |
| OrderedDict( | |
| [ | |
| ( | |
| "B%02d" % i, | |
| MF_UNIT( | |
| num_in=conv2_num_out if i == 1 else conv3_num_out, | |
| num_mid=num_mid, | |
| num_out=conv3_num_out, | |
| stride=(1, 2, 2) if i == 1 else (1, 1, 1), | |
| g=groups, | |
| first_block=(i == 1), | |
| ), | |
| ) | |
| for i in range(1, k_sec[3] + 1) | |
| ] | |
| ) | |
| ) | |
| # conv4 - x14 (x8) | |
| num_mid *= 2 | |
| conv4_num_out = 2 * conv3_num_out | |
| self.conv4 = nn.Sequential( | |
| OrderedDict( | |
| [ | |
| ( | |
| "B%02d" % i, | |
| MF_UNIT( | |
| num_in=conv3_num_out if i == 1 else conv4_num_out, | |
| num_mid=num_mid, | |
| num_out=conv4_num_out, | |
| stride=(1, 2, 2) if i == 1 else (1, 1, 1), | |
| g=groups, | |
| first_block=(i == 1), | |
| ), | |
| ) | |
| for i in range(1, k_sec[4] + 1) | |
| ] | |
| ) | |
| ) | |
| # conv5 - x7 (x8) | |
| num_mid *= 2 | |
| conv5_num_out = 2 * conv4_num_out | |
| self.conv5 = nn.Sequential( | |
| OrderedDict( | |
| [ | |
| ( | |
| "B%02d" % i, | |
| MF_UNIT( | |
| num_in=conv4_num_out if i == 1 else conv5_num_out, | |
| num_mid=num_mid, | |
| num_out=conv5_num_out, | |
| stride=(1, 2, 2) if i == 1 else (1, 1, 1), | |
| g=groups, | |
| first_block=(i == 1), | |
| ), | |
| ) | |
| for i in range(1, k_sec[5] + 1) | |
| ] | |
| ) | |
| ) | |
| # final | |
| self.tail = nn.Sequential( | |
| OrderedDict( | |
| [("bn", nn.BatchNorm3d(conv5_num_out)), ("relu", nn.ReLU(inplace=True))] | |
| ) | |
| ) | |
| self.globalpool = nn.Sequential( | |
| OrderedDict( | |
| [ | |
| ("avg", nn.AvgPool3d(kernel_size=(1, 7, 7), stride=(1, 1, 1))), | |
| ("dropout", nn.Dropout(p=0.5)), # only for fine-tuning | |
| ] | |
| ) | |
| ) | |
| # self.classifier = nn.Linear(conv5_num_out, num_classes) | |
| def forward(self, x): | |
| # assert x.shape[2] == 16 | |
| h = self.conv1(x) # x224 -> x112 | |
| h = self.maxpool(h) # x112 -> x56 | |
| h = self.conv2(h) # x56 -> x56 | |
| h = self.conv3(h) # x56 -> x28 | |
| h = self.conv4(h) # x28 -> x14 | |
| h = self.conv5(h) # x14 -> x7 | |
| h = self.tail(h) | |
| h = self.globalpool(h) | |
| h = h.view(h.shape[0], -1) | |
| # h = self.classifier(h) | |
| # h = h.view(h.shape[0], -1) | |
| return h | |
| def load_state(self, state_dict): | |
| # customized partialy load function | |
| checkpoint = torch.load(state_dict, map_location=torch.device("cpu")) | |
| state_dict = checkpoint["state_dict"] | |
| net_state_keys = list(self.state_dict().keys()) | |
| for name, param in state_dict.items(): | |
| name = name.replace("module.", "") | |
| if name in self.state_dict().keys(): | |
| dst_param_shape = self.state_dict()[name].shape | |
| if param.shape == dst_param_shape: | |
| self.state_dict()[name].copy_(param.view(dst_param_shape)) | |
| net_state_keys.remove(name) | |
| # indicating missed keys | |
| if net_state_keys: | |
| logging.warning(f">> Failed to load: {net_state_keys}") | |
| return self | |