mediaportal's picture
Update app.py
66d9598 verified
import os
# Forces the use of Keras 2 logic, preventing common 'recursion' errors in new environments
os.environ["TF_USE_LEGACY_KERAS"] = "1"
import gradio as gr
import tensorflow as tf
import tf_keras as keras
import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download
# --- CONFIGURATION ---
# Based on your provided link: https://huggingface.co/mediaportal/Braintumor-MRI-detection
REPO_ID = "mediaportal/Braintumor-MRI-detection"
MODEL_FILENAME = "BraintumorMRI99.h5"
# If your repo is set to PRIVATE, add your token to Space Secrets as 'HF_TOKEN'
hf_token = os.getenv("HF_TOKEN")
model = None
def load_model_with_progress(progress=gr.Progress(track_tqdm=True)):
global model
try:
progress(0, desc="Downloading model from Hugging Face...")
path = hf_hub_download(
repo_id=REPO_ID,
filename=MODEL_FILENAME,
token=hf_token
)
progress(0.7, desc="Loading weights into Xception architecture...")
# compile=False avoids loading the optimizer state from the training session
model = keras.models.load_model(path, compile=False)
progress(1.0, desc="✅ Model Ready!")
return "Model Loaded Successfully."
except Exception as e:
return f"❌ Error: {str(e)}"
def predict(img):
if model is None:
return "System is still initializing. Please wait."
if img is None:
return "No image provided."
# Preprocessing based on notebook: Xception requires 299x299
img = Image.fromarray(img.astype('uint8'), 'RGB').resize((299, 299))
# Rescale 1/255 as used in the notebook's ImageDataGenerator
img_array = np.array(img).astype('float32') / 255.0
img_array = np.expand_dims(img_array, axis=0)
prediction = model.predict(img_array)[0]
# Class labels identified from tr_df in the notebook
# Standard order for this dataset: glioma, meningioma, notumor, pituitary
labels = ["Glioma", "Meningioma", "No Tumor", "Pituitary"]
return {labels[i]: float(prediction[i]) for i in range(len(labels))}
# --- GRADIO INTERFACE ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🧠 Brain Tumor MRI Classification")
gr.Markdown("Identify Glioma, Meningioma, Pituitary tumors, or Healthy scans.")
status_box = gr.Markdown("⏳ Initializing... Checking access to " + REPO_ID)
with gr.Row():
with gr.Column():
input_img = gr.Image(label="Upload MRI Scan")
btn = gr.Button("Run Diagnosis", variant="primary")
with gr.Column():
output_label = gr.Label(num_top_classes=4, label="Prediction Result")
# Load model immediately upon app startup
demo.load(load_model_with_progress, outputs=status_box)
btn.click(fn=predict, inputs=input_img, outputs=output_label)
if __name__ == "__main__":
demo.queue().launch()