| | """Chainer optimizer builders.""" |
| | import argparse |
| |
|
| | import chainer |
| | from chainer.optimizer_hooks import WeightDecay |
| |
|
| | from espnet.optimizer.factory import OptimizerFactoryInterface |
| | from espnet.optimizer.parser import adadelta |
| | from espnet.optimizer.parser import adam |
| | from espnet.optimizer.parser import sgd |
| |
|
| |
|
| | class AdamFactory(OptimizerFactoryInterface): |
| | """Adam factory.""" |
| |
|
| | @staticmethod |
| | def add_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: |
| | """Register args.""" |
| | return adam(parser) |
| |
|
| | @staticmethod |
| | def from_args(target, args: argparse.Namespace): |
| | """Initialize optimizer from argparse Namespace. |
| | |
| | Args: |
| | target: for pytorch `model.parameters()`, |
| | for chainer `model` |
| | args (argparse.Namespace): parsed command-line args |
| | |
| | """ |
| | opt = chainer.optimizers.Adam( |
| | alpha=args.lr, |
| | beta1=args.beta1, |
| | beta2=args.beta2, |
| | ) |
| | opt.setup(target) |
| | opt.add_hook(WeightDecay(args.weight_decay)) |
| | return opt |
| |
|
| |
|
| | class SGDFactory(OptimizerFactoryInterface): |
| | """SGD factory.""" |
| |
|
| | @staticmethod |
| | def add_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: |
| | """Register args.""" |
| | return sgd(parser) |
| |
|
| | @staticmethod |
| | def from_args(target, args: argparse.Namespace): |
| | """Initialize optimizer from argparse Namespace. |
| | |
| | Args: |
| | target: for pytorch `model.parameters()`, |
| | for chainer `model` |
| | args (argparse.Namespace): parsed command-line args |
| | |
| | """ |
| | opt = chainer.optimizers.SGD( |
| | lr=args.lr, |
| | ) |
| | opt.setup(target) |
| | opt.add_hook(WeightDecay(args.weight_decay)) |
| | return opt |
| |
|
| |
|
| | class AdadeltaFactory(OptimizerFactoryInterface): |
| | """Adadelta factory.""" |
| |
|
| | @staticmethod |
| | def add_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: |
| | """Register args.""" |
| | return adadelta(parser) |
| |
|
| | @staticmethod |
| | def from_args(target, args: argparse.Namespace): |
| | """Initialize optimizer from argparse Namespace. |
| | |
| | Args: |
| | target: for pytorch `model.parameters()`, |
| | for chainer `model` |
| | args (argparse.Namespace): parsed command-line args |
| | |
| | """ |
| | opt = chainer.optimizers.AdaDelta( |
| | rho=args.rho, |
| | eps=args.eps, |
| | ) |
| | opt.setup(target) |
| | opt.add_hook(WeightDecay(args.weight_decay)) |
| | return opt |
| |
|
| |
|
| | OPTIMIZER_FACTORY_DICT = { |
| | "adam": AdamFactory, |
| | "sgd": SGDFactory, |
| | "adadelta": AdadeltaFactory, |
| | } |
| |
|