| import argparse |
| import logging |
| from typing import Callable |
| from typing import Collection |
| from typing import Dict |
| from typing import List |
| from typing import Optional |
| from typing import Tuple |
|
|
| import numpy as np |
| import torch |
| from typeguard import check_argument_types |
| from typeguard import check_return_type |
|
|
| from espnet2.lm.abs_model import AbsLM |
| from espnet2.lm.espnet_model import ESPnetLanguageModel |
| from espnet2.lm.seq_rnn_lm import SequentialRNNLM |
| from espnet2.lm.transformer_lm import TransformerLM |
| from espnet2.tasks.abs_task import AbsTask |
| from espnet2.torch_utils.initialize import initialize |
| from espnet2.train.class_choices import ClassChoices |
| from espnet2.train.collate_fn import CommonCollateFn |
| from espnet2.train.preprocessor import CommonPreprocessor |
| from espnet2.train.trainer import Trainer |
| from espnet2.utils.get_default_kwargs import get_default_kwargs |
| from espnet2.utils.nested_dict_action import NestedDictAction |
| from espnet2.utils.types import str2bool |
| from espnet2.utils.types import str_or_none |
|
|
|
|
| lm_choices = ClassChoices( |
| "lm", |
| classes=dict( |
| seq_rnn=SequentialRNNLM, |
| transformer=TransformerLM, |
| ), |
| type_check=AbsLM, |
| default="seq_rnn", |
| ) |
|
|
|
|
| class LMTask(AbsTask): |
| |
| num_optimizers: int = 1 |
|
|
| |
| class_choices_list = [lm_choices] |
|
|
| |
| trainer = Trainer |
|
|
| @classmethod |
| def add_task_arguments(cls, parser: argparse.ArgumentParser): |
| |
| assert check_argument_types() |
| group = parser.add_argument_group(description="Task related") |
|
|
| |
| |
| required = parser.get_default("required") |
| required += ["token_list"] |
|
|
| group.add_argument( |
| "--token_list", |
| type=str_or_none, |
| default=None, |
| help="A text mapping int-id to token", |
| ) |
| group.add_argument( |
| "--init", |
| type=lambda x: str_or_none(x.lower()), |
| default=None, |
| help="The initialization method", |
| choices=[ |
| "chainer", |
| "xavier_uniform", |
| "xavier_normal", |
| "kaiming_uniform", |
| "kaiming_normal", |
| None, |
| ], |
| ) |
| group.add_argument( |
| "--model_conf", |
| action=NestedDictAction, |
| default=get_default_kwargs(ESPnetLanguageModel), |
| help="The keyword arguments for model class.", |
| ) |
|
|
| group = parser.add_argument_group(description="Preprocess related") |
| group.add_argument( |
| "--use_preprocessor", |
| type=str2bool, |
| default=True, |
| help="Apply preprocessing to data or not", |
| ) |
| group.add_argument( |
| "--token_type", |
| type=str, |
| default="bpe", |
| choices=["bpe", "char", "word"], |
| help="", |
| ) |
| group.add_argument( |
| "--bpemodel", |
| type=str_or_none, |
| default=None, |
| help="The model file fo sentencepiece", |
| ) |
| parser.add_argument( |
| "--non_linguistic_symbols", |
| type=str_or_none, |
| help="non_linguistic_symbols file path", |
| ) |
| parser.add_argument( |
| "--cleaner", |
| type=str_or_none, |
| choices=[None, "tacotron", "jaconv", "vietnamese"], |
| default=None, |
| help="Apply text cleaning", |
| ) |
| parser.add_argument( |
| "--g2p", |
| type=str_or_none, |
| choices=[None, "g2p_en", "pyopenjtalk", "pyopenjtalk_kana"], |
| default=None, |
| help="Specify g2p method if --token_type=phn", |
| ) |
|
|
| for class_choices in cls.class_choices_list: |
| |
| |
| class_choices.add_arguments(group) |
|
|
| assert check_return_type(parser) |
| return parser |
|
|
| @classmethod |
| def build_collate_fn( |
| cls, args: argparse.Namespace, train: bool |
| ) -> Callable[ |
| [Collection[Tuple[str, Dict[str, np.ndarray]]]], |
| Tuple[List[str], Dict[str, torch.Tensor]], |
| ]: |
| assert check_argument_types() |
| return CommonCollateFn(int_pad_value=0) |
|
|
| @classmethod |
| def build_preprocess_fn( |
| cls, args: argparse.Namespace, train: bool |
| ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]: |
| assert check_argument_types() |
| if args.use_preprocessor: |
| retval = CommonPreprocessor( |
| train=train, |
| token_type=args.token_type, |
| token_list=args.token_list, |
| bpemodel=args.bpemodel, |
| text_cleaner=args.cleaner, |
| g2p_type=args.g2p, |
| non_linguistic_symbols=args.non_linguistic_symbols, |
| ) |
| else: |
| retval = None |
| assert check_return_type(retval) |
| return retval |
|
|
| @classmethod |
| def required_data_names( |
| cls, train: bool = True, inference: bool = False |
| ) -> Tuple[str, ...]: |
| retval = ("text",) |
| return retval |
|
|
| @classmethod |
| def optional_data_names( |
| cls, train: bool = True, inference: bool = False |
| ) -> Tuple[str, ...]: |
| retval = () |
| return retval |
|
|
| @classmethod |
| def build_model(cls, args: argparse.Namespace) -> ESPnetLanguageModel: |
| assert check_argument_types() |
| if isinstance(args.token_list, str): |
| with open(args.token_list, encoding="utf-8") as f: |
| token_list = [line.rstrip() for line in f] |
|
|
| |
| |
| args.token_list = token_list.copy() |
| elif isinstance(args.token_list, (tuple, list)): |
| token_list = args.token_list.copy() |
| else: |
| raise RuntimeError("token_list must be str or dict") |
|
|
| vocab_size = len(token_list) |
| logging.info(f"Vocabulary size: {vocab_size }") |
|
|
| |
| lm_class = lm_choices.get_class(args.lm) |
| lm = lm_class(vocab_size=vocab_size, **args.lm_conf) |
|
|
| |
| |
| model = ESPnetLanguageModel(lm=lm, vocab_size=vocab_size, **args.model_conf) |
|
|
| |
| |
| if args.init is not None: |
| initialize(model, args.init) |
|
|
| assert check_return_type(model) |
| return model |
|
|