interactSpeech / GRPO /Reward.py
Student0809's picture
Add files using upload-large-folder tool
3438cdb verified
import os
import re
import math
import json
from datetime import datetime
from swift.plugin import ORM,orms
from typing import Dict, List, Union
class MultiModalAccuracyORM(ORM):
def __call__(self, completions, solution, **kwargs) -> List[float]:
"""
Reward function that checks if the completion is correct.
Args:
completions (list[str]): Generated outputs
solution (list[str]): Ground Truths.
Returns:
list[float]: Reward scores
"""
rewards = []
#completion_contents = [completion[0]["content"] for completion in completions]
for content, gt_score_orig in zip(completions, solution):
score_match = re.search(r"<overall score>(\d+)</overall score>", content)
#score_match = re.search(r"<score>(\d+)</score>", content)
pred_score = None
gt_score = None
# breakpoint()
# print(content)
# print(score_match)
if score_match:
try:
pred_score = int(score_match.group(1))
if not (1 <= pred_score <= 2):
pred_score = None
except:
pass
try:
gt_score = int(gt_score_orig[0])
if not (1 <= gt_score <= 2):
gt_score = None
except:
pass
# 分段奖励逻辑
if pred_score is not None and gt_score is not None:
if pred_score == gt_score:
reward = 5.0
elif abs(pred_score - gt_score) <= 1:
reward = 1.0
else:
reward = 0.0
else:
reward = 0.0
rewards.append(reward)
return rewards
class MultiModalFormatAccuracyORM(ORM):
def __call__(self, completions, **kwargs) -> List[float]:
"""Reward function that checks if the completion has a specific format."""
rewards = []
response_pattern = r"<response think>.*?</response think>"
react_pattern = r"<fluency think>.*?</fluency think>"
score_pattern = r"[*\s]*<overall score>(\d+)</overall score>[\s*]*"
#completion_contents = [completion[0]["content"] for completion in completions]
for content in completions:
# breakpoint()
# print(content)
has_response = bool(re.search(response_pattern, content, re.DOTALL))
#print(has_response)
has_react = bool(re.search(react_pattern, content, re.DOTALL))
#print(has_react)
has_score = bool(re.search(score_pattern, content, re.DOTALL))
#print(has_score)
if has_response and has_react and has_score:
rewards.append(5.0)
# elif has_score and (has_response or has_react):
# rewards.append(3.0)
# elif has_response or has_react:
# rewards.append(1.0)
else:
rewards.append(0)
return rewards
orms['external_r1v_format_acc'] = MultiModalFormatAccuracyORM
orms['external_r1v_acc'] = MultiModalAccuracyORM