multi-chex / app.py
saba2000's picture
Update app.py
ccfcbd4 verified
raw
history blame
1.5 kB
from transformers import AutoModelForImageClassification, AutoImageProcessor
import torch
import torch.nn.functional as F
from PIL import Image
import gradio as gr
# Load the properly fine-tuned chest X-ray model
model_name = "Lucario-K17/biomedclip_radiology_diagnosis"
model = AutoModelForImageClassification.from_pretrained(model_name)
processor = AutoImageProcessor.from_pretrained(model_name)
model.eval()
# All 14 disease labels
all_diseases = [
"Atelectasis", "Cardiomegaly", "Effusion", "Infiltration", "Mass",
"Nodule", "Pneumonia", "Pneumothorax", "Consolidation", "Edema",
"Emphysema", "Fibrosis", "Pleural_Thickening", "Hernia"
]
# Lock to desired diseases
target_diseases = ["Pneumonia", "Effusion", "Atelectasis"]
target_idxs = [all_diseases.index(d) for d in target_diseases]
def predict(image):
img = image.convert("RGB").resize((224, 224))
inputs = processor(images=img, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
probs = F.softmax(logits, dim=1).squeeze()
results = []
for i, d in zip(target_idxs, target_diseases):
results.append(f"{d}: {probs[i].item():.2f}")
return "\n".join(results)
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs="text",
title="Chest X-ray: Pneumonia / Effusion / Atelectasis",
description="Upload a chest X-ray. Model predicts probability for Pneumonia, Effusion, and Atelectasis."
)
iface.launch()