Spaces:
Runtime error
Runtime error
put comments
Browse files- pages/Model_Evaluation.py +33 -13
- training/training.ipynb +37 -51
pages/Model_Evaluation.py
CHANGED
|
@@ -161,64 +161,76 @@ with st.expander("ℹ️ **What is Model Evaluation?**", expanded=True):
|
|
| 161 |
|
| 162 |
|
| 163 |
# ---- Evaluation Logic ----
|
|
|
|
| 164 |
if st.session_state.trigger_eval:
|
| 165 |
st.markdown("### ⏱️ Evaluation Results")
|
| 166 |
|
|
|
|
| 167 |
start_time = time.time()
|
| 168 |
-
y_true = []
|
| 169 |
-
y_pred = []
|
| 170 |
-
y_score = []
|
| 171 |
-
misclassified_images = []
|
| 172 |
|
| 173 |
-
total_batches = len(test_loader)
|
| 174 |
-
progress_bar = st.progress(0)
|
| 175 |
-
status_text = st.empty()
|
| 176 |
-
stop_info = st.empty()
|
| 177 |
|
|
|
|
| 178 |
with torch.no_grad():
|
| 179 |
for i, (images, labels) in enumerate(test_loader):
|
|
|
|
| 180 |
if st.session_state.stop_eval:
|
| 181 |
stop_info.warning("🚩 Evaluation stopped by user.")
|
| 182 |
break
|
| 183 |
|
|
|
|
| 184 |
outputs = model(images)
|
| 185 |
-
_, predicted = torch.max(outputs, 1)
|
| 186 |
y_true.extend(labels.numpy())
|
| 187 |
y_pred.extend(predicted.numpy())
|
| 188 |
y_score.extend(outputs.detach().numpy())
|
| 189 |
|
|
|
|
| 190 |
for j in range(len(labels)):
|
| 191 |
if predicted[j] != labels[j]:
|
| 192 |
misclassified_images.append((images[j], predicted[j].item(), labels[j].item()))
|
| 193 |
|
|
|
|
| 194 |
percent_complete = (i + 1) / total_batches
|
| 195 |
progress_bar.progress(min(percent_complete, 1.0))
|
| 196 |
status_text.text(f"Evaluating on Test Set: {int(percent_complete * 100)}% | Batch {i+1}/{total_batches}")
|
| 197 |
-
time.sleep(0.1)
|
| 198 |
|
| 199 |
end_time = time.time()
|
| 200 |
-
eval_time = end_time - start_time
|
| 201 |
|
|
|
|
| 202 |
if not st.session_state.stop_eval:
|
| 203 |
st.session_state.evaluation_done = True
|
| 204 |
st.session_state.trigger_eval = False # ✅ Reset the trigger
|
| 205 |
st.success(f"✅ Evaluation completed in **{eval_time:.2f} seconds**")
|
| 206 |
|
|
|
|
| 207 |
report = classification_report(y_true, y_pred, target_names=class_names, output_dict=True)
|
| 208 |
report_df = pd.DataFrame(report).transpose()
|
| 209 |
st.dataframe(report_df.style.format("{:.2f}"))
|
| 210 |
|
|
|
|
| 211 |
pdf = FPDF()
|
| 212 |
pdf.add_page()
|
| 213 |
pdf.set_font("Arial", size=12)
|
| 214 |
pdf.cell(200, 10, txt=clean_text("Classification Report"), ln=True, align='C')
|
| 215 |
|
|
|
|
| 216 |
col_widths = [40, 40, 40, 40]
|
| 217 |
headers = ["Class", "Precision", "Recall", "F1-Score"]
|
| 218 |
for i, header in enumerate(headers):
|
| 219 |
pdf.cell(col_widths[i], 10, header, border=1)
|
| 220 |
pdf.ln()
|
| 221 |
|
|
|
|
| 222 |
for idx, row in report_df.iterrows():
|
| 223 |
if idx in ['accuracy', 'macro avg', 'weighted avg']:
|
| 224 |
continue
|
|
@@ -228,6 +240,7 @@ if st.session_state.trigger_eval:
|
|
| 228 |
pdf.cell(col_widths[3], 10, f"{row['f1-score']:.2f}", border=1)
|
| 229 |
pdf.ln()
|
| 230 |
|
|
|
|
| 231 |
cm = confusion_matrix(y_true, y_pred)
|
| 232 |
fig_cm, ax = plt.subplots()
|
| 233 |
sns.heatmap(cm, annot=True, fmt='d', xticklabels=class_names, yticklabels=class_names, cmap="Blues", ax=ax)
|
|
@@ -235,12 +248,15 @@ if st.session_state.trigger_eval:
|
|
| 235 |
ax.set_ylabel('True')
|
| 236 |
ax.set_title("Confusion Matrix")
|
| 237 |
st.pyplot(fig_cm)
|
|
|
|
|
|
|
| 238 |
cm_path = "confusion_matrix.png"
|
| 239 |
fig_cm.savefig(cm_path, format='png', dpi=300, bbox_inches='tight')
|
| 240 |
plt.close(fig_cm)
|
| 241 |
if os.path.exists(cm_path):
|
| 242 |
pdf.image(cm_path, x=10, y=None, w=180)
|
| 243 |
|
|
|
|
| 244 |
y_true_bin = label_binarize(y_true, classes=list(range(len(class_names))))
|
| 245 |
y_score_np = np.array(y_score)
|
| 246 |
fig_roc, ax = plt.subplots()
|
|
@@ -249,26 +265,30 @@ if st.session_state.trigger_eval:
|
|
| 249 |
roc_auc = auc(fpr, tpr)
|
| 250 |
ax.plot(fpr, tpr, label=f'{class_names[i]} (AUC = {roc_auc:.2f})')
|
| 251 |
|
| 252 |
-
ax.plot([0, 1], [0, 1], 'k--')
|
| 253 |
ax.set_xlabel('False Positive Rate')
|
| 254 |
ax.set_ylabel('True Positive Rate')
|
| 255 |
ax.set_title('Multi-class ROC Curve')
|
| 256 |
ax.legend(loc='lower right')
|
| 257 |
st.pyplot(fig_roc)
|
|
|
|
|
|
|
| 258 |
roc_path = "roc_curve.png"
|
| 259 |
fig_roc.savefig(roc_path, format='png', dpi=300, bbox_inches='tight')
|
| 260 |
plt.close(fig_roc)
|
| 261 |
if os.path.exists(roc_path):
|
| 262 |
pdf.image(roc_path, x=10, y=None, w=180)
|
| 263 |
|
|
|
|
| 264 |
st.markdown("### ❌ Misclassified Samples")
|
| 265 |
fig_mis, axs = plt.subplots(1, min(5, len(misclassified_images)), figsize=(15, 4))
|
| 266 |
for idx, (img, pred, true) in enumerate(misclassified_images[:5]):
|
| 267 |
-
axs[idx].imshow(img.permute(1, 2, 0))
|
| 268 |
axs[idx].set_title(f"True: {class_names[true]}\nPred: {class_names[pred]}")
|
| 269 |
axs[idx].axis('off')
|
| 270 |
st.pyplot(fig_mis)
|
| 271 |
|
|
|
|
| 272 |
output_pdf = "evaluation_report.pdf"
|
| 273 |
pdf.output(output_pdf)
|
| 274 |
with open(output_pdf, "rb") as f:
|
|
|
|
| 161 |
|
| 162 |
|
| 163 |
# ---- Evaluation Logic ----
|
| 164 |
+
# Check if evaluation should be triggered
|
| 165 |
if st.session_state.trigger_eval:
|
| 166 |
st.markdown("### ⏱️ Evaluation Results")
|
| 167 |
|
| 168 |
+
# Start timing the evaluation
|
| 169 |
start_time = time.time()
|
| 170 |
+
y_true = [] # Ground truth labels
|
| 171 |
+
y_pred = [] # Predicted labels
|
| 172 |
+
y_score = [] # Raw model outputs
|
| 173 |
+
misclassified_images = [] # List to store misclassified samples
|
| 174 |
|
| 175 |
+
total_batches = len(test_loader) # Total number of batches
|
| 176 |
+
progress_bar = st.progress(0) # Initialize progress bar
|
| 177 |
+
status_text = st.empty() # Placeholder for status updates
|
| 178 |
+
stop_info = st.empty() # Placeholder for stop message
|
| 179 |
|
| 180 |
+
# Disable gradient calculation for faster evaluation
|
| 181 |
with torch.no_grad():
|
| 182 |
for i, (images, labels) in enumerate(test_loader):
|
| 183 |
+
# Allow user to stop the evaluation
|
| 184 |
if st.session_state.stop_eval:
|
| 185 |
stop_info.warning("🚩 Evaluation stopped by user.")
|
| 186 |
break
|
| 187 |
|
| 188 |
+
# Run model on input images
|
| 189 |
outputs = model(images)
|
| 190 |
+
_, predicted = torch.max(outputs, 1) # Get predicted class
|
| 191 |
y_true.extend(labels.numpy())
|
| 192 |
y_pred.extend(predicted.numpy())
|
| 193 |
y_score.extend(outputs.detach().numpy())
|
| 194 |
|
| 195 |
+
# Store misclassified samples
|
| 196 |
for j in range(len(labels)):
|
| 197 |
if predicted[j] != labels[j]:
|
| 198 |
misclassified_images.append((images[j], predicted[j].item(), labels[j].item()))
|
| 199 |
|
| 200 |
+
# Update progress bar and status text
|
| 201 |
percent_complete = (i + 1) / total_batches
|
| 202 |
progress_bar.progress(min(percent_complete, 1.0))
|
| 203 |
status_text.text(f"Evaluating on Test Set: {int(percent_complete * 100)}% | Batch {i+1}/{total_batches}")
|
| 204 |
+
time.sleep(0.1) # Add delay for UI responsiveness
|
| 205 |
|
| 206 |
end_time = time.time()
|
| 207 |
+
eval_time = end_time - start_time # Total evaluation time
|
| 208 |
|
| 209 |
+
# Finalize evaluation if not stopped
|
| 210 |
if not st.session_state.stop_eval:
|
| 211 |
st.session_state.evaluation_done = True
|
| 212 |
st.session_state.trigger_eval = False # ✅ Reset the trigger
|
| 213 |
st.success(f"✅ Evaluation completed in **{eval_time:.2f} seconds**")
|
| 214 |
|
| 215 |
+
# Generate classification report and display as a DataFrame
|
| 216 |
report = classification_report(y_true, y_pred, target_names=class_names, output_dict=True)
|
| 217 |
report_df = pd.DataFrame(report).transpose()
|
| 218 |
st.dataframe(report_df.style.format("{:.2f}"))
|
| 219 |
|
| 220 |
+
# Initialize PDF report
|
| 221 |
pdf = FPDF()
|
| 222 |
pdf.add_page()
|
| 223 |
pdf.set_font("Arial", size=12)
|
| 224 |
pdf.cell(200, 10, txt=clean_text("Classification Report"), ln=True, align='C')
|
| 225 |
|
| 226 |
+
# Add table headers
|
| 227 |
col_widths = [40, 40, 40, 40]
|
| 228 |
headers = ["Class", "Precision", "Recall", "F1-Score"]
|
| 229 |
for i, header in enumerate(headers):
|
| 230 |
pdf.cell(col_widths[i], 10, header, border=1)
|
| 231 |
pdf.ln()
|
| 232 |
|
| 233 |
+
# Add metrics for each class
|
| 234 |
for idx, row in report_df.iterrows():
|
| 235 |
if idx in ['accuracy', 'macro avg', 'weighted avg']:
|
| 236 |
continue
|
|
|
|
| 240 |
pdf.cell(col_widths[3], 10, f"{row['f1-score']:.2f}", border=1)
|
| 241 |
pdf.ln()
|
| 242 |
|
| 243 |
+
# Create and display confusion matrix
|
| 244 |
cm = confusion_matrix(y_true, y_pred)
|
| 245 |
fig_cm, ax = plt.subplots()
|
| 246 |
sns.heatmap(cm, annot=True, fmt='d', xticklabels=class_names, yticklabels=class_names, cmap="Blues", ax=ax)
|
|
|
|
| 248 |
ax.set_ylabel('True')
|
| 249 |
ax.set_title("Confusion Matrix")
|
| 250 |
st.pyplot(fig_cm)
|
| 251 |
+
|
| 252 |
+
# Save confusion matrix to PDF
|
| 253 |
cm_path = "confusion_matrix.png"
|
| 254 |
fig_cm.savefig(cm_path, format='png', dpi=300, bbox_inches='tight')
|
| 255 |
plt.close(fig_cm)
|
| 256 |
if os.path.exists(cm_path):
|
| 257 |
pdf.image(cm_path, x=10, y=None, w=180)
|
| 258 |
|
| 259 |
+
# Create and display ROC curve for each class
|
| 260 |
y_true_bin = label_binarize(y_true, classes=list(range(len(class_names))))
|
| 261 |
y_score_np = np.array(y_score)
|
| 262 |
fig_roc, ax = plt.subplots()
|
|
|
|
| 265 |
roc_auc = auc(fpr, tpr)
|
| 266 |
ax.plot(fpr, tpr, label=f'{class_names[i]} (AUC = {roc_auc:.2f})')
|
| 267 |
|
| 268 |
+
ax.plot([0, 1], [0, 1], 'k--') # Diagonal reference line
|
| 269 |
ax.set_xlabel('False Positive Rate')
|
| 270 |
ax.set_ylabel('True Positive Rate')
|
| 271 |
ax.set_title('Multi-class ROC Curve')
|
| 272 |
ax.legend(loc='lower right')
|
| 273 |
st.pyplot(fig_roc)
|
| 274 |
+
|
| 275 |
+
# Save ROC curve to PDF
|
| 276 |
roc_path = "roc_curve.png"
|
| 277 |
fig_roc.savefig(roc_path, format='png', dpi=300, bbox_inches='tight')
|
| 278 |
plt.close(fig_roc)
|
| 279 |
if os.path.exists(roc_path):
|
| 280 |
pdf.image(roc_path, x=10, y=None, w=180)
|
| 281 |
|
| 282 |
+
# Show misclassified samples (up to 5)
|
| 283 |
st.markdown("### ❌ Misclassified Samples")
|
| 284 |
fig_mis, axs = plt.subplots(1, min(5, len(misclassified_images)), figsize=(15, 4))
|
| 285 |
for idx, (img, pred, true) in enumerate(misclassified_images[:5]):
|
| 286 |
+
axs[idx].imshow(img.permute(1, 2, 0)) # Convert tensor to image format
|
| 287 |
axs[idx].set_title(f"True: {class_names[true]}\nPred: {class_names[pred]}")
|
| 288 |
axs[idx].axis('off')
|
| 289 |
st.pyplot(fig_mis)
|
| 290 |
|
| 291 |
+
# Save PDF and provide download button
|
| 292 |
output_pdf = "evaluation_report.pdf"
|
| 293 |
pdf.output(output_pdf)
|
| 294 |
with open(output_pdf, "rb") as f:
|
training/training.ipynb
CHANGED
|
@@ -1017,7 +1017,7 @@
|
|
| 1017 |
},
|
| 1018 |
{
|
| 1019 |
"cell_type": "code",
|
| 1020 |
-
"execution_count":
|
| 1021 |
"id": "560a2a1b",
|
| 1022 |
"metadata": {},
|
| 1023 |
"outputs": [
|
|
@@ -1035,38 +1035,70 @@
|
|
| 1035 |
"source": [
|
| 1036 |
"from sklearn.metrics import roc_curve, auc\n",
|
| 1037 |
"from sklearn.preprocessing import label_binarize\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1038 |
"\n",
|
| 1039 |
-
"
|
|
|
|
| 1040 |
"y_true_bin = label_binarize(all_labels, classes=[0, 1, 2, 3, 4])\n",
|
|
|
|
|
|
|
| 1041 |
"y_scores = []\n",
|
| 1042 |
"\n",
|
|
|
|
| 1043 |
"model.eval()\n",
|
|
|
|
|
|
|
| 1044 |
"with torch.no_grad():\n",
|
| 1045 |
" for inputs, labels in test_loader:\n",
|
| 1046 |
" inputs = inputs.to(device)\n",
|
|
|
|
|
|
|
| 1047 |
" outputs = model(inputs)\n",
|
|
|
|
|
|
|
| 1048 |
" probs = torch.softmax(outputs, dim=1)\n",
|
|
|
|
|
|
|
| 1049 |
" y_scores.extend(probs.cpu().numpy())\n",
|
| 1050 |
"\n",
|
| 1051 |
-
"#
|
| 1052 |
-
"fpr, tpr, roc_auc = dict(), dict(), dict()\n",
|
| 1053 |
"y_scores = np.array(y_scores)\n",
|
| 1054 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1055 |
"for i in range(n_classes):\n",
|
|
|
|
| 1056 |
" fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], y_scores[:, i])\n",
|
|
|
|
|
|
|
| 1057 |
" roc_auc[i] = auc(fpr[i], tpr[i])\n",
|
| 1058 |
"\n",
|
| 1059 |
-
"# Plot
|
| 1060 |
"plt.figure(figsize=(10, 7))\n",
|
|
|
|
| 1061 |
"for i in range(n_classes):\n",
|
| 1062 |
" plt.plot(fpr[i], tpr[i], label=f'{class_names[i]} (AUC = {roc_auc[i]:.2f})')\n",
|
| 1063 |
"\n",
|
|
|
|
| 1064 |
"plt.plot([0, 1], [0, 1], 'k--')\n",
|
|
|
|
|
|
|
| 1065 |
"plt.title('Multi-class ROC Curve')\n",
|
| 1066 |
"plt.xlabel('False Positive Rate')\n",
|
| 1067 |
"plt.ylabel('True Positive Rate')\n",
|
|
|
|
|
|
|
| 1068 |
"plt.legend(loc='lower right')\n",
|
| 1069 |
"plt.grid(True)\n",
|
|
|
|
|
|
|
| 1070 |
"plt.show()\n"
|
| 1071 |
]
|
| 1072 |
},
|
|
@@ -1474,52 +1506,6 @@
|
|
| 1474 |
"print(f\"Predicted Class: {predicted_class}\")\n",
|
| 1475 |
"print(f\"Confidence: {confidence_percentage:.2f}%\")"
|
| 1476 |
]
|
| 1477 |
-
},
|
| 1478 |
-
{
|
| 1479 |
-
"cell_type": "code",
|
| 1480 |
-
"execution_count": 1,
|
| 1481 |
-
"id": "eb2308ed",
|
| 1482 |
-
"metadata": {},
|
| 1483 |
-
"outputs": [
|
| 1484 |
-
{
|
| 1485 |
-
"ename": "FileNotFoundError",
|
| 1486 |
-
"evalue": "[Errno 2] No such file or directory: 'D:\\\\DR_Classification\\\\dataset\\\\splits\\\\test_labels.csv'",
|
| 1487 |
-
"output_type": "error",
|
| 1488 |
-
"traceback": [
|
| 1489 |
-
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
|
| 1490 |
-
"\u001b[31mFileNotFoundError\u001b[39m Traceback (most recent call last)",
|
| 1491 |
-
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[1]\u001b[39m\u001b[32m, line 12\u001b[39m\n\u001b[32m 9\u001b[39m new_dir = \u001b[33mr\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mD:\u001b[39m\u001b[33m\\\u001b[39m\u001b[33mDR_Classification\u001b[39m\u001b[33m\\\u001b[39m\u001b[33mdataset\u001b[39m\u001b[33m\\\u001b[39m\u001b[33msplitted-data\u001b[39m\u001b[33m\\\u001b[39m\u001b[33mtest\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 11\u001b[39m \u001b[38;5;66;03m# === Load the CSV ===\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m12\u001b[39m df = \u001b[43mpd\u001b[49m\u001b[43m.\u001b[49m\u001b[43mread_csv\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcsv_path\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 14\u001b[39m \u001b[38;5;66;03m# === Replace old path with new path in 'new_path' column ===\u001b[39;00m\n\u001b[32m 15\u001b[39m df[\u001b[33m'\u001b[39m\u001b[33mnew_path\u001b[39m\u001b[33m'\u001b[39m] = df[\u001b[33m'\u001b[39m\u001b[33mnew_path\u001b[39m\u001b[33m'\u001b[39m].str.replace(old_dir, new_dir, regex=\u001b[38;5;28;01mFalse\u001b[39;00m)\n",
|
| 1492 |
-
"\u001b[36mFile \u001b[39m\u001b[32md:\\DR_Classification\\.venv\\Lib\\site-packages\\pandas\\io\\parsers\\readers.py:1026\u001b[39m, in \u001b[36mread_csv\u001b[39m\u001b[34m(filepath_or_buffer, sep, delimiter, header, names, index_col, usecols, dtype, engine, converters, true_values, false_values, skipinitialspace, skiprows, skipfooter, nrows, na_values, keep_default_na, na_filter, verbose, skip_blank_lines, parse_dates, infer_datetime_format, keep_date_col, date_parser, date_format, dayfirst, cache_dates, iterator, chunksize, compression, thousands, decimal, lineterminator, quotechar, quoting, doublequote, escapechar, comment, encoding, encoding_errors, dialect, on_bad_lines, delim_whitespace, low_memory, memory_map, float_precision, storage_options, dtype_backend)\u001b[39m\n\u001b[32m 1013\u001b[39m kwds_defaults = _refine_defaults_read(\n\u001b[32m 1014\u001b[39m dialect,\n\u001b[32m 1015\u001b[39m delimiter,\n\u001b[32m (...)\u001b[39m\u001b[32m 1022\u001b[39m dtype_backend=dtype_backend,\n\u001b[32m 1023\u001b[39m )\n\u001b[32m 1024\u001b[39m kwds.update(kwds_defaults)\n\u001b[32m-> \u001b[39m\u001b[32m1026\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_read\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfilepath_or_buffer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwds\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 1493 |
-
"\u001b[36mFile \u001b[39m\u001b[32md:\\DR_Classification\\.venv\\Lib\\site-packages\\pandas\\io\\parsers\\readers.py:620\u001b[39m, in \u001b[36m_read\u001b[39m\u001b[34m(filepath_or_buffer, kwds)\u001b[39m\n\u001b[32m 617\u001b[39m _validate_names(kwds.get(\u001b[33m\"\u001b[39m\u001b[33mnames\u001b[39m\u001b[33m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m))\n\u001b[32m 619\u001b[39m \u001b[38;5;66;03m# Create the parser.\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m620\u001b[39m parser = \u001b[43mTextFileReader\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfilepath_or_buffer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwds\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 622\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m chunksize \u001b[38;5;129;01mor\u001b[39;00m iterator:\n\u001b[32m 623\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m parser\n",
|
| 1494 |
-
"\u001b[36mFile \u001b[39m\u001b[32md:\\DR_Classification\\.venv\\Lib\\site-packages\\pandas\\io\\parsers\\readers.py:1620\u001b[39m, in \u001b[36mTextFileReader.__init__\u001b[39m\u001b[34m(self, f, engine, **kwds)\u001b[39m\n\u001b[32m 1617\u001b[39m \u001b[38;5;28mself\u001b[39m.options[\u001b[33m\"\u001b[39m\u001b[33mhas_index_names\u001b[39m\u001b[33m\"\u001b[39m] = kwds[\u001b[33m\"\u001b[39m\u001b[33mhas_index_names\u001b[39m\u001b[33m\"\u001b[39m]\n\u001b[32m 1619\u001b[39m \u001b[38;5;28mself\u001b[39m.handles: IOHandles | \u001b[38;5;28;01mNone\u001b[39;00m = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m-> \u001b[39m\u001b[32m1620\u001b[39m \u001b[38;5;28mself\u001b[39m._engine = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_make_engine\u001b[49m\u001b[43m(\u001b[49m\u001b[43mf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mengine\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 1495 |
-
"\u001b[36mFile \u001b[39m\u001b[32md:\\DR_Classification\\.venv\\Lib\\site-packages\\pandas\\io\\parsers\\readers.py:1880\u001b[39m, in \u001b[36mTextFileReader._make_engine\u001b[39m\u001b[34m(self, f, engine)\u001b[39m\n\u001b[32m 1878\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[33m\"\u001b[39m\u001b[33mb\u001b[39m\u001b[33m\"\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m mode:\n\u001b[32m 1879\u001b[39m mode += \u001b[33m\"\u001b[39m\u001b[33mb\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m-> \u001b[39m\u001b[32m1880\u001b[39m \u001b[38;5;28mself\u001b[39m.handles = \u001b[43mget_handle\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 1881\u001b[39m \u001b[43m \u001b[49m\u001b[43mf\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1882\u001b[39m \u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1883\u001b[39m \u001b[43m \u001b[49m\u001b[43mencoding\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43moptions\u001b[49m\u001b[43m.\u001b[49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mencoding\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1884\u001b[39m \u001b[43m \u001b[49m\u001b[43mcompression\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43moptions\u001b[49m\u001b[43m.\u001b[49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mcompression\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1885\u001b[39m \u001b[43m \u001b[49m\u001b[43mmemory_map\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43moptions\u001b[49m\u001b[43m.\u001b[49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mmemory_map\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1886\u001b[39m \u001b[43m \u001b[49m\u001b[43mis_text\u001b[49m\u001b[43m=\u001b[49m\u001b[43mis_text\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1887\u001b[39m \u001b[43m \u001b[49m\u001b[43merrors\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43moptions\u001b[49m\u001b[43m.\u001b[49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mencoding_errors\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mstrict\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1888\u001b[39m \u001b[43m \u001b[49m\u001b[43mstorage_options\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43moptions\u001b[49m\u001b[43m.\u001b[49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mstorage_options\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1889\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1890\u001b[39m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m.handles \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 1891\u001b[39m f = \u001b[38;5;28mself\u001b[39m.handles.handle\n",
|
| 1496 |
-
"\u001b[36mFile \u001b[39m\u001b[32md:\\DR_Classification\\.venv\\Lib\\site-packages\\pandas\\io\\common.py:873\u001b[39m, in \u001b[36mget_handle\u001b[39m\u001b[34m(path_or_buf, mode, encoding, compression, memory_map, is_text, errors, storage_options)\u001b[39m\n\u001b[32m 868\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(handle, \u001b[38;5;28mstr\u001b[39m):\n\u001b[32m 869\u001b[39m \u001b[38;5;66;03m# Check whether the filename is to be opened in binary mode.\u001b[39;00m\n\u001b[32m 870\u001b[39m \u001b[38;5;66;03m# Binary mode does not support 'encoding' and 'newline'.\u001b[39;00m\n\u001b[32m 871\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m ioargs.encoding \u001b[38;5;129;01mand\u001b[39;00m \u001b[33m\"\u001b[39m\u001b[33mb\u001b[39m\u001b[33m\"\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m ioargs.mode:\n\u001b[32m 872\u001b[39m \u001b[38;5;66;03m# Encoding\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m873\u001b[39m handle = \u001b[38;5;28;43mopen\u001b[39;49m\u001b[43m(\u001b[49m\n\u001b[32m 874\u001b[39m \u001b[43m \u001b[49m\u001b[43mhandle\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 875\u001b[39m \u001b[43m \u001b[49m\u001b[43mioargs\u001b[49m\u001b[43m.\u001b[49m\u001b[43mmode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 876\u001b[39m \u001b[43m \u001b[49m\u001b[43mencoding\u001b[49m\u001b[43m=\u001b[49m\u001b[43mioargs\u001b[49m\u001b[43m.\u001b[49m\u001b[43mencoding\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 877\u001b[39m \u001b[43m \u001b[49m\u001b[43merrors\u001b[49m\u001b[43m=\u001b[49m\u001b[43merrors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 878\u001b[39m \u001b[43m \u001b[49m\u001b[43mnewline\u001b[49m\u001b[43m=\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[32m 879\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 880\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 881\u001b[39m \u001b[38;5;66;03m# Binary mode\u001b[39;00m\n\u001b[32m 882\u001b[39m handle = \u001b[38;5;28mopen\u001b[39m(handle, ioargs.mode)\n",
|
| 1497 |
-
"\u001b[31mFileNotFoundError\u001b[39m: [Errno 2] No such file or directory: 'D:\\\\DR_Classification\\\\dataset\\\\splits\\\\test_labels.csv'"
|
| 1498 |
-
]
|
| 1499 |
-
}
|
| 1500 |
-
],
|
| 1501 |
-
"source": [
|
| 1502 |
-
"import pandas as pd\n",
|
| 1503 |
-
"\n",
|
| 1504 |
-
"# === File paths ===\n",
|
| 1505 |
-
"csv_path = r\"D:\\DR_Classification\\dataset\\splits\\test_labels.csv\"\n",
|
| 1506 |
-
"output_csv_path = r\"D:\\DR_Classification\\dataset\\Splitted_data\\splits\\test_labels.csv\"\n",
|
| 1507 |
-
"\n",
|
| 1508 |
-
"# === Old and new base directory paths ===\n",
|
| 1509 |
-
"old_dir = r\"D:\\DR_Classification\\splits\\test\"\n",
|
| 1510 |
-
"new_dir = r\"D:\\DR_Classification\\dataset\\splitted-data\\test\"\n",
|
| 1511 |
-
"\n",
|
| 1512 |
-
"# === Load the CSV ===\n",
|
| 1513 |
-
"df = pd.read_csv(csv_path)\n",
|
| 1514 |
-
"\n",
|
| 1515 |
-
"# === Replace old path with new path in 'new_path' column ===\n",
|
| 1516 |
-
"df['new_path'] = df['new_path'].str.replace(old_dir, new_dir, regex=False)\n",
|
| 1517 |
-
"\n",
|
| 1518 |
-
"# === Save the updated CSV ===\n",
|
| 1519 |
-
"df.to_csv(output_csv_path, index=False)\n",
|
| 1520 |
-
"\n",
|
| 1521 |
-
"print(\"✅ CSV updated and saved at:\", output_csv_path)\n"
|
| 1522 |
-
]
|
| 1523 |
}
|
| 1524 |
],
|
| 1525 |
"metadata": {
|
|
|
|
| 1017 |
},
|
| 1018 |
{
|
| 1019 |
"cell_type": "code",
|
| 1020 |
+
"execution_count": null,
|
| 1021 |
"id": "560a2a1b",
|
| 1022 |
"metadata": {},
|
| 1023 |
"outputs": [
|
|
|
|
| 1035 |
"source": [
|
| 1036 |
"from sklearn.metrics import roc_curve, auc\n",
|
| 1037 |
"from sklearn.preprocessing import label_binarize\n",
|
| 1038 |
+
"import numpy as np\n",
|
| 1039 |
+
"import matplotlib.pyplot as plt\n",
|
| 1040 |
+
"import torch\n",
|
| 1041 |
+
"\n",
|
| 1042 |
+
"# Number of classes for Diabetic Retinopathy (DR) classification\n",
|
| 1043 |
+
"n_classes = 5 \n",
|
| 1044 |
"\n",
|
| 1045 |
+
"# Convert class labels to one-hot encoded format (needed for multi-class ROC)\n",
|
| 1046 |
+
"# Example: label 2 becomes [0, 0, 1, 0, 0]\n",
|
| 1047 |
"y_true_bin = label_binarize(all_labels, classes=[0, 1, 2, 3, 4])\n",
|
| 1048 |
+
"\n",
|
| 1049 |
+
"# Will hold predicted probabilities for each class\n",
|
| 1050 |
"y_scores = []\n",
|
| 1051 |
"\n",
|
| 1052 |
+
"# Set model to evaluation mode\n",
|
| 1053 |
"model.eval()\n",
|
| 1054 |
+
"\n",
|
| 1055 |
+
"# Disable gradient calculation for faster inference\n",
|
| 1056 |
"with torch.no_grad():\n",
|
| 1057 |
" for inputs, labels in test_loader:\n",
|
| 1058 |
" inputs = inputs.to(device)\n",
|
| 1059 |
+
" \n",
|
| 1060 |
+
" # Forward pass through the model\n",
|
| 1061 |
" outputs = model(inputs)\n",
|
| 1062 |
+
" \n",
|
| 1063 |
+
" # Apply softmax to get class probabilities\n",
|
| 1064 |
" probs = torch.softmax(outputs, dim=1)\n",
|
| 1065 |
+
" \n",
|
| 1066 |
+
" # Append the probabilities to y_scores list\n",
|
| 1067 |
" y_scores.extend(probs.cpu().numpy())\n",
|
| 1068 |
"\n",
|
| 1069 |
+
"# Convert the list of predictions to a NumPy array\n",
|
|
|
|
| 1070 |
"y_scores = np.array(y_scores)\n",
|
| 1071 |
"\n",
|
| 1072 |
+
"# Initialize dictionaries to store False Positive Rate (FPR), True Positive Rate (TPR), and AUC\n",
|
| 1073 |
+
"fpr, tpr, roc_auc = dict(), dict(), dict()\n",
|
| 1074 |
+
"\n",
|
| 1075 |
+
"# Compute ROC curve and AUC for each class\n",
|
| 1076 |
"for i in range(n_classes):\n",
|
| 1077 |
+
" # Calculate FPR and TPR for class `i`\n",
|
| 1078 |
" fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], y_scores[:, i])\n",
|
| 1079 |
+
" \n",
|
| 1080 |
+
" # Calculate Area Under Curve (AUC) for class `i`\n",
|
| 1081 |
" roc_auc[i] = auc(fpr[i], tpr[i])\n",
|
| 1082 |
"\n",
|
| 1083 |
+
"# Plot ROC curves for all classes\n",
|
| 1084 |
"plt.figure(figsize=(10, 7))\n",
|
| 1085 |
+
"\n",
|
| 1086 |
"for i in range(n_classes):\n",
|
| 1087 |
" plt.plot(fpr[i], tpr[i], label=f'{class_names[i]} (AUC = {roc_auc[i]:.2f})')\n",
|
| 1088 |
"\n",
|
| 1089 |
+
"# Plot the diagonal line (chance level)\n",
|
| 1090 |
"plt.plot([0, 1], [0, 1], 'k--')\n",
|
| 1091 |
+
"\n",
|
| 1092 |
+
"# Set plot title and axis labels\n",
|
| 1093 |
"plt.title('Multi-class ROC Curve')\n",
|
| 1094 |
"plt.xlabel('False Positive Rate')\n",
|
| 1095 |
"plt.ylabel('True Positive Rate')\n",
|
| 1096 |
+
"\n",
|
| 1097 |
+
"# Add legend and grid\n",
|
| 1098 |
"plt.legend(loc='lower right')\n",
|
| 1099 |
"plt.grid(True)\n",
|
| 1100 |
+
"\n",
|
| 1101 |
+
"# Show the plot\n",
|
| 1102 |
"plt.show()\n"
|
| 1103 |
]
|
| 1104 |
},
|
|
|
|
| 1506 |
"print(f\"Predicted Class: {predicted_class}\")\n",
|
| 1507 |
"print(f\"Confidence: {confidence_percentage:.2f}%\")"
|
| 1508 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1509 |
}
|
| 1510 |
],
|
| 1511 |
"metadata": {
|