FGDemo / future_work /reward.py
KarishGupta's picture
New Slate Mistakes Were Made
3b6ded8
import re
from tqdm import tqdm
from consts import REASONING_END, REASONING_START, SOLUTION_START, SOLUTION_END
def formatting_reward_func(completions, **kwargs):
thinking_pattern = f"{REASONING_START}(.*?){REASONING_END}"
answer_pattern = f"{SOLUTION_START}(.*?){SOLUTION_END}"
scores = []
for completion in tqdm(completions, desc="Computing formatting reward"):
score = 0
thinking_matches = re.findall(thinking_pattern, completion, re.DOTALL)
answer_matches = re.findall(answer_pattern, completion, re.DOTALL)
if len(thinking_matches) == 1:
score += 1.0
if len(answer_matches) == 1:
score += 1.0
scores.append(score)
return scores
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
answer_pattern = f"{SOLUTION_START}(.*?){SOLUTION_END}"
responses = [
re.findall(answer_pattern, completion, re.DOTALL)
for completion in tqdm(completions, desc="Extracting responses for correctness")
]
q = prompts[0]
print(
"-" * 20,
f"Question:\n{q}",
f"\nAnswer:\n{answer[0]}",
f"\nResponse:{completions[0]}",
)
return [
2.0 if len(r) == 1 and a == r[0].replace("\n", "") else 0.0
for r, a in tqdm(
zip(responses, answer), desc="Checking correctness", total=len(responses)
)
]