import sys import os import time import cv2 import numpy as np import pandas as pd import seaborn as sns import torch import torch.nn as nn from torch.utils.data import DataLoader, Dataset from torchvision import transforms, models from PIL import Image from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc from sklearn.preprocessing import label_binarize import streamlit as st import matplotlib.pyplot as plt from fpdf import FPDF # ---- Streamlit State Initialization ---- if 'stop_eval' not in st.session_state: st.session_state.stop_eval = False if 'evaluation_done' not in st.session_state: st.session_state.evaluation_done = False if 'trigger_eval' not in st.session_state: st.session_state.trigger_eval = False # ---- Streamlit Title ---- st.markdown("

πŸ“ˆ Model Evaluation

", unsafe_allow_html=True) # ---- Class Names & Label Mapping ---- class_names = ['No_DR', 'Mild', 'Moderate', 'Severe', 'Proliferative_DR'] label_map = {label: idx for idx, label in enumerate(class_names)} # ---- Text Cleaning Function for PDF ---- def clean_text(text): return text.encode('utf-8', 'ignore').decode('utf-8') # ---- Preprocessing Functions ---- def apply_median_filter(image): return cv2.medianBlur(image, 5) def apply_clahe(image): lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB) l, a, b = cv2.split(lab) clahe = cv2.createCLAHE(clipLimit=2.0) cl = clahe.apply(l) merged = cv2.merge((cl, a, b)) return cv2.cvtColor(merged, cv2.COLOR_LAB2RGB) def apply_gamma_correction(image, gamma=1.2): invGamma = 1.0 / gamma table = np.array([(i / 255.0) ** invGamma * 255 for i in np.arange(0, 256)]).astype("uint8") return cv2.LUT(image, table) def apply_gaussian_filter(image, kernel_size=(5, 5), sigma=1.0): return cv2.GaussianBlur(image, kernel_size, sigma) # ---- Custom Dataset ---- class DDRDataset(Dataset): def __init__(self, csv_path, transform=None): self.data = pd.read_csv(csv_path) self.image_paths = self.data['new_path'].tolist() self.labels = self.data['label'].tolist() self.transform = transform def __len__(self): return len(self.image_paths) def __getitem__(self, idx): img_path = self.image_paths[idx] label = int(self.labels[idx]) image = cv2.imread(img_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Apply preprocessing image = apply_median_filter(image) image = apply_clahe(image) image = apply_gamma_correction(image) image = apply_gaussian_filter(image) image = Image.fromarray(image) if self.transform: image = self.transform(image) return image, torch.tensor(label, dtype=torch.long) # ---- Image Transforms ---- val_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # ---- Load Data (with caching) ---- @st.cache_resource def load_test_data(csv_path): dataset = DDRDataset(csv_path=csv_path, transform=val_transform) return DataLoader(dataset, batch_size=32, shuffle=False) # ---- Load Model (with caching) ---- @st.cache_resource def load_model(): model = models.densenet121(pretrained=False) model.classifier = nn.Linear(model.classifier.in_features, len(class_names)) model.load_state_dict(torch.load("./Model/Pretrained_Densenet-121.pth", map_location=torch.device('cpu'))) model.eval() return model # ---- Main UI Buttons ---- csv_path = "https://huggingface.co/datasets/Ci-Dave/DDR_dataset_train_test/raw/main/splits/test_labels.csv" model = load_model() test_loader = load_test_data(csv_path) col1, col2 = st.columns([1, 1]) with col1: if st.button("πŸš€ Start Evaluation"): st.session_state.stop_eval = False st.session_state.evaluation_done = False st.session_state.trigger_eval = True with col2: if st.button("🚩 Stop Evaluation"): st.session_state.stop_eval = True if st.session_state.evaluation_done: reevaluate_col, download_col = st.columns([1, 1]) # ---- Description for Model Evaluation ---- with st.expander("ℹ️ **What is Model Evaluation?**", expanded=True): st.markdown("""
The **Model Evaluation** section tests how well the trained AI model performs on the unseen test set of retinal images. This provides insights into the reliability and performance of the model when deployed in real scenarios. #### πŸ” What It Does: - Loads the test dataset of labeled retinal images - Runs the model to predict labels - Compares predictions vs. true labels - Computes: - πŸ“‹ **Classification Report** (Precision, Recall, F1-Score) - 🧊 **Confusion Matrix** - πŸ“ˆ **Multi-class ROC Curve** - ❌ **Misclassified Image Samples** - Saves the full report as a downloadable PDF #### 🧭 How to Use: 1. Click **πŸš€ Start Evaluation** to begin analyzing the model’s performance. 2. Wait for the evaluation to finish (shows progress bar and batch updates). 3. Once done: - Check performance scores for each DR class - View visual summaries like confusion matrix and ROC curve - See the top 5 misclassified examples 4. Optionally, download the full evaluation report via **πŸ“„ Download PDF** ⚠️ Note: This evaluation runs on the full test set and might take several seconds depending on hardware.
""", unsafe_allow_html=True) # ---- Evaluation Logic ---- # Check if evaluation should be triggered if st.session_state.trigger_eval: st.markdown("### ⏱️ Evaluation Results") # Start timing the evaluation start_time = time.time() y_true = [] # Ground truth labels y_pred = [] # Predicted labels y_score = [] # Raw model outputs misclassified_images = [] # List to store misclassified samples total_batches = len(test_loader) # Total number of batches progress_bar = st.progress(0) # Initialize progress bar status_text = st.empty() # Placeholder for status updates stop_info = st.empty() # Placeholder for stop message # Disable gradient calculation for faster evaluation with torch.no_grad(): for i, (images, labels) in enumerate(test_loader): # Allow user to stop the evaluation if st.session_state.stop_eval: stop_info.warning("🚩 Evaluation stopped by user.") break # Run model on input images outputs = model(images) _, predicted = torch.max(outputs, 1) # Get predicted class y_true.extend(labels.numpy()) y_pred.extend(predicted.numpy()) y_score.extend(outputs.detach().numpy()) # Store misclassified samples for j in range(len(labels)): if predicted[j] != labels[j]: misclassified_images.append((images[j], predicted[j].item(), labels[j].item())) # Update progress bar and status text percent_complete = (i + 1) / total_batches progress_bar.progress(min(percent_complete, 1.0)) status_text.text(f"Evaluating on Test Set: {int(percent_complete * 100)}% | Batch {i+1}/{total_batches}") time.sleep(0.1) # Add delay for UI responsiveness end_time = time.time() eval_time = end_time - start_time # Total evaluation time # Finalize evaluation if not stopped if not st.session_state.stop_eval: st.session_state.evaluation_done = True st.session_state.trigger_eval = False # βœ… Reset the trigger st.success(f"βœ… Evaluation completed in **{eval_time:.2f} seconds**") # Generate classification report and display as a DataFrame report = classification_report(y_true, y_pred, target_names=class_names, output_dict=True) report_df = pd.DataFrame(report).transpose() st.dataframe(report_df.style.format("{:.2f}")) # Initialize PDF report pdf = FPDF() pdf.add_page() pdf.set_font("Arial", size=12) pdf.cell(200, 10, txt=clean_text("Classification Report"), ln=True, align='C') # Add table headers col_widths = [40, 40, 40, 40] headers = ["Class", "Precision", "Recall", "F1-Score"] for i, header in enumerate(headers): pdf.cell(col_widths[i], 10, header, border=1) pdf.ln() # Add metrics for each class for idx, row in report_df.iterrows(): if idx in ['accuracy', 'macro avg', 'weighted avg']: continue pdf.cell(col_widths[0], 10, str(idx), border=1) pdf.cell(col_widths[1], 10, f"{row['precision']:.2f}", border=1) pdf.cell(col_widths[2], 10, f"{row['recall']:.2f}", border=1) pdf.cell(col_widths[3], 10, f"{row['f1-score']:.2f}", border=1) pdf.ln() # Create and display confusion matrix cm = confusion_matrix(y_true, y_pred) fig_cm, ax = plt.subplots() sns.heatmap(cm, annot=True, fmt='d', xticklabels=class_names, yticklabels=class_names, cmap="Blues", ax=ax) ax.set_xlabel('Predicted') ax.set_ylabel('True') ax.set_title("Confusion Matrix") st.pyplot(fig_cm) # Save confusion matrix to PDF cm_path = "confusion_matrix.png" fig_cm.savefig(cm_path, format='png', dpi=300, bbox_inches='tight') plt.close(fig_cm) if os.path.exists(cm_path): pdf.image(cm_path, x=10, y=None, w=180) # Create and display ROC curve for each class y_true_bin = label_binarize(y_true, classes=list(range(len(class_names)))) y_score_np = np.array(y_score) fig_roc, ax = plt.subplots() for i in range(len(class_names)): fpr, tpr, _ = roc_curve(y_true_bin[:, i], y_score_np[:, i]) roc_auc = auc(fpr, tpr) ax.plot(fpr, tpr, label=f'{class_names[i]} (AUC = {roc_auc:.2f})') ax.plot([0, 1], [0, 1], 'k--') # Diagonal reference line ax.set_xlabel('False Positive Rate') ax.set_ylabel('True Positive Rate') ax.set_title('Multi-class ROC Curve') ax.legend(loc='lower right') st.pyplot(fig_roc) # Save ROC curve to PDF roc_path = "roc_curve.png" fig_roc.savefig(roc_path, format='png', dpi=300, bbox_inches='tight') plt.close(fig_roc) if os.path.exists(roc_path): pdf.image(roc_path, x=10, y=None, w=180) # Show misclassified samples (up to 5) st.markdown("### ❌ Misclassified Samples") fig_mis, axs = plt.subplots(1, min(5, len(misclassified_images)), figsize=(15, 4)) for idx, (img, pred, true) in enumerate(misclassified_images[:5]): axs[idx].imshow(img.permute(1, 2, 0)) # Convert tensor to image format axs[idx].set_title(f"True: {class_names[true]}\nPred: {class_names[pred]}") axs[idx].axis('off') st.pyplot(fig_mis) # Save PDF and provide download button output_pdf = "evaluation_report.pdf" pdf.output(output_pdf) with open(output_pdf, "rb") as f: reevaluate_col, download_col = st.columns([1, 1]) with download_col: st.download_button("πŸ“„ Download Full Evaluation PDF", f, file_name="evaluation_report.pdf")