from transformers import AutoModelForImageClassification, AutoImageProcessor import torch import torch.nn.functional as F from PIL import Image import gradio as gr # ----------------------------- # 1. Load pretrained model # ----------------------------- model_name = "microsoft/resnet-50-finetuned-chestxray14" model = AutoModelForImageClassification.from_pretrained(model_name) processor = AutoImageProcessor.from_pretrained(model_name) model.eval() # Get labels from config id2label = model.config.id2label # Focus only on 3 diseases target_diseases = ["Pneumonia", "Effusion", "Atelectasis"] # ----------------------------- # 2. Prediction function # ----------------------------- def predict(image): img = image.convert("RGB").resize((224, 224)) inputs = processor(images=img, return_tensors="pt") with torch.no_grad(): logits = model(**inputs).logits probs = F.softmax(logits, dim=1).squeeze() results = [] for idx, label in id2label.items(): if label in target_diseases: prob = probs[idx].item() results.append(f"{label}: {'YES' if prob > 0.5 else 'NO'} ({prob:.2f})") return "\n".join(results) # ----------------------------- # 3. Gradio interface # ----------------------------- iface = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs="text", title="Chest X-ray: Pneumonia / Effusion / Atelectasis", description="Upload a chest X-ray. Model predicts YES/NO with probabilities for Pneumonia, Effusion, and Atelectasis." ) iface.launch()