diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/imageio/resources/images/stent.npz b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/imageio/resources/images/stent.npz new file mode 100644 index 0000000000000000000000000000000000000000..1df2db57f947caad53c626f29d3ba54c4302036e --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/imageio/resources/images/stent.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:60a83d2296b51ee6a53153e9ba96ba9020391b0c8952895d9d60a0a629ac6bb6 +size 824612 diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/__init__.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f47bc115ece5152c331ff832f2504c7d0d9eb5bb --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/__init__.py @@ -0,0 +1,56 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.callbacks.device_stats_monitor import DeviceStatsMonitor +from pytorch_lightning.callbacks.early_stopping import EarlyStopping +from pytorch_lightning.callbacks.finetuning import BackboneFinetuning, BaseFinetuning +from pytorch_lightning.callbacks.gpu_stats_monitor import GPUStatsMonitor +from pytorch_lightning.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler +from pytorch_lightning.callbacks.lambda_function import LambdaCallback +from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor +from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint +from pytorch_lightning.callbacks.model_summary import ModelSummary +from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter +from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase, RichProgressBar, TQDMProgressBar +from pytorch_lightning.callbacks.pruning import ModelPruning +from pytorch_lightning.callbacks.quantization import QuantizationAwareTraining +from pytorch_lightning.callbacks.rich_model_summary import RichModelSummary +from pytorch_lightning.callbacks.stochastic_weight_avg import StochasticWeightAveraging +from pytorch_lightning.callbacks.timer import Timer +from pytorch_lightning.callbacks.xla_stats_monitor import XLAStatsMonitor + +__all__ = [ + "BackboneFinetuning", + "BaseFinetuning", + "Callback", + "DeviceStatsMonitor", + "EarlyStopping", + "GPUStatsMonitor", + "XLAStatsMonitor", + "GradientAccumulationScheduler", + "LambdaCallback", + "LearningRateMonitor", + "ModelCheckpoint", + "ModelPruning", + "ModelSummary", + "BasePredictionWriter", + "ProgressBar", + "ProgressBarBase", + "QuantizationAwareTraining", + "RichModelSummary", + "RichProgressBar", + "StochasticWeightAveraging", + "Timer", + "TQDMProgressBar", +] diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/base.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/base.py new file mode 100644 index 0000000000000000000000000000000000000000..4dc4cf222e9d6ace3abda5d03367eb83b018004e --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/base.py @@ -0,0 +1,368 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +r""" +Base class used to build new callbacks. + +""" + +from typing import Any, Dict, List, Optional, Type + +import torch +from torch.optim import Optimizer + +import pytorch_lightning as pl +from pytorch_lightning.utilities.types import STEP_OUTPUT + + +class Callback: + r""" + Abstract base class used to build new callbacks. + + Subclass this class and override any of the relevant hooks + """ + + @property + def state_key(self) -> str: + """Identifier for the state of the callback. + + Used to store and retrieve a callback's state from the checkpoint dictionary by + ``checkpoint["callbacks"][state_key]``. Implementations of a callback need to provide a unique state key if 1) + the callback has state and 2) it is desired to maintain the state of multiple instances of that callback. + """ + return self.__class__.__qualname__ + + @property + def _legacy_state_key(self) -> Type["Callback"]: + """State key for checkpoints saved prior to version 1.5.0.""" + return type(self) + + def _generate_state_key(self, **kwargs: Any) -> str: + """Formats a set of key-value pairs into a state key string with the callback class name prefixed. Useful + for defining a :attr:`state_key`. + + Args: + **kwargs: A set of key-value pairs. Must be serializable to :class:`str`. + """ + return f"{self.__class__.__qualname__}{repr(kwargs)}" + + def on_configure_sharded_model(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + r""" + .. deprecated:: v1.6 + This callback hook was deprecated in v1.6 and will be removed in v1.8. Use `setup()` instead. + + Called before configure sharded model. + """ + + def on_before_accelerator_backend_setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + r""" + .. deprecated:: v1.6 + This callback hook was deprecated in v1.6 and will be removed in v1.8. Use ``setup()`` instead. + + Called before accelerator is being setup. + """ + + def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: + """Called when fit, validate, test, predict, or tune begins.""" + + def teardown(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: + """Called when fit, validate, test, predict, or tune ends.""" + + def on_init_start(self, trainer: "pl.Trainer") -> None: + r""" + .. deprecated:: v1.6 + This callback hook was deprecated in v1.6 and will be removed in v1.8. + + Called when the trainer initialization begins, model has not yet been set. + """ + + def on_init_end(self, trainer: "pl.Trainer") -> None: + r""" + .. deprecated:: v1.6 + This callback hook was deprecated in v1.6 and will be removed in v1.8. + + Called when the trainer initialization ends, model has not yet been set. + """ + + def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Called when fit begins.""" + + def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Called when fit ends.""" + + def on_sanity_check_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Called when the validation sanity check starts.""" + + def on_sanity_check_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Called when the validation sanity check ends.""" + + def on_train_batch_start( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + batch: Any, + batch_idx: int, + unused: int = 0, + ) -> None: + """Called when the train batch begins.""" + + def on_train_batch_end( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + outputs: STEP_OUTPUT, + batch: Any, + batch_idx: int, + unused: int = 0, + ) -> None: + """Called when the train batch ends.""" + + def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Called when the train epoch begins.""" + + def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Called when the train epoch ends. + + To access all batch outputs at the end of the epoch, either: + + 1. Implement `training_epoch_end` in the `LightningModule` and access outputs via the module OR + 2. Cache data across train batch hooks inside the callback implementation to post-process in this hook. + """ + + def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Called when the val epoch begins.""" + + def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Called when the val epoch ends.""" + + def on_test_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Called when the test epoch begins.""" + + def on_test_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Called when the test epoch ends.""" + + def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Called when the predict epoch begins.""" + + def on_predict_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: List[Any]) -> None: + """Called when the predict epoch ends.""" + + def on_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + r""" + .. deprecated:: v1.6 + This callback hook was deprecated in v1.6 and will be removed in v1.8. Use + ``on__epoch_start`` instead. + + Called when either of train/val/test epoch begins. + """ + + def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + r""" + .. deprecated:: v1.6 + This callback hook was deprecated in v1.6 and will be removed in v1.8. Use + ``on__epoch_end`` instead. + + Called when either of train/val/test epoch ends. + """ + + def on_batch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + r""" + .. deprecated:: v1.6 + This callback hook was deprecated in v1.6 and will be removed in v1.8. Use + ``on_train_batch_start`` instead. + + Called when the training batch begins. + """ + + def on_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + r""" + .. deprecated:: v1.6 + This callback hook was deprecated in v1.6 and will be removed in v1.8. Use + ``on_train_batch_end`` instead. + + Called when the training batch ends. + """ + + def on_validation_batch_start( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int + ) -> None: + """Called when the validation batch begins.""" + + def on_validation_batch_end( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + outputs: Optional[STEP_OUTPUT], + batch: Any, + batch_idx: int, + dataloader_idx: int, + ) -> None: + """Called when the validation batch ends.""" + + def on_test_batch_start( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int + ) -> None: + """Called when the test batch begins.""" + + def on_test_batch_end( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + outputs: Optional[STEP_OUTPUT], + batch: Any, + batch_idx: int, + dataloader_idx: int, + ) -> None: + """Called when the test batch ends.""" + + def on_predict_batch_start( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int + ) -> None: + """Called when the predict batch begins.""" + + def on_predict_batch_end( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + outputs: Any, + batch: Any, + batch_idx: int, + dataloader_idx: int, + ) -> None: + """Called when the predict batch ends.""" + + def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Called when the train begins.""" + + def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Called when the train ends.""" + + def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + r""" + .. deprecated:: v1.6 + + This callback hook was deprecated in v1.6 and will be removed in v1.8. Use ``on_fit_start`` instead. + + Called when the pretrain routine begins. + """ + + def on_pretrain_routine_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + r""" + .. deprecated:: v1.6 + + This callback hook was deprecated in v1.6 and will be removed in v1.8. Use ``on_fit_start`` instead. + + Called when the pretrain routine ends. + """ + + def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Called when the validation loop begins.""" + + def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Called when the validation loop ends.""" + + def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Called when the test begins.""" + + def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Called when the test ends.""" + + def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Called when the predict begins.""" + + def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Called when predict ends.""" + + def on_keyboard_interrupt(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + r""" + .. deprecated:: v1.5 + This callback hook was deprecated in v1.5 in favor of `on_exception` and will be removed in v1.7. + + Called when any trainer execution is interrupted by KeyboardInterrupt. + """ + + def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: BaseException) -> None: + """Called when any trainer execution is interrupted by an exception.""" + + def state_dict(self) -> Dict[str, Any]: + """Called when saving a checkpoint, implement to generate callback's ``state_dict``. + + Returns: + A dictionary containing callback state. + """ + return {} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + """Called when loading a checkpoint, implement to reload callback state given callback's ``state_dict``. + + Args: + state_dict: the callback state returned by ``state_dict``. + """ + pass + + def on_save_checkpoint( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] + ) -> Optional[dict]: + r""" + Called when saving a checkpoint to give you a chance to store anything else you might want to save. + + Args: + trainer: the current :class:`~pytorch_lightning.trainer.Trainer` instance. + pl_module: the current :class:`~pytorch_lightning.core.lightning.LightningModule` instance. + checkpoint: the checkpoint dictionary that will be saved. + + Returns: + None or the callback state. Support for returning callback state will be removed in v1.8. + + .. deprecated:: v1.6 + Returning a value from this method was deprecated in v1.6 and will be removed in v1.8. + Implement ``Callback.state_dict`` instead to return state. + In v1.8 ``Callback.on_save_checkpoint`` can only return None. + """ + + def on_load_checkpoint( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", callback_state: Dict[str, Any] + ) -> None: + r""" + Called when loading a model checkpoint, use to reload state. + + Args: + trainer: the current :class:`~pytorch_lightning.trainer.Trainer` instance. + pl_module: the current :class:`~pytorch_lightning.core.lightning.LightningModule` instance. + callback_state: the callback state returned by ``on_save_checkpoint``. + + Note: + The ``on_load_checkpoint`` won't be called with an undefined state. + If your ``on_load_checkpoint`` hook behavior doesn't rely on a state, + you will still need to override ``on_save_checkpoint`` to return a ``dummy state``. + + .. deprecated:: v1.6 + This callback hook will change its signature and behavior in v1.8. + If you wish to load the state of the callback, use ``Callback.load_state_dict`` instead. + In v1.8 ``Callback.on_load_checkpoint(checkpoint)`` will receive the entire loaded + checkpoint dictionary instead of only the callback state from the checkpoint. + """ + + def on_before_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", loss: torch.Tensor) -> None: + """Called before ``loss.backward()``.""" + + def on_after_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Called after ``loss.backward()`` and before optimizers are stepped.""" + + def on_before_optimizer_step( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: Optimizer, opt_idx: int + ) -> None: + """Called before ``optimizer.step()``.""" + + def on_before_zero_grad(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: Optimizer) -> None: + """Called before ``optimizer.zero_grad()``.""" diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/device_stats_monitor.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/device_stats_monitor.py new file mode 100644 index 0000000000000000000000000000000000000000..0929358cf0f74d524217263f071dc883d2b834d4 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/device_stats_monitor.py @@ -0,0 +1,104 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Device Stats Monitor +==================== + +Monitors and logs device stats during training. + +""" +from typing import Any, Dict, Optional + +import pytorch_lightning as pl +from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.types import STEP_OUTPUT +from pytorch_lightning.utilities.warnings import rank_zero_deprecation + + +class DeviceStatsMonitor(Callback): + r""" + Automatically monitors and logs device stats during training stage. ``DeviceStatsMonitor`` + is a special callback as it requires a ``logger`` to passed as argument to the ``Trainer``. + + Raises: + MisconfigurationException: + If ``Trainer`` has no logger. + + Example: + >>> from pytorch_lightning import Trainer + >>> from pytorch_lightning.callbacks import DeviceStatsMonitor + >>> device_stats = DeviceStatsMonitor() # doctest: +SKIP + >>> trainer = Trainer(callbacks=[device_stats]) # doctest: +SKIP + """ + + def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: + if not trainer.loggers: + raise MisconfigurationException("Cannot use DeviceStatsMonitor callback with Trainer that has no logger.") + + def on_train_batch_start( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + batch: Any, + batch_idx: int, + unused: int = 0, + ) -> None: + if not trainer.loggers: + raise MisconfigurationException("Cannot use `DeviceStatsMonitor` callback with `Trainer(logger=False)`.") + + if not trainer._logger_connector.should_update_logs: + return + + device = trainer.strategy.root_device + device_stats = trainer.accelerator.get_device_stats(device) + for logger in trainer.loggers: + separator = logger.group_separator + prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_start", separator) + logger.log_metrics(prefixed_device_stats, step=trainer.fit_loop.epoch_loop._batches_that_stepped) + + def on_train_batch_end( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + outputs: STEP_OUTPUT, + batch: Any, + batch_idx: int, + unused: int = 0, + ) -> None: + if not trainer.loggers: + raise MisconfigurationException("Cannot use `DeviceStatsMonitor` callback with `Trainer(logger=False)`.") + + if not trainer._logger_connector.should_update_logs: + return + + device = trainer.strategy.root_device + device_stats = trainer.accelerator.get_device_stats(device) + for logger in trainer.loggers: + separator = logger.group_separator + prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_end", separator) + logger.log_metrics(prefixed_device_stats, step=trainer.fit_loop.epoch_loop._batches_that_stepped) + + +def _prefix_metric_keys(metrics_dict: Dict[str, float], prefix: str, separator: str) -> Dict[str, float]: + return {prefix + separator + k: v for k, v in metrics_dict.items()} + + +def prefix_metric_keys(metrics_dict: Dict[str, float], prefix: str) -> Dict[str, float]: + rank_zero_deprecation( + "`pytorch_lightning.callbacks.device_stats_monitor.prefix_metrics`" + " is deprecated in v1.6 and will be removed in v1.8." + ) + sep = "" + return _prefix_metric_keys(metrics_dict, prefix, sep) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/early_stopping.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/early_stopping.py new file mode 100644 index 0000000000000000000000000000000000000000..16b1bfce152adb0efe1a583aa4038ac8a00116af --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/early_stopping.py @@ -0,0 +1,261 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +r""" +Early Stopping +^^^^^^^^^^^^^^ + +Monitor a metric and stop training when it stops improving. + +""" +import logging +from typing import Any, Callable, Dict, Optional, Tuple + +import numpy as np +import torch + +import pytorch_lightning as pl +from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.rank_zero import rank_zero_warn + +log = logging.getLogger(__name__) + + +class EarlyStopping(Callback): + r""" + Monitor a metric and stop training when it stops improving. + + Args: + monitor: quantity to be monitored. + min_delta: minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute + change of less than or equal to `min_delta`, will count as no improvement. + patience: number of checks with no improvement + after which training will be stopped. Under the default configuration, one check happens after + every training epoch. However, the frequency of validation can be modified by setting various parameters on + the ``Trainer``, for example ``check_val_every_n_epoch`` and ``val_check_interval``. + + .. note:: + + It must be noted that the patience parameter counts the number of validation checks with + no improvement, and not the number of training epochs. Therefore, with parameters + ``check_val_every_n_epoch=10`` and ``patience=3``, the trainer will perform at least 40 training + epochs before being stopped. + + verbose: verbosity mode. + mode: one of ``'min'``, ``'max'``. In ``'min'`` mode, training will stop when the quantity + monitored has stopped decreasing and in ``'max'`` mode it will stop when the quantity + monitored has stopped increasing. + strict: whether to crash the training if `monitor` is not found in the validation metrics. + check_finite: When set ``True``, stops training when the monitor becomes NaN or infinite. + stopping_threshold: Stop training immediately once the monitored quantity reaches this threshold. + divergence_threshold: Stop training as soon as the monitored quantity becomes worse than this threshold. + check_on_train_epoch_end: whether to run early stopping at the end of the training epoch. + If this is ``False``, then the check runs at the end of the validation. + + Raises: + MisconfigurationException: + If ``mode`` is none of ``"min"`` or ``"max"``. + RuntimeError: + If the metric ``monitor`` is not available. + + Example:: + + >>> from pytorch_lightning import Trainer + >>> from pytorch_lightning.callbacks import EarlyStopping + >>> early_stopping = EarlyStopping('val_loss') + >>> trainer = Trainer(callbacks=[early_stopping]) + + .. tip:: Saving and restoring multiple early stopping callbacks at the same time is supported under variation in the + following arguments: + + *monitor, mode* + + Read more: :ref:`Persisting Callback State` + """ + mode_dict = {"min": torch.lt, "max": torch.gt} + + order_dict = {"min": "<", "max": ">"} + + def __init__( + self, + monitor: str, + min_delta: float = 0.0, + patience: int = 3, + verbose: bool = False, + mode: str = "min", + strict: bool = True, + check_finite: bool = True, + stopping_threshold: Optional[float] = None, + divergence_threshold: Optional[float] = None, + check_on_train_epoch_end: Optional[bool] = None, + ): + super().__init__() + self.monitor = monitor + self.min_delta = min_delta + self.patience = patience + self.verbose = verbose + self.mode = mode + self.strict = strict + self.check_finite = check_finite + self.stopping_threshold = stopping_threshold + self.divergence_threshold = divergence_threshold + self.wait_count = 0 + self.stopped_epoch = 0 + self._check_on_train_epoch_end = check_on_train_epoch_end + + if self.mode not in self.mode_dict: + raise MisconfigurationException(f"`mode` can be {', '.join(self.mode_dict.keys())}, got {self.mode}") + + self.min_delta *= 1 if self.monitor_op == torch.gt else -1 + torch_inf = torch.tensor(np.Inf) + self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf + + @property + def state_key(self) -> str: + return self._generate_state_key(monitor=self.monitor, mode=self.mode) + + def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: + if self._check_on_train_epoch_end is None: + # if the user runs validation multiple times per training epoch or multiple training epochs without + # validation, then we run after validation instead of on train epoch end + self._check_on_train_epoch_end = trainer.val_check_interval == 1.0 and trainer.check_val_every_n_epoch == 1 + + def _validate_condition_metric(self, logs: Dict[str, float]) -> bool: + monitor_val = logs.get(self.monitor) + + error_msg = ( + f"Early stopping conditioned on metric `{self.monitor}` which is not available." + " Pass in or modify your `EarlyStopping` callback to use any of the following:" + f' `{"`, `".join(list(logs.keys()))}`' + ) + + if monitor_val is None: + if self.strict: + raise RuntimeError(error_msg) + if self.verbose > 0: + rank_zero_warn(error_msg, category=RuntimeWarning) + + return False + + return True + + @property + def monitor_op(self) -> Callable: + return self.mode_dict[self.mode] + + def state_dict(self) -> Dict[str, Any]: + return { + "wait_count": self.wait_count, + "stopped_epoch": self.stopped_epoch, + "best_score": self.best_score, + "patience": self.patience, + } + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + self.wait_count = state_dict["wait_count"] + self.stopped_epoch = state_dict["stopped_epoch"] + self.best_score = state_dict["best_score"] + self.patience = state_dict["patience"] + + def _should_skip_check(self, trainer: "pl.Trainer") -> bool: + from pytorch_lightning.trainer.states import TrainerFn + + return trainer.state.fn != TrainerFn.FITTING or trainer.sanity_checking + + def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if not self._check_on_train_epoch_end or self._should_skip_check(trainer): + return + self._run_early_stopping_check(trainer) + + def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if self._check_on_train_epoch_end or self._should_skip_check(trainer): + return + self._run_early_stopping_check(trainer) + + def _run_early_stopping_check(self, trainer: "pl.Trainer") -> None: + """Checks whether the early stopping condition is met and if so tells the trainer to stop the training.""" + logs = trainer.callback_metrics + + if trainer.fast_dev_run or not self._validate_condition_metric( # disable early_stopping with fast_dev_run + logs + ): # short circuit if metric not present + return + + current = logs[self.monitor].squeeze() + should_stop, reason = self._evaluate_stopping_criteria(current) + + # stop every ddp process if any world process decides to stop + should_stop = trainer.strategy.reduce_boolean_decision(should_stop) + trainer.should_stop = trainer.should_stop or should_stop + if should_stop: + self.stopped_epoch = trainer.current_epoch + if reason and self.verbose: + self._log_info(trainer, reason) + + def _evaluate_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, Optional[str]]: + should_stop = False + reason = None + if self.check_finite and not torch.isfinite(current): + should_stop = True + reason = ( + f"Monitored metric {self.monitor} = {current} is not finite." + f" Previous best value was {self.best_score:.3f}. Signaling Trainer to stop." + ) + elif self.stopping_threshold is not None and self.monitor_op(current, self.stopping_threshold): + should_stop = True + reason = ( + "Stopping threshold reached:" + f" {self.monitor} = {current} {self.order_dict[self.mode]} {self.stopping_threshold}." + " Signaling Trainer to stop." + ) + elif self.divergence_threshold is not None and self.monitor_op(-current, -self.divergence_threshold): + should_stop = True + reason = ( + "Divergence threshold reached:" + f" {self.monitor} = {current} {self.order_dict[self.mode]} {self.divergence_threshold}." + " Signaling Trainer to stop." + ) + elif self.monitor_op(current - self.min_delta, self.best_score.to(current.device)): + should_stop = False + reason = self._improvement_message(current) + self.best_score = current + self.wait_count = 0 + else: + self.wait_count += 1 + if self.wait_count >= self.patience: + should_stop = True + reason = ( + f"Monitored metric {self.monitor} did not improve in the last {self.wait_count} records." + f" Best score: {self.best_score:.3f}. Signaling Trainer to stop." + ) + + return should_stop, reason + + def _improvement_message(self, current: torch.Tensor) -> str: + """Formats a log message that informs the user about an improvement in the monitored score.""" + if torch.isfinite(self.best_score): + msg = ( + f"Metric {self.monitor} improved by {abs(self.best_score - current):.3f} >=" + f" min_delta = {abs(self.min_delta)}. New best score: {current:.3f}" + ) + else: + msg = f"Metric {self.monitor} improved. New best score: {current:.3f}" + return msg + + @staticmethod + def _log_info(trainer: Optional["pl.Trainer"], message: str) -> None: + if trainer is not None and trainer.world_size > 1: + log.info(f"[rank: {trainer.global_rank}] {message}") + else: + log.info(message) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/finetuning.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/finetuning.py new file mode 100644 index 0000000000000000000000000000000000000000..c01df9437851460ab03601291188feed79b668c6 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/finetuning.py @@ -0,0 +1,417 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +r""" +Finetuning Callback +^^^^^^^^^^^^^^^^^^^^ +Freeze and unfreeze models for finetuning purposes +""" +import logging +from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Union + +import torch +from torch.nn import Module, ModuleDict +from torch.nn.modules.batchnorm import _BatchNorm +from torch.optim.optimizer import Optimizer + +import pytorch_lightning as pl +from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.rank_zero import rank_zero_warn + +log = logging.getLogger(__name__) + + +def multiplicative(epoch): + return 2 + + +class BaseFinetuning(Callback): + r""" + This class implements the base logic for writing your own Finetuning Callback. + + Override ``freeze_before_training`` and ``finetune_function`` methods with your own logic. + + ``freeze_before_training``: This method is called before ``configure_optimizers`` + and should be used to freeze any modules parameters. + + ``finetune_function``: This method is called on every train epoch start and should be used to + ``unfreeze`` any parameters. Those parameters needs to be added in a new ``param_group`` + within the optimizer. + + .. note:: Make sure to filter the parameters based on ``requires_grad``. + + Example:: + + >>> from torch.optim import Adam + >>> class MyModel(pl.LightningModule): + ... def configure_optimizer(self): + ... # Make sure to filter the parameters based on `requires_grad` + ... return Adam(filter(lambda p: p.requires_grad, self.parameters())) + ... + >>> class FeatureExtractorFreezeUnfreeze(BaseFinetuning): + ... def __init__(self, unfreeze_at_epoch=10): + ... super().__init__() + ... self._unfreeze_at_epoch = unfreeze_at_epoch + ... + ... def freeze_before_training(self, pl_module): + ... # freeze any module you want + ... # Here, we are freezing `feature_extractor` + ... self.freeze(pl_module.feature_extractor) + ... + ... def finetune_function(self, pl_module, current_epoch, optimizer, optimizer_idx): + ... # When `current_epoch` is 10, feature_extractor will start training. + ... if current_epoch == self._unfreeze_at_epoch: + ... self.unfreeze_and_add_param_group( + ... modules=pl_module.feature_extractor, + ... optimizer=optimizer, + ... train_bn=True, + ... ) + """ + + def __init__(self): + self._internal_optimizer_metadata: Dict[int, List[Dict[str, Any]]] = {} + self._restarting = False + + def state_dict(self) -> Dict[str, Any]: + return { + "internal_optimizer_metadata": self._internal_optimizer_metadata, + } + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + self._restarting = True + if "internal_optimizer_metadata" in state_dict: + self._internal_optimizer_metadata = state_dict["internal_optimizer_metadata"] + else: + # compatibility to load from old checkpoints before PR #11887 + self._internal_optimizer_metadata = state_dict + + def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + # restore the param_groups created during the previous training. + if self._restarting: + named_parameters = dict(pl_module.named_parameters()) + for opt_idx, optimizer in enumerate(trainer.optimizers): + param_groups = self._apply_mapping_to_param_groups( + self._internal_optimizer_metadata[opt_idx], named_parameters + ) + optimizer.param_groups = param_groups + self._restarting = False + + @staticmethod + def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -> List[Module]: + """This function is used to flatten a module or an iterable of modules into a list of its leaf modules + (modules with no children) and parent modules that have parameters directly themselves. + + Args: + modules: A given module or an iterable of modules + + Returns: + List of modules + """ + if isinstance(modules, ModuleDict): + modules = modules.values() + + if isinstance(modules, Iterable): + _modules = [] + for m in modules: + _modules.extend(BaseFinetuning.flatten_modules(m)) + + else: + _modules = modules.modules() + + # Capture all leaf modules as well as parent modules that have parameters directly themselves + return [m for m in _modules if not list(m.children()) or m._parameters] + + @staticmethod + def filter_params( + modules: Union[Module, Iterable[Union[Module, Iterable]]], train_bn: bool = True, requires_grad: bool = True + ) -> Generator: + """Yields the `requires_grad` parameters of a given module or list of modules. + + Args: + modules: A given module or an iterable of modules + train_bn: Whether to train BatchNorm module + requires_grad: Whether to create a generator for trainable or non-trainable parameters. + Returns: + Generator + """ + modules = BaseFinetuning.flatten_modules(modules) + for mod in modules: + if isinstance(mod, _BatchNorm) and not train_bn: + continue + # recursion could yield duplicate parameters for parent modules w/ parameters so disabling it + for param in mod.parameters(recurse=False): + if param.requires_grad == requires_grad: + yield param + + @staticmethod + def make_trainable(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -> None: + """Unfreezes the parameters of the provided modules. + + Args: + modules: A given module or an iterable of modules + """ + modules = BaseFinetuning.flatten_modules(modules) + for module in modules: + # recursion could yield duplicate parameters for parent modules w/ parameters so disabling it + for param in module.parameters(recurse=False): + param.requires_grad = True + + @staticmethod + def freeze(modules: Union[Module, Iterable[Union[Module, Iterable]]], train_bn: bool = True) -> None: + """Freezes the parameters of the provided modules. + + Args: + modules: A given module or an iterable of modules + train_bn: If True, leave the BatchNorm layers in training mode + + Returns: + None + """ + modules = BaseFinetuning.flatten_modules(modules) + for mod in modules: + if isinstance(mod, _BatchNorm) and train_bn: + BaseFinetuning.make_trainable(mod) + else: + # recursion could yield duplicate parameters for parent modules w/ parameters so disabling it + for param in mod.parameters(recurse=False): + param.requires_grad = False + + @staticmethod + def filter_on_optimizer(optimizer: Optimizer, params: Iterable) -> List: + """This function is used to exclude any parameter which already exists in this optimizer. + + Args: + optimizer: Optimizer used for parameter exclusion + params: Iterable of parameters used to check against the provided optimizer + + Returns: + List of parameters not contained in this optimizer param groups + """ + out_params = [] + removed_params = [] + for param in params: + if not any(torch.equal(p, param) for group in optimizer.param_groups for p in group["params"]): + out_params.append(param) + else: + removed_params.append(param) + + if removed_params: + rank_zero_warn( + "The provided params to be frozen already exist within another group of this optimizer." + " Those parameters will be skipped.\n" + "HINT: Did you init your optimizer in `configure_optimizer` as such:\n" + f" {type(optimizer)}(filter(lambda p: p.requires_grad, self.parameters()), ...) ", + ) + return out_params + + @staticmethod + def unfreeze_and_add_param_group( + modules: Union[Module, Iterable[Union[Module, Iterable]]], + optimizer: Optimizer, + lr: Optional[float] = None, + initial_denom_lr: float = 10.0, + train_bn: bool = True, + ) -> None: + """Unfreezes a module and adds its parameters to an optimizer. + + Args: + modules: A module or iterable of modules to unfreeze. + Their parameters will be added to an optimizer as a new param group. + optimizer: The provided optimizer will receive new parameters and will add them to + `add_param_group` + lr: Learning rate for the new param group. + initial_denom_lr: If no lr is provided, the learning from the first param group will be used + and divided by `initial_denom_lr`. + train_bn: Whether to train the BatchNormalization layers. + """ + BaseFinetuning.make_trainable(modules) + params_lr = optimizer.param_groups[0]["lr"] if lr is None else float(lr) + denom_lr = initial_denom_lr if lr is None else 1.0 + params = BaseFinetuning.filter_params(modules, train_bn=train_bn, requires_grad=True) + params = BaseFinetuning.filter_on_optimizer(optimizer, params) + if params: + optimizer.add_param_group({"params": params, "lr": params_lr / denom_lr}) + + def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: + self.freeze_before_training(pl_module) + + @staticmethod + def _apply_mapping_to_param_groups(param_groups: List[Dict[str, Any]], mapping: dict) -> List[Dict[str, Any]]: + output = [] + for g in param_groups: + # skip params to save memory + group_state = {k: v for k, v in g.items() if k != "params"} + group_state["params"] = [mapping[p] for p in g["params"]] + output.append(group_state) + return output + + def _store( + self, + pl_module: "pl.LightningModule", + opt_idx: int, + num_param_groups: int, + current_param_groups: List[Dict[str, Any]], + ) -> None: + mapping = {p: n for n, p in pl_module.named_parameters()} + if opt_idx not in self._internal_optimizer_metadata: + self._internal_optimizer_metadata[opt_idx] = self._apply_mapping_to_param_groups( + current_param_groups, mapping + ) + elif num_param_groups != len(current_param_groups): + # save new param_groups possibly created by the users. + self._internal_optimizer_metadata[opt_idx].extend( + self._apply_mapping_to_param_groups(current_param_groups[num_param_groups:], mapping) + ) + + def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Called when the epoch begins.""" + # import is here to avoid circular imports + from pytorch_lightning.loops.utilities import _get_active_optimizers + + for opt_idx, optimizer in _get_active_optimizers(trainer.optimizers, trainer.optimizer_frequencies): + num_param_groups = len(optimizer.param_groups) + self.finetune_function(pl_module, trainer.current_epoch, optimizer, opt_idx) + current_param_groups = optimizer.param_groups + self._store(pl_module, opt_idx, num_param_groups, current_param_groups) + + def finetune_function( + self, pl_module: "pl.LightningModule", epoch: int, optimizer: Optimizer, opt_idx: int + ) -> None: + """Override to add your unfreeze logic.""" + raise NotImplementedError + + def freeze_before_training(self, pl_module: "pl.LightningModule") -> None: + """Override to add your freeze logic.""" + raise NotImplementedError + + +class BackboneFinetuning(BaseFinetuning): + r"""Finetune a backbone model based on a learning rate user-defined scheduling. + + When the backbone learning rate reaches the current model learning rate + and ``should_align`` is set to True, it will align with it for the rest of the training. + + Args: + unfreeze_backbone_at_epoch: Epoch at which the backbone will be unfreezed. + lambda_func: Scheduling function for increasing backbone learning rate. + backbone_initial_ratio_lr: + Used to scale down the backbone learning rate compared to rest of model + backbone_initial_lr: Optional, Initial learning rate for the backbone. + By default, we will use ``current_learning / backbone_initial_ratio_lr`` + should_align: Whether to align with current learning rate when backbone learning + reaches it. + initial_denom_lr: When unfreezing the backbone, the initial learning rate will + ``current_learning_rate / initial_denom_lr``. + train_bn: Whether to make Batch Normalization trainable. + verbose: Display current learning rate for model and backbone + rounding: Precision for displaying learning rate + + Example:: + + >>> from pytorch_lightning import Trainer + >>> from pytorch_lightning.callbacks import BackboneFinetuning + >>> multiplicative = lambda epoch: 1.5 + >>> backbone_finetuning = BackboneFinetuning(200, multiplicative) + >>> trainer = Trainer(callbacks=[backbone_finetuning]) + + """ + + def __init__( + self, + unfreeze_backbone_at_epoch: int = 10, + lambda_func: Callable = multiplicative, + backbone_initial_ratio_lr: float = 10e-2, + backbone_initial_lr: Optional[float] = None, + should_align: bool = True, + initial_denom_lr: float = 10.0, + train_bn: bool = True, + verbose: bool = False, + rounding: int = 12, + ) -> None: + super().__init__() + + self.unfreeze_backbone_at_epoch: int = unfreeze_backbone_at_epoch + self.lambda_func: Callable = lambda_func + self.backbone_initial_ratio_lr: float = backbone_initial_ratio_lr + self.backbone_initial_lr: Optional[float] = backbone_initial_lr + self.should_align: bool = should_align + self.initial_denom_lr: float = initial_denom_lr + self.train_bn: bool = train_bn + self.verbose: bool = verbose + self.rounding: int = rounding + self.previous_backbone_lr: Optional[float] = None + + def state_dict(self) -> Dict[str, Any]: + return { + "internal_optimizer_metadata": self._internal_optimizer_metadata, + "previous_backbone_lr": self.previous_backbone_lr, + } + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + self.previous_backbone_lr = state_dict["previous_backbone_lr"] + super().load_state_dict(state_dict) + + def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """ + Raises: + MisconfigurationException: + If LightningModule has no nn.Module `backbone` attribute. + """ + if hasattr(pl_module, "backbone") and isinstance(pl_module.backbone, Module): + return super().on_fit_start(trainer, pl_module) + raise MisconfigurationException("The LightningModule should have a nn.Module `backbone` attribute") + + def freeze_before_training(self, pl_module: "pl.LightningModule") -> None: + self.freeze(pl_module.backbone) + + def finetune_function( + self, pl_module: "pl.LightningModule", epoch: int, optimizer: Optimizer, opt_idx: int + ) -> None: + """Called when the epoch begins.""" + if epoch == self.unfreeze_backbone_at_epoch: + current_lr = optimizer.param_groups[0]["lr"] + initial_backbone_lr = ( + self.backbone_initial_lr + if self.backbone_initial_lr is not None + else current_lr * self.backbone_initial_ratio_lr + ) + self.previous_backbone_lr = initial_backbone_lr + self.unfreeze_and_add_param_group( + pl_module.backbone, + optimizer, + initial_backbone_lr, + train_bn=self.train_bn, + initial_denom_lr=self.initial_denom_lr, + ) + if self.verbose: + log.info( + f"Current lr: {round(current_lr, self.rounding)}, " + f"Backbone lr: {round(initial_backbone_lr, self.rounding)}" + ) + + elif epoch > self.unfreeze_backbone_at_epoch: + current_lr = optimizer.param_groups[0]["lr"] + next_current_backbone_lr = self.lambda_func(epoch + 1) * self.previous_backbone_lr + next_current_backbone_lr = ( + current_lr + if (self.should_align and next_current_backbone_lr > current_lr) + else next_current_backbone_lr + ) + optimizer.param_groups[-1]["lr"] = next_current_backbone_lr + self.previous_backbone_lr = next_current_backbone_lr + if self.verbose: + log.info( + f"Current lr: {round(current_lr, self.rounding)}, " + f"Backbone lr: {round(next_current_backbone_lr, self.rounding)}" + ) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/gpu_stats_monitor.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/gpu_stats_monitor.py new file mode 100644 index 0000000000000000000000000000000000000000..36b4006e37a2a83a05da705f01b3e0b869fada71 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/gpu_stats_monitor.py @@ -0,0 +1,262 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +GPU Stats Monitor +================= + +Monitor and logs GPU stats during training. + +""" + +import os +import shutil +import subprocess +import time +from typing import Any, Dict, List, Optional, Tuple + +import torch + +import pytorch_lightning as pl +from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.parsing import AttributeDict +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_only +from pytorch_lightning.utilities.types import STEP_OUTPUT + + +class GPUStatsMonitor(Callback): + r""" + .. deprecated:: v1.5 + The `GPUStatsMonitor` callback was deprecated in v1.5 and will be removed in v1.7. + Please use the `DeviceStatsMonitor` callback instead. + + Automatically monitors and logs GPU stats during training stage. ``GPUStatsMonitor`` + is a callback and in order to use it you need to assign a logger in the ``Trainer``. + + Args: + memory_utilization: Set to ``True`` to monitor used, free and percentage of memory + utilization at the start and end of each step. Default: ``True``. + gpu_utilization: Set to ``True`` to monitor percentage of GPU utilization + at the start and end of each step. Default: ``True``. + intra_step_time: Set to ``True`` to monitor the time of each step. Default: ``False``. + inter_step_time: Set to ``True`` to monitor the time between the end of one step + and the start of the next step. Default: ``False``. + fan_speed: Set to ``True`` to monitor percentage of fan speed. Default: ``False``. + temperature: Set to ``True`` to monitor the memory and gpu temperature in degree Celsius. + Default: ``False``. + + Raises: + MisconfigurationException: + If NVIDIA driver is not installed, not running on GPUs, or ``Trainer`` has no logger. + + Example:: + + >>> from pytorch_lightning import Trainer + >>> from pytorch_lightning.callbacks import GPUStatsMonitor + >>> gpu_stats = GPUStatsMonitor() # doctest: +SKIP + >>> trainer = Trainer(callbacks=[gpu_stats]) # doctest: +SKIP + + GPU stats are mainly based on `nvidia-smi --query-gpu` command. The description of the queries is as follows: + + - **fan.speed** – The fan speed value is the percent of maximum speed that the device's fan is currently + intended to run at. It ranges from 0 to 100 %. Note: The reported speed is the intended fan speed. + If the fan is physically blocked and unable to spin, this output will not match the actual fan speed. + Many parts do not report fan speeds because they rely on cooling via fans in the surrounding enclosure. + - **memory.used** – Total memory allocated by active contexts. + - **memory.free** – Total free memory. + - **utilization.gpu** – Percent of time over the past sample period during which one or more kernels was + executing on the GPU. The sample period may be between 1 second and 1/6 second depending on the product. + - **utilization.memory** – Percent of time over the past sample period during which global (device) memory was + being read or written. The sample period may be between 1 second and 1/6 second depending on the product. + - **temperature.gpu** – Core GPU temperature, in degrees C. + - **temperature.memory** – HBM memory temperature, in degrees C. + + """ + + def __init__( + self, + memory_utilization: bool = True, + gpu_utilization: bool = True, + intra_step_time: bool = False, + inter_step_time: bool = False, + fan_speed: bool = False, + temperature: bool = False, + ): + super().__init__() + + rank_zero_deprecation( + "The `GPUStatsMonitor` callback was deprecated in v1.5 and will be removed in v1.7." + " Please use the `DeviceStatsMonitor` callback instead." + ) + + if shutil.which("nvidia-smi") is None: + raise MisconfigurationException( + "Cannot use GPUStatsMonitor callback because NVIDIA driver is not installed." + ) + + self._log_stats = AttributeDict( + { + "memory_utilization": memory_utilization, + "gpu_utilization": gpu_utilization, + "intra_step_time": intra_step_time, + "inter_step_time": inter_step_time, + "fan_speed": fan_speed, + "temperature": temperature, + } + ) + + # The logical device IDs for selected devices + self._device_ids: List[int] = [] # will be assigned later in setup() + + # The unmasked real GPU IDs + self._gpu_ids: List[str] = [] # will be assigned later in setup() + + def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: + if not trainer.loggers: + raise MisconfigurationException("Cannot use GPUStatsMonitor callback with Trainer that has no logger.") + + if trainer.strategy.root_device.type != "cuda": + raise MisconfigurationException( + "You are using GPUStatsMonitor but are not running on GPU." + f" The root device type is {trainer.strategy.root_device.type}." + ) + + # The logical device IDs for selected devices + self._device_ids = sorted(set(trainer.device_ids)) + + # The unmasked real GPU IDs + self._gpu_ids = self._get_gpu_ids(self._device_ids) + + def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + self._snap_intra_step_time: Optional[float] = None + self._snap_inter_step_time: Optional[float] = None + + @rank_zero_only + def on_train_batch_start( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int + ) -> None: + if self._log_stats.intra_step_time: + self._snap_intra_step_time = time.time() + + if not trainer._logger_connector.should_update_logs: + return + + gpu_stat_keys = self._get_gpu_stat_keys() + gpu_stats = self._get_gpu_stats([k for k, _ in gpu_stat_keys]) + logs = self._parse_gpu_stats(self._device_ids, gpu_stats, gpu_stat_keys) + + if self._log_stats.inter_step_time and self._snap_inter_step_time: + # First log at beginning of second step + logs["batch_time/inter_step (ms)"] = (time.time() - self._snap_inter_step_time) * 1000 + + for logger in trainer.loggers: + logger.log_metrics(logs, step=trainer.fit_loop.epoch_loop._batches_that_stepped) + + @rank_zero_only + def on_train_batch_end( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + outputs: STEP_OUTPUT, + batch: Any, + batch_idx: int, + ) -> None: + if self._log_stats.inter_step_time: + self._snap_inter_step_time = time.time() + + if not trainer._logger_connector.should_update_logs: + return + + gpu_stat_keys = self._get_gpu_stat_keys() + self._get_gpu_device_stat_keys() + gpu_stats = self._get_gpu_stats([k for k, _ in gpu_stat_keys]) + logs = self._parse_gpu_stats(self._device_ids, gpu_stats, gpu_stat_keys) + + if self._log_stats.intra_step_time and self._snap_intra_step_time: + logs["batch_time/intra_step (ms)"] = (time.time() - self._snap_intra_step_time) * 1000 + + for logger in trainer.loggers: + logger.log_metrics(logs, step=trainer.fit_loop.epoch_loop._batches_that_stepped) + + @staticmethod + def _get_gpu_ids(device_ids: List[int]) -> List[str]: + """Get the unmasked real GPU IDs.""" + # All devices if `CUDA_VISIBLE_DEVICES` unset + default = ",".join(str(i) for i in range(torch.cuda.device_count())) + cuda_visible_devices: List[str] = os.getenv("CUDA_VISIBLE_DEVICES", default=default).split(",") + return [cuda_visible_devices[device_id].strip() for device_id in device_ids] + + def _get_gpu_stats(self, queries: List[str]) -> List[List[float]]: + if not queries: + return [] + + """Run nvidia-smi to get the gpu stats""" + gpu_query = ",".join(queries) + format = "csv,nounits,noheader" + gpu_ids = ",".join(self._gpu_ids) + result = subprocess.run( + [ + # it's ok to suppress the warning here since we ensure nvidia-smi exists during init + shutil.which("nvidia-smi"), # type: ignore + f"--query-gpu={gpu_query}", + f"--format={format}", + f"--id={gpu_ids}", + ], + encoding="utf-8", + capture_output=True, + check=True, + ) + + def _to_float(x: str) -> float: + try: + return float(x) + except ValueError: + return 0.0 + + stats = [[_to_float(x) for x in s.split(", ")] for s in result.stdout.strip().split(os.linesep)] + return stats + + @staticmethod + def _parse_gpu_stats( + device_ids: List[int], stats: List[List[float]], keys: List[Tuple[str, str]] + ) -> Dict[str, float]: + """Parse the gpu stats into a loggable dict.""" + logs = {} + for i, device_id in enumerate(device_ids): + for j, (x, unit) in enumerate(keys): + logs[f"device_id: {device_id}/{x} ({unit})"] = stats[i][j] + return logs + + def _get_gpu_stat_keys(self) -> List[Tuple[str, str]]: + """Get the GPU stats keys.""" + stat_keys = [] + + if self._log_stats.gpu_utilization: + stat_keys.append(("utilization.gpu", "%")) + + if self._log_stats.memory_utilization: + stat_keys.extend([("memory.used", "MB"), ("memory.free", "MB"), ("utilization.memory", "%")]) + + return stat_keys + + def _get_gpu_device_stat_keys(self) -> List[Tuple[str, str]]: + """Get the device stats keys.""" + stat_keys = [] + + if self._log_stats.fan_speed: + stat_keys.append(("fan.speed", "%")) + + if self._log_stats.temperature: + stat_keys.extend([("temperature.gpu", "°C"), ("temperature.memory", "°C")]) + + return stat_keys diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/lambda_function.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/lambda_function.py new file mode 100644 index 0000000000000000000000000000000000000000..1813e7d19090f42eccff2d3cd89c7728c51b719a --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/lambda_function.py @@ -0,0 +1,96 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +r""" +Lambda Callback +^^^^^^^^^^^^^^^ + +Create a simple callback on the fly using lambda functions. + +""" + +from typing import Callable, Optional + +from pytorch_lightning.callbacks.base import Callback + + +class LambdaCallback(Callback): + r""" + Create a simple callback on the fly using lambda functions. + + Args: + **kwargs: hooks supported by :class:`~pytorch_lightning.callbacks.base.Callback` + + Example:: + + >>> from pytorch_lightning import Trainer + >>> from pytorch_lightning.callbacks import LambdaCallback + >>> trainer = Trainer(callbacks=[LambdaCallback(setup=lambda *args: print('setup'))]) + """ + + def __init__( + self, + on_before_accelerator_backend_setup: Optional[Callable] = None, + setup: Optional[Callable] = None, + on_configure_sharded_model: Optional[Callable] = None, + teardown: Optional[Callable] = None, + on_init_start: Optional[Callable] = None, + on_init_end: Optional[Callable] = None, + on_fit_start: Optional[Callable] = None, + on_fit_end: Optional[Callable] = None, + on_sanity_check_start: Optional[Callable] = None, + on_sanity_check_end: Optional[Callable] = None, + on_train_batch_start: Optional[Callable] = None, + on_train_batch_end: Optional[Callable] = None, + on_train_epoch_start: Optional[Callable] = None, + on_train_epoch_end: Optional[Callable] = None, + on_validation_epoch_start: Optional[Callable] = None, + on_validation_epoch_end: Optional[Callable] = None, + on_test_epoch_start: Optional[Callable] = None, + on_test_epoch_end: Optional[Callable] = None, + on_epoch_start: Optional[Callable] = None, + on_epoch_end: Optional[Callable] = None, + on_batch_start: Optional[Callable] = None, + on_validation_batch_start: Optional[Callable] = None, + on_validation_batch_end: Optional[Callable] = None, + on_test_batch_start: Optional[Callable] = None, + on_test_batch_end: Optional[Callable] = None, + on_batch_end: Optional[Callable] = None, + on_train_start: Optional[Callable] = None, + on_train_end: Optional[Callable] = None, + on_pretrain_routine_start: Optional[Callable] = None, + on_pretrain_routine_end: Optional[Callable] = None, + on_validation_start: Optional[Callable] = None, + on_validation_end: Optional[Callable] = None, + on_test_start: Optional[Callable] = None, + on_test_end: Optional[Callable] = None, + on_keyboard_interrupt: Optional[Callable] = None, + on_exception: Optional[Callable] = None, + on_save_checkpoint: Optional[Callable] = None, + on_load_checkpoint: Optional[Callable] = None, + on_before_backward: Optional[Callable] = None, + on_after_backward: Optional[Callable] = None, + on_before_optimizer_step: Optional[Callable] = None, + on_before_zero_grad: Optional[Callable] = None, + on_predict_start: Optional[Callable] = None, + on_predict_end: Optional[Callable] = None, + on_predict_batch_start: Optional[Callable] = None, + on_predict_batch_end: Optional[Callable] = None, + on_predict_epoch_start: Optional[Callable] = None, + on_predict_epoch_end: Optional[Callable] = None, + ): + for k, v in locals().items(): + if k == "self": + continue + if v is not None: + setattr(self, k, v) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/lr_monitor.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/lr_monitor.py new file mode 100644 index 0000000000000000000000000000000000000000..b14985857511880e935f8f4b957116dfbff59a65 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/lr_monitor.py @@ -0,0 +1,354 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +r""" + +Learning Rate Monitor +===================== + +Monitor and logs learning rate for lr schedulers during training. + +""" +import itertools +from collections import defaultdict +from typing import Any, DefaultDict, Dict, List, Optional, Set, Tuple, Type + +from torch.optim.optimizer import Optimizer + +import pytorch_lightning as pl +from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn +from pytorch_lightning.utilities.types import LRSchedulerConfig + + +class LearningRateMonitor(Callback): + r""" + Automatically monitor and logs learning rate for learning rate schedulers during training. + + Args: + logging_interval: set to ``'epoch'`` or ``'step'`` to log ``lr`` of all optimizers + at the same interval, set to ``None`` to log at individual interval + according to the ``interval`` key of each scheduler. Defaults to ``None``. + log_momentum: option to also log the momentum values of the optimizer, if the optimizer + has the ``momentum`` or ``betas`` attribute. Defaults to ``False``. + + Raises: + MisconfigurationException: + If ``logging_interval`` is none of ``"step"``, ``"epoch"``, or ``None``. + + Example:: + + >>> from pytorch_lightning import Trainer + >>> from pytorch_lightning.callbacks import LearningRateMonitor + >>> lr_monitor = LearningRateMonitor(logging_interval='step') + >>> trainer = Trainer(callbacks=[lr_monitor]) + + Logging names are automatically determined based on optimizer class name. + In case of multiple optimizers of same type, they will be named ``Adam``, + ``Adam-1`` etc. If a optimizer has multiple parameter groups they will + be named ``Adam/pg1``, ``Adam/pg2`` etc. To control naming, pass in a + ``name`` keyword in the construction of the learning rate schedulers. + A ``name`` keyword can also be used for parameter groups in the + construction of the optimizer. + + Example:: + + def configure_optimizer(self): + optimizer = torch.optim.Adam(...) + lr_scheduler = { + 'scheduler': torch.optim.lr_scheduler.LambdaLR(optimizer, ...) + 'name': 'my_logging_name' + } + return [optimizer], [lr_scheduler] + + Example:: + + def configure_optimizer(self): + optimizer = torch.optim.SGD( + [{ + 'params': [p for p in self.parameters()], + 'name': 'my_parameter_group_name' + }], + lr=0.1 + ) + lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, ...) + return [optimizer], [lr_scheduler] + + """ + + def __init__(self, logging_interval: Optional[str] = None, log_momentum: bool = False) -> None: + if logging_interval not in (None, "step", "epoch"): + raise MisconfigurationException("logging_interval should be `step` or `epoch` or `None`.") + + self.logging_interval = logging_interval + self.log_momentum = log_momentum + self.lrs: Dict[str, List[float]] = {} + self._lr_sch_names: List[str] = [] + + def on_train_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: + """Called before training, determines unique names for all lr schedulers in the case of multiple of the + same type or in the case of multiple parameter groups. + + Raises: + MisconfigurationException: + If ``Trainer`` has no ``logger``. + """ + if not trainer.loggers: + raise MisconfigurationException( + "Cannot use `LearningRateMonitor` callback with `Trainer` that has no logger." + ) + + if self.log_momentum: + + def _check_no_key(key: str) -> bool: + if trainer.lr_scheduler_configs: + return any( + key not in config.scheduler.optimizer.defaults for config in trainer.lr_scheduler_configs + ) + + return any(key not in optimizer.defaults for optimizer in trainer.optimizers) + + if _check_no_key("momentum") and _check_no_key("betas"): + rank_zero_warn( + "You have set log_momentum=True, but some optimizers do not" + " have momentum. This will log a value 0 for the momentum.", + category=RuntimeWarning, + ) + + # Find names for schedulers + names: List[List[str]] = [] + ( + sched_hparam_keys, + optimizers_with_scheduler, + optimizers_with_scheduler_types, + ) = self._find_names_from_schedulers(trainer.lr_scheduler_configs) + names.extend(sched_hparam_keys) + + # Find names for leftover optimizers + optimizer_hparam_keys, _ = self._find_names_from_optimizers( + trainer.optimizers, + seen_optimizers=optimizers_with_scheduler, + seen_optimizer_types=optimizers_with_scheduler_types, + ) + names.extend(optimizer_hparam_keys) + + # Initialize for storing values + names_flatten = list(itertools.chain.from_iterable(names)) + self.lrs = {name: [] for name in names_flatten} + self.last_momentum_values = {name + "-momentum": None for name in names_flatten} + + def on_train_batch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: + if not trainer._logger_connector.should_update_logs: + return + + if self.logging_interval != "epoch": + interval = "step" if self.logging_interval is None else "any" + latest_stat = self._extract_stats(trainer, interval) + + if latest_stat: + for logger in trainer.loggers: + logger.log_metrics(latest_stat, step=trainer.fit_loop.epoch_loop._batches_that_stepped) + + def on_train_epoch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: + if self.logging_interval != "step": + interval = "epoch" if self.logging_interval is None else "any" + latest_stat = self._extract_stats(trainer, interval) + + if latest_stat: + for logger in trainer.loggers: + logger.log_metrics(latest_stat, step=trainer.fit_loop.epoch_loop._batches_that_stepped) + + def _extract_stats(self, trainer: "pl.Trainer", interval: str) -> Dict[str, float]: + latest_stat = {} + + ( + scheduler_hparam_keys, + optimizers_with_scheduler, + optimizers_with_scheduler_types, + ) = self._find_names_from_schedulers(trainer.lr_scheduler_configs, add_lr_sch_names=False) + self._remap_keys(scheduler_hparam_keys) + + for name, config in zip(scheduler_hparam_keys, trainer.lr_scheduler_configs): + if interval in [config.interval, "any"]: + opt = config.scheduler.optimizer + current_stat = self._get_lr_momentum_stat(opt, name) + latest_stat.update(current_stat) + + optimizer_hparam_keys, optimizers_without_scheduler = self._find_names_from_optimizers( + trainer.optimizers, + seen_optimizers=optimizers_with_scheduler, + seen_optimizer_types=optimizers_with_scheduler_types, + add_lr_sch_names=False, + ) + self._remap_keys(optimizer_hparam_keys) + + for opt, names in zip(optimizers_without_scheduler, optimizer_hparam_keys): + current_stat = self._get_lr_momentum_stat(opt, names) + latest_stat.update(current_stat) + + return latest_stat + + def _get_lr_momentum_stat(self, optimizer: Optimizer, names: List[str]) -> Dict[str, float]: + lr_momentum_stat = {} + param_groups = optimizer.param_groups + use_betas = "betas" in optimizer.defaults + + for pg, name in zip(param_groups, names): + lr = self._extract_lr(pg, name) + lr_momentum_stat.update(lr) + momentum = self._extract_momentum( + param_group=pg, name=name.replace(name, f"{name}-momentum"), use_betas=use_betas + ) + lr_momentum_stat.update(momentum) + + return lr_momentum_stat + + def _extract_lr(self, param_group: Dict[str, Any], name: str) -> Dict[str, Any]: + lr = param_group["lr"] + self.lrs[name].append(lr) + return {name: lr} + + def _remap_keys(self, names: List[List[str]], token: str = "/pg1") -> None: + """This function is used the remap the keys if param groups for a given optimizer increased.""" + for group_new_names in names: + for new_name in group_new_names: + old_name = new_name.replace(token, "") + if token in new_name and old_name in self.lrs: + self.lrs[new_name] = self.lrs.pop(old_name) + elif new_name not in self.lrs: + self.lrs[new_name] = [] + + def _extract_momentum(self, param_group: Dict[str, List], name: str, use_betas: bool) -> Dict[str, float]: + if not self.log_momentum: + return {} + + momentum = param_group["betas"][0] if use_betas else param_group.get("momentum", 0) + self.last_momentum_values[name] = momentum + return {name: momentum} + + def _add_prefix( + self, name: str, optimizer_cls: Type[Optimizer], seen_optimizer_types: DefaultDict[Type[Optimizer], int] + ) -> str: + if optimizer_cls not in seen_optimizer_types: + return name + count = seen_optimizer_types[optimizer_cls] + return name + f"-{count - 1}" if count > 1 else name + + def _add_suffix(self, name: str, param_groups: List[Dict], param_group_index: int, use_names: bool = True) -> str: + if len(param_groups) > 1: + if not use_names: + return f"{name}/pg{param_group_index+1}" + pg_name = param_groups[param_group_index].get("name", f"pg{param_group_index+1}") + return f"{name}/{pg_name}" + elif use_names: + pg_name = param_groups[param_group_index].get("name") + return f"{name}/{pg_name}" if pg_name else name + return name + + def _duplicate_param_group_names(self, param_groups: List[Dict]) -> Set[str]: + names = [pg.get("name", f"pg{i}") for i, pg in enumerate(param_groups, start=1)] + unique = set(names) + if len(names) == len(unique): + return set() + return {n for n in names if names.count(n) > 1} + + def _find_names_from_schedulers( + self, lr_scheduler_configs: List[LRSchedulerConfig], add_lr_sch_names: bool = True + ) -> Tuple[List[List[str]], List[Optimizer], DefaultDict[Type[Optimizer], int]]: + # Create unique names in the case we have multiple of the same learning + # rate scheduler + multiple parameter groups + names = [] + seen_optimizers: List[Optimizer] = [] + seen_optimizer_types: DefaultDict[Type[Optimizer], int] = defaultdict(int) + for config in lr_scheduler_configs: + sch = config.scheduler + if config.name is not None: + name = config.name + else: + name = "lr-" + sch.optimizer.__class__.__name__ + + updated_names = self._check_duplicates_and_update_name( + sch.optimizer, name, seen_optimizers, seen_optimizer_types, config, add_lr_sch_names + ) + names.append(updated_names) + + return names, seen_optimizers, seen_optimizer_types + + def _find_names_from_optimizers( + self, + optimizers: List[Any], + seen_optimizers: List[Optimizer], + seen_optimizer_types: DefaultDict[Type[Optimizer], int], + add_lr_sch_names: bool = True, + ) -> Tuple[List[List[str]], List[Optimizer]]: + names = [] + optimizers_without_scheduler = [] + + for optimizer in optimizers: + # Deepspeed optimizer wraps the native optimizer + optimizer = optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer + if optimizer in seen_optimizers: + continue + + name = "lr-" + optimizer.__class__.__name__ + updated_names = self._check_duplicates_and_update_name( + optimizer, name, seen_optimizers, seen_optimizer_types, None, add_lr_sch_names + ) + names.append(updated_names) + optimizers_without_scheduler.append(optimizer) + + return names, optimizers_without_scheduler + + def _check_duplicates_and_update_name( + self, + optimizer: Optimizer, + name: str, + seen_optimizers: List[Optimizer], + seen_optimizer_types: DefaultDict[Type[Optimizer], int], + lr_scheduler_config: Optional[LRSchedulerConfig], + add_lr_sch_names: bool = True, + ) -> List[str]: + seen_optimizers.append(optimizer) + optimizer_cls = type(optimizer) + if lr_scheduler_config is not None and lr_scheduler_config.name is None: + seen_optimizer_types[optimizer_cls] += 1 + elif lr_scheduler_config is None: + seen_optimizer_types[optimizer_cls] += 1 + + # Multiple param groups for the same optimizer + param_groups = optimizer.param_groups + duplicates = self._duplicate_param_group_names(param_groups) + if duplicates: + raise MisconfigurationException( + "A single `Optimizer` cannot have multiple parameter groups with identical " + f"`name` values. {name} has duplicated parameter group names {duplicates}" + ) + + name = self._add_prefix(name, optimizer_cls, seen_optimizer_types) + name_list = [self._add_suffix(name, param_groups, i) for i in range(len(param_groups))] + + if add_lr_sch_names: + self._lr_sch_names.append(name) + + return name_list + + @property + def lr_sch_names(self) -> List[str]: + # TODO remove `lr_sch_names` and `add_lr_sch_names` argument in v1.7.0 + rank_zero_deprecation( + "`LearningRateMonitor.lr_sch_names` has been deprecated in v1.5 and will be removed in 1.7." + " Consider accessing them using `LearningRateMonitor.lrs.keys()` which will return" + " the names of all the optimizers, even those without a scheduler." + ) + return self._lr_sch_names diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..0621ea8eb2cc3aea8b96cb47ec18ccdfe2034384 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py @@ -0,0 +1,720 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Model Checkpointing +=================== + +Automatically save model checkpoints during training. + +""" +import logging +import os +import re +import time +import warnings +from copy import deepcopy +from datetime import timedelta +from typing import Any, Dict, Optional +from weakref import proxy + +import numpy as np +import torch +import yaml + +import pytorch_lightning as pl +from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.utilities.cloud_io import get_filesystem +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.logger import _name, _version +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn +from pytorch_lightning.utilities.types import _METRIC, _PATH, STEP_OUTPUT +from pytorch_lightning.utilities.warnings import WarningCache + +log = logging.getLogger(__name__) +warning_cache = WarningCache() + + +class ModelCheckpoint(Callback): + r""" + Save the model periodically by monitoring a quantity. Every metric logged with + :meth:`~pytorch_lightning.core.lightning.log` or :meth:`~pytorch_lightning.core.lightning.log_dict` in + LightningModule is a candidate for the monitor key. For more information, see + :ref:`checkpointing`. + + After training finishes, use :attr:`best_model_path` to retrieve the path to the + best checkpoint file and :attr:`best_model_score` to retrieve its score. + + Args: + dirpath: directory to save the model file. + + Example:: + + # custom path + # saves a file like: my/path/epoch=0-step=10.ckpt + >>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/') + + By default, dirpath is ``None`` and will be set at runtime to the location + specified by :class:`~pytorch_lightning.trainer.trainer.Trainer`'s + :paramref:`~pytorch_lightning.trainer.trainer.Trainer.default_root_dir` or + :paramref:`~pytorch_lightning.trainer.trainer.Trainer.weights_save_path` arguments, + and if the Trainer uses a logger, the path will also contain logger name and version. + + filename: checkpoint filename. Can contain named formatting options to be auto-filled. + + Example:: + + # save any arbitrary metrics like `val_loss`, etc. in name + # saves a file like: my/path/epoch=2-val_loss=0.02-other_metric=0.03.ckpt + >>> checkpoint_callback = ModelCheckpoint( + ... dirpath='my/path', + ... filename='{epoch}-{val_loss:.2f}-{other_metric:.2f}' + ... ) + + By default, filename is ``None`` and will be set to ``'{epoch}-{step}'``. + monitor: quantity to monitor. By default it is ``None`` which saves a checkpoint only for the last epoch. + verbose: verbosity mode. Default: ``False``. + save_last: When ``True``, saves an exact copy of the checkpoint to a file `last.ckpt` whenever a checkpoint + file gets saved. This allows accessing the latest checkpoint in a deterministic manner. Default: ``None``. + save_top_k: if ``save_top_k == k``, + the best k models according to + the quantity monitored will be saved. + if ``save_top_k == 0``, no models are saved. + if ``save_top_k == -1``, all models are saved. + Please note that the monitors are checked every ``every_n_epochs`` epochs. + if ``save_top_k >= 2`` and the callback is called multiple + times inside an epoch, the name of the saved file will be + appended with a version count starting with ``v1``. + mode: one of {min, max}. + If ``save_top_k != 0``, the decision to overwrite the current save file is made + based on either the maximization or the minimization of the monitored quantity. + For ``'val_acc'``, this should be ``'max'``, for ``'val_loss'`` this should be ``'min'``, etc. + auto_insert_metric_name: When ``True``, the checkpoints filenames will contain the metric name. + For example, ``filename='checkpoint_{epoch:02d}-{acc:02.0f}`` with epoch ``1`` and acc ``1.12`` will resolve + to ``checkpoint_epoch=01-acc=01.ckpt``. Is useful to set it to ``False`` when metric names contain ``/`` + as this will result in extra folders. + save_weights_only: if ``True``, then only the model's weights will be + saved. Otherwise, the optimizer states, lr-scheduler states, etc are added in the checkpoint too. + every_n_train_steps: Number of training steps between checkpoints. + If ``every_n_train_steps == None or every_n_train_steps == 0``, we skip saving during training. + To disable, set ``every_n_train_steps = 0``. This value must be ``None`` or non-negative. + This must be mutually exclusive with ``train_time_interval`` and ``every_n_epochs``. + train_time_interval: Checkpoints are monitored at the specified time interval. + For all practical purposes, this cannot be smaller than the amount + of time it takes to process a single training batch. This is not + guaranteed to execute at the exact time specified, but should be close. + This must be mutually exclusive with ``every_n_train_steps`` and ``every_n_epochs``. + every_n_epochs: Number of epochs between checkpoints. + This value must be ``None`` or non-negative. + To disable saving top-k checkpoints, set ``every_n_epochs = 0``. + This argument does not impact the saving of ``save_last=True`` checkpoints. + If all of ``every_n_epochs``, ``every_n_train_steps`` and + ``train_time_interval`` are ``None``, we save a checkpoint at the end of every epoch + (equivalent to ``every_n_epochs = 1``). + If ``every_n_epochs == None`` and either ``every_n_train_steps != None`` or ``train_time_interval != None``, + saving at the end of each epoch is disabled + (equivalent to ``every_n_epochs = 0``). + This must be mutually exclusive with ``every_n_train_steps`` and ``train_time_interval``. + Setting both ``ModelCheckpoint(..., every_n_epochs=V, save_on_train_epoch_end=False)`` and + ``Trainer(max_epochs=N, check_val_every_n_epoch=M)`` + will only save checkpoints at epochs 0 < E <= N + where both values for ``every_n_epochs`` and ``check_val_every_n_epoch`` evenly divide E. + save_on_train_epoch_end: Whether to run checkpointing at the end of the training epoch. + If this is ``False``, then the check runs at the end of the validation. + + Note: + For extra customization, ModelCheckpoint includes the following attributes: + + - ``CHECKPOINT_JOIN_CHAR = "-"`` + - ``CHECKPOINT_NAME_LAST = "last"`` + - ``FILE_EXTENSION = ".ckpt"`` + - ``STARTING_VERSION = 1`` + + For example, you can change the default last checkpoint name by doing + ``checkpoint_callback.CHECKPOINT_NAME_LAST = "{epoch}-last"`` + + If you want to checkpoint every N hours, every M train batches, and/or every K val epochs, + then you should create multiple ``ModelCheckpoint`` callbacks. + + If the checkpoint's ``dirpath`` changed from what it was before while resuming the training, + only ``best_model_path`` will be reloaded and a warning will be issued. + + Raises: + MisconfigurationException: + If ``save_top_k`` is smaller than ``-1``, + if ``monitor`` is ``None`` and ``save_top_k`` is none of ``None``, ``-1``, and ``0``, or + if ``mode`` is none of ``"min"`` or ``"max"``. + ValueError: + If ``trainer.save_checkpoint`` is ``None``. + + Example:: + + >>> from pytorch_lightning import Trainer + >>> from pytorch_lightning.callbacks import ModelCheckpoint + + # saves checkpoints to 'my/path/' at every epoch + >>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/') + >>> trainer = Trainer(callbacks=[checkpoint_callback]) + + # save epoch and val_loss in name + # saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt + >>> checkpoint_callback = ModelCheckpoint( + ... monitor='val_loss', + ... dirpath='my/path/', + ... filename='sample-mnist-{epoch:02d}-{val_loss:.2f}' + ... ) + + # save epoch and val_loss in name, but specify the formatting yourself (e.g. to avoid problems with Tensorboard + # or Neptune, due to the presence of characters like '=' or '/') + # saves a file like: my/path/sample-mnist-epoch02-val_loss0.32.ckpt + >>> checkpoint_callback = ModelCheckpoint( + ... monitor='val/loss', + ... dirpath='my/path/', + ... filename='sample-mnist-epoch{epoch:02d}-val_loss{val/loss:.2f}', + ... auto_insert_metric_name=False + ... ) + + # retrieve the best checkpoint after training + checkpoint_callback = ModelCheckpoint(dirpath='my/path/') + trainer = Trainer(callbacks=[checkpoint_callback]) + model = ... + trainer.fit(model) + checkpoint_callback.best_model_path + + .. tip:: Saving and restoring multiple checkpoint callbacks at the same time is supported under variation in the + following arguments: + + *monitor, mode, every_n_train_steps, every_n_epochs, train_time_interval, save_on_train_epoch_end* + + Read more: :ref:`Persisting Callback State` + """ + + CHECKPOINT_JOIN_CHAR = "-" + CHECKPOINT_NAME_LAST = "last" + FILE_EXTENSION = ".ckpt" + STARTING_VERSION = 1 + + def __init__( + self, + dirpath: Optional[_PATH] = None, + filename: Optional[str] = None, + monitor: Optional[str] = None, + verbose: bool = False, + save_last: Optional[bool] = None, + save_top_k: int = 1, + save_weights_only: bool = False, + mode: str = "min", + auto_insert_metric_name: bool = True, + every_n_train_steps: Optional[int] = None, + train_time_interval: Optional[timedelta] = None, + every_n_epochs: Optional[int] = None, + save_on_train_epoch_end: Optional[bool] = None, + ): + super().__init__() + self.monitor = monitor + self.verbose = verbose + self.save_last = save_last + self.save_top_k = save_top_k + self.save_weights_only = save_weights_only + self.auto_insert_metric_name = auto_insert_metric_name + self._save_on_train_epoch_end = save_on_train_epoch_end + self._last_global_step_saved = 0 # no need to save when no steps were taken + self._last_time_checked: Optional[float] = None + self.current_score = None + self.best_k_models = {} + self.kth_best_model_path = "" + self.best_model_score = None + self.best_model_path = "" + self.last_model_path = "" + + self.__init_monitor_mode(mode) + self.__init_ckpt_dir(dirpath, filename) + self.__init_triggers(every_n_train_steps, every_n_epochs, train_time_interval) + self.__validate_init_configuration() + + @property + def state_key(self) -> str: + return self._generate_state_key( + monitor=self.monitor, + mode=self.mode, + every_n_train_steps=self._every_n_train_steps, + every_n_epochs=self._every_n_epochs, + train_time_interval=self._train_time_interval, + save_on_train_epoch_end=self._save_on_train_epoch_end, + ) + + def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: + self.__resolve_ckpt_dir(trainer) + if trainer.is_global_zero and stage == "fit": + self.__warn_if_dir_not_empty(self.dirpath) + + # NOTE: setting these attributes needs to happen as early as possible BEFORE reloading callback states, + # because the attributes are part of the state_key which needs to be fully defined before reloading. + if self._save_on_train_epoch_end is None: + # if the user runs validation multiple times per training epoch or multiple training epochs without + # validation, then we run after validation instead of on train epoch end + self._save_on_train_epoch_end = trainer.val_check_interval == 1.0 and trainer.check_val_every_n_epoch == 1 + + def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + self._last_time_checked = time.monotonic() + + def on_train_batch_end( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + outputs: STEP_OUTPUT, + batch: Any, + batch_idx: int, + ) -> None: + """Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps`""" + if self._should_skip_saving_checkpoint(trainer): + return + skip_batch = self._every_n_train_steps < 1 or (trainer.global_step % self._every_n_train_steps != 0) + + train_time_interval = self._train_time_interval + skip_time = True + now = time.monotonic() + if train_time_interval: + prev_time_check = self._last_time_checked + skip_time = prev_time_check is None or (now - prev_time_check) < train_time_interval.total_seconds() + # in case we have time differences across ranks + # broadcast the decision on whether to checkpoint from rank 0 to avoid possible hangs + skip_time = trainer.strategy.broadcast(skip_time) + + if skip_batch and skip_time: + return + if not skip_time: + self._last_time_checked = now + + monitor_candidates = self._monitor_candidates(trainer) + self._save_topk_checkpoint(trainer, monitor_candidates) + self._save_last_checkpoint(trainer, monitor_candidates) + + def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Save a checkpoint at the end of the training epoch.""" + if not self._should_skip_saving_checkpoint(trainer) and self._save_on_train_epoch_end: + monitor_candidates = self._monitor_candidates(trainer) + if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0: + self._save_topk_checkpoint(trainer, monitor_candidates) + self._save_last_checkpoint(trainer, monitor_candidates) + + def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Save a checkpoint at the end of the validation stage.""" + if not self._should_skip_saving_checkpoint(trainer) and not self._save_on_train_epoch_end: + monitor_candidates = self._monitor_candidates(trainer) + if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0: + self._save_topk_checkpoint(trainer, monitor_candidates) + self._save_last_checkpoint(trainer, monitor_candidates) + + def state_dict(self) -> Dict[str, Any]: + return { + "monitor": self.monitor, + "best_model_score": self.best_model_score, + "best_model_path": self.best_model_path, + "current_score": self.current_score, + "dirpath": self.dirpath, + "best_k_models": self.best_k_models, + "kth_best_model_path": self.kth_best_model_path, + "kth_value": self.kth_value, + "last_model_path": self.last_model_path, + } + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + dirpath_from_ckpt = state_dict.get("dirpath", self.dirpath) + + if self.dirpath == dirpath_from_ckpt: + self.best_model_score = state_dict["best_model_score"] + self.kth_best_model_path = state_dict.get("kth_best_model_path", self.kth_best_model_path) + self.kth_value = state_dict.get("kth_value", self.kth_value) + self.best_k_models = state_dict.get("best_k_models", self.best_k_models) + self.last_model_path = state_dict.get("last_model_path", self.last_model_path) + else: + warnings.warn( + f"The dirpath has changed from {dirpath_from_ckpt!r} to {self.dirpath!r}," + " therefore `best_model_score`, `kth_best_model_path`, `kth_value`, `last_model_path` and" + " `best_k_models` won't be reloaded. Only `best_model_path` will be reloaded." + ) + + self.best_model_path = state_dict["best_model_path"] + + def save_checkpoint(self, trainer: "pl.Trainer") -> None: # pragma: no-cover + """Performs the main logic around saving a checkpoint. + + This method runs on all ranks. It is the responsibility of `trainer.save_checkpoint` to correctly handle the + behaviour in distributed training, i.e., saving only on rank 0 for data parallel use cases. + """ + rank_zero_deprecation( + f"`{self.__class__.__name__}.save_checkpoint()` was deprecated in v1.6 and will be removed in v1.8." + " Instead, you can use `trainer.save_checkpoint()` to manually save a checkpoint." + ) + monitor_candidates = self._monitor_candidates(trainer) + self._save_topk_checkpoint(trainer, monitor_candidates) + self._save_last_checkpoint(trainer, monitor_candidates) + + def _save_topk_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None: + if self.save_top_k == 0: + return + + # validate metric + if self.monitor is not None: + if self.monitor not in monitor_candidates: + m = ( + f"`ModelCheckpoint(monitor={self.monitor!r})` could not find the monitored key in the returned" + f" metrics: {list(monitor_candidates)}." + f" HINT: Did you call `log({self.monitor!r}, value)` in the `LightningModule`?" + ) + if trainer.fit_loop.epoch_loop.val_loop._has_run: + raise MisconfigurationException(m) + warning_cache.warn(m) + self._save_monitor_checkpoint(trainer, monitor_candidates) + else: + self._save_none_monitor_checkpoint(trainer, monitor_candidates) + + def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None: + trainer.save_checkpoint(filepath, self.save_weights_only) + + self._last_global_step_saved = trainer.global_step + + # notify loggers + if trainer.is_global_zero: + for logger in trainer.loggers: + logger.after_save_checkpoint(proxy(self)) + + def _should_skip_saving_checkpoint(self, trainer: "pl.Trainer") -> bool: + from pytorch_lightning.trainer.states import TrainerFn + + return ( + trainer.fast_dev_run # disable checkpointing with fast_dev_run + or trainer.state.fn != TrainerFn.FITTING # don't save anything during non-fit + or trainer.sanity_checking # don't save anything during sanity check + or self._last_global_step_saved == trainer.global_step # already saved at the last step + ) + + def __validate_init_configuration(self) -> None: + if self.save_top_k < -1: + raise MisconfigurationException(f"Invalid value for save_top_k={self.save_top_k}. Must be >= -1") + if self._every_n_train_steps < 0: + raise MisconfigurationException( + f"Invalid value for every_n_train_steps={self._every_n_train_steps}. Must be >= 0" + ) + if self._every_n_epochs < 0: + raise MisconfigurationException(f"Invalid value for every_n_epochs={self._every_n_epochs}. Must be >= 0") + + every_n_train_steps_triggered = self._every_n_train_steps >= 1 + every_n_epochs_triggered = self._every_n_epochs >= 1 + train_time_interval_triggered = self._train_time_interval is not None + if every_n_train_steps_triggered + every_n_epochs_triggered + train_time_interval_triggered > 1: + raise MisconfigurationException( + f"Combination of parameters every_n_train_steps={self._every_n_train_steps}, " + f"every_n_epochs={self._every_n_epochs} and train_time_interval={self._train_time_interval} " + "should be mutually exclusive." + ) + + if self.monitor is None: + # -1: save all epochs, 0: nothing is saved, 1: save last epoch + if self.save_top_k not in (-1, 0, 1): + raise MisconfigurationException( + f"ModelCheckpoint(save_top_k={self.save_top_k}, monitor=None) is not a valid" + " configuration. No quantity for top_k to track." + ) + + if self.save_top_k == -1 and self.save_last: + rank_zero_info( + "ModelCheckpoint(save_last=True, save_top_k=-1, monitor=None)" + " will duplicate the last checkpoint saved." + ) + + def __init_ckpt_dir(self, dirpath: Optional[_PATH], filename: Optional[str]) -> None: + self._fs = get_filesystem(dirpath if dirpath else "") + + if dirpath and self._fs.protocol == "file": + dirpath = os.path.realpath(dirpath) + + self.dirpath = dirpath + self.filename = filename + + def __init_monitor_mode(self, mode: str) -> None: + torch_inf = torch.tensor(np.Inf) + mode_dict = {"min": (torch_inf, "min"), "max": (-torch_inf, "max")} + + if mode not in mode_dict: + raise MisconfigurationException(f"`mode` can be {', '.join(mode_dict.keys())} but got {mode}") + + self.kth_value, self.mode = mode_dict[mode] + + def __init_triggers( + self, + every_n_train_steps: Optional[int], + every_n_epochs: Optional[int], + train_time_interval: Optional[timedelta], + ) -> None: + + # Default to running once after each validation epoch if neither + # every_n_train_steps nor every_n_epochs is set + if every_n_train_steps is None and every_n_epochs is None and train_time_interval is None: + every_n_epochs = 1 + every_n_train_steps = 0 + log.debug("Both every_n_train_steps and every_n_epochs are not set. Setting every_n_epochs=1") + else: + every_n_epochs = every_n_epochs or 0 + every_n_train_steps = every_n_train_steps or 0 + + self._train_time_interval: Optional[timedelta] = train_time_interval + self._every_n_epochs: int = every_n_epochs + self._every_n_train_steps: int = every_n_train_steps + + @property + def every_n_epochs(self) -> Optional[int]: + return self._every_n_epochs + + def check_monitor_top_k(self, trainer: "pl.Trainer", current: Optional[torch.Tensor] = None) -> bool: + if current is None: + return False + + if self.save_top_k == -1: + return True + + less_than_k_models = len(self.best_k_models) < self.save_top_k + if less_than_k_models: + return True + + monitor_op = {"min": torch.lt, "max": torch.gt}[self.mode] + should_update_best_and_save = monitor_op(current, self.best_k_models[self.kth_best_model_path]) + + # If using multiple devices, make sure all processes are unanimous on the decision. + should_update_best_and_save = trainer.strategy.reduce_boolean_decision(should_update_best_and_save) + + return should_update_best_and_save + + @classmethod + def _format_checkpoint_name( + cls, + filename: Optional[str], + metrics: Dict[str, _METRIC], + prefix: str = "", + auto_insert_metric_name: bool = True, + ) -> str: + if not filename: + # filename is not set, use default name + filename = "{epoch}" + cls.CHECKPOINT_JOIN_CHAR + "{step}" + + # check and parse user passed keys in the string + groups = re.findall(r"(\{.*?)[:\}]", filename) + if len(groups) >= 0: + for group in groups: + name = group[1:] + + if auto_insert_metric_name: + filename = filename.replace(group, name + "={" + name) + + # support for dots: https://stackoverflow.com/a/7934969 + filename = filename.replace(group, f"{{0[{name}]") + + if name not in metrics: + metrics[name] = 0 + filename = filename.format(metrics) + + if prefix: + filename = cls.CHECKPOINT_JOIN_CHAR.join([prefix, filename]) + + return filename + + def format_checkpoint_name( + self, metrics: Dict[str, _METRIC], filename: Optional[str] = None, ver: Optional[int] = None + ) -> str: + """Generate a filename according to the defined template. + + Example:: + + >>> tmpdir = os.path.dirname(__file__) + >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}') + >>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=0))) + 'epoch=0.ckpt' + >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch:03d}') + >>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=5))) + 'epoch=005.ckpt' + >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}-{val_loss:.2f}') + >>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=2, val_loss=0.123456))) + 'epoch=2-val_loss=0.12.ckpt' + >>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=2, val_loss=0.12), filename='{epoch:d}')) + 'epoch=2.ckpt' + >>> ckpt = ModelCheckpoint(dirpath=tmpdir, + ... filename='epoch={epoch}-validation_loss={val_loss:.2f}', + ... auto_insert_metric_name=False) + >>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=2, val_loss=0.123456))) + 'epoch=2-validation_loss=0.12.ckpt' + >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{missing:d}') + >>> os.path.basename(ckpt.format_checkpoint_name({})) + 'missing=0.ckpt' + >>> ckpt = ModelCheckpoint(filename='{step}') + >>> os.path.basename(ckpt.format_checkpoint_name(dict(step=0))) + 'step=0.ckpt' + """ + filename = filename or self.filename + filename = self._format_checkpoint_name(filename, metrics, auto_insert_metric_name=self.auto_insert_metric_name) + + if ver is not None: + filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}")) + + ckpt_name = f"{filename}{self.FILE_EXTENSION}" + return os.path.join(self.dirpath, ckpt_name) if self.dirpath else ckpt_name + + def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> None: + """Determines model checkpoint save directory at runtime. References attributes from the trainer's logger + to determine where to save checkpoints. The base path for saving weights is set in this priority: + + 1. Checkpoint callback's path (if passed in) + 2. The default_root_dir from trainer if trainer has no logger + 3. The weights_save_path from trainer, if user provides it (deprecated) + 4. User provided weights_saved_path + + The base path gets extended with logger name and version (if these are available) + and subfolder "checkpoints". + """ + if self.dirpath is not None: + return # short circuit + + # TODO: Remove weights_save_path logic here in v1.8 + if trainer.loggers: + if trainer._weights_save_path_internal != trainer.default_root_dir: + # the user has changed weights_save_path, it overrides anything + save_dir = trainer._weights_save_path_internal + elif len(trainer.loggers) == 1: + save_dir = trainer.logger.save_dir or trainer.default_root_dir + else: + save_dir = trainer.default_root_dir + + name = _name(trainer.loggers) + version = _version(trainer.loggers) + version = version if isinstance(version, str) else f"version_{version}" + + ckpt_path = os.path.join(save_dir, str(name), version, "checkpoints") + else: + ckpt_path = os.path.join(trainer._weights_save_path_internal, "checkpoints") + + ckpt_path = trainer.strategy.broadcast(ckpt_path) + + self.dirpath = ckpt_path + + def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None: + if self.save_top_k != 0 and self._fs.isdir(dirpath) and len(self._fs.ls(dirpath)) > 0: + rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.") + + def _get_metric_interpolated_filepath_name( + self, monitor_candidates: Dict[str, _METRIC], trainer: "pl.Trainer", del_filepath: Optional[str] = None + ) -> str: + filepath = self.format_checkpoint_name(monitor_candidates) + + version_cnt = self.STARTING_VERSION + while self.file_exists(filepath, trainer) and filepath != del_filepath: + filepath = self.format_checkpoint_name(monitor_candidates, ver=version_cnt) + version_cnt += 1 + + return filepath + + def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, _METRIC]: + monitor_candidates = deepcopy(trainer.callback_metrics) + # cast to int if necessary because `self.log("epoch", 123)` will convert it to float. if it's not a tensor + # or does not exist we overwrite it as it's likely an error + epoch = monitor_candidates.get("epoch") + monitor_candidates["epoch"] = ( + epoch.int() if isinstance(epoch, torch.Tensor) else torch.tensor(trainer.current_epoch) + ) + step = monitor_candidates.get("step") + monitor_candidates["step"] = step.int() if isinstance(step, torch.Tensor) else torch.tensor(trainer.global_step) + return monitor_candidates + + def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None: + if not self.save_last: + return + + filepath = self.format_checkpoint_name(monitor_candidates, self.CHECKPOINT_NAME_LAST) + # set the last model path before saving because it will be part of the state. + previous, self.last_model_path = self.last_model_path, filepath + self._save_checkpoint(trainer, filepath) + if previous and previous != filepath: + trainer.strategy.remove_checkpoint(previous) + + def _save_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None: + current = monitor_candidates.get(self.monitor) + if self.check_monitor_top_k(trainer, current): + self._update_best_and_save(current, trainer, monitor_candidates) + elif self.verbose: + epoch = monitor_candidates["epoch"] + step = monitor_candidates["step"] + rank_zero_info(f"Epoch {epoch:d}, global step {step:d}: {self.monitor!r} was not in top {self.save_top_k}") + + def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None: + filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer) + # set the best model path before saving because it will be part of the state. + previous, self.best_model_path = self.best_model_path, filepath + self._save_checkpoint(trainer, filepath) + if self.save_top_k == 1 and previous and previous != filepath: + trainer.strategy.remove_checkpoint(previous) + + def _update_best_and_save( + self, current: torch.Tensor, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC] + ) -> None: + k = len(self.best_k_models) + 1 if self.save_top_k == -1 else self.save_top_k + + del_filepath = None + if len(self.best_k_models) == k and k > 0: + del_filepath = self.kth_best_model_path + self.best_k_models.pop(del_filepath) + + # do not save nan, replace with +/- inf + if isinstance(current, torch.Tensor) and torch.isnan(current): + current = torch.tensor(float("inf" if self.mode == "min" else "-inf"), device=current.device) + + filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer, del_filepath) + + # save the current score + self.current_score = current + self.best_k_models[filepath] = current + + if len(self.best_k_models) == k: + # monitor dict has reached k elements + _op = max if self.mode == "min" else min + self.kth_best_model_path = _op(self.best_k_models, key=self.best_k_models.get) + self.kth_value = self.best_k_models[self.kth_best_model_path] + + _op = min if self.mode == "min" else max + self.best_model_path = _op(self.best_k_models, key=self.best_k_models.get) + self.best_model_score = self.best_k_models[self.best_model_path] + + if self.verbose: + epoch = monitor_candidates["epoch"] + step = monitor_candidates["step"] + rank_zero_info( + f"Epoch {epoch:d}, global step {step:d}: {self.monitor!r} reached {current:0.5f}" + f" (best {self.best_model_score:0.5f}), saving model to {filepath!r} as top {k}" + ) + self._save_checkpoint(trainer, filepath) + + if del_filepath is not None and filepath != del_filepath: + trainer.strategy.remove_checkpoint(del_filepath) + + def to_yaml(self, filepath: Optional[_PATH] = None) -> None: + """Saves the `best_k_models` dict containing the checkpoint paths with the corresponding scores to a YAML + file.""" + best_k = {k: v.item() for k, v in self.best_k_models.items()} + if filepath is None: + filepath = os.path.join(self.dirpath, "best_k_models.yaml") + with self._fs.open(filepath, "w") as fp: + yaml.dump(best_k, fp) + + def file_exists(self, filepath: _PATH, trainer: "pl.Trainer") -> bool: + """Checks if a file exists on rank 0 and broadcasts the result to all other ranks, preventing the internal + state to diverge between ranks.""" + exists = self._fs.exists(filepath) + return trainer.strategy.broadcast(exists) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_summary.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_summary.py new file mode 100644 index 0000000000000000000000000000000000000000..e659ddd057ace04da81d7721fe9b2ea5743615d2 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_summary.py @@ -0,0 +1,73 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Model Summary +============= + +Generates a summary of all layers in a :class:`~pytorch_lightning.core.lightning.LightningModule`. + +The string representation of this summary prints a table with columns containing +the name, type and number of parameters for each layer. + +""" +import logging +from typing import List, Tuple + +import pytorch_lightning as pl +from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.utilities.model_summary import _format_summary_table, summarize + +log = logging.getLogger(__name__) + + +class ModelSummary(Callback): + r""" + Generates a summary of all layers in a :class:`~pytorch_lightning.core.lightning.LightningModule`. + + Args: + max_depth: The maximum depth of layer nesting that the summary will include. A value of 0 turns the + layer summary off. + + Example:: + + >>> from pytorch_lightning import Trainer + >>> from pytorch_lightning.callbacks import ModelSummary + >>> trainer = Trainer(callbacks=[ModelSummary(max_depth=1)]) + """ + + def __init__(self, max_depth: int = 1) -> None: + self._max_depth: int = max_depth + + def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if not self._max_depth: + return None + + model_summary = summarize(pl_module, max_depth=self._max_depth) + summary_data = model_summary._get_summary_data() + total_parameters = model_summary.total_parameters + trainable_parameters = model_summary.trainable_parameters + model_size = model_summary.model_size + + if trainer.is_global_zero: + self.summarize(summary_data, total_parameters, trainable_parameters, model_size) + + @staticmethod + def summarize( + summary_data: List[Tuple[str, List[str]]], + total_parameters: int, + trainable_parameters: int, + model_size: float, + ) -> None: + summary_table = _format_summary_table(total_parameters, trainable_parameters, model_size, *summary_data) + log.info("\n" + summary_table) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/prediction_writer.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/prediction_writer.py new file mode 100644 index 0000000000000000000000000000000000000000..c7461a82894a1979d1ff2b9c37ac9f4d75a4fb10 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/prediction_writer.py @@ -0,0 +1,119 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +r""" +BasePredictionWriter +==================== + +Aids in saving predictions +""" +from typing import Any, Optional, Sequence + +import pytorch_lightning as pl +from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.utilities import LightningEnum +from pytorch_lightning.utilities.exceptions import MisconfigurationException + + +class WriteInterval(LightningEnum): + BATCH = "batch" + EPOCH = "epoch" + BATCH_AND_EPOCH = "batch_and_epoch" + + @property + def on_batch(self) -> bool: + return self in (self.BATCH, self.BATCH_AND_EPOCH) + + @property + def on_epoch(self) -> bool: + return self in (self.EPOCH, self.BATCH_AND_EPOCH) + + +class BasePredictionWriter(Callback): + """Base class to implement how the predictions should be stored. + + Args: + write_interval: When to write. + + Example:: + + import torch + from pytorch_lightning.callbacks import BasePredictionWriter + + class CustomWriter(BasePredictionWriter): + + def __init__(self, output_dir: str, write_interval: str): + super().__init__(write_interval) + self.output_dir + + def write_on_batch_end( + self, trainer, pl_module: 'LightningModule', prediction: Any, batch_indices: List[int], batch: Any, + batch_idx: int, dataloader_idx: int + ): + torch.save(prediction, os.path.join(self.output_dir, dataloader_idx, f"{batch_idx}.pt")) + + def write_on_epoch_end( + self, trainer, pl_module: 'LightningModule', predictions: List[Any], batch_indices: List[Any] + ): + torch.save(predictions, os.path.join(self.output_dir, "predictions.pt")) + """ + + def __init__(self, write_interval: str = "batch") -> None: + if write_interval not in list(WriteInterval): + raise MisconfigurationException(f"`write_interval` should be one of {[i.value for i in WriteInterval]}.") + self.interval = WriteInterval(write_interval) + + def write_on_batch_end( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + prediction: Any, + batch_indices: Optional[Sequence[int]], + batch: Any, + batch_idx: int, + dataloader_idx: int, + ) -> None: + """Override with the logic to write a single batch.""" + raise NotImplementedError() + + def write_on_epoch_end( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + predictions: Sequence[Any], + batch_indices: Optional[Sequence[Any]], + ) -> None: + """Override with the logic to write all batches.""" + raise NotImplementedError() + + def on_predict_batch_end( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + outputs: Any, + batch: Any, + batch_idx: int, + dataloader_idx: int, + ) -> None: + if not self.interval.on_batch: + return + batch_indices = trainer.predict_loop.epoch_loop.current_batch_indices + self.write_on_batch_end(trainer, pl_module, outputs, batch_indices, batch, batch_idx, dataloader_idx) + + def on_predict_epoch_end( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: Sequence[Any] + ) -> None: + if not self.interval.on_epoch: + return + epoch_batch_indices = trainer.predict_loop.epoch_batch_indices + self.write_on_epoch_end(trainer, pl_module, trainer.predict_loop.predictions, epoch_batch_indices) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/pruning.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/pruning.py new file mode 100644 index 0000000000000000000000000000000000000000..4fdc872a0020f69b8c6da39139d3aff19eac24ac --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/pruning.py @@ -0,0 +1,486 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +r""" +ModelPruning +^^^^^^^^^^^^ +""" +import inspect +import logging +from copy import deepcopy +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + +import torch +import torch.nn.utils.prune as pytorch_prune +from torch import nn +from typing_extensions import TypedDict + +import pytorch_lightning as pl +from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_only + +log = logging.getLogger(__name__) + +_PYTORCH_PRUNING_FUNCTIONS = { + "ln_structured": pytorch_prune.ln_structured, + "l1_unstructured": pytorch_prune.l1_unstructured, + "random_structured": pytorch_prune.random_structured, + "random_unstructured": pytorch_prune.random_unstructured, +} + +_PYTORCH_PRUNING_METHOD = { + "ln_structured": pytorch_prune.LnStructured, + "l1_unstructured": pytorch_prune.L1Unstructured, + "random_structured": pytorch_prune.RandomStructured, + "random_unstructured": pytorch_prune.RandomUnstructured, +} + +_PARAM_TUPLE = Tuple[nn.Module, str] +_PARAM_LIST = Sequence[_PARAM_TUPLE] +_MODULE_CONTAINERS = (LightningModule, nn.Sequential, nn.ModuleList, nn.ModuleDict) + + +class _LayerRef(TypedDict): + data: nn.Module + names: List[Tuple[int, str]] + + +class ModelPruning(Callback): + PARAMETER_NAMES = ("weight", "bias") + + def __init__( + self, + pruning_fn: Union[Callable, str], + parameters_to_prune: _PARAM_LIST = (), + parameter_names: Optional[List[str]] = None, + use_global_unstructured: bool = True, + amount: Union[int, float, Callable[[int], Union[int, float]]] = 0.5, + apply_pruning: Union[bool, Callable[[int], bool]] = True, + make_pruning_permanent: bool = True, + use_lottery_ticket_hypothesis: Union[bool, Callable[[int], bool]] = True, + resample_parameters: bool = False, + pruning_dim: Optional[int] = None, + pruning_norm: Optional[int] = None, + verbose: int = 0, + prune_on_train_epoch_end: bool = True, + ) -> None: + """Model pruning Callback, using PyTorch's prune utilities. This callback is responsible of pruning + networks parameters during training. + + To learn more about pruning with PyTorch, please take a look at + `this tutorial `_. + + .. warning:: ``ModelPruning`` is in beta and subject to change. + + .. code-block:: python + + parameters_to_prune = [(model.mlp_1, "weight"), (model.mlp_2, "weight")] + + trainer = Trainer( + callbacks=[ + ModelPruning( + pruning_fn="l1_unstructured", + parameters_to_prune=parameters_to_prune, + amount=0.01, + use_global_unstructured=True, + ) + ] + ) + + When ``parameters_to_prune`` is ``None``, ``parameters_to_prune`` will contain all parameters from the model. + The user can override ``filter_parameters_to_prune`` to filter any ``nn.Module`` to be pruned. + + Args: + + pruning_fn: Function from torch.nn.utils.prune module or your own PyTorch ``BasePruningMethod`` subclass. + Can also be string e.g. `"l1_unstructured"`. See pytorch docs for more details. + + parameters_to_prune: List of tuples ``(nn.Module, "parameter_name_string")``. + + parameter_names: List of parameter names to be pruned from the nn.Module. + Can either be ``"weight"`` or ``"bias"``. + + use_global_unstructured: Whether to apply pruning globally on the model. + If ``parameters_to_prune`` is provided, global unstructured will be restricted on them. + + amount: Quantity of parameters to prune: + + - ``float``. Between 0.0 and 1.0. Represents the fraction of parameters to prune. + - ``int``. Represents the absolute number of parameters to prune. + - ``Callable``. For dynamic values. Will be called every epoch. Should return a value. + + apply_pruning: Whether to apply pruning. + + - ``bool``. Always apply it or not. + - ``Callable[[epoch], bool]``. For dynamic values. Will be called every epoch. + + make_pruning_permanent: Whether to remove all reparametrization pre-hooks and apply masks + when training ends or the model is saved. + + use_lottery_ticket_hypothesis: See `The lottery ticket hypothesis `_: + + - ``bool``. Whether to apply it or not. + - ``Callable[[epoch], bool]``. For dynamic values. Will be called every epoch. + + resample_parameters: Used with ``use_lottery_ticket_hypothesis``. If True, the model parameters will + be resampled, otherwise, the exact original parameters will be used. + + pruning_dim: If you are using a structured pruning method you need to specify the dimension. + + pruning_norm: If you are using ``ln_structured`` you need to specify the norm. + + verbose: Verbosity level. 0 to disable, 1 to log overall sparsity, 2 to log per-layer sparsity + + prune_on_train_epoch_end: whether to apply pruning at the end of the training epoch. + If this is ``False``, then the check runs at the end of the validation epoch. + + Raises: + MisconfigurationException: + If ``parameter_names`` is neither ``"weight"`` nor ``"bias"``, + if the provided ``pruning_fn`` is not supported, + if ``pruning_dim`` is not provided when ``"unstructured"``, + if ``pruning_norm`` is not provided when ``"ln_structured"``, + if ``pruning_fn`` is neither ``str`` nor :class:`torch.nn.utils.prune.BasePruningMethod`, or + if ``amount`` is none of ``int``, ``float`` and ``Callable``. + """ + + self._use_global_unstructured = use_global_unstructured + self._parameters_to_prune = parameters_to_prune + self._use_lottery_ticket_hypothesis = use_lottery_ticket_hypothesis + self._resample_parameters = resample_parameters + self._prune_on_train_epoch_end = prune_on_train_epoch_end + self._parameter_names = parameter_names or self.PARAMETER_NAMES + self._global_kwargs: Dict[str, Any] = {} + self._original_layers: Optional[Dict[int, _LayerRef]] = None + self._pruning_method_name: Optional[str] = None + + for name in self._parameter_names: + if name not in self.PARAMETER_NAMES: + raise MisconfigurationException( + f"The provided `parameter_names` name: {name} isn't in {self.PARAMETER_NAMES}" + ) + + if isinstance(pruning_fn, str): + pruning_kwargs = {} + pruning_fn = pruning_fn.lower() + if pruning_fn not in _PYTORCH_PRUNING_FUNCTIONS: + raise MisconfigurationException( + f"The provided `pruning_fn` {pruning_fn} isn't available in PyTorch's" + f" built-in functions: {list(_PYTORCH_PRUNING_FUNCTIONS.keys())} " + ) + if pruning_fn.endswith("_structured"): + if pruning_dim is None: + raise MisconfigurationException( + "When requesting `structured` pruning, the `pruning_dim` should be provided." + ) + if pruning_fn == "ln_structured": + if pruning_norm is None: + raise MisconfigurationException( + "When requesting `ln_structured` pruning, the `pruning_norm` should be provided." + ) + pruning_kwargs["n"] = pruning_norm + pruning_kwargs["dim"] = pruning_dim + pruning_fn = self._create_pruning_fn(pruning_fn, **pruning_kwargs) + elif self._is_pruning_method(pruning_fn): + if not use_global_unstructured: + raise MisconfigurationException( + "PyTorch `BasePruningMethod` is currently only supported with `use_global_unstructured=True`." + ) + else: + raise MisconfigurationException( + f"`pruning_fn` is expected to be a str in {list(_PYTORCH_PRUNING_FUNCTIONS.keys())}" + f" or a PyTorch `BasePruningMethod`. Found: {pruning_fn}." + " HINT: if passing a `BasePruningMethod`, pass the the class, not an instance" + ) + + # need to ignore typing here since pytorch base class does not define the PRUNING_TYPE attribute + if use_global_unstructured and pruning_fn.PRUNING_TYPE != "unstructured": # type: ignore + raise MisconfigurationException( + 'Only the "unstructured" PRUNING_TYPE is supported with `use_global_unstructured=True`.' # type: ignore + f" Found method {pruning_fn} of type {pruning_fn.PRUNING_TYPE}. " + ) + + self.pruning_fn = pruning_fn + self._apply_pruning = apply_pruning + self._make_pruning_permanent = make_pruning_permanent + + if not (isinstance(amount, (int, float)) or callable(amount)): + raise MisconfigurationException( + "`amount` should be provided and be either an int, a float or Callable function." + ) + + self.amount = amount + + if verbose not in (0, 1, 2): + raise MisconfigurationException("`verbose` must be any of (0, 1, 2)") + + self._verbose = verbose + + def filter_parameters_to_prune(self, parameters_to_prune: _PARAM_LIST = ()) -> _PARAM_LIST: + """This function can be overridden to control which module to prune.""" + return parameters_to_prune + + def _create_pruning_fn(self, pruning_fn: str, **kwargs: Any) -> Union[Callable, pytorch_prune.BasePruningMethod]: + """This function takes `pruning_fn`, a function name. + + IF use_global_unstructured, pruning_fn will be resolved into its associated ``PyTorch BasePruningMethod`` ELSE, + pruning_fn will be resolved into its function counterpart from `torch.nn.utils.prune`. + """ + pruning_meth = ( + _PYTORCH_PRUNING_METHOD[pruning_fn] + if self._use_global_unstructured + else _PYTORCH_PRUNING_FUNCTIONS[pruning_fn] + ) + assert callable(pruning_meth), "Selected pruning method is not callable" + if self._use_global_unstructured: + self._global_kwargs = kwargs + # save the function __name__ now because partial does not include it + # and there are issues setting the attribute manually in ddp. + self._pruning_method_name = pruning_meth.__name__ + if self._use_global_unstructured: + return pruning_meth + return ModelPruning._wrap_pruning_fn(pruning_meth, **kwargs) + + @staticmethod + def _wrap_pruning_fn(pruning_fn: Callable, **kwargs: Any) -> Callable: + return partial(pruning_fn, **kwargs) + + def make_pruning_permanent(self, module: nn.Module) -> None: + """Removes pruning buffers from any pruned modules. + + Adapted from https://github.com/pytorch/pytorch/blob/1.7.1/torch/nn/utils/prune.py#L1176-L1180 + """ + for _, module in module.named_modules(): + for k in list(module._forward_pre_hooks): + hook = module._forward_pre_hooks[k] + if isinstance(hook, pytorch_prune.BasePruningMethod): + hook.remove(module) + del module._forward_pre_hooks[k] + + @staticmethod + def _copy_param(new: nn.Module, old: nn.Module, name: str) -> None: + dst = getattr(new, name) + src = getattr(old, name) + if dst is None or src is None or not isinstance(dst, torch.Tensor) or not isinstance(src, torch.Tensor): + return + dst.data = src.data.to(dst.device) + + def apply_lottery_ticket_hypothesis(self) -> None: + r""" + Lottery ticket hypothesis algorithm (see page 2 of the paper): + + 1. Randomly initialize a neural network :math:`f(x; \theta_0)` (where :math:`\theta_0 \sim \mathcal{D}_\theta`). + 2. Train the network for :math:`j` iterations, arriving at parameters :math:`\theta_j`. + 3. Prune :math:`p\%` of the parameters in :math:`\theta_j`, creating a mask :math:`m`. + 4. Reset the remaining parameters to their values in :math:`\theta_0`, creating the winning ticket :math:`f(x; m \odot \theta_0)`. + + This function implements the step 4. + + The ``resample_parameters`` argument can be used to reset the parameters with a new :math:`\theta_z \sim \mathcal{D}_\theta` + """ # noqa: E501 + assert self._original_layers is not None + for d in self._original_layers.values(): + copy = d["data"] + names = d["names"] + if self._resample_parameters and hasattr(copy, "reset_parameters") and callable(copy.reset_parameters): + copy = deepcopy(copy) # keep the original parameters + copy.reset_parameters() + for i, name in names: + new, new_name = self._parameters_to_prune[i] + self._copy_param(new, copy, name) + + def _apply_local_pruning(self, amount: float) -> None: + for module, name in self._parameters_to_prune: + self.pruning_fn(module, name=name, amount=amount) + + def _resolve_global_kwargs(self, amount: float) -> Dict[str, Any]: + self._global_kwargs["amount"] = amount + params = set(inspect.signature(self.pruning_fn).parameters) + params.discard("self") + return {k: v for k, v in self._global_kwargs.items() if k in params} + + def _apply_global_pruning(self, amount: float) -> None: + pytorch_prune.global_unstructured( + self._parameters_to_prune, pruning_method=self.pruning_fn, **self._resolve_global_kwargs(amount) + ) + + @staticmethod + def _get_pruned_stats(module: nn.Module, name: str) -> Tuple[int, int]: + attr = f"{name}_mask" + if not hasattr(module, attr): + return 0, 1 + mask = getattr(module, attr) + return (mask == 0).sum().item(), mask.numel() + + def apply_pruning(self, amount: Union[int, float]) -> None: + """Applies pruning to ``parameters_to_prune``.""" + if self._verbose: + prev_stats = [self._get_pruned_stats(m, n) for m, n in self._parameters_to_prune] + + if self._use_global_unstructured: + self._apply_global_pruning(amount) + else: + self._apply_local_pruning(amount) + + if self._verbose: + curr_stats = [self._get_pruned_stats(m, n) for m, n in self._parameters_to_prune] + self._log_sparsity_stats(prev_stats, curr_stats, amount=amount) + + @rank_zero_only + def _log_sparsity_stats( + self, prev: List[Tuple[int, int]], curr: List[Tuple[int, int]], amount: Union[int, float] = 0 + ) -> None: + total_params = sum(p.numel() for layer, _ in self._parameters_to_prune for p in layer.parameters()) + prev_total_zeros = sum(zeros for zeros, _ in prev) + curr_total_zeros = sum(zeros for zeros, _ in curr) + log.info( + f"Applied `{self._pruning_method_name}`. Pruned:" + f" {prev_total_zeros}/{total_params} ({prev_total_zeros / total_params:.2%}) ->" + f" {curr_total_zeros}/{total_params} ({curr_total_zeros / total_params:.2%})" + ) + if self._verbose == 2: + for i, (module, name) in enumerate(self._parameters_to_prune): + prev_mask_zeros, prev_mask_size = prev[i] + curr_mask_zeros, curr_mask_size = curr[i] + log.info( + f"Applied `{self._pruning_method_name}` to `{module!r}.{name}` with amount={amount}. Pruned:" + f" {prev_mask_zeros} ({prev_mask_zeros / prev_mask_size:.2%}) ->" + f" {curr_mask_zeros} ({curr_mask_zeros / curr_mask_size:.2%})" + ) + + def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: + parameters_to_prune = self.sanitize_parameters_to_prune( + pl_module, self._parameters_to_prune, parameter_names=self._parameter_names + ) + + self._parameters_to_prune = self.filter_parameters_to_prune(parameters_to_prune) + + if self._use_lottery_ticket_hypothesis: + # group modules by id. Each entry has a copy of the initial data + # and a list of the associated parameter names to prune + self._original_layers = {} + for i, (module, name) in enumerate(self._parameters_to_prune): + id_ = id(module) + self._original_layers.setdefault(id_, _LayerRef(data=deepcopy(module), names=[])) + self._original_layers[id_]["names"].append((i, name)) + + def _run_pruning(self, current_epoch: int) -> None: + prune = self._apply_pruning(current_epoch) if callable(self._apply_pruning) else self._apply_pruning + amount = self.amount(current_epoch) if callable(self.amount) else self.amount + if not prune or not amount: + return + self.apply_pruning(amount) + + if ( + self._use_lottery_ticket_hypothesis(current_epoch) + if callable(self._use_lottery_ticket_hypothesis) + else self._use_lottery_ticket_hypothesis + ): + self.apply_lottery_ticket_hypothesis() + + def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: LightningModule) -> None: + if self._prune_on_train_epoch_end: + rank_zero_debug("`ModelPruning.on_train_epoch_end`. Applying pruning") + self._run_pruning(pl_module.current_epoch) + + def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if not trainer.sanity_checking and not self._prune_on_train_epoch_end: + rank_zero_debug("`ModelPruning.on_validation_epoch_end`. Applying pruning") + self._run_pruning(pl_module.current_epoch) + + def on_train_end(self, trainer: "pl.Trainer", pl_module: LightningModule) -> None: + if self._make_pruning_permanent: + rank_zero_debug("`ModelPruning.on_train_end`. Pruning is made permanent for this checkpoint") + self.make_pruning_permanent(pl_module) + + def _make_pruning_permanent_on_state_dict(self, pl_module: LightningModule) -> Dict[str, Any]: + state_dict = pl_module.state_dict() + + # find the mask and the original weights. + map_pruned_params = {k.replace("_mask", "") for k in state_dict.keys() if k.endswith("_mask")} + for tensor_name in map_pruned_params: + orig = state_dict.pop(tensor_name + "_orig") + mask = state_dict.pop(tensor_name + "_mask") + # make weights permanent + state_dict[tensor_name] = mask.to(dtype=orig.dtype) * orig + + def move_to_cpu(tensor: torch.Tensor) -> torch.Tensor: + # each tensor and move them on cpu + return tensor.cpu() + + return apply_to_collection(state_dict, torch.Tensor, move_to_cpu) + + def on_save_checkpoint( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] + ) -> Optional[dict]: + if self._make_pruning_permanent: + rank_zero_debug("`ModelPruning.on_save_checkpoint`. Pruning is made permanent for this checkpoint") + # manually prune the weights so training can keep going with the same buffers + checkpoint["state_dict"] = self._make_pruning_permanent_on_state_dict(pl_module) + + @staticmethod + def sanitize_parameters_to_prune( + pl_module: LightningModule, parameters_to_prune: _PARAM_LIST = (), parameter_names: Sequence[str] = () + ) -> _PARAM_LIST: + """This function is responsible of sanitizing ``parameters_to_prune`` and ``parameter_names``. If + ``parameters_to_prune is None``, it will be generated with all parameters of the model. + + Raises: + MisconfigurationException: + If ``parameters_to_prune`` doesn't exist in the model, or + if ``parameters_to_prune`` is neither a list nor a tuple. + """ + parameters = parameter_names or ModelPruning.PARAMETER_NAMES + + current_modules = [m for m in pl_module.modules() if not isinstance(m, _MODULE_CONTAINERS)] + + if not parameters_to_prune: + parameters_to_prune = [ + (m, p) for p in parameters for m in current_modules if getattr(m, p, None) is not None + ] + elif ( + isinstance(parameters_to_prune, (list, tuple)) + and len(parameters_to_prune) > 0 + and all(len(p) == 2 for p in parameters_to_prune) + and all(isinstance(a, nn.Module) and isinstance(b, str) for a, b in parameters_to_prune) + ): + missing_modules, missing_parameters = [], [] + for module, name in parameters_to_prune: + if module not in current_modules: + missing_modules.append(module) + continue + if not hasattr(module, name): + missing_parameters.append(name) + + if missing_modules or missing_parameters: + raise MisconfigurationException( + "Some provided `parameters_to_tune` don't exist in the model." + f" Found missing modules: {missing_modules} and missing parameters: {missing_parameters}" + ) + else: + raise MisconfigurationException( + "The provided `parameters_to_prune` should either be list of tuple" + " with 2 elements: (nn.Module, parameter_name_to_prune) or None" + ) + + return parameters_to_prune + + @staticmethod + def _is_pruning_method(method: Any) -> bool: + if not inspect.isclass(method): + return False + return issubclass(method, pytorch_prune.BasePruningMethod) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/quantization.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/quantization.py new file mode 100644 index 0000000000000000000000000000000000000000..2ae1262eb25d99aea399452140335bf75f923c27 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/quantization.py @@ -0,0 +1,344 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +r""" +Quantization +^^^^^^^^^^^^ + +""" +import copy +import functools +from typing import Any, Callable, Dict, Optional, Sequence, Union + +import torch +from torch import Tensor +from torch.quantization import FakeQuantizeBase + +import pytorch_lightning as pl +from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10, _TORCH_GREATER_EQUAL_1_11 +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +if _TORCH_GREATER_EQUAL_1_10: + from torch.ao.quantization.qconfig import QConfig +else: + from torch.quantization import QConfig + +if _TORCH_GREATER_EQUAL_1_11: + from torch.ao.quantization import fuse_modules_qat as fuse_modules +else: + from torch.quantization import fuse_modules + + +def wrap_qat_forward_context( + quant_cb, model: "pl.LightningModule", func: Callable, trigger_condition: Optional[Union[Callable, int]] = None +) -> Callable: + """Decorator to wrap forward path as it is needed to quantize inputs and dequantize outputs for in/out + compatibility Moreover this version has the (de)quantization conditional as it may not be needed for the + training all the time.""" + # todo: consider using registering hook before/after forward + @functools.wraps(func) + def wrapper(data) -> Any: + _is_func_true = isinstance(trigger_condition, Callable) and trigger_condition(model.trainer) + _is_count_true = isinstance(trigger_condition, int) and quant_cb._forward_calls < trigger_condition + _quant_run = trigger_condition is None or _is_func_true or _is_count_true + # apply custom trigger + if _quant_run: + quant_cb._forward_calls += 1 + data = model.quant(data) + data = func(data) + # apply custom trigger + if _quant_run: + data = model.dequant(data) + return data + + return wrapper + + +def wrap_quantize_forward_context(model: "pl.LightningModule", func: Callable) -> Callable: + """Decorator to wrap forward path as it is needed to quantize inputs and dequantize outputs for in/out + compatibility.""" + # todo: consider using registering hook before/after forward + @functools.wraps(func) + def wrapper(data) -> Any: + data = model.quant(data) + data = func(data) + data = model.dequant(data) + return data + + return wrapper + + +def _recursive_hasattr(obj: Any, attribs: str, state: bool = True) -> bool: + """recursive check if model has some layers denoted with '.'.""" + if "." in attribs: + attrib, attribs = attribs.split(".", 1) + if hasattr(obj, attrib): + return _recursive_hasattr(getattr(obj, attrib), attribs, state) + return False + return state and hasattr(obj, attribs) + + +class QuantizationAwareTraining(Callback): + """Quantization allows speeding up inference and decreasing memory requirements by performing computations and + storing tensors at lower bitwidths (such as INT8 or FLOAT16) than floating point precision. We use native + PyTorch API so for more information see `PyTorch Quantization`_. + + .. warning:: ``QuantizationAwareTraining`` is in beta and subject to change. + + The ``LightningModule`` is prepared for QAT training in the ``on_fit_start`` hook. Checkpoints saved during training + include already collected stats to perform the Quantization conversion, but it doesn't contain the quantized or + fused model/layers. The quantization is performed in the ``on_fit_end`` hook so the model needs to be saved after + training finishes if quantization is desired. + + Args: + + qconfig: quantization configuration: + + - 'fbgemm' for server inference. + - 'qnnpack' for mobile inference. + - a custom `torch.quantization.QConfig`_. + + observer_type: allows switching between ``MovingAverageMinMaxObserver`` as "average" (default) + and ``HistogramObserver`` as "histogram" which is more computationally expensive. + + collect_quantization: count or custom function to collect quantization statistics: + + - ``None`` (default). The quantization observer is called in each module forward + (useful for collecting extended statistic when using image/data augmentation). + - ``int``. Use to set a fixed number of calls, starting from the beginning. + - ``Callable``. Custom function with single trainer argument. + See this example to trigger only the last epoch: + + .. code-block:: python + + def custom_trigger_last(trainer): + return trainer.current_epoch == (trainer.max_epochs - 1) + + + QuantizationAwareTraining(collect_quantization=custom_trigger_last) + + modules_to_fuse: allows you fuse a few layers together as shown in + `diagram `_ + to find which layer types can be fused, check https://github.com/pytorch/pytorch/pull/43286. + + input_compatible: preserve quant/dequant layers. This allows to feat any input as to the original model, + but break compatibility to torchscript and export with ``torch.save``. + + quantize_on_fit_end: perform the quantization in `on_fit_end`. + Note that once converted, the model cannot be put in training mode again. + + observer_enabled_stages: allow fake-quantization modules' observers to do calibration during provided stages: + + - ``'train'``: the observers can do calibration during training. + - ``'validate'``: the observers can do calibration during validating. + Note that we don't disable observers during the sanity check as the model hasn't been calibrated with + training data yet. After the sanity check, the fake-quantization modules are restored to initial states. + - ``'test'``: the observers can do calibration during testing. + - ``'predict'``: the observers can do calibration during predicting. + + Note that we only handle observers belonging to fake-quantization modules. When ``qconfig`` is a ``str`` and + ``observer_type`` is ``'histogram'``, the observers won't belong to any fake-quantization modules and will + not be controlled by the callback. + + .. _PyTorch Quantization: https://pytorch.org/docs/stable/quantization.html#quantization-aware-training + .. _torch.quantization.QConfig: https://pytorch.org/docs/stable/torch.quantization.html#torch.quantization.QConfig + """ + + OBSERVER_TYPES = ("histogram", "average") + OBSERVER_STAGES = ("train", "validate", "test", "predict") + + def __init__( + self, + qconfig: Union[str, QConfig] = "fbgemm", + observer_type: str = "average", + collect_quantization: Optional[Union[int, Callable]] = None, + modules_to_fuse: Optional[Sequence] = None, + input_compatible: bool = True, + quantize_on_fit_end: bool = True, + observer_enabled_stages: Sequence[str] = ("train",), + ) -> None: + _valid_qconf_str = isinstance(qconfig, str) and qconfig in torch.backends.quantized.supported_engines + if not isinstance(qconfig, QConfig) and not _valid_qconf_str: + raise MisconfigurationException( + f"Unsupported qconfig: f{qconfig}.\nTry one of defaults: {torch.backends.quantized.supported_engines}" + ) + self._qconfig = qconfig + + if observer_type not in self.OBSERVER_TYPES: + raise MisconfigurationException( + f'Unsupported observer type "{observer_type}", allowed are {self.OBSERVER_TYPES}.' + ) + self._observer_type = observer_type + + if collect_quantization is not None and not isinstance(collect_quantization, (int, Callable)): + raise MisconfigurationException( + f'Unsupported `collect_quantization` "{collect_quantization}", allowed are `int` or `Callable`.' + ) + self._collect_quantization = collect_quantization + + self._modules_to_fuse = modules_to_fuse + self._input_compatible = input_compatible + self._convert_on_fit_end = quantize_on_fit_end + + observer_enabled_stages = set(observer_enabled_stages) + unsupported_stages = observer_enabled_stages - set(self.OBSERVER_STAGES) + if unsupported_stages: + raise MisconfigurationException( + f'Unsupported stages "{tuple(sorted(unsupported_stages))}", allowed are {self.OBSERVER_STAGES}.' + ) + self._observer_disabled_stages = set(self.OBSERVER_STAGES) - observer_enabled_stages + + self._forward_calls = 0 + self._fake_quant_to_initial_state_dict = {} + self._last_fake_quant_to_observer_enabled = {} + self._module_prepared = False + + def _check_feasible_fuse(self, model: "pl.LightningModule") -> bool: + if not self._modules_to_fuse: + return False + for group in self._modules_to_fuse: + if not all(_recursive_hasattr(model, m) for m in group): + raise MisconfigurationException( + f"You have requested to fuse {group} but one or more of them is not your model attributes" + ) + return True + + def _collect_observer_enabled(self) -> Dict[FakeQuantizeBase, Tensor]: + return { + fake_quant: fake_quant.observer_enabled.clone() for fake_quant in self._fake_quant_to_initial_state_dict + } + + def _disable_observer(self, pl_module: "pl.LightningModule") -> None: + self._last_fake_quant_to_observer_enabled = self._collect_observer_enabled() + pl_module.apply(torch.quantization.disable_observer) + + def _restore_last_observer_enabled(self) -> None: + for fake_quant, observer_enabled in self._last_fake_quant_to_observer_enabled.items(): + fake_quant.observer_enabled.copy_(observer_enabled) + + def _prepare_model(self, model: torch.nn.Module) -> None: + if self._module_prepared: + return + # QuantStub converts tensors from floating point to quantized + model.quant = torch.quantization.QuantStub() + # DeQuantStub converts tensors from quantized to floating point + model.dequant = torch.quantization.DeQuantStub() + # manually specify where tensors will be converted from quantized + # to floating point in the quantized model + self.__module_forward = model.forward + model.forward = wrap_qat_forward_context( + quant_cb=self, model=model, func=model.forward, trigger_condition=self._collect_quantization + ) + + # attach a global qconfig, which contains information about what kind + # of observers to attach. Use 'fbgemm' for server inference + if isinstance(self._qconfig, str): + if self._observer_type == "histogram": + model.qconfig = torch.quantization.get_default_qconfig(self._qconfig) + elif self._observer_type == "average": + # version=None corresponds to using FakeQuantize rather than + # FusedMovingAvgObsFakeQuantize which was introduced in PT1.10 + # details in https://github.com/pytorch/pytorch/issues/64564 + extra_kwargs = dict(version=None) if _TORCH_GREATER_EQUAL_1_10 else {} + model.qconfig = torch.quantization.get_default_qat_qconfig(self._qconfig, **extra_kwargs) + + elif isinstance(self._qconfig, QConfig): + model.qconfig = self._qconfig + + if self._check_feasible_fuse(model): + fuse_modules(model, self._modules_to_fuse, inplace=True) + + # Prepare the model for QAT. This inserts observers and fake_quants in + # the model that will observe weight and activation tensors during calibration. + torch.quantization.prepare_qat(model, inplace=True) + + fake_quants = tuple(module for module in model.modules() if isinstance(module, FakeQuantizeBase)) + self._fake_quant_to_initial_state_dict = { + fake_quant: copy.deepcopy(fake_quant.state_dict()) for fake_quant in fake_quants + } + self._module_prepared = True + + def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): + self._prepare_model(pl_module) + + def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if not self._convert_on_fit_end: + pl_module.forward = self.__module_forward + return + pl_module.eval() + # Convert the observed model to a quantized model. This does several things: + # quantizes the weights, computes and stores the scale and bias value to be + # used with each activation tensor, fuses modules where appropriate, + # and replaces key operators with quantized implementations. + torch.quantization.convert(pl_module, inplace=True) + # check we shall preserve wrapper + if self._input_compatible: + pl_module.forward = wrap_quantize_forward_context(model=pl_module, func=self.__module_forward) + else: + pl_module.forward = self.__module_forward + + def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if "train" in self._observer_disabled_stages: + self._disable_observer(pl_module) + + def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if "train" in self._observer_disabled_stages: + self._restore_last_observer_enabled() + + def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if "validate" in self._observer_disabled_stages and not trainer.sanity_checking: + # ``torch.quantization.MovingAveragePerChannelMinMaxObserver`` and ``torch.quantization.HistogramObserver`` + # need to see at least one batch to infer the shapes of quantization ``scale`` and ``zero_point``. So we + # don't disable observers during the sanity check so that they can infer the shapes of quantization + # parameters with validation data. + self._disable_observer(pl_module) + + def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if "validate" in self._observer_disabled_stages: + if trainer.sanity_checking: + for fake_quant, state_dict in self._fake_quant_to_initial_state_dict.items(): + fake_quant.load_state_dict(state_dict) + else: + self._restore_last_observer_enabled() + + def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if "test" in self._observer_disabled_stages: + self._disable_observer(pl_module) + + def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if "test" in self._observer_disabled_stages: + self._restore_last_observer_enabled() + + def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if "predict" in self._observer_disabled_stages: + self._disable_observer(pl_module) + + def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if "predict" in self._observer_disabled_stages: + self._restore_last_observer_enabled() + + def state_dict(self) -> Dict[str, Any]: + keys = {"_qconfig", "_observer_type", "_collect_quantization", "_modules_to_fuse", "_input_compatible"} + return {n: getattr(self, n) for n in keys} + + def _load_before_model(self, model: torch.nn.Module, state_dict: Dict[str, Any]) -> None: + """Special hook that gets called by the CheckpointConnector *before* the model gets loaded. + + This hook replaces the :meth:`on_load_checkpoint` and :meth:`load_state_dict` callback methods which get called + after the model has already loaded the weights. For quantization, we need to convert the model first before that + happens, assuming the previous training used quantization. + """ + for k, v in state_dict.items(): + setattr(self, k, v) + self._prepare_model(model) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/rich_model_summary.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/rich_model_summary.py new file mode 100644 index 0000000000000000000000000000000000000000..148de6275950e12e0b6bc570aec2dd9aba56aaec --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/rich_model_summary.py @@ -0,0 +1,109 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Tuple + +from pytorch_lightning.callbacks import ModelSummary +from pytorch_lightning.utilities.imports import _RICH_AVAILABLE +from pytorch_lightning.utilities.model_summary import get_human_readable_count + +if _RICH_AVAILABLE: + from rich import get_console + from rich.table import Table + + +class RichModelSummary(ModelSummary): + r""" + Generates a summary of all layers in a :class:`~pytorch_lightning.core.lightning.LightningModule` + with `rich text formatting `_. + + Install it with pip: + + .. code-block:: bash + + pip install rich + + .. code-block:: python + + from pytorch_lightning import Trainer + from pytorch_lightning.callbacks import RichModelSummary + + trainer = Trainer(callbacks=RichModelSummary()) + + You could also enable ``RichModelSummary`` using the :class:`~pytorch_lightning.callbacks.RichProgressBar` + + .. code-block:: python + + from pytorch_lightning import Trainer + from pytorch_lightning.callbacks import RichProgressBar + + trainer = Trainer(callbacks=RichProgressBar()) + + Args: + max_depth: The maximum depth of layer nesting that the summary will include. A value of 0 turns the + layer summary off. + + Raises: + ModuleNotFoundError: + If required `rich` package is not installed on the device. + """ + + def __init__(self, max_depth: int = 1) -> None: + if not _RICH_AVAILABLE: + raise ModuleNotFoundError( + "`RichModelSummary` requires `rich` to be installed. Install it by running `pip install -U rich`." + ) + super().__init__(max_depth) + + @staticmethod + def summarize( + summary_data: List[Tuple[str, List[str]]], + total_parameters: int, + trainable_parameters: int, + model_size: float, + ) -> None: + + console = get_console() + + table = Table(header_style="bold magenta") + table.add_column(" ", style="dim") + table.add_column("Name", justify="left", no_wrap=True) + table.add_column("Type") + table.add_column("Params", justify="right") + + column_names = list(zip(*summary_data))[0] + + for column_name in ["In sizes", "Out sizes"]: + if column_name in column_names: + table.add_column(column_name, justify="right", style="white") + + rows = list(zip(*(arr[1] for arr in summary_data))) + for row in rows: + table.add_row(*row) + + console.print(table) + + parameters = [] + for param in [trainable_parameters, total_parameters - trainable_parameters, total_parameters, model_size]: + parameters.append("{:<{}}".format(get_human_readable_count(int(param)), 10)) + + grid = Table.grid(expand=True) + grid.add_column() + grid.add_column() + + grid.add_row(f"[bold]Trainable params[/]: {parameters[0]}") + grid.add_row(f"[bold]Non-trainable params[/]: {parameters[1]}") + grid.add_row(f"[bold]Total params[/]: {parameters[2]}") + grid.add_row(f"[bold]Total estimated model params size (MB)[/]: {parameters[3]}") + + console.print(grid) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/stochastic_weight_avg.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/stochastic_weight_avg.py new file mode 100644 index 0000000000000000000000000000000000000000..ad9e8b8fc396b0878719692b18d298fb5627b087 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -0,0 +1,280 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +r""" +Stochastic Weight Averaging Callback +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +""" +from copy import deepcopy +from typing import Callable, List, Optional, Union + +import torch +from torch import nn +from torch.optim.swa_utils import SWALR + +import pytorch_lightning as pl +from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn +from pytorch_lightning.utilities.types import LRSchedulerConfig + +_AVG_FN = Callable[[torch.Tensor, torch.Tensor, torch.LongTensor], torch.FloatTensor] + + +class StochasticWeightAveraging(Callback): + def __init__( + self, + swa_epoch_start: Union[int, float] = 0.8, + swa_lrs: Optional[Union[float, List[float]]] = None, + annealing_epochs: int = 10, + annealing_strategy: str = "cos", + avg_fn: Optional[_AVG_FN] = None, + device: Optional[Union[torch.device, str]] = torch.device("cpu"), + ): + r""" + + Implements the Stochastic Weight Averaging (SWA) Callback to average a model. + + Stochastic Weight Averaging was proposed in ``Averaging Weights Leads to + Wider Optima and Better Generalization`` by Pavel Izmailov, Dmitrii + Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson + (UAI 2018). + + This documentation is highly inspired by PyTorch's work on SWA. + The callback arguments follow the scheme defined in PyTorch's ``swa_utils`` package. + + For a SWA explanation, please take a look + `here `_. + + .. warning:: ``StochasticWeightAveraging`` is in beta and subject to change. + + .. warning:: ``StochasticWeightAveraging`` is currently not supported for multiple optimizers/schedulers. + + .. warning:: ``StochasticWeightAveraging`` is currently only supported on every epoch. + + See also how to :ref:`enable it directly on the Trainer ` + + Arguments: + + swa_epoch_start: If provided as int, the procedure will start from + the ``swa_epoch_start``-th epoch. If provided as float between 0 and 1, + the procedure will start from ``int(swa_epoch_start * max_epochs)`` epoch + + swa_lrs: The SWA learning rate to use: + + - ``None``. Use the current learning rate of the optimizer at the time the SWA procedure starts. + - ``float``. Use this value for all parameter groups of the optimizer. + - ``List[float]``. A list values for each parameter group of the optimizer. + + annealing_epochs: number of epochs in the annealing phase (default: 10) + + annealing_strategy: Specifies the annealing strategy (default: "cos"): + + - ``"cos"``. For cosine annealing. + - ``"linear"`` For linear annealing + + avg_fn: the averaging function used to update the parameters; + the function must take in the current value of the + :class:`AveragedModel` parameter, the current value of :attr:`model` + parameter and the number of models already averaged; if None, + equally weighted average is used (default: ``None``) + + device: if provided, the averaged model will be stored on the ``device``. + When None is provided, it will infer the `device` from ``pl_module``. + (default: ``"cpu"``) + + """ + + err_msg = "swa_epoch_start should be a >0 integer or a float between 0 and 1." + if isinstance(swa_epoch_start, int) and swa_epoch_start < 1: + raise MisconfigurationException(err_msg) + if isinstance(swa_epoch_start, float) and not (0 <= swa_epoch_start <= 1): + raise MisconfigurationException(err_msg) + + wrong_type = not isinstance(swa_lrs, (float, list)) + wrong_float = isinstance(swa_lrs, float) and swa_lrs <= 0 + wrong_list = isinstance(swa_lrs, list) and not all(lr > 0 and isinstance(lr, float) for lr in swa_lrs) + if swa_lrs is not None and (wrong_type or wrong_float or wrong_list): + raise MisconfigurationException( + "The `swa_lrs` should be `None`, a positive float, or a list of positive floats" + ) + + if avg_fn is not None and not isinstance(avg_fn, Callable): + raise MisconfigurationException("The `avg_fn` should be callable.") + + if device is not None and not isinstance(device, (torch.device, str)): + raise MisconfigurationException(f"device is expected to be a torch.device or a str. Found {device}") + + self._swa_epoch_start = swa_epoch_start + self._swa_lrs = swa_lrs + self._annealing_epochs = annealing_epochs + self._annealing_strategy = annealing_strategy + self._avg_fn = avg_fn or self.avg_fn + self._device = device + self._model_contains_batch_norm = None + self._average_model = None + + @property + def swa_start(self) -> int: + return max(self._swa_epoch_start - 1, 0) # 0-based + + @property + def swa_end(self) -> int: + return self._max_epochs - 1 # 0-based + + @staticmethod + def pl_module_contains_batch_norm(pl_module: "pl.LightningModule"): + return any(isinstance(module, nn.modules.batchnorm._BatchNorm) for module in pl_module.modules()) + + def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: + # copy the model before moving it to accelerator device. + with pl_module._prevent_trainer_and_dataloaders_deepcopy(): + self._average_model = deepcopy(pl_module) + + def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): + if len(trainer.optimizers) != 1: + raise MisconfigurationException("SWA currently works with 1 `optimizer`.") + + if len(trainer.lr_scheduler_configs) > 1: + raise MisconfigurationException("SWA currently not supported for more than 1 `lr_scheduler`.") + + if isinstance(self._swa_epoch_start, float): + self._swa_epoch_start = int(trainer.max_epochs * self._swa_epoch_start) + + self._model_contains_batch_norm = self.pl_module_contains_batch_norm(pl_module) + + self._max_epochs = trainer.max_epochs + if self._model_contains_batch_norm: + # virtually increase max_epochs to perform batch norm update on latest epoch. + trainer.fit_loop.max_epochs += 1 + + def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): + if trainer.current_epoch == self.swa_start: + # move average model to request device. + self._average_model = self._average_model.to(self._device or pl_module.device) + + optimizer = trainer.optimizers[0] + if self._swa_lrs is None: + self._swa_lrs = [param_group["lr"] for param_group in optimizer.param_groups] + if isinstance(self._swa_lrs, float): + self._swa_lrs = [self._swa_lrs] * len(optimizer.param_groups) + + for lr, group in zip(self._swa_lrs, optimizer.param_groups): + group["initial_lr"] = lr + + self._swa_scheduler = SWALR( + optimizer, + swa_lr=self._swa_lrs, + anneal_epochs=self._annealing_epochs, + anneal_strategy=self._annealing_strategy, + last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1, + ) + # We assert that there is only one optimizer on fit start, so know opt_idx is always 0 + default_scheduler_cfg = LRSchedulerConfig(self._swa_scheduler, opt_idx=0) + assert default_scheduler_cfg.interval == "epoch" and default_scheduler_cfg.frequency == 1 + + if trainer.lr_scheduler_configs: + scheduler_cfg = trainer.lr_scheduler_configs[0] + if scheduler_cfg.interval != "epoch" or scheduler_cfg.frequency != 1: + rank_zero_warn(f"SWA is currently only supported every epoch. Found {scheduler_cfg}") + rank_zero_info( + f"Swapping scheduler `{scheduler_cfg.scheduler.__class__.__name__}`" + f" for `{self._swa_scheduler.__class__.__name__}`" + ) + trainer.lr_scheduler_configs[0] = default_scheduler_cfg + else: + trainer.lr_scheduler_configs.append(default_scheduler_cfg) + + self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device) + + if self.swa_start <= trainer.current_epoch <= self.swa_end: + self.update_parameters(self._average_model, pl_module, self.n_averaged, self._avg_fn) + + # Note: No > here in case the callback is saved with the model and training continues + if trainer.current_epoch == self.swa_end + 1: + + # Transfer weights from average model to pl_module + self.transfer_weights(self._average_model, pl_module) + + # Reset BatchNorm for update + self.reset_batch_norm_and_save_state(pl_module) + + # There is no need to perform either backward or optimizer.step as we are + # performing only one pass over the train data-loader to compute activation statistics + # Therefore, we will virtually increase `num_training_batches` by 1 and skip backward. + trainer.num_training_batches += 1 + trainer.fit_loop._skip_backward = True + self._accumulate_grad_batches = trainer.accumulate_grad_batches + + trainer.accumulate_grad_batches = trainer.num_training_batches + + def on_train_epoch_end(self, trainer: "pl.Trainer", *args): + trainer.fit_loop._skip_backward = False + + def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): + # the trainer increases the current epoch before this hook is called + if self._model_contains_batch_norm and trainer.current_epoch - 1 == self.swa_end + 1: + # BatchNorm epoch update. Reset state + trainer.accumulate_grad_batches = self._accumulate_grad_batches + trainer.num_training_batches -= 1 + trainer.fit_loop.max_epochs -= 1 + self.reset_momenta() + elif trainer.current_epoch - 1 == self.swa_end: + # Last SWA epoch. Transfer weights from average model to pl_module + self.transfer_weights(self._average_model, pl_module) + + @staticmethod + def transfer_weights(src_pl_module: "pl.LightningModule", dst_pl_module: "pl.LightningModule"): + for src_param, dst_param in zip(src_pl_module.parameters(), dst_pl_module.parameters()): + dst_param.detach().copy_(src_param.to(dst_param.device)) + + def reset_batch_norm_and_save_state(self, pl_module: "pl.LightningModule"): + """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L140-L154.""" + self.momenta = {} + for module in pl_module.modules(): + if not isinstance(module, nn.modules.batchnorm._BatchNorm): + continue + module.running_mean = torch.zeros_like( + module.running_mean, device=pl_module.device, dtype=module.running_mean.dtype + ) + module.running_var = torch.ones_like( + module.running_var, device=pl_module.device, dtype=module.running_var.dtype + ) + self.momenta[module] = module.momentum + module.momentum = None + module.num_batches_tracked *= 0 + + def reset_momenta(self): + """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164-L165.""" + for bn_module in self.momenta: + bn_module.momentum = self.momenta[bn_module] + + @staticmethod + def update_parameters( + average_model: "pl.LightningModule", model: "pl.LightningModule", n_averaged: torch.LongTensor, avg_fn: _AVG_FN + ): + """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L104-L112.""" + for p_swa, p_model in zip(average_model.parameters(), model.parameters()): + device = p_swa.device + p_swa_ = p_swa.detach() + p_model_ = p_model.detach().to(device) + src = p_model_ if n_averaged == 0 else avg_fn(p_swa_, p_model_, n_averaged.to(device)) + p_swa_.copy_(src) + n_averaged += 1 + + @staticmethod + def avg_fn( + averaged_model_parameter: torch.Tensor, model_parameter: torch.Tensor, num_averaged: torch.LongTensor + ) -> torch.FloatTensor: + """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L95-L97.""" + return averaged_model_parameter + (model_parameter - averaged_model_parameter) / (num_averaged + 1) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/timer.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/timer.py new file mode 100644 index 0000000000000000000000000000000000000000..92e1b13d6b1504183bdf89f3def238be812b740d --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/timer.py @@ -0,0 +1,176 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +r""" +Timer +^^^^^ +""" +import logging +import time +from datetime import timedelta +from typing import Any, Dict, Optional, Union + +import pytorch_lightning as pl +from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.trainer.states import RunningStage +from pytorch_lightning.utilities import LightningEnum +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.rank_zero import rank_zero_info + +log = logging.getLogger(__name__) + + +class Interval(LightningEnum): + step = "step" + epoch = "epoch" + + +class Timer(Callback): + """The Timer callback tracks the time spent in the training, validation, and test loops and interrupts the + Trainer if the given time limit for the training loop is reached. + + Args: + duration: A string in the format DD:HH:MM:SS (days, hours, minutes seconds), or a :class:`datetime.timedelta`, + or a dict containing key-value compatible with :class:`~datetime.timedelta`. + interval: Determines if the interruption happens on epoch level or mid-epoch. + Can be either ``"epoch"`` or ``"step"``. + verbose: Set this to ``False`` to suppress logging messages. + + Raises: + MisconfigurationException: + If ``interval`` is not one of the supported choices. + + Example:: + from pytorch_lightning import Trainer + from pytorch_lightning.callbacks import Timer + + # stop training after 12 hours + timer = Timer(duration="00:12:00:00") + + # or provide a datetime.timedelta + from datetime import timedelta + timer = Timer(duration=timedelta(weeks=1)) + + # or provide a dictionary + timer = Timer(duration=dict(weeks=4, days=2)) + + # force training to stop after given time limit + trainer = Trainer(callbacks=[timer]) + + # query training/validation/test time (in seconds) + timer.time_elapsed("train") + timer.start_time("validate") + timer.end_time("test") + """ + + def __init__( + self, + duration: Optional[Union[str, timedelta, Dict[str, int]]] = None, + interval: str = Interval.step, + verbose: bool = True, + ) -> None: + super().__init__() + if isinstance(duration, str): + dhms = duration.strip().split(":") + dhms = [int(i) for i in dhms] + duration = timedelta(days=dhms[0], hours=dhms[1], minutes=dhms[2], seconds=dhms[3]) + if isinstance(duration, dict): + duration = timedelta(**duration) + if interval not in set(Interval): + raise MisconfigurationException( + f"Unsupported parameter value `Timer(interval={interval})`. Possible choices are:" + f" {', '.join(set(Interval))}" + ) + self._duration = duration.total_seconds() if duration is not None else None + self._interval = interval + self._verbose = verbose + self._start_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage} + self._end_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage} + self._offset = 0 + + def start_time(self, stage: str = RunningStage.TRAINING) -> Optional[float]: + """Return the start time of a particular stage (in seconds)""" + stage = RunningStage(stage) + return self._start_time[stage] + + def end_time(self, stage: str = RunningStage.TRAINING) -> Optional[float]: + """Return the end time of a particular stage (in seconds)""" + stage = RunningStage(stage) + return self._end_time[stage] + + def time_elapsed(self, stage: str = RunningStage.TRAINING) -> float: + """Return the time elapsed for a particular stage (in seconds)""" + start = self.start_time(stage) + end = self.end_time(stage) + offset = self._offset if stage == RunningStage.TRAINING else 0 + if start is None: + return offset + if end is None: + return time.monotonic() - start + offset + return end - start + offset + + def time_remaining(self, stage: str = RunningStage.TRAINING) -> Optional[float]: + """Return the time remaining for a particular stage (in seconds)""" + if self._duration is not None: + return self._duration - self.time_elapsed(stage) + + def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + self._start_time[RunningStage.TRAINING] = time.monotonic() + + def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + self._end_time[RunningStage.TRAINING] = time.monotonic() + + def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + self._start_time[RunningStage.VALIDATING] = time.monotonic() + + def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + self._end_time[RunningStage.VALIDATING] = time.monotonic() + + def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + self._start_time[RunningStage.TESTING] = time.monotonic() + + def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + self._end_time[RunningStage.TESTING] = time.monotonic() + + def on_fit_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: + # this checks the time after the state is reloaded, regardless of the interval. + # this is necessary in case we load a state whose timer is already depleted + if self._duration is None: + return + self._check_time_remaining(trainer) + + def on_train_batch_end(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: + if self._interval != Interval.step or self._duration is None: + return + self._check_time_remaining(trainer) + + def on_train_epoch_end(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: + if self._interval != Interval.epoch or self._duration is None: + return + self._check_time_remaining(trainer) + + def state_dict(self) -> Dict[str, Any]: + return {"time_elapsed": {stage.value: self.time_elapsed(stage) for stage in list(RunningStage)}} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + time_elapsed = state_dict.get("time_elapsed", {}) + self._offset = time_elapsed.get(RunningStage.TRAINING.value, 0) + + def _check_time_remaining(self, trainer: "pl.Trainer") -> None: + assert self._duration is not None + should_stop = self.time_elapsed() >= self._duration + should_stop = trainer.strategy.broadcast(should_stop) + trainer.should_stop = trainer.should_stop or should_stop + if should_stop and self._verbose: + elapsed = timedelta(seconds=int(self.time_elapsed(RunningStage.TRAINING))) + rank_zero_info(f"Time limit reached. Elapsed time is {elapsed}. Signaling Trainer to stop.") diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/xla_stats_monitor.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/xla_stats_monitor.py new file mode 100644 index 0000000000000000000000000000000000000000..c7fe59a59d51579d7e4af81e0b7cb331d6bbf402 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/xla_stats_monitor.py @@ -0,0 +1,114 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +XLA Stats Monitor +================= + +Monitor and logs XLA stats during training. + +""" +import time + +import pytorch_lightning as pl +from pytorch_lightning.accelerators import TPUAccelerator +from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.utilities import _TPU_AVAILABLE +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info + +if _TPU_AVAILABLE: + import torch_xla.core.xla_model as xm + + +class XLAStatsMonitor(Callback): + r""" + .. deprecated:: v1.5 + The `XLAStatsMonitor` callback was deprecated in v1.5 and will be removed in v1.7. + Please use the `DeviceStatsMonitor` callback instead. + + Automatically monitors and logs XLA stats during training stage. ``XLAStatsMonitor`` is a callback and in + order to use it you need to assign a logger in the ``Trainer``. + + Args: + verbose: Set to ``True`` to print average peak and free memory, and epoch time + every epoch. + + Raises: + MisconfigurationException: + If not running on TPUs, or ``Trainer`` has no logger. + + Example:: + + >>> from pytorch_lightning import Trainer + >>> from pytorch_lightning.callbacks import XLAStatsMonitor + >>> xla_stats = XLAStatsMonitor() # doctest: +SKIP + >>> trainer = Trainer(callbacks=[xla_stats]) # doctest: +SKIP + """ + + def __init__(self, verbose: bool = True) -> None: + super().__init__() + + rank_zero_deprecation( + "The `XLAStatsMonitor` callback was deprecated in v1.5 and will be removed in v1.7." + " Please use the `DeviceStatsMonitor` callback instead." + ) + + if not _TPU_AVAILABLE: + raise MisconfigurationException("Cannot use XLAStatsMonitor with TPUs are not available") + + self._verbose = verbose + + def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if not trainer.loggers: + raise MisconfigurationException("Cannot use XLAStatsMonitor callback with Trainer that has no logger.") + + if not isinstance(trainer.accelerator, TPUAccelerator): + raise MisconfigurationException( + "You are using XLAStatsMonitor but are not running on TPU." + f" The accelerator is set to {trainer.accelerator.__class__.__name__}." + ) + + device = trainer.strategy.root_device + memory_info = xm.get_memory_info(device) + total_memory = trainer.strategy.reduce(memory_info["kb_total"]) * 0.001 + rank_zero_info(f"Average Total memory: {total_memory:.2f} MB") + + def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + self._start_time = time.time() + + def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if not trainer.loggers: + raise MisconfigurationException("Cannot use XLAStatsMonitor callback with Trainer that has no logger.") + + device = trainer.strategy.root_device + memory_info = xm.get_memory_info(device) + epoch_time = time.time() - self._start_time + + free_memory = memory_info["kb_free"] + peak_memory = memory_info["kb_total"] - free_memory + + free_memory = trainer.strategy.reduce(free_memory) * 0.001 + peak_memory = trainer.strategy.reduce(peak_memory) * 0.001 + epoch_time = trainer.strategy.reduce(epoch_time) + + for logger in trainer.loggers: + logger.log_metrics( + {"avg. free memory (MB)": float(free_memory), "avg. peak memory (MB)": float(peak_memory)}, + step=trainer.current_epoch, + ) + + if self._verbose: + rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds") + rank_zero_info(f"Average Peak memory: {peak_memory:.2f} MB") + rank_zero_info(f"Average Free memory: {free_memory:.2f} MB") diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/core/datamodule.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/core/datamodule.py new file mode 100644 index 0000000000000000000000000000000000000000..02011fd7e90bf658e9d3a5bb43c538ef9ac0b4b8 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/core/datamodule.py @@ -0,0 +1,264 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""LightningDataModule for loading DataLoaders with ease.""" +from argparse import ArgumentParser, Namespace +from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union + +from torch.utils.data import DataLoader, Dataset, IterableDataset + +from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks +from pytorch_lightning.core.mixins import HyperparametersMixin +from pytorch_lightning.utilities import rank_zero_deprecation +from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types + + +class LightningDataModule(CheckpointHooks, DataHooks, HyperparametersMixin): + """A DataModule standardizes the training, val, test splits, data preparation and transforms. The main + advantage is consistent data splits, data preparation and transforms across models. + + Example:: + + class MyDataModule(LightningDataModule): + def __init__(self): + super().__init__() + def prepare_data(self): + # download, split, etc... + # only called on 1 GPU/TPU in distributed + def setup(self, stage): + # make assignments here (val/train/test split) + # called on every process in DDP + def train_dataloader(self): + train_split = Dataset(...) + return DataLoader(train_split) + def val_dataloader(self): + val_split = Dataset(...) + return DataLoader(val_split) + def test_dataloader(self): + test_split = Dataset(...) + return DataLoader(test_split) + def teardown(self): + # clean up after fit or test + # called on every process in DDP + """ + + name: str = ... + + def __init__(self, train_transforms=None, val_transforms=None, test_transforms=None, dims=None): + super().__init__() + if train_transforms is not None: + rank_zero_deprecation( + "DataModule property `train_transforms` was deprecated in v1.5 and will be removed in v1.7." + ) + if val_transforms is not None: + rank_zero_deprecation( + "DataModule property `val_transforms` was deprecated in v1.5 and will be removed in v1.7." + ) + if test_transforms is not None: + rank_zero_deprecation( + "DataModule property `test_transforms` was deprecated in v1.5 and will be removed in v1.7." + ) + if dims is not None: + rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.") + self._train_transforms = train_transforms + self._val_transforms = val_transforms + self._test_transforms = test_transforms + self._dims = dims if dims is not None else () + + # Pointer to the trainer object + self.trainer = None + + @property + def train_transforms(self): + """Optional transforms (or collection of transforms) you can apply to train dataset. + + .. deprecated:: v1.5 Will be removed in v1.7.0. + """ + + rank_zero_deprecation( + "DataModule property `train_transforms` was deprecated in v1.5 and will be removed in v1.7." + ) + return self._train_transforms + + @train_transforms.setter + def train_transforms(self, t): + rank_zero_deprecation( + "DataModule property `train_transforms` was deprecated in v1.5 and will be removed in v1.7." + ) + self._train_transforms = t + + @property + def val_transforms(self): + """Optional transforms (or collection of transforms) you can apply to validation dataset. + + .. deprecated:: v1.5 Will be removed in v1.7.0. + """ + + rank_zero_deprecation( + "DataModule property `val_transforms` was deprecated in v1.5 and will be removed in v1.7." + ) + return self._val_transforms + + @val_transforms.setter + def val_transforms(self, t): + rank_zero_deprecation( + "DataModule property `val_transforms` was deprecated in v1.5 and will be removed in v1.7." + ) + self._val_transforms = t + + @property + def test_transforms(self): + """Optional transforms (or collection of transforms) you can apply to test dataset. + + .. deprecated:: v1.5 Will be removed in v1.7.0. + """ + + rank_zero_deprecation( + "DataModule property `test_transforms` was deprecated in v1.5 and will be removed in v1.7." + ) + return self._test_transforms + + @test_transforms.setter + def test_transforms(self, t): + rank_zero_deprecation( + "DataModule property `test_transforms` was deprecated in v1.5 and will be removed in v1.7." + ) + self._test_transforms = t + + @property + def dims(self): + """A tuple describing the shape of your data. Extra functionality exposed in ``size``. + + .. deprecated:: v1.5 Will be removed in v1.7.0. + """ + rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.") + return self._dims + + @dims.setter + def dims(self, d): + rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.") + self._dims = d + + def size(self, dim=None) -> Union[Tuple, List[Tuple]]: + """Return the dimension of each input either as a tuple or list of tuples. You can index this just as you + would with a torch tensor. + + .. deprecated:: v1.5 Will be removed in v1.7.0. + """ + rank_zero_deprecation("DataModule property `size` was deprecated in v1.5 and will be removed in v1.7.") + + if dim is not None: + return self.dims[dim] + + return self.dims + + @classmethod + def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentParser: + """Extends existing argparse by default `LightningDataModule` attributes.""" + return add_argparse_args(cls, parent_parser, **kwargs) + + @classmethod + def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs): + """Create an instance from CLI arguments. + + Args: + args: The parser or namespace to take arguments from. Only known arguments will be + parsed and passed to the :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. + **kwargs: Additional keyword arguments that may override ones in the parser or namespace. + These must be valid DataModule arguments. + + Example:: + + parser = ArgumentParser(add_help=False) + parser = LightningDataModule.add_argparse_args(parser) + module = LightningDataModule.from_argparse_args(args) + """ + return from_argparse_args(cls, args, **kwargs) + + @classmethod + def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]: + r"""Scans the DataModule signature and returns argument names, types and default values. + + Returns: + List with tuples of 3 values: + (argument name, set with argument types, argument default value). + """ + return get_init_arguments_and_types(cls) + + @classmethod + def from_datasets( + cls, + train_dataset: Optional[Union[Dataset, Sequence[Dataset], Mapping[str, Dataset]]] = None, + val_dataset: Optional[Union[Dataset, Sequence[Dataset]]] = None, + test_dataset: Optional[Union[Dataset, Sequence[Dataset]]] = None, + batch_size: int = 1, + num_workers: int = 0, + ): + r""" + Create an instance from torch.utils.data.Dataset. + + Args: + train_dataset: (optional) Dataset to be used for train_dataloader() + val_dataset: (optional) Dataset or list of Dataset to be used for val_dataloader() + test_dataset: (optional) Dataset or list of Dataset to be used for test_dataloader() + batch_size: Batch size to use for each dataloader. Default is 1. + num_workers: Number of subprocesses to use for data loading. 0 means that the + data will be loaded in the main process. Number of CPUs available. + + """ + + def dataloader(ds: Dataset, shuffle: bool = False) -> DataLoader: + shuffle &= not isinstance(ds, IterableDataset) + return DataLoader(ds, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True) + + def train_dataloader(): + if isinstance(train_dataset, Mapping): + return {key: dataloader(ds, shuffle=True) for key, ds in train_dataset.items()} + if isinstance(train_dataset, Sequence): + return [dataloader(ds, shuffle=True) for ds in train_dataset] + return dataloader(train_dataset, shuffle=True) + + def val_dataloader(): + if isinstance(val_dataset, Sequence): + return [dataloader(ds) for ds in val_dataset] + return dataloader(val_dataset) + + def test_dataloader(): + if isinstance(test_dataset, Sequence): + return [dataloader(ds) for ds in test_dataset] + return dataloader(test_dataset) + + datamodule = cls() + if train_dataset is not None: + datamodule.train_dataloader = train_dataloader + if val_dataset is not None: + datamodule.val_dataloader = val_dataloader + if test_dataset is not None: + datamodule.test_dataloader = test_dataloader + return datamodule + + def state_dict(self) -> Dict[str, Any]: + """Called when saving a checkpoint, implement to generate and save datamodule state. + + Returns: + A dictionary containing datamodule state. + """ + return {} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + """Called when loading a checkpoint, implement to reload datamodule state given datamodule state_dict. + + Args: + state_dict: the datamodule state returned by ``state_dict``. + """ + pass diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/core/decorators.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/core/decorators.py new file mode 100644 index 0000000000000000000000000000000000000000..33c83b4b10d6dbbf6774e9a9c724189502413fac --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/core/decorators.py @@ -0,0 +1,60 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn + +rank_zero_deprecation( + "Using `pytorch_lightning.core.decorators.parameter_validation` is deprecated in v1.5, " + "and will be removed in v1.7. It has been replaced by automatic parameters tying with " + "`pytorch_lightning.utilities.params_tying.set_shared_parameters`" +) + +from functools import wraps # noqa: E402 +from typing import Callable # noqa: E402 + + +def parameter_validation(fn: Callable) -> Callable: + """Validates that the module parameter lengths match after moving to the device. It is useful when tying + weights on TPU's. + + Args: + fn: ``model_to_device`` method + + Note: + TPU's require weights to be tied/shared after moving the module to the device. + Failure to do this results in the initialization of new weights which are not tied. + To overcome this issue, weights should be tied using the ``on_post_move_to_device`` model hook + which is called after the module has been moved to the device. + + See Also: + - `XLA Documentation `_ + """ + + @wraps(fn) + def inner_fn(self, *args, **kwargs): + pre_layer_count = len(list(self.model.parameters())) + module = fn(self, *args, **kwargs) + self.model.on_post_move_to_device() + post_layer_count = len(list(self.model.parameters())) + + if not pre_layer_count == post_layer_count: + rank_zero_warn( + "The model layers do not match after moving to the target device." + " If your model employs weight sharing on TPU," + " please tie your weights using the `on_post_move_to_device` model hook.\n" + f"Layer count: [Before: {pre_layer_count} After: {post_layer_count}]" + ) + + return module + + return inner_fn diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/core/hooks.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/core/hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..442da2274c3603861ec4699116bedbaed6bfe166 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/core/hooks.py @@ -0,0 +1,828 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Various hooks to be used in the Lightning code.""" + +from typing import Any, Dict, List, Optional + +import torch +from torch.optim.optimizer import Optimizer + +from pytorch_lightning.utilities import move_data_to_device +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.types import EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS + + +class ModelHooks: + """Hooks to be used in LightningModule.""" + + def on_fit_start(self) -> None: + """Called at the very beginning of fit. + + If on DDP it is called on every process + """ + + def on_fit_end(self) -> None: + """Called at the very end of fit. + + If on DDP it is called on every process + """ + + def on_train_start(self) -> None: + """Called at the beginning of training after sanity check.""" + + def on_train_end(self) -> None: + """Called at the end of training before logger experiment is closed.""" + + def on_validation_start(self) -> None: + """Called at the beginning of validation.""" + + def on_validation_end(self) -> None: + """Called at the end of validation.""" + + def on_test_start(self) -> None: + """Called at the beginning of testing.""" + + def on_test_end(self) -> None: + """Called at the end of testing.""" + + def on_predict_start(self) -> None: + """Called at the beginning of predicting.""" + + def on_predict_end(self) -> None: + """Called at the end of predicting.""" + + def on_pretrain_routine_start(self) -> None: + """Called at the beginning of the pretrain routine (between fit and train start). + + - fit + - pretrain_routine start + - pretrain_routine end + - training_start + + .. deprecated:: v1.6 + :meth:`on_pretrain_routine_start` has been deprecated in v1.6 and will be removed in v1.8. + Use ``on_fit_start`` instead. + """ + + def on_pretrain_routine_end(self) -> None: + """Called at the end of the pretrain routine (between fit and train start). + + - fit + - pretrain_routine start + - pretrain_routine end + - training_start + + .. deprecated:: v1.6 + :meth:`on_pretrain_routine_end` has been deprecated in v1.6 and will be removed in v1.8. + Use ``on_fit_start`` instead. + """ + + def on_train_batch_start(self, batch: Any, batch_idx: int, unused: int = 0) -> Optional[int]: + """Called in the training loop before anything happens for that batch. + + If you return -1 here, you will skip training for the rest of the current epoch. + + Args: + batch: The batched data as it is returned by the training DataLoader. + batch_idx: the index of the batch + unused: Deprecated argument. Will be removed in v1.7. + """ + + def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, unused: int = 0) -> None: + """Called in the training loop after the batch. + + Args: + outputs: The outputs of training_step_end(training_step(x)) + batch: The batched data as it is returned by the training DataLoader. + batch_idx: the index of the batch + unused: Deprecated argument. Will be removed in v1.7. + """ + + def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + """Called in the validation loop before anything happens for that batch. + + Args: + batch: The batched data as it is returned by the validation DataLoader. + batch_idx: the index of the batch + dataloader_idx: the index of the dataloader + """ + + def on_validation_batch_end( + self, outputs: Optional[STEP_OUTPUT], batch: Any, batch_idx: int, dataloader_idx: int + ) -> None: + """Called in the validation loop after the batch. + + Args: + outputs: The outputs of validation_step_end(validation_step(x)) + batch: The batched data as it is returned by the validation DataLoader. + batch_idx: the index of the batch + dataloader_idx: the index of the dataloader + """ + + def on_test_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + """Called in the test loop before anything happens for that batch. + + Args: + batch: The batched data as it is returned by the test DataLoader. + batch_idx: the index of the batch + dataloader_idx: the index of the dataloader + """ + + def on_test_batch_end( + self, outputs: Optional[STEP_OUTPUT], batch: Any, batch_idx: int, dataloader_idx: int + ) -> None: + """Called in the test loop after the batch. + + Args: + outputs: The outputs of test_step_end(test_step(x)) + batch: The batched data as it is returned by the test DataLoader. + batch_idx: the index of the batch + dataloader_idx: the index of the dataloader + """ + + def on_predict_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + """Called in the predict loop before anything happens for that batch. + + Args: + batch: The batched data as it is returned by the test DataLoader. + batch_idx: the index of the batch + dataloader_idx: the index of the dataloader + """ + + def on_predict_batch_end(self, outputs: Optional[Any], batch: Any, batch_idx: int, dataloader_idx: int) -> None: + """Called in the predict loop after the batch. + + Args: + outputs: The outputs of predict_step_end(test_step(x)) + batch: The batched data as it is returned by the test DataLoader. + batch_idx: the index of the batch + dataloader_idx: the index of the dataloader + """ + + def on_validation_model_eval(self) -> None: + """Sets the model to eval during the val loop.""" + self.trainer.model.eval() + + def on_validation_model_train(self) -> None: + """Sets the model to train during the val loop.""" + self.trainer.model.train() + + def on_test_model_train(self) -> None: + """Sets the model to train during the test loop.""" + self.trainer.model.train() + + def on_test_model_eval(self) -> None: + """Sets the model to eval during the test loop.""" + self.trainer.model.eval() + + def on_predict_model_eval(self) -> None: + """Sets the model to eval during the predict loop.""" + self.trainer.model.eval() + + def on_epoch_start(self) -> None: + """Called when either of train/val/test epoch begins. + + .. deprecated:: v1.6 + :meth:`on_epoch_start` has been deprecated in v1.6 and will be removed in v1.8. + Use ``on__epoch_start`` instead. + """ + + def on_epoch_end(self) -> None: + """Called when either of train/val/test epoch ends. + + .. deprecated:: v1.6 + :meth:`on_epoch_end` has been deprecated in v1.6 and will be removed in v1.8. + Use ``on__epoch_end`` instead. + """ + + def on_train_epoch_start(self) -> None: + """Called in the training loop at the very beginning of the epoch.""" + + def on_train_epoch_end(self) -> None: + """Called in the training loop at the very end of the epoch. + + To access all batch outputs at the end of the epoch, either: + + 1. Implement `training_epoch_end` in the LightningModule OR + 2. Cache data across steps on the attribute(s) of the `LightningModule` and access them in this hook + """ + + def on_validation_epoch_start(self) -> None: + """Called in the validation loop at the very beginning of the epoch.""" + + def on_validation_epoch_end(self) -> None: + """Called in the validation loop at the very end of the epoch.""" + + def on_test_epoch_start(self) -> None: + """Called in the test loop at the very beginning of the epoch.""" + + def on_test_epoch_end(self) -> None: + """Called in the test loop at the very end of the epoch.""" + + def on_predict_epoch_start(self) -> None: + """Called at the beginning of predicting.""" + + def on_predict_epoch_end(self, results: List[Any]) -> None: + """Called at the end of predicting.""" + + def on_before_zero_grad(self, optimizer: Optimizer) -> None: + """Called after ``training_step()`` and before ``optimizer.zero_grad()``. + + Called in the training loop after taking an optimizer step and before zeroing grads. + Good place to inspect weight information with weights updated. + + This is where it is called:: + + for optimizer in optimizers: + out = training_step(...) + + model.on_before_zero_grad(optimizer) # < ---- called here + optimizer.zero_grad() + + backward() + + Args: + optimizer: The optimizer for which grads should be zeroed. + """ + + def on_before_backward(self, loss: torch.Tensor) -> None: + """Called before ``loss.backward()``. + + Args: + loss: Loss divided by number of batches for gradient accumulation and scaled if using native AMP. + """ + pass + + def on_after_backward(self) -> None: + """Called after ``loss.backward()`` and before optimizers are stepped. + + Note: + If using native AMP, the gradients will not be unscaled at this point. + Use the ``on_before_optimizer_step`` if you need the unscaled gradients. + """ + + def on_before_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None: + """Called before ``optimizer.step()``. + + If using gradient accumulation, the hook is called once the gradients have been accumulated. + See: :paramref:`~pytorch_lightning.trainer.Trainer.accumulate_grad_batches`. + + If using native AMP, the loss will be unscaled before calling this hook. + See these `docs `__ + for more information on the scaling of gradients. + + If clipping gradients, the gradients will not have been clipped yet. + + Args: + optimizer: Current optimizer being used. + optimizer_idx: Index of the current optimizer being used. + + Example:: + + def on_before_optimizer_step(self, optimizer, optimizer_idx): + # example to inspect gradient information in tensorboard + if self.trainer.global_step % 25 == 0: # don't make the tf file huge + for k, v in self.named_parameters(): + self.logger.experiment.add_histogram( + tag=k, values=v.grad, global_step=self.trainer.global_step + ) + """ + + def on_post_move_to_device(self) -> None: + """Called in the ``parameter_validation`` decorator after + :meth:`~pytorch_lightning.core.LightningModule.to` is called. This is a good place to tie weights between + modules after moving them to a device. Can be used when training models with weight sharing properties on + TPU. + + Addresses the handling of shared weights on TPU: + https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#xla-tensor-quirks + + Example:: + + def on_post_move_to_device(self): + self.decoder.weight = self.encoder.weight + """ + + def configure_sharded_model(self) -> None: + """Hook to create modules in a distributed aware context. This is useful for when using sharded plugins, + where we'd like to shard the model instantly, which is useful for extremely large models which can save + memory and initialization time. + + This hook is called during each of fit/val/test/predict stages in the same process, so ensure that + implementation of this hook is idempotent. + """ + + +class DataHooks: + """Hooks to be used for data related stuff.""" + + def __init__(self) -> None: + """ + Attributes: + prepare_data_per_node: + If True, each LOCAL_RANK=0 will call prepare data. + Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data. + allow_zero_length_dataloader_with_multiple_devices: + If True, dataloader with zero length within local rank is allowed. + Default value is False. + """ + super().__init__() + self.prepare_data_per_node: bool = True + self.allow_zero_length_dataloader_with_multiple_devices: bool = False + + def prepare_data(self) -> None: + """Use this to download and prepare data. Downloading and saving data with multiple processes (distributed + settings) will result in corrupted data. Lightning ensures this method is called only within a single + process, so you can safely add your downloading logic within. + + .. warning:: DO NOT set state to the model (use ``setup`` instead) + since this is NOT called on every device + + Example:: + + def prepare_data(self): + # good + download_data() + tokenize() + etc() + + # bad + self.split = data_split + self.some_state = some_other_state() + + In DDP ``prepare_data`` can be called in two ways (using Trainer(prepare_data_per_node)): + + 1. Once per node. This is the default and is only called on LOCAL_RANK=0. + 2. Once in total. Only called on GLOBAL_RANK=0. + + See :ref:`prepare_data_per_node`. + + Example:: + + # DEFAULT + # called once per node on LOCAL_RANK=0 of that node + Trainer(prepare_data_per_node=True) + + # call on GLOBAL_RANK=0 (great for shared file systems) + Trainer(prepare_data_per_node=False) + + This is called before requesting the dataloaders: + + .. code-block:: python + + model.prepare_data() + initialize_distributed() + model.setup(stage) + model.train_dataloader() + model.val_dataloader() + model.test_dataloader() + """ + + def setup(self, stage: Optional[str] = None) -> None: + """Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when + you need to build models dynamically or adjust something about them. This hook is called on every process + when using DDP. + + Args: + stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'`` + + Example:: + + class LitModel(...): + def __init__(self): + self.l1 = None + + def prepare_data(self): + download_data() + tokenize() + + # don't do this + self.something = else + + def setup(self, stage): + data = load_data(...) + self.l1 = nn.Linear(28, data.num_classes) + """ + + def teardown(self, stage: Optional[str] = None) -> None: + """Called at the end of fit (train + validate), validate, test, or predict. + + Args: + stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'`` + """ + + def train_dataloader(self) -> TRAIN_DATALOADERS: + """Implement one or more PyTorch DataLoaders for training. + + Return: + A collection of :class:`torch.utils.data.DataLoader` specifying training samples. + In the case of multiple dataloaders, please see this :ref:`section `. + + The dataloader you return will not be reloaded unless you set + :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to + a positive integer. + + For data processing use the following pattern: + + - download in :meth:`prepare_data` + - process and split in :meth:`setup` + + However, the above are only necessary for distributed processing. + + .. warning:: do not assign state in prepare_data + + - :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit` + - :meth:`prepare_data` + - :meth:`setup` + + Note: + Lightning adds the correct sampler for distributed and arbitrary hardware. + There is no need to set it yourself. + + Example:: + + # single dataloader + def train_dataloader(self): + transform = transforms.Compose([transforms.ToTensor(), + transforms.Normalize((0.5,), (1.0,))]) + dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform, + download=True) + loader = torch.utils.data.DataLoader( + dataset=dataset, + batch_size=self.batch_size, + shuffle=True + ) + return loader + + # multiple dataloaders, return as list + def train_dataloader(self): + mnist = MNIST(...) + cifar = CIFAR(...) + mnist_loader = torch.utils.data.DataLoader( + dataset=mnist, batch_size=self.batch_size, shuffle=True + ) + cifar_loader = torch.utils.data.DataLoader( + dataset=cifar, batch_size=self.batch_size, shuffle=True + ) + # each batch will be a list of tensors: [batch_mnist, batch_cifar] + return [mnist_loader, cifar_loader] + + # multiple dataloader, return as dict + def train_dataloader(self): + mnist = MNIST(...) + cifar = CIFAR(...) + mnist_loader = torch.utils.data.DataLoader( + dataset=mnist, batch_size=self.batch_size, shuffle=True + ) + cifar_loader = torch.utils.data.DataLoader( + dataset=cifar, batch_size=self.batch_size, shuffle=True + ) + # each batch will be a dict of tensors: {'mnist': batch_mnist, 'cifar': batch_cifar} + return {'mnist': mnist_loader, 'cifar': cifar_loader} + """ + raise MisconfigurationException("`train_dataloader` must be implemented to be used with the Lightning Trainer") + + def test_dataloader(self) -> EVAL_DATALOADERS: + r""" + Implement one or multiple PyTorch DataLoaders for testing. + + For data processing use the following pattern: + + - download in :meth:`prepare_data` + - process and split in :meth:`setup` + + However, the above are only necessary for distributed processing. + + .. warning:: do not assign state in prepare_data + + + - :meth:`~pytorch_lightning.trainer.trainer.Trainer.test` + - :meth:`prepare_data` + - :meth:`setup` + + Note: + Lightning adds the correct sampler for distributed and arbitrary hardware. + There is no need to set it yourself. + + Return: + A :class:`torch.utils.data.DataLoader` or a sequence of them specifying testing samples. + + Example:: + + def test_dataloader(self): + transform = transforms.Compose([transforms.ToTensor(), + transforms.Normalize((0.5,), (1.0,))]) + dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform, + download=True) + loader = torch.utils.data.DataLoader( + dataset=dataset, + batch_size=self.batch_size, + shuffle=False + ) + + return loader + + # can also return multiple dataloaders + def test_dataloader(self): + return [loader_a, loader_b, ..., loader_n] + + Note: + If you don't need a test dataset and a :meth:`test_step`, you don't need to implement + this method. + + Note: + In the case where you return multiple test dataloaders, the :meth:`test_step` + will have an argument ``dataloader_idx`` which matches the order here. + """ + raise MisconfigurationException("`test_dataloader` must be implemented to be used with the Lightning Trainer") + + def val_dataloader(self) -> EVAL_DATALOADERS: + r""" + Implement one or multiple PyTorch DataLoaders for validation. + + The dataloader you return will not be reloaded unless you set + :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to + a positive integer. + + It's recommended that all data downloads and preparation happen in :meth:`prepare_data`. + + - :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit` + - :meth:`~pytorch_lightning.trainer.trainer.Trainer.validate` + - :meth:`prepare_data` + - :meth:`setup` + + Note: + Lightning adds the correct sampler for distributed and arbitrary hardware + There is no need to set it yourself. + + Return: + A :class:`torch.utils.data.DataLoader` or a sequence of them specifying validation samples. + + Examples:: + + def val_dataloader(self): + transform = transforms.Compose([transforms.ToTensor(), + transforms.Normalize((0.5,), (1.0,))]) + dataset = MNIST(root='/path/to/mnist/', train=False, + transform=transform, download=True) + loader = torch.utils.data.DataLoader( + dataset=dataset, + batch_size=self.batch_size, + shuffle=False + ) + + return loader + + # can also return multiple dataloaders + def val_dataloader(self): + return [loader_a, loader_b, ..., loader_n] + + Note: + If you don't need a validation dataset and a :meth:`validation_step`, you don't need to + implement this method. + + Note: + In the case where you return multiple validation dataloaders, the :meth:`validation_step` + will have an argument ``dataloader_idx`` which matches the order here. + """ + raise MisconfigurationException("`val_dataloader` must be implemented to be used with the Lightning Trainer") + + def predict_dataloader(self) -> EVAL_DATALOADERS: + r""" + Implement one or multiple PyTorch DataLoaders for prediction. + + It's recommended that all data downloads and preparation happen in :meth:`prepare_data`. + + - :meth:`~pytorch_lightning.trainer.trainer.Trainer.predict` + - :meth:`prepare_data` + - :meth:`setup` + + Note: + Lightning adds the correct sampler for distributed and arbitrary hardware + There is no need to set it yourself. + + Return: + A :class:`torch.utils.data.DataLoader` or a sequence of them specifying prediction samples. + + Note: + In the case where you return multiple prediction dataloaders, the :meth:`predict_step` + will have an argument ``dataloader_idx`` which matches the order here. + """ + raise MisconfigurationException( + "`predict_dataloader` must be implemented to be used with the Lightning Trainer" + ) + + def on_train_dataloader(self) -> None: + """Called before requesting the train dataloader. + + .. deprecated:: v1.5 + :meth:`on_train_dataloader` is deprecated and will be removed in v1.7.0. + Please use :meth:`train_dataloader()` directly. + """ + + def on_val_dataloader(self) -> None: + """Called before requesting the val dataloader. + + .. deprecated:: v1.5 + :meth:`on_val_dataloader` is deprecated and will be removed in v1.7.0. + Please use :meth:`val_dataloader()` directly. + """ + + def on_test_dataloader(self) -> None: + """Called before requesting the test dataloader. + + .. deprecated:: v1.5 + :meth:`on_test_dataloader` is deprecated and will be removed in v1.7.0. + Please use :meth:`test_dataloader()` directly. + """ + + def on_predict_dataloader(self) -> None: + """Called before requesting the predict dataloader. + + .. deprecated:: v1.5 + :meth:`on_predict_dataloader` is deprecated and will be removed in v1.7.0. + Please use :meth:`predict_dataloader()` directly. + """ + + def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any: + """Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors wrapped in a custom + data structure. + + The data types listed below (and any arbitrary nesting of them) are supported out of the box: + + - :class:`torch.Tensor` or anything that implements `.to(...)` + - :class:`list` + - :class:`dict` + - :class:`tuple` + - :class:`torchtext.data.batch.Batch` + + For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, ...). + + Note: + This hook should only transfer the data and not modify it, nor should it move the data to + any other device than the one passed in as argument (unless you know what you are doing). + To check the current state of execution of this hook you can use + ``self.trainer.training/testing/validating/predicting`` so that you can + add different logic as per your requirement. + + Note: + This hook only runs on single GPU training and DDP (no data-parallel). + Data-Parallel support will come in near future. + + Args: + batch: A batch of data that needs to be transferred to a new device. + device: The target device as defined in PyTorch. + dataloader_idx: The index of the dataloader to which the batch belongs. + + Returns: + A reference to the data on the new device. + + Example:: + + def transfer_batch_to_device(self, batch, device, dataloader_idx): + if isinstance(batch, CustomBatch): + # move all tensors in your custom data structure to the device + batch.samples = batch.samples.to(device) + batch.targets = batch.targets.to(device) + elif dataloader_idx == 0: + # skip device transfer for the first dataloader or anything you wish + pass + else: + batch = super().transfer_batch_to_device(data, device, dataloader_idx) + return batch + + Raises: + MisconfigurationException: + If using data-parallel, ``Trainer(strategy='dp')``. + + See Also: + - :meth:`move_data_to_device` + - :meth:`apply_to_collection` + """ + return move_data_to_device(batch, device) + + def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: + """Override to alter or apply batch augmentations to your batch before it is transferred to the device. + + Note: + To check the current state of execution of this hook you can use + ``self.trainer.training/testing/validating/predicting`` so that you can + add different logic as per your requirement. + + Note: + This hook only runs on single GPU training and DDP (no data-parallel). + Data-Parallel support will come in near future. + + Args: + batch: A batch of data that needs to be altered or augmented. + dataloader_idx: The index of the dataloader to which the batch belongs. + + Returns: + A batch of data + + Example:: + + def on_before_batch_transfer(self, batch, dataloader_idx): + batch['x'] = transforms(batch['x']) + return batch + + Raises: + MisconfigurationException: + If using data-parallel, ``Trainer(strategy='dp')``. + + See Also: + - :meth:`on_after_batch_transfer` + - :meth:`transfer_batch_to_device` + """ + return batch + + def on_after_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: + """Override to alter or apply batch augmentations to your batch after it is transferred to the device. + + Note: + To check the current state of execution of this hook you can use + ``self.trainer.training/testing/validating/predicting`` so that you can + add different logic as per your requirement. + + Note: + This hook only runs on single GPU training and DDP (no data-parallel). + Data-Parallel support will come in near future. + + Args: + batch: A batch of data that needs to be altered or augmented. + dataloader_idx: The index of the dataloader to which the batch belongs. + + Returns: + A batch of data + + Example:: + + def on_after_batch_transfer(self, batch, dataloader_idx): + batch['x'] = gpu_transforms(batch['x']) + return batch + + Raises: + MisconfigurationException: + If using data-parallel, ``Trainer(strategy='dp')``. + + See Also: + - :meth:`on_before_batch_transfer` + - :meth:`transfer_batch_to_device` + """ + return batch + + +class CheckpointHooks: + """Hooks to be used with Checkpointing.""" + + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + r""" + Called by Lightning to restore your model. + If you saved something with :meth:`on_save_checkpoint` this is your chance to restore this. + + Args: + checkpoint: Loaded checkpoint + + Example:: + + def on_load_checkpoint(self, checkpoint): + # 99% of the time you don't need to implement this method + self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save'] + + Note: + Lightning auto-restores global step, epoch, and train state including amp scaling. + There is no need for you to restore anything regarding training. + """ + + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + r""" + Called by Lightning when saving a checkpoint to give you a chance to store anything + else you might want to save. + + Args: + checkpoint: The full checkpoint dictionary before it gets dumped to a file. + Implementations of this hook can insert additional data into this dictionary. + + Example:: + + def on_save_checkpoint(self, checkpoint): + # 99% of use cases you don't need to implement this method + checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object + + Note: + Lightning saves all aspects of training (epoch, global step, etc...) + including amp scaling. + There is no need for you to store anything about training. + + """ diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..51b156510c1b14abd68089716004067d055a90f9 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py @@ -0,0 +1,409 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from contextlib import contextmanager +from dataclasses import fields +from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union +from weakref import proxy + +import torch +from torch import optim +from torch.optim import Optimizer + +import pytorch_lightning as pl +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.rank_zero import rank_zero_warn +from pytorch_lightning.utilities.types import _Stateful, LRSchedulerConfig, LRSchedulerTypeTuple, ReduceLROnPlateau + + +def do_nothing_closure() -> None: + return + + +class LightningOptimizer: + """This class is used to wrap the user optimizers and handle properly the backward and optimizer_step logic + across accelerators, AMP, accumulate_grad_batches.""" + + def __init__(self, optimizer: Optimizer): + # copy most of the `Optimizer` methods into this instance. `__del__` is skipped in case the optimizer has + # implemented custom logic which we would not want to call on destruction of the `LightningOptimizer` + self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k not in ("step", "__del__")} + + # For Horovod + if hasattr(optimizer, "skip_synchronize"): + self.__class__ = type( + "Lightning" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__.__bases__[0]), {} + ) + self.skip_synchronize = optimizer.skip_synchronize + self.synchronize = optimizer.synchronize + else: + self.__class__ = type("Lightning" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {}) + + self._optimizer = optimizer + self._strategy: Optional[pl.strategies.Strategy] = None + self._optimizer_idx = 0 + # to inject logic around the optimizer step, particularly useful with manual optimization + self._on_before_step = do_nothing_closure + self._on_after_step = do_nothing_closure + + @property + def optimizer(self) -> Optimizer: + return self._optimizer + + @classmethod + def _to_lightning_optimizer( + cls, optimizer: Union[Optimizer, "LightningOptimizer"], strategy: "pl.strategies.Strategy", opt_idx: int + ) -> "LightningOptimizer": + if isinstance(optimizer, LightningOptimizer): + # the user could return a `LightningOptimizer` from `configure_optimizers`, see test: + # tests/core/test_lightning_optimizer.py::test_lightning_optimizer[False] + lightning_optimizer = optimizer + else: + lightning_optimizer = cls(optimizer) + lightning_optimizer._strategy = proxy(strategy) + lightning_optimizer._optimizer_idx = opt_idx + return lightning_optimizer + + @contextmanager + def toggle_model(self, sync_grad: bool = True) -> Generator[None, None, None]: + """This function is just a helper for advanced users. + + Considering the current optimizer as A and all other optimizers as B. + Toggling means all parameters from B exclusive to A will have ``requires_grad`` set to False. + + When performing gradient accumulation, there is no need to perform grad synchronization + during the accumulation phase. + Setting `sync_grad` to False will block this synchronization and improve performance. + """ + # local import here to avoid circular import + from pytorch_lightning.loops.utilities import _block_parallel_sync_behavior + + assert self._strategy is not None + lightning_module = self._strategy.lightning_module + assert lightning_module is not None + with _block_parallel_sync_behavior(self._strategy, block=(not sync_grad)): + lightning_module.toggle_optimizer(self, self._optimizer_idx) + yield + lightning_module.untoggle_optimizer(self._optimizer_idx) + + def step(self, closure: Optional[Callable[[], Any]] = None, **kwargs: Any) -> Any: + """Performs a single optimization step (parameter update). + + Args: + closure: An optional optimizer closure. + kwargs: Any additional arguments to the ``optimizer.step()`` call. + + Returns: + The output from the step call, which is generally the output of the closure execution. + + Example:: + + # Scenario for a GAN using manual optimization + def training_step(...): + opt_gen, opt_dis = self.optimizers() + + ... + + # compute generator loss + loss_gen = self.compute_generator_loss(...) + # zero_grad needs to be called before backward + opt_gen.zero_grad() + self.manual_backward(loss_gen) + opt_gen.step() + + # compute discriminator loss + loss_dis = self.compute_discriminator_loss(...) + + # zero_grad needs to be called before backward + opt_dis.zero_grad() + self.manual_backward(loss_dis) + opt_dis.step() + + + # A more advanced example + def training_step(self, batch, batch_idx, ...): + opt_gen, opt_dis = self.optimizers() + + ... + accumulated_grad_batches = batch_idx % 2 == 0 + + # compute generator loss + def closure_gen(): + loss_gen = self.compute_generator_loss(...) + self.manual_backward(loss_gen) + if accumulated_grad_batches: + opt_gen.zero_grad() + + with opt_gen.toggle_model(sync_grad=accumulated_grad_batches): + opt_gen.step(closure=closure_gen) + + def closure_dis(): + loss_dis = self.compute_discriminator_loss(...) + self.manual_backward(loss_dis) + if accumulated_grad_batches: + opt_dis.zero_grad() + + with opt_dis.toggle_model(sync_grad=accumulated_grad_batches): + opt_dis.step(closure=closure_dis) + """ + self._on_before_step() + + if closure is None: + closure = do_nothing_closure + elif not callable(closure): + raise MisconfigurationException("When `optimizer.step(closure)` is called, the closure should be callable") + + assert self._strategy is not None + step_output = self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs) + + self._on_after_step() + + return step_output + + +def _init_optimizers_and_lr_schedulers( + model: "pl.LightningModule", +) -> Tuple[List[Optimizer], List[LRSchedulerConfig], List[int]]: + """Calls `LightningModule.configure_optimizers` and parses and validates the output.""" + assert model.trainer is not None + optim_conf = model.trainer._call_lightning_module_hook("configure_optimizers", pl_module=model) + + if optim_conf is None: + rank_zero_warn( + "`LightningModule.configure_optimizers` returned `None`, this fit will run with no optimizer", + ) + optim_conf = _MockOptimizer() + + optimizers, lr_schedulers, optimizer_frequencies, monitor = _configure_optimizers(optim_conf) + lr_scheduler_configs = ( + _configure_schedulers_automatic_opt(lr_schedulers, monitor) + if model.automatic_optimization + else _configure_schedulers_manual_opt(lr_schedulers) + ) + _set_scheduler_opt_idx(optimizers, lr_scheduler_configs) + _validate_scheduler_api(lr_scheduler_configs, model) + return optimizers, lr_scheduler_configs, optimizer_frequencies + + +def _configure_optimizers( + optim_conf: Union[Dict[str, Any], List, Optimizer, Tuple] +) -> Tuple[List, List, List, Optional[str]]: + optimizers, lr_schedulers, optimizer_frequencies = [], [], [] + monitor = None + + # single output, single optimizer + if isinstance(optim_conf, Optimizer): + optimizers = [optim_conf] + # two lists, optimizer + lr schedulers + elif ( + isinstance(optim_conf, (list, tuple)) + and len(optim_conf) == 2 + and isinstance(optim_conf[0], list) + and all(isinstance(opt, Optimizer) for opt in optim_conf[0]) + ): + opt, sch = optim_conf + optimizers = opt + lr_schedulers = sch if isinstance(sch, list) else [sch] + # single dictionary + elif isinstance(optim_conf, dict): + _validate_optim_conf(optim_conf) + optimizers = [optim_conf["optimizer"]] + monitor = optim_conf.get("monitor", None) + lr_schedulers = [optim_conf["lr_scheduler"]] if "lr_scheduler" in optim_conf else [] + # multiple dictionaries + elif isinstance(optim_conf, (list, tuple)) and all(isinstance(d, dict) for d in optim_conf): + for opt_dict in optim_conf: + _validate_optim_conf(opt_dict) + optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf] + scheduler_dict = ( + lambda scheduler, opt_idx: dict(scheduler, opt_idx=opt_idx) + if isinstance(scheduler, dict) + else {"scheduler": scheduler, "opt_idx": opt_idx} + ) + + lr_schedulers = [ + scheduler_dict(opt_dict["lr_scheduler"], opt_idx) + for opt_idx, opt_dict in enumerate(optim_conf) + if "lr_scheduler" in opt_dict + ] + optimizer_frequencies = [ + opt_dict["frequency"] for opt_dict in optim_conf if opt_dict.get("frequency", None) is not None + ] + # assert that if frequencies are present, they are given for all optimizers + if optimizer_frequencies and len(optimizer_frequencies) != len(optimizers): + raise ValueError("A frequency must be given to each optimizer.") + # single list or tuple, multiple optimizer + elif isinstance(optim_conf, (list, tuple)) and all(isinstance(opt, Optimizer) for opt in optim_conf): + optimizers = list(optim_conf) + # unknown configuration + else: + raise MisconfigurationException( + "Unknown configuration for model optimizers." + " Output from `model.configure_optimizers()` should be one of:\n" + " * `Optimizer`\n" + " * [`Optimizer`]\n" + " * ([`Optimizer`], [`_LRScheduler`])\n" + ' * {"optimizer": `Optimizer`, (optional) "lr_scheduler": `_LRScheduler`}\n' + ' * A list of the previously described dict format, with an optional "frequency" key (int)' + ) + return optimizers, lr_schedulers, optimizer_frequencies, monitor + + +def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str]) -> List[LRSchedulerConfig]: + """Convert each scheduler into `LRSchedulerConfig` with relevant information, when using automatic + optimization.""" + lr_scheduler_configs = [] + for scheduler in schedulers: + if isinstance(scheduler, dict): + # check provided keys + supported_keys = {field.name for field in fields(LRSchedulerConfig)} + extra_keys = scheduler.keys() - supported_keys + if extra_keys: + rank_zero_warn( + f"Found unsupported keys in the lr scheduler dict: {extra_keys}." + " HINT: remove them from the output of `configure_optimizers`.", + category=RuntimeWarning, + ) + scheduler = {k: v for k, v in scheduler.items() if k in supported_keys} + if "scheduler" not in scheduler: + raise MisconfigurationException( + 'The lr scheduler dict must have the key "scheduler" with its item being an lr scheduler' + ) + if "interval" in scheduler and scheduler["interval"] not in ("step", "epoch"): + raise MisconfigurationException( + 'The "interval" key in lr scheduler dict must be "step" or "epoch"' + f' but is "{scheduler["interval"]}"' + ) + scheduler["reduce_on_plateau"] = isinstance(scheduler["scheduler"], optim.lr_scheduler.ReduceLROnPlateau) + if scheduler["reduce_on_plateau"] and scheduler.get("monitor", None) is None: + raise MisconfigurationException( + "The lr scheduler dict must include a monitor when a `ReduceLROnPlateau` scheduler is used." + ' For example: {"optimizer": optimizer, "lr_scheduler":' + ' {"scheduler": scheduler, "monitor": "your_loss"}}' + ) + is_one_cycle = isinstance(scheduler["scheduler"], optim.lr_scheduler.OneCycleLR) + if is_one_cycle and scheduler.get("interval", "epoch") == "epoch": + rank_zero_warn( + "A `OneCycleLR` scheduler is using 'interval': 'epoch'." + " Are you sure you didn't mean 'interval': 'step'?", + category=RuntimeWarning, + ) + config = LRSchedulerConfig(**scheduler) + elif isinstance(scheduler, ReduceLROnPlateau): + if monitor is None: + raise MisconfigurationException( + "`configure_optimizers` must include a monitor when a `ReduceLROnPlateau`" + " scheduler is used. For example:" + ' {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}' + ) + config = LRSchedulerConfig(scheduler, reduce_on_plateau=True, monitor=monitor) + else: + config = LRSchedulerConfig(scheduler) + lr_scheduler_configs.append(config) + return lr_scheduler_configs + + +def _configure_schedulers_manual_opt(schedulers: list) -> List[LRSchedulerConfig]: + """Convert each scheduler into `LRSchedulerConfig` structure with relevant information, when using manual + optimization.""" + lr_scheduler_configs = [] + for scheduler in schedulers: + if isinstance(scheduler, dict): + invalid_keys = {"interval", "frequency", "reduce_on_plateau", "monitor", "strict"} + keys_to_warn = [k for k in scheduler.keys() if k in invalid_keys] + + if keys_to_warn: + rank_zero_warn( + f"The lr scheduler dict contains the key(s) {keys_to_warn}, but the keys will be ignored." + " You need to call `lr_scheduler.step()` manually in manual optimization.", + category=RuntimeWarning, + ) + + config = LRSchedulerConfig(**{key: scheduler[key] for key in scheduler if key not in invalid_keys}) + else: + config = LRSchedulerConfig(scheduler) + lr_scheduler_configs.append(config) + return lr_scheduler_configs + + +def _validate_scheduler_api(lr_scheduler_configs: List[LRSchedulerConfig], model: "pl.LightningModule") -> None: + for config in lr_scheduler_configs: + scheduler = config.scheduler + if not isinstance(scheduler, _Stateful): + raise TypeError( + f"The provided lr scheduler `{scheduler.__class__.__name__}` is invalid." + " It should have `state_dict` and `load_state_dict` methods defined." + ) + + if not isinstance(scheduler, LRSchedulerTypeTuple) and not is_overridden("lr_scheduler_step", model): + raise MisconfigurationException( + f"The provided lr scheduler `{scheduler.__class__.__name__}` doesn't follow PyTorch's LRScheduler" + " API. You should override the `LightningModule.lr_scheduler_step` hook with your own logic if" + " you are using a custom LR scheduler." + ) + + +def _set_scheduler_opt_idx(optimizers: List[Optimizer], lr_scheduler_configs: List[LRSchedulerConfig]) -> None: + for config in lr_scheduler_configs: + + for opt_idx, opt in enumerate(optimizers): + if config.scheduler.optimizer is opt: + if config.opt_idx is not None and config.opt_idx != opt_idx: + raise MisconfigurationException( + "`opt_idx` set inside scheduler config does not match with the index" + " of the respective optimizer returned from `configure_optimizers`." + ) + + config.opt_idx = opt_idx + break + else: + raise MisconfigurationException( + "Some schedulers are attached with an optimizer that wasn't returned from `configure_optimizers`." + ) + + +def _validate_optim_conf(optim_conf: Dict[str, Any]) -> None: + valid_keys = {"optimizer", "lr_scheduler", "frequency", "monitor"} + extra_keys = optim_conf.keys() - valid_keys + if extra_keys: + rank_zero_warn( + f"Found unsupported keys in the optimizer configuration: {set(extra_keys)}", category=RuntimeWarning + ) + + +class _MockOptimizer(Optimizer): + """The `_MockOptimizer` will be used inplace of an optimizer in the event that `None` is returned from + `configure_optimizers`.""" + + def __init__(self) -> None: + super().__init__([torch.zeros(1)], {}) + + def add_param_group(self, param_group: Dict[Any, Any]) -> None: + pass # Do Nothing + + def load_state_dict(self, state_dict: Dict[Any, Any]) -> None: + pass # Do Nothing + + def state_dict(self) -> Dict[str, Any]: + return {} # Return Empty + + def step(self, closure: Callable = None) -> None: + if closure is not None: + closure() + + def zero_grad(self, set_to_none: Optional[bool] = False) -> None: + pass # Do Nothing + + def __repr__(self) -> str: + return "No Optimizer" diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/core/saving.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/core/saving.py new file mode 100644 index 0000000000000000000000000000000000000000..fa0f92eb3b971bb872f71ba57065cf8bb26e4c5b --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/core/saving.py @@ -0,0 +1,419 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ast +import csv +import inspect +import logging +import os +from argparse import Namespace +from copy import deepcopy +from enum import Enum +from typing import Any, Callable, Dict, IO, MutableMapping, Optional, Union +from warnings import warn + +import torch +import yaml + +from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, AttributeDict +from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.cloud_io import get_filesystem +from pytorch_lightning.utilities.cloud_io import load as pl_load +from pytorch_lightning.utilities.migration import pl_legacy_patch +from pytorch_lightning.utilities.parsing import parse_class_init_keys +from pytorch_lightning.utilities.rank_zero import rank_zero_warn + +log = logging.getLogger(__name__) +PRIMITIVE_TYPES = (bool, int, float, str) +ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace) + +if _OMEGACONF_AVAILABLE: + from omegaconf import OmegaConf + from omegaconf.dictconfig import DictConfig + from omegaconf.errors import UnsupportedValueType, ValidationError + +# the older shall be on the top +CHECKPOINT_PAST_HPARAMS_KEYS = ("hparams", "module_arguments") # used in 0.7.6 + + +class ModelIO: + CHECKPOINT_HYPER_PARAMS_KEY = "hyper_parameters" + CHECKPOINT_HYPER_PARAMS_NAME = "hparams_name" + CHECKPOINT_HYPER_PARAMS_TYPE = "hparams_type" + + @classmethod + def load_from_checkpoint( + cls, + checkpoint_path: Union[str, IO], + map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None, + hparams_file: Optional[str] = None, + strict: bool = True, + **kwargs, + ): + r""" + Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint + it stores the arguments passed to ``__init__`` in the checkpoint under ``"hyper_parameters"``. + + Any arguments specified through \*\*kwargs will override args stored in ``"hyper_parameters"``. + + Args: + checkpoint_path: Path to checkpoint. This can also be a URL, or file-like object + map_location: + If your checkpoint saved a GPU model and you now load on CPUs + or a different number of GPUs, use this to map to the new setup. + The behaviour is the same as in :func:`torch.load`. + hparams_file: Optional path to a .yaml file with hierarchical structure + as in this example:: + + drop_prob: 0.2 + dataloader: + batch_size: 32 + + You most likely won't need this since Lightning will always save the hyperparameters + to the checkpoint. + However, if your checkpoint weights don't have the hyperparameters saved, + use this method to pass in a .yaml file with the hparams you'd like to use. + These will be converted into a :class:`~dict` and passed into your + :class:`LightningModule` for use. + + If your model's ``hparams`` argument is :class:`~argparse.Namespace` + and .yaml file has hierarchical structure, you need to refactor your model to treat + ``hparams`` as :class:`~dict`. + strict: Whether to strictly enforce that the keys in :attr:`checkpoint_path` match the keys + returned by this module's state dict. + kwargs: Any extra keyword args needed to init the model. Can also be used to override saved + hyperparameter values. + + Return: + :class:`LightningModule` instance with loaded weights and hyperparameters (if available). + + Note: + ``load_from_checkpoint`` is a **class** method. You should use your :class:`LightningModule` + **class** to call it instead of the :class:`LightningModule` instance. + + Example:: + + # load weights without mapping ... + model = MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt') + + # or load weights mapping all weights from GPU 1 to GPU 0 ... + map_location = {'cuda:1':'cuda:0'} + model = MyLightningModule.load_from_checkpoint( + 'path/to/checkpoint.ckpt', + map_location=map_location + ) + + # or load weights and hyperparameters from separate files. + model = MyLightningModule.load_from_checkpoint( + 'path/to/checkpoint.ckpt', + hparams_file='/path/to/hparams_file.yaml' + ) + + # override some of the params with new values + model = MyLightningModule.load_from_checkpoint( + PATH, + num_layers=128, + pretrained_ckpt_path=NEW_PATH, + ) + + # predict + pretrained_model.eval() + pretrained_model.freeze() + y_hat = pretrained_model(x) + """ + with pl_legacy_patch(): + if map_location is not None: + checkpoint = pl_load(checkpoint_path, map_location=map_location) + else: + checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) + + if hparams_file is not None: + extension = hparams_file.split(".")[-1] + if extension.lower() == "csv": + hparams = load_hparams_from_tags_csv(hparams_file) + elif extension.lower() in ("yml", "yaml"): + hparams = load_hparams_from_yaml(hparams_file) + else: + raise ValueError(".csv, .yml or .yaml is required for `hparams_file`") + + hparams["on_gpu"] = False + + # overwrite hparams by the given file + checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = hparams + + # for past checkpoint need to add the new key + if cls.CHECKPOINT_HYPER_PARAMS_KEY not in checkpoint: + checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = {} + # override the hparams with values that were passed in + checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(kwargs) + + model = cls._load_model_state(checkpoint, strict=strict, **kwargs) + return model + + @classmethod + def _load_model_state(cls, checkpoint: Dict[str, Any], strict: bool = True, **cls_kwargs_new): + cls_spec = inspect.getfullargspec(cls.__init__) + cls_init_args_name = inspect.signature(cls.__init__).parameters.keys() + + self_var, args_var, kwargs_var = parse_class_init_keys(cls) + drop_names = [n for n in (self_var, args_var, kwargs_var) if n] + cls_init_args_name = list(filter(lambda n: n not in drop_names, cls_init_args_name)) + + cls_kwargs_loaded = {} + # pass in the values we saved automatically + if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: + + # 1. (backward compatibility) Try to restore model hparams from checkpoint using old/past keys + for _old_hparam_key in CHECKPOINT_PAST_HPARAMS_KEYS: + cls_kwargs_loaded.update(checkpoint.get(_old_hparam_key, {})) + + # 2. Try to restore model hparams from checkpoint using the new key + _new_hparam_key = cls.CHECKPOINT_HYPER_PARAMS_KEY + cls_kwargs_loaded.update(checkpoint.get(_new_hparam_key)) + + # 3. Ensure that `cls_kwargs_old` has the right type, back compatibility between dict and Namespace + cls_kwargs_loaded = _convert_loaded_hparams( + cls_kwargs_loaded, checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_TYPE) + ) + + # 4. Update cls_kwargs_new with cls_kwargs_old, such that new has higher priority + args_name = checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_NAME) + if args_name and args_name in cls_init_args_name: + cls_kwargs_loaded = {args_name: cls_kwargs_loaded} + + _cls_kwargs = {} + _cls_kwargs.update(cls_kwargs_loaded) + _cls_kwargs.update(cls_kwargs_new) + + if not cls_spec.varkw: + # filter kwargs according to class init unless it allows any argument via kwargs + _cls_kwargs = {k: v for k, v in _cls_kwargs.items() if k in cls_init_args_name} + + model = cls(**_cls_kwargs) + + # give model a chance to load something + model.on_load_checkpoint(checkpoint) + + # load the state_dict on the model automatically + keys = model.load_state_dict(checkpoint["state_dict"], strict=strict) + + if not strict: + if keys.missing_keys: + rank_zero_warn( + f"Found keys that are in the model state dict but not in the checkpoint: {keys.missing_keys}" + ) + if keys.unexpected_keys: + rank_zero_warn( + f"Found keys that are not in the model state dict but in the checkpoint: {keys.unexpected_keys}" + ) + + return model + + # ------------------------- + # OPTIONAL HOOKS + # ------------------------- + def on_hpc_save(self, checkpoint: Dict[str, Any]) -> None: + """Hook to do whatever you need right before Slurm manager saves the model. + + Args: + checkpoint: A dictionary in which you can save variables to save in a checkpoint. + Contents need to be pickleable. + + .. deprecated:: v1.6 + This method is deprecated in v1.6 and will be removed in v1.8. + Please use ``LightningModule.on_save_checkpoint`` instead. + """ + + def on_hpc_load(self, checkpoint: Dict[str, Any]) -> None: + """Hook to do whatever you need right before Slurm manager loads the model. + + Args: + checkpoint: A dictionary with variables from the checkpoint. + + .. deprecated:: v1.6 + This method is deprecated in v1.6 and will be removed in v1.8. + Please use ``LightningModule.on_load_checkpoint`` instead. + """ + + +def _convert_loaded_hparams(model_args: dict, hparams_type: Optional[Union[Callable, str]] = None) -> object: + """Convert hparams according given type in callable or string (past) format.""" + # if not hparams type define + if not hparams_type: + return model_args + # if past checkpoint loaded, convert str to callable + if isinstance(hparams_type, str): + hparams_type = AttributeDict + # convert hparams + return hparams_type(model_args) + + +def update_hparams(hparams: dict, updates: dict) -> None: + """Overrides hparams with new values. + + >>> hparams = {'c': 4} + >>> update_hparams(hparams, {'a': {'b': 2}, 'c': 1}) + >>> hparams['a']['b'], hparams['c'] + (2, 1) + >>> update_hparams(hparams, {'a': {'b': 4}, 'c': 7}) + >>> hparams['a']['b'], hparams['c'] + (4, 7) + + Args: + hparams: the original params and also target object + updates: new params to be used as update + """ + for k, v in updates.items(): + # if missing, add the key + if k not in hparams: + hparams[k] = v + continue + + # recurse if dictionary + if isinstance(v, dict): + update_hparams(hparams[k], updates[k]) + else: + # update the value + hparams.update({k: v}) + + +def load_hparams_from_tags_csv(tags_csv: str) -> Dict[str, Any]: + """Load hparams from a file. + + >>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here') + >>> path_csv = os.path.join('.', 'testing-hparams.csv') + >>> save_hparams_to_tags_csv(path_csv, hparams) + >>> hparams_new = load_hparams_from_tags_csv(path_csv) + >>> vars(hparams) == hparams_new + True + >>> os.remove(path_csv) + """ + fs = get_filesystem(tags_csv) + if not fs.exists(tags_csv): + rank_zero_warn(f"Missing Tags: {tags_csv}.", category=RuntimeWarning) + return {} + + with fs.open(tags_csv, "r", newline="") as fp: + csv_reader = csv.reader(fp, delimiter=",") + tags = {row[0]: convert(row[1]) for row in list(csv_reader)[1:]} + + return tags + + +def save_hparams_to_tags_csv(tags_csv: str, hparams: Union[dict, Namespace]) -> None: + fs = get_filesystem(tags_csv) + if not fs.isdir(os.path.dirname(tags_csv)): + raise RuntimeError(f"Missing folder: {os.path.dirname(tags_csv)}.") + + if isinstance(hparams, Namespace): + hparams = vars(hparams) + + with fs.open(tags_csv, "w", newline="") as fp: + fieldnames = ["key", "value"] + writer = csv.DictWriter(fp, fieldnames=fieldnames) + writer.writerow({"key": "key", "value": "value"}) + for k, v in hparams.items(): + writer.writerow({"key": k, "value": v}) + + +def load_hparams_from_yaml(config_yaml: str, use_omegaconf: bool = True) -> Dict[str, Any]: + """Load hparams from a file. + + Args: + config_yaml: Path to config yaml file + use_omegaconf: If omegaconf is available and ``use_omegaconf=True``, + the hparams will be converted to ``DictConfig`` if possible. + + >>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here') + >>> path_yaml = './testing-hparams.yaml' + >>> save_hparams_to_yaml(path_yaml, hparams) + >>> hparams_new = load_hparams_from_yaml(path_yaml) + >>> vars(hparams) == hparams_new + True + >>> os.remove(path_yaml) + """ + fs = get_filesystem(config_yaml) + if not fs.exists(config_yaml): + rank_zero_warn(f"Missing Tags: {config_yaml}.", category=RuntimeWarning) + return {} + + with fs.open(config_yaml, "r") as fp: + hparams = yaml.full_load(fp) + + if _OMEGACONF_AVAILABLE: + if use_omegaconf: + try: + return OmegaConf.create(hparams) + except (UnsupportedValueType, ValidationError): + pass + return hparams + + +def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace], use_omegaconf: bool = True) -> None: + """ + Args: + config_yaml: path to new YAML file + hparams: parameters to be saved + use_omegaconf: If omegaconf is available and ``use_omegaconf=True``, + the hparams will be converted to ``DictConfig`` if possible. + + """ + fs = get_filesystem(config_yaml) + if not fs.isdir(os.path.dirname(config_yaml)): + raise RuntimeError(f"Missing folder: {os.path.dirname(config_yaml)}.") + + # convert Namespace or AD to dict + if isinstance(hparams, Namespace): + hparams = vars(hparams) + elif isinstance(hparams, AttributeDict): + hparams = dict(hparams) + + # saving with OmegaConf objects + if _OMEGACONF_AVAILABLE and use_omegaconf: + # deepcopy: hparams from user shouldn't be resolved + hparams = deepcopy(hparams) + hparams = apply_to_collection(hparams, DictConfig, OmegaConf.to_container, resolve=True) + with fs.open(config_yaml, "w", encoding="utf-8") as fp: + try: + OmegaConf.save(hparams, fp) + return + except (UnsupportedValueType, ValidationError): + pass + + if not isinstance(hparams, dict): + raise TypeError("hparams must be dictionary") + + hparams_allowed = {} + # drop parameters which contain some strange datatypes as fsspec + for k, v in hparams.items(): + try: + v = v.name if isinstance(v, Enum) else v + yaml.dump(v) + except TypeError: + warn(f"Skipping '{k}' parameter because it is not possible to safely dump to YAML.") + hparams[k] = type(v).__name__ + else: + hparams_allowed[k] = v + + # saving the standard way + with fs.open(config_yaml, "w", newline="") as fp: + yaml.dump(hparams_allowed, fp) + + +def convert(val: str) -> Union[int, float, bool, str]: + try: + return ast.literal_eval(val) + except (ValueError, SyntaxError) as err: + log.debug(err) + return val diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/distributed/__init__.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/distributed/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ea060e551ad9da277e03cb3d7d267626d948af14 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/distributed/__init__.py @@ -0,0 +1,14 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning.distributed.dist import LightningDistributed # noqa: F401 diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/distributed/dist.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/distributed/dist.py new file mode 100644 index 0000000000000000000000000000000000000000..1799450e3ce059fff68cdb443b31cc9aa6629d56 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/distributed/dist.py @@ -0,0 +1,47 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any + +import torch.distributed + +from pytorch_lightning.utilities import rank_zero_deprecation +from pytorch_lightning.utilities.distributed import group as _group + + +class LightningDistributed: + """ + .. deprecated:: v1.5 + This class is deprecated in v1.5 and will be removed in v1.7. + The broadcast logic will be moved to the :class:`DDPStrategy` and :class`DDPSpawnStrategy` classes. + + """ + + def __init__(self, rank=None, device=None): + rank_zero_deprecation( + "LightningDistributed is deprecated in v1.5 and will be removed in v1.7." + "Broadcast logic is implemented directly in the :class:`Strategy` implementations." + ) + self.rank = rank + self.device = device + + def broadcast(self, obj: Any, group=_group.WORLD): + # always wrap into a list so it can be broadcasted. + obj = [obj] + + if self.rank != 0: + obj = [None] * len(obj) + + torch.distributed.broadcast_object_list(obj, 0, group=group or _group.WORLD) + + return obj[0] diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/__pycache__/__init__.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf60ed3809da9b92c276488849a897e58328be83 Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/__pycache__/__init__.cpython-38.pyc differ diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/__pycache__/layer_sync.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/__pycache__/layer_sync.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da9adc7de0cd6dc4c98bd832ec903f2a18b51902 Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/__pycache__/layer_sync.cpython-38.pyc differ diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/__init__.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eab64cfe2daf5fb1b4cedc7ef1f0171ad6512686 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/__init__.py @@ -0,0 +1,20 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning.plugins.environments.bagua_environment import BaguaEnvironment # noqa: F401 +from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment # noqa: F401 +from pytorch_lightning.plugins.environments.kubeflow_environment import KubeflowEnvironment # noqa: F401 +from pytorch_lightning.plugins.environments.lightning_environment import LightningEnvironment # noqa: F401 +from pytorch_lightning.plugins.environments.lsf_environment import LSFEnvironment # noqa: F401 +from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment # noqa: F401 +from pytorch_lightning.plugins.environments.torchelastic_environment import TorchElasticEnvironment # noqa: F401 diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/__pycache__/__init__.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d01aca1295660b5fe2d456e1ad7eaa5625528daf Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/__pycache__/__init__.cpython-38.pyc differ diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/__pycache__/bagua_environment.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/__pycache__/bagua_environment.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..777771f0e3f122ccd5839c474d8292223d7dd405 Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/__pycache__/bagua_environment.cpython-38.pyc differ diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/bagua_environment.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/bagua_environment.py new file mode 100644 index 0000000000000000000000000000000000000000..0516b264c2ac3f033c3f81c669d69cdfdf1a3c55 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/bagua_environment.py @@ -0,0 +1,62 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os + +from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment + +log = logging.getLogger(__name__) + + +class BaguaEnvironment(ClusterEnvironment): + """Environment for distributed training with `Bagua `_""" + + @property + def creates_processes_externally(self) -> bool: + return True + + @property + def main_address(self) -> str: + return os.environ.get("MASTER_ADDR", "127.0.0.1") + + @property + def main_port(self) -> int: + return int(os.environ.get("MASTER_PORT", -1)) + + @property + def service_port(self) -> int: + return int(os.environ.get("BAGUA_SERVICE_PORT", -1)) + + @staticmethod + def detect() -> bool: + return "BAGUA_SERVICE_PORT" in os.environ + + def world_size(self) -> int: + return int(os.environ["WORLD_SIZE"]) + + def set_world_size(self, size: int) -> None: + log.debug("`BaguaEnvironment.set_world_size` was called, but setting world size is not allowed. Ignored.") + + def global_rank(self) -> int: + return int(os.environ["RANK"]) + + def set_global_rank(self, rank: int) -> None: + log.debug("`BaguaEnvironment.set_global_rank` was called, but setting global rank is not allowed. Ignored.") + + def local_rank(self) -> int: + return int(os.environ.get("LOCAL_RANK", 0)) + + def node_rank(self) -> int: + return int(os.environ.get("NODE_RANK", 0)) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/cluster_environment.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/cluster_environment.py new file mode 100644 index 0000000000000000000000000000000000000000..1871f0afdf193ccf4e10015b28c1fd1578e232e5 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/cluster_environment.py @@ -0,0 +1,87 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC, abstractmethod +from typing import Any, Type + +from pytorch_lightning.utilities import rank_zero_deprecation + + +class ClusterEnvironment(ABC): + """Specification of a cluster environment.""" + + def __new__(cls, *args: Any, **kwargs: Any) -> "ClusterEnvironment": + # TODO: remove in 1.7 + _check_for_deprecated_methods(cls) + return super().__new__(cls) + + @property + @abstractmethod + def creates_processes_externally(self) -> bool: + """Whether the environment creates the subprocesses or not.""" + + @property + @abstractmethod + def main_address(self) -> str: + """The main address through which all processes connect and communicate.""" + + @property + @abstractmethod + def main_port(self) -> int: + """An open and configured port in the main node through which all processes communicate.""" + + @staticmethod + @abstractmethod + def detect() -> bool: + """Detects the environment settings corresponding to this cluster and returns ``True`` if they match.""" + + @abstractmethod + def world_size(self) -> int: + """The number of processes across all devices and nodes.""" + + @abstractmethod + def set_world_size(self, size: int) -> None: + pass + + @abstractmethod + def global_rank(self) -> int: + """The rank (index) of the currently running process across all nodes and devices.""" + + @abstractmethod + def set_global_rank(self, rank: int) -> None: + pass + + @abstractmethod + def local_rank(self) -> int: + """The rank (index) of the currently running process inside of the current node.""" + + @abstractmethod + def node_rank(self) -> int: + """The rank (index) of the node on which the current process runs.""" + + def teardown(self) -> None: + """Clean up any state set after execution finishes.""" + pass + + +def _check_for_deprecated_methods(cls: Type[ClusterEnvironment]) -> None: + if hasattr(cls, "master_address") and callable(cls.master_address): + rank_zero_deprecation( + f"`{cls.__name__}.master_address` has been deprecated in v1.6 and will be removed in v1.7." + " Implement the property `main_address` instead (do not forget to add the `@property` decorator)." + ) + if hasattr(cls, "master_port") and callable(cls.master_port): + rank_zero_deprecation( + f"`{cls.__name__}.master_port` has been deprecated in v1.6 and will be removed in v1.7." + " Implement the property `main_port` instead (do not forget to add the `@property` decorator)." + ) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/kubeflow_environment.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/kubeflow_environment.py new file mode 100644 index 0000000000000000000000000000000000000000..03dfdde9d78a0871e4b5517cdafe48bbf13f7d2b --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/kubeflow_environment.py @@ -0,0 +1,78 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os + +from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment +from pytorch_lightning.utilities import rank_zero_deprecation + +log = logging.getLogger(__name__) + + +class KubeflowEnvironment(ClusterEnvironment): + """Environment for distributed training using the `PyTorchJob`_ operator from `Kubeflow`_ + + .. _PyTorchJob: https://www.kubeflow.org/docs/components/training/pytorch/ + .. _Kubeflow: https://www.kubeflow.org + """ + + def __init__(self) -> None: + super().__init__() + # TODO: remove in 1.7 + if hasattr(self, "is_using_kubeflow") and callable(self.is_using_kubeflow): + rank_zero_deprecation( + f"`{self.__class__.__name__}.is_using_kubeflow` has been deprecated in v1.6 and will be removed in" + f" v1.7. Implement the static method `detect()` instead (do not forget to add the `@staticmethod`" + f" decorator)." + ) + + @property + def creates_processes_externally(self) -> bool: + return True + + @property + def main_address(self) -> str: + return os.environ["MASTER_ADDR"] + + @property + def main_port(self) -> int: + return int(os.environ["MASTER_PORT"]) + + @staticmethod + def detect() -> bool: + """Returns ``True`` if the current process was launched using Kubeflow PyTorchJob.""" + required_env_vars = {"KUBERNETES_PORT", "MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "RANK"} + # torchelastic sets these. Make sure we're not in torchelastic + excluded_env_vars = {"GROUP_RANK", "LOCAL_RANK", "LOCAL_WORLD_SIZE"} + env_vars = os.environ.keys() + return required_env_vars.issubset(env_vars) and excluded_env_vars.isdisjoint(env_vars) + + def world_size(self) -> int: + return int(os.environ["WORLD_SIZE"]) + + def set_world_size(self, size: int) -> None: + log.debug("KubeflowEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.") + + def global_rank(self) -> int: + return int(os.environ["RANK"]) + + def set_global_rank(self, rank: int) -> None: + log.debug("KubeflowEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.") + + def local_rank(self) -> int: + return 0 + + def node_rank(self) -> int: + return self.global_rank() diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/lightning_environment.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/lightning_environment.py new file mode 100644 index 0000000000000000000000000000000000000000..5792d7cc16a67933c62349639e077add2c350702 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/lightning_environment.py @@ -0,0 +1,101 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import socket + +from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment +from pytorch_lightning.utilities.rank_zero import rank_zero_only + + +class LightningEnvironment(ClusterEnvironment): + """The default environment used by Lightning for a single node or free cluster (not managed). + + There are two modes the Lightning environment can operate with: + + 1. The user only launches the main process by :code:`python train.py ...` with no additional environment variables + set. Lightning will spawn new worker processes for distributed training in the current node. + 2. The user launches all processes manually or with utilities like :code:`torch.distributed.launch`. + The appropriate environment variables need to be set, and at minimum :code:`LOCAL_RANK`. + + If the main address and port are not provided, the default environment will choose them + automatically. It is recommended to use this default environment for single-node distributed + training as it provides a convenient way to launch the training script. + """ + + def __init__(self) -> None: + super().__init__() + self._main_port: int = -1 + self._global_rank: int = 0 + self._world_size: int = 1 + + @property + def creates_processes_externally(self) -> bool: + """Returns whether the cluster creates the processes or not. + + If at least :code:`LOCAL_RANK` is available as environment variable, Lightning assumes the user acts as the + process launcher/job scheduler and Lightning will not launch new processes. + """ + return "LOCAL_RANK" in os.environ + + @property + def main_address(self) -> str: + return os.environ.get("MASTER_ADDR", "127.0.0.1") + + @property + def main_port(self) -> int: + if self._main_port == -1: + self._main_port = int(os.environ.get("MASTER_PORT", find_free_network_port())) + return self._main_port + + @staticmethod + def detect() -> bool: + return True + + def world_size(self) -> int: + return self._world_size + + def set_world_size(self, size: int) -> None: + self._world_size = size + + def global_rank(self) -> int: + return self._global_rank + + def set_global_rank(self, rank: int) -> None: + self._global_rank = rank + rank_zero_only.rank = rank + + def local_rank(self) -> int: + return int(os.environ.get("LOCAL_RANK", 0)) + + def node_rank(self) -> int: + group_rank = os.environ.get("GROUP_RANK", 0) + return int(os.environ.get("NODE_RANK", group_rank)) + + def teardown(self) -> None: + if "WORLD_SIZE" in os.environ: + del os.environ["WORLD_SIZE"] + + +def find_free_network_port() -> int: + """Finds a free port on localhost. + + It is useful in single-node training when we don't want to connect to a real main node but have to set the + `MASTER_PORT` environment variable. + """ + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind(("", 0)) + port = s.getsockname()[1] + s.close() + return port diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/lsf_environment.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/lsf_environment.py new file mode 100644 index 0000000000000000000000000000000000000000..150328dbf706acbf3b40a250e9d4e30392d68d37 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/lsf_environment.py @@ -0,0 +1,190 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import socket +from typing import Dict, List + +from pytorch_lightning import _logger as log +from pytorch_lightning.plugins.environments import ClusterEnvironment +from pytorch_lightning.utilities import rank_zero_deprecation +from pytorch_lightning.utilities.cloud_io import get_filesystem + + +class LSFEnvironment(ClusterEnvironment): + """An environment for running on clusters managed by the LSF resource manager. + + It is expected that any execution using this ClusterEnvironment was executed + using the Job Step Manager i.e. ``jsrun``. + + This plugin expects the following environment variables: + + ``LSB_JOBID`` + The LSF assigned job ID + + ``LSB_DJOB_RANKFILE`` + The OpenMPI compatible rank file for the LSF job + + ``JSM_NAMESPACE_LOCAL_RANK`` + The node local rank for the task. This environment variable is set by ``jsrun`` + + ``JSM_NAMESPACE_SIZE`` + The world size for the task. This environment variable is set by ``jsrun`` + + ``JSM_NAMESPACE_RANK`` + The global rank for the task. This environment variable is set by ``jsrun`` + """ + + def __init__(self) -> None: + super().__init__() + # TODO: remove in 1.7 + if hasattr(self, "is_using_lsf") and callable(self.is_using_lsf): + rank_zero_deprecation( + f"`{self.__class__.__name__}.is_using_lsf` has been deprecated in v1.6 and will be removed in v1.7." + " Implement the static method `detect()` instead (do not forget to add the `@staticmethod` decorator)." + ) + self._main_address = self._get_main_address() + self._main_port = self._get_main_port() + self._node_rank = self._get_node_rank() + self._set_init_progress_group_env_vars() + + def _set_init_progress_group_env_vars(self) -> None: + # set environment variables needed for initializing torch distributed process group + os.environ["MASTER_ADDR"] = str(self._main_address) + log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}") + os.environ["MASTER_PORT"] = str(self._main_port) + log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}") + + @property + def creates_processes_externally(self) -> bool: + """LSF creates subprocesses, i.e., PyTorch Lightning does not need to spawn them.""" + return True + + @property + def main_address(self) -> str: + """The main address is read from an OpenMPI host rank file in the environment variable + ``LSB_DJOB_RANKFILE``.""" + return self._main_address + + @property + def main_port(self) -> int: + """The main port is calculated from the LSF job ID.""" + return self._main_port + + @staticmethod + def detect() -> bool: + """Returns ``True`` if the current process was launched using the ``jsrun`` command.""" + required_env_vars = {"LSB_JOBID", "LSB_DJOB_RANKFILE", "JSM_NAMESPACE_LOCAL_RANK", "JSM_NAMESPACE_SIZE"} + return required_env_vars.issubset(os.environ.keys()) + + def world_size(self) -> int: + """The world size is read from the environment variable ``JSM_NAMESPACE_SIZE``.""" + world_size = os.environ.get("JSM_NAMESPACE_SIZE") + if world_size is None: + raise ValueError( + "Cannot determine world size. Environment variable `JSM_NAMESPACE_SIZE` not found." + "Make sure you run your executable with `jsrun`." + ) + return int(world_size) + + def set_world_size(self, size: int) -> None: + log.debug("LSFEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.") + + def global_rank(self) -> int: + """The world size is read from the environment variable ``JSM_NAMESPACE_RANK``.""" + global_rank = os.environ.get("JSM_NAMESPACE_RANK") + if global_rank is None: + raise ValueError( + "Cannot determine global rank. Environment variable `JSM_NAMESPACE_RANK` not found." + "Make sure you run your executable with `jsrun`." + ) + return int(global_rank) + + def set_global_rank(self, rank: int) -> None: + log.debug("LSFEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.") + + def local_rank(self) -> int: + """The local rank is read from the environment variable `JSM_NAMESPACE_LOCAL_RANK`.""" + local_rank = os.environ.get("JSM_NAMESPACE_LOCAL_RANK") + if local_rank is None: + raise ValueError( + "Cannot determine local rank. Environment variable `JSM_NAMESPACE_LOCAL_RANK` not found." + "Make sure you run your executable with `jsrun`." + ) + return int(local_rank) + + def node_rank(self) -> int: + """The node rank is determined by the position of the current hostname in the OpenMPI host rank file stored + in ``LSB_DJOB_RANKFILE``.""" + return self._node_rank + + def _get_node_rank(self) -> int: + """A helper method for getting the node rank. + + The node rank is determined by the position of the current node in the list of hosts used in the job. This is + calculated by reading all hosts from ``LSB_DJOB_RANKFILE`` and finding this node's hostname in the list. + """ + hosts = self._read_hosts() + count: Dict[str, int] = {} + for host in hosts: + if host not in count: + count[host] = len(count) + return count[socket.gethostname()] + + @staticmethod + def _read_hosts() -> List[str]: + """Read compute hosts that are a part of the compute job. + + LSF uses the Job Step Manager (JSM) to manage job steps. Job steps are executed by the JSM from "launch" nodes. + Each job is assigned a launch node. This launch node will be the first node in the list contained in + ``LSB_DJOB_RANKFILE``. + """ + var = "LSB_DJOB_RANKFILE" + rankfile = os.environ.get(var) + if rankfile is None: + raise ValueError("Did not find the environment variable `LSB_DJOB_RANKFILE`") + if not rankfile: + raise ValueError("The environment variable `LSB_DJOB_RANKFILE` is empty") + + fs = get_filesystem(rankfile) + with fs.open(rankfile, "r") as f: + ret = [line.strip() for line in f] + # remove the launch node (i.e. the first node in LSB_DJOB_RANKFILE) from the list + return ret[1:] + + def _get_main_address(self) -> str: + """A helper for getting the main address. + + The main address is assigned to the first node in the list of nodes used for the job. + """ + hosts = self._read_hosts() + return hosts[0] + + @staticmethod + def _get_main_port() -> int: + """A helper function for accessing the main port. + + Uses the LSF job ID so all ranks can compute the main port. + """ + # check for user-specified main port + if "MASTER_PORT" in os.environ: + log.debug(f"Using externally specified main port: {os.environ['MASTER_PORT']}") + return int(os.environ["MASTER_PORT"]) + if "LSB_JOBID" in os.environ: + port = int(os.environ["LSB_JOBID"]) + # all ports should be in the 10k+ range + port = port % 1000 + 10000 + log.debug(f"calculated LSF main port: {port}") + return port + raise ValueError("Could not find job id in environment variable LSB_JOBID") diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/slurm_environment.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/slurm_environment.py new file mode 100644 index 0000000000000000000000000000000000000000..c17d2d765464e6dacf76ed44653e0ba666e96e92 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/slurm_environment.py @@ -0,0 +1,134 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import re +from typing import Optional + +from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment + +log = logging.getLogger(__name__) + + +class SLURMEnvironment(ClusterEnvironment): + """Cluster environment for training on a cluster managed by SLURM. + + Args: + auto_requeue: Whether automatic job resubmission is enabled or not. How and under which conditions a job gets + rescheduled gets determined by the owner of this plugin. + """ + + def __init__(self, auto_requeue: bool = True) -> None: + super().__init__() + self.auto_requeue = auto_requeue + + @property + def creates_processes_externally(self) -> bool: + return True + + @property + def main_address(self) -> str: + # figure out the root node addr + slurm_nodelist = os.environ.get("SLURM_NODELIST") + if slurm_nodelist: + root_node = slurm_nodelist.split(" ")[0].split(",")[0] + else: + root_node = "127.0.0.1" + + root_node = self.resolve_root_node_address(root_node) + os.environ["MASTER_ADDR"] = root_node + log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}") + return root_node + + @property + def main_port(self) -> int: + # ----------------------- + # SLURM JOB = PORT number + # ----------------------- + # this way every process knows what port to use + job_id = os.environ.get("SLURM_JOB_ID") + if job_id is not None: + # use the last 4 numbers in the job id as the id + default_port = job_id[-4:] + # all ports should be in the 10k+ range + default_port = int(default_port) + 15000 + else: + default_port = 12910 + + # ----------------------- + # PORT NUMBER = MASTER_PORT + # ----------------------- + # in case the user passed it in + if "MASTER_PORT" in os.environ: + default_port = int(os.environ["MASTER_PORT"]) + else: + os.environ["MASTER_PORT"] = str(default_port) + + log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}") + return default_port + + @staticmethod + def detect() -> bool: + """Returns ``True`` if the current process was launched on a SLURM cluster.""" + return "SLURM_NTASKS" in os.environ + + @staticmethod + def job_name() -> Optional[str]: + return os.environ.get("SLURM_JOB_NAME") + + @staticmethod + def job_id() -> Optional[int]: + # in interactive mode, don't make logs use the same job id + in_slurm_interactive_mode = SLURMEnvironment.job_name() == "bash" + if in_slurm_interactive_mode: + return None + + job_id = os.environ.get("SLURM_JOB_ID") + if job_id is None: + return None + try: + return int(job_id) + except ValueError: + return None + + def world_size(self) -> int: + return int(os.environ["SLURM_NTASKS"]) + + def set_world_size(self, size: int) -> None: + log.debug("SLURMEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.") + + def global_rank(self) -> int: + return int(os.environ["SLURM_PROCID"]) + + def set_global_rank(self, rank: int) -> None: + log.debug("SLURMEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.") + + def local_rank(self) -> int: + return int(os.environ["SLURM_LOCALID"]) + + def node_rank(self) -> int: + return int(os.environ["SLURM_NODEID"]) + + def resolve_root_node_address(self, root_node: str) -> str: + if "[" in root_node: + name, numbers = root_node.split("[", maxsplit=1) + number = numbers.split(",", maxsplit=1)[0] + if "-" in number: + number = number.split("-")[0] + + number = re.sub("[^0-9]", "", number) + root_node = name + number + + return root_node diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/torchelastic_environment.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/torchelastic_environment.py new file mode 100644 index 0000000000000000000000000000000000000000..98cad39a0a4714807e87f4c7769ce69321303159 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/torchelastic_environment.py @@ -0,0 +1,88 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os + +import torch.distributed + +from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment +from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_9_1 +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn + +log = logging.getLogger(__name__) + + +class TorchElasticEnvironment(ClusterEnvironment): + """Environment for fault-tolerant and elastic training with `torchelastic `_""" + + def __init__(self) -> None: + super().__init__() + # TODO: remove in 1.7 + if hasattr(self, "is_using_torchelastic") and callable(self.is_using_torchelastic): + rank_zero_deprecation( + f"`{self.__class__.__name__}.is_using_torchelastic` has been deprecated in v1.6 and will be removed in" + " v1.7. Implement the static method `detect()` instead (do not forget to add the `@staticmethod`" + " decorator)." + ) + + @property + def creates_processes_externally(self) -> bool: + return True + + @property + def main_address(self) -> str: + if "MASTER_ADDR" not in os.environ: + rank_zero_warn("MASTER_ADDR environment variable is not defined. Set as localhost") + os.environ["MASTER_ADDR"] = "127.0.0.1" + log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}") + return os.environ["MASTER_ADDR"] + + @property + def main_port(self) -> int: + if "MASTER_PORT" not in os.environ: + rank_zero_warn("MASTER_PORT environment variable is not defined. Set as 12910") + os.environ["MASTER_PORT"] = "12910" + log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}") + + return int(os.environ["MASTER_PORT"]) + + @staticmethod + def detect() -> bool: + """Returns ``True`` if the current process was launched using the torchelastic command.""" + if _TORCH_GREATER_EQUAL_1_9_1: + # if not available (for example on MacOS), `is_torchelastic_launched` is not defined + return torch.distributed.is_available() and torch.distributed.is_torchelastic_launched() + required_env_vars = {"RANK", "GROUP_RANK", "LOCAL_RANK", "LOCAL_WORLD_SIZE"} + return required_env_vars.issubset(os.environ.keys()) + + def world_size(self) -> int: + return int(os.environ["WORLD_SIZE"]) + + def set_world_size(self, size: int) -> None: + log.debug("TorchElasticEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.") + + def global_rank(self) -> int: + return int(os.environ["RANK"]) + + def set_global_rank(self, rank: int) -> None: + log.debug( + "TorchElasticEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored." + ) + + def local_rank(self) -> int: + return int(os.environ["LOCAL_RANK"]) + + def node_rank(self) -> int: + return int(os.environ.get("GROUP_RANK", 0)) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/io/__init__.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/io/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..abd196eb2b1e339d84363d6770dfcabc9781d88b --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/io/__init__.py @@ -0,0 +1,17 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO # noqa: F401 +from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO # noqa: F401 +from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO # noqa: F401 +from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO # noqa: F401 diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/io/__pycache__/xla_plugin.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/io/__pycache__/xla_plugin.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f55cb84531f9e6e9cb559e3941180118993daa9 Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/io/__pycache__/xla_plugin.cpython-38.pyc differ diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/io/checkpoint_plugin.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/io/checkpoint_plugin.py new file mode 100644 index 0000000000000000000000000000000000000000..1425a229963b77acfbe10124b9d382dbd4e1ffe3 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/io/checkpoint_plugin.py @@ -0,0 +1,62 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional + +from pytorch_lightning.utilities.types import _PATH + + +class CheckpointIO(ABC): + """Interface to save/load checkpoints as they are saved through the ``Strategy``. + + Typically most plugins either use the Torch based IO Plugin; ``TorchCheckpointIO`` but may + require particular handling depending on the plugin. + + In addition, you can pass a custom ``CheckpointIO`` by extending this class and passing it + to the Trainer, i.e ``Trainer(plugins=[MyCustomCheckpointIO()])``. + + .. note:: + + For some plugins, it is not possible to use a custom checkpoint plugin as checkpointing logic is not + modifiable. + """ + + @abstractmethod + def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: + """Save model/training states as a checkpoint file through state-dump and file-write. + + Args: + checkpoint: dict containing model and trainer state + path: write-target path + storage_options: Optional parameters when saving the model/training states. + """ + + @abstractmethod + def load_checkpoint(self, path: _PATH, storage_options: Optional[Any] = None) -> Dict[str, Any]: + """Load checkpoint from a path when resuming or loading ckpt for test/validate/predict stages. + + Args: + path: Path to checkpoint + storage_options: Optional parameters when loading the model/training states. + + Returns: The loaded checkpoint. + """ + + @abstractmethod + def remove_checkpoint(self, path: _PATH) -> None: + """Remove checkpoint file from the filesystem. + + Args: + path: Path to checkpoint + """ diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/io/hpu_plugin.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/io/hpu_plugin.py new file mode 100644 index 0000000000000000000000000000000000000000..c72d1d9fcd1126026609046eca38e6a3453a3257 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/io/hpu_plugin.py @@ -0,0 +1,52 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Any, Dict, Optional + +import torch + +from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO +from pytorch_lightning.utilities.apply_func import move_data_to_device +from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem +from pytorch_lightning.utilities.types import _PATH + + +class HPUCheckpointIO(TorchCheckpointIO): + """CheckpointIO to save checkpoints for HPU training strategies.""" + + def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: + """Save model/training states as a checkpoint file through state-dump and file-write. + + Args: + checkpoint: dict containing model and trainer state + path: write-target path + storage_options: not used in ``XLACheckpointIO.save_checkpoint`` + + Raises: + TypeError: + If ``storage_options`` arg is passed in + """ + if storage_options is not None: + raise TypeError( + "`Trainer.save_checkpoint(..., storage_options=...)` with `storage_options` arg" + f" is not supported for `{self.__class__.__name__}`. Please implement your custom `CheckpointIO`" + " to define how you'd like to use `storage_options`." + ) + fs = get_filesystem(path) + fs.makedirs(os.path.dirname(path), exist_ok=True) + + checkpoint = move_data_to_device(checkpoint, torch.device("cpu")) + # write the checkpoint dictionary to the provided path + atomic_save(checkpoint, path) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/io/torch_plugin.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/io/torch_plugin.py new file mode 100644 index 0000000000000000000000000000000000000000..be10bf967ab05fede330e0a4b457d181ae09903d --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/io/torch_plugin.py @@ -0,0 +1,96 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +from typing import Any, Callable, Dict, Optional + +import pytorch_lightning as pl +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem +from pytorch_lightning.utilities.cloud_io import load as pl_load +from pytorch_lightning.utilities.rank_zero import rank_zero_warn +from pytorch_lightning.utilities.types import _PATH + +log = logging.getLogger(__name__) + + +class TorchCheckpointIO(CheckpointIO): + """CheckpointIO that utilizes :func:`torch.save` and :func:`torch.load` to save and load checkpoints + respectively, common for most use cases.""" + + def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: + """Save model/training states as a checkpoint file through state-dump and file-write. + + Args: + checkpoint: dict containing model and trainer state + path: write-target path + storage_options: not used in ``TorchCheckpointIO.save_checkpoint`` + + Raises: + TypeError: + If ``storage_options`` arg is passed in + """ + if storage_options is not None: + raise TypeError( + "`Trainer.save_checkpoint(..., storage_options=...)` with `storage_options` arg" + f" is not supported for `{self.__class__.__name__}`. Please implement your custom `CheckpointIO`" + " to define how you'd like to use `storage_options`." + ) + fs = get_filesystem(path) + fs.makedirs(os.path.dirname(path), exist_ok=True) + try: + # write the checkpoint dictionary on the file + atomic_save(checkpoint, path) + except AttributeError as err: + # todo (sean): is this try catch necessary still? + # https://github.com/PyTorchLightning/pytorch-lightning/pull/431 + key = pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY + checkpoint.pop(key, None) + rank_zero_warn(f"Warning, `{key}` dropped from checkpoint. An attribute is not picklable: {err}") + atomic_save(checkpoint, path) + + def load_checkpoint( + self, path: _PATH, map_location: Optional[Callable] = lambda storage, loc: storage + ) -> Dict[str, Any]: + """Loads checkpoint using :func:`torch.load`, with additional handling for ``fsspec`` remote loading of + files. + + Args: + path: Path to checkpoint + map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage + locations. + + Returns: The loaded checkpoint. + + Raises: + FileNotFoundError: If ``path`` is not found by the ``fsspec`` filesystem + """ + + # Try to read the checkpoint at `path`. If not exist, do not restore checkpoint. + fs = get_filesystem(path) + if not fs.exists(path): + raise FileNotFoundError(f"Checkpoint at {path} not found. Aborting training.") + + return pl_load(path, map_location=map_location) + + def remove_checkpoint(self, path: _PATH) -> None: + """Remove checkpoint file from the filesystem. + + Args: + path: Path to checkpoint + """ + fs = get_filesystem(path) + if fs.exists(path): + fs.rm(path, recursive=True) + log.debug(f"Removed checkpoint: {path}") diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/io/xla_plugin.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/io/xla_plugin.py new file mode 100644 index 0000000000000000000000000000000000000000..3868995eea2c75da2a024e0a3586e884345f2bd8 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/io/xla_plugin.py @@ -0,0 +1,57 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import Any, Dict, Optional + +from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO +from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE +from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.cloud_io import get_filesystem +from pytorch_lightning.utilities.types import _PATH + +if _TPU_AVAILABLE: + import torch_xla.core.xla_model as xm + +if _OMEGACONF_AVAILABLE: + from omegaconf import DictConfig, ListConfig, OmegaConf + + +class XLACheckpointIO(TorchCheckpointIO): + """CheckpointIO that utilizes :func:`xm.save` to save checkpoints for TPU training strategies.""" + + def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: + """Save model/training states as a checkpoint file through state-dump and file-write. + + Args: + checkpoint: dict containing model and trainer state + path: write-target path + storage_options: not used in ``XLACheckpointIO.save_checkpoint`` + + Raises: + TypeError: + If ``storage_options`` arg is passed in + """ + if storage_options is not None: + raise TypeError( + "`Trainer.save_checkpoint(..., storage_options=...)` with `storage_options` arg" + f" is not supported for `{self.__class__.__name__}`. Please implement your custom `CheckpointIO`" + " to define how you'd like to use `storage_options`." + ) + fs = get_filesystem(path) + fs.makedirs(os.path.dirname(path), exist_ok=True) + # Todo: TypeError: 'mappingproxy' object does not support item assignment + # Ref: https://github.com/pytorch/xla/issues/2773 + if _OMEGACONF_AVAILABLE: + checkpoint = apply_to_collection(checkpoint, (DictConfig, ListConfig), OmegaConf.to_container) + xm.save({k: v for k, v in checkpoint.items() if k != "callbacks"}, path) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__init__.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4bc29c1be18649cf1b4c5e1d334a2033263e0a45 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__init__.py @@ -0,0 +1,27 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401 +from pytorch_lightning.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin # noqa: F401 +from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin # noqa: F401 +from pytorch_lightning.plugins.precision.fully_sharded_native_amp import ( # noqa: F401 + FullyShardedNativeMixedPrecisionPlugin, +) +from pytorch_lightning.plugins.precision.hpu import HPUPrecisionPlugin # noqa: F401 +from pytorch_lightning.plugins.precision.ipu import IPUPrecisionPlugin # noqa: F401 +from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin # noqa: F401 +from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401 +from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401 +from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin # noqa: F401 +from pytorch_lightning.plugins.precision.tpu import TPUPrecisionPlugin # noqa: F401 +from pytorch_lightning.plugins.precision.tpu_bf16 import TPUBf16PrecisionPlugin # noqa: F401 diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__pycache__/apex_amp.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__pycache__/apex_amp.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a36eb5a4a816fe627201fbdf1f915b9a2b6344b Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__pycache__/apex_amp.cpython-38.pyc differ diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__pycache__/deepspeed.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__pycache__/deepspeed.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5dbcbe44733b8e7748a3c009a49a86aedff2bf90 Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__pycache__/deepspeed.cpython-38.pyc differ diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__pycache__/double.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__pycache__/double.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84339a1e7b54af23c192d213937c6eef152f13f0 Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__pycache__/double.cpython-38.pyc differ diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__pycache__/fully_sharded_native_amp.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__pycache__/fully_sharded_native_amp.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..478c9475febad1bcef2fcaa2258864d91d088fba Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__pycache__/fully_sharded_native_amp.cpython-38.pyc differ diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__pycache__/mixed.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__pycache__/mixed.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a63041b1f90e51d84c3834fb2a9ab689ec840f6 Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__pycache__/mixed.cpython-38.pyc differ diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__pycache__/native_amp.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__pycache__/native_amp.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a0b727158614ec823ce63f4ffe9189823078bf4 Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__pycache__/native_amp.cpython-38.pyc differ diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__pycache__/sharded_native_amp.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__pycache__/sharded_native_amp.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34eaaeeaae4a743398a607abe0dd91ae6c66aa9e Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__pycache__/sharded_native_amp.cpython-38.pyc differ diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__pycache__/tpu_bf16.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__pycache__/tpu_bf16.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4d1affe8450befce95785b5ee3f172bec8d0104 Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__pycache__/tpu_bf16.cpython-38.pyc differ diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/apex_amp.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/apex_amp.py new file mode 100644 index 0000000000000000000000000000000000000000..fd29efeb9f4fbc0264af34500d3a01246848dc50 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/apex_amp.py @@ -0,0 +1,101 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Dict, Optional, Union + +from torch import Tensor +from torch.nn import Module +from torch.optim import LBFGS, Optimizer + +import pytorch_lightning as pl +from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin +from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.types import _PARAMETERS + +if _APEX_AVAILABLE: + from apex import amp + + +class ApexMixedPrecisionPlugin(MixedPrecisionPlugin): + """Mixed Precision Plugin based on Nvidia/Apex (https://github.com/NVIDIA/apex)""" + + backend = AMPType.APEX + + def __init__(self, amp_level: str = "O2") -> None: + if not _APEX_AVAILABLE: + raise MisconfigurationException( + "You have asked for Apex AMP but you have not installed it." + " Install `apex` using this guide: https://github.com/NVIDIA/apex" + ) + super().__init__() + self.amp_level = amp_level + self._connected = False + + def main_params(self, optimizer: Optimizer) -> _PARAMETERS: + return amp.master_params(optimizer) + + def dispatch(self, trainer: "pl.Trainer") -> None: + if not self._connected: + strategy = trainer.strategy + _, strategy.optimizers = amp.initialize( + trainer.lightning_module, strategy.optimizers, opt_level=self.amp_level + ) + self._connected = True + return super().dispatch(trainer) + + def backward( + self, + model: "pl.LightningModule", + closure_loss: Tensor, + optimizer: Optional[Optimizer], + *args: Any, + **kwargs: Any, + ) -> None: + """Run before precision plugin executes backward. + + Args: + model: the model to be optimized + closure_loss: the loss value obtained from the closure + optimizer: current optimizer being used. ``None`` if using manual optimization + """ + assert model.trainer is not None + opt = optimizer or model.trainer.optimizers + with amp.scale_loss(closure_loss, opt) as closure_loss: + super().backward(model, closure_loss, optimizer, *args, **kwargs) + + def optimizer_step( + self, + model: Union["pl.LightningModule", Module], + optimizer: Optimizer, + optimizer_idx: int, + closure: Callable[[], Any], + **kwargs: Any, + ) -> Any: + if isinstance(optimizer, LBFGS): + raise MisconfigurationException( + f"apex AMP and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})." + ) + closure_result = closure() + self._after_closure(model, optimizer, optimizer_idx) + skipped_backward = closure_result is None + # in manual optimization, the closure does not return a value + if not isinstance(model, pl.LightningModule) or not model.automatic_optimization or not skipped_backward: + return optimizer.step(**kwargs) + return closure_result + + def state_dict(self) -> Dict[str, Any]: + return amp.state_dict() + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + amp.load_state_dict(state_dict) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/deepspeed.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/deepspeed.py new file mode 100644 index 0000000000000000000000000000000000000000..3b70096dd5058f5c4e3a27cacb390033838d6a75 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/deepspeed.py @@ -0,0 +1,102 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Optional, Union + +from torch import Tensor +from torch.nn import Module +from torch.optim import LBFGS, Optimizer + +import pytorch_lightning as pl +from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin +from pytorch_lightning.utilities import GradClipAlgorithmType +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE +from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.warnings import WarningCache + +if _DEEPSPEED_AVAILABLE: + from deepspeed import DeepSpeedEngine + +warning_cache = WarningCache() + + +class DeepSpeedPrecisionPlugin(PrecisionPlugin): + """Precision plugin for DeepSpeed integration.""" + + def __init__(self, precision: Union[str, int], amp_type: str, amp_level: Optional[str] = None) -> None: + super().__init__() + self.precision = precision + self.amp_type = amp_type + self.amp_level = amp_level + + def backward(self, model: "pl.LightningModule", closure_loss: Tensor, *args: Any, **kwargs: Any) -> None: + if is_overridden("backward", model): + warning_cache.warn( + "You have overridden the `LightningModule.backward` hook but it will be ignored since DeepSpeed handles" + " the backward logic internally." + ) + assert model.trainer is not None + deepspeed_engine: DeepSpeedEngine = model.trainer.model + deepspeed_engine.backward(closure_loss, *args, **kwargs) + + def _run_backward(self, tensor: Tensor, model: Optional["DeepSpeedEngine"], *args: Any, **kwargs: Any) -> None: + if model is None: + raise ValueError("Please provide the model as input to `backward`.") + model.backward(tensor, *args, **kwargs) + + def optimizer_step( + self, + model: Union["pl.LightningModule", Module], + optimizer: Optimizer, + optimizer_idx: int, + closure: Callable[[], Any], + **kwargs: Any, + ) -> Any: + if isinstance(optimizer, LBFGS): + raise MisconfigurationException( + f"DeepSpeed and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})." + ) + closure_result = closure() + self._after_closure(model, optimizer, optimizer_idx) + skipped_backward = closure_result is None + # in manual optimization, the closure does not return a value + if isinstance(model, pl.LightningModule) and model.automatic_optimization and skipped_backward: + raise MisconfigurationException( + "Skipping backward by returning `None` from your `training_step` is not supported by `DeepSpeed`" + ) + # DeepSpeed handles the optimizer step internally + deepspeed_engine: DeepSpeedEngine + if isinstance(model, pl.LightningModule): + assert model.trainer is not None + deepspeed_engine = model.trainer.model + else: + deepspeed_engine = model + return deepspeed_engine.step(**kwargs) + + def clip_gradients( + self, + optimizer: Optimizer, + clip_val: Union[int, float] = 0.0, + gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, + ) -> None: + """DeepSpeed handles gradient clipping internally.""" + + def _track_grad_norm(self, trainer: "pl.Trainer") -> None: + if trainer.track_grad_norm == -1: + return + # the gradients are not available in the model due to gradient partitioning in zero stage >= 2 + warning_cache.warn( + f"You set `Trainer(track_grad_norm={trainer.track_grad_norm!r})' but this is not supported for DeepSpeed." + " The setting will be ignored." + ) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/double.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/double.py new file mode 100644 index 0000000000000000000000000000000000000000..5e9e8bd43b820588e686c71dd6b3df37b9bcb39c --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/double.py @@ -0,0 +1,102 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from contextlib import contextmanager +from typing import Any, cast, Generator, List, Tuple + +import torch +import torch.nn as nn +from torch.optim import Optimizer + +import pytorch_lightning as pl +from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase +from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin +from pytorch_lightning.utilities.apply_func import apply_to_collection + + +class LightningDoublePrecisionModule(_LightningPrecisionModuleWrapperBase): + """LightningModule wrapper which converts incoming floating point data in ``*_step`` and ``forward`` to double + (``torch.float64``) precision. + + Args: + pl_module: the model to wrap + """ + + @staticmethod + def _to_double_precision(data: torch.Tensor) -> torch.Tensor: + if data.is_floating_point(): + return data.double() + return data + + @staticmethod + def _move_float_tensors_to_double(collection: Any) -> Any: + return apply_to_collection(collection, torch.Tensor, LightningDoublePrecisionModule._to_double_precision) + + def training_step(self, *args: Any, **kwargs: Any) -> Any: + return self.module.training_step( + *LightningDoublePrecisionModule._move_float_tensors_to_double(args), + **LightningDoublePrecisionModule._move_float_tensors_to_double(kwargs), + ) + + def validation_step(self, *args: Any, **kwargs: Any) -> Any: + return self.module.validation_step( + *LightningDoublePrecisionModule._move_float_tensors_to_double(args), + **LightningDoublePrecisionModule._move_float_tensors_to_double(kwargs), + ) + + def test_step(self, *args: Any, **kwargs: Any) -> Any: + return self.module.test_step( + *LightningDoublePrecisionModule._move_float_tensors_to_double(args), + **LightningDoublePrecisionModule._move_float_tensors_to_double(kwargs), + ) + + def predict_step(self, *args: Any, **kwargs: Any) -> Any: + return self.module.predict_step( + *LightningDoublePrecisionModule._move_float_tensors_to_double(args), + **LightningDoublePrecisionModule._move_float_tensors_to_double(kwargs), + ) + + def forward(self, *args: Any, **kwargs: Any) -> Any: + return self.module( + *LightningDoublePrecisionModule._move_float_tensors_to_double(args), + **LightningDoublePrecisionModule._move_float_tensors_to_double(kwargs), + ) + + +class DoublePrecisionPlugin(PrecisionPlugin): + """Plugin for training with double (``torch.float64``) precision.""" + + precision: int = 64 + + def connect( + self, model: nn.Module, optimizers: List[Optimizer], lr_schedulers: List[Any] + ) -> Tuple[nn.Module, List["Optimizer"], List[Any]]: + """Converts the model to double precision and wraps it in a ``LightningDoublePrecisionModule`` to convert + incoming floating point data to double (``torch.float64``) precision. + + Does not alter `optimizers` or `lr_schedulers`. + """ + model = cast(pl.LightningModule, model.double()) + model = LightningDoublePrecisionModule(model) + + return super().connect(model, optimizers, lr_schedulers) + + @contextmanager + def forward_context(self) -> Generator[None, None, None]: + """A context manager to change the default tensor type. + + See: :meth:`torch.set_default_tensor_type` + """ + torch.set_default_tensor_type(torch.DoubleTensor) + yield + torch.set_default_tensor_type(torch.FloatTensor) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py new file mode 100644 index 0000000000000000000000000000000000000000..870e658bfc9c327296446b68174e74376649a36b --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py @@ -0,0 +1,31 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any + +from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin +from pytorch_lightning.utilities.exceptions import MisconfigurationException + + +class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin): + """Native AMP for Fully Sharded Training.""" + + def clip_grad_by_norm(self, *_: Any, **__: Any) -> None: + # see https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html + # section `Gradient Clipping`, using `torch.nn.utils.clip_grad_norm_` is incorrect + # for FSDP module. To overcome this, needs to call sharded_module.clip_grad_norm(clip_val) + # however we rely on LightningModule's configure_sharded_model to wrap FSDP, it would be hard to + # trace back the root FSDP. Now we only support clip by value. + raise MisconfigurationException( + f"`gradient_clip_algorithm='norm'` is currently not supported for `{self.__class__.__name__}`" + ) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/hpu.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/hpu.py new file mode 100644 index 0000000000000000000000000000000000000000..4f8db7dabb46041926645b202b919674f6f3bb62 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/hpu.py @@ -0,0 +1,57 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Union + +from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin +from pytorch_lightning.utilities.enums import PrecisionType +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _HPU_AVAILABLE + +if _HPU_AVAILABLE: + from habana_frameworks.torch.hpex import hmp + + +class HPUPrecisionPlugin(PrecisionPlugin): + """Plugin that enables bfloat/half support on HPUs. + + Args: + precision: The precision to use. + opt_level: Choose optimization level for hmp. + bf16_file_path: Path to bf16 ops list in hmp O1 mode. + fp32_file_path: Path to fp32 ops list in hmp O1 mode. + verbose: Enable verbose mode for hmp. + """ + + def __init__( + self, + precision: Union[str, int], + opt_level: str = "O2", + bf16_file_path: Optional[str] = None, + fp32_file_path: Optional[str] = None, + verbose: bool = False, + ) -> None: + if not _HPU_AVAILABLE: + raise MisconfigurationException("HPU precision plugin requires HPU devices.") + supported_precision_values = (16, 32, "bf16") + if precision not in supported_precision_values: + raise ValueError( + f"`Trainer(accelerator='hpu', precision={precision!r})` is not supported." + f" `precision` must be one of: {supported_precision_values}." + ) + super().__init__() + self.precision = precision + if precision in (PrecisionType.HALF, PrecisionType.BFLOAT): + hmp.convert( + opt_level=opt_level, bf16_file_path=bf16_file_path, fp32_file_path=fp32_file_path, isVerbose=verbose + ) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/ipu.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/ipu.py new file mode 100644 index 0000000000000000000000000000000000000000..9df0edb53913b1be47daaab991e25be21d0f21c9 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/ipu.py @@ -0,0 +1,88 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Union + +from torch.nn import Module +from torch.optim import LBFGS, Optimizer + +import pytorch_lightning as pl +from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin +from pytorch_lightning.utilities import GradClipAlgorithmType +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.warnings import WarningCache + +warning_cache = WarningCache() + + +class IPUPrecisionPlugin(PrecisionPlugin): + """Precision plugin for IPU integration. + + Raises: + ValueError: + If the precision is neither 16 nor 32. + """ + + def __init__(self, precision: int) -> None: + supported_precision_values = (16, 32) + if precision not in supported_precision_values: + raise ValueError( + f"`Trainer(accelerator='ipu', precision={precision!r})` is not supported." + f" `precision` must be one of: {supported_precision_values}." + ) + super().__init__() + self.precision = precision + + def backward(self, model: "pl.LightningModule", *args: Any, **kwargs: Any) -> None: + if is_overridden("backward", model): + warning_cache.warn( + "You have overridden the `LightningModule.backward` hook but it will be ignored since IPUs handle" + " the backward logic internally." + ) + + def optimizer_step( + self, + model: Union["pl.LightningModule", Module], + optimizer: Optimizer, + optimizer_idx: int, + closure: Callable[[], Any], + **kwargs: Any, + ) -> Any: + """IPUs handle the optimizer step internally.""" + if isinstance(optimizer, LBFGS): + raise MisconfigurationException( + f"IPUs and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})." + ) + closure_result = closure() + self._after_closure(model, optimizer, optimizer_idx) + skipped_backward = closure_result is None + # in manual optimization, the closure does not return a value + if isinstance(model, pl.LightningModule) and model.automatic_optimization and skipped_backward: + # we lack coverage here and IPUs are (currently) limited - something to explore if there's demand + raise MisconfigurationException( + "Skipping backward by returning `None` from your `training_step` is not implemented for IPUs." + " Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`" + " requesting this feature." + ) + return closure_result + + def clip_gradients( + self, + optimizer: Optimizer, + clip_val: Union[int, float] = 0.0, + gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, + ) -> None: + if clip_val <= 0: + return + raise MisconfigurationException("IPUs currently do not support clipping gradients.") diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/mixed.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/mixed.py new file mode 100644 index 0000000000000000000000000000000000000000..52c8b96d42882497ed34b347e61a7cc6c4258ef2 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/mixed.py @@ -0,0 +1,26 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING, Union + +from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin + +if TYPE_CHECKING: + from pytorch_lightning.utilities import AMPType + + +class MixedPrecisionPlugin(PrecisionPlugin): + """Base Class for mixed precision.""" + + backend: "AMPType" + precision: Union[str, int] = "mixed" diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/native_amp.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/native_amp.py new file mode 100644 index 0000000000000000000000000000000000000000..fa749af1d4a088f6ca5db8fc192f1d7170c8d4cc --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/native_amp.py @@ -0,0 +1,118 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from contextlib import contextmanager +from typing import Any, Callable, Dict, Generator, Optional, Union + +import torch +from torch import Tensor +from torch.nn import Module +from torch.optim import LBFGS, Optimizer + +import pytorch_lightning as pl +from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin +from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10, AMPType +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +if _TORCH_GREATER_EQUAL_1_10: + from torch import autocast as new_autocast +else: + from torch.cuda.amp import autocast as old_autocast + + +class NativeMixedPrecisionPlugin(MixedPrecisionPlugin): + """Plugin for Native Mixed Precision (AMP) training with ``torch.autocast``. + + Args: + precision: Whether to use ``torch.float16`` (``16``) or ``torch.bfloat16`` (``'bf16'``). + device: The device for ``torch.autocast``. + scaler: An optional :class:`torch.cuda.amp.GradScaler` to use. + """ + + backend = AMPType.NATIVE + + def __init__( + self, precision: Union[str, int], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None + ) -> None: + super().__init__() + if precision == "bf16" and not _TORCH_GREATER_EQUAL_1_10: + raise MisconfigurationException( + "To use bfloat16 with native amp you must install torch greater or equal to 1.10." + ) + if scaler is None and precision == 16: + scaler = torch.cuda.amp.GradScaler() + if scaler is not None and precision == "bf16": + raise MisconfigurationException(f"`precision='bf16'` does not use a scaler, found {scaler}.") + self.precision = precision + self.device = device + self.scaler = scaler + + def pre_backward(self, model: "pl.LightningModule", closure_loss: torch.Tensor) -> torch.Tensor: + if self.scaler is not None: + closure_loss = self.scaler.scale(closure_loss) + return super().pre_backward(model, closure_loss) + + def _run_backward(self, tensor: Tensor, model: Optional[Module], *args: Any, **kwargs: Any) -> None: + if self.scaler is not None: + tensor = self.scaler.scale(tensor) + super()._run_backward(tensor, model, *args, **kwargs) + + def optimizer_step( + self, + model: Union["pl.LightningModule", Module], + optimizer: Optimizer, + optimizer_idx: int, + closure: Callable[[], Any], + **kwargs: Any, + ) -> Any: + if self.scaler is None: + # skip scaler logic, as bfloat16 does not require scaler + return super().optimizer_step(model, optimizer, optimizer_idx, closure, **kwargs) + if isinstance(optimizer, LBFGS): + raise MisconfigurationException( + f"Native AMP and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})." + ) + closure_result = closure() + # `unscale` after the closure is executed but before the `on_before_optimizer_step` hook. + self.scaler.unscale_(optimizer) + self._after_closure(model, optimizer, optimizer_idx) + skipped_backward = closure_result is None + # in manual optimization, the closure does not return a value + if not isinstance(model, pl.LightningModule) or not model.automatic_optimization or not skipped_backward: + # note: the scaler will skip the `optimizer.step` if nonfinite gradients are found + step_output = self.scaler.step(optimizer, **kwargs) + self.scaler.update() + return step_output + return closure_result + + def autocast_context_manager(self) -> Union["old_autocast", "new_autocast"]: + if _TORCH_GREATER_EQUAL_1_10: + # the dtype could be automatically inferred but we need to manually set it due to a bug upstream + # https://github.com/pytorch/pytorch/issues/67233 + return new_autocast(self.device, dtype=torch.bfloat16 if self.precision == "bf16" else torch.half) + return old_autocast() + + @contextmanager + def forward_context(self) -> Generator[None, None, None]: + """Enable autocast context.""" + with self.autocast_context_manager(): + yield + + def state_dict(self) -> Dict[str, Any]: + if self.scaler is not None: + return self.scaler.state_dict() + return {} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + if self.scaler is not None: + self.scaler.load_state_dict(state_dict) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py new file mode 100644 index 0000000000000000000000000000000000000000..bdd63bba1785474a7377066880b8a2e896a4e3ce --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py @@ -0,0 +1,279 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import contextlib +from functools import partial +from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union + +import torch +from torch import Tensor +from torch.nn import Module +from torch.optim import Optimizer + +import pytorch_lightning as pl +from pytorch_lightning.core.hooks import CheckpointHooks +from pytorch_lightning.utilities import grad_norm, GradClipAlgorithmType +from pytorch_lightning.utilities.types import _PARAMETERS + + +class PrecisionPlugin(CheckpointHooks): + """Base class for all plugins handling the precision-specific parts of the training. + + The class attribute precision must be overwritten in child classes. The default value reflects fp32 training. + """ + + precision: Union[str, int] = 32 + + def main_params(self, optimizer: Optimizer) -> _PARAMETERS: + """The main params of the model. + + Returns the plain model params here. Maybe different in other precision plugins. + """ + for group in optimizer.param_groups: + yield from group["params"] + + def connect( + self, model: Module, optimizers: List[Optimizer], lr_schedulers: List[Any] + ) -> Tuple[Module, List[Optimizer], List[Any]]: + """Connects this plugin to the accelerator and the training process.""" + return model, optimizers, lr_schedulers + + def pre_backward(self, model: "pl.LightningModule", closure_loss: Tensor) -> Tensor: + """Run before precision plugin executes backward. + + Args: + model: the model to be optimized + closure_loss: the loss value obtained from the closure + """ + assert model.trainer is not None + model.trainer._call_callback_hooks("on_before_backward", closure_loss) + model.trainer._call_lightning_module_hook("on_before_backward", closure_loss) + return closure_loss + + def backward( + self, + model: "pl.LightningModule", + closure_loss: Tensor, + optimizer: Optional[Optimizer], + *args: Any, + **kwargs: Any, + ) -> None: + """Performs the actual backpropagation. + + Args: + model: the model to be optimized + closure_loss: the loss value obtained from the closure + optimizer: current optimizer being used. ``None`` if using manual optimization + """ + # do backward pass + if model is not None and isinstance(model, pl.LightningModule): + model.backward(closure_loss, optimizer, *args, **kwargs) + else: + self._run_backward(closure_loss, *args, **kwargs) + + def post_backward(self, model: "pl.LightningModule", closure_loss: Tensor) -> Tensor: + """Run after precision plugin executes backward. + + Args: + model: the model to be optimized + closure_loss: the loss value obtained from the closure + """ + # once backward has been applied, release graph + closure_loss = closure_loss.detach() + assert model.trainer is not None + model.trainer._call_callback_hooks("on_after_backward") + model.trainer._call_lightning_module_hook("on_after_backward") + return closure_loss + + def _run_backward(self, tensor: Tensor, model: Optional[Module], *args: Any, **kwargs: Any) -> None: + """Lightning-independent backward logic. + + Currently only used by Lightning Lite. Subject to further refactors. + """ + tensor.backward(*args, **kwargs) + + def _after_closure( + self, model: Union["pl.LightningModule", Module], optimizer: Optimizer, optimizer_idx: int + ) -> None: + """Utility to share some code after the closure has been run.""" + if not isinstance(model, pl.LightningModule): + # none of this applies to Lite + return + trainer = model.trainer + assert trainer is not None + trainer._call_callback_hooks("on_before_optimizer_step", optimizer, optimizer_idx) + trainer._call_lightning_module_hook("on_before_optimizer_step", optimizer, optimizer_idx) + # TODO: this is done for the entire model but should be changed to per-optimizer + if optimizer_idx == 0: + self._track_grad_norm(trainer) + self._clip_gradients( + model, + optimizer, + optimizer_idx, + trainer.gradient_clip_val, + gradient_clip_algorithm=trainer.gradient_clip_algorithm, + ) + + def _wrap_closure( + self, + model: "pl.LightningModule", + optimizer: Optimizer, + optimizer_idx: int, + closure: Callable[[], Any], + ) -> Any: + """This double-closure allows makes sure the ``closure`` is executed before the + ``on_before_optimizer_step`` hook is called. + + The closure (generally) runs ``backward`` so this allows inspecting gradients in this hook. This structure is + consistent with the ``PrecisionPlugin`` subclasses that cannot pass ``optimizer.step(closure)`` directly. + """ + closure_result = closure() + self._after_closure(model, optimizer, optimizer_idx) + return closure_result + + def optimizer_step( + self, + model: Union["pl.LightningModule", Module], + optimizer: Optimizer, + optimizer_idx: int, + closure: Callable[[], Any], + **kwargs: Any, + ) -> Any: + """Hook to run the optimizer step.""" + if isinstance(model, pl.LightningModule): + closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure) + return optimizer.step(closure=closure, **kwargs) + + def _track_grad_norm(self, trainer: "pl.Trainer") -> None: + if trainer.track_grad_norm == -1: + return + + kwargs = {} + if len(trainer.loggers) == 1: + kwargs["group_separator"] = trainer.loggers[0].group_separator + + grad_norm_dict = grad_norm(trainer.lightning_module, trainer.track_grad_norm, **kwargs) + if grad_norm_dict: + prev_fx = trainer.lightning_module._current_fx_name + trainer.lightning_module._current_fx_name = "on_before_optimizer_step" + trainer.lightning_module.log_grad_norm(grad_norm_dict) + trainer.lightning_module._current_fx_name = prev_fx + + def _clip_gradients( + self, + model: Union["pl.LightningModule", Module], + optimizer: Optimizer, + optimizer_idx: int, + clip_val: Optional[Union[int, float]] = None, + gradient_clip_algorithm: Optional[GradClipAlgorithmType] = None, + ) -> None: + if not isinstance(model, pl.LightningModule) or not model.automatic_optimization: + # the configuration validator disallows clipping on manual + return + model.configure_gradient_clipping( + optimizer, + optimizer_idx, + gradient_clip_val=clip_val, + gradient_clip_algorithm=gradient_clip_algorithm, + ) + + def clip_gradients( + self, + optimizer: Optimizer, + clip_val: Union[int, float] = 0.0, + gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, + ) -> None: + """Clips the gradients.""" + if clip_val <= 0: + return + if gradient_clip_algorithm == GradClipAlgorithmType.VALUE: + self.clip_grad_by_value(optimizer, clip_val) + elif gradient_clip_algorithm == GradClipAlgorithmType.NORM: + self.clip_grad_by_norm(optimizer, clip_val) + + def clip_grad_by_value(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None: + """Clip gradients by value.""" + parameters = self.main_params(optimizer) + torch.nn.utils.clip_grad_value_(parameters, clip_value=clip_val) + + def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None: + """Clip gradients by norm.""" + parameters = self.main_params(optimizer) + torch.nn.utils.clip_grad_norm_(parameters, clip_val) + + def dispatch(self, trainer: "pl.Trainer") -> None: + """Hook to do something when ``Strategy.dispatch()`` gets called.""" + + @contextlib.contextmanager + def forward_context(self) -> Generator[None, None, None]: + """A contextmanager for managing model forward/training_step/evaluation_step/predict_step.""" + yield + + @contextlib.contextmanager + def train_step_context(self) -> Generator[None, None, None]: + """A contextmanager for the training step.""" + with self.forward_context(): + yield + + @contextlib.contextmanager + def val_step_context(self) -> Generator[None, None, None]: + """A contextmanager for the validation step.""" + with self.forward_context(): + yield + + @contextlib.contextmanager + def test_step_context(self) -> Generator[None, None, None]: + """A contextmanager for the test step.""" + with self.forward_context(): + yield + + @contextlib.contextmanager + def predict_step_context(self) -> Generator[None, None, None]: + """A contextmanager for the predict step.""" + with self.forward_context(): + yield + + def teardown(self) -> None: + """This method is called to teardown the training process. + + It is the right place to release memory and free other resources. + """ + + def state_dict(self) -> Dict[str, Any]: + """Called when saving a checkpoint, implement to generate precision plugin state_dict. + + Returns: + A dictionary containing precision plugin state. + """ + return {} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + """Called when loading a checkpoint, implement to reload precision plugin state given precision plugin + state_dict. + + Args: + state_dict: the precision plugin state returned by ``state_dict``. + """ + pass + + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + """``PrecisionPlugin.on_save_checkpoint`` was deprecated in v1.6 and will be removed in v1.8. + + Use ``state_dict`` instead. + """ + + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + """``PrecisionPlugin.on_load_checkpoint`` was deprecated in v1.6 and will be removed in v1.8. + + Use ``load_state_dict`` instead. + """ diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/sharded_native_amp.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/sharded_native_amp.py new file mode 100644 index 0000000000000000000000000000000000000000..e40aea8ecf4eba1bafe23239c7fe4dc134d32679 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/sharded_native_amp.py @@ -0,0 +1,41 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Union + +import torch + +from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin +from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +if _FAIRSCALE_AVAILABLE: + from fairscale.optim import OSS + from fairscale.optim.grad_scaler import ShardedGradScaler + + +class ShardedNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin): + """Native AMP for Sharded Training.""" + + def __init__( + self, precision: Union[str, int], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None + ) -> None: + if not _FAIRSCALE_AVAILABLE: + raise MisconfigurationException( + "You have asked for sharded AMP but you have not installed it." + " Install `fairscale` using this guide: https://https://github.com/facebookresearch/fairscale" + ) + super().__init__(precision, device, scaler=ShardedGradScaler() if scaler is None and precision == 16 else None) + + def clip_grad_by_norm(self, optimizer: "OSS", clip_val: Union[int, float]) -> None: + optimizer.clip_grad_norm(clip_val) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/tpu.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/tpu.py new file mode 100644 index 0000000000000000000000000000000000000000..1afd34264c60cc297eb0c1e721923a52e5ae1580 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/tpu.py @@ -0,0 +1,52 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial +from typing import Any, Callable, Union + +from torch.nn import Module +from torch.optim import Optimizer + +import pytorch_lightning as pl +from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin +from pytorch_lightning.utilities import _XLA_AVAILABLE +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +if _XLA_AVAILABLE: + import torch_xla.core.xla_model as xm + + +class TPUPrecisionPlugin(PrecisionPlugin): + """Precision plugin for TPU integration.""" + + def optimizer_step( + self, + model: Union["pl.LightningModule", Module], + optimizer: Optimizer, + optimizer_idx: int, + closure: Callable[[], Any], + **kwargs: Any + ) -> Any: + if isinstance(model, pl.LightningModule): + closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure) + closure_result = xm.optimizer_step(optimizer, optimizer_args={"closure": closure, **kwargs}) + skipped_backward = closure_result is None + # in manual optimization, the closure does not return a value + if isinstance(model, pl.LightningModule) and model.automatic_optimization and skipped_backward: + # we lack coverage here so disable this - something to explore if there's demand + raise MisconfigurationException( + "Skipping backward by returning `None` from your `training_step` is not implemented for TPUs." + " Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`" + " requesting this feature." + ) + return closure_result diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/tpu_bf16.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/tpu_bf16.py new file mode 100644 index 0000000000000000000000000000000000000000..94254313b85beabfaaa1a4050ae880fc5dc84f13 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/tpu_bf16.py @@ -0,0 +1,35 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import Any, List, Tuple + +import torch.nn as nn +from torch.optim import Optimizer + +from pytorch_lightning.plugins.precision import TPUPrecisionPlugin + + +class TPUBf16PrecisionPlugin(TPUPrecisionPlugin): + """Plugin that enables bfloats on TPUs.""" + + precision: str = "bf16" + + def connect( + self, model: nn.Module, optimizers: List[Optimizer], lr_schedulers: List[Any] + ) -> Tuple[nn.Module, List[Optimizer], List[Any]]: + os.environ["XLA_USE_BF16"] = "1" + return super().connect(model=model, optimizers=optimizers, lr_schedulers=lr_schedulers) + + def teardown(self) -> None: + os.environ.pop("XLA_USE_BF16", None) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/__pycache__/utils.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10ab5ac0030ab732d75401bf4d8feee24ecac9d0 Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/__pycache__/utils.cpython-38.pyc differ diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/ipu.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/ipu.py new file mode 100644 index 0000000000000000000000000000000000000000..3959e84a943d6c7b9274cc0515ece204ca34e96a --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/ipu.py @@ -0,0 +1,24 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning.strategies.ipu import IPUStrategy +from pytorch_lightning.utilities import rank_zero_deprecation + + +class IPUPlugin(IPUStrategy): + def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + rank_zero_deprecation( + "The `pl.plugins.training_type.ipu.IPUPlugin` is deprecated in v1.6 and will be removed in v1.8." + " Use `pl.strategies.ipu.IPUStrategy` instead." + ) + super().__init__(*args, **kwargs) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/parallel.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..5822e17c61f23b17202c5fdd87b07b736b1ce0f5 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/parallel.py @@ -0,0 +1,26 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC + +from pytorch_lightning.strategies import ParallelStrategy +from pytorch_lightning.utilities import rank_zero_deprecation + + +class ParallelPlugin(ParallelStrategy, ABC): + def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + rank_zero_deprecation( + "The `pl.plugins.training_type.parallel.ParallelPlugin` is deprecated in v1.6 and will be removed in v1.8." + " Use `pl.strategies.parallel.ParallelStrategy` instead." + ) + super().__init__(*args, **kwargs) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/sharded.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/sharded.py new file mode 100644 index 0000000000000000000000000000000000000000..d66442d565ad98fdefa3b6910074f39b00f44d4d --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/sharded.py @@ -0,0 +1,24 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning.strategies import DDPShardedStrategy +from pytorch_lightning.utilities import rank_zero_deprecation + + +class DDPShardedPlugin(DDPShardedStrategy): + def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + rank_zero_deprecation( + "The `pl.plugins.training_type.sharded.DDPShardedPlugin` is deprecated in v1.6 and will be removed in v1.8." + " Use `pl.strategies.sharded.DDPShardedStrategy` instead." + ) + super().__init__(*args, **kwargs) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/strategies/hpu_parallel.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/strategies/hpu_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..9bbae8008d6e99c1347a71e00a999f3ce5f2ae20 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/strategies/hpu_parallel.py @@ -0,0 +1,144 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +from typing import Any, Callable, Dict, List, Optional + +import torch.distributed + +import pytorch_lightning as pl +from pytorch_lightning.overrides import LightningDistributedModule +from pytorch_lightning.overrides.torch_distributed import broadcast_object_list +from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO +from pytorch_lightning.plugins.precision import PrecisionPlugin +from pytorch_lightning.strategies.ddp import DDPStrategy +from pytorch_lightning.utilities.distributed import group as _group +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _HPU_AVAILABLE, _TORCH_LESSER_EQUAL_1_10_2 + +if _HPU_AVAILABLE: + import habana_frameworks.torch.core.hccl # noqa: F401 + from habana_frameworks.torch.utils.library_loader import load_habana_module + +log = logging.getLogger(__name__) + + +class HPUParallelStrategy(DDPStrategy): + """Strategy for distributed training on multiple HPU devices.""" + + strategy_name = "hpu_parallel" + + def __init__( + self, + accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None, + parallel_devices: Optional[List[torch.device]] = None, + cluster_environment: Optional[ClusterEnvironment] = None, + checkpoint_io: Optional[CheckpointIO] = None, + precision_plugin: Optional[PrecisionPlugin] = None, + ddp_comm_state: Optional[object] = None, + ddp_comm_hook: Optional[Callable] = None, + ddp_comm_wrapper: Optional[Callable] = None, + model_averaging_period: Optional[int] = None, + process_group_backend: Optional[str] = "hccl", + **kwargs: Any, + ) -> None: + + if not _HPU_AVAILABLE: + raise MisconfigurationException("`HPUParallelStrategy` requires HPU devices to run") + + super().__init__( + accelerator=accelerator, + parallel_devices=parallel_devices, + cluster_environment=cluster_environment, + checkpoint_io=checkpoint_io or HPUCheckpointIO(), + precision_plugin=precision_plugin, + ddp_comm_state=ddp_comm_state, + ddp_comm_hook=ddp_comm_hook, + ddp_comm_wrapper=ddp_comm_wrapper, + model_averaging_period=model_averaging_period, + process_group_backend=process_group_backend, + **kwargs, + ) + + def setup_environment(self) -> None: + # This function is used to load Habana libraries required for PyTorch + # to register HPU as one of the available devices. + load_habana_module() + + os.environ["ID"] = str(self.local_rank) + if self._process_group_backend == "hccl": + # this env is used in overrides to check the backend initiated + os.environ["HCCL_DISTRIBUTED_BACKEND"] = str(1) + super().setup_environment() + + def determine_ddp_device_ids(self) -> None: + return None + + def _pre_configure_ddp(self) -> None: + # if unset, default `find_unused_parameters` `True` + # Many models require setting this parameter to True, as there are corner cases + # when not all parameter backward hooks are fired by the autograd engine even if require_grad is set to True. + # This flag does come with a performance hit, so it is suggested to disable in cases where it is possible. + self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get("find_unused_parameters", True) + + self._static_graph = False + static_graph = self._ddp_kwargs.get("static_graph") + if static_graph: + # when _set_static_graph() is called find_unused_parameters does not have any significance. + # Resetting the value of find_unused_parameters to False which is the default value to DDP + self._ddp_kwargs["find_unused_parameters"] = False + self._static_graph = True + if static_graph is not None: + # DDP does not accept static_graph as a parameter, hence removing it from the list + del self._ddp_kwargs["static_graph"] + + def configure_ddp(self) -> None: + # DDP does not accept static graph as param with torch < 1.11 + if _TORCH_LESSER_EQUAL_1_10_2: + log.detail(f"{self.__class__.__name__}: configuring DistributedDataParallel") + self._pre_configure_ddp() + self.model = self._setup_model(LightningDistributedModule(self.model)) # type: ignore + if self.root_device.type == "hpu" and self._static_graph: + self._model._set_static_graph() # type: ignore + self._register_ddp_hooks() + else: + super().configure_ddp() + + def broadcast(self, obj: object, src: int = 0) -> object: # type: ignore + obj = [obj] + if self.global_rank != src: + obj = [None] + + broadcast_object_list(obj, src, group=_group.WORLD) + return obj[0] + + def teardown(self) -> None: + log.detail(f"{self.__class__.__name__}: tearing down strategy.") + super().teardown() + + log.detail(f"{self.__class__.__name__}: moving model to CPU") + self.lightning_module.cpu() # type: ignore + # Was set to local rank + os.environ.pop("ID", None) + os.environ.pop("HCCL_DISTRIBUTED_BACKEND", None) + + @classmethod + def register_strategies(cls, strategy_registry: Dict) -> None: + strategy_registry.register( + cls.strategy_name, + cls, + description=f"{cls.__class__.__name__}", + ) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/strategies/single_device.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/strategies/single_device.py new file mode 100644 index 0000000000000000000000000000000000000000..da80bad32ad1375dbe83de1ffe30201683174e71 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/strategies/single_device.py @@ -0,0 +1,98 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import Any + +import torch + +import pytorch_lightning as pl +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.plugins.precision import PrecisionPlugin +from pytorch_lightning.strategies.strategy import Strategy +from pytorch_lightning.utilities.types import _DEVICE + + +class SingleDeviceStrategy(Strategy): + """Strategy that handles communication on a single device.""" + + strategy_name = "single_device" + + def __init__( + self, + device: _DEVICE = "cpu", + accelerator: pl.accelerators.accelerator.Accelerator | None = None, + checkpoint_io: CheckpointIO | None = None, + precision_plugin: PrecisionPlugin | None = None, + ): + super().__init__(accelerator=accelerator, checkpoint_io=checkpoint_io, precision_plugin=precision_plugin) + self._root_device = torch.device(device) + self.global_rank = 0 + self.local_rank = 0 + self.world_size = 1 + + def reduce(self, tensor: Any | torch.Tensor, *args: Any, **kwargs: Any) -> Any | torch.Tensor: + """Reduces a tensor from several distributed processes to one aggregated tensor. As this plugin only + operates with a single device, the reduction is simply the identity. + + Args: + tensor: the tensor to sync and reduce + *args: ignored + **kwargs: ignored + + Return: + the unmodified input as reduction is not needed for single process operation + """ + return tensor + + def all_gather(self, tensor: torch.Tensor, group: Any | None = None, sync_grads: bool = False) -> torch.Tensor: + """Perform a all_gather on all processes.""" + return tensor + + @property + def root_device(self) -> torch.device: + return self._root_device + + def model_to_device(self) -> None: + self.model.to(self.root_device) + + def setup(self, trainer: pl.Trainer) -> None: + self.model_to_device() + super().setup(trainer) + + @property + def is_global_zero(self) -> bool: + return True + + def barrier(self, *args, **kwargs) -> None: + pass + + def broadcast(self, obj: object, src: int = 0) -> object: + return obj + + @classmethod + def register_strategies(cls, strategy_registry: dict) -> None: + strategy_registry.register( + cls.strategy_name, + cls, + description=f"{cls.__class__.__name__}", + ) + + def teardown(self) -> None: + super().teardown() + if self.root_device.type == "cuda": + # GPU teardown + self.lightning_module.cpu() + # clean up memory + torch.cuda.empty_cache() diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/strategies/strategy.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/strategies/strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..87c5c171d0eced36e9263883518ea0e9f9c2f473 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/strategies/strategy.py @@ -0,0 +1,507 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import contextlib +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple, TypeVar, Union + +import torch +from torch import Tensor +from torch.nn import Module +from torch.optim import Optimizer +from torch.utils.data import DataLoader + +import pytorch_lightning as pl +from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers, LightningOptimizer +from pytorch_lightning.overrides.base import unwrap_lightning_module +from pytorch_lightning.plugins import TorchCheckpointIO +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.plugins.precision import PrecisionPlugin +from pytorch_lightning.strategies.launchers.base import _Launcher +from pytorch_lightning.trainer.states import TrainerFn +from pytorch_lightning.utilities import rank_zero_deprecation +from pytorch_lightning.utilities.apply_func import move_data_to_device +from pytorch_lightning.utilities.distributed import ReduceOp +from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.optimizer import optimizer_to_device, optimizers_to_device +from pytorch_lightning.utilities.types import _PATH, LRSchedulerConfig, STEP_OUTPUT + +TBroadcast = TypeVar("TBroadcast") + + +class Strategy(ABC): + """Base class for all strategies that change the behaviour of the training, validation and test- loop.""" + + def __init__( + self, + accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None, + checkpoint_io: Optional[CheckpointIO] = None, + precision_plugin: Optional[PrecisionPlugin] = None, + ) -> None: + self.accelerator = accelerator + self._launcher: Optional[_Launcher] = None + self._model: Optional[Module] = None + self.checkpoint_io = checkpoint_io + self.precision_plugin = precision_plugin + self._optimizers: List[Optimizer] = [] + self._lightning_optimizers: Dict[int, LightningOptimizer] = {} + self.lr_scheduler_configs: List[LRSchedulerConfig] = [] + self.optimizer_frequencies: List[int] = [] + if is_overridden("post_dispatch", self, parent=Strategy): + rank_zero_deprecation( + f"`{self.__class__.__name__}.post_dispatch()` has been deprecated in v1.6 and will be removed in v1.7." + f" Move your implementation to `{self.__class__.__name__}.teardown()` instead." + ) + + @property + def launcher(self) -> Optional[_Launcher]: + return self._launcher + + @property + def accelerator(self) -> "pl.accelerators.accelerator.Accelerator": + return self._accelerator + + @accelerator.setter + def accelerator(self, accelerator: "pl.accelerators.accelerator.Accelerator") -> None: + self._accelerator = accelerator + + @property + def checkpoint_io(self) -> CheckpointIO: + return self._checkpoint_io if self._checkpoint_io is not None else TorchCheckpointIO() + + @checkpoint_io.setter + def checkpoint_io(self, io: Optional[CheckpointIO]) -> None: + self._checkpoint_io = io + + @property + def precision_plugin(self) -> PrecisionPlugin: + return self._precision_plugin if self._precision_plugin is not None else PrecisionPlugin() + + @precision_plugin.setter + def precision_plugin(self, precision_plugin: Optional[PrecisionPlugin]) -> None: + self._precision_plugin = precision_plugin + + @property + def optimizers(self) -> List[Optimizer]: + return self._optimizers + + @optimizers.setter + def optimizers(self, optimizers: List[Optimizer]) -> None: + self._optimizers = optimizers + self._lightning_optimizers = { + idx: LightningOptimizer._to_lightning_optimizer(opt, self, idx) for idx, opt in enumerate(self.optimizers) + } + + def connect(self, model: Module) -> None: + """Called by the accelerator to connect the accelerator and the model with this plugin.""" + self.model = model + + def _configure_launcher(self): + """Attach the launcher based on Strategy.""" + + def setup_environment(self) -> None: + """Setup any processes or distributed connections. + + This is called before the LightningModule/DataModule setup hook which allows the user to access the accelerator + environment before setup is complete. + """ + self.accelerator.setup_environment(self.root_device) + + def setup_optimizers(self, trainer: "pl.Trainer") -> None: + """Creates optimizers and schedulers. + + Args: + trainer: the Trainer, these optimizers should be connected to + """ + if trainer.state.fn not in (TrainerFn.FITTING, TrainerFn.TUNING): + return + self.optimizers, self.lr_scheduler_configs, self.optimizer_frequencies = _init_optimizers_and_lr_schedulers( + self.lightning_module + ) + + def setup(self, trainer: "pl.Trainer") -> None: + """Setup plugins for the trainer fit and creates optimizers. + + Args: + trainer: the trainer instance + """ + self.accelerator.setup(trainer) + self.setup_optimizers(trainer) + self.setup_precision_plugin() + optimizers_to_device(self.optimizers, self.root_device) + + def setup_precision_plugin(self) -> None: + """Attaches the precision plugin to the accelerator.""" + model, optimizers, lr_scheduler_configs = self.precision_plugin.connect( + self.model, self.optimizers, self.lr_scheduler_configs + ) + self.model = model + self.optimizers = optimizers + self.lr_scheduler_configs = lr_scheduler_configs + + def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]: + """Returns state of an optimizer. + + Allows for syncing/collating optimizer state from processes in custom plugins. + """ + return optimizer.state_dict() + + def backward(self, closure_loss: Tensor, *args: Any, **kwargs: Any) -> Tensor: + """Forwards backward-calls to the precision plugin. + + Args: + closure_loss: a tensor holding the loss value to backpropagate + """ + self.pre_backward(closure_loss) + closure_loss = self.precision_plugin.pre_backward(self.lightning_module, closure_loss) + + self.precision_plugin.backward(self.lightning_module, closure_loss, *args, **kwargs) + + closure_loss = self.precision_plugin.post_backward(self.lightning_module, closure_loss) + self.post_backward(closure_loss) + + return closure_loss + + def optimizer_step( + self, + optimizer: Optimizer, + opt_idx: int, + closure: Callable[[], Any], + model: Optional[Union["pl.LightningModule", Module]] = None, + **kwargs: Any, + ) -> Any: + """Performs the actual optimizer step. + + Args: + optimizer: the optimizer performing the step + opt_idx: index of the current optimizer + closure: closure calculating the loss value + model: reference to the model, optionally defining optimizer step related hooks + **kwargs: Any extra arguments to ``optimizer.step`` + """ + model = model or self.lightning_module + return self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs) + + def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]: + """Setup a model and multiple optimizers together. + + The returned objects are expected to be in the same order they were passed in. The default implementation will + call :meth:`_setup_model` and :meth:`_setup_optimizer` on the inputs. + """ + # TODO (@awaelchli): standardize this across all plugins in Lightning and Lite. Related refactor: #7324 + model = self._setup_model(model) + optimizers = [self._setup_optimizer(optimizer) for optimizer in optimizers] + return model, optimizers + + def _setup_model(self, model: Module) -> Module: + """Performs setup for the model, e.g., by wrapping it by another class.""" + # TODO (@awaelchli): standardize this across all plugins in Lightning and Lite. Related refactor: #7324 + return model + + def _setup_optimizer(self, optimizer: Optimizer) -> Optimizer: + """Performs setup for the optimizer, e.g., by wrapping it by another class.""" + # TODO (@awaelchli): standardize this across all plugins in Lightning and Lite. Related refactor: #7324 + return optimizer + + def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0) -> Any: + """Moves the batch to the correct device. + + The returned batch is of the same type as the input batch, just + having all tensors on the correct device. + + Args: + batch: The batch of samples to move to the correct device + device: The target device + dataloader_idx: The index of the dataloader to which the batch belongs. + """ + model = self.lightning_module + device = device or self.root_device + if model is not None: + return model._apply_batch_transfer_handler(batch, device=device, dataloader_idx=dataloader_idx) + return move_data_to_device(batch, device) + + @property + @abstractmethod + def root_device(self) -> torch.device: + """Returns the root device.""" + + @abstractmethod + def model_to_device(self) -> None: + """Moves the model to the correct device.""" + + @property + @abstractmethod + def is_global_zero(self) -> bool: + """Whether the current process is the rank zero process not only on the local node, but for all nodes.""" + + @abstractmethod + def reduce( + self, + tensor: Union[torch.Tensor, Any], + group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = "mean", + ) -> Union[torch.Tensor, Any]: + """Reduces the given tensor (e.g. across GPUs/processes). + + Args: + tensor: the tensor to sync and reduce + group: the process group to reduce + reduce_op: the reduction operation. Defaults to 'mean'. + Can also be a string 'sum' or ReduceOp. + """ + + @abstractmethod + def barrier(self, name: Optional[str] = None) -> None: + """Synchronizes all processes which blocks processes until the whole group enters this function. + + Args: + name: an optional name to pass into barrier. + """ + + @abstractmethod + def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: + """Broadcasts an object to all processes. + + Args: + obj: the object to broadcast + src: source rank + """ + + @abstractmethod + def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: + """Perform an all_gather on all processes. + + Args: + tensor: the tensor to all_gather + group: the process group to gather results from + sync_grads: flag that allows users to synchronize gradients for all_gather op + """ + + def reduce_boolean_decision(self, decision: bool) -> bool: + """Reduce the early stopping decision across all processes.""" + return decision + + def pre_backward(self, closure_loss: torch.Tensor) -> None: + """Run before precision plugin executes backward.""" + + def post_backward(self, closure_loss: torch.Tensor) -> None: + """Run after precision plugin executes backward.""" + + @property + def model(self) -> Optional[Module]: + """Returns the potentially wrapped LightningModule.""" + return self._model + + @model.setter + def model(self, new_model: Optional[Module]) -> None: + self._model = new_model + + @property + def lightning_module(self) -> Optional["pl.LightningModule"]: + """Returns the pure LightningModule without potential wrappers.""" + return unwrap_lightning_module(self.model) if self.model is not None else None + + def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: + torch.cuda.empty_cache() + return self.checkpoint_io.load_checkpoint(checkpoint_path) + + def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None: + self.lightning_module.load_state_dict(checkpoint["state_dict"]) + + def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: + optimizer_states = checkpoint["optimizer_states"] + for optimizer, opt_state in zip(self.optimizers, optimizer_states): + optimizer.load_state_dict(opt_state) + optimizer_to_device(optimizer, self.root_device) + + def training_step(self, *args, **kwargs) -> STEP_OUTPUT: + """The actual training step. + + See :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` for more details + """ + with self.precision_plugin.train_step_context(): + return self.model.training_step(*args, **kwargs) + + def post_training_step(self): + pass + + def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]: + """The actual validation step. + + See :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step` for more details + """ + with self.precision_plugin.val_step_context(): + return self.model.validation_step(*args, **kwargs) + + def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]: + """The actual test step. + + See :meth:`~pytorch_lightning.core.lightning.LightningModule.test_step` for more details + """ + with self.precision_plugin.test_step_context(): + return self.model.test_step(*args, **kwargs) + + def predict_step(self, *args, **kwargs) -> STEP_OUTPUT: + """The actual predict step. + + See :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step` for more details + """ + with self.precision_plugin.predict_step_context(): + return self.model.predict_step(*args, **kwargs) + + def training_step_end(self, output): + return output + + def validation_step_end(self, output): + return output + + def test_step_end(self, output): + return output + + def process_dataloader(self, dataloader: DataLoader) -> DataLoader: + """Wraps the dataloader if necessary. + + Args: + dataloader: iterable. Ideally of type: :class:`torch.utils.data.DataLoader` + """ + return dataloader + + @property + def restore_checkpoint_after_setup(self) -> bool: + """Override to delay restoring from checkpoint till after pre-dispatch. This is useful when the plugin + requires all the setup hooks to run before loading checkpoint. + + Returns: + If true, restore checkpoint after pre_dispatch. + """ + return False + + @property + def lightning_restore_optimizer(self) -> bool: + """Override to disable Lightning restoring optimizers/schedulers. + + This is useful for plugins which manage restoring optimizers/schedulers. + """ + return True + + @property + def handles_gradient_accumulation(self) -> bool: + """Whether the plugin handles gradient accumulation internally.""" + return False + + def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]: + """Returns model state.""" + model = self.lightning_module + return model.state_dict() + + def save_checkpoint( + self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None + ) -> None: + """Save model/training states as a checkpoint file through state-dump and file-write. + + Args: + checkpoint: dict containing model and trainer state + filepath: write-target file's path + storage_options: parameter for how to save to storage, passed to ``CheckpointIO`` plugin + """ + if self.is_global_zero: + self.checkpoint_io.save_checkpoint(checkpoint, filepath, storage_options=storage_options) + + def remove_checkpoint(self, filepath: _PATH) -> None: + """Remove checkpoint filepath from the filesystem. + + Args: + filepath: Path to checkpoint + """ + if self.is_global_zero: + self.checkpoint_io.remove_checkpoint(filepath) + + @contextlib.contextmanager + def model_sharded_context(self) -> Generator: + """Provide hook to create modules in a distributed aware context. This is useful for when we'd like to + shard the model instantly, which is useful for extremely large models which can save memory and + initialization time. + + Returns: Model parallel context. + """ + yield + + def teardown(self) -> None: + """This method is called to teardown the training process. + + It is the right place to release memory and free other resources. + """ + optimizers_to_device(self.optimizers, torch.device("cpu")) + self.precision_plugin.teardown() + + @classmethod + def register_strategies(cls, strategy_registry) -> None: + pass + + def on_train_start(self) -> None: + """Called when train begins.""" + pass + + def on_validation_start(self) -> None: + """Called when validation begins.""" + pass + + def on_test_start(self) -> None: + """Called when test begins.""" + pass + + def on_predict_start(self) -> None: + """Called when predict begins.""" + pass + + def on_train_end(self) -> None: + """Called when train ends.""" + pass + + def on_validation_end(self) -> None: + """Called when validation ends.""" + pass + + def on_test_end(self) -> None: + """Called when test end.""" + pass + + def on_predict_end(self): + """Called when predict ends.""" + pass + + def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: + """Called in the training loop before anything happens for that batch.""" + pass + + def dispatch(self, trainer: "pl.Trainer") -> None: + """Hook to do something before the training/evaluation/prediction starts.""" + self.precision_plugin.dispatch(trainer) + + def __getstate__(self) -> Dict: + # `LightningOptimizer` overrides `self.__class__` so they cannot be pickled + state = dict(vars(self)) # copy + state["_lightning_optimizers"] = {} + return state + + def __setstate__(self, state: Dict) -> None: + self.__dict__ = state + self.optimizers = self.optimizers # re-create the `_lightning_optimizers` + + def post_dispatch(self, trainer: "pl.Trainer") -> None: + r""" + .. deprecated:: + v1.6 This method has been deprecated in v1.6 and will be removed in v1.7. Use :meth:`teardown` instead. + + Hook to do something after the training/evaluation/prediction finishes. + """ diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/tuner/__init__.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/tuner/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/tuner/auto_gpu_select.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/tuner/auto_gpu_select.py new file mode 100644 index 0000000000000000000000000000000000000000..d87eba64494f04a1f10f99aa2263fb6e2fc0ea6e --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/tuner/auto_gpu_select.py @@ -0,0 +1,72 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List + +import torch + +from pytorch_lightning.utilities.exceptions import MisconfigurationException + + +def pick_multiple_gpus(nb: int) -> List[int]: + """ + Raises: + MisconfigurationException: + If ``gpus`` or ``devices`` is set to 0, when ``auto_select_gpus=True``, or when the requested number is + higher than the number of GPUs available on the machine. + """ + if nb == 0: + raise MisconfigurationException( + "auto_select_gpus=True, gpus=0 is not a valid configuration." + " Please select a valid number of GPU resources when using auto_select_gpus." + ) + + num_gpus = torch.cuda.device_count() + if nb > num_gpus: + raise MisconfigurationException(f"You requested {nb} GPUs but your machine only has {num_gpus} GPUs.") + nb = num_gpus if nb == -1 else nb + + picked: List[int] = [] + for _ in range(nb): + picked.append(pick_single_gpu(exclude_gpus=picked)) + + return picked + + +def pick_single_gpu(exclude_gpus: List[int]) -> int: + """ + Raises: + RuntimeError: + If you try to allocate a GPU, when no GPUs are available. + """ + previously_used_gpus = [] + unused_gpus = [] + for i in range(torch.cuda.device_count()): + if i in exclude_gpus: + continue + + if torch.cuda.memory_reserved(f"cuda:{i}") > 0: + previously_used_gpus.append(i) + else: + unused_gpus.append(i) + + # Prioritize previously used GPUs + for i in previously_used_gpus + unused_gpus: + # Try to allocate on device: + device = torch.device(f"cuda:{i}") + try: + torch.ones(1).to(device) + except RuntimeError: + continue + return i + raise RuntimeError("No GPUs available.") diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/tuner/batch_size_scaling.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/tuner/batch_size_scaling.py new file mode 100644 index 0000000000000000000000000000000000000000..e8ac418441fdee616deb2fe45e38d4f2a6de2690 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/tuner/batch_size_scaling.py @@ -0,0 +1,251 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License +import logging +import os +import uuid +from typing import Any, Dict, Optional, Tuple + +from torch.utils.data import DataLoader + +import pytorch_lightning as pl +from pytorch_lightning.loggers.base import DummyLogger +from pytorch_lightning.utilities.data import has_len_all_ranks +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.memory import garbage_collection_cuda, is_oom_error +from pytorch_lightning.utilities.parsing import lightning_getattr, lightning_hasattr, lightning_setattr +from pytorch_lightning.utilities.rank_zero import rank_zero_warn + +log = logging.getLogger(__name__) + + +def scale_batch_size( + trainer: "pl.Trainer", + model: "pl.LightningModule", + mode: str = "power", + steps_per_trial: int = 3, + init_val: int = 2, + max_trials: int = 25, + batch_arg_name: str = "batch_size", +) -> Optional[int]: + """See :meth:`~pytorch_lightning.tuner.tuning.Tuner.scale_batch_size`""" + if trainer.fast_dev_run: + rank_zero_warn("Skipping batch size scaler since fast_dev_run is enabled.") + return + + if not lightning_hasattr(model, batch_arg_name): + raise MisconfigurationException(f"Field {batch_arg_name} not found in both `model` and `model.hparams`") + if hasattr(model, batch_arg_name) and hasattr(model, "hparams") and batch_arg_name in model.hparams: + rank_zero_warn( + f"Field `model.{batch_arg_name}` and `model.hparams.{batch_arg_name}` are mutually exclusive!" + f" `model.{batch_arg_name}` will be used as the initial batch size for scaling." + " If this is not the intended behavior, please remove either one." + ) + + if not trainer._data_connector._train_dataloader_source.is_module(): + raise MisconfigurationException( + "The batch scaling feature cannot be used with dataloaders passed directly to `.fit()`." + " Please disable the feature or incorporate the dataloader into the model." + ) + + # Save initial model, that is loaded after batch size is found + ckpt_path = os.path.join(trainer.default_root_dir, f".scale_batch_size_{uuid.uuid4()}.ckpt") + trainer.save_checkpoint(ckpt_path) + params = __scale_batch_dump_params(trainer) + + # Set to values that are required by the algorithm + __scale_batch_reset_params(trainer, steps_per_trial) + + if trainer.progress_bar_callback: + trainer.progress_bar_callback.disable() + + # Initially we just double in size until an OOM is encountered + new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=init_val) # initially set to init_val + if mode == "power": + new_size = _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials) + elif mode == "binsearch": + new_size = _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials) + else: + raise ValueError("mode in method `scale_batch_size` could either be `power` or `binsearch`") + + garbage_collection_cuda() + log.info(f"Finished batch size finder, will continue with full run using batch size {new_size}") + + __scale_batch_restore_params(trainer, params) + + if trainer.progress_bar_callback: + trainer.progress_bar_callback.enable() + + # Restore initial state of model + trainer._checkpoint_connector.restore(ckpt_path) + trainer.strategy.remove_checkpoint(ckpt_path) + + return new_size + + +def __scale_batch_dump_params(trainer: "pl.Trainer") -> Dict[str, Any]: + return { + "max_steps": trainer.fit_loop.max_steps, + "logger": trainer.logger, + "callbacks": trainer.callbacks, + "auto_scale_batch_size": trainer.auto_scale_batch_size, + "auto_lr_find": trainer.auto_lr_find, + "limit_train_batches": trainer.limit_train_batches, + } + + +def __scale_batch_reset_params(trainer: "pl.Trainer", steps_per_trial: int) -> None: + trainer.auto_scale_batch_size = None # prevent recursion + trainer.auto_lr_find = False # avoid lr find being called multiple times + trainer.fit_loop.max_steps = steps_per_trial # take few steps + trainer.loggers = [DummyLogger()] if trainer.loggers else [] + trainer.callbacks = [] # not needed before full run + trainer.limit_train_batches = 1.0 + + +def __scale_batch_restore_params(trainer: "pl.Trainer", params: Dict[str, Any]) -> None: + trainer.auto_scale_batch_size = params["auto_scale_batch_size"] + trainer.auto_lr_find = params["auto_lr_find"] + trainer.fit_loop.max_steps = params["max_steps"] + trainer.logger = params["logger"] + trainer.callbacks = params["callbacks"] + trainer.limit_train_batches = params["limit_train_batches"] + + +def _run_power_scaling( + trainer: "pl.Trainer", model: "pl.LightningModule", new_size: int, batch_arg_name: str, max_trials: int +) -> int: + """Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered.""" + for _ in range(max_trials): + garbage_collection_cuda() + trainer.fit_loop.global_step = 0 # reset after each try + try: + # Try fit + trainer.tuner._run(model) + # Double in size + new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc="succeeded") + except RuntimeError as exception: + # Only these errors should trigger an adjustment + if is_oom_error(exception): + # If we fail in power mode, half the size and return + garbage_collection_cuda() + new_size, _ = _adjust_batch_size(trainer, batch_arg_name, factor=0.5, desc="failed") + break + else: + raise # some other error not memory related + + if changed: + # Force the train dataloader to reset as the batch size has changed + trainer.reset_train_dataloader(model) + trainer.reset_val_dataloader(model) + else: + break + return new_size + + +def _run_binsearch_scaling( + trainer: "pl.Trainer", model: "pl.LightningModule", new_size: int, batch_arg_name: str, max_trials: int +) -> int: + """Batch scaling mode where the size is initially is doubled at each iteration until an OOM error is + encountered. + + Hereafter, the batch size is further refined using a binary search + """ + low = 1 + high = None + count = 0 + while True: + garbage_collection_cuda() + trainer.fit_loop.global_step = 0 # reset after each try + try: + # Try fit + trainer.tuner._run(model) + count += 1 + if count > max_trials: + break + # Double in size + low = new_size + if high: + if high - low <= 1: + break + midval = (high + low) // 2 + new_size, changed = _adjust_batch_size(trainer, batch_arg_name, value=midval, desc="succeeded") + else: + new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc="succeeded") + + if changed: + # Force the train dataloader to reset as the batch size has changed + trainer.reset_train_dataloader(model) + trainer.reset_val_dataloader(model) + else: + break + + except RuntimeError as exception: + # Only these errors should trigger an adjustment + if is_oom_error(exception): + # If we fail in power mode, half the size and return + garbage_collection_cuda() + high = new_size + midval = (high + low) // 2 + new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=midval, desc="failed") + if high - low <= 1: + break + else: + raise # some other error not memory related + + return new_size + + +def _adjust_batch_size( + trainer: "pl.Trainer", + batch_arg_name: str = "batch_size", + factor: float = 1.0, + value: Optional[int] = None, + desc: Optional[str] = None, +) -> Tuple[int, bool]: + """Helper function for adjusting the batch size. + + Args: + trainer: instance of pytorch_lightning.Trainer + + batch_arg_name: name of the field where batch_size is stored. + + factor: value which the old batch size is multiplied by to get the + new batch size + + value: if a value is given, will override the batch size with this value. + Note that the value of `factor` will not have an effect in this case + + desc: either `succeeded` or `failed`. Used purely for logging + + Returns: + The new batch size for the next trial and a bool that signals whether the + new value is different than the previous batch size. + """ + model = trainer.lightning_module + batch_size = lightning_getattr(model, batch_arg_name) + new_size = value if value is not None else int(batch_size * factor) + if desc: + log.info(f"Batch size {batch_size} {desc}, trying batch size {new_size}") + + if not _is_valid_batch_size(new_size, trainer.train_dataloader, trainer): + new_size = min(new_size, len(trainer.train_dataloader.dataset)) + + changed = new_size != batch_size + lightning_setattr(model, batch_arg_name, new_size) + return new_size, changed + + +def _is_valid_batch_size(batch_size: int, dataloader: DataLoader, trainer: "pl.Trainer"): + module = trainer.lightning_module or trainer.datamodule + return not has_len_all_ranks(dataloader, trainer.strategy, module) or batch_size <= len(dataloader) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/tuner/lr_finder.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/tuner/lr_finder.py new file mode 100644 index 0000000000000000000000000000000000000000..a33eec16c757ab4f8d24bcaaa4798ee28d5d0ebf --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/tuner/lr_finder.py @@ -0,0 +1,426 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +import logging +import os +import uuid +from functools import wraps +from typing import Any, Dict, Optional, Sequence + +import numpy as np +import torch +from torch.optim.lr_scheduler import _LRScheduler + +import pytorch_lightning as pl +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers, _set_scheduler_opt_idx +from pytorch_lightning.loggers.base import DummyLogger +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr +from pytorch_lightning.utilities.rank_zero import rank_zero_warn +from pytorch_lightning.utilities.types import LRSchedulerConfig + +# check if ipywidgets is installed before importing tqdm.auto +# to ensure it won't fail and a progress bar is displayed +if importlib.util.find_spec("ipywidgets") is not None: + from tqdm.auto import tqdm +else: + from tqdm import tqdm + +log = logging.getLogger(__name__) + + +def _determine_lr_attr_name(trainer: "pl.Trainer", model: "pl.LightningModule") -> str: + if isinstance(trainer.auto_lr_find, str): + if not lightning_hasattr(model, trainer.auto_lr_find): + raise MisconfigurationException( + f"`auto_lr_find` was set to {trainer.auto_lr_find}, however" + " could not find this as a field in `model` or `model.hparams`." + ) + return trainer.auto_lr_find + + attr_options = ("lr", "learning_rate") + for attr in attr_options: + if lightning_hasattr(model, attr): + return attr + + raise MisconfigurationException( + "When `auto_lr_find=True`, either `model` or `model.hparams` should" + f" have one of these fields: {attr_options} overridden." + ) + + +class _LRFinder: + """LR finder object. This object stores the results of lr_find(). + + Args: + mode: either `linear` or `exponential`, how to increase lr after each step + + lr_min: lr to start search from + + lr_max: lr to stop search + + num_training: number of steps to take between lr_min and lr_max + + Example:: + # Run lr finder + lr_finder = trainer.lr_find(model) + + # Results stored in + lr_finder.results + + # Plot using + lr_finder.plot() + + # Get suggestion + lr = lr_finder.suggestion() + """ + + def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int): + assert mode in ("linear", "exponential"), "mode should be either `linear` or `exponential`" + + self.mode = mode + self.lr_min = lr_min + self.lr_max = lr_max + self.num_training = num_training + + self.results = {} + self._total_batch_idx = 0 # for debug purpose + + def _exchange_scheduler(self, trainer: "pl.Trainer", model: "pl.LightningModule"): + """Decorate `trainer.strategy.setup_optimizers` method such that it sets the user's originally specified + optimizer together with a new scheduler that takes care of the learning rate search.""" + setup_optimizers = trainer.strategy.setup_optimizers + + @wraps(setup_optimizers) + def func(trainer): + # Decide the structure of the output from _init_optimizers_and_lr_schedulers + optimizers, _, _ = _init_optimizers_and_lr_schedulers(trainer.lightning_module) + + if len(optimizers) != 1: + raise MisconfigurationException( + f"`model.configure_optimizers()` returned {len(optimizers)}, but" + " learning rate finder only works with single optimizer" + ) + + optimizer = optimizers[0] + + new_lrs = [self.lr_min] * len(optimizer.param_groups) + for param_group, new_lr in zip(optimizer.param_groups, new_lrs): + param_group["lr"] = new_lr + param_group["initial_lr"] = new_lr + + args = (optimizer, self.lr_max, self.num_training) + scheduler = _LinearLR(*args) if self.mode == "linear" else _ExponentialLR(*args) + + trainer.strategy.optimizers = [optimizer] + trainer.strategy.lr_scheduler_configs = [LRSchedulerConfig(scheduler, interval="step", opt_idx=0)] + trainer.strategy.optimizer_frequencies = [] + _set_scheduler_opt_idx(trainer.optimizers, trainer.lr_scheduler_configs) + + return func + + def plot(self, suggest: bool = False, show: bool = False): + """Plot results from lr_find run + Args: + suggest: if True, will mark suggested lr to use with a red point + + show: if True, will show figure + """ + import matplotlib.pyplot as plt + + lrs = self.results["lr"] + losses = self.results["loss"] + + fig, ax = plt.subplots() + + # Plot loss as a function of the learning rate + ax.plot(lrs, losses) + if self.mode == "exponential": + ax.set_xscale("log") + ax.set_xlabel("Learning rate") + ax.set_ylabel("Loss") + + if suggest: + _ = self.suggestion() + if self._optimal_idx: + ax.plot(lrs[self._optimal_idx], losses[self._optimal_idx], markersize=10, marker="o", color="red") + + if show: + plt.show() + + return fig + + def suggestion(self, skip_begin: int = 10, skip_end: int = 1): + """This will propose a suggestion for choice of initial learning rate as the point with the steepest + negative gradient. + + Returns: + lr: suggested initial learning rate to use + skip_begin: how many samples to skip in the beginning. Prevent too naive estimates + skip_end: how many samples to skip in the end. Prevent too optimistic estimates + """ + try: + loss = np.array(self.results["loss"][skip_begin:-skip_end]) + loss = loss[np.isfinite(loss)] + min_grad = np.gradient(loss).argmin() + self._optimal_idx = min_grad + skip_begin + return self.results["lr"][self._optimal_idx] + # todo: specify the possible exception + except Exception: + log.exception("Failed to compute suggesting for `lr`. There might not be enough points.") + self._optimal_idx = None + + +def lr_find( + trainer: "pl.Trainer", + model: "pl.LightningModule", + min_lr: float = 1e-8, + max_lr: float = 1, + num_training: int = 100, + mode: str = "exponential", + early_stop_threshold: float = 4.0, + update_attr: bool = False, +) -> Optional[_LRFinder]: + """See :meth:`~pytorch_lightning.tuner.tuning.Tuner.lr_find`""" + if trainer.fast_dev_run: + rank_zero_warn("Skipping learning rate finder since fast_dev_run is enabled.") + return + + # Determine lr attr + if update_attr: + lr_attr_name = _determine_lr_attr_name(trainer, model) + + # Save initial model, that is loaded after learning rate is found + ckpt_path = os.path.join(trainer.default_root_dir, f".lr_find_{uuid.uuid4()}.ckpt") + trainer.save_checkpoint(ckpt_path) + params = __lr_finder_dump_params(trainer) + + # Set to values that are required by the algorithm + __lr_finder_reset_params(trainer, num_training, early_stop_threshold) + + # Initialize lr finder object (stores results) + lr_finder = _LRFinder(mode, min_lr, max_lr, num_training) + + # Disable standard progress bar for fit + if trainer.progress_bar_callback: + trainer.progress_bar_callback.disable() + + # Configure optimizer and scheduler + trainer.strategy.setup_optimizers = lr_finder._exchange_scheduler(trainer, model) + + # Fit, lr & loss logged in callback + trainer.tuner._run(model) + + # Prompt if we stopped early + if trainer.global_step != num_training: + log.info(f"LR finder stopped early after {trainer.global_step} steps due to diverging loss.") + + # Transfer results from callback to lr finder object + lr_finder.results.update({"lr": trainer.callbacks[0].lrs, "loss": trainer.callbacks[0].losses}) + lr_finder._total_batch_idx = trainer.fit_loop.total_batch_idx # for debug purpose + + __lr_finder_restore_params(trainer, params) + + if trainer.progress_bar_callback: + trainer.progress_bar_callback.enable() + + # Update lr attr if required + if update_attr: + lr = lr_finder.suggestion() + + # TODO: log lr.results to self.logger + lightning_setattr(model, lr_attr_name, lr) + log.info(f"Learning rate set to {lr}") + + # Restore initial state of model + trainer._checkpoint_connector.restore(ckpt_path) + trainer.strategy.remove_checkpoint(ckpt_path) + + return lr_finder + + +def __lr_finder_dump_params(trainer: "pl.Trainer") -> Dict[str, Any]: + return { + "auto_lr_find": trainer.auto_lr_find, + "callbacks": trainer.callbacks, + "logger": trainer.logger, + "max_steps": trainer.fit_loop.max_steps, + } + + +def __lr_finder_reset_params(trainer: "pl.Trainer", num_training: int, early_stop_threshold: float) -> None: + # avoid lr find being called multiple times + trainer.auto_lr_find = False + # Use special lr logger callback + trainer.callbacks = [_LRCallback(num_training, early_stop_threshold, progress_bar_refresh_rate=1)] + # No logging + trainer.loggers = [DummyLogger()] if trainer.loggers else [] + # Max step set to number of iterations + trainer.fit_loop.max_steps = num_training + + +def __lr_finder_restore_params(trainer: "pl.Trainer", params: Dict[str, Any]) -> None: + trainer.auto_lr_find = params["auto_lr_find"] + trainer.callbacks = params["callbacks"] + trainer.logger = params["logger"] + trainer.fit_loop.max_steps = params["max_steps"] + + +class _LRCallback(Callback): + """Special callback used by the learning rate finder. This callbacks log the learning rate before each batch + and log the corresponding loss after each batch. + + Args: + num_training: number of iterations done by the learning rate finder + early_stop_threshold: threshold for stopping the search. If the + loss at any point is larger than ``early_stop_threshold*best_loss`` + then the search is stopped. To disable, set to ``None``. + progress_bar_refresh_rate: rate to refresh the progress bar for + the learning rate finder + beta: smoothing value, the loss being logged is a running average of + loss values logged until now. ``beta`` controls the forget rate i.e. + if ``beta=0`` all past information is ignored. + """ + + def __init__( + self, + num_training: int, + early_stop_threshold: float = 4.0, + progress_bar_refresh_rate: int = 0, + beta: float = 0.98, + ): + self.num_training = num_training + self.early_stop_threshold = early_stop_threshold + self.beta = beta + self.losses = [] + self.lrs = [] + self.avg_loss = 0.0 + self.best_loss = 0.0 + self.progress_bar_refresh_rate = progress_bar_refresh_rate + self.progress_bar = None + + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): + """Called before each training batch, logs the lr that will be used.""" + if (trainer.fit_loop.batch_idx + 1) % trainer.accumulate_grad_batches != 0: + return + + if self.progress_bar_refresh_rate and self.progress_bar is None: + self.progress_bar = tqdm(desc="Finding best initial lr", total=self.num_training) + + self.lrs.append(trainer.lr_scheduler_configs[0].scheduler.lr[0]) + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + """Called when the training batch ends, logs the calculated loss.""" + if (trainer.fit_loop.batch_idx + 1) % trainer.accumulate_grad_batches != 0: + return + + if self.progress_bar: + self.progress_bar.update() + + current_loss = trainer.fit_loop.running_loss.last().item() + current_step = trainer.global_step + + # Avg loss (loss with momentum) + smoothing + self.avg_loss = self.beta * self.avg_loss + (1 - self.beta) * current_loss + smoothed_loss = self.avg_loss / (1 - self.beta ** (current_step + 1)) + + # Check if we diverging + if self.early_stop_threshold is not None: + if current_step > 1 and smoothed_loss > self.early_stop_threshold * self.best_loss: + trainer.fit_loop.max_steps = current_step # stop signal + if self.progress_bar: + self.progress_bar.close() + + # Save best loss for diverging checking + if smoothed_loss < self.best_loss or current_step == 1: + self.best_loss = smoothed_loss + + self.losses.append(smoothed_loss) + + +class _LinearLR(_LRScheduler): + """Linearly increases the learning rate between two boundaries over a number of iterations. + + Args: + + optimizer: wrapped optimizer. + + end_lr: the final learning rate. + + num_iter: the number of iterations over which the test occurs. + + last_epoch: the index of last epoch. Default: -1. + """ + + last_epoch: int + base_lrs: Sequence + + def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: int, last_epoch: int = -1): + self.end_lr = end_lr + self.num_iter = num_iter + super().__init__(optimizer, last_epoch) + + def get_lr(self): + curr_iter = self.last_epoch + 1 + r = curr_iter / self.num_iter + + if self.last_epoch > 0: + val = [base_lr + r * (self.end_lr - base_lr) for base_lr in self.base_lrs] + else: + val = [base_lr for base_lr in self.base_lrs] + self._lr = val + return val + + @property + def lr(self): + return self._lr + + +class _ExponentialLR(_LRScheduler): + """Exponentially increases the learning rate between two boundaries over a number of iterations. + + Arguments: + + optimizer: wrapped optimizer. + + end_lr: the final learning rate. + + num_iter: the number of iterations over which the test occurs. + + last_epoch: the index of last epoch. Default: -1. + """ + + last_epoch: int + base_lrs: Sequence + + def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: int, last_epoch: int = -1): + self.end_lr = end_lr + self.num_iter = num_iter + super().__init__(optimizer, last_epoch) + + def get_lr(self): + curr_iter = self.last_epoch + 1 + r = curr_iter / self.num_iter + + if self.last_epoch > 0: + val = [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs] + else: + val = [base_lr for base_lr in self.base_lrs] + self._lr = val + return val + + @property + def lr(self): + return self._lr diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/tuner/tuning.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/tuner/tuning.py new file mode 100644 index 0000000000000000000000000000000000000000..b1a38bd27688ca53c2f8926ab1afb36155bcdff2 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/tuner/tuning.py @@ -0,0 +1,207 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, Optional, Union + +import pytorch_lightning as pl +from pytorch_lightning.trainer.states import TrainerStatus +from pytorch_lightning.tuner.batch_size_scaling import scale_batch_size +from pytorch_lightning.tuner.lr_finder import _LRFinder, lr_find +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS + + +class Tuner: + """Tuner class to tune your model.""" + + def __init__(self, trainer: "pl.Trainer") -> None: + self.trainer = trainer + + def on_trainer_init(self, auto_lr_find: Union[str, bool], auto_scale_batch_size: Union[str, bool]) -> None: + self.trainer.auto_lr_find = auto_lr_find + self.trainer.auto_scale_batch_size = auto_scale_batch_size + + def _tune( + self, + model: "pl.LightningModule", + scale_batch_size_kwargs: Optional[Dict[str, Any]] = None, + lr_find_kwargs: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Optional[Union[int, _LRFinder]]]: + scale_batch_size_kwargs = scale_batch_size_kwargs or {} + lr_find_kwargs = lr_find_kwargs or {} + # return a dict instead of a tuple so BC is not broken if a new tuning procedure is added + result = {} + + self.trainer.strategy.connect(model) + + is_tuning = self.trainer.auto_scale_batch_size or self.trainer.auto_lr_find + if self.trainer._accelerator_connector.is_distributed and is_tuning: + raise MisconfigurationException( + "`trainer.tune()` is currently not supported with" + f" `Trainer(strategy={self.trainer.strategy.strategy_name!r})`." + ) + + # Run auto batch size scaling + if self.trainer.auto_scale_batch_size: + if isinstance(self.trainer.auto_scale_batch_size, str): + scale_batch_size_kwargs.setdefault("mode", self.trainer.auto_scale_batch_size) + result["scale_batch_size"] = scale_batch_size(self.trainer, model, **scale_batch_size_kwargs) + + # Run learning rate finder: + if self.trainer.auto_lr_find: + lr_find_kwargs.setdefault("update_attr", True) + result["lr_find"] = lr_find(self.trainer, model, **lr_find_kwargs) + + self.trainer.state.status = TrainerStatus.FINISHED + + return result + + def _run(self, *args: Any, **kwargs: Any) -> None: + """`_run` wrapper to set the proper state during tuning, as this can be called multiple times.""" + self.trainer.state.status = TrainerStatus.RUNNING # last `_run` call might have set it to `FINISHED` + self.trainer.training = True + self.trainer._run(*args, **kwargs) + self.trainer.tuning = True + + def scale_batch_size( + self, + model: "pl.LightningModule", + train_dataloaders: Optional[Union[TRAIN_DATALOADERS, "pl.LightningDataModule"]] = None, + val_dataloaders: Optional[EVAL_DATALOADERS] = None, + datamodule: Optional["pl.LightningDataModule"] = None, + mode: str = "power", + steps_per_trial: int = 3, + init_val: int = 2, + max_trials: int = 25, + batch_arg_name: str = "batch_size", + ) -> Optional[int]: + """Iteratively try to find the largest batch size for a given model that does not give an out of memory + (OOM) error. + + Args: + model: Model to tune. + + train_dataloaders: A collection of :class:`torch.utils.data.DataLoader` or a + :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying training samples. + In the case of multiple dataloaders, please see this :ref:`section `. + + val_dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying validation samples. + + datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. + + mode: Search strategy to update the batch size: + + - ``'power'`` (default): Keep multiplying the batch size by 2, until we get an OOM error. + - ``'binsearch'``: Initially keep multiplying by 2 and after encountering an OOM error + do a binary search between the last successful batch size and the batch size that failed. + + steps_per_trial: number of steps to run with a given batch size. + Ideally 1 should be enough to test if a OOM error occurs, + however in practise a few are needed + + init_val: initial batch size to start the search with + + max_trials: max number of increase in batch size done before + algorithm is terminated + + batch_arg_name: name of the attribute that stores the batch size. + It is expected that the user has provided a model or datamodule that has a hyperparameter + with that name. We will look for this attribute name in the following places + + - ``model`` + - ``model.hparams`` + - ``trainer.datamodule`` (the datamodule passed to the tune method) + """ + self.trainer.auto_scale_batch_size = True + result = self.trainer.tune( + model, + train_dataloaders=train_dataloaders, + val_dataloaders=val_dataloaders, + datamodule=datamodule, + scale_batch_size_kwargs={ + "mode": mode, + "steps_per_trial": steps_per_trial, + "init_val": init_val, + "max_trials": max_trials, + "batch_arg_name": batch_arg_name, + }, + ) + self.trainer.auto_scale_batch_size = False + return result["scale_batch_size"] + + def lr_find( + self, + model: "pl.LightningModule", + train_dataloaders: Optional[Union[TRAIN_DATALOADERS, "pl.LightningDataModule"]] = None, + val_dataloaders: Optional[EVAL_DATALOADERS] = None, + datamodule: Optional["pl.LightningDataModule"] = None, + min_lr: float = 1e-8, + max_lr: float = 1, + num_training: int = 100, + mode: str = "exponential", + early_stop_threshold: float = 4.0, + update_attr: bool = False, + ) -> Optional[_LRFinder]: + """Enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in + picking a good starting learning rate. + + Args: + model: Model to tune. + + train_dataloaders: A collection of :class:`torch.utils.data.DataLoader` or a + :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying training samples. + In the case of multiple dataloaders, please see this :ref:`section `. + + val_dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying validation samples. + + datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. + + min_lr: minimum learning rate to investigate + + max_lr: maximum learning rate to investigate + + num_training: number of learning rates to test + + mode: Search strategy to update learning rate after each batch: + + - ``'exponential'`` (default): Will increase the learning rate exponentially. + - ``'linear'``: Will increase the learning rate linearly. + + early_stop_threshold: threshold for stopping the search. If the + loss at any point is larger than early_stop_threshold*best_loss + then the search is stopped. To disable, set to None. + + update_attr: Whether to update the learning rate attribute or not. + + Raises: + MisconfigurationException: + If learning rate/lr in ``model`` or ``model.hparams`` isn't overridden when ``auto_lr_find=True``, + or if you are using more than one optimizer. + """ + self.trainer.auto_lr_find = True + result = self.trainer.tune( + model, + train_dataloaders=train_dataloaders, + val_dataloaders=val_dataloaders, + datamodule=datamodule, + lr_find_kwargs={ + "min_lr": min_lr, + "max_lr": max_lr, + "num_training": num_training, + "mode": mode, + "early_stop_threshold": early_stop_threshold, + "update_attr": update_attr, + }, + ) + self.trainer.auto_lr_find = False + return result["lr_find"] diff --git a/tmp_inputs_bgUwSA7K/case00001.nii.gz b/tmp_inputs_bgUwSA7K/case00001.nii.gz new file mode 100644 index 0000000000000000000000000000000000000000..82ebb7fa0142ecd64f75249db79c91e5f68a1150 --- /dev/null +++ b/tmp_inputs_bgUwSA7K/case00001.nii.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:736504da8f2ad06160335b27cce279aece181b2de831736390aa56259f3546bd +size 41096835