Spaces:
Sleeping
Sleeping
Commit
·
4ab9c98
1
Parent(s):
e32216e
Resolved error generation metrics
Browse files- evaluation/metrics/generation_metrics.py +39 -7
- requirements.txt +1 -0
evaluation/metrics/generation_metrics.py
CHANGED
|
@@ -1,29 +1,61 @@
|
|
| 1 |
-
"""Generation
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
from typing import Sequence, Mapping, Any
|
| 5 |
import functools
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
def _load(metric_name: str):
|
| 10 |
-
"""Cache metric loading to avoid re
|
|
|
|
|
|
|
| 11 |
return functools.lru_cache()(lambda: evaluate.load(metric_name))()
|
| 12 |
|
| 13 |
|
| 14 |
def bleu(predictions: Sequence[str], references: Sequence[str]) -> float:
|
|
|
|
|
|
|
|
|
|
| 15 |
metric = _load("sacrebleu")
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
return result["score"] / 100.0
|
| 18 |
|
| 19 |
|
| 20 |
def rouge_l(predictions: Sequence[str], references: Sequence[str]) -> float:
|
|
|
|
|
|
|
|
|
|
| 21 |
metric = _load("rouge")
|
| 22 |
-
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
|
| 26 |
def bert_score(predictions: Sequence[str], references: Sequence[str]) -> float:
|
|
|
|
|
|
|
|
|
|
| 27 |
metric = _load("bertscore")
|
|
|
|
|
|
|
| 28 |
result = metric.compute(predictions=predictions, references=references, lang="en")
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Generation-level metrics using the `evaluate` library."""
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
from typing import Sequence, Mapping, Any
|
| 5 |
import functools
|
| 6 |
+
|
| 7 |
+
# Attempt to import the `evaluate` package; if missing, set to None.
|
| 8 |
+
try:
|
| 9 |
+
import evaluate # type: ignore[import]
|
| 10 |
+
except ImportError:
|
| 11 |
+
evaluate = None
|
| 12 |
|
| 13 |
|
| 14 |
def _load(metric_name: str):
|
| 15 |
+
"""Cache metric loading to avoid re-downloads."""
|
| 16 |
+
if evaluate is None:
|
| 17 |
+
return None
|
| 18 |
return functools.lru_cache()(lambda: evaluate.load(metric_name))()
|
| 19 |
|
| 20 |
|
| 21 |
def bleu(predictions: Sequence[str], references: Sequence[str]) -> float:
|
| 22 |
+
"""Compute BLEU via sacrebleu. If `evaluate` is missing, return 0.0."""
|
| 23 |
+
if evaluate is None:
|
| 24 |
+
return 0.0
|
| 25 |
metric = _load("sacrebleu")
|
| 26 |
+
if metric is None:
|
| 27 |
+
return 0.0
|
| 28 |
+
result: Mapping[str, Any] = metric.compute(
|
| 29 |
+
predictions=predictions,
|
| 30 |
+
references=[[r] for r in references],
|
| 31 |
+
)
|
| 32 |
return result["score"] / 100.0
|
| 33 |
|
| 34 |
|
| 35 |
def rouge_l(predictions: Sequence[str], references: Sequence[str]) -> float:
|
| 36 |
+
"""Compute ROUGE-L via `evaluate`. If `evaluate` is missing, return 0.0."""
|
| 37 |
+
if evaluate is None:
|
| 38 |
+
return 0.0
|
| 39 |
metric = _load("rouge")
|
| 40 |
+
if metric is None:
|
| 41 |
+
return 0.0
|
| 42 |
+
result = metric.compute(
|
| 43 |
+
predictions=predictions,
|
| 44 |
+
references=references,
|
| 45 |
+
rouge_types=["rougeL"],
|
| 46 |
+
)
|
| 47 |
+
return result.get("rougeL", 0.0)
|
| 48 |
|
| 49 |
|
| 50 |
def bert_score(predictions: Sequence[str], references: Sequence[str]) -> float:
|
| 51 |
+
"""Compute BERTScore via `evaluate`. If `evaluate` is missing, return 0.0."""
|
| 52 |
+
if evaluate is None:
|
| 53 |
+
return 0.0
|
| 54 |
metric = _load("bertscore")
|
| 55 |
+
if metric is None:
|
| 56 |
+
return 0.0
|
| 57 |
result = metric.compute(predictions=predictions, references=references, lang="en")
|
| 58 |
+
f1_scores = result.get("f1", [])
|
| 59 |
+
if not f1_scores:
|
| 60 |
+
return 0.0
|
| 61 |
+
return float(sum(f1_scores) / len(f1_scores))
|
requirements.txt
CHANGED
|
@@ -8,6 +8,7 @@ sentence-transformers>=2.7
|
|
| 8 |
langchain>=0.1.0
|
| 9 |
ragas>=0.1.0
|
| 10 |
trulens-eval>=0.21.0
|
|
|
|
| 11 |
|
| 12 |
# Data & science
|
| 13 |
pandas>=2.2
|
|
|
|
| 8 |
langchain>=0.1.0
|
| 9 |
ragas>=0.1.0
|
| 10 |
trulens-eval>=0.21.0
|
| 11 |
+
evaluate
|
| 12 |
|
| 13 |
# Data & science
|
| 14 |
pandas>=2.2
|