File size: 3,264 Bytes
51d5430
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
import re
import math
import json
from datetime import datetime
from swift.plugin import ORM,orms
from typing import Dict, List, Union


class MultiModalAccuracyORM(ORM):
    def __call__(self, completions, solution, **kwargs) -> List[float]:
        """
        Reward function that checks if the completion is correct.
        Args:
            completions (list[str]): Generated outputs
            solution (list[str]): Ground Truths.

        Returns:
            list[float]: Reward scores
        """
        rewards = []
        #completion_contents = [completion[0]["content"] for completion in completions]
        for content, gt_score_orig in zip(completions, solution):
            score_match = re.search(r"<overall score>(\d+)</overall score>", content)
            #score_match = re.search(r"<score>(\d+)</score>", content)
            pred_score = None
            gt_score = None
            # breakpoint()
            # print(content)
            # print(score_match)
            if score_match:
                try:
                    pred_score = int(score_match.group(1))
                    if not (1 <= pred_score <= 2):
                        pred_score = None
                except:
                    pass
    
            try:
                gt_score = int(gt_score_orig[0])
                
                if not (1 <= gt_score <= 2):
                    gt_score = None
            except:
                pass
    
            # 分段奖励逻辑
            if pred_score is not None and gt_score is not None:
                if pred_score == gt_score:
                    reward = 5.0
                elif abs(pred_score - gt_score) <= 1:
                    reward = 1.0
                else:
                    reward = 0.0
            else:
                reward = 0.0
    
            rewards.append(reward)
        return rewards
class MultiModalFormatAccuracyORM(ORM):
    def __call__(self, completions, **kwargs) -> List[float]:
        """Reward function that checks if the completion has a specific format."""
        rewards = []
        response_pattern = r"<response think>.*?</response think>"
        react_pattern = r"<fluency think>.*?</fluency think>"
        score_pattern = r"[*\s]*<overall score>(\d+)</overall score>[\s*]*"
        #completion_contents = [completion[0]["content"] for completion in completions]
        for content in completions:
            # breakpoint()
            # print(content)
            has_response = bool(re.search(response_pattern, content, re.DOTALL))
            #print(has_response)
            has_react = bool(re.search(react_pattern, content, re.DOTALL))
            #print(has_react)
            has_score = bool(re.search(score_pattern, content, re.DOTALL))
            #print(has_score)
            if has_response and has_react and has_score:
                rewards.append(5.0)
            # elif has_score and (has_response or has_react):
            #     rewards.append(3.0)
            # elif has_response or has_react:
            #     rewards.append(1.0)
            else:
                rewards.append(0)
        return rewards
orms['external_r1v_format_acc'] = MultiModalFormatAccuracyORM
orms['external_r1v_acc'] = MultiModalAccuracyORM