HusseinEid's picture
Super-squash branch 'main' using huggingface_hub
68a4c53
"""PUA-aware decoding.
After the underlying ByteLevel BPE decoder reconstructs a string, we must
substitute every PUA character back to its original word. This is a single
linear scan with a dict lookup per character — O(n).
"""
from __future__ import annotations
from ._accel_loader import USE_RUST, accel, prepare_mapping
from .pua import PUAMapping, is_pua_char
def reverse_pua_batch(texts: list[str], mapping: PUAMapping) -> list[str]:
"""Batched reverse PUA substitution. One FFI hop, Rayon-parallel."""
if USE_RUST and hasattr(accel, "reverse_pua_batch"):
return list(accel.reverse_pua_batch(texts, prepare_mapping(mapping)))
return [reverse_pua_substitute(t, mapping) for t in texts]
def reverse_pua_substitute(text: str, mapping: PUAMapping) -> str:
"""Replace every PUA character in `text` with its original mapped word.
Characters not in the mapping are passed through unchanged. This is
safe even if the input contains PUA chars that weren't in the mapping
(they survive the round-trip as themselves).
"""
if USE_RUST:
return accel.reverse_pua_substitute(text, prepare_mapping(mapping))
pua_to_word = mapping.pua_to_word
if not pua_to_word:
return text
# Fast path: if no PUA chars present, return as-is.
if not any(is_pua_char(c) for c in text):
return text
out: list[str] = []
for ch in text:
if is_pua_char(ch):
out.append(pua_to_word.get(ch, ch))
else:
out.append(ch)
return "".join(out)
__all__ = ["reverse_pua_substitute"]