import os import re import math import json from datetime import datetime from swift.plugin import ORM,orms from typing import Dict, List, Union class MultiModalAccuracyORM(ORM): def __call__(self, completions, solution, **kwargs) -> List[float]: """ Reward function that checks if the completion is correct. Args: completions (list[str]): Generated outputs solution (list[str]): Ground Truths. Returns: list[float]: Reward scores """ rewards = [] #completion_contents = [completion[0]["content"] for completion in completions] for content, gt_score_orig in zip(completions, solution): score_match = re.search(r"(\d+)", content) #score_match = re.search(r"(\d+)", content) pred_score = None gt_score = None # breakpoint() # print(content) # print(score_match) if score_match: try: pred_score = int(score_match.group(1)) if not (1 <= pred_score <= 2): pred_score = None except: pass try: gt_score = int(gt_score_orig[0]) if not (1 <= gt_score <= 2): gt_score = None except: pass # 分段奖励逻辑 if pred_score is not None and gt_score is not None: if pred_score == gt_score: reward = 5.0 elif abs(pred_score - gt_score) <= 1: reward = 1.0 else: reward = 0.0 else: reward = 0.0 rewards.append(reward) return rewards class MultiModalFormatAccuracyORM(ORM): def __call__(self, completions, **kwargs) -> List[float]: """Reward function that checks if the completion has a specific format.""" rewards = [] response_pattern = r".*?" react_pattern = r".*?" score_pattern = r"[*\s]*(\d+)[\s*]*" #completion_contents = [completion[0]["content"] for completion in completions] for content in completions: # breakpoint() # print(content) has_response = bool(re.search(response_pattern, content, re.DOTALL)) #print(has_response) has_react = bool(re.search(react_pattern, content, re.DOTALL)) #print(has_react) has_score = bool(re.search(score_pattern, content, re.DOTALL)) #print(has_score) if has_response and has_react and has_score: rewards.append(5.0) # elif has_score and (has_response or has_react): # rewards.append(3.0) # elif has_response or has_react: # rewards.append(1.0) else: rewards.append(0) return rewards orms['external_r1v_format_acc'] = MultiModalFormatAccuracyORM orms['external_r1v_acc'] = MultiModalAccuracyORM