Spaces:
Sleeping
Sleeping
File size: 4,880 Bytes
e4d5fc0 96e2c6c 667a4ff 96e2c6c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import pandas as pd
import re
import encoder_parallel_telugu as encode_parallel
import time
import json
from collections import defaultdict
from tqdm import tqdm
def load_and_encode_tokens():
tokens = encode_parallel.load_telugu_texts()
start_time = time.time()
encoded_tokens = encode_parallel.encode_tokens_parallel(tokens, chunk_size=1_000_000, max_workers=10)
print('encoded_tokens:', encoded_tokens[:100])
print(len(encoded_tokens))
end_time = time.time()
print(f"Time taken to encode and process tokens in parallel: {end_time - start_time:.4f} seconds")
print('length of encoded_text:', len(encoded_tokens))
print('unique characters in decoded_text:', {token.decode('utf-8') for token in set(encoded_tokens)})
# print('unique characters in encoded_text:', set(encoded_tokens))
print('unique characters in encoded_text:', len(set(encoded_tokens)))
return encoded_tokens
def get_stats(ids):
counts = {}
for pair in zip(ids, ids[1:]):
counts[pair] = counts.get(pair, 0) + 1
return counts
def merge(ids, pair, idx):
new_ids = []
i = 0
while i < len(ids):
if i < len(ids) - 1 and ids[i] == pair[0] and ids[i + 1] == pair[1]:
new_ids.append(idx)
i += 2
else:
new_ids.append(ids[i])
i += 1
return new_ids
def bpe_process(encoded_tokens,vocab_size=500, encoded_tokens_length=10_00_000):
num_merges = vocab_size - 256 # our unique tokens are 194, for our sample text.
encoded_tokens = encoded_tokens[:encoded_tokens_length]
ids = list(encoded_tokens) # copy so we don't destroy the original list
merges = {} # (int, int) -> int
for i in tqdm(range(num_merges), desc="Merging tokens"):
stats = get_stats(ids)
pair = max(stats, key=stats.get)
idx = 256 + i
ids = merge(ids, pair, idx)
merges[pair] = idx # merge has a pair of tokens and the new token index
print("tokens length:", len(encoded_tokens))
print("ids length:", len(ids))
print("by paired tokens length:", len(set(ids)))
print(f"compression ratio: {len(encoded_tokens) / len(ids):.2f}X")
# print(f"token size: {len(set(encoded_tokens))}")
return merges
def build_vocabulary(merges):
telugu_unicode_chars = [chr(i) for i in range(0x0C00, 0x0C7F)] # Telugu Unicode range
vocab = {token: idx for token, idx in merges.items()}
for idx, char in enumerate([chr(i).encode('utf-8') for i in range(0x0C00, 0x0C7F)]):
if idx < 256: # Ensure we only add up to 256 characters
vocab[char] = idx # Map the character to its index
vocab[b' '] = 255
vocab[b'.'] = 254
with open('merges_vocab.json', 'w') as f:
json.dump({'merges': {str(k): v for k, v in merges.items()}, 'vocab': {str(k): v for k, v in vocab.items()}}, f)
def read_vocab_from_file():
with open('merges_vocab.json', 'r') as f:
data = json.load(f)
distributed_data = defaultdict(list)
for key, value in data['vocab'].items():
distributed_data['vocab'].append({key: value})
formatted_vocab = {}
for item in distributed_data['vocab']:
for k, v in item.items():
if ',' not in k:
formatted_vocab[(eval(k),)] = v
else:
formatted_vocab[eval(k)] = v
return formatted_vocab
def expand_vocab(inverted_vocab):
def convert_to_bytes(value):
if isinstance(value, bytes):
return value
elif value in inverted_vocab:
return process_tuple(inverted_vocab[value])
else:
print(f'value not found in inverted_vocab: {value}')
return None
def process_tuple(value_tuple):
converted_values = []
for v in value_tuple:
result = convert_to_bytes(v)
if isinstance(result, tuple):
converted_values.extend(result)
else:
converted_values.append(result)
return tuple(converted_values)
decoder_map = {k: process_tuple(v) for k, v in inverted_vocab.items()}
print("sample decoder map:", {k: decoder_map[k] for k in list(decoder_map)[:5]})
return decoder_map
# Main execution
if __name__ == "__main__":
# 1. Load and encode tokens
encoded_tokens = load_and_encode_tokens()
# 2. Process BPE
merges = bpe_process(encoded_tokens,vocab_size=5000, encoded_tokens_length=20_00_000)
# 3. Build vocabulary
build_vocabulary(merges)
# 4. Read vocabulary from file
formatted_vocab = read_vocab_from_file()
# 5. Invert vocabulary
inverted_vocab = {v: k for k, v in formatted_vocab.items()}
# 6. Expand vocabulary
decoder_map = expand_vocab(inverted_vocab)
# 7. Invert back again
re_inverted_vocab = {k: v for v, k in decoder_map.items()}
# print(re_inverted_vocab)
|