File size: 7,022 Bytes
9440cb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183



from itertools import islice, zip_longest
from typing import Callable, Dict, List, Optional, Tuple, TypedDict
import json

def repeatness_reward(s: str):
    def ranks(l):
        index = {v: i for i, v in enumerate(sorted(set(l)))}
        return [index[v] for v in l]

    def suffixArray(s):
        line = ranks(s)
        n, k, ans, sa = len(s), 1, line, [0] * len(s)
        while k < n - 1:
            line = ranks(list(zip_longest(line, islice(line, k, None), fillvalue=-1)))
            ans, k = line, k << 1
        for i, k in enumerate(ans):
            sa[k] = i
        return ans, sa

    def lcp(arr, suffixArr, inv_suff):
        n, ans, k = len(arr), [0] * len(arr), 0

        for i in range(n):
            if inv_suff[i] == n - 1:
                k = 0
                continue

            j = suffixArr[inv_suff[i] + 1]
            while i + k < n and j + k < n and arr[i + k] == arr[j + k]:
                k += 1

            ans[inv_suff[i]] = k
            if k > 0:
                k -= 1

        return ans

    arr = [ord(i) for i in s]
    n = len(arr)
    if n <= 1:
        return 0
    c, sa = suffixArray(arr)
    cnt = sum(lcp(arr, sa, c))

    return 1 - cnt * 2 / (n * (n + 1))

import re

def format_reward(predict_str: str) -> float:
    """
    格式奖励函数,严格要求输出格式为:
    <think>...</think><answer>...</answer>
    中间不能有多余内容
    """
    pattern = r'^<think>.*?</think>\s*<answer>\s*([0-9])\s*</answer>$'
    return 1.0 if re.fullmatch(pattern, predict_str.strip(), re.DOTALL) else 0.0

def acc_reward(predict_str: str, ground_truth: str) -> float:
    """
    准确率奖励函数
    要求<answer>中内容与ground_truth完全一致(顺序、空格等)
    """
    match = re.search(r'<answer>\s*([0-9])\s*</answer>', predict_str)
    if not match:
        return 0.0
    answer_content = match.group(1)
    # print(ground_truth)
    # print(answer_content)
    # print(int(answer_content) == ground_truth)
    # print("ground_truth 类型:", type(ground_truth))
    # print("answer_content 类型:", type(answer_content))
    # print("========")
    if int(answer_content) == ground_truth:
        return 1.0
    else:
        return 0.0
    # return 1.0 if answer_content == ground_truth else 0.0
    # match = re.search(r'<answer>(.*?)</answer>', predict_str, re.DOTALL)
    # if not match:
    #     return 0.0
    # answer_content = match.group(1).strip()
    # return 1.0 if answer_content == ground_truth else 0.0

# def compute_score( solution_str: str, ground_truth: str, extra_info):
#     """
#     综合评分函数
#     """
def compute_score(predicts: List[str], ground_truths: List[str], format_weight: float = 0.1) -> List[Dict[str, float]]:
    scores = []
    save_path="/nas/shared/kilab/wangyujia/check_rl/result-06170934.jsonl"
    with open(save_path, "w", encoding="utf-8") as f:
        for solution_str, ground_truth in zip(predicts, ground_truths):
            format_score = format_reward(solution_str)
            acc_score = acc_reward(solution_str, ground_truth)

            # 提取<think>内容
            think_match = re.search(r'<think>(.*?)</think>', solution_str, re.DOTALL)
            think_str = think_match.group(1).strip() if think_match else ""
            repeat_score = repeatness_reward(think_str)

            scores.append(
                {
                    "overall": format_score + acc_score + repeat_score,
                    "format": format_score,
                    "accuracy": acc_score,
                    "repeat" : repeat_score,
                }
            )

            # 写入 JSONL 文件
            f.write(json.dumps({
                "solution_str": solution_str,
                "ground_truth": ground_truth,
                "overall": format_score + acc_score + repeat_score,
                "format": format_score,
                "accuracy": acc_score,
                "repeat" : repeat_score,
            }, ensure_ascii=False) + "\n")
    
    # 加权综合评分(格式占30%,准确率占70%)
    # 合成字典
    # total_score = {
    #     "format_score": format_score,
    #     "acc_score": acc_score,
    #     "repeat_score": repeat_score,
    #     "total_score": format_score + acc_score + repeat_score
    # }
    #total_score=format_score + acc_score + repeat_score

    return scores


# print(format_reward("<think>Step-by-step logic</think>   <answer> 5 </answer>"))
# print(format_reward("<think>Something\nacross lines</think>\n<answer> 0 </answer>"))

# print(format_reward("No tags here"))
# print(format_reward("<think>OK</think><answer>12</answer>"))  # 多位数字
# print(format_reward("<think>OK</think><answer>A</answer>"))   # 字母不允许
# print(format_reward("<think>Yes</think><answer> </answer>"))  # 空的答案
# print(format_reward("<think>OK</think><answer>3</answer>extra"))  # 多余内容
# print(format_reward("<answer>3</answer><think>Reasoning</think>"))  # 标签顺序错误

# print(acc_reward("<think>Step-by-step logic</think>   <answer> 5 </answer>",'5'))
# print(acc_reward("<think>Something\nacross lines</think>\n<answer> 0 </answer>",'1'))


# str_="<think>\nThe protein name is P32783, the protein amino acid sequence is MSTKPEKPIWMSQEDYDRQYGSITGDESSTVSKKDSKVTANAPGDGNGSLPVLQSSSILTSKVSDLPIEAESGFKIQKRRHERYDQEERLRKQRAQKLREEQLKRHEIEMTANRSINVDQIVREHYNERTIIANRAKRNLSPIIKLRNFNNAIKYMLIDKYTKPGDVVLELGCGKGGDLRKYGAAGISQFIGIDISNASIQEAHKRYRSMRNLDYQVVLITGDCFGESLGVAVEPFPDCRFPCDIVSTQFCLHYAFETEEKARRALLNVAKSLKIGGHFFGTIPDSEFIRYKLNKFPKEVEKPSWGNSIYKVTFENNSYQKNDYEFTSPYGQMYTYWLEDAIDNVPEYVVPFETLRSLADEYGLELVSQMPFNKFFVQEIPKWIERFSPKMREGLQRSDGRYGVEGDEKEAASYFYTMFAFRKVKQYIEPESVKPN, the protein localization prediction for P32783 is Cell.membrane,M, so the location label is 4. Therefore, option 4 is the correct answer.\n</think>\n<answer>\n4\n</answer>"
# print(format_reward(str_))



def check_rewards(jsonl_path: str) -> List[Dict[str, float]]:
    results = []
    with open(jsonl_path, "r", encoding="utf-8") as f:
        for line in f:
            data = json.loads(line)
            solution_str = data["solution_str"]
            ground_truth = data["ground_truth"]

            # 重新计算三个分数
            format_score = format_reward(solution_str)
            acc_score = acc_reward(solution_str, ground_truth)
            think_match = re.search(r'<think>(.*?)</think>', solution_str, re.DOTALL)
            think_str = think_match.group(1).strip() if think_match else ""
            repeat_score = repeatness_reward(think_str)

            total_score = format_score + acc_score + repeat_score


            result = {
                "format": format_score,
                "accuracy": acc_score,
                "repeat": repeat_score,
                "overall": total_score,
            }
            # results.append(result)

            print(json.dumps(result, indent=2, ensure_ascii=False)) 

check_rewards("/nas/shared/kilab/wangyujia/check_rl/check.jsonl")