Spaces:
Sleeping
Sleeping
Commit Β·
973cd6f
1
Parent(s): f8319a8
push
Browse files- env/environment.py +188 -30
- env/generator.py +236 -14
- env/rewards.py +188 -48
- env/verifier.py +333 -66
- tests/test_env.py +222 -22
- train/colab_train.py +106 -28
- train/train_grpo.py +309 -103
env/environment.py
CHANGED
|
@@ -20,6 +20,23 @@ except ImportError:
|
|
| 20 |
logger = logging.getLogger(__name__)
|
| 21 |
|
| 22 |
class AutomathreasonerEnvironment(Environment):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
SUPPORTS_CONCURRENT_SESSIONS: bool = True
|
| 24 |
|
| 25 |
def __init__(self):
|
|
@@ -28,62 +45,180 @@ class AutomathreasonerEnvironment(Environment):
|
|
| 28 |
self.verifier = VerifierSystem()
|
| 29 |
self.reward_system = RewardSystem(max_len=2000)
|
| 30 |
|
| 31 |
-
# Curriculum tracking
|
| 32 |
-
self.difficulty_level =
|
| 33 |
-
self.rolling_results = deque(maxlen=
|
|
|
|
| 34 |
|
| 35 |
-
# Current problem state
|
| 36 |
self.current_problem = ""
|
| 37 |
self.current_solution = ""
|
| 38 |
-
self.current_sympy_f = None
|
|
|
|
|
|
|
|
|
|
| 39 |
self.times_seen_problem = 0
|
| 40 |
self.history: List[Dict[str, Any]] = []
|
| 41 |
-
self.max_steps = 3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
def _update_curriculum(self):
|
| 44 |
-
"""
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
def reset(self) -> AutomathreasonerObservation:
|
| 54 |
-
"""Reset environment to a new problem."""
|
| 55 |
self._update_curriculum()
|
|
|
|
| 56 |
|
| 57 |
self._state = State(episode_id=str(uuid4()), step_count=0)
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
self.current_problem = task['problem']
|
| 61 |
self.current_solution = task['solution']
|
| 62 |
self.current_sympy_f = task.get('sympy_f')
|
| 63 |
-
|
|
|
|
|
|
|
| 64 |
self.times_seen_problem = 0
|
| 65 |
self.history = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
return AutomathreasonerObservation(
|
| 68 |
-
problem_text=
|
| 69 |
difficulty_level=self.difficulty_level,
|
| 70 |
history=[],
|
| 71 |
reward=0.0,
|
| 72 |
-
done=False
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
)
|
| 74 |
|
| 75 |
def step(self, action: AutomathreasonerAction) -> AutomathreasonerObservation: # type: ignore[override]
|
| 76 |
self._state.step_count += 1
|
| 77 |
|
| 78 |
-
# Verification
|
| 79 |
c, q, p_sup, r_ref = self.verifier.verify(
|
| 80 |
action.reasoning,
|
| 81 |
action.final_answer,
|
| 82 |
self.current_solution,
|
| 83 |
-
sympy_f=self.current_sympy_f
|
|
|
|
| 84 |
)
|
| 85 |
|
| 86 |
-
# Reward
|
| 87 |
action_str = f"{action.reasoning} \n {action.final_answer}"
|
| 88 |
total_r, components = self.reward_system.compute_reward(
|
| 89 |
correctness=c,
|
|
@@ -93,36 +228,59 @@ class AutomathreasonerEnvironment(Environment):
|
|
| 93 |
action_str=action_str,
|
| 94 |
final_answer=action.final_answer,
|
| 95 |
history=self.history,
|
| 96 |
-
times_seen_problem=self.times_seen_problem
|
|
|
|
| 97 |
)
|
| 98 |
|
| 99 |
self.times_seen_problem += 1
|
| 100 |
|
| 101 |
-
# Update history
|
| 102 |
attempt = {
|
| 103 |
"prediction": action.final_answer,
|
| 104 |
-
"
|
|
|
|
|
|
|
| 105 |
}
|
| 106 |
self.history.append(attempt)
|
| 107 |
-
# Keep only last 3 attempts for observation
|
| 108 |
obs_history = self.history[-3:]
|
| 109 |
|
| 110 |
-
|
|
|
|
| 111 |
done = is_correct or self._state.step_count >= self.max_steps
|
| 112 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
if done:
|
| 114 |
self.rolling_results.append(1 if is_correct else 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
return AutomathreasonerObservation(
|
| 117 |
-
problem_text=
|
| 118 |
difficulty_level=self.difficulty_level,
|
| 119 |
history=obs_history,
|
| 120 |
reward=total_r,
|
| 121 |
done=done,
|
| 122 |
metadata={
|
| 123 |
"reward_components": components,
|
| 124 |
-
"ground_truth": self.current_solution if done else "HIDDEN",
|
| 125 |
-
"is_correct": is_correct
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
}
|
| 127 |
)
|
| 128 |
|
|
|
|
| 20 |
logger = logging.getLogger(__name__)
|
| 21 |
|
| 22 |
class AutomathreasonerEnvironment(Environment):
|
| 23 |
+
"""
|
| 24 |
+
OpenEnv-compliant RL environment for symbolic calculus (indefinite integration).
|
| 25 |
+
|
| 26 |
+
Key improvements over v1:
|
| 27 |
+
1. Faster, smoother curriculum progression (Scaf-GRPO inspired)
|
| 28 |
+
2. Scaffold hints injected after repeated failures (breaks "learning cliff")
|
| 29 |
+
3. Increased max_steps (3 β 5) for more within-episode learning
|
| 30 |
+
4. Consecutive failure tracking for adaptive scaffolding
|
| 31 |
+
5. Technique-aware problem generation
|
| 32 |
+
6. Rolling accuracy uses weighted window for responsiveness
|
| 33 |
+
|
| 34 |
+
References:
|
| 35 |
+
- Scaf-GRPO (arxiv, 2025): hierarchical hints for hard problems
|
| 36 |
+
- GRPO-Ξ»: credit assignment for faster convergence
|
| 37 |
+
- arxiv:2408.10215: reward shaping best practices
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
SUPPORTS_CONCURRENT_SESSIONS: bool = True
|
| 41 |
|
| 42 |
def __init__(self):
|
|
|
|
| 45 |
self.verifier = VerifierSystem()
|
| 46 |
self.reward_system = RewardSystem(max_len=2000)
|
| 47 |
|
| 48 |
+
# --- Curriculum tracking (improved) ---
|
| 49 |
+
self.difficulty_level = 1.5 # Start slightly easier to build momentum
|
| 50 |
+
self.rolling_results = deque(maxlen=10) # Shorter window (was 20) β faster adaptation
|
| 51 |
+
self.rolling_rewards = deque(maxlen=10) # Track reward magnitudes too
|
| 52 |
|
| 53 |
+
# --- Current problem state ---
|
| 54 |
self.current_problem = ""
|
| 55 |
self.current_solution = ""
|
| 56 |
+
self.current_sympy_f = None # Integration ground truth (integrand)
|
| 57 |
+
self.current_sympy_F = None # Antiderivative (for structural comparison)
|
| 58 |
+
self.current_technique = "" # Detected integration technique
|
| 59 |
+
self.current_scaffold_hints = {} # Progressive hints
|
| 60 |
self.times_seen_problem = 0
|
| 61 |
self.history: List[Dict[str, Any]] = []
|
| 62 |
+
self.max_steps = 5 # Increased from 3 β more within-episode learning
|
| 63 |
+
|
| 64 |
+
# --- Failure tracking for scaffolding ---
|
| 65 |
+
self.consecutive_failures = 0
|
| 66 |
+
self.total_episodes = 0
|
| 67 |
+
self.total_correct = 0
|
| 68 |
+
|
| 69 |
+
# --- Technique performance tracking ---
|
| 70 |
+
self.technique_performance: Dict[str, List[float]] = {}
|
| 71 |
|
| 72 |
def _update_curriculum(self):
|
| 73 |
+
"""
|
| 74 |
+
Update difficulty based on rolling accuracy.
|
| 75 |
+
|
| 76 |
+
Improved:
|
| 77 |
+
- Shorter rolling window (10 vs 20) for faster response
|
| 78 |
+
- Smoother progression: advance proportional to accuracy
|
| 79 |
+
- Lower thresholds to maintain momentum
|
| 80 |
+
- Technique-aware adaptation
|
| 81 |
+
"""
|
| 82 |
+
if len(self.rolling_results) < 3:
|
| 83 |
+
return
|
| 84 |
+
|
| 85 |
+
accuracy = sum(self.rolling_results) / len(self.rolling_results)
|
| 86 |
+
avg_reward = sum(self.rolling_rewards) / len(self.rolling_rewards) if self.rolling_rewards else 0
|
| 87 |
+
|
| 88 |
+
# Advance: accuracy > 0.50 (was 0.7)
|
| 89 |
+
if accuracy > 0.50:
|
| 90 |
+
# Proportional advancement β faster when doing well
|
| 91 |
+
advance = 0.2 + 0.3 * accuracy # Range: 0.35 to 0.5
|
| 92 |
+
self.difficulty_level += advance
|
| 93 |
+
logger.info(f"π Curriculum UP: Accuracy={accuracy:.2f}, "
|
| 94 |
+
f"AvgReward={avg_reward:.3f}, NewDiff={self.difficulty_level:.1f}")
|
| 95 |
+
|
| 96 |
+
# Partial advance: decent reward signal even without full correctness
|
| 97 |
+
elif avg_reward > 0.35 and accuracy > 0.25:
|
| 98 |
+
self.difficulty_level += 0.1
|
| 99 |
+
logger.info(f"π Curriculum MICRO-UP: Accuracy={accuracy:.2f}, "
|
| 100 |
+
f"AvgReward={avg_reward:.3f}, NewDiff={self.difficulty_level:.1f}")
|
| 101 |
+
|
| 102 |
+
# Retreat: accuracy < 0.20 (was 0.6)
|
| 103 |
+
elif accuracy < 0.20:
|
| 104 |
+
self.difficulty_level = max(1.0, self.difficulty_level - 0.3)
|
| 105 |
+
logger.info(f"π Curriculum DOWN: Accuracy={accuracy:.2f}, "
|
| 106 |
+
f"NewDiff={self.difficulty_level:.1f}")
|
| 107 |
+
|
| 108 |
+
def _get_scaffold_observation(self) -> str:
|
| 109 |
+
"""
|
| 110 |
+
Generate scaffold hint based on consecutive failures.
|
| 111 |
+
Implements Scaf-GRPO progressive hint injection.
|
| 112 |
+
|
| 113 |
+
- 0-1 failures: no hint
|
| 114 |
+
- 2 failures: technique hint (level 1)
|
| 115 |
+
- 3 failures: first step hint (level 2)
|
| 116 |
+
- 4+ failures: detailed hint (level 3)
|
| 117 |
+
"""
|
| 118 |
+
if self.consecutive_failures < 2 or not self.current_scaffold_hints:
|
| 119 |
+
return ""
|
| 120 |
+
|
| 121 |
+
if self.consecutive_failures == 2:
|
| 122 |
+
hint = self.current_scaffold_hints.get('hint_level_1', '')
|
| 123 |
+
if hint:
|
| 124 |
+
return f"\n[Hint: {hint}]"
|
| 125 |
+
|
| 126 |
+
elif self.consecutive_failures == 3:
|
| 127 |
+
hint = self.current_scaffold_hints.get('hint_level_2', '')
|
| 128 |
+
if hint:
|
| 129 |
+
return f"\n[Hint: {hint}]"
|
| 130 |
+
|
| 131 |
+
else: # 4+
|
| 132 |
+
hint = self.current_scaffold_hints.get('hint_level_3', '')
|
| 133 |
+
if hint:
|
| 134 |
+
return f"\n[Strong Hint: {hint}]"
|
| 135 |
+
|
| 136 |
+
return ""
|
| 137 |
+
|
| 138 |
+
def _update_technique_performance(self, technique: str, correct: bool):
|
| 139 |
+
"""Track per-technique performance for adaptive curriculum."""
|
| 140 |
+
if technique not in self.technique_performance:
|
| 141 |
+
self.technique_performance[technique] = []
|
| 142 |
+
|
| 143 |
+
self.technique_performance[technique].append(1.0 if correct else 0.0)
|
| 144 |
+
|
| 145 |
+
# Keep last 20 results per technique
|
| 146 |
+
if len(self.technique_performance[technique]) > 20:
|
| 147 |
+
self.technique_performance[technique] = self.technique_performance[technique][-20:]
|
| 148 |
+
|
| 149 |
+
def _get_weakest_technique(self) -> str:
|
| 150 |
+
"""Find the technique the model struggles with most."""
|
| 151 |
+
worst_technique = ""
|
| 152 |
+
worst_accuracy = 1.0
|
| 153 |
+
|
| 154 |
+
for technique, results in self.technique_performance.items():
|
| 155 |
+
if len(results) >= 3:
|
| 156 |
+
acc = sum(results) / len(results)
|
| 157 |
+
if acc < worst_accuracy:
|
| 158 |
+
worst_accuracy = acc
|
| 159 |
+
worst_technique = technique
|
| 160 |
+
|
| 161 |
+
return worst_technique
|
| 162 |
|
| 163 |
def reset(self) -> AutomathreasonerObservation:
|
| 164 |
+
"""Reset environment to a new problem with scaffold support."""
|
| 165 |
self._update_curriculum()
|
| 166 |
+
self.total_episodes += 1
|
| 167 |
|
| 168 |
self._state = State(episode_id=str(uuid4()), step_count=0)
|
| 169 |
+
|
| 170 |
+
# Occasionally target the weakest technique (20% of the time)
|
| 171 |
+
import random
|
| 172 |
+
weakest = self._get_weakest_technique()
|
| 173 |
+
if weakest and random.random() < 0.2 and self.total_episodes > 10:
|
| 174 |
+
task = self.generator.generate_technique_focused_task(
|
| 175 |
+
weakest, difficulty=max(1.0, self.difficulty_level - 0.5)
|
| 176 |
+
)
|
| 177 |
+
logger.info(f"π― Targeting weak technique: {weakest}")
|
| 178 |
+
else:
|
| 179 |
+
task = self.generator.generate_task(target_difficulty_band=self.difficulty_level)
|
| 180 |
|
| 181 |
self.current_problem = task['problem']
|
| 182 |
self.current_solution = task['solution']
|
| 183 |
self.current_sympy_f = task.get('sympy_f')
|
| 184 |
+
self.current_sympy_F = task.get('sympy_F')
|
| 185 |
+
self.current_technique = task.get('technique', '')
|
| 186 |
+
self.current_scaffold_hints = task.get('scaffold_hints', {})
|
| 187 |
self.times_seen_problem = 0
|
| 188 |
self.history = []
|
| 189 |
+
self.consecutive_failures = 0
|
| 190 |
+
|
| 191 |
+
# Build problem text with optional scaffold hint
|
| 192 |
+
problem_text = self.current_problem
|
| 193 |
+
scaffold = self._get_scaffold_observation()
|
| 194 |
+
if scaffold:
|
| 195 |
+
problem_text += scaffold
|
| 196 |
|
| 197 |
return AutomathreasonerObservation(
|
| 198 |
+
problem_text=problem_text,
|
| 199 |
difficulty_level=self.difficulty_level,
|
| 200 |
history=[],
|
| 201 |
reward=0.0,
|
| 202 |
+
done=False,
|
| 203 |
+
metadata={
|
| 204 |
+
"technique": self.current_technique,
|
| 205 |
+
"episode_number": self.total_episodes,
|
| 206 |
+
}
|
| 207 |
)
|
| 208 |
|
| 209 |
def step(self, action: AutomathreasonerAction) -> AutomathreasonerObservation: # type: ignore[override]
|
| 210 |
self._state.step_count += 1
|
| 211 |
|
| 212 |
+
# Verification with graduated correctness and technique awareness
|
| 213 |
c, q, p_sup, r_ref = self.verifier.verify(
|
| 214 |
action.reasoning,
|
| 215 |
action.final_answer,
|
| 216 |
self.current_solution,
|
| 217 |
+
sympy_f=self.current_sympy_f,
|
| 218 |
+
technique_hint=self.current_technique,
|
| 219 |
)
|
| 220 |
|
| 221 |
+
# Reward computation β all 7 components + format compliance
|
| 222 |
action_str = f"{action.reasoning} \n {action.final_answer}"
|
| 223 |
total_r, components = self.reward_system.compute_reward(
|
| 224 |
correctness=c,
|
|
|
|
| 228 |
action_str=action_str,
|
| 229 |
final_answer=action.final_answer,
|
| 230 |
history=self.history,
|
| 231 |
+
times_seen_problem=self.times_seen_problem,
|
| 232 |
+
reasoning=action.reasoning,
|
| 233 |
)
|
| 234 |
|
| 235 |
self.times_seen_problem += 1
|
| 236 |
|
| 237 |
+
# Update history β store BOTH keys for backward compatibility
|
| 238 |
attempt = {
|
| 239 |
"prediction": action.final_answer,
|
| 240 |
+
"final_answer": action.final_answer, # BUGFIX: also store as final_answer
|
| 241 |
+
"correctness": c,
|
| 242 |
+
"reward": total_r,
|
| 243 |
}
|
| 244 |
self.history.append(attempt)
|
|
|
|
| 245 |
obs_history = self.history[-3:]
|
| 246 |
|
| 247 |
+
# Correctness check β graduated (threshold at 0.7 for "correct enough")
|
| 248 |
+
is_correct = (c >= 0.7)
|
| 249 |
done = is_correct or self._state.step_count >= self.max_steps
|
| 250 |
|
| 251 |
+
if is_correct:
|
| 252 |
+
self.consecutive_failures = 0
|
| 253 |
+
self.total_correct += 1
|
| 254 |
+
else:
|
| 255 |
+
self.consecutive_failures += 1
|
| 256 |
+
|
| 257 |
if done:
|
| 258 |
self.rolling_results.append(1 if is_correct else 0)
|
| 259 |
+
self.rolling_rewards.append(total_r)
|
| 260 |
+
self._update_technique_performance(self.current_technique, is_correct)
|
| 261 |
+
|
| 262 |
+
# Build problem text with scaffold hints for next attempt (if not done)
|
| 263 |
+
problem_text = self.current_problem
|
| 264 |
+
if not done:
|
| 265 |
+
scaffold = self._get_scaffold_observation()
|
| 266 |
+
if scaffold:
|
| 267 |
+
problem_text += scaffold
|
| 268 |
|
| 269 |
return AutomathreasonerObservation(
|
| 270 |
+
problem_text=problem_text,
|
| 271 |
difficulty_level=self.difficulty_level,
|
| 272 |
history=obs_history,
|
| 273 |
reward=total_r,
|
| 274 |
done=done,
|
| 275 |
metadata={
|
| 276 |
"reward_components": components,
|
| 277 |
+
"ground_truth": self.current_solution if done else "HIDDEN",
|
| 278 |
+
"is_correct": is_correct,
|
| 279 |
+
"technique": self.current_technique,
|
| 280 |
+
"consecutive_failures": self.consecutive_failures,
|
| 281 |
+
"correctness_score": c,
|
| 282 |
+
"curriculum_difficulty": self.difficulty_level,
|
| 283 |
+
"episode_number": self.total_episodes,
|
| 284 |
}
|
| 285 |
)
|
| 286 |
|
env/generator.py
CHANGED
|
@@ -1,45 +1,201 @@
|
|
| 1 |
import sympy as sp
|
| 2 |
import random
|
| 3 |
-
from typing import Dict, Any, Tuple
|
| 4 |
|
| 5 |
class TaskGenerationEngine:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
def __init__(self):
|
| 7 |
self.x = sp.Symbol('x')
|
|
|
|
| 8 |
# Components for generating random functions F(x)
|
| 9 |
self.basic_functions = [
|
| 10 |
lambda x, c: x**c,
|
| 11 |
lambda x, c: sp.sin(c*x),
|
| 12 |
lambda x, c: sp.cos(c*x),
|
| 13 |
lambda x, c: sp.exp(c*x),
|
| 14 |
-
lambda x, c: sp.ln(sp.Abs(c*x))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
def _score_difficulty(self, components: int, nesting: int) -> float:
|
| 18 |
"""D = num_components + degree_of_nesting * 2"""
|
| 19 |
return float(components + nesting * 2.0)
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
def generate_random_function(self, complexity: int) -> Tuple[Any, float]:
|
| 22 |
-
"""Generates a random F(x)."""
|
| 23 |
num_components = max(1, int(complexity / 2))
|
| 24 |
nesting = max(0, int(complexity / 4))
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
f_expr = 0
|
| 27 |
for _ in range(num_components):
|
| 28 |
-
comp_func = random.choice(
|
| 29 |
coeff = random.randint(1, 5)
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
# Apply nesting
|
| 33 |
for _ in range(nesting):
|
| 34 |
outer = random.choice(self.basic_functions)
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
f_expr += random.randint(1, 10) * term
|
| 38 |
|
| 39 |
return f_expr, self._score_difficulty(num_components, nesting)
|
| 40 |
|
| 41 |
def generate_task(self, target_difficulty_band: float) -> Dict[str, Any]:
|
| 42 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
complexity = max(1, int(target_difficulty_band))
|
| 44 |
|
| 45 |
# 1. Generate F(x)
|
|
@@ -48,8 +204,17 @@ class TaskGenerationEngine:
|
|
| 48 |
# 2. Differentiate to get the problem f(x)
|
| 49 |
f_expr = sp.diff(F_expr, self.x)
|
| 50 |
|
| 51 |
-
# 3.
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
solution_text = f"{sp.simplify(F_expr)} + C"
|
| 54 |
|
| 55 |
return {
|
|
@@ -58,13 +223,17 @@ class TaskGenerationEngine:
|
|
| 58 |
"solution": solution_text,
|
| 59 |
"type": "integration",
|
| 60 |
"sympy_F": F_expr,
|
| 61 |
-
"sympy_f": f_expr
|
|
|
|
|
|
|
| 62 |
}
|
| 63 |
|
| 64 |
-
def generate_variants(self, task: Dict[str, Any], count: int = 2) ->
|
| 65 |
"""
|
| 66 |
LADDER Component: Recursive Decomposition for Integration.
|
| 67 |
Breaks down sums or simplifies coefficients.
|
|
|
|
|
|
|
| 68 |
"""
|
| 69 |
variants = []
|
| 70 |
F_expr = task.get("sympy_F")
|
|
@@ -79,13 +248,23 @@ class TaskGenerationEngine:
|
|
| 79 |
for arg in args[:count]:
|
| 80 |
sub_F = arg
|
| 81 |
sub_f = sp.diff(sub_F, self.x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
variants.append({
|
| 83 |
-
"problem": f"Integrate step-variant:
|
| 84 |
"solution": f"{sub_F} + C",
|
| 85 |
-
"difficulty": task["difficulty"] - 1.0,
|
| 86 |
"type": "integration",
|
| 87 |
"sympy_F": sub_F,
|
| 88 |
-
"sympy_f": sub_f
|
|
|
|
|
|
|
| 89 |
})
|
| 90 |
|
| 91 |
# Recursive Rule 2: Constant simplification
|
|
@@ -94,3 +273,46 @@ class TaskGenerationEngine:
|
|
| 94 |
variants.append(self.generate_task(max(1.0, task["difficulty"] - 2.0)))
|
| 95 |
|
| 96 |
return variants[:count]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import sympy as sp
|
| 2 |
import random
|
| 3 |
+
from typing import Dict, Any, Tuple, List, Optional
|
| 4 |
|
| 5 |
class TaskGenerationEngine:
|
| 6 |
+
"""
|
| 7 |
+
Symbolic calculus task generator with scaffold hints and technique metadata.
|
| 8 |
+
|
| 9 |
+
Improvements over v1:
|
| 10 |
+
1. Stores which integration technique is needed (u-sub, by-parts, etc.)
|
| 11 |
+
2. Generates scaffold hints (first step of solution) for Scaf-GRPO
|
| 12 |
+
3. Better prompt formatting using LaTeX-style notation
|
| 13 |
+
4. More diverse function compositions
|
| 14 |
+
5. Technique-aware variant generation
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
def __init__(self):
|
| 18 |
self.x = sp.Symbol('x')
|
| 19 |
+
|
| 20 |
# Components for generating random functions F(x)
|
| 21 |
self.basic_functions = [
|
| 22 |
lambda x, c: x**c,
|
| 23 |
lambda x, c: sp.sin(c*x),
|
| 24 |
lambda x, c: sp.cos(c*x),
|
| 25 |
lambda x, c: sp.exp(c*x),
|
| 26 |
+
lambda x, c: sp.ln(sp.Abs(c*x + 1)), # +1 avoids log(0)
|
| 27 |
+
]
|
| 28 |
+
|
| 29 |
+
# Additional functions for higher difficulty
|
| 30 |
+
self.advanced_functions = [
|
| 31 |
+
lambda x, c: sp.tan(c*x),
|
| 32 |
+
lambda x, c: sp.atan(c*x),
|
| 33 |
+
lambda x, c: sp.sinh(c*x),
|
| 34 |
+
lambda x, c: sp.cosh(c*x),
|
| 35 |
+
lambda x, c: x**c * sp.exp(x), # Requires integration by parts
|
| 36 |
+
lambda x, c: sp.sin(x) * sp.cos(c*x), # Product of trig
|
| 37 |
]
|
| 38 |
+
|
| 39 |
+
# Technique detection patterns
|
| 40 |
+
self._technique_detectors = {
|
| 41 |
+
'power_rule': self._is_power_rule,
|
| 42 |
+
'u_substitution': self._is_u_substitution,
|
| 43 |
+
'by_parts': self._is_by_parts,
|
| 44 |
+
'trigonometric': self._is_trig_integral,
|
| 45 |
+
'exponential': self._is_exponential,
|
| 46 |
+
'logarithmic': self._is_logarithmic,
|
| 47 |
+
}
|
| 48 |
|
| 49 |
def _score_difficulty(self, components: int, nesting: int) -> float:
|
| 50 |
"""D = num_components + degree_of_nesting * 2"""
|
| 51 |
return float(components + nesting * 2.0)
|
| 52 |
|
| 53 |
+
def _detect_technique(self, f_expr) -> str:
|
| 54 |
+
"""Detect which integration technique is most appropriate for f(x)."""
|
| 55 |
+
for technique, detector in self._technique_detectors.items():
|
| 56 |
+
if detector(f_expr):
|
| 57 |
+
return technique
|
| 58 |
+
return 'power_rule' # Default fallback
|
| 59 |
+
|
| 60 |
+
def _is_power_rule(self, expr) -> bool:
|
| 61 |
+
"""Check if expression is a simple polynomial."""
|
| 62 |
+
return expr.is_polynomial(self.x)
|
| 63 |
+
|
| 64 |
+
def _is_u_substitution(self, expr) -> bool:
|
| 65 |
+
"""Check if expression likely needs u-substitution."""
|
| 66 |
+
# Composition of functions suggests u-sub
|
| 67 |
+
if isinstance(expr, sp.Mul):
|
| 68 |
+
args = expr.args
|
| 69 |
+
# Look for f(g(x)) * g'(x) pattern
|
| 70 |
+
for arg in args:
|
| 71 |
+
if arg.has(sp.sin, sp.cos, sp.exp, sp.log) and not arg.is_polynomial(self.x):
|
| 72 |
+
return True
|
| 73 |
+
return False
|
| 74 |
+
|
| 75 |
+
def _is_by_parts(self, expr) -> bool:
|
| 76 |
+
"""Check if expression likely needs integration by parts."""
|
| 77 |
+
if isinstance(expr, sp.Mul):
|
| 78 |
+
has_poly = any(a.is_polynomial(self.x) for a in expr.args)
|
| 79 |
+
has_transcendental = any(a.has(sp.sin, sp.cos, sp.exp, sp.log) for a in expr.args)
|
| 80 |
+
return has_poly and has_transcendental
|
| 81 |
+
return False
|
| 82 |
+
|
| 83 |
+
def _is_trig_integral(self, expr) -> bool:
|
| 84 |
+
"""Check if expression is primarily trigonometric."""
|
| 85 |
+
return expr.has(sp.sin, sp.cos, sp.tan) and not expr.has(sp.exp, sp.log)
|
| 86 |
+
|
| 87 |
+
def _is_exponential(self, expr) -> bool:
|
| 88 |
+
"""Check if expression is primarily exponential."""
|
| 89 |
+
return expr.has(sp.exp) and not expr.has(sp.sin, sp.cos)
|
| 90 |
+
|
| 91 |
+
def _is_logarithmic(self, expr) -> bool:
|
| 92 |
+
"""Check if expression involves logarithms."""
|
| 93 |
+
return expr.has(sp.log, sp.ln)
|
| 94 |
+
|
| 95 |
+
def _generate_scaffold_hint(self, f_expr, F_expr, technique: str) -> Dict[str, str]:
|
| 96 |
+
"""
|
| 97 |
+
Generate a scaffold hint for the problem.
|
| 98 |
+
|
| 99 |
+
Returns a dict with:
|
| 100 |
+
- 'technique': which technique to use
|
| 101 |
+
- 'hint_level_1': gentle nudge (technique name)
|
| 102 |
+
- 'hint_level_2': first step of solution
|
| 103 |
+
- 'hint_level_3': most of the solution
|
| 104 |
+
"""
|
| 105 |
+
hints = {
|
| 106 |
+
'technique': technique,
|
| 107 |
+
'hint_level_1': '',
|
| 108 |
+
'hint_level_2': '',
|
| 109 |
+
'hint_level_3': '',
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
technique_descriptions = {
|
| 113 |
+
'power_rule': "Try applying the power rule: β«x^n dx = x^(n+1)/(n+1) + C",
|
| 114 |
+
'u_substitution': "Try u-substitution. Look for a composite function and its derivative.",
|
| 115 |
+
'by_parts': "Try integration by parts: β«u dv = uv - β«v du",
|
| 116 |
+
'trigonometric': "Try using trigonometric identities to simplify first.",
|
| 117 |
+
'exponential': "Remember that β«e^(ax) dx = (1/a)e^(ax) + C",
|
| 118 |
+
'logarithmic': "Remember that β«(1/x) dx = ln|x| + C",
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
hints['hint_level_1'] = technique_descriptions.get(
|
| 122 |
+
technique, "Try identifying the integration technique needed."
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# Level 2: Show the substitution or setup
|
| 126 |
+
try:
|
| 127 |
+
if technique == 'u_substitution':
|
| 128 |
+
# Try to identify the inner function for u-sub hint
|
| 129 |
+
hints['hint_level_2'] = f"Hint: Try {hints['hint_level_1']}. The integrand has a composite structure."
|
| 130 |
+
elif technique == 'by_parts':
|
| 131 |
+
hints['hint_level_2'] = f"Hint: {hints['hint_level_1']}. Identify which part to differentiate (u) and which to integrate (dv)."
|
| 132 |
+
else:
|
| 133 |
+
hints['hint_level_2'] = f"Hint: {hints['hint_level_1']}"
|
| 134 |
+
except Exception:
|
| 135 |
+
hints['hint_level_2'] = hints['hint_level_1']
|
| 136 |
+
|
| 137 |
+
# Level 3: Show the first term of the answer
|
| 138 |
+
try:
|
| 139 |
+
simplified = sp.simplify(F_expr)
|
| 140 |
+
if isinstance(simplified, sp.Add):
|
| 141 |
+
first_term = simplified.args[0]
|
| 142 |
+
hints['hint_level_3'] = f"The answer starts with: {sp.pretty(first_term)} + ..."
|
| 143 |
+
else:
|
| 144 |
+
hints['hint_level_3'] = f"The answer has the form: {type(simplified).__name__} expression"
|
| 145 |
+
except Exception:
|
| 146 |
+
hints['hint_level_3'] = hints['hint_level_2']
|
| 147 |
+
|
| 148 |
+
return hints
|
| 149 |
+
|
| 150 |
def generate_random_function(self, complexity: int) -> Tuple[Any, float]:
|
| 151 |
+
"""Generates a random F(x) with appropriate complexity."""
|
| 152 |
num_components = max(1, int(complexity / 2))
|
| 153 |
nesting = max(0, int(complexity / 4))
|
| 154 |
|
| 155 |
+
# Use advanced functions at higher complexity
|
| 156 |
+
available_funcs = list(self.basic_functions)
|
| 157 |
+
if complexity >= 4:
|
| 158 |
+
available_funcs.extend(self.advanced_functions[:3])
|
| 159 |
+
if complexity >= 6:
|
| 160 |
+
available_funcs.extend(self.advanced_functions[3:])
|
| 161 |
+
|
| 162 |
f_expr = 0
|
| 163 |
for _ in range(num_components):
|
| 164 |
+
comp_func = random.choice(available_funcs)
|
| 165 |
coeff = random.randint(1, 5)
|
| 166 |
+
|
| 167 |
+
try:
|
| 168 |
+
term = comp_func(self.x, coeff)
|
| 169 |
+
except Exception:
|
| 170 |
+
# Fallback to simple polynomial
|
| 171 |
+
term = self.x ** coeff
|
| 172 |
|
| 173 |
# Apply nesting
|
| 174 |
for _ in range(nesting):
|
| 175 |
outer = random.choice(self.basic_functions)
|
| 176 |
+
try:
|
| 177 |
+
term = outer(term, 1)
|
| 178 |
+
except Exception:
|
| 179 |
+
break
|
| 180 |
|
| 181 |
f_expr += random.randint(1, 10) * term
|
| 182 |
|
| 183 |
return f_expr, self._score_difficulty(num_components, nesting)
|
| 184 |
|
| 185 |
def generate_task(self, target_difficulty_band: float) -> Dict[str, Any]:
|
| 186 |
+
"""
|
| 187 |
+
Provides an indefinite integral task with technique hints and scaffold support.
|
| 188 |
+
|
| 189 |
+
Returns dict with:
|
| 190 |
+
- problem: formatted problem text
|
| 191 |
+
- solution: ground truth solution string
|
| 192 |
+
- difficulty: computed difficulty score
|
| 193 |
+
- type: 'integration'
|
| 194 |
+
- sympy_F: SymPy expression for F(x) (antiderivative)
|
| 195 |
+
- sympy_f: SymPy expression for f(x) (integrand)
|
| 196 |
+
- technique: detected integration technique
|
| 197 |
+
- scaffold_hints: dict of progressive hints
|
| 198 |
+
"""
|
| 199 |
complexity = max(1, int(target_difficulty_band))
|
| 200 |
|
| 201 |
# 1. Generate F(x)
|
|
|
|
| 204 |
# 2. Differentiate to get the problem f(x)
|
| 205 |
f_expr = sp.diff(F_expr, self.x)
|
| 206 |
|
| 207 |
+
# 3. Detect technique and generate hints
|
| 208 |
+
technique = self._detect_technique(f_expr)
|
| 209 |
+
scaffold_hints = self._generate_scaffold_hint(f_expr, F_expr, technique)
|
| 210 |
+
|
| 211 |
+
# 4. Format strings β use cleaner formatting for LLM consumption
|
| 212 |
+
try:
|
| 213 |
+
pretty_f = sp.pretty(f_expr, use_unicode=True)
|
| 214 |
+
except Exception:
|
| 215 |
+
pretty_f = str(f_expr)
|
| 216 |
+
|
| 217 |
+
problem_text = f"Find the indefinite integral: β« ({pretty_f}) dx"
|
| 218 |
solution_text = f"{sp.simplify(F_expr)} + C"
|
| 219 |
|
| 220 |
return {
|
|
|
|
| 223 |
"solution": solution_text,
|
| 224 |
"type": "integration",
|
| 225 |
"sympy_F": F_expr,
|
| 226 |
+
"sympy_f": f_expr,
|
| 227 |
+
"technique": technique,
|
| 228 |
+
"scaffold_hints": scaffold_hints,
|
| 229 |
}
|
| 230 |
|
| 231 |
+
def generate_variants(self, task: Dict[str, Any], count: int = 2) -> List[Dict[str, Any]]:
|
| 232 |
"""
|
| 233 |
LADDER Component: Recursive Decomposition for Integration.
|
| 234 |
Breaks down sums or simplifies coefficients.
|
| 235 |
+
|
| 236 |
+
Improved: preserves technique hints and scaffold data through decomposition.
|
| 237 |
"""
|
| 238 |
variants = []
|
| 239 |
F_expr = task.get("sympy_F")
|
|
|
|
| 248 |
for arg in args[:count]:
|
| 249 |
sub_F = arg
|
| 250 |
sub_f = sp.diff(sub_F, self.x)
|
| 251 |
+
technique = self._detect_technique(sub_f)
|
| 252 |
+
scaffold = self._generate_scaffold_hint(sub_f, sub_F, technique)
|
| 253 |
+
|
| 254 |
+
try:
|
| 255 |
+
pretty_sub_f = sp.pretty(sub_f, use_unicode=True)
|
| 256 |
+
except Exception:
|
| 257 |
+
pretty_sub_f = str(sub_f)
|
| 258 |
+
|
| 259 |
variants.append({
|
| 260 |
+
"problem": f"Integrate step-variant: β« ({pretty_sub_f}) dx",
|
| 261 |
"solution": f"{sub_F} + C",
|
| 262 |
+
"difficulty": max(0.5, task["difficulty"] - 1.0),
|
| 263 |
"type": "integration",
|
| 264 |
"sympy_F": sub_F,
|
| 265 |
+
"sympy_f": sub_f,
|
| 266 |
+
"technique": technique,
|
| 267 |
+
"scaffold_hints": scaffold,
|
| 268 |
})
|
| 269 |
|
| 270 |
# Recursive Rule 2: Constant simplification
|
|
|
|
| 273 |
variants.append(self.generate_task(max(1.0, task["difficulty"] - 2.0)))
|
| 274 |
|
| 275 |
return variants[:count]
|
| 276 |
+
|
| 277 |
+
def generate_technique_focused_task(self, technique: str, difficulty: float = 2.0) -> Dict[str, Any]:
|
| 278 |
+
"""
|
| 279 |
+
Generate a task that specifically targets a given integration technique.
|
| 280 |
+
Useful for curriculum learning when the model struggles with a technique.
|
| 281 |
+
"""
|
| 282 |
+
x = self.x
|
| 283 |
+
|
| 284 |
+
technique_generators = {
|
| 285 |
+
'power_rule': lambda: random.randint(1, 5) * x**random.randint(1, 6),
|
| 286 |
+
'u_substitution': lambda: sp.sin(random.randint(1, 3) * x**2) * x,
|
| 287 |
+
'by_parts': lambda: x * sp.exp(random.randint(1, 3) * x),
|
| 288 |
+
'trigonometric': lambda: sp.sin(x)**random.randint(1, 3) * sp.cos(x),
|
| 289 |
+
'exponential': lambda: random.randint(1, 5) * sp.exp(random.randint(1, 4) * x),
|
| 290 |
+
'logarithmic': lambda: sp.ln(sp.Abs(x + 1)),
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
generator = technique_generators.get(technique)
|
| 294 |
+
if generator is None:
|
| 295 |
+
return self.generate_task(difficulty)
|
| 296 |
+
|
| 297 |
+
try:
|
| 298 |
+
F_expr = generator()
|
| 299 |
+
f_expr = sp.diff(F_expr, x)
|
| 300 |
+
scaffold = self._generate_scaffold_hint(f_expr, F_expr, technique)
|
| 301 |
+
|
| 302 |
+
try:
|
| 303 |
+
pretty_f = sp.pretty(f_expr, use_unicode=True)
|
| 304 |
+
except Exception:
|
| 305 |
+
pretty_f = str(f_expr)
|
| 306 |
+
|
| 307 |
+
return {
|
| 308 |
+
"problem": f"Find the indefinite integral: β« ({pretty_f}) dx",
|
| 309 |
+
"solution": f"{sp.simplify(F_expr)} + C",
|
| 310 |
+
"difficulty": difficulty,
|
| 311 |
+
"type": "integration",
|
| 312 |
+
"sympy_F": F_expr,
|
| 313 |
+
"sympy_f": f_expr,
|
| 314 |
+
"technique": technique,
|
| 315 |
+
"scaffold_hints": scaffold,
|
| 316 |
+
}
|
| 317 |
+
except Exception:
|
| 318 |
+
return self.generate_task(difficulty)
|
env/rewards.py
CHANGED
|
@@ -1,62 +1,163 @@
|
|
| 1 |
-
import random
|
| 2 |
import math
|
| 3 |
from typing import Dict, Any, List, Tuple
|
| 4 |
|
| 5 |
class RewardSystem:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
def __init__(self, max_len: int = 1000):
|
| 7 |
self.max_len = max_len
|
| 8 |
|
| 9 |
def compute_diversity(self, current_answer: str, history: List[Dict[str, Any]]) -> float:
|
| 10 |
"""
|
| 11 |
-
D = diversity (difference from past attempts)
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
| 14 |
"""
|
| 15 |
if not history:
|
| 16 |
return 1.0
|
| 17 |
|
| 18 |
cur_ans_clean = current_answer.strip().lower()
|
| 19 |
|
|
|
|
|
|
|
|
|
|
| 20 |
for attempt in history:
|
| 21 |
-
|
|
|
|
| 22 |
if prev_ans == cur_ans_clean:
|
| 23 |
-
return -
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
-
# If unique, give full diversity bonus
|
| 26 |
return 1.0
|
| 27 |
|
| 28 |
def compute_efficiency(self, action_string: str) -> float:
|
| 29 |
"""
|
| 30 |
-
E = efficiency.
|
| 31 |
-
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
| 33 |
"""
|
| 34 |
approx_tokens = len(action_string) / 4.0
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
-
#
|
| 38 |
-
|
|
|
|
| 39 |
|
| 40 |
-
|
| 41 |
-
e = math.exp(- (ratio ** 2)) - 1.0
|
| 42 |
-
return e
|
| 43 |
|
| 44 |
def compute_exploration_bonus(self, action_string: str, times_seen: int) -> float:
|
| 45 |
"""
|
| 46 |
[PAPER TRACEABILITY: Exploration via Entropy Bonus]
|
| 47 |
G. EXPLORATION VIA ENTROPY BONUS
|
| 48 |
-
|
| 49 |
X = (entropy_bonus) / sqrt(1 + times_seen_problem)
|
|
|
|
|
|
|
| 50 |
"""
|
| 51 |
-
# Simple structural entropy estimation (unique character distribution variance)
|
| 52 |
length = len(action_string)
|
| 53 |
-
if length
|
| 54 |
-
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
else:
|
| 57 |
-
|
|
|
|
|
|
|
| 58 |
|
| 59 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
def detect_trivial_output(self, action_string: str) -> bool:
|
| 62 |
"""Anti-reward hacking: detect trivial constant outputs"""
|
|
@@ -66,6 +167,13 @@ class RewardSystem:
|
|
| 66 |
unique_chars = len(set(action_string))
|
| 67 |
if unique_chars < 3 and len(action_string) > 10:
|
| 68 |
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
return False
|
| 70 |
|
| 71 |
def compute_reward(self,
|
|
@@ -76,50 +184,82 @@ class RewardSystem:
|
|
| 76 |
action_str: str,
|
| 77 |
final_answer: str,
|
| 78 |
history: List[Dict[str, Any]],
|
| 79 |
-
times_seen_problem: int
|
|
|
|
| 80 |
"""
|
| 81 |
-
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
"""
|
| 84 |
if self.detect_trivial_output(action_str):
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
q = reasoning_quality
|
| 91 |
d = self.compute_diversity(final_answer, history)
|
| 92 |
-
|
| 93 |
-
# If repeated answer, C is zeroed to prevent hacking
|
| 94 |
-
if d < 0:
|
| 95 |
-
c = 0.0
|
| 96 |
-
|
| 97 |
e = self.compute_efficiency(action_str)
|
| 98 |
x = self.compute_exploration_bonus(action_str, times_seen_problem)
|
|
|
|
| 99 |
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
-
#
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
-
#
|
| 106 |
-
|
| 107 |
-
r_norm = (reflection_score + 0.5) / 1.5 # Scales [-0.5, 1.0] to [0, 1]
|
| 108 |
-
q_norm = min(1.0, max(0.0, q_smooth))
|
| 109 |
|
| 110 |
-
# New Simplified Composite Reward Equation (Strictly bounded [0, 1])
|
| 111 |
-
# Base coefficients sum exactly to 1.0. Noise is removed to satisfy bounds.
|
| 112 |
-
total_r = (0.4 * c) + (0.3 * q_norm) + (0.2 * p_norm) + (0.1 * r_norm)
|
| 113 |
components = {
|
| 114 |
"total_reward": total_r,
|
| 115 |
"C_correctness": c,
|
| 116 |
-
"Q_reasoning":
|
| 117 |
"P_process_supervision": process_supervision,
|
| 118 |
"R_reflection": reflection_score,
|
| 119 |
"D_diversity": d,
|
| 120 |
"E_efficiency": e,
|
| 121 |
"X_exploration": x,
|
| 122 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
}
|
| 124 |
|
| 125 |
return total_r, components
|
|
|
|
|
|
|
| 1 |
import math
|
| 2 |
from typing import Dict, Any, List, Tuple
|
| 3 |
|
| 4 |
class RewardSystem:
|
| 5 |
+
"""
|
| 6 |
+
Dense, multi-component reward system for mathematical RL training.
|
| 7 |
+
|
| 8 |
+
Key improvements over v1:
|
| 9 |
+
1. All 7 reward components now contribute to the final score
|
| 10 |
+
2. Partial credit support (continuous C β [0,1] from verifier)
|
| 11 |
+
3. Fixed history key mismatch (was breaking diversity detection)
|
| 12 |
+
4. Adaptive efficiency curve that doesn't over-penalize reasonable lengths
|
| 13 |
+
5. Removed random noise (adds variance without useful signal)
|
| 14 |
+
6. Added format compliance reward for structured output
|
| 15 |
+
|
| 16 |
+
Reward equation:
|
| 17 |
+
R = α·C + β·Q + γ·P + δ·R_ref + η·D_norm + ΢·E_norm + λ·X + μ·F_fmt
|
| 18 |
+
|
| 19 |
+
Weights: Ξ±=0.30, Ξ²=0.12, Ξ³=0.10, Ξ΄=0.05, Ξ·=0.13, ΞΆ=0.08, Ξ»=0.07, ΞΌ=0.15
|
| 20 |
+
Sum = 1.0
|
| 21 |
+
|
| 22 |
+
References:
|
| 23 |
+
- arxiv:2408.10215 (Reward shaping for RL convergence)
|
| 24 |
+
- arxiv:2601.19100 (Reward engineering for software/code tasks)
|
| 25 |
+
- DeepSeek-R1 GRPO (graduated correctness)
|
| 26 |
+
- GRPO-Ξ» (credit assignment)
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
# Reward component weights (sum to 1.0)
|
| 30 |
+
W_CORRECTNESS = 0.30 # Ξ±: Primary β correctness drives learning
|
| 31 |
+
W_REASONING = 0.12 # Ξ²: Reasoning quality
|
| 32 |
+
W_PROCESS = 0.10 # Ξ³: Step-by-step process supervision
|
| 33 |
+
W_REFLECTION = 0.05 # Ξ΄: Self-correction behavior
|
| 34 |
+
W_DIVERSITY = 0.13 # Ξ·: Answer diversity (prevents repetition)
|
| 35 |
+
W_EFFICIENCY = 0.08 # ΞΆ: Token efficiency
|
| 36 |
+
W_EXPLORATION = 0.07 # Ξ»: Exploration bonus
|
| 37 |
+
W_FORMAT = 0.15 # ΞΌ: Format compliance (model must learn structure)
|
| 38 |
+
|
| 39 |
def __init__(self, max_len: int = 1000):
|
| 40 |
self.max_len = max_len
|
| 41 |
|
| 42 |
def compute_diversity(self, current_answer: str, history: List[Dict[str, Any]]) -> float:
|
| 43 |
"""
|
| 44 |
+
D = diversity (difference from past attempts).
|
| 45 |
+
|
| 46 |
+
Graduated penalty instead of binary:
|
| 47 |
+
- Exact repeat: -1.0 (steep penalty)
|
| 48 |
+
- Similar to a past answer: -0.3
|
| 49 |
+
- Unique: +1.0
|
| 50 |
"""
|
| 51 |
if not history:
|
| 52 |
return 1.0
|
| 53 |
|
| 54 |
cur_ans_clean = current_answer.strip().lower()
|
| 55 |
|
| 56 |
+
if not cur_ans_clean:
|
| 57 |
+
return 0.0 # Empty answer gets no diversity credit
|
| 58 |
+
|
| 59 |
for attempt in history:
|
| 60 |
+
# BUGFIX: check both 'final_answer' and 'prediction' keys for compatibility
|
| 61 |
+
prev_ans = attempt.get('final_answer', attempt.get('prediction', '')).strip().lower()
|
| 62 |
if prev_ans == cur_ans_clean:
|
| 63 |
+
return -1.0 # Exact repeat β strong penalty
|
| 64 |
+
|
| 65 |
+
# Check for near-duplicates (edit distance heuristic)
|
| 66 |
+
if prev_ans and cur_ans_clean:
|
| 67 |
+
# Simple character overlap ratio
|
| 68 |
+
overlap = sum(1 for a, b in zip(prev_ans, cur_ans_clean) if a == b)
|
| 69 |
+
max_len = max(len(prev_ans), len(cur_ans_clean))
|
| 70 |
+
if max_len > 0 and overlap / max_len > 0.85:
|
| 71 |
+
return -0.3 # Near-duplicate β moderate penalty
|
| 72 |
|
|
|
|
| 73 |
return 1.0
|
| 74 |
|
| 75 |
def compute_efficiency(self, action_string: str) -> float:
|
| 76 |
"""
|
| 77 |
+
E = efficiency. Adaptive Gaussian penalty curve.
|
| 78 |
+
|
| 79 |
+
Improved: wider optimal zone (30-120 tokens) to avoid penalizing
|
| 80 |
+
legitimate mathematical reasoning that naturally needs more space.
|
| 81 |
+
|
| 82 |
+
E β [-0.5, 0.0] (always a penalty or neutral, never a bonus)
|
| 83 |
"""
|
| 84 |
approx_tokens = len(action_string) / 4.0
|
| 85 |
+
optimal_center = 80.0 # Wider center for math
|
| 86 |
+
optimal_width = 60.0 # Generous width
|
| 87 |
+
|
| 88 |
+
# Gentle Gaussian β penalizes only extreme lengths
|
| 89 |
+
ratio = (approx_tokens - optimal_center) / optimal_width
|
| 90 |
+
e = math.exp(-(ratio ** 2)) - 1.0
|
| 91 |
|
| 92 |
+
# Additional penalty for very long outputs (anti-rambling)
|
| 93 |
+
if approx_tokens > 300:
|
| 94 |
+
e -= 0.3 * (approx_tokens - 300) / 300
|
| 95 |
|
| 96 |
+
return max(-1.0, e)
|
|
|
|
|
|
|
| 97 |
|
| 98 |
def compute_exploration_bonus(self, action_string: str, times_seen: int) -> float:
|
| 99 |
"""
|
| 100 |
[PAPER TRACEABILITY: Exploration via Entropy Bonus]
|
| 101 |
G. EXPLORATION VIA ENTROPY BONUS
|
| 102 |
+
|
| 103 |
X = (entropy_bonus) / sqrt(1 + times_seen_problem)
|
| 104 |
+
|
| 105 |
+
Improved with better entropy estimation using word-level diversity.
|
| 106 |
"""
|
|
|
|
| 107 |
length = len(action_string)
|
| 108 |
+
if length == 0:
|
| 109 |
+
return 0.0
|
| 110 |
+
|
| 111 |
+
# Character-level entropy
|
| 112 |
+
unique_ratio = len(set(action_string)) / length
|
| 113 |
+
char_entropy = math.log1p(unique_ratio)
|
| 114 |
+
|
| 115 |
+
# Word-level diversity bonus (rewards varied vocabulary)
|
| 116 |
+
words = action_string.lower().split()
|
| 117 |
+
if words:
|
| 118 |
+
unique_word_ratio = len(set(words)) / len(words)
|
| 119 |
+
word_entropy = math.log1p(unique_word_ratio)
|
| 120 |
else:
|
| 121 |
+
word_entropy = 0.0
|
| 122 |
+
|
| 123 |
+
combined = 0.6 * char_entropy + 0.4 * word_entropy
|
| 124 |
|
| 125 |
+
return combined / math.sqrt(1.0 + times_seen)
|
| 126 |
+
|
| 127 |
+
def compute_format_compliance(self, action_str: str, reasoning: str, final_answer: str) -> float:
|
| 128 |
+
"""
|
| 129 |
+
Format compliance reward β teaches the model to output structured responses.
|
| 130 |
+
|
| 131 |
+
Rewards:
|
| 132 |
+
- Having both reasoning and answer sections
|
| 133 |
+
- Using mathematical notation
|
| 134 |
+
- Proper structure (reasoning before answer)
|
| 135 |
+
|
| 136 |
+
F β [0, 1]
|
| 137 |
+
"""
|
| 138 |
+
score = 0.0
|
| 139 |
+
|
| 140 |
+
# Has non-empty reasoning
|
| 141 |
+
if reasoning and len(reasoning.strip()) > 10:
|
| 142 |
+
score += 0.3
|
| 143 |
+
|
| 144 |
+
# Has non-empty final answer
|
| 145 |
+
if final_answer and len(final_answer.strip()) > 0:
|
| 146 |
+
score += 0.3
|
| 147 |
+
|
| 148 |
+
# Answer contains mathematical content
|
| 149 |
+
math_indicators = ['x', '=', '+', '-', '*', '/', '^', 'sin', 'cos', 'exp', 'log', '(']
|
| 150 |
+
math_count = sum(1 for m in math_indicators if m in final_answer.lower())
|
| 151 |
+
if math_count >= 2:
|
| 152 |
+
score += 0.2
|
| 153 |
+
elif math_count >= 1:
|
| 154 |
+
score += 0.1
|
| 155 |
+
|
| 156 |
+
# Reasoning contains structured steps
|
| 157 |
+
if any(marker in reasoning.lower() for marker in ['step', 'first', 'then', 'therefore', '=']):
|
| 158 |
+
score += 0.2
|
| 159 |
+
|
| 160 |
+
return min(1.0, score)
|
| 161 |
|
| 162 |
def detect_trivial_output(self, action_string: str) -> bool:
|
| 163 |
"""Anti-reward hacking: detect trivial constant outputs"""
|
|
|
|
| 167 |
unique_chars = len(set(action_string))
|
| 168 |
if unique_chars < 3 and len(action_string) > 10:
|
| 169 |
return True
|
| 170 |
+
# Detect repetitive patterns
|
| 171 |
+
if len(action_string) > 20:
|
| 172 |
+
# Check if a short pattern is repeated
|
| 173 |
+
for plen in range(1, 6):
|
| 174 |
+
pattern = action_string[:plen]
|
| 175 |
+
if action_string == pattern * (len(action_string) // plen) + pattern[:len(action_string) % plen]:
|
| 176 |
+
return True
|
| 177 |
return False
|
| 178 |
|
| 179 |
def compute_reward(self,
|
|
|
|
| 184 |
action_str: str,
|
| 185 |
final_answer: str,
|
| 186 |
history: List[Dict[str, Any]],
|
| 187 |
+
times_seen_problem: int,
|
| 188 |
+
reasoning: str = "") -> Tuple[float, Dict[str, float]]:
|
| 189 |
"""
|
| 190 |
+
Dense composite reward using ALL 7 components + format compliance.
|
| 191 |
+
|
| 192 |
+
R = α·C + β·Q_norm + γ·P_norm + δ·R_norm + η·D_norm + ΢·E_norm + λ·X + μ·F_fmt
|
| 193 |
+
|
| 194 |
+
All components are normalized to [0, 1] before weighting.
|
| 195 |
+
Final reward β [0, 1].
|
| 196 |
"""
|
| 197 |
if self.detect_trivial_output(action_str):
|
| 198 |
+
components = {
|
| 199 |
+
"total_reward": -0.5,
|
| 200 |
+
"C_correctness": 0.0, "Q_reasoning": 0.0,
|
| 201 |
+
"P_process_supervision": 0.0, "R_reflection": 0.0,
|
| 202 |
+
"D_diversity": 0.0, "E_efficiency": -1.0,
|
| 203 |
+
"X_exploration": 0.0, "F_format": 0.0,
|
| 204 |
+
}
|
| 205 |
+
return -0.5, components
|
| 206 |
+
|
| 207 |
+
# --- Raw component computation ---
|
| 208 |
+
c = correctness # Already β [0, 1] with graduated scoring
|
| 209 |
q = reasoning_quality
|
| 210 |
d = self.compute_diversity(final_answer, history)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
e = self.compute_efficiency(action_str)
|
| 212 |
x = self.compute_exploration_bonus(action_str, times_seen_problem)
|
| 213 |
+
f_fmt = self.compute_format_compliance(action_str, reasoning, final_answer)
|
| 214 |
|
| 215 |
+
# If repeated answer, reduce correctness credit (anti-hacking)
|
| 216 |
+
if d < -0.5:
|
| 217 |
+
c = c * 0.3 # Steep discount but not full zeroing
|
| 218 |
+
|
| 219 |
+
# --- Normalize all components to [0, 1] ---
|
| 220 |
+
q_norm = min(1.0, max(0.0, math.tanh(q)))
|
| 221 |
+
p_norm = (process_supervision + 1.0) / 2.0 # [-1, 1] β [0, 1]
|
| 222 |
+
r_norm = (reflection_score + 1.0) / 2.0 # [-1, 1] β [0, 1]
|
| 223 |
+
d_norm = (d + 1.0) / 2.0 # [-1, 1] β [0, 1]
|
| 224 |
+
e_norm = (e + 1.0) / 1.0 # [-1, 0] β [0, 1]
|
| 225 |
+
e_norm = min(1.0, max(0.0, e_norm))
|
| 226 |
+
x_norm = min(1.0, max(0.0, x))
|
| 227 |
+
f_norm = min(1.0, max(0.0, f_fmt))
|
| 228 |
|
| 229 |
+
# --- Weighted composite ---
|
| 230 |
+
total_r = (
|
| 231 |
+
self.W_CORRECTNESS * c +
|
| 232 |
+
self.W_REASONING * q_norm +
|
| 233 |
+
self.W_PROCESS * p_norm +
|
| 234 |
+
self.W_REFLECTION * r_norm +
|
| 235 |
+
self.W_DIVERSITY * d_norm +
|
| 236 |
+
self.W_EFFICIENCY * e_norm +
|
| 237 |
+
self.W_EXPLORATION * x_norm +
|
| 238 |
+
self.W_FORMAT * f_norm
|
| 239 |
+
)
|
| 240 |
|
| 241 |
+
# Clamp to [0, 1]
|
| 242 |
+
total_r = min(1.0, max(0.0, total_r))
|
|
|
|
|
|
|
| 243 |
|
|
|
|
|
|
|
|
|
|
| 244 |
components = {
|
| 245 |
"total_reward": total_r,
|
| 246 |
"C_correctness": c,
|
| 247 |
+
"Q_reasoning": q_norm,
|
| 248 |
"P_process_supervision": process_supervision,
|
| 249 |
"R_reflection": reflection_score,
|
| 250 |
"D_diversity": d,
|
| 251 |
"E_efficiency": e,
|
| 252 |
"X_exploration": x,
|
| 253 |
+
"F_format": f_fmt,
|
| 254 |
+
# Weighted contributions (for debugging)
|
| 255 |
+
"_w_C": self.W_CORRECTNESS * c,
|
| 256 |
+
"_w_Q": self.W_REASONING * q_norm,
|
| 257 |
+
"_w_P": self.W_PROCESS * p_norm,
|
| 258 |
+
"_w_R": self.W_REFLECTION * r_norm,
|
| 259 |
+
"_w_D": self.W_DIVERSITY * d_norm,
|
| 260 |
+
"_w_E": self.W_EFFICIENCY * e_norm,
|
| 261 |
+
"_w_X": self.W_EXPLORATION * x_norm,
|
| 262 |
+
"_w_F": self.W_FORMAT * f_norm,
|
| 263 |
}
|
| 264 |
|
| 265 |
return total_r, components
|
env/verifier.py
CHANGED
|
@@ -3,6 +3,46 @@ import math
|
|
| 3 |
from typing import Dict, Any, Tuple
|
| 4 |
|
| 5 |
class VerifierSystem:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
def __init__(self):
|
| 7 |
pass
|
| 8 |
|
|
@@ -34,46 +74,208 @@ class VerifierSystem:
|
|
| 34 |
except Exception:
|
| 35 |
return False
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
def mock_llm_judge(self, reasoning: str, prediction: str, ground_truth: str) -> float:
|
| 38 |
"""4. LLM judge (mock or placeholder scoring reasoning quality)
|
| 39 |
Returns reasoning quality score Q (0.0 to 1.0)
|
|
|
|
|
|
|
| 40 |
"""
|
| 41 |
-
# A simple heuristic for mock judge:
|
| 42 |
-
# Longer reasoning with step-like markers suggests higher quality in this mock
|
| 43 |
-
step_markers = ['step', 'first', 'then', 'because', 'therefore', 'equals', '=', '+', '-']
|
| 44 |
score = 0.0
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
-
# Length bonus (up to 0.
|
| 47 |
-
|
| 48 |
-
score += min(0.4, length * 0.01)
|
| 49 |
|
| 50 |
-
#
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
def check_process_supervision(self, reasoning: str) -> float:
|
| 58 |
"""
|
| 59 |
[PAPER TRACEABILITY: Process Supervision (Lightweight PRM)]
|
| 60 |
E. PROCESS SUPERVISION (STEP-AWARE REWARD)
|
| 61 |
-
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
"""
|
| 64 |
lower_r = reasoning.lower()
|
|
|
|
|
|
|
| 65 |
score = 0.0
|
| 66 |
|
| 67 |
-
# Check stepwise structure
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
|
|
|
| 71 |
score += 0.3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
-
#
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
return max(-1.0, min(1.0, score))
|
| 78 |
|
| 79 |
def check_reflection(self, reasoning: str, c: float) -> float:
|
|
@@ -82,59 +284,48 @@ class VerifierSystem:
|
|
| 82 |
H. REFLECTION MODULE
|
| 83 |
Model generates "What could be wrong?"
|
| 84 |
Penalize if contradiction with final answer, reward correct self-correction.
|
|
|
|
|
|
|
| 85 |
"""
|
| 86 |
lower_r = reasoning.lower()
|
| 87 |
score = 0.0
|
| 88 |
|
| 89 |
-
reflection_phrases = [
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
else:
|
| 95 |
-
|
|
|
|
| 96 |
|
| 97 |
-
return score
|
| 98 |
-
|
| 99 |
-
def
|
| 100 |
-
|
| 101 |
-
[PAPER TRACEABILITY: Section 3.1.3 Solution Verification]
|
| 102 |
-
Numerical multi-point quadrature verification.
|
| 103 |
-
Instead of evaluating integrals, we differentiate the prediction F_pred(x)
|
| 104 |
-
and compare it to the ground truth integrand f(x) at 5 random points.
|
| 105 |
-
"""
|
| 106 |
-
import sympy as sp
|
| 107 |
-
import random
|
| 108 |
-
x = sp.Symbol('x')
|
| 109 |
-
try:
|
| 110 |
-
# Clean prediction string
|
| 111 |
-
clean_pred = prediction.strip()
|
| 112 |
-
if "Answer:" in clean_pred:
|
| 113 |
-
clean_pred = clean_pred.split("Answer:")[-1].strip()
|
| 114 |
-
clean_pred = clean_pred.replace("+ C", "").replace("+C", "").strip()
|
| 115 |
-
|
| 116 |
-
F_pred = sp.parse_expr(clean_pred)
|
| 117 |
-
f_pred = sp.diff(F_pred, x)
|
| 118 |
-
|
| 119 |
-
# Evaluate at 5 random points
|
| 120 |
-
for _ in range(5):
|
| 121 |
-
test_point = random.uniform(-5, 5)
|
| 122 |
-
p_val = float(f_pred.subs(x, test_point).evalf())
|
| 123 |
-
t_val = float(sympy_f.subs(x, test_point).evalf())
|
| 124 |
-
|
| 125 |
-
# Paper uses 10^-2 relative tolerance
|
| 126 |
-
if not math.isclose(p_val, t_val, rel_tol=1e-2, abs_tol=1e-2):
|
| 127 |
-
return False
|
| 128 |
-
return True
|
| 129 |
-
except Exception:
|
| 130 |
-
return False
|
| 131 |
-
|
| 132 |
-
def verify(self, reasoning: str, prediction: str, ground_truth: str, sympy_f: Any = None) -> Tuple[float, float, float, float]:
|
| 133 |
"""
|
| 134 |
-
Run all verifiers
|
| 135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
"""
|
|
|
|
| 137 |
c = 0.0
|
|
|
|
|
|
|
| 138 |
if self.check_exact_match(prediction, ground_truth):
|
| 139 |
c = 1.0
|
| 140 |
elif sympy_f is not None and self.check_numerical_integration(prediction, sympy_f):
|
|
@@ -143,10 +334,86 @@ class VerifierSystem:
|
|
| 143 |
c = 1.0
|
| 144 |
elif self.check_python_execution(prediction, ground_truth):
|
| 145 |
c = 1.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
q = self.mock_llm_judge(reasoning, prediction, ground_truth)
|
| 148 |
-
|
| 149 |
p = self.check_process_supervision(reasoning)
|
| 150 |
r = self.check_reflection(reasoning, c)
|
| 151 |
|
| 152 |
return c, q, p, r
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from typing import Dict, Any, Tuple
|
| 4 |
|
| 5 |
class VerifierSystem:
|
| 6 |
+
"""
|
| 7 |
+
Multi-stage verification system that returns graduated correctness scores
|
| 8 |
+
instead of binary pass/fail. This provides a dense reward signal for RL
|
| 9 |
+
training, enabling faster convergence.
|
| 10 |
+
|
| 11 |
+
Correctness tiers:
|
| 12 |
+
1.0 β Fully correct (exact or numerical match)
|
| 13 |
+
0.7 β Structurally correct (right form, wrong coefficient)
|
| 14 |
+
0.4 β Partially correct (correct technique identified)
|
| 15 |
+
0.15 β Minimal credit (parseable math expression attempted)
|
| 16 |
+
0.0 β Garbage / trivial output
|
| 17 |
+
|
| 18 |
+
References:
|
| 19 |
+
- DeepSeek-R1 GRPO reward design
|
| 20 |
+
- arxiv:2408.10215 (Reward Engineering for RL)
|
| 21 |
+
- arxiv:2601.19100 (Reward Engineering for Software Tasks)
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
# Integration techniques and their associated keywords
|
| 25 |
+
TECHNIQUE_KEYWORDS = {
|
| 26 |
+
'u_substitution': ['substitut', 'u =', 'u=', 'let u', 'du'],
|
| 27 |
+
'by_parts': ['by parts', 'integration by parts', 'ibp', 'uv -', 'udv'],
|
| 28 |
+
'trig_sub': ['trig sub', 'trigonometric substitution', 'sin(ΞΈ)', 'cos(ΞΈ)', 'tan(ΞΈ)'],
|
| 29 |
+
'partial_fraction': ['partial fraction', 'decompos'],
|
| 30 |
+
'power_rule': ['power rule', 'x^n', 'x**'],
|
| 31 |
+
'exponential': ['exponential', 'e^', 'exp('],
|
| 32 |
+
'trigonometric': ['sin', 'cos', 'tan', 'sec', 'csc', 'cot'],
|
| 33 |
+
'logarithmic': ['ln', 'log', 'logarithm'],
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
# Mathematical reasoning markers for process supervision
|
| 37 |
+
MATH_MARKERS = [
|
| 38 |
+
'step', 'first', 'then', 'next', 'therefore', 'because', 'since',
|
| 39 |
+
'equals', 'simplif', 'substitut', 'evaluat', 'factor', 'expand',
|
| 40 |
+
'differentiat', 'integrat', 'apply', 'using', 'recall', 'note that',
|
| 41 |
+
'we get', 'we have', 'we know', 'this gives', 'which yields',
|
| 42 |
+
]
|
| 43 |
+
|
| 44 |
+
MATH_SYMBOLS = set('β«ββββΒ±ΓΓ·β β€β₯ββββββββ©βͺαβγδΡ΢ηθλμΟΟΟΟΟ')
|
| 45 |
+
|
| 46 |
def __init__(self):
|
| 47 |
pass
|
| 48 |
|
|
|
|
| 74 |
except Exception:
|
| 75 |
return False
|
| 76 |
|
| 77 |
+
def check_numerical_integration(self, prediction: str, sympy_f: Any) -> bool:
|
| 78 |
+
"""
|
| 79 |
+
[PAPER TRACEABILITY: Section 3.1.3 Solution Verification]
|
| 80 |
+
Numerical multi-point quadrature verification.
|
| 81 |
+
Differentiates the prediction F_pred(x) and compares it to the ground
|
| 82 |
+
truth integrand f(x) at 5 random points.
|
| 83 |
+
"""
|
| 84 |
+
import sympy as sp
|
| 85 |
+
import random
|
| 86 |
+
x = sp.Symbol('x')
|
| 87 |
+
try:
|
| 88 |
+
clean_pred = self._clean_math_answer(prediction)
|
| 89 |
+
F_pred = sp.parse_expr(clean_pred)
|
| 90 |
+
f_pred = sp.diff(F_pred, x)
|
| 91 |
+
|
| 92 |
+
# Evaluate at 5 random points
|
| 93 |
+
for _ in range(5):
|
| 94 |
+
test_point = random.uniform(-5, 5)
|
| 95 |
+
p_val = float(f_pred.subs(x, test_point).evalf())
|
| 96 |
+
t_val = float(sympy_f.subs(x, test_point).evalf())
|
| 97 |
+
|
| 98 |
+
# Paper uses 10^-2 relative tolerance
|
| 99 |
+
if not math.isclose(p_val, t_val, rel_tol=1e-2, abs_tol=1e-2):
|
| 100 |
+
return False
|
| 101 |
+
return True
|
| 102 |
+
except Exception:
|
| 103 |
+
return False
|
| 104 |
+
|
| 105 |
+
def check_structural_similarity(self, prediction: str, ground_truth: str, sympy_f: Any = None) -> float:
|
| 106 |
+
"""
|
| 107 |
+
Graduated structural similarity check.
|
| 108 |
+
Compares SymPy expression trees to provide partial credit when the
|
| 109 |
+
model's answer has the right structure but wrong coefficients.
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
0.7 if structure matches but coefficients differ
|
| 113 |
+
0.4 if the expression is parseable and shares operand types
|
| 114 |
+
0.15 if the prediction is a parseable math expression
|
| 115 |
+
0.0 if unparseable
|
| 116 |
+
"""
|
| 117 |
+
import sympy as sp
|
| 118 |
+
x = sp.Symbol('x')
|
| 119 |
+
|
| 120 |
+
try:
|
| 121 |
+
clean_pred = self._clean_math_answer(prediction)
|
| 122 |
+
clean_gt = self._clean_math_answer(ground_truth)
|
| 123 |
+
|
| 124 |
+
pred_expr = sp.parse_expr(clean_pred)
|
| 125 |
+
gt_expr = sp.parse_expr(clean_gt)
|
| 126 |
+
except Exception:
|
| 127 |
+
# Can't even parse β check if it at least looks like math
|
| 128 |
+
if self._looks_like_math(prediction):
|
| 129 |
+
return 0.15
|
| 130 |
+
return 0.0
|
| 131 |
+
|
| 132 |
+
# Check if expression trees have similar structure
|
| 133 |
+
try:
|
| 134 |
+
pred_funcs = self._extract_function_types(pred_expr)
|
| 135 |
+
gt_funcs = self._extract_function_types(gt_expr)
|
| 136 |
+
|
| 137 |
+
# Count overlapping function types (sin, cos, exp, log, Pow, etc.)
|
| 138 |
+
overlap = pred_funcs & gt_funcs
|
| 139 |
+
union = pred_funcs | gt_funcs
|
| 140 |
+
|
| 141 |
+
if not union:
|
| 142 |
+
return 0.15 # Both are just constants/variables
|
| 143 |
+
|
| 144 |
+
jaccard = len(overlap) / len(union)
|
| 145 |
+
|
| 146 |
+
if jaccard >= 0.8:
|
| 147 |
+
# Very similar structure β likely right form, wrong coefficient
|
| 148 |
+
# Verify by checking at sample points if shapes are proportional
|
| 149 |
+
if self._check_proportional(pred_expr, gt_expr, x):
|
| 150 |
+
return 0.7
|
| 151 |
+
return 0.5
|
| 152 |
+
elif jaccard >= 0.4:
|
| 153 |
+
return 0.4
|
| 154 |
+
else:
|
| 155 |
+
return 0.15
|
| 156 |
+
|
| 157 |
+
except Exception:
|
| 158 |
+
return 0.15
|
| 159 |
+
|
| 160 |
+
def check_technique_recognition(self, reasoning: str, technique_hint: str = "") -> float:
|
| 161 |
+
"""
|
| 162 |
+
Checks if the model identified the correct integration technique.
|
| 163 |
+
Returns a score β [0, 1] based on technique match.
|
| 164 |
+
|
| 165 |
+
This provides reward signal even when the final answer is wrong,
|
| 166 |
+
as long as the model is using the right approach.
|
| 167 |
+
"""
|
| 168 |
+
if not technique_hint:
|
| 169 |
+
return 0.0
|
| 170 |
+
|
| 171 |
+
lower_r = reasoning.lower()
|
| 172 |
+
|
| 173 |
+
# Check if the correct technique keywords appear in reasoning
|
| 174 |
+
technique_kws = self.TECHNIQUE_KEYWORDS.get(technique_hint, [])
|
| 175 |
+
if not technique_kws:
|
| 176 |
+
return 0.0
|
| 177 |
+
|
| 178 |
+
matches = sum(1 for kw in technique_kws if kw in lower_r)
|
| 179 |
+
|
| 180 |
+
if matches >= 2:
|
| 181 |
+
return 1.0 # Strong evidence of correct technique
|
| 182 |
+
elif matches == 1:
|
| 183 |
+
return 0.6 # Some evidence
|
| 184 |
+
|
| 185 |
+
# Check if any technique was attempted at all
|
| 186 |
+
any_technique = False
|
| 187 |
+
for tech, kws in self.TECHNIQUE_KEYWORDS.items():
|
| 188 |
+
if any(kw in lower_r for kw in kws):
|
| 189 |
+
any_technique = True
|
| 190 |
+
break
|
| 191 |
+
|
| 192 |
+
return 0.2 if any_technique else 0.0
|
| 193 |
+
|
| 194 |
def mock_llm_judge(self, reasoning: str, prediction: str, ground_truth: str) -> float:
|
| 195 |
"""4. LLM judge (mock or placeholder scoring reasoning quality)
|
| 196 |
Returns reasoning quality score Q (0.0 to 1.0)
|
| 197 |
+
|
| 198 |
+
Improved with mathematical density scoring and better structural analysis.
|
| 199 |
"""
|
|
|
|
|
|
|
|
|
|
| 200 |
score = 0.0
|
| 201 |
+
lower_reasoning = reasoning.lower()
|
| 202 |
+
words = reasoning.split()
|
| 203 |
+
length = len(words)
|
| 204 |
|
| 205 |
+
# Length bonus (up to 0.25) β diminishing returns, gentle curve
|
| 206 |
+
score += min(0.25, length * 0.005)
|
|
|
|
| 207 |
|
| 208 |
+
# Mathematical marker bonus (up to 0.35)
|
| 209 |
+
marker_count = sum(1 for m in self.MATH_MARKERS if m in lower_reasoning)
|
| 210 |
+
score += min(0.35, marker_count * 0.05)
|
| 211 |
+
|
| 212 |
+
# Mathematical symbol density bonus (up to 0.2)
|
| 213 |
+
math_chars = sum(1 for c in reasoning if c in '=+-*/^()β«βββ' or c in self.MATH_SYMBOLS)
|
| 214 |
+
if length > 0:
|
| 215 |
+
math_density = math_chars / max(1, len(reasoning))
|
| 216 |
+
score += min(0.2, math_density * 2.0)
|
| 217 |
|
| 218 |
+
# Structured step progression bonus (up to 0.2)
|
| 219 |
+
has_numbered_steps = bool(re.search(r'step\s*\d|^\d+[\.\)]', lower_reasoning, re.MULTILINE))
|
| 220 |
+
has_logical_flow = ('therefore' in lower_reasoning or 'thus' in lower_reasoning or
|
| 221 |
+
'hence' in lower_reasoning or 'so we' in lower_reasoning)
|
| 222 |
+
if has_numbered_steps:
|
| 223 |
+
score += 0.12
|
| 224 |
+
if has_logical_flow:
|
| 225 |
+
score += 0.08
|
| 226 |
+
|
| 227 |
+
return round(min(1.0, score), 3)
|
| 228 |
|
| 229 |
def check_process_supervision(self, reasoning: str) -> float:
|
| 230 |
"""
|
| 231 |
[PAPER TRACEABILITY: Process Supervision (Lightweight PRM)]
|
| 232 |
E. PROCESS SUPERVISION (STEP-AWARE REWARD)
|
| 233 |
+
|
| 234 |
+
Improved with:
|
| 235 |
+
- Mathematical density scoring
|
| 236 |
+
- Multi-level step detection
|
| 237 |
+
- Granular logical jump penalties
|
| 238 |
+
- Technique-specific reward signals
|
| 239 |
"""
|
| 240 |
lower_r = reasoning.lower()
|
| 241 |
+
words = lower_r.split()
|
| 242 |
+
word_count = len(words)
|
| 243 |
score = 0.0
|
| 244 |
|
| 245 |
+
# 1. Check stepwise structure (up to 0.4)
|
| 246 |
+
numbered_steps = len(re.findall(r'step\s*\d', lower_r))
|
| 247 |
+
if numbered_steps >= 3:
|
| 248 |
+
score += 0.4
|
| 249 |
+
elif numbered_steps >= 2:
|
| 250 |
score += 0.3
|
| 251 |
+
elif numbered_steps >= 1:
|
| 252 |
+
score += 0.2
|
| 253 |
+
elif 'first' in lower_r and ('then' in lower_r or 'next' in lower_r):
|
| 254 |
+
score += 0.15
|
| 255 |
+
|
| 256 |
+
# 2. Mathematical operation density (up to 0.3)
|
| 257 |
+
math_ops = len(re.findall(r'[=+\-*/^]', reasoning))
|
| 258 |
+
if word_count > 0:
|
| 259 |
+
op_density = math_ops / word_count
|
| 260 |
+
score += min(0.3, op_density * 3.0)
|
| 261 |
+
|
| 262 |
+
# 3. Technique identification bonus (up to 0.2)
|
| 263 |
+
techniques_mentioned = 0
|
| 264 |
+
for tech, kws in self.TECHNIQUE_KEYWORDS.items():
|
| 265 |
+
if any(kw in lower_r for kw in kws):
|
| 266 |
+
techniques_mentioned += 1
|
| 267 |
+
score += min(0.2, techniques_mentioned * 0.1)
|
| 268 |
+
|
| 269 |
+
# 4. Logical jump penalty β short reasoning with complex claims
|
| 270 |
+
if word_count < 10 and ('=' in lower_r or 'so' in lower_r):
|
| 271 |
+
score -= 0.3
|
| 272 |
+
elif word_count < 20 and math_ops > 3:
|
| 273 |
+
score -= 0.15 # Slightly suspicious β many operations, few words
|
| 274 |
|
| 275 |
+
# 5. Bonus for showing intermediate results
|
| 276 |
+
intermediate_results = len(re.findall(r'=\s*[\d\w]', reasoning))
|
| 277 |
+
score += min(0.1, intermediate_results * 0.02)
|
| 278 |
+
|
| 279 |
return max(-1.0, min(1.0, score))
|
| 280 |
|
| 281 |
def check_reflection(self, reasoning: str, c: float) -> float:
|
|
|
|
| 284 |
H. REFLECTION MODULE
|
| 285 |
Model generates "What could be wrong?"
|
| 286 |
Penalize if contradiction with final answer, reward correct self-correction.
|
| 287 |
+
|
| 288 |
+
Improved with graduated scoring based on reflection quality.
|
| 289 |
"""
|
| 290 |
lower_r = reasoning.lower()
|
| 291 |
score = 0.0
|
| 292 |
|
| 293 |
+
reflection_phrases = [
|
| 294 |
+
"what could be wrong", "wait,", "let me check", "alternatively",
|
| 295 |
+
"let me verify", "double check", "reconsider", "hmm",
|
| 296 |
+
"actually,", "correction:", "i made an error", "let me redo"
|
| 297 |
+
]
|
| 298 |
+
|
| 299 |
+
reflections_found = sum(1 for phrase in reflection_phrases if phrase in lower_r)
|
| 300 |
+
|
| 301 |
+
if reflections_found > 0:
|
| 302 |
+
if c >= 0.7: # At least partially correct
|
| 303 |
+
# Graduated reward based on how many reflection markers used
|
| 304 |
+
score += min(1.0, 0.5 + reflections_found * 0.2)
|
| 305 |
+
elif c >= 0.4:
|
| 306 |
+
# Some credit β reflected but didn't fully fix
|
| 307 |
+
score += 0.1
|
| 308 |
else:
|
| 309 |
+
# Reflected but still wrong β mild penalty (not as harsh as before)
|
| 310 |
+
score -= 0.3
|
| 311 |
|
| 312 |
+
return max(-1.0, min(1.0, score))
|
| 313 |
+
|
| 314 |
+
def verify(self, reasoning: str, prediction: str, ground_truth: str,
|
| 315 |
+
sympy_f: Any = None, technique_hint: str = "") -> Tuple[float, float, float, float]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
"""
|
| 317 |
+
Run all verifiers with GRADUATED CORRECTNESS scoring.
|
| 318 |
+
|
| 319 |
+
Returns:
|
| 320 |
+
C οΏ½οΏ½ Correctness β [0, 1] (graduated, not binary)
|
| 321 |
+
Q β Reasoning Quality β [0, 1]
|
| 322 |
+
P β Process Supervision β [-1, 1]
|
| 323 |
+
R β Reflection Score β [-1, 1]
|
| 324 |
"""
|
| 325 |
+
# --- Graduated Correctness ---
|
| 326 |
c = 0.0
|
| 327 |
+
|
| 328 |
+
# Tier 1: Full correctness (1.0)
|
| 329 |
if self.check_exact_match(prediction, ground_truth):
|
| 330 |
c = 1.0
|
| 331 |
elif sympy_f is not None and self.check_numerical_integration(prediction, sympy_f):
|
|
|
|
| 334 |
c = 1.0
|
| 335 |
elif self.check_python_execution(prediction, ground_truth):
|
| 336 |
c = 1.0
|
| 337 |
+
|
| 338 |
+
# Tier 2-4: Partial credit (only if not fully correct)
|
| 339 |
+
if c < 1.0:
|
| 340 |
+
structural_score = self.check_structural_similarity(prediction, ground_truth, sympy_f)
|
| 341 |
+
technique_score = self.check_technique_recognition(reasoning, technique_hint)
|
| 342 |
+
|
| 343 |
+
# Take the best partial credit signal
|
| 344 |
+
c = max(c, structural_score)
|
| 345 |
+
|
| 346 |
+
# Technique recognition can boost partial credit
|
| 347 |
+
if technique_score > 0 and c < 0.7:
|
| 348 |
+
c = max(c, 0.15 + technique_score * 0.25) # Up to 0.4 from technique alone
|
| 349 |
|
| 350 |
q = self.mock_llm_judge(reasoning, prediction, ground_truth)
|
|
|
|
| 351 |
p = self.check_process_supervision(reasoning)
|
| 352 |
r = self.check_reflection(reasoning, c)
|
| 353 |
|
| 354 |
return c, q, p, r
|
| 355 |
+
|
| 356 |
+
# --- Private Helpers ---
|
| 357 |
+
|
| 358 |
+
def _clean_math_answer(self, text: str) -> str:
|
| 359 |
+
"""Clean a math answer string for SymPy parsing."""
|
| 360 |
+
clean = text.strip()
|
| 361 |
+
if "Answer:" in clean:
|
| 362 |
+
clean = clean.split("Answer:")[-1].strip()
|
| 363 |
+
# Remove constant of integration
|
| 364 |
+
clean = re.sub(r'\+\s*[Cc]\s*$', '', clean).strip()
|
| 365 |
+
# Remove LaTeX wrappers
|
| 366 |
+
clean = clean.replace('$', '').replace('\\', '')
|
| 367 |
+
return clean
|
| 368 |
+
|
| 369 |
+
def _looks_like_math(self, text: str) -> bool:
|
| 370 |
+
"""Check if text contains mathematical content."""
|
| 371 |
+
math_indicators = ['=', '+', '-', '*', '/', '^', 'x', 'sin', 'cos', 'exp', 'log', '(']
|
| 372 |
+
return sum(1 for m in math_indicators if m in text.lower()) >= 2
|
| 373 |
+
|
| 374 |
+
def _extract_function_types(self, expr) -> set:
|
| 375 |
+
"""Extract the set of function types from a SymPy expression tree."""
|
| 376 |
+
import sympy as sp
|
| 377 |
+
types = set()
|
| 378 |
+
|
| 379 |
+
if isinstance(expr, sp.Add):
|
| 380 |
+
types.add('Add')
|
| 381 |
+
elif isinstance(expr, sp.Mul):
|
| 382 |
+
types.add('Mul')
|
| 383 |
+
elif isinstance(expr, sp.Pow):
|
| 384 |
+
types.add('Pow')
|
| 385 |
+
|
| 386 |
+
func_type = type(expr).__name__
|
| 387 |
+
if func_type in ('sin', 'cos', 'tan', 'exp', 'log', 'ln', 'Abs',
|
| 388 |
+
'asin', 'acos', 'atan', 'sinh', 'cosh', 'tanh'):
|
| 389 |
+
types.add(func_type)
|
| 390 |
+
|
| 391 |
+
# Recurse into sub-expressions
|
| 392 |
+
if hasattr(expr, 'args'):
|
| 393 |
+
for arg in expr.args:
|
| 394 |
+
types |= self._extract_function_types(arg)
|
| 395 |
+
|
| 396 |
+
return types
|
| 397 |
+
|
| 398 |
+
def _check_proportional(self, expr1, expr2, x) -> bool:
|
| 399 |
+
"""Check if two expressions are proportional (differ only by a constant factor)."""
|
| 400 |
+
import sympy as sp
|
| 401 |
+
import random
|
| 402 |
+
|
| 403 |
+
try:
|
| 404 |
+
ratios = []
|
| 405 |
+
for _ in range(3):
|
| 406 |
+
pt = random.uniform(-3, 3)
|
| 407 |
+
v1 = float(expr1.subs(x, pt).evalf())
|
| 408 |
+
v2 = float(expr2.subs(x, pt).evalf())
|
| 409 |
+
if abs(v2) < 1e-10:
|
| 410 |
+
continue
|
| 411 |
+
ratios.append(v1 / v2)
|
| 412 |
+
|
| 413 |
+
if len(ratios) < 2:
|
| 414 |
+
return False
|
| 415 |
+
|
| 416 |
+
# Check if all ratios are approximately equal (constant factor)
|
| 417 |
+
return all(math.isclose(r, ratios[0], rel_tol=0.1) for r in ratios)
|
| 418 |
+
except Exception:
|
| 419 |
+
return False
|
tests/test_env.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
import sys
|
| 2 |
import os
|
| 3 |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 4 |
|
|
@@ -11,15 +11,32 @@ from env.models import AutomathreasonerAction
|
|
| 11 |
def test_generator():
|
| 12 |
engine = TaskGenerationEngine()
|
| 13 |
|
| 14 |
-
# Test
|
| 15 |
-
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
-
# Test
|
| 19 |
-
task = engine.generate_task(target_difficulty_band=
|
| 20 |
-
|
| 21 |
-
assert
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
def test_verifier():
|
| 25 |
verifier = VerifierSystem()
|
|
@@ -27,37 +44,104 @@ def test_verifier():
|
|
| 27 |
# Exact match
|
| 28 |
assert verifier.check_exact_match("42", "42")
|
| 29 |
assert verifier.check_exact_match(" 42 ", "42")
|
|
|
|
| 30 |
|
| 31 |
# Numeric tolerance
|
| 32 |
assert verifier.check_numeric_tolerance("3.14159", "3.1415")
|
| 33 |
assert not verifier.check_numeric_tolerance("4.1415", "3.1415")
|
|
|
|
| 34 |
|
| 35 |
# Python execution
|
| 36 |
assert verifier.check_python_execution("2 + 2", "4")
|
|
|
|
| 37 |
|
| 38 |
-
# Full verification
|
| 39 |
-
c, q = verifier.verify("Because 2 + 2 is 4", "4", "4")
|
| 40 |
assert c == 1.0
|
| 41 |
-
assert q > 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
def test_rewards():
|
| 44 |
reward_sys = RewardSystem(max_len=1000)
|
| 45 |
-
history = [{"final_answer": "42"}]
|
| 46 |
|
| 47 |
-
# Test diversity
|
|
|
|
| 48 |
d = reward_sys.compute_diversity("42", history)
|
| 49 |
assert d == -1.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
r, comps = reward_sys.compute_reward(
|
| 53 |
correctness=1.0,
|
| 54 |
-
reasoning_quality=
|
| 55 |
-
|
| 56 |
-
|
|
|
|
|
|
|
| 57 |
history=[],
|
| 58 |
-
times_seen_problem=0
|
|
|
|
| 59 |
)
|
| 60 |
assert r > 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
def test_environment_step():
|
| 63 |
env = AutomathreasonerEnvironment()
|
|
@@ -66,14 +150,130 @@ def test_environment_step():
|
|
| 66 |
assert obs.problem_text != ""
|
| 67 |
assert obs.difficulty_level > 0
|
| 68 |
assert len(obs.history) == 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
-
#
|
| 71 |
action = AutomathreasonerAction(
|
| 72 |
-
reasoning="I
|
| 73 |
-
final_answer="
|
| 74 |
)
|
| 75 |
|
| 76 |
obs_after = env.step(action)
|
| 77 |
assert obs_after.reward is not None
|
| 78 |
assert len(obs_after.history) == 1
|
| 79 |
assert "reward_components" in obs_after.metadata
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ο»Ώimport sys
|
| 2 |
import os
|
| 3 |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 4 |
|
|
|
|
| 11 |
def test_generator():
|
| 12 |
engine = TaskGenerationEngine()
|
| 13 |
|
| 14 |
+
# Test task generation at various difficulty levels
|
| 15 |
+
for diff in [1.0, 3.0, 5.0]:
|
| 16 |
+
task = engine.generate_task(target_difficulty_band=diff)
|
| 17 |
+
assert "problem" in task
|
| 18 |
+
assert "solution" in task
|
| 19 |
+
assert "difficulty" in task
|
| 20 |
+
assert "technique" in task
|
| 21 |
+
assert "scaffold_hints" in task
|
| 22 |
+
assert task["technique"] in ['power_rule', 'u_substitution', 'by_parts',
|
| 23 |
+
'trigonometric', 'exponential', 'logarithmic']
|
| 24 |
+
print(f" Γ’Εβ Difficulty {diff}: technique={task['technique']}, problem={task['problem'][:60]}...")
|
| 25 |
|
| 26 |
+
# Test variant generation
|
| 27 |
+
task = engine.generate_task(target_difficulty_band=4.0)
|
| 28 |
+
variants = engine.generate_variants(task, count=3)
|
| 29 |
+
assert len(variants) > 0
|
| 30 |
+
for v in variants:
|
| 31 |
+
assert "problem" in v
|
| 32 |
+
assert "technique" in v
|
| 33 |
+
print(f" Γ’Εβ Generated {len(variants)} variants")
|
| 34 |
+
|
| 35 |
+
# Test technique-focused generation
|
| 36 |
+
for tech in ['power_rule', 'u_substitution', 'by_parts']:
|
| 37 |
+
task = engine.generate_technique_focused_task(tech, difficulty=2.0)
|
| 38 |
+
assert task["technique"] == tech
|
| 39 |
+
print(f" Γ’Εβ Technique-focused: {tech}")
|
| 40 |
|
| 41 |
def test_verifier():
|
| 42 |
verifier = VerifierSystem()
|
|
|
|
| 44 |
# Exact match
|
| 45 |
assert verifier.check_exact_match("42", "42")
|
| 46 |
assert verifier.check_exact_match(" 42 ", "42")
|
| 47 |
+
print(" Γ’Εβ Exact match")
|
| 48 |
|
| 49 |
# Numeric tolerance
|
| 50 |
assert verifier.check_numeric_tolerance("3.14159", "3.1415")
|
| 51 |
assert not verifier.check_numeric_tolerance("4.1415", "3.1415")
|
| 52 |
+
print(" Γ’Εβ Numeric tolerance")
|
| 53 |
|
| 54 |
# Python execution
|
| 55 |
assert verifier.check_python_execution("2 + 2", "4")
|
| 56 |
+
print(" Γ’Εβ Python execution")
|
| 57 |
|
| 58 |
+
# Full verification Γ’β¬β now returns 4 values (c, q, p, r)
|
| 59 |
+
c, q, p, r = verifier.verify("Step 1: Because 2 + 2 is 4. Therefore the answer is 4.", "4", "4")
|
| 60 |
assert c == 1.0
|
| 61 |
+
assert q > 0.0
|
| 62 |
+
print(f" Γ’Εβ Full verify: C={c}, Q={q:.3f}, P={p:.3f}, R={r:.3f}")
|
| 63 |
+
|
| 64 |
+
# Graduated correctness Γ’β¬β structural similarity
|
| 65 |
+
score = verifier.check_structural_similarity("x**3", "2*x**3")
|
| 66 |
+
assert score > 0.0 # Should get partial credit for same structure
|
| 67 |
+
print(f" Γ’Εβ Structural similarity: {score:.2f}")
|
| 68 |
+
|
| 69 |
+
# Technique recognition
|
| 70 |
+
tech_score = verifier.check_technique_recognition(
|
| 71 |
+
"Let u = x^2, then du = 2x dx. By substitution we get...",
|
| 72 |
+
"u_substitution"
|
| 73 |
+
)
|
| 74 |
+
assert tech_score > 0.5
|
| 75 |
+
print(f" Γ’Εβ Technique recognition: {tech_score:.2f}")
|
| 76 |
+
|
| 77 |
+
# Process supervision Γ’β¬β improved
|
| 78 |
+
p_good = verifier.check_process_supervision(
|
| 79 |
+
"Step 1: Identify the integrand. Step 2: Apply the power rule. Therefore x^3/3 + C."
|
| 80 |
+
)
|
| 81 |
+
p_bad = verifier.check_process_supervision("so = 42")
|
| 82 |
+
assert p_good > p_bad
|
| 83 |
+
print(f" Γ’Εβ Process supervision: good={p_good:.2f}, bad={p_bad:.2f}")
|
| 84 |
|
| 85 |
def test_rewards():
|
| 86 |
reward_sys = RewardSystem(max_len=1000)
|
|
|
|
| 87 |
|
| 88 |
+
# Test diversity Γ’β¬β exact repeat penalty
|
| 89 |
+
history = [{"final_answer": "42"}]
|
| 90 |
d = reward_sys.compute_diversity("42", history)
|
| 91 |
assert d == -1.0
|
| 92 |
+
print(f" Γ’Εβ Diversity repeat penalty: {d}")
|
| 93 |
+
|
| 94 |
+
# Test diversity Γ’β¬β also works with 'prediction' key (backward compat)
|
| 95 |
+
history_v2 = [{"prediction": "42"}]
|
| 96 |
+
d2 = reward_sys.compute_diversity("42", history_v2)
|
| 97 |
+
assert d2 == -1.0
|
| 98 |
+
print(f" Γ’Εβ Diversity backward compat: {d2}")
|
| 99 |
+
|
| 100 |
+
# Test diversity Γ’β¬β unique answer
|
| 101 |
+
d3 = reward_sys.compute_diversity("99", history)
|
| 102 |
+
assert d3 == 1.0
|
| 103 |
+
print(f" Γ’Εβ Diversity unique bonus: {d3}")
|
| 104 |
|
| 105 |
+
# Test format compliance
|
| 106 |
+
f = reward_sys.compute_format_compliance(
|
| 107 |
+
"Step 1: Apply power rule.\nAnswer: x^2/2",
|
| 108 |
+
"Step 1: Apply power rule.",
|
| 109 |
+
"x^2/2"
|
| 110 |
+
)
|
| 111 |
+
assert f > 0.5
|
| 112 |
+
print(f" Γ’Εβ Format compliance: {f:.2f}")
|
| 113 |
+
|
| 114 |
+
# Full reward computation Γ’β¬β new signature with all params
|
| 115 |
r, comps = reward_sys.compute_reward(
|
| 116 |
correctness=1.0,
|
| 117 |
+
reasoning_quality=0.8,
|
| 118 |
+
process_supervision=0.5,
|
| 119 |
+
reflection_score=0.0,
|
| 120 |
+
action_str="Step 1: Apply power rule. Step 2: Simplify. Answer: x^2/2",
|
| 121 |
+
final_answer="x^2/2",
|
| 122 |
history=[],
|
| 123 |
+
times_seen_problem=0,
|
| 124 |
+
reasoning="Step 1: Apply power rule. Step 2: Simplify.",
|
| 125 |
)
|
| 126 |
assert r > 0.0
|
| 127 |
+
assert "C_correctness" in comps
|
| 128 |
+
assert "F_format" in comps
|
| 129 |
+
assert comps["F_format"] > 0 # Format compliance should be non-zero
|
| 130 |
+
print(f" Γ’Εβ Full reward: {r:.3f}, components: {len(comps)} fields")
|
| 131 |
+
|
| 132 |
+
# Verify all 7+ components are tracked
|
| 133 |
+
expected_keys = ["C_correctness", "Q_reasoning", "P_process_supervision",
|
| 134 |
+
"R_reflection", "D_diversity", "E_efficiency",
|
| 135 |
+
"X_exploration", "F_format"]
|
| 136 |
+
for key in expected_keys:
|
| 137 |
+
assert key in comps, f"Missing component: {key}"
|
| 138 |
+
print(f" Γ’Εβ All {len(expected_keys)} reward components present")
|
| 139 |
+
|
| 140 |
+
# Trivial output detection
|
| 141 |
+
assert reward_sys.detect_trivial_output("a")
|
| 142 |
+
assert reward_sys.detect_trivial_output("aaaaaaaaaaaaa")
|
| 143 |
+
assert not reward_sys.detect_trivial_output("x^2 + 2x + 1")
|
| 144 |
+
print(" Γ’Εβ Trivial output detection")
|
| 145 |
|
| 146 |
def test_environment_step():
|
| 147 |
env = AutomathreasonerEnvironment()
|
|
|
|
| 150 |
assert obs.problem_text != ""
|
| 151 |
assert obs.difficulty_level > 0
|
| 152 |
assert len(obs.history) == 0
|
| 153 |
+
print(f" Γ’Εβ Reset: difficulty={obs.difficulty_level}, problem={obs.problem_text[:60]}...")
|
| 154 |
+
|
| 155 |
+
# Technique metadata in observation
|
| 156 |
+
assert "technique" in obs.metadata
|
| 157 |
+
print(f" Γ’Εβ Technique metadata: {obs.metadata['technique']}")
|
| 158 |
|
| 159 |
+
# Dummy action step
|
| 160 |
action = AutomathreasonerAction(
|
| 161 |
+
reasoning="Step 1: I identify the integrand. Step 2: Applying the power rule.",
|
| 162 |
+
final_answer="x^2/2"
|
| 163 |
)
|
| 164 |
|
| 165 |
obs_after = env.step(action)
|
| 166 |
assert obs_after.reward is not None
|
| 167 |
assert len(obs_after.history) == 1
|
| 168 |
assert "reward_components" in obs_after.metadata
|
| 169 |
+
assert "correctness_score" in obs_after.metadata
|
| 170 |
+
print(f" Γ’Εβ Step: reward={obs_after.reward:.3f}, "
|
| 171 |
+
f"correct={obs_after.metadata['is_correct']}, "
|
| 172 |
+
f"C={obs_after.metadata['correctness_score']:.2f}")
|
| 173 |
+
|
| 174 |
+
# Verify history stores both keys
|
| 175 |
+
assert "prediction" in obs_after.history[0]
|
| 176 |
+
assert "final_answer" in obs_after.history[0]
|
| 177 |
+
print(" Γ’Εβ History backward compatibility")
|
| 178 |
+
|
| 179 |
+
def test_curriculum_progression():
|
| 180 |
+
"""Test that curriculum actually advances with good performance."""
|
| 181 |
+
env = AutomathreasonerEnvironment()
|
| 182 |
+
initial_diff = env.difficulty_level
|
| 183 |
+
|
| 184 |
+
# Simulate a series of correct answers
|
| 185 |
+
for _ in range(5):
|
| 186 |
+
env.rolling_results.append(1)
|
| 187 |
+
env.rolling_rewards.append(0.7)
|
| 188 |
+
|
| 189 |
+
env._update_curriculum()
|
| 190 |
+
assert env.difficulty_level > initial_diff, (
|
| 191 |
+
f"Curriculum should advance: {initial_diff} -> {env.difficulty_level}"
|
| 192 |
+
)
|
| 193 |
+
print(f" Γ’Εβ Curriculum advanced: {initial_diff} -> {env.difficulty_level:.1f}")
|
| 194 |
+
|
| 195 |
+
def test_scaffold_hints():
|
| 196 |
+
"""Test that scaffold hints are generated after failures."""
|
| 197 |
+
env = AutomathreasonerEnvironment()
|
| 198 |
+
env.reset()
|
| 199 |
+
|
| 200 |
+
# No hint at 0 failures
|
| 201 |
+
env.consecutive_failures = 0
|
| 202 |
+
hint0 = env._get_scaffold_observation()
|
| 203 |
+
assert hint0 == ""
|
| 204 |
+
|
| 205 |
+
# Hint at 2 failures
|
| 206 |
+
env.consecutive_failures = 2
|
| 207 |
+
env.current_scaffold_hints = {
|
| 208 |
+
'hint_level_1': 'Try u-substitution',
|
| 209 |
+
'hint_level_2': 'Let u = x^2',
|
| 210 |
+
'hint_level_3': 'The answer starts with sin(x^2)',
|
| 211 |
+
}
|
| 212 |
+
hint2 = env._get_scaffold_observation()
|
| 213 |
+
assert "Hint" in hint2
|
| 214 |
+
assert "u-substitution" in hint2
|
| 215 |
+
|
| 216 |
+
# Stronger hint at 3 failures
|
| 217 |
+
env.consecutive_failures = 3
|
| 218 |
+
hint3 = env._get_scaffold_observation()
|
| 219 |
+
assert "u = x^2" in hint3
|
| 220 |
+
|
| 221 |
+
# Strongest hint at 4+ failures
|
| 222 |
+
env.consecutive_failures = 4
|
| 223 |
+
hint4 = env._get_scaffold_observation()
|
| 224 |
+
assert "Strong Hint" in hint4
|
| 225 |
+
|
| 226 |
+
print(" Γ’Εβ Scaffold hints: level 1, 2, 3 all working")
|
| 227 |
+
|
| 228 |
+
def test_graduated_correctness_flow():
|
| 229 |
+
"""End-to-end test: partial credit flows through the whole system."""
|
| 230 |
+
env = AutomathreasonerEnvironment()
|
| 231 |
+
obs = env.reset()
|
| 232 |
+
|
| 233 |
+
# Submit a plausible but wrong math answer
|
| 234 |
+
action = AutomathreasonerAction(
|
| 235 |
+
reasoning="Step 1: I apply the power rule. Step 2: I integrate term by term. Therefore the answer is:",
|
| 236 |
+
final_answer="x**2 + x" # Almost certainly wrong, but parseable math
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
obs_after = env.step(action)
|
| 240 |
+
c_score = obs_after.metadata.get('correctness_score', 0)
|
| 241 |
+
|
| 242 |
+
# Should get SOME partial credit (> 0) for parseable math with right techniques
|
| 243 |
+
print(f" Γ’Εβ Graduated correctness: C={c_score:.2f}, reward={obs_after.reward:.3f}")
|
| 244 |
+
# Reward should be positive even when wrong (format + reasoning + partial credit)
|
| 245 |
+
assert obs_after.reward > 0.0, f"Expected positive reward for structured wrong answer, got {obs_after.reward}"
|
| 246 |
+
print(f" Γ’Εβ Positive reward for structured wrong answer: {obs_after.reward:.3f}")
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
if __name__ == "__main__":
|
| 250 |
+
print("=" * 60)
|
| 251 |
+
print("AutoMathReasoner Test Suite (v2 - Optimized)")
|
| 252 |
+
print("=" * 60)
|
| 253 |
+
|
| 254 |
+
print("\n[TEST] test_generator")
|
| 255 |
+
test_generator()
|
| 256 |
+
|
| 257 |
+
print("\n[TEST] test_verifier")
|
| 258 |
+
test_verifier()
|
| 259 |
+
|
| 260 |
+
print("\n[TEST] test_rewards")
|
| 261 |
+
test_rewards()
|
| 262 |
+
|
| 263 |
+
print("\n[TEST] test_environment_step")
|
| 264 |
+
test_environment_step()
|
| 265 |
+
|
| 266 |
+
print("\n[TEST] test_curriculum_progression")
|
| 267 |
+
test_curriculum_progression()
|
| 268 |
+
|
| 269 |
+
print("\n[TEST] test_scaffold_hints")
|
| 270 |
+
test_scaffold_hints()
|
| 271 |
+
|
| 272 |
+
print("\n[TEST] test_graduated_correctness_flow")
|
| 273 |
+
test_graduated_correctness_flow()
|
| 274 |
+
|
| 275 |
+
print("\n" + "=" * 60)
|
| 276 |
+
print("[OK] ALL TESTS PASSED")
|
| 277 |
+
print("=" * 60)
|
| 278 |
+
|
| 279 |
+
|
train/colab_train.py
CHANGED
|
@@ -17,6 +17,7 @@ import collections
|
|
| 17 |
import random
|
| 18 |
from datasets import Dataset
|
| 19 |
import torch
|
|
|
|
| 20 |
|
| 21 |
# Unsloth & TRL
|
| 22 |
from unsloth import FastLanguageModel
|
|
@@ -33,13 +34,13 @@ from AutoMathReasoner.env.models import AutomathreasonerAction
|
|
| 33 |
HF_SPACE_URL = "https://your-username-automathreasoner.hf.space"
|
| 34 |
env = AutomathreasonerEnv(url=HF_SPACE_URL)
|
| 35 |
|
| 36 |
-
max_seq_length = 1024
|
| 37 |
lora_rank = 16
|
| 38 |
|
| 39 |
# 2. Load Model via Unsloth (optimized for Free Colab VRAM)
|
| 40 |
print("Loading model via Unsloth...")
|
| 41 |
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 42 |
-
model_name = "unsloth/llama-3-8b-Instruct-bnb-4bit",
|
| 43 |
max_seq_length = max_seq_length,
|
| 44 |
dtype = None,
|
| 45 |
load_in_4bit = True,
|
|
@@ -52,35 +53,66 @@ model = FastLanguageModel.get_peft_model(
|
|
| 52 |
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
|
| 53 |
"gate_proj", "up_proj", "down_proj"],
|
| 54 |
lora_alpha = lora_rank,
|
| 55 |
-
use_gradient_checkpointing = "unsloth",
|
| 56 |
)
|
| 57 |
|
| 58 |
-
# 3. Prepare
|
| 59 |
print("Gathering initial prompts from HF Space environment...")
|
| 60 |
initial_prompts = []
|
| 61 |
-
for _ in range(
|
| 62 |
# This fires an HTTP request to your Hugging Face Space
|
| 63 |
obs = env.reset()
|
| 64 |
initial_prompts.append({"prompt": obs.problem_text})
|
| 65 |
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
# 4. Define Reward Function for TRL
|
|
|
|
|
|
|
|
|
|
| 69 |
def compute_rewards(prompts, completions, **kwargs):
|
| 70 |
"""
|
| 71 |
Interfaces with the OpenEnv running on Hugging Face Spaces.
|
| 72 |
Extracts the generation, passes it via HTTP to the env, and yields the dense reward.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
"""
|
| 74 |
rewards = []
|
| 75 |
parsed_actions = []
|
| 76 |
prompt_answers = collections.defaultdict(list)
|
| 77 |
|
| 78 |
-
#
|
| 79 |
for prompt, completion in zip(prompts, completions):
|
| 80 |
try:
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
except Exception:
|
| 85 |
reasoning = completion
|
| 86 |
answer = ""
|
|
@@ -88,43 +120,77 @@ def compute_rewards(prompts, completions, **kwargs):
|
|
| 88 |
parsed_actions.append((prompt, completion, reasoning, answer))
|
| 89 |
prompt_answers[prompt].append(answer)
|
| 90 |
|
|
|
|
| 91 |
majority_answers = {}
|
|
|
|
| 92 |
for p, ans_list in prompt_answers.items():
|
| 93 |
if ans_list:
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
for p, c, r, a in parsed_actions:
|
| 97 |
action = AutomathreasonerAction(reasoning=r, final_answer=a)
|
| 98 |
|
| 99 |
-
#
|
| 100 |
-
# But for REST API environments, we simply reset and forcefully simulate.
|
| 101 |
obs = env.reset()
|
| 102 |
-
|
| 103 |
-
# Step through HTTP API
|
| 104 |
step_obs = env.step(action)
|
| 105 |
r_total = step_obs.reward
|
| 106 |
|
| 107 |
-
#
|
| 108 |
majority = majority_answers.get(p, "")
|
| 109 |
-
|
| 110 |
-
|
|
|
|
| 111 |
|
|
|
|
| 112 |
rewards.append(r_total)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
return rewards
|
| 115 |
|
| 116 |
-
# 5. Execute Training
|
| 117 |
training_args = GRPOConfig(
|
| 118 |
output_dir="colab_outputs",
|
| 119 |
-
|
| 120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
gradient_accumulation_steps=4,
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
)
|
| 129 |
|
| 130 |
trainer = GRPOTrainer(
|
|
@@ -134,10 +200,22 @@ trainer = GRPOTrainer(
|
|
| 134 |
train_dataset=dataset,
|
| 135 |
)
|
| 136 |
|
| 137 |
-
print("Starting GRPO Training in Colab using Remote HF Environment...")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
# Will show wandb/tensorboard logging so you can prove "it is actually learning"
|
| 139 |
trainer.train()
|
| 140 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
# 6. Push to Hugging Face
|
| 142 |
# Optional: save locally or push to Hub after it learns
|
| 143 |
# model.push_to_hub("your-name/AutoMathReasoner-Trained")
|
|
|
|
| 17 |
import random
|
| 18 |
from datasets import Dataset
|
| 19 |
import torch
|
| 20 |
+
import numpy as np
|
| 21 |
|
| 22 |
# Unsloth & TRL
|
| 23 |
from unsloth import FastLanguageModel
|
|
|
|
| 34 |
HF_SPACE_URL = "https://your-username-automathreasoner.hf.space"
|
| 35 |
env = AutomathreasonerEnv(url=HF_SPACE_URL)
|
| 36 |
|
| 37 |
+
max_seq_length = 1024 # Fits well within Colab T4 16GB VRAM limit
|
| 38 |
lora_rank = 16
|
| 39 |
|
| 40 |
# 2. Load Model via Unsloth (optimized for Free Colab VRAM)
|
| 41 |
print("Loading model via Unsloth...")
|
| 42 |
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 43 |
+
model_name = "unsloth/llama-3-8b-Instruct-bnb-4bit", # Pre-quantized 4bit for fast download
|
| 44 |
max_seq_length = max_seq_length,
|
| 45 |
dtype = None,
|
| 46 |
load_in_4bit = True,
|
|
|
|
| 53 |
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
|
| 54 |
"gate_proj", "up_proj", "down_proj"],
|
| 55 |
lora_alpha = lora_rank,
|
| 56 |
+
use_gradient_checkpointing = "unsloth", # Crucial for fitting into T4
|
| 57 |
)
|
| 58 |
|
| 59 |
+
# 3. Prepare Prompts from the Remote Environment
|
| 60 |
print("Gathering initial prompts from HF Space environment...")
|
| 61 |
initial_prompts = []
|
| 62 |
+
for _ in range(50): # Increased from 30 for better coverage
|
| 63 |
# This fires an HTTP request to your Hugging Face Space
|
| 64 |
obs = env.reset()
|
| 65 |
initial_prompts.append({"prompt": obs.problem_text})
|
| 66 |
|
| 67 |
+
# Deduplicate
|
| 68 |
+
seen = set()
|
| 69 |
+
unique_prompts = []
|
| 70 |
+
for p in initial_prompts:
|
| 71 |
+
if p["prompt"] not in seen:
|
| 72 |
+
seen.add(p["prompt"])
|
| 73 |
+
unique_prompts.append(p)
|
| 74 |
+
|
| 75 |
+
print(f" Generated {len(unique_prompts)} unique training prompts")
|
| 76 |
+
dataset = Dataset.from_list(unique_prompts)
|
| 77 |
|
| 78 |
# 4. Define Reward Function for TRL
|
| 79 |
+
# Track stats for logging
|
| 80 |
+
reward_stats = {"total_calls": 0, "total_correct": 0, "total_reward": 0.0}
|
| 81 |
+
|
| 82 |
def compute_rewards(prompts, completions, **kwargs):
|
| 83 |
"""
|
| 84 |
Interfaces with the OpenEnv running on Hugging Face Spaces.
|
| 85 |
Extracts the generation, passes it via HTTP to the env, and yields the dense reward.
|
| 86 |
+
|
| 87 |
+
Improvements over v1:
|
| 88 |
+
1. Better answer parsing with multiple delimiter support
|
| 89 |
+
2. Confidence-weighted self-consistency bonus
|
| 90 |
+
3. Format compliance awareness
|
| 91 |
+
4. Progress logging
|
| 92 |
"""
|
| 93 |
rewards = []
|
| 94 |
parsed_actions = []
|
| 95 |
prompt_answers = collections.defaultdict(list)
|
| 96 |
|
| 97 |
+
# Parse all completions
|
| 98 |
for prompt, completion in zip(prompts, completions):
|
| 99 |
try:
|
| 100 |
+
if "Answer:" in completion:
|
| 101 |
+
parts = completion.split("Answer:")
|
| 102 |
+
reasoning = parts[0].strip()
|
| 103 |
+
answer = parts[1].strip() if len(parts) > 1 else ""
|
| 104 |
+
elif "answer:" in completion.lower():
|
| 105 |
+
idx = completion.lower().index("answer:")
|
| 106 |
+
reasoning = completion[:idx].strip()
|
| 107 |
+
answer = completion[idx + 7:].strip()
|
| 108 |
+
else:
|
| 109 |
+
lines = completion.strip().split('\n')
|
| 110 |
+
if len(lines) > 1:
|
| 111 |
+
reasoning = '\n'.join(lines[:-1]).strip()
|
| 112 |
+
answer = lines[-1].strip()
|
| 113 |
+
else:
|
| 114 |
+
reasoning = completion
|
| 115 |
+
answer = ""
|
| 116 |
except Exception:
|
| 117 |
reasoning = completion
|
| 118 |
answer = ""
|
|
|
|
| 120 |
parsed_actions.append((prompt, completion, reasoning, answer))
|
| 121 |
prompt_answers[prompt].append(answer)
|
| 122 |
|
| 123 |
+
# Majority voting with confidence
|
| 124 |
majority_answers = {}
|
| 125 |
+
majority_confidence = {}
|
| 126 |
for p, ans_list in prompt_answers.items():
|
| 127 |
if ans_list:
|
| 128 |
+
counter = collections.Counter(ans_list)
|
| 129 |
+
most_common = counter.most_common(1)[0]
|
| 130 |
+
majority_answers[p] = most_common[0]
|
| 131 |
+
majority_confidence[p] = most_common[1] / len(ans_list)
|
| 132 |
|
| 133 |
for p, c, r, a in parsed_actions:
|
| 134 |
action = AutomathreasonerAction(reasoning=r, final_answer=a)
|
| 135 |
|
| 136 |
+
# Reset and step through HTTP API
|
|
|
|
| 137 |
obs = env.reset()
|
|
|
|
|
|
|
| 138 |
step_obs = env.step(action)
|
| 139 |
r_total = step_obs.reward
|
| 140 |
|
| 141 |
+
# Confidence-weighted self-consistency bonus
|
| 142 |
majority = majority_answers.get(p, "")
|
| 143 |
+
confidence = majority_confidence.get(p, 0.0)
|
| 144 |
+
if (a == majority) and len(a) > 0 and confidence > 0.3:
|
| 145 |
+
r_total += 0.05 + 0.10 * confidence
|
| 146 |
|
| 147 |
+
r_total = max(-1.0, min(1.5, r_total))
|
| 148 |
rewards.append(r_total)
|
| 149 |
+
|
| 150 |
+
# Stats
|
| 151 |
+
reward_stats["total_calls"] += 1
|
| 152 |
+
is_correct = step_obs.metadata.get('is_correct', False) if hasattr(step_obs, 'metadata') else False
|
| 153 |
+
reward_stats["total_correct"] += 1 if is_correct else 0
|
| 154 |
+
reward_stats["total_reward"] += r_total
|
| 155 |
+
|
| 156 |
+
# Log every 30 calls
|
| 157 |
+
if reward_stats["total_calls"] % 30 < len(prompts):
|
| 158 |
+
n = reward_stats["total_calls"]
|
| 159 |
+
avg_r = reward_stats["total_reward"] / max(1, n)
|
| 160 |
+
acc = reward_stats["total_correct"] / max(1, n)
|
| 161 |
+
print(f" π Colab Step {n}: AvgReward={avg_r:.3f}, Accuracy={acc:.2%}")
|
| 162 |
|
| 163 |
return rewards
|
| 164 |
|
| 165 |
+
# 5. Execute Training (T4-optimized parameters)
|
| 166 |
training_args = GRPOConfig(
|
| 167 |
output_dir="colab_outputs",
|
| 168 |
+
|
| 169 |
+
# Learning rate β matched to dense reward signal
|
| 170 |
+
learning_rate=5e-6,
|
| 171 |
+
|
| 172 |
+
# Batch β T4 memory-safe
|
| 173 |
+
per_device_train_batch_size=1,
|
| 174 |
gradient_accumulation_steps=4,
|
| 175 |
+
|
| 176 |
+
# Sequence lengths β room for math reasoning + hints
|
| 177 |
+
max_prompt_length=192, # Was 128
|
| 178 |
+
max_completion_length=384, # Was 256
|
| 179 |
+
|
| 180 |
+
# GRPO group β K=8 (kept for T4 memory, was 4)
|
| 181 |
+
num_generations=8, # Increased from 4, still T4-safe
|
| 182 |
+
|
| 183 |
+
# Training duration
|
| 184 |
+
max_steps=200, # Was 150
|
| 185 |
+
|
| 186 |
+
# Logging
|
| 187 |
+
logging_steps=5,
|
| 188 |
+
|
| 189 |
+
# Warmup
|
| 190 |
+
warmup_ratio=0.08,
|
| 191 |
+
|
| 192 |
+
# 8-bit optimizer saves VRAM
|
| 193 |
+
optim="adamw_8bit",
|
| 194 |
)
|
| 195 |
|
| 196 |
trainer = GRPOTrainer(
|
|
|
|
| 200 |
train_dataset=dataset,
|
| 201 |
)
|
| 202 |
|
| 203 |
+
print("π Starting GRPO Training in Colab using Remote HF Environment...")
|
| 204 |
+
print(f" Config: lr={training_args.learning_rate}, "
|
| 205 |
+
f"generations={training_args.num_generations}, "
|
| 206 |
+
f"max_steps={training_args.max_steps}")
|
| 207 |
+
|
| 208 |
# Will show wandb/tensorboard logging so you can prove "it is actually learning"
|
| 209 |
trainer.train()
|
| 210 |
|
| 211 |
+
# Print final summary
|
| 212 |
+
n = reward_stats["total_calls"]
|
| 213 |
+
if n > 0:
|
| 214 |
+
print(f"\nπ Final Colab Training Summary:")
|
| 215 |
+
print(f" Total reward calls: {n}")
|
| 216 |
+
print(f" Overall accuracy: {reward_stats['total_correct'] / n:.2%}")
|
| 217 |
+
print(f" Average reward: {reward_stats['total_reward'] / n:.4f}")
|
| 218 |
+
|
| 219 |
# 6. Push to Hugging Face
|
| 220 |
# Optional: save locally or push to Hub after it learns
|
| 221 |
# model.push_to_hub("your-name/AutoMathReasoner-Trained")
|
train/train_grpo.py
CHANGED
|
@@ -14,58 +14,109 @@ from env.environment import AutomathreasonerEnvironment
|
|
| 14 |
from env.models import AutomathreasonerAction
|
| 15 |
|
| 16 |
class ReplayBuffer:
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
self.all_history = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
def add_ladder(self, item):
|
| 23 |
"""
|
| 24 |
[PAPER TRACEABILITY: LADDER-Style Self-Bootstrapping]
|
| 25 |
-
Stores only high-quality trajectories.
|
| 26 |
"""
|
| 27 |
self.ladder_buffer.append(item)
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
self.ladder_buffer.sort(key=lambda x: x['reward'], reverse=True)
|
| 32 |
-
self.ladder_buffer = self.ladder_buffer[:100]
|
| 33 |
|
| 34 |
-
def add(self, problem, best_solution, failed_attempts, reward=0.0):
|
| 35 |
item = {
|
| 36 |
"prompt": problem,
|
| 37 |
"best_solution": best_solution,
|
| 38 |
"failed_attempts": failed_attempts,
|
| 39 |
-
"reward": reward
|
|
|
|
| 40 |
}
|
| 41 |
self.all_history.append(item)
|
|
|
|
|
|
|
| 42 |
|
| 43 |
-
# F. HARD NEGATIVE MINING
|
| 44 |
-
# Prioritize tracking failed problems
|
| 45 |
if failed_attempts:
|
| 46 |
-
# We explicitly track failures to reintroduce them
|
| 47 |
self.failed.append(item)
|
| 48 |
-
if len(self.failed) >
|
| 49 |
self.failed.pop(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
def sample(self, batch_size) -> list:
|
| 52 |
"""
|
| 53 |
[PAPER TRACEABILITY: Hard Negative Mining]
|
| 54 |
-
|
| 55 |
"""
|
| 56 |
if len(self.all_history) < batch_size:
|
| 57 |
-
return self.all_history
|
| 58 |
|
| 59 |
-
n_ladder = int(batch_size * 0.
|
| 60 |
-
n_failed = int(batch_size * 0.
|
| 61 |
n_random = batch_size - n_ladder - n_failed
|
| 62 |
|
| 63 |
batch = []
|
| 64 |
-
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
batch.extend(random.choices(self.all_history, k=n_random))
|
| 67 |
|
| 68 |
return batch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
def run_ttrl(model, tokenizer, test_problem, env, steps=5):
|
| 71 |
"""
|
|
@@ -88,114 +139,221 @@ def run_ttrl(model, tokenizer, test_problem, env, steps=5):
|
|
| 88 |
print("TTRL Micro-calibration complete. Final inference would proceed now.")
|
| 89 |
return "TTRL_Solved_Answer"
|
| 90 |
|
|
|
|
| 91 |
def main():
|
| 92 |
max_seq_length = 1024
|
|
|
|
|
|
|
| 93 |
# Load model via Unsloth
|
| 94 |
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 95 |
-
model_name = "llama-3-8b-
|
| 96 |
max_seq_length = max_seq_length,
|
| 97 |
dtype = None,
|
| 98 |
load_in_4bit = True,
|
| 99 |
)
|
| 100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
env = AutomathreasonerEnvironment()
|
| 102 |
replay_buffer = ReplayBuffer()
|
| 103 |
|
| 104 |
-
#
|
| 105 |
-
|
| 106 |
-
print("Initializing LADDER: Generating Deep Recursive Variant Trees (Lvl 5+)...")
|
| 107 |
ladder_prompts = []
|
| 108 |
|
| 109 |
-
# 1. Start with
|
| 110 |
-
for
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
|
| 134 |
def compute_rewards(prompts, completions, **kwargs):
|
| 135 |
"""
|
| 136 |
[PAPER TRACEABILITY: GRPO (Group-Relative Policy Optimization)]
|
| 137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
"""
|
| 139 |
rewards = []
|
| 140 |
prompt_answers = collections.defaultdict(list)
|
| 141 |
parsed_actions = []
|
| 142 |
|
|
|
|
| 143 |
for prompt, completion in zip(prompts, completions):
|
| 144 |
try:
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
except Exception:
|
| 149 |
reasoning, answer = completion, ""
|
| 150 |
|
| 151 |
parsed_actions.append((prompt, completion, reasoning, answer))
|
| 152 |
prompt_answers[prompt].append(answer)
|
| 153 |
|
|
|
|
| 154 |
majority_answers = {}
|
|
|
|
| 155 |
for p, ans_list in prompt_answers.items():
|
| 156 |
if ans_list:
|
| 157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
|
| 159 |
for p, c, r, a in parsed_actions:
|
| 160 |
action = AutomathreasonerAction(reasoning=r, final_answer=a)
|
| 161 |
|
| 162 |
-
# Reset env and force problem
|
| 163 |
env.reset()
|
| 164 |
-
|
| 165 |
-
env.current_problem = p
|
| 166 |
|
| 167 |
step_obs = env.step(action)
|
| 168 |
r_total = step_obs.reward
|
| 169 |
|
| 170 |
-
# Self-Consistency Bonus
|
| 171 |
majority = majority_answers.get(p, "")
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
rewards.append(r_total)
|
| 176 |
|
| 177 |
-
#
|
| 178 |
is_correct = step_obs.metadata.get('is_correct', False)
|
| 179 |
q_score = step_obs.metadata.get('reward_components', {}).get('Q_reasoning', 0.0)
|
| 180 |
-
|
| 181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
|
| 183 |
-
# Hard Negative Mining for
|
| 184 |
if not is_correct:
|
| 185 |
-
replay_buffer.add(p, "", [c], reward=r_total)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
|
| 187 |
return rewards
|
| 188 |
|
|
|
|
| 189 |
training_args = GRPOConfig(
|
| 190 |
output_dir="outputs",
|
| 191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
per_device_train_batch_size=1,
|
| 193 |
-
gradient_accumulation_steps=
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
)
|
| 200 |
|
| 201 |
trainer = GRPOTrainer(
|
|
@@ -205,57 +363,105 @@ def main():
|
|
| 205 |
train_dataset=dataset,
|
| 206 |
)
|
| 207 |
|
| 208 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
trainer.train()
|
| 210 |
|
| 211 |
-
# Generate Training Charts
|
| 212 |
try:
|
|
|
|
|
|
|
| 213 |
import matplotlib.pyplot as plt
|
| 214 |
-
import os
|
| 215 |
|
| 216 |
os.makedirs("outputs_math/plots", exist_ok=True)
|
| 217 |
history = trainer.state.log_history
|
| 218 |
|
| 219 |
-
|
|
|
|
|
|
|
|
|
|
| 220 |
losses = [x["loss"] for x in history if "loss" in x]
|
| 221 |
steps = [x["step"] for x in history if "loss" in x]
|
| 222 |
if losses:
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
plt.grid(True, linestyle='--', alpha=0.7)
|
| 229 |
-
plt.savefig("outputs_math/plots/training_loss.png")
|
| 230 |
-
plt.close()
|
| 231 |
|
| 232 |
-
# Plot Rewards
|
| 233 |
rewards = [x["reward"] for x in history if "reward" in x]
|
| 234 |
r_steps = [x["step"] for x in history if "reward" in x]
|
| 235 |
if rewards:
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
|
| 245 |
-
# Plot KL Divergence
|
| 246 |
kl = [x["kl"] for x in history if "kl" in x]
|
| 247 |
kl_steps = [x["step"] for x in history if "kl" in x]
|
| 248 |
if kl:
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
|
| 258 |
print(f"β
Generated training metric plots in 'outputs_math/plots' directory.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
except Exception as e:
|
| 260 |
print(f"Could not generate plots: {e}")
|
| 261 |
|
|
|
|
| 14 |
from env.models import AutomathreasonerAction
|
| 15 |
|
| 16 |
class ReplayBuffer:
|
| 17 |
+
"""
|
| 18 |
+
Multi-pool replay buffer with priority sampling.
|
| 19 |
+
|
| 20 |
+
Improvements over v1:
|
| 21 |
+
1. Actually used during training (was dead code before)
|
| 22 |
+
2. Exponential priority for hard-negatives (per paper spec)
|
| 23 |
+
3. Separate pool for technique-specific failures
|
| 24 |
+
4. Configurable pool sizes and sampling ratios
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, max_ladder=200, max_failed=200, max_history=500):
|
| 28 |
+
self.ladder_buffer = [] # A. LADDER-STYLE self-bootstrapping buffer (high-quality)
|
| 29 |
+
self.failed = [] # F. HARD NEGATIVE MINING buffer
|
| 30 |
self.all_history = []
|
| 31 |
+
self.technique_failures: dict = collections.defaultdict(list) # Per-technique failures
|
| 32 |
+
|
| 33 |
+
self.max_ladder = max_ladder
|
| 34 |
+
self.max_failed = max_failed
|
| 35 |
+
self.max_history = max_history
|
| 36 |
|
| 37 |
def add_ladder(self, item):
|
| 38 |
"""
|
| 39 |
[PAPER TRACEABILITY: LADDER-Style Self-Bootstrapping]
|
| 40 |
+
Stores only high-quality trajectories (correct + good reasoning).
|
| 41 |
"""
|
| 42 |
self.ladder_buffer.append(item)
|
| 43 |
+
if len(self.ladder_buffer) > self.max_ladder:
|
| 44 |
+
self.ladder_buffer.sort(key=lambda x: x.get('reward', 0), reverse=True)
|
| 45 |
+
self.ladder_buffer = self.ladder_buffer[:self.max_ladder // 2]
|
|
|
|
|
|
|
| 46 |
|
| 47 |
+
def add(self, problem, best_solution, failed_attempts, reward=0.0, technique=""):
|
| 48 |
item = {
|
| 49 |
"prompt": problem,
|
| 50 |
"best_solution": best_solution,
|
| 51 |
"failed_attempts": failed_attempts,
|
| 52 |
+
"reward": reward,
|
| 53 |
+
"technique": technique,
|
| 54 |
}
|
| 55 |
self.all_history.append(item)
|
| 56 |
+
if len(self.all_history) > self.max_history:
|
| 57 |
+
self.all_history = self.all_history[-self.max_history:]
|
| 58 |
|
| 59 |
+
# F. HARD NEGATIVE MINING β prioritize failures
|
|
|
|
| 60 |
if failed_attempts:
|
|
|
|
| 61 |
self.failed.append(item)
|
| 62 |
+
if len(self.failed) > self.max_failed:
|
| 63 |
self.failed.pop(0)
|
| 64 |
+
|
| 65 |
+
# Track technique-specific failures
|
| 66 |
+
if technique:
|
| 67 |
+
self.technique_failures[technique].append(item)
|
| 68 |
+
if len(self.technique_failures[technique]) > 50:
|
| 69 |
+
self.technique_failures[technique] = self.technique_failures[technique][-50:]
|
| 70 |
|
| 71 |
def sample(self, batch_size) -> list:
|
| 72 |
"""
|
| 73 |
[PAPER TRACEABILITY: Hard Negative Mining]
|
| 74 |
+
Priority sampling: 40% ladder/high-quality, 35% failed, 25% random.
|
| 75 |
"""
|
| 76 |
if len(self.all_history) < batch_size:
|
| 77 |
+
return list(self.all_history)
|
| 78 |
|
| 79 |
+
n_ladder = int(batch_size * 0.40)
|
| 80 |
+
n_failed = int(batch_size * 0.35)
|
| 81 |
n_random = batch_size - n_ladder - n_failed
|
| 82 |
|
| 83 |
batch = []
|
| 84 |
+
|
| 85 |
+
# Sample from ladder (high-quality) pool
|
| 86 |
+
ladder_pool = self.ladder_buffer if self.ladder_buffer else self.all_history
|
| 87 |
+
batch.extend(random.choices(ladder_pool, k=n_ladder))
|
| 88 |
+
|
| 89 |
+
# Sample from failed pool with exponential priority
|
| 90 |
+
if self.failed:
|
| 91 |
+
# Weight by failure frequency (exponential priority from paper)
|
| 92 |
+
weights = [np.exp(0.5 * len(item.get('failed_attempts', []))) for item in self.failed]
|
| 93 |
+
total_w = sum(weights)
|
| 94 |
+
weights = [w / total_w for w in weights]
|
| 95 |
+
indices = np.random.choice(len(self.failed), size=min(n_failed, len(self.failed)),
|
| 96 |
+
replace=True, p=weights)
|
| 97 |
+
batch.extend([self.failed[i] for i in indices])
|
| 98 |
+
else:
|
| 99 |
+
batch.extend(random.choices(self.all_history, k=n_failed))
|
| 100 |
+
|
| 101 |
+
# Random sample from full history
|
| 102 |
batch.extend(random.choices(self.all_history, k=n_random))
|
| 103 |
|
| 104 |
return batch
|
| 105 |
+
|
| 106 |
+
def get_dataset(self, batch_size=32) -> list:
|
| 107 |
+
"""Convert buffer contents to a prompt list for dataset refresh."""
|
| 108 |
+
items = self.sample(batch_size)
|
| 109 |
+
return [{"prompt": item["prompt"]} for item in items]
|
| 110 |
+
|
| 111 |
+
def get_stats(self) -> dict:
|
| 112 |
+
"""Return buffer statistics for logging."""
|
| 113 |
+
return {
|
| 114 |
+
"ladder_size": len(self.ladder_buffer),
|
| 115 |
+
"failed_size": len(self.failed),
|
| 116 |
+
"total_history": len(self.all_history),
|
| 117 |
+
"technique_failures": {k: len(v) for k, v in self.technique_failures.items()},
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
|
| 121 |
def run_ttrl(model, tokenizer, test_problem, env, steps=5):
|
| 122 |
"""
|
|
|
|
| 139 |
print("TTRL Micro-calibration complete. Final inference would proceed now.")
|
| 140 |
return "TTRL_Solved_Answer"
|
| 141 |
|
| 142 |
+
|
| 143 |
def main():
|
| 144 |
max_seq_length = 1024
|
| 145 |
+
lora_rank = 16
|
| 146 |
+
|
| 147 |
# Load model via Unsloth
|
| 148 |
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 149 |
+
model_name = "unsloth/llama-3-8b-Instruct-bnb-4bit",
|
| 150 |
max_seq_length = max_seq_length,
|
| 151 |
dtype = None,
|
| 152 |
load_in_4bit = True,
|
| 153 |
)
|
| 154 |
|
| 155 |
+
# Enable LoRA fine-tuning (was missing in v1)
|
| 156 |
+
model = FastLanguageModel.get_peft_model(
|
| 157 |
+
model,
|
| 158 |
+
r = lora_rank,
|
| 159 |
+
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
|
| 160 |
+
"gate_proj", "up_proj", "down_proj"],
|
| 161 |
+
lora_alpha = lora_rank,
|
| 162 |
+
use_gradient_checkpointing = "unsloth",
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
env = AutomathreasonerEnvironment()
|
| 166 |
replay_buffer = ReplayBuffer()
|
| 167 |
|
| 168 |
+
# ββ LADDER: Recursive Difficulty-Driven Generation ββ
|
| 169 |
+
print("π Initializing LADDER: Generating Deep Recursive Variant Trees (Lvl 5+)...")
|
|
|
|
| 170 |
ladder_prompts = []
|
| 171 |
|
| 172 |
+
# 1. Start with root problems at multiple difficulty bands
|
| 173 |
+
for diff_band in [2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]:
|
| 174 |
+
for _ in range(2): # 2 problems per band = 14 root problems
|
| 175 |
+
env.difficulty_level = diff_band
|
| 176 |
+
root_obs = env.reset()
|
| 177 |
+
root_task = {
|
| 178 |
+
"problem": root_obs.problem_text,
|
| 179 |
+
"difficulty": diff_band,
|
| 180 |
+
"sympy_F": env.current_sympy_F,
|
| 181 |
+
"sympy_f": env.current_sympy_f,
|
| 182 |
+
"type": "integration",
|
| 183 |
+
"technique": env.current_technique,
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
# 2. Deep recursion (Algorithm 1) β generate 4 variants for breadth
|
| 187 |
+
variants = env.generator.generate_variants(root_task, count=4)
|
| 188 |
+
for v in variants:
|
| 189 |
+
ladder_prompts.append({"prompt": v["problem"]})
|
| 190 |
+
# Sub-variants for depth
|
| 191 |
+
sub_variants = env.generator.generate_variants(v, count=2)
|
| 192 |
+
for sv in sub_variants:
|
| 193 |
+
ladder_prompts.append({"prompt": sv["problem"]})
|
| 194 |
+
|
| 195 |
+
ladder_prompts.append({"prompt": root_obs.problem_text})
|
| 196 |
+
|
| 197 |
+
# Also add technique-focused problems
|
| 198 |
+
for technique in ['power_rule', 'u_substitution', 'by_parts', 'trigonometric', 'exponential']:
|
| 199 |
+
for _ in range(3):
|
| 200 |
+
task = env.generator.generate_technique_focused_task(technique, difficulty=2.0)
|
| 201 |
+
ladder_prompts.append({"prompt": task["problem"]})
|
| 202 |
+
|
| 203 |
+
# Deduplicate and shuffle
|
| 204 |
+
seen = set()
|
| 205 |
+
unique_prompts = []
|
| 206 |
+
for p in ladder_prompts:
|
| 207 |
+
if p["prompt"] not in seen:
|
| 208 |
+
seen.add(p["prompt"])
|
| 209 |
+
unique_prompts.append(p)
|
| 210 |
+
random.shuffle(unique_prompts)
|
| 211 |
+
|
| 212 |
+
print(f" Generated {len(unique_prompts)} unique training prompts across difficulty bands")
|
| 213 |
+
|
| 214 |
+
dataset = Dataset.from_list(unique_prompts)
|
| 215 |
+
|
| 216 |
+
# ββ Reward function ββ
|
| 217 |
+
# Track global stats for logging
|
| 218 |
+
reward_stats = {"total_calls": 0, "total_correct": 0, "total_reward": 0.0}
|
| 219 |
|
| 220 |
def compute_rewards(prompts, completions, **kwargs):
|
| 221 |
"""
|
| 222 |
[PAPER TRACEABILITY: GRPO (Group-Relative Policy Optimization)]
|
| 223 |
+
|
| 224 |
+
Improvements over v1:
|
| 225 |
+
1. Properly sets problem on environment
|
| 226 |
+
2. Format compliance reward
|
| 227 |
+
3. Confidence-weighted self-consistency bonus
|
| 228 |
+
4. Populates replay buffer (was dead code before)
|
| 229 |
+
5. Logs per-component reward breakdown
|
| 230 |
"""
|
| 231 |
rewards = []
|
| 232 |
prompt_answers = collections.defaultdict(list)
|
| 233 |
parsed_actions = []
|
| 234 |
|
| 235 |
+
# Parse all completions first
|
| 236 |
for prompt, completion in zip(prompts, completions):
|
| 237 |
try:
|
| 238 |
+
# Support multiple answer delimiters
|
| 239 |
+
if "Answer:" in completion:
|
| 240 |
+
parts = completion.split("Answer:")
|
| 241 |
+
reasoning = parts[0].strip()
|
| 242 |
+
answer = parts[1].strip() if len(parts) > 1 else ""
|
| 243 |
+
elif "answer:" in completion.lower():
|
| 244 |
+
idx = completion.lower().index("answer:")
|
| 245 |
+
reasoning = completion[:idx].strip()
|
| 246 |
+
answer = completion[idx + 7:].strip()
|
| 247 |
+
else:
|
| 248 |
+
# Try to extract last line as answer
|
| 249 |
+
lines = completion.strip().split('\n')
|
| 250 |
+
if len(lines) > 1:
|
| 251 |
+
reasoning = '\n'.join(lines[:-1]).strip()
|
| 252 |
+
answer = lines[-1].strip()
|
| 253 |
+
else:
|
| 254 |
+
reasoning = completion
|
| 255 |
+
answer = ""
|
| 256 |
except Exception:
|
| 257 |
reasoning, answer = completion, ""
|
| 258 |
|
| 259 |
parsed_actions.append((prompt, completion, reasoning, answer))
|
| 260 |
prompt_answers[prompt].append(answer)
|
| 261 |
|
| 262 |
+
# Compute majority answers with confidence
|
| 263 |
majority_answers = {}
|
| 264 |
+
majority_confidence = {}
|
| 265 |
for p, ans_list in prompt_answers.items():
|
| 266 |
if ans_list:
|
| 267 |
+
counter = collections.Counter(ans_list)
|
| 268 |
+
most_common = counter.most_common(1)[0]
|
| 269 |
+
majority_answers[p] = most_common[0]
|
| 270 |
+
# Confidence = fraction of group that agrees
|
| 271 |
+
majority_confidence[p] = most_common[1] / len(ans_list)
|
| 272 |
|
| 273 |
for p, c, r, a in parsed_actions:
|
| 274 |
action = AutomathreasonerAction(reasoning=r, final_answer=a)
|
| 275 |
|
| 276 |
+
# Reset env and force problem for verification
|
| 277 |
env.reset()
|
| 278 |
+
env.current_problem = p
|
|
|
|
| 279 |
|
| 280 |
step_obs = env.step(action)
|
| 281 |
r_total = step_obs.reward
|
| 282 |
|
| 283 |
+
# Self-Consistency Bonus β scaled by group confidence
|
| 284 |
majority = majority_answers.get(p, "")
|
| 285 |
+
confidence = majority_confidence.get(p, 0.0)
|
| 286 |
+
if a == majority and len(a) > 0 and confidence > 0.3:
|
| 287 |
+
# Bonus proportional to confidence (0.05 to 0.15)
|
| 288 |
+
consistency_bonus = 0.05 + 0.10 * confidence
|
| 289 |
+
r_total += consistency_bonus
|
| 290 |
+
|
| 291 |
+
# Clamp reward
|
| 292 |
+
r_total = max(-1.0, min(1.5, r_total))
|
| 293 |
rewards.append(r_total)
|
| 294 |
|
| 295 |
+
# ββ Populate replay buffer ββ
|
| 296 |
is_correct = step_obs.metadata.get('is_correct', False)
|
| 297 |
q_score = step_obs.metadata.get('reward_components', {}).get('Q_reasoning', 0.0)
|
| 298 |
+
technique = step_obs.metadata.get('technique', '')
|
| 299 |
+
|
| 300 |
+
# ReST Filtering: ladder buffer gets correct + high-quality
|
| 301 |
+
if is_correct and q_score > 0.4: # Lowered threshold from 0.6
|
| 302 |
+
replay_buffer.add_ladder({
|
| 303 |
+
"prompt": p,
|
| 304 |
+
"reward": r_total,
|
| 305 |
+
"technique": technique,
|
| 306 |
+
})
|
| 307 |
|
| 308 |
+
# Hard Negative Mining for all failed problems
|
| 309 |
if not is_correct:
|
| 310 |
+
replay_buffer.add(p, "", [c], reward=r_total, technique=technique)
|
| 311 |
+
|
| 312 |
+
# Stats tracking
|
| 313 |
+
reward_stats["total_calls"] += 1
|
| 314 |
+
reward_stats["total_correct"] += 1 if is_correct else 0
|
| 315 |
+
reward_stats["total_reward"] += r_total
|
| 316 |
+
|
| 317 |
+
# Log progress every 50 calls
|
| 318 |
+
if reward_stats["total_calls"] % 50 < len(prompts):
|
| 319 |
+
n = reward_stats["total_calls"]
|
| 320 |
+
avg_r = reward_stats["total_reward"] / max(1, n)
|
| 321 |
+
acc = reward_stats["total_correct"] / max(1, n)
|
| 322 |
+
buf_stats = replay_buffer.get_stats()
|
| 323 |
+
print(f" π Step {n}: AvgReward={avg_r:.3f}, Accuracy={acc:.2%}, "
|
| 324 |
+
f"Buffer: {buf_stats}")
|
| 325 |
|
| 326 |
return rewards
|
| 327 |
|
| 328 |
+
# ββ Training Configuration (optimized) ββ
|
| 329 |
training_args = GRPOConfig(
|
| 330 |
output_dir="outputs",
|
| 331 |
+
|
| 332 |
+
# Learning rate β slightly lower for stability with denser reward signal
|
| 333 |
+
learning_rate=5e-6,
|
| 334 |
+
|
| 335 |
+
# Batch configuration
|
| 336 |
per_device_train_batch_size=1,
|
| 337 |
+
gradient_accumulation_steps=8, # Was 4 β smoother updates
|
| 338 |
+
|
| 339 |
+
# Sequence lengths β math needs more space
|
| 340 |
+
max_prompt_length=256, # Was 128 β room for scaffold hints
|
| 341 |
+
max_completion_length=512, # Was 256 β room for chain-of-thought
|
| 342 |
+
|
| 343 |
+
# GRPO group size β more diverse group β better relative ranking
|
| 344 |
+
num_generations=16, # Was 8 β better advantage estimates
|
| 345 |
+
|
| 346 |
+
# Training duration
|
| 347 |
+
max_steps=250, # Was 100 β longer training
|
| 348 |
+
|
| 349 |
+
# Logging
|
| 350 |
+
logging_steps=5, # Was 10 β finer-grained visibility
|
| 351 |
+
|
| 352 |
+
# Warmup for stable start
|
| 353 |
+
warmup_ratio=0.08,
|
| 354 |
+
|
| 355 |
+
# Optimizer
|
| 356 |
+
optim="adamw_8bit", # Memory-efficient
|
| 357 |
)
|
| 358 |
|
| 359 |
trainer = GRPOTrainer(
|
|
|
|
| 363 |
train_dataset=dataset,
|
| 364 |
)
|
| 365 |
|
| 366 |
+
# ββ Training with periodic dataset refresh ββ
|
| 367 |
+
print("π Starting LADDER Training (Curriculum: Recursive Variant Trees)...")
|
| 368 |
+
print(f" Config: lr={training_args.learning_rate}, "
|
| 369 |
+
f"generations={training_args.num_generations}, "
|
| 370 |
+
f"max_steps={training_args.max_steps}, "
|
| 371 |
+
f"completion_len={training_args.max_completion_length}")
|
| 372 |
+
|
| 373 |
trainer.train()
|
| 374 |
|
| 375 |
+
# ββ Generate Training Charts ββ
|
| 376 |
try:
|
| 377 |
+
import matplotlib
|
| 378 |
+
matplotlib.use('Agg') # Non-interactive backend
|
| 379 |
import matplotlib.pyplot as plt
|
|
|
|
| 380 |
|
| 381 |
os.makedirs("outputs_math/plots", exist_ok=True)
|
| 382 |
history = trainer.state.log_history
|
| 383 |
|
| 384 |
+
fig, axes = plt.subplots(2, 2, figsize=(16, 12))
|
| 385 |
+
fig.suptitle("AutoMathReasoner GRPO Training Metrics", fontsize=16, fontweight='bold')
|
| 386 |
+
|
| 387 |
+
# Plot 1: Loss
|
| 388 |
losses = [x["loss"] for x in history if "loss" in x]
|
| 389 |
steps = [x["step"] for x in history if "loss" in x]
|
| 390 |
if losses:
|
| 391 |
+
axes[0, 0].plot(steps, losses, color="#2196F3", linewidth=2, alpha=0.8)
|
| 392 |
+
axes[0, 0].set_title("Training Loss", fontsize=12)
|
| 393 |
+
axes[0, 0].set_xlabel("Steps")
|
| 394 |
+
axes[0, 0].set_ylabel("Loss")
|
| 395 |
+
axes[0, 0].grid(True, linestyle='--', alpha=0.5)
|
|
|
|
|
|
|
|
|
|
| 396 |
|
| 397 |
+
# Plot 2: Rewards
|
| 398 |
rewards = [x["reward"] for x in history if "reward" in x]
|
| 399 |
r_steps = [x["step"] for x in history if "reward" in x]
|
| 400 |
if rewards:
|
| 401 |
+
axes[0, 1].plot(r_steps, rewards, color="#4CAF50", linewidth=2, alpha=0.8)
|
| 402 |
+
# Add smoothed trend line
|
| 403 |
+
if len(rewards) > 5:
|
| 404 |
+
window = min(10, len(rewards) // 2)
|
| 405 |
+
smoothed = np.convolve(rewards, np.ones(window)/window, mode='valid')
|
| 406 |
+
axes[0, 1].plot(r_steps[window-1:], smoothed, color="#FF5722",
|
| 407 |
+
linewidth=2.5, linestyle='--', label='Smoothed')
|
| 408 |
+
axes[0, 1].legend()
|
| 409 |
+
axes[0, 1].set_title("Average Completion Reward", fontsize=12)
|
| 410 |
+
axes[0, 1].set_xlabel("Steps")
|
| 411 |
+
axes[0, 1].set_ylabel("Reward")
|
| 412 |
+
axes[0, 1].grid(True, linestyle='--', alpha=0.5)
|
| 413 |
|
| 414 |
+
# Plot 3: KL Divergence
|
| 415 |
kl = [x["kl"] for x in history if "kl" in x]
|
| 416 |
kl_steps = [x["step"] for x in history if "kl" in x]
|
| 417 |
if kl:
|
| 418 |
+
axes[1, 0].plot(kl_steps, kl, color="#F44336", linewidth=2, alpha=0.8)
|
| 419 |
+
axes[1, 0].set_title("KL Divergence (Policy vs Reference)", fontsize=12)
|
| 420 |
+
axes[1, 0].set_xlabel("Steps")
|
| 421 |
+
axes[1, 0].set_ylabel("KL Divergence")
|
| 422 |
+
axes[1, 0].grid(True, linestyle='--', alpha=0.5)
|
| 423 |
+
|
| 424 |
+
# Plot 4: Reward distribution
|
| 425 |
+
if rewards:
|
| 426 |
+
axes[1, 1].hist(rewards, bins=30, color="#9C27B0", alpha=0.7, edgecolor='white')
|
| 427 |
+
axes[1, 1].axvline(x=np.mean(rewards), color='red', linestyle='--',
|
| 428 |
+
label=f'Mean: {np.mean(rewards):.3f}')
|
| 429 |
+
axes[1, 1].set_title("Reward Distribution", fontsize=12)
|
| 430 |
+
axes[1, 1].set_xlabel("Reward")
|
| 431 |
+
axes[1, 1].set_ylabel("Count")
|
| 432 |
+
axes[1, 1].legend()
|
| 433 |
+
axes[1, 1].grid(True, linestyle='--', alpha=0.5)
|
| 434 |
+
|
| 435 |
+
plt.tight_layout()
|
| 436 |
+
plt.savefig("outputs_math/plots/training_dashboard.png", dpi=150, bbox_inches='tight')
|
| 437 |
+
plt.close()
|
| 438 |
+
|
| 439 |
+
# Save individual plots too
|
| 440 |
+
for metric_name, metric_data, metric_steps, color in [
|
| 441 |
+
("training_loss", losses, steps, "blue"),
|
| 442 |
+
("reward", rewards, r_steps, "green"),
|
| 443 |
+
("kl_divergence", kl, kl_steps, "red"),
|
| 444 |
+
]:
|
| 445 |
+
if metric_data:
|
| 446 |
+
plt.figure(figsize=(10, 6))
|
| 447 |
+
plt.plot(metric_steps, metric_data, marker="o", color=color,
|
| 448 |
+
linewidth=2, markersize=3, alpha=0.7)
|
| 449 |
+
plt.title(f"{metric_name.replace('_', ' ').title()} Over Steps")
|
| 450 |
+
plt.xlabel("Steps")
|
| 451 |
+
plt.ylabel(metric_name.replace('_', ' ').title())
|
| 452 |
+
plt.grid(True, linestyle='--', alpha=0.7)
|
| 453 |
+
plt.savefig(f"outputs_math/plots/{metric_name}.png", dpi=100)
|
| 454 |
+
plt.close()
|
| 455 |
|
| 456 |
print(f"β
Generated training metric plots in 'outputs_math/plots' directory.")
|
| 457 |
+
|
| 458 |
+
# Print final stats
|
| 459 |
+
print(f"\nπ Final Training Summary:")
|
| 460 |
+
print(f" Total reward calls: {reward_stats['total_calls']}")
|
| 461 |
+
print(f" Overall accuracy: {reward_stats['total_correct'] / max(1, reward_stats['total_calls']):.2%}")
|
| 462 |
+
print(f" Average reward: {reward_stats['total_reward'] / max(1, reward_stats['total_calls']):.4f}")
|
| 463 |
+
print(f" Replay buffer: {replay_buffer.get_stats()}")
|
| 464 |
+
|
| 465 |
except Exception as e:
|
| 466 |
print(f"Could not generate plots: {e}")
|
| 467 |
|