Varshithdharmajv commited on
Commit
99f7550
·
verified ·
1 Parent(s): 5d4095b

Upload consensus_fusion.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. consensus_fusion.py +27 -7
consensus_fusion.py CHANGED
@@ -1,14 +1,26 @@
1
  from typing import List, Dict, Any
2
  import re
3
 
4
- def _normalize_answer(ans: str) -> str:
5
- """Normalize an answer string for comparison (remove spaces, lowercase, strip LaTeX wrappers)."""
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  s = str(ans).strip()
7
  s = re.sub(r'\$', '', s)
8
  s = re.sub(r'\\(?:approx|approx|cdot|,|;|\s)', ' ', s)
9
  s = s.replace("\\", "").replace("{", "").replace("}", "")
10
  s = s.replace(" ", "").lower()
11
- # Normalize floats: "3.0" == "3"
12
  try:
13
  f = float(s)
14
  s = str(int(f)) if f == int(f) else str(round(f, 6))
@@ -23,10 +35,18 @@ def normalize_answers(answers: List[str]) -> Dict[str, List[int]]:
23
  clean = _normalize_answer(ans)
24
  matched = False
25
  for key in list(normalized_groups.keys()):
26
- if _normalize_answer(key) == clean:
27
- normalized_groups[key].append(idx)
28
- matched = True
29
- break
 
 
 
 
 
 
 
 
30
  if not matched:
31
  normalized_groups[ans] = [idx]
32
  return normalized_groups
 
1
  from typing import List, Dict, Any
2
  import re
3
 
4
+ try:
5
+ from math_verify import parse, verify
6
+ MATH_VERIFY_AVAILABLE = True
7
+ except ImportError:
8
+ MATH_VERIFY_AVAILABLE = False
9
+
10
+ def _normalize_answer(ans: str) -> Any:
11
+ """Uses math_verify to parse the answer for robust comparison."""
12
+ if MATH_VERIFY_AVAILABLE:
13
+ try:
14
+ return parse(str(ans))
15
+ except:
16
+ return str(ans)
17
+
18
+ # Legacy fallback
19
  s = str(ans).strip()
20
  s = re.sub(r'\$', '', s)
21
  s = re.sub(r'\\(?:approx|approx|cdot|,|;|\s)', ' ', s)
22
  s = s.replace("\\", "").replace("{", "").replace("}", "")
23
  s = s.replace(" ", "").lower()
 
24
  try:
25
  f = float(s)
26
  s = str(int(f)) if f == int(f) else str(round(f, 6))
 
35
  clean = _normalize_answer(ans)
36
  matched = False
37
  for key in list(normalized_groups.keys()):
38
+ key_clean = _normalize_answer(key)
39
+
40
+ if MATH_VERIFY_AVAILABLE:
41
+ if verify(clean, key_clean):
42
+ normalized_groups[key].append(idx)
43
+ matched = True
44
+ break
45
+ else:
46
+ if key_clean == clean:
47
+ normalized_groups[key].append(idx)
48
+ matched = True
49
+ break
50
  if not matched:
51
  normalized_groups[ans] = [idx]
52
  return normalized_groups