File size: 6,296 Bytes
f16e7e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
# predict.py — inference engine: TF | ONNX FP32 | Dynamic INT8 | Static INT8

import os
import numpy as np
import cv2
import tensorflow as tf
import onnxruntime as ort
from tensorflow.keras.preprocessing.image import load_img

from src.utils import get_logger, get_gradcam_heatmap, get_last_conv_layer

logger = get_logger("predict")


class BrainTumorPredictor:
    """
    Unified predictor supporting TF model, ONNX FP32,
    ONNX Dynamic INT8, and ONNX Static INT8.
    """

    BACKENDS = ["tensorflow", "onnx_fp32", "onnx_dynamic", "onnx_static"]

    def __init__(self, cfg: dict, backend: str = "tensorflow"):
        if backend not in self.BACKENDS:
            raise ValueError(f"backend must be one of {self.BACKENDS}")

        self.backend = backend
        self.image_size = tuple(cfg["data"]["image_size"])
        self.class_names = cfg["data"]["classes"]
        self.save_dir = cfg["models"]["save_dir"]
        self.onnx_dir = cfg["models"]["onnx_dir"]

        self.tf_model = None
        self.ort_session = None
        self._load(backend)

    def _load(self, backend: str):
        if backend == "tensorflow":
            path = os.path.join(self.save_dir, "ft_best.h5")
            logger.info(f"Loading TF model from {path}")
            self.tf_model = tf.keras.models.load_model(path, compile=False)

        elif backend == "onnx_fp32":
            path = os.path.join(self.onnx_dir, "model_fp32.onnx")
            logger.info(f"Loading ONNX FP32 from {path}")
            self.ort_session = ort.InferenceSession(path, providers=["CPUExecutionProvider"])

        elif backend == "onnx_dynamic":
            path = os.path.join(self.onnx_dir, "model_dynamic_int8.onnx")
            logger.info(f"Loading ONNX Dynamic INT8 from {path}")
            try:
                self.ort_session = ort.InferenceSession(path, providers=["CPUExecutionProvider"])
            except Exception as e:
                raise RuntimeError(
                    f"ONNX Dynamic INT8 model is not supported in this ONNX Runtime build: {e}"
                )

        elif backend == "onnx_static":
            path = os.path.join(self.onnx_dir, "model_static_int8.onnx")
            logger.info(f"Loading ONNX Static INT8 from {path}")
            try:
                self.ort_session = ort.InferenceSession(path, providers=["CPUExecutionProvider"])
            except Exception as e:
                raise RuntimeError(
                    f"ONNX Static INT8 model is not supported in this ONNX Runtime build: {e}"
                )

    def preprocess(self, image_path: str) -> tuple:
        img = load_img(image_path, target_size=self.image_size)
        arr = np.array(img) / 255.0
        img_input = np.expand_dims(arr, axis=0).astype(np.float32)
        return img, arr, img_input

    def predict(self, image_path: str) -> dict:
        _, _, img_input = self.preprocess(image_path)

        if self.backend == "tensorflow":
            probs = self.tf_model.predict(img_input, verbose=0)[0]
        else:
            inp_name = self.ort_session.get_inputs()[0].name
            out_name = self.ort_session.get_outputs()[0].name
            probs = self.ort_session.run([out_name], {inp_name: img_input})[0][0]

        pred_idx = int(np.argmax(probs))
        pred_class = self.class_names[pred_idx]
        confidence = float(probs[pred_idx]) * 100

        all_probs = {cls: float(p) * 100 for cls, p in zip(self.class_names, probs)}

        return {
            "predicted_class": pred_class,
            "confidence": round(confidence, 2),
            "all_probabilities": all_probs,
            "backend": self.backend,
        }

    def predict_with_gradcam(self, image_path: str) -> dict:
        if self.backend != "tensorflow":
            raise RuntimeError("Grad-CAM is only supported with tensorflow backend.")

        result = self.predict(image_path)
        _, arr, img_input = self.preprocess(image_path)

        last_conv = get_last_conv_layer(self.tf_model)
        heatmap, _ = get_gradcam_heatmap(self.tf_model, img_input, last_conv)

        heatmap_resized = cv2.resize(heatmap, self.image_size)
        heatmap_colored = cv2.cvtColor(
            cv2.applyColorMap(np.uint8(255 * heatmap_resized), cv2.COLORMAP_JET),
            cv2.COLOR_BGR2RGB
        )
        overlay = cv2.addWeighted(np.uint8(255 * arr), 0.6, heatmap_colored, 0.4, 0)

        result["gradcam_overlay"] = overlay
        result["heatmap"] = heatmap_resized
        return result


if __name__ == "__main__":
    import argparse
    import matplotlib.pyplot as plt
    from src.utils import load_config

    parser = argparse.ArgumentParser(description="Brain Tumor MRI Predictor")
    parser.add_argument("--image", required=True)
    parser.add_argument("--backend", default="tensorflow", choices=BrainTumorPredictor.BACKENDS)
    parser.add_argument("--gradcam", action="store_true")
    args = parser.parse_args()

    cfg = load_config("config.yaml")
    predictor = BrainTumorPredictor(cfg, backend=args.backend)

    if args.gradcam and args.backend == "tensorflow":
        result = predictor.predict_with_gradcam(args.image)

        fig, axes = plt.subplots(1, 3, figsize=(13, 4))
        img = load_img(args.image, target_size=tuple(cfg["data"]["image_size"]))
        axes[0].imshow(img)
        axes[0].set_title("Input MRI")
        axes[0].axis("off")

        axes[1].imshow(result["heatmap"], cmap="jet")
        axes[1].set_title("Grad-CAM")
        axes[1].axis("off")

        axes[2].imshow(result["gradcam_overlay"])
        axes[2].set_title(f"Pred: {result['predicted_class']} ({result['confidence']:.1f}%)")
        axes[2].axis("off")

        plt.tight_layout()
        plt.show()
    else:
        result = predictor.predict(args.image)

    print("\n" + "=" * 42)
    print(f"  PREDICTION : {result['predicted_class'].upper()}")
    print(f"  CONFIDENCE : {result['confidence']:.2f}%")
    print(f"  BACKEND    : {result['backend']}")
    print("=" * 42)
    print("  All probabilities:")
    for cls, prob in sorted(result["all_probabilities"].items(), key=lambda x: -x[1]):
        bar = "█" * int(prob / 4)
        marker = " ← predicted" if cls == result["predicted_class"] else ""
        print(f"  {cls:<15} {prob:5.1f}%  {bar}{marker}")