nih-chest-xray / test.py
Jamshaid89's picture
initial model
728dfff
import gradio as gr
import numpy as np
from PIL import Image
# Dummy class names
CLASS_NAMES = {
"0": "Atelectasis",
"1": "Cardiomegaly",
"2": "Consolidation",
"3": "Edema",
"4": "Effusion",
"5": "Emphysema",
"6": "Fibrosis",
"7": "Hernia",
"8": "Infiltration",
"9": "Mass",
"10": "No Finding",
"11": "Nodule",
"12": "Pleural_Thickening",
"13": "Pneumonia",
"14": "Pneumothorax"
}
# Test function that preprocesses the image to the desired size
def test_predict_xray(image: np.ndarray):
# Resize image to (224, 224) if needed
if image.shape[:2] != (224, 224):
image = Image.fromarray(image).resize((224, 224))
image = np.array(image)
return {
"Atelectasis": 0.1,
"Cardiomegaly": 0.05,
"Consolidation": 0.2,
"Edema": 0.01,
"Effusion": 0.3,
"Emphysema": 0.02,
"Fibrosis": 0.08,
"Hernia": 0.005,
"Infiltration": 0.15,
"Mass": 0.07,
"Nodule": 0.12,
"Pleural_Thickening": 0.09,
"Pneumonia": 0.25,
"Pneumothorax": 0.18,
}
# Gradio interface
interface = gr.Interface(
fn=test_predict_xray,
inputs=gr.Image(type="numpy"), # Removed shape and source parameters
outputs=gr.Label(num_top_classes=14, label="Predicted Probabilities"),
title="NIH Chest X-ray Multi‐Label Classifier (Test)",
description="Upload a chest X-ray. The model outputs probabilities for 14 findings (using a test function)."
)
if __name__ == "__main__":
interface.launch()