Twitch-BPE / src /pretokenizer.py
Soldier-Boy's picture
create: src files
c6e5251 verified
from __future__ import annotations
import re
from typing import List, Tuple
from . import config as CFG
# Rough emoji + ZWJ sequence pattern (approximation)
EMOJI_SEQ_RE = re.compile(r"(?:[\U0001F1E6-\U0001FAD6\U0001F300-\U0001FAFF][\u200D\uFE0F]?)+", flags=re.UNICODE)
PUNCT_RUN_RE = re.compile(r"[!?]{2,}")
# Repeated chat slang runs (not strictly used yet but reserved for protection)
LOL_RE = re.compile(r"(?i)l+o+l+")
KEKW_RE = re.compile(r"(?i)k+e+k+w+")
URL_RE = re.compile(CFG.PROTECT_REGEX["url"])
USER_RE = re.compile(CFG.PROTECT_REGEX["user"])
TIME_RE = re.compile(CFG.PROTECT_REGEX["time"])
NUMBER_RE = re.compile(CFG.PROTECT_REGEX["number"])
EMOTE_RE = re.compile(CFG.PROTECT_REGEX["emote"])
# Canonical base token serializers
def byte_sym(b: int) -> str:
# 2 hex digits; pattern <b:HH>
return f"<b:{b:02X}>"
def macro_sym(text: str) -> str:
# Keep placeholder macros as their literal tokens
return f"<m:{text}>"
def to_bytes_symbols(text: str) -> List[str]:
return [byte_sym(b) for b in text.encode('utf-8')]
def pretokenize_line(line: str) -> List[str]:
"""Pretokenize a line into a list of base symbols according to granularity.
Modes:
- byte: fall back to per UTF-8 byte symbols (<b:HH>) for all non-macros.
- char: treat each Unicode character as a symbol (macros kept intact).
- word: split raw text into tokens as consecutive non-space plus trailing spaces (macros kept intact).
Macros (placeholders, emoji sequences) are always wrapped as <m:...> to remain atomic.
"""
s = line
# URL and USER already replaced in cleaning, but double-guard here
s = URL_RE.sub(CFG.PLACEHOLDERS["URL"], s)
s = USER_RE.sub(CFG.PLACEHOLDERS["USER"], s)
# Collect macro spans (emoji, placeholders)
token_spans: List[Tuple[int, int, str | None]] = []
for m in EMOJI_SEQ_RE.finditer(s):
token_spans.append((m.start(), m.end(), m.group(0)))
for rx in (URL_RE, USER_RE):
for m in rx.finditer(s):
token_spans.append((m.start(), m.end(), m.group(0)))
gran = getattr(CFG, 'TOKEN_GRANULARITY', 'byte')
token_spans.sort(key=lambda x: x[0])
out: List[str] = []
if gran == 'byte':
i = 0
j = 0
while i < len(s):
if j < len(token_spans) and i == token_spans[j][0]:
_, e, lit = token_spans[j]
out.append(macro_sym(lit))
i = e
j += 1
continue
out.extend(to_bytes_symbols(s[i]))
i += 1
while j < len(token_spans) and i >= token_spans[j][1]:
j += 1
return out
if gran == 'char':
i = 0
j = 0
while i < len(s):
if j < len(token_spans) and i == token_spans[j][0]:
_, e, lit = token_spans[j]
out.append(macro_sym(lit))
i = e
j += 1
continue
out.append(s[i])
i += 1
while j < len(token_spans) and i >= token_spans[j][1]:
j += 1
return out
if gran == 'word':
# Build segments alternating raw substrings and macro tokens
segments: List[str] = []
last = 0
for (start, end, lit) in token_spans:
if start > last:
segments.append(s[last:start])
segments.append(macro_sym(lit))
last = end
if last < len(s):
segments.append(s[last:])
for seg in segments:
if seg.startswith('<m:'):
out.append(seg)
else:
# Emit words and whitespace as separate tokens to allow cross-word merges
for m in re.finditer(r"\S+|\s+", seg, flags=re.UNICODE):
out.append(m.group(0))
return out
# Fallback to byte behavior if mode unknown
i = 0
j = 0
while i < len(s):
if j < len(token_spans) and i == token_spans[j][0]:
_, e, lit = token_spans[j]
out.append(macro_sym(lit))
i = e
j += 1
continue
out.extend(to_bytes_symbols(s[i]))
i += 1
while j < len(token_spans) and i >= token_spans[j][1]:
j += 1
return out