import gradio as gr
import tensorflow as tf
import numpy as np
import os
# Load the model directly from the file path
model = tf.keras.models.load_model("alzimers_model.h5")
class_names = ['MildDemented', 'ModerateDemented', 'NonDemented', 'VeryMildDemented']
class_colors = {
'MildDemented': '#FFD700',
'ModerateDemented': '#FF8C00',
'NonDemented': '#32CD32',
'VeryMildDemented': '#1E90FF'
}
def stage_badges():
badges = ""
for cls in class_names:
badges += f'{cls.replace("Demented", " Demented")} '
return badges
def predict_image(img):
if img is None:
return "Please upload an image first."
img = img.resize((64, 64))
img_array = np.array(img) / 255.0
img_array = np.expand_dims(img_array, axis=0)
preds = model.predict(img_array)
pred_idx = np.argmax(preds)
pred_class = class_names[pred_idx]
confidence = float(np.max(preds))
color = class_colors[pred_class]
md = f"""
{pred_class.replace("Demented", " Demented")}
Confidence: {confidence:.2f}
"""
return md
description_md = f"""
Alzheimer's Detection AI Model
Deep learning model for Alzheimer's stage detection from MRI scans.
Upload an MRI image to predict the stage of Alzheimer's disease.
Stages:
{stage_badges()}
Instructions:
- Click 'Upload' or drag an MRI image.
- Click 'Submit' to see prediction.
- The model analyzes brain MRI scans to detect Alzheimer's progression.
"""
footer_md = """
---
Powered by TensorFlow & Gradio
"""
with gr.Blocks(theme=gr.themes.Monochrome(primary_hue="blue", secondary_hue="purple")) as demo:
gr.Markdown(description_md)
with gr.Row():
image_input = gr.Image(type="pil", label="Upload MRI Image", elem_id="centered_image", show_label=True)
gr.Markdown("## Prediction Result")
output = gr.Markdown(label="", elem_id="prediction_block")
gr.Markdown(footer_md)
submit_btn = gr.Button("Submit", elem_id="submit_btn", variant="primary")
submit_btn.click(fn=predict_image, inputs=image_input, outputs=output)
if __name__ == "__main__":
demo.launch()