File size: 4,045 Bytes
1d197a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import numpy as np
import tensorflow as tf
import gc

from ..config import resolve_lumen_model_dir

# Prevent TensorFlow from pre-allocating most GPU memory.
for gpu in tf.config.list_physical_devices("GPU"):
    try:
        tf.config.experimental.set_memory_growth(gpu, True)
    except Exception:
        pass

IMG_MEAN = tf.constant([60.3486], dtype=tf.float32)

model_path = str(resolve_lumen_model_dir())

try:
    model = tf.saved_model.load(model_path)
except:
    warning = (
        "Warning: No saved weights have been found, segmentation will be unsuccessful, "
        f"check that weights are saved in {model_path}"
    )
    print(warning)
   
def cast_and_center(image):
    image = tf.cast(image, dtype=tf.float32)
    image = image - IMG_MEAN
    return image

def set_input_channels(images, channels=3):
    image_dim = images.get_shape()
    if len(image_dim) < 4:
        images = tf.expand_dims(images, axis=3)
    if image_dim[-1] != channels:
        images = tf.tile(images, [1, 1, 1, channels])
    return images
        
def predict(images, return_confidence=False, return_class_confidence=False, batch_size=64):
    """Runs Convolutional Neural Network to predict image pixel class.

    If return_confidence is True, also returns per-frame confidence defined as
    the mean of the per-pixel max softmax probability.
    """
    dataset = tf.data.Dataset.from_tensor_slices((images))
    dataset = dataset.map(cast_and_center)
    dataset = dataset.batch(batch_size)
    num_batches = int(np.ceil(images.shape[0]/batch_size))
    
    pred = []
    confidences = []
    class_conf_lumen = []
    class_conf_plaque = []
    for i, batch in enumerate(dataset):
        batch = set_input_channels(batch)
        logits = model(batch, training=False)
        logits = tf.image.resize(logits, (tf.shape(batch)[1], tf.shape(batch)[2]))
        probs = tf.nn.softmax(logits, axis=-1)
        pred_batch = tf.argmax(probs, axis=-1, output_type=tf.dtypes.int32)
        pred.append(pred_batch)
        if return_confidence:
            per_pixel_conf = tf.reduce_max(probs, axis=-1)
            confidences.append(tf.reduce_mean(per_pixel_conf, axis=[1, 2]))
        if return_class_confidence:
            pred_lumen = tf.equal(pred_batch, 1)
            pred_plaque = tf.equal(pred_batch, 2)

            lumen_probs = probs[..., 1]
            plaque_probs = probs[..., 2]

            lumen_sum = tf.reduce_sum(tf.where(pred_lumen, lumen_probs, tf.zeros_like(lumen_probs)), axis=[1, 2])
            plaque_sum = tf.reduce_sum(tf.where(pred_plaque, plaque_probs, tf.zeros_like(plaque_probs)), axis=[1, 2])

            lumen_count = tf.reduce_sum(tf.cast(pred_lumen, tf.float32), axis=[1, 2])
            plaque_count = tf.reduce_sum(tf.cast(pred_plaque, tf.float32), axis=[1, 2])

            lumen_mean = tf.where(lumen_count > 0, lumen_sum / lumen_count, tf.constant(np.nan, dtype=tf.float32))
            plaque_mean = tf.where(plaque_count > 0, plaque_sum / plaque_count, tf.constant(np.nan, dtype=tf.float32))

            class_conf_lumen.append(lumen_mean)
            class_conf_plaque.append(plaque_mean)
        print('Batch {} of {} completed'.format(i+1, num_batches))
    pred = np.concatenate(pred)
    if return_confidence:
        confidences = np.concatenate(confidences)
    if return_class_confidence:
        class_conf_lumen = np.concatenate(class_conf_lumen)
        class_conf_plaque = np.concatenate(class_conf_plaque)
        class_conf = {"lumen": class_conf_lumen, "plaque": class_conf_plaque}

    if return_confidence and return_class_confidence:
        return pred, confidences, class_conf
    if return_confidence:
        return pred, confidences
    if return_class_confidence:
        return pred, class_conf
    return pred


def release_resources():
    """Release TensorFlow resources after segmentation so other frameworks can use GPU memory."""
    global model
    model = None
    try:
        tf.keras.backend.clear_session()
    except Exception:
        pass
    gc.collect()