|
|
|
|
|
import torch |
|
|
import gradio as gr |
|
|
from PIL import Image |
|
|
from transformers import SwinForImageClassification, ViTImageProcessor |
|
|
|
|
|
|
|
|
MODEL_NAME = "microsoft/swin-tiny-patch4-window7-224" |
|
|
MODEL_PATH = "best_model_swin.pth" |
|
|
NUM_CLASSES = 3 |
|
|
CLASS_NAMES = ['COVID19', 'NORMAL', 'PNEUMONIA'] |
|
|
device = torch.device("cpu") |
|
|
|
|
|
|
|
|
CONFIDENCE_THRESHOLD = 0.90 |
|
|
|
|
|
processor = ViTImageProcessor.from_pretrained(MODEL_NAME) |
|
|
model = SwinForImageClassification.from_pretrained( |
|
|
MODEL_NAME, |
|
|
num_labels=NUM_CLASSES, |
|
|
ignore_mismatched_sizes=True |
|
|
) |
|
|
model.load_state_dict(torch.load(MODEL_PATH, map_location=device)) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
def classify_image(input_image: Image.Image): |
|
|
if input_image is None: |
|
|
return "Please upload an image." |
|
|
if input_image.mode != "RGB": |
|
|
input_image = input_image.convert("RGB") |
|
|
|
|
|
inputs = processor(images=input_image, return_tensors="pt") |
|
|
pixel_values = inputs['pixel_values'].to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(pixel_values) |
|
|
|
|
|
probabilities = torch.nn.functional.softmax(outputs.logits, dim=1) |
|
|
|
|
|
|
|
|
|
|
|
top_confidence, top_idx = torch.max(probabilities, dim=1) |
|
|
top_confidence_score = top_confidence.item() |
|
|
top_class_name = CLASS_NAMES[top_idx.item()] |
|
|
|
|
|
|
|
|
if top_confidence_score < CONFIDENCE_THRESHOLD: |
|
|
|
|
|
return {f"Invalid Image or Low Confidence ({top_class_name})": top_confidence_score} |
|
|
|
|
|
|
|
|
|
|
|
confidences = {CLASS_NAMES[i]: prob.item() for i, prob in enumerate(probabilities[0])} |
|
|
|
|
|
return confidences |
|
|
|
|
|
|
|
|
iface = gr.Interface( |
|
|
fn=classify_image, |
|
|
inputs=gr.Image(type="pil", label="Upload Chest X-Ray"), |
|
|
outputs=gr.Label(num_top_classes=3, label="Predictions"), |
|
|
title="Swin Transformer Chest X-Ray Classifier", |
|
|
description="Upload an X-ray image to classify it as COVID-19, Normal, or Pneumonia." |
|
|
) |
|
|
|
|
|
|
|
|
iface.launch() |