from __future__ import annotations import re import torch from transformers import PreTrainedTokenizerBase # Inner reasoning body only (tags excluded from the masked span). THINK_INNER_RE = re.compile( r"(.*?)", re.DOTALL ) def redacted_thinking_kl_scale( completion_ids: torch.Tensor, completion_mask: torch.Tensor, tokenizer: PreTrainedTokenizerBase, inner_kl_weight: float, ) -> torch.Tensor: """ Per completion token, return a multiplier for the KL term in GRPO. - Default / outside the inner ``...`` body: 1.0 - Tokens whose character span overlaps the inner body: ``inner_kl_weight`` If ``inner_kl_weight == 1.0``, returns ones without decoding (fast path). """ device = completion_ids.device bsz, seqlen = completion_ids.shape scale = torch.ones((bsz, seqlen), device=device, dtype=torch.float32) if inner_kl_weight == 1.0: return scale for b in range(bsz): valid_len = int(completion_mask[b].sum().item()) if valid_len <= 0: continue row_ids = completion_ids[b, :valid_len].tolist() text = tokenizer.decode(row_ids, skip_special_tokens=True) match = THINK_INNER_RE.search(text) if not match: continue inner_start, inner_end = match.span(1) try: enc = tokenizer( text, add_special_tokens=False, return_offsets_mapping=True, ) except TypeError: continue offsets = enc.offset_mapping ids = enc.input_ids if len(ids) != valid_len or len(offsets) != valid_len: continue for ti in range(min(valid_len, seqlen)): if int(completion_mask[b, ti].item()) == 0: continue cs, ce = offsets[ti] if ce <= cs: continue overlaps = cs < inner_end and ce > inner_start if overlaps: scale[b, ti] = float(inner_kl_weight) return scale