|
|
import streamlit as st |
|
|
import numpy as np |
|
|
import tensorflow as tf |
|
|
from tensorflow.keras import models, layers |
|
|
from tensorflow.keras.models import Model |
|
|
import matplotlib.pyplot as plt |
|
|
import cv2 |
|
|
import pickle |
|
|
|
|
|
class BrainTumorExplainer: |
|
|
def __init__(self, model_path): |
|
|
""" |
|
|
Initialize the Brain Tumor Explainer. |
|
|
|
|
|
Args: |
|
|
model_path (str): Path to the saved model weights |
|
|
""" |
|
|
self.tumor_types = ["Glioma", "Meningioma", "Pituitary Tumor"] |
|
|
self.model = self.load_model(model_path) |
|
|
|
|
|
def load_model(self, weights_path): |
|
|
""" |
|
|
Load model weights and rebuild the model. |
|
|
|
|
|
Args: |
|
|
weights_path (str): Path to model weights |
|
|
|
|
|
Returns: |
|
|
tf.keras.Model: Loaded and compiled model |
|
|
""" |
|
|
def rebuild_model(): |
|
|
"""Rebuild the original model architecture""" |
|
|
model = models.Sequential([ |
|
|
layers.Conv2D(32, (3, 3), activation='relu', input_shape=(224, 224, 1)), |
|
|
layers.BatchNormalization(), |
|
|
layers.MaxPooling2D((2, 2)), |
|
|
layers.Conv2D(64, (3, 3), activation='relu'), |
|
|
layers.BatchNormalization(), |
|
|
layers.MaxPooling2D((2, 2)), |
|
|
layers.Conv2D(128, (3, 3), activation='relu'), |
|
|
layers.BatchNormalization(), |
|
|
layers.MaxPooling2D((2, 2)), |
|
|
layers.Conv2D(256, (3, 3), activation='relu'), |
|
|
layers.BatchNormalization(), |
|
|
layers.MaxPooling2D((2, 2)), |
|
|
layers.Flatten(), |
|
|
layers.Dropout(0.5), |
|
|
layers.Dense(512, activation='relu'), |
|
|
layers.BatchNormalization(), |
|
|
layers.Dropout(0.3), |
|
|
layers.Dense(3, activation='softmax') |
|
|
]) |
|
|
|
|
|
model.compile( |
|
|
optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), |
|
|
loss='sparse_categorical_crossentropy', |
|
|
metrics=['accuracy'] |
|
|
) |
|
|
return model |
|
|
|
|
|
try: |
|
|
|
|
|
model = rebuild_model() |
|
|
|
|
|
|
|
|
with open(weights_path, 'rb') as f: |
|
|
weights_list = pickle.load(f) |
|
|
|
|
|
|
|
|
model.set_weights(weights_list) |
|
|
|
|
|
return model |
|
|
except Exception as e: |
|
|
st.error(f"Error loading model weights: {e}") |
|
|
return None |
|
|
|
|
|
def preprocess_image(self, image): |
|
|
""" |
|
|
Preprocess the uploaded image for prediction. |
|
|
|
|
|
Args: |
|
|
image (PIL.Image): Input image |
|
|
|
|
|
Returns: |
|
|
numpy.ndarray: Preprocessed image |
|
|
""" |
|
|
|
|
|
img = image.resize((224, 224)).convert("L") |
|
|
img_array = np.array(img).astype("float32") / 255.0 |
|
|
img_array = np.expand_dims(img_array, axis=-1) |
|
|
img_array = np.expand_dims(img_array, axis=0) |
|
|
return img_array |
|
|
|
|
|
def grad_cam(self, img_array, layer_name=None): |
|
|
""" |
|
|
Generate Grad-CAM visualization. |
|
|
|
|
|
Args: |
|
|
img_array (numpy.ndarray): Preprocessed input image |
|
|
layer_name (str, optional): Specific layer for Grad-CAM. Defaults to last convolutional layer. |
|
|
|
|
|
Returns: |
|
|
tuple: Heatmap and original image with overlay |
|
|
""" |
|
|
|
|
|
if layer_name is None: |
|
|
layer_name = [layer.name for layer in self.model.layers |
|
|
if isinstance(layer, layers.Conv2D)][-1] |
|
|
|
|
|
|
|
|
grad_model = Model( |
|
|
inputs=self.model.inputs, |
|
|
outputs=[self.model.get_layer(layer_name).output, self.model.output] |
|
|
) |
|
|
|
|
|
|
|
|
with tf.GradientTape() as tape: |
|
|
conv_outputs, predictions = grad_model(img_array) |
|
|
predicted_class = tf.argmax(predictions[0]) |
|
|
loss = predictions[0][predicted_class] |
|
|
|
|
|
|
|
|
gradients = tape.gradient(loss, conv_outputs) |
|
|
|
|
|
|
|
|
pooled_gradients = tf.reduce_mean(gradients, axis=(0, 1, 2)) |
|
|
|
|
|
|
|
|
conv_outputs = conv_outputs[0] |
|
|
heatmap = tf.reduce_mean( |
|
|
tf.multiply(pooled_gradients, conv_outputs), |
|
|
axis=-1 |
|
|
).numpy() |
|
|
|
|
|
|
|
|
heatmap = np.maximum(heatmap, 0) |
|
|
heatmap /= np.max(heatmap) |
|
|
|
|
|
|
|
|
heatmap = cv2.resize(heatmap, (224, 224)) |
|
|
heatmap = np.uint8(255 * heatmap) |
|
|
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) |
|
|
|
|
|
|
|
|
original_img = (img_array[0, :, :, 0] * 255).astype("uint8") |
|
|
original_img = cv2.cvtColor(original_img, cv2.COLOR_GRAY2RGB) |
|
|
|
|
|
|
|
|
superimposed_img = cv2.addWeighted(original_img, 0.6, heatmap, 0.4, 0) |
|
|
|
|
|
return heatmap, superimposed_img |
|
|
|
|
|
def predict_and_explain(self, img_array): |
|
|
""" |
|
|
Predict tumor type and generate explanation. |
|
|
|
|
|
Args: |
|
|
img_array (numpy.ndarray): Preprocessed input image |
|
|
|
|
|
Returns: |
|
|
tuple: Prediction, probabilities, heatmap, superimposed image |
|
|
""" |
|
|
|
|
|
prediction = self.model.predict(img_array) |
|
|
predicted_class_index = np.argmax(prediction) |
|
|
predicted_class = self.tumor_types[predicted_class_index] |
|
|
|
|
|
|
|
|
heatmap, superimposed_img = self.grad_cam(img_array) |
|
|
|
|
|
return predicted_class, prediction[0], heatmap, superimposed_img |
|
|
|
|
|
def main(): |
|
|
st.title("🧠 Brain Tumor Classification with Explainable AI") |
|
|
|
|
|
|
|
|
explainer = BrainTumorExplainer("model.pkl") |
|
|
|
|
|
|
|
|
uploaded_file = st.file_uploader( |
|
|
"Upload Brain MRI Image", |
|
|
type=["jpg", "jpeg", "png"] |
|
|
) |
|
|
|
|
|
if uploaded_file is not None: |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
image = Image.open(uploaded_file) |
|
|
|
|
|
|
|
|
st.subheader("Original Image") |
|
|
st.image(image, use_container_width=True) |
|
|
|
|
|
|
|
|
img_array = explainer.preprocess_image(image) |
|
|
|
|
|
|
|
|
if st.button("Analyze and Explain"): |
|
|
|
|
|
predicted_class, probabilities, heatmap, superimposed_img = \ |
|
|
explainer.predict_and_explain(img_array) |
|
|
|
|
|
|
|
|
st.subheader("Prediction Results") |
|
|
st.write(f"**Detected Tumor Type:** {predicted_class}") |
|
|
|
|
|
|
|
|
st.write("Prediction Probabilities:") |
|
|
for tumor, prob in zip(explainer.tumor_types, probabilities): |
|
|
st.write(f"{tumor}: {prob:.2%}") |
|
|
|
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
|
|
|
with col1: |
|
|
st.subheader("Grad-CAM Heatmap") |
|
|
st.image(heatmap, use_container_width=True, |
|
|
caption="Areas of model's focus (red = high importance)") |
|
|
|
|
|
with col2: |
|
|
st.subheader("Heatmap Overlay") |
|
|
st.image(superimposed_img, use_container_width=True, |
|
|
caption="Heatmap superimposed on original image") |
|
|
|
|
|
|
|
|
st.info( |
|
|
"**Interpretation:**\n" |
|
|
"- The heatmap shows which regions of the image " |
|
|
"the model considers most important for its classification.\n" |
|
|
"- Warmer colors (red, yellow) indicate higher importance.\n" |
|
|
"- This helps understand how the AI makes its decision." |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |