| |
| |
| |
| |
|
|
| import os |
| from collections import Counter |
| from multiprocessing import Pool |
|
|
| import torch |
| from fairseq import utils |
| from fairseq.data import data_utils |
| from fairseq.file_chunker_utils import Chunker, find_offsets |
| from fairseq.file_io import PathManager |
| from fairseq.tokenizer import tokenize_line |
|
|
|
|
| class Dictionary: |
| """A mapping from symbols to consecutive integers""" |
|
|
| def __init__( |
| self, |
| *, |
| bos="<s>", |
| pad="<pad>", |
| eos="</s>", |
| unk="<unk>", |
| extra_special_symbols=None, |
| ): |
| self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos |
| self.symbols = [] |
| self.count = [] |
| self.indices = {} |
| self.bos_index = self.add_symbol(bos) |
| self.pad_index = self.add_symbol(pad) |
| self.eos_index = self.add_symbol(eos) |
| self.unk_index = self.add_symbol(unk) |
| if extra_special_symbols: |
| for s in extra_special_symbols: |
| self.add_symbol(s) |
| self.nspecial = len(self.symbols) |
|
|
| def __eq__(self, other): |
| return self.indices == other.indices |
|
|
| def __getitem__(self, idx): |
| if idx < len(self.symbols): |
| return self.symbols[idx] |
| return self.unk_word |
|
|
| def get_count(self, idx): |
| return self.count[idx] |
|
|
| def __len__(self): |
| """Returns the number of symbols in the dictionary""" |
| return len(self.symbols) |
|
|
| def __contains__(self, sym): |
| return sym in self.indices |
|
|
| def index(self, sym): |
| """Returns the index of the specified symbol""" |
| assert isinstance(sym, str) |
| if sym in self.indices: |
| return self.indices[sym] |
| return self.unk_index |
|
|
| def string( |
| self, |
| tensor, |
| bpe_symbol=None, |
| escape_unk=False, |
| extra_symbols_to_ignore=None, |
| unk_string=None, |
| include_eos=False, |
| separator=" ", |
| ): |
| """Helper for converting a tensor of token indices to a string. |
| |
| Can optionally remove BPE symbols or escape <unk> words. |
| """ |
| if torch.is_tensor(tensor) and tensor.dim() == 2: |
| return "\n".join( |
| self.string( |
| t, |
| bpe_symbol, |
| escape_unk, |
| extra_symbols_to_ignore, |
| include_eos=include_eos, |
| ) |
| for t in tensor |
| ) |
|
|
| extra_symbols_to_ignore = set(extra_symbols_to_ignore or []) |
| if not include_eos: |
| extra_symbols_to_ignore.add(self.eos()) |
|
|
| def token_string(i): |
| if i == self.unk(): |
| if unk_string is not None: |
| return unk_string |
| else: |
| return self.unk_string(escape_unk) |
| else: |
| return self[i] |
|
|
| if hasattr(self, "bos_index"): |
| extra_symbols_to_ignore.add(self.bos()) |
|
|
| sent = separator.join( |
| token_string(i) |
| for i in tensor |
| if utils.item(i) not in extra_symbols_to_ignore |
| ) |
|
|
| return data_utils.post_process(sent, bpe_symbol) |
|
|
| def unk_string(self, escape=False): |
| """Return unknown string, optionally escaped as: <<unk>>""" |
| if escape: |
| return "<{}>".format(self.unk_word) |
| else: |
| return self.unk_word |
|
|
| def add_symbol(self, word, n=1, overwrite=False): |
| """Adds a word to the dictionary""" |
| if word in self.indices and not overwrite: |
| idx = self.indices[word] |
| self.count[idx] = self.count[idx] + n |
| return idx |
| else: |
| idx = len(self.symbols) |
| self.indices[word] = idx |
| self.symbols.append(word) |
| self.count.append(n) |
| return idx |
|
|
| def update(self, new_dict): |
| """Updates counts from new dictionary.""" |
| for word in new_dict.symbols: |
| idx2 = new_dict.indices[word] |
| if word in self.indices: |
| idx = self.indices[word] |
| self.count[idx] = self.count[idx] + new_dict.count[idx2] |
| else: |
| idx = len(self.symbols) |
| self.indices[word] = idx |
| self.symbols.append(word) |
| self.count.append(new_dict.count[idx2]) |
|
|
| def finalize(self, threshold=-1, nwords=-1, padding_factor=8): |
| """Sort symbols by frequency in descending order, ignoring special ones. |
| |
| Args: |
| - threshold defines the minimum word count |
| - nwords defines the total number of words in the final dictionary, |
| including special symbols |
| - padding_factor can be used to pad the dictionary size to be a |
| multiple of 8, which is important on some hardware (e.g., Nvidia |
| Tensor Cores). |
| """ |
| if nwords <= 0: |
| nwords = len(self) |
|
|
| new_indices = dict(zip(self.symbols[: self.nspecial], range(self.nspecial))) |
| new_symbols = self.symbols[: self.nspecial] |
| new_count = self.count[: self.nspecial] |
|
|
| c = Counter( |
| dict( |
| sorted(zip(self.symbols[self.nspecial :], self.count[self.nspecial :])) |
| ) |
| ) |
| for symbol, count in c.most_common(nwords - self.nspecial): |
| if count >= threshold: |
| new_indices[symbol] = len(new_symbols) |
| new_symbols.append(symbol) |
| new_count.append(count) |
| else: |
| break |
|
|
| assert len(new_symbols) == len(new_indices) |
|
|
| self.count = list(new_count) |
| self.symbols = list(new_symbols) |
| self.indices = new_indices |
|
|
| self.pad_to_multiple_(padding_factor) |
|
|
| def pad_to_multiple_(self, padding_factor): |
| """Pad Dictionary size to be a multiple of *padding_factor*.""" |
| if padding_factor > 1: |
| i = 0 |
| while len(self) % padding_factor != 0: |
| symbol = "madeupword{:04d}".format(i) |
| self.add_symbol(symbol, n=0) |
| i += 1 |
|
|
| def bos(self): |
| """Helper to get index of beginning-of-sentence symbol""" |
| return self.bos_index |
|
|
| def pad(self): |
| """Helper to get index of pad symbol""" |
| return self.pad_index |
|
|
| def eos(self): |
| """Helper to get index of end-of-sentence symbol""" |
| return self.eos_index |
|
|
| def unk(self): |
| """Helper to get index of unk symbol""" |
| return self.unk_index |
|
|
| @classmethod |
| def load(cls, f): |
| """Loads the dictionary from a text file with the format: |
| |
| ``` |
| <symbol0> <count0> |
| <symbol1> <count1> |
| ... |
| ``` |
| """ |
| d = cls() |
| d.add_from_file(f) |
| return d |
|
|
| def add_from_file(self, f): |
| """ |
| Loads a pre-existing dictionary from a text file and adds its symbols |
| to this instance. |
| """ |
| if isinstance(f, str): |
| try: |
| with open(PathManager.get_local_path(f), "r", encoding="utf-8") as fd: |
| self.add_from_file(fd) |
| except FileNotFoundError as fnfe: |
| raise fnfe |
| except UnicodeError: |
| raise Exception( |
| "Incorrect encoding detected in {}, please " |
| "rebuild the dataset".format(f) |
| ) |
| return |
|
|
| lines = f.readlines() |
| indices_start_line = self._load_meta(lines) |
|
|
| for line in lines[indices_start_line:]: |
| try: |
| line, field = line.rstrip().rsplit(" ", 1) |
| if field == "#fairseq:overwrite": |
| overwrite = True |
| line, field = line.rsplit(" ", 1) |
| else: |
| overwrite = False |
| count = int(field) |
| word = line |
| if word in self and not overwrite: |
| raise RuntimeError( |
| "Duplicate word found when loading Dictionary: '{}'. " |
| "Duplicate words can overwrite earlier ones by adding the " |
| "#fairseq:overwrite flag at the end of the corresponding row " |
| "in the dictionary file. If using the Camembert model, please " |
| "download an updated copy of the model file.".format(word) |
| ) |
| self.add_symbol(word, n=count, overwrite=overwrite) |
| except ValueError: |
| raise ValueError( |
| f"Incorrect dictionary format, expected '<token> <cnt> [flags]': \"{line}\"" |
| ) |
|
|
| def _save(self, f, kv_iterator): |
| if isinstance(f, str): |
| PathManager.mkdirs(os.path.dirname(f)) |
| with PathManager.open(f, "w", encoding="utf-8") as fd: |
| return self.save(fd) |
| for k, v in kv_iterator: |
| print("{} {}".format(k, v), file=f) |
|
|
| def _get_meta(self): |
| return [], [] |
|
|
| def _load_meta(self, lines): |
| return 0 |
|
|
| def save(self, f): |
| """Stores dictionary into a text file""" |
| ex_keys, ex_vals = self._get_meta() |
| self._save( |
| f, |
| zip( |
| ex_keys + self.symbols[self.nspecial :], |
| ex_vals + self.count[self.nspecial :], |
| ), |
| ) |
|
|
| def dummy_sentence(self, length): |
| t = torch.Tensor(length).uniform_(self.nspecial + 1, len(self)).long() |
| t[-1] = self.eos() |
| return t |
|
|
| def encode_line( |
| self, |
| line, |
| line_tokenizer=tokenize_line, |
| add_if_not_exist=True, |
| consumer=None, |
| append_eos=True, |
| reverse_order=False, |
| ) -> torch.IntTensor: |
| words = line_tokenizer(line) |
| if reverse_order: |
| words = list(reversed(words)) |
| nwords = len(words) |
| ids = torch.IntTensor(nwords + 1 if append_eos else nwords) |
|
|
| for i, word in enumerate(words): |
| if add_if_not_exist: |
| idx = self.add_symbol(word) |
| else: |
| idx = self.index(word) |
| if consumer is not None: |
| consumer(word, idx) |
| ids[i] = idx |
| if append_eos: |
| ids[nwords] = self.eos_index |
| return ids |
|
|
| @staticmethod |
| def _add_file_to_dictionary_single_worker( |
| filename, |
| tokenize, |
| eos_word, |
| start_offset, |
| end_offset, |
| ): |
| counter = Counter() |
| with Chunker(filename, start_offset, end_offset) as line_iterator: |
| for line in line_iterator: |
| for word in tokenize(line): |
| counter.update([word]) |
| counter.update([eos_word]) |
| return counter |
|
|
| @staticmethod |
| def add_file_to_dictionary(filename, dict, tokenize, num_workers): |
| def merge_result(counter): |
| for w, c in sorted(counter.items()): |
| dict.add_symbol(w, c) |
|
|
| local_file = PathManager.get_local_path(filename) |
| offsets = find_offsets(local_file, num_workers) |
| if num_workers > 1: |
| chunks = zip(offsets, offsets[1:]) |
| pool = Pool(processes=num_workers) |
| results = [] |
| for (start_offset, end_offset) in chunks: |
| results.append( |
| pool.apply_async( |
| Dictionary._add_file_to_dictionary_single_worker, |
| ( |
| local_file, |
| tokenize, |
| dict.eos_word, |
| start_offset, |
| end_offset, |
| ), |
| ) |
| ) |
| pool.close() |
| pool.join() |
| for r in results: |
| merge_result(r.get()) |
| else: |
| merge_result( |
| Dictionary._add_file_to_dictionary_single_worker( |
| local_file, tokenize, dict.eos_word, offsets[0], offsets[1] |
| ) |
| ) |
|
|
|
|
| class TruncatedDictionary(object): |
| def __init__(self, wrapped_dict, length): |
| self.__class__ = type( |
| wrapped_dict.__class__.__name__, |
| (self.__class__, wrapped_dict.__class__), |
| {}, |
| ) |
| self.__dict__ = wrapped_dict.__dict__ |
| self.wrapped_dict = wrapped_dict |
| self.length = min(len(self.wrapped_dict), length) |
|
|
| def __len__(self): |
| return self.length |
|
|
| def __getitem__(self, i): |
| if i < self.length: |
| return self.wrapped_dict[i] |
| return self.wrapped_dict.unk() |
|
|