Spaces:
Runtime error
Runtime error
File size: 11,583 Bytes
6a5be80 97c84ef 6a5be80 97c84ef 6a5be80 97c84ef 28649a0 6a5be80 177b675 6a5be80 177b675 28649a0 6a5be80 b925d8f 97c84ef 6a5be80 97c84ef 177b675 6a5be80 97c84ef 6a5be80 97c84ef 6a5be80 97c84ef 6a5be80 97c84ef 6a5be80 177b675 6a5be80 177b675 332cebc 177b675 332cebc 97c84ef 177b675 6a5be80 ddfabeb 97c84ef 6a5be80 97c84ef 177b675 27c50f4 6a5be80 177b675 97c84ef 6a5be80 177b675 6a5be80 b925d8f 6a5be80 28649a0 177b675 6a5be80 b925d8f 6a5be80 97c84ef 6a5be80 97c84ef b925d8f 466431f b925d8f 466431f b925d8f 466431f b925d8f 466431f b925d8f 28649a0 f519c0b 28649a0 b925d8f 97c84ef f519c0b 6a5be80 f519c0b 97c84ef f519c0b 97c84ef f519c0b 6a5be80 f519c0b 6a5be80 b925d8f 6a5be80 f519c0b 6a5be80 f519c0b 6a5be80 b925d8f 6a5be80 f519c0b 6a5be80 f519c0b 6a5be80 b925d8f f519c0b 6a5be80 f519c0b 6a5be80 f519c0b 6a5be80 b925d8f 6a5be80 f519c0b 6a5be80 f519c0b 6a5be80 b925d8f 6a5be80 f519c0b b925d8f 6a5be80 b925d8f 6a5be80 f519c0b 6a5be80 b925d8f 6a5be80 f519c0b 6a5be80 f519c0b 6a5be80 f519c0b 6a5be80 f519c0b 6a5be80 f519c0b 6a5be80 f519c0b 6a5be80 f519c0b 6a5be80 f519c0b 6a5be80 b925d8f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 |
import sys
import os
import time
import cv2
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models
from PIL import Image
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc
from sklearn.preprocessing import label_binarize
import streamlit as st
import matplotlib.pyplot as plt
from fpdf import FPDF
# ---- Streamlit State Initialization ----
if 'stop_eval' not in st.session_state:
st.session_state.stop_eval = False
if 'evaluation_done' not in st.session_state:
st.session_state.evaluation_done = False
if 'trigger_eval' not in st.session_state:
st.session_state.trigger_eval = False
# ---- Streamlit Title ----
st.markdown("<h2 style='color: #2E86C1;'>π Model Evaluation</h2>", unsafe_allow_html=True)
# ---- Class Names & Label Mapping ----
class_names = ['No_DR', 'Mild', 'Moderate', 'Severe', 'Proliferative_DR']
label_map = {label: idx for idx, label in enumerate(class_names)}
# ---- Text Cleaning Function for PDF ----
def clean_text(text):
return text.encode('utf-8', 'ignore').decode('utf-8')
# ---- Preprocessing Functions ----
def apply_median_filter(image):
return cv2.medianBlur(image, 5)
def apply_clahe(image):
lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
l, a, b = cv2.split(lab)
clahe = cv2.createCLAHE(clipLimit=2.0)
cl = clahe.apply(l)
merged = cv2.merge((cl, a, b))
return cv2.cvtColor(merged, cv2.COLOR_LAB2RGB)
def apply_gamma_correction(image, gamma=1.2):
invGamma = 1.0 / gamma
table = np.array([(i / 255.0) ** invGamma * 255 for i in np.arange(0, 256)]).astype("uint8")
return cv2.LUT(image, table)
def apply_gaussian_filter(image, kernel_size=(5, 5), sigma=1.0):
return cv2.GaussianBlur(image, kernel_size, sigma)
# ---- Custom Dataset ----
class DDRDataset(Dataset):
def __init__(self, csv_path, transform=None):
self.data = pd.read_csv(csv_path)
self.image_paths = self.data['new_path'].tolist()
self.labels = self.data['label'].tolist()
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img_path = self.image_paths[idx]
label = int(self.labels[idx])
image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Apply preprocessing
image = apply_median_filter(image)
image = apply_clahe(image)
image = apply_gamma_correction(image)
image = apply_gaussian_filter(image)
image = Image.fromarray(image)
if self.transform:
image = self.transform(image)
return image, torch.tensor(label, dtype=torch.long)
# ---- Image Transforms ----
val_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# ---- Load Data (with caching) ----
@st.cache_resource
def load_test_data(csv_path):
dataset = DDRDataset(csv_path=csv_path, transform=val_transform)
return DataLoader(dataset, batch_size=32, shuffle=False)
# ---- Load Model (with caching) ----
@st.cache_resource
def load_model():
model = models.densenet121(pretrained=False)
model.classifier = nn.Linear(model.classifier.in_features, len(class_names))
model.load_state_dict(torch.load("./Model/Pretrained_Densenet-121.pth", map_location=torch.device('cpu')))
model.eval()
return model
# ---- Main UI Buttons ----
csv_path = "https://huggingface.co/datasets/Ci-Dave/DDR_dataset_train_test/raw/main/splits/test_labels.csv"
model = load_model()
test_loader = load_test_data(csv_path)
col1, col2 = st.columns([1, 1])
with col1:
if st.button("π Start Evaluation"):
st.session_state.stop_eval = False
st.session_state.evaluation_done = False
st.session_state.trigger_eval = True
with col2:
if st.button("π© Stop Evaluation"):
st.session_state.stop_eval = True
if st.session_state.evaluation_done:
reevaluate_col, download_col = st.columns([1, 1])
# ---- Description for Model Evaluation ----
with st.expander("βΉοΈ **What is Model Evaluation?**", expanded=True):
st.markdown("""
<div style='font-size:16px;'>
The **Model Evaluation** section tests how well the trained AI model performs on the unseen <strong>test set</strong> of retinal images. This provides insights into the reliability and performance of the model when deployed in real scenarios.
#### π What It Does:
- Loads the test dataset of labeled retinal images
- Runs the model to predict labels
- Compares predictions vs. true labels
- Computes:
- π **Classification Report** (Precision, Recall, F1-Score)
- π§ **Confusion Matrix**
- π **Multi-class ROC Curve**
- β **Misclassified Image Samples**
- Saves the full report as a downloadable PDF
#### π§ How to Use:
1. Click **π Start Evaluation** to begin analyzing the modelβs performance.
2. Wait for the evaluation to finish (shows progress bar and batch updates).
3. Once done:
- Check performance scores for each DR class
- View visual summaries like confusion matrix and ROC curve
- See the top 5 misclassified examples
4. Optionally, download the full evaluation report via **π Download PDF**
β οΈ <i>Note: This evaluation runs on the full test set and might take several seconds depending on hardware.</i>
</div>
""", unsafe_allow_html=True)
# ---- Evaluation Logic ----
# Check if evaluation should be triggered
if st.session_state.trigger_eval:
st.markdown("### β±οΈ Evaluation Results")
# Start timing the evaluation
start_time = time.time()
y_true = [] # Ground truth labels
y_pred = [] # Predicted labels
y_score = [] # Raw model outputs
misclassified_images = [] # List to store misclassified samples
total_batches = len(test_loader) # Total number of batches
progress_bar = st.progress(0) # Initialize progress bar
status_text = st.empty() # Placeholder for status updates
stop_info = st.empty() # Placeholder for stop message
# Disable gradient calculation for faster evaluation
with torch.no_grad():
for i, (images, labels) in enumerate(test_loader):
# Allow user to stop the evaluation
if st.session_state.stop_eval:
stop_info.warning("π© Evaluation stopped by user.")
break
# Run model on input images
outputs = model(images)
_, predicted = torch.max(outputs, 1) # Get predicted class
y_true.extend(labels.numpy())
y_pred.extend(predicted.numpy())
y_score.extend(outputs.detach().numpy())
# Store misclassified samples
for j in range(len(labels)):
if predicted[j] != labels[j]:
misclassified_images.append((images[j], predicted[j].item(), labels[j].item()))
# Update progress bar and status text
percent_complete = (i + 1) / total_batches
progress_bar.progress(min(percent_complete, 1.0))
status_text.text(f"Evaluating on Test Set: {int(percent_complete * 100)}% | Batch {i+1}/{total_batches}")
time.sleep(0.1) # Add delay for UI responsiveness
end_time = time.time()
eval_time = end_time - start_time # Total evaluation time
# Finalize evaluation if not stopped
if not st.session_state.stop_eval:
st.session_state.evaluation_done = True
st.session_state.trigger_eval = False # β
Reset the trigger
st.success(f"β
Evaluation completed in **{eval_time:.2f} seconds**")
# Generate classification report and display as a DataFrame
report = classification_report(y_true, y_pred, target_names=class_names, output_dict=True)
report_df = pd.DataFrame(report).transpose()
st.dataframe(report_df.style.format("{:.2f}"))
# Initialize PDF report
pdf = FPDF()
pdf.add_page()
pdf.set_font("Arial", size=12)
pdf.cell(200, 10, txt=clean_text("Classification Report"), ln=True, align='C')
# Add table headers
col_widths = [40, 40, 40, 40]
headers = ["Class", "Precision", "Recall", "F1-Score"]
for i, header in enumerate(headers):
pdf.cell(col_widths[i], 10, header, border=1)
pdf.ln()
# Add metrics for each class
for idx, row in report_df.iterrows():
if idx in ['accuracy', 'macro avg', 'weighted avg']:
continue
pdf.cell(col_widths[0], 10, str(idx), border=1)
pdf.cell(col_widths[1], 10, f"{row['precision']:.2f}", border=1)
pdf.cell(col_widths[2], 10, f"{row['recall']:.2f}", border=1)
pdf.cell(col_widths[3], 10, f"{row['f1-score']:.2f}", border=1)
pdf.ln()
# Create and display confusion matrix
cm = confusion_matrix(y_true, y_pred)
fig_cm, ax = plt.subplots()
sns.heatmap(cm, annot=True, fmt='d', xticklabels=class_names, yticklabels=class_names, cmap="Blues", ax=ax)
ax.set_xlabel('Predicted')
ax.set_ylabel('True')
ax.set_title("Confusion Matrix")
st.pyplot(fig_cm)
# Save confusion matrix to PDF
cm_path = "confusion_matrix.png"
fig_cm.savefig(cm_path, format='png', dpi=300, bbox_inches='tight')
plt.close(fig_cm)
if os.path.exists(cm_path):
pdf.image(cm_path, x=10, y=None, w=180)
# Create and display ROC curve for each class
y_true_bin = label_binarize(y_true, classes=list(range(len(class_names))))
y_score_np = np.array(y_score)
fig_roc, ax = plt.subplots()
for i in range(len(class_names)):
fpr, tpr, _ = roc_curve(y_true_bin[:, i], y_score_np[:, i])
roc_auc = auc(fpr, tpr)
ax.plot(fpr, tpr, label=f'{class_names[i]} (AUC = {roc_auc:.2f})')
ax.plot([0, 1], [0, 1], 'k--') # Diagonal reference line
ax.set_xlabel('False Positive Rate')
ax.set_ylabel('True Positive Rate')
ax.set_title('Multi-class ROC Curve')
ax.legend(loc='lower right')
st.pyplot(fig_roc)
# Save ROC curve to PDF
roc_path = "roc_curve.png"
fig_roc.savefig(roc_path, format='png', dpi=300, bbox_inches='tight')
plt.close(fig_roc)
if os.path.exists(roc_path):
pdf.image(roc_path, x=10, y=None, w=180)
# Show misclassified samples (up to 5)
st.markdown("### β Misclassified Samples")
fig_mis, axs = plt.subplots(1, min(5, len(misclassified_images)), figsize=(15, 4))
for idx, (img, pred, true) in enumerate(misclassified_images[:5]):
axs[idx].imshow(img.permute(1, 2, 0)) # Convert tensor to image format
axs[idx].set_title(f"True: {class_names[true]}\nPred: {class_names[pred]}")
axs[idx].axis('off')
st.pyplot(fig_mis)
# Save PDF and provide download button
output_pdf = "evaluation_report.pdf"
pdf.output(output_pdf)
with open(output_pdf, "rb") as f:
reevaluate_col, download_col = st.columns([1, 1])
with download_col:
st.download_button("π Download Full Evaluation PDF", f, file_name="evaluation_report.pdf")
|