ivus-segmentation / deepivus /inference /tensorflow_predictor.py
Aditya2162's picture
Upload folder using huggingface_hub
1d197a4 verified
import numpy as np
import tensorflow as tf
import gc
from ..config import resolve_lumen_model_dir
# Prevent TensorFlow from pre-allocating most GPU memory.
for gpu in tf.config.list_physical_devices("GPU"):
try:
tf.config.experimental.set_memory_growth(gpu, True)
except Exception:
pass
IMG_MEAN = tf.constant([60.3486], dtype=tf.float32)
model_path = str(resolve_lumen_model_dir())
try:
model = tf.saved_model.load(model_path)
except:
warning = (
"Warning: No saved weights have been found, segmentation will be unsuccessful, "
f"check that weights are saved in {model_path}"
)
print(warning)
def cast_and_center(image):
image = tf.cast(image, dtype=tf.float32)
image = image - IMG_MEAN
return image
def set_input_channels(images, channels=3):
image_dim = images.get_shape()
if len(image_dim) < 4:
images = tf.expand_dims(images, axis=3)
if image_dim[-1] != channels:
images = tf.tile(images, [1, 1, 1, channels])
return images
def predict(images, return_confidence=False, return_class_confidence=False, batch_size=64):
"""Runs Convolutional Neural Network to predict image pixel class.
If return_confidence is True, also returns per-frame confidence defined as
the mean of the per-pixel max softmax probability.
"""
dataset = tf.data.Dataset.from_tensor_slices((images))
dataset = dataset.map(cast_and_center)
dataset = dataset.batch(batch_size)
num_batches = int(np.ceil(images.shape[0]/batch_size))
pred = []
confidences = []
class_conf_lumen = []
class_conf_plaque = []
for i, batch in enumerate(dataset):
batch = set_input_channels(batch)
logits = model(batch, training=False)
logits = tf.image.resize(logits, (tf.shape(batch)[1], tf.shape(batch)[2]))
probs = tf.nn.softmax(logits, axis=-1)
pred_batch = tf.argmax(probs, axis=-1, output_type=tf.dtypes.int32)
pred.append(pred_batch)
if return_confidence:
per_pixel_conf = tf.reduce_max(probs, axis=-1)
confidences.append(tf.reduce_mean(per_pixel_conf, axis=[1, 2]))
if return_class_confidence:
pred_lumen = tf.equal(pred_batch, 1)
pred_plaque = tf.equal(pred_batch, 2)
lumen_probs = probs[..., 1]
plaque_probs = probs[..., 2]
lumen_sum = tf.reduce_sum(tf.where(pred_lumen, lumen_probs, tf.zeros_like(lumen_probs)), axis=[1, 2])
plaque_sum = tf.reduce_sum(tf.where(pred_plaque, plaque_probs, tf.zeros_like(plaque_probs)), axis=[1, 2])
lumen_count = tf.reduce_sum(tf.cast(pred_lumen, tf.float32), axis=[1, 2])
plaque_count = tf.reduce_sum(tf.cast(pred_plaque, tf.float32), axis=[1, 2])
lumen_mean = tf.where(lumen_count > 0, lumen_sum / lumen_count, tf.constant(np.nan, dtype=tf.float32))
plaque_mean = tf.where(plaque_count > 0, plaque_sum / plaque_count, tf.constant(np.nan, dtype=tf.float32))
class_conf_lumen.append(lumen_mean)
class_conf_plaque.append(plaque_mean)
print('Batch {} of {} completed'.format(i+1, num_batches))
pred = np.concatenate(pred)
if return_confidence:
confidences = np.concatenate(confidences)
if return_class_confidence:
class_conf_lumen = np.concatenate(class_conf_lumen)
class_conf_plaque = np.concatenate(class_conf_plaque)
class_conf = {"lumen": class_conf_lumen, "plaque": class_conf_plaque}
if return_confidence and return_class_confidence:
return pred, confidences, class_conf
if return_confidence:
return pred, confidences
if return_class_confidence:
return pred, class_conf
return pred
def release_resources():
"""Release TensorFlow resources after segmentation so other frameworks can use GPU memory."""
global model
model = None
try:
tf.keras.backend.clear_session()
except Exception:
pass
gc.collect()