| | |
| | |
| | |
| | |
| |
|
| | from typing import Any, Dict |
| |
|
| | from fairseq.distributed import utils |
| |
|
| |
|
| | try: |
| | from fairscale.optim import OSS |
| |
|
| | _has_fairscale = True |
| | except ImportError: |
| | _has_fairscale = False |
| |
|
| |
|
| | def shard_(optimizer, group): |
| | if not _has_fairscale: |
| | raise ImportError( |
| | "\n\nPlease install the fairscale package:" "\n\n pip install fairscale" |
| | ) |
| |
|
| | class FairseqOSS(OSS): |
| | @property |
| | def disable_mem_eff_fp16_loading_hack(self): |
| | return True |
| |
|
| | def __getattr__(self, name): |
| | if name.startswith("supports") and hasattr(self.optim, name): |
| | return getattr(self.optim, name) |
| | raise AttributeError( |
| | "'FairseqOSS' object has no attribute {0!r}".format(name) |
| | ) |
| |
|
| | def broadcast_global_state_dict( |
| | self, state_dict: Dict[str, Any] |
| | ) -> Dict[str, Any]: |
| | """ |
| | Broadcasts the entire state_dict to all other ranks |
| | each rank is responsible to load their own partition of data |
| | """ |
| | return utils.broadcast_object( |
| | state_dict, |
| | src_rank=0, |
| | group=self.group, |
| | ) |
| |
|
| | torch_optimizer = optimizer.optimizer |
| | optim_cls = type(torch_optimizer) |
| |
|
| | optimizer.optimizer = FairseqOSS( |
| | torch_optimizer.param_groups, |
| | optim_cls, |
| | group=group, |
| | **optimizer.optimizer_config |
| | ) |
| |
|