Spaces:
Configuration error
Configuration error
| import numpy as np | |
| import pandas as pd | |
| import wandb | |
| from .decorators import TryExcept | |
| def create_custom_wandb_metric( | |
| xs: list, | |
| ys: list, | |
| classes: list, | |
| title: str = "Precision Recall Curve", | |
| x_axis_title: str = "Recall", | |
| y_axis_title: str = "Precision", | |
| ): | |
| """Creates a custom wandb metric similar to default wandb.plot.pr_curve | |
| Args: | |
| xs: list of N values to plot on the x-axis | |
| ys: list of N values to plot on the y-axis | |
| classes: class labels for each point (list of N values) | |
| title: plot title | |
| Returns: | |
| wandb object to log | |
| """ | |
| df = pd.DataFrame( | |
| { | |
| "class": classes, | |
| "y": ys, | |
| "x": xs, | |
| } | |
| ).round(3) | |
| return wandb.plot_table( | |
| "wandb/area-under-curve/v0", | |
| wandb.Table(dataframe=df), | |
| {"x": "x", "y": "y", "class": "class"}, | |
| { | |
| "title": title, | |
| "x-axis-title": x_axis_title, | |
| "y-axis-title": y_axis_title, | |
| }, | |
| ) | |
| def plot_curve_wandb( | |
| xs: np.ndarray, | |
| ys: np.ndarray, | |
| names: list = [], | |
| id: str = "precision-recall", | |
| title: str = "Precision Recall Curve", | |
| x_axis_title: str = "Recall", | |
| y_axis_title: str = "Precision", | |
| num_xs: int = 100, | |
| only_mean: bool = True, | |
| ): | |
| """adds a metric curve to wandb | |
| Args: | |
| xs: np.array of N values | |
| ys: np.array of C by N values where C is the number of classes | |
| names: dict of class names | |
| id: log id in wandb | |
| title: plot title in wandb | |
| num_xs: number of points to interpolate to | |
| only_mean: if True, only the mean curve is plotted | |
| """ | |
| # create new xs | |
| xs_new = np.linspace(xs[0], xs[-1], num_xs) | |
| # create arrays for logging | |
| xs_log = xs_new.tolist() | |
| ys_log = np.interp(xs_new, xs, np.mean(ys, axis=0)).tolist() | |
| classes = ["mean"] * len(xs_log) | |
| if not only_mean and len(names) == len(ys): | |
| for i, y in enumerate(ys): | |
| # add new xs | |
| xs_log.extend(xs_new) | |
| # interpolate y to new xs | |
| ys_log.extend(np.interp(xs_new, xs, y)) | |
| # add class names | |
| classes.extend([names[i]] * len(xs_new)) | |
| wandb.log( | |
| { | |
| id: create_custom_wandb_metric( | |
| xs_log, | |
| ys_log, | |
| classes, | |
| title, | |
| x_axis_title, | |
| y_axis_title, | |
| ) | |
| }, | |
| commit=False, | |
| ) | |