| | import logging |
| | import math |
| | import random |
| | import re |
| | import string |
| | from collections.abc import Iterable |
| | from typing import List |
| |
|
| | import numpy as np |
| | import sacrebleu |
| |
|
| | from lm_eval.api.registry import register_aggregation, register_metric |
| |
|
| |
|
| | eval_logger = logging.getLogger("lm-eval") |
| |
|
| |
|
| | |
| | @register_aggregation("bypass") |
| | def bypass_agg(arr): |
| | return 999 |
| |
|
| |
|
| | @register_aggregation("mean") |
| | def mean(arr): |
| | return sum(arr) / len(arr) |
| |
|
| |
|
| | @register_aggregation("median") |
| | def median(arr): |
| | return arr[len(arr) // 2] |
| |
|
| |
|
| | |
| | |
| | @register_aggregation("perplexity") |
| | def perplexity(items): |
| | return math.exp(-mean(items)) |
| |
|
| |
|
| | @register_aggregation("weighted_perplexity") |
| | def weighted_perplexity(items): |
| | return math.exp(-weighted_mean(items)) |
| |
|
| |
|
| | @register_aggregation("bits_per_byte") |
| | def bits_per_byte(items): |
| | return -weighted_mean(items) / math.log(2) |
| |
|
| |
|
| | @register_aggregation("f1") |
| | def f1_score(items): |
| | from sklearn.metrics import f1_score |
| |
|
| | unzipped_list = list(zip(*items)) |
| | golds = unzipped_list[0] |
| | preds = unzipped_list[1] |
| | fscore = f1_score(golds, preds) |
| |
|
| | return np.max(fscore) |
| |
|
| |
|
| | @register_aggregation("matthews_corrcoef") |
| | def matthews_corrcoef(items): |
| | from sklearn.metrics import matthews_corrcoef |
| |
|
| | unzipped_list = list(zip(*items)) |
| | golds = unzipped_list[0] |
| | preds = unzipped_list[1] |
| | return matthews_corrcoef(golds, preds) |
| |
|
| |
|
| | @register_aggregation("bleu") |
| | def bleu(items): |
| | """The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric |
| | for evaluating a generated sentence to a reference sentence. It counts matching |
| | n-grams in the candidate translation to n-grams in the reference text, where |
| | 1-gram or unigram would be each token and a bigram comparison would be each |
| | word pair. The comparison is made regardless of word order |
| | Source: https://machinelearningmastery.com/calculate-bleu-score-for-text-python/ |
| | Paper: https://www.aclweb.org/anthology/P02-1040/ |
| | |
| | Higher is better |
| | """ |
| | refs = list(zip(*items))[0] |
| | preds = list(zip(*items))[1] |
| | refs, preds = _sacreformat(refs, preds) |
| | return sacrebleu.corpus_bleu(preds, refs).score |
| |
|
| |
|
| | @register_aggregation("chrf") |
| | def chrf(items): |
| | """chrF++ is a tool for automatic evaluation of machine translation output |
| | based on character n-gram precision and recall enhanced with word n-grams. |
| | Source: https://github.com/m-popovic/chrF |
| | Paper: https://www.aclweb.org/anthology/W15-3049.pdf |
| | |
| | Higher is better # TODO I think |
| | """ |
| | refs = list(zip(*items))[0] |
| | preds = list(zip(*items))[1] |
| | refs, preds = _sacreformat(refs, preds) |
| | return sacrebleu.corpus_chrf(preds, refs).score |
| |
|
| |
|
| | @register_aggregation("ter") |
| | def ter(items): |
| | """Translation Error Rate is an error metric for machine translation that |
| | measures the number of edits required to change a system output into one |
| | of the references |
| | Source: http://www.cs.umd.edu/~snover/tercom/ |
| | Paper: http://mt-archive.info/AMTA-2006-Snover.pdf |
| | |
| | Lower is better |
| | """ |
| | refs = list(zip(*items))[0] |
| | preds = list(zip(*items))[1] |
| | refs, preds = _sacreformat(refs, preds) |
| | return sacrebleu.corpus_ter(preds, refs).score |
| |
|
| |
|
| | @register_aggregation("brier_score") |
| | def brier_score(items): |
| | gold, predictions = list(zip(*items)) |
| | bs, num_class = np.array(predictions).shape |
| |
|
| | gold = list(gold) |
| | gold_one_hot = np.eye(num_class)[gold] |
| | return np.mean(np.sum((predictions - gold_one_hot) ** 2, axis=1)) |
| |
|
| |
|
| | @register_metric( |
| | metric="brier_score", |
| | higher_is_better=False, |
| | output_type=["multiple_choice"], |
| | aggregation="brier_score", |
| | ) |
| | def brier_score_fn(items): |
| | return items |
| |
|
| |
|
| | @register_metric( |
| | metric="acc", |
| | higher_is_better=True, |
| | output_type=["loglikelihood", "multiple_choice"], |
| | aggregation="mean", |
| | ) |
| | def acc_fn(items): |
| | return items |
| |
|
| |
|
| | @register_metric( |
| | metric="acc_norm", |
| | higher_is_better=True, |
| | output_type=["loglikelihood", "multiple_choice"], |
| | aggregation="mean", |
| | ) |
| | def acc_norm_fn(items): |
| | return items |
| |
|
| |
|
| | @register_metric( |
| | metric="acc_mutual_info", |
| | higher_is_better=True, |
| | output_type="multiple_choice", |
| | aggregation="mean", |
| | ) |
| | def acc_mutual_info_fn(items): |
| | return items |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | def exact_match_hf_evaluate( |
| | predictions, |
| | references, |
| | regexes_to_ignore=None, |
| | ignore_case=False, |
| | ignore_punctuation=False, |
| | ignore_numbers=False, |
| | ): |
| | if regexes_to_ignore is not None: |
| | for s in regexes_to_ignore: |
| | predictions = np.array([re.sub(s, "", x) for x in predictions]) |
| | references = np.array([re.sub(s, "", x) for x in references]) |
| | else: |
| | predictions = np.asarray(predictions) |
| | references = np.asarray(references) |
| |
|
| | if ignore_case: |
| | predictions = np.char.lower(predictions) |
| | references = np.char.lower(references) |
| |
|
| | if ignore_punctuation: |
| | repl_table = string.punctuation.maketrans("", "", string.punctuation) |
| | predictions = np.char.translate(predictions, table=repl_table) |
| | references = np.char.translate(references, table=repl_table) |
| |
|
| | if ignore_numbers: |
| | repl_table = string.digits.maketrans("", "", string.digits) |
| | predictions = np.char.translate(predictions, table=repl_table) |
| | references = np.char.translate(references, table=repl_table) |
| |
|
| | score_list = predictions == references |
| |
|
| | return {"exact_match": np.mean(score_list)} |
| |
|
| |
|
| | |
| |
|
| |
|
| | @register_metric( |
| | metric="exact_match", |
| | higher_is_better=True, |
| | output_type="generate_until", |
| | aggregation="mean", |
| | ) |
| | def exact_match_fn(**kwargs): |
| | return exact_match_hf_evaluate(**kwargs) |
| |
|
| |
|
| | @register_metric( |
| | metric="perplexity", |
| | higher_is_better=False, |
| | output_type="loglikelihood", |
| | aggregation="perplexity", |
| | ) |
| | def perplexity_fn(items): |
| | return items |
| |
|
| |
|
| | @register_metric( |
| | metric="word_perplexity", |
| | higher_is_better=False, |
| | output_type="loglikelihood_rolling", |
| | aggregation="weighted_perplexity", |
| | ) |
| | def word_perplexity_fn(items): |
| | return items |
| |
|
| |
|
| | @register_metric( |
| | metric="byte_perplexity", |
| | higher_is_better=False, |
| | output_type="loglikelihood_rolling", |
| | aggregation="weighted_perplexity", |
| | ) |
| | def byte_perplexity_fn(items): |
| | return items |
| |
|
| |
|
| | @register_metric( |
| | metric="bits_per_byte", |
| | higher_is_better=False, |
| | output_type="loglikelihood_rolling", |
| | aggregation="bits_per_byte", |
| | ) |
| | def bits_per_byte_fn(items): |
| | return items |
| |
|
| |
|
| | def pop_stddev(arr): |
| | mu = mean(arr) |
| | return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / len(arr)) |
| |
|
| |
|
| | def sample_stddev(arr): |
| | mu = mean(arr) |
| | return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / (len(arr) - 1)) |
| |
|
| |
|
| | def mean_stderr(arr): |
| | return sample_stddev(arr) / math.sqrt(len(arr)) |
| |
|
| |
|
| | @register_metric( |
| | metric="bypass", |
| | higher_is_better=True, |
| | output_type=["loglikelihood", "multiple_choice", "generate_until"], |
| | aggregation="bypass", |
| | ) |
| | def bypass(items): |
| | return None |
| |
|
| |
|
| | @register_metric( |
| | metric="mcc", |
| | higher_is_better=True, |
| | output_type="multiple_choice", |
| | aggregation="matthews_corrcoef", |
| | ) |
| | def mcc_fn(items): |
| | return items |
| |
|
| |
|
| | @register_metric( |
| | metric="f1", |
| | higher_is_better=True, |
| | output_type="multiple_choice", |
| | aggregation="f1", |
| | ) |
| | def f1_fn(items): |
| | return items |
| |
|
| |
|
| | @register_metric( |
| | metric="bleu", |
| | higher_is_better=True, |
| | output_type="generate_until", |
| | aggregation="bleu", |
| | ) |
| | def bleu_fn(items): |
| | return items |
| |
|
| |
|
| | @register_metric( |
| | metric="chrf", |
| | higher_is_better=True, |
| | output_type="generate_until", |
| | aggregation="chrf", |
| | ) |
| | def chrf_fn(items): |
| | return items |
| |
|
| |
|
| | @register_metric( |
| | metric="ter", |
| | higher_is_better=True, |
| | output_type="generate_until", |
| | aggregation="ter", |
| | ) |
| | def ter_fn(items): |
| | return items |
| |
|
| |
|
| | @register_metric( |
| | metric="acc_all", |
| | higher_is_better=True, |
| | output_type="loglikelihood", |
| | aggregation="mean", |
| | ) |
| | def acc_all(items): |
| | |
| | question_scoring_dict = {} |
| | preds = list(zip(*items))[0] |
| | docs = list(zip(*items))[1] |
| |
|
| | for doc, pred in zip(docs, preds): |
| | paragraph_id = doc["idx"]["paragraph"] |
| | question_id = doc["idx"]["question"] |
| | if (paragraph_id, question_id) not in question_scoring_dict: |
| | question_scoring_dict[(paragraph_id, question_id)] = [] |
| |
|
| | gold_label = doc["label"] == 1 |
| |
|
| | question_scoring_dict[(paragraph_id, question_id)].append(gold_label == pred) |
| | acc = np.mean([int(all(x)) for x in question_scoring_dict.values()]) |
| | return acc |
| |
|
| |
|
| | def acc_all_stderr(items): |
| | |
| | question_scoring_dict = {} |
| | preds = list(zip(*items))[0] |
| | docs = list(zip(*items))[1] |
| |
|
| | for doc, pred in zip(docs, preds): |
| | question_id = doc["idx"]["question"] |
| | if question_id not in question_scoring_dict: |
| | question_scoring_dict[question_id] = [] |
| |
|
| | gold_label = doc["label"] == 1 |
| | question_scoring_dict[question_id].append(gold_label == pred) |
| |
|
| | acc = mean_stderr([int(all(x)) for x in question_scoring_dict.values()]) |
| | return acc |
| |
|
| |
|
| | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): |
| | """Compute max metric between prediction and each ground truth.""" |
| | scores_for_ground_truths = [] |
| | for ground_truth in ground_truths: |
| | score = metric_fn(prediction, ground_truth) |
| | scores_for_ground_truths.append(score) |
| | return max(scores_for_ground_truths) |
| |
|
| |
|
| | def weighted_mean(items): |
| | a, b = zip(*items) |
| | return sum(a) / sum(b) |
| |
|
| |
|
| | def is_non_str_iterable(obj): |
| | return isinstance(obj, Iterable) and not isinstance(obj, str) |
| |
|
| |
|
| | def _sacreformat(refs, preds): |
| | """Format refs and preds for sacrebleu corpus calculation. It is very particular""" |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | if not is_non_str_iterable(refs): |
| | refs = list(refs) |
| | if not is_non_str_iterable(refs[0]): |
| | refs = [[ref] for ref in refs] |
| | refs = list(zip(*refs)) |
| | |
| |
|
| | |
| | if not is_non_str_iterable(preds): |
| | preds = list(preds) |
| | if is_non_str_iterable(preds[0]): |
| | assert len(preds[0]) == 1, f"Pred must be a str, was {preds[0]}" |
| | preds = [pred[0] for pred in preds] |
| |
|
| | return refs, preds |
| |
|
| |
|
| | |
| |
|
| |
|
| | class _bootstrap_internal: |
| | def __init__(self, f, n) -> None: |
| | self.f = f |
| | self.n = n |
| |
|
| | def __call__(self, v): |
| | i, xs = v |
| | rnd = random.Random() |
| | rnd.seed(i) |
| | res = [] |
| | for _ in range(self.n): |
| | res.append(self.f(rnd.choices(xs, k=len(xs)))) |
| | return res |
| |
|
| |
|
| | def bootstrap_stderr(f, xs, iters): |
| | import multiprocessing as mp |
| |
|
| | pool = mp.Pool(mp.cpu_count()) |
| | |
| | |
| | |
| | |
| | |
| | |
| | res = [] |
| | chunk_size = min(1000, iters) |
| | from tqdm import tqdm |
| |
|
| | print("bootstrapping for stddev:", f.__name__) |
| | for bootstrap in tqdm( |
| | pool.imap( |
| | _bootstrap_internal(f, chunk_size), |
| | [(i, xs) for i in range(iters // chunk_size)], |
| | ), |
| | total=iters // chunk_size, |
| | ): |
| | |
| | res.extend(bootstrap) |
| |
|
| | pool.close() |
| | return sample_stddev(res) |
| |
|
| |
|
| | def stderr_for_metric(metric, bootstrap_iters: int): |
| | if bootstrap_iters <= 0: |
| | |
| | return None |
| |
|
| | bootstrappable = [ |
| | median, |
| | matthews_corrcoef, |
| | f1_score, |
| | perplexity, |
| | bleu, |
| | chrf, |
| | ter, |
| | ] |
| |
|
| | if metric in bootstrappable: |
| | return lambda x: bootstrap_stderr(metric, x, iters=bootstrap_iters) |
| |
|
| | stderr = {mean: mean_stderr, acc_all: acc_all_stderr} |
| |
|
| | return stderr.get(metric, None) |
| |
|
| |
|
| | def pooled_sample_stderr(stderrs: List[float], sizes: List[int]): |
| | |
| | |
| | |
| |
|
| | assert len(stderrs) == len(sizes) |
| |
|
| | |
| | |
| | |
| | |
| | pooled_sample_var = ( |
| | sum([(size - 1) * stderr**2 * size for size, stderr in zip(sizes, stderrs)]) |
| | ) / (sum(sizes) - len(sizes)) |
| |
|
| | return np.sqrt(pooled_sample_var / sum(sizes)) |
| |
|
| |
|
| | def combined_sample_stderr(stderrs: List[float], sizes: List[int], metrics=None): |
| | assert ( |
| | metrics is not None |
| | ), "Need to pass a list of each subtask's metric for this stderr aggregation" |
| | assert len(stderrs) == len(sizes) and len(sizes) == len(metrics) |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| |
|
| | |
| | variance = stderrs[0] ** 2 |
| | curr_size = sizes[0] |
| | curr_score = metrics[0] |
| |
|
| | for stderr, size, score in zip(stderrs[1:], sizes[1:], metrics[1:]): |
| | curr_score = ((curr_score * curr_size) + (score * size)) / ( |
| | curr_size + size |
| | ) |
| |
|
| | variance = ((curr_size - 1) * variance + (size - 1) * (stderr**2)) / ( |
| | curr_size + size - 1 |
| | ) + curr_size * size / ((curr_size + size) * (curr_size + size - 1)) * ( |
| | curr_score - score |
| | ) ** 2 |
| |
|
| | return np.sqrt(variance) |
| |
|
| |
|
| | def aggregate_subtask_metrics(metrics, sizes, weight_by_size=True): |
| | |
| | |
| | |
| | if not weight_by_size: |
| | sizes = [1] * len(sizes) |
| |
|
| | assert len(metrics) == len(sizes) |
| |
|
| | return sum([metric * size for metric, size in zip(metrics, sizes)]) / sum(sizes) |
| |
|