interactSpeech / GRPO /formatReward.py
Student0809's picture
Add files using upload-large-folder tool
51d5430 verified
class MultiModalFormatAccuracyORM(ORM):
def __call__(self, completions, **kwargs) -> List[float]:
"""Reward function that checks if the completion has a specific format."""
rewards = []
response_pattern = r"<response think>.*?</response think>"
react_pattern = r"<fluency think>.*?</fluency think>"
score_pattern = r"[*\s]*<overall score>(\d+)</overall score>[\s*]*"
for content in completion_contents:
breakpoint()
print(content)
has_response = bool(re.search(response_pattern, content, re.DOTALL))
print(has_response)
has_react = bool(re.search(react_pattern, content, re.DOTALL))
print(has_react)
has_score = bool(re.search(score_pattern, content, re.DOTALL))
print(has_score)
if has_response and has_react and has_score:
rewards.append(5.0)
# elif has_score and (has_response or has_react):
# rewards.append(3.0)
# elif has_response or has_react:
# rewards.append(1.0)
else:
rewards.append(0)
return rewards
orms['external_r1v_format_acc'] = MultiModalFormatAccuracyORM