Spaces:
Running
Running
| """ | |
| Baseline: based on most-common answer | |
| """ | |
| import pandas as pd | |
| import numpy as np | |
| from tqdm import tqdm | |
| from .metrics import mapk, rank_biased_overlap | |
| from .plots import plot_ranks | |
| import logging | |
| from typing import List, Callable, Optional | |
| from rouge_score import rouge_scorer as rs | |
| from collections import Counter | |
| import random | |
| logger = logging.getLogger(__name__) | |
| tol = 0.001 | |
| class MCARank: | |
| """ | |
| Baseline method: based on most common answer | |
| """ | |
| def __init__( | |
| self, | |
| MODELS: List, | |
| evaluator: Callable, | |
| true_ranking: Optional[List] = None, | |
| show_progress: Optional[bool] = False, | |
| ): | |
| self.MODELS = MODELS | |
| self.N = len(MODELS) | |
| self.evaluate = evaluator | |
| self.true_ranking = true_ranking | |
| self.show_progress = show_progress | |
| def fit(self, df: pd.DataFrame, measure: Optional[str]='equality', p: float = 0): | |
| """ | |
| df: Dataframe where each row is a benchmark instance, | |
| and there is a column with the output for each Model | |
| measure: decides how the most common answer is decided. | |
| p - is the noise level to include (only used for noisy-equality) | |
| """ | |
| assert set(self.MODELS) == set(df.columns), "Benchmark data models inconsistent with models to be ranked." | |
| if measure == 'equality': | |
| # Select the most common answer per question | |
| mca = df.mode(axis=1).iloc[:, 0] | |
| # Count all the times each model answered the most common one | |
| wins = df.eq(mca, axis=0).astype(int) | |
| self.ranking = wins.sum().sort_values(ascending=False).index.to_list() | |
| elif measure == 'noisy_equality': | |
| # Most common answer | |
| mca = df.mode(axis=1).iloc[:, 0] | |
| perturb = lambda x: not x if (random.random() <= p) else x | |
| def __noisy_equality(x, mca): | |
| wins = (x == mca).apply(perturb) | |
| return wins | |
| wins = df.apply(__noisy_equality, axis='rows', args=(mca, )) | |
| self.ranking = wins.sum().sort_values(ascending=False).index.to_list() | |
| elif measure == 'rouge': | |
| MODELS = df.columns.to_list() | |
| SIZE = 256 | |
| def __mca(x): | |
| """ Most Commmon Answer, as the top k bigrams across all outputs """ | |
| cs = [rs._create_ngrams(x[m], n=2) for m in MODELS] | |
| c = sum(cs, Counter()) | |
| return Counter(dict(c.most_common(SIZE))) | |
| def __score_mca(x): | |
| """ Rouge score computed relative to most-common-answer """ | |
| res = {} | |
| for m in MODELS: | |
| p_n = rs._create_ngrams(x[m], n=2) | |
| res[m] = rs._score_ngrams(x.mca, p_n).fmeasure | |
| return pd.Series(res) | |
| df['mca'] = df.apply(__mca, axis=1) | |
| # Winning model based on best ROUGE score for each question | |
| win_rates = df.apply(__score_mca, axis=1).idxmax(axis=1).value_counts() | |
| win_rate_rank = win_rates.index.tolist() | |
| # include models with nowins at the bottom | |
| no_wins = list(set(MODELS) - set(win_rate_rank)) | |
| self.ranking = win_rate_rank + no_wins | |
| else: | |
| raise ValueError(f"Measure {measure} not understood.") | |
| logger.info(f"Estimated ranks (best to worst): {self.ranking}") | |
| logger.info(f"True ranking: {self.true_ranking}") | |
| logger.info(f"RBO measure: {self.measure()}") | |
| return self.ranking # Best to worst | |
| def measure(self, metric='rbo', k=5, p=0.95) -> float: | |
| """ | |
| Report metric related to self-rank | |
| """ | |
| if metric not in ['rbo', 'mapk']: | |
| raise ValueError(f"Metric {metric} not supported (use 'rbo'/'mapk').") | |
| if hasattr(self, 'ranking'): | |
| if self.true_ranking is not None: | |
| if metric == 'mapk': | |
| if k > len(self.true_ranking): | |
| logger.warning(f"MAPk metric is for k={len(self.true_ranking)}, and not k={k}.") | |
| actual = [self.true_ranking[:k]] | |
| pred = [self.ranking[:k]] | |
| return mapk(actual, pred, k=k) | |
| elif metric == 'rbo': | |
| return rank_biased_overlap(self.true_ranking, self.ranking, p=p) | |
| else: | |
| raise ValueError(f"Metric {metric} not understood.") | |
| else: | |
| raise ValueError("True ranking not available for metric calculation.") | |
| else: | |
| raise ValueError("Ranking not estimated. Run 'fit' first.") | |
| def plot(self, caselabel="output"): | |
| if hasattr(self, 'ranking') & (self.true_ranking is not None): | |
| plot_ranks(self.true_ranking, self.ranking, "actual", "estimated", caselabel) | |