Spaces:
Sleeping
Sleeping
| import os | |
| import warnings | |
| from typing import Any, Dict, Optional, Type | |
| from lightning_fabric.utilities.cloud_io import get_filesystem | |
| from pytorch_lightning import LightningModule, Trainer | |
| from pytorch_lightning.cli import ( | |
| LightningArgumentParser, | |
| LightningCLI, | |
| LRSchedulerTypeUnion, | |
| ReduceLROnPlateau, | |
| SaveConfigCallback, | |
| ) | |
| from pytorch_lightning.loggers import Logger, WandbLogger | |
| from torch.optim import Optimizer | |
| from torch.optim.lr_scheduler import CyclicLR, OneCycleLR | |
| class WandbSaveConfigCallback(SaveConfigCallback): | |
| def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: | |
| if self.already_saved: | |
| return | |
| log_dir = trainer.log_dir # this broadcasts the directory | |
| if trainer.logger is not None and trainer.logger.name is not None and trainer.logger.version is not None: | |
| log_dir = os.path.join(log_dir, trainer.logger.name, str(trainer.logger.version)) | |
| config_path = os.path.join(log_dir, self.config_filename) | |
| fs = get_filesystem(log_dir) | |
| if not self.overwrite: | |
| # check if the file exists on rank 0 | |
| file_exists = fs.isfile(config_path) if trainer.is_global_zero else False | |
| # broadcast whether to fail to all ranks | |
| file_exists = trainer.strategy.broadcast(file_exists) | |
| if file_exists: | |
| raise RuntimeError( | |
| f"{self.__class__.__name__} expected {config_path} to NOT exist. Aborting to avoid overwriting" | |
| " results of a previous run. You can delete the previous config file," | |
| " set `LightningCLI(save_config_callback=None)` to disable config saving," | |
| ' or set `LightningCLI(save_config_kwargs={"overwrite": True})` to overwrite the config file.' | |
| ) | |
| # save the file on rank 0 | |
| if trainer.is_global_zero: | |
| # save only on rank zero to avoid race conditions. | |
| # the `log_dir` needs to be created as we rely on the logger to do it usually | |
| # but it hasn't logged anything at this point | |
| fs.makedirs(log_dir, exist_ok=True) | |
| self.parser.save( | |
| self.config, config_path, skip_none=False, overwrite=self.overwrite, multifile=self.multifile | |
| ) | |
| self.already_saved = True | |
| # save optimizer and lr scheduler config | |
| # for _logger in trainer.loggers: | |
| # if isinstance(_logger, Logger): | |
| # config = {} | |
| # if "optimizer" in self.config: | |
| # config["optimizer"] = { | |
| # k.replace("init_args.", ""): v for k, v in dict(self.config["optimizer"]).items() | |
| # } | |
| # if "lr_scheduler" in self.config: | |
| # config["lr_scheduler"] = { | |
| # k.replace("init_args.", ""): v for k, v in dict(self.config["lr_scheduler"]).items() | |
| # } | |
| # _logger.log_hyperparams(config) | |
| # save optimizer and lr scheduler config | |
| for _logger in trainer.loggers: | |
| if isinstance(_logger, Logger): | |
| config = {} | |
| optimizer_config = self.config.get("optimizer", None) | |
| if optimizer_config is not None: | |
| if isinstance(optimizer_config, list): | |
| config["optimizer"] = [ | |
| {k.replace("init_args.", ""): v for k, v in opt_conf.items()} | |
| for opt_conf in optimizer_config if opt_conf is not None | |
| ] | |
| else: | |
| config["optimizer"] = { | |
| k.replace("init_args.", ""): v for k, v in optimizer_config.items() | |
| } | |
| lr_scheduler_config = self.config.get("lr_scheduler", None) | |
| if lr_scheduler_config is not None: | |
| if isinstance(lr_scheduler_config, list): | |
| config["lr_scheduler"] = [ | |
| {k.replace("init_args.", ""): v for k, v in sch_conf.items()} | |
| for sch_conf in lr_scheduler_config if sch_conf is not None | |
| ] | |
| else: | |
| config["lr_scheduler"] = { | |
| k.replace("init_args.", ""): v for k, v in lr_scheduler_config.items() | |
| } | |
| _logger.log_hyperparams(config) | |
| # broadcast so that all ranks are in sync on future calls to .setup() | |
| self.already_saved = trainer.strategy.broadcast(self.already_saved) | |
| class CustomLightningCLI(LightningCLI): | |
| def __init__( | |
| self, | |
| save_config_callback: Optional[Type[SaveConfigCallback]] = WandbSaveConfigCallback, | |
| parser_kwargs: Optional[Dict[str, Any]] = None, | |
| **kwargs: Any, | |
| ) -> None: | |
| new_parser_kwargs = { | |
| sub_command: dict(default_config_files=[os.path.join("configs", "default.yaml")]) | |
| for sub_command in ["fit", "validate", "test", "predict"] | |
| } | |
| new_parser_kwargs.update(parser_kwargs or {}) | |
| super().__init__(save_config_callback=save_config_callback, parser_kwargs=new_parser_kwargs, **kwargs) | |
| def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: | |
| parser.add_argument("--ignore_warnings", default=False, type=bool, help="Ignore warnings") | |
| parser.add_argument("--git_commit_before_fit", default=False, type=bool, help="Git commit before training") | |
| parser.add_argument( | |
| "--test_after_fit", default=True, type=bool, help="Run test on the best checkpoint after training" | |
| ) | |
| parser.add_argument("--validate_every_n_steps", default=10000, type=int, help="Run validation every N steps (0 to disable)") | |
| parser.add_argument("--test_every_epoch", default=False, type=bool, help="Run test after every epoch") | |
| parser.add_argument( | |
| "--load_weights_path", | |
| type=str, | |
| default=None, | |
| help="If set, load model weights from this checkpoint path (but do NOT load optimizer state)." | |
| ) | |
| def before_instantiate_classes(self) -> None: | |
| if self.config[self.subcommand].get("ignore_warnings"): | |
| warnings.filterwarnings("ignore") | |
| def after_instantiate_classes(self) -> None: | |
| """ | |
| This is called after: | |
| - arguments have been parsed, | |
| - your LightningModule (self.model) has been created, | |
| - datamodule (if any) has been created, | |
| - and trainer has been instantiated (but not yet fitted). | |
| We use it to load just the model‐weights if --load_weights_path was provided. | |
| """ | |
| load_path = self.config.get("load_weights_path", None) | |
| if load_path: | |
| ckpt = torch.load(load_path, map_location=lambda storage, loc: storage) | |
| # Only load the model's state_dict; optimizer / lr_scheduler are untouched. | |
| self.model.load_state_dict(ckpt["state_dict"]) | |
| def before_fit(self) -> None: | |
| if self.config.fit.get("git_commit_before_fit") and not os.environ.get("DEBUG", False): | |
| logger = self.trainer.logger | |
| if isinstance(logger, WandbLogger): | |
| version = getattr(logger, "version") | |
| name = getattr(logger, "_name") | |
| message = "Commit Message" | |
| if name and version: | |
| message = f"{name}_{version}" | |
| elif name: | |
| message = name | |
| elif version: | |
| message = version | |
| os.system(f'git commit -am "{message}"') | |
| # Setup validation interval if specified | |
| validate_every_n_steps = self.config.fit.get("validate_every_n_steps", 0) | |
| if validate_every_n_steps > 0: | |
| self.trainer.fit_loop.epoch_loop.val_check_batch = validate_every_n_steps | |
| def after_fit(self) -> None: | |
| if self.config.fit.get("test_after_fit") and not os.environ.get("DEBUG", False): | |
| self._run_subcommand("test") | |
| def after_train_epoch(self) -> None: | |
| if self.config.fit.get("test_every_epoch") and not os.environ.get("DEBUG", False): | |
| self._run_subcommand("test") | |
| def before_test(self) -> None: | |
| if self.trainer.checkpoint_callback and self.trainer.checkpoint_callback.best_model_path: | |
| tested_ckpt_path = self.trainer.checkpoint_callback.best_model_path | |
| elif self.config_init[self.config_init["subcommand"]]["ckpt_path"]: | |
| return | |
| else: | |
| tested_ckpt_path = None | |
| self.config_init[self.config_init["subcommand"]]["ckpt_path"] = tested_ckpt_path | |
| def _prepare_subcommand_kwargs(self, subcommand: str) -> Dict[str, Any]: | |
| """Prepares the keyword arguments to pass to the subcommand to run.""" | |
| fn_kwargs = { | |
| k: v | |
| for k, v in self.config_init[self.config_init["subcommand"]].items() | |
| if k in self._subcommand_method_arguments[subcommand] | |
| } | |
| fn_kwargs["model"] = self.model | |
| if self.datamodule is not None: | |
| fn_kwargs["datamodule"] = self.datamodule | |
| return fn_kwargs | |
| def configure_optimizers( | |
| lightning_module: LightningModule, optimizer: Optimizer, lr_scheduler: Optional[LRSchedulerTypeUnion] = None | |
| ) -> Any: | |
| """Override to customize the :meth:`~pytorch_lightning.core.LightningModule.configure_optimizers` method. | |
| Args: | |
| lightning_module: A reference to the model. | |
| optimizer: The optimizer. | |
| lr_scheduler: The learning rate scheduler (if used). | |
| """ | |
| if lr_scheduler is None: | |
| return optimizer | |
| if isinstance(lr_scheduler, ReduceLROnPlateau): | |
| return { | |
| "optimizer": optimizer, | |
| "lr_scheduler": {"scheduler": lr_scheduler, "monitor": lr_scheduler.monitor}, | |
| } | |
| if isinstance(lr_scheduler, (OneCycleLR, CyclicLR)): | |
| # CyclicLR and OneCycleLR are step-based schedulers, where the default interval is "epoch". | |
| return {"optimizer": optimizer, "lr_scheduler": {"scheduler": lr_scheduler, "interval": "step"}} | |
| return [optimizer], [lr_scheduler] | |