| |
| |
|
|
| from typing import List, Tuple, Optional, Union, Dict, Any |
| import os, json, re |
| from transformers import PreTrainedTokenizerFast |
| from huggingface_hub import hf_hub_download |
|
|
| def _get_repo_file(repo_id_or_path: str, filename: str, revision: Optional[str] = None) -> str: |
| """ |
| If `repo_id_or_path` is a local folder and the file exists there, return its path. |
| Otherwise download it from the Hub and return the cached path. |
| """ |
| local = os.path.join(repo_id_or_path, filename) |
| if os.path.exists(local): |
| return local |
| return hf_hub_download(repo_id=repo_id_or_path, filename=filename, revision=revision) |
|
|
| def _deserialize_suffixes_from_json(sfx_list): |
| out = set() |
| for item in sfx_list: |
| if isinstance(item, list): |
| |
| base, nested = item |
| out.add((base, frozenset(nested))) |
| else: |
| out.add(item) |
| return out |
|
|
| def _load_paradigms_any(path): |
| import json |
| with open(path, "r", encoding="utf-8") as f: |
| payload = json.load(f) |
|
|
| |
| if isinstance(payload, dict) and "paradigms" in payload: |
| paradigms = [] |
| for p in payload["paradigms"]: |
| stems = set(p["stems"]) |
| suffixes = _deserialize_suffixes_from_json(p["suffixes"]) |
| paradigms.append((stems, suffixes)) |
| meta = payload.get("meta", {}) |
| return paradigms, meta |
|
|
| |
| if isinstance(payload, list) and payload and isinstance(payload[0], list): |
| paradigms = [] |
| for stems, suffixes in payload: |
| stems = set(stems) |
| |
| norm = _deserialize_suffixes_from_json(suffixes) |
| paradigms.append((stems, norm)) |
| return paradigms, {} |
|
|
| |
| if isinstance(payload, list) and payload and isinstance(payload[0], (list, tuple)) and len(payload[0]) == 2: |
| return payload, {} |
|
|
| raise ValueError("Unrecognized paradigms.json format") |
|
|
| |
| |
| |
| class ParadigmFinderSegmenter: |
| def __init__(self, paradigms, lowercase=True, space_punct=True): |
| self.paradigms = paradigms |
| self.lowercase = lowercase |
| self.space_punct = space_punct |
|
|
| def _preprocess(self, text: str) -> str: |
| s = text |
| if self.lowercase: |
| s = s.lower() |
| if self.space_punct: |
| s = re.sub(r"([^\w\s'])", r" \1 ", s) |
| s = re.sub(r"\s+", " ", s).strip() |
| return s |
|
|
| |
| def _segment_word(self, word: str, fallback=True, top_k=20) -> List[str]: |
| def match_suffixes(suffixes, remainder): |
| for suffix in suffixes: |
| if isinstance(suffix, (tuple, list)): |
| base, nested = suffix |
| if remainder.startswith(base): |
| sub = remainder[len(base):] |
| nested_result = match_suffixes(nested, sub) |
| if nested_result is not None: |
| return [base] + nested_result |
| elif remainder == suffix: |
| return [suffix] if suffix else [] |
| return None |
|
|
| for stems, suffixes in self.paradigms: |
| for stem in stems: |
| if word.startswith(stem): |
| remainder = word[len(stem):] |
| matched_suffix = match_suffixes(suffixes, remainder) |
| if matched_suffix is not None: |
| return [stem] + matched_suffix |
|
|
| if fallback: |
| candidates = self.paradigms[:top_k] |
| longest = "" |
| def collect_flat(sfx): |
| for s in sfx: |
| if isinstance(s, (tuple, list)): |
| yield s[0] |
| yield from collect_flat(s[1]) |
| else: |
| yield s |
| for _, suffixes in candidates: |
| for suffix in collect_flat(suffixes): |
| if word.endswith(suffix) and len(suffix) > len(longest): |
| longest = suffix |
| if longest: |
| stem = word[:-len(longest)] |
| return [stem, longest] |
|
|
| return [word] |
|
|
| def segment_with_alignment(self, raw_text: str) -> Tuple[str, List[Optional[int]]]: |
| """ |
| Preprocess + segment; return segmented text and a char map from segmented |
| text back to raw indices. |
| """ |
| |
| pre_chars, pre_map = [], [] |
| s = raw_text.lower() if self.lowercase else raw_text |
| out, out_map = [], [] |
|
|
| |
| for i, ch in enumerate(s): |
| if self.space_punct and re.match(r"[^\w\s']", ch): |
| out.append(" "); out_map.append(None) |
| out.append(ch); out_map.append(i) |
| out.append(" "); out_map.append(None) |
| else: |
| out.append(ch); out_map.append(i) |
|
|
| |
| pre = [] |
| pre2raw = [] |
| prev_space = False |
| for ch, m in zip(out, out_map): |
| if ch.isspace(): |
| if not prev_space: |
| pre.append(" "); pre2raw.append(None) |
| prev_space = True |
| else: |
| pre.append(ch); pre2raw.append(m); prev_space = False |
| if pre and pre[0] == " ": pre.pop(0); pre2raw.pop(0) |
| if pre and pre[-1] == " ": pre.pop(); pre2raw.pop() |
| norm = "".join(pre) |
|
|
| |
| seg_chars, seg_map = [], [] |
| i = 0 |
| n = len(norm) |
| while i < n: |
| while i < n and norm[i].isspace(): |
| i += 1 |
| if i >= n: break |
| j = i |
| while j < n and not norm[j].isspace(): |
| j += 1 |
| token = norm[i:j] |
| token_map = pre2raw[i:j] |
| parts = self._segment_word(token, fallback=True) |
|
|
| |
| pos = 0 |
| for p_index, part in enumerate(parts): |
| L = len(part) |
| |
| L = min(L, len(token) - pos) |
| if L <= 0: continue |
| for k in range(L): |
| seg_chars.append(token[pos + k]) |
| seg_map.append(token_map[pos + k]) |
| pos += L |
| if p_index < len(parts) - 1: |
| seg_chars.append(" "); seg_map.append(None) |
| |
| i = j |
| while i < n and norm[i].isspace(): |
| i += 1 |
| if i < n: |
| seg_chars.append(" "); seg_map.append(None) |
|
|
| |
| final = [] |
| final_map = [] |
| prev_space = False |
| for ch, m in zip(seg_chars, seg_map): |
| if ch.isspace(): |
| if not prev_space: |
| final.append(" "); final_map.append(None); prev_space = True |
| else: |
| final.append(ch); final_map.append(m); prev_space = False |
| if final and final[0] == " ": final.pop(0); final_map.pop(0) |
| if final and final[-1] == " ": final.pop(); final_map.pop() |
|
|
| return "".join(final), final_map |
|
|
| |
| |
| |
| def remap_offsets_to_raw(offsets: List[Tuple[int,int]], pre2raw: List[Optional[int]]) -> List[Tuple[int,int]]: |
| mapped = [] |
| L = len(pre2raw) |
| for s,e in offsets: |
| s = max(0, min(s, L)); e = max(0, min(e, L)) |
| rs = re_ = None |
| t = s |
| while t < e and rs is None: |
| if pre2raw[t] is not None: rs = pre2raw[t] |
| t += 1 |
| t = e - 1 |
| while t >= s and re_ is None: |
| if pre2raw[t] is not None: re_ = pre2raw[t] + 1 |
| t -= 1 |
| mapped.append((rs if rs is not None else 0, re_ if re_ is not None else 0)) |
| return mapped |
|
|
| def _coerce_to_str(x): |
| |
| if isinstance(x, str): |
| return x |
| if isinstance(x, dict): |
| for key in ("text", "sentence", "input", "prompt"): |
| if key in x and isinstance(x[key], str): |
| return x[key] |
| |
| vals = [v for v in x.values() if isinstance(v, str)] |
| if vals: |
| return " ".join(vals) |
| return str(x) |
| if isinstance(x, (list, tuple)): |
| |
| for pick in (0, -1): |
| try: |
| v = x[pick] |
| if isinstance(v, str): |
| return v |
| except Exception: |
| pass |
| |
| parts = [v for v in x if isinstance(v, str)] |
| if parts: |
| return " ".join(parts) |
| return str(x) |
| |
| return str(x) |
|
|
| |
| |
| |
| class ParadigmTokenizerWrapper(PreTrainedTokenizerFast): |
| slow_tokenizer_class = None |
| |
| def __init__(self, *args, **kwargs): |
|
|
| name_or_path = kwargs.get("name_or_path", None) |
| if name_or_path is None and len(args) > 0 and isinstance(args[0], str): |
| name_or_path = args[0] |
|
|
| if "tokenizer_file" not in kwargs and "tokenizer_object" not in kwargs and name_or_path is not None: |
| tf = os.path.join(name_or_path, "tokenizer.json") |
| if not os.path.isfile(tf): |
| raise FileNotFoundError(f"Expected tokenizer.json at {tf}") |
| kwargs["tokenizer_file"] = tf |
|
|
| super().__init__(*args, **kwargs) |
|
|
| repo_id_or_path = kwargs.get("name_or_path", getattr(self, "name_or_path", None)) \ |
| or os.path.dirname(getattr(self, "tokenizer_file", "")) or "." |
| revision = kwargs.get("revision", None) |
|
|
| cfg = {"lowercase": True, "space_punct": True} |
| ppath = _get_repo_file(repo_id_or_path, "paradigms.json", revision) |
| self.paradigms, self.paradigms_meta = _load_paradigms_any(ppath) |
|
|
| cpath = _get_repo_file(repo_id_or_path, "preprocess_config.json", revision) |
| cfg_path_exists = os.path.exists(cpath) |
| with open(cpath, "r", encoding="utf-8") as f: |
| cfg.update(json.load(f)) |
|
|
| self.segmenter = ParadigmFinderSegmenter( |
| paradigms=self.paradigms, |
| lowercase=cfg.get("lowercase", True), |
| space_punct=cfg.get("space_punct", True), |
| ) |
|
|
| def __call__(self, text, **kwargs): |
| |
| if isinstance(text, str): |
| seg, _ = self.segmenter.segment_with_alignment(text) |
| return super().__call__(seg, **kwargs) |
| |
| |
| if isinstance(text, dict): |
| s = _coerce_to_str(text) |
| seg, _ = self.segmenter.segment_with_alignment(s) |
| return super().__call__(seg, **kwargs) |
| |
| |
| try: |
| items = list(text) |
| except TypeError: |
| s = _coerce_to_str(text) |
| seg, _ = self.segmenter.segment_with_alignment(s) |
| return super().__call__(seg, **kwargs) |
| |
| segs = [] |
| for t in items: |
| s = _coerce_to_str(t) |
| seg, _ = self.segmenter.segment_with_alignment(s) |
| segs.append(seg) |
| return super().__call__(segs, **kwargs) |
| |
| |
| def tokenize(self, text, **kwargs): |
| if isinstance(text, str): |
| seg, _ = self.segmenter.segment_with_alignment(text) |
| return super().tokenize(seg, **kwargs) |
| |
| if isinstance(text, dict): |
| s = _coerce_to_str(text) |
| seg, _ = self.segmenter.segment_with_alignment(s) |
| return super().tokenize(seg, **kwargs) |
| |
| try: |
| items = list(text) |
| except TypeError: |
| s = _coerce_to_str(text) |
| seg, _ = self.segmenter.segment_with_alignment(s) |
| return super().tokenize(seg, **kwargs) |
| |
| out = [] |
| for t in items: |
| s = _coerce_to_str(t) |
| seg, _ = self.segmenter.segment_with_alignment(s) |
| out.extend(super().tokenize(seg, **kwargs)) |
| return out |