| from __future__ import annotations |
|
|
| import json |
| import re |
| from functools import lru_cache |
| from pathlib import Path |
|
|
| from .registry import register_reward |
|
|
| NUMBER_RE = re.compile(r"[-+]?\d+(?:,\d{3})*(?:\.\d+)?") |
| FRAC_RE = re.compile(r"^([-+]?\d+)\s*/\s*([-+]?\d+)$") |
| LATEX_FRAC_RE = re.compile(r"^\\frac\{([^{}]+)\}\{([^{}]+)\}$") |
| STRICT_FORMAT_RE = re.compile( |
| r"^\s*<think>(.*?)</think>\s*\\boxed\{(.+?)\}\s*$", |
| re.DOTALL, |
| ) |
| BOXED_RE = re.compile(r"\\boxed\{") |
| THINK_RE = re.compile(r"<think>(.*?)</think>", re.DOTALL) |
|
|
|
|
| def _extract_last_boxed(text: str) -> str: |
| starts = [match.start() for match in BOXED_RE.finditer(text)] |
| if not starts: |
| return "" |
| start = starts[-1] |
| idx = start + len("\\boxed{") |
| depth = 1 |
| content_chars: list[str] = [] |
| while idx < len(text): |
| ch = text[idx] |
| if ch == "{": |
| depth += 1 |
| content_chars.append(ch) |
| elif ch == "}": |
| depth -= 1 |
| if depth == 0: |
| return "".join(content_chars).strip() |
| content_chars.append(ch) |
| else: |
| content_chars.append(ch) |
| idx += 1 |
| return "" |
|
|
|
|
| def _extract_numbers(text: str) -> list[float]: |
| numbers: list[float] = [] |
| for raw in NUMBER_RE.findall(text): |
| try: |
| numbers.append(float(raw.replace(",", ""))) |
| except ValueError: |
| continue |
| return numbers |
|
|
|
|
| def _extract_reference_target(reference: str) -> float | None: |
| answer_text = _extract_reference_answer_text(reference) |
| parsed = _parse_numeric(answer_text) |
| if parsed is not None: |
| return parsed |
| nums = _extract_numbers(answer_text) |
| if nums: |
| return nums[-1] |
| return None |
|
|
|
|
| def _normalize_answer_text(text: str) -> str: |
| normalized = text.strip() |
| normalized = normalized.replace("$", "") |
| normalized = normalized.replace("\\left", "").replace("\\right", "") |
| normalized = re.sub(r"\s+", "", normalized) |
| return normalized |
|
|
|
|
| def _parse_numeric(text: str) -> float | None: |
| source = _normalize_answer_text(text) |
| if not source: |
| return None |
| frac_match = FRAC_RE.match(source) |
| if frac_match: |
| num = float(frac_match.group(1)) |
| den = float(frac_match.group(2)) |
| if den == 0: |
| return None |
| return num / den |
| latex_frac_match = LATEX_FRAC_RE.match(source) |
| if latex_frac_match: |
| num = _parse_numeric(latex_frac_match.group(1)) |
| den = _parse_numeric(latex_frac_match.group(2)) |
| if num is None or den in (None, 0.0): |
| return None |
| return num / den |
| nums = _extract_numbers(source) |
| if len(nums) == 1 and source.replace(",", "") == str(nums[0]).rstrip("0").rstrip("."): |
| return nums[0] |
| try: |
| return float(source.replace(",", "")) |
| except ValueError: |
| return None |
|
|
|
|
| def _extract_reference_answer_text(reference: str) -> str: |
| |
| boxed = _extract_last_boxed(reference) |
| if boxed: |
| return boxed |
| |
| marker_match = re.search(r"####\s*([^\n\r]+)", reference) |
| if marker_match: |
| return marker_match.group(1).strip() |
| return reference.strip() |
|
|
|
|
| def _extract_predicted_answer_text(completion: str) -> str: |
| return _extract_last_boxed(completion) |
|
|
|
|
| def _is_close(a: float, b: float) -> bool: |
| |
| return abs(a - b) <= max(1e-3, 1e-3 * max(abs(a), abs(b))) |
|
|
|
|
| def _has_strict_format(completion: str) -> bool: |
| match = STRICT_FORMAT_RE.match(completion) |
| if match is None: |
| return False |
| think_content = match.group(1).strip() |
| answer_content = match.group(2).strip() |
| return bool(think_content and answer_content) |
|
|
|
|
| def _completion_length_tokens(text: str) -> int: |
| |
| return len(text.split()) |
|
|
|
|
| @lru_cache(maxsize=4) |
| def _cached_tokenizer( |
| tokenizer_name: str, cache_dir: str | None, trust_remote_code: bool |
| ): |
| from transformers import AutoTokenizer |
|
|
| return AutoTokenizer.from_pretrained( |
| tokenizer_name, |
| cache_dir=cache_dir, |
| trust_remote_code=trust_remote_code, |
| local_files_only=True, |
| ) |
|
|
|
|
| def _think_length_tokens( |
| text: str, |
| tokenizer_name: str | None = None, |
| cache_dir: str | None = None, |
| trust_remote_code: bool = False, |
| ) -> int: |
| """Count CoT length from strict single <think>...</think> completion format.""" |
| match = STRICT_FORMAT_RE.match(text) |
| if match is None: |
| return 0 |
| think_content = match.group(1).strip() |
| if not think_content: |
| return 0 |
|
|
| if tokenizer_name: |
| tokenizer = _cached_tokenizer( |
| tokenizer_name=tokenizer_name, |
| cache_dir=cache_dir, |
| trust_remote_code=trust_remote_code, |
| ) |
| token_ids = tokenizer.encode(think_content, add_special_tokens=False) |
| return len(token_ids) |
|
|
| |
| return len(think_content.split()) |
|
|
|
|
| def _length_penalty_scores_for_group( |
| completions: list[str], |
| tokenizer_name: str | None = None, |
| cache_dir: str | None = None, |
| trust_remote_code: bool = False, |
| ) -> list[float]: |
| if not completions: |
| return [] |
| lengths = [ |
| _think_length_tokens( |
| c, |
| tokenizer_name=tokenizer_name, |
| cache_dir=cache_dir, |
| trust_remote_code=trust_remote_code, |
| ) |
| for c in completions |
| ] |
| avg_len = sum(lengths) / len(lengths) |
| if avg_len <= 0: |
| return [0.0 for _ in completions] |
| return [max(0.0, 1.0 - (length / avg_len)) for length in lengths] |
|
|
|
|
| @register_reward("format_tag_reward") |
| def format_tag_reward( |
| prompts: list[str], |
| completions: list[str], |
| references: list[str], |
| metadata: list[dict], |
| ) -> list[float]: |
| del prompts, references, metadata |
| return [1.0 if _has_strict_format(c) else 0.0 for c in completions] |
|
|
|
|
| @register_reward("length_penalty_reward") |
| def length_penalty_reward( |
| prompts: list[str], |
| completions: list[str], |
| references: list[str], |
| metadata: list[dict], |
| group_size: int | None = None, |
| tokenizer_name: str | None = None, |
| cache_dir: str | None = None, |
| trust_remote_code: bool = False, |
| ) -> list[float]: |
| del references, metadata |
| if not completions: |
| return [] |
|
|
| if group_size is not None and group_size > 0: |
| scores: list[float] = [] |
| for start in range(0, len(completions), group_size): |
| group = completions[start : start + group_size] |
| scores.extend( |
| _length_penalty_scores_for_group( |
| group, |
| tokenizer_name=tokenizer_name, |
| cache_dir=cache_dir, |
| trust_remote_code=trust_remote_code, |
| ) |
| ) |
| return scores |
|
|
| |
| scores = [0.0 for _ in completions] |
| start = 0 |
| while start < len(completions): |
| end = start + 1 |
| while end < len(completions) and prompts[end] == prompts[start]: |
| end += 1 |
| group_scores = _length_penalty_scores_for_group( |
| completions[start:end], |
| tokenizer_name=tokenizer_name, |
| cache_dir=cache_dir, |
| trust_remote_code=trust_remote_code, |
| ) |
| scores[start:end] = group_scores |
| start = end |
| return scores |
|
|
|
|
| @register_reward("gsm8k_correctness_reward") |
| def gsm8k_correctness_reward( |
| prompts: list[str], |
| completions: list[str], |
| references: list[str], |
| metadata: list[dict], |
| ) -> list[float]: |
| del prompts, metadata |
| scores: list[float] = [] |
| for completion, reference in zip(completions, references, strict=True): |
| pred_text = _extract_predicted_answer_text(completion) |
| ref_text = _extract_reference_answer_text(reference) |
| if not pred_text or not ref_text: |
| scores.append(0.0) |
| continue |
|
|
| pred_norm = _normalize_answer_text(pred_text) |
| ref_norm = _normalize_answer_text(ref_text) |
| if pred_norm and ref_norm and pred_norm == ref_norm: |
| scores.append(1.0) |
| continue |
|
|
| pred_value = _parse_numeric(pred_text) |
| ref_value = _parse_numeric(ref_text) |
| if pred_value is not None and ref_value is not None: |
| if _is_close(pred_value, ref_value): |
| scores.append(1.0) |
| else: |
| scores.append(0.0) |
| continue |
|
|
| |
| scores.append(0.0) |
| return scores |
|
|
|
|
| @lru_cache(maxsize=8) |
| def _load_zeroshot_correctness_by_index(results_jsonl_path: str) -> dict[int, bool]: |
| path = Path(results_jsonl_path) |
| if not path.exists(): |
| raise FileNotFoundError(f"Zero-shot results file not found: {path}") |
|
|
| mapping: dict[int, bool] = {} |
| with path.open("r", encoding="utf-8") as handle: |
| for raw_line in handle: |
| line = raw_line.strip() |
| if not line: |
| continue |
| row = json.loads(line) |
| idx = row.get("sample_index") |
| if idx is None: |
| continue |
|
|
| if "passed" in row: |
| passed = bool(row["passed"]) |
| elif "correctness" in row: |
| passed = float(row["correctness"]) >= 0.5 |
| else: |
| passed = False |
| mapping[int(idx)] = passed |
| return mapping |
|
|
|
|
| @register_reward("token_utilisation_reward") |
| def token_utilisation_reward( |
| prompts: list[str], |
| completions: list[str], |
| references: list[str], |
| metadata: list[dict], |
| results_jsonl_path: str, |
| ) -> list[float]: |
| """ |
| Reward logic: |
| - if zero-shot was correct on this sample: |
| - training correct -> 0.0 |
| - training incorrect -> -1.0 |
| - if zero-shot was incorrect: |
| - training correct -> +1.0 |
| - training incorrect -> 0.0 |
| """ |
| del prompts |
| if len(completions) != len(references) or len(completions) != len(metadata): |
| raise ValueError("completions, references, and metadata must align.") |
|
|
| zeroshot_pass = _load_zeroshot_correctness_by_index(results_jsonl_path) |
| train_scores = gsm8k_correctness_reward( |
| prompts=["" for _ in completions], |
| completions=completions, |
| references=references, |
| metadata=metadata, |
| ) |
|
|
| rewards: list[float] = [] |
| for idx, train_score in enumerate(train_scores): |
| meta = metadata[idx] if idx < len(metadata) else {} |
| sample_index = meta.get("sample_index") |
| zero_shot_correct = ( |
| bool(zeroshot_pass.get(int(sample_index), False)) |
| if sample_index is not None |
| else False |
| ) |
| train_correct = float(train_score) >= 0.5 |
|
|
| if zero_shot_correct: |
| rewards.append(0.0 if train_correct else -1.0) |
| else: |
| rewards.append(1.0 if train_correct else 0.0) |
| return rewards |
|
|
| return scores |
|
|
| return scores |
|
|
| return scores |
|
|
| return scores |
|
|