| from typing import List, Optional, Sequence |
| from pathlib import Path |
|
|
| import hydra |
| from omegaconf import OmegaConf, DictConfig |
| from pytorch_lightning import ( |
| Callback, |
| LightningDataModule, |
| LightningModule, |
| Trainer, |
| seed_everything, |
| ) |
| from pytorch_lightning.loggers import LightningLoggerBase |
|
|
| from src.utils import utils |
|
|
| log = utils.get_logger(__name__) |
|
|
|
|
| def last_modification_time(path): |
| """Including files / directory 1-level below the path |
| """ |
| path = Path(path) |
| if path.is_file(): |
| return path.stat().st_mtime |
| elif path.is_dir(): |
| return max(child.stat().st_mtime for child in path.iterdir()) |
| else: |
| return None |
|
|
|
|
| def train(config: DictConfig) -> Optional[float]: |
| """Contains training pipeline. |
| Instantiates all PyTorch Lightning objects from config. |
| |
| Args: |
| config (DictConfig): Configuration composed by Hydra. |
| |
| Returns: |
| Optional[float]: Metric score for hyperparameter optimization. |
| """ |
|
|
| |
| if config.get("seed"): |
| seed_everything(config.seed, workers=True) |
|
|
| |
| OmegaConf.set_struct(config, False) |
| |
| model: LightningModule = hydra.utils.instantiate(config.task, cfg=config, _recursive_=False) |
| datamodule: LightningDataModule = model._datamodule |
|
|
| |
| callbacks: List[Callback] = [] |
| if "callbacks" in config: |
| for _, cb_conf in config.callbacks.items(): |
| if cb_conf is not None and "_target_" in cb_conf: |
| log.info(f"Instantiating callback <{cb_conf._target_}>") |
| callbacks.append(hydra.utils.instantiate(cb_conf)) |
|
|
| |
| logger: List[LightningLoggerBase] = [] |
| if "logger" in config: |
| for _, lg_conf in config.logger.items(): |
| if lg_conf is not None and "_target_" in lg_conf: |
| log.info(f"Instantiating logger <{lg_conf._target_}>") |
| logger.append(hydra.utils.instantiate(lg_conf)) |
|
|
| ckpt_cfg = {} |
| if config.get('resume'): |
| try: |
| checkpoint_path = Path(config.callbacks.model_checkpoint.dirpath) |
| if checkpoint_path.is_dir(): |
| last_ckpt = checkpoint_path / 'last.ckpt' |
| autosave_ckpt = checkpoint_path / '.pl_auto_save.ckpt' |
| if not (last_ckpt.exists() or autosave_ckpt.exists()): |
| raise FileNotFoundError("Resume requires either last.ckpt or .pl_autosave.ckpt") |
| if ((not last_ckpt.exists()) |
| or (autosave_ckpt.exists() |
| and last_modification_time(autosave_ckpt) > last_modification_time(last_ckpt))): |
| |
| checkpoint_path = autosave_ckpt |
| else: |
| checkpoint_path = last_ckpt |
| |
| if checkpoint_path.is_file() or checkpoint_path.is_dir(): |
| ckpt_cfg = {'ckpt_path': str(checkpoint_path)} |
| else: |
| log.info(f'Checkpoint file {str(checkpoint_path)} not found. Will start training from scratch') |
| except (KeyError, FileNotFoundError): |
| pass |
|
|
| |
| n_devices = config.trainer.get('devices', 1) |
| if isinstance(n_devices, Sequence): |
| n_devices = len(n_devices) |
| if n_devices > 1 and config.trainer.get('strategy', None) is None: |
| config.trainer.strategy = dict( |
| _target_='pytorch_lightning.strategies.DDPStrategy', |
| find_unused_parameters=False, |
| gradient_as_bucket_view=True, |
| ) |
|
|
| |
| log.info(f"Instantiating trainer <{config.trainer._target_}>") |
| trainer: Trainer = hydra.utils.instantiate( |
| config.trainer, callbacks=callbacks, logger=logger) |
|
|
| |
| log.info("Starting training!") |
| trainer.fit(model=model, datamodule=datamodule, **ckpt_cfg) |
|
|
| |
| if config.get("test_after_training") and not config.trainer.get("fast_dev_run"): |
| log.info("Starting testing!") |
| trainer.test(model=model, datamodule=datamodule) |
|
|
| |
| log.info("Finalizing!") |
| utils.finish( |
| config=config, |
| model=model, |
| datamodule=datamodule, |
| trainer=trainer, |
| callbacks=callbacks, |
| logger=logger, |
| ) |
|
|
| |
| if not config.trainer.get("fast_dev_run"): |
| log.info(f"Best model ckpt: {trainer.checkpoint_callback.best_model_path}") |
|
|
| |
| optimized_metric = config.get("optimized_metric") |
| if optimized_metric: |
| return trainer.callback_metrics[optimized_metric] |
|
|