import re from tqdm import tqdm from consts import REASONING_END, REASONING_START, SOLUTION_START, SOLUTION_END def formatting_reward_func(completions, **kwargs): thinking_pattern = f"{REASONING_START}(.*?){REASONING_END}" answer_pattern = f"{SOLUTION_START}(.*?){SOLUTION_END}" scores = [] for completion in tqdm(completions, desc="Computing formatting reward"): score = 0 thinking_matches = re.findall(thinking_pattern, completion, re.DOTALL) answer_matches = re.findall(answer_pattern, completion, re.DOTALL) if len(thinking_matches) == 1: score += 1.0 if len(answer_matches) == 1: score += 1.0 scores.append(score) return scores def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]: answer_pattern = f"{SOLUTION_START}(.*?){SOLUTION_END}" responses = [ re.findall(answer_pattern, completion, re.DOTALL) for completion in tqdm(completions, desc="Extracting responses for correctness") ] q = prompts[0] print( "-" * 20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:{completions[0]}", ) return [ 2.0 if len(r) == 1 and a == r[0].replace("\n", "") else 0.0 for r, a in tqdm( zip(responses, answer), desc="Checking correctness", total=len(responses) ) ]