yuccaaa's picture
Add files using upload-large-folder tool
052f594 verified
from itertools import islice, zip_longest
from typing import Callable, Dict, List, Optional, Tuple, TypedDict
import json
def repeatness_reward(s: str):
def ranks(l):
index = {v: i for i, v in enumerate(sorted(set(l)))}
return [index[v] for v in l]
def suffixArray(s):
line = ranks(s)
n, k, ans, sa = len(s), 1, line, [0] * len(s)
while k < n - 1:
line = ranks(list(zip_longest(line, islice(line, k, None), fillvalue=-1)))
ans, k = line, k << 1
for i, k in enumerate(ans):
sa[k] = i
return ans, sa
def lcp(arr, suffixArr, inv_suff):
n, ans, k = len(arr), [0] * len(arr), 0
for i in range(n):
if inv_suff[i] == n - 1:
k = 0
continue
j = suffixArr[inv_suff[i] + 1]
while i + k < n and j + k < n and arr[i + k] == arr[j + k]:
k += 1
ans[inv_suff[i]] = k
if k > 0:
k -= 1
return ans
arr = [ord(i) for i in s]
n = len(arr)
if n <= 1:
return 0
c, sa = suffixArray(arr)
cnt = sum(lcp(arr, sa, c))
return 1 - cnt * 2 / (n * (n + 1))
import re
def format_reward(predict_str: str) -> float:
"""
格式奖励函数,严格要求输出格式为:
<think>...</think><answer>...</answer>
中间不能有多余内容
"""
pattern = r'^<think>.*?</think>\s*<answer>\s*([0-9])\s*</answer>$'
return 1.0 if re.fullmatch(pattern, predict_str.strip(), re.DOTALL) else 0.0
def acc_reward(predict_str: str, ground_truth: str) -> float:
"""
准确率奖励函数
要求<answer>中内容与ground_truth完全一致(顺序、空格等)
"""
match = re.search(r'<answer>\s*([0-9])\s*</answer>', predict_str)
if not match:
return 0.0
answer_content = match.group(1)
# print(ground_truth)
# print(answer_content)
# print(int(answer_content) == ground_truth)
# print("ground_truth 类型:", type(ground_truth))
# print("answer_content 类型:", type(answer_content))
# print("========")
if int(answer_content) == ground_truth:
return 1.0
else:
return 0.0
# return 1.0 if answer_content == ground_truth else 0.0
# match = re.search(r'<answer>(.*?)</answer>', predict_str, re.DOTALL)
# if not match:
# return 0.0
# answer_content = match.group(1).strip()
# return 1.0 if answer_content == ground_truth else 0.0
# def compute_score( solution_str: str, ground_truth: str, extra_info):
# """
# 综合评分函数
# """
def compute_score(predicts: List[str], ground_truths: List[str], format_weight: float = 0.1) -> List[Dict[str, float]]:
scores = []
save_path="/nas/shared/kilab/wangyujia/check_rl/result-06170934.jsonl"
with open(save_path, "w", encoding="utf-8") as f:
for solution_str, ground_truth in zip(predicts, ground_truths):
format_score = format_reward(solution_str)
acc_score = acc_reward(solution_str, ground_truth)
# 提取<think>内容
think_match = re.search(r'<think>(.*?)</think>', solution_str, re.DOTALL)
think_str = think_match.group(1).strip() if think_match else ""
repeat_score = repeatness_reward(think_str)
scores.append(
{
"overall": format_score + acc_score + repeat_score,
"format": format_score,
"accuracy": acc_score,
"repeat" : repeat_score,
}
)
# 写入 JSONL 文件
f.write(json.dumps({
"solution_str": solution_str,
"ground_truth": ground_truth,
"overall": format_score + acc_score + repeat_score,
"format": format_score,
"accuracy": acc_score,
"repeat" : repeat_score,
}, ensure_ascii=False) + "\n")
# 加权综合评分(格式占30%,准确率占70%)
# 合成字典
# total_score = {
# "format_score": format_score,
# "acc_score": acc_score,
# "repeat_score": repeat_score,
# "total_score": format_score + acc_score + repeat_score
# }
#total_score=format_score + acc_score + repeat_score
return scores
# print(format_reward("<think>Step-by-step logic</think> <answer> 5 </answer>"))
# print(format_reward("<think>Something\nacross lines</think>\n<answer> 0 </answer>"))
# print(format_reward("No tags here"))
# print(format_reward("<think>OK</think><answer>12</answer>")) # 多位数字
# print(format_reward("<think>OK</think><answer>A</answer>")) # 字母不允许
# print(format_reward("<think>Yes</think><answer> </answer>")) # 空的答案
# print(format_reward("<think>OK</think><answer>3</answer>extra")) # 多余内容
# print(format_reward("<answer>3</answer><think>Reasoning</think>")) # 标签顺序错误
# print(acc_reward("<think>Step-by-step logic</think> <answer> 5 </answer>",'5'))
# print(acc_reward("<think>Something\nacross lines</think>\n<answer> 0 </answer>",'1'))
# str_="<think>\nThe protein name is P32783, the protein amino acid sequence is MSTKPEKPIWMSQEDYDRQYGSITGDESSTVSKKDSKVTANAPGDGNGSLPVLQSSSILTSKVSDLPIEAESGFKIQKRRHERYDQEERLRKQRAQKLREEQLKRHEIEMTANRSINVDQIVREHYNERTIIANRAKRNLSPIIKLRNFNNAIKYMLIDKYTKPGDVVLELGCGKGGDLRKYGAAGISQFIGIDISNASIQEAHKRYRSMRNLDYQVVLITGDCFGESLGVAVEPFPDCRFPCDIVSTQFCLHYAFETEEKARRALLNVAKSLKIGGHFFGTIPDSEFIRYKLNKFPKEVEKPSWGNSIYKVTFENNSYQKNDYEFTSPYGQMYTYWLEDAIDNVPEYVVPFETLRSLADEYGLELVSQMPFNKFFVQEIPKWIERFSPKMREGLQRSDGRYGVEGDEKEAASYFYTMFAFRKVKQYIEPESVKPN, the protein localization prediction for P32783 is Cell.membrane,M, so the location label is 4. Therefore, option 4 is the correct answer.\n</think>\n<answer>\n4\n</answer>"
# print(format_reward(str_))
def check_rewards(jsonl_path: str) -> List[Dict[str, float]]:
results = []
with open(jsonl_path, "r", encoding="utf-8") as f:
for line in f:
data = json.loads(line)
solution_str = data["solution_str"]
ground_truth = data["ground_truth"]
# 重新计算三个分数
format_score = format_reward(solution_str)
acc_score = acc_reward(solution_str, ground_truth)
think_match = re.search(r'<think>(.*?)</think>', solution_str, re.DOTALL)
think_str = think_match.group(1).strip() if think_match else ""
repeat_score = repeatness_reward(think_str)
total_score = format_score + acc_score + repeat_score
result = {
"format": format_score,
"accuracy": acc_score,
"repeat": repeat_score,
"overall": total_score,
}
# results.append(result)
print(json.dumps(result, indent=2, ensure_ascii=False))
check_rewards("/nas/shared/kilab/wangyujia/check_rl/check.jsonl")