| import re
|
| import six
|
| from six.moves import range
|
|
|
| PAD = "<pad>"
|
| EOS = "<EOS>"
|
| UNK = "<UNK>"
|
| SEG = "|"
|
| RESERVED_TOKENS = [PAD, EOS, UNK]
|
| NUM_RESERVED_TOKENS = len(RESERVED_TOKENS)
|
| PAD_ID = RESERVED_TOKENS.index(PAD)
|
| EOS_ID = RESERVED_TOKENS.index(EOS)
|
| UNK_ID = RESERVED_TOKENS.index(UNK)
|
|
|
| if six.PY2:
|
| RESERVED_TOKENS_BYTES = RESERVED_TOKENS
|
| else:
|
| RESERVED_TOKENS_BYTES = [bytes(PAD, "ascii"), bytes(EOS, "ascii")]
|
|
|
|
|
|
|
|
|
|
|
| _UNESCAPE_REGEX = re.compile(r"\\u|\\\\|\\([0-9]+);")
|
| _ESCAPE_CHARS = set(u"\\_u;0123456789")
|
|
|
|
|
| def strip_ids(ids, ids_to_strip):
|
| """Strip ids_to_strip from the end ids."""
|
| ids = list(ids)
|
| while ids and ids[-1] in ids_to_strip:
|
| ids.pop()
|
| return ids
|
|
|
|
|
| class TextEncoder(object):
|
| """Base class for converting from ints to/from human readable strings."""
|
|
|
| def __init__(self, num_reserved_ids=NUM_RESERVED_TOKENS):
|
| self._num_reserved_ids = num_reserved_ids
|
|
|
| @property
|
| def num_reserved_ids(self):
|
| return self._num_reserved_ids
|
|
|
| def encode(self, s):
|
| """Transform a human-readable string into a sequence of int ids.
|
|
|
| The ids should be in the range [num_reserved_ids, vocab_size). Ids [0,
|
| num_reserved_ids) are reserved.
|
|
|
| EOS is not appended.
|
|
|
| Args:
|
| s: human-readable string to be converted.
|
|
|
| Returns:
|
| ids: list of integers
|
| """
|
| return [int(w) + self._num_reserved_ids for w in s.split()]
|
|
|
| def decode(self, ids, strip_extraneous=False):
|
| """Transform a sequence of int ids into a human-readable string.
|
|
|
| EOS is not expected in ids.
|
|
|
| Args:
|
| ids: list of integers to be converted.
|
| strip_extraneous: bool, whether to strip off extraneous tokens
|
| (EOS and PAD).
|
|
|
| Returns:
|
| s: human-readable string.
|
| """
|
| if strip_extraneous:
|
| ids = strip_ids(ids, list(range(self._num_reserved_ids or 0)))
|
| return " ".join(self.decode_list(ids))
|
|
|
| def decode_list(self, ids):
|
| """Transform a sequence of int ids into a their string versions.
|
|
|
| This method supports transforming individual input/output ids to their
|
| string versions so that sequence to/from text conversions can be visualized
|
| in a human readable format.
|
|
|
| Args:
|
| ids: list of integers to be converted.
|
|
|
| Returns:
|
| strs: list of human-readable string.
|
| """
|
| decoded_ids = []
|
| for id_ in ids:
|
| if 0 <= id_ < self._num_reserved_ids:
|
| decoded_ids.append(RESERVED_TOKENS[int(id_)])
|
| else:
|
| decoded_ids.append(id_ - self._num_reserved_ids)
|
| return [str(d) for d in decoded_ids]
|
|
|
| @property
|
| def vocab_size(self):
|
| raise NotImplementedError()
|
|
|
|
|
| class ByteTextEncoder(TextEncoder):
|
| """Encodes each byte to an id. For 8-bit strings only."""
|
|
|
| def encode(self, s):
|
| numres = self._num_reserved_ids
|
| if six.PY2:
|
| if isinstance(s, unicode):
|
| s = s.encode("utf-8")
|
| return [ord(c) + numres for c in s]
|
|
|
| return [c + numres for c in s.encode("utf-8")]
|
|
|
| def decode(self, ids, strip_extraneous=False):
|
| if strip_extraneous:
|
| ids = strip_ids(ids, list(range(self._num_reserved_ids or 0)))
|
| numres = self._num_reserved_ids
|
| decoded_ids = []
|
| int2byte = six.int2byte
|
| for id_ in ids:
|
| if 0 <= id_ < numres:
|
| decoded_ids.append(RESERVED_TOKENS_BYTES[int(id_)])
|
| else:
|
| decoded_ids.append(int2byte(id_ - numres))
|
| if six.PY2:
|
| return "".join(decoded_ids)
|
|
|
| return b"".join(decoded_ids).decode("utf-8", "replace")
|
|
|
| def decode_list(self, ids):
|
| numres = self._num_reserved_ids
|
| decoded_ids = []
|
| int2byte = six.int2byte
|
| for id_ in ids:
|
| if 0 <= id_ < numres:
|
| decoded_ids.append(RESERVED_TOKENS_BYTES[int(id_)])
|
| else:
|
| decoded_ids.append(int2byte(id_ - numres))
|
|
|
| return decoded_ids
|
|
|
| @property
|
| def vocab_size(self):
|
| return 2**8 + self._num_reserved_ids
|
|
|
|
|
| class ByteTextEncoderWithEos(ByteTextEncoder):
|
| """Encodes each byte to an id and appends the EOS token."""
|
|
|
| def encode(self, s):
|
| return super(ByteTextEncoderWithEos, self).encode(s) + [EOS_ID]
|
|
|
|
|
| class TokenTextEncoder(TextEncoder):
|
| """Encoder based on a user-supplied vocabulary (file or list)."""
|
|
|
| def __init__(self,
|
| vocab_filename,
|
| reverse=False,
|
| vocab_list=None,
|
| replace_oov=None,
|
| num_reserved_ids=NUM_RESERVED_TOKENS):
|
| """Initialize from a file or list, one token per line.
|
|
|
| Handling of reserved tokens works as follows:
|
| - When initializing from a list, we add reserved tokens to the vocab.
|
| - When initializing from a file, we do not add reserved tokens to the vocab.
|
| - When saving vocab files, we save reserved tokens to the file.
|
|
|
| Args:
|
| vocab_filename: If not None, the full filename to read vocab from. If this
|
| is not None, then vocab_list should be None.
|
| reverse: Boolean indicating if tokens should be reversed during encoding
|
| and decoding.
|
| vocab_list: If not None, a list of elements of the vocabulary. If this is
|
| not None, then vocab_filename should be None.
|
| replace_oov: If not None, every out-of-vocabulary token seen when
|
| encoding will be replaced by this string (which must be in vocab).
|
| num_reserved_ids: Number of IDs to save for reserved tokens like <EOS>.
|
| """
|
| super(TokenTextEncoder, self).__init__(num_reserved_ids=num_reserved_ids)
|
| self._reverse = reverse
|
| self._replace_oov = replace_oov
|
| if vocab_filename:
|
| self._init_vocab_from_file(vocab_filename)
|
| else:
|
| assert vocab_list is not None
|
| self._init_vocab_from_list(vocab_list)
|
| self.pad_index = self._token_to_id[PAD]
|
| self.eos_index = self._token_to_id[EOS]
|
| self.unk_index = self._token_to_id[UNK]
|
| self.seg_index = self._token_to_id[SEG] if SEG in self._token_to_id else self.eos_index
|
|
|
| def encode(self, s):
|
| """Converts a space-separated string of tokens to a list of ids."""
|
| sentence = s
|
| tokens = sentence.strip().split()
|
| if self._replace_oov is not None:
|
| tokens = [t if t in self._token_to_id else self._replace_oov
|
| for t in tokens]
|
| ret = [self._token_to_id[tok] for tok in tokens]
|
| return ret[::-1] if self._reverse else ret
|
|
|
| def decode(self, ids, strip_eos=False, strip_padding=False):
|
| if strip_padding and self.pad() in list(ids):
|
| pad_pos = list(ids).index(self.pad())
|
| ids = ids[:pad_pos]
|
| if strip_eos and self.eos() in list(ids):
|
| eos_pos = list(ids).index(self.eos())
|
| ids = ids[:eos_pos]
|
| return " ".join(self.decode_list(ids))
|
|
|
| def decode_list(self, ids):
|
| seq = reversed(ids) if self._reverse else ids
|
| return [self._safe_id_to_token(i) for i in seq]
|
|
|
| @property
|
| def vocab_size(self):
|
| return len(self._id_to_token)
|
|
|
| def __len__(self):
|
| return self.vocab_size
|
|
|
| def _safe_id_to_token(self, idx):
|
| return self._id_to_token.get(idx, "ID_%d" % idx)
|
|
|
| def _init_vocab_from_file(self, filename):
|
| """Load vocab from a file.
|
|
|
| Args:
|
| filename: The file to load vocabulary from.
|
| """
|
| with open(filename) as f:
|
| tokens = [token.strip() for token in f.readlines()]
|
|
|
| def token_gen():
|
| for token in tokens:
|
| yield token
|
|
|
| self._init_vocab(token_gen(), add_reserved_tokens=False)
|
|
|
| def _init_vocab_from_list(self, vocab_list):
|
| """Initialize tokens from a list of tokens.
|
|
|
| It is ok if reserved tokens appear in the vocab list. They will be
|
| removed. The set of tokens in vocab_list should be unique.
|
|
|
| Args:
|
| vocab_list: A list of tokens.
|
| """
|
| def token_gen():
|
| for token in vocab_list:
|
| if token not in RESERVED_TOKENS:
|
| yield token
|
|
|
| self._init_vocab(token_gen())
|
|
|
| def _init_vocab(self, token_generator, add_reserved_tokens=True):
|
| """Initialize vocabulary with tokens from token_generator."""
|
|
|
| self._id_to_token = {}
|
| non_reserved_start_index = 0
|
|
|
| if add_reserved_tokens:
|
| self._id_to_token.update(enumerate(RESERVED_TOKENS))
|
| non_reserved_start_index = len(RESERVED_TOKENS)
|
|
|
| self._id_to_token.update(
|
| enumerate(token_generator, start=non_reserved_start_index))
|
|
|
|
|
| self._token_to_id = dict((v, k)
|
| for k, v in six.iteritems(self._id_to_token))
|
|
|
| def pad(self):
|
| return self.pad_index
|
|
|
| def eos(self):
|
| return self.eos_index
|
|
|
| def unk(self):
|
| return self.unk_index
|
|
|
| def seg(self):
|
| return self.seg_index
|
|
|
| def store_to_file(self, filename):
|
| """Write vocab file to disk.
|
|
|
| Vocab files have one token per line. The file ends in a newline. Reserved
|
| tokens are written to the vocab file as well.
|
|
|
| Args:
|
| filename: Full path of the file to store the vocab to.
|
| """
|
| with open(filename, "w") as f:
|
| for i in range(len(self._id_to_token)):
|
| f.write(self._id_to_token[i] + "\n")
|
|
|
| def sil_phonemes(self):
|
| return [p for p in self._id_to_token.values() if not p[0].isalpha()]
|
|
|