| """ |
| |
| Notes: |
| |
| The script handles variable numbers of PRS and RRS probability values. |
| It combines the probabilities and labels, computes the ROC curve, calculates the AUC (Area Under the Curve), and plots the ROC curve using Matplotlib. |
| You can customize the script to read multiple PRS and RRS files by modifying the read_probabilities function and how you handle the input arguments. |
| Dependencies: |
| |
| Python 3 |
| numpy |
| matplotlib |
| scikit-learn |
| |
| Explanation of the Area Under the Curve (AUC): |
| |
| The Area Under the Curve (AUC) refers to the area under the Receiver Operating Characteristic (ROC) curve. The ROC curve is a graphical representation of a classifier's performance across all classification thresholds. It plots the True Positive Rate (TPR) against the False Positive Rate (FPR) at various threshold settings. |
| |
| True Positive Rate (TPR), also known as Sensitivity or Recall, is the proportion of actual positives that are correctly identified. |
| False Positive Rate (FPR) is the proportion of actual negatives that are incorrectly identified as positives. |
| The AUC provides a single scalar value that summarizes the performance of the classifier: |
| |
| An AUC of 1.0 indicates a perfect classifier. |
| An AUC of 0.5 suggests no discriminative ability (equivalent to random guessing). |
| An AUC between 0.5 and 1.0 indicates the degree to which the classifier can distinguish between the positive and negative classes. |
| Why is AUC important? |
| |
| More on interpreting the ROC Curve: |
| |
| The ROC curve plots the TPR against the FPR at various threshold levels. |
| The closer the curve follows the left-hand border and then the top border of the ROC space, the better the classifier. |
| The diagonal line represents the performance of a classifier that makes random guesses. |
| Understanding AUC Values: |
| |
| AUC = 0.90-1.00: Excellent |
| AUC = 0.80-0.90: Good |
| AUC = 0.70-0.80: Fair |
| AUC = 0.60-0.70: Poor |
| AUC = 0.50-0.60: Fail |
| |
| |
| |
| Threshold-Independent: AUC measures the classifier's ability to rank predictions without being dependent on a specific threshold. |
| Performance Metric: It provides a comprehensive measure of performance across all possible classification thresholds. |
| In summary, the AUC quantifies the overall ability of the model to discriminate between positive and negative classes. A higher AUC indicates better model performance. |
| |
| |
| |
| F1 = 2 * ( (precision * recall) / (precision + recall) ) |
| |
| precision = TP / (TP + FP) |
| |
| recall = TP / (TP + FN) |
| |
| |
| You can adjust the decimal percision by changing ".6f" to desired value in f'Best F1 Threshold: {best_thresh:.6f}' |
| |
| |
| """ |
| |
|
|
| |
|
|
| |
|
|
| |
|
|
| |
| import argparse |
| import numpy as np |
| import matplotlib.pyplot as plt |
| from sklearn.metrics import roc_curve, auc, f1_score |
| import csv |
|
|
| def read_probabilities_from_csv(filename): |
| """Read PRS and RRS probability values from a CSV file.""" |
| prs_probs = [] |
| rrs_probs = [] |
| with open(filename, 'r') as csvfile: |
| reader = csv.reader(csvfile) |
| headers = next(reader) |
| for row in reader: |
| |
| if len(row) >= 2: |
| prs_value = row[0].strip() |
| rrs_value = row[1].strip() |
| |
| if prs_value: |
| prs_probs.append(float(prs_value)) |
| |
| if rrs_value: |
| rrs_probs.append(float(rrs_value)) |
| return prs_probs, rrs_probs |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description='Compute ROC curve, best F1 score, and annotate thresholds.') |
| parser.add_argument('--input_csv', required=True, help='CSV file containing PRS and RRS probability values') |
| parser.add_argument('--output_file', default='roc_curve.png', help='Output file name for ROC curve plot') |
|
|
| args = parser.parse_args() |
|
|
| |
| prs_probs, rrs_probs = read_probabilities_from_csv(args.input_csv) |
|
|
| |
| prs_labels = [1] * len(prs_probs) |
| rrs_labels = [0] * len(rrs_probs) |
|
|
| |
| probs = np.array(prs_probs + rrs_probs) |
| labels = np.array(prs_labels + rrs_labels) |
|
|
| |
| fpr, tpr, thresholds = roc_curve(labels, probs) |
| roc_auc = auc(fpr, tpr) |
|
|
| |
| finite_idxs = np.where(np.isfinite(thresholds))[0] |
| fpr = fpr[finite_idxs] |
| tpr = tpr[finite_idxs] |
| thresholds = thresholds[finite_idxs] |
|
|
| |
| valid_thresholds_idxs = np.where((thresholds >= 0) & (thresholds <= 1))[0] |
| fpr = fpr[valid_thresholds_idxs] |
| tpr = tpr[valid_thresholds_idxs] |
| thresholds = thresholds[valid_thresholds_idxs] |
|
|
| |
| best_f1 = -1.0 |
| best_thresh = None |
| best_idx = None |
| for i, thresh in enumerate(thresholds): |
| predicted_labels = (probs >= thresh).astype(int) |
| current_f1 = f1_score(labels, predicted_labels) |
| if current_f1 > best_f1: |
| best_f1 = current_f1 |
| best_thresh = thresh |
| best_idx = i |
|
|
| |
| best_fpr = fpr[best_idx] |
| best_tpr = tpr[best_idx] |
|
|
| |
| plt.rcParams['font.family'] = 'Arial' |
|
|
| |
| fig, ax = plt.subplots(figsize=(10, 8)) |
| norm = plt.Normalize(vmin=thresholds.min(), vmax=thresholds.max()) |
| cmap = plt.cm.viridis |
|
|
| |
| for i in range(len(fpr) - 1): |
| x = fpr[i:i + 2] |
| y = tpr[i:i + 2] |
| z = thresholds[i] |
| ax.plot(x, y, color=cmap(norm(z)), lw=2.5) |
|
|
| |
| ax.plot([0, 1], [0, 1], color='gray', lw=2, linestyle='--') |
|
|
| |
| sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) |
| sm.set_array([]) |
| cbar = fig.colorbar(sm, ax=ax) |
| cbar.set_label('Threshold', fontsize=16) |
| cbar.ax.tick_params(labelsize=14) |
|
|
| |
| num_thresholds_to_annotate = 10 |
| idxs = np.linspace(0, len(thresholds) - 1, num_thresholds_to_annotate).astype(int) |
| for idx in idxs: |
| thresh = thresholds[idx] |
| ax.annotate(f'{thresh:.2f}', xy=(fpr[idx], tpr[idx]), |
| textcoords='offset points', xytext=(0, 10), |
| ha='center', fontsize=12, color='blue') |
|
|
| |
| |
|
|
| |
| ax.set_xlim([0.0, 1.0]) |
| ax.set_ylim([0.0, 1.05]) |
| ax.set_xlabel('False Positive Rate', fontsize=16) |
| ax.set_ylabel('True Positive Rate', fontsize=16) |
| ax.set_title('Receiver Operating Characteristic (ROC)', fontsize=18) |
| ax.tick_params(axis='both', which='major', labelsize=14) |
|
|
| |
| ax.grid(True, linestyle='--', linewidth=0.5, alpha=0.7) |
|
|
| |
| legend_text = (f'ROC curve (AUC = {roc_auc:.3f}, ' |
| f'Best F1 = {best_f1:.3f}, ' |
| f'Best F1 Threshold = {best_thresh:.3f})') |
| ax.legend([legend_text], loc="lower right", fontsize=12) |
|
|
| |
| plt.tight_layout() |
|
|
| |
| plt.savefig(args.output_file, dpi=300, format='png') |
| plt.show() |
|
|
| print(f"ROC curve saved to {args.output_file}") |
| print(f"Best F1 Score: {best_f1:.3f} at threshold {best_thresh:.3f}") |
|
|
| if __name__ == '__main__': |
| main() |
|
|