Spaces:
Sleeping
Sleeping
| from typing import Any, List | |
| import inspect | |
| import torch | |
| import hydra | |
| from pytorch_lightning import LightningModule, LightningDataModule | |
| from torchmetrics import MetricCollection | |
| from einops import rearrange | |
| from omegaconf import OmegaConf | |
| from src.utils.utils import get_logger | |
| from src.optim.param_grouping import group_parameters_for_optimizer | |
| from src.utils.checkpoint import load_checkpoint | |
| logger = get_logger(__name__) | |
| class SequenceModel(LightningModule): | |
| def __init__(self, cfg, model_cfg=None): | |
| """If model_cfg is passed, it will take precedence over cfg.model | |
| """ | |
| super().__init__() | |
| # this line ensures params passed to LightningModule will be saved to ckpt | |
| # it also allows to access params with 'self.hparams' attribute | |
| self.save_hyperparameters(cfg) | |
| self.cfg = cfg | |
| self.model_cfg = model_cfg or self.cfg.model | |
| self.instantiate_datamodule() | |
| self.instantiate_model() | |
| self.warmstart() | |
| self.instantiate_loss() | |
| self.instantiate_metrics() | |
| def instantiate_datamodule(self): | |
| logger.info(f"Instantiating datamodule <{self.cfg.datamodule._target_}>") | |
| # Calling this self.datamodule will mess with PL since it also assigns self.datamodule | |
| self._datamodule: LightningDataModule = hydra.utils.instantiate(self.cfg.datamodule) | |
| self._datamodule.prepare_data() | |
| self._datamodule.setup() | |
| OmegaConf.clear_resolver('datamodule') | |
| OmegaConf.register_new_resolver('datamodule', lambda attr: getattr(self._datamodule, attr)) | |
| def instantiate_model(self): | |
| # if hasattr(self._datamodule, 'num_classes'): | |
| # self.model_cfg.num_classes = self._datamodule.num_classes | |
| # if (hasattr(self._datamodule, 'vocab_size') | |
| # and self.model_cfg.get('embedding_cfg', None) is not None | |
| # and self.model_cfg.embedding_cfg._target_ == "torch.nn.Embedding"): | |
| # self.model_cfg.embedding_cfg.num_embeddings = self._datamodule.vocab_size | |
| logger.info(f"Instantiating model <{self.model_cfg._target_}>") | |
| recursive = getattr(self.model_cfg, '_recursive_', False) | |
| self.model = hydra.utils.instantiate(self.model_cfg, _recursive_=recursive) | |
| def instantiate_loss(self): | |
| loss_fn_cfg = self.cfg.train.get('loss_fn') | |
| if loss_fn_cfg is None: | |
| loss_fn_cfg = {'_target_': 'torch.nn.CrossEntropyLoss'} | |
| self.loss_fn = hydra.utils.instantiate(loss_fn_cfg) | |
| loss_fn_val_cfg = self.cfg.train.get('loss_fn_val', loss_fn_cfg) | |
| self.loss_fn_val = hydra.utils.instantiate(loss_fn_val_cfg) | |
| def instantiate_metrics(self): | |
| # use separate metric instance for train, val and test step | |
| # to ensure a proper reduction over the epoch | |
| if 'eval' in self.cfg and 'metrics' in self.cfg.eval: | |
| metrics_cfg = self.cfg.eval.metrics | |
| else: | |
| metrics_cfg = {'acc': {'_target_': 'torchmetrics.Accuracy'}} | |
| metrics = MetricCollection({name: hydra.utils.instantiate(cfg) | |
| for name, cfg in metrics_cfg.items()}) | |
| self.train_metrics = metrics.clone(prefix='train/') | |
| self.val_metrics = metrics.clone(prefix='val/') | |
| self.test_metrics = metrics.clone(prefix='test/') | |
| def warmstart(self): | |
| if self.cfg.train.get('warmstart', None) is not None: | |
| logger.info(f"Warm-starting with weights from {self.cfg.train.warmstart.path}") | |
| strict = self.cfg.train.warmstart.get('strict', True) | |
| state_dict = load_checkpoint(self.cfg.train.warmstart.path) | |
| if self.cfg.train.warmstart.get('post_process', None) is not None: | |
| state_dict = hydra.utils.instantiate(self.cfg.train.warmstart.post_process, | |
| state_dict) | |
| load_return = self.model.load_state_dict(state_dict, strict=False) | |
| logger.info(load_return) | |
| def forward(self, *args, **kwargs): | |
| return self.model(*args, **kwargs) | |
| def step(self, batch: Any, is_train=True): | |
| try: | |
| x, y, lengths = batch | |
| except ValueError: | |
| x, y = batch | |
| lengths = None | |
| output = self.forward(x) if lengths is None else self.forward(x, lengths=lengths) | |
| loss = self.loss_fn(output, y) if is_train else self.loss_fn_val(output, y) | |
| return loss, output, y | |
| def shared_step(self, batch: Any, batch_idx: int, phase='train'): | |
| loss, output, targets = self.step(batch, is_train=(phase == 'train')) | |
| metrics = getattr(self, f'{phase}_metrics') | |
| metrics(output, targets) | |
| log_on_step = 'eval' in self.cfg and self.cfg.eval.get('log_on_step', False) and phase == 'train' | |
| self.log(f"{phase}/loss", loss, on_step=log_on_step, on_epoch=True, | |
| prog_bar=False, sync_dist=True) | |
| # https://pytorch-lightning.readthedocs.io/en/stable/visualize/logging_advanced.html#enable-metrics-for-distributed-training | |
| # We need to log the Metrics object, not the metric result, since otherwise | |
| # pytorch-lightning will use torch.mean to reduce it. | |
| # This would be wrong for perplexity, for example. | |
| self.log_dict(metrics, on_step=log_on_step, on_epoch=True, prog_bar=True, sync_dist=True) | |
| return {"loss": loss, "output": output, "targets": targets} | |
| def training_step(self, batch: Any, batch_idx: int): | |
| return self.shared_step(batch, batch_idx, phase='train') | |
| def validation_step(self, batch: Any, batch_idx: int): | |
| return self.shared_step(batch, batch_idx, phase='val') | |
| def test_step(self, batch: Any, batch_idx: int): | |
| return self.shared_step(batch, batch_idx, phase='test') | |
| def configure_optimizers(self): | |
| if 'optimizer_param_grouping' in self.cfg.train: # Set zero weight decay for some params | |
| parameters = group_parameters_for_optimizer(self.model, self.cfg.train.optimizer, | |
| **self.cfg.train.optimizer_param_grouping) | |
| else: | |
| # parameters = self.model.parameters() | |
| parameters = self.parameters() # [21-09-08] AG: this will train task specific parameters such as Retrieval head for AAN | |
| optimizer = hydra.utils.instantiate(self.cfg.train.optimizer, parameters) | |
| # Log optimizer info | |
| for i, g in enumerate(optimizer.param_groups): | |
| ntensors = len(g['params']) | |
| nparams = sum(p.numel() for p in g['params']) | |
| hparams = {k: v for k, v in g.items() if k != 'params'} | |
| logger.info(f'Optimizer group {i}: {ntensors} tensors, {nparams} parameters, {hparams}') | |
| if 'scheduler' not in self.cfg.train: | |
| return optimizer | |
| else: | |
| # lr_scheduler should be called either every step (default) or every epoch | |
| lr_scheduler = hydra.utils.instantiate(self.cfg.train.scheduler, optimizer) | |
| return [optimizer], {'scheduler': lr_scheduler, | |
| 'interval': self.cfg.train.get('scheduler_interval', 'step'), | |
| 'monitor': self.cfg.train.get('scheduler_monitor', 'val/loss')} | |
| def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx): | |
| # https://pytorch-lightning.readthedocs.io/en/latest/guides/speed.html#set-grads-to-none | |
| # TD [2022-04-30]: DeepSpeed optimizer uses the kwarg set_grad_to_none instead of set_to_none | |
| if 'set_to_none' in inspect.signature(optimizer.zero_grad).parameters: | |
| optimizer.zero_grad(set_to_none=True) | |
| else: | |
| optimizer.zero_grad() | |
| def on_save_checkpoint(self, checkpoint): | |
| # TD [2022-08-07] ['epoch_loop.batch_progress']['total']['completed'] is 1 iteration | |
| # behind, so we're using the optimizer's progress. | |
| checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['total']['completed'] = checkpoint['loops']['fit_loop']['epoch_loop.batch_loop.optimizer_loop.optim_progress']['optimizer']['step']['total']['completed'] * self.trainer.accumulate_grad_batches | |
| checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['completed'] = checkpoint['loops']['fit_loop']['epoch_loop.batch_loop.optimizer_loop.optim_progress']['optimizer']['step']['current']['completed'] * self.trainer.accumulate_grad_batches | |
| # _batches_that_stepped tracks the number of global steps, not the number | |
| # of local steps, so we don't multiply with self.trainer.accumulate_grad_batches here. | |
| checkpoint['loops']['fit_loop']['epoch_loop.state_dict']['_batches_that_stepped'] = checkpoint['loops']['fit_loop']['epoch_loop.batch_loop.optimizer_loop.optim_progress']['optimizer']['step']['total']['completed'] | |
| class SequenceLMModel(SequenceModel): | |
| def step(self, batch: Any, is_train=True): | |
| x, y = batch | |
| output = self.forward(x).logits | |
| output = rearrange(output, '... C -> (...) C') | |
| y = rearrange(y, '... -> (...)') | |
| loss = self.loss_fn(output, y) if is_train else self.loss_fn_val(output, y) | |
| return loss, output, y | |
| def shared_step(self, batch: Any, batch_idx: int, phase='train'): | |
| loss, output, targets = self.step(batch, is_train=(phase == 'train')) | |
| # Passing the loss to the perplexity metrics to avoid recomputation | |
| metrics = getattr(self, f'{phase}_metrics') | |
| metrics(output, targets, loss=loss) | |
| log_on_step = 'eval' in self.cfg and self.cfg.eval.get('log_on_step', False) and phase == 'train' | |
| self.log(f"{phase}/loss", loss, on_step=log_on_step, on_epoch=True, | |
| prog_bar=False, sync_dist=True) | |
| # https://pytorch-lightning.readthedocs.io/en/stable/visualize/logging_advanced.html#enable-metrics-for-distributed-training | |
| # We need to log the Metrics object, not the metric result, since otherwise | |
| # pytorch-lightning will use torch.mean to reduce it. | |
| # This would be wrong for perplexity, for example. | |
| self.log_dict(metrics, on_step=log_on_step, on_epoch=True, prog_bar=True, sync_dist=True) | |
| return {"loss": loss, "output": output, "targets": targets} | |