# tokenizer.py from typing import List, Tuple, Optional, Union, Dict, Any import os, json from transformers import PreTrainedTokenizerFast from huggingface_hub import hf_hub_download from .syllabic_pretokenizer import ( Preprocessor, preprocess_and_segment_with_alignment, remap_offsets_to_raw, ) def _get_repo_file(repo_id_or_path: str, filename: str, revision: Optional[str] = None) -> str: local = os.path.join(repo_id_or_path, filename) if os.path.exists(local): return local return hf_hub_download(repo_id=repo_id_or_path, filename=filename, revision=revision) def _coerce_to_str(x): # common cases first if isinstance(x, str): return x if isinstance(x, dict): for key in ("text", "sentence", "input", "prompt"): if key in x and isinstance(x[key], str): return x[key] # fallback: join any stringy values vals = [v for v in x.values() if isinstance(v, str)] if vals: return " ".join(vals) return str(x) if isinstance(x, (list, tuple)): # prefer first/last string element if present for pick in (0, -1): try: v = x[pick] if isinstance(v, str): return v except Exception: pass # else join all string elements parts = [v for v in x if isinstance(v, str)] if parts: return " ".join(parts) return str(x) # final fallback return str(x) class SyllabicTokenizerWrapper(PreTrainedTokenizerFast): slow_tokenizer_class = None def __init__(self, *args, **kwargs): name_or_path = kwargs.get("name_or_path") or (args[0] if args and isinstance(args[0], str) else None) if "tokenizer_file" not in kwargs and name_or_path: tf = os.path.join(name_or_path, "tokenizer.json") if not os.path.isfile(tf): raise FileNotFoundError(f"Expected tokenizer.json at {tf}") kwargs["tokenizer_file"] = tf super().__init__(*args, **kwargs) hf_dir = kwargs.get("name_or_path", getattr(self, "name_or_path", None)) \ or os.path.dirname(getattr(self, "tokenizer_file", "")) or "." revision = kwargs.get("revision", None) cfg_path = _get_repo_file(hf_dir, "preprocess_config.json", revision) if not os.path.exists(cfg_path): raise FileNotFoundError(f"Missing preprocess_config.json in {hf_dir}.") with open(cfg_path, "r", encoding="utf-8") as f: self.pre_cfg = json.load(f) self.preprocessor = Preprocessor(**self.pre_cfg) def _segment_one(self, text: str) -> Tuple[str, List[Optional[int]]]: return preprocess_and_segment_with_alignment(text, self.preprocessor) def __call__(self, text: Union[str, List[str]], **kwargs) -> Dict[str, Any]: # NOTE: correct HF kwarg spelling want_offsets = kwargs.get("return_offsets_mapping", False) if isinstance(text, str): seg, seg_map = self._segment_one(text) enc = super().__call__(seg, **kwargs) if want_offsets and "offset_mapping" in enc: enc["raw_offset_mapping"] = remap_offsets_to_raw(enc["offset_mapping"], seg_map) return enc if isinstance(text, dict): s = _coerce_to_str(text) seg, seg_map = self._segment_one(s) enc = super().__call__(seg, **kwargs) if want_offsets and "offset_mapping" in enc: enc["raw_offset_mapping"] = remap_offsets_to_raw(enc["offset_mapping"], seg_map) return enc # sequences / other iterables try: items = list(text) except TypeError: s = _coerce_to_str(text) seg, seg_map = self._segment_one(s) enc = super().__call__(seg, **kwargs) if want_offsets and "offset_mapping" in enc: enc["raw_offset_mapping"] = remap_offsets_to_raw(enc["offset_mapping"], seg_map) return enc segs, maps = [], [] for t in items: s = _coerce_to_str(t) seg, seg_map = self._segment_one(s) segs.append(seg); maps.append(seg_map) enc = super().__call__(segs, **kwargs) if want_offsets and "offset_mapping" in enc: enc["raw_offset_mapping"] = [ remap_offsets_to_raw(off, m) for off, m in zip(enc["offset_mapping"], maps) ] return enc def tokenize(self, text: Union[str, List[str]], **kwargs): if isinstance(text, str): seg, _ = self._segment_one(text) return super().tokenize(seg, **kwargs) if isinstance(text, dict): s = _coerce_to_str(text) seg, _ = self._segment_one(s) return super().tokenize(seg, **kwargs) try: items = list(text) except TypeError: s = _coerce_to_str(text) seg, _ = self._segment_one(s) return super().tokenize(seg, **kwargs) out: List[str] = [] for t in items: s = _coerce_to_str(t) seg, _ = self._segment_one(s) out.extend(super().tokenize(seg, **kwargs)) return out