Spaces:
Running
Running
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from typing import List | |
| class bcolors: | |
| PURPLE = '\033[95m' | |
| BLUE = '\033[94m' | |
| GREEN = '\033[92m' | |
| WARNING = '\033[93m' | |
| RED = '\033[91m' | |
| ENDC = '\033[0m' | |
| BOLD = '\033[1m' | |
| UNDERLINE = '\033[4m' | |
| def plot_ranks(r1: List, r2: List, r1_label: str, r2_label: str, output: str) -> plt.axes: | |
| """ | |
| e.g.: | |
| df = rank_data(true_ranking, ranking, "actual", "predicted", "output") | |
| """ | |
| items = list(set(r1 + r2)) | |
| xs = [] | |
| for i in items: | |
| for lbl, l in zip((r1_label, r2_label), (r1, r2)): | |
| try: | |
| x = l.index(i) | |
| except ValueError: | |
| x = np.nan | |
| xs.append({"item": i, "version": lbl, "rank": x + 1}) | |
| df = pd.DataFrame(xs).pivot(index="item", columns="version", values="rank").T | |
| fig = plt.figure(figsize=(5, 10)) | |
| bumpchart( | |
| df, | |
| show_rank_axis=False, | |
| scatter=True, | |
| ax=fig.gca(), | |
| holes=False, | |
| line_args={"linewidth": 5, "alpha": 0.5}, | |
| scatter_args={"s": 100, "alpha": 0.8}, | |
| ) | |
| plt.savefig(f"{output}.png", dpi=150, bbox_inches="tight") | |
| return fig | |
| def bumpchart( | |
| df, | |
| show_rank_axis=True, | |
| rank_axis_distance=1.1, | |
| ax=None, | |
| scatter=False, | |
| holes=False, | |
| line_args={}, | |
| scatter_args={}, | |
| hole_args={}, | |
| ): | |
| if ax is None: | |
| left_yaxis = plt.gca() | |
| else: | |
| left_yaxis = ax | |
| # Creating the right axis. | |
| right_yaxis = left_yaxis.twinx() | |
| axes = [left_yaxis, right_yaxis] | |
| # Creating the far right axis if show_rank_axis is True | |
| if show_rank_axis: | |
| far_right_yaxis = left_yaxis.twinx() | |
| axes.append(far_right_yaxis) | |
| for col in df.columns: | |
| y = df[col] | |
| x = df.index.values | |
| # Plotting blank points on the right axis/axes | |
| # so that they line up with the left axis. | |
| for axis in axes[1:]: | |
| axis.plot(x, y, alpha=0) | |
| left_yaxis.plot(x, y, **line_args, solid_capstyle="round") | |
| # Adding scatter plots | |
| if scatter: | |
| left_yaxis.scatter(x, y, **scatter_args) | |
| # Adding see-through holes | |
| if holes: | |
| bg_color = left_yaxis.get_facecolor() | |
| left_yaxis.scatter(x, y, color=bg_color, **hole_args) | |
| # Number of lines | |
| lines = len(df.columns) | |
| y_ticks = [*range(1, lines + 1)] | |
| # Configuring the axes so that they line up well. | |
| for axis in axes: | |
| axis.invert_yaxis() | |
| axis.set_yticks(y_ticks) | |
| axis.set_ylim((lines + 0.5, 0.5)) | |
| # Sorting the labels to match the ranks. | |
| left_labels = df.iloc[0].sort_values().index | |
| right_labels = df.iloc[-1].sort_values().index | |
| left_yaxis.set_yticklabels(left_labels) | |
| right_yaxis.set_yticklabels(right_labels) | |
| # Setting the position of the far right axis so that it doesn't overlap with the right axis | |
| if show_rank_axis: | |
| far_right_yaxis.spines["right"].set_position(("axes", rank_axis_distance)) | |
| return axes | |