Agentic_A-Maze_Studio / utils /components.py
CuiD's picture
Upload folder using huggingface_hub
8dbd05b verified
Raw
History Blame Contribute Delete
13.7 kB
"""
Utility helpers for multilingual text preprocessing and lexicon lookup.
This module provides two groups of functionality used by the maze pipeline:
1) Text and punctuation utilities:
- language-aware punctuation sets (`get_punctuation`)
- sentence/word-list conversion helpers
- punctuation stripping/reattachment for candidate-token handling
- candidate normalization (trim + de-duplicate)
2) Lexicon access and neighborhood retrieval:
- `Lexicon` loads a frequency-ranked lexicon file (word, length, rank)
- words are bucketed by `(length, frequency_bin)` for fast lookup
- `get_neighbor` returns words with similar length and nearby
frequency-bin groups, which is useful for generating replacement
candidates with comparable lexical difficulty.
"""
import pandas as pd
import random
import csv
import os
from dataclasses import dataclass
from typing import Dict, List, Optional, Set, Tuple, Iterable, Sequence
from functools import lru_cache
import yaml
import string
random.seed(42)
_BASE_PUNCT = {
"latin": set(string.punctuation) | {"“", "”", "‘", "’", "—", "–", "…"},
"zh": {"。", ",", "!", "?", ":", ";", "、", "(", ")", "《", "》",
"「", "」", "『", "』", "【", "】", "“", "”", "‘", "’", "—", "…"},
"ja": {"。", "、", "!", "?", "「", "」", "『", "』", "・", "ー", "…"},
"ko": set(string.punctuation) | {"…", "·", "“", "”"},
"arabic": set(string.punctuation) | {
"،", "؛", "؟", "٪", "٫", "٬",
"«", "»", "“", "”", "‘", "’",
"…", "—", "–",
},
}
_LANG_MAP = {
"en": "latin", "de": "latin", "fr": "latin", "es": "latin",
"zh": "zh", "zh-cn": "zh", "zh-hans": "zh", "zh-hant": "zh",
"ja": "ja", "ko": "ko",
"ar": "arabic", "fa": "arabic", "ur": "arabic",
}
def get_punctuation(
language_code: str = "latin",
extra: Optional[Iterable[str]] = None,
remove: Optional[Iterable[str]] = None,
) -> Set[str]:
key = _LANG_MAP.get(language_code.lower(), language_code.lower())
punct = set(_BASE_PUNCT.get(key, _BASE_PUNCT["latin"]))
if extra:
punct |= set(extra)
if remove:
punct -= set(remove)
return punct
@lru_cache(maxsize=None)
def load_config(path="config.yaml"):
with open(path, "r", encoding="utf-8") as file:
config = yaml.safe_load(file)
return config or {}
def load_punctuation(punctuation_list):
if not punctuation_list:
return frozenset()
return frozenset(punctuation_list)
_NO_SPACE_LANGS = {"chinese", "zh", "japanese", "ja", "thai", "th"}
def _default_separator(language, fallback=" "):
if not language:
return fallback
normalized = str(language).strip().lower()
return "" if normalized in _NO_SPACE_LANGS else fallback
def _split_sentence(sentence, split_on=None):
if sentence is None:
raise ValueError("There is no sentence to split.")
if split_on is None:
return list(sentence)
if split_on:
return sentence.split(split_on)
def sentences_to_word_lists(sentences, split_on=None):
if not sentences:
return []
return [_split_sentence(sentence, split_on) for sentence in sentences]
def _combine_words(words, join_with=" ", language=None):
if words is None:
raise ValueError("There is no words to combine.")
if language is not None:
join_with = _default_separator(language, fallback=join_with)
return join_with.join(words)
_NO_SPACE_BEFORE = {
".", ",", "!", "?", ":", ";",
"。", ",", "!", "?", ":", ";", "、",
"،", "؛", "؟", "٪", "٫", "٬",
")", "]", "}", ")", "】", "》", "」", "』",
"»", "”", "’",
}
_NO_SPACE_AFTER = {
"(", "[", "{", "(", "【", "《", "「", "『",
"«", "“", "‘",
}
def join_tokens(tokens: Sequence[str], join_with: str = " ", puncts: Optional[Set[str]] = None) -> str:
"""
Join tokens into a sentence with optional punctuation-aware spacing.
- If join_with == " " and puncts is provided, suppress space before common closing punctuation.
- Keep behavior identical to normal join for other joiners.
"""
if not tokens:
return ""
if join_with == "":
return "".join(tokens)
if join_with != " " or puncts is None:
return join_with.join(tokens)
out = str(tokens[0])
prev = str(tokens[0])
for tok in tokens[1:]:
tok = str(tok)
is_punct_token = bool(tok) and all(ch in puncts for ch in tok)
if is_punct_token and all(ch in _NO_SPACE_BEFORE for ch in tok):
out += tok
elif prev and all(ch in _NO_SPACE_AFTER for ch in prev):
out += tok
else:
out += " " + tok
prev = tok
return out
def word_lists_to_sentences(word_lists: list[list[str]], join_with=" ", language=None):
if not word_lists:
raise ValueError("There is no word lists to combine.")
return [_combine_words(words, join_with, language=language) for words in word_lists]
def _read_lines_from_txt(path_to_data):
with open(path_to_data, "r", encoding="utf-8") as file:
return [line.strip() for line in file if line.strip()]
def _read_rows_from_csv(path_to_data):
with open(path_to_data, "r", encoding="utf-8") as file:
reader = csv.reader(file)
return [row for row in reader if row]
def read_sentences_input(data_input, split_on=None):
if isinstance(data_input, os.PathLike):
data_input = os.fspath(data_input)
if isinstance(data_input, str):
if not os.path.exists(data_input):
raise ValueError(f"Input file does not exist: {data_input}")
_, ext = os.path.splitext(data_input)
if ext.lower() == ".txt":
return sentences_to_word_lists(_read_lines_from_txt(data_input), split_on=split_on)
if ext.lower() == ".csv":
return sentences_to_word_lists(_read_rows_from_csv(data_input), split_on=split_on)
raise ValueError(f"Unsupported file type: {ext.lower()}")
if isinstance(data_input, list):
# list of word lists
if data_input and all(isinstance(x, list) for x in data_input):
return data_input
# list of sentences (strings)
if data_input and all(isinstance(x, str) for x in data_input):
return sentences_to_word_lists(data_input, split_on=split_on)
raise ValueError("List input must be list[str] or list[list[str]].")
raise ValueError("data_input must be a file path or a list input.")
def strip_punctuation(word: str, puncts: Set[str]) -> Tuple[str, str, str]:
start = 0
end = len(word)
# leading
while start < end and word[start] in puncts:
start += 1
# trailing
while end > start and word[end - 1] in puncts:
end -= 1
return word[:start], word[start:end], word[end:]
def attach_punctuation(core: str, prefix: str, suffix: str) -> str:
return f"{prefix}{core}{suffix}"
def normalize_candidates(words: Sequence[str]) -> list[str]:
"""Strip, drop empties, de-duplicate (order-preserving), trim."""
uniq = []
seen = set()
for w in words:
if not w:
continue
w = w.strip()
if not w:
continue
if w in seen:
continue
seen.add(w)
uniq.append(w)
return uniq
@dataclass
class Lexicon:
path_to_lexicon: str
rank_bin: int = 100
def __post_init__(self) -> None:
df = pd.read_csv(
self.path_to_lexicon,
sep=self._infer_sep(self.path_to_lexicon),
engine="python",
encoding="utf-8",
)
df.columns = [c.lower().strip() for c in df.columns]
if "word" not in df.columns or "frequency_rank" not in df.columns:
raise ValueError("Lexicon must contain columns: 'word' and 'frequency_rank'.")
# Normalize words early so len() is always safe and malformed rows are removed.
df["word"] = df["word"].fillna("").astype(str).str.strip()
df = df[df["word"] != ""]
# Guard against common stringified-null artifacts from CSV/Arrow parsing.
df = df[~df["word"].str.lower().isin({"nan", "none", "<na>"})]
df["frequency_rank"] = pd.to_numeric(df["frequency_rank"], errors="coerce")
df = df.dropna(subset=["frequency_rank"])
df["frequency_rank"] = df["frequency_rank"].astype(int)
if "length" not in df.columns:
df["length"] = df["word"].map(len)
else:
fallback_length = df["word"].map(lambda x: len(x) if isinstance(x, str) else 0)
df["length"] = pd.to_numeric(df["length"], errors="coerce").fillna(fallback_length).astype(int)
df["freq_group"] = ((df["frequency_rank"] - 1) // self.rank_bin).astype(int)
self.df = df
self.max_freq_group = int(df["freq_group"].max())
self.max_frequency_rank = int(df["frequency_rank"].max())
self.min_length = int(df["length"].min())
self.max_length = int(df["length"].max())
self.group_to_words: Dict[Tuple[int, int], Set[str]] = (
df.groupby(["length", "freq_group"])["word"].apply(set).to_dict()
)
self.word_to_group: Dict[str, Tuple[int, int]] = dict(
zip(df["word"], zip(df["length"], df["freq_group"]))
)
# keep best (smallest) rank for each word
rank_series = df.groupby("word")["frequency_rank"].min()
self.word_to_rank: Dict[str, int] = {w: int(r) for w, r in rank_series.items()}
@staticmethod
def _infer_sep(path: str) -> str:
with open(path, "r", encoding="utf-8") as f:
header = f.readline()
if "\t" in header:
return "\t"
if "," in header:
return ","
return r"\s+"
def get_neighbor(self, word: str, min_size: int = 10, max_size: Optional[int] = None) -> List[str]:
word = word.strip()
if not word:
return []
w_len, fg = self.word_to_group.get(word, (len(word), self.max_freq_group))
out: Set[str] = set()
delta = 0
while len(out) < min_size:
left, right = fg - delta, fg + delta
if left < 0 and right > self.max_freq_group:
break
if 0 <= left <= self.max_freq_group:
out |= self.group_to_words.get((w_len, left), set())
if right != left and 0 <= right <= self.max_freq_group:
out |= self.group_to_words.get((w_len, right), set())
out.discard(word)
delta += 1
neighbors = list(out)
if max_size is not None and len(neighbors) > max_size:
random.shuffle(neighbors)
neighbors = neighbors[:max_size]
return neighbors
def get_rank(self, word: str, default_to_max: bool = True) -> Optional[int]:
rank = self.word_to_rank.get(word)
if rank is not None:
return int(rank)
if default_to_max:
return int(self.max_frequency_rank)
return None
def get_neighbor_by_profile(
self,
*,
target_length: int,
target_rank: int,
min_size: int = 10,
max_size: Optional[int] = None,
exclude_words: Optional[Iterable[str]] = None,
max_length_delta: int = 3,
) -> List[str]:
"""
Retrieve neighbors by target length/rank profile (without a pivot word).
Useful for controlled items where target words differ across conditions.
"""
out: Set[str] = set()
excludes = set(exclude_words or [])
t_len = max(self.min_length, min(self.max_length, int(target_length)))
rank = max(1, min(self.max_frequency_rank, int(target_rank)))
fg = (rank - 1) // self.rank_bin
for freq_delta in range(0, self.max_freq_group + 1):
for len_delta in range(0, max_length_delta + 1):
lengths = [t_len] if len_delta == 0 else [t_len - len_delta, t_len + len_delta]
groups = [fg] if freq_delta == 0 else [fg - freq_delta, fg + freq_delta]
for l in lengths:
if l < self.min_length or l > self.max_length:
continue
for g in groups:
if g < 0 or g > self.max_freq_group:
continue
out |= self.group_to_words.get((l, g), set())
out -= excludes
if len(out) >= min_size:
break
neighbors = list(out)
if max_size is not None and len(neighbors) > max_size:
random.shuffle(neighbors)
neighbors = neighbors[:max_size]
return neighbors
if __name__ == "__main__":
lexicon = Lexicon("/swdata/yin/Cui/LLM-MAZE/llmmaze/data/lexicon/lexicon_zh.txt")
neighbors = lexicon.get_neighbor("人权",max_size=20)
print(neighbors)
# puncts_zh = get_punct("zh")
# puncts_en = get_punct("en", extra={"★"}, remove={"'"})
# print("Punctuations for Chinese:")
# print(puncts_zh)
# print("Punctuations for English:")
# print(puncts_en)
# word = "人权,。?"
# pure, punct, suffix = strip_punctuation(word, puncts_zh)
# print(pure, "->", punct, "->", suffix)
# word = "‘’rights,."
# pure, punct, suffix = strip_punctuation(word, puncts_en)
# print(pure, "->", punct, "->", suffix)
# word = "rig.hts,★"
# pure, punct, suffix = strip_punctuation(word, puncts_en)
# print(pure, "->", punct, "->", suffix)