| | import gradio as gr |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torchvision import transforms, models |
| | from PIL import Image |
| | import os |
| | import time |
| |
|
| | |
| | |
| | |
| | transform = transforms.Compose([ |
| | transforms.Resize((224, 224)), |
| | transforms.ToTensor(), |
| | transforms.Normalize( |
| | mean=[0.485, 0.456, 0.406], |
| | std=[0.229, 0.224, 0.225] |
| | ) |
| | ]) |
| |
|
| | |
| | |
| | |
| | class FineTunedResNet(nn.Module): |
| | def __init__(self, num_classes=4): |
| | super().__init__() |
| | self.resnet = models.resnet50( |
| | weights=models.ResNet50_Weights.DEFAULT |
| | ) |
| |
|
| | self.resnet.fc = nn.Sequential( |
| | nn.Linear(self.resnet.fc.in_features, 1024), |
| | nn.BatchNorm1d(1024), |
| | nn.ReLU(), |
| | nn.Dropout(0.5), |
| |
|
| | nn.Linear(1024, 512), |
| | nn.BatchNorm1d(512), |
| | nn.ReLU(), |
| | nn.Dropout(0.5), |
| |
|
| | nn.Linear(512, 256), |
| | nn.BatchNorm1d(256), |
| | nn.ReLU(), |
| | nn.Dropout(0.5), |
| |
|
| | nn.Linear(256, num_classes) |
| | ) |
| |
|
| | def forward(self, x): |
| | return self.resnet(x) |
| |
|
| | |
| | |
| | |
| | MODEL_PATH = "models/final_fine_tuned_resnet50.pth" |
| |
|
| | if not os.path.exists(MODEL_PATH): |
| | raise FileNotFoundError(f"Model not found: {MODEL_PATH}") |
| |
|
| | model = FineTunedResNet(num_classes=4) |
| | model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu")) |
| | model.eval() |
| | model.to("cpu") |
| |
|
| | CLASSES = ["🦠 COVID", "🫁 Normal", "🦠 Pneumonia", "🦠 TB"] |
| |
|
| | |
| | |
| | |
| | def predict(image: Image.Image) -> str: |
| | start = time.time() |
| |
|
| | image = transform(image).unsqueeze(0) |
| |
|
| | with torch.no_grad(): |
| | output = model(image) |
| | probs = F.softmax(output, dim=1)[0] |
| | top_probs, top_idxs = torch.topk(probs, 3) |
| |
|
| | elapsed = time.time() - start |
| |
|
| | result = "Top Predictions:\n\n" |
| | for prob, idx in zip(top_probs, top_idxs): |
| | result += f"{CLASSES[idx]} → {prob.item():.4f}\n" |
| |
|
| | result += f"\n⏱️ Prediction Time: {elapsed:.2f} seconds" |
| | return result |
| |
|
| | |
| | |
| | |
| | examples = [ |
| | ["examples/Pneumonia/02009view1_frontal.jpg"], |
| | ["examples/Pneumonia/02055view1_frontal.jpg"], |
| | ["examples/Pneumonia/03152view1_frontal.jpg"], |
| | ["examples/COVID/11547_2020_1200_Fig3_HTML-a.png"], |
| | ["examples/COVID/11547_2020_1200_Fig3_HTML-b.png"], |
| | ["examples/COVID/11547_2020_1203_Fig1_HTML-b.png"], |
| | ["examples/Normal/06bc1cfe-23a0-43a4-a01b-dfa10314bbb0.jpg"], |
| | ["examples/Normal/08ae6c0b-d044-4de2-a410-b3cf8dc65868.jpg"], |
| | ["examples/Normal/IM-0178-0001.jpeg"] |
| | ] |
| |
|
| | |
| | |
| | |
| | visualization_images = [ |
| | "pictures/1.png", |
| | "pictures/2.png", |
| | "pictures/3.png", |
| | "pictures/4.png", |
| | "pictures/5.png" |
| | ] |
| |
|
| | def display_visualizations(): |
| | return [Image.open(path) for path in visualization_images] |
| |
|
| | |
| | |
| | |
| | prediction_interface = gr.Interface( |
| | fn=predict, |
| | inputs=gr.Image(type="pil", label="Upload Chest X-ray"), |
| | outputs=gr.Textbox(label="Prediction Result"), |
| | examples=examples, |
| | cache_examples=False, |
| | title="Lung Disease Detection XVI", |
| | description=""" |
| | Upload a chest X-ray image to detect: |
| | 🦠 COVID-19 • 🦠 Pneumonia • 🫁 Normal • 🦠 Tuberculosis |
| | """ |
| | ) |
| |
|
| | visualization_interface = gr.Interface( |
| | fn=display_visualizations, |
| | inputs=None, |
| | outputs=[ |
| | gr.Image(type="pil", label=f"Visualization {i+1}") |
| | for i in range(len(visualization_images)) |
| | ], |
| | title="Model Performance Visualizations" |
| | ) |
| |
|
| | app = gr.TabbedInterface( |
| | interface_list=[prediction_interface, visualization_interface], |
| | tab_names=["Predict", "Model Performance"] |
| | ) |
| |
|
| | |
| | |
| | |
| | app.launch() |
| |
|