import sys
import os
import time
import cv2
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models
from PIL import Image
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc
from sklearn.preprocessing import label_binarize
import streamlit as st
import matplotlib.pyplot as plt
from fpdf import FPDF
# ---- Streamlit State Initialization ----
if 'stop_eval' not in st.session_state:
st.session_state.stop_eval = False
if 'evaluation_done' not in st.session_state:
st.session_state.evaluation_done = False
if 'trigger_eval' not in st.session_state:
st.session_state.trigger_eval = False
# ---- Streamlit Title ----
st.markdown("
π Model Evaluation
", unsafe_allow_html=True)
# ---- Class Names & Label Mapping ----
class_names = ['No_DR', 'Mild', 'Moderate', 'Severe', 'Proliferative_DR']
label_map = {label: idx for idx, label in enumerate(class_names)}
# ---- Text Cleaning Function for PDF ----
def clean_text(text):
return text.encode('utf-8', 'ignore').decode('utf-8')
# ---- Preprocessing Functions ----
def apply_median_filter(image):
return cv2.medianBlur(image, 5)
def apply_clahe(image):
lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
l, a, b = cv2.split(lab)
clahe = cv2.createCLAHE(clipLimit=2.0)
cl = clahe.apply(l)
merged = cv2.merge((cl, a, b))
return cv2.cvtColor(merged, cv2.COLOR_LAB2RGB)
def apply_gamma_correction(image, gamma=1.2):
invGamma = 1.0 / gamma
table = np.array([(i / 255.0) ** invGamma * 255 for i in np.arange(0, 256)]).astype("uint8")
return cv2.LUT(image, table)
def apply_gaussian_filter(image, kernel_size=(5, 5), sigma=1.0):
return cv2.GaussianBlur(image, kernel_size, sigma)
# ---- Custom Dataset ----
class DDRDataset(Dataset):
def __init__(self, csv_path, transform=None):
self.data = pd.read_csv(csv_path)
self.image_paths = self.data['new_path'].tolist()
self.labels = self.data['label'].tolist()
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img_path = self.image_paths[idx]
label = int(self.labels[idx])
image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Apply preprocessing
image = apply_median_filter(image)
image = apply_clahe(image)
image = apply_gamma_correction(image)
image = apply_gaussian_filter(image)
image = Image.fromarray(image)
if self.transform:
image = self.transform(image)
return image, torch.tensor(label, dtype=torch.long)
# ---- Image Transforms ----
val_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# ---- Load Data (with caching) ----
@st.cache_resource
def load_test_data(csv_path):
dataset = DDRDataset(csv_path=csv_path, transform=val_transform)
return DataLoader(dataset, batch_size=32, shuffle=False)
# ---- Load Model (with caching) ----
@st.cache_resource
def load_model():
model = models.densenet121(pretrained=False)
model.classifier = nn.Linear(model.classifier.in_features, len(class_names))
model.load_state_dict(torch.load("./Model/Pretrained_Densenet-121.pth", map_location=torch.device('cpu')))
model.eval()
return model
# ---- Main UI Buttons ----
csv_path = "https://huggingface.co/datasets/Ci-Dave/DDR_dataset_train_test/raw/main/splits/test_labels.csv"
model = load_model()
test_loader = load_test_data(csv_path)
col1, col2 = st.columns([1, 1])
with col1:
if st.button("π Start Evaluation"):
st.session_state.stop_eval = False
st.session_state.evaluation_done = False
st.session_state.trigger_eval = True
with col2:
if st.button("π© Stop Evaluation"):
st.session_state.stop_eval = True
if st.session_state.evaluation_done:
reevaluate_col, download_col = st.columns([1, 1])
# ---- Description for Model Evaluation ----
with st.expander("βΉοΈ **What is Model Evaluation?**", expanded=True):
st.markdown("""
The **Model Evaluation** section tests how well the trained AI model performs on the unseen test set of retinal images. This provides insights into the reliability and performance of the model when deployed in real scenarios.
#### π What It Does:
- Loads the test dataset of labeled retinal images
- Runs the model to predict labels
- Compares predictions vs. true labels
- Computes:
- π **Classification Report** (Precision, Recall, F1-Score)
- π§ **Confusion Matrix**
- π **Multi-class ROC Curve**
- β **Misclassified Image Samples**
- Saves the full report as a downloadable PDF
#### π§ How to Use:
1. Click **π Start Evaluation** to begin analyzing the modelβs performance.
2. Wait for the evaluation to finish (shows progress bar and batch updates).
3. Once done:
- Check performance scores for each DR class
- View visual summaries like confusion matrix and ROC curve
- See the top 5 misclassified examples
4. Optionally, download the full evaluation report via **π Download PDF**
β οΈ Note: This evaluation runs on the full test set and might take several seconds depending on hardware.
""", unsafe_allow_html=True)
# ---- Evaluation Logic ----
# Check if evaluation should be triggered
if st.session_state.trigger_eval:
st.markdown("### β±οΈ Evaluation Results")
# Start timing the evaluation
start_time = time.time()
y_true = [] # Ground truth labels
y_pred = [] # Predicted labels
y_score = [] # Raw model outputs
misclassified_images = [] # List to store misclassified samples
total_batches = len(test_loader) # Total number of batches
progress_bar = st.progress(0) # Initialize progress bar
status_text = st.empty() # Placeholder for status updates
stop_info = st.empty() # Placeholder for stop message
# Disable gradient calculation for faster evaluation
with torch.no_grad():
for i, (images, labels) in enumerate(test_loader):
# Allow user to stop the evaluation
if st.session_state.stop_eval:
stop_info.warning("π© Evaluation stopped by user.")
break
# Run model on input images
outputs = model(images)
_, predicted = torch.max(outputs, 1) # Get predicted class
y_true.extend(labels.numpy())
y_pred.extend(predicted.numpy())
y_score.extend(outputs.detach().numpy())
# Store misclassified samples
for j in range(len(labels)):
if predicted[j] != labels[j]:
misclassified_images.append((images[j], predicted[j].item(), labels[j].item()))
# Update progress bar and status text
percent_complete = (i + 1) / total_batches
progress_bar.progress(min(percent_complete, 1.0))
status_text.text(f"Evaluating on Test Set: {int(percent_complete * 100)}% | Batch {i+1}/{total_batches}")
time.sleep(0.1) # Add delay for UI responsiveness
end_time = time.time()
eval_time = end_time - start_time # Total evaluation time
# Finalize evaluation if not stopped
if not st.session_state.stop_eval:
st.session_state.evaluation_done = True
st.session_state.trigger_eval = False # β
Reset the trigger
st.success(f"β
Evaluation completed in **{eval_time:.2f} seconds**")
# Generate classification report and display as a DataFrame
report = classification_report(y_true, y_pred, target_names=class_names, output_dict=True)
report_df = pd.DataFrame(report).transpose()
st.dataframe(report_df.style.format("{:.2f}"))
# Initialize PDF report
pdf = FPDF()
pdf.add_page()
pdf.set_font("Arial", size=12)
pdf.cell(200, 10, txt=clean_text("Classification Report"), ln=True, align='C')
# Add table headers
col_widths = [40, 40, 40, 40]
headers = ["Class", "Precision", "Recall", "F1-Score"]
for i, header in enumerate(headers):
pdf.cell(col_widths[i], 10, header, border=1)
pdf.ln()
# Add metrics for each class
for idx, row in report_df.iterrows():
if idx in ['accuracy', 'macro avg', 'weighted avg']:
continue
pdf.cell(col_widths[0], 10, str(idx), border=1)
pdf.cell(col_widths[1], 10, f"{row['precision']:.2f}", border=1)
pdf.cell(col_widths[2], 10, f"{row['recall']:.2f}", border=1)
pdf.cell(col_widths[3], 10, f"{row['f1-score']:.2f}", border=1)
pdf.ln()
# Create and display confusion matrix
cm = confusion_matrix(y_true, y_pred)
fig_cm, ax = plt.subplots()
sns.heatmap(cm, annot=True, fmt='d', xticklabels=class_names, yticklabels=class_names, cmap="Blues", ax=ax)
ax.set_xlabel('Predicted')
ax.set_ylabel('True')
ax.set_title("Confusion Matrix")
st.pyplot(fig_cm)
# Save confusion matrix to PDF
cm_path = "confusion_matrix.png"
fig_cm.savefig(cm_path, format='png', dpi=300, bbox_inches='tight')
plt.close(fig_cm)
if os.path.exists(cm_path):
pdf.image(cm_path, x=10, y=None, w=180)
# Create and display ROC curve for each class
y_true_bin = label_binarize(y_true, classes=list(range(len(class_names))))
y_score_np = np.array(y_score)
fig_roc, ax = plt.subplots()
for i in range(len(class_names)):
fpr, tpr, _ = roc_curve(y_true_bin[:, i], y_score_np[:, i])
roc_auc = auc(fpr, tpr)
ax.plot(fpr, tpr, label=f'{class_names[i]} (AUC = {roc_auc:.2f})')
ax.plot([0, 1], [0, 1], 'k--') # Diagonal reference line
ax.set_xlabel('False Positive Rate')
ax.set_ylabel('True Positive Rate')
ax.set_title('Multi-class ROC Curve')
ax.legend(loc='lower right')
st.pyplot(fig_roc)
# Save ROC curve to PDF
roc_path = "roc_curve.png"
fig_roc.savefig(roc_path, format='png', dpi=300, bbox_inches='tight')
plt.close(fig_roc)
if os.path.exists(roc_path):
pdf.image(roc_path, x=10, y=None, w=180)
# Show misclassified samples (up to 5)
st.markdown("### β Misclassified Samples")
fig_mis, axs = plt.subplots(1, min(5, len(misclassified_images)), figsize=(15, 4))
for idx, (img, pred, true) in enumerate(misclassified_images[:5]):
axs[idx].imshow(img.permute(1, 2, 0)) # Convert tensor to image format
axs[idx].set_title(f"True: {class_names[true]}\nPred: {class_names[pred]}")
axs[idx].axis('off')
st.pyplot(fig_mis)
# Save PDF and provide download button
output_pdf = "evaluation_report.pdf"
pdf.output(output_pdf)
with open(output_pdf, "rb") as f:
reevaluate_col, download_col = st.columns([1, 1])
with download_col:
st.download_button("π Download Full Evaluation PDF", f, file_name="evaluation_report.pdf")