File size: 3,700 Bytes
429fe88 ae73f85 429fe88 ae73f85 fddd228 429fe88 fddd228 429fe88 ae73f85 fddd228 ae73f85 fddd228 ae73f85 fddd228 ae73f85 fddd228 ae73f85 fddd228 ae73f85 429fe88 fddd228 ae73f85 fddd228 ae73f85 fddd228 ae73f85 fddd228 ae73f85 fddd228 ae73f85 429fe88 ae73f85 fddd228 ae73f85 fddd228 ae73f85 fddd228 ae73f85 fddd228 ae73f85 fddd228 429fe88 fddd228 429fe88 fddd228 429fe88 fddd228 429fe88 fddd228 429fe88 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
# app.py
import os, glob, traceback
import numpy as np
from PIL import Image
import gradio as gr
# Use Keras 3 with the TensorFlow backend
os.environ.setdefault("KERAS_BACKEND", "tensorflow")
import keras # Keras 3.x
from huggingface_hub import hf_hub_download
HF_MODEL_ID = "Vedag812/xray_cnn"
CLASS_NAMES = ["NORMAL", "PNEUMONIA"]
@gr.cache_resource
def load_model():
model_path = hf_hub_download(HF_MODEL_ID, filename="xray_cnn.keras")
# safe_mode=False allows loading models saved with older or custom configs
model = keras.saving.load_model(model_path, compile=False, safe_mode=False)
return model
def _infer_input_shape(model):
# returns (H, W, C)
try:
shp = tuple(model.inputs[0].shape)
except Exception:
shp = getattr(model, "input_shape", None)
if shp is None:
return 150, 150, 1
if len(shp) < 4:
return 150, 150, 1
H = int(shp[1]) if shp[1] is not None else 150
W = int(shp[2]) if shp[2] is not None else 150
C = int(shp[3]) if shp[3] is not None else 1
return H, W, C
def preprocess(pil_img: Image.Image, target):
H, W, C = target
g = pil_img.convert("L").resize((W, H))
arr = np.array(g).astype("float32") / 255.0 # (H, W)
if C == 1:
x = np.expand_dims(arr, axis=(0, -1)) # (1,H,W,1)
elif C == 3:
x = np.expand_dims(np.stack([arr]*3, axis=-1), 0) # (1,H,W,3)
else:
x = np.expand_dims(np.repeat(arr[..., None], C, axis=-1), 0)
return x
def predict_fn(pil_img: Image.Image):
try:
if pil_img is None:
return {"NORMAL": 0.0, "PNEUMONIA": 0.0}, "Please upload an image or pick a sample."
model = load_model()
H, W, C = _infer_input_shape(model)
x = preprocess(pil_img, (H, W, C))
y = model.predict(x, verbose=0)
prob = float(np.ravel(y)[0]) # sigmoid output
idx = int(prob > 0.5)
conf = prob if idx == 1 else 1 - prob
probs = {CLASS_NAMES[0]: 1 - prob, CLASS_NAMES[1]: prob}
msg = f"Prediction: {CLASS_NAMES[idx]} | Confidence: {conf*100:.2f}%"
return probs, msg
except Exception as e:
tip = (
"If this persists, make sure the Space has keras>=3 and tensorflow>=2.16."
)
err = f"⚠️ Error during prediction:\n\n{e}\n\n{tip}"
# Optional: uncomment next line to print full stack to the Space logs
# print(traceback.format_exc())
return {"NORMAL": 0.0, "PNEUMONIA": 0.0}, err
def list_examples():
files = []
for pat in ("images/*.jpeg", "images/*.jpg", "images/*.png"):
files.extend(glob.glob(pat))
return [[p] for p in sorted(files)]
with gr.Blocks(css="""
.gradio-container {max-width: 980px !important; margin: auto;}
#title {text-align:center;}
""") as demo:
gr.Markdown("<h1 id='title'>Chest X-Ray Classification</h1>")
gr.Markdown("Upload an image or click a sample. The model predicts NORMAL or PNEUMONIA.")
with gr.Row():
with gr.Column(scale=2):
inp = gr.Image(type="pil", image_mode="L", label="Upload X-ray")
with gr.Row():
btn = gr.Button("Predict", variant="primary")
gr.ClearButton(components=[inp])
gr.Markdown("### Samples")
gr.Examples(examples=list_examples(), inputs=inp, examples_per_page=12)
with gr.Column(scale=1):
probs = gr.Label(num_top_classes=2, label="Class probabilities")
out_text = gr.Markdown()
btn.click(predict_fn, inputs=inp, outputs=[probs, out_text])
inp.change(predict_fn, inputs=inp, outputs=[probs, out_text])
if __name__ == "__main__":
demo.launch()
|