warisqr007's picture
Add application file
eb9c81a
import os
import warnings
from typing import Any, Dict, Optional, Type
from lightning_fabric.utilities.cloud_io import get_filesystem
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.cli import (
LightningArgumentParser,
LightningCLI,
LRSchedulerTypeUnion,
ReduceLROnPlateau,
SaveConfigCallback,
)
from pytorch_lightning.loggers import Logger, WandbLogger
from torch.optim import Optimizer
from torch.optim.lr_scheduler import CyclicLR, OneCycleLR
class WandbSaveConfigCallback(SaveConfigCallback):
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
if self.already_saved:
return
log_dir = trainer.log_dir # this broadcasts the directory
if trainer.logger is not None and trainer.logger.name is not None and trainer.logger.version is not None:
log_dir = os.path.join(log_dir, trainer.logger.name, str(trainer.logger.version))
config_path = os.path.join(log_dir, self.config_filename)
fs = get_filesystem(log_dir)
if not self.overwrite:
# check if the file exists on rank 0
file_exists = fs.isfile(config_path) if trainer.is_global_zero else False
# broadcast whether to fail to all ranks
file_exists = trainer.strategy.broadcast(file_exists)
if file_exists:
raise RuntimeError(
f"{self.__class__.__name__} expected {config_path} to NOT exist. Aborting to avoid overwriting"
" results of a previous run. You can delete the previous config file,"
" set `LightningCLI(save_config_callback=None)` to disable config saving,"
' or set `LightningCLI(save_config_kwargs={"overwrite": True})` to overwrite the config file.'
)
# save the file on rank 0
if trainer.is_global_zero:
# save only on rank zero to avoid race conditions.
# the `log_dir` needs to be created as we rely on the logger to do it usually
# but it hasn't logged anything at this point
fs.makedirs(log_dir, exist_ok=True)
self.parser.save(
self.config, config_path, skip_none=False, overwrite=self.overwrite, multifile=self.multifile
)
self.already_saved = True
# save optimizer and lr scheduler config
# for _logger in trainer.loggers:
# if isinstance(_logger, Logger):
# config = {}
# if "optimizer" in self.config:
# config["optimizer"] = {
# k.replace("init_args.", ""): v for k, v in dict(self.config["optimizer"]).items()
# }
# if "lr_scheduler" in self.config:
# config["lr_scheduler"] = {
# k.replace("init_args.", ""): v for k, v in dict(self.config["lr_scheduler"]).items()
# }
# _logger.log_hyperparams(config)
# save optimizer and lr scheduler config
for _logger in trainer.loggers:
if isinstance(_logger, Logger):
config = {}
optimizer_config = self.config.get("optimizer", None)
if optimizer_config is not None:
if isinstance(optimizer_config, list):
config["optimizer"] = [
{k.replace("init_args.", ""): v for k, v in opt_conf.items()}
for opt_conf in optimizer_config if opt_conf is not None
]
else:
config["optimizer"] = {
k.replace("init_args.", ""): v for k, v in optimizer_config.items()
}
lr_scheduler_config = self.config.get("lr_scheduler", None)
if lr_scheduler_config is not None:
if isinstance(lr_scheduler_config, list):
config["lr_scheduler"] = [
{k.replace("init_args.", ""): v for k, v in sch_conf.items()}
for sch_conf in lr_scheduler_config if sch_conf is not None
]
else:
config["lr_scheduler"] = {
k.replace("init_args.", ""): v for k, v in lr_scheduler_config.items()
}
_logger.log_hyperparams(config)
# broadcast so that all ranks are in sync on future calls to .setup()
self.already_saved = trainer.strategy.broadcast(self.already_saved)
class CustomLightningCLI(LightningCLI):
def __init__(
self,
save_config_callback: Optional[Type[SaveConfigCallback]] = WandbSaveConfigCallback,
parser_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
new_parser_kwargs = {
sub_command: dict(default_config_files=[os.path.join("configs", "default.yaml")])
for sub_command in ["fit", "validate", "test", "predict"]
}
new_parser_kwargs.update(parser_kwargs or {})
super().__init__(save_config_callback=save_config_callback, parser_kwargs=new_parser_kwargs, **kwargs)
def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
parser.add_argument("--ignore_warnings", default=False, type=bool, help="Ignore warnings")
parser.add_argument("--git_commit_before_fit", default=False, type=bool, help="Git commit before training")
parser.add_argument(
"--test_after_fit", default=True, type=bool, help="Run test on the best checkpoint after training"
)
parser.add_argument("--validate_every_n_steps", default=10000, type=int, help="Run validation every N steps (0 to disable)")
parser.add_argument("--test_every_epoch", default=False, type=bool, help="Run test after every epoch")
parser.add_argument(
"--load_weights_path",
type=str,
default=None,
help="If set, load model weights from this checkpoint path (but do NOT load optimizer state)."
)
def before_instantiate_classes(self) -> None:
if self.config[self.subcommand].get("ignore_warnings"):
warnings.filterwarnings("ignore")
def after_instantiate_classes(self) -> None:
"""
This is called after:
- arguments have been parsed,
- your LightningModule (self.model) has been created,
- datamodule (if any) has been created,
- and trainer has been instantiated (but not yet fitted).
We use it to load just the model‐weights if --load_weights_path was provided.
"""
load_path = self.config.get("load_weights_path", None)
if load_path:
ckpt = torch.load(load_path, map_location=lambda storage, loc: storage)
# Only load the model's state_dict; optimizer / lr_scheduler are untouched.
self.model.load_state_dict(ckpt["state_dict"])
def before_fit(self) -> None:
if self.config.fit.get("git_commit_before_fit") and not os.environ.get("DEBUG", False):
logger = self.trainer.logger
if isinstance(logger, WandbLogger):
version = getattr(logger, "version")
name = getattr(logger, "_name")
message = "Commit Message"
if name and version:
message = f"{name}_{version}"
elif name:
message = name
elif version:
message = version
os.system(f'git commit -am "{message}"')
# Setup validation interval if specified
validate_every_n_steps = self.config.fit.get("validate_every_n_steps", 0)
if validate_every_n_steps > 0:
self.trainer.fit_loop.epoch_loop.val_check_batch = validate_every_n_steps
def after_fit(self) -> None:
if self.config.fit.get("test_after_fit") and not os.environ.get("DEBUG", False):
self._run_subcommand("test")
def after_train_epoch(self) -> None:
if self.config.fit.get("test_every_epoch") and not os.environ.get("DEBUG", False):
self._run_subcommand("test")
def before_test(self) -> None:
if self.trainer.checkpoint_callback and self.trainer.checkpoint_callback.best_model_path:
tested_ckpt_path = self.trainer.checkpoint_callback.best_model_path
elif self.config_init[self.config_init["subcommand"]]["ckpt_path"]:
return
else:
tested_ckpt_path = None
self.config_init[self.config_init["subcommand"]]["ckpt_path"] = tested_ckpt_path
def _prepare_subcommand_kwargs(self, subcommand: str) -> Dict[str, Any]:
"""Prepares the keyword arguments to pass to the subcommand to run."""
fn_kwargs = {
k: v
for k, v in self.config_init[self.config_init["subcommand"]].items()
if k in self._subcommand_method_arguments[subcommand]
}
fn_kwargs["model"] = self.model
if self.datamodule is not None:
fn_kwargs["datamodule"] = self.datamodule
return fn_kwargs
@staticmethod
def configure_optimizers(
lightning_module: LightningModule, optimizer: Optimizer, lr_scheduler: Optional[LRSchedulerTypeUnion] = None
) -> Any:
"""Override to customize the :meth:`~pytorch_lightning.core.LightningModule.configure_optimizers` method.
Args:
lightning_module: A reference to the model.
optimizer: The optimizer.
lr_scheduler: The learning rate scheduler (if used).
"""
if lr_scheduler is None:
return optimizer
if isinstance(lr_scheduler, ReduceLROnPlateau):
return {
"optimizer": optimizer,
"lr_scheduler": {"scheduler": lr_scheduler, "monitor": lr_scheduler.monitor},
}
if isinstance(lr_scheduler, (OneCycleLR, CyclicLR)):
# CyclicLR and OneCycleLR are step-based schedulers, where the default interval is "epoch".
return {"optimizer": optimizer, "lr_scheduler": {"scheduler": lr_scheduler, "interval": "step"}}
return [optimizer], [lr_scheduler]