File size: 11,583 Bytes
6a5be80
 
 
 
 
 
 
97c84ef
6a5be80
97c84ef
 
 
6a5be80
 
 
 
 
97c84ef
28649a0
6a5be80
 
177b675
6a5be80
 
177b675
28649a0
 
 
6a5be80
b925d8f
97c84ef
6a5be80
97c84ef
 
 
177b675
6a5be80
 
 
 
97c84ef
 
 
 
 
 
6a5be80
97c84ef
 
 
 
6a5be80
97c84ef
6a5be80
97c84ef
 
 
 
 
6a5be80
177b675
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a5be80
 
 
 
 
 
177b675
 
332cebc
177b675
332cebc
97c84ef
 
177b675
6a5be80
 
 
 
ddfabeb
97c84ef
6a5be80
97c84ef
177b675
27c50f4
6a5be80
177b675
97c84ef
6a5be80
177b675
6a5be80
b925d8f
6a5be80
 
28649a0
177b675
6a5be80
b925d8f
6a5be80
97c84ef
6a5be80
 
97c84ef
b925d8f
 
466431f
 
b925d8f
466431f
b925d8f
 
466431f
 
 
b925d8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466431f
 
 
b925d8f
28649a0
f519c0b
28649a0
b925d8f
97c84ef
f519c0b
6a5be80
f519c0b
 
 
 
97c84ef
f519c0b
 
 
 
97c84ef
f519c0b
6a5be80
 
f519c0b
6a5be80
b925d8f
6a5be80
 
f519c0b
6a5be80
f519c0b
6a5be80
 
b925d8f
6a5be80
f519c0b
6a5be80
 
 
 
f519c0b
6a5be80
 
b925d8f
f519c0b
6a5be80
 
f519c0b
6a5be80
f519c0b
6a5be80
 
b925d8f
6a5be80
 
f519c0b
6a5be80
 
 
 
f519c0b
6a5be80
 
 
b925d8f
6a5be80
f519c0b
b925d8f
6a5be80
b925d8f
 
6a5be80
 
f519c0b
6a5be80
 
 
b925d8f
 
 
 
6a5be80
 
f519c0b
6a5be80
 
 
 
 
 
 
f519c0b
 
6a5be80
 
 
 
 
 
f519c0b
6a5be80
 
 
 
 
 
 
 
f519c0b
6a5be80
 
 
 
 
f519c0b
 
6a5be80
 
 
 
 
 
f519c0b
6a5be80
 
 
f519c0b
6a5be80
 
 
 
f519c0b
6a5be80
 
 
 
 
b925d8f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
import sys
import os
import time
import cv2
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models
from PIL import Image
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc
from sklearn.preprocessing import label_binarize
import streamlit as st
import matplotlib.pyplot as plt
from fpdf import FPDF

# ---- Streamlit State Initialization ----
if 'stop_eval' not in st.session_state:
    st.session_state.stop_eval = False

if 'evaluation_done' not in st.session_state:
    st.session_state.evaluation_done = False

if 'trigger_eval' not in st.session_state:
    st.session_state.trigger_eval = False

# ---- Streamlit Title ----
st.markdown("<h2 style='color: #2E86C1;'>πŸ“ˆ Model Evaluation</h2>", unsafe_allow_html=True)

# ---- Class Names & Label Mapping ----
class_names = ['No_DR', 'Mild', 'Moderate', 'Severe', 'Proliferative_DR']
label_map = {label: idx for idx, label in enumerate(class_names)}

# ---- Text Cleaning Function for PDF ----
def clean_text(text):
    return text.encode('utf-8', 'ignore').decode('utf-8')

# ---- Preprocessing Functions ----
def apply_median_filter(image):
    return cv2.medianBlur(image, 5)

def apply_clahe(image):
    lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
    l, a, b = cv2.split(lab)
    clahe = cv2.createCLAHE(clipLimit=2.0)
    cl = clahe.apply(l)
    merged = cv2.merge((cl, a, b))
    return cv2.cvtColor(merged, cv2.COLOR_LAB2RGB)

def apply_gamma_correction(image, gamma=1.2):
    invGamma = 1.0 / gamma
    table = np.array([(i / 255.0) ** invGamma * 255 for i in np.arange(0, 256)]).astype("uint8")
    return cv2.LUT(image, table)

def apply_gaussian_filter(image, kernel_size=(5, 5), sigma=1.0):
    return cv2.GaussianBlur(image, kernel_size, sigma)

# ---- Custom Dataset ----
class DDRDataset(Dataset):
    def __init__(self, csv_path, transform=None):
        self.data = pd.read_csv(csv_path)
        self.image_paths = self.data['new_path'].tolist()
        self.labels = self.data['label'].tolist()
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = int(self.labels[idx])

        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # Apply preprocessing
        image = apply_median_filter(image)
        image = apply_clahe(image)
        image = apply_gamma_correction(image)
        image = apply_gaussian_filter(image)

        image = Image.fromarray(image)
        if self.transform:
            image = self.transform(image)

        return image, torch.tensor(label, dtype=torch.long)

# ---- Image Transforms ----
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# ---- Load Data (with caching) ----
@st.cache_resource
def load_test_data(csv_path):
    dataset = DDRDataset(csv_path=csv_path, transform=val_transform)
    return DataLoader(dataset, batch_size=32, shuffle=False)

# ---- Load Model (with caching) ----
@st.cache_resource
def load_model():
    model = models.densenet121(pretrained=False)
    model.classifier = nn.Linear(model.classifier.in_features, len(class_names))
    model.load_state_dict(torch.load("./Model/Pretrained_Densenet-121.pth", map_location=torch.device('cpu')))
    model.eval()
    return model

# ---- Main UI Buttons ----
csv_path = "https://huggingface.co/datasets/Ci-Dave/DDR_dataset_train_test/raw/main/splits/test_labels.csv"
model = load_model()
test_loader = load_test_data(csv_path)

col1, col2 = st.columns([1, 1])

with col1:
    if st.button("πŸš€ Start Evaluation"):
        st.session_state.stop_eval = False
        st.session_state.evaluation_done = False
        st.session_state.trigger_eval = True

with col2:
    if st.button("🚩 Stop Evaluation"):
        st.session_state.stop_eval = True

if st.session_state.evaluation_done:
    reevaluate_col, download_col = st.columns([1, 1])

    # ---- Description for Model Evaluation ----
with st.expander("ℹ️ **What is Model Evaluation?**", expanded=True):
    st.markdown("""
    <div style='font-size:16px;'>
    The **Model Evaluation** section tests how well the trained AI model performs on the unseen <strong>test set</strong> of retinal images. This provides insights into the reliability and performance of the model when deployed in real scenarios.

    #### πŸ” What It Does:
    - Loads the test dataset of labeled retinal images
    - Runs the model to predict labels
    - Compares predictions vs. true labels
    - Computes:
        - πŸ“‹ **Classification Report** (Precision, Recall, F1-Score)
        - 🧊 **Confusion Matrix**
        - πŸ“ˆ **Multi-class ROC Curve**
        - ❌ **Misclassified Image Samples**
    - Saves the full report as a downloadable PDF

    #### 🧭 How to Use:
    1. Click **πŸš€ Start Evaluation** to begin analyzing the model’s performance.
    2. Wait for the evaluation to finish (shows progress bar and batch updates).
    3. Once done:
        - Check performance scores for each DR class
        - View visual summaries like confusion matrix and ROC curve
        - See the top 5 misclassified examples
    4. Optionally, download the full evaluation report via **πŸ“„ Download PDF**

    ⚠️ <i>Note: This evaluation runs on the full test set and might take several seconds depending on hardware.</i>
    </div>
    """, unsafe_allow_html=True)


# ---- Evaluation Logic ----
# Check if evaluation should be triggered
if st.session_state.trigger_eval:
    st.markdown("### ⏱️ Evaluation Results")

    # Start timing the evaluation
    start_time = time.time()
    y_true = []  # Ground truth labels
    y_pred = []  # Predicted labels
    y_score = []  # Raw model outputs
    misclassified_images = []  # List to store misclassified samples

    total_batches = len(test_loader)  # Total number of batches
    progress_bar = st.progress(0)  # Initialize progress bar
    status_text = st.empty()  # Placeholder for status updates
    stop_info = st.empty()  # Placeholder for stop message

    # Disable gradient calculation for faster evaluation
    with torch.no_grad():
        for i, (images, labels) in enumerate(test_loader):
            # Allow user to stop the evaluation
            if st.session_state.stop_eval:
                stop_info.warning("🚩 Evaluation stopped by user.")
                break

            # Run model on input images
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)  # Get predicted class
            y_true.extend(labels.numpy())
            y_pred.extend(predicted.numpy())
            y_score.extend(outputs.detach().numpy())

            # Store misclassified samples
            for j in range(len(labels)):
                if predicted[j] != labels[j]:
                    misclassified_images.append((images[j], predicted[j].item(), labels[j].item()))

            # Update progress bar and status text
            percent_complete = (i + 1) / total_batches
            progress_bar.progress(min(percent_complete, 1.0))
            status_text.text(f"Evaluating on Test Set: {int(percent_complete * 100)}% | Batch {i+1}/{total_batches}")
            time.sleep(0.1)  # Add delay for UI responsiveness

    end_time = time.time()
    eval_time = end_time - start_time  # Total evaluation time

    # Finalize evaluation if not stopped
    if not st.session_state.stop_eval:
        st.session_state.evaluation_done = True
        st.session_state.trigger_eval = False  # βœ… Reset the trigger
        st.success(f"βœ… Evaluation completed in **{eval_time:.2f} seconds**")

        # Generate classification report and display as a DataFrame
        report = classification_report(y_true, y_pred, target_names=class_names, output_dict=True)
        report_df = pd.DataFrame(report).transpose()
        st.dataframe(report_df.style.format("{:.2f}"))

        # Initialize PDF report
        pdf = FPDF()
        pdf.add_page()
        pdf.set_font("Arial", size=12)
        pdf.cell(200, 10, txt=clean_text("Classification Report"), ln=True, align='C')

        # Add table headers
        col_widths = [40, 40, 40, 40]
        headers = ["Class", "Precision", "Recall", "F1-Score"]
        for i, header in enumerate(headers):
            pdf.cell(col_widths[i], 10, header, border=1)
        pdf.ln()

        # Add metrics for each class
        for idx, row in report_df.iterrows():
            if idx in ['accuracy', 'macro avg', 'weighted avg']:
                continue
            pdf.cell(col_widths[0], 10, str(idx), border=1)
            pdf.cell(col_widths[1], 10, f"{row['precision']:.2f}", border=1)
            pdf.cell(col_widths[2], 10, f"{row['recall']:.2f}", border=1)
            pdf.cell(col_widths[3], 10, f"{row['f1-score']:.2f}", border=1)
            pdf.ln()

        # Create and display confusion matrix
        cm = confusion_matrix(y_true, y_pred)
        fig_cm, ax = plt.subplots()
        sns.heatmap(cm, annot=True, fmt='d', xticklabels=class_names, yticklabels=class_names, cmap="Blues", ax=ax)
        ax.set_xlabel('Predicted')
        ax.set_ylabel('True')
        ax.set_title("Confusion Matrix")
        st.pyplot(fig_cm)

        # Save confusion matrix to PDF
        cm_path = "confusion_matrix.png"
        fig_cm.savefig(cm_path, format='png', dpi=300, bbox_inches='tight')
        plt.close(fig_cm)
        if os.path.exists(cm_path):
            pdf.image(cm_path, x=10, y=None, w=180)

        # Create and display ROC curve for each class
        y_true_bin = label_binarize(y_true, classes=list(range(len(class_names))))
        y_score_np = np.array(y_score)
        fig_roc, ax = plt.subplots()
        for i in range(len(class_names)):
            fpr, tpr, _ = roc_curve(y_true_bin[:, i], y_score_np[:, i])
            roc_auc = auc(fpr, tpr)
            ax.plot(fpr, tpr, label=f'{class_names[i]} (AUC = {roc_auc:.2f})')

        ax.plot([0, 1], [0, 1], 'k--')  # Diagonal reference line
        ax.set_xlabel('False Positive Rate')
        ax.set_ylabel('True Positive Rate')
        ax.set_title('Multi-class ROC Curve')
        ax.legend(loc='lower right')
        st.pyplot(fig_roc)

        # Save ROC curve to PDF
        roc_path = "roc_curve.png"
        fig_roc.savefig(roc_path, format='png', dpi=300, bbox_inches='tight')
        plt.close(fig_roc)
        if os.path.exists(roc_path):
            pdf.image(roc_path, x=10, y=None, w=180)

        # Show misclassified samples (up to 5)
        st.markdown("### ❌ Misclassified Samples")
        fig_mis, axs = plt.subplots(1, min(5, len(misclassified_images)), figsize=(15, 4))
        for idx, (img, pred, true) in enumerate(misclassified_images[:5]):
            axs[idx].imshow(img.permute(1, 2, 0))  # Convert tensor to image format
            axs[idx].set_title(f"True: {class_names[true]}\nPred: {class_names[pred]}")
            axs[idx].axis('off')
        st.pyplot(fig_mis)

        # Save PDF and provide download button
        output_pdf = "evaluation_report.pdf"
        pdf.output(output_pdf)
        with open(output_pdf, "rb") as f:
            reevaluate_col, download_col = st.columns([1, 1])
            with download_col:
                st.download_button("πŸ“„ Download Full Evaluation PDF", f, file_name="evaluation_report.pdf")