Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import os | |
| from collections import Counter | |
| from multiprocessing import Pool | |
| import torch | |
| from fairseq import utils | |
| from fairseq.data import data_utils | |
| from fairseq.file_chunker_utils import Chunker, find_offsets | |
| from fairseq.file_io import PathManager | |
| from fairseq.tokenizer import tokenize_line | |
| class Dictionary: | |
| """A mapping from symbols to consecutive integers""" | |
| def __init__( | |
| self, | |
| *, # begin keyword-only arguments | |
| bos="<s>", | |
| pad="<pad>", | |
| eos="</s>", | |
| unk="<unk>", | |
| extra_special_symbols=None, | |
| add_special_symbols=True, | |
| ): | |
| self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos | |
| self.symbols = [] | |
| self.count = [] | |
| self.indices = {} | |
| if add_special_symbols: | |
| self.bos_index = self.add_symbol(bos) | |
| self.pad_index = self.add_symbol(pad) | |
| self.eos_index = self.add_symbol(eos) | |
| self.unk_index = self.add_symbol(unk) | |
| if extra_special_symbols: | |
| for s in extra_special_symbols: | |
| self.add_symbol(s) | |
| self.nspecial = len(self.symbols) | |
| def __eq__(self, other): | |
| return self.indices == other.indices | |
| def __getitem__(self, idx): | |
| if idx < len(self.symbols): | |
| return self.symbols[idx] | |
| return self.unk_word | |
| def get_count(self, idx): | |
| return self.count[idx] | |
| def __len__(self): | |
| """Returns the number of symbols in the dictionary""" | |
| return len(self.symbols) | |
| def __contains__(self, sym): | |
| return sym in self.indices | |
| def index(self, sym): | |
| """Returns the index of the specified symbol""" | |
| assert isinstance(sym, str) | |
| if sym in self.indices: | |
| return self.indices[sym] | |
| return self.unk_index | |
| def string( | |
| self, | |
| tensor, | |
| bpe_symbol=None, | |
| escape_unk=False, | |
| extra_symbols_to_ignore=None, | |
| unk_string=None, | |
| include_eos=False, | |
| separator=" ", | |
| ): | |
| """Helper for converting a tensor of token indices to a string. | |
| Can optionally remove BPE symbols or escape <unk> words. | |
| """ | |
| if torch.is_tensor(tensor) and tensor.dim() == 2: | |
| return "\n".join( | |
| self.string( | |
| t, | |
| bpe_symbol, | |
| escape_unk, | |
| extra_symbols_to_ignore, | |
| include_eos=include_eos, | |
| ) | |
| for t in tensor | |
| ) | |
| extra_symbols_to_ignore = set(extra_symbols_to_ignore or []) | |
| if not include_eos: | |
| extra_symbols_to_ignore.add(self.eos()) | |
| def token_string(i): | |
| if i == self.unk(): | |
| if unk_string is not None: | |
| return unk_string | |
| else: | |
| return self.unk_string(escape_unk) | |
| else: | |
| return self[i] | |
| if hasattr(self, "bos_index"): | |
| extra_symbols_to_ignore.add(self.bos()) | |
| sent = separator.join( | |
| token_string(i) | |
| for i in tensor | |
| if utils.item(i) not in extra_symbols_to_ignore | |
| ) | |
| return data_utils.post_process(sent, bpe_symbol) | |
| def unk_string(self, escape=False): | |
| """Return unknown string, optionally escaped as: <<unk>>""" | |
| if escape: | |
| return "<{}>".format(self.unk_word) | |
| else: | |
| return self.unk_word | |
| def add_symbol(self, word, n=1, overwrite=False): | |
| """Adds a word to the dictionary""" | |
| if word in self.indices and not overwrite: | |
| idx = self.indices[word] | |
| self.count[idx] = self.count[idx] + n | |
| return idx | |
| else: | |
| idx = len(self.symbols) | |
| self.indices[word] = idx | |
| self.symbols.append(word) | |
| self.count.append(n) | |
| return idx | |
| def update(self, new_dict): | |
| """Updates counts from new dictionary.""" | |
| for word in new_dict.symbols: | |
| idx2 = new_dict.indices[word] | |
| if word in self.indices: | |
| idx = self.indices[word] | |
| self.count[idx] = self.count[idx] + new_dict.count[idx2] | |
| else: | |
| idx = len(self.symbols) | |
| self.indices[word] = idx | |
| self.symbols.append(word) | |
| self.count.append(new_dict.count[idx2]) | |
| def finalize(self, threshold=-1, nwords=-1, padding_factor=8): | |
| """Sort symbols by frequency in descending order, ignoring special ones. | |
| Args: | |
| - threshold defines the minimum word count | |
| - nwords defines the total number of words in the final dictionary, | |
| including special symbols | |
| - padding_factor can be used to pad the dictionary size to be a | |
| multiple of 8, which is important on some hardware (e.g., Nvidia | |
| Tensor Cores). | |
| """ | |
| if nwords <= 0: | |
| nwords = len(self) | |
| new_indices = dict(zip(self.symbols[: self.nspecial], range(self.nspecial))) | |
| new_symbols = self.symbols[: self.nspecial] | |
| new_count = self.count[: self.nspecial] | |
| c = Counter( | |
| dict( | |
| sorted(zip(self.symbols[self.nspecial :], self.count[self.nspecial :])) | |
| ) | |
| ) | |
| for symbol, count in c.most_common(nwords - self.nspecial): | |
| if count >= threshold: | |
| new_indices[symbol] = len(new_symbols) | |
| new_symbols.append(symbol) | |
| new_count.append(count) | |
| else: | |
| break | |
| assert len(new_symbols) == len(new_indices) | |
| self.count = list(new_count) | |
| self.symbols = list(new_symbols) | |
| self.indices = new_indices | |
| self.pad_to_multiple_(padding_factor) | |
| def pad_to_multiple_(self, padding_factor): | |
| """Pad Dictionary size to be a multiple of *padding_factor*.""" | |
| if padding_factor > 1: | |
| i = 0 | |
| while len(self) % padding_factor != 0: | |
| symbol = "madeupword{:04d}".format(i) | |
| self.add_symbol(symbol, n=0) | |
| i += 1 | |
| def bos(self): | |
| """Helper to get index of beginning-of-sentence symbol""" | |
| return self.bos_index | |
| def pad(self): | |
| """Helper to get index of pad symbol""" | |
| return self.pad_index | |
| def eos(self): | |
| """Helper to get index of end-of-sentence symbol""" | |
| return self.eos_index | |
| def unk(self): | |
| """Helper to get index of unk symbol""" | |
| return self.unk_index | |
| def load(cls, f, add_special_symbols=True): | |
| """Loads the dictionary from a text file with the format: | |
| ``` | |
| <symbol0> <count0> | |
| <symbol1> <count1> | |
| ... | |
| ``` | |
| """ | |
| d = cls(add_special_symbols=add_special_symbols) | |
| d.add_from_file(f) | |
| return d | |
| def add_from_file(self, f): | |
| """ | |
| Loads a pre-existing dictionary from a text file and adds its symbols | |
| to this instance. | |
| """ | |
| if isinstance(f, str): | |
| try: | |
| with open(PathManager.get_local_path(f), "r", encoding="utf-8") as fd: | |
| self.add_from_file(fd) | |
| except FileNotFoundError as fnfe: | |
| raise fnfe | |
| except UnicodeError: | |
| raise Exception( | |
| "Incorrect encoding detected in {}, please " | |
| "rebuild the dataset".format(f) | |
| ) | |
| return | |
| lines = f.readlines() | |
| indices_start_line = self._load_meta(lines) | |
| for line in lines[indices_start_line:]: | |
| try: | |
| line, field = line.rstrip().rsplit(" ", 1) | |
| if field == "#fairseq:overwrite": | |
| overwrite = True | |
| line, field = line.rsplit(" ", 1) | |
| else: | |
| overwrite = False | |
| count = int(field) | |
| word = line | |
| if word in self and not overwrite: | |
| raise RuntimeError( | |
| "Duplicate word found when loading Dictionary: '{}'. " | |
| "Duplicate words can overwrite earlier ones by adding the " | |
| "#fairseq:overwrite flag at the end of the corresponding row " | |
| "in the dictionary file. If using the Camembert model, please " | |
| "download an updated copy of the model file.".format(word) | |
| ) | |
| self.add_symbol(word, n=count, overwrite=overwrite) | |
| except ValueError: | |
| raise ValueError( | |
| f"Incorrect dictionary format, expected '<token> <cnt> [flags]': \"{line}\"" | |
| ) | |
| def _save(self, f, kv_iterator): | |
| if isinstance(f, str): | |
| PathManager.mkdirs(os.path.dirname(f)) | |
| with PathManager.open(f, "w", encoding="utf-8") as fd: | |
| return self.save(fd) | |
| for k, v in kv_iterator: | |
| print("{} {}".format(k, v), file=f) | |
| def _get_meta(self): | |
| return [], [] | |
| def _load_meta(self, lines): | |
| return 0 | |
| def save(self, f): | |
| """Stores dictionary into a text file""" | |
| ex_keys, ex_vals = self._get_meta() | |
| self._save( | |
| f, | |
| zip( | |
| ex_keys + self.symbols[self.nspecial :], | |
| ex_vals + self.count[self.nspecial :], | |
| ), | |
| ) | |
| def dummy_sentence(self, length): | |
| t = torch.Tensor(length).uniform_(self.nspecial + 1, len(self)).long() | |
| t[-1] = self.eos() | |
| return t | |
| def encode_line( | |
| self, | |
| line, | |
| line_tokenizer=tokenize_line, | |
| add_if_not_exist=True, | |
| consumer=None, | |
| append_eos=True, | |
| reverse_order=False, | |
| ) -> torch.IntTensor: | |
| words = line_tokenizer(line) | |
| if reverse_order: | |
| words = list(reversed(words)) | |
| nwords = len(words) | |
| ids = torch.IntTensor(nwords + 1 if append_eos else nwords) | |
| for i, word in enumerate(words): | |
| if add_if_not_exist: | |
| idx = self.add_symbol(word) | |
| else: | |
| idx = self.index(word) | |
| if consumer is not None: | |
| consumer(word, idx) | |
| ids[i] = idx | |
| if append_eos: | |
| ids[nwords] = self.eos_index | |
| return ids | |
| def _add_file_to_dictionary_single_worker( | |
| filename, | |
| tokenize, | |
| eos_word, | |
| start_offset, | |
| end_offset, | |
| ): | |
| counter = Counter() | |
| with Chunker(filename, start_offset, end_offset) as line_iterator: | |
| for line in line_iterator: | |
| for word in tokenize(line): | |
| counter.update([word]) | |
| counter.update([eos_word]) | |
| return counter | |
| def add_file_to_dictionary(filename, dict, tokenize, num_workers): | |
| def merge_result(counter): | |
| for w, c in sorted(counter.items()): | |
| dict.add_symbol(w, c) | |
| local_file = PathManager.get_local_path(filename) | |
| offsets = find_offsets(local_file, num_workers) | |
| if num_workers > 1: | |
| chunks = zip(offsets, offsets[1:]) | |
| pool = Pool(processes=num_workers) | |
| results = [] | |
| for (start_offset, end_offset) in chunks: | |
| results.append( | |
| pool.apply_async( | |
| Dictionary._add_file_to_dictionary_single_worker, | |
| ( | |
| local_file, | |
| tokenize, | |
| dict.eos_word, | |
| start_offset, | |
| end_offset, | |
| ), | |
| ) | |
| ) | |
| pool.close() | |
| pool.join() | |
| for r in results: | |
| merge_result(r.get()) | |
| else: | |
| merge_result( | |
| Dictionary._add_file_to_dictionary_single_worker( | |
| local_file, tokenize, dict.eos_word, offsets[0], offsets[1] | |
| ) | |
| ) | |
| class TruncatedDictionary(object): | |
| def __init__(self, wrapped_dict, length): | |
| self.__class__ = type( | |
| wrapped_dict.__class__.__name__, | |
| (self.__class__, wrapped_dict.__class__), | |
| {}, | |
| ) | |
| self.__dict__ = wrapped_dict.__dict__ | |
| self.wrapped_dict = wrapped_dict | |
| self.length = min(len(self.wrapped_dict), length) | |
| def __len__(self): | |
| return self.length | |
| def __getitem__(self, i): | |
| if i < self.length: | |
| return self.wrapped_dict[i] | |
| return self.wrapped_dict.unk() | |