import gradio as gr import tensorflow as tf import numpy as np from PIL import Image import cv2 # Load your trained model model = tf.keras.models.load_model('brain_tumor_model.h5') # or .keras # Define class labels (adjust based on your model's classes) CLASS_LABELS = [ 'glioma_tumor', 'meningioma_tumor', 'no_tumor', 'pituitary_tumor' ] def preprocess_image(image): """Preprocess image for Xception model""" # Convert PIL to numpy array img_array = np.array(image) # Resize to 299x299 (Xception input size) img_resized = cv2.resize(img_array, (299, 299)) # Ensure RGB format if len(img_resized.shape) == 3 and img_resized.shape[2] == 3: pass # Already RGB elif len(img_resized.shape) == 2: img_resized = cv2.cvtColor(img_resized, cv2.COLOR_GRAY2RGB) # Normalize pixel values to [0, 1] img_normalized = img_resized.astype('float32') / 255.0 # Add batch dimension img_batch = np.expand_dims(img_normalized, axis=0) return img_batch def predict(image): """Make prediction on uploaded image""" if image is None: return "Please upload an image" try: # Preprocess the image processed_image = preprocess_image(image) # Make prediction predictions = model.predict(processed_image) # Get probabilities probabilities = tf.nn.softmax(predictions[0]).numpy() # Create results dictionary results = {} for i, label in enumerate(CLASS_LABELS): results[label.replace('_', ' ').title()] = float(probabilities[i]) return results except Exception as e: return f"Error processing image: {str(e)}" # Create Gradio interface demo = gr.Interface( fn=predict, inputs=gr.Image(type="pil", label="Upload Brain MRI Scan"), outputs=gr.Label(num_top_classes=4, label="Prediction"), title="🧠 Brain Tumor Classification - Xception Model", description=""" Upload an MRI brain scan image to classify tumor types. **Model:** Sequential Xception Architecture **Accuracy:** 99% (on validation set) **Classes:** - Glioma Tumor - Meningioma Tumor - No Tumor - Pituitary Tumor ⚠️ **Disclaimer:** For research/educational purposes only. Not for medical diagnosis. """, examples=[ # Add example images if you have them ], theme=gr.themes.Soft(), analytics_enabled=False ) if __name__ == "__main__": demo.launch()