|
|
""" |
|
|
ByteRewriter class for efficient morpheme matching and replacement. |
|
|
Based on the MYTE implementation using a hash tree (trie) structure. |
|
|
""" |
|
|
import json |
|
|
from typing import Dict, List, Union |
|
|
|
|
|
|
|
|
class ByteRewriter: |
|
|
""" |
|
|
Efficient byte-level rewriter using a hash tree (trie) for morpheme matching. |
|
|
|
|
|
This class builds a prefix tree from morpheme mappings and uses it to |
|
|
efficiently find and replace morpheme sequences in byte streams. |
|
|
""" |
|
|
|
|
|
def __init__(self, rewriting_rules: Union[str, Dict]): |
|
|
""" |
|
|
Initialize ByteRewriter with rewriting rules. |
|
|
|
|
|
Args: |
|
|
rewriting_rules: Either a path to JSON file or dict mapping morphemes to codes |
|
|
""" |
|
|
if isinstance(rewriting_rules, str): |
|
|
with open(rewriting_rules, 'r', encoding='utf-8') as f: |
|
|
self.rules = json.load(f) |
|
|
else: |
|
|
self.rules = rewriting_rules |
|
|
|
|
|
|
|
|
self.hash_tree = self.construct_hash_tree(self.rules) |
|
|
|
|
|
|
|
|
self.reverse_rules = {v: k for k, v in self.rules.items()} |
|
|
self.reverse_hash_tree = self.construct_hash_tree(self.reverse_rules) |
|
|
|
|
|
def construct_hash_tree(self, rules: Dict) -> Dict: |
|
|
""" |
|
|
Construct a hash tree (trie) from rewriting rules. |
|
|
|
|
|
Args: |
|
|
rules: Dictionary mapping source sequences to target sequences |
|
|
|
|
|
Returns: |
|
|
Hash tree as nested dictionaries |
|
|
""" |
|
|
tree = {} |
|
|
|
|
|
for source, target in rules.items(): |
|
|
|
|
|
if isinstance(source, str): |
|
|
source_bytes = source.encode('utf-8') |
|
|
else: |
|
|
source_bytes = source |
|
|
|
|
|
|
|
|
current_node = tree |
|
|
for byte in source_bytes: |
|
|
if byte not in current_node: |
|
|
current_node[byte] = {} |
|
|
current_node = current_node[byte] |
|
|
|
|
|
|
|
|
current_node['__value__'] = target |
|
|
|
|
|
return tree |
|
|
|
|
|
def rewrite_bytes(self, in_bytes: Union[str, bytes, List], reverse: bool = False) -> List[str]: |
|
|
""" |
|
|
Rewrite byte sequence using the hash tree. |
|
|
|
|
|
Args: |
|
|
in_bytes: Input byte sequence (string, bytes, or list of hex strings) |
|
|
reverse: If True, use reverse mapping (for decoding) |
|
|
|
|
|
Returns: |
|
|
List of hex strings representing rewritten bytes |
|
|
""" |
|
|
|
|
|
if isinstance(in_bytes, str): |
|
|
byte_list = list(in_bytes.encode('utf-8')) |
|
|
elif isinstance(in_bytes, bytes): |
|
|
byte_list = list(in_bytes) |
|
|
elif isinstance(in_bytes, list): |
|
|
|
|
|
if all(isinstance(x, str) for x in in_bytes): |
|
|
byte_list = [int(x, 16) for x in in_bytes] |
|
|
else: |
|
|
byte_list = in_bytes |
|
|
else: |
|
|
raise ValueError(f"Unsupported input type: {type(in_bytes)}") |
|
|
|
|
|
|
|
|
tree = self.reverse_hash_tree if reverse else self.hash_tree |
|
|
|
|
|
|
|
|
result = [] |
|
|
i = 0 |
|
|
|
|
|
while i < len(byte_list): |
|
|
|
|
|
current_node = tree |
|
|
match_length = 0 |
|
|
match_value = None |
|
|
|
|
|
j = i |
|
|
while j < len(byte_list) and byte_list[j] in current_node: |
|
|
current_node = current_node[byte_list[j]] |
|
|
j += 1 |
|
|
|
|
|
|
|
|
if '__value__' in current_node: |
|
|
match_length = j - i |
|
|
match_value = current_node['__value__'] |
|
|
|
|
|
if match_value is not None: |
|
|
|
|
|
if isinstance(match_value, str): |
|
|
|
|
|
target_bytes = match_value.encode('utf-8') |
|
|
result.extend([f"{b:02x}" for b in target_bytes]) |
|
|
elif isinstance(match_value, bytes): |
|
|
result.extend([f"{b:02x}" for b in match_value]) |
|
|
else: |
|
|
|
|
|
result.append(match_value) |
|
|
|
|
|
i += match_length |
|
|
else: |
|
|
|
|
|
result.append(f"{byte_list[i]:02x}") |
|
|
i += 1 |
|
|
|
|
|
return result |
|
|
|
|
|
def rewrite_text(self, text: str, reverse: bool = False) -> str: |
|
|
""" |
|
|
Rewrite text by applying byte-level morpheme replacements. |
|
|
|
|
|
OPTIMIZED: Works directly with bytes to avoid inefficient hex string conversions. |
|
|
|
|
|
Args: |
|
|
text: Input text |
|
|
reverse: If True, use reverse mapping (for decoding) |
|
|
|
|
|
Returns: |
|
|
Rewritten text |
|
|
""" |
|
|
if not text: |
|
|
return text |
|
|
|
|
|
|
|
|
input_bytes = text.encode('utf-8') |
|
|
byte_list = list(input_bytes) |
|
|
|
|
|
|
|
|
tree = self.reverse_hash_tree if reverse else self.hash_tree |
|
|
|
|
|
|
|
|
result_bytes = [] |
|
|
i = 0 |
|
|
|
|
|
while i < len(byte_list): |
|
|
|
|
|
current_node = tree |
|
|
match_length = 0 |
|
|
match_value = None |
|
|
|
|
|
j = i |
|
|
while j < len(byte_list) and byte_list[j] in current_node: |
|
|
current_node = current_node[byte_list[j]] |
|
|
j += 1 |
|
|
|
|
|
|
|
|
if '__value__' in current_node: |
|
|
match_length = j - i |
|
|
match_value = current_node['__value__'] |
|
|
|
|
|
if match_value is not None: |
|
|
|
|
|
if isinstance(match_value, str): |
|
|
|
|
|
result_bytes.extend(match_value.encode('utf-8')) |
|
|
elif isinstance(match_value, bytes): |
|
|
result_bytes.extend(match_value) |
|
|
else: |
|
|
|
|
|
try: |
|
|
if isinstance(match_value, int): |
|
|
result_bytes.append(match_value) |
|
|
else: |
|
|
result_bytes.extend(str(match_value).encode('utf-8')) |
|
|
except: |
|
|
result_bytes.append(byte_list[i]) |
|
|
|
|
|
i += match_length |
|
|
else: |
|
|
|
|
|
result_bytes.append(byte_list[i]) |
|
|
i += 1 |
|
|
|
|
|
|
|
|
try: |
|
|
return bytes(result_bytes).decode('utf-8', errors='replace') |
|
|
except Exception as e: |
|
|
print(f"Warning: Could not decode result: {e}") |
|
|
return text |
|
|
|
|
|
def get_stats(self) -> Dict: |
|
|
"""Get statistics about the rewriter.""" |
|
|
return { |
|
|
'num_rules': len(self.rules), |
|
|
'num_reverse_rules': len(self.reverse_rules), |
|
|
'avg_source_length': sum(len(k.encode('utf-8') if isinstance(k, str) else k) |
|
|
for k in self.rules.keys()) / len(self.rules) if self.rules else 0, |
|
|
'avg_target_length': sum(len(v.encode('utf-8') if isinstance(v, str) else v) |
|
|
for v in self.rules.values()) / len(self.rules) if self.rules else 0, |
|
|
} |
|
|
|
|
|
|
|
|
class MorphemeEncoder: |
|
|
""" |
|
|
High-level morpheme encoder that combines decompose and merge rewriters. |
|
|
This mirrors the MyT5Tokenizer morphological_encode method. |
|
|
""" |
|
|
|
|
|
def __init__(self, decompose_map: Dict = None, merge_map: Dict = None): |
|
|
""" |
|
|
Initialize morpheme encoder with optional decompose and merge maps. |
|
|
|
|
|
Args: |
|
|
decompose_map: Mapping for decomposition (normalization) |
|
|
merge_map: Mapping for merging (morphological segmentation) |
|
|
""" |
|
|
self.decompose_rewriter = ByteRewriter(decompose_map) if decompose_map else None |
|
|
self.merge_rewriter = ByteRewriter(merge_map) if merge_map else None |
|
|
|
|
|
def encode(self, text: str) -> List[str]: |
|
|
""" |
|
|
Encode text using morphological preprocessing. |
|
|
|
|
|
Args: |
|
|
text: Input text |
|
|
|
|
|
Returns: |
|
|
List of token indices (as hex strings) |
|
|
""" |
|
|
|
|
|
indices = [f"{b:02x}" for b in text.encode('utf-8')] |
|
|
|
|
|
|
|
|
if self.decompose_rewriter: |
|
|
indices = self.decompose_rewriter.rewrite_bytes(indices, reverse=False) |
|
|
|
|
|
|
|
|
if self.merge_rewriter: |
|
|
indices = self.merge_rewriter.rewrite_bytes(indices, reverse=False) |
|
|
|
|
|
return indices |
|
|
|
|
|
def decode(self, indices: List[str]) -> str: |
|
|
""" |
|
|
Decode token indices back to text. |
|
|
|
|
|
Args: |
|
|
indices: List of token indices (as hex strings) |
|
|
|
|
|
Returns: |
|
|
Decoded text |
|
|
""" |
|
|
|
|
|
if self.merge_rewriter: |
|
|
indices = self.merge_rewriter.rewrite_bytes(indices, reverse=True) |
|
|
|
|
|
|
|
|
if self.decompose_rewriter: |
|
|
indices = self.decompose_rewriter.rewrite_bytes(indices, reverse=True) |
|
|
|
|
|
|
|
|
try: |
|
|
byte_result = bytes([int(h, 16) for h in indices]) |
|
|
return byte_result.decode('utf-8', errors='replace') |
|
|
except Exception as e: |
|
|
print(f"Warning: Could not decode: {e}") |
|
|
return "" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
|
|
|
example_rules = { |
|
|
'ال': '\x01', |
|
|
'و': '\x02', |
|
|
'ب': '\x03', |
|
|
'the': '\x04', |
|
|
} |
|
|
|
|
|
print("=" * 80) |
|
|
print("ByteRewriter Test") |
|
|
print("=" * 80) |
|
|
|
|
|
|
|
|
rewriter = ByteRewriter(example_rules) |
|
|
|
|
|
|
|
|
arabic_text = "والمدرسة" |
|
|
print(f"\nOriginal Arabic: {arabic_text}") |
|
|
|
|
|
|
|
|
rewritten = rewriter.rewrite_text(arabic_text, reverse=False) |
|
|
print(f"Rewritten: {repr(rewritten)}") |
|
|
|
|
|
|
|
|
reversed_text = rewriter.rewrite_text(rewritten, reverse=True) |
|
|
print(f"Reversed: {reversed_text}") |
|
|
|
|
|
|
|
|
english_text = "the school" |
|
|
print(f"\nOriginal English: {english_text}") |
|
|
|
|
|
rewritten = rewriter.rewrite_text(english_text, reverse=False) |
|
|
print(f"Rewritten: {repr(rewritten)}") |
|
|
|
|
|
reversed_text = rewriter.rewrite_text(rewritten, reverse=True) |
|
|
print(f"Reversed: {reversed_text}") |
|
|
|
|
|
|
|
|
print(f"\nRewriter stats: {rewriter.get_stats()}") |
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print("MorphemeEncoder Test") |
|
|
print("=" * 80) |
|
|
|
|
|
|
|
|
encoder = MorphemeEncoder(merge_map=example_rules) |
|
|
|
|
|
test_text = "والمدرسة the school" |
|
|
print(f"\nOriginal: {test_text}") |
|
|
|
|
|
encoded = encoder.encode(test_text) |
|
|
print(f"Encoded: {encoded[:20]}...") |
|
|
|
|
|
decoded = encoder.decode(encoded) |
|
|
print(f"Decoded: {decoded}") |
|
|
|
|
|
print("\nTest complete!") |
|
|
|