Spaces:
Build error
Build error
| import re | |
| import ftfy | |
| import json | |
| import spacy | |
| import torch | |
| from tqdm import tqdm | |
| def load_existing_data_loader(data_loader, path): | |
| old_data_loader = torch.load(path) | |
| for attr in data_loader.__dict__.keys(): | |
| if attr not in old_data_loader.__dict__.keys(): | |
| continue | |
| setattr(data_loader, attr, getattr(old_data_loader, attr)) | |
| ################################################################################ | |
| # | |
| # Code Below taken from HuggingFace pytorch-openai-lm repository | |
| # | |
| ################################################################################ | |
| def get_pairs(word): | |
| """ | |
| Return set of symbol pairs in a word. | |
| word is represented as tuple of symbols (symbols being variable-length strings) | |
| """ | |
| pairs = set() | |
| prev_char = word[0] | |
| for char in word[1:]: | |
| pairs.add((prev_char, char)) | |
| prev_char = char | |
| return pairs | |
| def text_standardize(text): | |
| """ | |
| fixes some issues the spacy tokenizer had on books corpus | |
| also does some whitespace standardization | |
| """ | |
| text = text.replace('β', '-') | |
| text = text.replace('β', '-') | |
| text = text.replace('β', '-') | |
| text = text.replace('β¦', '...') | |
| text = text.replace('Β΄', "'") | |
| text = re.sub(r'''(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)''', r' \1 ', text) | |
| text = re.sub(r'\s*\n\s*', ' \n ', text) | |
| text = re.sub(r'[^\S\n]+', ' ', text) | |
| return text.strip() | |
| class TextEncoder(object): | |
| """ | |
| mostly a wrapper for a public python bpe tokenizer | |
| """ | |
| def __init__(self, encoder_path, bpe_path): | |
| self.nlp = spacy.load( | |
| 'en', disable=['parser', 'tagger', 'ner', 'textcat']) | |
| self.encoder = json.load(open(encoder_path)) | |
| self.decoder = {v: k for k, v in self.encoder.items()} | |
| merges = open(bpe_path, encoding='utf-8').read().split('\n')[1:-1] | |
| merges = [tuple(merge.split()) for merge in merges] | |
| self.bpe_ranks = dict(zip(merges, range(len(merges)))) | |
| self.cache = {} | |
| def bpe(self, token): | |
| word = tuple(token[:-1]) + (token[-1] + '</w>',) | |
| if token in self.cache: | |
| return self.cache[token] | |
| pairs = get_pairs(word) | |
| if not pairs: | |
| return token+'</w>' | |
| while True: | |
| bigram = min(pairs, key=lambda pair: self.bpe_ranks.get( | |
| pair, float('inf'))) | |
| if bigram not in self.bpe_ranks: | |
| break | |
| first, second = bigram | |
| new_word = [] | |
| i = 0 | |
| while i < len(word): | |
| try: | |
| j = word.index(first, i) | |
| new_word.extend(word[i:j]) | |
| i = j | |
| except: | |
| new_word.extend(word[i:]) | |
| break | |
| if (word[i] == first and i < len(word) - 1 and | |
| word[i+1] == second): | |
| new_word.append(first+second) | |
| i += 2 | |
| else: | |
| new_word.append(word[i]) | |
| i += 1 | |
| new_word = tuple(new_word) | |
| word = new_word | |
| if len(word) == 1: | |
| break | |
| else: | |
| pairs = get_pairs(word) | |
| word = ' '.join(word) | |
| if word == '\n </w>': | |
| word = '\n</w>' | |
| self.cache[token] = word | |
| return word | |
| def encode(self, texts, verbose=True): | |
| texts_tokens = [] | |
| if verbose: | |
| for text in tqdm(texts, ncols=80, leave=False): | |
| text = self.nlp(text_standardize(ftfy.fix_text(text))) | |
| text_tokens = [] | |
| for token in text: | |
| text_tokens.extend( | |
| [self.encoder.get(t, 0) for t in | |
| self.bpe(token.text.lower()).split(' ')]) | |
| texts_tokens.append(text_tokens) | |
| else: | |
| for text in texts: | |
| text = self.nlp(text_standardize(ftfy.fix_text(text))) | |
| text_tokens = [] | |
| for token in text: | |
| text_tokens.extend( | |
| [self.encoder.get(t, 0) for t in | |
| self.bpe(token.text.lower()).split(' ')]) | |
| texts_tokens.append(text_tokens) | |
| return texts_tokens | |