|
|
|
|
|
import torch
|
|
|
import gradio as gr
|
|
|
from PIL import Image
|
|
|
from transformers import SwinForImageClassification, ViTImageProcessor
|
|
|
|
|
|
|
|
|
MODEL_NAME = "microsoft/swin-tiny-patch4-window7-224"
|
|
|
MODEL_PATH = "best_model_swin.pth"
|
|
|
NUM_CLASSES = 3
|
|
|
CLASS_NAMES = ['COVID19', 'NORMAL', 'PNEUMONIA']
|
|
|
device = torch.device("cpu")
|
|
|
|
|
|
processor = ViTImageProcessor.from_pretrained(MODEL_NAME)
|
|
|
model = SwinForImageClassification.from_pretrained(
|
|
|
MODEL_NAME,
|
|
|
num_labels=NUM_CLASSES,
|
|
|
ignore_mismatched_sizes=True
|
|
|
)
|
|
|
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
|
|
|
model.to(device)
|
|
|
model.eval()
|
|
|
|
|
|
|
|
|
def classify_image(input_image: Image.Image):
|
|
|
if input_image is None:
|
|
|
return "Please upload an image."
|
|
|
if input_image.mode != "RGB":
|
|
|
input_image = input_image.convert("RGB")
|
|
|
|
|
|
inputs = processor(images=input_image, return_tensors="pt")
|
|
|
pixel_values = inputs['pixel_values'].to(device)
|
|
|
|
|
|
with torch.no_grad():
|
|
|
outputs = model(pixel_values)
|
|
|
|
|
|
probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
|
|
|
|
|
|
|
|
|
confidences = {CLASS_NAMES[i]: prob.item() for i, prob in enumerate(probabilities[0])}
|
|
|
|
|
|
return confidences
|
|
|
|
|
|
|
|
|
iface = gr.Interface(
|
|
|
fn=classify_image,
|
|
|
inputs=gr.Image(type="pil", label="Upload Chest X-Ray"),
|
|
|
outputs=gr.Label(num_top_classes=3, label="Predictions"),
|
|
|
title="Swin Transformer Chest X-Ray Classifier",
|
|
|
description="Upload an X-ray image to classify it as COVID-19, Normal, or Pneumonia."
|
|
|
)
|
|
|
|
|
|
|
|
|
iface.launch() |