Spaces:
Sleeping
Sleeping
| import re | |
| import collections | |
| from typing import Dict, List, Tuple, Set | |
| import json | |
| from pathlib import Path | |
| class TeluguBPE: | |
| def __init__(self, vocab_size: int = 5000): | |
| self.vocab_size = vocab_size | |
| self.merges: Dict[Tuple[str, str], str] = {} | |
| self.vocab: Set[str] = set() | |
| def preprocess_telugu_text(self, text: str) -> str: | |
| """ | |
| Preprocess Telugu text with specific rules | |
| """ | |
| # Remove any ASCII characters except spaces and newlines | |
| text = re.sub(r'[^\u0C00-\u0C7F\s\n]', '', text) | |
| # Normalize spaces | |
| text = re.sub(r'\s+', ' ', text) | |
| # Add spaces between Telugu characters and numbers | |
| text = re.sub(r'(\d+)', r' \1 ', text) | |
| # Add spaces between Telugu punctuation marks | |
| text = re.sub(r'([।॥,?!])', r' \1 ', text) | |
| # Handle Telugu specific patterns | |
| # Add space after purna virama (full stop) | |
| text = re.sub(r'([।॥])', r'\1 ', text) | |
| # Separate combined vowel marks | |
| text = re.sub(r'([\u0C3E-\u0C4C])', r' \1', text) | |
| return text.strip() | |
| def get_stats(self, words: List[List[str]]) -> Dict[Tuple[str, str], int]: | |
| """ | |
| Count frequency of adjacent pairs in current vocabulary | |
| """ | |
| pairs = collections.defaultdict(int) | |
| for word in words: | |
| for i in range(len(word) - 1): | |
| pairs[tuple(word[i:i + 2])] += 1 | |
| return pairs | |
| def merge_vocab(self, words: List[List[str]], pair: Tuple[str, str]) -> List[List[str]]: | |
| """ | |
| Merge all occurrences of the most frequent pair | |
| """ | |
| first, second = pair | |
| new_words = [] | |
| for word in words: | |
| i = 0 | |
| new_word = [] | |
| while i < len(word): | |
| if i < len(word) - 1 and word[i] == first and word[i + 1] == second: | |
| new_word.append(first + second) | |
| i += 2 | |
| else: | |
| new_word.append(word[i]) | |
| i += 1 | |
| new_words.append(new_word) | |
| return new_words | |
| def learn_bpe(self, text: str) -> None: | |
| """ | |
| Learn BPE merges from text | |
| """ | |
| # Initial vocabulary: character level | |
| words = [[char for char in word] for word in text.split()] | |
| self.vocab = set(char for word in words for char in word) | |
| num_merges = self.vocab_size - len(self.vocab) | |
| for i in range(num_merges): | |
| pairs = self.get_stats(words) | |
| if not pairs: | |
| break | |
| best_pair = max(pairs.items(), key=lambda x: x[1])[0] | |
| self.merges[best_pair] = best_pair[0] + best_pair[1] | |
| self.vocab.add(self.merges[best_pair]) | |
| words = self.merge_vocab(words, best_pair) | |
| if len(self.vocab) >= self.vocab_size: | |
| break | |
| def encode(self, text: str) -> List[str]: | |
| """ | |
| Encode text using learned BPE merges | |
| """ | |
| words = [[char for char in word] for word in text.split()] | |
| for pair, merge in self.merges.items(): | |
| words = self.merge_vocab(words, pair) | |
| return [token for word in words for token in word] | |
| def save_model(self, path: str) -> None: | |
| """ | |
| Save BPE model to file | |
| """ | |
| model_data = { | |
| 'vocab_size': self.vocab_size, | |
| 'merges': {f'{k[0]} {k[1]}': v for k, v in self.merges.items()}, | |
| 'vocab': list(self.vocab) | |
| } | |
| with open(path, 'w', encoding='utf-8') as f: | |
| json.dump(model_data, f, ensure_ascii=False, indent=2) | |
| def load_model(self, path: str) -> None: | |
| """ | |
| Load BPE model from file | |
| """ | |
| with open(path, 'r', encoding='utf-8') as f: | |
| model_data = json.load(f) | |
| self.vocab_size = model_data['vocab_size'] | |
| self.merges = {tuple(k.split()): v for k, v in model_data['merges'].items()} | |
| self.vocab = set(model_data['vocab']) | |
| def main(): | |
| # Example usage | |
| input_file = "telugu_text.txt" | |
| model_file = "telugu_bpe_model.json" | |
| # Read input text | |
| with open(input_file, 'r', encoding='utf-8') as f: | |
| text = f.read() | |
| print(f'Started learning BPE') | |
| bpe = TeluguBPE(vocab_size=5000) | |
| # Preprocess text | |
| processed_text = bpe.preprocess_telugu_text(text) | |
| # Calculate original text statistics | |
| original_chars = len(processed_text) | |
| original_tokens = len(processed_text.split()) | |
| # Learn BPE | |
| bpe.learn_bpe(processed_text) | |
| # Encode the entire text to calculate compression | |
| encoded_text = bpe.encode(processed_text) | |
| encoded_length = len(encoded_text) | |
| # Calculate compression ratio | |
| compression_ratio = original_chars / encoded_length | |
| # Save model | |
| bpe.save_model(model_file) | |
| # Print statistics | |
| print(f"\nCompression Statistics:") | |
| print(f"Original characters: {original_chars}") | |
| print(f"Original tokens (words): {original_tokens}") | |
| print(f"Encoded tokens: {encoded_length}") | |
| print(f"Compression ratio: {compression_ratio:.2f}x") | |
| print(f"Vocabulary size: {len(bpe.vocab)}") | |
| # Example encoding | |
| sample_text = "నమస్కారం" # "Hello" in Telugu | |
| encoded = bpe.encode(bpe.preprocess_telugu_text(sample_text)) | |
| print(f"\nExample encoding:") | |
| print(f"Sample text: {sample_text}") | |
| print(f"Encoded text: {encoded}") | |
| if __name__ == "__main__": | |
| main() |