3v324v23 commited on
Commit
b925d8f
Β·
1 Parent(s): da6f0a0

revert changes

Browse files
Files changed (1) hide show
  1. pages/Model_Evaluation.py +54 -32
pages/Model_Evaluation.py CHANGED
@@ -16,7 +16,7 @@ import streamlit as st
16
  import matplotlib.pyplot as plt
17
  from fpdf import FPDF
18
  from datasets import load_dataset
19
- from huggingface_hub import hf_hub_download
20
 
21
  # ---- Streamlit State Initialization ----
22
  if 'stop_eval' not in st.session_state:
@@ -27,7 +27,7 @@ if 'trigger_eval' not in st.session_state:
27
  st.session_state.trigger_eval = False
28
 
29
  # ---- Streamlit Title ----
30
- st.markdown("<h2 style='color: #2E86C1;'>\ud83d\udcc8 Model Evaluation</h2>", unsafe_allow_html=True)
31
 
32
  # ---- Class Names & Label Mapping ----
33
  class_names = ['No_DR', 'Mild', 'Moderate', 'Severe', 'Proliferative_DR']
@@ -93,13 +93,21 @@ val_transform = transforms.Compose([
93
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
94
  ])
95
 
96
- # ---- Load Data ----
97
  @st.cache_resource
98
- def load_test_data():
99
- dataset = DDRDataset(csv_path="splits/test_labels_with_paths.csv", transform=val_transform)
 
 
 
 
 
 
 
 
100
  return DataLoader(dataset, batch_size=32, shuffle=False)
101
 
102
- # ---- Load Model ----
103
  @st.cache_resource
104
  def load_model():
105
  model_path = hf_hub_download(repo_id="Ci-Dave/Densenet121", filename="Pretrained_Densenet-121.pth")
@@ -111,43 +119,55 @@ def load_model():
111
 
112
  # ---- UI Buttons ----
113
  model = load_model()
114
- test_loader = load_test_data()
115
 
116
  col1, col2 = st.columns([1, 1])
117
  with col1:
118
- if st.button("\ud83d\ude80 Start Evaluation"):
119
  st.session_state.stop_eval = False
120
  st.session_state.evaluation_done = False
121
  st.session_state.trigger_eval = True
122
  with col2:
123
- if st.button("\ud83d\udea9 Stop Evaluation"):
124
  st.session_state.stop_eval = True
125
 
126
  if st.session_state.evaluation_done:
127
  reevaluate_col, download_col = st.columns([1, 1])
128
 
129
- # ---- Model Evaluation Explanation ----
130
- with st.expander("\u2139\ufe0f **What is Model Evaluation?**", expanded=True):
131
  st.markdown("""
132
  <div style='font-size:16px;'>
133
- The <strong>Model Evaluation</strong> section tests how well the trained AI model performs on the test set of retinal images.
134
 
135
- #### What It Does:
136
- - Loads the test dataset
137
  - Runs the model to predict labels
138
  - Compares predictions vs. true labels
139
  - Computes:
140
- - Classification Report
141
- - Confusion Matrix
142
- - ROC Curve
143
- - Misclassified Samples
144
- - Saves a downloadable PDF report
 
 
 
 
 
 
 
 
 
 
 
145
  </div>
146
  """, unsafe_allow_html=True)
147
 
 
148
  # ---- Evaluation Logic ----
149
  if st.session_state.trigger_eval:
150
- st.markdown("### \u23f1\ufe0f Evaluation Results")
151
 
152
  start_time = time.time()
153
  y_true = []
@@ -163,14 +183,14 @@ if st.session_state.trigger_eval:
163
  with torch.no_grad():
164
  for i, (images, labels) in enumerate(test_loader):
165
  if st.session_state.stop_eval:
166
- stop_info.warning("\ud83d\udea9 Evaluation stopped by user.")
167
  break
168
 
169
  outputs = model(images)
170
  _, predicted = torch.max(outputs, 1)
171
  y_true.extend(labels.numpy())
172
  y_pred.extend(predicted.numpy())
173
- y_score.extend(outputs.numpy())
174
 
175
  for j in range(len(labels)):
176
  if predicted[j] != labels[j]:
@@ -178,14 +198,15 @@ if st.session_state.trigger_eval:
178
 
179
  percent_complete = (i + 1) / total_batches
180
  progress_bar.progress(min(percent_complete, 1.0))
181
- status_text.text(f"Evaluating: {int(percent_complete * 100)}% | Batch {i+1}/{total_batches}")
 
182
 
183
  end_time = time.time()
184
  eval_time = end_time - start_time
185
 
186
  if not st.session_state.stop_eval:
187
  st.session_state.evaluation_done = True
188
- st.session_state.trigger_eval = False
189
  st.success(f"βœ… Evaluation completed in **{eval_time:.2f} seconds**")
190
 
191
  report = classification_report(y_true, y_pred, target_names=class_names, output_dict=True)
@@ -195,20 +216,21 @@ if st.session_state.trigger_eval:
195
  pdf = FPDF()
196
  pdf.add_page()
197
  pdf.set_font("Arial", size=12)
198
- pdf.cell(200, 10, txt="Classification Report", ln=True, align='C')
199
 
 
200
  headers = ["Class", "Precision", "Recall", "F1-Score"]
201
- for header in headers:
202
- pdf.cell(40, 10, header, border=1)
203
  pdf.ln()
204
 
205
  for idx, row in report_df.iterrows():
206
  if idx in ['accuracy', 'macro avg', 'weighted avg']:
207
  continue
208
- pdf.cell(40, 10, str(idx), border=1)
209
- pdf.cell(40, 10, f"{row['precision']:.2f}", border=1)
210
- pdf.cell(40, 10, f"{row['recall']:.2f}", border=1)
211
- pdf.cell(40, 10, f"{row['f1-score']:.2f}", border=1)
212
  pdf.ln()
213
 
214
  cm = confusion_matrix(y_true, y_pred)
@@ -257,4 +279,4 @@ if st.session_state.trigger_eval:
257
  with open(output_pdf, "rb") as f:
258
  reevaluate_col, download_col = st.columns([1, 1])
259
  with download_col:
260
- st.download_button("\ud83d\udcc4 Download Full Evaluation PDF", f, file_name="evaluation_report.pdf")
 
16
  import matplotlib.pyplot as plt
17
  from fpdf import FPDF
18
  from datasets import load_dataset
19
+ from huggingface_hub import hf_hub_download # βœ… NEW
20
 
21
  # ---- Streamlit State Initialization ----
22
  if 'stop_eval' not in st.session_state:
 
27
  st.session_state.trigger_eval = False
28
 
29
  # ---- Streamlit Title ----
30
+ st.markdown("<h2 style='color: #2E86C1;'>πŸ“ˆ Model Evaluation</h2>", unsafe_allow_html=True)
31
 
32
  # ---- Class Names & Label Mapping ----
33
  class_names = ['No_DR', 'Mild', 'Moderate', 'Severe', 'Proliferative_DR']
 
93
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
94
  ])
95
 
96
+ # ---- Load Data from Hugging Face (cached) ----
97
  @st.cache_resource
98
+ def load_test_data_from_huggingface():
99
+ dataset = load_dataset(
100
+ "Ci-Dave/DDR_dataset_train_test",
101
+ data_files={"test": "splits/test_labels.csv"},
102
+ split="test"
103
+ )
104
+ df = dataset.to_pandas()
105
+ csv_path = "test_labels_temp.csv"
106
+ df.to_csv(csv_path, index=False)
107
+ dataset = DDRDataset(csv_path=csv_path, transform=val_transform)
108
  return DataLoader(dataset, batch_size=32, shuffle=False)
109
 
110
+ # ---- Load Model from Hugging Face (cached) ----
111
  @st.cache_resource
112
  def load_model():
113
  model_path = hf_hub_download(repo_id="Ci-Dave/Densenet121", filename="Pretrained_Densenet-121.pth")
 
119
 
120
  # ---- UI Buttons ----
121
  model = load_model()
122
+ test_loader = load_test_data_from_huggingface()
123
 
124
  col1, col2 = st.columns([1, 1])
125
  with col1:
126
+ if st.button("πŸš€ Start Evaluation"):
127
  st.session_state.stop_eval = False
128
  st.session_state.evaluation_done = False
129
  st.session_state.trigger_eval = True
130
  with col2:
131
+ if st.button("🚩 Stop Evaluation"):
132
  st.session_state.stop_eval = True
133
 
134
  if st.session_state.evaluation_done:
135
  reevaluate_col, download_col = st.columns([1, 1])
136
 
137
+ # ---- Description for Model Evaluation ----
138
+ with st.expander("ℹ️ **What is Model Evaluation?**", expanded=True):
139
  st.markdown("""
140
  <div style='font-size:16px;'>
141
+ The **Model Evaluation** section tests how well the trained AI model performs on the unseen <strong>test set</strong> of retinal images. This provides insights into the reliability and performance of the model when deployed in real scenarios.
142
 
143
+ #### πŸ” What It Does:
144
+ - Loads the test dataset of labeled retinal images
145
  - Runs the model to predict labels
146
  - Compares predictions vs. true labels
147
  - Computes:
148
+ - πŸ“‹ **Classification Report** (Precision, Recall, F1-Score)
149
+ - 🧊 **Confusion Matrix**
150
+ - πŸ“ˆ **Multi-class ROC Curve**
151
+ - ❌ **Misclassified Image Samples**
152
+ - Saves the full report as a downloadable PDF
153
+
154
+ #### 🧭 How to Use:
155
+ 1. Click **πŸš€ Start Evaluation** to begin analyzing the model’s performance.
156
+ 2. Wait for the evaluation to finish (shows progress bar and batch updates).
157
+ 3. Once done:
158
+ - Check performance scores for each DR class
159
+ - View visual summaries like confusion matrix and ROC curve
160
+ - See the top 5 misclassified examples
161
+ 4. Optionally, download the full evaluation report via **πŸ“„ Download PDF**
162
+
163
+ ⚠️ <i>Note: This evaluation runs on the full test set and might take several seconds depending on hardware.</i>
164
  </div>
165
  """, unsafe_allow_html=True)
166
 
167
+
168
  # ---- Evaluation Logic ----
169
  if st.session_state.trigger_eval:
170
+ st.markdown("### ⏱️ Evaluation Results")
171
 
172
  start_time = time.time()
173
  y_true = []
 
183
  with torch.no_grad():
184
  for i, (images, labels) in enumerate(test_loader):
185
  if st.session_state.stop_eval:
186
+ stop_info.warning("🚩 Evaluation stopped by user.")
187
  break
188
 
189
  outputs = model(images)
190
  _, predicted = torch.max(outputs, 1)
191
  y_true.extend(labels.numpy())
192
  y_pred.extend(predicted.numpy())
193
+ y_score.extend(outputs.detach().numpy())
194
 
195
  for j in range(len(labels)):
196
  if predicted[j] != labels[j]:
 
198
 
199
  percent_complete = (i + 1) / total_batches
200
  progress_bar.progress(min(percent_complete, 1.0))
201
+ status_text.text(f"Evaluating on Test Set: {int(percent_complete * 100)}% | Batch {i+1}/{total_batches}")
202
+ time.sleep(0.1)
203
 
204
  end_time = time.time()
205
  eval_time = end_time - start_time
206
 
207
  if not st.session_state.stop_eval:
208
  st.session_state.evaluation_done = True
209
+ st.session_state.trigger_eval = False # βœ… Reset the trigger
210
  st.success(f"βœ… Evaluation completed in **{eval_time:.2f} seconds**")
211
 
212
  report = classification_report(y_true, y_pred, target_names=class_names, output_dict=True)
 
216
  pdf = FPDF()
217
  pdf.add_page()
218
  pdf.set_font("Arial", size=12)
219
+ pdf.cell(200, 10, txt=clean_text("Classification Report"), ln=True, align='C')
220
 
221
+ col_widths = [40, 40, 40, 40]
222
  headers = ["Class", "Precision", "Recall", "F1-Score"]
223
+ for i, header in enumerate(headers):
224
+ pdf.cell(col_widths[i], 10, header, border=1)
225
  pdf.ln()
226
 
227
  for idx, row in report_df.iterrows():
228
  if idx in ['accuracy', 'macro avg', 'weighted avg']:
229
  continue
230
+ pdf.cell(col_widths[0], 10, str(idx), border=1)
231
+ pdf.cell(col_widths[1], 10, f"{row['precision']:.2f}", border=1)
232
+ pdf.cell(col_widths[2], 10, f"{row['recall']:.2f}", border=1)
233
+ pdf.cell(col_widths[3], 10, f"{row['f1-score']:.2f}", border=1)
234
  pdf.ln()
235
 
236
  cm = confusion_matrix(y_true, y_pred)
 
279
  with open(output_pdf, "rb") as f:
280
  reevaluate_col, download_col = st.columns([1, 1])
281
  with download_col:
282
+ st.download_button("πŸ“„ Download Full Evaluation PDF", f, file_name="evaluation_report.pdf")