yermandy's picture
init
c29babb
import numpy as np
import pandas as pd
import wandb
from .decorators import TryExcept
@TryExcept()
def create_custom_wandb_metric(
xs: list,
ys: list,
classes: list,
title: str = "Precision Recall Curve",
x_axis_title: str = "Recall",
y_axis_title: str = "Precision",
):
"""Creates a custom wandb metric similar to default wandb.plot.pr_curve
Args:
xs: list of N values to plot on the x-axis
ys: list of N values to plot on the y-axis
classes: class labels for each point (list of N values)
title: plot title
Returns:
wandb object to log
"""
df = pd.DataFrame(
{
"class": classes,
"y": ys,
"x": xs,
}
).round(3)
return wandb.plot_table(
"wandb/area-under-curve/v0",
wandb.Table(dataframe=df),
{"x": "x", "y": "y", "class": "class"},
{
"title": title,
"x-axis-title": x_axis_title,
"y-axis-title": y_axis_title,
},
)
@TryExcept()
def plot_curve_wandb(
xs: np.ndarray,
ys: np.ndarray,
names: list = [],
id: str = "precision-recall",
title: str = "Precision Recall Curve",
x_axis_title: str = "Recall",
y_axis_title: str = "Precision",
num_xs: int = 100,
only_mean: bool = True,
):
"""adds a metric curve to wandb
Args:
xs: np.array of N values
ys: np.array of C by N values where C is the number of classes
names: dict of class names
id: log id in wandb
title: plot title in wandb
num_xs: number of points to interpolate to
only_mean: if True, only the mean curve is plotted
"""
# create new xs
xs_new = np.linspace(xs[0], xs[-1], num_xs)
# create arrays for logging
xs_log = xs_new.tolist()
ys_log = np.interp(xs_new, xs, np.mean(ys, axis=0)).tolist()
classes = ["mean"] * len(xs_log)
if not only_mean and len(names) == len(ys):
for i, y in enumerate(ys):
# add new xs
xs_log.extend(xs_new)
# interpolate y to new xs
ys_log.extend(np.interp(xs_new, xs, y))
# add class names
classes.extend([names[i]] * len(xs_new))
wandb.log(
{
id: create_custom_wandb_metric(
xs_log,
ys_log,
classes,
title,
x_axis_title,
y_axis_title,
)
},
commit=False,
)