File size: 1,557 Bytes
0ddd5ad
39ea0a2
 
 
 
 
fbf1393
0ddd5ad
fbf1393
0ddd5ad
39ea0a2
 
 
 
0ddd5ad
fbf1393
ccfcbd4
0ddd5ad
ccfcbd4
39ea0a2
fbf1393
 
 
39ea0a2
 
 
fbf1393
39ea0a2
 
fbf1393
39ea0a2
fbf1393
39ea0a2
fbf1393
 
0ddd5ad
 
fbf1393
39ea0a2
 
fbf1393
 
 
39ea0a2
 
 
 
ccfcbd4
0ddd5ad
929016b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from transformers import AutoModelForImageClassification, AutoImageProcessor
import torch
import torch.nn.functional as F
from PIL import Image
import gradio as gr

# -----------------------------
# 1. Load pretrained model
# -----------------------------
model_name = "microsoft/resnet-50-finetuned-chestxray14"
model = AutoModelForImageClassification.from_pretrained(model_name)
processor = AutoImageProcessor.from_pretrained(model_name)
model.eval()

# Get labels from config
id2label = model.config.id2label

# Focus only on 3 diseases
target_diseases = ["Pneumonia", "Effusion", "Atelectasis"]

# -----------------------------
# 2. Prediction function
# -----------------------------
def predict(image):
    img = image.convert("RGB").resize((224, 224))
    inputs = processor(images=img, return_tensors="pt")

    with torch.no_grad():
        logits = model(**inputs).logits

    probs = F.softmax(logits, dim=1).squeeze()

    results = []
    for idx, label in id2label.items():
        if label in target_diseases:
            prob = probs[idx].item()
            results.append(f"{label}: {'YES' if prob > 0.5 else 'NO'} ({prob:.2f})")

    return "\n".join(results)

# -----------------------------
# 3. Gradio interface
# -----------------------------
iface = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil"),
    outputs="text",
    title="Chest X-ray: Pneumonia / Effusion / Atelectasis",
    description="Upload a chest X-ray. Model predicts YES/NO with probabilities for Pneumonia, Effusion, and Atelectasis."
)

iface.launch()