Rom89823974978 commited on
Commit
4ab9c98
·
1 Parent(s): e32216e

Resolved error generation metrics

Browse files
evaluation/metrics/generation_metrics.py CHANGED
@@ -1,29 +1,61 @@
1
- """Generationlevel metrics using the `evaluate` library."""
2
 
3
  from __future__ import annotations
4
  from typing import Sequence, Mapping, Any
5
  import functools
6
- import evaluate # type: ignore[import]
 
 
 
 
 
7
 
8
 
9
  def _load(metric_name: str):
10
- """Cache metric loading to avoid redownloads."""
 
 
11
  return functools.lru_cache()(lambda: evaluate.load(metric_name))()
12
 
13
 
14
  def bleu(predictions: Sequence[str], references: Sequence[str]) -> float:
 
 
 
15
  metric = _load("sacrebleu")
16
- result: Mapping[str, Any] = metric.compute(predictions=predictions, references=[[r] for r in references])
 
 
 
 
 
17
  return result["score"] / 100.0
18
 
19
 
20
  def rouge_l(predictions: Sequence[str], references: Sequence[str]) -> float:
 
 
 
21
  metric = _load("rouge")
22
- result = metric.compute(predictions=predictions, references=references, rouge_types=["rougeL"])
23
- return result["rougeL"]
 
 
 
 
 
 
24
 
25
 
26
  def bert_score(predictions: Sequence[str], references: Sequence[str]) -> float:
 
 
 
27
  metric = _load("bertscore")
 
 
28
  result = metric.compute(predictions=predictions, references=references, lang="en")
29
- return float(sum(result["f1"]) / len(result["f1"]))
 
 
 
 
1
+ """Generation-level metrics using the `evaluate` library."""
2
 
3
  from __future__ import annotations
4
  from typing import Sequence, Mapping, Any
5
  import functools
6
+
7
+ # Attempt to import the `evaluate` package; if missing, set to None.
8
+ try:
9
+ import evaluate # type: ignore[import]
10
+ except ImportError:
11
+ evaluate = None
12
 
13
 
14
  def _load(metric_name: str):
15
+ """Cache metric loading to avoid re-downloads."""
16
+ if evaluate is None:
17
+ return None
18
  return functools.lru_cache()(lambda: evaluate.load(metric_name))()
19
 
20
 
21
  def bleu(predictions: Sequence[str], references: Sequence[str]) -> float:
22
+ """Compute BLEU via sacrebleu. If `evaluate` is missing, return 0.0."""
23
+ if evaluate is None:
24
+ return 0.0
25
  metric = _load("sacrebleu")
26
+ if metric is None:
27
+ return 0.0
28
+ result: Mapping[str, Any] = metric.compute(
29
+ predictions=predictions,
30
+ references=[[r] for r in references],
31
+ )
32
  return result["score"] / 100.0
33
 
34
 
35
  def rouge_l(predictions: Sequence[str], references: Sequence[str]) -> float:
36
+ """Compute ROUGE-L via `evaluate`. If `evaluate` is missing, return 0.0."""
37
+ if evaluate is None:
38
+ return 0.0
39
  metric = _load("rouge")
40
+ if metric is None:
41
+ return 0.0
42
+ result = metric.compute(
43
+ predictions=predictions,
44
+ references=references,
45
+ rouge_types=["rougeL"],
46
+ )
47
+ return result.get("rougeL", 0.0)
48
 
49
 
50
  def bert_score(predictions: Sequence[str], references: Sequence[str]) -> float:
51
+ """Compute BERTScore via `evaluate`. If `evaluate` is missing, return 0.0."""
52
+ if evaluate is None:
53
+ return 0.0
54
  metric = _load("bertscore")
55
+ if metric is None:
56
+ return 0.0
57
  result = metric.compute(predictions=predictions, references=references, lang="en")
58
+ f1_scores = result.get("f1", [])
59
+ if not f1_scores:
60
+ return 0.0
61
+ return float(sum(f1_scores) / len(f1_scores))
requirements.txt CHANGED
@@ -8,6 +8,7 @@ sentence-transformers>=2.7
8
  langchain>=0.1.0
9
  ragas>=0.1.0
10
  trulens-eval>=0.21.0
 
11
 
12
  # Data & science
13
  pandas>=2.2
 
8
  langchain>=0.1.0
9
  ragas>=0.1.0
10
  trulens-eval>=0.21.0
11
+ evaluate
12
 
13
  # Data & science
14
  pandas>=2.2