import os import numpy as np import cv2 import torch from nets import U_Net_P, R2AttDecoder, build_backbone def weighted_sigmoid(arr, w=1): return 1. / (1 + np.exp(-arr * w)) state_dic_path = 'checkpoints/effi_b3_p_r2attunet_5.pkl' channels = (24, 12, 40, 120, 384) encoder = build_backbone('efficientnet_b3_p') decoder = R2AttDecoder(channels=channels) model = U_Net_P(encoder=encoder, decoder=decoder, output_ch=channels[0], num_classes=1) state_dict = torch.load(state_dic_path, map_location=torch.device('cpu')) model.load_state_dict(state_dict=state_dict, strict=True) model.eval() for param in model.parameters(): param.requires_grad = False img_root = 'test_images' out_root = img_root.rstrip(os.sep) + '_vessel' os.makedirs(out_root, exist_ok=True) for fname in os.listdir(img_root): if fname.startswith('.'): continue im = cv2.imread(os.path.join(img_root, fname))[:, :, [2, 1, 0]] im = cv2.resize(im, (512, 512)) im = np.multiply(im, 1 / 255.0) im = np.transpose(im, (2, 0, 1)) im = torch.FloatTensor(np.array([im])) vessel = model(im)[0][0] vessel = vessel.data.numpy() vessel = weighted_sigmoid(vessel) * 255 vessel = vessel.astype(np.uint8) cv2.imwrite(os.path.join(out_root, '.'.join(fname.split('.')[:-1] + ['png'])), vessel)