File size: 6,296 Bytes
f16e7e0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 | # predict.py — inference engine: TF | ONNX FP32 | Dynamic INT8 | Static INT8
import os
import numpy as np
import cv2
import tensorflow as tf
import onnxruntime as ort
from tensorflow.keras.preprocessing.image import load_img
from src.utils import get_logger, get_gradcam_heatmap, get_last_conv_layer
logger = get_logger("predict")
class BrainTumorPredictor:
"""
Unified predictor supporting TF model, ONNX FP32,
ONNX Dynamic INT8, and ONNX Static INT8.
"""
BACKENDS = ["tensorflow", "onnx_fp32", "onnx_dynamic", "onnx_static"]
def __init__(self, cfg: dict, backend: str = "tensorflow"):
if backend not in self.BACKENDS:
raise ValueError(f"backend must be one of {self.BACKENDS}")
self.backend = backend
self.image_size = tuple(cfg["data"]["image_size"])
self.class_names = cfg["data"]["classes"]
self.save_dir = cfg["models"]["save_dir"]
self.onnx_dir = cfg["models"]["onnx_dir"]
self.tf_model = None
self.ort_session = None
self._load(backend)
def _load(self, backend: str):
if backend == "tensorflow":
path = os.path.join(self.save_dir, "ft_best.h5")
logger.info(f"Loading TF model from {path}")
self.tf_model = tf.keras.models.load_model(path, compile=False)
elif backend == "onnx_fp32":
path = os.path.join(self.onnx_dir, "model_fp32.onnx")
logger.info(f"Loading ONNX FP32 from {path}")
self.ort_session = ort.InferenceSession(path, providers=["CPUExecutionProvider"])
elif backend == "onnx_dynamic":
path = os.path.join(self.onnx_dir, "model_dynamic_int8.onnx")
logger.info(f"Loading ONNX Dynamic INT8 from {path}")
try:
self.ort_session = ort.InferenceSession(path, providers=["CPUExecutionProvider"])
except Exception as e:
raise RuntimeError(
f"ONNX Dynamic INT8 model is not supported in this ONNX Runtime build: {e}"
)
elif backend == "onnx_static":
path = os.path.join(self.onnx_dir, "model_static_int8.onnx")
logger.info(f"Loading ONNX Static INT8 from {path}")
try:
self.ort_session = ort.InferenceSession(path, providers=["CPUExecutionProvider"])
except Exception as e:
raise RuntimeError(
f"ONNX Static INT8 model is not supported in this ONNX Runtime build: {e}"
)
def preprocess(self, image_path: str) -> tuple:
img = load_img(image_path, target_size=self.image_size)
arr = np.array(img) / 255.0
img_input = np.expand_dims(arr, axis=0).astype(np.float32)
return img, arr, img_input
def predict(self, image_path: str) -> dict:
_, _, img_input = self.preprocess(image_path)
if self.backend == "tensorflow":
probs = self.tf_model.predict(img_input, verbose=0)[0]
else:
inp_name = self.ort_session.get_inputs()[0].name
out_name = self.ort_session.get_outputs()[0].name
probs = self.ort_session.run([out_name], {inp_name: img_input})[0][0]
pred_idx = int(np.argmax(probs))
pred_class = self.class_names[pred_idx]
confidence = float(probs[pred_idx]) * 100
all_probs = {cls: float(p) * 100 for cls, p in zip(self.class_names, probs)}
return {
"predicted_class": pred_class,
"confidence": round(confidence, 2),
"all_probabilities": all_probs,
"backend": self.backend,
}
def predict_with_gradcam(self, image_path: str) -> dict:
if self.backend != "tensorflow":
raise RuntimeError("Grad-CAM is only supported with tensorflow backend.")
result = self.predict(image_path)
_, arr, img_input = self.preprocess(image_path)
last_conv = get_last_conv_layer(self.tf_model)
heatmap, _ = get_gradcam_heatmap(self.tf_model, img_input, last_conv)
heatmap_resized = cv2.resize(heatmap, self.image_size)
heatmap_colored = cv2.cvtColor(
cv2.applyColorMap(np.uint8(255 * heatmap_resized), cv2.COLORMAP_JET),
cv2.COLOR_BGR2RGB
)
overlay = cv2.addWeighted(np.uint8(255 * arr), 0.6, heatmap_colored, 0.4, 0)
result["gradcam_overlay"] = overlay
result["heatmap"] = heatmap_resized
return result
if __name__ == "__main__":
import argparse
import matplotlib.pyplot as plt
from src.utils import load_config
parser = argparse.ArgumentParser(description="Brain Tumor MRI Predictor")
parser.add_argument("--image", required=True)
parser.add_argument("--backend", default="tensorflow", choices=BrainTumorPredictor.BACKENDS)
parser.add_argument("--gradcam", action="store_true")
args = parser.parse_args()
cfg = load_config("config.yaml")
predictor = BrainTumorPredictor(cfg, backend=args.backend)
if args.gradcam and args.backend == "tensorflow":
result = predictor.predict_with_gradcam(args.image)
fig, axes = plt.subplots(1, 3, figsize=(13, 4))
img = load_img(args.image, target_size=tuple(cfg["data"]["image_size"]))
axes[0].imshow(img)
axes[0].set_title("Input MRI")
axes[0].axis("off")
axes[1].imshow(result["heatmap"], cmap="jet")
axes[1].set_title("Grad-CAM")
axes[1].axis("off")
axes[2].imshow(result["gradcam_overlay"])
axes[2].set_title(f"Pred: {result['predicted_class']} ({result['confidence']:.1f}%)")
axes[2].axis("off")
plt.tight_layout()
plt.show()
else:
result = predictor.predict(args.image)
print("\n" + "=" * 42)
print(f" PREDICTION : {result['predicted_class'].upper()}")
print(f" CONFIDENCE : {result['confidence']:.2f}%")
print(f" BACKEND : {result['backend']}")
print("=" * 42)
print(" All probabilities:")
for cls, prob in sorted(result["all_probabilities"].items(), key=lambda x: -x[1]):
bar = "█" * int(prob / 4)
marker = " ← predicted" if cls == result["predicted_class"] else ""
print(f" {cls:<15} {prob:5.1f}% {bar}{marker}") |