File size: 4,205 Bytes
87c4954
 
d10d7b3
87c4954
d10d7b3
87c4954
bc62586
87c4954
 
e30af19
 
 
87c4954
e30af19
87c4954
e30af19
 
 
 
87c4954
 
e30af19
 
 
5761260
 
e30af19
 
 
 
5761260
 
e30af19
5761260
 
 
e30af19
 
5761260
 
 
e30af19
 
5761260
 
 
e30af19
 
5761260
 
 
 
 
e30af19
 
 
 
bc62586
e30af19
 
bc62586
e30af19
 
87c4954
e30af19
 
 
 
 
 
 
 
 
 
 
87c4954
 
 
e30af19
 
 
 
 
 
 
 
 
 
87c4954
 
e30af19
 
 
87c4954
e30af19
 
 
 
 
 
 
 
 
87c4954
 
e30af19
 
 
1699b35
7a15187
 
 
 
 
1699b35
 
 
e30af19
1699b35
e30af19
 
 
1699b35
87c4954
e30af19
 
87c4954
e30af19
87c4954
e30af19
 
 
 
87c4954
 
1699b35
 
 
e30af19
 
 
 
 
1699b35
 
 
 
 
 
 
e30af19
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
from PIL import Image
import os
import time

# =========================
# Image preprocessing
# =========================
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Required for ResNet50
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

# =========================
# Model definition
# =========================
class FineTunedResNet(nn.Module):
    def __init__(self, num_classes=4):
        super().__init__()
        self.resnet = models.resnet50(
            weights=models.ResNet50_Weights.DEFAULT
        )

        self.resnet.fc = nn.Sequential(
            nn.Linear(self.resnet.fc.in_features, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(0.5),

            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.5),

            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.5),

            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        return self.resnet(x)

# =========================
# Load model
# =========================
MODEL_PATH = "models/final_fine_tuned_resnet50.pth"

if not os.path.exists(MODEL_PATH):
    raise FileNotFoundError(f"Model not found: {MODEL_PATH}")

model = FineTunedResNet(num_classes=4)
model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu"))
model.eval()
model.to("cpu")

CLASSES = ["🦠 COVID", "🫁 Normal", "🦠 Pneumonia", "🦠 TB"]

# =========================
# Prediction function
# =========================
def predict(image: Image.Image) -> str:
    start = time.time()

    image = transform(image).unsqueeze(0)

    with torch.no_grad():
        output = model(image)
        probs = F.softmax(output, dim=1)[0]
        top_probs, top_idxs = torch.topk(probs, 3)

    elapsed = time.time() - start

    result = "Top Predictions:\n\n"
    for prob, idx in zip(top_probs, top_idxs):
        result += f"{CLASSES[idx]}{prob.item():.4f}\n"

    result += f"\n⏱️ Prediction Time: {elapsed:.2f} seconds"
    return result

# =========================
# Example images
# =========================
examples = [
    ["examples/Pneumonia/02009view1_frontal.jpg"],
    ["examples/Pneumonia/02055view1_frontal.jpg"],
    ["examples/Pneumonia/03152view1_frontal.jpg"],
    ["examples/COVID/11547_2020_1200_Fig3_HTML-a.png"],
    ["examples/COVID/11547_2020_1200_Fig3_HTML-b.png"],
    ["examples/COVID/11547_2020_1203_Fig1_HTML-b.png"],
    ["examples/Normal/06bc1cfe-23a0-43a4-a01b-dfa10314bbb0.jpg"],
    ["examples/Normal/08ae6c0b-d044-4de2-a410-b3cf8dc65868.jpg"],
    ["examples/Normal/IM-0178-0001.jpeg"]
]

# =========================
# Visualization images
# =========================
visualization_images = [
    "pictures/1.png",
    "pictures/2.png",
    "pictures/3.png",
    "pictures/4.png",
    "pictures/5.png"
]

def display_visualizations():
    return [Image.open(path) for path in visualization_images]

# =========================
# Gradio interfaces
# =========================
prediction_interface = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil", label="Upload Chest X-ray"),
    outputs=gr.Textbox(label="Prediction Result"),
    examples=examples,
    cache_examples=False,  # IMPORTANT for HF Spaces
    title="Lung Disease Detection XVI",
    description="""
Upload a chest X-ray image to detect:
🦠 COVID-19 • 🦠 Pneumonia • 🫁 Normal • 🦠 Tuberculosis
"""
)

visualization_interface = gr.Interface(
    fn=display_visualizations,
    inputs=None,
    outputs=[
        gr.Image(type="pil", label=f"Visualization {i+1}")
        for i in range(len(visualization_images))
    ],
    title="Model Performance Visualizations"
)

app = gr.TabbedInterface(
    interface_list=[prediction_interface, visualization_interface],
    tab_names=["Predict", "Model Performance"]
)

# =========================
# Launch (HF Spaces safe)
# =========================
app.launch()