File size: 1,411 Bytes
3b6ded8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import re
from tqdm import tqdm
from consts import REASONING_END, REASONING_START, SOLUTION_START, SOLUTION_END


def formatting_reward_func(completions, **kwargs):
    thinking_pattern = f"{REASONING_START}(.*?){REASONING_END}"
    answer_pattern = f"{SOLUTION_START}(.*?){SOLUTION_END}"
    scores = []
    for completion in tqdm(completions, desc="Computing formatting reward"):
        score = 0
        thinking_matches = re.findall(thinking_pattern, completion, re.DOTALL)
        answer_matches = re.findall(answer_pattern, completion, re.DOTALL)
        if len(thinking_matches) == 1:
            score += 1.0
        if len(answer_matches) == 1:
            score += 1.0
        scores.append(score)
    return scores


def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    answer_pattern = f"{SOLUTION_START}(.*?){SOLUTION_END}"

    responses = [
        re.findall(answer_pattern, completion, re.DOTALL)
        for completion in tqdm(completions, desc="Extracting responses for correctness")
    ]
    q = prompts[0]

    print(
        "-" * 20,
        f"Question:\n{q}",
        f"\nAnswer:\n{answer[0]}",
        f"\nResponse:{completions[0]}",
    )
    return [
        2.0 if len(r) == 1 and a == r[0].replace("\n", "") else 0.0
        for r, a in tqdm(
            zip(responses, answer), desc="Checking correctness", total=len(responses)
        )
    ]