from itertools import islice, zip_longest from typing import Callable, Dict, List, Optional, Tuple, TypedDict import json def repeatness_reward(s: str): def ranks(l): index = {v: i for i, v in enumerate(sorted(set(l)))} return [index[v] for v in l] def suffixArray(s): line = ranks(s) n, k, ans, sa = len(s), 1, line, [0] * len(s) while k < n - 1: line = ranks(list(zip_longest(line, islice(line, k, None), fillvalue=-1))) ans, k = line, k << 1 for i, k in enumerate(ans): sa[k] = i return ans, sa def lcp(arr, suffixArr, inv_suff): n, ans, k = len(arr), [0] * len(arr), 0 for i in range(n): if inv_suff[i] == n - 1: k = 0 continue j = suffixArr[inv_suff[i] + 1] while i + k < n and j + k < n and arr[i + k] == arr[j + k]: k += 1 ans[inv_suff[i]] = k if k > 0: k -= 1 return ans arr = [ord(i) for i in s] n = len(arr) if n <= 1: return 0 c, sa = suffixArray(arr) cnt = sum(lcp(arr, sa, c)) return 1 - cnt * 2 / (n * (n + 1)) import re def format_reward(predict_str: str) -> float: """ 格式奖励函数,严格要求输出格式为: ...... 中间不能有多余内容 """ pattern = r'^.*?\s*\s*([0-9])\s*$' return 1.0 if re.fullmatch(pattern, predict_str.strip(), re.DOTALL) else 0.0 def acc_reward(predict_str: str, ground_truth: str) -> float: """ 准确率奖励函数 要求中内容与ground_truth完全一致(顺序、空格等) """ match = re.search(r'\s*([0-9])\s*', predict_str) if not match: return 0.0 answer_content = match.group(1) # print(ground_truth) # print(answer_content) # print(int(answer_content) == ground_truth) # print("ground_truth 类型:", type(ground_truth)) # print("answer_content 类型:", type(answer_content)) # print("========") if int(answer_content) == ground_truth: return 1.0 else: return 0.0 # return 1.0 if answer_content == ground_truth else 0.0 # match = re.search(r'(.*?)', predict_str, re.DOTALL) # if not match: # return 0.0 # answer_content = match.group(1).strip() # return 1.0 if answer_content == ground_truth else 0.0 # def compute_score( solution_str: str, ground_truth: str, extra_info): # """ # 综合评分函数 # """ def compute_score(predicts: List[str], ground_truths: List[str], format_weight: float = 0.1) -> List[Dict[str, float]]: scores = [] save_path="/nas/shared/kilab/wangyujia/check_rl/result-06170934.jsonl" with open(save_path, "w", encoding="utf-8") as f: for solution_str, ground_truth in zip(predicts, ground_truths): format_score = format_reward(solution_str) acc_score = acc_reward(solution_str, ground_truth) # 提取内容 think_match = re.search(r'(.*?)', solution_str, re.DOTALL) think_str = think_match.group(1).strip() if think_match else "" repeat_score = repeatness_reward(think_str) scores.append( { "overall": format_score + acc_score + repeat_score, "format": format_score, "accuracy": acc_score, "repeat" : repeat_score, } ) # 写入 JSONL 文件 f.write(json.dumps({ "solution_str": solution_str, "ground_truth": ground_truth, "overall": format_score + acc_score + repeat_score, "format": format_score, "accuracy": acc_score, "repeat" : repeat_score, }, ensure_ascii=False) + "\n") # 加权综合评分(格式占30%,准确率占70%) # 合成字典 # total_score = { # "format_score": format_score, # "acc_score": acc_score, # "repeat_score": repeat_score, # "total_score": format_score + acc_score + repeat_score # } #total_score=format_score + acc_score + repeat_score return scores # print(format_reward("Step-by-step logic 5 ")) # print(format_reward("Something\nacross lines\n 0 ")) # print(format_reward("No tags here")) # print(format_reward("OK12")) # 多位数字 # print(format_reward("OKA")) # 字母不允许 # print(format_reward("Yes ")) # 空的答案 # print(format_reward("OK3extra")) # 多余内容 # print(format_reward("3Reasoning")) # 标签顺序错误 # print(acc_reward("Step-by-step logic 5 ",'5')) # print(acc_reward("Something\nacross lines\n 0 ",'1')) # str_="\nThe protein name is P32783, the protein amino acid sequence is MSTKPEKPIWMSQEDYDRQYGSITGDESSTVSKKDSKVTANAPGDGNGSLPVLQSSSILTSKVSDLPIEAESGFKIQKRRHERYDQEERLRKQRAQKLREEQLKRHEIEMTANRSINVDQIVREHYNERTIIANRAKRNLSPIIKLRNFNNAIKYMLIDKYTKPGDVVLELGCGKGGDLRKYGAAGISQFIGIDISNASIQEAHKRYRSMRNLDYQVVLITGDCFGESLGVAVEPFPDCRFPCDIVSTQFCLHYAFETEEKARRALLNVAKSLKIGGHFFGTIPDSEFIRYKLNKFPKEVEKPSWGNSIYKVTFENNSYQKNDYEFTSPYGQMYTYWLEDAIDNVPEYVVPFETLRSLADEYGLELVSQMPFNKFFVQEIPKWIERFSPKMREGLQRSDGRYGVEGDEKEAASYFYTMFAFRKVKQYIEPESVKPN, the protein localization prediction for P32783 is Cell.membrane,M, so the location label is 4. Therefore, option 4 is the correct answer.\n\n\n4\n" # print(format_reward(str_)) def check_rewards(jsonl_path: str) -> List[Dict[str, float]]: results = [] with open(jsonl_path, "r", encoding="utf-8") as f: for line in f: data = json.loads(line) solution_str = data["solution_str"] ground_truth = data["ground_truth"] # 重新计算三个分数 format_score = format_reward(solution_str) acc_score = acc_reward(solution_str, ground_truth) think_match = re.search(r'(.*?)', solution_str, re.DOTALL) think_str = think_match.group(1).strip() if think_match else "" repeat_score = repeatness_reward(think_str) total_score = format_score + acc_score + repeat_score result = { "format": format_score, "accuracy": acc_score, "repeat": repeat_score, "overall": total_score, } # results.append(result) print(json.dumps(result, indent=2, ensure_ascii=False)) check_rewards("/nas/shared/kilab/wangyujia/check_rl/check.jsonl")