|
|
import gradio as gr |
|
|
import tensorflow as tf |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
model = tf.keras.models.load_model('model.hdf5', compile=False) |
|
|
|
|
|
|
|
|
LABELS = ['NORMAL', 'TUBERCULOSIS', 'PNEUMONIA', 'COVID19'] |
|
|
|
|
|
|
|
|
def predict_input_image(img): |
|
|
img = img.resize((128, 128)) |
|
|
img_array = np.array(img).reshape(-1, 128, 128, 3) / 255.0 |
|
|
prediction = model.predict(img_array)[0] |
|
|
return {LABELS[i]: float(prediction[i]) for i in range(4)} |
|
|
|
|
|
|
|
|
def clear_image(): |
|
|
return None |
|
|
|
|
|
|
|
|
with gr.Blocks(title="Lung Disease Classification") as demo: |
|
|
gr.Markdown("## Lung Disease Classification Model\nUpload a chest X-ray to predict disease class.") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1, min_width=600): |
|
|
image = gr.Image(type="pil", label="Upload image") |
|
|
with gr.Row(): |
|
|
clear_btn = gr.Button("Clear") |
|
|
submit_btn = gr.Button("Submit", variant='primary') |
|
|
label = gr.Label(num_top_classes=4) |
|
|
|
|
|
|
|
|
clear_btn.click(fn=clear_image, inputs=[], outputs=image) |
|
|
submit_btn.click(fn=predict_input_image, inputs=image, outputs=label) |
|
|
|
|
|
|
|
|
demo.launch() |
|
|
|