| """Helper functions for classification tasks.""" |
| import matplotlib.pyplot as plt |
| import pandas as pd |
| import numpy as np |
|
|
|
|
| def plot_metric_curve( |
| xvalues, yvalues, thresholds, title=None, |
| figsize=(8, 7), show_thresholds=True, show_legend=True, |
| ylabel='X', xlabel='Y', ax=None, text_delta=0.01, |
| label="Metric Curve", color="royalblue", show=False, |
| fill=None, |
| ): |
| """Plot a metric curve, e.g., PR curve or ROC curve.""" |
|
|
| if ax is None: |
| fig, ax = plt.subplots(1, 1, figsize=figsize) |
|
|
| ax.grid(alpha=0.3) |
| ax.set_title(title) |
| ax.set_ylabel(ylabel) |
| ax.set_xlabel(xlabel) |
| |
| ax.plot(xvalues, yvalues, marker='o', label=label, color=color) |
| ax.set_xlim(-0.08, 1.08) |
| ax.set_ylim(-0.08, 1.08) |
| |
| if fill is not None: |
| yticks = ax.get_yticks() |
| ax.fill_between(xvalues, yvalues, "", alpha=0.08, color=color) |
| |
| |
| ax.text(0.4, 0.5, fill, color=color) |
| ax.set_yticks(yticks) |
| ax.set_yticklabels([f"{y:.1f}" for y in yticks]) |
| ax.set_ylim(-0.08, 1.08) |
| |
| |
| if show_thresholds: |
| for x, y, t in zip(xvalues, yvalues, thresholds): |
| ax.text(x + text_delta, y + text_delta, np.round(t, 2), color=color, alpha=0.5) |
| |
| if show_legend: |
| ax.legend() |
| |
| if show: |
| plt.show() |
|
|