AutoMathReasoner / env /rewards.py
HarshitShri026's picture
push
973cd6f
import math
from typing import Dict, Any, List, Tuple
class RewardSystem:
"""
Dense, multi-component reward system for mathematical RL training.
Key improvements over v1:
1. All 7 reward components now contribute to the final score
2. Partial credit support (continuous C ∈ [0,1] from verifier)
3. Fixed history key mismatch (was breaking diversity detection)
4. Adaptive efficiency curve that doesn't over-penalize reasonable lengths
5. Removed random noise (adds variance without useful signal)
6. Added format compliance reward for structured output
Reward equation:
R = α·C + β·Q + γ·P + δ·R_ref + η·D_norm + ζ·E_norm + λ·X + μ·F_fmt
Weights: α=0.30, β=0.12, γ=0.10, δ=0.05, η=0.13, ζ=0.08, λ=0.07, μ=0.15
Sum = 1.0
References:
- arxiv:2408.10215 (Reward shaping for RL convergence)
- arxiv:2601.19100 (Reward engineering for software/code tasks)
- DeepSeek-R1 GRPO (graduated correctness)
- GRPO-λ (credit assignment)
"""
# Reward component weights (sum to 1.0)
W_CORRECTNESS = 0.30 # α: Primary — correctness drives learning
W_REASONING = 0.12 # β: Reasoning quality
W_PROCESS = 0.10 # γ: Step-by-step process supervision
W_REFLECTION = 0.05 # δ: Self-correction behavior
W_DIVERSITY = 0.13 # η: Answer diversity (prevents repetition)
W_EFFICIENCY = 0.08 # ζ: Token efficiency
W_EXPLORATION = 0.07 # λ: Exploration bonus
W_FORMAT = 0.15 # μ: Format compliance (model must learn structure)
def __init__(self, max_len: int = 1000):
self.max_len = max_len
def compute_diversity(self, current_answer: str, history: List[Dict[str, Any]]) -> float:
"""
D = diversity (difference from past attempts).
Graduated penalty instead of binary:
- Exact repeat: -1.0 (steep penalty)
- Similar to a past answer: -0.3
- Unique: +1.0
"""
if not history:
return 1.0
cur_ans_clean = current_answer.strip().lower()
if not cur_ans_clean:
return 0.0 # Empty answer gets no diversity credit
for attempt in history:
# BUGFIX: check both 'final_answer' and 'prediction' keys for compatibility
prev_ans = attempt.get('final_answer', attempt.get('prediction', '')).strip().lower()
if prev_ans == cur_ans_clean:
return -1.0 # Exact repeat — strong penalty
# Check for near-duplicates (edit distance heuristic)
if prev_ans and cur_ans_clean:
# Simple character overlap ratio
overlap = sum(1 for a, b in zip(prev_ans, cur_ans_clean) if a == b)
max_len = max(len(prev_ans), len(cur_ans_clean))
if max_len > 0 and overlap / max_len > 0.85:
return -0.3 # Near-duplicate — moderate penalty
return 1.0
def compute_efficiency(self, action_string: str) -> float:
"""
E = efficiency. Adaptive Gaussian penalty curve.
Improved: wider optimal zone (30-120 tokens) to avoid penalizing
legitimate mathematical reasoning that naturally needs more space.
E ∈ [-0.5, 0.0] (always a penalty or neutral, never a bonus)
"""
approx_tokens = len(action_string) / 4.0
optimal_center = 80.0 # Wider center for math
optimal_width = 60.0 # Generous width
# Gentle Gaussian — penalizes only extreme lengths
ratio = (approx_tokens - optimal_center) / optimal_width
e = math.exp(-(ratio ** 2)) - 1.0
# Additional penalty for very long outputs (anti-rambling)
if approx_tokens > 300:
e -= 0.3 * (approx_tokens - 300) / 300
return max(-1.0, e)
def compute_exploration_bonus(self, action_string: str, times_seen: int) -> float:
"""
[PAPER TRACEABILITY: Exploration via Entropy Bonus]
G. EXPLORATION VIA ENTROPY BONUS
X = (entropy_bonus) / sqrt(1 + times_seen_problem)
Improved with better entropy estimation using word-level diversity.
"""
length = len(action_string)
if length == 0:
return 0.0
# Character-level entropy
unique_ratio = len(set(action_string)) / length
char_entropy = math.log1p(unique_ratio)
# Word-level diversity bonus (rewards varied vocabulary)
words = action_string.lower().split()
if words:
unique_word_ratio = len(set(words)) / len(words)
word_entropy = math.log1p(unique_word_ratio)
else:
word_entropy = 0.0
combined = 0.6 * char_entropy + 0.4 * word_entropy
return combined / math.sqrt(1.0 + times_seen)
def compute_format_compliance(self, action_str: str, reasoning: str, final_answer: str) -> float:
"""
Format compliance reward — teaches the model to output structured responses.
Rewards:
- Having both reasoning and answer sections
- Using mathematical notation
- Proper structure (reasoning before answer)
F ∈ [0, 1]
"""
score = 0.0
# Has non-empty reasoning
if reasoning and len(reasoning.strip()) > 10:
score += 0.3
# Has non-empty final answer
if final_answer and len(final_answer.strip()) > 0:
score += 0.3
# Answer contains mathematical content
math_indicators = ['x', '=', '+', '-', '*', '/', '^', 'sin', 'cos', 'exp', 'log', '(']
math_count = sum(1 for m in math_indicators if m in final_answer.lower())
if math_count >= 2:
score += 0.2
elif math_count >= 1:
score += 0.1
# Reasoning contains structured steps
if any(marker in reasoning.lower() for marker in ['step', 'first', 'then', 'therefore', '=']):
score += 0.2
return min(1.0, score)
def detect_trivial_output(self, action_string: str) -> bool:
"""Anti-reward hacking: detect trivial constant outputs"""
# If the output is just a single character repeated or very low entropy
if len(action_string) < 2:
return True
unique_chars = len(set(action_string))
if unique_chars < 3 and len(action_string) > 10:
return True
# Detect repetitive patterns
if len(action_string) > 20:
# Check if a short pattern is repeated
for plen in range(1, 6):
pattern = action_string[:plen]
if action_string == pattern * (len(action_string) // plen) + pattern[:len(action_string) % plen]:
return True
return False
def compute_reward(self,
correctness: float,
reasoning_quality: float,
process_supervision: float,
reflection_score: float,
action_str: str,
final_answer: str,
history: List[Dict[str, Any]],
times_seen_problem: int,
reasoning: str = "") -> Tuple[float, Dict[str, float]]:
"""
Dense composite reward using ALL 7 components + format compliance.
R = α·C + β·Q_norm + γ·P_norm + δ·R_norm + η·D_norm + ζ·E_norm + λ·X + μ·F_fmt
All components are normalized to [0, 1] before weighting.
Final reward ∈ [0, 1].
"""
if self.detect_trivial_output(action_str):
components = {
"total_reward": -0.5,
"C_correctness": 0.0, "Q_reasoning": 0.0,
"P_process_supervision": 0.0, "R_reflection": 0.0,
"D_diversity": 0.0, "E_efficiency": -1.0,
"X_exploration": 0.0, "F_format": 0.0,
}
return -0.5, components
# --- Raw component computation ---
c = correctness # Already ∈ [0, 1] with graduated scoring
q = reasoning_quality
d = self.compute_diversity(final_answer, history)
e = self.compute_efficiency(action_str)
x = self.compute_exploration_bonus(action_str, times_seen_problem)
f_fmt = self.compute_format_compliance(action_str, reasoning, final_answer)
# If repeated answer, reduce correctness credit (anti-hacking)
if d < -0.5:
c = c * 0.3 # Steep discount but not full zeroing
# --- Normalize all components to [0, 1] ---
q_norm = min(1.0, max(0.0, math.tanh(q)))
p_norm = (process_supervision + 1.0) / 2.0 # [-1, 1] → [0, 1]
r_norm = (reflection_score + 1.0) / 2.0 # [-1, 1] → [0, 1]
d_norm = (d + 1.0) / 2.0 # [-1, 1] → [0, 1]
e_norm = (e + 1.0) / 1.0 # [-1, 0] → [0, 1]
e_norm = min(1.0, max(0.0, e_norm))
x_norm = min(1.0, max(0.0, x))
f_norm = min(1.0, max(0.0, f_fmt))
# --- Weighted composite ---
total_r = (
self.W_CORRECTNESS * c +
self.W_REASONING * q_norm +
self.W_PROCESS * p_norm +
self.W_REFLECTION * r_norm +
self.W_DIVERSITY * d_norm +
self.W_EFFICIENCY * e_norm +
self.W_EXPLORATION * x_norm +
self.W_FORMAT * f_norm
)
# Clamp to [0, 1]
total_r = min(1.0, max(0.0, total_r))
components = {
"total_reward": total_r,
"C_correctness": c,
"Q_reasoning": q_norm,
"P_process_supervision": process_supervision,
"R_reflection": reflection_score,
"D_diversity": d,
"E_efficiency": e,
"X_exploration": x,
"F_format": f_fmt,
# Weighted contributions (for debugging)
"_w_C": self.W_CORRECTNESS * c,
"_w_Q": self.W_REASONING * q_norm,
"_w_P": self.W_PROCESS * p_norm,
"_w_R": self.W_REFLECTION * r_norm,
"_w_D": self.W_DIVERSITY * d_norm,
"_w_E": self.W_EFFICIENCY * e_norm,
"_w_X": self.W_EXPLORATION * x_norm,
"_w_F": self.W_FORMAT * f_norm,
}
return total_r, components