myte-parity-sweep / rewrite_bytes.py
almaghrabima's picture
Upload rewrite_bytes.py with huggingface_hub
bd87f55 verified
"""
ByteRewriter class for efficient morpheme matching and replacement.
Based on the MYTE implementation using a hash tree (trie) structure.
"""
import json
from typing import Dict, List, Union
class ByteRewriter:
"""
Efficient byte-level rewriter using a hash tree (trie) for morpheme matching.
This class builds a prefix tree from morpheme mappings and uses it to
efficiently find and replace morpheme sequences in byte streams.
"""
def __init__(self, rewriting_rules: Union[str, Dict]):
"""
Initialize ByteRewriter with rewriting rules.
Args:
rewriting_rules: Either a path to JSON file or dict mapping morphemes to codes
"""
if isinstance(rewriting_rules, str):
with open(rewriting_rules, 'r', encoding='utf-8') as f:
self.rules = json.load(f)
else:
self.rules = rewriting_rules
# Build hash tree for efficient prefix matching
self.hash_tree = self.construct_hash_tree(self.rules)
# Build reverse mapping for decoding
self.reverse_rules = {v: k for k, v in self.rules.items()}
self.reverse_hash_tree = self.construct_hash_tree(self.reverse_rules)
def construct_hash_tree(self, rules: Dict) -> Dict:
"""
Construct a hash tree (trie) from rewriting rules.
Args:
rules: Dictionary mapping source sequences to target sequences
Returns:
Hash tree as nested dictionaries
"""
tree = {}
for source, target in rules.items():
# Convert source to byte sequence if it's a string
if isinstance(source, str):
source_bytes = source.encode('utf-8')
else:
source_bytes = source
# Build tree path for this source
current_node = tree
for byte in source_bytes:
if byte not in current_node:
current_node[byte] = {}
current_node = current_node[byte]
# Mark end of sequence with special key
current_node['__value__'] = target
return tree
def rewrite_bytes(self, in_bytes: Union[str, bytes, List], reverse: bool = False) -> List[str]:
"""
Rewrite byte sequence using the hash tree.
Args:
in_bytes: Input byte sequence (string, bytes, or list of hex strings)
reverse: If True, use reverse mapping (for decoding)
Returns:
List of hex strings representing rewritten bytes
"""
# Convert input to uniform format (list of bytes)
if isinstance(in_bytes, str):
byte_list = list(in_bytes.encode('utf-8'))
elif isinstance(in_bytes, bytes):
byte_list = list(in_bytes)
elif isinstance(in_bytes, list):
# Assume list of hex strings like ['41', '42', '43']
if all(isinstance(x, str) for x in in_bytes):
byte_list = [int(x, 16) for x in in_bytes]
else:
byte_list = in_bytes
else:
raise ValueError(f"Unsupported input type: {type(in_bytes)}")
# Select tree based on direction
tree = self.reverse_hash_tree if reverse else self.hash_tree
# Rewrite using hash tree
result = []
i = 0
while i < len(byte_list):
# Try to match longest prefix in tree
current_node = tree
match_length = 0
match_value = None
j = i
while j < len(byte_list) and byte_list[j] in current_node:
current_node = current_node[byte_list[j]]
j += 1
# Check if this is a complete match
if '__value__' in current_node:
match_length = j - i
match_value = current_node['__value__']
if match_value is not None:
# Found a match - replace with target sequence
if isinstance(match_value, str):
# Convert target string to hex bytes
target_bytes = match_value.encode('utf-8')
result.extend([f"{b:02x}" for b in target_bytes])
elif isinstance(match_value, bytes):
result.extend([f"{b:02x}" for b in match_value])
else:
# Assume it's already in the right format
result.append(match_value)
i += match_length
else:
# No match - keep original byte
result.append(f"{byte_list[i]:02x}")
i += 1
return result
def rewrite_text(self, text: str, reverse: bool = False) -> str:
"""
Rewrite text by applying byte-level morpheme replacements.
OPTIMIZED: Works directly with bytes to avoid inefficient hex string conversions.
Args:
text: Input text
reverse: If True, use reverse mapping (for decoding)
Returns:
Rewritten text
"""
if not text:
return text
# Convert text to bytes directly
input_bytes = text.encode('utf-8')
byte_list = list(input_bytes)
# Select tree based on direction
tree = self.reverse_hash_tree if reverse else self.hash_tree
# Rewrite using hash tree - work directly with bytes
result_bytes = []
i = 0
while i < len(byte_list):
# Try to match longest prefix in tree
current_node = tree
match_length = 0
match_value = None
j = i
while j < len(byte_list) and byte_list[j] in current_node:
current_node = current_node[byte_list[j]]
j += 1
# Check if this is a complete match
if '__value__' in current_node:
match_length = j - i
match_value = current_node['__value__']
if match_value is not None:
# Found a match - replace with target sequence
if isinstance(match_value, str):
# Convert target string to bytes directly
result_bytes.extend(match_value.encode('utf-8'))
elif isinstance(match_value, bytes):
result_bytes.extend(match_value)
else:
# Fallback: convert to bytes if needed
try:
if isinstance(match_value, int):
result_bytes.append(match_value)
else:
result_bytes.extend(str(match_value).encode('utf-8'))
except:
result_bytes.append(byte_list[i])
i += match_length
else:
# No match - keep original byte
result_bytes.append(byte_list[i])
i += 1
# Convert result bytes back to string
try:
return bytes(result_bytes).decode('utf-8', errors='replace')
except Exception as e:
print(f"Warning: Could not decode result: {e}")
return text
def get_stats(self) -> Dict:
"""Get statistics about the rewriter."""
return {
'num_rules': len(self.rules),
'num_reverse_rules': len(self.reverse_rules),
'avg_source_length': sum(len(k.encode('utf-8') if isinstance(k, str) else k)
for k in self.rules.keys()) / len(self.rules) if self.rules else 0,
'avg_target_length': sum(len(v.encode('utf-8') if isinstance(v, str) else v)
for v in self.rules.values()) / len(self.rules) if self.rules else 0,
}
class MorphemeEncoder:
"""
High-level morpheme encoder that combines decompose and merge rewriters.
This mirrors the MyT5Tokenizer morphological_encode method.
"""
def __init__(self, decompose_map: Dict = None, merge_map: Dict = None):
"""
Initialize morpheme encoder with optional decompose and merge maps.
Args:
decompose_map: Mapping for decomposition (normalization)
merge_map: Mapping for merging (morphological segmentation)
"""
self.decompose_rewriter = ByteRewriter(decompose_map) if decompose_map else None
self.merge_rewriter = ByteRewriter(merge_map) if merge_map else None
def encode(self, text: str) -> List[str]:
"""
Encode text using morphological preprocessing.
Args:
text: Input text
Returns:
List of token indices (as hex strings)
"""
# Convert text to hex bytes
indices = [f"{b:02x}" for b in text.encode('utf-8')]
# Apply decomposition (normalization)
if self.decompose_rewriter:
indices = self.decompose_rewriter.rewrite_bytes(indices, reverse=False)
# Apply merging (morphological segmentation)
if self.merge_rewriter:
indices = self.merge_rewriter.rewrite_bytes(indices, reverse=False)
return indices
def decode(self, indices: List[str]) -> str:
"""
Decode token indices back to text.
Args:
indices: List of token indices (as hex strings)
Returns:
Decoded text
"""
# Apply reverse merging
if self.merge_rewriter:
indices = self.merge_rewriter.rewrite_bytes(indices, reverse=True)
# Apply reverse decomposition
if self.decompose_rewriter:
indices = self.decompose_rewriter.rewrite_bytes(indices, reverse=True)
# Convert hex bytes back to text
try:
byte_result = bytes([int(h, 16) for h in indices])
return byte_result.decode('utf-8', errors='replace')
except Exception as e:
print(f"Warning: Could not decode: {e}")
return ""
# -----------------------------------------------------------------------------
# Testing and demonstration
if __name__ == '__main__':
# Example: Simple morpheme mapping
# In practice, this would come from Morfessor training
example_rules = {
'ال': '\x01', # Arabic definite article
'و': '\x02', # Arabic conjunction
'ب': '\x03', # Arabic preposition
'the': '\x04', # English definite article
}
print("=" * 80)
print("ByteRewriter Test")
print("=" * 80)
# Create rewriter
rewriter = ByteRewriter(example_rules)
# Test Arabic text
arabic_text = "والمدرسة"
print(f"\nOriginal Arabic: {arabic_text}")
# Rewrite
rewritten = rewriter.rewrite_text(arabic_text, reverse=False)
print(f"Rewritten: {repr(rewritten)}")
# Reverse
reversed_text = rewriter.rewrite_text(rewritten, reverse=True)
print(f"Reversed: {reversed_text}")
# Test English text
english_text = "the school"
print(f"\nOriginal English: {english_text}")
rewritten = rewriter.rewrite_text(english_text, reverse=False)
print(f"Rewritten: {repr(rewritten)}")
reversed_text = rewriter.rewrite_text(rewritten, reverse=True)
print(f"Reversed: {reversed_text}")
# Print stats
print(f"\nRewriter stats: {rewriter.get_stats()}")
print("\n" + "=" * 80)
print("MorphemeEncoder Test")
print("=" * 80)
# Test morpheme encoder
encoder = MorphemeEncoder(merge_map=example_rules)
test_text = "والمدرسة the school"
print(f"\nOriginal: {test_text}")
encoded = encoder.encode(test_text)
print(f"Encoded: {encoded[:20]}...") # Show first 20 tokens
decoded = encoder.decode(encoded)
print(f"Decoded: {decoded}")
print("\nTest complete!")