import streamlit as st import numpy as np import tensorflow as tf from tensorflow.keras import models, layers from tensorflow.keras.models import Model import matplotlib.pyplot as plt import cv2 import pickle class BrainTumorExplainer: def __init__(self, model_path): """ Initialize the Brain Tumor Explainer. Args: model_path (str): Path to the saved model weights """ self.tumor_types = ["Glioma", "Meningioma", "Pituitary Tumor"] self.model = self.load_model(model_path) def load_model(self, weights_path): """ Load model weights and rebuild the model. Args: weights_path (str): Path to model weights Returns: tf.keras.Model: Loaded and compiled model """ def rebuild_model(): """Rebuild the original model architecture""" model = models.Sequential([ layers.Conv2D(32, (3, 3), activation='relu', input_shape=(224, 224, 1)), layers.BatchNormalization(), layers.MaxPooling2D((2, 2)), layers.Conv2D(64, (3, 3), activation='relu'), layers.BatchNormalization(), layers.MaxPooling2D((2, 2)), layers.Conv2D(128, (3, 3), activation='relu'), layers.BatchNormalization(), layers.MaxPooling2D((2, 2)), layers.Conv2D(256, (3, 3), activation='relu'), layers.BatchNormalization(), layers.MaxPooling2D((2, 2)), layers.Flatten(), layers.Dropout(0.5), layers.Dense(512, activation='relu'), layers.BatchNormalization(), layers.Dropout(0.3), layers.Dense(3, activation='softmax') ]) model.compile( optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), loss='sparse_categorical_crossentropy', metrics=['accuracy'] ) return model try: # Rebuild the model architecture model = rebuild_model() # Load weights from pickle file with open(weights_path, 'rb') as f: weights_list = pickle.load(f) # Set the weights to the model model.set_weights(weights_list) return model except Exception as e: st.error(f"Error loading model weights: {e}") return None def preprocess_image(self, image): """ Preprocess the uploaded image for prediction. Args: image (PIL.Image): Input image Returns: numpy.ndarray: Preprocessed image """ # Convert to grayscale and resize img = image.resize((224, 224)).convert("L") img_array = np.array(img).astype("float32") / 255.0 img_array = np.expand_dims(img_array, axis=-1) img_array = np.expand_dims(img_array, axis=0) return img_array def grad_cam(self, img_array, layer_name=None): """ Generate Grad-CAM visualization. Args: img_array (numpy.ndarray): Preprocessed input image layer_name (str, optional): Specific layer for Grad-CAM. Defaults to last convolutional layer. Returns: tuple: Heatmap and original image with overlay """ # If no layer specified, find the last convolutional layer if layer_name is None: layer_name = [layer.name for layer in self.model.layers if isinstance(layer, layers.Conv2D)][-1] # Create model that outputs the last conv layer and the predictions grad_model = Model( inputs=self.model.inputs, outputs=[self.model.get_layer(layer_name).output, self.model.output] ) # Compute gradients with tf.GradientTape() as tape: conv_outputs, predictions = grad_model(img_array) predicted_class = tf.argmax(predictions[0]) loss = predictions[0][predicted_class] # Get the gradients of the loss with respect to the conv layer output gradients = tape.gradient(loss, conv_outputs) # Global average pooling of the gradients pooled_gradients = tf.reduce_mean(gradients, axis=(0, 1, 2)) # Weighted combination of the conv layer outputs conv_outputs = conv_outputs[0] heatmap = tf.reduce_mean( tf.multiply(pooled_gradients, conv_outputs), axis=-1 ).numpy() # Normalize the heatmap heatmap = np.maximum(heatmap, 0) heatmap /= np.max(heatmap) # Resize heatmap to original image size heatmap = cv2.resize(heatmap, (224, 224)) heatmap = np.uint8(255 * heatmap) heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) # Convert original image for overlay original_img = (img_array[0, :, :, 0] * 255).astype("uint8") original_img = cv2.cvtColor(original_img, cv2.COLOR_GRAY2RGB) # Overlay heatmap on original image superimposed_img = cv2.addWeighted(original_img, 0.6, heatmap, 0.4, 0) return heatmap, superimposed_img def predict_and_explain(self, img_array): """ Predict tumor type and generate explanation. Args: img_array (numpy.ndarray): Preprocessed input image Returns: tuple: Prediction, probabilities, heatmap, superimposed image """ # Predict prediction = self.model.predict(img_array) predicted_class_index = np.argmax(prediction) predicted_class = self.tumor_types[predicted_class_index] # Generate Grad-CAM visualization heatmap, superimposed_img = self.grad_cam(img_array) return predicted_class, prediction[0], heatmap, superimposed_img def main(): st.title("🧠 Brain Tumor Classification with Explainable AI") # Initialize the explainer explainer = BrainTumorExplainer("model.pkl") # File uploader uploaded_file = st.file_uploader( "Upload Brain MRI Image", type=["jpg", "jpeg", "png"] ) if uploaded_file is not None: from PIL import Image # Read the image image = Image.open(uploaded_file) # Display original image st.subheader("Original Image") st.image(image, use_container_width=True) # Preprocess the image img_array = explainer.preprocess_image(image) # Predict and explain if st.button("Analyze and Explain"): # Get prediction and explanation predicted_class, probabilities, heatmap, superimposed_img = \ explainer.predict_and_explain(img_array) # Display prediction results st.subheader("Prediction Results") st.write(f"**Detected Tumor Type:** {predicted_class}") # Show prediction probabilities st.write("Prediction Probabilities:") for tumor, prob in zip(explainer.tumor_types, probabilities): st.write(f"{tumor}: {prob:.2%}") # Display Grad-CAM visualizations col1, col2 = st.columns(2) with col1: st.subheader("Grad-CAM Heatmap") st.image(heatmap, use_container_width=True, caption="Areas of model's focus (red = high importance)") with col2: st.subheader("Heatmap Overlay") st.image(superimposed_img, use_container_width=True, caption="Heatmap superimposed on original image") # Explanation of the visualization st.info( "**Interpretation:**\n" "- The heatmap shows which regions of the image " "the model considers most important for its classification.\n" "- Warmer colors (red, yellow) indicate higher importance.\n" "- This helps understand how the AI makes its decision." ) if __name__ == "__main__": main()