# app.py import os, glob, traceback import numpy as np from PIL import Image import gradio as gr # Use Keras 3 with the TensorFlow backend os.environ.setdefault("KERAS_BACKEND", "tensorflow") import keras # Keras 3.x from huggingface_hub import hf_hub_download HF_MODEL_ID = "Vedag812/xray_cnn" CLASS_NAMES = ["NORMAL", "PNEUMONIA"] def load_model(): model_path = hf_hub_download(HF_MODEL_ID, filename="xray_cnn.keras") # safe_mode=False allows loading models saved with older or custom configs model = keras.saving.load_model(model_path, compile=False, safe_mode=False) return model def _infer_input_shape(model): # returns (H, W, C) try: shp = tuple(model.inputs[0].shape) except Exception: shp = getattr(model, "input_shape", None) if shp is None: return 150, 150, 1 if len(shp) < 4: return 150, 150, 1 H = int(shp[1]) if shp[1] is not None else 150 W = int(shp[2]) if shp[2] is not None else 150 C = int(shp[3]) if shp[3] is not None else 1 return H, W, C def preprocess(pil_img: Image.Image, target): H, W, C = target g = pil_img.convert("L").resize((W, H)) arr = np.array(g).astype("float32") / 255.0 # (H, W) if C == 1: x = np.expand_dims(arr, axis=(0, -1)) # (1,H,W,1) elif C == 3: x = np.expand_dims(np.stack([arr]*3, axis=-1), 0) # (1,H,W,3) else: x = np.expand_dims(np.repeat(arr[..., None], C, axis=-1), 0) return x def predict_fn(pil_img: Image.Image): try: if pil_img is None: return {"NORMAL": 0.0, "PNEUMONIA": 0.0}, "Please upload an image or pick a sample." model = load_model() H, W, C = _infer_input_shape(model) x = preprocess(pil_img, (H, W, C)) y = model.predict(x, verbose=0) prob = float(np.ravel(y)[0]) # sigmoid output idx = int(prob > 0.5) conf = prob if idx == 1 else 1 - prob probs = {CLASS_NAMES[0]: 1 - prob, CLASS_NAMES[1]: prob} msg = f"Prediction: {CLASS_NAMES[idx]} | Confidence: {conf*100:.2f}%" return probs, msg except Exception as e: tip = ( "If this persists, make sure the Space has keras>=3 and tensorflow>=2.16." ) err = f"⚠️ Error during prediction:\n\n{e}\n\n{tip}" # Optional: uncomment next line to print full stack to the Space logs # print(traceback.format_exc()) return {"NORMAL": 0.0, "PNEUMONIA": 0.0}, err def list_examples(): files = [] for pat in ("images/*.jpeg", "images/*.jpg", "images/*.png"): files.extend(glob.glob(pat)) return [[p] for p in sorted(files)] with gr.Blocks(css=""" .gradio-container {max-width: 980px !important; margin: auto;} #title {text-align:center;} """) as demo: gr.Markdown("

Chest X-Ray Classification

") gr.Markdown("Upload an image or click a sample. The model predicts NORMAL or PNEUMONIA.") with gr.Row(): with gr.Column(scale=2): inp = gr.Image(type="pil", image_mode="L", label="Upload X-ray") with gr.Row(): btn = gr.Button("Predict", variant="primary") gr.ClearButton(components=[inp]) gr.Markdown("### Samples") gr.Examples(examples=list_examples(), inputs=inp, examples_per_page=12) with gr.Column(scale=1): probs = gr.Label(num_top_classes=2, label="Class probabilities") out_text = gr.Markdown() btn.click(predict_fn, inputs=inp, outputs=[probs, out_text]) inp.change(predict_fn, inputs=inp, outputs=[probs, out_text]) if __name__ == "__main__": demo.launch()