File size: 4,205 Bytes
87c4954 d10d7b3 87c4954 d10d7b3 87c4954 bc62586 87c4954 e30af19 87c4954 e30af19 87c4954 e30af19 87c4954 e30af19 5761260 e30af19 5761260 e30af19 5761260 e30af19 5761260 e30af19 5761260 e30af19 5761260 e30af19 bc62586 e30af19 bc62586 e30af19 87c4954 e30af19 87c4954 e30af19 87c4954 e30af19 87c4954 e30af19 87c4954 e30af19 1699b35 7a15187 1699b35 e30af19 1699b35 e30af19 1699b35 87c4954 e30af19 87c4954 e30af19 87c4954 e30af19 87c4954 1699b35 e30af19 1699b35 e30af19 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 | import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
from PIL import Image
import os
import time
# =========================
# Image preprocessing
# =========================
transform = transforms.Compose([
transforms.Resize((224, 224)), # Required for ResNet50
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# =========================
# Model definition
# =========================
class FineTunedResNet(nn.Module):
def __init__(self, num_classes=4):
super().__init__()
self.resnet = models.resnet50(
weights=models.ResNet50_Weights.DEFAULT
)
self.resnet.fc = nn.Sequential(
nn.Linear(self.resnet.fc.in_features, 1024),
nn.BatchNorm1d(1024),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(1024, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(512, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(256, num_classes)
)
def forward(self, x):
return self.resnet(x)
# =========================
# Load model
# =========================
MODEL_PATH = "models/final_fine_tuned_resnet50.pth"
if not os.path.exists(MODEL_PATH):
raise FileNotFoundError(f"Model not found: {MODEL_PATH}")
model = FineTunedResNet(num_classes=4)
model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu"))
model.eval()
model.to("cpu")
CLASSES = ["🦠 COVID", "🫁 Normal", "🦠 Pneumonia", "🦠 TB"]
# =========================
# Prediction function
# =========================
def predict(image: Image.Image) -> str:
start = time.time()
image = transform(image).unsqueeze(0)
with torch.no_grad():
output = model(image)
probs = F.softmax(output, dim=1)[0]
top_probs, top_idxs = torch.topk(probs, 3)
elapsed = time.time() - start
result = "Top Predictions:\n\n"
for prob, idx in zip(top_probs, top_idxs):
result += f"{CLASSES[idx]} → {prob.item():.4f}\n"
result += f"\n⏱️ Prediction Time: {elapsed:.2f} seconds"
return result
# =========================
# Example images
# =========================
examples = [
["examples/Pneumonia/02009view1_frontal.jpg"],
["examples/Pneumonia/02055view1_frontal.jpg"],
["examples/Pneumonia/03152view1_frontal.jpg"],
["examples/COVID/11547_2020_1200_Fig3_HTML-a.png"],
["examples/COVID/11547_2020_1200_Fig3_HTML-b.png"],
["examples/COVID/11547_2020_1203_Fig1_HTML-b.png"],
["examples/Normal/06bc1cfe-23a0-43a4-a01b-dfa10314bbb0.jpg"],
["examples/Normal/08ae6c0b-d044-4de2-a410-b3cf8dc65868.jpg"],
["examples/Normal/IM-0178-0001.jpeg"]
]
# =========================
# Visualization images
# =========================
visualization_images = [
"pictures/1.png",
"pictures/2.png",
"pictures/3.png",
"pictures/4.png",
"pictures/5.png"
]
def display_visualizations():
return [Image.open(path) for path in visualization_images]
# =========================
# Gradio interfaces
# =========================
prediction_interface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload Chest X-ray"),
outputs=gr.Textbox(label="Prediction Result"),
examples=examples,
cache_examples=False, # IMPORTANT for HF Spaces
title="Lung Disease Detection XVI",
description="""
Upload a chest X-ray image to detect:
🦠 COVID-19 • 🦠 Pneumonia • 🫁 Normal • 🦠 Tuberculosis
"""
)
visualization_interface = gr.Interface(
fn=display_visualizations,
inputs=None,
outputs=[
gr.Image(type="pil", label=f"Visualization {i+1}")
for i in range(len(visualization_images))
],
title="Model Performance Visualizations"
)
app = gr.TabbedInterface(
interface_list=[prediction_interface, visualization_interface],
tab_names=["Predict", "Model Performance"]
)
# =========================
# Launch (HF Spaces safe)
# =========================
app.launch()
|