File size: 3,682 Bytes
429fe88
ae73f85
429fe88
 
 
ae73f85
1698179
 
 
 
429fe88
 
 
 
8ea9e10
429fe88
ae73f85
1698179
 
 
ae73f85
 
1698179
ae73f85
1698179
ae73f85
1698179
 
 
 
ae73f85
1698179
 
 
ae73f85
429fe88
1698179
 
ae73f85
1698179
ae73f85
1698179
ae73f85
1698179
ae73f85
1698179
ae73f85
429fe88
 
ae73f85
1698179
 
ae73f85
 
 
1698179
 
 
 
ae73f85
1698179
ae73f85
 
 
1698179
ae73f85
1698179
 
 
 
429fe88
 
 
1698179
 
 
429fe88
 
 
 
 
 
1698179
429fe88
 
 
 
 
 
1698179
429fe88
1698179
429fe88
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# app.py
import os, glob, traceback
import numpy as np
from PIL import Image
import gradio as gr

# Use Keras 3 with the TensorFlow backend
os.environ.setdefault("KERAS_BACKEND", "tensorflow")
import keras  # Keras 3.x
from huggingface_hub import hf_hub_download

HF_MODEL_ID = "Vedag812/xray_cnn"
CLASS_NAMES = ["NORMAL", "PNEUMONIA"]


def load_model():
    model_path = hf_hub_download(HF_MODEL_ID, filename="xray_cnn.keras")
    # safe_mode=False allows loading models saved with older or custom configs
    model = keras.saving.load_model(model_path, compile=False, safe_mode=False)
    return model

def _infer_input_shape(model):
    # returns (H, W, C)
    try:
        shp = tuple(model.inputs[0].shape)
    except Exception:
        shp = getattr(model, "input_shape", None)
        if shp is None:
            return 150, 150, 1
    if len(shp) < 4:
        return 150, 150, 1
    H = int(shp[1]) if shp[1] is not None else 150
    W = int(shp[2]) if shp[2] is not None else 150
    C = int(shp[3]) if shp[3] is not None else 1
    return H, W, C

def preprocess(pil_img: Image.Image, target):
    H, W, C = target
    g = pil_img.convert("L").resize((W, H))
    arr = np.array(g).astype("float32") / 255.0  # (H, W)
    if C == 1:
        x = np.expand_dims(arr, axis=(0, -1))           # (1,H,W,1)
    elif C == 3:
        x = np.expand_dims(np.stack([arr]*3, axis=-1), 0)  # (1,H,W,3)
    else:
        x = np.expand_dims(np.repeat(arr[..., None], C, axis=-1), 0)
    return x

def predict_fn(pil_img: Image.Image):
    try:
        if pil_img is None:
            return {"NORMAL": 0.0, "PNEUMONIA": 0.0}, "Please upload an image or pick a sample."
        model = load_model()
        H, W, C = _infer_input_shape(model)
        x = preprocess(pil_img, (H, W, C))
        y = model.predict(x, verbose=0)
        prob = float(np.ravel(y)[0])  # sigmoid output
        idx = int(prob > 0.5)
        conf = prob if idx == 1 else 1 - prob
        probs = {CLASS_NAMES[0]: 1 - prob, CLASS_NAMES[1]: prob}
        msg = f"Prediction: {CLASS_NAMES[idx]} | Confidence: {conf*100:.2f}%"
        return probs, msg
    except Exception as e:
        tip = (
            "If this persists, make sure the Space has keras>=3 and tensorflow>=2.16."
        )
        err = f"⚠️ Error during prediction:\n\n{e}\n\n{tip}"
        # Optional: uncomment next line to print full stack to the Space logs
        # print(traceback.format_exc())
        return {"NORMAL": 0.0, "PNEUMONIA": 0.0}, err

def list_examples():
    files = []
    for pat in ("images/*.jpeg", "images/*.jpg", "images/*.png"):
        files.extend(glob.glob(pat))
    return [[p] for p in sorted(files)]

with gr.Blocks(css="""
.gradio-container {max-width: 980px !important; margin: auto;}
#title {text-align:center;}
""") as demo:
    gr.Markdown("<h1 id='title'>Chest X-Ray Classification</h1>")
    gr.Markdown("Upload an image or click a sample. The model predicts NORMAL or PNEUMONIA.")

    with gr.Row():
        with gr.Column(scale=2):
            inp = gr.Image(type="pil", image_mode="L", label="Upload X-ray")
            with gr.Row():
                btn = gr.Button("Predict", variant="primary")
                gr.ClearButton(components=[inp])
            gr.Markdown("### Samples")
            gr.Examples(examples=list_examples(), inputs=inp, examples_per_page=12)
        with gr.Column(scale=1):
            probs = gr.Label(num_top_classes=2, label="Class probabilities")
            out_text = gr.Markdown()

    btn.click(predict_fn, inputs=inp, outputs=[probs, out_text])
    inp.change(predict_fn, inputs=inp, outputs=[probs, out_text])

if __name__ == "__main__":
    demo.launch()