from __future__ import annotations
import json
import re
from functools import lru_cache
from pathlib import Path
from .registry import register_reward
NUMBER_RE = re.compile(r"[-+]?\d+(?:,\d{3})*(?:\.\d+)?")
FRAC_RE = re.compile(r"^([-+]?\d+)\s*/\s*([-+]?\d+)$")
LATEX_FRAC_RE = re.compile(r"^\\frac\{([^{}]+)\}\{([^{}]+)\}$")
STRICT_FORMAT_RE = re.compile(
r"^\s*(.*?)\s*\\boxed\{(.+?)\}\s*$",
re.DOTALL,
)
BOXED_RE = re.compile(r"\\boxed\{")
THINK_RE = re.compile(r"(.*?)", re.DOTALL)
def _extract_last_boxed(text: str) -> str:
starts = [match.start() for match in BOXED_RE.finditer(text)]
if not starts:
return ""
start = starts[-1]
idx = start + len("\\boxed{")
depth = 1
content_chars: list[str] = []
while idx < len(text):
ch = text[idx]
if ch == "{":
depth += 1
content_chars.append(ch)
elif ch == "}":
depth -= 1
if depth == 0:
return "".join(content_chars).strip()
content_chars.append(ch)
else:
content_chars.append(ch)
idx += 1
return ""
def _extract_numbers(text: str) -> list[float]:
numbers: list[float] = []
for raw in NUMBER_RE.findall(text):
try:
numbers.append(float(raw.replace(",", "")))
except ValueError:
continue
return numbers
def _extract_reference_target(reference: str) -> float | None:
answer_text = _extract_reference_answer_text(reference)
parsed = _parse_numeric(answer_text)
if parsed is not None:
return parsed
nums = _extract_numbers(answer_text)
if nums:
return nums[-1]
return None
def _normalize_answer_text(text: str) -> str:
normalized = text.strip()
normalized = normalized.replace("$", "")
normalized = normalized.replace("\\left", "").replace("\\right", "")
normalized = re.sub(r"\s+", "", normalized)
return normalized
def _parse_numeric(text: str) -> float | None:
source = _normalize_answer_text(text)
if not source:
return None
frac_match = FRAC_RE.match(source)
if frac_match:
num = float(frac_match.group(1))
den = float(frac_match.group(2))
if den == 0:
return None
return num / den
latex_frac_match = LATEX_FRAC_RE.match(source)
if latex_frac_match:
num = _parse_numeric(latex_frac_match.group(1))
den = _parse_numeric(latex_frac_match.group(2))
if num is None or den in (None, 0.0):
return None
return num / den
nums = _extract_numbers(source)
if len(nums) == 1 and source.replace(",", "") == str(nums[0]).rstrip("0").rstrip("."):
return nums[0]
try:
return float(source.replace(",", ""))
except ValueError:
return None
def _extract_reference_answer_text(reference: str) -> str:
# Prefer boxed final answers in Hendrycks MATH.
boxed = _extract_last_boxed(reference)
if boxed:
return boxed
# GSM8K answers typically end with "#### ".
marker_match = re.search(r"####\s*([^\n\r]+)", reference)
if marker_match:
return marker_match.group(1).strip()
return reference.strip()
def _extract_predicted_answer_text(completion: str) -> str:
return _extract_last_boxed(completion)
def _is_close(a: float, b: float) -> bool:
# Allow small rounding differences for decimal answers.
return abs(a - b) <= max(1e-3, 1e-3 * max(abs(a), abs(b)))
def _has_strict_format(completion: str) -> bool:
match = STRICT_FORMAT_RE.match(completion)
if match is None:
return False
think_content = match.group(1).strip()
answer_content = match.group(2).strip()
return bool(think_content and answer_content)
def _completion_length_tokens(text: str) -> int:
# Count approximate generation length as whitespace-separated tokens.
return len(text.split())
@lru_cache(maxsize=4)
def _cached_tokenizer(
tokenizer_name: str, cache_dir: str | None, trust_remote_code: bool
):
from transformers import AutoTokenizer
return AutoTokenizer.from_pretrained(
tokenizer_name,
cache_dir=cache_dir,
trust_remote_code=trust_remote_code,
local_files_only=True,
)
def _think_length_tokens(
text: str,
tokenizer_name: str | None = None,
cache_dir: str | None = None,
trust_remote_code: bool = False,
) -> int:
"""Count CoT length from strict single ... completion format."""
match = STRICT_FORMAT_RE.match(text)
if match is None:
return 0
think_content = match.group(1).strip()
if not think_content:
return 0
if tokenizer_name:
tokenizer = _cached_tokenizer(
tokenizer_name=tokenizer_name,
cache_dir=cache_dir,
trust_remote_code=trust_remote_code,
)
token_ids = tokenizer.encode(think_content, add_special_tokens=False)
return len(token_ids)
# Fallback (should rarely be used): approximate by whitespace words.
return len(think_content.split())
def _length_penalty_scores_for_group(
completions: list[str],
tokenizer_name: str | None = None,
cache_dir: str | None = None,
trust_remote_code: bool = False,
) -> list[float]:
if not completions:
return []
lengths = [
_think_length_tokens(
c,
tokenizer_name=tokenizer_name,
cache_dir=cache_dir,
trust_remote_code=trust_remote_code,
)
for c in completions
]
avg_len = sum(lengths) / len(lengths)
if avg_len <= 0:
return [0.0 for _ in completions]
return [max(0.0, 1.0 - (length / avg_len)) for length in lengths]
@register_reward("format_tag_reward")
def format_tag_reward(
prompts: list[str],
completions: list[str],
references: list[str],
metadata: list[dict],
) -> list[float]:
del prompts, references, metadata
return [1.0 if _has_strict_format(c) else 0.0 for c in completions]
@register_reward("length_penalty_reward")
def length_penalty_reward(
prompts: list[str],
completions: list[str],
references: list[str],
metadata: list[dict],
group_size: int | None = None,
tokenizer_name: str | None = None,
cache_dir: str | None = None,
trust_remote_code: bool = False,
) -> list[float]:
del references, metadata
if not completions:
return []
if group_size is not None and group_size > 0:
scores: list[float] = []
for start in range(0, len(completions), group_size):
group = completions[start : start + group_size]
scores.extend(
_length_penalty_scores_for_group(
group,
tokenizer_name=tokenizer_name,
cache_dir=cache_dir,
trust_remote_code=trust_remote_code,
)
)
return scores
# Fallback: infer groups by contiguous prompt text.
scores = [0.0 for _ in completions]
start = 0
while start < len(completions):
end = start + 1
while end < len(completions) and prompts[end] == prompts[start]:
end += 1
group_scores = _length_penalty_scores_for_group(
completions[start:end],
tokenizer_name=tokenizer_name,
cache_dir=cache_dir,
trust_remote_code=trust_remote_code,
)
scores[start:end] = group_scores
start = end
return scores
@register_reward("gsm8k_correctness_reward")
def gsm8k_correctness_reward(
prompts: list[str],
completions: list[str],
references: list[str],
metadata: list[dict],
) -> list[float]:
del prompts, metadata
scores: list[float] = []
for completion, reference in zip(completions, references, strict=True):
pred_text = _extract_predicted_answer_text(completion)
ref_text = _extract_reference_answer_text(reference)
if not pred_text or not ref_text:
scores.append(0.0)
continue
pred_norm = _normalize_answer_text(pred_text)
ref_norm = _normalize_answer_text(ref_text)
if pred_norm and ref_norm and pred_norm == ref_norm:
scores.append(1.0)
continue
pred_value = _parse_numeric(pred_text)
ref_value = _parse_numeric(ref_text)
if pred_value is not None and ref_value is not None:
if _is_close(pred_value, ref_value):
scores.append(1.0)
else:
scores.append(0.0)
continue
# If symbolic forms don't exactly match, give no correctness credit.
scores.append(0.0)
return scores
@lru_cache(maxsize=8)
def _load_zeroshot_correctness_by_index(results_jsonl_path: str) -> dict[int, bool]:
path = Path(results_jsonl_path)
if not path.exists():
raise FileNotFoundError(f"Zero-shot results file not found: {path}")
mapping: dict[int, bool] = {}
with path.open("r", encoding="utf-8") as handle:
for raw_line in handle:
line = raw_line.strip()
if not line:
continue
row = json.loads(line)
idx = row.get("sample_index")
if idx is None:
continue
if "passed" in row:
passed = bool(row["passed"])
elif "correctness" in row:
passed = float(row["correctness"]) >= 0.5
else:
passed = False
mapping[int(idx)] = passed
return mapping
@register_reward("token_utilisation_reward")
def token_utilisation_reward(
prompts: list[str],
completions: list[str],
references: list[str],
metadata: list[dict],
results_jsonl_path: str,
) -> list[float]:
"""
Reward logic:
- if zero-shot was correct on this sample:
- training correct -> 0.0
- training incorrect -> -1.0
- if zero-shot was incorrect:
- training correct -> +1.0
- training incorrect -> 0.0
"""
del prompts
if len(completions) != len(references) or len(completions) != len(metadata):
raise ValueError("completions, references, and metadata must align.")
zeroshot_pass = _load_zeroshot_correctness_by_index(results_jsonl_path)
train_scores = gsm8k_correctness_reward(
prompts=["" for _ in completions],
completions=completions,
references=references,
metadata=metadata,
)
rewards: list[float] = []
for idx, train_score in enumerate(train_scores):
meta = metadata[idx] if idx < len(metadata) else {}
sample_index = meta.get("sample_index")
zero_shot_correct = (
bool(zeroshot_pass.get(int(sample_index), False))
if sample_index is not None
else False
)
train_correct = float(train_score) >= 0.5
if zero_shot_correct:
rewards.append(0.0 if train_correct else -1.0)
else:
rewards.append(1.0 if train_correct else 0.0)
return rewards
return scores
return scores
return scores
return scores