Spaces:
Sleeping
Sleeping
| import logging | |
| from uuid import uuid4 | |
| from collections import deque | |
| from typing import Dict, Any, List | |
| from openenv.core.env_server.interfaces import Environment | |
| from openenv.core.env_server.types import State | |
| try: | |
| from .models import AutomathreasonerAction, AutomathreasonerObservation | |
| from .generator import TaskGenerationEngine | |
| from .verifier import VerifierSystem | |
| from .rewards import RewardSystem | |
| except ImportError: | |
| from env.models import AutomathreasonerAction, AutomathreasonerObservation | |
| from env.generator import TaskGenerationEngine | |
| from env.verifier import VerifierSystem | |
| from env.rewards import RewardSystem | |
| logger = logging.getLogger(__name__) | |
| class AutomathreasonerEnvironment(Environment): | |
| """ | |
| OpenEnv-compliant RL environment for symbolic calculus (indefinite integration). | |
| Key improvements over v1: | |
| 1. Faster, smoother curriculum progression (Scaf-GRPO inspired) | |
| 2. Scaffold hints injected after repeated failures (breaks "learning cliff") | |
| 3. Increased max_steps (3 → 5) for more within-episode learning | |
| 4. Consecutive failure tracking for adaptive scaffolding | |
| 5. Technique-aware problem generation | |
| 6. Rolling accuracy uses weighted window for responsiveness | |
| References: | |
| - Scaf-GRPO (arxiv, 2025): hierarchical hints for hard problems | |
| - GRPO-λ: credit assignment for faster convergence | |
| - arxiv:2408.10215: reward shaping best practices | |
| """ | |
| SUPPORTS_CONCURRENT_SESSIONS: bool = True | |
| def __init__(self): | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| self.generator = TaskGenerationEngine() | |
| self.verifier = VerifierSystem() | |
| self.reward_system = RewardSystem(max_len=2000) | |
| # --- Curriculum tracking (improved) --- | |
| self.difficulty_level = 1.5 # Start slightly easier to build momentum | |
| self.rolling_results = deque(maxlen=10) # Shorter window (was 20) → faster adaptation | |
| self.rolling_rewards = deque(maxlen=10) # Track reward magnitudes too | |
| # --- Current problem state --- | |
| self.current_problem = "" | |
| self.current_solution = "" | |
| self.current_sympy_f = None # Integration ground truth (integrand) | |
| self.current_sympy_F = None # Antiderivative (for structural comparison) | |
| self.current_technique = "" # Detected integration technique | |
| self.current_scaffold_hints = {} # Progressive hints | |
| self.times_seen_problem = 0 | |
| self.history: List[Dict[str, Any]] = [] | |
| self.max_steps = 5 # Increased from 3 → more within-episode learning | |
| # --- Failure tracking for scaffolding --- | |
| self.consecutive_failures = 0 | |
| self.total_episodes = 0 | |
| self.total_correct = 0 | |
| # --- Technique performance tracking --- | |
| self.technique_performance: Dict[str, List[float]] = {} | |
| def _update_curriculum(self): | |
| """ | |
| Update difficulty based on rolling accuracy. | |
| Improved: | |
| - Shorter rolling window (10 vs 20) for faster response | |
| - Smoother progression: advance proportional to accuracy | |
| - Lower thresholds to maintain momentum | |
| - Technique-aware adaptation | |
| """ | |
| if len(self.rolling_results) < 3: | |
| return | |
| accuracy = sum(self.rolling_results) / len(self.rolling_results) | |
| avg_reward = sum(self.rolling_rewards) / len(self.rolling_rewards) if self.rolling_rewards else 0 | |
| # Advance: accuracy > 0.50 (was 0.7) | |
| if accuracy > 0.50: | |
| # Proportional advancement — faster when doing well | |
| advance = 0.2 + 0.3 * accuracy # Range: 0.35 to 0.5 | |
| self.difficulty_level += advance | |
| logger.info(f"📈 Curriculum UP: Accuracy={accuracy:.2f}, " | |
| f"AvgReward={avg_reward:.3f}, NewDiff={self.difficulty_level:.1f}") | |
| # Partial advance: decent reward signal even without full correctness | |
| elif avg_reward > 0.35 and accuracy > 0.25: | |
| self.difficulty_level += 0.1 | |
| logger.info(f"📊 Curriculum MICRO-UP: Accuracy={accuracy:.2f}, " | |
| f"AvgReward={avg_reward:.3f}, NewDiff={self.difficulty_level:.1f}") | |
| # Retreat: accuracy < 0.20 (was 0.6) | |
| elif accuracy < 0.20: | |
| self.difficulty_level = max(1.0, self.difficulty_level - 0.3) | |
| logger.info(f"📉 Curriculum DOWN: Accuracy={accuracy:.2f}, " | |
| f"NewDiff={self.difficulty_level:.1f}") | |
| def _get_scaffold_observation(self) -> str: | |
| """ | |
| Generate scaffold hint based on consecutive failures. | |
| Implements Scaf-GRPO progressive hint injection. | |
| - 0-1 failures: no hint | |
| - 2 failures: technique hint (level 1) | |
| - 3 failures: first step hint (level 2) | |
| - 4+ failures: detailed hint (level 3) | |
| """ | |
| if self.consecutive_failures < 2 or not self.current_scaffold_hints: | |
| return "" | |
| if self.consecutive_failures == 2: | |
| hint = self.current_scaffold_hints.get('hint_level_1', '') | |
| if hint: | |
| return f"\n[Hint: {hint}]" | |
| elif self.consecutive_failures == 3: | |
| hint = self.current_scaffold_hints.get('hint_level_2', '') | |
| if hint: | |
| return f"\n[Hint: {hint}]" | |
| else: # 4+ | |
| hint = self.current_scaffold_hints.get('hint_level_3', '') | |
| if hint: | |
| return f"\n[Strong Hint: {hint}]" | |
| return "" | |
| def _update_technique_performance(self, technique: str, correct: bool): | |
| """Track per-technique performance for adaptive curriculum.""" | |
| if technique not in self.technique_performance: | |
| self.technique_performance[technique] = [] | |
| self.technique_performance[technique].append(1.0 if correct else 0.0) | |
| # Keep last 20 results per technique | |
| if len(self.technique_performance[technique]) > 20: | |
| self.technique_performance[technique] = self.technique_performance[technique][-20:] | |
| def _get_weakest_technique(self) -> str: | |
| """Find the technique the model struggles with most.""" | |
| worst_technique = "" | |
| worst_accuracy = 1.0 | |
| for technique, results in self.technique_performance.items(): | |
| if len(results) >= 3: | |
| acc = sum(results) / len(results) | |
| if acc < worst_accuracy: | |
| worst_accuracy = acc | |
| worst_technique = technique | |
| return worst_technique | |
| def reset(self) -> AutomathreasonerObservation: | |
| """Reset environment to a new problem with scaffold support.""" | |
| self._update_curriculum() | |
| self.total_episodes += 1 | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| # Occasionally target the weakest technique (20% of the time) | |
| import random | |
| weakest = self._get_weakest_technique() | |
| if weakest and random.random() < 0.2 and self.total_episodes > 10: | |
| task = self.generator.generate_technique_focused_task( | |
| weakest, difficulty=max(1.0, self.difficulty_level - 0.5) | |
| ) | |
| logger.info(f"🎯 Targeting weak technique: {weakest}") | |
| else: | |
| task = self.generator.generate_task(target_difficulty_band=self.difficulty_level) | |
| self.current_problem = task['problem'] | |
| self.current_solution = task['solution'] | |
| self.current_sympy_f = task.get('sympy_f') | |
| self.current_sympy_F = task.get('sympy_F') | |
| self.current_technique = task.get('technique', '') | |
| self.current_scaffold_hints = task.get('scaffold_hints', {}) | |
| self.times_seen_problem = 0 | |
| self.history = [] | |
| self.consecutive_failures = 0 | |
| # Build problem text with optional scaffold hint | |
| problem_text = self.current_problem | |
| scaffold = self._get_scaffold_observation() | |
| if scaffold: | |
| problem_text += scaffold | |
| return AutomathreasonerObservation( | |
| problem_text=problem_text, | |
| difficulty_level=self.difficulty_level, | |
| history=[], | |
| reward=0.0, | |
| done=False, | |
| metadata={ | |
| "technique": self.current_technique, | |
| "episode_number": self.total_episodes, | |
| } | |
| ) | |
| def step(self, action: AutomathreasonerAction) -> AutomathreasonerObservation: # type: ignore[override] | |
| self._state.step_count += 1 | |
| # Verification with graduated correctness and technique awareness | |
| c, q, p_sup, r_ref = self.verifier.verify( | |
| action.reasoning, | |
| action.final_answer, | |
| self.current_solution, | |
| sympy_f=self.current_sympy_f, | |
| technique_hint=self.current_technique, | |
| ) | |
| # Reward computation — all 7 components + format compliance | |
| action_str = f"{action.reasoning} \n {action.final_answer}" | |
| total_r, components = self.reward_system.compute_reward( | |
| correctness=c, | |
| reasoning_quality=q, | |
| process_supervision=p_sup, | |
| reflection_score=r_ref, | |
| action_str=action_str, | |
| final_answer=action.final_answer, | |
| history=self.history, | |
| times_seen_problem=self.times_seen_problem, | |
| reasoning=action.reasoning, | |
| ) | |
| self.times_seen_problem += 1 | |
| # Update history — store BOTH keys for backward compatibility | |
| attempt = { | |
| "prediction": action.final_answer, | |
| "final_answer": action.final_answer, # BUGFIX: also store as final_answer | |
| "correctness": c, | |
| "reward": total_r, | |
| } | |
| self.history.append(attempt) | |
| obs_history = self.history[-3:] | |
| # Correctness check — graduated (threshold at 0.7 for "correct enough") | |
| is_correct = (c >= 0.7) | |
| done = is_correct or self._state.step_count >= self.max_steps | |
| if is_correct: | |
| self.consecutive_failures = 0 | |
| self.total_correct += 1 | |
| else: | |
| self.consecutive_failures += 1 | |
| if done: | |
| self.rolling_results.append(1 if is_correct else 0) | |
| self.rolling_rewards.append(total_r) | |
| self._update_technique_performance(self.current_technique, is_correct) | |
| # Build problem text with scaffold hints for next attempt (if not done) | |
| problem_text = self.current_problem | |
| if not done: | |
| scaffold = self._get_scaffold_observation() | |
| if scaffold: | |
| problem_text += scaffold | |
| return AutomathreasonerObservation( | |
| problem_text=problem_text, | |
| difficulty_level=self.difficulty_level, | |
| history=obs_history, | |
| reward=total_r, | |
| done=done, | |
| metadata={ | |
| "reward_components": components, | |
| "ground_truth": self.current_solution if done else "HIDDEN", | |
| "is_correct": is_correct, | |
| "technique": self.current_technique, | |
| "consecutive_failures": self.consecutive_failures, | |
| "correctness_score": c, | |
| "curriculum_difficulty": self.difficulty_level, | |
| "episode_number": self.total_episodes, | |
| } | |
| ) | |
| def state(self) -> State: | |
| return self._state | |