Spaces:
Runtime error
Runtime error
Upload train_arithmetic_v3.py with huggingface_hub
Browse files- train_arithmetic_v3.py +34 -8
train_arithmetic_v3.py
CHANGED
|
@@ -53,12 +53,42 @@ def generate_arithmetic_samples(n_samples):
|
|
| 53 |
return samples
|
| 54 |
|
| 55 |
# ============================================================================
|
| 56 |
-
# REWARD FUNCTION
|
| 57 |
# ============================================================================
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
def reward_func(completions, prompts=None, **kwargs):
|
| 60 |
"""
|
| 61 |
-
Reward function for arithmetic.
|
| 62 |
"""
|
| 63 |
# Try multiple column names for ground truth
|
| 64 |
answers = None
|
|
@@ -81,12 +111,8 @@ def reward_func(completions, prompts=None, **kwargs):
|
|
| 81 |
else:
|
| 82 |
text = str(completion)
|
| 83 |
|
| 84 |
-
# Extract
|
| 85 |
-
|
| 86 |
-
if numbers:
|
| 87 |
-
predicted = numbers[-1].strip()
|
| 88 |
-
else:
|
| 89 |
-
predicted = ""
|
| 90 |
|
| 91 |
# Exact match reward
|
| 92 |
is_correct = predicted == str(truth).strip()
|
|
|
|
| 53 |
return samples
|
| 54 |
|
| 55 |
# ============================================================================
|
| 56 |
+
# REWARD FUNCTION (Improved)
|
| 57 |
# ============================================================================
|
| 58 |
|
| 59 |
+
def extract_answer(text):
|
| 60 |
+
"""
|
| 61 |
+
Extract the final answer from model output.
|
| 62 |
+
Priority:
|
| 63 |
+
1. Number in $$...$$ LaTeX blocks (last one)
|
| 64 |
+
2. Number after "Answer:" pattern
|
| 65 |
+
3. Last standalone number (fallback)
|
| 66 |
+
"""
|
| 67 |
+
# Try to find numbers in $$...$$ blocks first
|
| 68 |
+
latex_blocks = re.findall(r'\$\$(.*?)\$\$', text, re.DOTALL)
|
| 69 |
+
if latex_blocks:
|
| 70 |
+
# Get the last LaTeX block and extract number
|
| 71 |
+
last_block = latex_blocks[-1]
|
| 72 |
+
numbers = re.findall(r'-?\d+\.?\d*', last_block)
|
| 73 |
+
if numbers:
|
| 74 |
+
return numbers[-1].strip()
|
| 75 |
+
|
| 76 |
+
# Try to find number after "Answer:" pattern
|
| 77 |
+
answer_match = re.search(r'Answer:\s*(-?\d+\.?\d*)', text, re.IGNORECASE)
|
| 78 |
+
if answer_match:
|
| 79 |
+
return answer_match.group(1).strip()
|
| 80 |
+
|
| 81 |
+
# Fallback: last number in text
|
| 82 |
+
numbers = re.findall(r'-?\d+\.?\d*', text)
|
| 83 |
+
if numbers:
|
| 84 |
+
return numbers[-1].strip()
|
| 85 |
+
|
| 86 |
+
return ""
|
| 87 |
+
|
| 88 |
+
|
| 89 |
def reward_func(completions, prompts=None, **kwargs):
|
| 90 |
"""
|
| 91 |
+
Reward function for arithmetic with improved extraction.
|
| 92 |
"""
|
| 93 |
# Try multiple column names for ground truth
|
| 94 |
answers = None
|
|
|
|
| 111 |
else:
|
| 112 |
text = str(completion)
|
| 113 |
|
| 114 |
+
# Extract answer using improved method
|
| 115 |
+
predicted = extract_answer(text)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
|
| 117 |
# Exact match reward
|
| 118 |
is_correct = predicted == str(truth).strip()
|