| import os |
| import re |
| import math |
| import json |
| from datetime import datetime |
| from swift.plugin import ORM,orms |
| from typing import Dict, List, Union |
|
|
|
|
| class MultiModalAccuracyORM(ORM): |
| def __call__(self, completions, solution, **kwargs) -> List[float]: |
| """ |
| Reward function that checks if the completion is correct. |
| Args: |
| completions (list[str]): Generated outputs |
| solution (list[str]): Ground Truths. |
| |
| Returns: |
| list[float]: Reward scores |
| """ |
| rewards = [] |
| |
| for content, gt_score_orig in zip(completions, solution): |
| score_match = re.search(r"<overall score>(\d+)</overall score>", content) |
| |
| pred_score = None |
| gt_score = None |
| |
| |
| |
| if score_match: |
| try: |
| pred_score = int(score_match.group(1)) |
| if not (1 <= pred_score <= 2): |
| pred_score = None |
| except: |
| pass |
| |
| try: |
| gt_score = int(gt_score_orig[0]) |
| |
| if not (1 <= gt_score <= 2): |
| gt_score = None |
| except: |
| pass |
| |
| |
| if pred_score is not None and gt_score is not None: |
| if pred_score == gt_score: |
| reward = 5.0 |
| elif abs(pred_score - gt_score) <= 1: |
| reward = 1.0 |
| else: |
| reward = 0.0 |
| else: |
| reward = 0.0 |
| |
| rewards.append(reward) |
| return rewards |
| class MultiModalFormatAccuracyORM(ORM): |
| def __call__(self, completions, **kwargs) -> List[float]: |
| """Reward function that checks if the completion has a specific format.""" |
| rewards = [] |
| response_pattern = r"<response think>.*?</response think>" |
| react_pattern = r"<fluency think>.*?</fluency think>" |
| score_pattern = r"[*\s]*<overall score>(\d+)</overall score>[\s*]*" |
| |
| for content in completions: |
| |
| |
| has_response = bool(re.search(response_pattern, content, re.DOTALL)) |
| |
| has_react = bool(re.search(react_pattern, content, re.DOTALL)) |
| |
| has_score = bool(re.search(score_pattern, content, re.DOTALL)) |
| |
| if has_response and has_react and has_score: |
| rewards.append(5.0) |
| |
| |
| |
| |
| else: |
| rewards.append(0) |
| return rewards |
| orms['external_r1v_format_acc'] = MultiModalFormatAccuracyORM |
| orms['external_r1v_acc'] = MultiModalAccuracyORM |