|
|
|
|
|
from __future__ import absolute_import
|
|
|
from __future__ import division
|
|
|
from __future__ import print_function
|
|
|
|
|
|
import copy
|
|
|
import logging
|
|
|
import math
|
|
|
|
|
|
from os.path import join as pjoin
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import numpy as np
|
|
|
import os
|
|
|
|
|
|
from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
|
|
|
from torch.nn.modules.utils import _pair
|
|
|
from scipy import ndimage
|
|
|
from ._swin_transformer_unet_skip_expand_decoder_sys import SwinTransformerSys
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class model(nn.Module):
|
|
|
def __init__(self, img_size=224, in_channels=3, freeze_encoder=False, num_classes=21843, zero_head=False, vis=False):
|
|
|
super(model, self).__init__()
|
|
|
self.img_size = img_size
|
|
|
self.in_channels = in_channels
|
|
|
self.num_classes = num_classes
|
|
|
self.zero_head = zero_head
|
|
|
|
|
|
self.swin_unet = SwinTransformerSys(img_size=self.img_size,
|
|
|
patch_size=4,
|
|
|
in_chans=self.in_channels,
|
|
|
num_classes=self.num_classes,
|
|
|
embed_dim=96,
|
|
|
depths=[2, 2, 6, 2],
|
|
|
num_heads=[3, 6, 12, 24],
|
|
|
window_size=7,
|
|
|
mlp_ratio=4.,
|
|
|
qkv_bias=True,
|
|
|
qk_scale=None,
|
|
|
drop_rate=0.,
|
|
|
drop_path_rate=0.1,
|
|
|
ape=False,
|
|
|
patch_norm=True,
|
|
|
use_checkpoint=False)
|
|
|
base_path = os.path.dirname(os.path.abspath(__file__))
|
|
|
self.load_from(os.path.join(base_path, 'pretrained_weights/swin_tiny_patch4_window7_224.pth'))
|
|
|
|
|
|
def forward(self, x):
|
|
|
if x.size()[1] == 1:
|
|
|
x = x.repeat(1,3,1,1)
|
|
|
logits = self.swin_unet(x)
|
|
|
return logits
|
|
|
|
|
|
def load_from(self, pretrained_path):
|
|
|
if pretrained_path is not None:
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
pretrained_dict = torch.load(pretrained_path, map_location=device)
|
|
|
if "model" not in pretrained_dict:
|
|
|
pretrained_dict = {k[17:]:v for k,v in pretrained_dict.items()}
|
|
|
for k in list(pretrained_dict.keys()):
|
|
|
if "output" in k:
|
|
|
del pretrained_dict[k]
|
|
|
msg = self.swin_unet.load_state_dict(pretrained_dict,strict=False)
|
|
|
|
|
|
return
|
|
|
pretrained_dict = pretrained_dict['model']
|
|
|
|
|
|
model_dict = self.swin_unet.state_dict()
|
|
|
full_dict = copy.deepcopy(pretrained_dict)
|
|
|
for k, v in pretrained_dict.items():
|
|
|
if "layers." in k:
|
|
|
current_layer_num = 3-int(k[7:8])
|
|
|
current_k = "layers_up." + str(current_layer_num) + k[8:]
|
|
|
full_dict.update({current_k:v})
|
|
|
for k in list(full_dict.keys()):
|
|
|
if k in model_dict:
|
|
|
if full_dict[k].shape != model_dict[k].shape:
|
|
|
del full_dict[k]
|
|
|
|
|
|
msg = self.swin_unet.load_state_dict(full_dict, strict=False)
|
|
|
|
|
|
else:
|
|
|
print("none pretrain")
|
|
|
|