chexnet-demo / app.py
saba2000's picture
Rename app.py3.txt to app.py
c7efd22 verified
from transformers import AutoModelForImageClassification, AutoImageProcessor
import torch
from PIL import Image
import gradio as gr
# -----------------------------
# 1. Load the pretrained model
# -----------------------------
model_name = "dima806/chest_xray_pneumonia_detection"
model = AutoModelForImageClassification.from_pretrained(model_name)
processor = AutoImageProcessor.from_pretrained(model_name)
model.eval()
# -----------------------------
# 2. Prediction function
# -----------------------------
def predict(image):
# Ensure image is in RGB (some models require 3 channels)
img = image.convert("RGB")
# Preprocess image
inputs = processor(images=img, return_tensors="pt")
# Forward pass
with torch.no_grad():
logits = model(**inputs).logits
# Predicted class index (0 = no pneumonia, 1 = pneumonia)
predicted_class_idx = logits.argmax(-1).item()
return "Pneumonia: YES" if predicted_class_idx == 1 else "Pneumonia: NO"
# -----------------------------
# 3. Gradio interface
# -----------------------------
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs="text",
title="State-of-the-Art Pneumonia Detector",
description="Upload a chest X-ray. The model predicts if pneumonia is present (YES/NO)."
)
iface.launch()