| |
| |
|
|
| |
| |
|
|
| """Training/decoding definition for the text translation task.""" |
|
|
| import json |
| import logging |
| import os |
| import sys |
|
|
| from chainer import training |
| from chainer.training import extensions |
| import numpy as np |
| from tensorboardX import SummaryWriter |
| import torch |
|
|
| from espnet.asr.asr_utils import adadelta_eps_decay |
| from espnet.asr.asr_utils import adam_lr_decay |
| from espnet.asr.asr_utils import add_results_to_json |
| from espnet.asr.asr_utils import CompareValueTrigger |
| from espnet.asr.asr_utils import restore_snapshot |
| from espnet.asr.asr_utils import snapshot_object |
| from espnet.asr.asr_utils import torch_load |
| from espnet.asr.asr_utils import torch_resume |
| from espnet.asr.asr_utils import torch_snapshot |
| from espnet.nets.mt_interface import MTInterface |
| from espnet.nets.pytorch_backend.e2e_asr import pad_list |
| from espnet.utils.dataset import ChainerDataLoader |
| from espnet.utils.dataset import TransformDataset |
| from espnet.utils.deterministic_utils import set_deterministic_pytorch |
| from espnet.utils.dynamic_import import dynamic_import |
| from espnet.utils.io_utils import LoadInputsAndTargets |
| from espnet.utils.training.batchfy import make_batchset |
| from espnet.utils.training.iterators import ShufflingEnabler |
| from espnet.utils.training.tensorboard_logger import TensorboardLogger |
| from espnet.utils.training.train_utils import check_early_stop |
| from espnet.utils.training.train_utils import set_early_stop |
|
|
| from espnet.asr.pytorch_backend.asr import CustomEvaluator |
| from espnet.asr.pytorch_backend.asr import CustomUpdater |
| from espnet.asr.pytorch_backend.asr import load_trained_model |
|
|
| import matplotlib |
|
|
| matplotlib.use("Agg") |
|
|
| if sys.version_info[0] == 2: |
| from itertools import izip_longest as zip_longest |
| else: |
| from itertools import zip_longest as zip_longest |
|
|
|
|
| class CustomConverter(object): |
| """Custom batch converter for Pytorch.""" |
|
|
| def __init__(self): |
| """Construct a CustomConverter object.""" |
| self.ignore_id = -1 |
| self.pad = 0 |
| |
| |
| |
| |
|
|
| def __call__(self, batch, device=torch.device("cpu")): |
| """Transform a batch and send it to a device. |
| |
| Args: |
| batch (list): The batch to transform. |
| device (torch.device): The device to send to. |
| |
| Returns: |
| tuple(torch.Tensor, torch.Tensor, torch.Tensor) |
| |
| """ |
| |
| assert len(batch) == 1 |
| xs, ys = batch[0] |
|
|
| |
| ilens = np.array([x.shape[0] for x in xs]) |
|
|
| |
| xs_pad = pad_list([torch.from_numpy(x).long() for x in xs], self.pad).to(device) |
| ilens = torch.from_numpy(ilens).to(device) |
| ys_pad = pad_list([torch.from_numpy(y).long() for y in ys], self.ignore_id).to( |
| device |
| ) |
|
|
| return xs_pad, ilens, ys_pad |
|
|
|
|
| def train(args): |
| """Train with the given args. |
| |
| Args: |
| args (namespace): The program arguments. |
| |
| """ |
| set_deterministic_pytorch(args) |
|
|
| |
| if not torch.cuda.is_available(): |
| logging.warning("cuda is not available") |
|
|
| |
| with open(args.valid_json, "rb") as f: |
| valid_json = json.load(f)["utts"] |
| utts = list(valid_json.keys()) |
| idim = int(valid_json[utts[0]]["output"][1]["shape"][1]) |
| odim = int(valid_json[utts[0]]["output"][0]["shape"][1]) |
| logging.info("#input dims : " + str(idim)) |
| logging.info("#output dims: " + str(odim)) |
|
|
| |
| model_class = dynamic_import(args.model_module) |
| model = model_class(idim, odim, args) |
| assert isinstance(model, MTInterface) |
|
|
| |
| if not os.path.exists(args.outdir): |
| os.makedirs(args.outdir) |
| model_conf = args.outdir + "/model.json" |
| with open(model_conf, "wb") as f: |
| logging.info("writing a model config file to " + model_conf) |
| f.write( |
| json.dumps( |
| (idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True |
| ).encode("utf_8") |
| ) |
| for key in sorted(vars(args).keys()): |
| logging.info("ARGS: " + key + ": " + str(vars(args)[key])) |
|
|
| reporter = model.reporter |
|
|
| |
| if args.ngpu > 1: |
| if args.batch_size != 0: |
| logging.warning( |
| "batch size is automatically increased (%d -> %d)" |
| % (args.batch_size, args.batch_size * args.ngpu) |
| ) |
| args.batch_size *= args.ngpu |
|
|
| |
| device = torch.device("cuda" if args.ngpu > 0 else "cpu") |
| if args.train_dtype in ("float16", "float32", "float64"): |
| dtype = getattr(torch, args.train_dtype) |
| else: |
| dtype = torch.float32 |
| model = model.to(device=device, dtype=dtype) |
|
|
| logging.warning( |
| "num. model params: {:,} (num. trained: {:,} ({:.1f}%))".format( |
| sum(p.numel() for p in model.parameters()), |
| sum(p.numel() for p in model.parameters() if p.requires_grad), |
| sum(p.numel() for p in model.parameters() if p.requires_grad) |
| * 100.0 |
| / sum(p.numel() for p in model.parameters()), |
| ) |
| ) |
|
|
| |
| if args.opt == "adadelta": |
| optimizer = torch.optim.Adadelta( |
| model.parameters(), rho=0.95, eps=args.eps, weight_decay=args.weight_decay |
| ) |
| elif args.opt == "adam": |
| optimizer = torch.optim.Adam( |
| model.parameters(), lr=args.lr, weight_decay=args.weight_decay |
| ) |
| elif args.opt == "noam": |
| from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt |
|
|
| optimizer = get_std_opt( |
| model.parameters(), |
| args.adim, |
| args.transformer_warmup_steps, |
| args.transformer_lr, |
| ) |
| else: |
| raise NotImplementedError("unknown optimizer: " + args.opt) |
|
|
| |
| if args.train_dtype in ("O0", "O1", "O2", "O3"): |
| try: |
| from apex import amp |
| except ImportError as e: |
| logging.error( |
| f"You need to install apex for --train-dtype {args.train_dtype}. " |
| "See https://github.com/NVIDIA/apex#linux" |
| ) |
| raise e |
| if args.opt == "noam": |
| model, optimizer.optimizer = amp.initialize( |
| model, optimizer.optimizer, opt_level=args.train_dtype |
| ) |
| else: |
| model, optimizer = amp.initialize( |
| model, optimizer, opt_level=args.train_dtype |
| ) |
| use_apex = True |
| else: |
| use_apex = False |
|
|
| |
| setattr(optimizer, "target", reporter) |
| setattr(optimizer, "serialize", lambda s: reporter.serialize(s)) |
|
|
| |
| converter = CustomConverter() |
|
|
| |
| with open(args.train_json, "rb") as f: |
| train_json = json.load(f)["utts"] |
| with open(args.valid_json, "rb") as f: |
| valid_json = json.load(f)["utts"] |
|
|
| use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 |
| |
| train = make_batchset( |
| train_json, |
| args.batch_size, |
| args.maxlen_in, |
| args.maxlen_out, |
| args.minibatches, |
| min_batch_size=args.ngpu if args.ngpu > 1 else 1, |
| shortest_first=use_sortagrad, |
| count=args.batch_count, |
| batch_bins=args.batch_bins, |
| batch_frames_in=args.batch_frames_in, |
| batch_frames_out=args.batch_frames_out, |
| batch_frames_inout=args.batch_frames_inout, |
| mt=True, |
| iaxis=1, |
| oaxis=0, |
| ) |
| valid = make_batchset( |
| valid_json, |
| args.batch_size, |
| args.maxlen_in, |
| args.maxlen_out, |
| args.minibatches, |
| min_batch_size=args.ngpu if args.ngpu > 1 else 1, |
| count=args.batch_count, |
| batch_bins=args.batch_bins, |
| batch_frames_in=args.batch_frames_in, |
| batch_frames_out=args.batch_frames_out, |
| batch_frames_inout=args.batch_frames_inout, |
| mt=True, |
| iaxis=1, |
| oaxis=0, |
| ) |
|
|
| load_tr = LoadInputsAndTargets(mode="mt", load_output=True) |
| load_cv = LoadInputsAndTargets(mode="mt", load_output=True) |
| |
| |
| |
| |
| train_iter = ChainerDataLoader( |
| dataset=TransformDataset(train, lambda data: converter([load_tr(data)])), |
| batch_size=1, |
| num_workers=args.n_iter_processes, |
| shuffle=not use_sortagrad, |
| collate_fn=lambda x: x[0], |
| ) |
| valid_iter = ChainerDataLoader( |
| dataset=TransformDataset(valid, lambda data: converter([load_cv(data)])), |
| batch_size=1, |
| shuffle=False, |
| collate_fn=lambda x: x[0], |
| num_workers=args.n_iter_processes, |
| ) |
|
|
| |
| updater = CustomUpdater( |
| model, |
| args.grad_clip, |
| {"main": train_iter}, |
| optimizer, |
| device, |
| args.ngpu, |
| False, |
| args.accum_grad, |
| use_apex=use_apex, |
| ) |
| trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir) |
|
|
| if use_sortagrad: |
| trainer.extend( |
| ShufflingEnabler([train_iter]), |
| trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"), |
| ) |
|
|
| |
| if args.resume: |
| logging.info("resumed from %s" % args.resume) |
| torch_resume(args.resume, trainer) |
|
|
| |
| if args.save_interval_iters > 0: |
| trainer.extend( |
| CustomEvaluator(model, {"main": valid_iter}, reporter, device, args.ngpu), |
| trigger=(args.save_interval_iters, "iteration"), |
| ) |
| else: |
| trainer.extend( |
| CustomEvaluator(model, {"main": valid_iter}, reporter, device, args.ngpu) |
| ) |
|
|
| |
| if args.num_save_attention > 0: |
| |
| data = sorted( |
| list(valid_json.items())[: args.num_save_attention], |
| key=lambda x: int(x[1]["output"][0]["shape"][0]), |
| reverse=True, |
| ) |
| if hasattr(model, "module"): |
| att_vis_fn = model.module.calculate_all_attentions |
| plot_class = model.module.attention_plot_class |
| else: |
| att_vis_fn = model.calculate_all_attentions |
| plot_class = model.attention_plot_class |
| att_reporter = plot_class( |
| att_vis_fn, |
| data, |
| args.outdir + "/att_ws", |
| converter=converter, |
| transform=load_cv, |
| device=device, |
| ikey="output", |
| iaxis=1, |
| ) |
| trainer.extend(att_reporter, trigger=(1, "epoch")) |
| else: |
| att_reporter = None |
|
|
| |
| trainer.extend( |
| extensions.PlotReport( |
| ["main/loss", "validation/main/loss"], "epoch", file_name="loss.png" |
| ) |
| ) |
| trainer.extend( |
| extensions.PlotReport( |
| ["main/acc", "validation/main/acc"], "epoch", file_name="acc.png" |
| ) |
| ) |
| trainer.extend( |
| extensions.PlotReport( |
| ["main/ppl", "validation/main/ppl"], "epoch", file_name="ppl.png" |
| ) |
| ) |
| trainer.extend( |
| extensions.PlotReport( |
| ["main/bleu", "validation/main/bleu"], "epoch", file_name="bleu.png" |
| ) |
| ) |
|
|
| |
| trainer.extend( |
| snapshot_object(model, "model.loss.best"), |
| trigger=training.triggers.MinValueTrigger("validation/main/loss"), |
| ) |
| trainer.extend( |
| snapshot_object(model, "model.acc.best"), |
| trigger=training.triggers.MaxValueTrigger("validation/main/acc"), |
| ) |
|
|
| |
| if args.save_interval_iters > 0: |
| trainer.extend( |
| torch_snapshot(filename="snapshot.iter.{.updater.iteration}"), |
| trigger=(args.save_interval_iters, "iteration"), |
| ) |
| else: |
| trainer.extend(torch_snapshot(), trigger=(1, "epoch")) |
|
|
| |
| if args.opt == "adadelta": |
| if args.criterion == "acc": |
| trainer.extend( |
| restore_snapshot( |
| model, args.outdir + "/model.acc.best", load_fn=torch_load |
| ), |
| trigger=CompareValueTrigger( |
| "validation/main/acc", |
| lambda best_value, current_value: best_value > current_value, |
| ), |
| ) |
| trainer.extend( |
| adadelta_eps_decay(args.eps_decay), |
| trigger=CompareValueTrigger( |
| "validation/main/acc", |
| lambda best_value, current_value: best_value > current_value, |
| ), |
| ) |
| elif args.criterion == "loss": |
| trainer.extend( |
| restore_snapshot( |
| model, args.outdir + "/model.loss.best", load_fn=torch_load |
| ), |
| trigger=CompareValueTrigger( |
| "validation/main/loss", |
| lambda best_value, current_value: best_value < current_value, |
| ), |
| ) |
| trainer.extend( |
| adadelta_eps_decay(args.eps_decay), |
| trigger=CompareValueTrigger( |
| "validation/main/loss", |
| lambda best_value, current_value: best_value < current_value, |
| ), |
| ) |
| elif args.opt == "adam": |
| if args.criterion == "acc": |
| trainer.extend( |
| restore_snapshot( |
| model, args.outdir + "/model.acc.best", load_fn=torch_load |
| ), |
| trigger=CompareValueTrigger( |
| "validation/main/acc", |
| lambda best_value, current_value: best_value > current_value, |
| ), |
| ) |
| trainer.extend( |
| adam_lr_decay(args.lr_decay), |
| trigger=CompareValueTrigger( |
| "validation/main/acc", |
| lambda best_value, current_value: best_value > current_value, |
| ), |
| ) |
| elif args.criterion == "loss": |
| trainer.extend( |
| restore_snapshot( |
| model, args.outdir + "/model.loss.best", load_fn=torch_load |
| ), |
| trigger=CompareValueTrigger( |
| "validation/main/loss", |
| lambda best_value, current_value: best_value < current_value, |
| ), |
| ) |
| trainer.extend( |
| adam_lr_decay(args.lr_decay), |
| trigger=CompareValueTrigger( |
| "validation/main/loss", |
| lambda best_value, current_value: best_value < current_value, |
| ), |
| ) |
|
|
| |
| trainer.extend( |
| extensions.LogReport(trigger=(args.report_interval_iters, "iteration")) |
| ) |
| report_keys = [ |
| "epoch", |
| "iteration", |
| "main/loss", |
| "validation/main/loss", |
| "main/acc", |
| "validation/main/acc", |
| "main/ppl", |
| "validation/main/ppl", |
| "elapsed_time", |
| ] |
| if args.opt == "adadelta": |
| trainer.extend( |
| extensions.observe_value( |
| "eps", |
| lambda trainer: trainer.updater.get_optimizer("main").param_groups[0][ |
| "eps" |
| ], |
| ), |
| trigger=(args.report_interval_iters, "iteration"), |
| ) |
| report_keys.append("eps") |
| elif args.opt in ["adam", "noam"]: |
| trainer.extend( |
| extensions.observe_value( |
| "lr", |
| lambda trainer: trainer.updater.get_optimizer("main").param_groups[0][ |
| "lr" |
| ], |
| ), |
| trigger=(args.report_interval_iters, "iteration"), |
| ) |
| report_keys.append("lr") |
| if args.report_bleu: |
| report_keys.append("main/bleu") |
| report_keys.append("validation/main/bleu") |
| trainer.extend( |
| extensions.PrintReport(report_keys), |
| trigger=(args.report_interval_iters, "iteration"), |
| ) |
|
|
| trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters)) |
| set_early_stop(trainer, args) |
|
|
| if args.tensorboard_dir is not None and args.tensorboard_dir != "": |
| trainer.extend( |
| TensorboardLogger(SummaryWriter(args.tensorboard_dir), att_reporter), |
| trigger=(args.report_interval_iters, "iteration"), |
| ) |
| |
| trainer.run() |
| check_early_stop(trainer, args.epochs) |
|
|
|
|
| def trans(args): |
| """Decode with the given args. |
| |
| Args: |
| args (namespace): The program arguments. |
| |
| """ |
| set_deterministic_pytorch(args) |
| model, train_args = load_trained_model(args.model) |
| assert isinstance(model, MTInterface) |
| model.trans_args = args |
|
|
| |
| if args.ngpu == 1: |
| gpu_id = list(range(args.ngpu)) |
| logging.info("gpu id: " + str(gpu_id)) |
| model.cuda() |
|
|
| |
| with open(args.trans_json, "rb") as f: |
| js = json.load(f)["utts"] |
| new_js = {} |
|
|
| |
| if train_args.multilingual: |
| js = { |
| k: v |
| for k, v in js.items() |
| if v["output"][0]["shape"][0] > 1 and v["output"][1]["shape"][0] > 1 |
| } |
| else: |
| js = { |
| k: v |
| for k, v in js.items() |
| if v["output"][0]["shape"][0] > 0 and v["output"][1]["shape"][0] > 0 |
| } |
|
|
| if args.batchsize == 0: |
| with torch.no_grad(): |
| for idx, name in enumerate(js.keys(), 1): |
| logging.info("(%d/%d) decoding " + name, idx, len(js.keys())) |
| feat = [js[name]["output"][1]["tokenid"].split()] |
| nbest_hyps = model.translate(feat, args, train_args.char_list) |
| new_js[name] = add_results_to_json( |
| js[name], nbest_hyps, train_args.char_list |
| ) |
|
|
| else: |
|
|
| def grouper(n, iterable, fillvalue=None): |
| kargs = [iter(iterable)] * n |
| return zip_longest(*kargs, fillvalue=fillvalue) |
|
|
| |
| keys = list(js.keys()) |
| feat_lens = [js[key]["output"][1]["shape"][0] for key in keys] |
| sorted_index = sorted(range(len(feat_lens)), key=lambda i: -feat_lens[i]) |
| keys = [keys[i] for i in sorted_index] |
|
|
| with torch.no_grad(): |
| for names in grouper(args.batchsize, keys, None): |
| names = [name for name in names if name] |
| feats = [ |
| np.fromiter( |
| map(int, js[name]["output"][1]["tokenid"].split()), |
| dtype=np.int64, |
| ) |
| for name in names |
| ] |
| nbest_hyps = model.translate_batch( |
| feats, |
| args, |
| train_args.char_list, |
| ) |
|
|
| for i, nbest_hyp in enumerate(nbest_hyps): |
| name = names[i] |
| new_js[name] = add_results_to_json( |
| js[name], nbest_hyp, train_args.char_list |
| ) |
|
|
| with open(args.result_label, "wb") as f: |
| f.write( |
| json.dumps( |
| {"utts": new_js}, indent=4, ensure_ascii=False, sort_keys=True |
| ).encode("utf_8") |
| ) |
|
|