from transformers import AutoModelForImageClassification, AutoImageProcessor import torch import torch.nn.functional as F from PIL import Image import gradio as gr # ----------------------------- # 1. Load the pretrained model # ----------------------------- model_name = "microsoft/resnet-50" # fine-tuned for chest x-ray multi-disease model = AutoModelForImageClassification.from_pretrained(model_name) processor = AutoImageProcessor.from_pretrained(model_name) model.eval() # Example disease list (adjust depending on model config) diseases = ["Pneumonia", "Effusion", "Atelectasis"] # ----------------------------- # 2. Prediction function # ----------------------------- def predict(image): img = image.convert("RGB").resize((224, 224)) inputs = processor(images=img, return_tensors="pt") with torch.no_grad(): logits = model(**inputs).logits probs = F.softmax(logits, dim=1).squeeze() # Get top-3 predictions top_probs, top_idxs = torch.topk(probs, k=3) results = [] for idx, prob in zip(top_idxs, top_probs): disease_name = diseases[idx] if idx < len(diseases) else f"Class {idx.item()}" results.append(f"{disease_name}: {prob.item():.2f}") return "\n".join(results) # ----------------------------- # 3. Gradio interface # ----------------------------- iface = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs="text", title="Chest X-ray Detector", description="