multi-chex / app.py
saba2000's picture
Update app.py
0ddd5ad verified
raw
history blame
1.56 kB
from transformers import AutoModelForImageClassification, AutoImageProcessor
import torch
import torch.nn.functional as F
from PIL import Image
import gradio as gr
# -----------------------------
# 1. Load pretrained model
# -----------------------------
model_name = "microsoft/resnet-50-finetuned-chestxray14"
model = AutoModelForImageClassification.from_pretrained(model_name)
processor = AutoImageProcessor.from_pretrained(model_name)
model.eval()
# Get labels from config
id2label = model.config.id2label
# Focus only on 3 diseases
target_diseases = ["Pneumonia", "Effusion", "Atelectasis"]
# -----------------------------
# 2. Prediction function
# -----------------------------
def predict(image):
img = image.convert("RGB").resize((224, 224))
inputs = processor(images=img, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
probs = F.softmax(logits, dim=1).squeeze()
results = []
for idx, label in id2label.items():
if label in target_diseases:
prob = probs[idx].item()
results.append(f"{label}: {'YES' if prob > 0.5 else 'NO'} ({prob:.2f})")
return "\n".join(results)
# -----------------------------
# 3. Gradio interface
# -----------------------------
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs="text",
title="Chest X-ray: Pneumonia / Effusion / Atelectasis",
description="Upload a chest X-ray. Model predicts YES/NO with probabilities for Pneumonia, Effusion, and Atelectasis."
)
iface.launch()