interactSpeech / GRPO /Reward.py
Student0809's picture
Add files using upload-large-folder tool
3438cdb verified
raw
history blame
3.26 kB
import os
import re
import math
import json
from datetime import datetime
from swift.plugin import ORM,orms
from typing import Dict, List, Union
class MultiModalAccuracyORM(ORM):
def __call__(self, completions, solution, **kwargs) -> List[float]:
"""
Reward function that checks if the completion is correct.
Args:
completions (list[str]): Generated outputs
solution (list[str]): Ground Truths.
Returns:
list[float]: Reward scores
"""
rewards = []
#completion_contents = [completion[0]["content"] for completion in completions]
for content, gt_score_orig in zip(completions, solution):
score_match = re.search(r"<overall score>(\d+)</overall score>", content)
#score_match = re.search(r"<score>(\d+)</score>", content)
pred_score = None
gt_score = None
# breakpoint()
# print(content)
# print(score_match)
if score_match:
try:
pred_score = int(score_match.group(1))
if not (1 <= pred_score <= 2):
pred_score = None
except:
pass
try:
gt_score = int(gt_score_orig[0])
if not (1 <= gt_score <= 2):
gt_score = None
except:
pass
# 分段奖励逻辑
if pred_score is not None and gt_score is not None:
if pred_score == gt_score:
reward = 5.0
elif abs(pred_score - gt_score) <= 1:
reward = 1.0
else:
reward = 0.0
else:
reward = 0.0
rewards.append(reward)
return rewards
class MultiModalFormatAccuracyORM(ORM):
def __call__(self, completions, **kwargs) -> List[float]:
"""Reward function that checks if the completion has a specific format."""
rewards = []
response_pattern = r"<response think>.*?</response think>"
react_pattern = r"<fluency think>.*?</fluency think>"
score_pattern = r"[*\s]*<overall score>(\d+)</overall score>[\s*]*"
#completion_contents = [completion[0]["content"] for completion in completions]
for content in completions:
# breakpoint()
# print(content)
has_response = bool(re.search(response_pattern, content, re.DOTALL))
#print(has_response)
has_react = bool(re.search(react_pattern, content, re.DOTALL))
#print(has_react)
has_score = bool(re.search(score_pattern, content, re.DOTALL))
#print(has_score)
if has_response and has_react and has_score:
rewards.append(5.0)
# elif has_score and (has_response or has_react):
# rewards.append(3.0)
# elif has_response or has_react:
# rewards.append(1.0)
else:
rewards.append(0)
return rewards
orms['external_r1v_format_acc'] = MultiModalFormatAccuracyORM
orms['external_r1v_acc'] = MultiModalAccuracyORM