| |
|
| |
|
| |
|
| | from itertools import islice, zip_longest |
| | from typing import Callable, Dict, List, Optional, Tuple, TypedDict |
| | import json |
| |
|
| | def repeatness_reward(s: str): |
| | def ranks(l): |
| | index = {v: i for i, v in enumerate(sorted(set(l)))} |
| | return [index[v] for v in l] |
| |
|
| | def suffixArray(s): |
| | line = ranks(s) |
| | n, k, ans, sa = len(s), 1, line, [0] * len(s) |
| | while k < n - 1: |
| | line = ranks(list(zip_longest(line, islice(line, k, None), fillvalue=-1))) |
| | ans, k = line, k << 1 |
| | for i, k in enumerate(ans): |
| | sa[k] = i |
| | return ans, sa |
| |
|
| | def lcp(arr, suffixArr, inv_suff): |
| | n, ans, k = len(arr), [0] * len(arr), 0 |
| |
|
| | for i in range(n): |
| | if inv_suff[i] == n - 1: |
| | k = 0 |
| | continue |
| |
|
| | j = suffixArr[inv_suff[i] + 1] |
| | while i + k < n and j + k < n and arr[i + k] == arr[j + k]: |
| | k += 1 |
| |
|
| | ans[inv_suff[i]] = k |
| | if k > 0: |
| | k -= 1 |
| |
|
| | return ans |
| |
|
| | arr = [ord(i) for i in s] |
| | n = len(arr) |
| | if n <= 1: |
| | return 0 |
| | c, sa = suffixArray(arr) |
| | cnt = sum(lcp(arr, sa, c)) |
| |
|
| | return 1 - cnt * 2 / (n * (n + 1)) |
| |
|
| | import re |
| |
|
| | def format_reward(predict_str: str) -> float: |
| | """ |
| | 格式奖励函数,严格要求输出格式为: |
| | <think>...</think><answer>...</answer> |
| | 中间不能有多余内容 |
| | """ |
| | pattern = r'^<think>.*?</think>\s*<answer>\s*([0-9])\s*</answer>$' |
| | return 1.0 if re.fullmatch(pattern, predict_str.strip(), re.DOTALL) else 0.0 |
| |
|
| | def acc_reward(predict_str: str, ground_truth: str) -> float: |
| | """ |
| | 准确率奖励函数 |
| | 要求<answer>中内容与ground_truth完全一致(顺序、空格等) |
| | """ |
| | match = re.search(r'<answer>\s*([0-9])\s*</answer>', predict_str) |
| | if not match: |
| | return 0.0 |
| | answer_content = match.group(1) |
| | |
| | |
| | |
| | |
| | |
| | |
| | if int(answer_content) == ground_truth: |
| | return 1.0 |
| | else: |
| | return 0.0 |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | def compute_score(predicts: List[str], ground_truths: List[str], format_weight: float = 0.1) -> List[Dict[str, float]]: |
| | scores = [] |
| | save_path="/nas/shared/kilab/wangyujia/check_rl/result-06170934.jsonl" |
| | with open(save_path, "w", encoding="utf-8") as f: |
| | for solution_str, ground_truth in zip(predicts, ground_truths): |
| | format_score = format_reward(solution_str) |
| | acc_score = acc_reward(solution_str, ground_truth) |
| |
|
| | |
| | think_match = re.search(r'<think>(.*?)</think>', solution_str, re.DOTALL) |
| | think_str = think_match.group(1).strip() if think_match else "" |
| | repeat_score = repeatness_reward(think_str) |
| |
|
| | scores.append( |
| | { |
| | "overall": format_score + acc_score + repeat_score, |
| | "format": format_score, |
| | "accuracy": acc_score, |
| | "repeat" : repeat_score, |
| | } |
| | ) |
| |
|
| | |
| | f.write(json.dumps({ |
| | "solution_str": solution_str, |
| | "ground_truth": ground_truth, |
| | "overall": format_score + acc_score + repeat_score, |
| | "format": format_score, |
| | "accuracy": acc_score, |
| | "repeat" : repeat_score, |
| | }, ensure_ascii=False) + "\n") |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | return scores |
| |
|
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| |
|
| | |
| | |
| |
|
| |
|
| |
|
| | def check_rewards(jsonl_path: str) -> List[Dict[str, float]]: |
| | results = [] |
| | with open(jsonl_path, "r", encoding="utf-8") as f: |
| | for line in f: |
| | data = json.loads(line) |
| | solution_str = data["solution_str"] |
| | ground_truth = data["ground_truth"] |
| |
|
| | |
| | format_score = format_reward(solution_str) |
| | acc_score = acc_reward(solution_str, ground_truth) |
| | think_match = re.search(r'<think>(.*?)</think>', solution_str, re.DOTALL) |
| | think_str = think_match.group(1).strip() if think_match else "" |
| | repeat_score = repeatness_reward(think_str) |
| |
|
| | total_score = format_score + acc_score + repeat_score |
| |
|
| |
|
| | result = { |
| | "format": format_score, |
| | "accuracy": acc_score, |
| | "repeat": repeat_score, |
| | "overall": total_score, |
| | } |
| | |
| |
|
| | print(json.dumps(result, indent=2, ensure_ascii=False)) |
| |
|
| | check_rewards("/nas/shared/kilab/wangyujia/check_rl/check.jsonl") |