PyTorch
gpt2
gpt2-10M-syllitok-eng / tokenizer.py
achille-fusco's picture
Upload folder using huggingface_hub
064b576 verified
raw
history blame
3.58 kB
# tokenizer.py
from typing import List, Tuple, Optional, Union, Dict, Any
import os, json
from transformers import PreTrainedTokenizerFast
from .syllabic_pretokenizer import (
Preprocessor,
preprocess_and_segment_with_alignment,
remap_offsets_to_raw,
)
class SyllabicTokenizerWrapper(PreTrainedTokenizerFast):
"""
A HF-compatible tokenizer that FIRST applies your syllabic segmentation,
then delegates to the underlying fast tokenizer from tokenizer.json.
Required files in the same directory:
- tokenizer.json, tokenizer_config.json, special_tokens_map.json
- preprocess_config.json (with the Preprocessor flags)
"""
slow_tokenizer_class = None # required by HF when no slow version exists
def __init__(self, *args, **kwargs):
# Ensure we load the fast tokenizer directly (no slow->fast conversion).
name_or_path = kwargs.get("name_or_path") or (args[0] if args and isinstance(args[0], str) else None)
if "tokenizer_file" not in kwargs and name_or_path:
tf = os.path.join(name_or_path, "tokenizer.json")
if not os.path.isfile(tf):
raise FileNotFoundError(f"Expected tokenizer.json at {tf}")
kwargs["tokenizer_file"] = tf
super().__init__(*args, **kwargs)
# Resolve the directory where the artifacts live
hf_dir = kwargs.get("name_or_path", getattr(self, "name_or_path", None)) \
or os.path.dirname(getattr(self, "tokenizer_file", "")) or "."
# Load preprocessing flags saved during training
cfg_path = os.path.join(hf_dir, "preprocess_config.json")
if not os.path.exists(cfg_path):
raise FileNotFoundError(
f"Missing preprocess_config.json in {hf_dir}. "
f"Did you save it during tokenizer training?"
)
with open(cfg_path, "r", encoding="utf-8") as f:
self.pre_cfg = json.load(f)
self.preprocessor = Preprocessor(**self.pre_cfg)
# --- core segmentation helpers ---
def _segment_one(self, text: str) -> Tuple[str, List[Optional[int]]]:
return preprocess_and_segment_with_alignment(text, self.preprocessor)
# --- public API overrides ---
def __call__(self, text: Union[str, List[str]], **kwargs) -> Dict[str, Any]:
"""
Segments -> calls the fast tokenizer (super) with segmented text.
"""
want_offset = kwargs.pop("return_offset_mapping", False)
if isinstance(text, str):
seg, seg_map = self._segment_one(text)
enc = super().__call__(seg, **kwargs)
return enc
elif isinstance(text, (list, tuple)):
segs = []
for t in text:
seg, maps = self._segment_one(t)
segs.append(seg)
enc = super().__call__(segs, **kwargs)
return enc
else:
raise TypeError("text must be str or List[str]")
def tokenize(self, text: Union[str, List[str]], **kwargs):
"""
Also intercept manual .tokenize() to ensure segmentation happens first.
"""
if isinstance(text, str):
seg, _ = self._segment_one(text)
return super().tokenize(seg, **kwargs)
elif isinstance(text, list):
out: List[str] = []
for t in text:
seg, _ = self._segment_one(t)
out.extend(super().tokenize(seg, **kwargs))
return out
else:
raise TypeError("tokenize() expects str or List[str]")