from __future__ import annotations
import re
import torch
from transformers import PreTrainedTokenizerBase
# Inner reasoning body only (tags excluded from the masked span).
THINK_INNER_RE = re.compile(
r"(.*?)", re.DOTALL
)
def redacted_thinking_kl_scale(
completion_ids: torch.Tensor,
completion_mask: torch.Tensor,
tokenizer: PreTrainedTokenizerBase,
inner_kl_weight: float,
) -> torch.Tensor:
"""
Per completion token, return a multiplier for the KL term in GRPO.
- Default / outside the inner ``...`` body: 1.0
- Tokens whose character span overlaps the inner body: ``inner_kl_weight``
If ``inner_kl_weight == 1.0``, returns ones without decoding (fast path).
"""
device = completion_ids.device
bsz, seqlen = completion_ids.shape
scale = torch.ones((bsz, seqlen), device=device, dtype=torch.float32)
if inner_kl_weight == 1.0:
return scale
for b in range(bsz):
valid_len = int(completion_mask[b].sum().item())
if valid_len <= 0:
continue
row_ids = completion_ids[b, :valid_len].tolist()
text = tokenizer.decode(row_ids, skip_special_tokens=True)
match = THINK_INNER_RE.search(text)
if not match:
continue
inner_start, inner_end = match.span(1)
try:
enc = tokenizer(
text,
add_special_tokens=False,
return_offsets_mapping=True,
)
except TypeError:
continue
offsets = enc.offset_mapping
ids = enc.input_ids
if len(ids) != valid_len or len(offsets) != valid_len:
continue
for ti in range(min(valid_len, seqlen)):
if int(completion_mask[b, ti].item()) == 0:
continue
cs, ce = offsets[ti]
if ce <= cs:
continue
overlaps = cs < inner_end and ce > inner_start
if overlaps:
scale[b, ti] = float(inner_kl_weight)
return scale