|
|
|
|
|
import os, glob, traceback |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
os.environ.setdefault("KERAS_BACKEND", "tensorflow") |
|
|
import keras |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
HF_MODEL_ID = "Vedag812/xray_cnn" |
|
|
CLASS_NAMES = ["NORMAL", "PNEUMONIA"] |
|
|
|
|
|
|
|
|
def load_model(): |
|
|
model_path = hf_hub_download(HF_MODEL_ID, filename="xray_cnn.keras") |
|
|
|
|
|
model = keras.saving.load_model(model_path, compile=False, safe_mode=False) |
|
|
return model |
|
|
|
|
|
def _infer_input_shape(model): |
|
|
|
|
|
try: |
|
|
shp = tuple(model.inputs[0].shape) |
|
|
except Exception: |
|
|
shp = getattr(model, "input_shape", None) |
|
|
if shp is None: |
|
|
return 150, 150, 1 |
|
|
if len(shp) < 4: |
|
|
return 150, 150, 1 |
|
|
H = int(shp[1]) if shp[1] is not None else 150 |
|
|
W = int(shp[2]) if shp[2] is not None else 150 |
|
|
C = int(shp[3]) if shp[3] is not None else 1 |
|
|
return H, W, C |
|
|
|
|
|
def preprocess(pil_img: Image.Image, target): |
|
|
H, W, C = target |
|
|
g = pil_img.convert("L").resize((W, H)) |
|
|
arr = np.array(g).astype("float32") / 255.0 |
|
|
if C == 1: |
|
|
x = np.expand_dims(arr, axis=(0, -1)) |
|
|
elif C == 3: |
|
|
x = np.expand_dims(np.stack([arr]*3, axis=-1), 0) |
|
|
else: |
|
|
x = np.expand_dims(np.repeat(arr[..., None], C, axis=-1), 0) |
|
|
return x |
|
|
|
|
|
def predict_fn(pil_img: Image.Image): |
|
|
try: |
|
|
if pil_img is None: |
|
|
return {"NORMAL": 0.0, "PNEUMONIA": 0.0}, "Please upload an image or pick a sample." |
|
|
model = load_model() |
|
|
H, W, C = _infer_input_shape(model) |
|
|
x = preprocess(pil_img, (H, W, C)) |
|
|
y = model.predict(x, verbose=0) |
|
|
prob = float(np.ravel(y)[0]) |
|
|
idx = int(prob > 0.5) |
|
|
conf = prob if idx == 1 else 1 - prob |
|
|
probs = {CLASS_NAMES[0]: 1 - prob, CLASS_NAMES[1]: prob} |
|
|
msg = f"Prediction: {CLASS_NAMES[idx]} | Confidence: {conf*100:.2f}%" |
|
|
return probs, msg |
|
|
except Exception as e: |
|
|
tip = ( |
|
|
"If this persists, make sure the Space has keras>=3 and tensorflow>=2.16." |
|
|
) |
|
|
err = f"⚠️ Error during prediction:\n\n{e}\n\n{tip}" |
|
|
|
|
|
|
|
|
return {"NORMAL": 0.0, "PNEUMONIA": 0.0}, err |
|
|
|
|
|
def list_examples(): |
|
|
files = [] |
|
|
for pat in ("images/*.jpeg", "images/*.jpg", "images/*.png"): |
|
|
files.extend(glob.glob(pat)) |
|
|
return [[p] for p in sorted(files)] |
|
|
|
|
|
with gr.Blocks(css=""" |
|
|
.gradio-container {max-width: 980px !important; margin: auto;} |
|
|
#title {text-align:center;} |
|
|
""") as demo: |
|
|
gr.Markdown("<h1 id='title'>Chest X-Ray Classification</h1>") |
|
|
gr.Markdown("Upload an image or click a sample. The model predicts NORMAL or PNEUMONIA.") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
inp = gr.Image(type="pil", image_mode="L", label="Upload X-ray") |
|
|
with gr.Row(): |
|
|
btn = gr.Button("Predict", variant="primary") |
|
|
gr.ClearButton(components=[inp]) |
|
|
gr.Markdown("### Samples") |
|
|
gr.Examples(examples=list_examples(), inputs=inp, examples_per_page=12) |
|
|
with gr.Column(scale=1): |
|
|
probs = gr.Label(num_top_classes=2, label="Class probabilities") |
|
|
out_text = gr.Markdown() |
|
|
|
|
|
btn.click(predict_fn, inputs=inp, outputs=[probs, out_text]) |
|
|
inp.change(predict_fn, inputs=inp, outputs=[probs, out_text]) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|