Timerns's picture
Upload folder using huggingface_hub
984cdba verified
# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import logging
import math
from os.path import join as pjoin
import torch
import torch.nn as nn
import numpy as np
import os
from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
from torch.nn.modules.utils import _pair
from scipy import ndimage
from ._swin_transformer_unet_skip_expand_decoder_sys import SwinTransformerSys
logger = logging.getLogger(__name__)
class model(nn.Module):
def __init__(self, img_size=224, in_channels=3, freeze_encoder=False, num_classes=21843, zero_head=False, vis=False):
super(model, self).__init__()
self.img_size = img_size
self.in_channels = in_channels
self.num_classes = num_classes
self.zero_head = zero_head
self.swin_unet = SwinTransformerSys(img_size=self.img_size,
patch_size=4,
in_chans=self.in_channels,
num_classes=self.num_classes,
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
drop_path_rate=0.1,
ape=False,
patch_norm=True,
use_checkpoint=False)
base_path = os.path.dirname(os.path.abspath(__file__))
self.load_from(os.path.join(base_path, 'pretrained_weights/swin_tiny_patch4_window7_224.pth'))
def forward(self, x):
if x.size()[1] == 1:
x = x.repeat(1,3,1,1)
logits = self.swin_unet(x)
return logits
def load_from(self, pretrained_path):
if pretrained_path is not None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pretrained_dict = torch.load(pretrained_path, map_location=device)
if "model" not in pretrained_dict:
pretrained_dict = {k[17:]:v for k,v in pretrained_dict.items()}
for k in list(pretrained_dict.keys()):
if "output" in k:
del pretrained_dict[k]
msg = self.swin_unet.load_state_dict(pretrained_dict,strict=False)
# print(msg)
return
pretrained_dict = pretrained_dict['model']
model_dict = self.swin_unet.state_dict()
full_dict = copy.deepcopy(pretrained_dict)
for k, v in pretrained_dict.items():
if "layers." in k:
current_layer_num = 3-int(k[7:8])
current_k = "layers_up." + str(current_layer_num) + k[8:]
full_dict.update({current_k:v})
for k in list(full_dict.keys()):
if k in model_dict:
if full_dict[k].shape != model_dict[k].shape:
del full_dict[k]
msg = self.swin_unet.load_state_dict(full_dict, strict=False)
# print(msg)
else:
print("none pretrain")