import gradio as gr import torch import torch.nn as nn import torch.nn.functional as F from torchvision import transforms, models from PIL import Image import os import time # ========================= # Image preprocessing # ========================= transform = transforms.Compose([ transforms.Resize((224, 224)), # Required for ResNet50 transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) # ========================= # Model definition # ========================= class FineTunedResNet(nn.Module): def __init__(self, num_classes=4): super().__init__() self.resnet = models.resnet50( weights=models.ResNet50_Weights.DEFAULT ) self.resnet.fc = nn.Sequential( nn.Linear(self.resnet.fc.in_features, 1024), nn.BatchNorm1d(1024), nn.ReLU(), nn.Dropout(0.5), nn.Linear(1024, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(0.5), nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.5), nn.Linear(256, num_classes) ) def forward(self, x): return self.resnet(x) # ========================= # Load model # ========================= MODEL_PATH = "models/final_fine_tuned_resnet50.pth" if not os.path.exists(MODEL_PATH): raise FileNotFoundError(f"Model not found: {MODEL_PATH}") model = FineTunedResNet(num_classes=4) model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu")) model.eval() model.to("cpu") CLASSES = ["🦠 COVID", "🫁 Normal", "🦠 Pneumonia", "🦠 TB"] # ========================= # Prediction function # ========================= def predict(image: Image.Image) -> str: start = time.time() image = transform(image).unsqueeze(0) with torch.no_grad(): output = model(image) probs = F.softmax(output, dim=1)[0] top_probs, top_idxs = torch.topk(probs, 3) elapsed = time.time() - start result = "Top Predictions:\n\n" for prob, idx in zip(top_probs, top_idxs): result += f"{CLASSES[idx]} → {prob.item():.4f}\n" result += f"\nā±ļø Prediction Time: {elapsed:.2f} seconds" return result # ========================= # Example images # ========================= examples = [ ["examples/Pneumonia/02009view1_frontal.jpg"], ["examples/Pneumonia/02055view1_frontal.jpg"], ["examples/Pneumonia/03152view1_frontal.jpg"], ["examples/COVID/11547_2020_1200_Fig3_HTML-a.png"], ["examples/COVID/11547_2020_1200_Fig3_HTML-b.png"], ["examples/COVID/11547_2020_1203_Fig1_HTML-b.png"], ["examples/Normal/06bc1cfe-23a0-43a4-a01b-dfa10314bbb0.jpg"], ["examples/Normal/08ae6c0b-d044-4de2-a410-b3cf8dc65868.jpg"], ["examples/Normal/IM-0178-0001.jpeg"] ] # ========================= # Visualization images # ========================= visualization_images = [ "pictures/1.png", "pictures/2.png", "pictures/3.png", "pictures/4.png", "pictures/5.png" ] def display_visualizations(): return [Image.open(path) for path in visualization_images] # ========================= # Gradio interfaces # ========================= prediction_interface = gr.Interface( fn=predict, inputs=gr.Image(type="pil", label="Upload Chest X-ray"), outputs=gr.Textbox(label="Prediction Result"), examples=examples, cache_examples=False, # IMPORTANT for HF Spaces title="Lung Disease Detection XVI", description=""" Upload a chest X-ray image to detect: 🦠 COVID-19 • 🦠 Pneumonia • 🫁 Normal • 🦠 Tuberculosis """ ) visualization_interface = gr.Interface( fn=display_visualizations, inputs=None, outputs=[ gr.Image(type="pil", label=f"Visualization {i+1}") for i in range(len(visualization_images)) ], title="Model Performance Visualizations" ) app = gr.TabbedInterface( interface_list=[prediction_interface, visualization_interface], tab_names=["Predict", "Model Performance"] ) # ========================= # Launch (HF Spaces safe) # ========================= app.launch()