| from abc import ABC |
| import json, os, re, torch |
| from abc import abstractmethod |
| from transformers.utils import logging |
| from collections import Counter, defaultdict |
| from typing import TYPE_CHECKING, List, Optional, Tuple |
| from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer |
|
|
| logger = logging.get_logger(__name__) |
|
|
| VOCAB_FILES_NAMES = { |
| "subtoken_reference_file": "subtoken_reference.json", |
| "vocab_file": "vocab.json", |
| "merges_file": "merges.json", |
| "chars_file": "chars.txt" |
| } |
|
|
| PRETRAINED_VOCAB_FILES_MAP = { |
| "subtoken_reference_file": { |
| "https://huggingface.co/saffu-BBLM10M/resolve/main/subtoken_reference.json", |
| "https://huggingface.co/saffu-BBLM100M/resolve/main/subtoken_reference.json", |
| }, |
| "chars_file": { |
| "https://huggingface.co/saffu-BBLM10M/resolve/main/chars.json", |
| "https://huggingface.co/saffu-BBLM100M/resolve/main/chars.json", |
| }, |
| "vocab_file": { |
| "saffu-BBLM10M": "https://huggingface.co/saffu-BBLM10M/resolve/main/vocab.json", |
| "saffu-BBLM100M": "https://huggingface.co/saffu-BBLM100M/resolve/main/vocab.json", |
| }, |
| "merges_file": { |
| "saffu-BBLM10M": "https://huggingface.co/saffu-BBLM10M/resolve/main/merges.json", |
| "saffu-BBLM100M": "https://huggingface.co/saffu-BBLM100M/resolve/main/merges.json", |
| }, |
| } |
|
|
| class SAFFUTokenizer(PreTrainedTokenizer): |
| """ |
| Construct a SAFFU tokenizer. Based on rule-based pre-tokenization followed by Byte-Pair sub-word chunking. |
| """ |
|
|
| vocab_files_names = VOCAB_FILES_NAMES |
| pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP |
| model_input_names = ["input_ids"] |
|
|
| def __init__( |
| self, |
| vocab_file, |
| subtoken_reference_file, |
| merges_file, |
| chars_file, |
| r = 2, |
| block_size = 100, |
| heads = 2, |
| space = False, |
| pad = "<pad>", |
| oov = "<oov>", |
| sod = "<sod>", |
| eod = "<eod>", |
| frg = "<frg>", |
| **kwargs, |
| ): |
| super().__init__( |
| **kwargs, |
| ) |
| |
| self._r = r |
| self._space = space |
| self._heads = heads |
| self._block_size = block_size |
| self._raw_td = json.load(open(merges_file)) |
| self._td = load_td(path = merges_file) |
| self._wordchars = re.sub(" ", "", open(chars_file).read().strip()) |
| self._vocabulary = json.loads(open(vocab_file).read()) |
| self._index = {self._vocabulary[t]: t for t in self._vocabulary} |
| self._subtoken_reference = json.loads(open(subtoken_reference_file).read()) |
| self._pad = pad |
| self._oov = oov |
| self._sod = sod |
| self._eod = eod |
| self._frg = frg |
| self._padding = [self._vocabulary[self._pad]]*self._r |
| self._masking = [self._vocabulary[self._pad]]*self._block_size |
| self._heads_padding = [self._vocabulary[self._pad]]*self._heads |
|
|
| def save_vocabulary(self, save_directory: str, |
| filename_prefix: Optional[str] = None) -> Tuple[str]: |
| if not os.path.isdir(save_directory): |
| logger.error(f"Vocabulary path ({save_directory}) should be a directory") |
| return |
| vocab_file = os.path.join( |
| save_directory, (filename_prefix + "-" if filename_prefix else "") + |
| VOCAB_FILES_NAMES["vocab_file"] |
| ) |
| merge_file = os.path.join( |
| save_directory, (filename_prefix + "-" if filename_prefix else "") + |
| VOCAB_FILES_NAMES["merges_file"] |
| ) |
| subtoken_reference_file = os.path.join( |
| save_directory, (filename_prefix + "-" if filename_prefix else "") + |
| VOCAB_FILES_NAMES["subtoken_reference_file"] |
| ) |
| chars_file = os.path.join( |
| save_directory, (filename_prefix + "-" if filename_prefix else "") + |
| VOCAB_FILES_NAMES["chars_file"] |
| ) |
| |
| with open(vocab_file, "w", encoding="utf-8") as f: |
| f.write(json.dumps(self._vocabulary, indent=2, sort_keys=True, ensure_ascii=False) + "\n") |
| with open(merge_file, "w", encoding="utf-8") as f: |
| f.write(json.dumps(self._raw_td, indent=2, sort_keys=True, ensure_ascii=False)) |
| with open(subtoken_reference_file, "w", encoding="utf-8") as f: |
| f.write(json.dumps(self._subtoken_reference, indent=2, sort_keys=True, ensure_ascii=False)) |
| with open(chars_file, "w", encoding="utf-8") as f: |
| f.write(self._wordchars) |
|
|
| return vocab_file, merge_file, subtoken_reference_file, chars_file |
|
|
| @property |
| def vocab_size(self): |
| return len(self._vocabulary) |
|
|
| @staticmethod |
| def word_tokenize(text, wordchars = "a-zA-Z0-9-'"): |
| return [token for token in re.split("(["+wordchars+"'-]+)", text) if token] |
|
|
| @staticmethod |
| def stick_spaces(stream): |
| tokens = [] |
| for wi, w in enumerate(stream): |
| if not tokens: |
| tokens.append(w) |
| elif w == ' ': |
| if (tokens[-1][-1] != ' ') and (wi != len(stream)-1): |
| tokens.append(w) |
| else: |
| tokens[-1] = tokens[-1] + w |
| else: |
| if tokens[-1][-1] == ' ': |
| if tokens[-1] == ' ': |
| tokens[-1] = tokens[-1] + w |
| else: |
| tokens[-1] = tokens[-1][:-1] |
| tokens.append(' ' + w) |
| else: |
| tokens.append(w) |
| return tokens |
|
|
| @staticmethod |
| def sentence_tokenize(text, wordchars = "a-zA-Z0-9-'", puncts = ".?!;:\n|"): |
| sentences = [] |
| for sentence in re.split("(\s+(?<=["+puncts+"][^"+wordchars+"'-])\s*)", text): |
| if not sentence: continue |
| if not re.search("["+wordchars+"'-]", sentence): |
| if len(sentences): |
| if sentence[-1] == " ": |
| if len(sentence) > 1: |
| sentences[-1] = sentences[-1] + sentence[:-1] |
| sentences.append(sentence[-1]) |
| else: |
| sentences.append(sentence) |
| else: |
| sentences[-1] = sentences[-1] + sentence |
| else: |
| sentences.append(sentence) |
| else: |
| if len(sentences): |
| if len(sentences[-1]) == 1 and sentences[-1] == " ": |
| sentences[-1] = sentences[-1] + sentence |
| else: |
| sentences.append(sentence) |
| else: |
| sentences.append(sentence) |
| return sentences |
|
|
| def bpe_tokenize(self, text): |
| stream = self._subtoken_reference.get(text, list_tokenize(text, td = self._td)) |
| return (list(stream if self._space else self.stick_spaces(stream))) |
| |
| def _tokenize(self, text): |
| """Tokenize a string.""" |
| return [sub for s in self.sentence_tokenize(text, wordchars = self._wordchars) |
| for t in (self.word_tokenize(s, wordchars = self._wordchars) if self._space else |
| self.stick_spaces(self.word_tokenize(s, wordchars = self._wordchars))) |
| for sub in self.bpe_tokenize(t)] |
| return document |
|
|
| def preprocess(self, input_ids = []): |
| document = input_ids + [self._vocabulary[self._eod], self._vocabulary[self._pad]] |
| blocks = []; docsize = len(document) |
| for bi in range(int(docsize/self._block_size) + 1): |
| start = bi*self._block_size |
| if start > docsize: continue |
| end = min([(bi+1)*self._block_size, docsize]) |
| data = [self._vocabulary[self._frg if bi else self._sod]] + document[start:end] |
| block = (lambda x: x[:3] + [x[3:]])( |
| list(map(list, |
| zip(*[(t, self._padding[:self._r - m] + data[:m] if m - self._r < 0 else data[m - self._r:m], |
| data[:m] + self._masking[:self._block_size - m] if m < self._block_size else data[:self._block_size], |
| *(self._heads_padding[:self._heads - m] + data[:m] if m - self._heads < 0 else |
| data[m - self._heads:m])[::-1]) |
| for m, t in enumerate(data)]) |
| )) ) |
| blocks.append(block) |
| return blocks |
|
|
| def _convert_token_to_id(self, t): |
| """Converts a token (str) in an id using the vocab.""" |
| return self._vocabulary.get(t, self._vocabulary.get(self._oov)) |
|
|
| def _convert_id_to_token(self, i): |
| """Converts an index (integer) in a token (str) using the vocab.""" |
| return self._index.get(i, self._oov) |
|
|
| def load_td(data = None, path = ''): |
| if data is None and path: |
| data = json.load(open(path)) |
| td = {} |
| |
| td['tok2ind'] = data['tok2ind'] |
| td['ind2tok'] = {v: k for k, v in td['tok2ind'].items()} |
| td['action_trace'] = [{'pair': tuple(a[0]), 'type': 'merge' if a[1] else 'split', |
| 'count': a[2], 'score': a[3]} for a in data['action_trace']] |
| td['tok2acts'] = defaultdict(list) |
| td['pair2merge'] = dict() |
| td['tok2splits'] = defaultdict(list) |
| for aix, a in enumerate(td['action_trace']): |
| if a['type'] =='split': |
| td['tok2acts']["".join(a['pair'])].append(aix) |
| td['tok2splits']["".join(a['pair'])].append(aix) |
| else: |
| td['pair2merge'][tuple(a['pair'])] = aix |
| td['tok2acts'][a['pair'][0]].append(aix) |
| td['tok2acts'][a['pair'][1]].append(aix) |
| td['maxtoklen'] = max([len(t) for t in td['tok2ind']]) |
| |
| if 'unigraph' in data: |
| td['unigraph'] = Counter(data['unigraph']) |
| td['digraph'] = Counter({(l, r): v for l, r, v in data['digraph']}) |
| td['doc_unigraph'] = defaultdict(Counter) |
| for k, v in data['doc_unigraph'].items(): |
| td['doc_unigraph'][k] = Counter(v) |
| td['init_method'] = data['init_method'] |
| |
| if 'param_method' in data: |
| td['param_method'] = data['param_method'] |
| td['reg_model'] = data['reg_model'] |
| td['early_stop'] = data['early_stop'] |
| return td |
|
|
| def list_tokenize(text, td = {}): |
| assert td['action_trace'], "Can't tokenize, no trained model!" |
| mock = BPE() |
| mock.init([text], method=td['init_method'], apply=True) |
| prev_aix = -1; available_action_indices = []; observed = set(); tokenizing = True |
| while tokenizing: |
| available_action_indices = sorted(list(filter(lambda next_aix: next_aix > prev_aix, available_action_indices)) + |
| [next_aix for next_aix in [aix for tok in mock._unigraph |
| for aix in td['tok2splits'][tok] if tok not in observed] + |
| [td['pair2merge'][pair] for pair in mock._digraph |
| if pair not in observed and pair in td['pair2merge']] |
| if next_aix > prev_aix]) |
| observed = observed.union(set(mock._unigraph.keys()).union(set(mock._digraph.keys()))) |
| if available_action_indices: |
| aix = available_action_indices[0] |
| else: |
| tokenizing = False |
| break |
| prev_aix = aix |
| action = td['action_trace'][aix] |
| if action['type'] == 'merge': |
| mock.merge(action['pair']) |
| else: |
| mock.split(action['pair']) |
| tks = [] |
| for t, idxs in mock._tok_idx.items(): |
| for ix in idxs: |
| tks.append((t, ix)) |
| tks.sort(key=lambda ti: ti[1]) |
| tks, _ = zip(*tks) |
| return tks |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| class Action: |
| |
| |
| |
| |
| |
| |
| |
| def __init__(self, pair, type='merge', count=-1): |
| self.pair = pair |
| self.type = type |
| self.count = int(count) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| class Tokenizer(ABC): |
| |
| |
| |
| |
| |
| def __init__(self, tok2ind=None): |
| if tok2ind is None: |
| self._tok2ind = {} |
| else: |
| self._tok2ind = tok2ind |
| self._ind2tok = {v: k for k, v in self._tok2ind.items()} |
| self._action_trace = [] |
| |
| |
| def __len__(self): |
| return len(self._tok2ind) |
| |
| |
| |
| def add_type(self, tok): |
| if tok not in self._tok2ind: |
| self._tok2ind[tok] = len(self._tok2ind) |
| self._ind2tok[self._tok2ind[tok]] = tok |
| |
| |
| |
| def del_type(self, tok): |
| if tok in self._tok2ind: |
| idx = self._tok2ind[tok] |
| |
| del self._ind2tok[idx] |
| del self._tok2ind[tok] |
| |
| |
| i = idx + 1 |
| while i in self._ind2tok: |
| t = self._ind2tok[i] |
| self._tok2ind[t] = i - 1 |
| self._ind2tok[i - 1] = t |
| del self._ind2tok[i] |
| |
| |
| |
| |
| |
| |
| def save(self, path, data=None): |
| if data is None: |
| data = {} |
| data['tok2ind'] = self._tok2ind |
| data['action_trace'] = [[a.pair, 1 if a.type == 'merge' else 0, a.count, a.score if hasattr(a, "score") else None] |
| for a in self._action_trace] |
| json.dump(data, open(path, 'w+')) |
| |
| |
| |
| |
| |
| def load(self, path): |
| data = json.load(open(path)) |
| self._tok2ind = data['tok2ind'] |
| self._ind2tok = {v: k for k, v in self._tok2ind.items()} |
| self._action_trace = [ScoredAction(tuple(a[0]), count=a[2], score=a[3], |
| type='merge' if a[1] else 'split') |
| for a in data['action_trace']] |
| self._tok2acts = defaultdict(list) |
| self._pair2merge = dict() |
| self._tok2splits = defaultdict(list) |
| for aix, a in enumerate(self._action_trace): |
| if a.type =='split': |
| self._tok2acts["".join(a.pair)].append(aix) |
| self._tok2splits["".join(a.pair)].append(aix) |
| else: |
| self._pair2merge[tuple(a.pair)] = aix |
| self._tok2acts[a.pair[0]].append(aix) |
| self._tok2acts[a.pair[1]].append(aix) |
| self._maxtoklen = max([len(t) for t in self._tok2ind]) |
| |
| return data |
|
|
| |
| |
| |
| |
| @abstractmethod |
| def init(self, docs, seed=None): |
| raise NotImplementedError |
|
|
| |
| |
| |
| |
| |
| |
| |
| @abstractmethod |
| def fit(self, num_batches, batch_size=1, seed=None): |
| raise NotImplementedError |
| |
| |
| |
| |
| |
| def encode(self, text): |
| return self.tokens_to_indices(self.tokenize(text)) |
|
|
| |
| |
| |
| |
| |
| |
| def tokenize(self, text, start=-1): |
| assert self._action_trace, "Can't tokenize, no trained model!" |
| return self.apply_action_trace(text) |
| |
| |
| |
| |
| |
| |
| |
| def update_action_indices(self, available_action_indices, model, prev_aix = -1, observed = set()): |
| available_action_indices = sorted(list(filter(lambda next_aix: next_aix > prev_aix, available_action_indices)) + |
| [next_aix for next_aix in [aix for tok in model._unigraph |
| for aix in self._tok2splits[tok] if tok not in observed] + |
| [self._pair2merge[pair] for pair in model._digraph |
| if pair not in observed and pair in self._pair2merge] |
| if next_aix > prev_aix]) |
| observed = observed.union(set(model._unigraph.keys()).union(set(model._digraph.keys()))) |
| return available_action_indices, observed |
|
|
| |
| |
| |
| |
| |
| def apply_action_trace(self, text): |
| mock = BPE() |
| mock.init([text], method=self._init_method, apply=True) |
| prev_aix = -1; available_action_indices = []; observed = set(); tokenizing = True |
| while tokenizing: |
| available_action_indices, observed = self.update_action_indices(available_action_indices, mock, |
| prev_aix = prev_aix, observed = observed) |
| if available_action_indices: |
| aix = available_action_indices[0] |
| else: |
| tokenizing = False |
| break |
| prev_aix = aix |
| action = self._action_trace[aix] |
| if action.type == 'merge': |
| mock.merge(action.pair) |
| else: |
| mock.split(action.pair) |
| tks = [] |
| for t, idxs in mock._tok_idx.items(): |
| for ix in idxs: |
| tks.append((t, ix)) |
| tks.sort(key=lambda ti: ti[1]) |
| tks, _ = zip(*tks) |
| return tks |
| |
| def return_tokenization(self): |
| tks = [] |
| for t, idxs in self._tok_idx.items(): |
| for ix in idxs: |
| tks.append((t, ix)) |
| tks.sort(key=lambda ti: ti[1]) |
| tks, _ = zip(*tks) |
| return tks |
|
|
| |
| |
| |
| |
| def decode(self, indices): |
| return ''.join(self.indices_to_tokens(indices)) |
|
|
| |
| |
| |
| |
| |
| def tokens_to_indices(self, toks): |
| return [self._tok2ind[t] for t in toks] |
|
|
| |
| |
| |
| |
| def indices_to_tokens(self, indices): |
| return [self._ind2tok[i] for i in indices] |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| class BPE(Tokenizer): |
| |
| |
| |
| |
| def __init__(self, tok2ind=None, covering_vocab = set()): |
| |
| self._covering_vocab = covering_vocab |
| self._covered = {} |
| self._covering = {} |
| if self._covering_vocab: |
| if tok2ind: |
| tok2ind = {t: i for i, t in enumerate(set(list(tok2ind.keys())+list(self._covering_vocab)))} |
| else: |
| tok2ind = {t: i for i, t in enumerate(self._covering_vocab)} |
| |
| super().__init__(tok2ind=tok2ind) |
| |
| self._lefts = {} |
| self._rights = {} |
| |
| self._unigraph = Counter() |
| self._doc_unigraph = defaultdict(Counter) |
| self._digraph = Counter() |
| |
| self._tok_idx = defaultdict(set) |
| self._pair_idx = defaultdict(set) |
| self._char2docidx = {} |
|
|
| |
| |
| |
| |
| def save(self, path, data=None): |
| if data is None: |
| data = {} |
| data['unigraph'] = dict(self._unigraph) |
| data['digraph'] = [[k[0], k[1], v] for k, v in self._digraph.items()] |
| data['doc_unigraph'] = {k: dict(v) for k, v in self._doc_unigraph.items()} |
| data['init_method'] = self._init_method |
| super(BPE, self).save(path, data=data) |
| |
| |
|
|
| |
| |
| |
| def load(self, path): |
| data = super(BPE, self).load(path) |
| self._unigraph = Counter(data['unigraph']) |
| self._digraph = Counter({(l, r): v for l, r, v in data['digraph']}) |
| self._doc_unigraph = defaultdict(Counter) |
| for k, v in data['doc_unigraph'].items(): |
| self._doc_unigraph[k] = Counter(v) |
| self._init_method = data['init_method'] |
| return data |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| def init(self, docs, seed=None, method='char', apply=False, covering = [], action_protect = ''): |
| |
| |
| self._doc_counts = Counter(); ds = []; cs = []; doc_index = {} |
| for di, doc in enumerate(docs): |
| if doc not in doc_index: |
| doc_index[doc] = len(ds) |
| ds.append(doc) |
| if covering: |
| cs.append(covering[di]) |
| self._doc_counts[doc_index[doc]] += docs[doc] if type(docs) == Counter else 1 |
| docs = ds; covering = cs |
| |
| self._init_method = method |
| self._action_protect = action_protect |
| |
| self._covering = {} |
| self._hascover = bool(covering) |
| ix = 0 |
| if covering: |
| |
| for doc, segmentation in zip(docs, covering): |
| for s_ix, s in enumerate(segmentation): |
| for ch in s: |
| self._covering[ix] = s_ix |
| ix += 1 |
| ix += 1 |
| d_ix = 0 |
| s_ix = max(self._covering.values()) if self._covering else -1 |
| for doc in docs: |
| if d_ix + len(doc) > ix: |
| s_ix += 1 |
| for ch in doc: |
| if d_ix > ix: |
| self._covering[d_ix] = s_ix |
| d_ix += 1 |
| d_ix += 1 |
| if seed: |
| np.random.seed(seed=seed) |
| offset = 0 |
| for doc_idx, doc in enumerate(docs if apply else tqdm(docs, desc=f'Initializing')): |
| stream = self._init_doc(doc, method=method) |
| assert (sum(map(len, stream)) == len(doc)) |
| for ix, tok in enumerate(stream): |
| self._unigraph[tok] += self._doc_counts[doc_idx] |
| self._tok_idx[tok].add(offset) |
| for char_idx in range(offset, offset + len(tok)): |
| self._char2docidx[char_idx] = doc_idx |
| self._doc_unigraph[doc_idx][tok] += self._doc_counts[doc_idx] |
| tok_pair = (stream[ix - 1], tok) if ix else ('', tok) |
| self._lefts[(offset - len(stream[ix - 1])) if ix else (offset - 1)] = tok_pair |
| self._rights[offset] = tok_pair |
| if ix: |
| self._digraph[tok_pair] += self._doc_counts[doc_idx] |
| self._pair_idx[tok_pair].add(offset - len(stream[ix - 1])) |
| offset += len(tok) |
| tok_pair = (tok, '') |
| self._lefts[offset - len(tok)] = tok_pair |
| self._rights[offset] = tok_pair |
| offset += 1 |
|
|
| |
| |
| |
| |
| |
| @staticmethod |
| def _init_doc(d, method='char'): |
| if method == 'char': |
| return d |
| elif method == 'warm': |
| return [token for token in re.split("([a-zA-Z0-9-']+)", d) if token] |
| elif method == 'rand': |
| topidx = sorted(set( |
| [0] + sorted(np.random.choice(np.arange(1, len(d)), size=int(len(d) / 2), replace=False)) + [len(d)])) |
| return [d[topidx[idx - 1]:topidx[idx]] for idx in range(1, len(topidx))] |
| else: |
| raise ValueError(f'Unrecognized document pre-processing method: {method}') |
|
|
| |
| |
| |
| |
| def under_cover(self, pair): |
| newtok = "".join(pair) |
| skip_next = False |
| for i in sorted(list(self._pair_idx[pair])): |
| if skip_next: |
| skip_next = False |
| continue |
| skip_next = True if pair[0] == pair[1] and pair[1] == self._lefts[i + len(pair[0])][1] else False |
| if (i in self._covering) and (i+len(newtok)-1 in self._covering): |
| if self._covering[i] != self._covering[i+len(newtok)-1]: |
| return False |
| elif (i in self._covering) or ((i+len(newtok)-1) in self._covering): |
| return False |
| else: |
| return True |
|
|
| |
| |
| |
| |
| def split_under_cover(self, wpair): |
| oldtok = "".join(wpair) |
| locations = list(self._tok_idx[oldtok]) |
| for i in sorted(locations): |
| if (self._covering[i] != self._covering[i+len(wpair[0])-1] or |
| self._covering[i+len(wpair[0])] != self._covering[i+len(wpair[0])+len(wpair[1])-1]): |
| return False |
| else: |
| return True |
|
|
| |
| |
| |
| def is_covered(self, newtok): |
| if newtok in self._covered: |
| return self._covered[newtok] |
| else: |
| for cover_token in self._covering_vocab: |
| if newtok in cover_token: |
| self._covered[newtok] = True |
| return self._covered[newtok] |
| else: |
| self._covered[newtok] = False |
| return self._covered[newtok] |
|
|
| |
| |
| |
| def is_covering(self, newtok): |
| if newtok in self._covering: |
| return self._covering[newtok] |
| else: |
| for cover_token in self._covering_vocab: |
| if cover_token in newtok: |
| self._covering[newtok] = True |
| return self._covering[newtok] |
| else: |
| self._covering[newtok] = False |
| return self._covering[newtok] |
|
|
| |
| |
| |
| |
| |
| |
| |
| def fit(self, num_batches, batch_size=1, actions_per_batch=None, seed=None): |
| if seed: |
| np.random.seed(seed=seed) |
|
|
| if actions_per_batch is None: |
| actions_per_batch = batch_size |
| elif actions_per_batch > batch_size: |
| actions_per_batch = batch_size |
|
|
| pbar = tqdm(total=self._early_stop, desc = 'Fitting') |
| for batch in range(num_batches): |
| actions = self.rank_actions(self.get_actions(batch_size, actions_per_batch)) |
| for action in actions: |
| vsize = len(self._unigraph) |
| if action.type == 'merge': |
| |
| newtok = "".join(action.pair) |
| if self._action_protect: |
| if re.search("("+"|".join(self._action_protect)+")", newtok): continue |
| if self._hascover: |
| if not self.under_cover(action.pair): |
| continue |
| if self._covering_vocab: |
| if (not self.is_covered(newtok)) and (not self.is_covering(newtok)): |
| continue |
| self.merge(action.pair) |
| else: |
| |
| if self._action_protect: |
| if (re.search("("+"|".join(self._action_protect)+")", action.pair[0]) or |
| re.search("("+"|".join(self._action_protect)+")", action.pair[1])): |
| continue |
| if self._hascover: |
| if not self.split_under_cover(action.pair): |
| continue |
| if self._covering_vocab: |
| if (((not self.is_covered(action.pair[0])) and (not self.is_covering(action.pair[0]))) or |
| ((not self.is_covered(action.pair[1])) and (not self.is_covering(action.pair[1])))): |
| continue |
| self.split(action.pair) |
| |
| self._action_trace.append(action) |
| pbar.update(len(self._unigraph) - vsize) |
| if self.do_break_early() or not actions: |
| break |
| if self.do_break_early() or not actions: |
| break |
|
|
| |
| for k, v in sorted(self._unigraph.items(), key=lambda kv: kv[1], reverse=True): |
| self.add_type(k) |
|
|
| self._tok2acts = defaultdict(list) |
| self._pair2merge = dict() |
| self._tok2splits = defaultdict(list) |
| for aix, a in enumerate(self._action_trace): |
| if a.type =='split': |
| self._tok2acts["".join(a.pair)].append(aix) |
| self._tok2splits["".join(a.pair)].append(aix) |
| else: |
| self._pair2merge[tuple(a.pair)] = aix |
| self._tok2acts[a.pair[0]].append(aix) |
| self._tok2acts[a.pair[1]].append(aix) |
| self._maxtoklen = max([len(t) for t in self._tok2ind]) |
|
|
| print(f'Built a vocabulary of {len(self)} types') |
|
|
| |
| |
| |
| |
| def merge(self, pair): |
| newtok = "".join(pair) |
|
|
| skip_next = False |
| locations = list(self._pair_idx[pair]) |
| for i in sorted(locations): |
| if skip_next: |
| skip_next = False |
| continue |
|
|
| |
| lneighbor = self._rights[i][0] |
| rneighbor = self._lefts[i + len(pair[0])][1] |
| skip_next = True if pair[0] == pair[1] and pair[1] == rneighbor else False |
|
|
| |
| del (self._lefts[i]) |
| del (self._rights[i + len(pair[0])]) |
|
|
| |
| lpair = (lneighbor, pair[0]) |
| rpair = (pair[1], rneighbor) |
|
|
| |
| newlpair = (lneighbor, newtok) |
| newrpair = (newtok, rneighbor) |
|
|
| |
| del (self._lefts[i - len(lneighbor) if lneighbor else i - 1]) |
| del (self._rights[i]) |
| del (self._lefts[i + len(pair[0])]) |
| del (self._rights[i + len(newtok)]) |
|
|
| |
| self._lefts[i - len(lneighbor) if lneighbor else i - 1] = newlpair |
| self._rights[i] = newlpair |
| self._lefts[i] = newrpair |
| self._rights[i + len(newtok)] = newrpair |
|
|
| texti = self._char2docidx[i] |
| |
| |
| if lneighbor: |
| self._digraph[newlpair] += self._doc_counts[texti] |
| self._digraph[lpair] -= self._doc_counts[texti] |
| self._pair_idx[newlpair].add(i - len(lneighbor)) |
| self._pair_idx[lpair].remove(i - len(lneighbor)) |
| if not self._digraph[lpair]: |
| del (self._digraph[lpair]) |
| if not self._pair_idx[lpair]: |
| del (self._pair_idx[lpair]) |
|
|
| |
| if rneighbor: |
| self._digraph[newrpair] += self._doc_counts[texti] |
| self._digraph[rpair] -= self._doc_counts[texti] |
| self._pair_idx[newrpair].add(i) |
| self._pair_idx[rpair].remove(i + len(pair[0])) |
| if not self._digraph[rpair]: |
| del (self._digraph[rpair]) |
| if not self._pair_idx[rpair]: |
| del (self._pair_idx[rpair]) |
|
|
| |
| self._unigraph[newtok] += self._doc_counts[texti] |
| self._unigraph[pair[0]] -= self._doc_counts[texti] |
| self._unigraph[pair[1]] -= self._doc_counts[texti] |
| if not self._unigraph[pair[0]]: |
| del (self._unigraph[pair[0]]) |
| if not self._unigraph[pair[1]]: |
| del (self._unigraph[pair[1]]) |
|
|
| |
| self._doc_unigraph[texti][newtok] += self._doc_counts[texti] |
| self._doc_unigraph[texti][pair[0]] -= self._doc_counts[texti] |
| if not self._doc_unigraph[texti][pair[0]]: |
| del (self._doc_unigraph[texti][pair[0]]) |
| self._doc_unigraph[texti][pair[1]] -= self._doc_counts[texti] |
| if not self._doc_unigraph[texti][pair[1]]: |
| del (self._doc_unigraph[texti][pair[1]]) |
|
|
| |
| self._tok_idx[newtok].add(i) |
| self._tok_idx[pair[0]].remove(i) |
| self._tok_idx[pair[1]].remove(i + len(pair[0])) |
| if not self._tok_idx[pair[0]]: |
| del (self._tok_idx[pair[0]]) |
| if not self._tok_idx[pair[1]]: |
| del (self._tok_idx[pair[1]]) |
|
|
| |
| self._digraph[pair] -= self._doc_counts[texti] |
| self._pair_idx[pair].remove(i) |
| if not self._pair_idx[pair]: |
| del (self._pair_idx[pair]) |
| if not self._digraph[pair]: |
| del (self._digraph[pair]) |
|
|
| |
| |
| |
| |
| def split(self, wpair): |
| oldtok = "".join(wpair) |
| locations = list(self._tok_idx[oldtok]) |
| for i in sorted(locations): |
| |
| |
| lneighbor = self._rights[i][0] |
| rneighbor = self._lefts[i][1] |
| lpair = (lneighbor, oldtok) |
| rpair = (oldtok, rneighbor) |
| newlpair = (lneighbor, wpair[0]) |
| newcpair = wpair |
| newrpair = (wpair[1], rneighbor) |
|
|
| texti = self._char2docidx[i] |
| |
| |
| self._digraph[newcpair] += self._doc_counts[texti] |
| self._pair_idx[newcpair].add(i) |
| self._lefts[i] = wpair |
| self._rights[i + len(wpair[0])] = wpair |
|
|
| |
| del (self._rights[i]) |
| self._rights[i] = newlpair |
| del (self._lefts[i - len(lneighbor) if lneighbor else i - 1]) |
| self._lefts[i - len(lneighbor) if lneighbor else i - 1] = newlpair |
| if lneighbor: |
| self._digraph[newlpair] += self._doc_counts[texti] |
| self._digraph[lpair] -= self._doc_counts[texti] |
| self._pair_idx[newlpair].add(i - len(lneighbor)) |
| self._pair_idx[lpair].remove(i - len(lneighbor)) |
| if not self._digraph[lpair]: |
| del self._digraph[lpair] |
| if not self._pair_idx[lpair]: |
| del (self._pair_idx[lpair]) |
|
|
| |
| |
| self._lefts[i + len(wpair[0])] = newrpair |
| |
| self._rights[i + len(oldtok)] = newrpair |
| if rneighbor: |
| self._digraph[newrpair] += self._doc_counts[texti] |
| self._digraph[rpair] -= self._doc_counts[texti] |
| self._pair_idx[newrpair].add(i + len(wpair[0])) |
| self._pair_idx[rpair].remove(i) |
| if not self._digraph[rpair]: |
| del (self._digraph[rpair]) |
| if not self._pair_idx[rpair]: |
| del (self._pair_idx[rpair]) |
|
|
| |
| self._unigraph[oldtok] -= self._doc_counts[texti] |
| self._unigraph[wpair[0]] += self._doc_counts[texti] |
| self._unigraph[wpair[1]] += self._doc_counts[texti] |
| if not self._unigraph[oldtok]: |
| del self._unigraph[oldtok] |
|
|
| |
| self._tok_idx[oldtok].remove(i) |
| self._tok_idx[wpair[0]].add(i) |
| self._tok_idx[wpair[1]].add(i + len(wpair[0])) |
| if not self._tok_idx[oldtok]: |
| del (self._tok_idx[oldtok]) |
|
|
| |
| self._doc_unigraph[texti][oldtok] -= self._doc_counts[texti] |
| if not self._doc_unigraph[texti][oldtok]: |
| del (self._doc_unigraph[texti][oldtok]) |
| self._doc_unigraph[texti][wpair[0]] += self._doc_counts[texti] |
| self._doc_unigraph[texti][wpair[1]] += self._doc_counts[texti] |
|
|
| |
| |
| |
| |
| |
| def get_actions(self, batch_size, actions_per_batch): |
| raise NotImplementedError |
|
|
| |
| |
| |
| |
| def rank_actions(self, actions): |
| raise NotImplementedError |
|
|
| |
| |
| |
| def do_break_early(self): |
| return False |
|
|
| |
| |
| |
| |
| |
| class GreedyBPE(BPE): |
| |
| |
| |
| |
| def __init__(self, tok2ind=None, covering_vocab = set(), early_stop=1_000_000_000): |
| super().__init__(tok2ind=tok2ind, covering_vocab = covering_vocab) |
| self._early_stop = early_stop |
| |
| |
| |
| |
| |
| def get_actions(self, batch_size, _): |
| return [Action(pair, type='merge', count=cnt) for pair, cnt in self._digraph.most_common(batch_size)] |
|
|
| |
| |
| |
| |
| def rank_actions(self, actions): |
| return sorted(actions, reverse=True, key=lambda a: a.count) |
|
|
| |
| |
| |
| def do_break_early(self): |
| |
| return((len(self._unigraph) >= self._early_stop and self._early_stop) or |
| self.get_actions(1, 1)[0].count == 1) |