azeemkhan417's picture
Update app.py
04e2557 verified
import streamlit as st
import numpy as np
import tensorflow as tf
from tensorflow.keras import models, layers
from tensorflow.keras.models import Model
import matplotlib.pyplot as plt
import cv2
import pickle
class BrainTumorExplainer:
def __init__(self, model_path):
"""
Initialize the Brain Tumor Explainer.
Args:
model_path (str): Path to the saved model weights
"""
self.tumor_types = ["Glioma", "Meningioma", "Pituitary Tumor"]
self.model = self.load_model(model_path)
def load_model(self, weights_path):
"""
Load model weights and rebuild the model.
Args:
weights_path (str): Path to model weights
Returns:
tf.keras.Model: Loaded and compiled model
"""
def rebuild_model():
"""Rebuild the original model architecture"""
model = models.Sequential([
layers.Conv2D(32, (3, 3), activation='relu', input_shape=(224, 224, 1)),
layers.BatchNormalization(),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(64, (3, 3), activation='relu'),
layers.BatchNormalization(),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(128, (3, 3), activation='relu'),
layers.BatchNormalization(),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(256, (3, 3), activation='relu'),
layers.BatchNormalization(),
layers.MaxPooling2D((2, 2)),
layers.Flatten(),
layers.Dropout(0.5),
layers.Dense(512, activation='relu'),
layers.BatchNormalization(),
layers.Dropout(0.3),
layers.Dense(3, activation='softmax')
])
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
return model
try:
# Rebuild the model architecture
model = rebuild_model()
# Load weights from pickle file
with open(weights_path, 'rb') as f:
weights_list = pickle.load(f)
# Set the weights to the model
model.set_weights(weights_list)
return model
except Exception as e:
st.error(f"Error loading model weights: {e}")
return None
def preprocess_image(self, image):
"""
Preprocess the uploaded image for prediction.
Args:
image (PIL.Image): Input image
Returns:
numpy.ndarray: Preprocessed image
"""
# Convert to grayscale and resize
img = image.resize((224, 224)).convert("L")
img_array = np.array(img).astype("float32") / 255.0
img_array = np.expand_dims(img_array, axis=-1)
img_array = np.expand_dims(img_array, axis=0)
return img_array
def grad_cam(self, img_array, layer_name=None):
"""
Generate Grad-CAM visualization.
Args:
img_array (numpy.ndarray): Preprocessed input image
layer_name (str, optional): Specific layer for Grad-CAM. Defaults to last convolutional layer.
Returns:
tuple: Heatmap and original image with overlay
"""
# If no layer specified, find the last convolutional layer
if layer_name is None:
layer_name = [layer.name for layer in self.model.layers
if isinstance(layer, layers.Conv2D)][-1]
# Create model that outputs the last conv layer and the predictions
grad_model = Model(
inputs=self.model.inputs,
outputs=[self.model.get_layer(layer_name).output, self.model.output]
)
# Compute gradients
with tf.GradientTape() as tape:
conv_outputs, predictions = grad_model(img_array)
predicted_class = tf.argmax(predictions[0])
loss = predictions[0][predicted_class]
# Get the gradients of the loss with respect to the conv layer output
gradients = tape.gradient(loss, conv_outputs)
# Global average pooling of the gradients
pooled_gradients = tf.reduce_mean(gradients, axis=(0, 1, 2))
# Weighted combination of the conv layer outputs
conv_outputs = conv_outputs[0]
heatmap = tf.reduce_mean(
tf.multiply(pooled_gradients, conv_outputs),
axis=-1
).numpy()
# Normalize the heatmap
heatmap = np.maximum(heatmap, 0)
heatmap /= np.max(heatmap)
# Resize heatmap to original image size
heatmap = cv2.resize(heatmap, (224, 224))
heatmap = np.uint8(255 * heatmap)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
# Convert original image for overlay
original_img = (img_array[0, :, :, 0] * 255).astype("uint8")
original_img = cv2.cvtColor(original_img, cv2.COLOR_GRAY2RGB)
# Overlay heatmap on original image
superimposed_img = cv2.addWeighted(original_img, 0.6, heatmap, 0.4, 0)
return heatmap, superimposed_img
def predict_and_explain(self, img_array):
"""
Predict tumor type and generate explanation.
Args:
img_array (numpy.ndarray): Preprocessed input image
Returns:
tuple: Prediction, probabilities, heatmap, superimposed image
"""
# Predict
prediction = self.model.predict(img_array)
predicted_class_index = np.argmax(prediction)
predicted_class = self.tumor_types[predicted_class_index]
# Generate Grad-CAM visualization
heatmap, superimposed_img = self.grad_cam(img_array)
return predicted_class, prediction[0], heatmap, superimposed_img
def main():
st.title("🧠 Brain Tumor Classification with Explainable AI")
# Initialize the explainer
explainer = BrainTumorExplainer("model.pkl")
# File uploader
uploaded_file = st.file_uploader(
"Upload Brain MRI Image",
type=["jpg", "jpeg", "png"]
)
if uploaded_file is not None:
from PIL import Image
# Read the image
image = Image.open(uploaded_file)
# Display original image
st.subheader("Original Image")
st.image(image, use_container_width=True)
# Preprocess the image
img_array = explainer.preprocess_image(image)
# Predict and explain
if st.button("Analyze and Explain"):
# Get prediction and explanation
predicted_class, probabilities, heatmap, superimposed_img = \
explainer.predict_and_explain(img_array)
# Display prediction results
st.subheader("Prediction Results")
st.write(f"**Detected Tumor Type:** {predicted_class}")
# Show prediction probabilities
st.write("Prediction Probabilities:")
for tumor, prob in zip(explainer.tumor_types, probabilities):
st.write(f"{tumor}: {prob:.2%}")
# Display Grad-CAM visualizations
col1, col2 = st.columns(2)
with col1:
st.subheader("Grad-CAM Heatmap")
st.image(heatmap, use_container_width=True,
caption="Areas of model's focus (red = high importance)")
with col2:
st.subheader("Heatmap Overlay")
st.image(superimposed_img, use_container_width=True,
caption="Heatmap superimposed on original image")
# Explanation of the visualization
st.info(
"**Interpretation:**\n"
"- The heatmap shows which regions of the image "
"the model considers most important for its classification.\n"
"- Warmer colors (red, yellow) indicate higher importance.\n"
"- This helps understand how the AI makes its decision."
)
if __name__ == "__main__":
main()