| |
|
|
| |
| |
|
|
| |
| |
|
|
| import chainer |
| import h5py |
| import logging |
| import numpy as np |
| import os |
| import random |
| import six |
| from tqdm import tqdm |
|
|
| from chainer.training import extension |
|
|
|
|
| def load_dataset(path, label_dict, outdir=None): |
| """Load and save HDF5 that contains a dataset and stats for LM |
| |
| Args: |
| path (str): The path of an input text dataset file |
| label_dict (dict[str, int]): |
| dictionary that maps token label string to its ID number |
| outdir (str): The path of an output dir |
| |
| Returns: |
| tuple[list[np.ndarray], int, int]: Tuple of |
| token IDs in np.int32 converted by `read_tokens` |
| the number of tokens by `count_tokens`, |
| and the number of OOVs by `count_tokens` |
| """ |
| if outdir is not None: |
| os.makedirs(outdir, exist_ok=True) |
| filename = outdir + "/" + os.path.basename(path) + ".h5" |
| if os.path.exists(filename): |
| logging.info(f"loading binary dataset: {filename}") |
| f = h5py.File(filename, "r") |
| return f["data"][:], f["n_tokens"][()], f["n_oovs"][()] |
| else: |
| logging.info("skip dump/load HDF5 because the output dir is not specified") |
| logging.info(f"reading text dataset: {path}") |
| ret = read_tokens(path, label_dict) |
| n_tokens, n_oovs = count_tokens(ret, label_dict["<unk>"]) |
| if outdir is not None: |
| logging.info(f"saving binary dataset: {filename}") |
| with h5py.File(filename, "w") as f: |
| |
| data = f.create_dataset( |
| "data", (len(ret),), dtype=h5py.special_dtype(vlen=np.int32) |
| ) |
| data[:] = ret |
| f["n_tokens"] = n_tokens |
| f["n_oovs"] = n_oovs |
| return ret, n_tokens, n_oovs |
|
|
|
|
| def read_tokens(filename, label_dict): |
| """Read tokens as a sequence of sentences |
| |
| :param str filename : The name of the input file |
| :param dict label_dict : dictionary that maps token label string to its ID number |
| :return list of ID sequences |
| :rtype list |
| """ |
|
|
| data = [] |
| unk = label_dict["<unk>"] |
| for ln in tqdm(open(filename, "r", encoding="utf-8")): |
| data.append( |
| np.array( |
| [label_dict.get(label, unk) for label in ln.split()], dtype=np.int32 |
| ) |
| ) |
| return data |
|
|
|
|
| def count_tokens(data, unk_id=None): |
| """Count tokens and oovs in token ID sequences. |
| |
| Args: |
| data (list[np.ndarray]): list of token ID sequences |
| unk_id (int): ID of unknown token |
| |
| Returns: |
| tuple: tuple of number of token occurrences and number of oov tokens |
| |
| """ |
|
|
| n_tokens = 0 |
| n_oovs = 0 |
| for sentence in data: |
| n_tokens += len(sentence) |
| if unk_id is not None: |
| n_oovs += np.count_nonzero(sentence == unk_id) |
| return n_tokens, n_oovs |
|
|
|
|
| def compute_perplexity(result): |
| """Computes and add the perplexity to the LogReport |
| |
| :param dict result: The current observations |
| """ |
| |
| result["perplexity"] = np.exp(result["main/loss"] / result["main/count"]) |
| if "validation/main/loss" in result: |
| result["val_perplexity"] = np.exp(result["validation/main/loss"]) |
|
|
|
|
| class ParallelSentenceIterator(chainer.dataset.Iterator): |
| """Dataset iterator to create a batch of sentences. |
| |
| This iterator returns a pair of sentences, where one token is shifted |
| between the sentences like '<sos> w1 w2 w3' and 'w1 w2 w3 <eos>' |
| Sentence batches are made in order of longer sentences, and then |
| randomly shuffled. |
| """ |
|
|
| def __init__( |
| self, dataset, batch_size, max_length=0, sos=0, eos=0, repeat=True, shuffle=True |
| ): |
| self.dataset = dataset |
| self.batch_size = batch_size |
| |
| |
| |
| self.epoch = 0 |
| |
| self.is_new_epoch = False |
| self.repeat = repeat |
| length = len(dataset) |
| self.batch_indices = [] |
| |
| if batch_size > 1: |
| indices = sorted(range(len(dataset)), key=lambda i: -len(dataset[i])) |
| bs = 0 |
| while bs < length: |
| be = min(bs + batch_size, length) |
| |
| |
| if max_length > 0: |
| sent_length = len(dataset[indices[bs]]) |
| be = min( |
| be, bs + max(batch_size // (sent_length // max_length + 1), 1) |
| ) |
| self.batch_indices.append(np.array(indices[bs:be])) |
| bs = be |
| if shuffle: |
| |
| random.shuffle(self.batch_indices) |
| else: |
| self.batch_indices = [np.array([i]) for i in six.moves.range(length)] |
|
|
| |
| |
| self.iteration = 0 |
| self.sos = sos |
| self.eos = eos |
| |
| self._previous_epoch_detail = -1.0 |
|
|
| def __next__(self): |
| |
| |
| |
| n_batches = len(self.batch_indices) |
| if not self.repeat and self.iteration >= n_batches: |
| |
| |
| raise StopIteration |
|
|
| batch = [] |
| for idx in self.batch_indices[self.iteration % n_batches]: |
| batch.append( |
| ( |
| np.append([self.sos], self.dataset[idx]), |
| np.append(self.dataset[idx], [self.eos]), |
| ) |
| ) |
|
|
| self._previous_epoch_detail = self.epoch_detail |
| self.iteration += 1 |
|
|
| epoch = self.iteration // n_batches |
| self.is_new_epoch = self.epoch < epoch |
| if self.is_new_epoch: |
| self.epoch = epoch |
|
|
| return batch |
|
|
| def start_shuffle(self): |
| random.shuffle(self.batch_indices) |
|
|
| @property |
| def epoch_detail(self): |
| |
| return self.iteration / len(self.batch_indices) |
|
|
| @property |
| def previous_epoch_detail(self): |
| if self._previous_epoch_detail < 0: |
| return None |
| return self._previous_epoch_detail |
|
|
| def serialize(self, serializer): |
| |
| self.iteration = serializer("iteration", self.iteration) |
| self.epoch = serializer("epoch", self.epoch) |
| try: |
| self._previous_epoch_detail = serializer( |
| "previous_epoch_detail", self._previous_epoch_detail |
| ) |
| except KeyError: |
| |
| self._previous_epoch_detail = self.epoch + ( |
| self.current_position - 1 |
| ) / len(self.batch_indices) |
| if self.epoch_detail > 0: |
| self._previous_epoch_detail = max(self._previous_epoch_detail, 0.0) |
| else: |
| self._previous_epoch_detail = -1.0 |
|
|
|
|
| class MakeSymlinkToBestModel(extension.Extension): |
| """Extension that makes a symbolic link to the best model |
| |
| :param str key: Key of value |
| :param str prefix: Prefix of model files and link target |
| :param str suffix: Suffix of link target |
| """ |
|
|
| def __init__(self, key, prefix="model", suffix="best"): |
| super(MakeSymlinkToBestModel, self).__init__() |
| self.best_model = -1 |
| self.min_loss = 0.0 |
| self.key = key |
| self.prefix = prefix |
| self.suffix = suffix |
|
|
| def __call__(self, trainer): |
| observation = trainer.observation |
| if self.key in observation: |
| loss = observation[self.key] |
| if self.best_model == -1 or loss < self.min_loss: |
| self.min_loss = loss |
| self.best_model = trainer.updater.epoch |
| src = "%s.%d" % (self.prefix, self.best_model) |
| dest = os.path.join(trainer.out, "%s.%s" % (self.prefix, self.suffix)) |
| if os.path.lexists(dest): |
| os.remove(dest) |
| os.symlink(src, dest) |
| logging.info("best model is " + src) |
|
|
| def serialize(self, serializer): |
| if isinstance(serializer, chainer.serializer.Serializer): |
| serializer("_best_model", self.best_model) |
| serializer("_min_loss", self.min_loss) |
| serializer("_key", self.key) |
| serializer("_prefix", self.prefix) |
| serializer("_suffix", self.suffix) |
| else: |
| self.best_model = serializer("_best_model", -1) |
| self.min_loss = serializer("_min_loss", 0.0) |
| self.key = serializer("_key", "") |
| self.prefix = serializer("_prefix", "model") |
| self.suffix = serializer("_suffix", "best") |
|
|
|
|
| |
| |
| def make_lexical_tree(word_dict, subword_dict, word_unk): |
| """Make a lexical tree to compute word-level probabilities""" |
| |
| root = [{}, -1, None] |
| for w, wid in word_dict.items(): |
| if wid > 0 and wid != word_unk: |
| if True in [c not in subword_dict for c in w]: |
| continue |
| succ = root[0] |
| for i, c in enumerate(w): |
| cid = subword_dict[c] |
| if cid not in succ: |
| succ[cid] = [{}, -1, (wid - 1, wid)] |
| else: |
| prev = succ[cid][2] |
| succ[cid][2] = (min(prev[0], wid - 1), max(prev[1], wid)) |
| if i == len(w) - 1: |
| succ[cid][1] = wid |
| succ = succ[cid][0] |
| return root |
|
|