Spaces:
Sleeping
Sleeping
| """ | |
| Utility helpers for multilingual text preprocessing and lexicon lookup. | |
| This module provides two groups of functionality used by the maze pipeline: | |
| 1) Text and punctuation utilities: | |
| - language-aware punctuation sets (`get_punctuation`) | |
| - sentence/word-list conversion helpers | |
| - punctuation stripping/reattachment for candidate-token handling | |
| - candidate normalization (trim + de-duplicate) | |
| 2) Lexicon access and neighborhood retrieval: | |
| - `Lexicon` loads a frequency-ranked lexicon file (word, length, rank) | |
| - words are bucketed by `(length, frequency_bin)` for fast lookup | |
| - `get_neighbor` returns words with similar length and nearby | |
| frequency-bin groups, which is useful for generating replacement | |
| candidates with comparable lexical difficulty. | |
| """ | |
| import pandas as pd | |
| import random | |
| import csv | |
| import os | |
| from dataclasses import dataclass | |
| from typing import Dict, List, Optional, Set, Tuple, Iterable, Sequence | |
| from functools import lru_cache | |
| import yaml | |
| import string | |
| random.seed(42) | |
| _BASE_PUNCT = { | |
| "latin": set(string.punctuation) | {"“", "”", "‘", "’", "—", "–", "…"}, | |
| "zh": {"。", ",", "!", "?", ":", ";", "、", "(", ")", "《", "》", | |
| "「", "」", "『", "』", "【", "】", "“", "”", "‘", "’", "—", "…"}, | |
| "ja": {"。", "、", "!", "?", "「", "」", "『", "』", "・", "ー", "…"}, | |
| "ko": set(string.punctuation) | {"…", "·", "“", "”"}, | |
| "arabic": set(string.punctuation) | { | |
| "،", "؛", "؟", "٪", "٫", "٬", | |
| "«", "»", "“", "”", "‘", "’", | |
| "…", "—", "–", | |
| }, | |
| } | |
| _LANG_MAP = { | |
| "en": "latin", "de": "latin", "fr": "latin", "es": "latin", | |
| "zh": "zh", "zh-cn": "zh", "zh-hans": "zh", "zh-hant": "zh", | |
| "ja": "ja", "ko": "ko", | |
| "ar": "arabic", "fa": "arabic", "ur": "arabic", | |
| } | |
| def get_punctuation( | |
| language_code: str = "latin", | |
| extra: Optional[Iterable[str]] = None, | |
| remove: Optional[Iterable[str]] = None, | |
| ) -> Set[str]: | |
| key = _LANG_MAP.get(language_code.lower(), language_code.lower()) | |
| punct = set(_BASE_PUNCT.get(key, _BASE_PUNCT["latin"])) | |
| if extra: | |
| punct |= set(extra) | |
| if remove: | |
| punct -= set(remove) | |
| return punct | |
| def load_config(path="config.yaml"): | |
| with open(path, "r", encoding="utf-8") as file: | |
| config = yaml.safe_load(file) | |
| return config or {} | |
| def load_punctuation(punctuation_list): | |
| if not punctuation_list: | |
| return frozenset() | |
| return frozenset(punctuation_list) | |
| _NO_SPACE_LANGS = {"chinese", "zh", "japanese", "ja", "thai", "th"} | |
| def _default_separator(language, fallback=" "): | |
| if not language: | |
| return fallback | |
| normalized = str(language).strip().lower() | |
| return "" if normalized in _NO_SPACE_LANGS else fallback | |
| def _split_sentence(sentence, split_on=None): | |
| if sentence is None: | |
| raise ValueError("There is no sentence to split.") | |
| if split_on is None: | |
| return list(sentence) | |
| if split_on: | |
| return sentence.split(split_on) | |
| def sentences_to_word_lists(sentences, split_on=None): | |
| if not sentences: | |
| return [] | |
| return [_split_sentence(sentence, split_on) for sentence in sentences] | |
| def _combine_words(words, join_with=" ", language=None): | |
| if words is None: | |
| raise ValueError("There is no words to combine.") | |
| if language is not None: | |
| join_with = _default_separator(language, fallback=join_with) | |
| return join_with.join(words) | |
| _NO_SPACE_BEFORE = { | |
| ".", ",", "!", "?", ":", ";", | |
| "。", ",", "!", "?", ":", ";", "、", | |
| "،", "؛", "؟", "٪", "٫", "٬", | |
| ")", "]", "}", ")", "】", "》", "」", "』", | |
| "»", "”", "’", | |
| } | |
| _NO_SPACE_AFTER = { | |
| "(", "[", "{", "(", "【", "《", "「", "『", | |
| "«", "“", "‘", | |
| } | |
| def join_tokens(tokens: Sequence[str], join_with: str = " ", puncts: Optional[Set[str]] = None) -> str: | |
| """ | |
| Join tokens into a sentence with optional punctuation-aware spacing. | |
| - If join_with == " " and puncts is provided, suppress space before common closing punctuation. | |
| - Keep behavior identical to normal join for other joiners. | |
| """ | |
| if not tokens: | |
| return "" | |
| if join_with == "": | |
| return "".join(tokens) | |
| if join_with != " " or puncts is None: | |
| return join_with.join(tokens) | |
| out = str(tokens[0]) | |
| prev = str(tokens[0]) | |
| for tok in tokens[1:]: | |
| tok = str(tok) | |
| is_punct_token = bool(tok) and all(ch in puncts for ch in tok) | |
| if is_punct_token and all(ch in _NO_SPACE_BEFORE for ch in tok): | |
| out += tok | |
| elif prev and all(ch in _NO_SPACE_AFTER for ch in prev): | |
| out += tok | |
| else: | |
| out += " " + tok | |
| prev = tok | |
| return out | |
| def word_lists_to_sentences(word_lists: list[list[str]], join_with=" ", language=None): | |
| if not word_lists: | |
| raise ValueError("There is no word lists to combine.") | |
| return [_combine_words(words, join_with, language=language) for words in word_lists] | |
| def _read_lines_from_txt(path_to_data): | |
| with open(path_to_data, "r", encoding="utf-8") as file: | |
| return [line.strip() for line in file if line.strip()] | |
| def _read_rows_from_csv(path_to_data): | |
| with open(path_to_data, "r", encoding="utf-8") as file: | |
| reader = csv.reader(file) | |
| return [row for row in reader if row] | |
| def read_sentences_input(data_input, split_on=None): | |
| if isinstance(data_input, os.PathLike): | |
| data_input = os.fspath(data_input) | |
| if isinstance(data_input, str): | |
| if not os.path.exists(data_input): | |
| raise ValueError(f"Input file does not exist: {data_input}") | |
| _, ext = os.path.splitext(data_input) | |
| if ext.lower() == ".txt": | |
| return sentences_to_word_lists(_read_lines_from_txt(data_input), split_on=split_on) | |
| if ext.lower() == ".csv": | |
| return sentences_to_word_lists(_read_rows_from_csv(data_input), split_on=split_on) | |
| raise ValueError(f"Unsupported file type: {ext.lower()}") | |
| if isinstance(data_input, list): | |
| # list of word lists | |
| if data_input and all(isinstance(x, list) for x in data_input): | |
| return data_input | |
| # list of sentences (strings) | |
| if data_input and all(isinstance(x, str) for x in data_input): | |
| return sentences_to_word_lists(data_input, split_on=split_on) | |
| raise ValueError("List input must be list[str] or list[list[str]].") | |
| raise ValueError("data_input must be a file path or a list input.") | |
| def strip_punctuation(word: str, puncts: Set[str]) -> Tuple[str, str, str]: | |
| start = 0 | |
| end = len(word) | |
| # leading | |
| while start < end and word[start] in puncts: | |
| start += 1 | |
| # trailing | |
| while end > start and word[end - 1] in puncts: | |
| end -= 1 | |
| return word[:start], word[start:end], word[end:] | |
| def attach_punctuation(core: str, prefix: str, suffix: str) -> str: | |
| return f"{prefix}{core}{suffix}" | |
| def normalize_candidates(words: Sequence[str]) -> list[str]: | |
| """Strip, drop empties, de-duplicate (order-preserving), trim.""" | |
| uniq = [] | |
| seen = set() | |
| for w in words: | |
| if not w: | |
| continue | |
| w = w.strip() | |
| if not w: | |
| continue | |
| if w in seen: | |
| continue | |
| seen.add(w) | |
| uniq.append(w) | |
| return uniq | |
| class Lexicon: | |
| path_to_lexicon: str | |
| rank_bin: int = 100 | |
| def __post_init__(self) -> None: | |
| df = pd.read_csv( | |
| self.path_to_lexicon, | |
| sep=self._infer_sep(self.path_to_lexicon), | |
| engine="python", | |
| encoding="utf-8", | |
| ) | |
| df.columns = [c.lower().strip() for c in df.columns] | |
| if "word" not in df.columns or "frequency_rank" not in df.columns: | |
| raise ValueError("Lexicon must contain columns: 'word' and 'frequency_rank'.") | |
| # Normalize words early so len() is always safe and malformed rows are removed. | |
| df["word"] = df["word"].fillna("").astype(str).str.strip() | |
| df = df[df["word"] != ""] | |
| # Guard against common stringified-null artifacts from CSV/Arrow parsing. | |
| df = df[~df["word"].str.lower().isin({"nan", "none", "<na>"})] | |
| df["frequency_rank"] = pd.to_numeric(df["frequency_rank"], errors="coerce") | |
| df = df.dropna(subset=["frequency_rank"]) | |
| df["frequency_rank"] = df["frequency_rank"].astype(int) | |
| if "length" not in df.columns: | |
| df["length"] = df["word"].map(len) | |
| else: | |
| fallback_length = df["word"].map(lambda x: len(x) if isinstance(x, str) else 0) | |
| df["length"] = pd.to_numeric(df["length"], errors="coerce").fillna(fallback_length).astype(int) | |
| df["freq_group"] = ((df["frequency_rank"] - 1) // self.rank_bin).astype(int) | |
| self.df = df | |
| self.max_freq_group = int(df["freq_group"].max()) | |
| self.max_frequency_rank = int(df["frequency_rank"].max()) | |
| self.min_length = int(df["length"].min()) | |
| self.max_length = int(df["length"].max()) | |
| self.group_to_words: Dict[Tuple[int, int], Set[str]] = ( | |
| df.groupby(["length", "freq_group"])["word"].apply(set).to_dict() | |
| ) | |
| self.word_to_group: Dict[str, Tuple[int, int]] = dict( | |
| zip(df["word"], zip(df["length"], df["freq_group"])) | |
| ) | |
| # keep best (smallest) rank for each word | |
| rank_series = df.groupby("word")["frequency_rank"].min() | |
| self.word_to_rank: Dict[str, int] = {w: int(r) for w, r in rank_series.items()} | |
| def _infer_sep(path: str) -> str: | |
| with open(path, "r", encoding="utf-8") as f: | |
| header = f.readline() | |
| if "\t" in header: | |
| return "\t" | |
| if "," in header: | |
| return "," | |
| return r"\s+" | |
| def get_neighbor(self, word: str, min_size: int = 10, max_size: Optional[int] = None) -> List[str]: | |
| word = word.strip() | |
| if not word: | |
| return [] | |
| w_len, fg = self.word_to_group.get(word, (len(word), self.max_freq_group)) | |
| out: Set[str] = set() | |
| delta = 0 | |
| while len(out) < min_size: | |
| left, right = fg - delta, fg + delta | |
| if left < 0 and right > self.max_freq_group: | |
| break | |
| if 0 <= left <= self.max_freq_group: | |
| out |= self.group_to_words.get((w_len, left), set()) | |
| if right != left and 0 <= right <= self.max_freq_group: | |
| out |= self.group_to_words.get((w_len, right), set()) | |
| out.discard(word) | |
| delta += 1 | |
| neighbors = list(out) | |
| if max_size is not None and len(neighbors) > max_size: | |
| random.shuffle(neighbors) | |
| neighbors = neighbors[:max_size] | |
| return neighbors | |
| def get_rank(self, word: str, default_to_max: bool = True) -> Optional[int]: | |
| rank = self.word_to_rank.get(word) | |
| if rank is not None: | |
| return int(rank) | |
| if default_to_max: | |
| return int(self.max_frequency_rank) | |
| return None | |
| def get_neighbor_by_profile( | |
| self, | |
| *, | |
| target_length: int, | |
| target_rank: int, | |
| min_size: int = 10, | |
| max_size: Optional[int] = None, | |
| exclude_words: Optional[Iterable[str]] = None, | |
| max_length_delta: int = 3, | |
| ) -> List[str]: | |
| """ | |
| Retrieve neighbors by target length/rank profile (without a pivot word). | |
| Useful for controlled items where target words differ across conditions. | |
| """ | |
| out: Set[str] = set() | |
| excludes = set(exclude_words or []) | |
| t_len = max(self.min_length, min(self.max_length, int(target_length))) | |
| rank = max(1, min(self.max_frequency_rank, int(target_rank))) | |
| fg = (rank - 1) // self.rank_bin | |
| for freq_delta in range(0, self.max_freq_group + 1): | |
| for len_delta in range(0, max_length_delta + 1): | |
| lengths = [t_len] if len_delta == 0 else [t_len - len_delta, t_len + len_delta] | |
| groups = [fg] if freq_delta == 0 else [fg - freq_delta, fg + freq_delta] | |
| for l in lengths: | |
| if l < self.min_length or l > self.max_length: | |
| continue | |
| for g in groups: | |
| if g < 0 or g > self.max_freq_group: | |
| continue | |
| out |= self.group_to_words.get((l, g), set()) | |
| out -= excludes | |
| if len(out) >= min_size: | |
| break | |
| neighbors = list(out) | |
| if max_size is not None and len(neighbors) > max_size: | |
| random.shuffle(neighbors) | |
| neighbors = neighbors[:max_size] | |
| return neighbors | |
| if __name__ == "__main__": | |
| lexicon = Lexicon("/swdata/yin/Cui/LLM-MAZE/llmmaze/data/lexicon/lexicon_zh.txt") | |
| neighbors = lexicon.get_neighbor("人权",max_size=20) | |
| print(neighbors) | |
| # puncts_zh = get_punct("zh") | |
| # puncts_en = get_punct("en", extra={"★"}, remove={"'"}) | |
| # print("Punctuations for Chinese:") | |
| # print(puncts_zh) | |
| # print("Punctuations for English:") | |
| # print(puncts_en) | |
| # word = "人权,。?" | |
| # pure, punct, suffix = strip_punctuation(word, puncts_zh) | |
| # print(pure, "->", punct, "->", suffix) | |
| # word = "‘’rights,." | |
| # pure, punct, suffix = strip_punctuation(word, puncts_en) | |
| # print(pure, "->", punct, "->", suffix) | |
| # word = "rig.hts,★" | |
| # pure, punct, suffix = strip_punctuation(word, puncts_en) | |
| # print(pure, "->", punct, "->", suffix) |