class MultiModalFormatAccuracyORM(ORM):
def __call__(self, completions, **kwargs) -> List[float]:
"""Reward function that checks if the completion has a specific format."""
rewards = []
response_pattern = r".*?"
react_pattern = r".*?"
score_pattern = r"[*\s]*(\d+)[\s*]*"
for content in completion_contents:
breakpoint()
print(content)
has_response = bool(re.search(response_pattern, content, re.DOTALL))
print(has_response)
has_react = bool(re.search(react_pattern, content, re.DOTALL))
print(has_react)
has_score = bool(re.search(score_pattern, content, re.DOTALL))
print(has_score)
if has_response and has_react and has_score:
rewards.append(5.0)
# elif has_score and (has_response or has_react):
# rewards.append(3.0)
# elif has_response or has_react:
# rewards.append(1.0)
else:
rewards.append(0)
return rewards
orms['external_r1v_format_acc'] = MultiModalFormatAccuracyORM