| """PUA-aware decoding.
|
|
|
| After the underlying ByteLevel BPE decoder reconstructs a string, we must
|
| substitute every PUA character back to its original word. This is a single
|
| linear scan with a dict lookup per character — O(n).
|
| """
|
|
|
| from __future__ import annotations
|
|
|
| from ._accel_loader import USE_RUST, accel, prepare_mapping
|
| from .pua import PUAMapping, is_pua_char
|
|
|
|
|
| def reverse_pua_batch(texts: list[str], mapping: PUAMapping) -> list[str]:
|
| """Batched reverse PUA substitution. One FFI hop, Rayon-parallel."""
|
| if USE_RUST and hasattr(accel, "reverse_pua_batch"):
|
| return list(accel.reverse_pua_batch(texts, prepare_mapping(mapping)))
|
| return [reverse_pua_substitute(t, mapping) for t in texts]
|
|
|
|
|
| def reverse_pua_substitute(text: str, mapping: PUAMapping) -> str:
|
| """Replace every PUA character in `text` with its original mapped word.
|
|
|
| Characters not in the mapping are passed through unchanged. This is
|
| safe even if the input contains PUA chars that weren't in the mapping
|
| (they survive the round-trip as themselves).
|
| """
|
| if USE_RUST:
|
| return accel.reverse_pua_substitute(text, prepare_mapping(mapping))
|
|
|
| pua_to_word = mapping.pua_to_word
|
| if not pua_to_word:
|
| return text
|
|
|
|
|
| if not any(is_pua_char(c) for c in text):
|
| return text
|
|
|
| out: list[str] = []
|
| for ch in text:
|
| if is_pua_char(ch):
|
| out.append(pua_to_word.get(ch, ch))
|
| else:
|
| out.append(ch)
|
| return "".join(out)
|
|
|
|
|
| __all__ = ["reverse_pua_substitute"]
|
|
|