File size: 11,748 Bytes
f8319a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
973cd6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8319a8
 
 
 
 
 
 
 
973cd6f
 
 
 
f8319a8
973cd6f
f8319a8
 
973cd6f
 
 
 
f8319a8
 
973cd6f
 
 
 
 
 
 
 
 
f8319a8
 
973cd6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8319a8
 
973cd6f
f8319a8
973cd6f
f8319a8
 
973cd6f
 
 
 
 
 
 
 
 
 
 
f8319a8
 
 
 
973cd6f
 
 
f8319a8
 
973cd6f
 
 
 
 
 
 
f8319a8
 
973cd6f
f8319a8
 
 
973cd6f
 
 
 
 
f8319a8
 
 
 
 
973cd6f
f8319a8
 
 
 
973cd6f
 
f8319a8
 
973cd6f
f8319a8
 
 
 
 
 
 
 
 
973cd6f
 
f8319a8
 
 
 
973cd6f
f8319a8
 
973cd6f
 
 
f8319a8
 
 
 
973cd6f
 
f8319a8
 
973cd6f
 
 
 
 
 
f8319a8
 
973cd6f
 
 
 
 
 
 
 
 
f8319a8
 
973cd6f
f8319a8
 
 
 
 
 
973cd6f
 
 
 
 
 
 
f8319a8
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
import logging
from uuid import uuid4
from collections import deque
from typing import Dict, Any, List

from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State

try:
    from .models import AutomathreasonerAction, AutomathreasonerObservation
    from .generator import TaskGenerationEngine
    from .verifier import VerifierSystem
    from .rewards import RewardSystem
except ImportError:
    from env.models import AutomathreasonerAction, AutomathreasonerObservation
    from env.generator import TaskGenerationEngine
    from env.verifier import VerifierSystem
    from env.rewards import RewardSystem

logger = logging.getLogger(__name__)

class AutomathreasonerEnvironment(Environment):
    """
    OpenEnv-compliant RL environment for symbolic calculus (indefinite integration).
    
    Key improvements over v1:
    1. Faster, smoother curriculum progression (Scaf-GRPO inspired)
    2. Scaffold hints injected after repeated failures (breaks "learning cliff")
    3. Increased max_steps (3 β†’ 5) for more within-episode learning
    4. Consecutive failure tracking for adaptive scaffolding
    5. Technique-aware problem generation
    6. Rolling accuracy uses weighted window for responsiveness
    
    References:
        - Scaf-GRPO (arxiv, 2025): hierarchical hints for hard problems
        - GRPO-Ξ»: credit assignment for faster convergence
        - arxiv:2408.10215: reward shaping best practices
    """
    
    SUPPORTS_CONCURRENT_SESSIONS: bool = True

    def __init__(self):
        self._state = State(episode_id=str(uuid4()), step_count=0)
        self.generator = TaskGenerationEngine()
        self.verifier = VerifierSystem()
        self.reward_system = RewardSystem(max_len=2000)
        
        # --- Curriculum tracking (improved) ---
        self.difficulty_level = 1.5          # Start slightly easier to build momentum
        self.rolling_results = deque(maxlen=10)  # Shorter window (was 20) β†’ faster adaptation
        self.rolling_rewards = deque(maxlen=10)   # Track reward magnitudes too
        
        # --- Current problem state ---
        self.current_problem = ""
        self.current_solution = ""
        self.current_sympy_f = None       # Integration ground truth (integrand)
        self.current_sympy_F = None       # Antiderivative (for structural comparison)
        self.current_technique = ""       # Detected integration technique
        self.current_scaffold_hints = {}  # Progressive hints
        self.times_seen_problem = 0
        self.history: List[Dict[str, Any]] = []
        self.max_steps = 5                # Increased from 3 β†’ more within-episode learning
        
        # --- Failure tracking for scaffolding ---
        self.consecutive_failures = 0
        self.total_episodes = 0
        self.total_correct = 0
        
        # --- Technique performance tracking ---
        self.technique_performance: Dict[str, List[float]] = {}

    def _update_curriculum(self):
        """
        Update difficulty based on rolling accuracy.
        
        Improved:
        - Shorter rolling window (10 vs 20) for faster response
        - Smoother progression: advance proportional to accuracy
        - Lower thresholds to maintain momentum
        - Technique-aware adaptation
        """
        if len(self.rolling_results) < 3:
            return
            
        accuracy = sum(self.rolling_results) / len(self.rolling_results)
        avg_reward = sum(self.rolling_rewards) / len(self.rolling_rewards) if self.rolling_rewards else 0
        
        # Advance: accuracy > 0.50 (was 0.7)
        if accuracy > 0.50:
            # Proportional advancement β€” faster when doing well
            advance = 0.2 + 0.3 * accuracy  # Range: 0.35 to 0.5
            self.difficulty_level += advance
            logger.info(f"πŸ“ˆ Curriculum UP: Accuracy={accuracy:.2f}, "
                       f"AvgReward={avg_reward:.3f}, NewDiff={self.difficulty_level:.1f}")
        
        # Partial advance: decent reward signal even without full correctness
        elif avg_reward > 0.35 and accuracy > 0.25:
            self.difficulty_level += 0.1
            logger.info(f"πŸ“Š Curriculum MICRO-UP: Accuracy={accuracy:.2f}, "
                       f"AvgReward={avg_reward:.3f}, NewDiff={self.difficulty_level:.1f}")
        
        # Retreat: accuracy < 0.20 (was 0.6)
        elif accuracy < 0.20:
            self.difficulty_level = max(1.0, self.difficulty_level - 0.3)
            logger.info(f"πŸ“‰ Curriculum DOWN: Accuracy={accuracy:.2f}, "
                       f"NewDiff={self.difficulty_level:.1f}")
    
    def _get_scaffold_observation(self) -> str:
        """
        Generate scaffold hint based on consecutive failures.
        Implements Scaf-GRPO progressive hint injection.
        
        - 0-1 failures: no hint
        - 2 failures: technique hint (level 1)
        - 3 failures: first step hint (level 2)  
        - 4+ failures: detailed hint (level 3)
        """
        if self.consecutive_failures < 2 or not self.current_scaffold_hints:
            return ""
        
        if self.consecutive_failures == 2:
            hint = self.current_scaffold_hints.get('hint_level_1', '')
            if hint:
                return f"\n[Hint: {hint}]"
        
        elif self.consecutive_failures == 3:
            hint = self.current_scaffold_hints.get('hint_level_2', '')
            if hint:
                return f"\n[Hint: {hint}]"
        
        else:  # 4+
            hint = self.current_scaffold_hints.get('hint_level_3', '')
            if hint:
                return f"\n[Strong Hint: {hint}]"
        
        return ""
    
    def _update_technique_performance(self, technique: str, correct: bool):
        """Track per-technique performance for adaptive curriculum."""
        if technique not in self.technique_performance:
            self.technique_performance[technique] = []
        
        self.technique_performance[technique].append(1.0 if correct else 0.0)
        
        # Keep last 20 results per technique
        if len(self.technique_performance[technique]) > 20:
            self.technique_performance[technique] = self.technique_performance[technique][-20:]
    
    def _get_weakest_technique(self) -> str:
        """Find the technique the model struggles with most."""
        worst_technique = ""
        worst_accuracy = 1.0
        
        for technique, results in self.technique_performance.items():
            if len(results) >= 3:
                acc = sum(results) / len(results)
                if acc < worst_accuracy:
                    worst_accuracy = acc
                    worst_technique = technique
        
        return worst_technique

    def reset(self) -> AutomathreasonerObservation:
        """Reset environment to a new problem with scaffold support."""
        self._update_curriculum()
        self.total_episodes += 1
        
        self._state = State(episode_id=str(uuid4()), step_count=0)
        
        # Occasionally target the weakest technique (20% of the time)
        import random
        weakest = self._get_weakest_technique()
        if weakest and random.random() < 0.2 and self.total_episodes > 10:
            task = self.generator.generate_technique_focused_task(
                weakest, difficulty=max(1.0, self.difficulty_level - 0.5)
            )
            logger.info(f"🎯 Targeting weak technique: {weakest}")
        else:
            task = self.generator.generate_task(target_difficulty_band=self.difficulty_level)
        
        self.current_problem = task['problem']
        self.current_solution = task['solution']
        self.current_sympy_f = task.get('sympy_f')
        self.current_sympy_F = task.get('sympy_F')
        self.current_technique = task.get('technique', '')
        self.current_scaffold_hints = task.get('scaffold_hints', {})
        self.times_seen_problem = 0
        self.history = []
        self.consecutive_failures = 0
        
        # Build problem text with optional scaffold hint
        problem_text = self.current_problem
        scaffold = self._get_scaffold_observation()
        if scaffold:
            problem_text += scaffold
        
        return AutomathreasonerObservation(
            problem_text=problem_text,
            difficulty_level=self.difficulty_level,
            history=[],
            reward=0.0,
            done=False,
            metadata={
                "technique": self.current_technique,
                "episode_number": self.total_episodes,
            }
        )

    def step(self, action: AutomathreasonerAction) -> AutomathreasonerObservation:  # type: ignore[override]
        self._state.step_count += 1
        
        # Verification with graduated correctness and technique awareness
        c, q, p_sup, r_ref = self.verifier.verify(
            action.reasoning, 
            action.final_answer, 
            self.current_solution,
            sympy_f=self.current_sympy_f,
            technique_hint=self.current_technique,
        )
        
        # Reward computation β€” all 7 components + format compliance
        action_str = f"{action.reasoning} \n {action.final_answer}"
        total_r, components = self.reward_system.compute_reward(
            correctness=c,
            reasoning_quality=q,
            process_supervision=p_sup,
            reflection_score=r_ref,
            action_str=action_str,
            final_answer=action.final_answer,
            history=self.history,
            times_seen_problem=self.times_seen_problem,
            reasoning=action.reasoning,
        )
        
        self.times_seen_problem += 1
        
        # Update history β€” store BOTH keys for backward compatibility
        attempt = {
            "prediction": action.final_answer,
            "final_answer": action.final_answer,  # BUGFIX: also store as final_answer
            "correctness": c,
            "reward": total_r,
        }
        self.history.append(attempt)
        obs_history = self.history[-3:]
        
        # Correctness check β€” graduated (threshold at 0.7 for "correct enough")
        is_correct = (c >= 0.7)
        done = is_correct or self._state.step_count >= self.max_steps
        
        if is_correct:
            self.consecutive_failures = 0
            self.total_correct += 1
        else:
            self.consecutive_failures += 1
        
        if done:
            self.rolling_results.append(1 if is_correct else 0)
            self.rolling_rewards.append(total_r)
            self._update_technique_performance(self.current_technique, is_correct)
        
        # Build problem text with scaffold hints for next attempt (if not done)
        problem_text = self.current_problem
        if not done:
            scaffold = self._get_scaffold_observation()
            if scaffold:
                problem_text += scaffold
            
        return AutomathreasonerObservation(
            problem_text=problem_text,
            difficulty_level=self.difficulty_level,
            history=obs_history,
            reward=total_r,
            done=done,
            metadata={
                "reward_components": components,
                "ground_truth": self.current_solution if done else "HIDDEN",
                "is_correct": is_correct,
                "technique": self.current_technique,
                "consecutive_failures": self.consecutive_failures,
                "correctness_score": c,
                "curriculum_difficulty": self.difficulty_level,
                "episode_number": self.total_episodes,
            }
        )

    @property
    def state(self) -> State:
        return self._state