Vedag812's picture
Update app.py
8ea9e10 verified
# app.py
import os, glob, traceback
import numpy as np
from PIL import Image
import gradio as gr
# Use Keras 3 with the TensorFlow backend
os.environ.setdefault("KERAS_BACKEND", "tensorflow")
import keras # Keras 3.x
from huggingface_hub import hf_hub_download
HF_MODEL_ID = "Vedag812/xray_cnn"
CLASS_NAMES = ["NORMAL", "PNEUMONIA"]
def load_model():
model_path = hf_hub_download(HF_MODEL_ID, filename="xray_cnn.keras")
# safe_mode=False allows loading models saved with older or custom configs
model = keras.saving.load_model(model_path, compile=False, safe_mode=False)
return model
def _infer_input_shape(model):
# returns (H, W, C)
try:
shp = tuple(model.inputs[0].shape)
except Exception:
shp = getattr(model, "input_shape", None)
if shp is None:
return 150, 150, 1
if len(shp) < 4:
return 150, 150, 1
H = int(shp[1]) if shp[1] is not None else 150
W = int(shp[2]) if shp[2] is not None else 150
C = int(shp[3]) if shp[3] is not None else 1
return H, W, C
def preprocess(pil_img: Image.Image, target):
H, W, C = target
g = pil_img.convert("L").resize((W, H))
arr = np.array(g).astype("float32") / 255.0 # (H, W)
if C == 1:
x = np.expand_dims(arr, axis=(0, -1)) # (1,H,W,1)
elif C == 3:
x = np.expand_dims(np.stack([arr]*3, axis=-1), 0) # (1,H,W,3)
else:
x = np.expand_dims(np.repeat(arr[..., None], C, axis=-1), 0)
return x
def predict_fn(pil_img: Image.Image):
try:
if pil_img is None:
return {"NORMAL": 0.0, "PNEUMONIA": 0.0}, "Please upload an image or pick a sample."
model = load_model()
H, W, C = _infer_input_shape(model)
x = preprocess(pil_img, (H, W, C))
y = model.predict(x, verbose=0)
prob = float(np.ravel(y)[0]) # sigmoid output
idx = int(prob > 0.5)
conf = prob if idx == 1 else 1 - prob
probs = {CLASS_NAMES[0]: 1 - prob, CLASS_NAMES[1]: prob}
msg = f"Prediction: {CLASS_NAMES[idx]} | Confidence: {conf*100:.2f}%"
return probs, msg
except Exception as e:
tip = (
"If this persists, make sure the Space has keras>=3 and tensorflow>=2.16."
)
err = f"⚠️ Error during prediction:\n\n{e}\n\n{tip}"
# Optional: uncomment next line to print full stack to the Space logs
# print(traceback.format_exc())
return {"NORMAL": 0.0, "PNEUMONIA": 0.0}, err
def list_examples():
files = []
for pat in ("images/*.jpeg", "images/*.jpg", "images/*.png"):
files.extend(glob.glob(pat))
return [[p] for p in sorted(files)]
with gr.Blocks(css="""
.gradio-container {max-width: 980px !important; margin: auto;}
#title {text-align:center;}
""") as demo:
gr.Markdown("<h1 id='title'>Chest X-Ray Classification</h1>")
gr.Markdown("Upload an image or click a sample. The model predicts NORMAL or PNEUMONIA.")
with gr.Row():
with gr.Column(scale=2):
inp = gr.Image(type="pil", image_mode="L", label="Upload X-ray")
with gr.Row():
btn = gr.Button("Predict", variant="primary")
gr.ClearButton(components=[inp])
gr.Markdown("### Samples")
gr.Examples(examples=list_examples(), inputs=inp, examples_per_page=12)
with gr.Column(scale=1):
probs = gr.Label(num_top_classes=2, label="Class probabilities")
out_text = gr.Markdown()
btn.click(predict_fn, inputs=inp, outputs=[probs, out_text])
inp.change(predict_fn, inputs=inp, outputs=[probs, out_text])
if __name__ == "__main__":
demo.launch()