3v324v23 commited on
Commit
8e7b0b6
·
1 Parent(s): 0e30fdd
Files changed (1) hide show
  1. pages/Model_Evaluation.py +31 -53
pages/Model_Evaluation.py CHANGED
@@ -16,7 +16,7 @@ import streamlit as st
16
  import matplotlib.pyplot as plt
17
  from fpdf import FPDF
18
  from datasets import load_dataset
19
- from huggingface_hub import hf_hub_download # ✅ NEW
20
 
21
  # ---- Streamlit State Initialization ----
22
  if 'stop_eval' not in st.session_state:
@@ -27,7 +27,7 @@ if 'trigger_eval' not in st.session_state:
27
  st.session_state.trigger_eval = False
28
 
29
  # ---- Streamlit Title ----
30
- st.markdown("<h2 style='color: #2E86C1;'>📈 Model Evaluation</h2>", unsafe_allow_html=True)
31
 
32
  # ---- Class Names & Label Mapping ----
33
  class_names = ['No_DR', 'Mild', 'Moderate', 'Severe', 'Proliferative_DR']
@@ -93,21 +93,13 @@ val_transform = transforms.Compose([
93
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
94
  ])
95
 
96
- # ---- Load Data from Hugging Face (cached) ----
97
  @st.cache_resource
98
- def load_test_data_from_huggingface():
99
- dataset = load_dataset(
100
- "Ci-Dave/DDR_dataset_train_test",
101
- data_files={"test": "splits/test_labels.csv"},
102
- split="test"
103
- )
104
- df = dataset.to_pandas()
105
- csv_path = "test_labels_temp.csv"
106
- df.to_csv(csv_path, index=False)
107
- dataset = DDRDataset(csv_path=csv_path, transform=val_transform)
108
  return DataLoader(dataset, batch_size=32, shuffle=False)
109
 
110
- # ---- Load Model from Hugging Face (cached) ----
111
  @st.cache_resource
112
  def load_model():
113
  model_path = hf_hub_download(repo_id="Ci-Dave/Densenet121", filename="Pretrained_Densenet-121.pth")
@@ -119,55 +111,43 @@ def load_model():
119
 
120
  # ---- UI Buttons ----
121
  model = load_model()
122
- test_loader = load_test_data_from_huggingface()
123
 
124
  col1, col2 = st.columns([1, 1])
125
  with col1:
126
- if st.button("🚀 Start Evaluation"):
127
  st.session_state.stop_eval = False
128
  st.session_state.evaluation_done = False
129
  st.session_state.trigger_eval = True
130
  with col2:
131
- if st.button("🚩 Stop Evaluation"):
132
  st.session_state.stop_eval = True
133
 
134
  if st.session_state.evaluation_done:
135
  reevaluate_col, download_col = st.columns([1, 1])
136
 
137
- # ---- Description for Model Evaluation ----
138
- with st.expander("ℹ️ **What is Model Evaluation?**", expanded=True):
139
  st.markdown("""
140
  <div style='font-size:16px;'>
141
- The **Model Evaluation** section tests how well the trained AI model performs on the unseen <strong>test set</strong> of retinal images. This provides insights into the reliability and performance of the model when deployed in real scenarios.
142
 
143
- #### 🔍 What It Does:
144
- - Loads the test dataset of labeled retinal images
145
  - Runs the model to predict labels
146
  - Compares predictions vs. true labels
147
  - Computes:
148
- - 📋 **Classification Report** (Precision, Recall, F1-Score)
149
- - 🧊 **Confusion Matrix**
150
- - 📈 **Multi-class ROC Curve**
151
- - ❌ **Misclassified Image Samples**
152
- - Saves the full report as a downloadable PDF
153
-
154
- #### 🧭 How to Use:
155
- 1. Click **🚀 Start Evaluation** to begin analyzing the model’s performance.
156
- 2. Wait for the evaluation to finish (shows progress bar and batch updates).
157
- 3. Once done:
158
- - Check performance scores for each DR class
159
- - View visual summaries like confusion matrix and ROC curve
160
- - See the top 5 misclassified examples
161
- 4. Optionally, download the full evaluation report via **📄 Download PDF**
162
-
163
- ⚠️ <i>Note: This evaluation runs on the full test set and might take several seconds depending on hardware.</i>
164
  </div>
165
  """, unsafe_allow_html=True)
166
 
167
-
168
  # ---- Evaluation Logic ----
169
  if st.session_state.trigger_eval:
170
- st.markdown("### ⏱️ Evaluation Results")
171
 
172
  start_time = time.time()
173
  y_true = []
@@ -183,14 +163,14 @@ if st.session_state.trigger_eval:
183
  with torch.no_grad():
184
  for i, (images, labels) in enumerate(test_loader):
185
  if st.session_state.stop_eval:
186
- stop_info.warning("🚩 Evaluation stopped by user.")
187
  break
188
 
189
  outputs = model(images)
190
  _, predicted = torch.max(outputs, 1)
191
  y_true.extend(labels.numpy())
192
  y_pred.extend(predicted.numpy())
193
- y_score.extend(outputs.detach().numpy())
194
 
195
  for j in range(len(labels)):
196
  if predicted[j] != labels[j]:
@@ -198,15 +178,14 @@ if st.session_state.trigger_eval:
198
 
199
  percent_complete = (i + 1) / total_batches
200
  progress_bar.progress(min(percent_complete, 1.0))
201
- status_text.text(f"Evaluating on Test Set: {int(percent_complete * 100)}% | Batch {i+1}/{total_batches}")
202
- time.sleep(0.1)
203
 
204
  end_time = time.time()
205
  eval_time = end_time - start_time
206
 
207
  if not st.session_state.stop_eval:
208
  st.session_state.evaluation_done = True
209
- st.session_state.trigger_eval = False # ✅ Reset the trigger
210
  st.success(f"✅ Evaluation completed in **{eval_time:.2f} seconds**")
211
 
212
  report = classification_report(y_true, y_pred, target_names=class_names, output_dict=True)
@@ -218,19 +197,18 @@ if st.session_state.trigger_eval:
218
  pdf.set_font("Arial", size=12)
219
  pdf.cell(200, 10, txt=clean_text("Classification Report"), ln=True, align='C')
220
 
221
- col_widths = [40, 40, 40, 40]
222
  headers = ["Class", "Precision", "Recall", "F1-Score"]
223
- for i, header in enumerate(headers):
224
- pdf.cell(col_widths[i], 10, header, border=1)
225
  pdf.ln()
226
 
227
  for idx, row in report_df.iterrows():
228
  if idx in ['accuracy', 'macro avg', 'weighted avg']:
229
  continue
230
- pdf.cell(col_widths[0], 10, str(idx), border=1)
231
- pdf.cell(col_widths[1], 10, f"{row['precision']:.2f}", border=1)
232
- pdf.cell(col_widths[2], 10, f"{row['recall']:.2f}", border=1)
233
- pdf.cell(col_widths[3], 10, f"{row['f1-score']:.2f}", border=1)
234
  pdf.ln()
235
 
236
  cm = confusion_matrix(y_true, y_pred)
@@ -279,4 +257,4 @@ if st.session_state.trigger_eval:
279
  with open(output_pdf, "rb") as f:
280
  reevaluate_col, download_col = st.columns([1, 1])
281
  with download_col:
282
- st.download_button("📄 Download Full Evaluation PDF", f, file_name="evaluation_report.pdf")
 
16
  import matplotlib.pyplot as plt
17
  from fpdf import FPDF
18
  from datasets import load_dataset
19
+ from huggingface_hub import hf_hub_download
20
 
21
  # ---- Streamlit State Initialization ----
22
  if 'stop_eval' not in st.session_state:
 
27
  st.session_state.trigger_eval = False
28
 
29
  # ---- Streamlit Title ----
30
+ st.markdown("<h2 style='color: #2E86C1;'>\ud83d\udcc8 Model Evaluation</h2>", unsafe_allow_html=True)
31
 
32
  # ---- Class Names & Label Mapping ----
33
  class_names = ['No_DR', 'Mild', 'Moderate', 'Severe', 'Proliferative_DR']
 
93
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
94
  ])
95
 
96
+ # ---- Load Data ----
97
  @st.cache_resource
98
+ def load_test_data():
99
+ dataset = DDRDataset(csv_path="splits/test_labels_with_paths.csv", transform=val_transform)
 
 
 
 
 
 
 
 
100
  return DataLoader(dataset, batch_size=32, shuffle=False)
101
 
102
+ # ---- Load Model ----
103
  @st.cache_resource
104
  def load_model():
105
  model_path = hf_hub_download(repo_id="Ci-Dave/Densenet121", filename="Pretrained_Densenet-121.pth")
 
111
 
112
  # ---- UI Buttons ----
113
  model = load_model()
114
+ test_loader = load_test_data()
115
 
116
  col1, col2 = st.columns([1, 1])
117
  with col1:
118
+ if st.button("\ud83d\ude80 Start Evaluation"):
119
  st.session_state.stop_eval = False
120
  st.session_state.evaluation_done = False
121
  st.session_state.trigger_eval = True
122
  with col2:
123
+ if st.button("\ud83d\udea9 Stop Evaluation"):
124
  st.session_state.stop_eval = True
125
 
126
  if st.session_state.evaluation_done:
127
  reevaluate_col, download_col = st.columns([1, 1])
128
 
129
+ # ---- Model Evaluation Explanation ----
130
+ with st.expander("\u2139\ufe0f **What is Model Evaluation?**", expanded=True):
131
  st.markdown("""
132
  <div style='font-size:16px;'>
133
+ The <strong>Model Evaluation</strong> section tests how well the trained AI model performs on the test set of retinal images.
134
 
135
+ #### What It Does:
136
+ - Loads the test dataset
137
  - Runs the model to predict labels
138
  - Compares predictions vs. true labels
139
  - Computes:
140
+ - Classification Report
141
+ - Confusion Matrix
142
+ - ROC Curve
143
+ - Misclassified Samples
144
+ - Saves a downloadable PDF report
 
 
 
 
 
 
 
 
 
 
 
145
  </div>
146
  """, unsafe_allow_html=True)
147
 
 
148
  # ---- Evaluation Logic ----
149
  if st.session_state.trigger_eval:
150
+ st.markdown("### \u23f1\ufe0f Evaluation Results")
151
 
152
  start_time = time.time()
153
  y_true = []
 
163
  with torch.no_grad():
164
  for i, (images, labels) in enumerate(test_loader):
165
  if st.session_state.stop_eval:
166
+ stop_info.warning("\ud83d\udea9 Evaluation stopped by user.")
167
  break
168
 
169
  outputs = model(images)
170
  _, predicted = torch.max(outputs, 1)
171
  y_true.extend(labels.numpy())
172
  y_pred.extend(predicted.numpy())
173
+ y_score.extend(outputs.numpy())
174
 
175
  for j in range(len(labels)):
176
  if predicted[j] != labels[j]:
 
178
 
179
  percent_complete = (i + 1) / total_batches
180
  progress_bar.progress(min(percent_complete, 1.0))
181
+ status_text.text(f"Evaluating: {int(percent_complete * 100)}% | Batch {i+1}/{total_batches}")
 
182
 
183
  end_time = time.time()
184
  eval_time = end_time - start_time
185
 
186
  if not st.session_state.stop_eval:
187
  st.session_state.evaluation_done = True
188
+ st.session_state.trigger_eval = False
189
  st.success(f"✅ Evaluation completed in **{eval_time:.2f} seconds**")
190
 
191
  report = classification_report(y_true, y_pred, target_names=class_names, output_dict=True)
 
197
  pdf.set_font("Arial", size=12)
198
  pdf.cell(200, 10, txt=clean_text("Classification Report"), ln=True, align='C')
199
 
 
200
  headers = ["Class", "Precision", "Recall", "F1-Score"]
201
+ for header in headers:
202
+ pdf.cell(40, 10, header, border=1)
203
  pdf.ln()
204
 
205
  for idx, row in report_df.iterrows():
206
  if idx in ['accuracy', 'macro avg', 'weighted avg']:
207
  continue
208
+ pdf.cell(40, 10, str(idx), border=1)
209
+ pdf.cell(40, 10, f"{row['precision']:.2f}", border=1)
210
+ pdf.cell(40, 10, f"{row['recall']:.2f}", border=1)
211
+ pdf.cell(40, 10, f"{row['f1-score']:.2f}", border=1)
212
  pdf.ln()
213
 
214
  cm = confusion_matrix(y_true, y_pred)
 
257
  with open(output_pdf, "rb") as f:
258
  reevaluate_col, download_col = st.columns([1, 1])
259
  with download_col:
260
+ st.download_button("\ud83d\udcc4 Download Full Evaluation PDF", f, file_name="evaluation_report.pdf")