Spaces:
Runtime error
Runtime error
| import evaluate | |
| from rapidfuzz.distance.Levenshtein import distance, normalized_similarity | |
| import config | |
| BLEU = evaluate.load("saridormi/b_norm", cache_dir=config.CACHE_DIR) | |
| def bleu_fn(pred, ref, **kwargs): | |
| if "refs" in kwargs: | |
| return BLEU.compute(predictions=[pred] * len(kwargs["refs"]), references=kwargs["refs"])["b_norm"] | |
| return BLEU.compute(predictions=[pred], references=[ref])["b_norm"] | |
| METEOR = evaluate.load("meteor", cache_dir=config.CACHE_DIR) | |
| def meteor_fn(pred, ref, **kwargs): | |
| if "refs" in kwargs: | |
| return METEOR.compute(predictions=[pred] * len(kwargs["refs"]), references=kwargs["refs"])["meteor"] | |
| return METEOR.compute(predictions=[pred], references=[ref])["meteor"] | |
| ROUGE = evaluate.load("rouge", cache_dir=config.CACHE_DIR) | |
| def rouge1_fn(pred, ref, **kwargs): | |
| if "refs" in kwargs: | |
| return ROUGE.compute(predictions=[pred] * len(kwargs["refs"]), references=kwargs["refs"])["rouge1"] | |
| return ROUGE.compute(predictions=[pred], references=[ref])["rouge1"] | |
| def rouge2_fn(pred, ref, **kwargs): | |
| if "refs" in kwargs: | |
| return ROUGE.compute(predictions=[pred] * len(kwargs["refs"]), references=kwargs["refs"])["rouge2"] | |
| return ROUGE.compute(predictions=[pred], references=[ref])["rouge2"] | |
| def rougeL_fn(pred, ref, **kwargs): | |
| if "refs" in kwargs: | |
| return ROUGE.compute(predictions=[pred] * len(kwargs["refs"]), references=kwargs["refs"])["rougeL"] | |
| return ROUGE.compute(predictions=[pred], references=[ref])["rougeL"] | |
| BERTSCORE = evaluate.load("bertscore", cache_dir=config.CACHE_DIR) | |
| def bertscore_fn(pred, ref, **kwargs): | |
| if "refs" in kwargs: | |
| return BERTSCORE.compute(predictions=[pred], references=[kwargs["refs"]], model_type="distilbert-base-uncased")[ | |
| "f1" | |
| ][0] | |
| return BERTSCORE.compute(predictions=[pred], references=[ref], model_type="distilbert-base-uncased")["f1"][0] | |
| CHRF = evaluate.load("chrf") | |
| def chrf_fn(pred, ref, **kwargs): | |
| if "refs" in kwargs: | |
| return CHRF.compute(predictions=[pred], references=[kwargs["refs"]])["score"] | |
| return CHRF.compute(predictions=[pred], references=[[ref]])["score"] | |
| def edit_distance_fn(pred, ref, **kwargs): | |
| if "refs" in kwargs: | |
| scores = [distance(pred, ref) for ref in kwargs["refs"]] | |
| return sum(scores) / len(scores) | |
| return distance(pred, ref) | |
| def edit_distance_norm_fn(pred, ref, **kwargs): | |
| if "refs" in kwargs: | |
| scores = [normalized_similarity(pred, ref) for ref in kwargs["refs"]] | |
| return sum(scores) / len(scores) | |
| return normalized_similarity(pred, ref) | |
| AGGR_METRICS = { | |
| "editdist": edit_distance_fn, | |
| "editsim": edit_distance_norm_fn, | |
| "bleu": bleu_fn, | |
| "meteor": meteor_fn, | |
| "rouge1": rouge1_fn, | |
| "rouge2": rouge2_fn, | |
| "rougeL": rougeL_fn, | |
| "bertscore": bertscore_fn, | |
| "chrF": chrf_fn, | |
| } | |
| REL_METRICS = { | |
| "editdist": edit_distance_fn, | |
| } | |