PyTorch
gpt2
gpt2-10M-parfind-eng / tokenizer.py
achille-fusco's picture
Update tokenizer.py
a4b935f verified
# tokenizer.py
# Wrapper for ParadigmFinder segmentation + portable HF tokenizer
from typing import List, Tuple, Optional, Union, Dict, Any
import os, json, re
from transformers import PreTrainedTokenizerFast
from huggingface_hub import hf_hub_download
def _get_repo_file(repo_id_or_path: str, filename: str, revision: Optional[str] = None) -> str:
"""
If `repo_id_or_path` is a local folder and the file exists there, return its path.
Otherwise download it from the Hub and return the cached path.
"""
local = os.path.join(repo_id_or_path, filename)
if os.path.exists(local):
return local
return hf_hub_download(repo_id=repo_id_or_path, filename=filename, revision=revision)
def _deserialize_suffixes_from_json(sfx_list):
out = set()
for item in sfx_list:
if isinstance(item, list):
# JSON nested: [base, nested_list]
base, nested = item
out.add((base, frozenset(nested)))
else:
out.add(item) # plain string like "", "ing", "s"
return out
def _load_paradigms_any(path):
import json
with open(path, "r", encoding="utf-8") as f:
payload = json.load(f)
# Case A: new schema with top-level dict {"paradigms": [...]}
if isinstance(payload, dict) and "paradigms" in payload:
paradigms = []
for p in payload["paradigms"]:
stems = set(p["stems"])
suffixes = _deserialize_suffixes_from_json(p["suffixes"])
paradigms.append((stems, suffixes))
meta = payload.get("meta", {})
return paradigms, meta
# Case B: older “list of pairs” JSON [[stems, suffixes], ...]
if isinstance(payload, list) and payload and isinstance(payload[0], list):
paradigms = []
for stems, suffixes in payload:
stems = set(stems)
# suffixes may be ["", ["er", ["", "s"]], "ing"] or already strings
norm = _deserialize_suffixes_from_json(suffixes)
paradigms.append((stems, norm))
return paradigms, {}
# Case C: already python-native structure (rare if not using JSON)
if isinstance(payload, list) and payload and isinstance(payload[0], (list, tuple)) and len(payload[0]) == 2:
return payload, {}
raise ValueError("Unrecognized paradigms.json format")
# ----------------------------
# Paradigm-based segmenter
# ----------------------------
class ParadigmFinderSegmenter:
def __init__(self, paradigms, lowercase=True, space_punct=True):
self.paradigms = paradigms
self.lowercase = lowercase
self.space_punct = space_punct
def _preprocess(self, text: str) -> str:
s = text
if self.lowercase:
s = s.lower()
if self.space_punct:
s = re.sub(r"([^\w\s'])", r" \1 ", s)
s = re.sub(r"\s+", " ", s).strip()
return s
# faithful to your segmentation logic
def _segment_word(self, word: str, fallback=True, top_k=20) -> List[str]:
def match_suffixes(suffixes, remainder):
for suffix in suffixes:
if isinstance(suffix, (tuple, list)):
base, nested = suffix
if remainder.startswith(base):
sub = remainder[len(base):]
nested_result = match_suffixes(nested, sub)
if nested_result is not None:
return [base] + nested_result
elif remainder == suffix:
return [suffix] if suffix else []
return None
for stems, suffixes in self.paradigms:
for stem in stems:
if word.startswith(stem):
remainder = word[len(stem):]
matched_suffix = match_suffixes(suffixes, remainder)
if matched_suffix is not None:
return [stem] + matched_suffix
if fallback:
candidates = self.paradigms[:top_k]
longest = ""
def collect_flat(sfx):
for s in sfx:
if isinstance(s, (tuple, list)):
yield s[0]
yield from collect_flat(s[1])
else:
yield s
for _, suffixes in candidates:
for suffix in collect_flat(suffixes):
if word.endswith(suffix) and len(suffix) > len(longest):
longest = suffix
if longest:
stem = word[:-len(longest)]
return [stem, longest]
return [word]
def segment_with_alignment(self, raw_text: str) -> Tuple[str, List[Optional[int]]]:
"""
Preprocess + segment; return segmented text and a char map from segmented
text back to raw indices.
"""
# 1) Preprocess with alignment
pre_chars, pre_map = [], []
s = raw_text.lower() if self.lowercase else raw_text
out, out_map = [], []
# insert spaces around punctuation (if enabled), tracking alignment
for i, ch in enumerate(s):
if self.space_punct and re.match(r"[^\w\s']", ch):
out.append(" "); out_map.append(None)
out.append(ch); out_map.append(i)
out.append(" "); out_map.append(None)
else:
out.append(ch); out_map.append(i)
# collapse/strip spaces
pre = []
pre2raw = []
prev_space = False
for ch, m in zip(out, out_map):
if ch.isspace():
if not prev_space:
pre.append(" "); pre2raw.append(None)
prev_space = True
else:
pre.append(ch); pre2raw.append(m); prev_space = False
if pre and pre[0] == " ": pre.pop(0); pre2raw.pop(0)
if pre and pre[-1] == " ": pre.pop(); pre2raw.pop()
norm = "".join(pre)
# 2) Segment by paradigms, preserving alignment
seg_chars, seg_map = [], []
i = 0
n = len(norm)
while i < n:
while i < n and norm[i].isspace():
i += 1
if i >= n: break
j = i
while j < n and not norm[j].isspace():
j += 1
token = norm[i:j]
token_map = pre2raw[i:j]
parts = self._segment_word(token, fallback=True)
# robust emission: consume all chars exactly once
pos = 0
for p_index, part in enumerate(parts):
L = len(part)
# clamp to remaining length
L = min(L, len(token) - pos)
if L <= 0: continue
for k in range(L):
seg_chars.append(token[pos + k])
seg_map.append(token_map[pos + k])
pos += L
if p_index < len(parts) - 1:
seg_chars.append(" "); seg_map.append(None)
# inter-token space
i = j
while i < n and norm[i].isspace():
i += 1
if i < n:
seg_chars.append(" "); seg_map.append(None)
# final collapse (defensive)
final = []
final_map = []
prev_space = False
for ch, m in zip(seg_chars, seg_map):
if ch.isspace():
if not prev_space:
final.append(" "); final_map.append(None); prev_space = True
else:
final.append(ch); final_map.append(m); prev_space = False
if final and final[0] == " ": final.pop(0); final_map.pop(0)
if final and final[-1] == " ": final.pop(); final_map.pop()
return "".join(final), final_map
# ----------------------------
# Offset remapping helper
# ----------------------------
def remap_offsets_to_raw(offsets: List[Tuple[int,int]], pre2raw: List[Optional[int]]) -> List[Tuple[int,int]]:
mapped = []
L = len(pre2raw)
for s,e in offsets:
s = max(0, min(s, L)); e = max(0, min(e, L))
rs = re_ = None
t = s
while t < e and rs is None:
if pre2raw[t] is not None: rs = pre2raw[t]
t += 1
t = e - 1
while t >= s and re_ is None:
if pre2raw[t] is not None: re_ = pre2raw[t] + 1
t -= 1
mapped.append((rs if rs is not None else 0, re_ if re_ is not None else 0))
return mapped
def _coerce_to_str(x):
# common cases first
if isinstance(x, str):
return x
if isinstance(x, dict):
for key in ("text", "sentence", "input", "prompt"):
if key in x and isinstance(x[key], str):
return x[key]
# fallback: join any stringy values
vals = [v for v in x.values() if isinstance(v, str)]
if vals:
return " ".join(vals)
return str(x)
if isinstance(x, (list, tuple)):
# prefer first/last string element if present
for pick in (0, -1):
try:
v = x[pick]
if isinstance(v, str):
return v
except Exception:
pass
# else join all string elements
parts = [v for v in x if isinstance(v, str)]
if parts:
return " ".join(parts)
return str(x)
# final fallback
return str(x)
# ----------------------------
# Public wrapper
# ----------------------------
class ParadigmTokenizerWrapper(PreTrainedTokenizerFast):
slow_tokenizer_class = None
def __init__(self, *args, **kwargs):
name_or_path = kwargs.get("name_or_path", None)
if name_or_path is None and len(args) > 0 and isinstance(args[0], str):
name_or_path = args[0]
if "tokenizer_file" not in kwargs and "tokenizer_object" not in kwargs and name_or_path is not None:
tf = os.path.join(name_or_path, "tokenizer.json")
if not os.path.isfile(tf):
raise FileNotFoundError(f"Expected tokenizer.json at {tf}")
kwargs["tokenizer_file"] = tf
super().__init__(*args, **kwargs)
repo_id_or_path = kwargs.get("name_or_path", getattr(self, "name_or_path", None)) \
or os.path.dirname(getattr(self, "tokenizer_file", "")) or "."
revision = kwargs.get("revision", None)
cfg = {"lowercase": True, "space_punct": True}
ppath = _get_repo_file(repo_id_or_path, "paradigms.json", revision)
self.paradigms, self.paradigms_meta = _load_paradigms_any(ppath)
cpath = _get_repo_file(repo_id_or_path, "preprocess_config.json", revision)
cfg_path_exists = os.path.exists(cpath) # when local path returned
with open(cpath, "r", encoding="utf-8") as f:
cfg.update(json.load(f))
self.segmenter = ParadigmFinderSegmenter(
paradigms=self.paradigms,
lowercase=cfg.get("lowercase", True),
space_punct=cfg.get("space_punct", True),
)
def __call__(self, text, **kwargs):
# 1) fast path: already a plain string
if isinstance(text, str):
seg, _ = self.segmenter.segment_with_alignment(text)
return super().__call__(seg, **kwargs)
# 2) dicts: coerce to a single string (don't iterate keys!)
if isinstance(text, dict):
s = _coerce_to_str(text)
seg, _ = self.segmenter.segment_with_alignment(s)
return super().__call__(seg, **kwargs)
# 3) sequences (list/tuple/etc.): coerce each element to a string
try:
items = list(text)
except TypeError:
s = _coerce_to_str(text)
seg, _ = self.segmenter.segment_with_alignment(s)
return super().__call__(seg, **kwargs)
segs = []
for t in items:
s = _coerce_to_str(t)
seg, _ = self.segmenter.segment_with_alignment(s)
segs.append(seg)
return super().__call__(segs, **kwargs)
def tokenize(self, text, **kwargs):
if isinstance(text, str):
seg, _ = self.segmenter.segment_with_alignment(text) # <-- fix here
return super().tokenize(seg, **kwargs)
if isinstance(text, dict):
s = _coerce_to_str(text)
seg, _ = self.segmenter.segment_with_alignment(s)
return super().tokenize(seg, **kwargs)
try:
items = list(text)
except TypeError:
s = _coerce_to_str(text)
seg, _ = self.segmenter.segment_with_alignment(s)
return super().tokenize(seg, **kwargs)
out = []
for t in items:
s = _coerce_to_str(t)
seg, _ = self.segmenter.segment_with_alignment(s)
out.extend(super().tokenize(seg, **kwargs))
return out