File size: 5,512 Bytes
78a0ca9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import re
from typing import List, Optional, Any, Union

class RewardFunctions:
    @staticmethod
    def format_reward(completions: List[str], **kwargs) -> List[float]:
        """Checks for <reasoning>...</reasoning><answer>...</answer> format."""
        pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
        return [1.0 if re.search(pattern, c, re.DOTALL) else 0.0 for c in completions]

    @staticmethod
    def accuracy_reward(completions: List[str], output: Optional[Union[str, List[str]]] = None, **kwargs) -> List[float]:
        """Compares model completions to the reference output. 
        Robustly extracts answers from <answer> tags and normalizes for comparison."""
        if output is None:
            return [0.0] * len(completions)
        
        if isinstance(output, str):
            output = [output] * len(completions)
            
        def normalize(text: str) -> str:
            # Remove <answer> tags if they still exist
            text = re.sub(r"</?answer>", "", text, flags=re.IGNORECASE)
            # Lowercase
            text = text.lower().strip()
            # Remove punctuation at the end
            text = re.sub(r'[.\u3002?!\uff01\uff1f]+$', '', text)
            # Normalize whitespace
            text = " ".join(text.split())
            # Remove common "The answer is" prefix
            text = re.sub(r'^(the answer is|answer:|result:)\s*', '', text)
            return text

        rewards = []
        for c, ref in zip(completions, output):
            # Extract answer from <answer> tags if present in completion
            c_match = re.search(r"<answer>(.*?)</answer>", c, re.DOTALL | re.IGNORECASE)
            c_answer = c_match.group(1).strip() if c_match else c.strip()
            
            # Extract answer from <answer> tags if present in reference
            ref_match = re.search(r"<answer>(.*?)</answer>", str(ref), re.DOTALL | re.IGNORECASE)
            ref_answer = ref_match.group(1).strip() if ref_match else str(ref).strip()
            
            norm_c = normalize(c_answer)
            norm_ref = normalize(ref_answer)
            
            if norm_c == norm_ref:
                rewards.append(1.0)
            elif norm_ref in norm_c or norm_c in norm_ref:
                # Partial credit if one is a substring of the other (e.g. "42" in "The answer is 42")
                # but only if the overlap is significant
                if len(norm_c) > 0 and len(norm_ref) > 0:
                    ratio = min(len(norm_c), len(norm_ref)) / max(len(norm_c), len(norm_ref))
                    rewards.append(0.5 * ratio if ratio > 0.5 else 0.2)
                else:
                    rewards.append(0.0)
            else:
                rewards.append(0.0)
        return rewards

    @staticmethod
    def reasoning_reward(completions: List[str], **kwargs) -> List[float]:
        """Rewards presence and quality of reasoning steps."""
        rewards = []
        for c in completions:
            match = re.search(r"<reasoning>(.*?)</reasoning>", c, re.DOTALL | re.IGNORECASE)
            if match:
                reasoning = match.group(1).strip()
                
                # Check for step markers
                step_markers = len(re.findall(r"(?:step\s*\d+)|(?:\d+\.)|(?:\bfirst\b|\bsecond\b|\bthird\b|\bfinally\b)", reasoning, re.I))
                
                # Check for logical connectors
                logical_connectors = len(re.findall(r"(?:\btherefore\b|\bthus\b|\bbecause\b|\bhence\b|\bso\b|\bsince\b|\bconsequently\b)", reasoning, re.I))
                
                # Check for "thought" markers
                thought_markers = len(re.findall(r"(?:\blet's\b|\bwe can\b|\bif we\b|\bthen\b|\bassume\b)", reasoning, re.I))
                
                # Base score on length and diversity
                score = 0.0
                if len(reasoning) > 200:
                    score += 0.4
                elif len(reasoning) > 50:
                    score += 0.2
                
                # Bonus for steps and logic
                score += min(0.3, step_markers * 0.1)
                score += min(0.2, logical_connectors * 0.05)
                score += min(0.1, thought_markers * 0.02)
                
                # Penalty for very short reasoning with tags
                if len(reasoning) < 20:
                    score = 0.1
                
                rewards.append(min(1.0, score))
            else:
                rewards.append(0.0)
        return rewards

    @staticmethod
    def length_penalty(completions: List[str], max_len: int = 1000, **kwargs) -> List[float]:
        """Penalizes excessively long completions."""
        return [max(0.0, 1.0 - (len(c) / max_len)) if len(c) > max_len else 1.0 for c in completions]

    @staticmethod
    def combined_reward(completions: List[str], **kwargs) -> List[float]:
        """Combines format, accuracy, reasoning, and length rewards."""
        f_rewards = RewardFunctions.format_reward(completions, **kwargs)
        a_rewards = RewardFunctions.accuracy_reward(completions, **kwargs)
        r_rewards = RewardFunctions.reasoning_reward(completions, **kwargs)
        l_rewards = RewardFunctions.length_penalty(completions, **kwargs)
        
        # Weight: 15% format, 55% accuracy, 20% reasoning, 10% length
        return [
            f * 0.15 + a * 0.55 + r * 0.2 + l * 0.1 
            for f, a, r, l in zip(f_rewards, a_rewards, r_rewards, l_rewards)
        ]