AutoMathReasoner / tests /test_env.py
HarshitShri026's picture
push
973cd6f
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from env.generator import TaskGenerationEngine
from env.verifier import VerifierSystem
from env.rewards import RewardSystem
from env.environment import AutomathreasonerEnvironment
from env.models import AutomathreasonerAction
def test_generator():
engine = TaskGenerationEngine()
# Test task generation at various difficulty levels
for diff in [1.0, 3.0, 5.0]:
task = engine.generate_task(target_difficulty_band=diff)
assert "problem" in task
assert "solution" in task
assert "difficulty" in task
assert "technique" in task
assert "scaffold_hints" in task
assert task["technique"] in ['power_rule', 'u_substitution', 'by_parts',
'trigonometric', 'exponential', 'logarithmic']
print(f" ✓ Difficulty {diff}: technique={task['technique']}, problem={task['problem'][:60]}...")
# Test variant generation
task = engine.generate_task(target_difficulty_band=4.0)
variants = engine.generate_variants(task, count=3)
assert len(variants) > 0
for v in variants:
assert "problem" in v
assert "technique" in v
print(f" ✓ Generated {len(variants)} variants")
# Test technique-focused generation
for tech in ['power_rule', 'u_substitution', 'by_parts']:
task = engine.generate_technique_focused_task(tech, difficulty=2.0)
assert task["technique"] == tech
print(f" ✓ Technique-focused: {tech}")
def test_verifier():
verifier = VerifierSystem()
# Exact match
assert verifier.check_exact_match("42", "42")
assert verifier.check_exact_match(" 42 ", "42")
print(" ✓ Exact match")
# Numeric tolerance
assert verifier.check_numeric_tolerance("3.14159", "3.1415")
assert not verifier.check_numeric_tolerance("4.1415", "3.1415")
print(" ✓ Numeric tolerance")
# Python execution
assert verifier.check_python_execution("2 + 2", "4")
print(" ✓ Python execution")
# Full verification — now returns 4 values (c, q, p, r)
c, q, p, r = verifier.verify("Step 1: Because 2 + 2 is 4. Therefore the answer is 4.", "4", "4")
assert c == 1.0
assert q > 0.0
print(f" ✓ Full verify: C={c}, Q={q:.3f}, P={p:.3f}, R={r:.3f}")
# Graduated correctness — structural similarity
score = verifier.check_structural_similarity("x**3", "2*x**3")
assert score > 0.0 # Should get partial credit for same structure
print(f" ✓ Structural similarity: {score:.2f}")
# Technique recognition
tech_score = verifier.check_technique_recognition(
"Let u = x^2, then du = 2x dx. By substitution we get...",
"u_substitution"
)
assert tech_score > 0.5
print(f" ✓ Technique recognition: {tech_score:.2f}")
# Process supervision — improved
p_good = verifier.check_process_supervision(
"Step 1: Identify the integrand. Step 2: Apply the power rule. Therefore x^3/3 + C."
)
p_bad = verifier.check_process_supervision("so = 42")
assert p_good > p_bad
print(f" ✓ Process supervision: good={p_good:.2f}, bad={p_bad:.2f}")
def test_rewards():
reward_sys = RewardSystem(max_len=1000)
# Test diversity — exact repeat penalty
history = [{"final_answer": "42"}]
d = reward_sys.compute_diversity("42", history)
assert d == -1.0
print(f" ✓ Diversity repeat penalty: {d}")
# Test diversity — also works with 'prediction' key (backward compat)
history_v2 = [{"prediction": "42"}]
d2 = reward_sys.compute_diversity("42", history_v2)
assert d2 == -1.0
print(f" ✓ Diversity backward compat: {d2}")
# Test diversity — unique answer
d3 = reward_sys.compute_diversity("99", history)
assert d3 == 1.0
print(f" ✓ Diversity unique bonus: {d3}")
# Test format compliance
f = reward_sys.compute_format_compliance(
"Step 1: Apply power rule.\nAnswer: x^2/2",
"Step 1: Apply power rule.",
"x^2/2"
)
assert f > 0.5
print(f" ✓ Format compliance: {f:.2f}")
# Full reward computation — new signature with all params
r, comps = reward_sys.compute_reward(
correctness=1.0,
reasoning_quality=0.8,
process_supervision=0.5,
reflection_score=0.0,
action_str="Step 1: Apply power rule. Step 2: Simplify. Answer: x^2/2",
final_answer="x^2/2",
history=[],
times_seen_problem=0,
reasoning="Step 1: Apply power rule. Step 2: Simplify.",
)
assert r > 0.0
assert "C_correctness" in comps
assert "F_format" in comps
assert comps["F_format"] > 0 # Format compliance should be non-zero
print(f" ✓ Full reward: {r:.3f}, components: {len(comps)} fields")
# Verify all 7+ components are tracked
expected_keys = ["C_correctness", "Q_reasoning", "P_process_supervision",
"R_reflection", "D_diversity", "E_efficiency",
"X_exploration", "F_format"]
for key in expected_keys:
assert key in comps, f"Missing component: {key}"
print(f" ✓ All {len(expected_keys)} reward components present")
# Trivial output detection
assert reward_sys.detect_trivial_output("a")
assert reward_sys.detect_trivial_output("aaaaaaaaaaaaa")
assert not reward_sys.detect_trivial_output("x^2 + 2x + 1")
print(" ✓ Trivial output detection")
def test_environment_step():
env = AutomathreasonerEnvironment()
obs = env.reset()
assert obs.problem_text != ""
assert obs.difficulty_level > 0
assert len(obs.history) == 0
print(f" ✓ Reset: difficulty={obs.difficulty_level}, problem={obs.problem_text[:60]}...")
# Technique metadata in observation
assert "technique" in obs.metadata
print(f" ✓ Technique metadata: {obs.metadata['technique']}")
# Dummy action step
action = AutomathreasonerAction(
reasoning="Step 1: I identify the integrand. Step 2: Applying the power rule.",
final_answer="x^2/2"
)
obs_after = env.step(action)
assert obs_after.reward is not None
assert len(obs_after.history) == 1
assert "reward_components" in obs_after.metadata
assert "correctness_score" in obs_after.metadata
print(f" ✓ Step: reward={obs_after.reward:.3f}, "
f"correct={obs_after.metadata['is_correct']}, "
f"C={obs_after.metadata['correctness_score']:.2f}")
# Verify history stores both keys
assert "prediction" in obs_after.history[0]
assert "final_answer" in obs_after.history[0]
print(" ✓ History backward compatibility")
def test_curriculum_progression():
"""Test that curriculum actually advances with good performance."""
env = AutomathreasonerEnvironment()
initial_diff = env.difficulty_level
# Simulate a series of correct answers
for _ in range(5):
env.rolling_results.append(1)
env.rolling_rewards.append(0.7)
env._update_curriculum()
assert env.difficulty_level > initial_diff, (
f"Curriculum should advance: {initial_diff} -> {env.difficulty_level}"
)
print(f" ✓ Curriculum advanced: {initial_diff} -> {env.difficulty_level:.1f}")
def test_scaffold_hints():
"""Test that scaffold hints are generated after failures."""
env = AutomathreasonerEnvironment()
env.reset()
# No hint at 0 failures
env.consecutive_failures = 0
hint0 = env._get_scaffold_observation()
assert hint0 == ""
# Hint at 2 failures
env.consecutive_failures = 2
env.current_scaffold_hints = {
'hint_level_1': 'Try u-substitution',
'hint_level_2': 'Let u = x^2',
'hint_level_3': 'The answer starts with sin(x^2)',
}
hint2 = env._get_scaffold_observation()
assert "Hint" in hint2
assert "u-substitution" in hint2
# Stronger hint at 3 failures
env.consecutive_failures = 3
hint3 = env._get_scaffold_observation()
assert "u = x^2" in hint3
# Strongest hint at 4+ failures
env.consecutive_failures = 4
hint4 = env._get_scaffold_observation()
assert "Strong Hint" in hint4
print(" ✓ Scaffold hints: level 1, 2, 3 all working")
def test_graduated_correctness_flow():
"""End-to-end test: partial credit flows through the whole system."""
env = AutomathreasonerEnvironment()
obs = env.reset()
# Submit a plausible but wrong math answer
action = AutomathreasonerAction(
reasoning="Step 1: I apply the power rule. Step 2: I integrate term by term. Therefore the answer is:",
final_answer="x**2 + x" # Almost certainly wrong, but parseable math
)
obs_after = env.step(action)
c_score = obs_after.metadata.get('correctness_score', 0)
# Should get SOME partial credit (> 0) for parseable math with right techniques
print(f" ✓ Graduated correctness: C={c_score:.2f}, reward={obs_after.reward:.3f}")
# Reward should be positive even when wrong (format + reasoning + partial credit)
assert obs_after.reward > 0.0, f"Expected positive reward for structured wrong answer, got {obs_after.reward}"
print(f" ✓ Positive reward for structured wrong answer: {obs_after.reward:.3f}")
if __name__ == "__main__":
print("=" * 60)
print("AutoMathReasoner Test Suite (v2 - Optimized)")
print("=" * 60)
print("\n[TEST] test_generator")
test_generator()
print("\n[TEST] test_verifier")
test_verifier()
print("\n[TEST] test_rewards")
test_rewards()
print("\n[TEST] test_environment_step")
test_environment_step()
print("\n[TEST] test_curriculum_progression")
test_curriculum_progression()
print("\n[TEST] test_scaffold_hints")
test_scaffold_hints()
print("\n[TEST] test_graduated_correctness_flow")
test_graduated_correctness_flow()
print("\n" + "=" * 60)
print("[OK] ALL TESTS PASSED")
print("=" * 60)