File size: 1,241 Bytes
51d5430
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class MultiModalFormatAccuracyORM(ORM):
    def __call__(self, completions, **kwargs) -> List[float]:
        """Reward function that checks if the completion has a specific format."""
        rewards = []
        response_pattern = r"<response think>.*?</response think>"
        react_pattern = r"<fluency think>.*?</fluency think>"
        score_pattern = r"[*\s]*<overall score>(\d+)</overall score>[\s*]*"
        for content in completion_contents:
            breakpoint()
            print(content)
            has_response = bool(re.search(response_pattern, content, re.DOTALL))
            print(has_response)
            has_react = bool(re.search(react_pattern, content, re.DOTALL))
            print(has_react)
            has_score = bool(re.search(score_pattern, content, re.DOTALL))
            print(has_score)
            if has_response and has_react and has_score:
                rewards.append(5.0)
            # elif has_score and (has_response or has_react):
            #     rewards.append(3.0)
            # elif has_response or has_react:
            #     rewards.append(1.0)
            else:
                rewards.append(0)
        return rewards
orms['external_r1v_format_acc'] = MultiModalFormatAccuracyORM