""" ByteRewriter class for efficient morpheme matching and replacement. Based on the MYTE implementation using a hash tree (trie) structure. """ import json from typing import Dict, List, Union class ByteRewriter: """ Efficient byte-level rewriter using a hash tree (trie) for morpheme matching. This class builds a prefix tree from morpheme mappings and uses it to efficiently find and replace morpheme sequences in byte streams. """ def __init__(self, rewriting_rules: Union[str, Dict]): """ Initialize ByteRewriter with rewriting rules. Args: rewriting_rules: Either a path to JSON file or dict mapping morphemes to codes """ if isinstance(rewriting_rules, str): with open(rewriting_rules, 'r', encoding='utf-8') as f: self.rules = json.load(f) else: self.rules = rewriting_rules # Build hash tree for efficient prefix matching self.hash_tree = self.construct_hash_tree(self.rules) # Build reverse mapping for decoding self.reverse_rules = {v: k for k, v in self.rules.items()} self.reverse_hash_tree = self.construct_hash_tree(self.reverse_rules) def construct_hash_tree(self, rules: Dict) -> Dict: """ Construct a hash tree (trie) from rewriting rules. Args: rules: Dictionary mapping source sequences to target sequences Returns: Hash tree as nested dictionaries """ tree = {} for source, target in rules.items(): # Convert source to byte sequence if it's a string if isinstance(source, str): source_bytes = source.encode('utf-8') else: source_bytes = source # Build tree path for this source current_node = tree for byte in source_bytes: if byte not in current_node: current_node[byte] = {} current_node = current_node[byte] # Mark end of sequence with special key current_node['__value__'] = target return tree def rewrite_bytes(self, in_bytes: Union[str, bytes, List], reverse: bool = False) -> List[str]: """ Rewrite byte sequence using the hash tree. Args: in_bytes: Input byte sequence (string, bytes, or list of hex strings) reverse: If True, use reverse mapping (for decoding) Returns: List of hex strings representing rewritten bytes """ # Convert input to uniform format (list of bytes) if isinstance(in_bytes, str): byte_list = list(in_bytes.encode('utf-8')) elif isinstance(in_bytes, bytes): byte_list = list(in_bytes) elif isinstance(in_bytes, list): # Assume list of hex strings like ['41', '42', '43'] if all(isinstance(x, str) for x in in_bytes): byte_list = [int(x, 16) for x in in_bytes] else: byte_list = in_bytes else: raise ValueError(f"Unsupported input type: {type(in_bytes)}") # Select tree based on direction tree = self.reverse_hash_tree if reverse else self.hash_tree # Rewrite using hash tree result = [] i = 0 while i < len(byte_list): # Try to match longest prefix in tree current_node = tree match_length = 0 match_value = None j = i while j < len(byte_list) and byte_list[j] in current_node: current_node = current_node[byte_list[j]] j += 1 # Check if this is a complete match if '__value__' in current_node: match_length = j - i match_value = current_node['__value__'] if match_value is not None: # Found a match - replace with target sequence if isinstance(match_value, str): # Convert target string to hex bytes target_bytes = match_value.encode('utf-8') result.extend([f"{b:02x}" for b in target_bytes]) elif isinstance(match_value, bytes): result.extend([f"{b:02x}" for b in match_value]) else: # Assume it's already in the right format result.append(match_value) i += match_length else: # No match - keep original byte result.append(f"{byte_list[i]:02x}") i += 1 return result def rewrite_text(self, text: str, reverse: bool = False) -> str: """ Rewrite text by applying byte-level morpheme replacements. OPTIMIZED: Works directly with bytes to avoid inefficient hex string conversions. Args: text: Input text reverse: If True, use reverse mapping (for decoding) Returns: Rewritten text """ if not text: return text # Convert text to bytes directly input_bytes = text.encode('utf-8') byte_list = list(input_bytes) # Select tree based on direction tree = self.reverse_hash_tree if reverse else self.hash_tree # Rewrite using hash tree - work directly with bytes result_bytes = [] i = 0 while i < len(byte_list): # Try to match longest prefix in tree current_node = tree match_length = 0 match_value = None j = i while j < len(byte_list) and byte_list[j] in current_node: current_node = current_node[byte_list[j]] j += 1 # Check if this is a complete match if '__value__' in current_node: match_length = j - i match_value = current_node['__value__'] if match_value is not None: # Found a match - replace with target sequence if isinstance(match_value, str): # Convert target string to bytes directly result_bytes.extend(match_value.encode('utf-8')) elif isinstance(match_value, bytes): result_bytes.extend(match_value) else: # Fallback: convert to bytes if needed try: if isinstance(match_value, int): result_bytes.append(match_value) else: result_bytes.extend(str(match_value).encode('utf-8')) except: result_bytes.append(byte_list[i]) i += match_length else: # No match - keep original byte result_bytes.append(byte_list[i]) i += 1 # Convert result bytes back to string try: return bytes(result_bytes).decode('utf-8', errors='replace') except Exception as e: print(f"Warning: Could not decode result: {e}") return text def get_stats(self) -> Dict: """Get statistics about the rewriter.""" return { 'num_rules': len(self.rules), 'num_reverse_rules': len(self.reverse_rules), 'avg_source_length': sum(len(k.encode('utf-8') if isinstance(k, str) else k) for k in self.rules.keys()) / len(self.rules) if self.rules else 0, 'avg_target_length': sum(len(v.encode('utf-8') if isinstance(v, str) else v) for v in self.rules.values()) / len(self.rules) if self.rules else 0, } class MorphemeEncoder: """ High-level morpheme encoder that combines decompose and merge rewriters. This mirrors the MyT5Tokenizer morphological_encode method. """ def __init__(self, decompose_map: Dict = None, merge_map: Dict = None): """ Initialize morpheme encoder with optional decompose and merge maps. Args: decompose_map: Mapping for decomposition (normalization) merge_map: Mapping for merging (morphological segmentation) """ self.decompose_rewriter = ByteRewriter(decompose_map) if decompose_map else None self.merge_rewriter = ByteRewriter(merge_map) if merge_map else None def encode(self, text: str) -> List[str]: """ Encode text using morphological preprocessing. Args: text: Input text Returns: List of token indices (as hex strings) """ # Convert text to hex bytes indices = [f"{b:02x}" for b in text.encode('utf-8')] # Apply decomposition (normalization) if self.decompose_rewriter: indices = self.decompose_rewriter.rewrite_bytes(indices, reverse=False) # Apply merging (morphological segmentation) if self.merge_rewriter: indices = self.merge_rewriter.rewrite_bytes(indices, reverse=False) return indices def decode(self, indices: List[str]) -> str: """ Decode token indices back to text. Args: indices: List of token indices (as hex strings) Returns: Decoded text """ # Apply reverse merging if self.merge_rewriter: indices = self.merge_rewriter.rewrite_bytes(indices, reverse=True) # Apply reverse decomposition if self.decompose_rewriter: indices = self.decompose_rewriter.rewrite_bytes(indices, reverse=True) # Convert hex bytes back to text try: byte_result = bytes([int(h, 16) for h in indices]) return byte_result.decode('utf-8', errors='replace') except Exception as e: print(f"Warning: Could not decode: {e}") return "" # ----------------------------------------------------------------------------- # Testing and demonstration if __name__ == '__main__': # Example: Simple morpheme mapping # In practice, this would come from Morfessor training example_rules = { 'ال': '\x01', # Arabic definite article 'و': '\x02', # Arabic conjunction 'ب': '\x03', # Arabic preposition 'the': '\x04', # English definite article } print("=" * 80) print("ByteRewriter Test") print("=" * 80) # Create rewriter rewriter = ByteRewriter(example_rules) # Test Arabic text arabic_text = "والمدرسة" print(f"\nOriginal Arabic: {arabic_text}") # Rewrite rewritten = rewriter.rewrite_text(arabic_text, reverse=False) print(f"Rewritten: {repr(rewritten)}") # Reverse reversed_text = rewriter.rewrite_text(rewritten, reverse=True) print(f"Reversed: {reversed_text}") # Test English text english_text = "the school" print(f"\nOriginal English: {english_text}") rewritten = rewriter.rewrite_text(english_text, reverse=False) print(f"Rewritten: {repr(rewritten)}") reversed_text = rewriter.rewrite_text(rewritten, reverse=True) print(f"Reversed: {reversed_text}") # Print stats print(f"\nRewriter stats: {rewriter.get_stats()}") print("\n" + "=" * 80) print("MorphemeEncoder Test") print("=" * 80) # Test morpheme encoder encoder = MorphemeEncoder(merge_map=example_rules) test_text = "والمدرسة the school" print(f"\nOriginal: {test_text}") encoded = encoder.encode(test_text) print(f"Encoded: {encoded[:20]}...") # Show first 20 tokens decoded = encoder.decode(encoded) print(f"Decoded: {decoded}") print("\nTest complete!")