Spaces:
Runtime error
Runtime error
| # Copyright 2017 Johns Hopkins University (Shinji Watanabe) | |
| # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) | |
| import argparse | |
| import copy | |
| import json | |
| import logging | |
| import os | |
| import shutil | |
| import tempfile | |
| import numpy as np | |
| import torch | |
| # * -------------------- training iterator related -------------------- * | |
| class CompareValueTrigger(object): | |
| """Trigger invoked when key value getting bigger or lower than before. | |
| Args: | |
| key (str) : Key of value. | |
| compare_fn ((float, float) -> bool) : Function to compare the values. | |
| trigger (tuple(int, str)) : Trigger that decide the comparison interval. | |
| """ | |
| def __init__(self, key, compare_fn, trigger=(1, "epoch")): | |
| from chainer import training | |
| self._key = key | |
| self._best_value = None | |
| self._interval_trigger = training.util.get_trigger(trigger) | |
| self._init_summary() | |
| self._compare_fn = compare_fn | |
| def __call__(self, trainer): | |
| """Get value related to the key and compare with current value.""" | |
| observation = trainer.observation | |
| summary = self._summary | |
| key = self._key | |
| if key in observation: | |
| summary.add({key: observation[key]}) | |
| if not self._interval_trigger(trainer): | |
| return False | |
| stats = summary.compute_mean() | |
| value = float(stats[key]) # copy to CPU | |
| self._init_summary() | |
| if self._best_value is None: | |
| # initialize best value | |
| self._best_value = value | |
| return False | |
| elif self._compare_fn(self._best_value, value): | |
| return True | |
| else: | |
| self._best_value = value | |
| return False | |
| def _init_summary(self): | |
| import chainer | |
| self._summary = chainer.reporter.DictSummary() | |
| try: | |
| from chainer.training import extension | |
| except ImportError: | |
| PlotAttentionReport = None | |
| else: | |
| class PlotAttentionReport(extension.Extension): | |
| """Plot attention reporter. | |
| Args: | |
| att_vis_fn (espnet.nets.*_backend.e2e_asr.E2E.calculate_all_attentions): | |
| Function of attention visualization. | |
| data (list[tuple(str, dict[str, list[Any]])]): List json utt key items. | |
| outdir (str): Directory to save figures. | |
| converter (espnet.asr.*_backend.asr.CustomConverter): | |
| Function to convert data. | |
| device (int | torch.device): Device. | |
| reverse (bool): If True, input and output length are reversed. | |
| ikey (str): Key to access input | |
| (for ASR/ST ikey="input", for MT ikey="output".) | |
| iaxis (int): Dimension to access input | |
| (for ASR/ST iaxis=0, for MT iaxis=1.) | |
| okey (str): Key to access output | |
| (for ASR/ST okey="input", MT okay="output".) | |
| oaxis (int): Dimension to access output | |
| (for ASR/ST oaxis=0, for MT oaxis=0.) | |
| subsampling_factor (int): subsampling factor in encoder | |
| """ | |
| def __init__( | |
| self, | |
| att_vis_fn, | |
| data, | |
| outdir, | |
| converter, | |
| transform, | |
| device, | |
| reverse=False, | |
| ikey="input", | |
| iaxis=0, | |
| okey="output", | |
| oaxis=0, | |
| subsampling_factor=1, | |
| ): | |
| self.att_vis_fn = att_vis_fn | |
| self.data = copy.deepcopy(data) | |
| self.data_dict = {k: v for k, v in copy.deepcopy(data)} | |
| # key is utterance ID | |
| self.outdir = outdir | |
| self.converter = converter | |
| self.transform = transform | |
| self.device = device | |
| self.reverse = reverse | |
| self.ikey = ikey | |
| self.iaxis = iaxis | |
| self.okey = okey | |
| self.oaxis = oaxis | |
| self.factor = subsampling_factor | |
| if not os.path.exists(self.outdir): | |
| os.makedirs(self.outdir) | |
| def __call__(self, trainer): | |
| """Plot and save image file of att_ws matrix.""" | |
| att_ws, uttid_list = self.get_attention_weights() | |
| if isinstance(att_ws, list): # multi-encoder case | |
| num_encs = len(att_ws) - 1 | |
| # atts | |
| for i in range(num_encs): | |
| for idx, att_w in enumerate(att_ws[i]): | |
| filename = "%s/%s.ep.{.updater.epoch}.att%d.png" % ( | |
| self.outdir, | |
| uttid_list[idx], | |
| i + 1, | |
| ) | |
| att_w = self.trim_attention_weight(uttid_list[idx], att_w) | |
| np_filename = "%s/%s.ep.{.updater.epoch}.att%d.npy" % ( | |
| self.outdir, | |
| uttid_list[idx], | |
| i + 1, | |
| ) | |
| np.save(np_filename.format(trainer), att_w) | |
| self._plot_and_save_attention(att_w, filename.format(trainer)) | |
| # han | |
| for idx, att_w in enumerate(att_ws[num_encs]): | |
| filename = "%s/%s.ep.{.updater.epoch}.han.png" % ( | |
| self.outdir, | |
| uttid_list[idx], | |
| ) | |
| att_w = self.trim_attention_weight(uttid_list[idx], att_w) | |
| np_filename = "%s/%s.ep.{.updater.epoch}.han.npy" % ( | |
| self.outdir, | |
| uttid_list[idx], | |
| ) | |
| np.save(np_filename.format(trainer), att_w) | |
| self._plot_and_save_attention( | |
| att_w, filename.format(trainer), han_mode=True | |
| ) | |
| else: | |
| for idx, att_w in enumerate(att_ws): | |
| filename = "%s/%s.ep.{.updater.epoch}.png" % ( | |
| self.outdir, | |
| uttid_list[idx], | |
| ) | |
| att_w = self.trim_attention_weight(uttid_list[idx], att_w) | |
| np_filename = "%s/%s.ep.{.updater.epoch}.npy" % ( | |
| self.outdir, | |
| uttid_list[idx], | |
| ) | |
| np.save(np_filename.format(trainer), att_w) | |
| self._plot_and_save_attention(att_w, filename.format(trainer)) | |
| def log_attentions(self, logger, step): | |
| """Add image files of att_ws matrix to the tensorboard.""" | |
| att_ws, uttid_list = self.get_attention_weights() | |
| if isinstance(att_ws, list): # multi-encoder case | |
| num_encs = len(att_ws) - 1 | |
| # atts | |
| for i in range(num_encs): | |
| for idx, att_w in enumerate(att_ws[i]): | |
| att_w = self.trim_attention_weight(uttid_list[idx], att_w) | |
| plot = self.draw_attention_plot(att_w) | |
| logger.add_figure( | |
| "%s_att%d" % (uttid_list[idx], i + 1), | |
| plot.gcf(), | |
| step, | |
| ) | |
| # han | |
| for idx, att_w in enumerate(att_ws[num_encs]): | |
| att_w = self.trim_attention_weight(uttid_list[idx], att_w) | |
| plot = self.draw_han_plot(att_w) | |
| logger.add_figure( | |
| "%s_han" % (uttid_list[idx]), | |
| plot.gcf(), | |
| step, | |
| ) | |
| else: | |
| for idx, att_w in enumerate(att_ws): | |
| att_w = self.trim_attention_weight(uttid_list[idx], att_w) | |
| plot = self.draw_attention_plot(att_w) | |
| logger.add_figure("%s" % (uttid_list[idx]), plot.gcf(), step) | |
| def get_attention_weights(self): | |
| """Return attention weights. | |
| Returns: | |
| numpy.ndarray: attention weights. float. Its shape would be | |
| differ from backend. | |
| * pytorch-> 1) multi-head case => (B, H, Lmax, Tmax), 2) | |
| other case => (B, Lmax, Tmax). | |
| * chainer-> (B, Lmax, Tmax) | |
| """ | |
| return_batch, uttid_list = self.transform(self.data, return_uttid=True) | |
| batch = self.converter([return_batch], self.device) | |
| if isinstance(batch, tuple): | |
| att_ws = self.att_vis_fn(*batch) | |
| else: | |
| att_ws = self.att_vis_fn(**batch) | |
| return att_ws, uttid_list | |
| def trim_attention_weight(self, uttid, att_w): | |
| """Transform attention matrix with regard to self.reverse.""" | |
| if self.reverse: | |
| enc_key, enc_axis = self.okey, self.oaxis | |
| dec_key, dec_axis = self.ikey, self.iaxis | |
| else: | |
| enc_key, enc_axis = self.ikey, self.iaxis | |
| dec_key, dec_axis = self.okey, self.oaxis | |
| dec_len = int(self.data_dict[uttid][dec_key][dec_axis]["shape"][0]) | |
| enc_len = int(self.data_dict[uttid][enc_key][enc_axis]["shape"][0]) | |
| if self.factor > 1: | |
| enc_len //= self.factor | |
| if len(att_w.shape) == 3: | |
| att_w = att_w[:, :dec_len, :enc_len] | |
| else: | |
| att_w = att_w[:dec_len, :enc_len] | |
| return att_w | |
| def draw_attention_plot(self, att_w): | |
| """Plot the att_w matrix. | |
| Returns: | |
| matplotlib.pyplot: pyplot object with attention matrix image. | |
| """ | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| plt.clf() | |
| att_w = att_w.astype(np.float32) | |
| if len(att_w.shape) == 3: | |
| for h, aw in enumerate(att_w, 1): | |
| plt.subplot(1, len(att_w), h) | |
| plt.imshow(aw, aspect="auto") | |
| plt.xlabel("Encoder Index") | |
| plt.ylabel("Decoder Index") | |
| else: | |
| plt.imshow(att_w, aspect="auto") | |
| plt.xlabel("Encoder Index") | |
| plt.ylabel("Decoder Index") | |
| plt.tight_layout() | |
| return plt | |
| def draw_han_plot(self, att_w): | |
| """Plot the att_w matrix for hierarchical attention. | |
| Returns: | |
| matplotlib.pyplot: pyplot object with attention matrix image. | |
| """ | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| plt.clf() | |
| if len(att_w.shape) == 3: | |
| for h, aw in enumerate(att_w, 1): | |
| legends = [] | |
| plt.subplot(1, len(att_w), h) | |
| for i in range(aw.shape[1]): | |
| plt.plot(aw[:, i]) | |
| legends.append("Att{}".format(i)) | |
| plt.ylim([0, 1.0]) | |
| plt.xlim([0, aw.shape[0]]) | |
| plt.grid(True) | |
| plt.ylabel("Attention Weight") | |
| plt.xlabel("Decoder Index") | |
| plt.legend(legends) | |
| else: | |
| legends = [] | |
| for i in range(att_w.shape[1]): | |
| plt.plot(att_w[:, i]) | |
| legends.append("Att{}".format(i)) | |
| plt.ylim([0, 1.0]) | |
| plt.xlim([0, att_w.shape[0]]) | |
| plt.grid(True) | |
| plt.ylabel("Attention Weight") | |
| plt.xlabel("Decoder Index") | |
| plt.legend(legends) | |
| plt.tight_layout() | |
| return plt | |
| def _plot_and_save_attention(self, att_w, filename, han_mode=False): | |
| if han_mode: | |
| plt = self.draw_han_plot(att_w) | |
| else: | |
| plt = self.draw_attention_plot(att_w) | |
| plt.savefig(filename) | |
| plt.close() | |
| try: | |
| from chainer.training import extension | |
| except ImportError: | |
| PlotCTCReport = None | |
| else: | |
| class PlotCTCReport(extension.Extension): | |
| """Plot CTC reporter. | |
| Args: | |
| ctc_vis_fn (espnet.nets.*_backend.e2e_asr.E2E.calculate_all_ctc_probs): | |
| Function of CTC visualization. | |
| data (list[tuple(str, dict[str, list[Any]])]): List json utt key items. | |
| outdir (str): Directory to save figures. | |
| converter (espnet.asr.*_backend.asr.CustomConverter): | |
| Function to convert data. | |
| device (int | torch.device): Device. | |
| reverse (bool): If True, input and output length are reversed. | |
| ikey (str): Key to access input | |
| (for ASR/ST ikey="input", for MT ikey="output".) | |
| iaxis (int): Dimension to access input | |
| (for ASR/ST iaxis=0, for MT iaxis=1.) | |
| okey (str): Key to access output | |
| (for ASR/ST okey="input", MT okay="output".) | |
| oaxis (int): Dimension to access output | |
| (for ASR/ST oaxis=0, for MT oaxis=0.) | |
| subsampling_factor (int): subsampling factor in encoder | |
| """ | |
| def __init__( | |
| self, | |
| ctc_vis_fn, | |
| data, | |
| outdir, | |
| converter, | |
| transform, | |
| device, | |
| reverse=False, | |
| ikey="input", | |
| iaxis=0, | |
| okey="output", | |
| oaxis=0, | |
| subsampling_factor=1, | |
| ): | |
| self.ctc_vis_fn = ctc_vis_fn | |
| self.data = copy.deepcopy(data) | |
| self.data_dict = {k: v for k, v in copy.deepcopy(data)} | |
| # key is utterance ID | |
| self.outdir = outdir | |
| self.converter = converter | |
| self.transform = transform | |
| self.device = device | |
| self.reverse = reverse | |
| self.ikey = ikey | |
| self.iaxis = iaxis | |
| self.okey = okey | |
| self.oaxis = oaxis | |
| self.factor = subsampling_factor | |
| if not os.path.exists(self.outdir): | |
| os.makedirs(self.outdir) | |
| def __call__(self, trainer): | |
| """Plot and save image file of ctc prob.""" | |
| ctc_probs, uttid_list = self.get_ctc_probs() | |
| if isinstance(ctc_probs, list): # multi-encoder case | |
| num_encs = len(ctc_probs) - 1 | |
| for i in range(num_encs): | |
| for idx, ctc_prob in enumerate(ctc_probs[i]): | |
| filename = "%s/%s.ep.{.updater.epoch}.ctc%d.png" % ( | |
| self.outdir, | |
| uttid_list[idx], | |
| i + 1, | |
| ) | |
| ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob) | |
| np_filename = "%s/%s.ep.{.updater.epoch}.ctc%d.npy" % ( | |
| self.outdir, | |
| uttid_list[idx], | |
| i + 1, | |
| ) | |
| np.save(np_filename.format(trainer), ctc_prob) | |
| self._plot_and_save_ctc(ctc_prob, filename.format(trainer)) | |
| else: | |
| for idx, ctc_prob in enumerate(ctc_probs): | |
| filename = "%s/%s.ep.{.updater.epoch}.png" % ( | |
| self.outdir, | |
| uttid_list[idx], | |
| ) | |
| ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob) | |
| np_filename = "%s/%s.ep.{.updater.epoch}.npy" % ( | |
| self.outdir, | |
| uttid_list[idx], | |
| ) | |
| np.save(np_filename.format(trainer), ctc_prob) | |
| self._plot_and_save_ctc(ctc_prob, filename.format(trainer)) | |
| def log_ctc_probs(self, logger, step): | |
| """Add image files of ctc probs to the tensorboard.""" | |
| ctc_probs, uttid_list = self.get_ctc_probs() | |
| if isinstance(ctc_probs, list): # multi-encoder case | |
| num_encs = len(ctc_probs) - 1 | |
| for i in range(num_encs): | |
| for idx, ctc_prob in enumerate(ctc_probs[i]): | |
| ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob) | |
| plot = self.draw_ctc_plot(ctc_prob) | |
| logger.add_figure( | |
| "%s_ctc%d" % (uttid_list[idx], i + 1), | |
| plot.gcf(), | |
| step, | |
| ) | |
| else: | |
| for idx, ctc_prob in enumerate(ctc_probs): | |
| ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob) | |
| plot = self.draw_ctc_plot(ctc_prob) | |
| logger.add_figure("%s" % (uttid_list[idx]), plot.gcf(), step) | |
| def get_ctc_probs(self): | |
| """Return CTC probs. | |
| Returns: | |
| numpy.ndarray: CTC probs. float. Its shape would be | |
| differ from backend. (B, Tmax, vocab). | |
| """ | |
| return_batch, uttid_list = self.transform(self.data, return_uttid=True) | |
| batch = self.converter([return_batch], self.device) | |
| if isinstance(batch, tuple): | |
| probs = self.ctc_vis_fn(*batch) | |
| else: | |
| probs = self.ctc_vis_fn(**batch) | |
| return probs, uttid_list | |
| def trim_ctc_prob(self, uttid, prob): | |
| """Trim CTC posteriors accoding to input lengths.""" | |
| enc_len = int(self.data_dict[uttid][self.ikey][self.iaxis]["shape"][0]) | |
| if self.factor > 1: | |
| enc_len //= self.factor | |
| prob = prob[:enc_len] | |
| return prob | |
| def draw_ctc_plot(self, ctc_prob): | |
| """Plot the ctc_prob matrix. | |
| Returns: | |
| matplotlib.pyplot: pyplot object with CTC prob matrix image. | |
| """ | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| ctc_prob = ctc_prob.astype(np.float32) | |
| plt.clf() | |
| topk_ids = np.argsort(ctc_prob, axis=1) | |
| n_frames, vocab = ctc_prob.shape | |
| times_probs = np.arange(n_frames) | |
| plt.figure(figsize=(20, 8)) | |
| # NOTE: index 0 is reserved for blank | |
| for idx in set(topk_ids.reshape(-1).tolist()): | |
| if idx == 0: | |
| plt.plot( | |
| times_probs, ctc_prob[:, 0], ":", label="<blank>", color="grey" | |
| ) | |
| else: | |
| plt.plot(times_probs, ctc_prob[:, idx]) | |
| plt.xlabel("Input [frame]", fontsize=12) | |
| plt.ylabel("Posteriors", fontsize=12) | |
| plt.xticks(list(range(0, int(n_frames) + 1, 10))) | |
| plt.yticks(list(range(0, 2, 1))) | |
| plt.tight_layout() | |
| return plt | |
| def _plot_and_save_ctc(self, ctc_prob, filename): | |
| plt = self.draw_ctc_plot(ctc_prob) | |
| plt.savefig(filename) | |
| plt.close() | |
| def restore_snapshot(model, snapshot, load_fn=None): | |
| """Extension to restore snapshot. | |
| Returns: | |
| An extension function. | |
| """ | |
| import chainer | |
| from chainer import training | |
| if load_fn is None: | |
| load_fn = chainer.serializers.load_npz | |
| def restore_snapshot(trainer): | |
| _restore_snapshot(model, snapshot, load_fn) | |
| return restore_snapshot | |
| def _restore_snapshot(model, snapshot, load_fn=None): | |
| if load_fn is None: | |
| import chainer | |
| load_fn = chainer.serializers.load_npz | |
| load_fn(snapshot, model) | |
| logging.info("restored from " + str(snapshot)) | |
| def adadelta_eps_decay(eps_decay): | |
| """Extension to perform adadelta eps decay. | |
| Args: | |
| eps_decay (float): Decay rate of eps. | |
| Returns: | |
| An extension function. | |
| """ | |
| from chainer import training | |
| def adadelta_eps_decay(trainer): | |
| _adadelta_eps_decay(trainer, eps_decay) | |
| return adadelta_eps_decay | |
| def _adadelta_eps_decay(trainer, eps_decay): | |
| optimizer = trainer.updater.get_optimizer("main") | |
| # for chainer | |
| if hasattr(optimizer, "eps"): | |
| current_eps = optimizer.eps | |
| setattr(optimizer, "eps", current_eps * eps_decay) | |
| logging.info("adadelta eps decayed to " + str(optimizer.eps)) | |
| # pytorch | |
| else: | |
| for p in optimizer.param_groups: | |
| p["eps"] *= eps_decay | |
| logging.info("adadelta eps decayed to " + str(p["eps"])) | |
| def adam_lr_decay(eps_decay): | |
| """Extension to perform adam lr decay. | |
| Args: | |
| eps_decay (float): Decay rate of lr. | |
| Returns: | |
| An extension function. | |
| """ | |
| from chainer import training | |
| def adam_lr_decay(trainer): | |
| _adam_lr_decay(trainer, eps_decay) | |
| return adam_lr_decay | |
| def _adam_lr_decay(trainer, eps_decay): | |
| optimizer = trainer.updater.get_optimizer("main") | |
| # for chainer | |
| if hasattr(optimizer, "lr"): | |
| current_lr = optimizer.lr | |
| setattr(optimizer, "lr", current_lr * eps_decay) | |
| logging.info("adam lr decayed to " + str(optimizer.lr)) | |
| # pytorch | |
| else: | |
| for p in optimizer.param_groups: | |
| p["lr"] *= eps_decay | |
| logging.info("adam lr decayed to " + str(p["lr"])) | |
| def torch_snapshot(savefun=torch.save, filename="snapshot.ep.{.updater.epoch}"): | |
| """Extension to take snapshot of the trainer for pytorch. | |
| Returns: | |
| An extension function. | |
| """ | |
| from chainer.training import extension | |
| def torch_snapshot(trainer): | |
| _torch_snapshot_object(trainer, trainer, filename.format(trainer), savefun) | |
| return torch_snapshot | |
| def _torch_snapshot_object(trainer, target, filename, savefun): | |
| from chainer.serializers import DictionarySerializer | |
| # make snapshot_dict dictionary | |
| s = DictionarySerializer() | |
| s.save(trainer) | |
| if hasattr(trainer.updater.model, "model"): | |
| # (for TTS) | |
| if hasattr(trainer.updater.model.model, "module"): | |
| model_state_dict = trainer.updater.model.model.module.state_dict() | |
| else: | |
| model_state_dict = trainer.updater.model.model.state_dict() | |
| else: | |
| # (for ASR) | |
| if hasattr(trainer.updater.model, "module"): | |
| model_state_dict = trainer.updater.model.module.state_dict() | |
| else: | |
| model_state_dict = trainer.updater.model.state_dict() | |
| snapshot_dict = { | |
| "trainer": s.target, | |
| "model": model_state_dict, | |
| "optimizer": trainer.updater.get_optimizer("main").state_dict(), | |
| } | |
| # save snapshot dictionary | |
| fn = filename.format(trainer) | |
| prefix = "tmp" + fn | |
| tmpdir = tempfile.mkdtemp(prefix=prefix, dir=trainer.out) | |
| tmppath = os.path.join(tmpdir, fn) | |
| try: | |
| savefun(snapshot_dict, tmppath) | |
| shutil.move(tmppath, os.path.join(trainer.out, fn)) | |
| finally: | |
| shutil.rmtree(tmpdir) | |
| def add_gradient_noise(model, iteration, duration=100, eta=1.0, scale_factor=0.55): | |
| """Adds noise from a standard normal distribution to the gradients. | |
| The standard deviation (`sigma`) is controlled by the three hyper-parameters below. | |
| `sigma` goes to zero (no noise) with more iterations. | |
| Args: | |
| model (torch.nn.model): Model. | |
| iteration (int): Number of iterations. | |
| duration (int) {100, 1000}: | |
| Number of durations to control the interval of the `sigma` change. | |
| eta (float) {0.01, 0.3, 1.0}: The magnitude of `sigma`. | |
| scale_factor (float) {0.55}: The scale of `sigma`. | |
| """ | |
| interval = (iteration // duration) + 1 | |
| sigma = eta / interval**scale_factor | |
| for param in model.parameters(): | |
| if param.grad is not None: | |
| _shape = param.grad.size() | |
| noise = sigma * torch.randn(_shape).to(param.device) | |
| param.grad += noise | |
| # * -------------------- general -------------------- * | |
| def get_model_conf(model_path, conf_path=None): | |
| """Get model config information by reading a model config file (model.json). | |
| Args: | |
| model_path (str): Model path. | |
| conf_path (str): Optional model config path. | |
| Returns: | |
| list[int, int, dict[str, Any]]: Config information loaded from json file. | |
| """ | |
| if conf_path is None: | |
| model_conf = os.path.dirname(model_path) + "/model.json" | |
| else: | |
| model_conf = conf_path | |
| with open(model_conf, "rb") as f: | |
| logging.info("reading a config file from " + model_conf) | |
| confs = json.load(f) | |
| if isinstance(confs, dict): | |
| # for lm | |
| args = confs | |
| return argparse.Namespace(**args) | |
| else: | |
| # for asr, tts, mt | |
| idim, odim, args = confs | |
| return idim, odim, argparse.Namespace(**args) | |
| def chainer_load(path, model): | |
| """Load chainer model parameters. | |
| Args: | |
| path (str): Model path or snapshot file path to be loaded. | |
| model (chainer.Chain): Chainer model. | |
| """ | |
| import chainer | |
| if "snapshot" in os.path.basename(path): | |
| chainer.serializers.load_npz(path, model, path="updater/model:main/") | |
| else: | |
| chainer.serializers.load_npz(path, model) | |
| def torch_save(path, model): | |
| """Save torch model states. | |
| Args: | |
| path (str): Model path to be saved. | |
| model (torch.nn.Module): Torch model. | |
| """ | |
| if hasattr(model, "module"): | |
| torch.save(model.module.state_dict(), path) | |
| else: | |
| torch.save(model.state_dict(), path) | |
| def snapshot_object(target, filename): | |
| """Returns a trainer extension to take snapshots of a given object. | |
| Args: | |
| target (model): Object to serialize. | |
| filename (str): Name of the file into which the object is serialized.It can | |
| be a format string, where the trainer object is passed to | |
| the :meth: `str.format` method. For example, | |
| ``'snapshot_{.updater.iteration}'`` is converted to | |
| ``'snapshot_10000'`` at the 10,000th iteration. | |
| Returns: | |
| An extension function. | |
| """ | |
| from chainer.training import extension | |
| def snapshot_object(trainer): | |
| torch_save(os.path.join(trainer.out, filename.format(trainer)), target) | |
| return snapshot_object | |
| def torch_load(path, model): | |
| """Load torch model states. | |
| Args: | |
| path (str): Model path or snapshot file path to be loaded. | |
| model (torch.nn.Module): Torch model. | |
| """ | |
| if "snapshot" in os.path.basename(path): | |
| model_state_dict = torch.load(path, map_location=lambda storage, loc: storage)[ | |
| "model" | |
| ] | |
| else: | |
| model_state_dict = torch.load(path, map_location=lambda storage, loc: storage) | |
| if hasattr(model, "module"): | |
| model.module.load_state_dict(model_state_dict) | |
| else: | |
| model.load_state_dict(model_state_dict) | |
| del model_state_dict | |
| def torch_resume(snapshot_path, trainer): | |
| """Resume from snapshot for pytorch. | |
| Args: | |
| snapshot_path (str): Snapshot file path. | |
| trainer (chainer.training.Trainer): Chainer's trainer instance. | |
| """ | |
| from chainer.serializers import NpzDeserializer | |
| # load snapshot | |
| snapshot_dict = torch.load(snapshot_path, map_location=lambda storage, loc: storage) | |
| # restore trainer states | |
| d = NpzDeserializer(snapshot_dict["trainer"]) | |
| d.load(trainer) | |
| # restore model states | |
| if hasattr(trainer.updater.model, "model"): | |
| # (for TTS model) | |
| if hasattr(trainer.updater.model.model, "module"): | |
| trainer.updater.model.model.module.load_state_dict(snapshot_dict["model"]) | |
| else: | |
| trainer.updater.model.model.load_state_dict(snapshot_dict["model"]) | |
| else: | |
| # (for ASR model) | |
| if hasattr(trainer.updater.model, "module"): | |
| trainer.updater.model.module.load_state_dict(snapshot_dict["model"]) | |
| else: | |
| trainer.updater.model.load_state_dict(snapshot_dict["model"]) | |
| # retore optimizer states | |
| trainer.updater.get_optimizer("main").load_state_dict(snapshot_dict["optimizer"]) | |
| # delete opened snapshot | |
| del snapshot_dict | |
| # * ------------------ recognition related ------------------ * | |
| def parse_hypothesis(hyp, char_list): | |
| """Parse hypothesis. | |
| Args: | |
| hyp (list[dict[str, Any]]): Recognition hypothesis. | |
| char_list (list[str]): List of characters. | |
| Returns: | |
| tuple(str, str, str, float) | |
| """ | |
| # remove sos and get results | |
| tokenid_as_list = list(map(int, hyp["yseq"][1:])) | |
| token_as_list = [char_list[idx] for idx in tokenid_as_list] | |
| score = float(hyp["score"]) | |
| # convert to string | |
| tokenid = " ".join([str(idx) for idx in tokenid_as_list]) | |
| token = " ".join(token_as_list) | |
| text = "".join(token_as_list).replace("<space>", " ") | |
| return text, token, tokenid, score | |
| def add_results_to_json(nbest_hyps, char_list): | |
| """Add N-best results to json. | |
| Args: | |
| js (dict[str, Any]): Groundtruth utterance dict. | |
| nbest_hyps_sd (list[dict[str, Any]]): | |
| List of hypothesis for multi_speakers: nutts x nspkrs. | |
| char_list (list[str]): List of characters. | |
| Returns: | |
| str: 1-best result | |
| """ | |
| assert len(nbest_hyps) == 1, "only 1-best result is supported." | |
| # parse hypothesis | |
| rec_text, rec_token, rec_tokenid, score = parse_hypothesis(nbest_hyps[0], char_list) | |
| return rec_text | |
| def plot_spectrogram( | |
| plt, | |
| spec, | |
| mode="db", | |
| fs=None, | |
| frame_shift=None, | |
| bottom=True, | |
| left=True, | |
| right=True, | |
| top=False, | |
| labelbottom=True, | |
| labelleft=True, | |
| labelright=True, | |
| labeltop=False, | |
| cmap="inferno", | |
| ): | |
| """Plot spectrogram using matplotlib. | |
| Args: | |
| plt (matplotlib.pyplot): pyplot object. | |
| spec (numpy.ndarray): Input stft (Freq, Time) | |
| mode (str): db or linear. | |
| fs (int): Sample frequency. To convert y-axis to kHz unit. | |
| frame_shift (int): The frame shift of stft. To convert x-axis to second unit. | |
| bottom (bool):Whether to draw the respective ticks. | |
| left (bool): | |
| right (bool): | |
| top (bool): | |
| labelbottom (bool):Whether to draw the respective tick labels. | |
| labelleft (bool): | |
| labelright (bool): | |
| labeltop (bool): | |
| cmap (str): Colormap defined in matplotlib. | |
| """ | |
| spec = np.abs(spec) | |
| if mode == "db": | |
| x = 20 * np.log10(spec + np.finfo(spec.dtype).eps) | |
| elif mode == "linear": | |
| x = spec | |
| else: | |
| raise ValueError(mode) | |
| if fs is not None: | |
| ytop = fs / 2000 | |
| ylabel = "kHz" | |
| else: | |
| ytop = x.shape[0] | |
| ylabel = "bin" | |
| if frame_shift is not None and fs is not None: | |
| xtop = x.shape[1] * frame_shift / fs | |
| xlabel = "s" | |
| else: | |
| xtop = x.shape[1] | |
| xlabel = "frame" | |
| extent = (0, xtop, 0, ytop) | |
| plt.imshow(x[::-1], cmap=cmap, extent=extent) | |
| if labelbottom: | |
| plt.xlabel("time [{}]".format(xlabel)) | |
| if labelleft: | |
| plt.ylabel("freq [{}]".format(ylabel)) | |
| plt.colorbar().set_label("{}".format(mode)) | |
| plt.tick_params( | |
| bottom=bottom, | |
| left=left, | |
| right=right, | |
| top=top, | |
| labelbottom=labelbottom, | |
| labelleft=labelleft, | |
| labelright=labelright, | |
| labeltop=labeltop, | |
| ) | |
| plt.axis("auto") | |
| # * ------------------ recognition related ------------------ * | |
| def format_mulenc_args(args): | |
| """Format args for multi-encoder setup. | |
| It deals with following situations: (when args.num_encs=2): | |
| 1. args.elayers = None -> args.elayers = [4, 4]; | |
| 2. args.elayers = 4 -> args.elayers = [4, 4]; | |
| 3. args.elayers = [4, 4, 4] -> args.elayers = [4, 4]. | |
| """ | |
| # default values when None is assigned. | |
| default_dict = { | |
| "etype": "blstmp", | |
| "elayers": 4, | |
| "eunits": 300, | |
| "subsample": "1", | |
| "dropout_rate": 0.0, | |
| "atype": "dot", | |
| "adim": 320, | |
| "awin": 5, | |
| "aheads": 4, | |
| "aconv_chans": -1, | |
| "aconv_filts": 100, | |
| } | |
| for k in default_dict.keys(): | |
| if isinstance(vars(args)[k], list): | |
| if len(vars(args)[k]) != args.num_encs: | |
| logging.warning( | |
| "Length mismatch {}: Convert {} to {}.".format( | |
| k, vars(args)[k], vars(args)[k][: args.num_encs] | |
| ) | |
| ) | |
| vars(args)[k] = vars(args)[k][: args.num_encs] | |
| else: | |
| if not vars(args)[k]: | |
| # assign default value if it is None | |
| vars(args)[k] = default_dict[k] | |
| logging.warning( | |
| "{} is not specified, use default value {}.".format( | |
| k, default_dict[k] | |
| ) | |
| ) | |
| # duplicate | |
| logging.warning( | |
| "Type mismatch {}: Convert {} to {}.".format( | |
| k, vars(args)[k], [vars(args)[k] for _ in range(args.num_encs)] | |
| ) | |
| ) | |
| vars(args)[k] = [vars(args)[k] for _ in range(args.num_encs)] | |
| return args | |