# app.py import torch import gradio as gr from PIL import Image from transformers import SwinForImageClassification, ViTImageProcessor # --- 1. Load Model & Processor --- MODEL_NAME = "microsoft/swin-tiny-patch4-window7-224" MODEL_PATH = "best_model_swin.pth" NUM_CLASSES = 3 CLASS_NAMES = ['COVID19', 'NORMAL', 'PNEUMONIA'] device = torch.device("cpu") # We will reject any prediction where the model's top guess is below 90% confidence. CONFIDENCE_THRESHOLD = 0.90 processor = ViTImageProcessor.from_pretrained(MODEL_NAME) model = SwinForImageClassification.from_pretrained( MODEL_NAME, num_labels=NUM_CLASSES, ignore_mismatched_sizes=True ) model.load_state_dict(torch.load(MODEL_PATH, map_location=device)) model.to(device) model.eval() # --- 2. Define Prediction Function --- def classify_image(input_image: Image.Image): if input_image is None: return "Please upload an image." if input_image.mode != "RGB": input_image = input_image.convert("RGB") inputs = processor(images=input_image, return_tensors="pt") pixel_values = inputs['pixel_values'].to(device) with torch.no_grad(): outputs = model(pixel_values) probabilities = torch.nn.functional.softmax(outputs.logits, dim=1) # Get the top class and its confidence score top_confidence, top_idx = torch.max(probabilities, dim=1) top_confidence_score = top_confidence.item() top_class_name = CLASS_NAMES[top_idx.item()] # Check if the confidence is below our threshold if top_confidence_score < CONFIDENCE_THRESHOLD: # Return a custom label for low-confidence predictions return {f"Invalid Image or Low Confidence ({top_class_name})": top_confidence_score} # If confidence is high enough, return the normal dictionary confidences = {CLASS_NAMES[i]: prob.item() for i, prob in enumerate(probabilities[0])} return confidences # --- 3. Create the Gradio Interface --- iface = gr.Interface( fn=classify_image, inputs=gr.Image(type="pil", label="Upload Chest X-Ray"), outputs=gr.Label(num_top_classes=3, label="Predictions"), title="Swin Transformer Chest X-Ray Classifier", description="Upload an X-ray image to classify it as COVID-19, Normal, or Pneumonia." ) # --- 4. Launch the app --- iface.launch()