Spaces:
Sleeping
Sleeping
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from typing import Any, Dict | |
| from fairseq.distributed import utils | |
| try: | |
| from fairscale.optim import OSS | |
| _has_fairscale = True | |
| except ImportError: | |
| _has_fairscale = False | |
| def shard_(optimizer, group): | |
| if not _has_fairscale: | |
| raise ImportError( | |
| "\n\nPlease install the fairscale package:" "\n\n pip install fairscale" | |
| ) | |
| class FairseqOSS(OSS): | |
| def disable_mem_eff_fp16_loading_hack(self): | |
| return True | |
| def __getattr__(self, name): | |
| if name.startswith("supports") and hasattr(self.optim, name): | |
| return getattr(self.optim, name) | |
| raise AttributeError( | |
| "'FairseqOSS' object has no attribute {0!r}".format(name) | |
| ) | |
| def broadcast_global_state_dict( | |
| self, state_dict: Dict[str, Any] | |
| ) -> Dict[str, Any]: | |
| """ | |
| Broadcasts the entire state_dict to all other ranks | |
| each rank is responsible to load their own partition of data | |
| """ | |
| return utils.broadcast_object( | |
| state_dict, | |
| src_rank=0, | |
| group=self.group, | |
| ) | |
| torch_optimizer = optimizer.optimizer | |
| optim_cls = type(torch_optimizer) | |
| optimizer.optimizer = FairseqOSS( | |
| torch_optimizer.param_groups, | |
| optim_cls, | |
| group=group, | |
| **optimizer.optimizer_config | |
| ) | |