|
|
|
|
|
from typing import List, Tuple, Optional, Union, Dict, Any |
|
|
import os, json |
|
|
from transformers import PreTrainedTokenizerFast |
|
|
from huggingface_hub import hf_hub_download |
|
|
from .syllabic_pretokenizer import ( |
|
|
Preprocessor, |
|
|
preprocess_and_segment_with_alignment, |
|
|
remap_offsets_to_raw, |
|
|
) |
|
|
|
|
|
def _get_repo_file(repo_id_or_path: str, filename: str, revision: Optional[str] = None) -> str: |
|
|
|
|
|
local = os.path.join(repo_id_or_path, filename) |
|
|
if os.path.exists(local): |
|
|
return local |
|
|
return hf_hub_download(repo_id=repo_id_or_path, filename=filename, revision=revision) |
|
|
|
|
|
def _coerce_to_str(x): |
|
|
|
|
|
if isinstance(x, str): |
|
|
return x |
|
|
if isinstance(x, dict): |
|
|
for key in ("text", "sentence", "input", "prompt"): |
|
|
if key in x and isinstance(x[key], str): |
|
|
return x[key] |
|
|
|
|
|
vals = [v for v in x.values() if isinstance(v, str)] |
|
|
if vals: |
|
|
return " ".join(vals) |
|
|
return str(x) |
|
|
if isinstance(x, (list, tuple)): |
|
|
|
|
|
for pick in (0, -1): |
|
|
try: |
|
|
v = x[pick] |
|
|
if isinstance(v, str): |
|
|
return v |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
parts = [v for v in x if isinstance(v, str)] |
|
|
if parts: |
|
|
return " ".join(parts) |
|
|
return str(x) |
|
|
|
|
|
return str(x) |
|
|
|
|
|
class SyllabicTokenizerWrapper(PreTrainedTokenizerFast): |
|
|
slow_tokenizer_class = None |
|
|
|
|
|
def __init__(self, *args, **kwargs): |
|
|
name_or_path = kwargs.get("name_or_path") or (args[0] if args and isinstance(args[0], str) else None) |
|
|
if "tokenizer_file" not in kwargs and name_or_path: |
|
|
tf = os.path.join(name_or_path, "tokenizer.json") |
|
|
if not os.path.isfile(tf): |
|
|
raise FileNotFoundError(f"Expected tokenizer.json at {tf}") |
|
|
kwargs["tokenizer_file"] = tf |
|
|
super().__init__(*args, **kwargs) |
|
|
|
|
|
hf_dir = kwargs.get("name_or_path", getattr(self, "name_or_path", None)) \ |
|
|
or os.path.dirname(getattr(self, "tokenizer_file", "")) or "." |
|
|
revision = kwargs.get("revision", None) |
|
|
|
|
|
cfg_path = _get_repo_file(hf_dir, "preprocess_config.json", revision) |
|
|
if not os.path.exists(cfg_path): |
|
|
raise FileNotFoundError(f"Missing preprocess_config.json in {hf_dir}.") |
|
|
with open(cfg_path, "r", encoding="utf-8") as f: |
|
|
self.pre_cfg = json.load(f) |
|
|
|
|
|
self.preprocessor = Preprocessor(**self.pre_cfg) |
|
|
|
|
|
def _segment_one(self, text: str) -> Tuple[str, List[Optional[int]]]: |
|
|
return preprocess_and_segment_with_alignment(text, self.preprocessor) |
|
|
|
|
|
def __call__(self, text: Union[str, List[str]], **kwargs) -> Dict[str, Any]: |
|
|
|
|
|
want_offsets = kwargs.get("return_offsets_mapping", False) |
|
|
|
|
|
if isinstance(text, str): |
|
|
seg, seg_map = self._segment_one(text) |
|
|
enc = super().__call__(seg, **kwargs) |
|
|
if want_offsets and "offset_mapping" in enc: |
|
|
enc["raw_offset_mapping"] = remap_offsets_to_raw(enc["offset_mapping"], seg_map) |
|
|
return enc |
|
|
|
|
|
if isinstance(text, dict): |
|
|
s = _coerce_to_str(text) |
|
|
seg, seg_map = self._segment_one(s) |
|
|
enc = super().__call__(seg, **kwargs) |
|
|
if want_offsets and "offset_mapping" in enc: |
|
|
enc["raw_offset_mapping"] = remap_offsets_to_raw(enc["offset_mapping"], seg_map) |
|
|
return enc |
|
|
|
|
|
|
|
|
try: |
|
|
items = list(text) |
|
|
except TypeError: |
|
|
s = _coerce_to_str(text) |
|
|
seg, seg_map = self._segment_one(s) |
|
|
enc = super().__call__(seg, **kwargs) |
|
|
if want_offsets and "offset_mapping" in enc: |
|
|
enc["raw_offset_mapping"] = remap_offsets_to_raw(enc["offset_mapping"], seg_map) |
|
|
return enc |
|
|
|
|
|
segs, maps = [], [] |
|
|
for t in items: |
|
|
s = _coerce_to_str(t) |
|
|
seg, seg_map = self._segment_one(s) |
|
|
segs.append(seg); maps.append(seg_map) |
|
|
|
|
|
enc = super().__call__(segs, **kwargs) |
|
|
if want_offsets and "offset_mapping" in enc: |
|
|
enc["raw_offset_mapping"] = [ |
|
|
remap_offsets_to_raw(off, m) for off, m in zip(enc["offset_mapping"], maps) |
|
|
] |
|
|
return enc |
|
|
|
|
|
def tokenize(self, text: Union[str, List[str]], **kwargs): |
|
|
if isinstance(text, str): |
|
|
seg, _ = self._segment_one(text) |
|
|
return super().tokenize(seg, **kwargs) |
|
|
|
|
|
if isinstance(text, dict): |
|
|
s = _coerce_to_str(text) |
|
|
seg, _ = self._segment_one(s) |
|
|
return super().tokenize(seg, **kwargs) |
|
|
|
|
|
try: |
|
|
items = list(text) |
|
|
except TypeError: |
|
|
s = _coerce_to_str(text) |
|
|
seg, _ = self._segment_one(s) |
|
|
return super().tokenize(seg, **kwargs) |
|
|
|
|
|
out: List[str] = [] |
|
|
for t in items: |
|
|
s = _coerce_to_str(t) |
|
|
seg, _ = self._segment_one(s) |
|
|
out.extend(super().tokenize(seg, **kwargs)) |
|
|
return out |
|
|
|