aauss commited on
Commit
0f334f5
·
1 Parent(s): cac0344

Improve typing, early input checks and format code.

Browse files
app.py CHANGED
@@ -3,4 +3,4 @@ from evaluate.utils import launch_gradio_widget
3
 
4
 
5
  module = evaluate.load("aauss/test_of_time_accuracy")
6
- launch_gradio_widget(module)
 
3
 
4
 
5
  module = evaluate.load("aauss/test_of_time_accuracy")
6
+ launch_gradio_widget(module)
test_of_time_accuracy.py CHANGED
@@ -20,24 +20,25 @@ from typing import Any, Literal, TypedDict
20
  import datasets
21
  import evaluate
22
 
23
- # Field names used throughout the metric
24
  FIELD_EXPLANATION = "explanation"
25
  FIELD_ANSWER = "answer"
26
  FIELD_AGE = "age"
27
  FIELD_ORDERED_LIST = "ordered_list"
28
  FIELD_UNORDERED_LIST = "unordered_list"
29
 
30
- # Subset names
31
  SUBSET_ARITHMETIC = "arithmetic"
32
  SUBSET_SEMANTIC = "semantic"
33
  VALID_SUBSETS = frozenset({SUBSET_ARITHMETIC, SUBSET_SEMANTIC})
34
 
 
 
35
  # Control character escape mappings for JSON string normalization
36
- CONTROL_CHAR_ESCAPES = {'\n': '\\n', '\r': '\\r', '\t': '\\t'}
37
 
38
 
39
  class AccuracyResult(TypedDict):
40
- accuracy: float | list[bool]
 
41
 
42
  _CITATION = """\
43
  @InProceedings{huggingface:module,
@@ -58,9 +59,9 @@ Args:
58
  predictions: list of predictions to score. Each prediction should be a string that contains a JSON object (e.g., generated by an LLM).
59
  references: list of reference answers.
60
  subset: The subset of the benchmark being evaluated. Must be one of "arithmetic" or "semantic".
61
- return_average: If True, returns the average accuracy. If False, returns a list of boolean scores (correct/incorrect) for each sample. Defaults to True.
62
  Returns:
63
- accuracy: The accuracy score (0.0 to 1.0) if return_average=True, or a list of booleans indicating correctness per sample if return_average=False.
64
  Examples:
65
  >>> import evaluate
66
  >>> metric = evaluate.load("aauss/test_of_time_accuracy")
@@ -122,7 +123,7 @@ class TestOfTimeAccuracy(evaluate.Metric):
122
  decoder = json.JSONDecoder()
123
  idx = 0
124
  while idx < len(text):
125
- if text[idx] == '{':
126
  try:
127
  obj, _ = decoder.raw_decode(text, idx)
128
  if isinstance(obj, dict):
@@ -145,7 +146,7 @@ class TestOfTimeAccuracy(evaluate.Metric):
145
  i = 0
146
  while i < len(text):
147
  char = text[i]
148
- if char == '\\' and in_string and i + 1 < len(text):
149
  # Preserve existing escape sequences
150
  result.append(char)
151
  result.append(text[i + 1])
@@ -158,7 +159,7 @@ class TestOfTimeAccuracy(evaluate.Metric):
158
  else:
159
  result.append(char)
160
  i += 1
161
- return ''.join(result)
162
 
163
  @staticmethod
164
  def _parse_reference_label(label_str: str) -> dict | None:
@@ -297,7 +298,9 @@ class TestOfTimeAccuracy(evaluate.Metric):
297
  # Process list fields regardless of key order
298
  for key in (FIELD_ORDERED_LIST, FIELD_UNORDERED_LIST):
299
  if key in data and isinstance(data[key], list):
300
- data[key] = [item.lower() for item in data[key] if isinstance(item, str)]
 
 
301
 
302
  return data
303
 
@@ -413,7 +416,7 @@ class TestOfTimeAccuracy(evaluate.Metric):
413
  # Semantic references are used as-is
414
  return raw_references
415
 
416
- def _compare_pair(self, prediction: Any, reference: Any, subset: str) -> bool:
417
  """
418
  Compares a single prediction-reference pair.
419
 
@@ -423,7 +426,7 @@ class TestOfTimeAccuracy(evaluate.Metric):
423
  subset: Either 'arithmetic' or 'semantic'
424
 
425
  Returns:
426
- True if prediction matches reference, False otherwise
427
  """
428
  if subset == SUBSET_ARITHMETIC:
429
  prediction, reference = self._process_arithmetic_prediction(
@@ -434,13 +437,13 @@ class TestOfTimeAccuracy(evaluate.Metric):
434
  prediction, reference
435
  )
436
 
437
- return prediction == reference
438
 
439
  def _compute(
440
  self,
441
  predictions: list[str],
442
  references: list[str],
443
- subset: Literal["arithmetic", "semantic"],
444
  return_average: bool = True,
445
  ) -> AccuracyResult:
446
  """
@@ -456,11 +459,22 @@ class TestOfTimeAccuracy(evaluate.Metric):
456
  Returns:
457
  Dictionary with 'accuracy' key containing either:
458
  - float: average accuracy (if return_average=True)
459
- - list[bool]: per-sample correctness (if return_average=False)
460
 
461
  Raises:
462
  ValueError: If subset is not 'arithmetic' or 'semantic'
 
 
463
  """
 
 
 
 
 
 
 
 
 
464
  # Validate subset
465
  if subset not in VALID_SUBSETS:
466
  raise ValueError(
 
20
  import datasets
21
  import evaluate
22
 
 
23
  FIELD_EXPLANATION = "explanation"
24
  FIELD_ANSWER = "answer"
25
  FIELD_AGE = "age"
26
  FIELD_ORDERED_LIST = "ordered_list"
27
  FIELD_UNORDERED_LIST = "unordered_list"
28
 
 
29
  SUBSET_ARITHMETIC = "arithmetic"
30
  SUBSET_SEMANTIC = "semantic"
31
  VALID_SUBSETS = frozenset({SUBSET_ARITHMETIC, SUBSET_SEMANTIC})
32
 
33
+ SubsetType = Literal["arithmetic", "semantic"]
34
+
35
  # Control character escape mappings for JSON string normalization
36
+ CONTROL_CHAR_ESCAPES = {"\n": "\\n", "\r": "\\r", "\t": "\\t"}
37
 
38
 
39
  class AccuracyResult(TypedDict):
40
+ accuracy: float | list[int]
41
+
42
 
43
  _CITATION = """\
44
  @InProceedings{huggingface:module,
 
59
  predictions: list of predictions to score. Each prediction should be a string that contains a JSON object (e.g., generated by an LLM).
60
  references: list of reference answers.
61
  subset: The subset of the benchmark being evaluated. Must be one of "arithmetic" or "semantic".
62
+ return_average: If True, returns the average accuracy. If False, returns a list of int scores (0 or 1) for each sample. Defaults to True.
63
  Returns:
64
+ accuracy: The accuracy score (0.0 to 1.0) if return_average=True, or a list of int (0 or 1) indicating correctness per sample if return_average=False.
65
  Examples:
66
  >>> import evaluate
67
  >>> metric = evaluate.load("aauss/test_of_time_accuracy")
 
123
  decoder = json.JSONDecoder()
124
  idx = 0
125
  while idx < len(text):
126
+ if text[idx] == "{":
127
  try:
128
  obj, _ = decoder.raw_decode(text, idx)
129
  if isinstance(obj, dict):
 
146
  i = 0
147
  while i < len(text):
148
  char = text[i]
149
+ if char == "\\" and in_string and i + 1 < len(text):
150
  # Preserve existing escape sequences
151
  result.append(char)
152
  result.append(text[i + 1])
 
159
  else:
160
  result.append(char)
161
  i += 1
162
+ return "".join(result)
163
 
164
  @staticmethod
165
  def _parse_reference_label(label_str: str) -> dict | None:
 
298
  # Process list fields regardless of key order
299
  for key in (FIELD_ORDERED_LIST, FIELD_UNORDERED_LIST):
300
  if key in data and isinstance(data[key], list):
301
+ data[key] = [
302
+ item.lower() for item in data[key] if isinstance(item, str)
303
+ ]
304
 
305
  return data
306
 
 
416
  # Semantic references are used as-is
417
  return raw_references
418
 
419
+ def _compare_pair(self, prediction: Any, reference: Any, subset: str) -> int:
420
  """
421
  Compares a single prediction-reference pair.
422
 
 
426
  subset: Either 'arithmetic' or 'semantic'
427
 
428
  Returns:
429
+ 1 if prediction matches reference, 0 otherwise
430
  """
431
  if subset == SUBSET_ARITHMETIC:
432
  prediction, reference = self._process_arithmetic_prediction(
 
437
  prediction, reference
438
  )
439
 
440
+ return int(prediction == reference)
441
 
442
  def _compute(
443
  self,
444
  predictions: list[str],
445
  references: list[str],
446
+ subset: SubsetType,
447
  return_average: bool = True,
448
  ) -> AccuracyResult:
449
  """
 
459
  Returns:
460
  Dictionary with 'accuracy' key containing either:
461
  - float: average accuracy (if return_average=True)
462
+ - list[int]: per-sample correctness (if return_average=False)
463
 
464
  Raises:
465
  ValueError: If subset is not 'arithmetic' or 'semantic'
466
+ ValueError: If predictions is empty
467
+ ValueError: If predictions and references have different lengths
468
  """
469
+ # Validate inputs
470
+ if not predictions:
471
+ raise ValueError("predictions cannot be empty")
472
+ if len(predictions) != len(references):
473
+ raise ValueError(
474
+ f"predictions and references must have same length, "
475
+ f"got {len(predictions)} and {len(references)}"
476
+ )
477
+
478
  # Validate subset
479
  if subset not in VALID_SUBSETS:
480
  raise ValueError(
tests/test_arithmetic_type_casting.py CHANGED
@@ -86,6 +86,7 @@ def test_ordered_list_type_casting():
86
  pred_cast = TestOfTimeAccuracy._cast_prediction_to_reference_types(ref, pred)
87
  assert ref == pred_cast
88
 
 
89
  # TODO: Check if I should treat float strings differently, e.g., int(float("18.0"))
90
  def test_abc_type_casting():
91
  references_abc_keys = [
 
86
  pred_cast = TestOfTimeAccuracy._cast_prediction_to_reference_types(ref, pred)
87
  assert ref == pred_cast
88
 
89
+
90
  # TODO: Check if I should treat float strings differently, e.g., int(float("18.0"))
91
  def test_abc_type_casting():
92
  references_abc_keys = [