| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch | |
| from fairseq.data import encoders | |
| def get_whole_word_mask(args, dictionary): | |
| bpe = encoders.build_bpe(args) | |
| if bpe is not None: | |
| def is_beginning_of_word(i): | |
| if i < dictionary.nspecial: | |
| # special elements are always considered beginnings | |
| return True | |
| tok = dictionary[i] | |
| if tok.startswith("madeupword"): | |
| return True | |
| try: | |
| return bpe.is_beginning_of_word(tok) | |
| except ValueError: | |
| return True | |
| mask_whole_words = torch.ByteTensor( | |
| list(map(is_beginning_of_word, range(len(dictionary)))) | |
| ) | |
| return mask_whole_words | |
| return None | |