AutoMathReasoner / env /environment.py
HarshitShri026's picture
push
973cd6f
import logging
from uuid import uuid4
from collections import deque
from typing import Dict, Any, List
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State
try:
from .models import AutomathreasonerAction, AutomathreasonerObservation
from .generator import TaskGenerationEngine
from .verifier import VerifierSystem
from .rewards import RewardSystem
except ImportError:
from env.models import AutomathreasonerAction, AutomathreasonerObservation
from env.generator import TaskGenerationEngine
from env.verifier import VerifierSystem
from env.rewards import RewardSystem
logger = logging.getLogger(__name__)
class AutomathreasonerEnvironment(Environment):
"""
OpenEnv-compliant RL environment for symbolic calculus (indefinite integration).
Key improvements over v1:
1. Faster, smoother curriculum progression (Scaf-GRPO inspired)
2. Scaffold hints injected after repeated failures (breaks "learning cliff")
3. Increased max_steps (3 → 5) for more within-episode learning
4. Consecutive failure tracking for adaptive scaffolding
5. Technique-aware problem generation
6. Rolling accuracy uses weighted window for responsiveness
References:
- Scaf-GRPO (arxiv, 2025): hierarchical hints for hard problems
- GRPO-λ: credit assignment for faster convergence
- arxiv:2408.10215: reward shaping best practices
"""
SUPPORTS_CONCURRENT_SESSIONS: bool = True
def __init__(self):
self._state = State(episode_id=str(uuid4()), step_count=0)
self.generator = TaskGenerationEngine()
self.verifier = VerifierSystem()
self.reward_system = RewardSystem(max_len=2000)
# --- Curriculum tracking (improved) ---
self.difficulty_level = 1.5 # Start slightly easier to build momentum
self.rolling_results = deque(maxlen=10) # Shorter window (was 20) → faster adaptation
self.rolling_rewards = deque(maxlen=10) # Track reward magnitudes too
# --- Current problem state ---
self.current_problem = ""
self.current_solution = ""
self.current_sympy_f = None # Integration ground truth (integrand)
self.current_sympy_F = None # Antiderivative (for structural comparison)
self.current_technique = "" # Detected integration technique
self.current_scaffold_hints = {} # Progressive hints
self.times_seen_problem = 0
self.history: List[Dict[str, Any]] = []
self.max_steps = 5 # Increased from 3 → more within-episode learning
# --- Failure tracking for scaffolding ---
self.consecutive_failures = 0
self.total_episodes = 0
self.total_correct = 0
# --- Technique performance tracking ---
self.technique_performance: Dict[str, List[float]] = {}
def _update_curriculum(self):
"""
Update difficulty based on rolling accuracy.
Improved:
- Shorter rolling window (10 vs 20) for faster response
- Smoother progression: advance proportional to accuracy
- Lower thresholds to maintain momentum
- Technique-aware adaptation
"""
if len(self.rolling_results) < 3:
return
accuracy = sum(self.rolling_results) / len(self.rolling_results)
avg_reward = sum(self.rolling_rewards) / len(self.rolling_rewards) if self.rolling_rewards else 0
# Advance: accuracy > 0.50 (was 0.7)
if accuracy > 0.50:
# Proportional advancement — faster when doing well
advance = 0.2 + 0.3 * accuracy # Range: 0.35 to 0.5
self.difficulty_level += advance
logger.info(f"📈 Curriculum UP: Accuracy={accuracy:.2f}, "
f"AvgReward={avg_reward:.3f}, NewDiff={self.difficulty_level:.1f}")
# Partial advance: decent reward signal even without full correctness
elif avg_reward > 0.35 and accuracy > 0.25:
self.difficulty_level += 0.1
logger.info(f"📊 Curriculum MICRO-UP: Accuracy={accuracy:.2f}, "
f"AvgReward={avg_reward:.3f}, NewDiff={self.difficulty_level:.1f}")
# Retreat: accuracy < 0.20 (was 0.6)
elif accuracy < 0.20:
self.difficulty_level = max(1.0, self.difficulty_level - 0.3)
logger.info(f"📉 Curriculum DOWN: Accuracy={accuracy:.2f}, "
f"NewDiff={self.difficulty_level:.1f}")
def _get_scaffold_observation(self) -> str:
"""
Generate scaffold hint based on consecutive failures.
Implements Scaf-GRPO progressive hint injection.
- 0-1 failures: no hint
- 2 failures: technique hint (level 1)
- 3 failures: first step hint (level 2)
- 4+ failures: detailed hint (level 3)
"""
if self.consecutive_failures < 2 or not self.current_scaffold_hints:
return ""
if self.consecutive_failures == 2:
hint = self.current_scaffold_hints.get('hint_level_1', '')
if hint:
return f"\n[Hint: {hint}]"
elif self.consecutive_failures == 3:
hint = self.current_scaffold_hints.get('hint_level_2', '')
if hint:
return f"\n[Hint: {hint}]"
else: # 4+
hint = self.current_scaffold_hints.get('hint_level_3', '')
if hint:
return f"\n[Strong Hint: {hint}]"
return ""
def _update_technique_performance(self, technique: str, correct: bool):
"""Track per-technique performance for adaptive curriculum."""
if technique not in self.technique_performance:
self.technique_performance[technique] = []
self.technique_performance[technique].append(1.0 if correct else 0.0)
# Keep last 20 results per technique
if len(self.technique_performance[technique]) > 20:
self.technique_performance[technique] = self.technique_performance[technique][-20:]
def _get_weakest_technique(self) -> str:
"""Find the technique the model struggles with most."""
worst_technique = ""
worst_accuracy = 1.0
for technique, results in self.technique_performance.items():
if len(results) >= 3:
acc = sum(results) / len(results)
if acc < worst_accuracy:
worst_accuracy = acc
worst_technique = technique
return worst_technique
def reset(self) -> AutomathreasonerObservation:
"""Reset environment to a new problem with scaffold support."""
self._update_curriculum()
self.total_episodes += 1
self._state = State(episode_id=str(uuid4()), step_count=0)
# Occasionally target the weakest technique (20% of the time)
import random
weakest = self._get_weakest_technique()
if weakest and random.random() < 0.2 and self.total_episodes > 10:
task = self.generator.generate_technique_focused_task(
weakest, difficulty=max(1.0, self.difficulty_level - 0.5)
)
logger.info(f"🎯 Targeting weak technique: {weakest}")
else:
task = self.generator.generate_task(target_difficulty_band=self.difficulty_level)
self.current_problem = task['problem']
self.current_solution = task['solution']
self.current_sympy_f = task.get('sympy_f')
self.current_sympy_F = task.get('sympy_F')
self.current_technique = task.get('technique', '')
self.current_scaffold_hints = task.get('scaffold_hints', {})
self.times_seen_problem = 0
self.history = []
self.consecutive_failures = 0
# Build problem text with optional scaffold hint
problem_text = self.current_problem
scaffold = self._get_scaffold_observation()
if scaffold:
problem_text += scaffold
return AutomathreasonerObservation(
problem_text=problem_text,
difficulty_level=self.difficulty_level,
history=[],
reward=0.0,
done=False,
metadata={
"technique": self.current_technique,
"episode_number": self.total_episodes,
}
)
def step(self, action: AutomathreasonerAction) -> AutomathreasonerObservation: # type: ignore[override]
self._state.step_count += 1
# Verification with graduated correctness and technique awareness
c, q, p_sup, r_ref = self.verifier.verify(
action.reasoning,
action.final_answer,
self.current_solution,
sympy_f=self.current_sympy_f,
technique_hint=self.current_technique,
)
# Reward computation — all 7 components + format compliance
action_str = f"{action.reasoning} \n {action.final_answer}"
total_r, components = self.reward_system.compute_reward(
correctness=c,
reasoning_quality=q,
process_supervision=p_sup,
reflection_score=r_ref,
action_str=action_str,
final_answer=action.final_answer,
history=self.history,
times_seen_problem=self.times_seen_problem,
reasoning=action.reasoning,
)
self.times_seen_problem += 1
# Update history — store BOTH keys for backward compatibility
attempt = {
"prediction": action.final_answer,
"final_answer": action.final_answer, # BUGFIX: also store as final_answer
"correctness": c,
"reward": total_r,
}
self.history.append(attempt)
obs_history = self.history[-3:]
# Correctness check — graduated (threshold at 0.7 for "correct enough")
is_correct = (c >= 0.7)
done = is_correct or self._state.step_count >= self.max_steps
if is_correct:
self.consecutive_failures = 0
self.total_correct += 1
else:
self.consecutive_failures += 1
if done:
self.rolling_results.append(1 if is_correct else 0)
self.rolling_rewards.append(total_r)
self._update_technique_performance(self.current_technique, is_correct)
# Build problem text with scaffold hints for next attempt (if not done)
problem_text = self.current_problem
if not done:
scaffold = self._get_scaffold_observation()
if scaffold:
problem_text += scaffold
return AutomathreasonerObservation(
problem_text=problem_text,
difficulty_level=self.difficulty_level,
history=obs_history,
reward=total_r,
done=done,
metadata={
"reward_components": components,
"ground_truth": self.current_solution if done else "HIDDEN",
"is_correct": is_correct,
"technique": self.current_technique,
"consecutive_failures": self.consecutive_failures,
"correctness_score": c,
"curriculum_difficulty": self.difficulty_level,
"episode_number": self.total_episodes,
}
)
@property
def state(self) -> State:
return self._state