|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import importlib |
|
|
import os |
|
|
|
|
|
from fairseq import registry |
|
|
from fairseq.optim.fairseq_optimizer import FairseqOptimizer |
|
|
from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer |
|
|
from fairseq.optim.bmuf import FairseqBMUF |
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
'FairseqOptimizer', |
|
|
'FP16Optimizer', |
|
|
'MemoryEfficientFP16Optimizer', |
|
|
] |
|
|
|
|
|
|
|
|
build_optimizer, register_optimizer, OPTIMIZER_REGISTRY = registry.setup_registry( |
|
|
'--optimizer', |
|
|
base_class=FairseqOptimizer, |
|
|
default='nag', |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
for file in os.listdir(os.path.dirname(__file__)): |
|
|
if file.endswith('.py') and not file.startswith('_'): |
|
|
module = file[:file.find('.py')] |
|
|
importlib.import_module('fairseq.optim.' + module) |
|
|
|