Spaces:
Sleeping
Sleeping
| import os | |
| import numpy as np | |
| import cv2 | |
| import torch | |
| from nets import U_Net_P, R2AttDecoder, build_backbone | |
| def weighted_sigmoid(arr, w=1): | |
| return 1. / (1 + np.exp(-arr * w)) | |
| state_dic_path = 'checkpoints/effi_b3_p_r2attunet_5.pkl' | |
| channels = (24, 12, 40, 120, 384) | |
| encoder = build_backbone('efficientnet_b3_p') | |
| decoder = R2AttDecoder(channels=channels) | |
| model = U_Net_P(encoder=encoder, decoder=decoder, output_ch=channels[0], num_classes=1) | |
| state_dict = torch.load(state_dic_path, map_location=torch.device('cpu')) | |
| model.load_state_dict(state_dict=state_dict, strict=True) | |
| model.eval() | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| img_root = 'test_images' | |
| out_root = img_root.rstrip(os.sep) + '_vessel' | |
| os.makedirs(out_root, exist_ok=True) | |
| for fname in os.listdir(img_root): | |
| if fname.startswith('.'): | |
| continue | |
| im = cv2.imread(os.path.join(img_root, fname))[:, :, [2, 1, 0]] | |
| im = cv2.resize(im, (512, 512)) | |
| im = np.multiply(im, 1 / 255.0) | |
| im = np.transpose(im, (2, 0, 1)) | |
| im = torch.FloatTensor(np.array([im])) | |
| vessel = model(im)[0][0] | |
| vessel = vessel.data.numpy() | |
| vessel = weighted_sigmoid(vessel) * 255 | |
| vessel = vessel.astype(np.uint8) | |
| cv2.imwrite(os.path.join(out_root, '.'.join(fname.split('.')[:-1] + ['png'])), vessel) |