| from sklearn.metrics import roc_curve, precision_recall_curve | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| def eval_binary_classification(pred: np.array, true: np.array): | |
| plt.figure(figsize=(12, 6)) | |
| eval_roc_curve(pred, true) | |
| eval_pr_curve(pred, true) | |
| plt.tight_layout() | |
| plt.show() | |
| def eval_pr_curve(pred: np.array, true: np.array): | |
| precision, recall, _ = precision_recall_curve(true, pred) | |
| plt.subplot(1, 2, 1) | |
| plt.plot(recall, precision, label="Precision-Recall Curve", color="red") | |
| plt.ylim(0) | |
| plt.xlabel("Recall") | |
| plt.ylabel("Precision") | |
| plt.title("Precision-Recall Curve") | |
| plt.legend(loc="lower right") | |
| def eval_roc_curve(pred: np.array, true: np.array) -> None: | |
| false_pos_rate, true_pos_rate, _ = roc_curve(true, pred) | |
| plt.subplot(1, 2, 2) | |
| plt.plot(false_pos_rate, true_pos_rate, label="ROC Curve") | |
| plt.plot([0, 1], [0, 1], linestyle="--", label="Random Guessing Model") | |
| plt.title("ROC Curve vs. Random") | |
| plt.xlabel("False Positive Rate") | |
| plt.ylabel("True Positive Rate") | |
| plt.legend(loc="lower right") | |