File size: 2,335 Bytes
4ac2dc3
6e41f2b
 
 
 
 
 
 
 
 
 
4ac2dc3
6e41f2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5cfe6ab
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# app.py
import torch
import gradio as gr
from PIL import Image
from transformers import SwinForImageClassification, ViTImageProcessor

# --- 1. Load Model & Processor ---
MODEL_NAME = "microsoft/swin-tiny-patch4-window7-224"
MODEL_PATH = "best_model_swin.pth"
NUM_CLASSES = 3
CLASS_NAMES = ['COVID19', 'NORMAL', 'PNEUMONIA']
device = torch.device("cpu")

# We will reject any prediction where the model's top guess is below 90% confidence.
CONFIDENCE_THRESHOLD = 0.90  

processor = ViTImageProcessor.from_pretrained(MODEL_NAME)
model = SwinForImageClassification.from_pretrained(
    MODEL_NAME, 
    num_labels=NUM_CLASSES, 
    ignore_mismatched_sizes=True
)
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.to(device)
model.eval()

# --- 2. Define Prediction Function ---
def classify_image(input_image: Image.Image):
    if input_image is None:
        return "Please upload an image."
    if input_image.mode != "RGB":
        input_image = input_image.convert("RGB")
        
    inputs = processor(images=input_image, return_tensors="pt")
    pixel_values = inputs['pixel_values'].to(device)

    with torch.no_grad():
        outputs = model(pixel_values)
        
    probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
    
    
    # Get the top class and its confidence score
    top_confidence, top_idx = torch.max(probabilities, dim=1)
    top_confidence_score = top_confidence.item()
    top_class_name = CLASS_NAMES[top_idx.item()]

    # Check if the confidence is below our threshold
    if top_confidence_score < CONFIDENCE_THRESHOLD:
        # Return a custom label for low-confidence predictions
        return {f"Invalid Image or Low Confidence ({top_class_name})": top_confidence_score}


    # If confidence is high enough, return the normal dictionary
    confidences = {CLASS_NAMES[i]: prob.item() for i, prob in enumerate(probabilities[0])}
    
    return confidences

# --- 3. Create the Gradio Interface ---
iface = gr.Interface(
    fn=classify_image,
    inputs=gr.Image(type="pil", label="Upload Chest X-Ray"),
    outputs=gr.Label(num_top_classes=3, label="Predictions"),
    title="Swin Transformer Chest X-Ray Classifier",
    description="Upload an X-ray image to classify it as COVID-19, Normal, or Pneumonia."
)

# --- 4. Launch the app ---
iface.launch()