mindchain commited on
Commit
786e916
Β·
verified Β·
1 Parent(s): 22c3bb1

Upload train_arithmetic_v3.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_arithmetic_v3.py +34 -8
train_arithmetic_v3.py CHANGED
@@ -53,12 +53,42 @@ def generate_arithmetic_samples(n_samples):
53
  return samples
54
 
55
  # ============================================================================
56
- # REWARD FUNCTION
57
  # ============================================================================
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  def reward_func(completions, prompts=None, **kwargs):
60
  """
61
- Reward function for arithmetic.
62
  """
63
  # Try multiple column names for ground truth
64
  answers = None
@@ -81,12 +111,8 @@ def reward_func(completions, prompts=None, **kwargs):
81
  else:
82
  text = str(completion)
83
 
84
- # Extract the last number
85
- numbers = re.findall(r'-?\d+\.?\d*', text)
86
- if numbers:
87
- predicted = numbers[-1].strip()
88
- else:
89
- predicted = ""
90
 
91
  # Exact match reward
92
  is_correct = predicted == str(truth).strip()
 
53
  return samples
54
 
55
  # ============================================================================
56
+ # REWARD FUNCTION (Improved)
57
  # ============================================================================
58
 
59
+ def extract_answer(text):
60
+ """
61
+ Extract the final answer from model output.
62
+ Priority:
63
+ 1. Number in $$...$$ LaTeX blocks (last one)
64
+ 2. Number after "Answer:" pattern
65
+ 3. Last standalone number (fallback)
66
+ """
67
+ # Try to find numbers in $$...$$ blocks first
68
+ latex_blocks = re.findall(r'\$\$(.*?)\$\$', text, re.DOTALL)
69
+ if latex_blocks:
70
+ # Get the last LaTeX block and extract number
71
+ last_block = latex_blocks[-1]
72
+ numbers = re.findall(r'-?\d+\.?\d*', last_block)
73
+ if numbers:
74
+ return numbers[-1].strip()
75
+
76
+ # Try to find number after "Answer:" pattern
77
+ answer_match = re.search(r'Answer:\s*(-?\d+\.?\d*)', text, re.IGNORECASE)
78
+ if answer_match:
79
+ return answer_match.group(1).strip()
80
+
81
+ # Fallback: last number in text
82
+ numbers = re.findall(r'-?\d+\.?\d*', text)
83
+ if numbers:
84
+ return numbers[-1].strip()
85
+
86
+ return ""
87
+
88
+
89
  def reward_func(completions, prompts=None, **kwargs):
90
  """
91
+ Reward function for arithmetic with improved extraction.
92
  """
93
  # Try multiple column names for ground truth
94
  answers = None
 
111
  else:
112
  text = str(completion)
113
 
114
+ # Extract answer using improved method
115
+ predicted = extract_answer(text)
 
 
 
 
116
 
117
  # Exact match reward
118
  is_correct = predicted == str(truth).strip()