| | import re |
| | import six |
| | from six.moves import range |
| |
|
| | PAD = "<pad>" |
| | EOS = "<EOS>" |
| | UNK = "<UNK>" |
| | SEG = "|" |
| | RESERVED_TOKENS = [PAD, EOS, UNK] |
| | NUM_RESERVED_TOKENS = len(RESERVED_TOKENS) |
| | PAD_ID = RESERVED_TOKENS.index(PAD) |
| | EOS_ID = RESERVED_TOKENS.index(EOS) |
| | UNK_ID = RESERVED_TOKENS.index(UNK) |
| |
|
| | if six.PY2: |
| | RESERVED_TOKENS_BYTES = RESERVED_TOKENS |
| | else: |
| | RESERVED_TOKENS_BYTES = [bytes(PAD, "ascii"), bytes(EOS, "ascii")] |
| |
|
| | |
| | |
| | |
| | |
| | _UNESCAPE_REGEX = re.compile(r"\\u|\\\\|\\([0-9]+);") |
| | _ESCAPE_CHARS = set(u"\\_u;0123456789") |
| |
|
| |
|
| | def strip_ids(ids, ids_to_strip): |
| | """Strip ids_to_strip from the end ids.""" |
| | ids = list(ids) |
| | while ids and ids[-1] in ids_to_strip: |
| | ids.pop() |
| | return ids |
| |
|
| |
|
| | class TextEncoder(object): |
| | """Base class for converting from ints to/from human readable strings.""" |
| |
|
| | def __init__(self, num_reserved_ids=NUM_RESERVED_TOKENS): |
| | self._num_reserved_ids = num_reserved_ids |
| |
|
| | @property |
| | def num_reserved_ids(self): |
| | return self._num_reserved_ids |
| |
|
| | def encode(self, s): |
| | """Transform a human-readable string into a sequence of int ids. |
| | |
| | The ids should be in the range [num_reserved_ids, vocab_size). Ids [0, |
| | num_reserved_ids) are reserved. |
| | |
| | EOS is not appended. |
| | |
| | Args: |
| | s: human-readable string to be converted. |
| | |
| | Returns: |
| | ids: list of integers |
| | """ |
| | return [int(w) + self._num_reserved_ids for w in s.split()] |
| |
|
| | def decode(self, ids, strip_extraneous=False): |
| | """Transform a sequence of int ids into a human-readable string. |
| | |
| | EOS is not expected in ids. |
| | |
| | Args: |
| | ids: list of integers to be converted. |
| | strip_extraneous: bool, whether to strip off extraneous tokens |
| | (EOS and PAD). |
| | |
| | Returns: |
| | s: human-readable string. |
| | """ |
| | if strip_extraneous: |
| | ids = strip_ids(ids, list(range(self._num_reserved_ids or 0))) |
| | return " ".join(self.decode_list(ids)) |
| |
|
| | def decode_list(self, ids): |
| | """Transform a sequence of int ids into a their string versions. |
| | |
| | This method supports transforming individual input/output ids to their |
| | string versions so that sequence to/from text conversions can be visualized |
| | in a human readable format. |
| | |
| | Args: |
| | ids: list of integers to be converted. |
| | |
| | Returns: |
| | strs: list of human-readable string. |
| | """ |
| | decoded_ids = [] |
| | for id_ in ids: |
| | if 0 <= id_ < self._num_reserved_ids: |
| | decoded_ids.append(RESERVED_TOKENS[int(id_)]) |
| | else: |
| | decoded_ids.append(id_ - self._num_reserved_ids) |
| | return [str(d) for d in decoded_ids] |
| |
|
| | @property |
| | def vocab_size(self): |
| | raise NotImplementedError() |
| |
|
| |
|
| | class ByteTextEncoder(TextEncoder): |
| | """Encodes each byte to an id. For 8-bit strings only.""" |
| |
|
| | def encode(self, s): |
| | numres = self._num_reserved_ids |
| | if six.PY2: |
| | if isinstance(s, unicode): |
| | s = s.encode("utf-8") |
| | return [ord(c) + numres for c in s] |
| | |
| | return [c + numres for c in s.encode("utf-8")] |
| |
|
| | def decode(self, ids, strip_extraneous=False): |
| | if strip_extraneous: |
| | ids = strip_ids(ids, list(range(self._num_reserved_ids or 0))) |
| | numres = self._num_reserved_ids |
| | decoded_ids = [] |
| | int2byte = six.int2byte |
| | for id_ in ids: |
| | if 0 <= id_ < numres: |
| | decoded_ids.append(RESERVED_TOKENS_BYTES[int(id_)]) |
| | else: |
| | decoded_ids.append(int2byte(id_ - numres)) |
| | if six.PY2: |
| | return "".join(decoded_ids) |
| | |
| | return b"".join(decoded_ids).decode("utf-8", "replace") |
| |
|
| | def decode_list(self, ids): |
| | numres = self._num_reserved_ids |
| | decoded_ids = [] |
| | int2byte = six.int2byte |
| | for id_ in ids: |
| | if 0 <= id_ < numres: |
| | decoded_ids.append(RESERVED_TOKENS_BYTES[int(id_)]) |
| | else: |
| | decoded_ids.append(int2byte(id_ - numres)) |
| | |
| | return decoded_ids |
| |
|
| | @property |
| | def vocab_size(self): |
| | return 2**8 + self._num_reserved_ids |
| |
|
| |
|
| | class ByteTextEncoderWithEos(ByteTextEncoder): |
| | """Encodes each byte to an id and appends the EOS token.""" |
| |
|
| | def encode(self, s): |
| | return super(ByteTextEncoderWithEos, self).encode(s) + [EOS_ID] |
| |
|
| |
|
| | class TokenTextEncoder(TextEncoder): |
| | """Encoder based on a user-supplied vocabulary (file or list).""" |
| |
|
| | def __init__(self, |
| | vocab_filename, |
| | reverse=False, |
| | vocab_list=None, |
| | replace_oov=None, |
| | num_reserved_ids=NUM_RESERVED_TOKENS): |
| | """Initialize from a file or list, one token per line. |
| | |
| | Handling of reserved tokens works as follows: |
| | - When initializing from a list, we add reserved tokens to the vocab. |
| | - When initializing from a file, we do not add reserved tokens to the vocab. |
| | - When saving vocab files, we save reserved tokens to the file. |
| | |
| | Args: |
| | vocab_filename: If not None, the full filename to read vocab from. If this |
| | is not None, then vocab_list should be None. |
| | reverse: Boolean indicating if tokens should be reversed during encoding |
| | and decoding. |
| | vocab_list: If not None, a list of elements of the vocabulary. If this is |
| | not None, then vocab_filename should be None. |
| | replace_oov: If not None, every out-of-vocabulary token seen when |
| | encoding will be replaced by this string (which must be in vocab). |
| | num_reserved_ids: Number of IDs to save for reserved tokens like <EOS>. |
| | """ |
| | super(TokenTextEncoder, self).__init__(num_reserved_ids=num_reserved_ids) |
| | self._reverse = reverse |
| | self._replace_oov = replace_oov |
| | if vocab_filename: |
| | self._init_vocab_from_file(vocab_filename) |
| | else: |
| | assert vocab_list is not None |
| | self._init_vocab_from_list(vocab_list) |
| | self.pad_index = self._token_to_id[PAD] |
| | self.eos_index = self._token_to_id[EOS] |
| | self.unk_index = self._token_to_id[UNK] |
| | self.seg_index = self._token_to_id[SEG] if SEG in self._token_to_id else self.eos_index |
| |
|
| | def encode(self, s): |
| | """Converts a space-separated string of tokens to a list of ids.""" |
| | sentence = s |
| | tokens = sentence.strip().split() |
| | if self._replace_oov is not None: |
| | tokens = [t if t in self._token_to_id else self._replace_oov |
| | for t in tokens] |
| | ret = [self._token_to_id[tok] for tok in tokens] |
| | return ret[::-1] if self._reverse else ret |
| |
|
| | def decode(self, ids, strip_eos=False, strip_padding=False): |
| | if strip_padding and self.pad() in list(ids): |
| | pad_pos = list(ids).index(self.pad()) |
| | ids = ids[:pad_pos] |
| | if strip_eos and self.eos() in list(ids): |
| | eos_pos = list(ids).index(self.eos()) |
| | ids = ids[:eos_pos] |
| | return " ".join(self.decode_list(ids)) |
| |
|
| | def decode_list(self, ids): |
| | seq = reversed(ids) if self._reverse else ids |
| | return [self._safe_id_to_token(i) for i in seq] |
| |
|
| | @property |
| | def vocab_size(self): |
| | return len(self._id_to_token) |
| |
|
| | def __len__(self): |
| | return self.vocab_size |
| |
|
| | def _safe_id_to_token(self, idx): |
| | return self._id_to_token.get(idx, "ID_%d" % idx) |
| |
|
| | def _init_vocab_from_file(self, filename): |
| | """Load vocab from a file. |
| | |
| | Args: |
| | filename: The file to load vocabulary from. |
| | """ |
| | with open(filename) as f: |
| | tokens = [token.strip() for token in f.readlines()] |
| |
|
| | def token_gen(): |
| | for token in tokens: |
| | yield token |
| |
|
| | self._init_vocab(token_gen(), add_reserved_tokens=False) |
| |
|
| | def _init_vocab_from_list(self, vocab_list): |
| | """Initialize tokens from a list of tokens. |
| | |
| | It is ok if reserved tokens appear in the vocab list. They will be |
| | removed. The set of tokens in vocab_list should be unique. |
| | |
| | Args: |
| | vocab_list: A list of tokens. |
| | """ |
| | def token_gen(): |
| | for token in vocab_list: |
| | if token not in RESERVED_TOKENS: |
| | yield token |
| |
|
| | self._init_vocab(token_gen()) |
| |
|
| | def _init_vocab(self, token_generator, add_reserved_tokens=True): |
| | """Initialize vocabulary with tokens from token_generator.""" |
| |
|
| | self._id_to_token = {} |
| | non_reserved_start_index = 0 |
| |
|
| | if add_reserved_tokens: |
| | self._id_to_token.update(enumerate(RESERVED_TOKENS)) |
| | non_reserved_start_index = len(RESERVED_TOKENS) |
| |
|
| | self._id_to_token.update( |
| | enumerate(token_generator, start=non_reserved_start_index)) |
| |
|
| | |
| | self._token_to_id = dict((v, k) |
| | for k, v in six.iteritems(self._id_to_token)) |
| |
|
| | def pad(self): |
| | return self.pad_index |
| |
|
| | def eos(self): |
| | return self.eos_index |
| |
|
| | def unk(self): |
| | return self.unk_index |
| |
|
| | def seg(self): |
| | return self.seg_index |
| |
|
| | def store_to_file(self, filename): |
| | """Write vocab file to disk. |
| | |
| | Vocab files have one token per line. The file ends in a newline. Reserved |
| | tokens are written to the vocab file as well. |
| | |
| | Args: |
| | filename: Full path of the file to store the vocab to. |
| | """ |
| | with open(filename, "w") as f: |
| | for i in range(len(self._id_to_token)): |
| | f.write(self._id_to_token[i] + "\n") |
| |
|
| | def sil_phonemes(self): |
| | return [p for p in self._id_to_token.values() if not p[0].isalpha()] |
| |
|