# tokenizer.py # Wrapper for ParadigmFinder segmentation + portable HF tokenizer from typing import List, Tuple, Optional, Union, Dict, Any import os, json, re from transformers import PreTrainedTokenizerFast from huggingface_hub import hf_hub_download def _get_repo_file(repo_id_or_path: str, filename: str, revision: Optional[str] = None) -> str: """ If `repo_id_or_path` is a local folder and the file exists there, return its path. Otherwise download it from the Hub and return the cached path. """ local = os.path.join(repo_id_or_path, filename) if os.path.exists(local): return local return hf_hub_download(repo_id=repo_id_or_path, filename=filename, revision=revision) def _deserialize_suffixes_from_json(sfx_list): out = set() for item in sfx_list: if isinstance(item, list): # JSON nested: [base, nested_list] base, nested = item out.add((base, frozenset(nested))) else: out.add(item) # plain string like "", "ing", "s" return out def _load_paradigms_any(path): import json with open(path, "r", encoding="utf-8") as f: payload = json.load(f) # Case A: new schema with top-level dict {"paradigms": [...]} if isinstance(payload, dict) and "paradigms" in payload: paradigms = [] for p in payload["paradigms"]: stems = set(p["stems"]) suffixes = _deserialize_suffixes_from_json(p["suffixes"]) paradigms.append((stems, suffixes)) meta = payload.get("meta", {}) return paradigms, meta # Case B: older “list of pairs” JSON [[stems, suffixes], ...] if isinstance(payload, list) and payload and isinstance(payload[0], list): paradigms = [] for stems, suffixes in payload: stems = set(stems) # suffixes may be ["", ["er", ["", "s"]], "ing"] or already strings norm = _deserialize_suffixes_from_json(suffixes) paradigms.append((stems, norm)) return paradigms, {} # Case C: already python-native structure (rare if not using JSON) if isinstance(payload, list) and payload and isinstance(payload[0], (list, tuple)) and len(payload[0]) == 2: return payload, {} raise ValueError("Unrecognized paradigms.json format") # ---------------------------- # Paradigm-based segmenter # ---------------------------- class ParadigmFinderSegmenter: def __init__(self, paradigms, lowercase=True, space_punct=True): self.paradigms = paradigms self.lowercase = lowercase self.space_punct = space_punct def _preprocess(self, text: str) -> str: s = text if self.lowercase: s = s.lower() if self.space_punct: s = re.sub(r"([^\w\s'])", r" \1 ", s) s = re.sub(r"\s+", " ", s).strip() return s # faithful to your segmentation logic def _segment_word(self, word: str, fallback=True, top_k=20) -> List[str]: def match_suffixes(suffixes, remainder): for suffix in suffixes: if isinstance(suffix, (tuple, list)): base, nested = suffix if remainder.startswith(base): sub = remainder[len(base):] nested_result = match_suffixes(nested, sub) if nested_result is not None: return [base] + nested_result elif remainder == suffix: return [suffix] if suffix else [] return None for stems, suffixes in self.paradigms: for stem in stems: if word.startswith(stem): remainder = word[len(stem):] matched_suffix = match_suffixes(suffixes, remainder) if matched_suffix is not None: return [stem] + matched_suffix if fallback: candidates = self.paradigms[:top_k] longest = "" def collect_flat(sfx): for s in sfx: if isinstance(s, (tuple, list)): yield s[0] yield from collect_flat(s[1]) else: yield s for _, suffixes in candidates: for suffix in collect_flat(suffixes): if word.endswith(suffix) and len(suffix) > len(longest): longest = suffix if longest: stem = word[:-len(longest)] return [stem, longest] return [word] def segment_with_alignment(self, raw_text: str) -> Tuple[str, List[Optional[int]]]: """ Preprocess + segment; return segmented text and a char map from segmented text back to raw indices. """ # 1) Preprocess with alignment pre_chars, pre_map = [], [] s = raw_text.lower() if self.lowercase else raw_text out, out_map = [], [] # insert spaces around punctuation (if enabled), tracking alignment for i, ch in enumerate(s): if self.space_punct and re.match(r"[^\w\s']", ch): out.append(" "); out_map.append(None) out.append(ch); out_map.append(i) out.append(" "); out_map.append(None) else: out.append(ch); out_map.append(i) # collapse/strip spaces pre = [] pre2raw = [] prev_space = False for ch, m in zip(out, out_map): if ch.isspace(): if not prev_space: pre.append(" "); pre2raw.append(None) prev_space = True else: pre.append(ch); pre2raw.append(m); prev_space = False if pre and pre[0] == " ": pre.pop(0); pre2raw.pop(0) if pre and pre[-1] == " ": pre.pop(); pre2raw.pop() norm = "".join(pre) # 2) Segment by paradigms, preserving alignment seg_chars, seg_map = [], [] i = 0 n = len(norm) while i < n: while i < n and norm[i].isspace(): i += 1 if i >= n: break j = i while j < n and not norm[j].isspace(): j += 1 token = norm[i:j] token_map = pre2raw[i:j] parts = self._segment_word(token, fallback=True) # robust emission: consume all chars exactly once pos = 0 for p_index, part in enumerate(parts): L = len(part) # clamp to remaining length L = min(L, len(token) - pos) if L <= 0: continue for k in range(L): seg_chars.append(token[pos + k]) seg_map.append(token_map[pos + k]) pos += L if p_index < len(parts) - 1: seg_chars.append(" "); seg_map.append(None) # inter-token space i = j while i < n and norm[i].isspace(): i += 1 if i < n: seg_chars.append(" "); seg_map.append(None) # final collapse (defensive) final = [] final_map = [] prev_space = False for ch, m in zip(seg_chars, seg_map): if ch.isspace(): if not prev_space: final.append(" "); final_map.append(None); prev_space = True else: final.append(ch); final_map.append(m); prev_space = False if final and final[0] == " ": final.pop(0); final_map.pop(0) if final and final[-1] == " ": final.pop(); final_map.pop() return "".join(final), final_map # ---------------------------- # Offset remapping helper # ---------------------------- def remap_offsets_to_raw(offsets: List[Tuple[int,int]], pre2raw: List[Optional[int]]) -> List[Tuple[int,int]]: mapped = [] L = len(pre2raw) for s,e in offsets: s = max(0, min(s, L)); e = max(0, min(e, L)) rs = re_ = None t = s while t < e and rs is None: if pre2raw[t] is not None: rs = pre2raw[t] t += 1 t = e - 1 while t >= s and re_ is None: if pre2raw[t] is not None: re_ = pre2raw[t] + 1 t -= 1 mapped.append((rs if rs is not None else 0, re_ if re_ is not None else 0)) return mapped def _coerce_to_str(x): # common cases first if isinstance(x, str): return x if isinstance(x, dict): for key in ("text", "sentence", "input", "prompt"): if key in x and isinstance(x[key], str): return x[key] # fallback: join any stringy values vals = [v for v in x.values() if isinstance(v, str)] if vals: return " ".join(vals) return str(x) if isinstance(x, (list, tuple)): # prefer first/last string element if present for pick in (0, -1): try: v = x[pick] if isinstance(v, str): return v except Exception: pass # else join all string elements parts = [v for v in x if isinstance(v, str)] if parts: return " ".join(parts) return str(x) # final fallback return str(x) # ---------------------------- # Public wrapper # ---------------------------- class ParadigmTokenizerWrapper(PreTrainedTokenizerFast): slow_tokenizer_class = None def __init__(self, *args, **kwargs): name_or_path = kwargs.get("name_or_path", None) if name_or_path is None and len(args) > 0 and isinstance(args[0], str): name_or_path = args[0] if "tokenizer_file" not in kwargs and "tokenizer_object" not in kwargs and name_or_path is not None: tf = os.path.join(name_or_path, "tokenizer.json") if not os.path.isfile(tf): raise FileNotFoundError(f"Expected tokenizer.json at {tf}") kwargs["tokenizer_file"] = tf super().__init__(*args, **kwargs) repo_id_or_path = kwargs.get("name_or_path", getattr(self, "name_or_path", None)) \ or os.path.dirname(getattr(self, "tokenizer_file", "")) or "." revision = kwargs.get("revision", None) cfg = {"lowercase": True, "space_punct": True} ppath = _get_repo_file(repo_id_or_path, "paradigms.json", revision) self.paradigms, self.paradigms_meta = _load_paradigms_any(ppath) cpath = _get_repo_file(repo_id_or_path, "preprocess_config.json", revision) cfg_path_exists = os.path.exists(cpath) # when local path returned with open(cpath, "r", encoding="utf-8") as f: cfg.update(json.load(f)) self.segmenter = ParadigmFinderSegmenter( paradigms=self.paradigms, lowercase=cfg.get("lowercase", True), space_punct=cfg.get("space_punct", True), ) def __call__(self, text, **kwargs): # 1) fast path: already a plain string if isinstance(text, str): seg, _ = self.segmenter.segment_with_alignment(text) return super().__call__(seg, **kwargs) # 2) dicts: coerce to a single string (don't iterate keys!) if isinstance(text, dict): s = _coerce_to_str(text) seg, _ = self.segmenter.segment_with_alignment(s) return super().__call__(seg, **kwargs) # 3) sequences (list/tuple/etc.): coerce each element to a string try: items = list(text) except TypeError: s = _coerce_to_str(text) seg, _ = self.segmenter.segment_with_alignment(s) return super().__call__(seg, **kwargs) segs = [] for t in items: s = _coerce_to_str(t) seg, _ = self.segmenter.segment_with_alignment(s) segs.append(seg) return super().__call__(segs, **kwargs) def tokenize(self, text, **kwargs): if isinstance(text, str): seg, _ = self.segmenter.segment_with_alignment(text) # <-- fix here return super().tokenize(seg, **kwargs) if isinstance(text, dict): s = _coerce_to_str(text) seg, _ = self.segmenter.segment_with_alignment(s) return super().tokenize(seg, **kwargs) try: items = list(text) except TypeError: s = _coerce_to_str(text) seg, _ = self.segmenter.segment_with_alignment(s) return super().tokenize(seg, **kwargs) out = [] for t in items: s = _coerce_to_str(t) seg, _ = self.segmenter.segment_with_alignment(s) out.extend(super().tokenize(seg, **kwargs)) return out