| |
| |
| |
| |
| """isort:skip_file""" |
|
|
| import importlib |
| import os |
|
|
| from fairseq import registry |
| from fairseq.optim.bmuf import FairseqBMUF |
| from fairseq.optim.fairseq_optimizer import ( |
| FairseqOptimizer, |
| LegacyFairseqOptimizer, |
| ) |
| from fairseq.optim.amp_optimizer import AMPOptimizer |
| from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer |
| from fairseq.optim.shard import shard_ |
| from omegaconf import DictConfig |
|
|
| __all__ = [ |
| "AMPOptimizer", |
| "FairseqOptimizer", |
| "FP16Optimizer", |
| "MemoryEfficientFP16Optimizer", |
| "shard_", |
| ] |
|
|
| ( |
| _build_optimizer, |
| register_optimizer, |
| OPTIMIZER_REGISTRY, |
| OPTIMIZER_DATACLASS_REGISTRY, |
| ) = registry.setup_registry("--optimizer", base_class=FairseqOptimizer, required=True) |
|
|
|
|
| def build_optimizer(cfg: DictConfig, params, *extra_args, **extra_kwargs): |
| if all(isinstance(p, dict) for p in params): |
| params = [t for p in params for t in p.values()] |
| params = list(filter(lambda p: p.requires_grad, params)) |
| return _build_optimizer(cfg, params, *extra_args, **extra_kwargs) |
|
|
|
|
| |
| for file in sorted(os.listdir(os.path.dirname(__file__))): |
| if file.endswith(".py") and not file.startswith("_"): |
| file_name = file[: file.find(".py")] |
| importlib.import_module("fairseq.optim." + file_name) |
|
|