from __future__ import annotations import json import re from functools import lru_cache from pathlib import Path from .registry import register_reward NUMBER_RE = re.compile(r"[-+]?\d+(?:,\d{3})*(?:\.\d+)?") FRAC_RE = re.compile(r"^([-+]?\d+)\s*/\s*([-+]?\d+)$") LATEX_FRAC_RE = re.compile(r"^\\frac\{([^{}]+)\}\{([^{}]+)\}$") STRICT_FORMAT_RE = re.compile( r"^\s*(.*?)\s*\\boxed\{(.+?)\}\s*$", re.DOTALL, ) BOXED_RE = re.compile(r"\\boxed\{") THINK_RE = re.compile(r"(.*?)", re.DOTALL) def _extract_last_boxed(text: str) -> str: starts = [match.start() for match in BOXED_RE.finditer(text)] if not starts: return "" start = starts[-1] idx = start + len("\\boxed{") depth = 1 content_chars: list[str] = [] while idx < len(text): ch = text[idx] if ch == "{": depth += 1 content_chars.append(ch) elif ch == "}": depth -= 1 if depth == 0: return "".join(content_chars).strip() content_chars.append(ch) else: content_chars.append(ch) idx += 1 return "" def _extract_numbers(text: str) -> list[float]: numbers: list[float] = [] for raw in NUMBER_RE.findall(text): try: numbers.append(float(raw.replace(",", ""))) except ValueError: continue return numbers def _extract_reference_target(reference: str) -> float | None: answer_text = _extract_reference_answer_text(reference) parsed = _parse_numeric(answer_text) if parsed is not None: return parsed nums = _extract_numbers(answer_text) if nums: return nums[-1] return None def _normalize_answer_text(text: str) -> str: normalized = text.strip() normalized = normalized.replace("$", "") normalized = normalized.replace("\\left", "").replace("\\right", "") normalized = re.sub(r"\s+", "", normalized) return normalized def _parse_numeric(text: str) -> float | None: source = _normalize_answer_text(text) if not source: return None frac_match = FRAC_RE.match(source) if frac_match: num = float(frac_match.group(1)) den = float(frac_match.group(2)) if den == 0: return None return num / den latex_frac_match = LATEX_FRAC_RE.match(source) if latex_frac_match: num = _parse_numeric(latex_frac_match.group(1)) den = _parse_numeric(latex_frac_match.group(2)) if num is None or den in (None, 0.0): return None return num / den nums = _extract_numbers(source) if len(nums) == 1 and source.replace(",", "") == str(nums[0]).rstrip("0").rstrip("."): return nums[0] try: return float(source.replace(",", "")) except ValueError: return None def _extract_reference_answer_text(reference: str) -> str: # Prefer boxed final answers in Hendrycks MATH. boxed = _extract_last_boxed(reference) if boxed: return boxed # GSM8K answers typically end with "#### ". marker_match = re.search(r"####\s*([^\n\r]+)", reference) if marker_match: return marker_match.group(1).strip() return reference.strip() def _extract_predicted_answer_text(completion: str) -> str: return _extract_last_boxed(completion) def _is_close(a: float, b: float) -> bool: # Allow small rounding differences for decimal answers. return abs(a - b) <= max(1e-3, 1e-3 * max(abs(a), abs(b))) def _has_strict_format(completion: str) -> bool: match = STRICT_FORMAT_RE.match(completion) if match is None: return False think_content = match.group(1).strip() answer_content = match.group(2).strip() return bool(think_content and answer_content) def _completion_length_tokens(text: str) -> int: # Count approximate generation length as whitespace-separated tokens. return len(text.split()) @lru_cache(maxsize=4) def _cached_tokenizer( tokenizer_name: str, cache_dir: str | None, trust_remote_code: bool ): from transformers import AutoTokenizer return AutoTokenizer.from_pretrained( tokenizer_name, cache_dir=cache_dir, trust_remote_code=trust_remote_code, local_files_only=True, ) def _think_length_tokens( text: str, tokenizer_name: str | None = None, cache_dir: str | None = None, trust_remote_code: bool = False, ) -> int: """Count CoT length from strict single ... completion format.""" match = STRICT_FORMAT_RE.match(text) if match is None: return 0 think_content = match.group(1).strip() if not think_content: return 0 if tokenizer_name: tokenizer = _cached_tokenizer( tokenizer_name=tokenizer_name, cache_dir=cache_dir, trust_remote_code=trust_remote_code, ) token_ids = tokenizer.encode(think_content, add_special_tokens=False) return len(token_ids) # Fallback (should rarely be used): approximate by whitespace words. return len(think_content.split()) def _length_penalty_scores_for_group( completions: list[str], tokenizer_name: str | None = None, cache_dir: str | None = None, trust_remote_code: bool = False, ) -> list[float]: if not completions: return [] lengths = [ _think_length_tokens( c, tokenizer_name=tokenizer_name, cache_dir=cache_dir, trust_remote_code=trust_remote_code, ) for c in completions ] avg_len = sum(lengths) / len(lengths) if avg_len <= 0: return [0.0 for _ in completions] return [max(0.0, 1.0 - (length / avg_len)) for length in lengths] @register_reward("format_tag_reward") def format_tag_reward( prompts: list[str], completions: list[str], references: list[str], metadata: list[dict], ) -> list[float]: del prompts, references, metadata return [1.0 if _has_strict_format(c) else 0.0 for c in completions] @register_reward("length_penalty_reward") def length_penalty_reward( prompts: list[str], completions: list[str], references: list[str], metadata: list[dict], group_size: int | None = None, tokenizer_name: str | None = None, cache_dir: str | None = None, trust_remote_code: bool = False, ) -> list[float]: del references, metadata if not completions: return [] if group_size is not None and group_size > 0: scores: list[float] = [] for start in range(0, len(completions), group_size): group = completions[start : start + group_size] scores.extend( _length_penalty_scores_for_group( group, tokenizer_name=tokenizer_name, cache_dir=cache_dir, trust_remote_code=trust_remote_code, ) ) return scores # Fallback: infer groups by contiguous prompt text. scores = [0.0 for _ in completions] start = 0 while start < len(completions): end = start + 1 while end < len(completions) and prompts[end] == prompts[start]: end += 1 group_scores = _length_penalty_scores_for_group( completions[start:end], tokenizer_name=tokenizer_name, cache_dir=cache_dir, trust_remote_code=trust_remote_code, ) scores[start:end] = group_scores start = end return scores @register_reward("gsm8k_correctness_reward") def gsm8k_correctness_reward( prompts: list[str], completions: list[str], references: list[str], metadata: list[dict], ) -> list[float]: del prompts, metadata scores: list[float] = [] for completion, reference in zip(completions, references, strict=True): pred_text = _extract_predicted_answer_text(completion) ref_text = _extract_reference_answer_text(reference) if not pred_text or not ref_text: scores.append(0.0) continue pred_norm = _normalize_answer_text(pred_text) ref_norm = _normalize_answer_text(ref_text) if pred_norm and ref_norm and pred_norm == ref_norm: scores.append(1.0) continue pred_value = _parse_numeric(pred_text) ref_value = _parse_numeric(ref_text) if pred_value is not None and ref_value is not None: if _is_close(pred_value, ref_value): scores.append(1.0) else: scores.append(0.0) continue # If symbolic forms don't exactly match, give no correctness credit. scores.append(0.0) return scores @lru_cache(maxsize=8) def _load_zeroshot_correctness_by_index(results_jsonl_path: str) -> dict[int, bool]: path = Path(results_jsonl_path) if not path.exists(): raise FileNotFoundError(f"Zero-shot results file not found: {path}") mapping: dict[int, bool] = {} with path.open("r", encoding="utf-8") as handle: for raw_line in handle: line = raw_line.strip() if not line: continue row = json.loads(line) idx = row.get("sample_index") if idx is None: continue if "passed" in row: passed = bool(row["passed"]) elif "correctness" in row: passed = float(row["correctness"]) >= 0.5 else: passed = False mapping[int(idx)] = passed return mapping @register_reward("token_utilisation_reward") def token_utilisation_reward( prompts: list[str], completions: list[str], references: list[str], metadata: list[dict], results_jsonl_path: str, ) -> list[float]: """ Reward logic: - if zero-shot was correct on this sample: - training correct -> 0.0 - training incorrect -> -1.0 - if zero-shot was incorrect: - training correct -> +1.0 - training incorrect -> 0.0 """ del prompts if len(completions) != len(references) or len(completions) != len(metadata): raise ValueError("completions, references, and metadata must align.") zeroshot_pass = _load_zeroshot_correctness_by_index(results_jsonl_path) train_scores = gsm8k_correctness_reward( prompts=["" for _ in completions], completions=completions, references=references, metadata=metadata, ) rewards: list[float] = [] for idx, train_score in enumerate(train_scores): meta = metadata[idx] if idx < len(metadata) else {} sample_index = meta.get("sample_index") zero_shot_correct = ( bool(zeroshot_pass.get(int(sample_index), False)) if sample_index is not None else False ) train_correct = float(train_score) >= 0.5 if zero_shot_correct: rewards.append(0.0 if train_correct else -1.0) else: rewards.append(1.0 if train_correct else 0.0) return rewards return scores return scores return scores return scores