Spaces:
Runtime error
Runtime error
revert changes
Browse files- pages/Model_Evaluation.py +54 -32
pages/Model_Evaluation.py
CHANGED
|
@@ -16,7 +16,7 @@ import streamlit as st
|
|
| 16 |
import matplotlib.pyplot as plt
|
| 17 |
from fpdf import FPDF
|
| 18 |
from datasets import load_dataset
|
| 19 |
-
from huggingface_hub import hf_hub_download
|
| 20 |
|
| 21 |
# ---- Streamlit State Initialization ----
|
| 22 |
if 'stop_eval' not in st.session_state:
|
|
@@ -27,7 +27,7 @@ if 'trigger_eval' not in st.session_state:
|
|
| 27 |
st.session_state.trigger_eval = False
|
| 28 |
|
| 29 |
# ---- Streamlit Title ----
|
| 30 |
-
st.markdown("<h2 style='color: #2E86C1;'
|
| 31 |
|
| 32 |
# ---- Class Names & Label Mapping ----
|
| 33 |
class_names = ['No_DR', 'Mild', 'Moderate', 'Severe', 'Proliferative_DR']
|
|
@@ -93,13 +93,21 @@ val_transform = transforms.Compose([
|
|
| 93 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 94 |
])
|
| 95 |
|
| 96 |
-
# ---- Load Data ----
|
| 97 |
@st.cache_resource
|
| 98 |
-
def
|
| 99 |
-
dataset =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
return DataLoader(dataset, batch_size=32, shuffle=False)
|
| 101 |
|
| 102 |
-
# ---- Load Model ----
|
| 103 |
@st.cache_resource
|
| 104 |
def load_model():
|
| 105 |
model_path = hf_hub_download(repo_id="Ci-Dave/Densenet121", filename="Pretrained_Densenet-121.pth")
|
|
@@ -111,43 +119,55 @@ def load_model():
|
|
| 111 |
|
| 112 |
# ---- UI Buttons ----
|
| 113 |
model = load_model()
|
| 114 |
-
test_loader =
|
| 115 |
|
| 116 |
col1, col2 = st.columns([1, 1])
|
| 117 |
with col1:
|
| 118 |
-
if st.button("
|
| 119 |
st.session_state.stop_eval = False
|
| 120 |
st.session_state.evaluation_done = False
|
| 121 |
st.session_state.trigger_eval = True
|
| 122 |
with col2:
|
| 123 |
-
if st.button("
|
| 124 |
st.session_state.stop_eval = True
|
| 125 |
|
| 126 |
if st.session_state.evaluation_done:
|
| 127 |
reevaluate_col, download_col = st.columns([1, 1])
|
| 128 |
|
| 129 |
-
# ---- Model Evaluation
|
| 130 |
-
with st.expander("
|
| 131 |
st.markdown("""
|
| 132 |
<div style='font-size:16px;'>
|
| 133 |
-
The
|
| 134 |
|
| 135 |
-
#### What It Does:
|
| 136 |
-
- Loads the test dataset
|
| 137 |
- Runs the model to predict labels
|
| 138 |
- Compares predictions vs. true labels
|
| 139 |
- Computes:
|
| 140 |
-
- Classification Report
|
| 141 |
-
- Confusion Matrix
|
| 142 |
-
- ROC Curve
|
| 143 |
-
- Misclassified Samples
|
| 144 |
-
- Saves a downloadable PDF
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
</div>
|
| 146 |
""", unsafe_allow_html=True)
|
| 147 |
|
|
|
|
| 148 |
# ---- Evaluation Logic ----
|
| 149 |
if st.session_state.trigger_eval:
|
| 150 |
-
st.markdown("###
|
| 151 |
|
| 152 |
start_time = time.time()
|
| 153 |
y_true = []
|
|
@@ -163,14 +183,14 @@ if st.session_state.trigger_eval:
|
|
| 163 |
with torch.no_grad():
|
| 164 |
for i, (images, labels) in enumerate(test_loader):
|
| 165 |
if st.session_state.stop_eval:
|
| 166 |
-
stop_info.warning("
|
| 167 |
break
|
| 168 |
|
| 169 |
outputs = model(images)
|
| 170 |
_, predicted = torch.max(outputs, 1)
|
| 171 |
y_true.extend(labels.numpy())
|
| 172 |
y_pred.extend(predicted.numpy())
|
| 173 |
-
y_score.extend(outputs.numpy())
|
| 174 |
|
| 175 |
for j in range(len(labels)):
|
| 176 |
if predicted[j] != labels[j]:
|
|
@@ -178,14 +198,15 @@ if st.session_state.trigger_eval:
|
|
| 178 |
|
| 179 |
percent_complete = (i + 1) / total_batches
|
| 180 |
progress_bar.progress(min(percent_complete, 1.0))
|
| 181 |
-
status_text.text(f"Evaluating: {int(percent_complete * 100)}% | Batch {i+1}/{total_batches}")
|
|
|
|
| 182 |
|
| 183 |
end_time = time.time()
|
| 184 |
eval_time = end_time - start_time
|
| 185 |
|
| 186 |
if not st.session_state.stop_eval:
|
| 187 |
st.session_state.evaluation_done = True
|
| 188 |
-
st.session_state.trigger_eval = False
|
| 189 |
st.success(f"β
Evaluation completed in **{eval_time:.2f} seconds**")
|
| 190 |
|
| 191 |
report = classification_report(y_true, y_pred, target_names=class_names, output_dict=True)
|
|
@@ -195,20 +216,21 @@ if st.session_state.trigger_eval:
|
|
| 195 |
pdf = FPDF()
|
| 196 |
pdf.add_page()
|
| 197 |
pdf.set_font("Arial", size=12)
|
| 198 |
-
pdf.cell(200, 10, txt="Classification Report", ln=True, align='C')
|
| 199 |
|
|
|
|
| 200 |
headers = ["Class", "Precision", "Recall", "F1-Score"]
|
| 201 |
-
for header in headers:
|
| 202 |
-
pdf.cell(
|
| 203 |
pdf.ln()
|
| 204 |
|
| 205 |
for idx, row in report_df.iterrows():
|
| 206 |
if idx in ['accuracy', 'macro avg', 'weighted avg']:
|
| 207 |
continue
|
| 208 |
-
pdf.cell(
|
| 209 |
-
pdf.cell(
|
| 210 |
-
pdf.cell(
|
| 211 |
-
pdf.cell(
|
| 212 |
pdf.ln()
|
| 213 |
|
| 214 |
cm = confusion_matrix(y_true, y_pred)
|
|
@@ -257,4 +279,4 @@ if st.session_state.trigger_eval:
|
|
| 257 |
with open(output_pdf, "rb") as f:
|
| 258 |
reevaluate_col, download_col = st.columns([1, 1])
|
| 259 |
with download_col:
|
| 260 |
-
st.download_button("
|
|
|
|
| 16 |
import matplotlib.pyplot as plt
|
| 17 |
from fpdf import FPDF
|
| 18 |
from datasets import load_dataset
|
| 19 |
+
from huggingface_hub import hf_hub_download # β
NEW
|
| 20 |
|
| 21 |
# ---- Streamlit State Initialization ----
|
| 22 |
if 'stop_eval' not in st.session_state:
|
|
|
|
| 27 |
st.session_state.trigger_eval = False
|
| 28 |
|
| 29 |
# ---- Streamlit Title ----
|
| 30 |
+
st.markdown("<h2 style='color: #2E86C1;'>π Model Evaluation</h2>", unsafe_allow_html=True)
|
| 31 |
|
| 32 |
# ---- Class Names & Label Mapping ----
|
| 33 |
class_names = ['No_DR', 'Mild', 'Moderate', 'Severe', 'Proliferative_DR']
|
|
|
|
| 93 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 94 |
])
|
| 95 |
|
| 96 |
+
# ---- Load Data from Hugging Face (cached) ----
|
| 97 |
@st.cache_resource
|
| 98 |
+
def load_test_data_from_huggingface():
|
| 99 |
+
dataset = load_dataset(
|
| 100 |
+
"Ci-Dave/DDR_dataset_train_test",
|
| 101 |
+
data_files={"test": "splits/test_labels.csv"},
|
| 102 |
+
split="test"
|
| 103 |
+
)
|
| 104 |
+
df = dataset.to_pandas()
|
| 105 |
+
csv_path = "test_labels_temp.csv"
|
| 106 |
+
df.to_csv(csv_path, index=False)
|
| 107 |
+
dataset = DDRDataset(csv_path=csv_path, transform=val_transform)
|
| 108 |
return DataLoader(dataset, batch_size=32, shuffle=False)
|
| 109 |
|
| 110 |
+
# ---- Load Model from Hugging Face (cached) ----
|
| 111 |
@st.cache_resource
|
| 112 |
def load_model():
|
| 113 |
model_path = hf_hub_download(repo_id="Ci-Dave/Densenet121", filename="Pretrained_Densenet-121.pth")
|
|
|
|
| 119 |
|
| 120 |
# ---- UI Buttons ----
|
| 121 |
model = load_model()
|
| 122 |
+
test_loader = load_test_data_from_huggingface()
|
| 123 |
|
| 124 |
col1, col2 = st.columns([1, 1])
|
| 125 |
with col1:
|
| 126 |
+
if st.button("π Start Evaluation"):
|
| 127 |
st.session_state.stop_eval = False
|
| 128 |
st.session_state.evaluation_done = False
|
| 129 |
st.session_state.trigger_eval = True
|
| 130 |
with col2:
|
| 131 |
+
if st.button("π© Stop Evaluation"):
|
| 132 |
st.session_state.stop_eval = True
|
| 133 |
|
| 134 |
if st.session_state.evaluation_done:
|
| 135 |
reevaluate_col, download_col = st.columns([1, 1])
|
| 136 |
|
| 137 |
+
# ---- Description for Model Evaluation ----
|
| 138 |
+
with st.expander("βΉοΈ **What is Model Evaluation?**", expanded=True):
|
| 139 |
st.markdown("""
|
| 140 |
<div style='font-size:16px;'>
|
| 141 |
+
The **Model Evaluation** section tests how well the trained AI model performs on the unseen <strong>test set</strong> of retinal images. This provides insights into the reliability and performance of the model when deployed in real scenarios.
|
| 142 |
|
| 143 |
+
#### π What It Does:
|
| 144 |
+
- Loads the test dataset of labeled retinal images
|
| 145 |
- Runs the model to predict labels
|
| 146 |
- Compares predictions vs. true labels
|
| 147 |
- Computes:
|
| 148 |
+
- π **Classification Report** (Precision, Recall, F1-Score)
|
| 149 |
+
- π§ **Confusion Matrix**
|
| 150 |
+
- π **Multi-class ROC Curve**
|
| 151 |
+
- β **Misclassified Image Samples**
|
| 152 |
+
- Saves the full report as a downloadable PDF
|
| 153 |
+
|
| 154 |
+
#### π§ How to Use:
|
| 155 |
+
1. Click **π Start Evaluation** to begin analyzing the modelβs performance.
|
| 156 |
+
2. Wait for the evaluation to finish (shows progress bar and batch updates).
|
| 157 |
+
3. Once done:
|
| 158 |
+
- Check performance scores for each DR class
|
| 159 |
+
- View visual summaries like confusion matrix and ROC curve
|
| 160 |
+
- See the top 5 misclassified examples
|
| 161 |
+
4. Optionally, download the full evaluation report via **π Download PDF**
|
| 162 |
+
|
| 163 |
+
β οΈ <i>Note: This evaluation runs on the full test set and might take several seconds depending on hardware.</i>
|
| 164 |
</div>
|
| 165 |
""", unsafe_allow_html=True)
|
| 166 |
|
| 167 |
+
|
| 168 |
# ---- Evaluation Logic ----
|
| 169 |
if st.session_state.trigger_eval:
|
| 170 |
+
st.markdown("### β±οΈ Evaluation Results")
|
| 171 |
|
| 172 |
start_time = time.time()
|
| 173 |
y_true = []
|
|
|
|
| 183 |
with torch.no_grad():
|
| 184 |
for i, (images, labels) in enumerate(test_loader):
|
| 185 |
if st.session_state.stop_eval:
|
| 186 |
+
stop_info.warning("π© Evaluation stopped by user.")
|
| 187 |
break
|
| 188 |
|
| 189 |
outputs = model(images)
|
| 190 |
_, predicted = torch.max(outputs, 1)
|
| 191 |
y_true.extend(labels.numpy())
|
| 192 |
y_pred.extend(predicted.numpy())
|
| 193 |
+
y_score.extend(outputs.detach().numpy())
|
| 194 |
|
| 195 |
for j in range(len(labels)):
|
| 196 |
if predicted[j] != labels[j]:
|
|
|
|
| 198 |
|
| 199 |
percent_complete = (i + 1) / total_batches
|
| 200 |
progress_bar.progress(min(percent_complete, 1.0))
|
| 201 |
+
status_text.text(f"Evaluating on Test Set: {int(percent_complete * 100)}% | Batch {i+1}/{total_batches}")
|
| 202 |
+
time.sleep(0.1)
|
| 203 |
|
| 204 |
end_time = time.time()
|
| 205 |
eval_time = end_time - start_time
|
| 206 |
|
| 207 |
if not st.session_state.stop_eval:
|
| 208 |
st.session_state.evaluation_done = True
|
| 209 |
+
st.session_state.trigger_eval = False # β
Reset the trigger
|
| 210 |
st.success(f"β
Evaluation completed in **{eval_time:.2f} seconds**")
|
| 211 |
|
| 212 |
report = classification_report(y_true, y_pred, target_names=class_names, output_dict=True)
|
|
|
|
| 216 |
pdf = FPDF()
|
| 217 |
pdf.add_page()
|
| 218 |
pdf.set_font("Arial", size=12)
|
| 219 |
+
pdf.cell(200, 10, txt=clean_text("Classification Report"), ln=True, align='C')
|
| 220 |
|
| 221 |
+
col_widths = [40, 40, 40, 40]
|
| 222 |
headers = ["Class", "Precision", "Recall", "F1-Score"]
|
| 223 |
+
for i, header in enumerate(headers):
|
| 224 |
+
pdf.cell(col_widths[i], 10, header, border=1)
|
| 225 |
pdf.ln()
|
| 226 |
|
| 227 |
for idx, row in report_df.iterrows():
|
| 228 |
if idx in ['accuracy', 'macro avg', 'weighted avg']:
|
| 229 |
continue
|
| 230 |
+
pdf.cell(col_widths[0], 10, str(idx), border=1)
|
| 231 |
+
pdf.cell(col_widths[1], 10, f"{row['precision']:.2f}", border=1)
|
| 232 |
+
pdf.cell(col_widths[2], 10, f"{row['recall']:.2f}", border=1)
|
| 233 |
+
pdf.cell(col_widths[3], 10, f"{row['f1-score']:.2f}", border=1)
|
| 234 |
pdf.ln()
|
| 235 |
|
| 236 |
cm = confusion_matrix(y_true, y_pred)
|
|
|
|
| 279 |
with open(output_pdf, "rb") as f:
|
| 280 |
reevaluate_col, download_col = st.columns([1, 1])
|
| 281 |
with download_col:
|
| 282 |
+
st.download_button("π Download Full Evaluation PDF", f, file_name="evaluation_report.pdf")
|