Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import tensorflow as tf | |
| import numpy as np | |
| from PIL import Image | |
| import plotly.graph_objects as go | |
| from plotly.subplots import make_subplots | |
| # Define class names | |
| CLASS_NAMES = ['Glioma', 'Meningioma', 'No Tumor', 'Pituitary'] | |
| # Model configurations | |
| MODELS = { | |
| "MobileNetV2": { | |
| "path": "best_mobilenetv2.keras", | |
| "preprocessing": tf.keras.applications.mobilenet_v2.preprocess_input, | |
| "img_size": (224, 224), | |
| "description": "MobileNetV2 - Lightweight and efficient" | |
| }, | |
| "DenseNet121": { | |
| "path": "best_densenet121.keras", | |
| "preprocessing": tf.keras.applications.densenet.preprocess_input, | |
| "img_size": (224, 224), | |
| "description": "DenseNet121 - Dense connections for better gradient flow" | |
| }, | |
| "EfficientNetV2S": { | |
| "path": "best_efficientnetv2s.keras", | |
| "preprocessing": tf.keras.applications.efficientnet_v2.preprocess_input, | |
| "img_size": (224, 224), | |
| "description": "EfficientNetV2S - State-of-the-art efficiency" | |
| } | |
| } | |
| # Load all models at startup | |
| loaded_models = {} | |
| def load_model(model_name): | |
| """Load a model if not already loaded""" | |
| if model_name not in loaded_models: | |
| try: | |
| model_path = MODELS[model_name]["path"] | |
| loaded_models[model_name] = tf.keras.models.load_model(model_path) | |
| print(f"β Loaded {model_name}") | |
| except Exception as e: | |
| print(f"β Failed to load {model_name}: {str(e)}") | |
| return None | |
| return loaded_models[model_name] | |
| # Preload all models | |
| for model_name in MODELS.keys(): | |
| load_model(model_name) | |
| def preprocess_image(image, model_name): | |
| """Preprocess image according to model requirements""" | |
| img_size = MODELS[model_name]["img_size"] | |
| preprocessing_fn = MODELS[model_name]["preprocessing"] | |
| # Resize image | |
| img = image.resize(img_size) | |
| # Convert to array | |
| img_array = np.array(img) | |
| # Convert to RGB if grayscale | |
| if len(img_array.shape) == 2: | |
| img_array = np.stack([img_array] * 3, axis=-1) | |
| elif img_array.shape[-1] == 4: # RGBA | |
| img_array = img_array[..., :3] | |
| # Add batch dimension | |
| img_array = np.expand_dims(img_array, axis=0) | |
| # Apply model-specific preprocessing | |
| img_array = preprocessing_fn(img_array) | |
| return img_array | |
| def create_prediction_plot(predictions, class_names): | |
| """Create an interactive bar plot of predictions""" | |
| fig = go.Figure(data=[ | |
| go.Bar( | |
| x=predictions, | |
| y=class_names, | |
| orientation='h', | |
| marker=dict( | |
| color=predictions, | |
| colorscale='RdYlGn', | |
| showscale=True, | |
| colorbar=dict(title="Confidence") | |
| ), | |
| text=[f'{p:.2%}' for p in predictions], | |
| textposition='auto', | |
| ) | |
| ]) | |
| fig.update_layout( | |
| title="Prediction Confidence Distribution", | |
| xaxis_title="Confidence Score", | |
| yaxis_title="Tumor Type", | |
| height=400, | |
| xaxis=dict(range=[0, 1]), | |
| template="plotly_white" | |
| ) | |
| return fig | |
| def create_comparison_plot(all_predictions, model_names): | |
| """Create a grouped bar plot comparing predictions across models""" | |
| fig = go.Figure() | |
| colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#FFA07A'] | |
| for i, class_name in enumerate(CLASS_NAMES): | |
| fig.add_trace(go.Bar( | |
| name=class_name, | |
| x=model_names, | |
| y=[pred[i] for pred in all_predictions], | |
| marker_color=colors[i] | |
| )) | |
| fig.update_layout( | |
| title="Model Comparison - Prediction Confidence", | |
| xaxis_title="Model", | |
| yaxis_title="Confidence Score", | |
| barmode='group', | |
| height=450, | |
| template="plotly_white", | |
| legend=dict( | |
| title="Tumor Type", | |
| orientation="h", | |
| yanchor="bottom", | |
| y=1.02, | |
| xanchor="right", | |
| x=1 | |
| ) | |
| ) | |
| return fig | |
| def predict(image, model_name): | |
| """Make prediction using selected model""" | |
| if image is None: | |
| return "Please upload an image", None, None | |
| try: | |
| # Load model | |
| model = load_model(model_name) | |
| if model is None: | |
| return f"Error: Could not load {model_name}", None, None | |
| # Preprocess image | |
| processed_img = preprocess_image(image, model_name) | |
| # Make prediction | |
| predictions = model.predict(processed_img, verbose=0)[0] | |
| # Get predicted class | |
| predicted_idx = np.argmax(predictions) | |
| predicted_class = CLASS_NAMES[predicted_idx] | |
| confidence = predictions[predicted_idx] | |
| # Create result text | |
| result_text = f""" | |
| ### π¬ Diagnosis Result ({model_name}) | |
| **Predicted Class:** {predicted_class} | |
| **Confidence:** {confidence:.2%} | |
| #### All Class Probabilities: | |
| """ | |
| for i, (class_name, prob) in enumerate(zip(CLASS_NAMES, predictions)): | |
| emoji = "π―" if i == predicted_idx else "π" | |
| result_text += f"\n{emoji} **{class_name}:** {prob:.2%}" | |
| # Create visualization | |
| plot = create_prediction_plot(predictions, CLASS_NAMES) | |
| return result_text, plot, predictions | |
| except Exception as e: | |
| return f"Error during prediction: {str(e)}", None, None | |
| def compare_models(image): | |
| """Compare predictions across all models""" | |
| if image is None: | |
| return "Please upload an image", None | |
| try: | |
| all_predictions = [] | |
| model_names = [] | |
| result_text = "### π Model Comparison Results\n\n" | |
| for model_name in MODELS.keys(): | |
| model = load_model(model_name) | |
| if model is not None: | |
| processed_img = preprocess_image(image, model_name) | |
| predictions = model.predict(processed_img, verbose=0)[0] | |
| all_predictions.append(predictions) | |
| model_names.append(model_name) | |
| predicted_idx = np.argmax(predictions) | |
| predicted_class = CLASS_NAMES[predicted_idx] | |
| confidence = predictions[predicted_idx] | |
| result_text += f"**{model_name}:** {predicted_class} ({confidence:.2%})\n\n" | |
| if len(all_predictions) > 0: | |
| plot = create_comparison_plot(all_predictions, model_names) | |
| return result_text, plot | |
| else: | |
| return "Error: No models could be loaded", None | |
| except Exception as e: | |
| return f"Error during comparison: {str(e)}", None | |
| # Custom CSS for better styling | |
| custom_css = """ | |
| .gradio-container { | |
| font-family: 'Arial', sans-serif; | |
| } | |
| .gr-button-primary { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; | |
| border: none !important; | |
| } | |
| .gr-button-secondary { | |
| background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%) !important; | |
| border: none !important; | |
| } | |
| .output-markdown { | |
| background-color: #f8f9fa; | |
| padding: 20px; | |
| border-radius: 10px; | |
| box-shadow: 0 2px 4px rgba(0,0,0,0.1); | |
| } | |
| """ | |
| # Create Gradio interface | |
| with gr.Blocks(css=custom_css, title="Brain Tumor Classification") as app: | |
| gr.Markdown( | |
| """ | |
| # π§ Brain Tumor MRI Classification System | |
| Upload an MRI scan to classify brain tumors using state-of-the-art deep learning models. | |
| **Tumor Types:** Glioma, Meningioma, No Tumor, Pituitary | |
| """ | |
| ) | |
| with gr.Tabs(): | |
| # Single Model Prediction Tab | |
| with gr.TabItem("π Single Model Prediction"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_image = gr.Image(type="pil", label="Upload MRI Scan") | |
| model_dropdown = gr.Dropdown( | |
| choices=list(MODELS.keys()), | |
| value="MobileNetV2", | |
| label="Select Model", | |
| info="Choose which model to use for prediction" | |
| ) | |
| # Model description | |
| model_info = gr.Markdown(MODELS["MobileNetV2"]["description"]) | |
| def update_model_info(model_name): | |
| return MODELS[model_name]["description"] | |
| model_dropdown.change( | |
| fn=update_model_info, | |
| inputs=model_dropdown, | |
| outputs=model_info | |
| ) | |
| predict_btn = gr.Button("π¬ Analyze", variant="primary") | |
| with gr.Column(scale=2): | |
| output_text = gr.Markdown(label="Prediction Results") | |
| output_plot = gr.Plot(label="Confidence Distribution") | |
| predict_btn.click( | |
| fn=predict, | |
| inputs=[input_image, model_dropdown], | |
| outputs=[output_text, output_plot, gr.State()] | |
| ) | |
| # Model Comparison Tab | |
| with gr.TabItem("π Compare All Models"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| compare_image = gr.Image(type="pil", label="Upload MRI Scan") | |
| compare_btn = gr.Button("βοΈ Compare Models", variant="secondary") | |
| gr.Markdown( | |
| """ | |
| ### Available Models: | |
| - **MobileNetV2**: Fast and efficient | |
| - **DenseNet121**: Deep dense connections | |
| - **EfficientNetV2S**: Latest efficiency improvements | |
| """ | |
| ) | |
| with gr.Column(scale=2): | |
| compare_text = gr.Markdown(label="Comparison Results") | |
| compare_plot = gr.Plot(label="Model Comparison Visualization") | |
| compare_btn.click( | |
| fn=compare_models, | |
| inputs=compare_image, | |
| outputs=[compare_text, compare_plot] | |
| ) | |
| # Information Tab | |
| with gr.TabItem("βΉοΈ About"): | |
| gr.Markdown( | |
| """ | |
| ## About This Application | |
| This application uses deep learning models trained on brain MRI scans to classify different types of brain tumors. | |
| ### Tumor Types: | |
| 1. **Glioma**: A tumor that occurs in the brain and spinal cord | |
| 2. **Meningioma**: A tumor that forms on membranes covering the brain and spinal cord | |
| 3. **Pituitary**: A tumor in the pituitary gland | |
| 4. **No Tumor**: Healthy brain tissue | |
| ### Models: | |
| - **MobileNetV2**: Lightweight architecture ideal for mobile deployment | |
| - **DenseNet121**: Dense connections improve feature propagation | |
| - **EfficientNetV2S**: Optimized for both accuracy and efficiency | |
| ### Image Requirements: | |
| - Format: PNG, JPG, JPEG | |
| - The models automatically resize images to 224x224 pixels | |
| - Grayscale images are automatically converted to RGB | |
| ### Performance: | |
| All models achieve >99% test accuracy on the brain tumor dataset. | |
| --- | |
| **Note**: This is a demonstration system and should not be used for actual medical diagnosis. | |
| Always consult with qualified healthcare professionals for medical advice. | |
| """ | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| app.launch(share=True) | |