HusseinEid's picture
Super-squash branch 'main' using huggingface_hub
68a4c53
"""PUA candidate selection.
Two strategies are provided:
* `select_by_savings` — the production path. Ranks candidates by
`frequency * (baseline_token_cost - 1)` so PUA slots go to tokens that
actually save tokens vs the baseline (cl100k). Tokens whose baseline
cost is 1 (single byte/char) score 0 — byte fallback already handles
them optimally.
* `select_by_coverage` — the legacy v0 strategy. Ranks by raw frequency.
Kept as a deprecated shim for callers that haven't migrated; emits a
`DeprecationWarning` on use.
The savings strategy also applies a tier-aware penalty: tokens beyond
`PUA_BMP_SIZE` would be assigned a 4-byte supplementary-plane PUA char
(vs 3 bytes for BMP). For tokens whose token-savings == 1, the byte cost
of substitution can erase the saving, so by default we cap the budget at
`PUA_BMP_SIZE` unless the caller opts into supplementary planes.
"""
from __future__ import annotations
import warnings
from collections import Counter
import regex as _re
from .baseline import BaselineTokenizer
from .pua import PUA_BMP_SIZE
DEFAULT_MAX_LEN = 50
# Heuristic regex for hash / UUID / base64-blob shapes — these waste PUA slots.
# Keep conservative: false positives just demote a token from PUA to byte
# fallback, which is fine.
_HEX_HASH_RE = _re.compile(r"^[0-9a-f]{16,}$", _re.IGNORECASE)
_UUID_RE = _re.compile(
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$", _re.IGNORECASE
)
_BASE64_BLOB_RE = _re.compile(r"^[A-Za-z0-9+/=_\-]{24,}$")
# A "blob" must be mostly base64 alphabet with no recognizable English-word
# substrings. We approximate: if the token is long, has high alphabet entropy,
# and at least 1/3 digits or 1/3 mixed-case alpha runs, treat it as a blob.
def _looks_like_secret_or_hash(tok: str) -> bool:
"""Heuristic: True if `tok` looks like a hash, UUID, or base64 blob."""
if _UUID_RE.match(tok):
return True
if _HEX_HASH_RE.match(tok):
return True
if len(tok) >= 24 and _BASE64_BLOB_RE.match(tok):
# Has the shape — confirm it lacks lowercase-word substructure.
# A real identifier has long-ish lowercase runs; a base64 blob is
# interleaved upper/lower/digit. Require at least 3 short runs of
# contiguous lowercase letters before we accept it as "wordy".
lower_runs = _re.findall(r"[a-z]{3,}", tok)
if len(lower_runs) < 2:
return True
return False
def is_good_token(tok: str, max_len: int = DEFAULT_MAX_LEN) -> bool:
"""Filter for PUA-eligible tokens.
Rejects:
- Empty / whitespace-only / over-length tokens.
- Tokens that encode to a single UTF-8 byte (byte fallback handles them
optimally, so a PUA assignment is at best a wash).
- Hashes, UUIDs, long base64 blobs (waste PUA slots; near-zero
cross-document reuse).
"""
if not tok:
return False
if tok.isspace():
return False
if len(tok) > max_len:
return False
if len(tok.encode("utf-8")) <= 1:
return False
return not _looks_like_secret_or_hash(tok)
def compute_candidate_score(
token: str,
frequency: int,
baseline: BaselineTokenizer,
) -> float:
"""Net token savings per occurrence times frequency.
`savings = max(0, baseline.count_tokens(token) - 1)`.
A token with `baseline_count == 1` scores 0 — substituting it costs a
PUA slot but doesn't shorten the token sequence. A token with
`baseline_count == 5` and frequency 1000 scores 4000.
Note: this scoring is intentionally tier-agnostic. The tier-aware
penalty (BMP vs supplementary plane) is applied in `select_by_savings`
after ranking, not here, because the tier of any given candidate
depends on how many candidates outrank it.
"""
baseline_cost = baseline.count_tokens(token)
savings = max(0, baseline_cost - 1)
return frequency * savings
def select_by_savings(
freq: Counter[str],
baseline: BaselineTokenizer,
*,
vocab_budget: int,
max_len: int = DEFAULT_MAX_LEN,
min_score: float = 1.0,
allow_supplementary_pua: bool = False,
) -> list[str]:
"""Select PUA candidates by net token savings.
Parameters
----------
freq
Frequency counter from `count_frequencies`.
baseline
Baseline tokenizer used to score `(baseline_cost - 1)` per token.
vocab_budget
Hard upper bound on the number of selected tokens. Capped at
`PUA_BMP_SIZE` unless `allow_supplementary_pua` is True.
max_len
Reject tokens longer than this.
min_score
Reject tokens whose score is below this threshold. Default 1.0
(i.e., at least one token saved across the entire corpus).
allow_supplementary_pua
If True, the budget is uncapped (limited only by `PUA_TOTAL`).
Off by default because supplementary-plane PUA chars are 4 bytes
each and net savings can go negative for `baseline_count==2`
candidates.
Returns
-------
Tokens in deterministic order: descending score, ascending lex for ties.
"""
if vocab_budget <= 0:
return []
effective_budget = vocab_budget
if not allow_supplementary_pua:
effective_budget = min(effective_budget, PUA_BMP_SIZE)
scored: list[tuple[float, str]] = []
for token, count in freq.items():
if not is_good_token(token, max_len=max_len):
continue
score = compute_candidate_score(token, count, baseline)
if score < min_score:
continue
scored.append((score, token))
# Sort by (-score, token) for deterministic, savings-priority order.
scored.sort(key=lambda st: (-st[0], st[1]))
return [tok for _, tok in scored[:effective_budget]]
def select_by_coverage(
freq: Counter[str],
coverage_target: float = 0.90,
max_len: int = DEFAULT_MAX_LEN,
max_tokens: int | None = None,
) -> list[str]:
"""Frequency-based selection (deprecated, retained for backward compat).
Stops at the first of: cumulative coverage ≥ `coverage_target`,
`max_tokens` selected, or input exhausted. New code should use
`select_by_savings`.
"""
if not 0.0 < coverage_target <= 1.0:
raise ValueError(f"coverage_target must be in (0,1], got {coverage_target}")
total = sum(freq.values())
if total == 0:
return []
sorted_items = sorted(freq.items(), key=lambda kv: (-kv[1], kv[0]))
selected: list[str] = []
cumulative = 0
threshold = coverage_target * total
for token, count in sorted_items:
if not is_good_token(token, max_len=max_len):
continue
selected.append(token)
cumulative += count
if max_tokens is not None and len(selected) >= max_tokens:
break
if cumulative >= threshold:
break
return selected
def coverage_of(freq: Counter[str], tokens: list[str]) -> float:
"""Fraction of total frequency covered by `tokens`. Used for diagnostics."""
total = sum(freq.values())
if total == 0:
return 0.0
covered = sum(freq.get(t, 0) for t in tokens)
return covered / total
def select_by_coverage_deprecated(
freq: Counter[str],
coverage_target: float = 0.90,
max_len: int = DEFAULT_MAX_LEN,
max_tokens: int | None = None,
) -> list[str]:
"""Shim that emits a DeprecationWarning then delegates to `select_by_coverage`."""
warnings.warn(
"select_by_coverage is deprecated; use select_by_savings with a "
"BaselineTokenizer for production builds.",
DeprecationWarning,
stacklevel=2,
)
return select_by_coverage(freq, coverage_target, max_len, max_tokens)
__all__ = [
"DEFAULT_MAX_LEN",
"compute_candidate_score",
"coverage_of",
"is_good_token",
"select_by_coverage",
"select_by_coverage_deprecated",
"select_by_savings",
]