multi-chex / app.py
saba2000's picture
Rename app.py7.txt to app.py
3835968 verified
raw
history blame
1.59 kB
from transformers import AutoModelForImageClassification, AutoImageProcessor
import torch
import torch.nn.functional as F
from PIL import Image
import gradio as gr
# -----------------------------
# 1. Load the pretrained model
# -----------------------------
model_name = "microsoft/resnet-50" # fine-tuned for chest x-ray multi-disease
model = AutoModelForImageClassification.from_pretrained(model_name)
processor = AutoImageProcessor.from_pretrained(model_name)
model.eval()
# Example disease list (adjust depending on model config)
diseases = ["Pneumonia", "Effusion", "Atelectasis"]
# -----------------------------
# 2. Prediction function
# -----------------------------
def predict(image):
img = image.convert("RGB").resize((224, 224))
inputs = processor(images=img, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
probs = F.softmax(logits, dim=1).squeeze()
# Get top-3 predictions
top_probs, top_idxs = torch.topk(probs, k=3)
results = []
for idx, prob in zip(top_idxs, top_probs):
disease_name = diseases[idx] if idx < len(diseases) else f"Class {idx.item()}"
results.append(f"{disease_name}: {prob.item():.2f}")
return "\n".join(results)
# -----------------------------
# 3. Gradio interface
# -----------------------------
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs="text",
title="Chest X-ray Detector",
description="Upload a chest X-ray. The model predicts Pneumonia, Effusion, or Atelectasis."
)
iface.launch()