aauss commited on
Commit
e6bd448
·
1 Parent(s): dbb13b7

Improve regex, add tests, and fail early with wrong input lengths.

Browse files
Files changed (4) hide show
  1. .gitignore +2 -0
  2. tests.py +0 -17
  3. tests/test_metric.py +53 -0
  4. tram_accuracy.py +9 -2
.gitignore CHANGED
@@ -8,3 +8,5 @@ wheels/
8
 
9
  # Virtual environments
10
  .venv
 
 
 
8
 
9
  # Virtual environments
10
  .venv
11
+
12
+ .DS_Store
tests.py DELETED
@@ -1,17 +0,0 @@
1
- test_cases = [
2
- {
3
- "predictions": [0, 0],
4
- "references": [1, 1],
5
- "result": {"metric_score": 0}
6
- },
7
- {
8
- "predictions": [1, 1],
9
- "references": [1, 1],
10
- "result": {"metric_score": 1}
11
- },
12
- {
13
- "predictions": [1, 0],
14
- "references": [1, 1],
15
- "result": {"metric_score": 0.5}
16
- }
17
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/test_metric.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from tram_accuracy import TRAMAccuracy
2
 
3
 
@@ -16,3 +17,55 @@ def test_tram_accuracy():
16
  "accuracy"
17
  ]
18
  assert accuracy == 1.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
  from tram_accuracy import TRAMAccuracy
3
 
4
 
 
17
  "accuracy"
18
  ]
19
  assert accuracy == 1.0
20
+
21
+
22
+ def test_empty_predictions():
23
+ """Empty predictions should raise ValueError."""
24
+ with pytest.raises(ValueError, match="predictions cannot be empty"):
25
+ TRAMAccuracy()._compute(predictions=[], references=[])
26
+
27
+
28
+ def test_mismatched_lengths():
29
+ """Mismatched lengths should raise ValueError."""
30
+ with pytest.raises(ValueError, match="must have same length"):
31
+ TRAMAccuracy()._compute(
32
+ predictions=["The final answer is (A)."],
33
+ references=["A", "B"],
34
+ )
35
+
36
+
37
+ def test_no_regex_match():
38
+ """Predictions without the expected format should be marked incorrect."""
39
+ result = TRAMAccuracy()._compute(
40
+ predictions=["I think the answer is A", "The final answer is (B)."],
41
+ references=["A", "B"],
42
+ return_average=False,
43
+ )
44
+ assert result["accuracy"] == [0, 1]
45
+
46
+
47
+ def test_partial_accuracy():
48
+ """Test partial accuracy calculation."""
49
+ result = TRAMAccuracy()._compute(
50
+ predictions=[
51
+ "The final answer is (A).",
52
+ "The final answer is (B).",
53
+ "The final answer is (C).",
54
+ ],
55
+ references=["A", "C", "C"],
56
+ return_average=True,
57
+ )
58
+ assert result["accuracy"] == pytest.approx(2 / 3)
59
+
60
+
61
+ def test_case_variations():
62
+ """Both 'The' and 'the' should be matched."""
63
+ result = TRAMAccuracy()._compute(
64
+ predictions=[
65
+ "The final answer is (A).",
66
+ "the final answer is (B).",
67
+ ],
68
+ references=["A", "B"],
69
+ return_average=False,
70
+ )
71
+ assert result["accuracy"] == [1, 1]
tram_accuracy.py CHANGED
@@ -37,14 +37,14 @@ Args:
37
  predictions: list of predictions to score. Each prediction
38
  should be a string with the model's response, which contains the final answer.
39
  references: list of reference for each prediction. Each
40
- reference a single letter respresenting the correct answer.
41
  return_average: whether to return the average accuracy or the accuracy for each prediction.
42
  Returns:
43
  accuracy: the accuracy for the TRAM datasets.
44
  """
45
 
46
 
47
- TRAM_ANSWER_REGEX = re.compile(r"[Tt]he final answer is .([A-D]).")
48
 
49
 
50
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
@@ -75,6 +75,13 @@ class TRAMAccuracy(evaluate.Metric):
75
 
76
  def _compute(self, predictions, references, return_average=True):
77
  """Returns the accuracy for the (multiple choice) TRAM datasets."""
 
 
 
 
 
 
 
78
  predictions_matches = [
79
  TRAM_ANSWER_REGEX.search(prediction) for prediction in predictions
80
  ]
 
37
  predictions: list of predictions to score. Each prediction
38
  should be a string with the model's response, which contains the final answer.
39
  references: list of reference for each prediction. Each
40
+ reference a single letter representing the correct answer.
41
  return_average: whether to return the average accuracy or the accuracy for each prediction.
42
  Returns:
43
  accuracy: the accuracy for the TRAM datasets.
44
  """
45
 
46
 
47
+ TRAM_ANSWER_REGEX = re.compile(r"[Tt]he final answer is \(([A-D])\)")
48
 
49
 
50
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
 
75
 
76
  def _compute(self, predictions, references, return_average=True):
77
  """Returns the accuracy for the (multiple choice) TRAM datasets."""
78
+ if len(predictions) == 0:
79
+ raise ValueError("predictions cannot be empty")
80
+ if len(predictions) != len(references):
81
+ raise ValueError(
82
+ f"predictions and references must have same length, "
83
+ f"got {len(predictions)} and {len(references)}"
84
+ )
85
  predictions_matches = [
86
  TRAM_ANSWER_REGEX.search(prediction) for prediction in predictions
87
  ]