File size: 3,143 Bytes
c5f4ee2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6735c2f
c5f4ee2
 
 
 
 
 
 
6735c2f
c5f4ee2
 
 
 
 
 
6735c2f
c5f4ee2
 
 
 
 
6735c2f
c5f4ee2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# Example code for running inference on a pre-trained model
import os
import json
import numpy as np
import cv2
import torch
from models import build_model


# os.environ['CUDA_VISIBLE_DEVICES'] = "0,1"

device = torch.device('cuda' if torch.cuda.is_available() else "cpu")

def sigmoid(arr):
    return 1. / (1 + np.exp(-arr))

class Inference(object):
    def __init__(self, model_path):
        self.model_path = model_path
        config_path = os.path.join(model_path, 'config.json')
        with open(config_path) as fin:
            params = json.load(fin)
        self.model_params = params['model_params']
        self.modality_mapping = params['modality_mapping']
        self.model = self.load_model()
        
        
    def inference(self, image, modality):
        assert modality in self.modality_mapping, "Modality '{}' not supported".format(modality)
        
        image, raw_h, raw_w = self.load_image(image)
        modality_idx = self.modality_mapping[modality]
        modality_idx = torch.tensor([modality_idx])
        with torch.no_grad():
            output = self.model.predict(x=image, device=device, dataset_idx=modality_idx)
        output = output.data.cpu().numpy()[0][0]
        output = sigmoid(output) * 255
        output = output.astype(np.uint8)
        output = cv2.resize(output, (raw_w, raw_h))
        return output
        
    def load_image(self, image):
        # Load the image and preprocess it
        if isinstance(image, str):
            image = cv2.imread(image)[:, :, [2, 1, 0]]
        raw_h, raw_w = image.shape[:2]
        image = cv2.resize(image, (self.model_params['size_w'], self.model_params['size_h']))
        image = image.astype(np.float32) / 255.0
        image = np.transpose(image, (2, 0, 1))
        image = np.expand_dims(image, axis=0)
        image = torch.tensor(image)
        return image, raw_h, raw_w
    
    def load_model(self):
        print('Loading model from {}'.format(self.model_path))
        model = build_model(model_name=self.model_params['net'], 
                            model_params=self.model_params, 
                            training=False, 
                            dataset_idx=list(self.modality_mapping.values()),
                            pretrained=False)
        #print(model.model.pos_promot3['0'])

        model.set_device(device)
        # model.requires_grad_false()
        model.load_model(os.path.join(self.model_path, 'model.pkl'))
        model.set_mode('eval')
        
        return model


if __name__ == '__main__':
    model_path = 'checkpoints/UNet_DCP_1024'
    image_paths = [
        'images/FFA.bmp',
        'images/CFP.jpg',
        'images/SLO.jpg',
        'images/UWF.jpg',
        'images/OCTA.png'
        ]
    modalities = ['FFA', 'CFP', 'SLO', 'UWF', 'OCTA']
     
    output_root = 'output_images'
    os.makedirs(output_root, exist_ok=True)

    inference = Inference(model_path)
    
    for image_path, modality in zip(image_paths, modalities):
        output = inference.inference(image_path, modality)    
        cv2.imwrite(os.path.join(output_root, '{}.png'.format(modality)), output)