Spaces:
Build error
Build error
| import yaml | |
| import regex as re | |
| from tqdm import tqdm | |
| import gc | |
| import json | |
| def load_config(config_file_path: str = "config.yml"): | |
| with open(config_file_path, "r") as f: | |
| config = yaml.safe_load(f) | |
| return config | |
| def get_input_text(config: dict) -> str: | |
| with open(config["input_file_info"]["file_path"], 'r', encoding='utf-8') as _f: | |
| hi_text = [line.strip() for line in _f.readlines()] | |
| hi_text_abridged = hi_text[:int(config["input_file_info"]["input_file_limit"])] | |
| hi_text_abridged = '\n'.join(hi_text_abridged) | |
| if config["input_file_info"]["print_text"]: | |
| print(" Sample text: ", hi_text_abridged[:10]) | |
| return hi_text_abridged | |
| def get_stats(ids, counts= None): | |
| counts = {} if counts is None else counts | |
| for pair in zip(ids, ids[1:]): | |
| counts[pair] = counts.get(pair, 0) + 1 | |
| return counts | |
| def merge(ids, pair, idx): | |
| newids = [] | |
| i = 0 | |
| while i < len(ids): | |
| if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]: | |
| newids.append(idx) | |
| i += 2 | |
| else: | |
| newids.append(ids[i]) | |
| i += 1 | |
| return newids | |
| def stoi(text: str, config: dict) -> list: | |
| # tokenize the text | |
| if config["regex_string"] and len(config["regex_string"]) > 0: | |
| print("Using regex string: ", config["regex_string"]) | |
| tokens = re.findall(config["regex_string"], text) | |
| # Convert tokens to bytes and then to integers | |
| return [b for token in tokens for b in token.encode('utf-8')] | |
| else: | |
| print("Using default tokenizer") | |
| # Instead of splitting, we'll preserve spaces by encoding them directly | |
| return [b for ch in text for b in ch.encode('utf-8')] | |
| def encode(text, merges, config: dict): | |
| """ | |
| Encode text into tokens using the learned merges | |
| """ | |
| ids = stoi(text, config) | |
| sorted_merges = sorted(merges.items(), key=lambda x: x[1]) | |
| for (p1, p2), idx in sorted_merges: | |
| ids = merge(ids, (p1, p2), idx) | |
| return ids | |
| def decode(ids, merges, config: dict): | |
| """ | |
| Decode tokens back to text using the learned merges | |
| """ | |
| # Create reverse mapping from token to pair | |
| reverse_merges = {idx: pair for pair, idx in merges.items()} | |
| # Expand all tokens recursively | |
| def expand_token(token): | |
| if token < 256: # Base case: token is a byte | |
| return bytes([token]) | |
| # Recursive case: expand the token into its constituent pair | |
| pair = reverse_merges[token] | |
| return expand_token(pair[0]) + expand_token(pair[1]) | |
| # Expand all tokens and concatenate | |
| bytes_list = [expand_token(id) for id in ids] | |
| bytes_data = b''.join(bytes_list) | |
| # Convert bytes back to text | |
| try: | |
| return bytes_data.decode('utf-8') | |
| except UnicodeDecodeError: | |
| return "[DECODE_ERROR]" | |
| class Tokenizer: | |
| def __init__(self, merges = None, config: dict = None): | |
| self.merges = merges or {} | |
| self.config = config | |
| def save(self, file_path): | |
| # Convert tuple keys to strings for JSON serialization | |
| serializable_merges = {f"{k[0]},{k[1]}": v for k, v in self.merges.items()} | |
| with open(file_path, 'w', encoding='utf-8') as f: | |
| json.dump(serializable_merges, f) | |
| def load(cls, file_path): | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| serialized_merges = json.load(f) | |
| # Convert string keys back to tuples | |
| merges = {tuple(map(int, k.split(','))): v | |
| for k, v in serialized_merges.items()} | |
| return cls(merges) | |
| def encode(self, text): | |
| return encode(text, self.merges, self.config) | |
| def decode(self, ids): | |
| return decode(ids, self.merges, self.config) | |
| def train_tokenizer(config: dict) -> None: | |
| # get input text | |
| hi_text = get_input_text(config) | |
| # convert string to tokens | |
| tokens = stoi(hi_text, config) | |
| initial_len = len(tokens) | |
| print("Tokens length (initial): ", initial_len, " tokens unique: ", len(set(tokens))) | |
| print("Example tokens: ", ord('क'), chr(2325), ord("।"), chr(2404)) | |
| print("Training tokenizer....") | |
| num_merges = config["vocab_size"] - 256 | |
| original_token = tokens | |
| merges ={} | |
| pbar = tqdm(range(num_merges), desc="Training tokenizer") | |
| output_file = config["output_file_info"]["file_path"] | |
| for i in pbar: | |
| # Get statistics of the tokens | |
| stats = get_stats(tokens) | |
| # Get the most frequent pair | |
| pair = max (stats, key=stats.get) | |
| # Get the index of the new token | |
| idx = 256 + i | |
| # Merge the pair | |
| tokens = merge(tokens, pair, idx) | |
| merges[pair] = idx | |
| # Show progress | |
| if (i + 1) % 100 == 0: | |
| current_ratio = initial_len / len(tokens) | |
| pbar.write(f"Iteration {i+1}: compression ratio: {current_ratio:.2f}X") | |
| # Garbage collection periodically | |
| if (i + 1) % 1000 == 0: | |
| gc.collect() | |
| # Save intermediate merges | |
| if (i + 1) % 1000 == 0: | |
| temp_tokenizer = Tokenizer(merges) | |
| temp_tokenizer.save(f"{output_file}.checkpoint") | |
| print("Training tokenizer completed") | |
| final_tokenizer = Tokenizer(merges) | |
| final_tokenizer.save(f"{output_file}") | |
| print("\n=== Final Statistics ===") | |
| print(f"Vocabulary size: {config['vocab_size']}") | |
| print(f"Initial tokens: {initial_len:,}") | |
| print(f"Final tokens: {len(tokens):,}") | |
| print(f"Initial bytes: {initial_len * 4:,}") | |
| print(f"Final bytes: {len(tokens) * 4:,}") | |
| print(f"Token compression ratio: {initial_len / len(tokens):.2f}X") | |
| print(f"Byte compression ratio: {initial_len * 4 / len(tokens) * 4:.2f}X") | |
| print(f"Saved tokenizer to: {output_file}") | |
| return merges | |
| def load_tokenizer(config: dict) -> Tokenizer: | |
| "load the tokenizer from the json file" | |
| with open(config["output_file_info"]["file_path"], 'r', encoding='utf-8') as f: | |
| serialized_merges = json.load(f) | |
| merges = {tuple(map(int, k.split(','))): v | |
| for k, v in serialized_merges.items()} | |
| return Tokenizer(merges, config) | |
| if __name__ == "__main__": | |
| # TRAIN TOKENIZER | |
| config = load_config() | |
| merges = train_tokenizer(config) | |
| print("Merges: ", merges) | |
| # USE TOKENIZER | |
| # tokenizer = load_tokenizer(config) | |
| # test_text = config["test_text"] | |
| # print("Test text: ", test_text) | |
| # print("Encoded text: ", tokenizer.encode(test_text)) | |
| # decoded = tokenizer.decode(tokenizer.encode(test_text)) | |
| # print("Decoded text: ", decoded) | |
| # print(f"Successful roundtrip: {test_text == decoded}") | |