Update app/model.py
Browse files- app/model.py +170 -176
app/model.py
CHANGED
|
@@ -1,176 +1,170 @@
|
|
| 1 |
-
#
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
import
|
| 5 |
-
import
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
import
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
"""
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
arr
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
"""
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
confidence
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
heatmap = heatmap
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
#
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
def gradcam(img: Image.Image, interpolant=0.5):
|
| 172 |
-
return compute_gradcam_overlay(img, interpolant=interpolant)
|
| 173 |
-
|
| 174 |
-
# ---------------------------
|
| 175 |
-
# End of app/model.py
|
| 176 |
-
# ---------------------------
|
|
|
|
| 1 |
+
# File: app/model.py
|
| 2 |
+
import os
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import tensorflow as tf
|
| 6 |
+
import cv2
|
| 7 |
+
from tensorflow.keras.models import Model
|
| 8 |
+
from tensorflow.keras.layers import Conv2D
|
| 9 |
+
|
| 10 |
+
# GPU setup
|
| 11 |
+
# Try to enable memory growth for GPUs to avoid TF pre-allocating all memory
|
| 12 |
+
gpus = tf.config.list_physical_devices("GPU")
|
| 13 |
+
if gpus:
|
| 14 |
+
try:
|
| 15 |
+
for g in gpus:
|
| 16 |
+
tf.config.experimental.set_memory_growth(g, True)
|
| 17 |
+
except Exception as e:
|
| 18 |
+
# If setting memory growth fails, just print a warning and continue
|
| 19 |
+
print("Warning: Could not set memory growth:", e)
|
| 20 |
+
|
| 21 |
+
print("Num GPUs Available:", len(gpus))
|
| 22 |
+
print("TensorFlow version:", tf.__version__)
|
| 23 |
+
|
| 24 |
+
# Load model
|
| 25 |
+
MODEL_PATH = os.getenv("MODEL_PATH", "saved_model/InceptionV3_Brain_Tumor_MRI.h5")
|
| 26 |
+
print("Loading model from:", MODEL_PATH)
|
| 27 |
+
model = tf.keras.models.load_model(MODEL_PATH)
|
| 28 |
+
model.trainable = False
|
| 29 |
+
|
| 30 |
+
# Find last conv layer and build grad_model once
|
| 31 |
+
# Find the last Conv2D layer
|
| 32 |
+
last_conv_layer = None
|
| 33 |
+
for layer in reversed(model.layers):
|
| 34 |
+
if isinstance(layer, Conv2D):
|
| 35 |
+
last_conv_layer = layer
|
| 36 |
+
break
|
| 37 |
+
if last_conv_layer is None:
|
| 38 |
+
raise RuntimeError("No Conv2D layer found in the model; cannot build Grad-CAM.")
|
| 39 |
+
|
| 40 |
+
target_layer = model.get_layer(last_conv_layer.name)
|
| 41 |
+
grad_model = Model(inputs=model.inputs, outputs=[target_layer.output, model.output])
|
| 42 |
+
print("Built grad_model with target layer:", target_layer.name)
|
| 43 |
+
|
| 44 |
+
# Labels
|
| 45 |
+
CLASS_NAMES = ["glioma", "meningioma", "notumor", "pituitary"]
|
| 46 |
+
|
| 47 |
+
# Preprocessing (use 299x299 for InceptionV3)
|
| 48 |
+
def preprocess_image_pil(img: Image.Image, target_size=(512, 512)):
|
| 49 |
+
"""
|
| 50 |
+
Accepts PIL.Image, returns float32 numpy array shaped (1,H,W,3) with values in [0,1].
|
| 51 |
+
"""
|
| 52 |
+
img = img.convert("RGB")
|
| 53 |
+
img = img.resize(target_size, resample=Image.BILINEAR)
|
| 54 |
+
arr = np.asarray(img).astype("float32") / 255.0
|
| 55 |
+
arr = np.expand_dims(arr, axis=0)
|
| 56 |
+
return arr
|
| 57 |
+
|
| 58 |
+
def pil_to_tf_tensor(img: Image.Image, target_size=(512, 512)):
|
| 59 |
+
"""
|
| 60 |
+
Convert PIL image to a TF tensor float32 (1,H,W,3) scaled to [0,1].
|
| 61 |
+
Uses TF ops to allow better GPU pipeline.
|
| 62 |
+
"""
|
| 63 |
+
arr = preprocess_image_pil(img, target_size=target_size)
|
| 64 |
+
return tf.convert_to_tensor(arr, dtype=tf.float32)
|
| 65 |
+
|
| 66 |
+
# Prediction helper
|
| 67 |
+
def predict(img: Image.Image):
|
| 68 |
+
"""
|
| 69 |
+
Returns (label, confidence, prob_dict)
|
| 70 |
+
"""
|
| 71 |
+
input_tensor = preprocess_image_pil(img) # numpy (1,H,W,3)
|
| 72 |
+
# Try to call model by direct positional input (works for most Keras models)
|
| 73 |
+
preds = model(input_tensor, training=False)
|
| 74 |
+
probs = preds.numpy()[0]
|
| 75 |
+
class_idx = int(np.argmax(probs))
|
| 76 |
+
confidence = float(np.max(probs))
|
| 77 |
+
prob_dict = {CLASS_NAMES[i]: float(probs[i]) for i in range(len(CLASS_NAMES))}
|
| 78 |
+
return CLASS_NAMES[class_idx], confidence, prob_dict
|
| 79 |
+
|
| 80 |
+
# Compiled Grad-CAM compute function
|
| 81 |
+
# We create a tf.function that computes conv features and gradients for a given input and class index
|
| 82 |
+
@tf.function
|
| 83 |
+
def _compute_conv_and_grads(img_input, class_index):
|
| 84 |
+
with tf.GradientTape() as tape:
|
| 85 |
+
conv_outputs, preds = grad_model(img_input)
|
| 86 |
+
|
| 87 |
+
# preds is probably a list -> convert it
|
| 88 |
+
if isinstance(preds, (list, tuple)):
|
| 89 |
+
preds = preds[0] # take the actual tensor
|
| 90 |
+
|
| 91 |
+
class_logits = preds[:, class_index]
|
| 92 |
+
|
| 93 |
+
grads = tape.gradient(class_logits, conv_outputs)
|
| 94 |
+
return conv_outputs, grads, preds
|
| 95 |
+
|
| 96 |
+
def compute_gradcam_overlay(img: Image.Image, interpolant=0.5, target_size=(512,512)):
|
| 97 |
+
"""
|
| 98 |
+
High-level wrapper:
|
| 99 |
+
-> builds input tensor
|
| 100 |
+
-> obtains predicted class index (fast forward)
|
| 101 |
+
-> calls compiled grad function to get conv features + grads
|
| 102 |
+
-> computes heatmap and overlay efficiently
|
| 103 |
+
Returns: overlay as uint8 HxWx3 numpy array
|
| 104 |
+
"""
|
| 105 |
+
# Build tensor
|
| 106 |
+
input_tf = pil_to_tf_tensor(img, target_size=target_size) # (1,H,W,3), float32
|
| 107 |
+
|
| 108 |
+
# Fast predict to get class index (cheap forward pass)
|
| 109 |
+
preds = model(input_tf, training=False)
|
| 110 |
+
pred_np = preds.numpy()[0]
|
| 111 |
+
class_idx = int(np.argmax(pred_np))
|
| 112 |
+
|
| 113 |
+
# Use compiled function to compute conv features and grads for that class
|
| 114 |
+
conv_out, grads, _ = _compute_conv_and_grads(input_tf, tf.constant(class_idx, dtype=tf.int64))
|
| 115 |
+
|
| 116 |
+
# Convert to numpy and handle shapes robustly
|
| 117 |
+
conv_out_np = conv_out.numpy()
|
| 118 |
+
grads_np = grads.numpy() if grads is not None else None
|
| 119 |
+
|
| 120 |
+
if grads_np is None:
|
| 121 |
+
# Fallback: gradients None --> return original image as overlay (no heatmap)
|
| 122 |
+
H = input_tf.shape[1]
|
| 123 |
+
W = input_tf.shape[2]
|
| 124 |
+
original_img = np.array(img.resize((W, H))).astype("uint8")
|
| 125 |
+
if original_img.ndim == 2:
|
| 126 |
+
original_img = np.stack([original_img]*3, axis=-1)
|
| 127 |
+
return original_img
|
| 128 |
+
|
| 129 |
+
# conv_out_np shape (1,Hf,Wf,C) --> take first batch
|
| 130 |
+
if conv_out_np.ndim == 4 and conv_out_np.shape[0] == 1:
|
| 131 |
+
conv_out_np = conv_out_np[0]
|
| 132 |
+
# grads_np shape (1,Hf,Wf,C)
|
| 133 |
+
if grads_np.ndim == 4 and grads_np.shape[0] == 1:
|
| 134 |
+
grads_np = grads_np[0]
|
| 135 |
+
|
| 136 |
+
# Global average pooling of gradients over spatial dims (Hf,Wf)
|
| 137 |
+
pooled_grads = np.mean(grads_np, axis=(0,1)) # shape (C,)
|
| 138 |
+
|
| 139 |
+
# Weighted sum of conv feature maps
|
| 140 |
+
heatmap = np.sum(conv_out_np * pooled_grads[np.newaxis, np.newaxis, :], axis=-1) # (Hf,Wf)
|
| 141 |
+
heatmap = np.maximum(heatmap, 0.0)
|
| 142 |
+
max_val = np.max(heatmap) if heatmap.size else 0.0
|
| 143 |
+
if max_val > 0:
|
| 144 |
+
heatmap = heatmap / (max_val + 1e-9)
|
| 145 |
+
else:
|
| 146 |
+
heatmap = np.zeros_like(heatmap, dtype=np.float32)
|
| 147 |
+
|
| 148 |
+
# Resize heatmap to original image size
|
| 149 |
+
H = input_tf.shape[1]
|
| 150 |
+
W = input_tf.shape[2]
|
| 151 |
+
original_img = np.array(img.resize((W, H))).astype("float32")
|
| 152 |
+
if original_img.ndim == 2:
|
| 153 |
+
original_img = np.stack([original_img]*3, axis=-1)
|
| 154 |
+
|
| 155 |
+
heatmap_resized = cv2.resize((heatmap * 255.0).astype("uint8"), (W, H))
|
| 156 |
+
heatmap_color = cv2.applyColorMap(heatmap_resized, cv2.COLORMAP_JET) # BGR
|
| 157 |
+
heatmap_color = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB).astype("float32")
|
| 158 |
+
|
| 159 |
+
# Ensure original image is in uint8 (0,255)
|
| 160 |
+
orig_uint8 = np.clip(original_img, 0, 255).astype("uint8")
|
| 161 |
+
|
| 162 |
+
# Combine using interpolant: (interpolant * original + (1-interpolant) * heatmap_color)
|
| 163 |
+
overlay = np.clip(orig_uint8.astype("float32") * interpolant + heatmap_color * (1.0 - interpolant), 0, 255).astype("uint8")
|
| 164 |
+
return overlay
|
| 165 |
+
|
| 166 |
+
# Expose functions for main.py
|
| 167 |
+
__all__ = ["model", "grad_model", "predict", "compute_gradcam_overlay", "CLASS_NAMES"]
|
| 168 |
+
# Backwards-compatible function name expected by main.py
|
| 169 |
+
def gradcam(img: Image.Image, interpolant=0.5):
|
| 170 |
+
return compute_gradcam_overlay(img, interpolant=interpolant)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|