| from __future__ import annotations |
|
|
| import re |
|
|
| import torch |
| from transformers import PreTrainedTokenizerBase |
|
|
| |
| THINK_INNER_RE = re.compile( |
| r"<redacted_thinking>(.*?)</redacted_thinking>", re.DOTALL |
| ) |
|
|
|
|
| def redacted_thinking_kl_scale( |
| completion_ids: torch.Tensor, |
| completion_mask: torch.Tensor, |
| tokenizer: PreTrainedTokenizerBase, |
| inner_kl_weight: float, |
| ) -> torch.Tensor: |
| """ |
| Per completion token, return a multiplier for the KL term in GRPO. |
| |
| - Default / outside the inner ``<redacted_thinking>...</redacted_thinking>`` body: 1.0 |
| - Tokens whose character span overlaps the inner body: ``inner_kl_weight`` |
| |
| If ``inner_kl_weight == 1.0``, returns ones without decoding (fast path). |
| """ |
| device = completion_ids.device |
| bsz, seqlen = completion_ids.shape |
| scale = torch.ones((bsz, seqlen), device=device, dtype=torch.float32) |
| if inner_kl_weight == 1.0: |
| return scale |
|
|
| for b in range(bsz): |
| valid_len = int(completion_mask[b].sum().item()) |
| if valid_len <= 0: |
| continue |
| row_ids = completion_ids[b, :valid_len].tolist() |
| text = tokenizer.decode(row_ids, skip_special_tokens=True) |
| match = THINK_INNER_RE.search(text) |
| if not match: |
| continue |
| inner_start, inner_end = match.span(1) |
|
|
| try: |
| enc = tokenizer( |
| text, |
| add_special_tokens=False, |
| return_offsets_mapping=True, |
| ) |
| except TypeError: |
| continue |
| offsets = enc.offset_mapping |
| ids = enc.input_ids |
| if len(ids) != valid_len or len(offsets) != valid_len: |
| continue |
|
|
| for ti in range(min(valid_len, seqlen)): |
| if int(completion_mask[b, ti].item()) == 0: |
| continue |
| cs, ce = offsets[ti] |
| if ce <= cs: |
| continue |
| overlaps = cs < inner_end and ce > inner_start |
| if overlaps: |
| scale[b, ti] = float(inner_kl_weight) |
|
|
| return scale |
|
|