Spaces:
Sleeping
Sleeping
| import math | |
| from typing import Dict, Any, List, Tuple | |
| class RewardSystem: | |
| """ | |
| Dense, multi-component reward system for mathematical RL training. | |
| Key improvements over v1: | |
| 1. All 7 reward components now contribute to the final score | |
| 2. Partial credit support (continuous C ∈ [0,1] from verifier) | |
| 3. Fixed history key mismatch (was breaking diversity detection) | |
| 4. Adaptive efficiency curve that doesn't over-penalize reasonable lengths | |
| 5. Removed random noise (adds variance without useful signal) | |
| 6. Added format compliance reward for structured output | |
| Reward equation: | |
| R = α·C + β·Q + γ·P + δ·R_ref + η·D_norm + ζ·E_norm + λ·X + μ·F_fmt | |
| Weights: α=0.30, β=0.12, γ=0.10, δ=0.05, η=0.13, ζ=0.08, λ=0.07, μ=0.15 | |
| Sum = 1.0 | |
| References: | |
| - arxiv:2408.10215 (Reward shaping for RL convergence) | |
| - arxiv:2601.19100 (Reward engineering for software/code tasks) | |
| - DeepSeek-R1 GRPO (graduated correctness) | |
| - GRPO-λ (credit assignment) | |
| """ | |
| # Reward component weights (sum to 1.0) | |
| W_CORRECTNESS = 0.30 # α: Primary — correctness drives learning | |
| W_REASONING = 0.12 # β: Reasoning quality | |
| W_PROCESS = 0.10 # γ: Step-by-step process supervision | |
| W_REFLECTION = 0.05 # δ: Self-correction behavior | |
| W_DIVERSITY = 0.13 # η: Answer diversity (prevents repetition) | |
| W_EFFICIENCY = 0.08 # ζ: Token efficiency | |
| W_EXPLORATION = 0.07 # λ: Exploration bonus | |
| W_FORMAT = 0.15 # μ: Format compliance (model must learn structure) | |
| def __init__(self, max_len: int = 1000): | |
| self.max_len = max_len | |
| def compute_diversity(self, current_answer: str, history: List[Dict[str, Any]]) -> float: | |
| """ | |
| D = diversity (difference from past attempts). | |
| Graduated penalty instead of binary: | |
| - Exact repeat: -1.0 (steep penalty) | |
| - Similar to a past answer: -0.3 | |
| - Unique: +1.0 | |
| """ | |
| if not history: | |
| return 1.0 | |
| cur_ans_clean = current_answer.strip().lower() | |
| if not cur_ans_clean: | |
| return 0.0 # Empty answer gets no diversity credit | |
| for attempt in history: | |
| # BUGFIX: check both 'final_answer' and 'prediction' keys for compatibility | |
| prev_ans = attempt.get('final_answer', attempt.get('prediction', '')).strip().lower() | |
| if prev_ans == cur_ans_clean: | |
| return -1.0 # Exact repeat — strong penalty | |
| # Check for near-duplicates (edit distance heuristic) | |
| if prev_ans and cur_ans_clean: | |
| # Simple character overlap ratio | |
| overlap = sum(1 for a, b in zip(prev_ans, cur_ans_clean) if a == b) | |
| max_len = max(len(prev_ans), len(cur_ans_clean)) | |
| if max_len > 0 and overlap / max_len > 0.85: | |
| return -0.3 # Near-duplicate — moderate penalty | |
| return 1.0 | |
| def compute_efficiency(self, action_string: str) -> float: | |
| """ | |
| E = efficiency. Adaptive Gaussian penalty curve. | |
| Improved: wider optimal zone (30-120 tokens) to avoid penalizing | |
| legitimate mathematical reasoning that naturally needs more space. | |
| E ∈ [-0.5, 0.0] (always a penalty or neutral, never a bonus) | |
| """ | |
| approx_tokens = len(action_string) / 4.0 | |
| optimal_center = 80.0 # Wider center for math | |
| optimal_width = 60.0 # Generous width | |
| # Gentle Gaussian — penalizes only extreme lengths | |
| ratio = (approx_tokens - optimal_center) / optimal_width | |
| e = math.exp(-(ratio ** 2)) - 1.0 | |
| # Additional penalty for very long outputs (anti-rambling) | |
| if approx_tokens > 300: | |
| e -= 0.3 * (approx_tokens - 300) / 300 | |
| return max(-1.0, e) | |
| def compute_exploration_bonus(self, action_string: str, times_seen: int) -> float: | |
| """ | |
| [PAPER TRACEABILITY: Exploration via Entropy Bonus] | |
| G. EXPLORATION VIA ENTROPY BONUS | |
| X = (entropy_bonus) / sqrt(1 + times_seen_problem) | |
| Improved with better entropy estimation using word-level diversity. | |
| """ | |
| length = len(action_string) | |
| if length == 0: | |
| return 0.0 | |
| # Character-level entropy | |
| unique_ratio = len(set(action_string)) / length | |
| char_entropy = math.log1p(unique_ratio) | |
| # Word-level diversity bonus (rewards varied vocabulary) | |
| words = action_string.lower().split() | |
| if words: | |
| unique_word_ratio = len(set(words)) / len(words) | |
| word_entropy = math.log1p(unique_word_ratio) | |
| else: | |
| word_entropy = 0.0 | |
| combined = 0.6 * char_entropy + 0.4 * word_entropy | |
| return combined / math.sqrt(1.0 + times_seen) | |
| def compute_format_compliance(self, action_str: str, reasoning: str, final_answer: str) -> float: | |
| """ | |
| Format compliance reward — teaches the model to output structured responses. | |
| Rewards: | |
| - Having both reasoning and answer sections | |
| - Using mathematical notation | |
| - Proper structure (reasoning before answer) | |
| F ∈ [0, 1] | |
| """ | |
| score = 0.0 | |
| # Has non-empty reasoning | |
| if reasoning and len(reasoning.strip()) > 10: | |
| score += 0.3 | |
| # Has non-empty final answer | |
| if final_answer and len(final_answer.strip()) > 0: | |
| score += 0.3 | |
| # Answer contains mathematical content | |
| math_indicators = ['x', '=', '+', '-', '*', '/', '^', 'sin', 'cos', 'exp', 'log', '('] | |
| math_count = sum(1 for m in math_indicators if m in final_answer.lower()) | |
| if math_count >= 2: | |
| score += 0.2 | |
| elif math_count >= 1: | |
| score += 0.1 | |
| # Reasoning contains structured steps | |
| if any(marker in reasoning.lower() for marker in ['step', 'first', 'then', 'therefore', '=']): | |
| score += 0.2 | |
| return min(1.0, score) | |
| def detect_trivial_output(self, action_string: str) -> bool: | |
| """Anti-reward hacking: detect trivial constant outputs""" | |
| # If the output is just a single character repeated or very low entropy | |
| if len(action_string) < 2: | |
| return True | |
| unique_chars = len(set(action_string)) | |
| if unique_chars < 3 and len(action_string) > 10: | |
| return True | |
| # Detect repetitive patterns | |
| if len(action_string) > 20: | |
| # Check if a short pattern is repeated | |
| for plen in range(1, 6): | |
| pattern = action_string[:plen] | |
| if action_string == pattern * (len(action_string) // plen) + pattern[:len(action_string) % plen]: | |
| return True | |
| return False | |
| def compute_reward(self, | |
| correctness: float, | |
| reasoning_quality: float, | |
| process_supervision: float, | |
| reflection_score: float, | |
| action_str: str, | |
| final_answer: str, | |
| history: List[Dict[str, Any]], | |
| times_seen_problem: int, | |
| reasoning: str = "") -> Tuple[float, Dict[str, float]]: | |
| """ | |
| Dense composite reward using ALL 7 components + format compliance. | |
| R = α·C + β·Q_norm + γ·P_norm + δ·R_norm + η·D_norm + ζ·E_norm + λ·X + μ·F_fmt | |
| All components are normalized to [0, 1] before weighting. | |
| Final reward ∈ [0, 1]. | |
| """ | |
| if self.detect_trivial_output(action_str): | |
| components = { | |
| "total_reward": -0.5, | |
| "C_correctness": 0.0, "Q_reasoning": 0.0, | |
| "P_process_supervision": 0.0, "R_reflection": 0.0, | |
| "D_diversity": 0.0, "E_efficiency": -1.0, | |
| "X_exploration": 0.0, "F_format": 0.0, | |
| } | |
| return -0.5, components | |
| # --- Raw component computation --- | |
| c = correctness # Already ∈ [0, 1] with graduated scoring | |
| q = reasoning_quality | |
| d = self.compute_diversity(final_answer, history) | |
| e = self.compute_efficiency(action_str) | |
| x = self.compute_exploration_bonus(action_str, times_seen_problem) | |
| f_fmt = self.compute_format_compliance(action_str, reasoning, final_answer) | |
| # If repeated answer, reduce correctness credit (anti-hacking) | |
| if d < -0.5: | |
| c = c * 0.3 # Steep discount but not full zeroing | |
| # --- Normalize all components to [0, 1] --- | |
| q_norm = min(1.0, max(0.0, math.tanh(q))) | |
| p_norm = (process_supervision + 1.0) / 2.0 # [-1, 1] → [0, 1] | |
| r_norm = (reflection_score + 1.0) / 2.0 # [-1, 1] → [0, 1] | |
| d_norm = (d + 1.0) / 2.0 # [-1, 1] → [0, 1] | |
| e_norm = (e + 1.0) / 1.0 # [-1, 0] → [0, 1] | |
| e_norm = min(1.0, max(0.0, e_norm)) | |
| x_norm = min(1.0, max(0.0, x)) | |
| f_norm = min(1.0, max(0.0, f_fmt)) | |
| # --- Weighted composite --- | |
| total_r = ( | |
| self.W_CORRECTNESS * c + | |
| self.W_REASONING * q_norm + | |
| self.W_PROCESS * p_norm + | |
| self.W_REFLECTION * r_norm + | |
| self.W_DIVERSITY * d_norm + | |
| self.W_EFFICIENCY * e_norm + | |
| self.W_EXPLORATION * x_norm + | |
| self.W_FORMAT * f_norm | |
| ) | |
| # Clamp to [0, 1] | |
| total_r = min(1.0, max(0.0, total_r)) | |
| components = { | |
| "total_reward": total_r, | |
| "C_correctness": c, | |
| "Q_reasoning": q_norm, | |
| "P_process_supervision": process_supervision, | |
| "R_reflection": reflection_score, | |
| "D_diversity": d, | |
| "E_efficiency": e, | |
| "X_exploration": x, | |
| "F_format": f_fmt, | |
| # Weighted contributions (for debugging) | |
| "_w_C": self.W_CORRECTNESS * c, | |
| "_w_Q": self.W_REASONING * q_norm, | |
| "_w_P": self.W_PROCESS * p_norm, | |
| "_w_R": self.W_REFLECTION * r_norm, | |
| "_w_D": self.W_DIVERSITY * d_norm, | |
| "_w_E": self.W_EFFICIENCY * e_norm, | |
| "_w_X": self.W_EXPLORATION * x_norm, | |
| "_w_F": self.W_FORMAT * f_norm, | |
| } | |
| return total_r, components | |