|
|
import os |
|
|
import re |
|
|
import math |
|
|
import json |
|
|
from datetime import datetime |
|
|
from swift.plugin import ORM,orms |
|
|
from typing import Dict, List, Union |
|
|
|
|
|
|
|
|
class MultiModalAccuracyORM(ORM): |
|
|
def __call__(self, completions, solution, **kwargs) -> List[float]: |
|
|
""" |
|
|
Reward function that checks if the completion is correct. |
|
|
Args: |
|
|
completions (list[str]): Generated outputs |
|
|
solution (list[str]): Ground Truths. |
|
|
|
|
|
Returns: |
|
|
list[float]: Reward scores |
|
|
""" |
|
|
rewards = [] |
|
|
|
|
|
for content, gt_score_orig in zip(completions, solution): |
|
|
score_match = re.search(r"<overall score>(\d+)</overall score>", content) |
|
|
|
|
|
pred_score = None |
|
|
gt_score = None |
|
|
|
|
|
|
|
|
|
|
|
if score_match: |
|
|
try: |
|
|
pred_score = int(score_match.group(1)) |
|
|
if not (1 <= pred_score <= 2): |
|
|
pred_score = None |
|
|
except: |
|
|
pass |
|
|
|
|
|
try: |
|
|
gt_score = int(gt_score_orig[0]) |
|
|
|
|
|
if not (1 <= gt_score <= 2): |
|
|
gt_score = None |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
if pred_score is not None and gt_score is not None: |
|
|
if pred_score == gt_score: |
|
|
reward = 5.0 |
|
|
elif abs(pred_score - gt_score) <= 1: |
|
|
reward = 1.0 |
|
|
else: |
|
|
reward = 0.0 |
|
|
else: |
|
|
reward = 0.0 |
|
|
|
|
|
rewards.append(reward) |
|
|
return rewards |
|
|
class MultiModalFormatAccuracyORM(ORM): |
|
|
def __call__(self, completions, **kwargs) -> List[float]: |
|
|
"""Reward function that checks if the completion has a specific format.""" |
|
|
rewards = [] |
|
|
response_pattern = r"<response think>.*?</response think>" |
|
|
react_pattern = r"<fluency think>.*?</fluency think>" |
|
|
score_pattern = r"[*\s]*<overall score>(\d+)</overall score>[\s*]*" |
|
|
|
|
|
for content in completions: |
|
|
|
|
|
|
|
|
has_response = bool(re.search(response_pattern, content, re.DOTALL)) |
|
|
|
|
|
has_react = bool(re.search(react_pattern, content, re.DOTALL)) |
|
|
|
|
|
has_score = bool(re.search(score_pattern, content, re.DOTALL)) |
|
|
|
|
|
if has_response and has_react and has_score: |
|
|
rewards.append(5.0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
rewards.append(0) |
|
|
return rewards |
|
|
orms['external_r1v_format_acc'] = MultiModalFormatAccuracyORM |
|
|
orms['external_r1v_acc'] = MultiModalAccuracyORM |