|
|
|
|
|
|
|
|
|
|
|
from typing import List, Tuple, Optional, Union, Dict, Any |
|
|
import os, json, re |
|
|
from transformers import PreTrainedTokenizerFast |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
def _get_repo_file(repo_id_or_path: str, filename: str, revision: Optional[str] = None) -> str: |
|
|
""" |
|
|
If `repo_id_or_path` is a local folder and the file exists there, return its path. |
|
|
Otherwise download it from the Hub and return the cached path. |
|
|
""" |
|
|
local = os.path.join(repo_id_or_path, filename) |
|
|
if os.path.exists(local): |
|
|
return local |
|
|
return hf_hub_download(repo_id=repo_id_or_path, filename=filename, revision=revision) |
|
|
|
|
|
def _deserialize_suffixes_from_json(sfx_list): |
|
|
out = set() |
|
|
for item in sfx_list: |
|
|
if isinstance(item, list): |
|
|
|
|
|
base, nested = item |
|
|
out.add((base, frozenset(nested))) |
|
|
else: |
|
|
out.add(item) |
|
|
return out |
|
|
|
|
|
def _load_paradigms_any(path): |
|
|
import json |
|
|
with open(path, "r", encoding="utf-8") as f: |
|
|
payload = json.load(f) |
|
|
|
|
|
|
|
|
if isinstance(payload, dict) and "paradigms" in payload: |
|
|
paradigms = [] |
|
|
for p in payload["paradigms"]: |
|
|
stems = set(p["stems"]) |
|
|
suffixes = _deserialize_suffixes_from_json(p["suffixes"]) |
|
|
paradigms.append((stems, suffixes)) |
|
|
meta = payload.get("meta", {}) |
|
|
return paradigms, meta |
|
|
|
|
|
|
|
|
if isinstance(payload, list) and payload and isinstance(payload[0], list): |
|
|
paradigms = [] |
|
|
for stems, suffixes in payload: |
|
|
stems = set(stems) |
|
|
|
|
|
norm = _deserialize_suffixes_from_json(suffixes) |
|
|
paradigms.append((stems, norm)) |
|
|
return paradigms, {} |
|
|
|
|
|
|
|
|
if isinstance(payload, list) and payload and isinstance(payload[0], (list, tuple)) and len(payload[0]) == 2: |
|
|
return payload, {} |
|
|
|
|
|
raise ValueError("Unrecognized paradigms.json format") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ParadigmFinderSegmenter: |
|
|
def __init__(self, paradigms, lowercase=True, space_punct=True): |
|
|
self.paradigms = paradigms |
|
|
self.lowercase = lowercase |
|
|
self.space_punct = space_punct |
|
|
|
|
|
def _preprocess(self, text: str) -> str: |
|
|
s = text |
|
|
if self.lowercase: |
|
|
s = s.lower() |
|
|
if self.space_punct: |
|
|
s = re.sub(r"([^\w\s'])", r" \1 ", s) |
|
|
s = re.sub(r"\s+", " ", s).strip() |
|
|
return s |
|
|
|
|
|
|
|
|
def _segment_word(self, word: str, fallback=True, top_k=20) -> List[str]: |
|
|
def match_suffixes(suffixes, remainder): |
|
|
for suffix in suffixes: |
|
|
if isinstance(suffix, (tuple, list)): |
|
|
base, nested = suffix |
|
|
if remainder.startswith(base): |
|
|
sub = remainder[len(base):] |
|
|
nested_result = match_suffixes(nested, sub) |
|
|
if nested_result is not None: |
|
|
return [base] + nested_result |
|
|
elif remainder == suffix: |
|
|
return [suffix] if suffix else [] |
|
|
return None |
|
|
|
|
|
for stems, suffixes in self.paradigms: |
|
|
for stem in stems: |
|
|
if word.startswith(stem): |
|
|
remainder = word[len(stem):] |
|
|
matched_suffix = match_suffixes(suffixes, remainder) |
|
|
if matched_suffix is not None: |
|
|
return [stem] + matched_suffix |
|
|
|
|
|
if fallback: |
|
|
candidates = self.paradigms[:top_k] |
|
|
longest = "" |
|
|
def collect_flat(sfx): |
|
|
for s in sfx: |
|
|
if isinstance(s, (tuple, list)): |
|
|
yield s[0] |
|
|
yield from collect_flat(s[1]) |
|
|
else: |
|
|
yield s |
|
|
for _, suffixes in candidates: |
|
|
for suffix in collect_flat(suffixes): |
|
|
if word.endswith(suffix) and len(suffix) > len(longest): |
|
|
longest = suffix |
|
|
if longest: |
|
|
stem = word[:-len(longest)] |
|
|
return [stem, longest] |
|
|
|
|
|
return [word] |
|
|
|
|
|
def segment_with_alignment(self, raw_text: str) -> Tuple[str, List[Optional[int]]]: |
|
|
""" |
|
|
Preprocess + segment; return segmented text and a char map from segmented |
|
|
text back to raw indices. |
|
|
""" |
|
|
|
|
|
pre_chars, pre_map = [], [] |
|
|
s = raw_text.lower() if self.lowercase else raw_text |
|
|
out, out_map = [], [] |
|
|
|
|
|
|
|
|
for i, ch in enumerate(s): |
|
|
if self.space_punct and re.match(r"[^\w\s']", ch): |
|
|
out.append(" "); out_map.append(None) |
|
|
out.append(ch); out_map.append(i) |
|
|
out.append(" "); out_map.append(None) |
|
|
else: |
|
|
out.append(ch); out_map.append(i) |
|
|
|
|
|
|
|
|
pre = [] |
|
|
pre2raw = [] |
|
|
prev_space = False |
|
|
for ch, m in zip(out, out_map): |
|
|
if ch.isspace(): |
|
|
if not prev_space: |
|
|
pre.append(" "); pre2raw.append(None) |
|
|
prev_space = True |
|
|
else: |
|
|
pre.append(ch); pre2raw.append(m); prev_space = False |
|
|
if pre and pre[0] == " ": pre.pop(0); pre2raw.pop(0) |
|
|
if pre and pre[-1] == " ": pre.pop(); pre2raw.pop() |
|
|
norm = "".join(pre) |
|
|
|
|
|
|
|
|
seg_chars, seg_map = [], [] |
|
|
i = 0 |
|
|
n = len(norm) |
|
|
while i < n: |
|
|
while i < n and norm[i].isspace(): |
|
|
i += 1 |
|
|
if i >= n: break |
|
|
j = i |
|
|
while j < n and not norm[j].isspace(): |
|
|
j += 1 |
|
|
token = norm[i:j] |
|
|
token_map = pre2raw[i:j] |
|
|
parts = self._segment_word(token, fallback=True) |
|
|
|
|
|
|
|
|
pos = 0 |
|
|
for p_index, part in enumerate(parts): |
|
|
L = len(part) |
|
|
|
|
|
L = min(L, len(token) - pos) |
|
|
if L <= 0: continue |
|
|
for k in range(L): |
|
|
seg_chars.append(token[pos + k]) |
|
|
seg_map.append(token_map[pos + k]) |
|
|
pos += L |
|
|
if p_index < len(parts) - 1: |
|
|
seg_chars.append(" "); seg_map.append(None) |
|
|
|
|
|
i = j |
|
|
while i < n and norm[i].isspace(): |
|
|
i += 1 |
|
|
if i < n: |
|
|
seg_chars.append(" "); seg_map.append(None) |
|
|
|
|
|
|
|
|
final = [] |
|
|
final_map = [] |
|
|
prev_space = False |
|
|
for ch, m in zip(seg_chars, seg_map): |
|
|
if ch.isspace(): |
|
|
if not prev_space: |
|
|
final.append(" "); final_map.append(None); prev_space = True |
|
|
else: |
|
|
final.append(ch); final_map.append(m); prev_space = False |
|
|
if final and final[0] == " ": final.pop(0); final_map.pop(0) |
|
|
if final and final[-1] == " ": final.pop(); final_map.pop() |
|
|
|
|
|
return "".join(final), final_map |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def remap_offsets_to_raw(offsets: List[Tuple[int,int]], pre2raw: List[Optional[int]]) -> List[Tuple[int,int]]: |
|
|
mapped = [] |
|
|
L = len(pre2raw) |
|
|
for s,e in offsets: |
|
|
s = max(0, min(s, L)); e = max(0, min(e, L)) |
|
|
rs = re_ = None |
|
|
t = s |
|
|
while t < e and rs is None: |
|
|
if pre2raw[t] is not None: rs = pre2raw[t] |
|
|
t += 1 |
|
|
t = e - 1 |
|
|
while t >= s and re_ is None: |
|
|
if pre2raw[t] is not None: re_ = pre2raw[t] + 1 |
|
|
t -= 1 |
|
|
mapped.append((rs if rs is not None else 0, re_ if re_ is not None else 0)) |
|
|
return mapped |
|
|
|
|
|
def _coerce_to_str(x): |
|
|
|
|
|
if isinstance(x, str): |
|
|
return x |
|
|
if isinstance(x, dict): |
|
|
for key in ("text", "sentence", "input", "prompt"): |
|
|
if key in x and isinstance(x[key], str): |
|
|
return x[key] |
|
|
|
|
|
vals = [v for v in x.values() if isinstance(v, str)] |
|
|
if vals: |
|
|
return " ".join(vals) |
|
|
return str(x) |
|
|
if isinstance(x, (list, tuple)): |
|
|
|
|
|
for pick in (0, -1): |
|
|
try: |
|
|
v = x[pick] |
|
|
if isinstance(v, str): |
|
|
return v |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
parts = [v for v in x if isinstance(v, str)] |
|
|
if parts: |
|
|
return " ".join(parts) |
|
|
return str(x) |
|
|
|
|
|
return str(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ParadigmTokenizerWrapper(PreTrainedTokenizerFast): |
|
|
slow_tokenizer_class = None |
|
|
|
|
|
def __init__(self, *args, **kwargs): |
|
|
|
|
|
name_or_path = kwargs.get("name_or_path", None) |
|
|
if name_or_path is None and len(args) > 0 and isinstance(args[0], str): |
|
|
name_or_path = args[0] |
|
|
|
|
|
if "tokenizer_file" not in kwargs and "tokenizer_object" not in kwargs and name_or_path is not None: |
|
|
tf = os.path.join(name_or_path, "tokenizer.json") |
|
|
if not os.path.isfile(tf): |
|
|
raise FileNotFoundError(f"Expected tokenizer.json at {tf}") |
|
|
kwargs["tokenizer_file"] = tf |
|
|
|
|
|
super().__init__(*args, **kwargs) |
|
|
|
|
|
repo_id_or_path = kwargs.get("name_or_path", getattr(self, "name_or_path", None)) \ |
|
|
or os.path.dirname(getattr(self, "tokenizer_file", "")) or "." |
|
|
revision = kwargs.get("revision", None) |
|
|
|
|
|
cfg = {"lowercase": True, "space_punct": True} |
|
|
ppath = _get_repo_file(repo_id_or_path, "paradigms.json", revision) |
|
|
self.paradigms, self.paradigms_meta = _load_paradigms_any(ppath) |
|
|
|
|
|
cpath = _get_repo_file(repo_id_or_path, "preprocess_config.json", revision) |
|
|
cfg_path_exists = os.path.exists(cpath) |
|
|
with open(cpath, "r", encoding="utf-8") as f: |
|
|
cfg.update(json.load(f)) |
|
|
|
|
|
self.segmenter = ParadigmFinderSegmenter( |
|
|
paradigms=self.paradigms, |
|
|
lowercase=cfg.get("lowercase", True), |
|
|
space_punct=cfg.get("space_punct", True), |
|
|
) |
|
|
|
|
|
def __call__(self, text, **kwargs): |
|
|
|
|
|
if isinstance(text, str): |
|
|
seg, _ = self.segmenter.segment_with_alignment(text) |
|
|
return super().__call__(seg, **kwargs) |
|
|
|
|
|
|
|
|
if isinstance(text, dict): |
|
|
s = _coerce_to_str(text) |
|
|
seg, _ = self.segmenter.segment_with_alignment(s) |
|
|
return super().__call__(seg, **kwargs) |
|
|
|
|
|
|
|
|
try: |
|
|
items = list(text) |
|
|
except TypeError: |
|
|
s = _coerce_to_str(text) |
|
|
seg, _ = self.segmenter.segment_with_alignment(s) |
|
|
return super().__call__(seg, **kwargs) |
|
|
|
|
|
segs = [] |
|
|
for t in items: |
|
|
s = _coerce_to_str(t) |
|
|
seg, _ = self.segmenter.segment_with_alignment(s) |
|
|
segs.append(seg) |
|
|
return super().__call__(segs, **kwargs) |
|
|
|
|
|
|
|
|
def tokenize(self, text, **kwargs): |
|
|
if isinstance(text, str): |
|
|
seg, _ = self.segmenter.segment_with_alignment(text) |
|
|
return super().tokenize(seg, **kwargs) |
|
|
|
|
|
if isinstance(text, dict): |
|
|
s = _coerce_to_str(text) |
|
|
seg, _ = self.segmenter.segment_with_alignment(s) |
|
|
return super().tokenize(seg, **kwargs) |
|
|
|
|
|
try: |
|
|
items = list(text) |
|
|
except TypeError: |
|
|
s = _coerce_to_str(text) |
|
|
seg, _ = self.segmenter.segment_with_alignment(s) |
|
|
return super().tokenize(seg, **kwargs) |
|
|
|
|
|
out = [] |
|
|
for t in items: |
|
|
s = _coerce_to_str(t) |
|
|
seg, _ = self.segmenter.segment_with_alignment(s) |
|
|
out.extend(super().tokenize(seg, **kwargs)) |
|
|
return out |