| r""" | |
| - basic bpe-tokenizer that doesn't uses byte pairing, insted uses set of initial unique characters | |
| to train the new vocab | |
| - set of initial characters = ["\n", "A", "C", "G", "T", " "] that can be present in a file or are | |
| needed for the tokenizer | |
| - save and load functions, saves two files, '.model' and 'vocab.json' and only '.model' file is loaded | |
| 'vocab.json' is just for human interpretation | |
| """ | |
| from tqdm import tqdm | |
| import json | |
| import os | |
| current_dir = os.path.dirname(os.path.realpath(__file__)) | |
| os.chdir(current_dir) | |
| class DNAtokenizer: | |
| def __init__(self): | |
| """ | |
| inital variables: | |
| - chars = set of unique characters that could be present in the file, that are needed | |
| - merges, vocab = empty dictonaries to store future merges and final vocab | |
| - vocab_size = initially it's equal to 6 or len(chars), updated later | |
| - str_to_idx, idx_to_str = functions enumerate chars to idx and idx to chars | |
| """ | |
| super().__init__() | |
| self.chars = ["\n", "A", "C", "G", "T", " "] | |
| self.vocab_size = len(self.chars) | |
| self.merges = {} | |
| self.vocab = {} | |
| self.string_to_index = {char: idx for idx, char in enumerate(self.chars)} | |
| self.index_to_string = {idx: char for idx, char in enumerate(self.chars)} | |
| def _encode(self, string): | |
| """ | |
| encoder: takes a string, returns a list of integers | |
| eg. AATGC --> ['2', '2', '5', '4', '3'] | |
| """ | |
| encoded = [self.string_to_index[char] for char in string] | |
| return encoded | |
| def _decode(self, integer): | |
| """ | |
| decoder: takes a list of integers, returns a string | |
| eg. ['2', '2', '5', '4', '3'] --> AATGC | |
| """ | |
| decoded = ''.join([self.index_to_string[i] for i in integer]) | |
| return decoded | |
| def _get_stats(self, ids, counts=None): | |
| """ | |
| takes list of integers and returns dictionary of counts of pairs(consecutive ones) | |
| eg: [1, 2, 3, 1, 2] -> {(1, 2): 2, (2, 3): 1, (3, 1): 1} | |
| allows to update an existing dictionary of counts | |
| """ | |
| counts = {} if counts is None else counts | |
| for pair in zip(ids, ids[1:]): | |
| counts[pair] = counts.get(pair, 0) + 1 | |
| return counts | |
| def _merge(self, ids, pair, idx): | |
| """ | |
| in the list of integers, replaces all consecutive pair with the new integer token idx | |
| eg: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4] | |
| """ | |
| new_ids = [] | |
| i = 0 | |
| while i < len(ids): | |
| if i+1 < len(ids) and ids[i] == pair[0] and ids[i+1] == pair[1]: | |
| new_ids.append(idx) | |
| i += 2 | |
| else: | |
| new_ids.append(ids[i]) | |
| i += 1 | |
| return new_ids | |
| def _build_vocab(self): | |
| """ | |
| it was causing some bugs, if not used, so I had to use it | |
| """ | |
| return {i: ids for i, ids in enumerate(self.chars)} | |
| def train(self, train_data, target_vocab): | |
| """ | |
| - takes in the data, encodes it using _encode() function, converts each unique char to index | |
| eg. AATGC --> ['2', '2', '5', '4', '3'] | |
| - performs iteration till n_merges i.e. target_vocab - self.vocab_size | |
| - each iteration, makes dictonary of 2 consecutive pairs and then merges the max occuring | |
| pair together | |
| - at the end uses merges to build final vocab | |
| Args: | |
| train_data (str): a big file containing lots of dna sequence | |
| target_vocab (integer): name tells you fucking idiot | |
| """ | |
| vocab = self._build_vocab() | |
| tokens = self._encode(train_data) | |
| ids = list(tokens) | |
| merges = {} | |
| n_merges = target_vocab - self.vocab_size + 1 | |
| for i in tqdm(range(n_merges), desc='Training the tokenizer\t'): | |
| stats = self._get_stats(ids) | |
| pair = max(stats, key=stats.get) | |
| idx = self.vocab_size + i | |
| ids = self._merge(ids, pair, idx) | |
| merges[pair] = idx | |
| for (p0, p1), idx in merges.items(): | |
| vocab[idx] = vocab[p0] + vocab[p1] | |
| self.vocab = vocab | |
| self.merges = merges | |
| self.vocab_size = len(vocab) | |
| def continue_train(self, train_data, n_merges): | |
| """ | |
| - takes in the data, performs iteration till n_merges | |
| - continues from the last index of the loaded merges | |
| - each iteration, makes dictonary of 2 consecutive pairs and then merges the max occuring | |
| pair together (same as train()) | |
| - at the end uses merges to build final vocab | |
| Args: | |
| train_data (str): a big file containing lots of dna sequence | |
| n_merges (integer): no of merges | |
| ** this function has some problems | |
| """ | |
| tokens = self._encode(train_data) | |
| ids = list(tokens) | |
| for i in tqdm(range(n_merges), desc='Training continue'): | |
| stats = self._get_stats(ids) | |
| pair = max(stats, key=stats.get) | |
| idx = self.vocab_size + i | |
| ids = self._merge(ids, pair, idx) | |
| self.merges[pair] = idx | |
| for (p0, p1), idx in self.merges.items(): | |
| self.vocab[idx] = self.vocab[p0] + self.vocab[p1] | |
| self.vocab_size = len(self.vocab) | |
| def encode(self, text): | |
| """ | |
| - takes in the input string, encodes it using initial vocab '_encode()' function | |
| - fetches merges from saved or loaded merges | |
| Args: | |
| train_data (str): string of dna sequence | |
| self.merges (dictonary): contains merges | |
| """ | |
| tokens = self._encode(text) | |
| ids = list(tokens) | |
| while len(ids) >= 2: | |
| stats = self._get_stats(ids) | |
| pair = min(stats, key=lambda p: self.merges.get(p, float('inf'))) | |
| if pair not in self.merges: | |
| break | |
| idx = self.merges[pair] | |
| ids = self._merge(ids, pair, idx) | |
| return ids | |
| def decode(self, de_text): | |
| tokens = [self.vocab[idx] for idx in de_text] | |
| text = ''.join(tokens) | |
| return text | |
| def save_model(self, model_prefix): | |
| """ | |
| - basic save_model() funtion, saves two files, '.model' & 'vocab.json' | |
| - '.model' contians all the final merges, each on next line | |
| - 'vocab.json' contians the final vocab, for human interpretation | |
| Args: | |
| model_prefix (str): prefix along with the path | |
| self.merges (dict): contains final merges | |
| self.vocab (dict): contains final vocab | |
| """ | |
| model_file = model_prefix + '.model' | |
| with open(model_file, 'w', encoding='utf-8') as fwrite: | |
| for ids1, ids2 in self.merges: | |
| fwrite.write(f"{ids1} {ids2}\n") | |
| vocab_file = model_prefix + '_vocab.json' | |
| with open(vocab_file, 'w') as f: | |
| json.dump(self.vocab, f) | |
| print('model file saved successfully!') | |
| def load_model(self, model_path): | |
| """ | |
| - loads the '.model' file | |
| - re-writes the merges in the new merges dict | |
| - builds the vocab again for further use | |
| Args: | |
| model_path (str): path to the '.model' file | |
| """ | |
| assert model_path.endswith('.model') | |
| merges = {} | |
| idx = self.vocab_size | |
| with open(model_path, 'r', encoding='utf-8') as fread: | |
| for line in fread: | |
| idx1, idx2 = map(int, line.split()) | |
| merges[(idx1, idx2)] = idx | |
| idx += 1 | |
| vocab = self._build_vocab() | |
| for (p0, p1), idx in merges.items(): | |
| vocab[idx] = vocab[p0] + vocab[p1] | |
| self.merges = merges | |
| self.vocab = vocab | |
| self.vocab_size = len(self.vocab) |