PyTorch
gpt2
gpt2-10M-syllitok-eng / tokenizer.py
achille-fusco's picture
Update tokenizer.py
7512fea verified
# tokenizer.py
from typing import List, Tuple, Optional, Union, Dict, Any
import os, json
from transformers import PreTrainedTokenizerFast
from huggingface_hub import hf_hub_download
from .syllabic_pretokenizer import (
Preprocessor,
preprocess_and_segment_with_alignment,
remap_offsets_to_raw,
)
def _get_repo_file(repo_id_or_path: str, filename: str, revision: Optional[str] = None) -> str:
local = os.path.join(repo_id_or_path, filename)
if os.path.exists(local):
return local
return hf_hub_download(repo_id=repo_id_or_path, filename=filename, revision=revision)
def _coerce_to_str(x):
# common cases first
if isinstance(x, str):
return x
if isinstance(x, dict):
for key in ("text", "sentence", "input", "prompt"):
if key in x and isinstance(x[key], str):
return x[key]
# fallback: join any stringy values
vals = [v for v in x.values() if isinstance(v, str)]
if vals:
return " ".join(vals)
return str(x)
if isinstance(x, (list, tuple)):
# prefer first/last string element if present
for pick in (0, -1):
try:
v = x[pick]
if isinstance(v, str):
return v
except Exception:
pass
# else join all string elements
parts = [v for v in x if isinstance(v, str)]
if parts:
return " ".join(parts)
return str(x)
# final fallback
return str(x)
class SyllabicTokenizerWrapper(PreTrainedTokenizerFast):
slow_tokenizer_class = None
def __init__(self, *args, **kwargs):
name_or_path = kwargs.get("name_or_path") or (args[0] if args and isinstance(args[0], str) else None)
if "tokenizer_file" not in kwargs and name_or_path:
tf = os.path.join(name_or_path, "tokenizer.json")
if not os.path.isfile(tf):
raise FileNotFoundError(f"Expected tokenizer.json at {tf}")
kwargs["tokenizer_file"] = tf
super().__init__(*args, **kwargs)
hf_dir = kwargs.get("name_or_path", getattr(self, "name_or_path", None)) \
or os.path.dirname(getattr(self, "tokenizer_file", "")) or "."
revision = kwargs.get("revision", None)
cfg_path = _get_repo_file(hf_dir, "preprocess_config.json", revision)
if not os.path.exists(cfg_path):
raise FileNotFoundError(f"Missing preprocess_config.json in {hf_dir}.")
with open(cfg_path, "r", encoding="utf-8") as f:
self.pre_cfg = json.load(f)
self.preprocessor = Preprocessor(**self.pre_cfg)
def _segment_one(self, text: str) -> Tuple[str, List[Optional[int]]]:
return preprocess_and_segment_with_alignment(text, self.preprocessor)
def __call__(self, text: Union[str, List[str]], **kwargs) -> Dict[str, Any]:
# NOTE: correct HF kwarg spelling
want_offsets = kwargs.get("return_offsets_mapping", False)
if isinstance(text, str):
seg, seg_map = self._segment_one(text)
enc = super().__call__(seg, **kwargs)
if want_offsets and "offset_mapping" in enc:
enc["raw_offset_mapping"] = remap_offsets_to_raw(enc["offset_mapping"], seg_map)
return enc
if isinstance(text, dict):
s = _coerce_to_str(text)
seg, seg_map = self._segment_one(s)
enc = super().__call__(seg, **kwargs)
if want_offsets and "offset_mapping" in enc:
enc["raw_offset_mapping"] = remap_offsets_to_raw(enc["offset_mapping"], seg_map)
return enc
# sequences / other iterables
try:
items = list(text)
except TypeError:
s = _coerce_to_str(text)
seg, seg_map = self._segment_one(s)
enc = super().__call__(seg, **kwargs)
if want_offsets and "offset_mapping" in enc:
enc["raw_offset_mapping"] = remap_offsets_to_raw(enc["offset_mapping"], seg_map)
return enc
segs, maps = [], []
for t in items:
s = _coerce_to_str(t)
seg, seg_map = self._segment_one(s)
segs.append(seg); maps.append(seg_map)
enc = super().__call__(segs, **kwargs)
if want_offsets and "offset_mapping" in enc:
enc["raw_offset_mapping"] = [
remap_offsets_to_raw(off, m) for off, m in zip(enc["offset_mapping"], maps)
]
return enc
def tokenize(self, text: Union[str, List[str]], **kwargs):
if isinstance(text, str):
seg, _ = self._segment_one(text)
return super().tokenize(seg, **kwargs)
if isinstance(text, dict):
s = _coerce_to_str(text)
seg, _ = self._segment_one(s)
return super().tokenize(seg, **kwargs)
try:
items = list(text)
except TypeError:
s = _coerce_to_str(text)
seg, _ = self._segment_one(s)
return super().tokenize(seg, **kwargs)
out: List[str] = []
for t in items:
s = _coerce_to_str(t)
seg, _ = self._segment_one(s)
out.extend(super().tokenize(seg, **kwargs))
return out