File size: 1,308 Bytes
ffba4ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import os
import numpy as np
import cv2
import torch
from nets import U_Net_P, R2AttDecoder, build_backbone

def weighted_sigmoid(arr, w=1):
    return 1. / (1 + np.exp(-arr * w))

state_dic_path = 'checkpoints/effi_b3_p_r2attunet_5.pkl'
channels = (24, 12, 40, 120, 384)
encoder = build_backbone('efficientnet_b3_p')
decoder = R2AttDecoder(channels=channels)

model = U_Net_P(encoder=encoder, decoder=decoder, output_ch=channels[0], num_classes=1)

state_dict = torch.load(state_dic_path, map_location=torch.device('cpu'))
model.load_state_dict(state_dict=state_dict, strict=True)
model.eval()
for param in model.parameters():
    param.requires_grad = False


img_root = 'test_images'
out_root = img_root.rstrip(os.sep) + '_vessel'
os.makedirs(out_root, exist_ok=True)

for fname in os.listdir(img_root):
    if fname.startswith('.'):
        continue
    im = cv2.imread(os.path.join(img_root, fname))[:, :, [2, 1, 0]]
    im = cv2.resize(im, (512, 512))
    im = np.multiply(im, 1 / 255.0)
    im = np.transpose(im, (2, 0, 1))
    im = torch.FloatTensor(np.array([im]))

    vessel = model(im)[0][0]
    vessel = vessel.data.numpy()
    vessel = weighted_sigmoid(vessel) * 255
    vessel = vessel.astype(np.uint8)
    cv2.imwrite(os.path.join(out_root, '.'.join(fname.split('.')[:-1] + ['png'])), vessel)