| from pytorch_lightning import Callback |
| import copy |
| import itertools |
| import torch |
| import contextlib |
| from torch.distributed.fsdp import FullyShardedDataParallel |
|
|
|
|
| class EMACallback(Callback): |
| def __init__( |
| self, |
| module_attr_name, |
| ema_module_attr_name, |
| decay=0.999, |
| start_ema_step=0, |
| init_ema_random=True, |
| ): |
| super().__init__() |
| self.decay = decay |
| self.module_attr_name = module_attr_name |
| self.ema_module_attr_name = ema_module_attr_name |
| self.start_ema_step = start_ema_step |
| self.init_ema_random = init_ema_random |
|
|
| def on_train_start(self, trainer, pl_module): |
| if pl_module.global_step == 0: |
| if not hasattr(pl_module, self.module_attr_name): |
| raise ValueError( |
| f"Module {pl_module} does not have attribute {self.module_attr_name}" |
| ) |
| if not hasattr(pl_module, self.ema_module_attr_name): |
| pl_module.add_module( |
| self.ema_module_attr_name, |
| copy.deepcopy(getattr(pl_module, self.module_attr_name)) |
| .eval() |
| .requires_grad_(False), |
| ) |
| self.reset_ema(pl_module) |
|
|
| def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): |
| if pl_module.global_step == self.start_ema_step: |
| self.reset_ema(pl_module) |
| elif ( |
| pl_module.global_step < self.start_ema_step |
| and pl_module.global_step % 100 == 0 |
| ): |
| |
| self.update_ema(pl_module, decay=0.9) |
| elif pl_module.global_step > self.start_ema_step: |
| self.update_ema(pl_module, decay=self.decay) |
|
|
| def update_ema(self, pl_module, decay=0.999): |
| ema_module = getattr(pl_module, self.ema_module_attr_name) |
| module = getattr(pl_module, self.module_attr_name) |
| context_manager = self.get_model_context_manager(module) |
| with context_manager: |
| with torch.no_grad(): |
| ema_params = ema_module.state_dict() |
| for name, param in itertools.chain( |
| module.named_parameters(), module.named_buffers() |
| ): |
| if name in ema_params: |
| if param.requires_grad: |
| ema_params[name].copy_( |
| ema_params[name].detach().lerp(param.detach(), decay) |
| ) |
|
|
| def get_model_context_manager(self, module): |
| fsdp_enabled = is_model_fsdp(module) |
| model_context_manager = contextlib.nullcontext() |
| if fsdp_enabled: |
| model_context_manager = module.summon_full_params(module) |
| return model_context_manager |
|
|
| def reset_ema(self, pl_module): |
| ema_module = getattr(pl_module, self.ema_module_attr_name) |
| if self.init_ema_random: |
| ema_module.init_weights() |
| else: |
| module = getattr(pl_module, self.module_attr_name) |
| context_manager = self.get_model_context_manager(module) |
| with context_manager: |
| ema_params = ema_module.state_dict() |
| for name, param in itertools.chain( |
| module.named_parameters(), module.named_buffers() |
| ): |
| if name in ema_params: |
| ema_params[name].copy_(param.detach()) |
|
|
|
|
| def is_model_fsdp(model: torch.nn.Module) -> bool: |
| try: |
| if isinstance(model, FullyShardedDataParallel): |
| return True |
|
|
| |
| for _, obj in model.named_children(): |
| if isinstance(obj, FullyShardedDataParallel): |
| return True |
| return False |
| except ImportError: |
| return False |
|
|