PyTorch
gpt2
gpt2-10M-parfind-eng / paradigm_utils.py
achille-fusco's picture
Upload folder using huggingface_hub
c2760fe verified
# paradigm_utils.py
import time
from collections import defaultdict
from tqdm import tqdm
import os
import math
import json
from typing import List, Tuple, Set, Dict, Any
def _serialize_suffixes(sfx_set):
flat = []
for s in sfx_set:
if isinstance(s, tuple):
base, nested = s
flat.append([base, sorted(list(nested))]) # JSON-safe pair
else:
flat.append(s) # plain string
# stable order: strings first, then pairs; then lexicographic
def key(x):
return (0, x) if isinstance(x, str) else (1, x[0], tuple(x[1]))
return sorted(flat, key=key)
def paradigms_to_json(paradigms):
out = []
for stems, suffixes in paradigms:
out.append({
"stems": sorted(list(stems)),
"suffixes": _serialize_suffixes(suffixes),
})
return out
def save_paradigms_json(paradigms, path, meta=None):
payload = {
"schema_version": 1,
"created_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
"meta": meta or {},
"paradigms": paradigms_to_json(paradigms),
}
with open(path, "w", encoding="utf-8") as f:
json.dump(payload, f, ensure_ascii=False, indent=2)
def _deserialize_suffixes(sfx_list):
out = set()
for item in sfx_list:
if isinstance(item, list): # [base, nested_list]
base, nested = item
out.add((base, frozenset(nested)))
else:
out.add(item)
return out
def load_paradigms_json(path):
with open(path, "r", encoding="utf-8") as f:
payload = json.load(f)
paradigms = []
for p in payload["paradigms"]:
stems = set(p["stems"])
suffixes = _deserialize_suffixes(p["suffixes"])
paradigms.append((stems, suffixes))
meta = payload.get("meta", {})
return paradigms, meta
### -----------------------------
### 1. Extract (stem, suffix) pairs from vocabulary
### -----------------------------
def extract_stem_suffix_pairs(vocab):
"""Return a mapping from stems to all suffixes they occur with, including null suffix."""
stem_to_suffixes = defaultdict(set)
for word in tqdm(vocab, desc="[1/7] Extracting stem-suffix pairs"):
for i in range(0, len(word) + 1): # include empty suffix
stem, suffix = word[:i], word[i:]
stem_to_suffixes[stem].add(suffix)
return stem_to_suffixes
### -----------------------------
### 2. Group stems by shared suffix sets and normalize by common prefix
### -----------------------------
def group_stems_by_suffixes(stem_to_suffixes, min_shared_stems=2, min_suffixes=2):
suffix_to_stems = defaultdict(set)
for stem, suffixes in stem_to_suffixes.items():
suffix_key = frozenset(suffixes)
suffix_to_stems[suffix_key].add(stem)
normalized_suffix_map = defaultdict(set)
for suffixes, stems in tqdm(suffix_to_stems.items(), desc="[2/7] Grouping and normalizing"):
non_empty_suffixes = [s for s in suffixes if s]
if len(stems) >= min_shared_stems and len(suffixes) >= min_suffixes:
common_prefix = os.path.commonprefix(non_empty_suffixes) if non_empty_suffixes else ""
if common_prefix:
normalized_stems = {stem + common_prefix for stem in stems}
adjusted_suffixes = {s[len(common_prefix):] if s.startswith(common_prefix) else s for s in suffixes}
else:
normalized_stems = stems
adjusted_suffixes = suffixes
if len(adjusted_suffixes) >= min_suffixes:
suffix_key = frozenset(adjusted_suffixes)
normalized_suffix_map[suffix_key].update(normalized_stems)
paradigms = [(stems, set(suffixes)) for suffixes, stems in normalized_suffix_map.items()]
return paradigms
### -----------------------------
### 3. Expand stem sets based on suffix set coverage
### -----------------------------
def stem_set_expansion(paradigms, stem_to_suffixes):
updated = 0
suffix_to_stems = {frozenset(suffixes): set(stems) for stems, suffixes in paradigms}
for stem, suffixes in tqdm(stem_to_suffixes.items(), desc="[3/7] Expanding stem sets"):
added = False
for paradigm_suffixes in sorted(suffix_to_stems.keys(), key=lambda x: (-len(x), tuple(sorted(x)))):
if paradigm_suffixes.issubset(suffixes):
if stem not in suffix_to_stems[paradigm_suffixes]:
suffix_to_stems[paradigm_suffixes].add(stem)
updated += 1
added = True
if not added and stem == 'design':
print(f"[DEBUG] No suitable paradigm for 'design' with suffixes {suffixes}")
enriched = [(stems, set(suffixes)) for suffixes, stems in suffix_to_stems.items()]
print(f"✅ Added {updated} stems via stem set expansion.")
return enriched
### -----------------------------
### 4. Expand suffix sets based on partial compatibility
### -----------------------------
def harmonic_number(n):
return sum(1.0 / i for i in range(1, n + 1))
def suffix_set_expansion(paradigms):
base = paradigms[:] # snapshot
merged = [ (set(stems), set(suffixes)) for stems, suffixes in base ]
enriched_count = 0
# Iterate in a deterministic order
for i, (stems_i, suffixes_i) in enumerate(sort_paradigms(merged)):
for j, (stems_j, suffixes_j) in enumerate(sort_paradigms(merged)):
if i == j:
continue
if suffixes_i > suffixes_j:
intersection = stems_i & stems_j
denom = max(1, len(stems_j)) # guard
if (len(stems_j) - len(intersection)) < (len(stems_j) / harmonic_number(denom)):
stems_i |= stems_j
enriched_count += 1
# do not mutate stems_j/suffixes_j further
print(f"\n✅ Enriched {enriched_count} paradigms via suffix set expansion.")
# Return back in original tuple-of-sets form
return [ (set(st), set(sf)) for st, sf in sort_paradigms(merged) ]
### -----------------------------
### 5. Prune subsumed stems
### -----------------------------
def prune_subsumed_stems(paradigms):
pruned_paradigms = []
for i, (stems_i, suffixes_i) in enumerate(paradigms):
pruned_stems = set(stems_i)
for j, (stems_j, suffixes_j) in enumerate(paradigms):
if i == j:
continue
if suffixes_j >= suffixes_i:
pruned_stems -= (stems_j & stems_i)
if pruned_stems:
pruned_paradigms.append((pruned_stems, suffixes_i))
print(f"✅ Pruned to {len(pruned_paradigms)} paradigms after removing subsumed stems.")
return sort_paradigms(pruned_paradigms)
### -----------------------------
### 6. Sort paradigms by size
### -----------------------------
def sort_paradigms(paradigms):
"""
Primary: log(len(stems)) * log(len(suffixes)) (DESC)
Ties: (-len(stems), -len(suffixes), lexicographic stems, lexicographic suffix heads)
"""
def score(p):
stems, suffixes = p
if stems and suffixes:
return math.log(len(stems)) * math.log(len(suffixes))
return 0.0
def tie_key(p):
stems, suffixes = p
sfx_heads = []
for s in suffixes:
sfx_heads.append(s[0] if isinstance(s, tuple) else s)
return (-len(stems), -len(suffixes),
" ".join(sorted(stems)),
" ".join(sorted(sfx_heads)))
return sorted(paradigms, key=lambda p: (-score(p), tie_key(p)))
def sort_paradigms_by_suffix_count(paradigms):
def score(p):
stem_count = len(p[0])
suffix_count = len(p[1])
if stem_count > 0 and suffix_count > 0:
return suffix_count
return 0
return sorted(paradigms, key=score, reverse=True)
def nest_suffixes_from_paradigms(paradigms):
print("[7/7] Nesting suffixes based on reusable paradigms...")
suffix_set_index = {frozenset(suffixes): True for _, suffixes in paradigms}
nested_paradigms = []
for stems, suffixes in paradigms:
suffixes_list = list(suffixes)
nested_suffixes = set()
used = set()
# deterministic nested pairing
for i, s1 in enumerate(sorted(suffixes_list)):
for j, s2 in enumerate(sorted(suffixes_list)):
if i == j or s2 in used or not isinstance(s1, str) or not isinstance(s2, str):
continue
if s2.startswith(s1) and s1 != '':
remainder = s2[len(s1):]
if remainder and frozenset({'', remainder}) in suffix_set_index:
nested_suffixes.add((s1, frozenset({'', remainder})))
used.add(s2)
used.add(s1)
break
for s in suffixes_list:
if s not in used:
nested_suffixes.add(s)
nested_paradigms.append((set(stems), nested_suffixes))
print(f"✅ Nested structure created for {len(nested_paradigms)} paradigms.")
return sort_paradigms(nested_paradigms)
def refine_nested_stem_conflicts(paradigms):
"""
Remove stems from higher-ranked paradigms if they are fully explained by nested structures
in lower-ranked paradigms.
Args:
paradigms: list of (stem_set, suffix_set), where suffix_set may contain nested (str, frozenset) tuples
Returns:
Refined list of paradigms with redundant derived stems removed
"""
refined_paradigms = paradigms[:]
all_suffix_sets = {frozenset(suffixes) for _, suffixes in paradigms}
# Build a mapping from nested suffix sets to their parent prefixes
derived_stems = set()
for stems, suffixes in paradigms:
for sfx in suffixes:
if isinstance(sfx, tuple):
base, nested_suffixes = sfx
if frozenset(nested_suffixes) in all_suffix_sets:
for stem in stems:
derived_stems.add(stem + base)
# Remove derived stems from paradigms with simple suffix sets (like ['', 's'])
updated_paradigms = []
for stems, suffixes in refined_paradigms:
cleaned_stems = stems - derived_stems
updated_paradigms.append((cleaned_stems, suffixes))
print(f"✅ Removed {len(derived_stems)} derived stems explained by nested paradigms.")
return updated_paradigms
### -----------------------------
### 7. Segment word based on ranked paradigms
### -----------------------------
def recursive_fallback(word, suffix_set):
for suffix in sorted(suffix_set, key=lambda s: -len(s)):
if suffix and word.endswith(suffix):
stem_candidate = word[:-len(suffix)]
rest = recursive_fallback(stem_candidate, suffix_set)
return rest + [suffix]
return [word] # fallback to whole word if nothing matches
### -----------------------------
### Main runner
### -----------------------------
def run_paradigm_extraction(vocab, min_shared_stems=2, min_suffixes=2, enrich_suffix_sets=True):
start = time.time()
stem_to_suffixes = extract_stem_suffix_pairs(vocab)
paradigms = group_stems_by_suffixes(stem_to_suffixes, min_shared_stems, min_suffixes)
paradigms = stem_set_expansion(paradigms, stem_to_suffixes)
paradigms = sort_paradigms(paradigms)
paradigms = prune_subsumed_stems(paradigms)
paradigms = sort_paradigms(paradigms)
paradigms = nest_suffixes_from_paradigms(paradigms)
paradigms = refine_nested_stem_conflicts(paradigms)
paradigms = sort_paradigms(paradigms)
if enrich_suffix_sets:
print("[4/7] Expanding suffix sets based on partial compatibility...")
paradigms = suffix_set_expansion(paradigms)
paradigms = sort_paradigms(paradigms)
paradigms = prune_subsumed_stems(paradigms)
paradigms = sort_paradigms(paradigms)
'''# Fallback paradigm for unassigned full words
vocab_words = set(vocab)
assigned_words = set()
for stems, suffixes in paradigms:
for stem in stems:
for suffix in suffixes:
if isinstance(suffix, tuple):
base, _ = suffix
assigned_words.add(stem + base)
else:
assigned_words.add(stem + suffix)
unassigned_words = vocab_words - assigned_words
if unassigned_words:
print(f"✅ {len(unassigned_words)} full words were not assigned to any paradigm, added fallback paradigm.")
paradigms.append((set(unassigned_words), frozenset({""})))
paradigms = sort_paradigms(paradigms)'''
print(f"\n✅ Extracted {len(paradigms)} paradigms.")
print(f"⏱️ Finished in {time.time() - start:.2f} seconds.")
return paradigms
def segment_word_from_nested_paradigms(word, paradigms, fallback=True, top_k=300):
"""
Segment a word based on nested paradigms with optional fallback.
Parameters:
word (str): The word to segment.
paradigms (list): A list of tuples (stems, suffixes) with optional nesting.
fallback (bool): Whether to fall back on longest suffix match from top_k paradigms.
top_k (int): Number of top paradigms to consider in fallback.
Returns:
List[str]: Segmented pieces of the word.
"""
def match_suffixes(suffixes, remainder):
"""Recursive helper to match nested suffix structures."""
for suffix in suffixes:
if isinstance(suffix, tuple):
base, nested = suffix
if remainder.startswith(base):
sub = remainder[len(base):]
nested_result = match_suffixes(nested, sub)
if nested_result is not None:
return [base] + nested_result
elif remainder == suffix:
return [suffix] if suffix else []
return None
# First pass: try full nested match
for stems, suffixes in paradigms:
for stem in stems:
if word.startswith(stem):
remainder = word[len(stem):]
matched_suffix = match_suffixes(suffixes, remainder)
if matched_suffix is not None:
return [stem] + matched_suffix
# Fallback strategy: longest suffix among top_k paradigms
if fallback:
seen_suffixes = set()
def collect_suffixes(suffixes):
for s in suffixes:
if isinstance(s, tuple):
seen_suffixes.add(s[0])
collect_suffixes(s[1])
else:
seen_suffixes.add(s)
for _, suffixes in paradigms[:top_k]:
collect_suffixes(suffixes)
# Try matching the longest suffix first
for suffix in sorted(seen_suffixes, key=lambda s: -len(s)):
if suffix and word.endswith(suffix):
stem = word[:-len(suffix)]
return [stem, suffix]
return [word]
return [word]
def segment_word_from_paradigms(word, paradigms, top_k=20):
"""
Simpler fallback-only version: match longest suffix among top_k paradigms.
Parameters:
word (str): Word to segment.
paradigms (list): Paradigm structures.
top_k (int): How many paradigms to consider.
Returns:
List[str]: Segmentation result.
"""
candidates = paradigms[:top_k]
best_split = None
for stems, suffixes in candidates:
for suffix in sorted(suffixes, key=lambda s: -len(s) if isinstance(s, str) else -len(s[0])):
if isinstance(suffix, tuple):
suffix = suffix[0] # ignore nested for fallback
if word.endswith(suffix):
stem_candidate = word[:-len(suffix)] if suffix else word
if stem_candidate in stems:
split = [stem_candidate, suffix] if suffix else [stem_candidate]
if best_split is None or len(suffix) > len(best_split[-1]):
best_split = split
return best_split or [word]