import math from typing import Dict, Any, List, Tuple class RewardSystem: """ Dense, multi-component reward system for mathematical RL training. Key improvements over v1: 1. All 7 reward components now contribute to the final score 2. Partial credit support (continuous C ∈ [0,1] from verifier) 3. Fixed history key mismatch (was breaking diversity detection) 4. Adaptive efficiency curve that doesn't over-penalize reasonable lengths 5. Removed random noise (adds variance without useful signal) 6. Added format compliance reward for structured output Reward equation: R = α·C + β·Q + γ·P + δ·R_ref + η·D_norm + ζ·E_norm + λ·X + μ·F_fmt Weights: α=0.30, β=0.12, γ=0.10, δ=0.05, η=0.13, ζ=0.08, λ=0.07, μ=0.15 Sum = 1.0 References: - arxiv:2408.10215 (Reward shaping for RL convergence) - arxiv:2601.19100 (Reward engineering for software/code tasks) - DeepSeek-R1 GRPO (graduated correctness) - GRPO-λ (credit assignment) """ # Reward component weights (sum to 1.0) W_CORRECTNESS = 0.30 # α: Primary — correctness drives learning W_REASONING = 0.12 # β: Reasoning quality W_PROCESS = 0.10 # γ: Step-by-step process supervision W_REFLECTION = 0.05 # δ: Self-correction behavior W_DIVERSITY = 0.13 # η: Answer diversity (prevents repetition) W_EFFICIENCY = 0.08 # ζ: Token efficiency W_EXPLORATION = 0.07 # λ: Exploration bonus W_FORMAT = 0.15 # μ: Format compliance (model must learn structure) def __init__(self, max_len: int = 1000): self.max_len = max_len def compute_diversity(self, current_answer: str, history: List[Dict[str, Any]]) -> float: """ D = diversity (difference from past attempts). Graduated penalty instead of binary: - Exact repeat: -1.0 (steep penalty) - Similar to a past answer: -0.3 - Unique: +1.0 """ if not history: return 1.0 cur_ans_clean = current_answer.strip().lower() if not cur_ans_clean: return 0.0 # Empty answer gets no diversity credit for attempt in history: # BUGFIX: check both 'final_answer' and 'prediction' keys for compatibility prev_ans = attempt.get('final_answer', attempt.get('prediction', '')).strip().lower() if prev_ans == cur_ans_clean: return -1.0 # Exact repeat — strong penalty # Check for near-duplicates (edit distance heuristic) if prev_ans and cur_ans_clean: # Simple character overlap ratio overlap = sum(1 for a, b in zip(prev_ans, cur_ans_clean) if a == b) max_len = max(len(prev_ans), len(cur_ans_clean)) if max_len > 0 and overlap / max_len > 0.85: return -0.3 # Near-duplicate — moderate penalty return 1.0 def compute_efficiency(self, action_string: str) -> float: """ E = efficiency. Adaptive Gaussian penalty curve. Improved: wider optimal zone (30-120 tokens) to avoid penalizing legitimate mathematical reasoning that naturally needs more space. E ∈ [-0.5, 0.0] (always a penalty or neutral, never a bonus) """ approx_tokens = len(action_string) / 4.0 optimal_center = 80.0 # Wider center for math optimal_width = 60.0 # Generous width # Gentle Gaussian — penalizes only extreme lengths ratio = (approx_tokens - optimal_center) / optimal_width e = math.exp(-(ratio ** 2)) - 1.0 # Additional penalty for very long outputs (anti-rambling) if approx_tokens > 300: e -= 0.3 * (approx_tokens - 300) / 300 return max(-1.0, e) def compute_exploration_bonus(self, action_string: str, times_seen: int) -> float: """ [PAPER TRACEABILITY: Exploration via Entropy Bonus] G. EXPLORATION VIA ENTROPY BONUS X = (entropy_bonus) / sqrt(1 + times_seen_problem) Improved with better entropy estimation using word-level diversity. """ length = len(action_string) if length == 0: return 0.0 # Character-level entropy unique_ratio = len(set(action_string)) / length char_entropy = math.log1p(unique_ratio) # Word-level diversity bonus (rewards varied vocabulary) words = action_string.lower().split() if words: unique_word_ratio = len(set(words)) / len(words) word_entropy = math.log1p(unique_word_ratio) else: word_entropy = 0.0 combined = 0.6 * char_entropy + 0.4 * word_entropy return combined / math.sqrt(1.0 + times_seen) def compute_format_compliance(self, action_str: str, reasoning: str, final_answer: str) -> float: """ Format compliance reward — teaches the model to output structured responses. Rewards: - Having both reasoning and answer sections - Using mathematical notation - Proper structure (reasoning before answer) F ∈ [0, 1] """ score = 0.0 # Has non-empty reasoning if reasoning and len(reasoning.strip()) > 10: score += 0.3 # Has non-empty final answer if final_answer and len(final_answer.strip()) > 0: score += 0.3 # Answer contains mathematical content math_indicators = ['x', '=', '+', '-', '*', '/', '^', 'sin', 'cos', 'exp', 'log', '('] math_count = sum(1 for m in math_indicators if m in final_answer.lower()) if math_count >= 2: score += 0.2 elif math_count >= 1: score += 0.1 # Reasoning contains structured steps if any(marker in reasoning.lower() for marker in ['step', 'first', 'then', 'therefore', '=']): score += 0.2 return min(1.0, score) def detect_trivial_output(self, action_string: str) -> bool: """Anti-reward hacking: detect trivial constant outputs""" # If the output is just a single character repeated or very low entropy if len(action_string) < 2: return True unique_chars = len(set(action_string)) if unique_chars < 3 and len(action_string) > 10: return True # Detect repetitive patterns if len(action_string) > 20: # Check if a short pattern is repeated for plen in range(1, 6): pattern = action_string[:plen] if action_string == pattern * (len(action_string) // plen) + pattern[:len(action_string) % plen]: return True return False def compute_reward(self, correctness: float, reasoning_quality: float, process_supervision: float, reflection_score: float, action_str: str, final_answer: str, history: List[Dict[str, Any]], times_seen_problem: int, reasoning: str = "") -> Tuple[float, Dict[str, float]]: """ Dense composite reward using ALL 7 components + format compliance. R = α·C + β·Q_norm + γ·P_norm + δ·R_norm + η·D_norm + ζ·E_norm + λ·X + μ·F_fmt All components are normalized to [0, 1] before weighting. Final reward ∈ [0, 1]. """ if self.detect_trivial_output(action_str): components = { "total_reward": -0.5, "C_correctness": 0.0, "Q_reasoning": 0.0, "P_process_supervision": 0.0, "R_reflection": 0.0, "D_diversity": 0.0, "E_efficiency": -1.0, "X_exploration": 0.0, "F_format": 0.0, } return -0.5, components # --- Raw component computation --- c = correctness # Already ∈ [0, 1] with graduated scoring q = reasoning_quality d = self.compute_diversity(final_answer, history) e = self.compute_efficiency(action_str) x = self.compute_exploration_bonus(action_str, times_seen_problem) f_fmt = self.compute_format_compliance(action_str, reasoning, final_answer) # If repeated answer, reduce correctness credit (anti-hacking) if d < -0.5: c = c * 0.3 # Steep discount but not full zeroing # --- Normalize all components to [0, 1] --- q_norm = min(1.0, max(0.0, math.tanh(q))) p_norm = (process_supervision + 1.0) / 2.0 # [-1, 1] → [0, 1] r_norm = (reflection_score + 1.0) / 2.0 # [-1, 1] → [0, 1] d_norm = (d + 1.0) / 2.0 # [-1, 1] → [0, 1] e_norm = (e + 1.0) / 1.0 # [-1, 0] → [0, 1] e_norm = min(1.0, max(0.0, e_norm)) x_norm = min(1.0, max(0.0, x)) f_norm = min(1.0, max(0.0, f_fmt)) # --- Weighted composite --- total_r = ( self.W_CORRECTNESS * c + self.W_REASONING * q_norm + self.W_PROCESS * p_norm + self.W_REFLECTION * r_norm + self.W_DIVERSITY * d_norm + self.W_EFFICIENCY * e_norm + self.W_EXPLORATION * x_norm + self.W_FORMAT * f_norm ) # Clamp to [0, 1] total_r = min(1.0, max(0.0, total_r)) components = { "total_reward": total_r, "C_correctness": c, "Q_reasoning": q_norm, "P_process_supervision": process_supervision, "R_reflection": reflection_score, "D_diversity": d, "E_efficiency": e, "X_exploration": x, "F_format": f_fmt, # Weighted contributions (for debugging) "_w_C": self.W_CORRECTNESS * c, "_w_Q": self.W_REASONING * q_norm, "_w_P": self.W_PROCESS * p_norm, "_w_R": self.W_REFLECTION * r_norm, "_w_D": self.W_DIVERSITY * d_norm, "_w_E": self.W_EFFICIENCY * e_norm, "_w_X": self.W_EXPLORATION * x_norm, "_w_F": self.W_FORMAT * f_norm, } return total_r, components