from transformers import AutoModelForImageClassification, AutoImageProcessor import torch from PIL import Image import gradio as gr # ----------------------------- # 1. Load the pretrained model # ----------------------------- model_name = "dima806/chest_xray_pneumonia_detection" model = AutoModelForImageClassification.from_pretrained(model_name) processor = AutoImageProcessor.from_pretrained(model_name) model.eval() # ----------------------------- # 2. Prediction function # ----------------------------- def predict(image): # Ensure image is in RGB (some models require 3 channels) img = image.convert("RGB") # Preprocess image inputs = processor(images=img, return_tensors="pt") # Forward pass with torch.no_grad(): logits = model(**inputs).logits # Predicted class index (0 = no pneumonia, 1 = pneumonia) predicted_class_idx = logits.argmax(-1).item() return "Pneumonia: YES" if predicted_class_idx == 1 else "Pneumonia: NO" # ----------------------------- # 3. Gradio interface # ----------------------------- iface = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs="text", title="State-of-the-Art Pneumonia Detector", description="Upload a chest X-ray. The model predicts if pneumonia is present (YES/NO)." ) iface.launch()