DR_Classification / pages /Model_Evaluation.py
3v324v23's picture
save
7c7bcd8
raw
history blame
10.4 kB
import sys
import os
import time
import cv2
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models
from PIL import Image
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc
from sklearn.preprocessing import label_binarize
import streamlit as st
import matplotlib.pyplot as plt
from fpdf import FPDF
# ---- Streamlit State Initialization ----
if 'stop_eval' not in st.session_state:
st.session_state.stop_eval = False
if 'evaluation_done' not in st.session_state:
st.session_state.evaluation_done = False
if 'trigger_eval' not in st.session_state:
st.session_state.trigger_eval = False
# ---- Streamlit Title ----
st.markdown("<h2 style='color: #2E86C1;'>📈 Model Evaluation</h2>", unsafe_allow_html=True)
# ---- Class Names & Label Mapping ----
class_names = ['No_DR', 'Mild', 'Moderate', 'Severe', 'Proliferative_DR']
label_map = {label: idx for idx, label in enumerate(class_names)}
# ---- Text Cleaning Function for PDF ----
def clean_text(text):
return text.encode('utf-8', 'ignore').decode('utf-8')
# ---- Preprocessing Functions ----
def apply_median_filter(image):
return cv2.medianBlur(image, 5)
def apply_clahe(image):
lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
l, a, b = cv2.split(lab)
clahe = cv2.createCLAHE(clipLimit=2.0)
cl = clahe.apply(l)
merged = cv2.merge((cl, a, b))
return cv2.cvtColor(merged, cv2.COLOR_LAB2RGB)
def apply_gamma_correction(image, gamma=1.2):
invGamma = 1.0 / gamma
table = np.array([(i / 255.0) ** invGamma * 255 for i in np.arange(0, 256)]).astype("uint8")
return cv2.LUT(image, table)
def apply_gaussian_filter(image, kernel_size=(5, 5), sigma=1.0):
return cv2.GaussianBlur(image, kernel_size, sigma)
# ---- Custom Dataset ----
class DDRDataset(Dataset):
def __init__(self, csv_path, transform=None):
self.data = pd.read_csv(csv_path)
self.image_paths = self.data['new_path'].tolist()
self.labels = self.data['label'].tolist()
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img_path = self.image_paths[idx]
label = int(self.labels[idx])
image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Apply preprocessing
image = apply_median_filter(image)
image = apply_clahe(image)
image = apply_gamma_correction(image)
image = apply_gaussian_filter(image)
image = Image.fromarray(image)
if self.transform:
image = self.transform(image)
return image, torch.tensor(label, dtype=torch.long)
# ---- Image Transforms ----
val_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# ---- Load Data (with caching) ----
@st.cache_resource
def load_test_data(csv_path):
dataset = DDRDataset(csv_path=csv_path, transform=val_transform)
return DataLoader(dataset, batch_size=32, shuffle=False)
# ---- Load Model (with caching) ----
@st.cache_resource
def load_model():
model = models.densenet121(pretrained=False)
model.classifier = nn.Linear(model.classifier.in_features, len(class_names))
model.load_state_dict(torch.load(r"training\Pretrained_Densenet-121.pth", map_location=torch.device('cpu')))
model.eval()
return model
# ---- Main UI Buttons ----
csv_path = r"splits\test_labels.csv"
model = load_model()
test_loader = load_test_data(csv_path)
col1, col2 = st.columns([1, 1])
with col1:
if st.button("🚀 Start Evaluation"):
st.session_state.stop_eval = False
st.session_state.evaluation_done = False
st.session_state.trigger_eval = True
with col2:
if st.button("🚩 Stop Evaluation"):
st.session_state.stop_eval = True
if st.session_state.evaluation_done:
reevaluate_col, download_col = st.columns([1, 1])
# ---- Description for Model Evaluation ----
with st.expander("ℹ️ **What is Model Evaluation?**", expanded=True):
st.markdown("""
<div style='font-size:16px;'>
The **Model Evaluation** section tests how well the trained AI model performs on the unseen <strong>test set</strong> of retinal images. This provides insights into the reliability and performance of the model when deployed in real scenarios.
#### 🔍 What It Does:
- Loads the test dataset of labeled retinal images
- Runs the model to predict labels
- Compares predictions vs. true labels
- Computes:
- 📋 **Classification Report** (Precision, Recall, F1-Score)
- 🧊 **Confusion Matrix**
- 📈 **Multi-class ROC Curve**
- ❌ **Misclassified Image Samples**
- Saves the full report as a downloadable PDF
#### 🧭 How to Use:
1. Click **🚀 Start Evaluation** to begin analyzing the model’s performance.
2. Wait for the evaluation to finish (shows progress bar and batch updates).
3. Once done:
- Check performance scores for each DR class
- View visual summaries like confusion matrix and ROC curve
- See the top 5 misclassified examples
4. Optionally, download the full evaluation report via **📄 Download PDF**
⚠️ <i>Note: This evaluation runs on the full test set and might take several seconds depending on hardware.</i>
</div>
""", unsafe_allow_html=True)
# ---- Evaluation Logic ----
if st.session_state.trigger_eval:
st.markdown("### ⏱️ Evaluation Results")
start_time = time.time()
y_true = []
y_pred = []
y_score = []
misclassified_images = []
total_batches = len(test_loader)
progress_bar = st.progress(0)
status_text = st.empty()
stop_info = st.empty()
with torch.no_grad():
for i, (images, labels) in enumerate(test_loader):
if st.session_state.stop_eval:
stop_info.warning("🚩 Evaluation stopped by user.")
break
outputs = model(images)
_, predicted = torch.max(outputs, 1)
y_true.extend(labels.numpy())
y_pred.extend(predicted.numpy())
y_score.extend(outputs.detach().numpy())
for j in range(len(labels)):
if predicted[j] != labels[j]:
misclassified_images.append((images[j], predicted[j].item(), labels[j].item()))
percent_complete = (i + 1) / total_batches
progress_bar.progress(min(percent_complete, 1.0))
status_text.text(f"Evaluating on Test Set: {int(percent_complete * 100)}% | Batch {i+1}/{total_batches}")
time.sleep(0.1)
end_time = time.time()
eval_time = end_time - start_time
if not st.session_state.stop_eval:
st.session_state.evaluation_done = True
st.session_state.trigger_eval = False # ✅ Reset the trigger
st.success(f"✅ Evaluation completed in **{eval_time:.2f} seconds**")
report = classification_report(y_true, y_pred, target_names=class_names, output_dict=True)
report_df = pd.DataFrame(report).transpose()
st.dataframe(report_df.style.format("{:.2f}"))
pdf = FPDF()
pdf.add_page()
pdf.set_font("Arial", size=12)
pdf.cell(200, 10, txt=clean_text("Classification Report"), ln=True, align='C')
col_widths = [40, 40, 40, 40]
headers = ["Class", "Precision", "Recall", "F1-Score"]
for i, header in enumerate(headers):
pdf.cell(col_widths[i], 10, header, border=1)
pdf.ln()
for idx, row in report_df.iterrows():
if idx in ['accuracy', 'macro avg', 'weighted avg']:
continue
pdf.cell(col_widths[0], 10, str(idx), border=1)
pdf.cell(col_widths[1], 10, f"{row['precision']:.2f}", border=1)
pdf.cell(col_widths[2], 10, f"{row['recall']:.2f}", border=1)
pdf.cell(col_widths[3], 10, f"{row['f1-score']:.2f}", border=1)
pdf.ln()
cm = confusion_matrix(y_true, y_pred)
fig_cm, ax = plt.subplots()
sns.heatmap(cm, annot=True, fmt='d', xticklabels=class_names, yticklabels=class_names, cmap="Blues", ax=ax)
ax.set_xlabel('Predicted')
ax.set_ylabel('True')
ax.set_title("Confusion Matrix")
st.pyplot(fig_cm)
cm_path = "confusion_matrix.png"
fig_cm.savefig(cm_path, format='png', dpi=300, bbox_inches='tight')
plt.close(fig_cm)
if os.path.exists(cm_path):
pdf.image(cm_path, x=10, y=None, w=180)
y_true_bin = label_binarize(y_true, classes=list(range(len(class_names))))
y_score_np = np.array(y_score)
fig_roc, ax = plt.subplots()
for i in range(len(class_names)):
fpr, tpr, _ = roc_curve(y_true_bin[:, i], y_score_np[:, i])
roc_auc = auc(fpr, tpr)
ax.plot(fpr, tpr, label=f'{class_names[i]} (AUC = {roc_auc:.2f})')
ax.plot([0, 1], [0, 1], 'k--')
ax.set_xlabel('False Positive Rate')
ax.set_ylabel('True Positive Rate')
ax.set_title('Multi-class ROC Curve')
ax.legend(loc='lower right')
st.pyplot(fig_roc)
roc_path = "roc_curve.png"
fig_roc.savefig(roc_path, format='png', dpi=300, bbox_inches='tight')
plt.close(fig_roc)
if os.path.exists(roc_path):
pdf.image(roc_path, x=10, y=None, w=180)
st.markdown("### ❌ Misclassified Samples")
fig_mis, axs = plt.subplots(1, min(5, len(misclassified_images)), figsize=(15, 4))
for idx, (img, pred, true) in enumerate(misclassified_images[:5]):
axs[idx].imshow(img.permute(1, 2, 0))
axs[idx].set_title(f"True: {class_names[true]}\nPred: {class_names[pred]}")
axs[idx].axis('off')
st.pyplot(fig_mis)
output_pdf = "evaluation_report.pdf"
pdf.output(output_pdf)
with open(output_pdf, "rb") as f:
reevaluate_col, download_col = st.columns([1, 1])
with download_col:
st.download_button("📄 Download Full Evaluation PDF", f, file_name="evaluation_report.pdf")