| |
| import os, glob |
| import numpy as np |
| from PIL import Image |
| import gradio as gr |
| import tensorflow as tf |
| from functools import lru_cache |
| from huggingface_hub import hf_hub_download |
|
|
| HF_MODEL_ID = "Vedag812/xray_cnn" |
| CLASS_NAMES = ["NORMAL", "PNEUMONIA"] |
|
|
| @lru_cache(maxsize=1) |
| def load_model(): |
| model_path = hf_hub_download(repo_id=HF_MODEL_ID, filename="xray_cnn.keras") |
| model = tf.keras.models.load_model(model_path, compile=False) |
| return model |
|
|
| def preprocess(pil_img: Image.Image): |
| img = pil_img.convert("L").resize((150, 150)) |
| arr = np.array(img).astype("float32") / 255.0 |
| arr = np.expand_dims(arr, axis=(0, -1)) |
| return arr |
|
|
| def predict_fn(pil_img: Image.Image): |
| model = load_model() |
| x = preprocess(pil_img) |
| prob = float(model.predict(x, verbose=0)[0][0]) |
| pred_idx = int(prob > 0.5) |
| confidence = prob if pred_idx == 1 else 1 - prob |
| probs = {CLASS_NAMES[0]: 1 - prob, CLASS_NAMES[1]: prob} |
| msg = f"Prediction: {CLASS_NAMES[pred_idx]} | Confidence: {confidence*100:.2f}%" |
| return probs, msg |
|
|
| def list_examples(): |
| files = [] |
| for pattern in ["images/*.jpeg", "images/*.jpg", "images/*.png"]: |
| files.extend(glob.glob(pattern)) |
| files = sorted(files) |
| return [[p] for p in files] |
|
|
| with gr.Blocks(css=""" |
| .gradio-container {max-width: 980px !important; margin: auto;} |
| #title {text-align:center;} |
| .card {border:1px solid #e5e7eb; border-radius:16px; padding:16px;} |
| """) as demo: |
| gr.Markdown("<h1 id='title'>Chest X-Ray Classification</h1>") |
| gr.Markdown("Upload an image or click a sample from the gallery. The model predicts NORMAL or PNEUMONIA.") |
|
|
| with gr.Row(): |
| with gr.Column(scale=2): |
| inp = gr.Image(type="pil", image_mode="L", label="Upload X-ray") |
| with gr.Row(): |
| btn = gr.Button("Predict", variant="primary") |
| clr = gr.ClearButton(components=[inp], value="Clear") |
| gr.Markdown("### Samples") |
| gr.Examples( |
| examples=list_examples(), |
| inputs=inp, |
| examples_per_page=12, |
| ) |
| with gr.Column(scale=1): |
| probs = gr.Label(num_top_classes=2, label="Class probabilities") |
| out_text = gr.Markdown() |
|
|
| |
| btn.click(predict_fn, inputs=inp, outputs=[probs, out_text]) |
| |
| inp.change(predict_fn, inputs=inp, outputs=[probs, out_text]) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|