Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/imageio/resources/images/stent.npz +3 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/__init__.py +56 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/base.py +368 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/device_stats_monitor.py +104 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/early_stopping.py +261 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/finetuning.py +417 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/gpu_stats_monitor.py +262 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/lambda_function.py +96 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/lr_monitor.py +354 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py +720 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_summary.py +73 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/prediction_writer.py +119 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/pruning.py +486 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/quantization.py +344 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/rich_model_summary.py +109 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/stochastic_weight_avg.py +280 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/timer.py +176 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/xla_stats_monitor.py +114 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/core/datamodule.py +264 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/core/decorators.py +60 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/core/hooks.py +828 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py +409 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/core/saving.py +419 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/distributed/__init__.py +14 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/distributed/dist.py +47 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/__pycache__/__init__.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/__pycache__/layer_sync.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/__init__.py +20 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/__pycache__/__init__.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/__pycache__/bagua_environment.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/bagua_environment.py +62 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/cluster_environment.py +87 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/kubeflow_environment.py +78 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/lightning_environment.py +101 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/lsf_environment.py +190 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/slurm_environment.py +134 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/torchelastic_environment.py +88 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/io/__init__.py +17 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/io/__pycache__/xla_plugin.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/io/checkpoint_plugin.py +62 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/io/hpu_plugin.py +52 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/io/torch_plugin.py +96 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/io/xla_plugin.py +57 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__init__.py +27 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__pycache__/apex_amp.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__pycache__/deepspeed.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__pycache__/double.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__pycache__/fully_sharded_native_amp.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__pycache__/mixed.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__pycache__/native_amp.cpython-38.pyc +0 -0
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/imageio/resources/images/stent.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:60a83d2296b51ee6a53153e9ba96ba9020391b0c8952895d9d60a0a629ac6bb6
|
| 3 |
+
size 824612
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/__init__.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright The PyTorch Lightning team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from pytorch_lightning.callbacks.base import Callback
|
| 15 |
+
from pytorch_lightning.callbacks.device_stats_monitor import DeviceStatsMonitor
|
| 16 |
+
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
| 17 |
+
from pytorch_lightning.callbacks.finetuning import BackboneFinetuning, BaseFinetuning
|
| 18 |
+
from pytorch_lightning.callbacks.gpu_stats_monitor import GPUStatsMonitor
|
| 19 |
+
from pytorch_lightning.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler
|
| 20 |
+
from pytorch_lightning.callbacks.lambda_function import LambdaCallback
|
| 21 |
+
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
|
| 22 |
+
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
|
| 23 |
+
from pytorch_lightning.callbacks.model_summary import ModelSummary
|
| 24 |
+
from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
|
| 25 |
+
from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase, RichProgressBar, TQDMProgressBar
|
| 26 |
+
from pytorch_lightning.callbacks.pruning import ModelPruning
|
| 27 |
+
from pytorch_lightning.callbacks.quantization import QuantizationAwareTraining
|
| 28 |
+
from pytorch_lightning.callbacks.rich_model_summary import RichModelSummary
|
| 29 |
+
from pytorch_lightning.callbacks.stochastic_weight_avg import StochasticWeightAveraging
|
| 30 |
+
from pytorch_lightning.callbacks.timer import Timer
|
| 31 |
+
from pytorch_lightning.callbacks.xla_stats_monitor import XLAStatsMonitor
|
| 32 |
+
|
| 33 |
+
__all__ = [
|
| 34 |
+
"BackboneFinetuning",
|
| 35 |
+
"BaseFinetuning",
|
| 36 |
+
"Callback",
|
| 37 |
+
"DeviceStatsMonitor",
|
| 38 |
+
"EarlyStopping",
|
| 39 |
+
"GPUStatsMonitor",
|
| 40 |
+
"XLAStatsMonitor",
|
| 41 |
+
"GradientAccumulationScheduler",
|
| 42 |
+
"LambdaCallback",
|
| 43 |
+
"LearningRateMonitor",
|
| 44 |
+
"ModelCheckpoint",
|
| 45 |
+
"ModelPruning",
|
| 46 |
+
"ModelSummary",
|
| 47 |
+
"BasePredictionWriter",
|
| 48 |
+
"ProgressBar",
|
| 49 |
+
"ProgressBarBase",
|
| 50 |
+
"QuantizationAwareTraining",
|
| 51 |
+
"RichModelSummary",
|
| 52 |
+
"RichProgressBar",
|
| 53 |
+
"StochasticWeightAveraging",
|
| 54 |
+
"Timer",
|
| 55 |
+
"TQDMProgressBar",
|
| 56 |
+
]
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/base.py
ADDED
|
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright The PyTorch Lightning team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
r"""
|
| 15 |
+
Base class used to build new callbacks.
|
| 16 |
+
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from typing import Any, Dict, List, Optional, Type
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
from torch.optim import Optimizer
|
| 23 |
+
|
| 24 |
+
import pytorch_lightning as pl
|
| 25 |
+
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class Callback:
|
| 29 |
+
r"""
|
| 30 |
+
Abstract base class used to build new callbacks.
|
| 31 |
+
|
| 32 |
+
Subclass this class and override any of the relevant hooks
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
@property
|
| 36 |
+
def state_key(self) -> str:
|
| 37 |
+
"""Identifier for the state of the callback.
|
| 38 |
+
|
| 39 |
+
Used to store and retrieve a callback's state from the checkpoint dictionary by
|
| 40 |
+
``checkpoint["callbacks"][state_key]``. Implementations of a callback need to provide a unique state key if 1)
|
| 41 |
+
the callback has state and 2) it is desired to maintain the state of multiple instances of that callback.
|
| 42 |
+
"""
|
| 43 |
+
return self.__class__.__qualname__
|
| 44 |
+
|
| 45 |
+
@property
|
| 46 |
+
def _legacy_state_key(self) -> Type["Callback"]:
|
| 47 |
+
"""State key for checkpoints saved prior to version 1.5.0."""
|
| 48 |
+
return type(self)
|
| 49 |
+
|
| 50 |
+
def _generate_state_key(self, **kwargs: Any) -> str:
|
| 51 |
+
"""Formats a set of key-value pairs into a state key string with the callback class name prefixed. Useful
|
| 52 |
+
for defining a :attr:`state_key`.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
**kwargs: A set of key-value pairs. Must be serializable to :class:`str`.
|
| 56 |
+
"""
|
| 57 |
+
return f"{self.__class__.__qualname__}{repr(kwargs)}"
|
| 58 |
+
|
| 59 |
+
def on_configure_sharded_model(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 60 |
+
r"""
|
| 61 |
+
.. deprecated:: v1.6
|
| 62 |
+
This callback hook was deprecated in v1.6 and will be removed in v1.8. Use `setup()` instead.
|
| 63 |
+
|
| 64 |
+
Called before configure sharded model.
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
def on_before_accelerator_backend_setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 68 |
+
r"""
|
| 69 |
+
.. deprecated:: v1.6
|
| 70 |
+
This callback hook was deprecated in v1.6 and will be removed in v1.8. Use ``setup()`` instead.
|
| 71 |
+
|
| 72 |
+
Called before accelerator is being setup.
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
|
| 76 |
+
"""Called when fit, validate, test, predict, or tune begins."""
|
| 77 |
+
|
| 78 |
+
def teardown(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
|
| 79 |
+
"""Called when fit, validate, test, predict, or tune ends."""
|
| 80 |
+
|
| 81 |
+
def on_init_start(self, trainer: "pl.Trainer") -> None:
|
| 82 |
+
r"""
|
| 83 |
+
.. deprecated:: v1.6
|
| 84 |
+
This callback hook was deprecated in v1.6 and will be removed in v1.8.
|
| 85 |
+
|
| 86 |
+
Called when the trainer initialization begins, model has not yet been set.
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
def on_init_end(self, trainer: "pl.Trainer") -> None:
|
| 90 |
+
r"""
|
| 91 |
+
.. deprecated:: v1.6
|
| 92 |
+
This callback hook was deprecated in v1.6 and will be removed in v1.8.
|
| 93 |
+
|
| 94 |
+
Called when the trainer initialization ends, model has not yet been set.
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 98 |
+
"""Called when fit begins."""
|
| 99 |
+
|
| 100 |
+
def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 101 |
+
"""Called when fit ends."""
|
| 102 |
+
|
| 103 |
+
def on_sanity_check_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 104 |
+
"""Called when the validation sanity check starts."""
|
| 105 |
+
|
| 106 |
+
def on_sanity_check_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 107 |
+
"""Called when the validation sanity check ends."""
|
| 108 |
+
|
| 109 |
+
def on_train_batch_start(
|
| 110 |
+
self,
|
| 111 |
+
trainer: "pl.Trainer",
|
| 112 |
+
pl_module: "pl.LightningModule",
|
| 113 |
+
batch: Any,
|
| 114 |
+
batch_idx: int,
|
| 115 |
+
unused: int = 0,
|
| 116 |
+
) -> None:
|
| 117 |
+
"""Called when the train batch begins."""
|
| 118 |
+
|
| 119 |
+
def on_train_batch_end(
|
| 120 |
+
self,
|
| 121 |
+
trainer: "pl.Trainer",
|
| 122 |
+
pl_module: "pl.LightningModule",
|
| 123 |
+
outputs: STEP_OUTPUT,
|
| 124 |
+
batch: Any,
|
| 125 |
+
batch_idx: int,
|
| 126 |
+
unused: int = 0,
|
| 127 |
+
) -> None:
|
| 128 |
+
"""Called when the train batch ends."""
|
| 129 |
+
|
| 130 |
+
def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 131 |
+
"""Called when the train epoch begins."""
|
| 132 |
+
|
| 133 |
+
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 134 |
+
"""Called when the train epoch ends.
|
| 135 |
+
|
| 136 |
+
To access all batch outputs at the end of the epoch, either:
|
| 137 |
+
|
| 138 |
+
1. Implement `training_epoch_end` in the `LightningModule` and access outputs via the module OR
|
| 139 |
+
2. Cache data across train batch hooks inside the callback implementation to post-process in this hook.
|
| 140 |
+
"""
|
| 141 |
+
|
| 142 |
+
def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 143 |
+
"""Called when the val epoch begins."""
|
| 144 |
+
|
| 145 |
+
def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 146 |
+
"""Called when the val epoch ends."""
|
| 147 |
+
|
| 148 |
+
def on_test_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 149 |
+
"""Called when the test epoch begins."""
|
| 150 |
+
|
| 151 |
+
def on_test_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 152 |
+
"""Called when the test epoch ends."""
|
| 153 |
+
|
| 154 |
+
def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 155 |
+
"""Called when the predict epoch begins."""
|
| 156 |
+
|
| 157 |
+
def on_predict_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: List[Any]) -> None:
|
| 158 |
+
"""Called when the predict epoch ends."""
|
| 159 |
+
|
| 160 |
+
def on_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 161 |
+
r"""
|
| 162 |
+
.. deprecated:: v1.6
|
| 163 |
+
This callback hook was deprecated in v1.6 and will be removed in v1.8. Use
|
| 164 |
+
``on_<train/validation/test>_epoch_start`` instead.
|
| 165 |
+
|
| 166 |
+
Called when either of train/val/test epoch begins.
|
| 167 |
+
"""
|
| 168 |
+
|
| 169 |
+
def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 170 |
+
r"""
|
| 171 |
+
.. deprecated:: v1.6
|
| 172 |
+
This callback hook was deprecated in v1.6 and will be removed in v1.8. Use
|
| 173 |
+
``on_<train/validation/test>_epoch_end`` instead.
|
| 174 |
+
|
| 175 |
+
Called when either of train/val/test epoch ends.
|
| 176 |
+
"""
|
| 177 |
+
|
| 178 |
+
def on_batch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 179 |
+
r"""
|
| 180 |
+
.. deprecated:: v1.6
|
| 181 |
+
This callback hook was deprecated in v1.6 and will be removed in v1.8. Use
|
| 182 |
+
``on_train_batch_start`` instead.
|
| 183 |
+
|
| 184 |
+
Called when the training batch begins.
|
| 185 |
+
"""
|
| 186 |
+
|
| 187 |
+
def on_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 188 |
+
r"""
|
| 189 |
+
.. deprecated:: v1.6
|
| 190 |
+
This callback hook was deprecated in v1.6 and will be removed in v1.8. Use
|
| 191 |
+
``on_train_batch_end`` instead.
|
| 192 |
+
|
| 193 |
+
Called when the training batch ends.
|
| 194 |
+
"""
|
| 195 |
+
|
| 196 |
+
def on_validation_batch_start(
|
| 197 |
+
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
|
| 198 |
+
) -> None:
|
| 199 |
+
"""Called when the validation batch begins."""
|
| 200 |
+
|
| 201 |
+
def on_validation_batch_end(
|
| 202 |
+
self,
|
| 203 |
+
trainer: "pl.Trainer",
|
| 204 |
+
pl_module: "pl.LightningModule",
|
| 205 |
+
outputs: Optional[STEP_OUTPUT],
|
| 206 |
+
batch: Any,
|
| 207 |
+
batch_idx: int,
|
| 208 |
+
dataloader_idx: int,
|
| 209 |
+
) -> None:
|
| 210 |
+
"""Called when the validation batch ends."""
|
| 211 |
+
|
| 212 |
+
def on_test_batch_start(
|
| 213 |
+
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
|
| 214 |
+
) -> None:
|
| 215 |
+
"""Called when the test batch begins."""
|
| 216 |
+
|
| 217 |
+
def on_test_batch_end(
|
| 218 |
+
self,
|
| 219 |
+
trainer: "pl.Trainer",
|
| 220 |
+
pl_module: "pl.LightningModule",
|
| 221 |
+
outputs: Optional[STEP_OUTPUT],
|
| 222 |
+
batch: Any,
|
| 223 |
+
batch_idx: int,
|
| 224 |
+
dataloader_idx: int,
|
| 225 |
+
) -> None:
|
| 226 |
+
"""Called when the test batch ends."""
|
| 227 |
+
|
| 228 |
+
def on_predict_batch_start(
|
| 229 |
+
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
|
| 230 |
+
) -> None:
|
| 231 |
+
"""Called when the predict batch begins."""
|
| 232 |
+
|
| 233 |
+
def on_predict_batch_end(
|
| 234 |
+
self,
|
| 235 |
+
trainer: "pl.Trainer",
|
| 236 |
+
pl_module: "pl.LightningModule",
|
| 237 |
+
outputs: Any,
|
| 238 |
+
batch: Any,
|
| 239 |
+
batch_idx: int,
|
| 240 |
+
dataloader_idx: int,
|
| 241 |
+
) -> None:
|
| 242 |
+
"""Called when the predict batch ends."""
|
| 243 |
+
|
| 244 |
+
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 245 |
+
"""Called when the train begins."""
|
| 246 |
+
|
| 247 |
+
def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 248 |
+
"""Called when the train ends."""
|
| 249 |
+
|
| 250 |
+
def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 251 |
+
r"""
|
| 252 |
+
.. deprecated:: v1.6
|
| 253 |
+
|
| 254 |
+
This callback hook was deprecated in v1.6 and will be removed in v1.8. Use ``on_fit_start`` instead.
|
| 255 |
+
|
| 256 |
+
Called when the pretrain routine begins.
|
| 257 |
+
"""
|
| 258 |
+
|
| 259 |
+
def on_pretrain_routine_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 260 |
+
r"""
|
| 261 |
+
.. deprecated:: v1.6
|
| 262 |
+
|
| 263 |
+
This callback hook was deprecated in v1.6 and will be removed in v1.8. Use ``on_fit_start`` instead.
|
| 264 |
+
|
| 265 |
+
Called when the pretrain routine ends.
|
| 266 |
+
"""
|
| 267 |
+
|
| 268 |
+
def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 269 |
+
"""Called when the validation loop begins."""
|
| 270 |
+
|
| 271 |
+
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 272 |
+
"""Called when the validation loop ends."""
|
| 273 |
+
|
| 274 |
+
def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 275 |
+
"""Called when the test begins."""
|
| 276 |
+
|
| 277 |
+
def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 278 |
+
"""Called when the test ends."""
|
| 279 |
+
|
| 280 |
+
def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 281 |
+
"""Called when the predict begins."""
|
| 282 |
+
|
| 283 |
+
def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 284 |
+
"""Called when predict ends."""
|
| 285 |
+
|
| 286 |
+
def on_keyboard_interrupt(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 287 |
+
r"""
|
| 288 |
+
.. deprecated:: v1.5
|
| 289 |
+
This callback hook was deprecated in v1.5 in favor of `on_exception` and will be removed in v1.7.
|
| 290 |
+
|
| 291 |
+
Called when any trainer execution is interrupted by KeyboardInterrupt.
|
| 292 |
+
"""
|
| 293 |
+
|
| 294 |
+
def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: BaseException) -> None:
|
| 295 |
+
"""Called when any trainer execution is interrupted by an exception."""
|
| 296 |
+
|
| 297 |
+
def state_dict(self) -> Dict[str, Any]:
|
| 298 |
+
"""Called when saving a checkpoint, implement to generate callback's ``state_dict``.
|
| 299 |
+
|
| 300 |
+
Returns:
|
| 301 |
+
A dictionary containing callback state.
|
| 302 |
+
"""
|
| 303 |
+
return {}
|
| 304 |
+
|
| 305 |
+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
| 306 |
+
"""Called when loading a checkpoint, implement to reload callback state given callback's ``state_dict``.
|
| 307 |
+
|
| 308 |
+
Args:
|
| 309 |
+
state_dict: the callback state returned by ``state_dict``.
|
| 310 |
+
"""
|
| 311 |
+
pass
|
| 312 |
+
|
| 313 |
+
def on_save_checkpoint(
|
| 314 |
+
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
|
| 315 |
+
) -> Optional[dict]:
|
| 316 |
+
r"""
|
| 317 |
+
Called when saving a checkpoint to give you a chance to store anything else you might want to save.
|
| 318 |
+
|
| 319 |
+
Args:
|
| 320 |
+
trainer: the current :class:`~pytorch_lightning.trainer.Trainer` instance.
|
| 321 |
+
pl_module: the current :class:`~pytorch_lightning.core.lightning.LightningModule` instance.
|
| 322 |
+
checkpoint: the checkpoint dictionary that will be saved.
|
| 323 |
+
|
| 324 |
+
Returns:
|
| 325 |
+
None or the callback state. Support for returning callback state will be removed in v1.8.
|
| 326 |
+
|
| 327 |
+
.. deprecated:: v1.6
|
| 328 |
+
Returning a value from this method was deprecated in v1.6 and will be removed in v1.8.
|
| 329 |
+
Implement ``Callback.state_dict`` instead to return state.
|
| 330 |
+
In v1.8 ``Callback.on_save_checkpoint`` can only return None.
|
| 331 |
+
"""
|
| 332 |
+
|
| 333 |
+
def on_load_checkpoint(
|
| 334 |
+
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", callback_state: Dict[str, Any]
|
| 335 |
+
) -> None:
|
| 336 |
+
r"""
|
| 337 |
+
Called when loading a model checkpoint, use to reload state.
|
| 338 |
+
|
| 339 |
+
Args:
|
| 340 |
+
trainer: the current :class:`~pytorch_lightning.trainer.Trainer` instance.
|
| 341 |
+
pl_module: the current :class:`~pytorch_lightning.core.lightning.LightningModule` instance.
|
| 342 |
+
callback_state: the callback state returned by ``on_save_checkpoint``.
|
| 343 |
+
|
| 344 |
+
Note:
|
| 345 |
+
The ``on_load_checkpoint`` won't be called with an undefined state.
|
| 346 |
+
If your ``on_load_checkpoint`` hook behavior doesn't rely on a state,
|
| 347 |
+
you will still need to override ``on_save_checkpoint`` to return a ``dummy state``.
|
| 348 |
+
|
| 349 |
+
.. deprecated:: v1.6
|
| 350 |
+
This callback hook will change its signature and behavior in v1.8.
|
| 351 |
+
If you wish to load the state of the callback, use ``Callback.load_state_dict`` instead.
|
| 352 |
+
In v1.8 ``Callback.on_load_checkpoint(checkpoint)`` will receive the entire loaded
|
| 353 |
+
checkpoint dictionary instead of only the callback state from the checkpoint.
|
| 354 |
+
"""
|
| 355 |
+
|
| 356 |
+
def on_before_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", loss: torch.Tensor) -> None:
|
| 357 |
+
"""Called before ``loss.backward()``."""
|
| 358 |
+
|
| 359 |
+
def on_after_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 360 |
+
"""Called after ``loss.backward()`` and before optimizers are stepped."""
|
| 361 |
+
|
| 362 |
+
def on_before_optimizer_step(
|
| 363 |
+
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: Optimizer, opt_idx: int
|
| 364 |
+
) -> None:
|
| 365 |
+
"""Called before ``optimizer.step()``."""
|
| 366 |
+
|
| 367 |
+
def on_before_zero_grad(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: Optimizer) -> None:
|
| 368 |
+
"""Called before ``optimizer.zero_grad()``."""
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/device_stats_monitor.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright The PyTorch Lightning team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""
|
| 15 |
+
Device Stats Monitor
|
| 16 |
+
====================
|
| 17 |
+
|
| 18 |
+
Monitors and logs device stats during training.
|
| 19 |
+
|
| 20 |
+
"""
|
| 21 |
+
from typing import Any, Dict, Optional
|
| 22 |
+
|
| 23 |
+
import pytorch_lightning as pl
|
| 24 |
+
from pytorch_lightning.callbacks.base import Callback
|
| 25 |
+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
| 26 |
+
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
| 27 |
+
from pytorch_lightning.utilities.warnings import rank_zero_deprecation
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class DeviceStatsMonitor(Callback):
|
| 31 |
+
r"""
|
| 32 |
+
Automatically monitors and logs device stats during training stage. ``DeviceStatsMonitor``
|
| 33 |
+
is a special callback as it requires a ``logger`` to passed as argument to the ``Trainer``.
|
| 34 |
+
|
| 35 |
+
Raises:
|
| 36 |
+
MisconfigurationException:
|
| 37 |
+
If ``Trainer`` has no logger.
|
| 38 |
+
|
| 39 |
+
Example:
|
| 40 |
+
>>> from pytorch_lightning import Trainer
|
| 41 |
+
>>> from pytorch_lightning.callbacks import DeviceStatsMonitor
|
| 42 |
+
>>> device_stats = DeviceStatsMonitor() # doctest: +SKIP
|
| 43 |
+
>>> trainer = Trainer(callbacks=[device_stats]) # doctest: +SKIP
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
|
| 47 |
+
if not trainer.loggers:
|
| 48 |
+
raise MisconfigurationException("Cannot use DeviceStatsMonitor callback with Trainer that has no logger.")
|
| 49 |
+
|
| 50 |
+
def on_train_batch_start(
|
| 51 |
+
self,
|
| 52 |
+
trainer: "pl.Trainer",
|
| 53 |
+
pl_module: "pl.LightningModule",
|
| 54 |
+
batch: Any,
|
| 55 |
+
batch_idx: int,
|
| 56 |
+
unused: int = 0,
|
| 57 |
+
) -> None:
|
| 58 |
+
if not trainer.loggers:
|
| 59 |
+
raise MisconfigurationException("Cannot use `DeviceStatsMonitor` callback with `Trainer(logger=False)`.")
|
| 60 |
+
|
| 61 |
+
if not trainer._logger_connector.should_update_logs:
|
| 62 |
+
return
|
| 63 |
+
|
| 64 |
+
device = trainer.strategy.root_device
|
| 65 |
+
device_stats = trainer.accelerator.get_device_stats(device)
|
| 66 |
+
for logger in trainer.loggers:
|
| 67 |
+
separator = logger.group_separator
|
| 68 |
+
prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_start", separator)
|
| 69 |
+
logger.log_metrics(prefixed_device_stats, step=trainer.fit_loop.epoch_loop._batches_that_stepped)
|
| 70 |
+
|
| 71 |
+
def on_train_batch_end(
|
| 72 |
+
self,
|
| 73 |
+
trainer: "pl.Trainer",
|
| 74 |
+
pl_module: "pl.LightningModule",
|
| 75 |
+
outputs: STEP_OUTPUT,
|
| 76 |
+
batch: Any,
|
| 77 |
+
batch_idx: int,
|
| 78 |
+
unused: int = 0,
|
| 79 |
+
) -> None:
|
| 80 |
+
if not trainer.loggers:
|
| 81 |
+
raise MisconfigurationException("Cannot use `DeviceStatsMonitor` callback with `Trainer(logger=False)`.")
|
| 82 |
+
|
| 83 |
+
if not trainer._logger_connector.should_update_logs:
|
| 84 |
+
return
|
| 85 |
+
|
| 86 |
+
device = trainer.strategy.root_device
|
| 87 |
+
device_stats = trainer.accelerator.get_device_stats(device)
|
| 88 |
+
for logger in trainer.loggers:
|
| 89 |
+
separator = logger.group_separator
|
| 90 |
+
prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_end", separator)
|
| 91 |
+
logger.log_metrics(prefixed_device_stats, step=trainer.fit_loop.epoch_loop._batches_that_stepped)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def _prefix_metric_keys(metrics_dict: Dict[str, float], prefix: str, separator: str) -> Dict[str, float]:
|
| 95 |
+
return {prefix + separator + k: v for k, v in metrics_dict.items()}
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def prefix_metric_keys(metrics_dict: Dict[str, float], prefix: str) -> Dict[str, float]:
|
| 99 |
+
rank_zero_deprecation(
|
| 100 |
+
"`pytorch_lightning.callbacks.device_stats_monitor.prefix_metrics`"
|
| 101 |
+
" is deprecated in v1.6 and will be removed in v1.8."
|
| 102 |
+
)
|
| 103 |
+
sep = ""
|
| 104 |
+
return _prefix_metric_keys(metrics_dict, prefix, sep)
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/early_stopping.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright The PyTorch Lightning team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
r"""
|
| 15 |
+
Early Stopping
|
| 16 |
+
^^^^^^^^^^^^^^
|
| 17 |
+
|
| 18 |
+
Monitor a metric and stop training when it stops improving.
|
| 19 |
+
|
| 20 |
+
"""
|
| 21 |
+
import logging
|
| 22 |
+
from typing import Any, Callable, Dict, Optional, Tuple
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
import torch
|
| 26 |
+
|
| 27 |
+
import pytorch_lightning as pl
|
| 28 |
+
from pytorch_lightning.callbacks.base import Callback
|
| 29 |
+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
| 30 |
+
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
| 31 |
+
|
| 32 |
+
log = logging.getLogger(__name__)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class EarlyStopping(Callback):
|
| 36 |
+
r"""
|
| 37 |
+
Monitor a metric and stop training when it stops improving.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
monitor: quantity to be monitored.
|
| 41 |
+
min_delta: minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute
|
| 42 |
+
change of less than or equal to `min_delta`, will count as no improvement.
|
| 43 |
+
patience: number of checks with no improvement
|
| 44 |
+
after which training will be stopped. Under the default configuration, one check happens after
|
| 45 |
+
every training epoch. However, the frequency of validation can be modified by setting various parameters on
|
| 46 |
+
the ``Trainer``, for example ``check_val_every_n_epoch`` and ``val_check_interval``.
|
| 47 |
+
|
| 48 |
+
.. note::
|
| 49 |
+
|
| 50 |
+
It must be noted that the patience parameter counts the number of validation checks with
|
| 51 |
+
no improvement, and not the number of training epochs. Therefore, with parameters
|
| 52 |
+
``check_val_every_n_epoch=10`` and ``patience=3``, the trainer will perform at least 40 training
|
| 53 |
+
epochs before being stopped.
|
| 54 |
+
|
| 55 |
+
verbose: verbosity mode.
|
| 56 |
+
mode: one of ``'min'``, ``'max'``. In ``'min'`` mode, training will stop when the quantity
|
| 57 |
+
monitored has stopped decreasing and in ``'max'`` mode it will stop when the quantity
|
| 58 |
+
monitored has stopped increasing.
|
| 59 |
+
strict: whether to crash the training if `monitor` is not found in the validation metrics.
|
| 60 |
+
check_finite: When set ``True``, stops training when the monitor becomes NaN or infinite.
|
| 61 |
+
stopping_threshold: Stop training immediately once the monitored quantity reaches this threshold.
|
| 62 |
+
divergence_threshold: Stop training as soon as the monitored quantity becomes worse than this threshold.
|
| 63 |
+
check_on_train_epoch_end: whether to run early stopping at the end of the training epoch.
|
| 64 |
+
If this is ``False``, then the check runs at the end of the validation.
|
| 65 |
+
|
| 66 |
+
Raises:
|
| 67 |
+
MisconfigurationException:
|
| 68 |
+
If ``mode`` is none of ``"min"`` or ``"max"``.
|
| 69 |
+
RuntimeError:
|
| 70 |
+
If the metric ``monitor`` is not available.
|
| 71 |
+
|
| 72 |
+
Example::
|
| 73 |
+
|
| 74 |
+
>>> from pytorch_lightning import Trainer
|
| 75 |
+
>>> from pytorch_lightning.callbacks import EarlyStopping
|
| 76 |
+
>>> early_stopping = EarlyStopping('val_loss')
|
| 77 |
+
>>> trainer = Trainer(callbacks=[early_stopping])
|
| 78 |
+
|
| 79 |
+
.. tip:: Saving and restoring multiple early stopping callbacks at the same time is supported under variation in the
|
| 80 |
+
following arguments:
|
| 81 |
+
|
| 82 |
+
*monitor, mode*
|
| 83 |
+
|
| 84 |
+
Read more: :ref:`Persisting Callback State`
|
| 85 |
+
"""
|
| 86 |
+
mode_dict = {"min": torch.lt, "max": torch.gt}
|
| 87 |
+
|
| 88 |
+
order_dict = {"min": "<", "max": ">"}
|
| 89 |
+
|
| 90 |
+
def __init__(
|
| 91 |
+
self,
|
| 92 |
+
monitor: str,
|
| 93 |
+
min_delta: float = 0.0,
|
| 94 |
+
patience: int = 3,
|
| 95 |
+
verbose: bool = False,
|
| 96 |
+
mode: str = "min",
|
| 97 |
+
strict: bool = True,
|
| 98 |
+
check_finite: bool = True,
|
| 99 |
+
stopping_threshold: Optional[float] = None,
|
| 100 |
+
divergence_threshold: Optional[float] = None,
|
| 101 |
+
check_on_train_epoch_end: Optional[bool] = None,
|
| 102 |
+
):
|
| 103 |
+
super().__init__()
|
| 104 |
+
self.monitor = monitor
|
| 105 |
+
self.min_delta = min_delta
|
| 106 |
+
self.patience = patience
|
| 107 |
+
self.verbose = verbose
|
| 108 |
+
self.mode = mode
|
| 109 |
+
self.strict = strict
|
| 110 |
+
self.check_finite = check_finite
|
| 111 |
+
self.stopping_threshold = stopping_threshold
|
| 112 |
+
self.divergence_threshold = divergence_threshold
|
| 113 |
+
self.wait_count = 0
|
| 114 |
+
self.stopped_epoch = 0
|
| 115 |
+
self._check_on_train_epoch_end = check_on_train_epoch_end
|
| 116 |
+
|
| 117 |
+
if self.mode not in self.mode_dict:
|
| 118 |
+
raise MisconfigurationException(f"`mode` can be {', '.join(self.mode_dict.keys())}, got {self.mode}")
|
| 119 |
+
|
| 120 |
+
self.min_delta *= 1 if self.monitor_op == torch.gt else -1
|
| 121 |
+
torch_inf = torch.tensor(np.Inf)
|
| 122 |
+
self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf
|
| 123 |
+
|
| 124 |
+
@property
|
| 125 |
+
def state_key(self) -> str:
|
| 126 |
+
return self._generate_state_key(monitor=self.monitor, mode=self.mode)
|
| 127 |
+
|
| 128 |
+
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
|
| 129 |
+
if self._check_on_train_epoch_end is None:
|
| 130 |
+
# if the user runs validation multiple times per training epoch or multiple training epochs without
|
| 131 |
+
# validation, then we run after validation instead of on train epoch end
|
| 132 |
+
self._check_on_train_epoch_end = trainer.val_check_interval == 1.0 and trainer.check_val_every_n_epoch == 1
|
| 133 |
+
|
| 134 |
+
def _validate_condition_metric(self, logs: Dict[str, float]) -> bool:
|
| 135 |
+
monitor_val = logs.get(self.monitor)
|
| 136 |
+
|
| 137 |
+
error_msg = (
|
| 138 |
+
f"Early stopping conditioned on metric `{self.monitor}` which is not available."
|
| 139 |
+
" Pass in or modify your `EarlyStopping` callback to use any of the following:"
|
| 140 |
+
f' `{"`, `".join(list(logs.keys()))}`'
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
if monitor_val is None:
|
| 144 |
+
if self.strict:
|
| 145 |
+
raise RuntimeError(error_msg)
|
| 146 |
+
if self.verbose > 0:
|
| 147 |
+
rank_zero_warn(error_msg, category=RuntimeWarning)
|
| 148 |
+
|
| 149 |
+
return False
|
| 150 |
+
|
| 151 |
+
return True
|
| 152 |
+
|
| 153 |
+
@property
|
| 154 |
+
def monitor_op(self) -> Callable:
|
| 155 |
+
return self.mode_dict[self.mode]
|
| 156 |
+
|
| 157 |
+
def state_dict(self) -> Dict[str, Any]:
|
| 158 |
+
return {
|
| 159 |
+
"wait_count": self.wait_count,
|
| 160 |
+
"stopped_epoch": self.stopped_epoch,
|
| 161 |
+
"best_score": self.best_score,
|
| 162 |
+
"patience": self.patience,
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
| 166 |
+
self.wait_count = state_dict["wait_count"]
|
| 167 |
+
self.stopped_epoch = state_dict["stopped_epoch"]
|
| 168 |
+
self.best_score = state_dict["best_score"]
|
| 169 |
+
self.patience = state_dict["patience"]
|
| 170 |
+
|
| 171 |
+
def _should_skip_check(self, trainer: "pl.Trainer") -> bool:
|
| 172 |
+
from pytorch_lightning.trainer.states import TrainerFn
|
| 173 |
+
|
| 174 |
+
return trainer.state.fn != TrainerFn.FITTING or trainer.sanity_checking
|
| 175 |
+
|
| 176 |
+
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 177 |
+
if not self._check_on_train_epoch_end or self._should_skip_check(trainer):
|
| 178 |
+
return
|
| 179 |
+
self._run_early_stopping_check(trainer)
|
| 180 |
+
|
| 181 |
+
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 182 |
+
if self._check_on_train_epoch_end or self._should_skip_check(trainer):
|
| 183 |
+
return
|
| 184 |
+
self._run_early_stopping_check(trainer)
|
| 185 |
+
|
| 186 |
+
def _run_early_stopping_check(self, trainer: "pl.Trainer") -> None:
|
| 187 |
+
"""Checks whether the early stopping condition is met and if so tells the trainer to stop the training."""
|
| 188 |
+
logs = trainer.callback_metrics
|
| 189 |
+
|
| 190 |
+
if trainer.fast_dev_run or not self._validate_condition_metric( # disable early_stopping with fast_dev_run
|
| 191 |
+
logs
|
| 192 |
+
): # short circuit if metric not present
|
| 193 |
+
return
|
| 194 |
+
|
| 195 |
+
current = logs[self.monitor].squeeze()
|
| 196 |
+
should_stop, reason = self._evaluate_stopping_criteria(current)
|
| 197 |
+
|
| 198 |
+
# stop every ddp process if any world process decides to stop
|
| 199 |
+
should_stop = trainer.strategy.reduce_boolean_decision(should_stop)
|
| 200 |
+
trainer.should_stop = trainer.should_stop or should_stop
|
| 201 |
+
if should_stop:
|
| 202 |
+
self.stopped_epoch = trainer.current_epoch
|
| 203 |
+
if reason and self.verbose:
|
| 204 |
+
self._log_info(trainer, reason)
|
| 205 |
+
|
| 206 |
+
def _evaluate_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, Optional[str]]:
|
| 207 |
+
should_stop = False
|
| 208 |
+
reason = None
|
| 209 |
+
if self.check_finite and not torch.isfinite(current):
|
| 210 |
+
should_stop = True
|
| 211 |
+
reason = (
|
| 212 |
+
f"Monitored metric {self.monitor} = {current} is not finite."
|
| 213 |
+
f" Previous best value was {self.best_score:.3f}. Signaling Trainer to stop."
|
| 214 |
+
)
|
| 215 |
+
elif self.stopping_threshold is not None and self.monitor_op(current, self.stopping_threshold):
|
| 216 |
+
should_stop = True
|
| 217 |
+
reason = (
|
| 218 |
+
"Stopping threshold reached:"
|
| 219 |
+
f" {self.monitor} = {current} {self.order_dict[self.mode]} {self.stopping_threshold}."
|
| 220 |
+
" Signaling Trainer to stop."
|
| 221 |
+
)
|
| 222 |
+
elif self.divergence_threshold is not None and self.monitor_op(-current, -self.divergence_threshold):
|
| 223 |
+
should_stop = True
|
| 224 |
+
reason = (
|
| 225 |
+
"Divergence threshold reached:"
|
| 226 |
+
f" {self.monitor} = {current} {self.order_dict[self.mode]} {self.divergence_threshold}."
|
| 227 |
+
" Signaling Trainer to stop."
|
| 228 |
+
)
|
| 229 |
+
elif self.monitor_op(current - self.min_delta, self.best_score.to(current.device)):
|
| 230 |
+
should_stop = False
|
| 231 |
+
reason = self._improvement_message(current)
|
| 232 |
+
self.best_score = current
|
| 233 |
+
self.wait_count = 0
|
| 234 |
+
else:
|
| 235 |
+
self.wait_count += 1
|
| 236 |
+
if self.wait_count >= self.patience:
|
| 237 |
+
should_stop = True
|
| 238 |
+
reason = (
|
| 239 |
+
f"Monitored metric {self.monitor} did not improve in the last {self.wait_count} records."
|
| 240 |
+
f" Best score: {self.best_score:.3f}. Signaling Trainer to stop."
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
return should_stop, reason
|
| 244 |
+
|
| 245 |
+
def _improvement_message(self, current: torch.Tensor) -> str:
|
| 246 |
+
"""Formats a log message that informs the user about an improvement in the monitored score."""
|
| 247 |
+
if torch.isfinite(self.best_score):
|
| 248 |
+
msg = (
|
| 249 |
+
f"Metric {self.monitor} improved by {abs(self.best_score - current):.3f} >="
|
| 250 |
+
f" min_delta = {abs(self.min_delta)}. New best score: {current:.3f}"
|
| 251 |
+
)
|
| 252 |
+
else:
|
| 253 |
+
msg = f"Metric {self.monitor} improved. New best score: {current:.3f}"
|
| 254 |
+
return msg
|
| 255 |
+
|
| 256 |
+
@staticmethod
|
| 257 |
+
def _log_info(trainer: Optional["pl.Trainer"], message: str) -> None:
|
| 258 |
+
if trainer is not None and trainer.world_size > 1:
|
| 259 |
+
log.info(f"[rank: {trainer.global_rank}] {message}")
|
| 260 |
+
else:
|
| 261 |
+
log.info(message)
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/finetuning.py
ADDED
|
@@ -0,0 +1,417 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright The PyTorch Lightning team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
r"""
|
| 15 |
+
Finetuning Callback
|
| 16 |
+
^^^^^^^^^^^^^^^^^^^^
|
| 17 |
+
Freeze and unfreeze models for finetuning purposes
|
| 18 |
+
"""
|
| 19 |
+
import logging
|
| 20 |
+
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Union
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
from torch.nn import Module, ModuleDict
|
| 24 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
| 25 |
+
from torch.optim.optimizer import Optimizer
|
| 26 |
+
|
| 27 |
+
import pytorch_lightning as pl
|
| 28 |
+
from pytorch_lightning.callbacks.base import Callback
|
| 29 |
+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
| 30 |
+
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
| 31 |
+
|
| 32 |
+
log = logging.getLogger(__name__)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def multiplicative(epoch):
|
| 36 |
+
return 2
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class BaseFinetuning(Callback):
|
| 40 |
+
r"""
|
| 41 |
+
This class implements the base logic for writing your own Finetuning Callback.
|
| 42 |
+
|
| 43 |
+
Override ``freeze_before_training`` and ``finetune_function`` methods with your own logic.
|
| 44 |
+
|
| 45 |
+
``freeze_before_training``: This method is called before ``configure_optimizers``
|
| 46 |
+
and should be used to freeze any modules parameters.
|
| 47 |
+
|
| 48 |
+
``finetune_function``: This method is called on every train epoch start and should be used to
|
| 49 |
+
``unfreeze`` any parameters. Those parameters needs to be added in a new ``param_group``
|
| 50 |
+
within the optimizer.
|
| 51 |
+
|
| 52 |
+
.. note:: Make sure to filter the parameters based on ``requires_grad``.
|
| 53 |
+
|
| 54 |
+
Example::
|
| 55 |
+
|
| 56 |
+
>>> from torch.optim import Adam
|
| 57 |
+
>>> class MyModel(pl.LightningModule):
|
| 58 |
+
... def configure_optimizer(self):
|
| 59 |
+
... # Make sure to filter the parameters based on `requires_grad`
|
| 60 |
+
... return Adam(filter(lambda p: p.requires_grad, self.parameters()))
|
| 61 |
+
...
|
| 62 |
+
>>> class FeatureExtractorFreezeUnfreeze(BaseFinetuning):
|
| 63 |
+
... def __init__(self, unfreeze_at_epoch=10):
|
| 64 |
+
... super().__init__()
|
| 65 |
+
... self._unfreeze_at_epoch = unfreeze_at_epoch
|
| 66 |
+
...
|
| 67 |
+
... def freeze_before_training(self, pl_module):
|
| 68 |
+
... # freeze any module you want
|
| 69 |
+
... # Here, we are freezing `feature_extractor`
|
| 70 |
+
... self.freeze(pl_module.feature_extractor)
|
| 71 |
+
...
|
| 72 |
+
... def finetune_function(self, pl_module, current_epoch, optimizer, optimizer_idx):
|
| 73 |
+
... # When `current_epoch` is 10, feature_extractor will start training.
|
| 74 |
+
... if current_epoch == self._unfreeze_at_epoch:
|
| 75 |
+
... self.unfreeze_and_add_param_group(
|
| 76 |
+
... modules=pl_module.feature_extractor,
|
| 77 |
+
... optimizer=optimizer,
|
| 78 |
+
... train_bn=True,
|
| 79 |
+
... )
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
def __init__(self):
|
| 83 |
+
self._internal_optimizer_metadata: Dict[int, List[Dict[str, Any]]] = {}
|
| 84 |
+
self._restarting = False
|
| 85 |
+
|
| 86 |
+
def state_dict(self) -> Dict[str, Any]:
|
| 87 |
+
return {
|
| 88 |
+
"internal_optimizer_metadata": self._internal_optimizer_metadata,
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
| 92 |
+
self._restarting = True
|
| 93 |
+
if "internal_optimizer_metadata" in state_dict:
|
| 94 |
+
self._internal_optimizer_metadata = state_dict["internal_optimizer_metadata"]
|
| 95 |
+
else:
|
| 96 |
+
# compatibility to load from old checkpoints before PR #11887
|
| 97 |
+
self._internal_optimizer_metadata = state_dict
|
| 98 |
+
|
| 99 |
+
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 100 |
+
# restore the param_groups created during the previous training.
|
| 101 |
+
if self._restarting:
|
| 102 |
+
named_parameters = dict(pl_module.named_parameters())
|
| 103 |
+
for opt_idx, optimizer in enumerate(trainer.optimizers):
|
| 104 |
+
param_groups = self._apply_mapping_to_param_groups(
|
| 105 |
+
self._internal_optimizer_metadata[opt_idx], named_parameters
|
| 106 |
+
)
|
| 107 |
+
optimizer.param_groups = param_groups
|
| 108 |
+
self._restarting = False
|
| 109 |
+
|
| 110 |
+
@staticmethod
|
| 111 |
+
def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -> List[Module]:
|
| 112 |
+
"""This function is used to flatten a module or an iterable of modules into a list of its leaf modules
|
| 113 |
+
(modules with no children) and parent modules that have parameters directly themselves.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
modules: A given module or an iterable of modules
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
List of modules
|
| 120 |
+
"""
|
| 121 |
+
if isinstance(modules, ModuleDict):
|
| 122 |
+
modules = modules.values()
|
| 123 |
+
|
| 124 |
+
if isinstance(modules, Iterable):
|
| 125 |
+
_modules = []
|
| 126 |
+
for m in modules:
|
| 127 |
+
_modules.extend(BaseFinetuning.flatten_modules(m))
|
| 128 |
+
|
| 129 |
+
else:
|
| 130 |
+
_modules = modules.modules()
|
| 131 |
+
|
| 132 |
+
# Capture all leaf modules as well as parent modules that have parameters directly themselves
|
| 133 |
+
return [m for m in _modules if not list(m.children()) or m._parameters]
|
| 134 |
+
|
| 135 |
+
@staticmethod
|
| 136 |
+
def filter_params(
|
| 137 |
+
modules: Union[Module, Iterable[Union[Module, Iterable]]], train_bn: bool = True, requires_grad: bool = True
|
| 138 |
+
) -> Generator:
|
| 139 |
+
"""Yields the `requires_grad` parameters of a given module or list of modules.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
modules: A given module or an iterable of modules
|
| 143 |
+
train_bn: Whether to train BatchNorm module
|
| 144 |
+
requires_grad: Whether to create a generator for trainable or non-trainable parameters.
|
| 145 |
+
Returns:
|
| 146 |
+
Generator
|
| 147 |
+
"""
|
| 148 |
+
modules = BaseFinetuning.flatten_modules(modules)
|
| 149 |
+
for mod in modules:
|
| 150 |
+
if isinstance(mod, _BatchNorm) and not train_bn:
|
| 151 |
+
continue
|
| 152 |
+
# recursion could yield duplicate parameters for parent modules w/ parameters so disabling it
|
| 153 |
+
for param in mod.parameters(recurse=False):
|
| 154 |
+
if param.requires_grad == requires_grad:
|
| 155 |
+
yield param
|
| 156 |
+
|
| 157 |
+
@staticmethod
|
| 158 |
+
def make_trainable(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -> None:
|
| 159 |
+
"""Unfreezes the parameters of the provided modules.
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
modules: A given module or an iterable of modules
|
| 163 |
+
"""
|
| 164 |
+
modules = BaseFinetuning.flatten_modules(modules)
|
| 165 |
+
for module in modules:
|
| 166 |
+
# recursion could yield duplicate parameters for parent modules w/ parameters so disabling it
|
| 167 |
+
for param in module.parameters(recurse=False):
|
| 168 |
+
param.requires_grad = True
|
| 169 |
+
|
| 170 |
+
@staticmethod
|
| 171 |
+
def freeze(modules: Union[Module, Iterable[Union[Module, Iterable]]], train_bn: bool = True) -> None:
|
| 172 |
+
"""Freezes the parameters of the provided modules.
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
modules: A given module or an iterable of modules
|
| 176 |
+
train_bn: If True, leave the BatchNorm layers in training mode
|
| 177 |
+
|
| 178 |
+
Returns:
|
| 179 |
+
None
|
| 180 |
+
"""
|
| 181 |
+
modules = BaseFinetuning.flatten_modules(modules)
|
| 182 |
+
for mod in modules:
|
| 183 |
+
if isinstance(mod, _BatchNorm) and train_bn:
|
| 184 |
+
BaseFinetuning.make_trainable(mod)
|
| 185 |
+
else:
|
| 186 |
+
# recursion could yield duplicate parameters for parent modules w/ parameters so disabling it
|
| 187 |
+
for param in mod.parameters(recurse=False):
|
| 188 |
+
param.requires_grad = False
|
| 189 |
+
|
| 190 |
+
@staticmethod
|
| 191 |
+
def filter_on_optimizer(optimizer: Optimizer, params: Iterable) -> List:
|
| 192 |
+
"""This function is used to exclude any parameter which already exists in this optimizer.
|
| 193 |
+
|
| 194 |
+
Args:
|
| 195 |
+
optimizer: Optimizer used for parameter exclusion
|
| 196 |
+
params: Iterable of parameters used to check against the provided optimizer
|
| 197 |
+
|
| 198 |
+
Returns:
|
| 199 |
+
List of parameters not contained in this optimizer param groups
|
| 200 |
+
"""
|
| 201 |
+
out_params = []
|
| 202 |
+
removed_params = []
|
| 203 |
+
for param in params:
|
| 204 |
+
if not any(torch.equal(p, param) for group in optimizer.param_groups for p in group["params"]):
|
| 205 |
+
out_params.append(param)
|
| 206 |
+
else:
|
| 207 |
+
removed_params.append(param)
|
| 208 |
+
|
| 209 |
+
if removed_params:
|
| 210 |
+
rank_zero_warn(
|
| 211 |
+
"The provided params to be frozen already exist within another group of this optimizer."
|
| 212 |
+
" Those parameters will be skipped.\n"
|
| 213 |
+
"HINT: Did you init your optimizer in `configure_optimizer` as such:\n"
|
| 214 |
+
f" {type(optimizer)}(filter(lambda p: p.requires_grad, self.parameters()), ...) ",
|
| 215 |
+
)
|
| 216 |
+
return out_params
|
| 217 |
+
|
| 218 |
+
@staticmethod
|
| 219 |
+
def unfreeze_and_add_param_group(
|
| 220 |
+
modules: Union[Module, Iterable[Union[Module, Iterable]]],
|
| 221 |
+
optimizer: Optimizer,
|
| 222 |
+
lr: Optional[float] = None,
|
| 223 |
+
initial_denom_lr: float = 10.0,
|
| 224 |
+
train_bn: bool = True,
|
| 225 |
+
) -> None:
|
| 226 |
+
"""Unfreezes a module and adds its parameters to an optimizer.
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
modules: A module or iterable of modules to unfreeze.
|
| 230 |
+
Their parameters will be added to an optimizer as a new param group.
|
| 231 |
+
optimizer: The provided optimizer will receive new parameters and will add them to
|
| 232 |
+
`add_param_group`
|
| 233 |
+
lr: Learning rate for the new param group.
|
| 234 |
+
initial_denom_lr: If no lr is provided, the learning from the first param group will be used
|
| 235 |
+
and divided by `initial_denom_lr`.
|
| 236 |
+
train_bn: Whether to train the BatchNormalization layers.
|
| 237 |
+
"""
|
| 238 |
+
BaseFinetuning.make_trainable(modules)
|
| 239 |
+
params_lr = optimizer.param_groups[0]["lr"] if lr is None else float(lr)
|
| 240 |
+
denom_lr = initial_denom_lr if lr is None else 1.0
|
| 241 |
+
params = BaseFinetuning.filter_params(modules, train_bn=train_bn, requires_grad=True)
|
| 242 |
+
params = BaseFinetuning.filter_on_optimizer(optimizer, params)
|
| 243 |
+
if params:
|
| 244 |
+
optimizer.add_param_group({"params": params, "lr": params_lr / denom_lr})
|
| 245 |
+
|
| 246 |
+
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
|
| 247 |
+
self.freeze_before_training(pl_module)
|
| 248 |
+
|
| 249 |
+
@staticmethod
|
| 250 |
+
def _apply_mapping_to_param_groups(param_groups: List[Dict[str, Any]], mapping: dict) -> List[Dict[str, Any]]:
|
| 251 |
+
output = []
|
| 252 |
+
for g in param_groups:
|
| 253 |
+
# skip params to save memory
|
| 254 |
+
group_state = {k: v for k, v in g.items() if k != "params"}
|
| 255 |
+
group_state["params"] = [mapping[p] for p in g["params"]]
|
| 256 |
+
output.append(group_state)
|
| 257 |
+
return output
|
| 258 |
+
|
| 259 |
+
def _store(
|
| 260 |
+
self,
|
| 261 |
+
pl_module: "pl.LightningModule",
|
| 262 |
+
opt_idx: int,
|
| 263 |
+
num_param_groups: int,
|
| 264 |
+
current_param_groups: List[Dict[str, Any]],
|
| 265 |
+
) -> None:
|
| 266 |
+
mapping = {p: n for n, p in pl_module.named_parameters()}
|
| 267 |
+
if opt_idx not in self._internal_optimizer_metadata:
|
| 268 |
+
self._internal_optimizer_metadata[opt_idx] = self._apply_mapping_to_param_groups(
|
| 269 |
+
current_param_groups, mapping
|
| 270 |
+
)
|
| 271 |
+
elif num_param_groups != len(current_param_groups):
|
| 272 |
+
# save new param_groups possibly created by the users.
|
| 273 |
+
self._internal_optimizer_metadata[opt_idx].extend(
|
| 274 |
+
self._apply_mapping_to_param_groups(current_param_groups[num_param_groups:], mapping)
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 278 |
+
"""Called when the epoch begins."""
|
| 279 |
+
# import is here to avoid circular imports
|
| 280 |
+
from pytorch_lightning.loops.utilities import _get_active_optimizers
|
| 281 |
+
|
| 282 |
+
for opt_idx, optimizer in _get_active_optimizers(trainer.optimizers, trainer.optimizer_frequencies):
|
| 283 |
+
num_param_groups = len(optimizer.param_groups)
|
| 284 |
+
self.finetune_function(pl_module, trainer.current_epoch, optimizer, opt_idx)
|
| 285 |
+
current_param_groups = optimizer.param_groups
|
| 286 |
+
self._store(pl_module, opt_idx, num_param_groups, current_param_groups)
|
| 287 |
+
|
| 288 |
+
def finetune_function(
|
| 289 |
+
self, pl_module: "pl.LightningModule", epoch: int, optimizer: Optimizer, opt_idx: int
|
| 290 |
+
) -> None:
|
| 291 |
+
"""Override to add your unfreeze logic."""
|
| 292 |
+
raise NotImplementedError
|
| 293 |
+
|
| 294 |
+
def freeze_before_training(self, pl_module: "pl.LightningModule") -> None:
|
| 295 |
+
"""Override to add your freeze logic."""
|
| 296 |
+
raise NotImplementedError
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
class BackboneFinetuning(BaseFinetuning):
|
| 300 |
+
r"""Finetune a backbone model based on a learning rate user-defined scheduling.
|
| 301 |
+
|
| 302 |
+
When the backbone learning rate reaches the current model learning rate
|
| 303 |
+
and ``should_align`` is set to True, it will align with it for the rest of the training.
|
| 304 |
+
|
| 305 |
+
Args:
|
| 306 |
+
unfreeze_backbone_at_epoch: Epoch at which the backbone will be unfreezed.
|
| 307 |
+
lambda_func: Scheduling function for increasing backbone learning rate.
|
| 308 |
+
backbone_initial_ratio_lr:
|
| 309 |
+
Used to scale down the backbone learning rate compared to rest of model
|
| 310 |
+
backbone_initial_lr: Optional, Initial learning rate for the backbone.
|
| 311 |
+
By default, we will use ``current_learning / backbone_initial_ratio_lr``
|
| 312 |
+
should_align: Whether to align with current learning rate when backbone learning
|
| 313 |
+
reaches it.
|
| 314 |
+
initial_denom_lr: When unfreezing the backbone, the initial learning rate will
|
| 315 |
+
``current_learning_rate / initial_denom_lr``.
|
| 316 |
+
train_bn: Whether to make Batch Normalization trainable.
|
| 317 |
+
verbose: Display current learning rate for model and backbone
|
| 318 |
+
rounding: Precision for displaying learning rate
|
| 319 |
+
|
| 320 |
+
Example::
|
| 321 |
+
|
| 322 |
+
>>> from pytorch_lightning import Trainer
|
| 323 |
+
>>> from pytorch_lightning.callbacks import BackboneFinetuning
|
| 324 |
+
>>> multiplicative = lambda epoch: 1.5
|
| 325 |
+
>>> backbone_finetuning = BackboneFinetuning(200, multiplicative)
|
| 326 |
+
>>> trainer = Trainer(callbacks=[backbone_finetuning])
|
| 327 |
+
|
| 328 |
+
"""
|
| 329 |
+
|
| 330 |
+
def __init__(
|
| 331 |
+
self,
|
| 332 |
+
unfreeze_backbone_at_epoch: int = 10,
|
| 333 |
+
lambda_func: Callable = multiplicative,
|
| 334 |
+
backbone_initial_ratio_lr: float = 10e-2,
|
| 335 |
+
backbone_initial_lr: Optional[float] = None,
|
| 336 |
+
should_align: bool = True,
|
| 337 |
+
initial_denom_lr: float = 10.0,
|
| 338 |
+
train_bn: bool = True,
|
| 339 |
+
verbose: bool = False,
|
| 340 |
+
rounding: int = 12,
|
| 341 |
+
) -> None:
|
| 342 |
+
super().__init__()
|
| 343 |
+
|
| 344 |
+
self.unfreeze_backbone_at_epoch: int = unfreeze_backbone_at_epoch
|
| 345 |
+
self.lambda_func: Callable = lambda_func
|
| 346 |
+
self.backbone_initial_ratio_lr: float = backbone_initial_ratio_lr
|
| 347 |
+
self.backbone_initial_lr: Optional[float] = backbone_initial_lr
|
| 348 |
+
self.should_align: bool = should_align
|
| 349 |
+
self.initial_denom_lr: float = initial_denom_lr
|
| 350 |
+
self.train_bn: bool = train_bn
|
| 351 |
+
self.verbose: bool = verbose
|
| 352 |
+
self.rounding: int = rounding
|
| 353 |
+
self.previous_backbone_lr: Optional[float] = None
|
| 354 |
+
|
| 355 |
+
def state_dict(self) -> Dict[str, Any]:
|
| 356 |
+
return {
|
| 357 |
+
"internal_optimizer_metadata": self._internal_optimizer_metadata,
|
| 358 |
+
"previous_backbone_lr": self.previous_backbone_lr,
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
| 362 |
+
self.previous_backbone_lr = state_dict["previous_backbone_lr"]
|
| 363 |
+
super().load_state_dict(state_dict)
|
| 364 |
+
|
| 365 |
+
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 366 |
+
"""
|
| 367 |
+
Raises:
|
| 368 |
+
MisconfigurationException:
|
| 369 |
+
If LightningModule has no nn.Module `backbone` attribute.
|
| 370 |
+
"""
|
| 371 |
+
if hasattr(pl_module, "backbone") and isinstance(pl_module.backbone, Module):
|
| 372 |
+
return super().on_fit_start(trainer, pl_module)
|
| 373 |
+
raise MisconfigurationException("The LightningModule should have a nn.Module `backbone` attribute")
|
| 374 |
+
|
| 375 |
+
def freeze_before_training(self, pl_module: "pl.LightningModule") -> None:
|
| 376 |
+
self.freeze(pl_module.backbone)
|
| 377 |
+
|
| 378 |
+
def finetune_function(
|
| 379 |
+
self, pl_module: "pl.LightningModule", epoch: int, optimizer: Optimizer, opt_idx: int
|
| 380 |
+
) -> None:
|
| 381 |
+
"""Called when the epoch begins."""
|
| 382 |
+
if epoch == self.unfreeze_backbone_at_epoch:
|
| 383 |
+
current_lr = optimizer.param_groups[0]["lr"]
|
| 384 |
+
initial_backbone_lr = (
|
| 385 |
+
self.backbone_initial_lr
|
| 386 |
+
if self.backbone_initial_lr is not None
|
| 387 |
+
else current_lr * self.backbone_initial_ratio_lr
|
| 388 |
+
)
|
| 389 |
+
self.previous_backbone_lr = initial_backbone_lr
|
| 390 |
+
self.unfreeze_and_add_param_group(
|
| 391 |
+
pl_module.backbone,
|
| 392 |
+
optimizer,
|
| 393 |
+
initial_backbone_lr,
|
| 394 |
+
train_bn=self.train_bn,
|
| 395 |
+
initial_denom_lr=self.initial_denom_lr,
|
| 396 |
+
)
|
| 397 |
+
if self.verbose:
|
| 398 |
+
log.info(
|
| 399 |
+
f"Current lr: {round(current_lr, self.rounding)}, "
|
| 400 |
+
f"Backbone lr: {round(initial_backbone_lr, self.rounding)}"
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
elif epoch > self.unfreeze_backbone_at_epoch:
|
| 404 |
+
current_lr = optimizer.param_groups[0]["lr"]
|
| 405 |
+
next_current_backbone_lr = self.lambda_func(epoch + 1) * self.previous_backbone_lr
|
| 406 |
+
next_current_backbone_lr = (
|
| 407 |
+
current_lr
|
| 408 |
+
if (self.should_align and next_current_backbone_lr > current_lr)
|
| 409 |
+
else next_current_backbone_lr
|
| 410 |
+
)
|
| 411 |
+
optimizer.param_groups[-1]["lr"] = next_current_backbone_lr
|
| 412 |
+
self.previous_backbone_lr = next_current_backbone_lr
|
| 413 |
+
if self.verbose:
|
| 414 |
+
log.info(
|
| 415 |
+
f"Current lr: {round(current_lr, self.rounding)}, "
|
| 416 |
+
f"Backbone lr: {round(next_current_backbone_lr, self.rounding)}"
|
| 417 |
+
)
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/gpu_stats_monitor.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright The PyTorch Lightning team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""
|
| 15 |
+
GPU Stats Monitor
|
| 16 |
+
=================
|
| 17 |
+
|
| 18 |
+
Monitor and logs GPU stats during training.
|
| 19 |
+
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import os
|
| 23 |
+
import shutil
|
| 24 |
+
import subprocess
|
| 25 |
+
import time
|
| 26 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 27 |
+
|
| 28 |
+
import torch
|
| 29 |
+
|
| 30 |
+
import pytorch_lightning as pl
|
| 31 |
+
from pytorch_lightning.callbacks.base import Callback
|
| 32 |
+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
| 33 |
+
from pytorch_lightning.utilities.parsing import AttributeDict
|
| 34 |
+
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_only
|
| 35 |
+
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class GPUStatsMonitor(Callback):
|
| 39 |
+
r"""
|
| 40 |
+
.. deprecated:: v1.5
|
| 41 |
+
The `GPUStatsMonitor` callback was deprecated in v1.5 and will be removed in v1.7.
|
| 42 |
+
Please use the `DeviceStatsMonitor` callback instead.
|
| 43 |
+
|
| 44 |
+
Automatically monitors and logs GPU stats during training stage. ``GPUStatsMonitor``
|
| 45 |
+
is a callback and in order to use it you need to assign a logger in the ``Trainer``.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
memory_utilization: Set to ``True`` to monitor used, free and percentage of memory
|
| 49 |
+
utilization at the start and end of each step. Default: ``True``.
|
| 50 |
+
gpu_utilization: Set to ``True`` to monitor percentage of GPU utilization
|
| 51 |
+
at the start and end of each step. Default: ``True``.
|
| 52 |
+
intra_step_time: Set to ``True`` to monitor the time of each step. Default: ``False``.
|
| 53 |
+
inter_step_time: Set to ``True`` to monitor the time between the end of one step
|
| 54 |
+
and the start of the next step. Default: ``False``.
|
| 55 |
+
fan_speed: Set to ``True`` to monitor percentage of fan speed. Default: ``False``.
|
| 56 |
+
temperature: Set to ``True`` to monitor the memory and gpu temperature in degree Celsius.
|
| 57 |
+
Default: ``False``.
|
| 58 |
+
|
| 59 |
+
Raises:
|
| 60 |
+
MisconfigurationException:
|
| 61 |
+
If NVIDIA driver is not installed, not running on GPUs, or ``Trainer`` has no logger.
|
| 62 |
+
|
| 63 |
+
Example::
|
| 64 |
+
|
| 65 |
+
>>> from pytorch_lightning import Trainer
|
| 66 |
+
>>> from pytorch_lightning.callbacks import GPUStatsMonitor
|
| 67 |
+
>>> gpu_stats = GPUStatsMonitor() # doctest: +SKIP
|
| 68 |
+
>>> trainer = Trainer(callbacks=[gpu_stats]) # doctest: +SKIP
|
| 69 |
+
|
| 70 |
+
GPU stats are mainly based on `nvidia-smi --query-gpu` command. The description of the queries is as follows:
|
| 71 |
+
|
| 72 |
+
- **fan.speed** – The fan speed value is the percent of maximum speed that the device's fan is currently
|
| 73 |
+
intended to run at. It ranges from 0 to 100 %. Note: The reported speed is the intended fan speed.
|
| 74 |
+
If the fan is physically blocked and unable to spin, this output will not match the actual fan speed.
|
| 75 |
+
Many parts do not report fan speeds because they rely on cooling via fans in the surrounding enclosure.
|
| 76 |
+
- **memory.used** – Total memory allocated by active contexts.
|
| 77 |
+
- **memory.free** – Total free memory.
|
| 78 |
+
- **utilization.gpu** – Percent of time over the past sample period during which one or more kernels was
|
| 79 |
+
executing on the GPU. The sample period may be between 1 second and 1/6 second depending on the product.
|
| 80 |
+
- **utilization.memory** – Percent of time over the past sample period during which global (device) memory was
|
| 81 |
+
being read or written. The sample period may be between 1 second and 1/6 second depending on the product.
|
| 82 |
+
- **temperature.gpu** – Core GPU temperature, in degrees C.
|
| 83 |
+
- **temperature.memory** – HBM memory temperature, in degrees C.
|
| 84 |
+
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
def __init__(
|
| 88 |
+
self,
|
| 89 |
+
memory_utilization: bool = True,
|
| 90 |
+
gpu_utilization: bool = True,
|
| 91 |
+
intra_step_time: bool = False,
|
| 92 |
+
inter_step_time: bool = False,
|
| 93 |
+
fan_speed: bool = False,
|
| 94 |
+
temperature: bool = False,
|
| 95 |
+
):
|
| 96 |
+
super().__init__()
|
| 97 |
+
|
| 98 |
+
rank_zero_deprecation(
|
| 99 |
+
"The `GPUStatsMonitor` callback was deprecated in v1.5 and will be removed in v1.7."
|
| 100 |
+
" Please use the `DeviceStatsMonitor` callback instead."
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
if shutil.which("nvidia-smi") is None:
|
| 104 |
+
raise MisconfigurationException(
|
| 105 |
+
"Cannot use GPUStatsMonitor callback because NVIDIA driver is not installed."
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
self._log_stats = AttributeDict(
|
| 109 |
+
{
|
| 110 |
+
"memory_utilization": memory_utilization,
|
| 111 |
+
"gpu_utilization": gpu_utilization,
|
| 112 |
+
"intra_step_time": intra_step_time,
|
| 113 |
+
"inter_step_time": inter_step_time,
|
| 114 |
+
"fan_speed": fan_speed,
|
| 115 |
+
"temperature": temperature,
|
| 116 |
+
}
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# The logical device IDs for selected devices
|
| 120 |
+
self._device_ids: List[int] = [] # will be assigned later in setup()
|
| 121 |
+
|
| 122 |
+
# The unmasked real GPU IDs
|
| 123 |
+
self._gpu_ids: List[str] = [] # will be assigned later in setup()
|
| 124 |
+
|
| 125 |
+
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
|
| 126 |
+
if not trainer.loggers:
|
| 127 |
+
raise MisconfigurationException("Cannot use GPUStatsMonitor callback with Trainer that has no logger.")
|
| 128 |
+
|
| 129 |
+
if trainer.strategy.root_device.type != "cuda":
|
| 130 |
+
raise MisconfigurationException(
|
| 131 |
+
"You are using GPUStatsMonitor but are not running on GPU."
|
| 132 |
+
f" The root device type is {trainer.strategy.root_device.type}."
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
# The logical device IDs for selected devices
|
| 136 |
+
self._device_ids = sorted(set(trainer.device_ids))
|
| 137 |
+
|
| 138 |
+
# The unmasked real GPU IDs
|
| 139 |
+
self._gpu_ids = self._get_gpu_ids(self._device_ids)
|
| 140 |
+
|
| 141 |
+
def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 142 |
+
self._snap_intra_step_time: Optional[float] = None
|
| 143 |
+
self._snap_inter_step_time: Optional[float] = None
|
| 144 |
+
|
| 145 |
+
@rank_zero_only
|
| 146 |
+
def on_train_batch_start(
|
| 147 |
+
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int
|
| 148 |
+
) -> None:
|
| 149 |
+
if self._log_stats.intra_step_time:
|
| 150 |
+
self._snap_intra_step_time = time.time()
|
| 151 |
+
|
| 152 |
+
if not trainer._logger_connector.should_update_logs:
|
| 153 |
+
return
|
| 154 |
+
|
| 155 |
+
gpu_stat_keys = self._get_gpu_stat_keys()
|
| 156 |
+
gpu_stats = self._get_gpu_stats([k for k, _ in gpu_stat_keys])
|
| 157 |
+
logs = self._parse_gpu_stats(self._device_ids, gpu_stats, gpu_stat_keys)
|
| 158 |
+
|
| 159 |
+
if self._log_stats.inter_step_time and self._snap_inter_step_time:
|
| 160 |
+
# First log at beginning of second step
|
| 161 |
+
logs["batch_time/inter_step (ms)"] = (time.time() - self._snap_inter_step_time) * 1000
|
| 162 |
+
|
| 163 |
+
for logger in trainer.loggers:
|
| 164 |
+
logger.log_metrics(logs, step=trainer.fit_loop.epoch_loop._batches_that_stepped)
|
| 165 |
+
|
| 166 |
+
@rank_zero_only
|
| 167 |
+
def on_train_batch_end(
|
| 168 |
+
self,
|
| 169 |
+
trainer: "pl.Trainer",
|
| 170 |
+
pl_module: "pl.LightningModule",
|
| 171 |
+
outputs: STEP_OUTPUT,
|
| 172 |
+
batch: Any,
|
| 173 |
+
batch_idx: int,
|
| 174 |
+
) -> None:
|
| 175 |
+
if self._log_stats.inter_step_time:
|
| 176 |
+
self._snap_inter_step_time = time.time()
|
| 177 |
+
|
| 178 |
+
if not trainer._logger_connector.should_update_logs:
|
| 179 |
+
return
|
| 180 |
+
|
| 181 |
+
gpu_stat_keys = self._get_gpu_stat_keys() + self._get_gpu_device_stat_keys()
|
| 182 |
+
gpu_stats = self._get_gpu_stats([k for k, _ in gpu_stat_keys])
|
| 183 |
+
logs = self._parse_gpu_stats(self._device_ids, gpu_stats, gpu_stat_keys)
|
| 184 |
+
|
| 185 |
+
if self._log_stats.intra_step_time and self._snap_intra_step_time:
|
| 186 |
+
logs["batch_time/intra_step (ms)"] = (time.time() - self._snap_intra_step_time) * 1000
|
| 187 |
+
|
| 188 |
+
for logger in trainer.loggers:
|
| 189 |
+
logger.log_metrics(logs, step=trainer.fit_loop.epoch_loop._batches_that_stepped)
|
| 190 |
+
|
| 191 |
+
@staticmethod
|
| 192 |
+
def _get_gpu_ids(device_ids: List[int]) -> List[str]:
|
| 193 |
+
"""Get the unmasked real GPU IDs."""
|
| 194 |
+
# All devices if `CUDA_VISIBLE_DEVICES` unset
|
| 195 |
+
default = ",".join(str(i) for i in range(torch.cuda.device_count()))
|
| 196 |
+
cuda_visible_devices: List[str] = os.getenv("CUDA_VISIBLE_DEVICES", default=default).split(",")
|
| 197 |
+
return [cuda_visible_devices[device_id].strip() for device_id in device_ids]
|
| 198 |
+
|
| 199 |
+
def _get_gpu_stats(self, queries: List[str]) -> List[List[float]]:
|
| 200 |
+
if not queries:
|
| 201 |
+
return []
|
| 202 |
+
|
| 203 |
+
"""Run nvidia-smi to get the gpu stats"""
|
| 204 |
+
gpu_query = ",".join(queries)
|
| 205 |
+
format = "csv,nounits,noheader"
|
| 206 |
+
gpu_ids = ",".join(self._gpu_ids)
|
| 207 |
+
result = subprocess.run(
|
| 208 |
+
[
|
| 209 |
+
# it's ok to suppress the warning here since we ensure nvidia-smi exists during init
|
| 210 |
+
shutil.which("nvidia-smi"), # type: ignore
|
| 211 |
+
f"--query-gpu={gpu_query}",
|
| 212 |
+
f"--format={format}",
|
| 213 |
+
f"--id={gpu_ids}",
|
| 214 |
+
],
|
| 215 |
+
encoding="utf-8",
|
| 216 |
+
capture_output=True,
|
| 217 |
+
check=True,
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
def _to_float(x: str) -> float:
|
| 221 |
+
try:
|
| 222 |
+
return float(x)
|
| 223 |
+
except ValueError:
|
| 224 |
+
return 0.0
|
| 225 |
+
|
| 226 |
+
stats = [[_to_float(x) for x in s.split(", ")] for s in result.stdout.strip().split(os.linesep)]
|
| 227 |
+
return stats
|
| 228 |
+
|
| 229 |
+
@staticmethod
|
| 230 |
+
def _parse_gpu_stats(
|
| 231 |
+
device_ids: List[int], stats: List[List[float]], keys: List[Tuple[str, str]]
|
| 232 |
+
) -> Dict[str, float]:
|
| 233 |
+
"""Parse the gpu stats into a loggable dict."""
|
| 234 |
+
logs = {}
|
| 235 |
+
for i, device_id in enumerate(device_ids):
|
| 236 |
+
for j, (x, unit) in enumerate(keys):
|
| 237 |
+
logs[f"device_id: {device_id}/{x} ({unit})"] = stats[i][j]
|
| 238 |
+
return logs
|
| 239 |
+
|
| 240 |
+
def _get_gpu_stat_keys(self) -> List[Tuple[str, str]]:
|
| 241 |
+
"""Get the GPU stats keys."""
|
| 242 |
+
stat_keys = []
|
| 243 |
+
|
| 244 |
+
if self._log_stats.gpu_utilization:
|
| 245 |
+
stat_keys.append(("utilization.gpu", "%"))
|
| 246 |
+
|
| 247 |
+
if self._log_stats.memory_utilization:
|
| 248 |
+
stat_keys.extend([("memory.used", "MB"), ("memory.free", "MB"), ("utilization.memory", "%")])
|
| 249 |
+
|
| 250 |
+
return stat_keys
|
| 251 |
+
|
| 252 |
+
def _get_gpu_device_stat_keys(self) -> List[Tuple[str, str]]:
|
| 253 |
+
"""Get the device stats keys."""
|
| 254 |
+
stat_keys = []
|
| 255 |
+
|
| 256 |
+
if self._log_stats.fan_speed:
|
| 257 |
+
stat_keys.append(("fan.speed", "%"))
|
| 258 |
+
|
| 259 |
+
if self._log_stats.temperature:
|
| 260 |
+
stat_keys.extend([("temperature.gpu", "°C"), ("temperature.memory", "°C")])
|
| 261 |
+
|
| 262 |
+
return stat_keys
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/lambda_function.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright The PyTorch Lightning team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
r"""
|
| 15 |
+
Lambda Callback
|
| 16 |
+
^^^^^^^^^^^^^^^
|
| 17 |
+
|
| 18 |
+
Create a simple callback on the fly using lambda functions.
|
| 19 |
+
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
from typing import Callable, Optional
|
| 23 |
+
|
| 24 |
+
from pytorch_lightning.callbacks.base import Callback
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class LambdaCallback(Callback):
|
| 28 |
+
r"""
|
| 29 |
+
Create a simple callback on the fly using lambda functions.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
**kwargs: hooks supported by :class:`~pytorch_lightning.callbacks.base.Callback`
|
| 33 |
+
|
| 34 |
+
Example::
|
| 35 |
+
|
| 36 |
+
>>> from pytorch_lightning import Trainer
|
| 37 |
+
>>> from pytorch_lightning.callbacks import LambdaCallback
|
| 38 |
+
>>> trainer = Trainer(callbacks=[LambdaCallback(setup=lambda *args: print('setup'))])
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
on_before_accelerator_backend_setup: Optional[Callable] = None,
|
| 44 |
+
setup: Optional[Callable] = None,
|
| 45 |
+
on_configure_sharded_model: Optional[Callable] = None,
|
| 46 |
+
teardown: Optional[Callable] = None,
|
| 47 |
+
on_init_start: Optional[Callable] = None,
|
| 48 |
+
on_init_end: Optional[Callable] = None,
|
| 49 |
+
on_fit_start: Optional[Callable] = None,
|
| 50 |
+
on_fit_end: Optional[Callable] = None,
|
| 51 |
+
on_sanity_check_start: Optional[Callable] = None,
|
| 52 |
+
on_sanity_check_end: Optional[Callable] = None,
|
| 53 |
+
on_train_batch_start: Optional[Callable] = None,
|
| 54 |
+
on_train_batch_end: Optional[Callable] = None,
|
| 55 |
+
on_train_epoch_start: Optional[Callable] = None,
|
| 56 |
+
on_train_epoch_end: Optional[Callable] = None,
|
| 57 |
+
on_validation_epoch_start: Optional[Callable] = None,
|
| 58 |
+
on_validation_epoch_end: Optional[Callable] = None,
|
| 59 |
+
on_test_epoch_start: Optional[Callable] = None,
|
| 60 |
+
on_test_epoch_end: Optional[Callable] = None,
|
| 61 |
+
on_epoch_start: Optional[Callable] = None,
|
| 62 |
+
on_epoch_end: Optional[Callable] = None,
|
| 63 |
+
on_batch_start: Optional[Callable] = None,
|
| 64 |
+
on_validation_batch_start: Optional[Callable] = None,
|
| 65 |
+
on_validation_batch_end: Optional[Callable] = None,
|
| 66 |
+
on_test_batch_start: Optional[Callable] = None,
|
| 67 |
+
on_test_batch_end: Optional[Callable] = None,
|
| 68 |
+
on_batch_end: Optional[Callable] = None,
|
| 69 |
+
on_train_start: Optional[Callable] = None,
|
| 70 |
+
on_train_end: Optional[Callable] = None,
|
| 71 |
+
on_pretrain_routine_start: Optional[Callable] = None,
|
| 72 |
+
on_pretrain_routine_end: Optional[Callable] = None,
|
| 73 |
+
on_validation_start: Optional[Callable] = None,
|
| 74 |
+
on_validation_end: Optional[Callable] = None,
|
| 75 |
+
on_test_start: Optional[Callable] = None,
|
| 76 |
+
on_test_end: Optional[Callable] = None,
|
| 77 |
+
on_keyboard_interrupt: Optional[Callable] = None,
|
| 78 |
+
on_exception: Optional[Callable] = None,
|
| 79 |
+
on_save_checkpoint: Optional[Callable] = None,
|
| 80 |
+
on_load_checkpoint: Optional[Callable] = None,
|
| 81 |
+
on_before_backward: Optional[Callable] = None,
|
| 82 |
+
on_after_backward: Optional[Callable] = None,
|
| 83 |
+
on_before_optimizer_step: Optional[Callable] = None,
|
| 84 |
+
on_before_zero_grad: Optional[Callable] = None,
|
| 85 |
+
on_predict_start: Optional[Callable] = None,
|
| 86 |
+
on_predict_end: Optional[Callable] = None,
|
| 87 |
+
on_predict_batch_start: Optional[Callable] = None,
|
| 88 |
+
on_predict_batch_end: Optional[Callable] = None,
|
| 89 |
+
on_predict_epoch_start: Optional[Callable] = None,
|
| 90 |
+
on_predict_epoch_end: Optional[Callable] = None,
|
| 91 |
+
):
|
| 92 |
+
for k, v in locals().items():
|
| 93 |
+
if k == "self":
|
| 94 |
+
continue
|
| 95 |
+
if v is not None:
|
| 96 |
+
setattr(self, k, v)
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/lr_monitor.py
ADDED
|
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright The PyTorch Lightning team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
r"""
|
| 15 |
+
|
| 16 |
+
Learning Rate Monitor
|
| 17 |
+
=====================
|
| 18 |
+
|
| 19 |
+
Monitor and logs learning rate for lr schedulers during training.
|
| 20 |
+
|
| 21 |
+
"""
|
| 22 |
+
import itertools
|
| 23 |
+
from collections import defaultdict
|
| 24 |
+
from typing import Any, DefaultDict, Dict, List, Optional, Set, Tuple, Type
|
| 25 |
+
|
| 26 |
+
from torch.optim.optimizer import Optimizer
|
| 27 |
+
|
| 28 |
+
import pytorch_lightning as pl
|
| 29 |
+
from pytorch_lightning.callbacks.base import Callback
|
| 30 |
+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
| 31 |
+
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
|
| 32 |
+
from pytorch_lightning.utilities.types import LRSchedulerConfig
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class LearningRateMonitor(Callback):
|
| 36 |
+
r"""
|
| 37 |
+
Automatically monitor and logs learning rate for learning rate schedulers during training.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
logging_interval: set to ``'epoch'`` or ``'step'`` to log ``lr`` of all optimizers
|
| 41 |
+
at the same interval, set to ``None`` to log at individual interval
|
| 42 |
+
according to the ``interval`` key of each scheduler. Defaults to ``None``.
|
| 43 |
+
log_momentum: option to also log the momentum values of the optimizer, if the optimizer
|
| 44 |
+
has the ``momentum`` or ``betas`` attribute. Defaults to ``False``.
|
| 45 |
+
|
| 46 |
+
Raises:
|
| 47 |
+
MisconfigurationException:
|
| 48 |
+
If ``logging_interval`` is none of ``"step"``, ``"epoch"``, or ``None``.
|
| 49 |
+
|
| 50 |
+
Example::
|
| 51 |
+
|
| 52 |
+
>>> from pytorch_lightning import Trainer
|
| 53 |
+
>>> from pytorch_lightning.callbacks import LearningRateMonitor
|
| 54 |
+
>>> lr_monitor = LearningRateMonitor(logging_interval='step')
|
| 55 |
+
>>> trainer = Trainer(callbacks=[lr_monitor])
|
| 56 |
+
|
| 57 |
+
Logging names are automatically determined based on optimizer class name.
|
| 58 |
+
In case of multiple optimizers of same type, they will be named ``Adam``,
|
| 59 |
+
``Adam-1`` etc. If a optimizer has multiple parameter groups they will
|
| 60 |
+
be named ``Adam/pg1``, ``Adam/pg2`` etc. To control naming, pass in a
|
| 61 |
+
``name`` keyword in the construction of the learning rate schedulers.
|
| 62 |
+
A ``name`` keyword can also be used for parameter groups in the
|
| 63 |
+
construction of the optimizer.
|
| 64 |
+
|
| 65 |
+
Example::
|
| 66 |
+
|
| 67 |
+
def configure_optimizer(self):
|
| 68 |
+
optimizer = torch.optim.Adam(...)
|
| 69 |
+
lr_scheduler = {
|
| 70 |
+
'scheduler': torch.optim.lr_scheduler.LambdaLR(optimizer, ...)
|
| 71 |
+
'name': 'my_logging_name'
|
| 72 |
+
}
|
| 73 |
+
return [optimizer], [lr_scheduler]
|
| 74 |
+
|
| 75 |
+
Example::
|
| 76 |
+
|
| 77 |
+
def configure_optimizer(self):
|
| 78 |
+
optimizer = torch.optim.SGD(
|
| 79 |
+
[{
|
| 80 |
+
'params': [p for p in self.parameters()],
|
| 81 |
+
'name': 'my_parameter_group_name'
|
| 82 |
+
}],
|
| 83 |
+
lr=0.1
|
| 84 |
+
)
|
| 85 |
+
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, ...)
|
| 86 |
+
return [optimizer], [lr_scheduler]
|
| 87 |
+
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
def __init__(self, logging_interval: Optional[str] = None, log_momentum: bool = False) -> None:
|
| 91 |
+
if logging_interval not in (None, "step", "epoch"):
|
| 92 |
+
raise MisconfigurationException("logging_interval should be `step` or `epoch` or `None`.")
|
| 93 |
+
|
| 94 |
+
self.logging_interval = logging_interval
|
| 95 |
+
self.log_momentum = log_momentum
|
| 96 |
+
self.lrs: Dict[str, List[float]] = {}
|
| 97 |
+
self._lr_sch_names: List[str] = []
|
| 98 |
+
|
| 99 |
+
def on_train_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
|
| 100 |
+
"""Called before training, determines unique names for all lr schedulers in the case of multiple of the
|
| 101 |
+
same type or in the case of multiple parameter groups.
|
| 102 |
+
|
| 103 |
+
Raises:
|
| 104 |
+
MisconfigurationException:
|
| 105 |
+
If ``Trainer`` has no ``logger``.
|
| 106 |
+
"""
|
| 107 |
+
if not trainer.loggers:
|
| 108 |
+
raise MisconfigurationException(
|
| 109 |
+
"Cannot use `LearningRateMonitor` callback with `Trainer` that has no logger."
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
if self.log_momentum:
|
| 113 |
+
|
| 114 |
+
def _check_no_key(key: str) -> bool:
|
| 115 |
+
if trainer.lr_scheduler_configs:
|
| 116 |
+
return any(
|
| 117 |
+
key not in config.scheduler.optimizer.defaults for config in trainer.lr_scheduler_configs
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
return any(key not in optimizer.defaults for optimizer in trainer.optimizers)
|
| 121 |
+
|
| 122 |
+
if _check_no_key("momentum") and _check_no_key("betas"):
|
| 123 |
+
rank_zero_warn(
|
| 124 |
+
"You have set log_momentum=True, but some optimizers do not"
|
| 125 |
+
" have momentum. This will log a value 0 for the momentum.",
|
| 126 |
+
category=RuntimeWarning,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
# Find names for schedulers
|
| 130 |
+
names: List[List[str]] = []
|
| 131 |
+
(
|
| 132 |
+
sched_hparam_keys,
|
| 133 |
+
optimizers_with_scheduler,
|
| 134 |
+
optimizers_with_scheduler_types,
|
| 135 |
+
) = self._find_names_from_schedulers(trainer.lr_scheduler_configs)
|
| 136 |
+
names.extend(sched_hparam_keys)
|
| 137 |
+
|
| 138 |
+
# Find names for leftover optimizers
|
| 139 |
+
optimizer_hparam_keys, _ = self._find_names_from_optimizers(
|
| 140 |
+
trainer.optimizers,
|
| 141 |
+
seen_optimizers=optimizers_with_scheduler,
|
| 142 |
+
seen_optimizer_types=optimizers_with_scheduler_types,
|
| 143 |
+
)
|
| 144 |
+
names.extend(optimizer_hparam_keys)
|
| 145 |
+
|
| 146 |
+
# Initialize for storing values
|
| 147 |
+
names_flatten = list(itertools.chain.from_iterable(names))
|
| 148 |
+
self.lrs = {name: [] for name in names_flatten}
|
| 149 |
+
self.last_momentum_values = {name + "-momentum": None for name in names_flatten}
|
| 150 |
+
|
| 151 |
+
def on_train_batch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
|
| 152 |
+
if not trainer._logger_connector.should_update_logs:
|
| 153 |
+
return
|
| 154 |
+
|
| 155 |
+
if self.logging_interval != "epoch":
|
| 156 |
+
interval = "step" if self.logging_interval is None else "any"
|
| 157 |
+
latest_stat = self._extract_stats(trainer, interval)
|
| 158 |
+
|
| 159 |
+
if latest_stat:
|
| 160 |
+
for logger in trainer.loggers:
|
| 161 |
+
logger.log_metrics(latest_stat, step=trainer.fit_loop.epoch_loop._batches_that_stepped)
|
| 162 |
+
|
| 163 |
+
def on_train_epoch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
|
| 164 |
+
if self.logging_interval != "step":
|
| 165 |
+
interval = "epoch" if self.logging_interval is None else "any"
|
| 166 |
+
latest_stat = self._extract_stats(trainer, interval)
|
| 167 |
+
|
| 168 |
+
if latest_stat:
|
| 169 |
+
for logger in trainer.loggers:
|
| 170 |
+
logger.log_metrics(latest_stat, step=trainer.fit_loop.epoch_loop._batches_that_stepped)
|
| 171 |
+
|
| 172 |
+
def _extract_stats(self, trainer: "pl.Trainer", interval: str) -> Dict[str, float]:
|
| 173 |
+
latest_stat = {}
|
| 174 |
+
|
| 175 |
+
(
|
| 176 |
+
scheduler_hparam_keys,
|
| 177 |
+
optimizers_with_scheduler,
|
| 178 |
+
optimizers_with_scheduler_types,
|
| 179 |
+
) = self._find_names_from_schedulers(trainer.lr_scheduler_configs, add_lr_sch_names=False)
|
| 180 |
+
self._remap_keys(scheduler_hparam_keys)
|
| 181 |
+
|
| 182 |
+
for name, config in zip(scheduler_hparam_keys, trainer.lr_scheduler_configs):
|
| 183 |
+
if interval in [config.interval, "any"]:
|
| 184 |
+
opt = config.scheduler.optimizer
|
| 185 |
+
current_stat = self._get_lr_momentum_stat(opt, name)
|
| 186 |
+
latest_stat.update(current_stat)
|
| 187 |
+
|
| 188 |
+
optimizer_hparam_keys, optimizers_without_scheduler = self._find_names_from_optimizers(
|
| 189 |
+
trainer.optimizers,
|
| 190 |
+
seen_optimizers=optimizers_with_scheduler,
|
| 191 |
+
seen_optimizer_types=optimizers_with_scheduler_types,
|
| 192 |
+
add_lr_sch_names=False,
|
| 193 |
+
)
|
| 194 |
+
self._remap_keys(optimizer_hparam_keys)
|
| 195 |
+
|
| 196 |
+
for opt, names in zip(optimizers_without_scheduler, optimizer_hparam_keys):
|
| 197 |
+
current_stat = self._get_lr_momentum_stat(opt, names)
|
| 198 |
+
latest_stat.update(current_stat)
|
| 199 |
+
|
| 200 |
+
return latest_stat
|
| 201 |
+
|
| 202 |
+
def _get_lr_momentum_stat(self, optimizer: Optimizer, names: List[str]) -> Dict[str, float]:
|
| 203 |
+
lr_momentum_stat = {}
|
| 204 |
+
param_groups = optimizer.param_groups
|
| 205 |
+
use_betas = "betas" in optimizer.defaults
|
| 206 |
+
|
| 207 |
+
for pg, name in zip(param_groups, names):
|
| 208 |
+
lr = self._extract_lr(pg, name)
|
| 209 |
+
lr_momentum_stat.update(lr)
|
| 210 |
+
momentum = self._extract_momentum(
|
| 211 |
+
param_group=pg, name=name.replace(name, f"{name}-momentum"), use_betas=use_betas
|
| 212 |
+
)
|
| 213 |
+
lr_momentum_stat.update(momentum)
|
| 214 |
+
|
| 215 |
+
return lr_momentum_stat
|
| 216 |
+
|
| 217 |
+
def _extract_lr(self, param_group: Dict[str, Any], name: str) -> Dict[str, Any]:
|
| 218 |
+
lr = param_group["lr"]
|
| 219 |
+
self.lrs[name].append(lr)
|
| 220 |
+
return {name: lr}
|
| 221 |
+
|
| 222 |
+
def _remap_keys(self, names: List[List[str]], token: str = "/pg1") -> None:
|
| 223 |
+
"""This function is used the remap the keys if param groups for a given optimizer increased."""
|
| 224 |
+
for group_new_names in names:
|
| 225 |
+
for new_name in group_new_names:
|
| 226 |
+
old_name = new_name.replace(token, "")
|
| 227 |
+
if token in new_name and old_name in self.lrs:
|
| 228 |
+
self.lrs[new_name] = self.lrs.pop(old_name)
|
| 229 |
+
elif new_name not in self.lrs:
|
| 230 |
+
self.lrs[new_name] = []
|
| 231 |
+
|
| 232 |
+
def _extract_momentum(self, param_group: Dict[str, List], name: str, use_betas: bool) -> Dict[str, float]:
|
| 233 |
+
if not self.log_momentum:
|
| 234 |
+
return {}
|
| 235 |
+
|
| 236 |
+
momentum = param_group["betas"][0] if use_betas else param_group.get("momentum", 0)
|
| 237 |
+
self.last_momentum_values[name] = momentum
|
| 238 |
+
return {name: momentum}
|
| 239 |
+
|
| 240 |
+
def _add_prefix(
|
| 241 |
+
self, name: str, optimizer_cls: Type[Optimizer], seen_optimizer_types: DefaultDict[Type[Optimizer], int]
|
| 242 |
+
) -> str:
|
| 243 |
+
if optimizer_cls not in seen_optimizer_types:
|
| 244 |
+
return name
|
| 245 |
+
count = seen_optimizer_types[optimizer_cls]
|
| 246 |
+
return name + f"-{count - 1}" if count > 1 else name
|
| 247 |
+
|
| 248 |
+
def _add_suffix(self, name: str, param_groups: List[Dict], param_group_index: int, use_names: bool = True) -> str:
|
| 249 |
+
if len(param_groups) > 1:
|
| 250 |
+
if not use_names:
|
| 251 |
+
return f"{name}/pg{param_group_index+1}"
|
| 252 |
+
pg_name = param_groups[param_group_index].get("name", f"pg{param_group_index+1}")
|
| 253 |
+
return f"{name}/{pg_name}"
|
| 254 |
+
elif use_names:
|
| 255 |
+
pg_name = param_groups[param_group_index].get("name")
|
| 256 |
+
return f"{name}/{pg_name}" if pg_name else name
|
| 257 |
+
return name
|
| 258 |
+
|
| 259 |
+
def _duplicate_param_group_names(self, param_groups: List[Dict]) -> Set[str]:
|
| 260 |
+
names = [pg.get("name", f"pg{i}") for i, pg in enumerate(param_groups, start=1)]
|
| 261 |
+
unique = set(names)
|
| 262 |
+
if len(names) == len(unique):
|
| 263 |
+
return set()
|
| 264 |
+
return {n for n in names if names.count(n) > 1}
|
| 265 |
+
|
| 266 |
+
def _find_names_from_schedulers(
|
| 267 |
+
self, lr_scheduler_configs: List[LRSchedulerConfig], add_lr_sch_names: bool = True
|
| 268 |
+
) -> Tuple[List[List[str]], List[Optimizer], DefaultDict[Type[Optimizer], int]]:
|
| 269 |
+
# Create unique names in the case we have multiple of the same learning
|
| 270 |
+
# rate scheduler + multiple parameter groups
|
| 271 |
+
names = []
|
| 272 |
+
seen_optimizers: List[Optimizer] = []
|
| 273 |
+
seen_optimizer_types: DefaultDict[Type[Optimizer], int] = defaultdict(int)
|
| 274 |
+
for config in lr_scheduler_configs:
|
| 275 |
+
sch = config.scheduler
|
| 276 |
+
if config.name is not None:
|
| 277 |
+
name = config.name
|
| 278 |
+
else:
|
| 279 |
+
name = "lr-" + sch.optimizer.__class__.__name__
|
| 280 |
+
|
| 281 |
+
updated_names = self._check_duplicates_and_update_name(
|
| 282 |
+
sch.optimizer, name, seen_optimizers, seen_optimizer_types, config, add_lr_sch_names
|
| 283 |
+
)
|
| 284 |
+
names.append(updated_names)
|
| 285 |
+
|
| 286 |
+
return names, seen_optimizers, seen_optimizer_types
|
| 287 |
+
|
| 288 |
+
def _find_names_from_optimizers(
|
| 289 |
+
self,
|
| 290 |
+
optimizers: List[Any],
|
| 291 |
+
seen_optimizers: List[Optimizer],
|
| 292 |
+
seen_optimizer_types: DefaultDict[Type[Optimizer], int],
|
| 293 |
+
add_lr_sch_names: bool = True,
|
| 294 |
+
) -> Tuple[List[List[str]], List[Optimizer]]:
|
| 295 |
+
names = []
|
| 296 |
+
optimizers_without_scheduler = []
|
| 297 |
+
|
| 298 |
+
for optimizer in optimizers:
|
| 299 |
+
# Deepspeed optimizer wraps the native optimizer
|
| 300 |
+
optimizer = optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer
|
| 301 |
+
if optimizer in seen_optimizers:
|
| 302 |
+
continue
|
| 303 |
+
|
| 304 |
+
name = "lr-" + optimizer.__class__.__name__
|
| 305 |
+
updated_names = self._check_duplicates_and_update_name(
|
| 306 |
+
optimizer, name, seen_optimizers, seen_optimizer_types, None, add_lr_sch_names
|
| 307 |
+
)
|
| 308 |
+
names.append(updated_names)
|
| 309 |
+
optimizers_without_scheduler.append(optimizer)
|
| 310 |
+
|
| 311 |
+
return names, optimizers_without_scheduler
|
| 312 |
+
|
| 313 |
+
def _check_duplicates_and_update_name(
|
| 314 |
+
self,
|
| 315 |
+
optimizer: Optimizer,
|
| 316 |
+
name: str,
|
| 317 |
+
seen_optimizers: List[Optimizer],
|
| 318 |
+
seen_optimizer_types: DefaultDict[Type[Optimizer], int],
|
| 319 |
+
lr_scheduler_config: Optional[LRSchedulerConfig],
|
| 320 |
+
add_lr_sch_names: bool = True,
|
| 321 |
+
) -> List[str]:
|
| 322 |
+
seen_optimizers.append(optimizer)
|
| 323 |
+
optimizer_cls = type(optimizer)
|
| 324 |
+
if lr_scheduler_config is not None and lr_scheduler_config.name is None:
|
| 325 |
+
seen_optimizer_types[optimizer_cls] += 1
|
| 326 |
+
elif lr_scheduler_config is None:
|
| 327 |
+
seen_optimizer_types[optimizer_cls] += 1
|
| 328 |
+
|
| 329 |
+
# Multiple param groups for the same optimizer
|
| 330 |
+
param_groups = optimizer.param_groups
|
| 331 |
+
duplicates = self._duplicate_param_group_names(param_groups)
|
| 332 |
+
if duplicates:
|
| 333 |
+
raise MisconfigurationException(
|
| 334 |
+
"A single `Optimizer` cannot have multiple parameter groups with identical "
|
| 335 |
+
f"`name` values. {name} has duplicated parameter group names {duplicates}"
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
name = self._add_prefix(name, optimizer_cls, seen_optimizer_types)
|
| 339 |
+
name_list = [self._add_suffix(name, param_groups, i) for i in range(len(param_groups))]
|
| 340 |
+
|
| 341 |
+
if add_lr_sch_names:
|
| 342 |
+
self._lr_sch_names.append(name)
|
| 343 |
+
|
| 344 |
+
return name_list
|
| 345 |
+
|
| 346 |
+
@property
|
| 347 |
+
def lr_sch_names(self) -> List[str]:
|
| 348 |
+
# TODO remove `lr_sch_names` and `add_lr_sch_names` argument in v1.7.0
|
| 349 |
+
rank_zero_deprecation(
|
| 350 |
+
"`LearningRateMonitor.lr_sch_names` has been deprecated in v1.5 and will be removed in 1.7."
|
| 351 |
+
" Consider accessing them using `LearningRateMonitor.lrs.keys()` which will return"
|
| 352 |
+
" the names of all the optimizers, even those without a scheduler."
|
| 353 |
+
)
|
| 354 |
+
return self._lr_sch_names
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py
ADDED
|
@@ -0,0 +1,720 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright The PyTorch Lightning team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""
|
| 15 |
+
Model Checkpointing
|
| 16 |
+
===================
|
| 17 |
+
|
| 18 |
+
Automatically save model checkpoints during training.
|
| 19 |
+
|
| 20 |
+
"""
|
| 21 |
+
import logging
|
| 22 |
+
import os
|
| 23 |
+
import re
|
| 24 |
+
import time
|
| 25 |
+
import warnings
|
| 26 |
+
from copy import deepcopy
|
| 27 |
+
from datetime import timedelta
|
| 28 |
+
from typing import Any, Dict, Optional
|
| 29 |
+
from weakref import proxy
|
| 30 |
+
|
| 31 |
+
import numpy as np
|
| 32 |
+
import torch
|
| 33 |
+
import yaml
|
| 34 |
+
|
| 35 |
+
import pytorch_lightning as pl
|
| 36 |
+
from pytorch_lightning.callbacks.base import Callback
|
| 37 |
+
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
| 38 |
+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
| 39 |
+
from pytorch_lightning.utilities.logger import _name, _version
|
| 40 |
+
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn
|
| 41 |
+
from pytorch_lightning.utilities.types import _METRIC, _PATH, STEP_OUTPUT
|
| 42 |
+
from pytorch_lightning.utilities.warnings import WarningCache
|
| 43 |
+
|
| 44 |
+
log = logging.getLogger(__name__)
|
| 45 |
+
warning_cache = WarningCache()
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class ModelCheckpoint(Callback):
|
| 49 |
+
r"""
|
| 50 |
+
Save the model periodically by monitoring a quantity. Every metric logged with
|
| 51 |
+
:meth:`~pytorch_lightning.core.lightning.log` or :meth:`~pytorch_lightning.core.lightning.log_dict` in
|
| 52 |
+
LightningModule is a candidate for the monitor key. For more information, see
|
| 53 |
+
:ref:`checkpointing`.
|
| 54 |
+
|
| 55 |
+
After training finishes, use :attr:`best_model_path` to retrieve the path to the
|
| 56 |
+
best checkpoint file and :attr:`best_model_score` to retrieve its score.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
dirpath: directory to save the model file.
|
| 60 |
+
|
| 61 |
+
Example::
|
| 62 |
+
|
| 63 |
+
# custom path
|
| 64 |
+
# saves a file like: my/path/epoch=0-step=10.ckpt
|
| 65 |
+
>>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
|
| 66 |
+
|
| 67 |
+
By default, dirpath is ``None`` and will be set at runtime to the location
|
| 68 |
+
specified by :class:`~pytorch_lightning.trainer.trainer.Trainer`'s
|
| 69 |
+
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.default_root_dir` or
|
| 70 |
+
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.weights_save_path` arguments,
|
| 71 |
+
and if the Trainer uses a logger, the path will also contain logger name and version.
|
| 72 |
+
|
| 73 |
+
filename: checkpoint filename. Can contain named formatting options to be auto-filled.
|
| 74 |
+
|
| 75 |
+
Example::
|
| 76 |
+
|
| 77 |
+
# save any arbitrary metrics like `val_loss`, etc. in name
|
| 78 |
+
# saves a file like: my/path/epoch=2-val_loss=0.02-other_metric=0.03.ckpt
|
| 79 |
+
>>> checkpoint_callback = ModelCheckpoint(
|
| 80 |
+
... dirpath='my/path',
|
| 81 |
+
... filename='{epoch}-{val_loss:.2f}-{other_metric:.2f}'
|
| 82 |
+
... )
|
| 83 |
+
|
| 84 |
+
By default, filename is ``None`` and will be set to ``'{epoch}-{step}'``.
|
| 85 |
+
monitor: quantity to monitor. By default it is ``None`` which saves a checkpoint only for the last epoch.
|
| 86 |
+
verbose: verbosity mode. Default: ``False``.
|
| 87 |
+
save_last: When ``True``, saves an exact copy of the checkpoint to a file `last.ckpt` whenever a checkpoint
|
| 88 |
+
file gets saved. This allows accessing the latest checkpoint in a deterministic manner. Default: ``None``.
|
| 89 |
+
save_top_k: if ``save_top_k == k``,
|
| 90 |
+
the best k models according to
|
| 91 |
+
the quantity monitored will be saved.
|
| 92 |
+
if ``save_top_k == 0``, no models are saved.
|
| 93 |
+
if ``save_top_k == -1``, all models are saved.
|
| 94 |
+
Please note that the monitors are checked every ``every_n_epochs`` epochs.
|
| 95 |
+
if ``save_top_k >= 2`` and the callback is called multiple
|
| 96 |
+
times inside an epoch, the name of the saved file will be
|
| 97 |
+
appended with a version count starting with ``v1``.
|
| 98 |
+
mode: one of {min, max}.
|
| 99 |
+
If ``save_top_k != 0``, the decision to overwrite the current save file is made
|
| 100 |
+
based on either the maximization or the minimization of the monitored quantity.
|
| 101 |
+
For ``'val_acc'``, this should be ``'max'``, for ``'val_loss'`` this should be ``'min'``, etc.
|
| 102 |
+
auto_insert_metric_name: When ``True``, the checkpoints filenames will contain the metric name.
|
| 103 |
+
For example, ``filename='checkpoint_{epoch:02d}-{acc:02.0f}`` with epoch ``1`` and acc ``1.12`` will resolve
|
| 104 |
+
to ``checkpoint_epoch=01-acc=01.ckpt``. Is useful to set it to ``False`` when metric names contain ``/``
|
| 105 |
+
as this will result in extra folders.
|
| 106 |
+
save_weights_only: if ``True``, then only the model's weights will be
|
| 107 |
+
saved. Otherwise, the optimizer states, lr-scheduler states, etc are added in the checkpoint too.
|
| 108 |
+
every_n_train_steps: Number of training steps between checkpoints.
|
| 109 |
+
If ``every_n_train_steps == None or every_n_train_steps == 0``, we skip saving during training.
|
| 110 |
+
To disable, set ``every_n_train_steps = 0``. This value must be ``None`` or non-negative.
|
| 111 |
+
This must be mutually exclusive with ``train_time_interval`` and ``every_n_epochs``.
|
| 112 |
+
train_time_interval: Checkpoints are monitored at the specified time interval.
|
| 113 |
+
For all practical purposes, this cannot be smaller than the amount
|
| 114 |
+
of time it takes to process a single training batch. This is not
|
| 115 |
+
guaranteed to execute at the exact time specified, but should be close.
|
| 116 |
+
This must be mutually exclusive with ``every_n_train_steps`` and ``every_n_epochs``.
|
| 117 |
+
every_n_epochs: Number of epochs between checkpoints.
|
| 118 |
+
This value must be ``None`` or non-negative.
|
| 119 |
+
To disable saving top-k checkpoints, set ``every_n_epochs = 0``.
|
| 120 |
+
This argument does not impact the saving of ``save_last=True`` checkpoints.
|
| 121 |
+
If all of ``every_n_epochs``, ``every_n_train_steps`` and
|
| 122 |
+
``train_time_interval`` are ``None``, we save a checkpoint at the end of every epoch
|
| 123 |
+
(equivalent to ``every_n_epochs = 1``).
|
| 124 |
+
If ``every_n_epochs == None`` and either ``every_n_train_steps != None`` or ``train_time_interval != None``,
|
| 125 |
+
saving at the end of each epoch is disabled
|
| 126 |
+
(equivalent to ``every_n_epochs = 0``).
|
| 127 |
+
This must be mutually exclusive with ``every_n_train_steps`` and ``train_time_interval``.
|
| 128 |
+
Setting both ``ModelCheckpoint(..., every_n_epochs=V, save_on_train_epoch_end=False)`` and
|
| 129 |
+
``Trainer(max_epochs=N, check_val_every_n_epoch=M)``
|
| 130 |
+
will only save checkpoints at epochs 0 < E <= N
|
| 131 |
+
where both values for ``every_n_epochs`` and ``check_val_every_n_epoch`` evenly divide E.
|
| 132 |
+
save_on_train_epoch_end: Whether to run checkpointing at the end of the training epoch.
|
| 133 |
+
If this is ``False``, then the check runs at the end of the validation.
|
| 134 |
+
|
| 135 |
+
Note:
|
| 136 |
+
For extra customization, ModelCheckpoint includes the following attributes:
|
| 137 |
+
|
| 138 |
+
- ``CHECKPOINT_JOIN_CHAR = "-"``
|
| 139 |
+
- ``CHECKPOINT_NAME_LAST = "last"``
|
| 140 |
+
- ``FILE_EXTENSION = ".ckpt"``
|
| 141 |
+
- ``STARTING_VERSION = 1``
|
| 142 |
+
|
| 143 |
+
For example, you can change the default last checkpoint name by doing
|
| 144 |
+
``checkpoint_callback.CHECKPOINT_NAME_LAST = "{epoch}-last"``
|
| 145 |
+
|
| 146 |
+
If you want to checkpoint every N hours, every M train batches, and/or every K val epochs,
|
| 147 |
+
then you should create multiple ``ModelCheckpoint`` callbacks.
|
| 148 |
+
|
| 149 |
+
If the checkpoint's ``dirpath`` changed from what it was before while resuming the training,
|
| 150 |
+
only ``best_model_path`` will be reloaded and a warning will be issued.
|
| 151 |
+
|
| 152 |
+
Raises:
|
| 153 |
+
MisconfigurationException:
|
| 154 |
+
If ``save_top_k`` is smaller than ``-1``,
|
| 155 |
+
if ``monitor`` is ``None`` and ``save_top_k`` is none of ``None``, ``-1``, and ``0``, or
|
| 156 |
+
if ``mode`` is none of ``"min"`` or ``"max"``.
|
| 157 |
+
ValueError:
|
| 158 |
+
If ``trainer.save_checkpoint`` is ``None``.
|
| 159 |
+
|
| 160 |
+
Example::
|
| 161 |
+
|
| 162 |
+
>>> from pytorch_lightning import Trainer
|
| 163 |
+
>>> from pytorch_lightning.callbacks import ModelCheckpoint
|
| 164 |
+
|
| 165 |
+
# saves checkpoints to 'my/path/' at every epoch
|
| 166 |
+
>>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
|
| 167 |
+
>>> trainer = Trainer(callbacks=[checkpoint_callback])
|
| 168 |
+
|
| 169 |
+
# save epoch and val_loss in name
|
| 170 |
+
# saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt
|
| 171 |
+
>>> checkpoint_callback = ModelCheckpoint(
|
| 172 |
+
... monitor='val_loss',
|
| 173 |
+
... dirpath='my/path/',
|
| 174 |
+
... filename='sample-mnist-{epoch:02d}-{val_loss:.2f}'
|
| 175 |
+
... )
|
| 176 |
+
|
| 177 |
+
# save epoch and val_loss in name, but specify the formatting yourself (e.g. to avoid problems with Tensorboard
|
| 178 |
+
# or Neptune, due to the presence of characters like '=' or '/')
|
| 179 |
+
# saves a file like: my/path/sample-mnist-epoch02-val_loss0.32.ckpt
|
| 180 |
+
>>> checkpoint_callback = ModelCheckpoint(
|
| 181 |
+
... monitor='val/loss',
|
| 182 |
+
... dirpath='my/path/',
|
| 183 |
+
... filename='sample-mnist-epoch{epoch:02d}-val_loss{val/loss:.2f}',
|
| 184 |
+
... auto_insert_metric_name=False
|
| 185 |
+
... )
|
| 186 |
+
|
| 187 |
+
# retrieve the best checkpoint after training
|
| 188 |
+
checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
|
| 189 |
+
trainer = Trainer(callbacks=[checkpoint_callback])
|
| 190 |
+
model = ...
|
| 191 |
+
trainer.fit(model)
|
| 192 |
+
checkpoint_callback.best_model_path
|
| 193 |
+
|
| 194 |
+
.. tip:: Saving and restoring multiple checkpoint callbacks at the same time is supported under variation in the
|
| 195 |
+
following arguments:
|
| 196 |
+
|
| 197 |
+
*monitor, mode, every_n_train_steps, every_n_epochs, train_time_interval, save_on_train_epoch_end*
|
| 198 |
+
|
| 199 |
+
Read more: :ref:`Persisting Callback State`
|
| 200 |
+
"""
|
| 201 |
+
|
| 202 |
+
CHECKPOINT_JOIN_CHAR = "-"
|
| 203 |
+
CHECKPOINT_NAME_LAST = "last"
|
| 204 |
+
FILE_EXTENSION = ".ckpt"
|
| 205 |
+
STARTING_VERSION = 1
|
| 206 |
+
|
| 207 |
+
def __init__(
|
| 208 |
+
self,
|
| 209 |
+
dirpath: Optional[_PATH] = None,
|
| 210 |
+
filename: Optional[str] = None,
|
| 211 |
+
monitor: Optional[str] = None,
|
| 212 |
+
verbose: bool = False,
|
| 213 |
+
save_last: Optional[bool] = None,
|
| 214 |
+
save_top_k: int = 1,
|
| 215 |
+
save_weights_only: bool = False,
|
| 216 |
+
mode: str = "min",
|
| 217 |
+
auto_insert_metric_name: bool = True,
|
| 218 |
+
every_n_train_steps: Optional[int] = None,
|
| 219 |
+
train_time_interval: Optional[timedelta] = None,
|
| 220 |
+
every_n_epochs: Optional[int] = None,
|
| 221 |
+
save_on_train_epoch_end: Optional[bool] = None,
|
| 222 |
+
):
|
| 223 |
+
super().__init__()
|
| 224 |
+
self.monitor = monitor
|
| 225 |
+
self.verbose = verbose
|
| 226 |
+
self.save_last = save_last
|
| 227 |
+
self.save_top_k = save_top_k
|
| 228 |
+
self.save_weights_only = save_weights_only
|
| 229 |
+
self.auto_insert_metric_name = auto_insert_metric_name
|
| 230 |
+
self._save_on_train_epoch_end = save_on_train_epoch_end
|
| 231 |
+
self._last_global_step_saved = 0 # no need to save when no steps were taken
|
| 232 |
+
self._last_time_checked: Optional[float] = None
|
| 233 |
+
self.current_score = None
|
| 234 |
+
self.best_k_models = {}
|
| 235 |
+
self.kth_best_model_path = ""
|
| 236 |
+
self.best_model_score = None
|
| 237 |
+
self.best_model_path = ""
|
| 238 |
+
self.last_model_path = ""
|
| 239 |
+
|
| 240 |
+
self.__init_monitor_mode(mode)
|
| 241 |
+
self.__init_ckpt_dir(dirpath, filename)
|
| 242 |
+
self.__init_triggers(every_n_train_steps, every_n_epochs, train_time_interval)
|
| 243 |
+
self.__validate_init_configuration()
|
| 244 |
+
|
| 245 |
+
@property
|
| 246 |
+
def state_key(self) -> str:
|
| 247 |
+
return self._generate_state_key(
|
| 248 |
+
monitor=self.monitor,
|
| 249 |
+
mode=self.mode,
|
| 250 |
+
every_n_train_steps=self._every_n_train_steps,
|
| 251 |
+
every_n_epochs=self._every_n_epochs,
|
| 252 |
+
train_time_interval=self._train_time_interval,
|
| 253 |
+
save_on_train_epoch_end=self._save_on_train_epoch_end,
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
|
| 257 |
+
self.__resolve_ckpt_dir(trainer)
|
| 258 |
+
if trainer.is_global_zero and stage == "fit":
|
| 259 |
+
self.__warn_if_dir_not_empty(self.dirpath)
|
| 260 |
+
|
| 261 |
+
# NOTE: setting these attributes needs to happen as early as possible BEFORE reloading callback states,
|
| 262 |
+
# because the attributes are part of the state_key which needs to be fully defined before reloading.
|
| 263 |
+
if self._save_on_train_epoch_end is None:
|
| 264 |
+
# if the user runs validation multiple times per training epoch or multiple training epochs without
|
| 265 |
+
# validation, then we run after validation instead of on train epoch end
|
| 266 |
+
self._save_on_train_epoch_end = trainer.val_check_interval == 1.0 and trainer.check_val_every_n_epoch == 1
|
| 267 |
+
|
| 268 |
+
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 269 |
+
self._last_time_checked = time.monotonic()
|
| 270 |
+
|
| 271 |
+
def on_train_batch_end(
|
| 272 |
+
self,
|
| 273 |
+
trainer: "pl.Trainer",
|
| 274 |
+
pl_module: "pl.LightningModule",
|
| 275 |
+
outputs: STEP_OUTPUT,
|
| 276 |
+
batch: Any,
|
| 277 |
+
batch_idx: int,
|
| 278 |
+
) -> None:
|
| 279 |
+
"""Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps`"""
|
| 280 |
+
if self._should_skip_saving_checkpoint(trainer):
|
| 281 |
+
return
|
| 282 |
+
skip_batch = self._every_n_train_steps < 1 or (trainer.global_step % self._every_n_train_steps != 0)
|
| 283 |
+
|
| 284 |
+
train_time_interval = self._train_time_interval
|
| 285 |
+
skip_time = True
|
| 286 |
+
now = time.monotonic()
|
| 287 |
+
if train_time_interval:
|
| 288 |
+
prev_time_check = self._last_time_checked
|
| 289 |
+
skip_time = prev_time_check is None or (now - prev_time_check) < train_time_interval.total_seconds()
|
| 290 |
+
# in case we have time differences across ranks
|
| 291 |
+
# broadcast the decision on whether to checkpoint from rank 0 to avoid possible hangs
|
| 292 |
+
skip_time = trainer.strategy.broadcast(skip_time)
|
| 293 |
+
|
| 294 |
+
if skip_batch and skip_time:
|
| 295 |
+
return
|
| 296 |
+
if not skip_time:
|
| 297 |
+
self._last_time_checked = now
|
| 298 |
+
|
| 299 |
+
monitor_candidates = self._monitor_candidates(trainer)
|
| 300 |
+
self._save_topk_checkpoint(trainer, monitor_candidates)
|
| 301 |
+
self._save_last_checkpoint(trainer, monitor_candidates)
|
| 302 |
+
|
| 303 |
+
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 304 |
+
"""Save a checkpoint at the end of the training epoch."""
|
| 305 |
+
if not self._should_skip_saving_checkpoint(trainer) and self._save_on_train_epoch_end:
|
| 306 |
+
monitor_candidates = self._monitor_candidates(trainer)
|
| 307 |
+
if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
|
| 308 |
+
self._save_topk_checkpoint(trainer, monitor_candidates)
|
| 309 |
+
self._save_last_checkpoint(trainer, monitor_candidates)
|
| 310 |
+
|
| 311 |
+
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 312 |
+
"""Save a checkpoint at the end of the validation stage."""
|
| 313 |
+
if not self._should_skip_saving_checkpoint(trainer) and not self._save_on_train_epoch_end:
|
| 314 |
+
monitor_candidates = self._monitor_candidates(trainer)
|
| 315 |
+
if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
|
| 316 |
+
self._save_topk_checkpoint(trainer, monitor_candidates)
|
| 317 |
+
self._save_last_checkpoint(trainer, monitor_candidates)
|
| 318 |
+
|
| 319 |
+
def state_dict(self) -> Dict[str, Any]:
|
| 320 |
+
return {
|
| 321 |
+
"monitor": self.monitor,
|
| 322 |
+
"best_model_score": self.best_model_score,
|
| 323 |
+
"best_model_path": self.best_model_path,
|
| 324 |
+
"current_score": self.current_score,
|
| 325 |
+
"dirpath": self.dirpath,
|
| 326 |
+
"best_k_models": self.best_k_models,
|
| 327 |
+
"kth_best_model_path": self.kth_best_model_path,
|
| 328 |
+
"kth_value": self.kth_value,
|
| 329 |
+
"last_model_path": self.last_model_path,
|
| 330 |
+
}
|
| 331 |
+
|
| 332 |
+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
| 333 |
+
dirpath_from_ckpt = state_dict.get("dirpath", self.dirpath)
|
| 334 |
+
|
| 335 |
+
if self.dirpath == dirpath_from_ckpt:
|
| 336 |
+
self.best_model_score = state_dict["best_model_score"]
|
| 337 |
+
self.kth_best_model_path = state_dict.get("kth_best_model_path", self.kth_best_model_path)
|
| 338 |
+
self.kth_value = state_dict.get("kth_value", self.kth_value)
|
| 339 |
+
self.best_k_models = state_dict.get("best_k_models", self.best_k_models)
|
| 340 |
+
self.last_model_path = state_dict.get("last_model_path", self.last_model_path)
|
| 341 |
+
else:
|
| 342 |
+
warnings.warn(
|
| 343 |
+
f"The dirpath has changed from {dirpath_from_ckpt!r} to {self.dirpath!r},"
|
| 344 |
+
" therefore `best_model_score`, `kth_best_model_path`, `kth_value`, `last_model_path` and"
|
| 345 |
+
" `best_k_models` won't be reloaded. Only `best_model_path` will be reloaded."
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
self.best_model_path = state_dict["best_model_path"]
|
| 349 |
+
|
| 350 |
+
def save_checkpoint(self, trainer: "pl.Trainer") -> None: # pragma: no-cover
|
| 351 |
+
"""Performs the main logic around saving a checkpoint.
|
| 352 |
+
|
| 353 |
+
This method runs on all ranks. It is the responsibility of `trainer.save_checkpoint` to correctly handle the
|
| 354 |
+
behaviour in distributed training, i.e., saving only on rank 0 for data parallel use cases.
|
| 355 |
+
"""
|
| 356 |
+
rank_zero_deprecation(
|
| 357 |
+
f"`{self.__class__.__name__}.save_checkpoint()` was deprecated in v1.6 and will be removed in v1.8."
|
| 358 |
+
" Instead, you can use `trainer.save_checkpoint()` to manually save a checkpoint."
|
| 359 |
+
)
|
| 360 |
+
monitor_candidates = self._monitor_candidates(trainer)
|
| 361 |
+
self._save_topk_checkpoint(trainer, monitor_candidates)
|
| 362 |
+
self._save_last_checkpoint(trainer, monitor_candidates)
|
| 363 |
+
|
| 364 |
+
def _save_topk_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
|
| 365 |
+
if self.save_top_k == 0:
|
| 366 |
+
return
|
| 367 |
+
|
| 368 |
+
# validate metric
|
| 369 |
+
if self.monitor is not None:
|
| 370 |
+
if self.monitor not in monitor_candidates:
|
| 371 |
+
m = (
|
| 372 |
+
f"`ModelCheckpoint(monitor={self.monitor!r})` could not find the monitored key in the returned"
|
| 373 |
+
f" metrics: {list(monitor_candidates)}."
|
| 374 |
+
f" HINT: Did you call `log({self.monitor!r}, value)` in the `LightningModule`?"
|
| 375 |
+
)
|
| 376 |
+
if trainer.fit_loop.epoch_loop.val_loop._has_run:
|
| 377 |
+
raise MisconfigurationException(m)
|
| 378 |
+
warning_cache.warn(m)
|
| 379 |
+
self._save_monitor_checkpoint(trainer, monitor_candidates)
|
| 380 |
+
else:
|
| 381 |
+
self._save_none_monitor_checkpoint(trainer, monitor_candidates)
|
| 382 |
+
|
| 383 |
+
def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
|
| 384 |
+
trainer.save_checkpoint(filepath, self.save_weights_only)
|
| 385 |
+
|
| 386 |
+
self._last_global_step_saved = trainer.global_step
|
| 387 |
+
|
| 388 |
+
# notify loggers
|
| 389 |
+
if trainer.is_global_zero:
|
| 390 |
+
for logger in trainer.loggers:
|
| 391 |
+
logger.after_save_checkpoint(proxy(self))
|
| 392 |
+
|
| 393 |
+
def _should_skip_saving_checkpoint(self, trainer: "pl.Trainer") -> bool:
|
| 394 |
+
from pytorch_lightning.trainer.states import TrainerFn
|
| 395 |
+
|
| 396 |
+
return (
|
| 397 |
+
trainer.fast_dev_run # disable checkpointing with fast_dev_run
|
| 398 |
+
or trainer.state.fn != TrainerFn.FITTING # don't save anything during non-fit
|
| 399 |
+
or trainer.sanity_checking # don't save anything during sanity check
|
| 400 |
+
or self._last_global_step_saved == trainer.global_step # already saved at the last step
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
def __validate_init_configuration(self) -> None:
|
| 404 |
+
if self.save_top_k < -1:
|
| 405 |
+
raise MisconfigurationException(f"Invalid value for save_top_k={self.save_top_k}. Must be >= -1")
|
| 406 |
+
if self._every_n_train_steps < 0:
|
| 407 |
+
raise MisconfigurationException(
|
| 408 |
+
f"Invalid value for every_n_train_steps={self._every_n_train_steps}. Must be >= 0"
|
| 409 |
+
)
|
| 410 |
+
if self._every_n_epochs < 0:
|
| 411 |
+
raise MisconfigurationException(f"Invalid value for every_n_epochs={self._every_n_epochs}. Must be >= 0")
|
| 412 |
+
|
| 413 |
+
every_n_train_steps_triggered = self._every_n_train_steps >= 1
|
| 414 |
+
every_n_epochs_triggered = self._every_n_epochs >= 1
|
| 415 |
+
train_time_interval_triggered = self._train_time_interval is not None
|
| 416 |
+
if every_n_train_steps_triggered + every_n_epochs_triggered + train_time_interval_triggered > 1:
|
| 417 |
+
raise MisconfigurationException(
|
| 418 |
+
f"Combination of parameters every_n_train_steps={self._every_n_train_steps}, "
|
| 419 |
+
f"every_n_epochs={self._every_n_epochs} and train_time_interval={self._train_time_interval} "
|
| 420 |
+
"should be mutually exclusive."
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
if self.monitor is None:
|
| 424 |
+
# -1: save all epochs, 0: nothing is saved, 1: save last epoch
|
| 425 |
+
if self.save_top_k not in (-1, 0, 1):
|
| 426 |
+
raise MisconfigurationException(
|
| 427 |
+
f"ModelCheckpoint(save_top_k={self.save_top_k}, monitor=None) is not a valid"
|
| 428 |
+
" configuration. No quantity for top_k to track."
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
if self.save_top_k == -1 and self.save_last:
|
| 432 |
+
rank_zero_info(
|
| 433 |
+
"ModelCheckpoint(save_last=True, save_top_k=-1, monitor=None)"
|
| 434 |
+
" will duplicate the last checkpoint saved."
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
def __init_ckpt_dir(self, dirpath: Optional[_PATH], filename: Optional[str]) -> None:
|
| 438 |
+
self._fs = get_filesystem(dirpath if dirpath else "")
|
| 439 |
+
|
| 440 |
+
if dirpath and self._fs.protocol == "file":
|
| 441 |
+
dirpath = os.path.realpath(dirpath)
|
| 442 |
+
|
| 443 |
+
self.dirpath = dirpath
|
| 444 |
+
self.filename = filename
|
| 445 |
+
|
| 446 |
+
def __init_monitor_mode(self, mode: str) -> None:
|
| 447 |
+
torch_inf = torch.tensor(np.Inf)
|
| 448 |
+
mode_dict = {"min": (torch_inf, "min"), "max": (-torch_inf, "max")}
|
| 449 |
+
|
| 450 |
+
if mode not in mode_dict:
|
| 451 |
+
raise MisconfigurationException(f"`mode` can be {', '.join(mode_dict.keys())} but got {mode}")
|
| 452 |
+
|
| 453 |
+
self.kth_value, self.mode = mode_dict[mode]
|
| 454 |
+
|
| 455 |
+
def __init_triggers(
|
| 456 |
+
self,
|
| 457 |
+
every_n_train_steps: Optional[int],
|
| 458 |
+
every_n_epochs: Optional[int],
|
| 459 |
+
train_time_interval: Optional[timedelta],
|
| 460 |
+
) -> None:
|
| 461 |
+
|
| 462 |
+
# Default to running once after each validation epoch if neither
|
| 463 |
+
# every_n_train_steps nor every_n_epochs is set
|
| 464 |
+
if every_n_train_steps is None and every_n_epochs is None and train_time_interval is None:
|
| 465 |
+
every_n_epochs = 1
|
| 466 |
+
every_n_train_steps = 0
|
| 467 |
+
log.debug("Both every_n_train_steps and every_n_epochs are not set. Setting every_n_epochs=1")
|
| 468 |
+
else:
|
| 469 |
+
every_n_epochs = every_n_epochs or 0
|
| 470 |
+
every_n_train_steps = every_n_train_steps or 0
|
| 471 |
+
|
| 472 |
+
self._train_time_interval: Optional[timedelta] = train_time_interval
|
| 473 |
+
self._every_n_epochs: int = every_n_epochs
|
| 474 |
+
self._every_n_train_steps: int = every_n_train_steps
|
| 475 |
+
|
| 476 |
+
@property
|
| 477 |
+
def every_n_epochs(self) -> Optional[int]:
|
| 478 |
+
return self._every_n_epochs
|
| 479 |
+
|
| 480 |
+
def check_monitor_top_k(self, trainer: "pl.Trainer", current: Optional[torch.Tensor] = None) -> bool:
|
| 481 |
+
if current is None:
|
| 482 |
+
return False
|
| 483 |
+
|
| 484 |
+
if self.save_top_k == -1:
|
| 485 |
+
return True
|
| 486 |
+
|
| 487 |
+
less_than_k_models = len(self.best_k_models) < self.save_top_k
|
| 488 |
+
if less_than_k_models:
|
| 489 |
+
return True
|
| 490 |
+
|
| 491 |
+
monitor_op = {"min": torch.lt, "max": torch.gt}[self.mode]
|
| 492 |
+
should_update_best_and_save = monitor_op(current, self.best_k_models[self.kth_best_model_path])
|
| 493 |
+
|
| 494 |
+
# If using multiple devices, make sure all processes are unanimous on the decision.
|
| 495 |
+
should_update_best_and_save = trainer.strategy.reduce_boolean_decision(should_update_best_and_save)
|
| 496 |
+
|
| 497 |
+
return should_update_best_and_save
|
| 498 |
+
|
| 499 |
+
@classmethod
|
| 500 |
+
def _format_checkpoint_name(
|
| 501 |
+
cls,
|
| 502 |
+
filename: Optional[str],
|
| 503 |
+
metrics: Dict[str, _METRIC],
|
| 504 |
+
prefix: str = "",
|
| 505 |
+
auto_insert_metric_name: bool = True,
|
| 506 |
+
) -> str:
|
| 507 |
+
if not filename:
|
| 508 |
+
# filename is not set, use default name
|
| 509 |
+
filename = "{epoch}" + cls.CHECKPOINT_JOIN_CHAR + "{step}"
|
| 510 |
+
|
| 511 |
+
# check and parse user passed keys in the string
|
| 512 |
+
groups = re.findall(r"(\{.*?)[:\}]", filename)
|
| 513 |
+
if len(groups) >= 0:
|
| 514 |
+
for group in groups:
|
| 515 |
+
name = group[1:]
|
| 516 |
+
|
| 517 |
+
if auto_insert_metric_name:
|
| 518 |
+
filename = filename.replace(group, name + "={" + name)
|
| 519 |
+
|
| 520 |
+
# support for dots: https://stackoverflow.com/a/7934969
|
| 521 |
+
filename = filename.replace(group, f"{{0[{name}]")
|
| 522 |
+
|
| 523 |
+
if name not in metrics:
|
| 524 |
+
metrics[name] = 0
|
| 525 |
+
filename = filename.format(metrics)
|
| 526 |
+
|
| 527 |
+
if prefix:
|
| 528 |
+
filename = cls.CHECKPOINT_JOIN_CHAR.join([prefix, filename])
|
| 529 |
+
|
| 530 |
+
return filename
|
| 531 |
+
|
| 532 |
+
def format_checkpoint_name(
|
| 533 |
+
self, metrics: Dict[str, _METRIC], filename: Optional[str] = None, ver: Optional[int] = None
|
| 534 |
+
) -> str:
|
| 535 |
+
"""Generate a filename according to the defined template.
|
| 536 |
+
|
| 537 |
+
Example::
|
| 538 |
+
|
| 539 |
+
>>> tmpdir = os.path.dirname(__file__)
|
| 540 |
+
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}')
|
| 541 |
+
>>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=0)))
|
| 542 |
+
'epoch=0.ckpt'
|
| 543 |
+
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch:03d}')
|
| 544 |
+
>>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=5)))
|
| 545 |
+
'epoch=005.ckpt'
|
| 546 |
+
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}-{val_loss:.2f}')
|
| 547 |
+
>>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=2, val_loss=0.123456)))
|
| 548 |
+
'epoch=2-val_loss=0.12.ckpt'
|
| 549 |
+
>>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=2, val_loss=0.12), filename='{epoch:d}'))
|
| 550 |
+
'epoch=2.ckpt'
|
| 551 |
+
>>> ckpt = ModelCheckpoint(dirpath=tmpdir,
|
| 552 |
+
... filename='epoch={epoch}-validation_loss={val_loss:.2f}',
|
| 553 |
+
... auto_insert_metric_name=False)
|
| 554 |
+
>>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=2, val_loss=0.123456)))
|
| 555 |
+
'epoch=2-validation_loss=0.12.ckpt'
|
| 556 |
+
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{missing:d}')
|
| 557 |
+
>>> os.path.basename(ckpt.format_checkpoint_name({}))
|
| 558 |
+
'missing=0.ckpt'
|
| 559 |
+
>>> ckpt = ModelCheckpoint(filename='{step}')
|
| 560 |
+
>>> os.path.basename(ckpt.format_checkpoint_name(dict(step=0)))
|
| 561 |
+
'step=0.ckpt'
|
| 562 |
+
"""
|
| 563 |
+
filename = filename or self.filename
|
| 564 |
+
filename = self._format_checkpoint_name(filename, metrics, auto_insert_metric_name=self.auto_insert_metric_name)
|
| 565 |
+
|
| 566 |
+
if ver is not None:
|
| 567 |
+
filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}"))
|
| 568 |
+
|
| 569 |
+
ckpt_name = f"{filename}{self.FILE_EXTENSION}"
|
| 570 |
+
return os.path.join(self.dirpath, ckpt_name) if self.dirpath else ckpt_name
|
| 571 |
+
|
| 572 |
+
def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> None:
|
| 573 |
+
"""Determines model checkpoint save directory at runtime. References attributes from the trainer's logger
|
| 574 |
+
to determine where to save checkpoints. The base path for saving weights is set in this priority:
|
| 575 |
+
|
| 576 |
+
1. Checkpoint callback's path (if passed in)
|
| 577 |
+
2. The default_root_dir from trainer if trainer has no logger
|
| 578 |
+
3. The weights_save_path from trainer, if user provides it (deprecated)
|
| 579 |
+
4. User provided weights_saved_path
|
| 580 |
+
|
| 581 |
+
The base path gets extended with logger name and version (if these are available)
|
| 582 |
+
and subfolder "checkpoints".
|
| 583 |
+
"""
|
| 584 |
+
if self.dirpath is not None:
|
| 585 |
+
return # short circuit
|
| 586 |
+
|
| 587 |
+
# TODO: Remove weights_save_path logic here in v1.8
|
| 588 |
+
if trainer.loggers:
|
| 589 |
+
if trainer._weights_save_path_internal != trainer.default_root_dir:
|
| 590 |
+
# the user has changed weights_save_path, it overrides anything
|
| 591 |
+
save_dir = trainer._weights_save_path_internal
|
| 592 |
+
elif len(trainer.loggers) == 1:
|
| 593 |
+
save_dir = trainer.logger.save_dir or trainer.default_root_dir
|
| 594 |
+
else:
|
| 595 |
+
save_dir = trainer.default_root_dir
|
| 596 |
+
|
| 597 |
+
name = _name(trainer.loggers)
|
| 598 |
+
version = _version(trainer.loggers)
|
| 599 |
+
version = version if isinstance(version, str) else f"version_{version}"
|
| 600 |
+
|
| 601 |
+
ckpt_path = os.path.join(save_dir, str(name), version, "checkpoints")
|
| 602 |
+
else:
|
| 603 |
+
ckpt_path = os.path.join(trainer._weights_save_path_internal, "checkpoints")
|
| 604 |
+
|
| 605 |
+
ckpt_path = trainer.strategy.broadcast(ckpt_path)
|
| 606 |
+
|
| 607 |
+
self.dirpath = ckpt_path
|
| 608 |
+
|
| 609 |
+
def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None:
|
| 610 |
+
if self.save_top_k != 0 and self._fs.isdir(dirpath) and len(self._fs.ls(dirpath)) > 0:
|
| 611 |
+
rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
|
| 612 |
+
|
| 613 |
+
def _get_metric_interpolated_filepath_name(
|
| 614 |
+
self, monitor_candidates: Dict[str, _METRIC], trainer: "pl.Trainer", del_filepath: Optional[str] = None
|
| 615 |
+
) -> str:
|
| 616 |
+
filepath = self.format_checkpoint_name(monitor_candidates)
|
| 617 |
+
|
| 618 |
+
version_cnt = self.STARTING_VERSION
|
| 619 |
+
while self.file_exists(filepath, trainer) and filepath != del_filepath:
|
| 620 |
+
filepath = self.format_checkpoint_name(monitor_candidates, ver=version_cnt)
|
| 621 |
+
version_cnt += 1
|
| 622 |
+
|
| 623 |
+
return filepath
|
| 624 |
+
|
| 625 |
+
def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, _METRIC]:
|
| 626 |
+
monitor_candidates = deepcopy(trainer.callback_metrics)
|
| 627 |
+
# cast to int if necessary because `self.log("epoch", 123)` will convert it to float. if it's not a tensor
|
| 628 |
+
# or does not exist we overwrite it as it's likely an error
|
| 629 |
+
epoch = monitor_candidates.get("epoch")
|
| 630 |
+
monitor_candidates["epoch"] = (
|
| 631 |
+
epoch.int() if isinstance(epoch, torch.Tensor) else torch.tensor(trainer.current_epoch)
|
| 632 |
+
)
|
| 633 |
+
step = monitor_candidates.get("step")
|
| 634 |
+
monitor_candidates["step"] = step.int() if isinstance(step, torch.Tensor) else torch.tensor(trainer.global_step)
|
| 635 |
+
return monitor_candidates
|
| 636 |
+
|
| 637 |
+
def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
|
| 638 |
+
if not self.save_last:
|
| 639 |
+
return
|
| 640 |
+
|
| 641 |
+
filepath = self.format_checkpoint_name(monitor_candidates, self.CHECKPOINT_NAME_LAST)
|
| 642 |
+
# set the last model path before saving because it will be part of the state.
|
| 643 |
+
previous, self.last_model_path = self.last_model_path, filepath
|
| 644 |
+
self._save_checkpoint(trainer, filepath)
|
| 645 |
+
if previous and previous != filepath:
|
| 646 |
+
trainer.strategy.remove_checkpoint(previous)
|
| 647 |
+
|
| 648 |
+
def _save_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
|
| 649 |
+
current = monitor_candidates.get(self.monitor)
|
| 650 |
+
if self.check_monitor_top_k(trainer, current):
|
| 651 |
+
self._update_best_and_save(current, trainer, monitor_candidates)
|
| 652 |
+
elif self.verbose:
|
| 653 |
+
epoch = monitor_candidates["epoch"]
|
| 654 |
+
step = monitor_candidates["step"]
|
| 655 |
+
rank_zero_info(f"Epoch {epoch:d}, global step {step:d}: {self.monitor!r} was not in top {self.save_top_k}")
|
| 656 |
+
|
| 657 |
+
def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
|
| 658 |
+
filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer)
|
| 659 |
+
# set the best model path before saving because it will be part of the state.
|
| 660 |
+
previous, self.best_model_path = self.best_model_path, filepath
|
| 661 |
+
self._save_checkpoint(trainer, filepath)
|
| 662 |
+
if self.save_top_k == 1 and previous and previous != filepath:
|
| 663 |
+
trainer.strategy.remove_checkpoint(previous)
|
| 664 |
+
|
| 665 |
+
def _update_best_and_save(
|
| 666 |
+
self, current: torch.Tensor, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]
|
| 667 |
+
) -> None:
|
| 668 |
+
k = len(self.best_k_models) + 1 if self.save_top_k == -1 else self.save_top_k
|
| 669 |
+
|
| 670 |
+
del_filepath = None
|
| 671 |
+
if len(self.best_k_models) == k and k > 0:
|
| 672 |
+
del_filepath = self.kth_best_model_path
|
| 673 |
+
self.best_k_models.pop(del_filepath)
|
| 674 |
+
|
| 675 |
+
# do not save nan, replace with +/- inf
|
| 676 |
+
if isinstance(current, torch.Tensor) and torch.isnan(current):
|
| 677 |
+
current = torch.tensor(float("inf" if self.mode == "min" else "-inf"), device=current.device)
|
| 678 |
+
|
| 679 |
+
filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer, del_filepath)
|
| 680 |
+
|
| 681 |
+
# save the current score
|
| 682 |
+
self.current_score = current
|
| 683 |
+
self.best_k_models[filepath] = current
|
| 684 |
+
|
| 685 |
+
if len(self.best_k_models) == k:
|
| 686 |
+
# monitor dict has reached k elements
|
| 687 |
+
_op = max if self.mode == "min" else min
|
| 688 |
+
self.kth_best_model_path = _op(self.best_k_models, key=self.best_k_models.get)
|
| 689 |
+
self.kth_value = self.best_k_models[self.kth_best_model_path]
|
| 690 |
+
|
| 691 |
+
_op = min if self.mode == "min" else max
|
| 692 |
+
self.best_model_path = _op(self.best_k_models, key=self.best_k_models.get)
|
| 693 |
+
self.best_model_score = self.best_k_models[self.best_model_path]
|
| 694 |
+
|
| 695 |
+
if self.verbose:
|
| 696 |
+
epoch = monitor_candidates["epoch"]
|
| 697 |
+
step = monitor_candidates["step"]
|
| 698 |
+
rank_zero_info(
|
| 699 |
+
f"Epoch {epoch:d}, global step {step:d}: {self.monitor!r} reached {current:0.5f}"
|
| 700 |
+
f" (best {self.best_model_score:0.5f}), saving model to {filepath!r} as top {k}"
|
| 701 |
+
)
|
| 702 |
+
self._save_checkpoint(trainer, filepath)
|
| 703 |
+
|
| 704 |
+
if del_filepath is not None and filepath != del_filepath:
|
| 705 |
+
trainer.strategy.remove_checkpoint(del_filepath)
|
| 706 |
+
|
| 707 |
+
def to_yaml(self, filepath: Optional[_PATH] = None) -> None:
|
| 708 |
+
"""Saves the `best_k_models` dict containing the checkpoint paths with the corresponding scores to a YAML
|
| 709 |
+
file."""
|
| 710 |
+
best_k = {k: v.item() for k, v in self.best_k_models.items()}
|
| 711 |
+
if filepath is None:
|
| 712 |
+
filepath = os.path.join(self.dirpath, "best_k_models.yaml")
|
| 713 |
+
with self._fs.open(filepath, "w") as fp:
|
| 714 |
+
yaml.dump(best_k, fp)
|
| 715 |
+
|
| 716 |
+
def file_exists(self, filepath: _PATH, trainer: "pl.Trainer") -> bool:
|
| 717 |
+
"""Checks if a file exists on rank 0 and broadcasts the result to all other ranks, preventing the internal
|
| 718 |
+
state to diverge between ranks."""
|
| 719 |
+
exists = self._fs.exists(filepath)
|
| 720 |
+
return trainer.strategy.broadcast(exists)
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_summary.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright The PyTorch Lightning team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""
|
| 15 |
+
Model Summary
|
| 16 |
+
=============
|
| 17 |
+
|
| 18 |
+
Generates a summary of all layers in a :class:`~pytorch_lightning.core.lightning.LightningModule`.
|
| 19 |
+
|
| 20 |
+
The string representation of this summary prints a table with columns containing
|
| 21 |
+
the name, type and number of parameters for each layer.
|
| 22 |
+
|
| 23 |
+
"""
|
| 24 |
+
import logging
|
| 25 |
+
from typing import List, Tuple
|
| 26 |
+
|
| 27 |
+
import pytorch_lightning as pl
|
| 28 |
+
from pytorch_lightning.callbacks.base import Callback
|
| 29 |
+
from pytorch_lightning.utilities.model_summary import _format_summary_table, summarize
|
| 30 |
+
|
| 31 |
+
log = logging.getLogger(__name__)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class ModelSummary(Callback):
|
| 35 |
+
r"""
|
| 36 |
+
Generates a summary of all layers in a :class:`~pytorch_lightning.core.lightning.LightningModule`.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
max_depth: The maximum depth of layer nesting that the summary will include. A value of 0 turns the
|
| 40 |
+
layer summary off.
|
| 41 |
+
|
| 42 |
+
Example::
|
| 43 |
+
|
| 44 |
+
>>> from pytorch_lightning import Trainer
|
| 45 |
+
>>> from pytorch_lightning.callbacks import ModelSummary
|
| 46 |
+
>>> trainer = Trainer(callbacks=[ModelSummary(max_depth=1)])
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
def __init__(self, max_depth: int = 1) -> None:
|
| 50 |
+
self._max_depth: int = max_depth
|
| 51 |
+
|
| 52 |
+
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 53 |
+
if not self._max_depth:
|
| 54 |
+
return None
|
| 55 |
+
|
| 56 |
+
model_summary = summarize(pl_module, max_depth=self._max_depth)
|
| 57 |
+
summary_data = model_summary._get_summary_data()
|
| 58 |
+
total_parameters = model_summary.total_parameters
|
| 59 |
+
trainable_parameters = model_summary.trainable_parameters
|
| 60 |
+
model_size = model_summary.model_size
|
| 61 |
+
|
| 62 |
+
if trainer.is_global_zero:
|
| 63 |
+
self.summarize(summary_data, total_parameters, trainable_parameters, model_size)
|
| 64 |
+
|
| 65 |
+
@staticmethod
|
| 66 |
+
def summarize(
|
| 67 |
+
summary_data: List[Tuple[str, List[str]]],
|
| 68 |
+
total_parameters: int,
|
| 69 |
+
trainable_parameters: int,
|
| 70 |
+
model_size: float,
|
| 71 |
+
) -> None:
|
| 72 |
+
summary_table = _format_summary_table(total_parameters, trainable_parameters, model_size, *summary_data)
|
| 73 |
+
log.info("\n" + summary_table)
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/prediction_writer.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright The PyTorch Lightning team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
r"""
|
| 15 |
+
BasePredictionWriter
|
| 16 |
+
====================
|
| 17 |
+
|
| 18 |
+
Aids in saving predictions
|
| 19 |
+
"""
|
| 20 |
+
from typing import Any, Optional, Sequence
|
| 21 |
+
|
| 22 |
+
import pytorch_lightning as pl
|
| 23 |
+
from pytorch_lightning.callbacks.base import Callback
|
| 24 |
+
from pytorch_lightning.utilities import LightningEnum
|
| 25 |
+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class WriteInterval(LightningEnum):
|
| 29 |
+
BATCH = "batch"
|
| 30 |
+
EPOCH = "epoch"
|
| 31 |
+
BATCH_AND_EPOCH = "batch_and_epoch"
|
| 32 |
+
|
| 33 |
+
@property
|
| 34 |
+
def on_batch(self) -> bool:
|
| 35 |
+
return self in (self.BATCH, self.BATCH_AND_EPOCH)
|
| 36 |
+
|
| 37 |
+
@property
|
| 38 |
+
def on_epoch(self) -> bool:
|
| 39 |
+
return self in (self.EPOCH, self.BATCH_AND_EPOCH)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class BasePredictionWriter(Callback):
|
| 43 |
+
"""Base class to implement how the predictions should be stored.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
write_interval: When to write.
|
| 47 |
+
|
| 48 |
+
Example::
|
| 49 |
+
|
| 50 |
+
import torch
|
| 51 |
+
from pytorch_lightning.callbacks import BasePredictionWriter
|
| 52 |
+
|
| 53 |
+
class CustomWriter(BasePredictionWriter):
|
| 54 |
+
|
| 55 |
+
def __init__(self, output_dir: str, write_interval: str):
|
| 56 |
+
super().__init__(write_interval)
|
| 57 |
+
self.output_dir
|
| 58 |
+
|
| 59 |
+
def write_on_batch_end(
|
| 60 |
+
self, trainer, pl_module: 'LightningModule', prediction: Any, batch_indices: List[int], batch: Any,
|
| 61 |
+
batch_idx: int, dataloader_idx: int
|
| 62 |
+
):
|
| 63 |
+
torch.save(prediction, os.path.join(self.output_dir, dataloader_idx, f"{batch_idx}.pt"))
|
| 64 |
+
|
| 65 |
+
def write_on_epoch_end(
|
| 66 |
+
self, trainer, pl_module: 'LightningModule', predictions: List[Any], batch_indices: List[Any]
|
| 67 |
+
):
|
| 68 |
+
torch.save(predictions, os.path.join(self.output_dir, "predictions.pt"))
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
def __init__(self, write_interval: str = "batch") -> None:
|
| 72 |
+
if write_interval not in list(WriteInterval):
|
| 73 |
+
raise MisconfigurationException(f"`write_interval` should be one of {[i.value for i in WriteInterval]}.")
|
| 74 |
+
self.interval = WriteInterval(write_interval)
|
| 75 |
+
|
| 76 |
+
def write_on_batch_end(
|
| 77 |
+
self,
|
| 78 |
+
trainer: "pl.Trainer",
|
| 79 |
+
pl_module: "pl.LightningModule",
|
| 80 |
+
prediction: Any,
|
| 81 |
+
batch_indices: Optional[Sequence[int]],
|
| 82 |
+
batch: Any,
|
| 83 |
+
batch_idx: int,
|
| 84 |
+
dataloader_idx: int,
|
| 85 |
+
) -> None:
|
| 86 |
+
"""Override with the logic to write a single batch."""
|
| 87 |
+
raise NotImplementedError()
|
| 88 |
+
|
| 89 |
+
def write_on_epoch_end(
|
| 90 |
+
self,
|
| 91 |
+
trainer: "pl.Trainer",
|
| 92 |
+
pl_module: "pl.LightningModule",
|
| 93 |
+
predictions: Sequence[Any],
|
| 94 |
+
batch_indices: Optional[Sequence[Any]],
|
| 95 |
+
) -> None:
|
| 96 |
+
"""Override with the logic to write all batches."""
|
| 97 |
+
raise NotImplementedError()
|
| 98 |
+
|
| 99 |
+
def on_predict_batch_end(
|
| 100 |
+
self,
|
| 101 |
+
trainer: "pl.Trainer",
|
| 102 |
+
pl_module: "pl.LightningModule",
|
| 103 |
+
outputs: Any,
|
| 104 |
+
batch: Any,
|
| 105 |
+
batch_idx: int,
|
| 106 |
+
dataloader_idx: int,
|
| 107 |
+
) -> None:
|
| 108 |
+
if not self.interval.on_batch:
|
| 109 |
+
return
|
| 110 |
+
batch_indices = trainer.predict_loop.epoch_loop.current_batch_indices
|
| 111 |
+
self.write_on_batch_end(trainer, pl_module, outputs, batch_indices, batch, batch_idx, dataloader_idx)
|
| 112 |
+
|
| 113 |
+
def on_predict_epoch_end(
|
| 114 |
+
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: Sequence[Any]
|
| 115 |
+
) -> None:
|
| 116 |
+
if not self.interval.on_epoch:
|
| 117 |
+
return
|
| 118 |
+
epoch_batch_indices = trainer.predict_loop.epoch_batch_indices
|
| 119 |
+
self.write_on_epoch_end(trainer, pl_module, trainer.predict_loop.predictions, epoch_batch_indices)
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/pruning.py
ADDED
|
@@ -0,0 +1,486 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright The PyTorch Lightning team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
r"""
|
| 15 |
+
ModelPruning
|
| 16 |
+
^^^^^^^^^^^^
|
| 17 |
+
"""
|
| 18 |
+
import inspect
|
| 19 |
+
import logging
|
| 20 |
+
from copy import deepcopy
|
| 21 |
+
from functools import partial
|
| 22 |
+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
import torch.nn.utils.prune as pytorch_prune
|
| 26 |
+
from torch import nn
|
| 27 |
+
from typing_extensions import TypedDict
|
| 28 |
+
|
| 29 |
+
import pytorch_lightning as pl
|
| 30 |
+
from pytorch_lightning.callbacks.base import Callback
|
| 31 |
+
from pytorch_lightning.core.lightning import LightningModule
|
| 32 |
+
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
| 33 |
+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
| 34 |
+
from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_only
|
| 35 |
+
|
| 36 |
+
log = logging.getLogger(__name__)
|
| 37 |
+
|
| 38 |
+
_PYTORCH_PRUNING_FUNCTIONS = {
|
| 39 |
+
"ln_structured": pytorch_prune.ln_structured,
|
| 40 |
+
"l1_unstructured": pytorch_prune.l1_unstructured,
|
| 41 |
+
"random_structured": pytorch_prune.random_structured,
|
| 42 |
+
"random_unstructured": pytorch_prune.random_unstructured,
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
_PYTORCH_PRUNING_METHOD = {
|
| 46 |
+
"ln_structured": pytorch_prune.LnStructured,
|
| 47 |
+
"l1_unstructured": pytorch_prune.L1Unstructured,
|
| 48 |
+
"random_structured": pytorch_prune.RandomStructured,
|
| 49 |
+
"random_unstructured": pytorch_prune.RandomUnstructured,
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
_PARAM_TUPLE = Tuple[nn.Module, str]
|
| 53 |
+
_PARAM_LIST = Sequence[_PARAM_TUPLE]
|
| 54 |
+
_MODULE_CONTAINERS = (LightningModule, nn.Sequential, nn.ModuleList, nn.ModuleDict)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class _LayerRef(TypedDict):
|
| 58 |
+
data: nn.Module
|
| 59 |
+
names: List[Tuple[int, str]]
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class ModelPruning(Callback):
|
| 63 |
+
PARAMETER_NAMES = ("weight", "bias")
|
| 64 |
+
|
| 65 |
+
def __init__(
|
| 66 |
+
self,
|
| 67 |
+
pruning_fn: Union[Callable, str],
|
| 68 |
+
parameters_to_prune: _PARAM_LIST = (),
|
| 69 |
+
parameter_names: Optional[List[str]] = None,
|
| 70 |
+
use_global_unstructured: bool = True,
|
| 71 |
+
amount: Union[int, float, Callable[[int], Union[int, float]]] = 0.5,
|
| 72 |
+
apply_pruning: Union[bool, Callable[[int], bool]] = True,
|
| 73 |
+
make_pruning_permanent: bool = True,
|
| 74 |
+
use_lottery_ticket_hypothesis: Union[bool, Callable[[int], bool]] = True,
|
| 75 |
+
resample_parameters: bool = False,
|
| 76 |
+
pruning_dim: Optional[int] = None,
|
| 77 |
+
pruning_norm: Optional[int] = None,
|
| 78 |
+
verbose: int = 0,
|
| 79 |
+
prune_on_train_epoch_end: bool = True,
|
| 80 |
+
) -> None:
|
| 81 |
+
"""Model pruning Callback, using PyTorch's prune utilities. This callback is responsible of pruning
|
| 82 |
+
networks parameters during training.
|
| 83 |
+
|
| 84 |
+
To learn more about pruning with PyTorch, please take a look at
|
| 85 |
+
`this tutorial <https://pytorch.org/tutorials/intermediate/pruning_tutorial.html>`_.
|
| 86 |
+
|
| 87 |
+
.. warning:: ``ModelPruning`` is in beta and subject to change.
|
| 88 |
+
|
| 89 |
+
.. code-block:: python
|
| 90 |
+
|
| 91 |
+
parameters_to_prune = [(model.mlp_1, "weight"), (model.mlp_2, "weight")]
|
| 92 |
+
|
| 93 |
+
trainer = Trainer(
|
| 94 |
+
callbacks=[
|
| 95 |
+
ModelPruning(
|
| 96 |
+
pruning_fn="l1_unstructured",
|
| 97 |
+
parameters_to_prune=parameters_to_prune,
|
| 98 |
+
amount=0.01,
|
| 99 |
+
use_global_unstructured=True,
|
| 100 |
+
)
|
| 101 |
+
]
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
When ``parameters_to_prune`` is ``None``, ``parameters_to_prune`` will contain all parameters from the model.
|
| 105 |
+
The user can override ``filter_parameters_to_prune`` to filter any ``nn.Module`` to be pruned.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
|
| 109 |
+
pruning_fn: Function from torch.nn.utils.prune module or your own PyTorch ``BasePruningMethod`` subclass.
|
| 110 |
+
Can also be string e.g. `"l1_unstructured"`. See pytorch docs for more details.
|
| 111 |
+
|
| 112 |
+
parameters_to_prune: List of tuples ``(nn.Module, "parameter_name_string")``.
|
| 113 |
+
|
| 114 |
+
parameter_names: List of parameter names to be pruned from the nn.Module.
|
| 115 |
+
Can either be ``"weight"`` or ``"bias"``.
|
| 116 |
+
|
| 117 |
+
use_global_unstructured: Whether to apply pruning globally on the model.
|
| 118 |
+
If ``parameters_to_prune`` is provided, global unstructured will be restricted on them.
|
| 119 |
+
|
| 120 |
+
amount: Quantity of parameters to prune:
|
| 121 |
+
|
| 122 |
+
- ``float``. Between 0.0 and 1.0. Represents the fraction of parameters to prune.
|
| 123 |
+
- ``int``. Represents the absolute number of parameters to prune.
|
| 124 |
+
- ``Callable``. For dynamic values. Will be called every epoch. Should return a value.
|
| 125 |
+
|
| 126 |
+
apply_pruning: Whether to apply pruning.
|
| 127 |
+
|
| 128 |
+
- ``bool``. Always apply it or not.
|
| 129 |
+
- ``Callable[[epoch], bool]``. For dynamic values. Will be called every epoch.
|
| 130 |
+
|
| 131 |
+
make_pruning_permanent: Whether to remove all reparametrization pre-hooks and apply masks
|
| 132 |
+
when training ends or the model is saved.
|
| 133 |
+
|
| 134 |
+
use_lottery_ticket_hypothesis: See `The lottery ticket hypothesis <https://arxiv.org/abs/1803.03635>`_:
|
| 135 |
+
|
| 136 |
+
- ``bool``. Whether to apply it or not.
|
| 137 |
+
- ``Callable[[epoch], bool]``. For dynamic values. Will be called every epoch.
|
| 138 |
+
|
| 139 |
+
resample_parameters: Used with ``use_lottery_ticket_hypothesis``. If True, the model parameters will
|
| 140 |
+
be resampled, otherwise, the exact original parameters will be used.
|
| 141 |
+
|
| 142 |
+
pruning_dim: If you are using a structured pruning method you need to specify the dimension.
|
| 143 |
+
|
| 144 |
+
pruning_norm: If you are using ``ln_structured`` you need to specify the norm.
|
| 145 |
+
|
| 146 |
+
verbose: Verbosity level. 0 to disable, 1 to log overall sparsity, 2 to log per-layer sparsity
|
| 147 |
+
|
| 148 |
+
prune_on_train_epoch_end: whether to apply pruning at the end of the training epoch.
|
| 149 |
+
If this is ``False``, then the check runs at the end of the validation epoch.
|
| 150 |
+
|
| 151 |
+
Raises:
|
| 152 |
+
MisconfigurationException:
|
| 153 |
+
If ``parameter_names`` is neither ``"weight"`` nor ``"bias"``,
|
| 154 |
+
if the provided ``pruning_fn`` is not supported,
|
| 155 |
+
if ``pruning_dim`` is not provided when ``"unstructured"``,
|
| 156 |
+
if ``pruning_norm`` is not provided when ``"ln_structured"``,
|
| 157 |
+
if ``pruning_fn`` is neither ``str`` nor :class:`torch.nn.utils.prune.BasePruningMethod`, or
|
| 158 |
+
if ``amount`` is none of ``int``, ``float`` and ``Callable``.
|
| 159 |
+
"""
|
| 160 |
+
|
| 161 |
+
self._use_global_unstructured = use_global_unstructured
|
| 162 |
+
self._parameters_to_prune = parameters_to_prune
|
| 163 |
+
self._use_lottery_ticket_hypothesis = use_lottery_ticket_hypothesis
|
| 164 |
+
self._resample_parameters = resample_parameters
|
| 165 |
+
self._prune_on_train_epoch_end = prune_on_train_epoch_end
|
| 166 |
+
self._parameter_names = parameter_names or self.PARAMETER_NAMES
|
| 167 |
+
self._global_kwargs: Dict[str, Any] = {}
|
| 168 |
+
self._original_layers: Optional[Dict[int, _LayerRef]] = None
|
| 169 |
+
self._pruning_method_name: Optional[str] = None
|
| 170 |
+
|
| 171 |
+
for name in self._parameter_names:
|
| 172 |
+
if name not in self.PARAMETER_NAMES:
|
| 173 |
+
raise MisconfigurationException(
|
| 174 |
+
f"The provided `parameter_names` name: {name} isn't in {self.PARAMETER_NAMES}"
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
if isinstance(pruning_fn, str):
|
| 178 |
+
pruning_kwargs = {}
|
| 179 |
+
pruning_fn = pruning_fn.lower()
|
| 180 |
+
if pruning_fn not in _PYTORCH_PRUNING_FUNCTIONS:
|
| 181 |
+
raise MisconfigurationException(
|
| 182 |
+
f"The provided `pruning_fn` {pruning_fn} isn't available in PyTorch's"
|
| 183 |
+
f" built-in functions: {list(_PYTORCH_PRUNING_FUNCTIONS.keys())} "
|
| 184 |
+
)
|
| 185 |
+
if pruning_fn.endswith("_structured"):
|
| 186 |
+
if pruning_dim is None:
|
| 187 |
+
raise MisconfigurationException(
|
| 188 |
+
"When requesting `structured` pruning, the `pruning_dim` should be provided."
|
| 189 |
+
)
|
| 190 |
+
if pruning_fn == "ln_structured":
|
| 191 |
+
if pruning_norm is None:
|
| 192 |
+
raise MisconfigurationException(
|
| 193 |
+
"When requesting `ln_structured` pruning, the `pruning_norm` should be provided."
|
| 194 |
+
)
|
| 195 |
+
pruning_kwargs["n"] = pruning_norm
|
| 196 |
+
pruning_kwargs["dim"] = pruning_dim
|
| 197 |
+
pruning_fn = self._create_pruning_fn(pruning_fn, **pruning_kwargs)
|
| 198 |
+
elif self._is_pruning_method(pruning_fn):
|
| 199 |
+
if not use_global_unstructured:
|
| 200 |
+
raise MisconfigurationException(
|
| 201 |
+
"PyTorch `BasePruningMethod` is currently only supported with `use_global_unstructured=True`."
|
| 202 |
+
)
|
| 203 |
+
else:
|
| 204 |
+
raise MisconfigurationException(
|
| 205 |
+
f"`pruning_fn` is expected to be a str in {list(_PYTORCH_PRUNING_FUNCTIONS.keys())}"
|
| 206 |
+
f" or a PyTorch `BasePruningMethod`. Found: {pruning_fn}."
|
| 207 |
+
" HINT: if passing a `BasePruningMethod`, pass the the class, not an instance"
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
# need to ignore typing here since pytorch base class does not define the PRUNING_TYPE attribute
|
| 211 |
+
if use_global_unstructured and pruning_fn.PRUNING_TYPE != "unstructured": # type: ignore
|
| 212 |
+
raise MisconfigurationException(
|
| 213 |
+
'Only the "unstructured" PRUNING_TYPE is supported with `use_global_unstructured=True`.' # type: ignore
|
| 214 |
+
f" Found method {pruning_fn} of type {pruning_fn.PRUNING_TYPE}. "
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
self.pruning_fn = pruning_fn
|
| 218 |
+
self._apply_pruning = apply_pruning
|
| 219 |
+
self._make_pruning_permanent = make_pruning_permanent
|
| 220 |
+
|
| 221 |
+
if not (isinstance(amount, (int, float)) or callable(amount)):
|
| 222 |
+
raise MisconfigurationException(
|
| 223 |
+
"`amount` should be provided and be either an int, a float or Callable function."
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
self.amount = amount
|
| 227 |
+
|
| 228 |
+
if verbose not in (0, 1, 2):
|
| 229 |
+
raise MisconfigurationException("`verbose` must be any of (0, 1, 2)")
|
| 230 |
+
|
| 231 |
+
self._verbose = verbose
|
| 232 |
+
|
| 233 |
+
def filter_parameters_to_prune(self, parameters_to_prune: _PARAM_LIST = ()) -> _PARAM_LIST:
|
| 234 |
+
"""This function can be overridden to control which module to prune."""
|
| 235 |
+
return parameters_to_prune
|
| 236 |
+
|
| 237 |
+
def _create_pruning_fn(self, pruning_fn: str, **kwargs: Any) -> Union[Callable, pytorch_prune.BasePruningMethod]:
|
| 238 |
+
"""This function takes `pruning_fn`, a function name.
|
| 239 |
+
|
| 240 |
+
IF use_global_unstructured, pruning_fn will be resolved into its associated ``PyTorch BasePruningMethod`` ELSE,
|
| 241 |
+
pruning_fn will be resolved into its function counterpart from `torch.nn.utils.prune`.
|
| 242 |
+
"""
|
| 243 |
+
pruning_meth = (
|
| 244 |
+
_PYTORCH_PRUNING_METHOD[pruning_fn]
|
| 245 |
+
if self._use_global_unstructured
|
| 246 |
+
else _PYTORCH_PRUNING_FUNCTIONS[pruning_fn]
|
| 247 |
+
)
|
| 248 |
+
assert callable(pruning_meth), "Selected pruning method is not callable"
|
| 249 |
+
if self._use_global_unstructured:
|
| 250 |
+
self._global_kwargs = kwargs
|
| 251 |
+
# save the function __name__ now because partial does not include it
|
| 252 |
+
# and there are issues setting the attribute manually in ddp.
|
| 253 |
+
self._pruning_method_name = pruning_meth.__name__
|
| 254 |
+
if self._use_global_unstructured:
|
| 255 |
+
return pruning_meth
|
| 256 |
+
return ModelPruning._wrap_pruning_fn(pruning_meth, **kwargs)
|
| 257 |
+
|
| 258 |
+
@staticmethod
|
| 259 |
+
def _wrap_pruning_fn(pruning_fn: Callable, **kwargs: Any) -> Callable:
|
| 260 |
+
return partial(pruning_fn, **kwargs)
|
| 261 |
+
|
| 262 |
+
def make_pruning_permanent(self, module: nn.Module) -> None:
|
| 263 |
+
"""Removes pruning buffers from any pruned modules.
|
| 264 |
+
|
| 265 |
+
Adapted from https://github.com/pytorch/pytorch/blob/1.7.1/torch/nn/utils/prune.py#L1176-L1180
|
| 266 |
+
"""
|
| 267 |
+
for _, module in module.named_modules():
|
| 268 |
+
for k in list(module._forward_pre_hooks):
|
| 269 |
+
hook = module._forward_pre_hooks[k]
|
| 270 |
+
if isinstance(hook, pytorch_prune.BasePruningMethod):
|
| 271 |
+
hook.remove(module)
|
| 272 |
+
del module._forward_pre_hooks[k]
|
| 273 |
+
|
| 274 |
+
@staticmethod
|
| 275 |
+
def _copy_param(new: nn.Module, old: nn.Module, name: str) -> None:
|
| 276 |
+
dst = getattr(new, name)
|
| 277 |
+
src = getattr(old, name)
|
| 278 |
+
if dst is None or src is None or not isinstance(dst, torch.Tensor) or not isinstance(src, torch.Tensor):
|
| 279 |
+
return
|
| 280 |
+
dst.data = src.data.to(dst.device)
|
| 281 |
+
|
| 282 |
+
def apply_lottery_ticket_hypothesis(self) -> None:
|
| 283 |
+
r"""
|
| 284 |
+
Lottery ticket hypothesis algorithm (see page 2 of the paper):
|
| 285 |
+
|
| 286 |
+
1. Randomly initialize a neural network :math:`f(x; \theta_0)` (where :math:`\theta_0 \sim \mathcal{D}_\theta`).
|
| 287 |
+
2. Train the network for :math:`j` iterations, arriving at parameters :math:`\theta_j`.
|
| 288 |
+
3. Prune :math:`p\%` of the parameters in :math:`\theta_j`, creating a mask :math:`m`.
|
| 289 |
+
4. Reset the remaining parameters to their values in :math:`\theta_0`, creating the winning ticket :math:`f(x; m \odot \theta_0)`.
|
| 290 |
+
|
| 291 |
+
This function implements the step 4.
|
| 292 |
+
|
| 293 |
+
The ``resample_parameters`` argument can be used to reset the parameters with a new :math:`\theta_z \sim \mathcal{D}_\theta`
|
| 294 |
+
""" # noqa: E501
|
| 295 |
+
assert self._original_layers is not None
|
| 296 |
+
for d in self._original_layers.values():
|
| 297 |
+
copy = d["data"]
|
| 298 |
+
names = d["names"]
|
| 299 |
+
if self._resample_parameters and hasattr(copy, "reset_parameters") and callable(copy.reset_parameters):
|
| 300 |
+
copy = deepcopy(copy) # keep the original parameters
|
| 301 |
+
copy.reset_parameters()
|
| 302 |
+
for i, name in names:
|
| 303 |
+
new, new_name = self._parameters_to_prune[i]
|
| 304 |
+
self._copy_param(new, copy, name)
|
| 305 |
+
|
| 306 |
+
def _apply_local_pruning(self, amount: float) -> None:
|
| 307 |
+
for module, name in self._parameters_to_prune:
|
| 308 |
+
self.pruning_fn(module, name=name, amount=amount)
|
| 309 |
+
|
| 310 |
+
def _resolve_global_kwargs(self, amount: float) -> Dict[str, Any]:
|
| 311 |
+
self._global_kwargs["amount"] = amount
|
| 312 |
+
params = set(inspect.signature(self.pruning_fn).parameters)
|
| 313 |
+
params.discard("self")
|
| 314 |
+
return {k: v for k, v in self._global_kwargs.items() if k in params}
|
| 315 |
+
|
| 316 |
+
def _apply_global_pruning(self, amount: float) -> None:
|
| 317 |
+
pytorch_prune.global_unstructured(
|
| 318 |
+
self._parameters_to_prune, pruning_method=self.pruning_fn, **self._resolve_global_kwargs(amount)
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
@staticmethod
|
| 322 |
+
def _get_pruned_stats(module: nn.Module, name: str) -> Tuple[int, int]:
|
| 323 |
+
attr = f"{name}_mask"
|
| 324 |
+
if not hasattr(module, attr):
|
| 325 |
+
return 0, 1
|
| 326 |
+
mask = getattr(module, attr)
|
| 327 |
+
return (mask == 0).sum().item(), mask.numel()
|
| 328 |
+
|
| 329 |
+
def apply_pruning(self, amount: Union[int, float]) -> None:
|
| 330 |
+
"""Applies pruning to ``parameters_to_prune``."""
|
| 331 |
+
if self._verbose:
|
| 332 |
+
prev_stats = [self._get_pruned_stats(m, n) for m, n in self._parameters_to_prune]
|
| 333 |
+
|
| 334 |
+
if self._use_global_unstructured:
|
| 335 |
+
self._apply_global_pruning(amount)
|
| 336 |
+
else:
|
| 337 |
+
self._apply_local_pruning(amount)
|
| 338 |
+
|
| 339 |
+
if self._verbose:
|
| 340 |
+
curr_stats = [self._get_pruned_stats(m, n) for m, n in self._parameters_to_prune]
|
| 341 |
+
self._log_sparsity_stats(prev_stats, curr_stats, amount=amount)
|
| 342 |
+
|
| 343 |
+
@rank_zero_only
|
| 344 |
+
def _log_sparsity_stats(
|
| 345 |
+
self, prev: List[Tuple[int, int]], curr: List[Tuple[int, int]], amount: Union[int, float] = 0
|
| 346 |
+
) -> None:
|
| 347 |
+
total_params = sum(p.numel() for layer, _ in self._parameters_to_prune for p in layer.parameters())
|
| 348 |
+
prev_total_zeros = sum(zeros for zeros, _ in prev)
|
| 349 |
+
curr_total_zeros = sum(zeros for zeros, _ in curr)
|
| 350 |
+
log.info(
|
| 351 |
+
f"Applied `{self._pruning_method_name}`. Pruned:"
|
| 352 |
+
f" {prev_total_zeros}/{total_params} ({prev_total_zeros / total_params:.2%}) ->"
|
| 353 |
+
f" {curr_total_zeros}/{total_params} ({curr_total_zeros / total_params:.2%})"
|
| 354 |
+
)
|
| 355 |
+
if self._verbose == 2:
|
| 356 |
+
for i, (module, name) in enumerate(self._parameters_to_prune):
|
| 357 |
+
prev_mask_zeros, prev_mask_size = prev[i]
|
| 358 |
+
curr_mask_zeros, curr_mask_size = curr[i]
|
| 359 |
+
log.info(
|
| 360 |
+
f"Applied `{self._pruning_method_name}` to `{module!r}.{name}` with amount={amount}. Pruned:"
|
| 361 |
+
f" {prev_mask_zeros} ({prev_mask_zeros / prev_mask_size:.2%}) ->"
|
| 362 |
+
f" {curr_mask_zeros} ({curr_mask_zeros / curr_mask_size:.2%})"
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
|
| 366 |
+
parameters_to_prune = self.sanitize_parameters_to_prune(
|
| 367 |
+
pl_module, self._parameters_to_prune, parameter_names=self._parameter_names
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
self._parameters_to_prune = self.filter_parameters_to_prune(parameters_to_prune)
|
| 371 |
+
|
| 372 |
+
if self._use_lottery_ticket_hypothesis:
|
| 373 |
+
# group modules by id. Each entry has a copy of the initial data
|
| 374 |
+
# and a list of the associated parameter names to prune
|
| 375 |
+
self._original_layers = {}
|
| 376 |
+
for i, (module, name) in enumerate(self._parameters_to_prune):
|
| 377 |
+
id_ = id(module)
|
| 378 |
+
self._original_layers.setdefault(id_, _LayerRef(data=deepcopy(module), names=[]))
|
| 379 |
+
self._original_layers[id_]["names"].append((i, name))
|
| 380 |
+
|
| 381 |
+
def _run_pruning(self, current_epoch: int) -> None:
|
| 382 |
+
prune = self._apply_pruning(current_epoch) if callable(self._apply_pruning) else self._apply_pruning
|
| 383 |
+
amount = self.amount(current_epoch) if callable(self.amount) else self.amount
|
| 384 |
+
if not prune or not amount:
|
| 385 |
+
return
|
| 386 |
+
self.apply_pruning(amount)
|
| 387 |
+
|
| 388 |
+
if (
|
| 389 |
+
self._use_lottery_ticket_hypothesis(current_epoch)
|
| 390 |
+
if callable(self._use_lottery_ticket_hypothesis)
|
| 391 |
+
else self._use_lottery_ticket_hypothesis
|
| 392 |
+
):
|
| 393 |
+
self.apply_lottery_ticket_hypothesis()
|
| 394 |
+
|
| 395 |
+
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: LightningModule) -> None:
|
| 396 |
+
if self._prune_on_train_epoch_end:
|
| 397 |
+
rank_zero_debug("`ModelPruning.on_train_epoch_end`. Applying pruning")
|
| 398 |
+
self._run_pruning(pl_module.current_epoch)
|
| 399 |
+
|
| 400 |
+
def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 401 |
+
if not trainer.sanity_checking and not self._prune_on_train_epoch_end:
|
| 402 |
+
rank_zero_debug("`ModelPruning.on_validation_epoch_end`. Applying pruning")
|
| 403 |
+
self._run_pruning(pl_module.current_epoch)
|
| 404 |
+
|
| 405 |
+
def on_train_end(self, trainer: "pl.Trainer", pl_module: LightningModule) -> None:
|
| 406 |
+
if self._make_pruning_permanent:
|
| 407 |
+
rank_zero_debug("`ModelPruning.on_train_end`. Pruning is made permanent for this checkpoint")
|
| 408 |
+
self.make_pruning_permanent(pl_module)
|
| 409 |
+
|
| 410 |
+
def _make_pruning_permanent_on_state_dict(self, pl_module: LightningModule) -> Dict[str, Any]:
|
| 411 |
+
state_dict = pl_module.state_dict()
|
| 412 |
+
|
| 413 |
+
# find the mask and the original weights.
|
| 414 |
+
map_pruned_params = {k.replace("_mask", "") for k in state_dict.keys() if k.endswith("_mask")}
|
| 415 |
+
for tensor_name in map_pruned_params:
|
| 416 |
+
orig = state_dict.pop(tensor_name + "_orig")
|
| 417 |
+
mask = state_dict.pop(tensor_name + "_mask")
|
| 418 |
+
# make weights permanent
|
| 419 |
+
state_dict[tensor_name] = mask.to(dtype=orig.dtype) * orig
|
| 420 |
+
|
| 421 |
+
def move_to_cpu(tensor: torch.Tensor) -> torch.Tensor:
|
| 422 |
+
# each tensor and move them on cpu
|
| 423 |
+
return tensor.cpu()
|
| 424 |
+
|
| 425 |
+
return apply_to_collection(state_dict, torch.Tensor, move_to_cpu)
|
| 426 |
+
|
| 427 |
+
def on_save_checkpoint(
|
| 428 |
+
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
|
| 429 |
+
) -> Optional[dict]:
|
| 430 |
+
if self._make_pruning_permanent:
|
| 431 |
+
rank_zero_debug("`ModelPruning.on_save_checkpoint`. Pruning is made permanent for this checkpoint")
|
| 432 |
+
# manually prune the weights so training can keep going with the same buffers
|
| 433 |
+
checkpoint["state_dict"] = self._make_pruning_permanent_on_state_dict(pl_module)
|
| 434 |
+
|
| 435 |
+
@staticmethod
|
| 436 |
+
def sanitize_parameters_to_prune(
|
| 437 |
+
pl_module: LightningModule, parameters_to_prune: _PARAM_LIST = (), parameter_names: Sequence[str] = ()
|
| 438 |
+
) -> _PARAM_LIST:
|
| 439 |
+
"""This function is responsible of sanitizing ``parameters_to_prune`` and ``parameter_names``. If
|
| 440 |
+
``parameters_to_prune is None``, it will be generated with all parameters of the model.
|
| 441 |
+
|
| 442 |
+
Raises:
|
| 443 |
+
MisconfigurationException:
|
| 444 |
+
If ``parameters_to_prune`` doesn't exist in the model, or
|
| 445 |
+
if ``parameters_to_prune`` is neither a list nor a tuple.
|
| 446 |
+
"""
|
| 447 |
+
parameters = parameter_names or ModelPruning.PARAMETER_NAMES
|
| 448 |
+
|
| 449 |
+
current_modules = [m for m in pl_module.modules() if not isinstance(m, _MODULE_CONTAINERS)]
|
| 450 |
+
|
| 451 |
+
if not parameters_to_prune:
|
| 452 |
+
parameters_to_prune = [
|
| 453 |
+
(m, p) for p in parameters for m in current_modules if getattr(m, p, None) is not None
|
| 454 |
+
]
|
| 455 |
+
elif (
|
| 456 |
+
isinstance(parameters_to_prune, (list, tuple))
|
| 457 |
+
and len(parameters_to_prune) > 0
|
| 458 |
+
and all(len(p) == 2 for p in parameters_to_prune)
|
| 459 |
+
and all(isinstance(a, nn.Module) and isinstance(b, str) for a, b in parameters_to_prune)
|
| 460 |
+
):
|
| 461 |
+
missing_modules, missing_parameters = [], []
|
| 462 |
+
for module, name in parameters_to_prune:
|
| 463 |
+
if module not in current_modules:
|
| 464 |
+
missing_modules.append(module)
|
| 465 |
+
continue
|
| 466 |
+
if not hasattr(module, name):
|
| 467 |
+
missing_parameters.append(name)
|
| 468 |
+
|
| 469 |
+
if missing_modules or missing_parameters:
|
| 470 |
+
raise MisconfigurationException(
|
| 471 |
+
"Some provided `parameters_to_tune` don't exist in the model."
|
| 472 |
+
f" Found missing modules: {missing_modules} and missing parameters: {missing_parameters}"
|
| 473 |
+
)
|
| 474 |
+
else:
|
| 475 |
+
raise MisconfigurationException(
|
| 476 |
+
"The provided `parameters_to_prune` should either be list of tuple"
|
| 477 |
+
" with 2 elements: (nn.Module, parameter_name_to_prune) or None"
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
return parameters_to_prune
|
| 481 |
+
|
| 482 |
+
@staticmethod
|
| 483 |
+
def _is_pruning_method(method: Any) -> bool:
|
| 484 |
+
if not inspect.isclass(method):
|
| 485 |
+
return False
|
| 486 |
+
return issubclass(method, pytorch_prune.BasePruningMethod)
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/quantization.py
ADDED
|
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright The PyTorch Lightning team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
r"""
|
| 15 |
+
Quantization
|
| 16 |
+
^^^^^^^^^^^^
|
| 17 |
+
|
| 18 |
+
"""
|
| 19 |
+
import copy
|
| 20 |
+
import functools
|
| 21 |
+
from typing import Any, Callable, Dict, Optional, Sequence, Union
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
from torch import Tensor
|
| 25 |
+
from torch.quantization import FakeQuantizeBase
|
| 26 |
+
|
| 27 |
+
import pytorch_lightning as pl
|
| 28 |
+
from pytorch_lightning.callbacks.base import Callback
|
| 29 |
+
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10, _TORCH_GREATER_EQUAL_1_11
|
| 30 |
+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
| 31 |
+
|
| 32 |
+
if _TORCH_GREATER_EQUAL_1_10:
|
| 33 |
+
from torch.ao.quantization.qconfig import QConfig
|
| 34 |
+
else:
|
| 35 |
+
from torch.quantization import QConfig
|
| 36 |
+
|
| 37 |
+
if _TORCH_GREATER_EQUAL_1_11:
|
| 38 |
+
from torch.ao.quantization import fuse_modules_qat as fuse_modules
|
| 39 |
+
else:
|
| 40 |
+
from torch.quantization import fuse_modules
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def wrap_qat_forward_context(
|
| 44 |
+
quant_cb, model: "pl.LightningModule", func: Callable, trigger_condition: Optional[Union[Callable, int]] = None
|
| 45 |
+
) -> Callable:
|
| 46 |
+
"""Decorator to wrap forward path as it is needed to quantize inputs and dequantize outputs for in/out
|
| 47 |
+
compatibility Moreover this version has the (de)quantization conditional as it may not be needed for the
|
| 48 |
+
training all the time."""
|
| 49 |
+
# todo: consider using registering hook before/after forward
|
| 50 |
+
@functools.wraps(func)
|
| 51 |
+
def wrapper(data) -> Any:
|
| 52 |
+
_is_func_true = isinstance(trigger_condition, Callable) and trigger_condition(model.trainer)
|
| 53 |
+
_is_count_true = isinstance(trigger_condition, int) and quant_cb._forward_calls < trigger_condition
|
| 54 |
+
_quant_run = trigger_condition is None or _is_func_true or _is_count_true
|
| 55 |
+
# apply custom trigger
|
| 56 |
+
if _quant_run:
|
| 57 |
+
quant_cb._forward_calls += 1
|
| 58 |
+
data = model.quant(data)
|
| 59 |
+
data = func(data)
|
| 60 |
+
# apply custom trigger
|
| 61 |
+
if _quant_run:
|
| 62 |
+
data = model.dequant(data)
|
| 63 |
+
return data
|
| 64 |
+
|
| 65 |
+
return wrapper
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def wrap_quantize_forward_context(model: "pl.LightningModule", func: Callable) -> Callable:
|
| 69 |
+
"""Decorator to wrap forward path as it is needed to quantize inputs and dequantize outputs for in/out
|
| 70 |
+
compatibility."""
|
| 71 |
+
# todo: consider using registering hook before/after forward
|
| 72 |
+
@functools.wraps(func)
|
| 73 |
+
def wrapper(data) -> Any:
|
| 74 |
+
data = model.quant(data)
|
| 75 |
+
data = func(data)
|
| 76 |
+
data = model.dequant(data)
|
| 77 |
+
return data
|
| 78 |
+
|
| 79 |
+
return wrapper
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _recursive_hasattr(obj: Any, attribs: str, state: bool = True) -> bool:
|
| 83 |
+
"""recursive check if model has some layers denoted with '.'."""
|
| 84 |
+
if "." in attribs:
|
| 85 |
+
attrib, attribs = attribs.split(".", 1)
|
| 86 |
+
if hasattr(obj, attrib):
|
| 87 |
+
return _recursive_hasattr(getattr(obj, attrib), attribs, state)
|
| 88 |
+
return False
|
| 89 |
+
return state and hasattr(obj, attribs)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class QuantizationAwareTraining(Callback):
|
| 93 |
+
"""Quantization allows speeding up inference and decreasing memory requirements by performing computations and
|
| 94 |
+
storing tensors at lower bitwidths (such as INT8 or FLOAT16) than floating point precision. We use native
|
| 95 |
+
PyTorch API so for more information see `PyTorch Quantization`_.
|
| 96 |
+
|
| 97 |
+
.. warning:: ``QuantizationAwareTraining`` is in beta and subject to change.
|
| 98 |
+
|
| 99 |
+
The ``LightningModule`` is prepared for QAT training in the ``on_fit_start`` hook. Checkpoints saved during training
|
| 100 |
+
include already collected stats to perform the Quantization conversion, but it doesn't contain the quantized or
|
| 101 |
+
fused model/layers. The quantization is performed in the ``on_fit_end`` hook so the model needs to be saved after
|
| 102 |
+
training finishes if quantization is desired.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
|
| 106 |
+
qconfig: quantization configuration:
|
| 107 |
+
|
| 108 |
+
- 'fbgemm' for server inference.
|
| 109 |
+
- 'qnnpack' for mobile inference.
|
| 110 |
+
- a custom `torch.quantization.QConfig`_.
|
| 111 |
+
|
| 112 |
+
observer_type: allows switching between ``MovingAverageMinMaxObserver`` as "average" (default)
|
| 113 |
+
and ``HistogramObserver`` as "histogram" which is more computationally expensive.
|
| 114 |
+
|
| 115 |
+
collect_quantization: count or custom function to collect quantization statistics:
|
| 116 |
+
|
| 117 |
+
- ``None`` (default). The quantization observer is called in each module forward
|
| 118 |
+
(useful for collecting extended statistic when using image/data augmentation).
|
| 119 |
+
- ``int``. Use to set a fixed number of calls, starting from the beginning.
|
| 120 |
+
- ``Callable``. Custom function with single trainer argument.
|
| 121 |
+
See this example to trigger only the last epoch:
|
| 122 |
+
|
| 123 |
+
.. code-block:: python
|
| 124 |
+
|
| 125 |
+
def custom_trigger_last(trainer):
|
| 126 |
+
return trainer.current_epoch == (trainer.max_epochs - 1)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
QuantizationAwareTraining(collect_quantization=custom_trigger_last)
|
| 130 |
+
|
| 131 |
+
modules_to_fuse: allows you fuse a few layers together as shown in
|
| 132 |
+
`diagram <https://pytorch.org/docs/stable/quantization.html#quantization-aware-training>`_
|
| 133 |
+
to find which layer types can be fused, check https://github.com/pytorch/pytorch/pull/43286.
|
| 134 |
+
|
| 135 |
+
input_compatible: preserve quant/dequant layers. This allows to feat any input as to the original model,
|
| 136 |
+
but break compatibility to torchscript and export with ``torch.save``.
|
| 137 |
+
|
| 138 |
+
quantize_on_fit_end: perform the quantization in `on_fit_end`.
|
| 139 |
+
Note that once converted, the model cannot be put in training mode again.
|
| 140 |
+
|
| 141 |
+
observer_enabled_stages: allow fake-quantization modules' observers to do calibration during provided stages:
|
| 142 |
+
|
| 143 |
+
- ``'train'``: the observers can do calibration during training.
|
| 144 |
+
- ``'validate'``: the observers can do calibration during validating.
|
| 145 |
+
Note that we don't disable observers during the sanity check as the model hasn't been calibrated with
|
| 146 |
+
training data yet. After the sanity check, the fake-quantization modules are restored to initial states.
|
| 147 |
+
- ``'test'``: the observers can do calibration during testing.
|
| 148 |
+
- ``'predict'``: the observers can do calibration during predicting.
|
| 149 |
+
|
| 150 |
+
Note that we only handle observers belonging to fake-quantization modules. When ``qconfig`` is a ``str`` and
|
| 151 |
+
``observer_type`` is ``'histogram'``, the observers won't belong to any fake-quantization modules and will
|
| 152 |
+
not be controlled by the callback.
|
| 153 |
+
|
| 154 |
+
.. _PyTorch Quantization: https://pytorch.org/docs/stable/quantization.html#quantization-aware-training
|
| 155 |
+
.. _torch.quantization.QConfig: https://pytorch.org/docs/stable/torch.quantization.html#torch.quantization.QConfig
|
| 156 |
+
"""
|
| 157 |
+
|
| 158 |
+
OBSERVER_TYPES = ("histogram", "average")
|
| 159 |
+
OBSERVER_STAGES = ("train", "validate", "test", "predict")
|
| 160 |
+
|
| 161 |
+
def __init__(
|
| 162 |
+
self,
|
| 163 |
+
qconfig: Union[str, QConfig] = "fbgemm",
|
| 164 |
+
observer_type: str = "average",
|
| 165 |
+
collect_quantization: Optional[Union[int, Callable]] = None,
|
| 166 |
+
modules_to_fuse: Optional[Sequence] = None,
|
| 167 |
+
input_compatible: bool = True,
|
| 168 |
+
quantize_on_fit_end: bool = True,
|
| 169 |
+
observer_enabled_stages: Sequence[str] = ("train",),
|
| 170 |
+
) -> None:
|
| 171 |
+
_valid_qconf_str = isinstance(qconfig, str) and qconfig in torch.backends.quantized.supported_engines
|
| 172 |
+
if not isinstance(qconfig, QConfig) and not _valid_qconf_str:
|
| 173 |
+
raise MisconfigurationException(
|
| 174 |
+
f"Unsupported qconfig: f{qconfig}.\nTry one of defaults: {torch.backends.quantized.supported_engines}"
|
| 175 |
+
)
|
| 176 |
+
self._qconfig = qconfig
|
| 177 |
+
|
| 178 |
+
if observer_type not in self.OBSERVER_TYPES:
|
| 179 |
+
raise MisconfigurationException(
|
| 180 |
+
f'Unsupported observer type "{observer_type}", allowed are {self.OBSERVER_TYPES}.'
|
| 181 |
+
)
|
| 182 |
+
self._observer_type = observer_type
|
| 183 |
+
|
| 184 |
+
if collect_quantization is not None and not isinstance(collect_quantization, (int, Callable)):
|
| 185 |
+
raise MisconfigurationException(
|
| 186 |
+
f'Unsupported `collect_quantization` "{collect_quantization}", allowed are `int` or `Callable`.'
|
| 187 |
+
)
|
| 188 |
+
self._collect_quantization = collect_quantization
|
| 189 |
+
|
| 190 |
+
self._modules_to_fuse = modules_to_fuse
|
| 191 |
+
self._input_compatible = input_compatible
|
| 192 |
+
self._convert_on_fit_end = quantize_on_fit_end
|
| 193 |
+
|
| 194 |
+
observer_enabled_stages = set(observer_enabled_stages)
|
| 195 |
+
unsupported_stages = observer_enabled_stages - set(self.OBSERVER_STAGES)
|
| 196 |
+
if unsupported_stages:
|
| 197 |
+
raise MisconfigurationException(
|
| 198 |
+
f'Unsupported stages "{tuple(sorted(unsupported_stages))}", allowed are {self.OBSERVER_STAGES}.'
|
| 199 |
+
)
|
| 200 |
+
self._observer_disabled_stages = set(self.OBSERVER_STAGES) - observer_enabled_stages
|
| 201 |
+
|
| 202 |
+
self._forward_calls = 0
|
| 203 |
+
self._fake_quant_to_initial_state_dict = {}
|
| 204 |
+
self._last_fake_quant_to_observer_enabled = {}
|
| 205 |
+
self._module_prepared = False
|
| 206 |
+
|
| 207 |
+
def _check_feasible_fuse(self, model: "pl.LightningModule") -> bool:
|
| 208 |
+
if not self._modules_to_fuse:
|
| 209 |
+
return False
|
| 210 |
+
for group in self._modules_to_fuse:
|
| 211 |
+
if not all(_recursive_hasattr(model, m) for m in group):
|
| 212 |
+
raise MisconfigurationException(
|
| 213 |
+
f"You have requested to fuse {group} but one or more of them is not your model attributes"
|
| 214 |
+
)
|
| 215 |
+
return True
|
| 216 |
+
|
| 217 |
+
def _collect_observer_enabled(self) -> Dict[FakeQuantizeBase, Tensor]:
|
| 218 |
+
return {
|
| 219 |
+
fake_quant: fake_quant.observer_enabled.clone() for fake_quant in self._fake_quant_to_initial_state_dict
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
def _disable_observer(self, pl_module: "pl.LightningModule") -> None:
|
| 223 |
+
self._last_fake_quant_to_observer_enabled = self._collect_observer_enabled()
|
| 224 |
+
pl_module.apply(torch.quantization.disable_observer)
|
| 225 |
+
|
| 226 |
+
def _restore_last_observer_enabled(self) -> None:
|
| 227 |
+
for fake_quant, observer_enabled in self._last_fake_quant_to_observer_enabled.items():
|
| 228 |
+
fake_quant.observer_enabled.copy_(observer_enabled)
|
| 229 |
+
|
| 230 |
+
def _prepare_model(self, model: torch.nn.Module) -> None:
|
| 231 |
+
if self._module_prepared:
|
| 232 |
+
return
|
| 233 |
+
# QuantStub converts tensors from floating point to quantized
|
| 234 |
+
model.quant = torch.quantization.QuantStub()
|
| 235 |
+
# DeQuantStub converts tensors from quantized to floating point
|
| 236 |
+
model.dequant = torch.quantization.DeQuantStub()
|
| 237 |
+
# manually specify where tensors will be converted from quantized
|
| 238 |
+
# to floating point in the quantized model
|
| 239 |
+
self.__module_forward = model.forward
|
| 240 |
+
model.forward = wrap_qat_forward_context(
|
| 241 |
+
quant_cb=self, model=model, func=model.forward, trigger_condition=self._collect_quantization
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
# attach a global qconfig, which contains information about what kind
|
| 245 |
+
# of observers to attach. Use 'fbgemm' for server inference
|
| 246 |
+
if isinstance(self._qconfig, str):
|
| 247 |
+
if self._observer_type == "histogram":
|
| 248 |
+
model.qconfig = torch.quantization.get_default_qconfig(self._qconfig)
|
| 249 |
+
elif self._observer_type == "average":
|
| 250 |
+
# version=None corresponds to using FakeQuantize rather than
|
| 251 |
+
# FusedMovingAvgObsFakeQuantize which was introduced in PT1.10
|
| 252 |
+
# details in https://github.com/pytorch/pytorch/issues/64564
|
| 253 |
+
extra_kwargs = dict(version=None) if _TORCH_GREATER_EQUAL_1_10 else {}
|
| 254 |
+
model.qconfig = torch.quantization.get_default_qat_qconfig(self._qconfig, **extra_kwargs)
|
| 255 |
+
|
| 256 |
+
elif isinstance(self._qconfig, QConfig):
|
| 257 |
+
model.qconfig = self._qconfig
|
| 258 |
+
|
| 259 |
+
if self._check_feasible_fuse(model):
|
| 260 |
+
fuse_modules(model, self._modules_to_fuse, inplace=True)
|
| 261 |
+
|
| 262 |
+
# Prepare the model for QAT. This inserts observers and fake_quants in
|
| 263 |
+
# the model that will observe weight and activation tensors during calibration.
|
| 264 |
+
torch.quantization.prepare_qat(model, inplace=True)
|
| 265 |
+
|
| 266 |
+
fake_quants = tuple(module for module in model.modules() if isinstance(module, FakeQuantizeBase))
|
| 267 |
+
self._fake_quant_to_initial_state_dict = {
|
| 268 |
+
fake_quant: copy.deepcopy(fake_quant.state_dict()) for fake_quant in fake_quants
|
| 269 |
+
}
|
| 270 |
+
self._module_prepared = True
|
| 271 |
+
|
| 272 |
+
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
|
| 273 |
+
self._prepare_model(pl_module)
|
| 274 |
+
|
| 275 |
+
def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 276 |
+
if not self._convert_on_fit_end:
|
| 277 |
+
pl_module.forward = self.__module_forward
|
| 278 |
+
return
|
| 279 |
+
pl_module.eval()
|
| 280 |
+
# Convert the observed model to a quantized model. This does several things:
|
| 281 |
+
# quantizes the weights, computes and stores the scale and bias value to be
|
| 282 |
+
# used with each activation tensor, fuses modules where appropriate,
|
| 283 |
+
# and replaces key operators with quantized implementations.
|
| 284 |
+
torch.quantization.convert(pl_module, inplace=True)
|
| 285 |
+
# check we shall preserve wrapper
|
| 286 |
+
if self._input_compatible:
|
| 287 |
+
pl_module.forward = wrap_quantize_forward_context(model=pl_module, func=self.__module_forward)
|
| 288 |
+
else:
|
| 289 |
+
pl_module.forward = self.__module_forward
|
| 290 |
+
|
| 291 |
+
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 292 |
+
if "train" in self._observer_disabled_stages:
|
| 293 |
+
self._disable_observer(pl_module)
|
| 294 |
+
|
| 295 |
+
def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 296 |
+
if "train" in self._observer_disabled_stages:
|
| 297 |
+
self._restore_last_observer_enabled()
|
| 298 |
+
|
| 299 |
+
def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 300 |
+
if "validate" in self._observer_disabled_stages and not trainer.sanity_checking:
|
| 301 |
+
# ``torch.quantization.MovingAveragePerChannelMinMaxObserver`` and ``torch.quantization.HistogramObserver``
|
| 302 |
+
# need to see at least one batch to infer the shapes of quantization ``scale`` and ``zero_point``. So we
|
| 303 |
+
# don't disable observers during the sanity check so that they can infer the shapes of quantization
|
| 304 |
+
# parameters with validation data.
|
| 305 |
+
self._disable_observer(pl_module)
|
| 306 |
+
|
| 307 |
+
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 308 |
+
if "validate" in self._observer_disabled_stages:
|
| 309 |
+
if trainer.sanity_checking:
|
| 310 |
+
for fake_quant, state_dict in self._fake_quant_to_initial_state_dict.items():
|
| 311 |
+
fake_quant.load_state_dict(state_dict)
|
| 312 |
+
else:
|
| 313 |
+
self._restore_last_observer_enabled()
|
| 314 |
+
|
| 315 |
+
def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 316 |
+
if "test" in self._observer_disabled_stages:
|
| 317 |
+
self._disable_observer(pl_module)
|
| 318 |
+
|
| 319 |
+
def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 320 |
+
if "test" in self._observer_disabled_stages:
|
| 321 |
+
self._restore_last_observer_enabled()
|
| 322 |
+
|
| 323 |
+
def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 324 |
+
if "predict" in self._observer_disabled_stages:
|
| 325 |
+
self._disable_observer(pl_module)
|
| 326 |
+
|
| 327 |
+
def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 328 |
+
if "predict" in self._observer_disabled_stages:
|
| 329 |
+
self._restore_last_observer_enabled()
|
| 330 |
+
|
| 331 |
+
def state_dict(self) -> Dict[str, Any]:
|
| 332 |
+
keys = {"_qconfig", "_observer_type", "_collect_quantization", "_modules_to_fuse", "_input_compatible"}
|
| 333 |
+
return {n: getattr(self, n) for n in keys}
|
| 334 |
+
|
| 335 |
+
def _load_before_model(self, model: torch.nn.Module, state_dict: Dict[str, Any]) -> None:
|
| 336 |
+
"""Special hook that gets called by the CheckpointConnector *before* the model gets loaded.
|
| 337 |
+
|
| 338 |
+
This hook replaces the :meth:`on_load_checkpoint` and :meth:`load_state_dict` callback methods which get called
|
| 339 |
+
after the model has already loaded the weights. For quantization, we need to convert the model first before that
|
| 340 |
+
happens, assuming the previous training used quantization.
|
| 341 |
+
"""
|
| 342 |
+
for k, v in state_dict.items():
|
| 343 |
+
setattr(self, k, v)
|
| 344 |
+
self._prepare_model(model)
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/rich_model_summary.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright The PyTorch Lightning team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import List, Tuple
|
| 15 |
+
|
| 16 |
+
from pytorch_lightning.callbacks import ModelSummary
|
| 17 |
+
from pytorch_lightning.utilities.imports import _RICH_AVAILABLE
|
| 18 |
+
from pytorch_lightning.utilities.model_summary import get_human_readable_count
|
| 19 |
+
|
| 20 |
+
if _RICH_AVAILABLE:
|
| 21 |
+
from rich import get_console
|
| 22 |
+
from rich.table import Table
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class RichModelSummary(ModelSummary):
|
| 26 |
+
r"""
|
| 27 |
+
Generates a summary of all layers in a :class:`~pytorch_lightning.core.lightning.LightningModule`
|
| 28 |
+
with `rich text formatting <https://github.com/willmcgugan/rich>`_.
|
| 29 |
+
|
| 30 |
+
Install it with pip:
|
| 31 |
+
|
| 32 |
+
.. code-block:: bash
|
| 33 |
+
|
| 34 |
+
pip install rich
|
| 35 |
+
|
| 36 |
+
.. code-block:: python
|
| 37 |
+
|
| 38 |
+
from pytorch_lightning import Trainer
|
| 39 |
+
from pytorch_lightning.callbacks import RichModelSummary
|
| 40 |
+
|
| 41 |
+
trainer = Trainer(callbacks=RichModelSummary())
|
| 42 |
+
|
| 43 |
+
You could also enable ``RichModelSummary`` using the :class:`~pytorch_lightning.callbacks.RichProgressBar`
|
| 44 |
+
|
| 45 |
+
.. code-block:: python
|
| 46 |
+
|
| 47 |
+
from pytorch_lightning import Trainer
|
| 48 |
+
from pytorch_lightning.callbacks import RichProgressBar
|
| 49 |
+
|
| 50 |
+
trainer = Trainer(callbacks=RichProgressBar())
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
max_depth: The maximum depth of layer nesting that the summary will include. A value of 0 turns the
|
| 54 |
+
layer summary off.
|
| 55 |
+
|
| 56 |
+
Raises:
|
| 57 |
+
ModuleNotFoundError:
|
| 58 |
+
If required `rich` package is not installed on the device.
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
def __init__(self, max_depth: int = 1) -> None:
|
| 62 |
+
if not _RICH_AVAILABLE:
|
| 63 |
+
raise ModuleNotFoundError(
|
| 64 |
+
"`RichModelSummary` requires `rich` to be installed. Install it by running `pip install -U rich`."
|
| 65 |
+
)
|
| 66 |
+
super().__init__(max_depth)
|
| 67 |
+
|
| 68 |
+
@staticmethod
|
| 69 |
+
def summarize(
|
| 70 |
+
summary_data: List[Tuple[str, List[str]]],
|
| 71 |
+
total_parameters: int,
|
| 72 |
+
trainable_parameters: int,
|
| 73 |
+
model_size: float,
|
| 74 |
+
) -> None:
|
| 75 |
+
|
| 76 |
+
console = get_console()
|
| 77 |
+
|
| 78 |
+
table = Table(header_style="bold magenta")
|
| 79 |
+
table.add_column(" ", style="dim")
|
| 80 |
+
table.add_column("Name", justify="left", no_wrap=True)
|
| 81 |
+
table.add_column("Type")
|
| 82 |
+
table.add_column("Params", justify="right")
|
| 83 |
+
|
| 84 |
+
column_names = list(zip(*summary_data))[0]
|
| 85 |
+
|
| 86 |
+
for column_name in ["In sizes", "Out sizes"]:
|
| 87 |
+
if column_name in column_names:
|
| 88 |
+
table.add_column(column_name, justify="right", style="white")
|
| 89 |
+
|
| 90 |
+
rows = list(zip(*(arr[1] for arr in summary_data)))
|
| 91 |
+
for row in rows:
|
| 92 |
+
table.add_row(*row)
|
| 93 |
+
|
| 94 |
+
console.print(table)
|
| 95 |
+
|
| 96 |
+
parameters = []
|
| 97 |
+
for param in [trainable_parameters, total_parameters - trainable_parameters, total_parameters, model_size]:
|
| 98 |
+
parameters.append("{:<{}}".format(get_human_readable_count(int(param)), 10))
|
| 99 |
+
|
| 100 |
+
grid = Table.grid(expand=True)
|
| 101 |
+
grid.add_column()
|
| 102 |
+
grid.add_column()
|
| 103 |
+
|
| 104 |
+
grid.add_row(f"[bold]Trainable params[/]: {parameters[0]}")
|
| 105 |
+
grid.add_row(f"[bold]Non-trainable params[/]: {parameters[1]}")
|
| 106 |
+
grid.add_row(f"[bold]Total params[/]: {parameters[2]}")
|
| 107 |
+
grid.add_row(f"[bold]Total estimated model params size (MB)[/]: {parameters[3]}")
|
| 108 |
+
|
| 109 |
+
console.print(grid)
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/stochastic_weight_avg.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright The PyTorch Lightning team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
r"""
|
| 15 |
+
Stochastic Weight Averaging Callback
|
| 16 |
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 17 |
+
"""
|
| 18 |
+
from copy import deepcopy
|
| 19 |
+
from typing import Callable, List, Optional, Union
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
from torch import nn
|
| 23 |
+
from torch.optim.swa_utils import SWALR
|
| 24 |
+
|
| 25 |
+
import pytorch_lightning as pl
|
| 26 |
+
from pytorch_lightning.callbacks.base import Callback
|
| 27 |
+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
| 28 |
+
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn
|
| 29 |
+
from pytorch_lightning.utilities.types import LRSchedulerConfig
|
| 30 |
+
|
| 31 |
+
_AVG_FN = Callable[[torch.Tensor, torch.Tensor, torch.LongTensor], torch.FloatTensor]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class StochasticWeightAveraging(Callback):
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
swa_epoch_start: Union[int, float] = 0.8,
|
| 38 |
+
swa_lrs: Optional[Union[float, List[float]]] = None,
|
| 39 |
+
annealing_epochs: int = 10,
|
| 40 |
+
annealing_strategy: str = "cos",
|
| 41 |
+
avg_fn: Optional[_AVG_FN] = None,
|
| 42 |
+
device: Optional[Union[torch.device, str]] = torch.device("cpu"),
|
| 43 |
+
):
|
| 44 |
+
r"""
|
| 45 |
+
|
| 46 |
+
Implements the Stochastic Weight Averaging (SWA) Callback to average a model.
|
| 47 |
+
|
| 48 |
+
Stochastic Weight Averaging was proposed in ``Averaging Weights Leads to
|
| 49 |
+
Wider Optima and Better Generalization`` by Pavel Izmailov, Dmitrii
|
| 50 |
+
Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson
|
| 51 |
+
(UAI 2018).
|
| 52 |
+
|
| 53 |
+
This documentation is highly inspired by PyTorch's work on SWA.
|
| 54 |
+
The callback arguments follow the scheme defined in PyTorch's ``swa_utils`` package.
|
| 55 |
+
|
| 56 |
+
For a SWA explanation, please take a look
|
| 57 |
+
`here <https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging>`_.
|
| 58 |
+
|
| 59 |
+
.. warning:: ``StochasticWeightAveraging`` is in beta and subject to change.
|
| 60 |
+
|
| 61 |
+
.. warning:: ``StochasticWeightAveraging`` is currently not supported for multiple optimizers/schedulers.
|
| 62 |
+
|
| 63 |
+
.. warning:: ``StochasticWeightAveraging`` is currently only supported on every epoch.
|
| 64 |
+
|
| 65 |
+
See also how to :ref:`enable it directly on the Trainer <advanced/training_tricks:Stochastic Weight Averaging>`
|
| 66 |
+
|
| 67 |
+
Arguments:
|
| 68 |
+
|
| 69 |
+
swa_epoch_start: If provided as int, the procedure will start from
|
| 70 |
+
the ``swa_epoch_start``-th epoch. If provided as float between 0 and 1,
|
| 71 |
+
the procedure will start from ``int(swa_epoch_start * max_epochs)`` epoch
|
| 72 |
+
|
| 73 |
+
swa_lrs: The SWA learning rate to use:
|
| 74 |
+
|
| 75 |
+
- ``None``. Use the current learning rate of the optimizer at the time the SWA procedure starts.
|
| 76 |
+
- ``float``. Use this value for all parameter groups of the optimizer.
|
| 77 |
+
- ``List[float]``. A list values for each parameter group of the optimizer.
|
| 78 |
+
|
| 79 |
+
annealing_epochs: number of epochs in the annealing phase (default: 10)
|
| 80 |
+
|
| 81 |
+
annealing_strategy: Specifies the annealing strategy (default: "cos"):
|
| 82 |
+
|
| 83 |
+
- ``"cos"``. For cosine annealing.
|
| 84 |
+
- ``"linear"`` For linear annealing
|
| 85 |
+
|
| 86 |
+
avg_fn: the averaging function used to update the parameters;
|
| 87 |
+
the function must take in the current value of the
|
| 88 |
+
:class:`AveragedModel` parameter, the current value of :attr:`model`
|
| 89 |
+
parameter and the number of models already averaged; if None,
|
| 90 |
+
equally weighted average is used (default: ``None``)
|
| 91 |
+
|
| 92 |
+
device: if provided, the averaged model will be stored on the ``device``.
|
| 93 |
+
When None is provided, it will infer the `device` from ``pl_module``.
|
| 94 |
+
(default: ``"cpu"``)
|
| 95 |
+
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
err_msg = "swa_epoch_start should be a >0 integer or a float between 0 and 1."
|
| 99 |
+
if isinstance(swa_epoch_start, int) and swa_epoch_start < 1:
|
| 100 |
+
raise MisconfigurationException(err_msg)
|
| 101 |
+
if isinstance(swa_epoch_start, float) and not (0 <= swa_epoch_start <= 1):
|
| 102 |
+
raise MisconfigurationException(err_msg)
|
| 103 |
+
|
| 104 |
+
wrong_type = not isinstance(swa_lrs, (float, list))
|
| 105 |
+
wrong_float = isinstance(swa_lrs, float) and swa_lrs <= 0
|
| 106 |
+
wrong_list = isinstance(swa_lrs, list) and not all(lr > 0 and isinstance(lr, float) for lr in swa_lrs)
|
| 107 |
+
if swa_lrs is not None and (wrong_type or wrong_float or wrong_list):
|
| 108 |
+
raise MisconfigurationException(
|
| 109 |
+
"The `swa_lrs` should be `None`, a positive float, or a list of positive floats"
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
if avg_fn is not None and not isinstance(avg_fn, Callable):
|
| 113 |
+
raise MisconfigurationException("The `avg_fn` should be callable.")
|
| 114 |
+
|
| 115 |
+
if device is not None and not isinstance(device, (torch.device, str)):
|
| 116 |
+
raise MisconfigurationException(f"device is expected to be a torch.device or a str. Found {device}")
|
| 117 |
+
|
| 118 |
+
self._swa_epoch_start = swa_epoch_start
|
| 119 |
+
self._swa_lrs = swa_lrs
|
| 120 |
+
self._annealing_epochs = annealing_epochs
|
| 121 |
+
self._annealing_strategy = annealing_strategy
|
| 122 |
+
self._avg_fn = avg_fn or self.avg_fn
|
| 123 |
+
self._device = device
|
| 124 |
+
self._model_contains_batch_norm = None
|
| 125 |
+
self._average_model = None
|
| 126 |
+
|
| 127 |
+
@property
|
| 128 |
+
def swa_start(self) -> int:
|
| 129 |
+
return max(self._swa_epoch_start - 1, 0) # 0-based
|
| 130 |
+
|
| 131 |
+
@property
|
| 132 |
+
def swa_end(self) -> int:
|
| 133 |
+
return self._max_epochs - 1 # 0-based
|
| 134 |
+
|
| 135 |
+
@staticmethod
|
| 136 |
+
def pl_module_contains_batch_norm(pl_module: "pl.LightningModule"):
|
| 137 |
+
return any(isinstance(module, nn.modules.batchnorm._BatchNorm) for module in pl_module.modules())
|
| 138 |
+
|
| 139 |
+
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
|
| 140 |
+
# copy the model before moving it to accelerator device.
|
| 141 |
+
with pl_module._prevent_trainer_and_dataloaders_deepcopy():
|
| 142 |
+
self._average_model = deepcopy(pl_module)
|
| 143 |
+
|
| 144 |
+
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
|
| 145 |
+
if len(trainer.optimizers) != 1:
|
| 146 |
+
raise MisconfigurationException("SWA currently works with 1 `optimizer`.")
|
| 147 |
+
|
| 148 |
+
if len(trainer.lr_scheduler_configs) > 1:
|
| 149 |
+
raise MisconfigurationException("SWA currently not supported for more than 1 `lr_scheduler`.")
|
| 150 |
+
|
| 151 |
+
if isinstance(self._swa_epoch_start, float):
|
| 152 |
+
self._swa_epoch_start = int(trainer.max_epochs * self._swa_epoch_start)
|
| 153 |
+
|
| 154 |
+
self._model_contains_batch_norm = self.pl_module_contains_batch_norm(pl_module)
|
| 155 |
+
|
| 156 |
+
self._max_epochs = trainer.max_epochs
|
| 157 |
+
if self._model_contains_batch_norm:
|
| 158 |
+
# virtually increase max_epochs to perform batch norm update on latest epoch.
|
| 159 |
+
trainer.fit_loop.max_epochs += 1
|
| 160 |
+
|
| 161 |
+
def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
|
| 162 |
+
if trainer.current_epoch == self.swa_start:
|
| 163 |
+
# move average model to request device.
|
| 164 |
+
self._average_model = self._average_model.to(self._device or pl_module.device)
|
| 165 |
+
|
| 166 |
+
optimizer = trainer.optimizers[0]
|
| 167 |
+
if self._swa_lrs is None:
|
| 168 |
+
self._swa_lrs = [param_group["lr"] for param_group in optimizer.param_groups]
|
| 169 |
+
if isinstance(self._swa_lrs, float):
|
| 170 |
+
self._swa_lrs = [self._swa_lrs] * len(optimizer.param_groups)
|
| 171 |
+
|
| 172 |
+
for lr, group in zip(self._swa_lrs, optimizer.param_groups):
|
| 173 |
+
group["initial_lr"] = lr
|
| 174 |
+
|
| 175 |
+
self._swa_scheduler = SWALR(
|
| 176 |
+
optimizer,
|
| 177 |
+
swa_lr=self._swa_lrs,
|
| 178 |
+
anneal_epochs=self._annealing_epochs,
|
| 179 |
+
anneal_strategy=self._annealing_strategy,
|
| 180 |
+
last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1,
|
| 181 |
+
)
|
| 182 |
+
# We assert that there is only one optimizer on fit start, so know opt_idx is always 0
|
| 183 |
+
default_scheduler_cfg = LRSchedulerConfig(self._swa_scheduler, opt_idx=0)
|
| 184 |
+
assert default_scheduler_cfg.interval == "epoch" and default_scheduler_cfg.frequency == 1
|
| 185 |
+
|
| 186 |
+
if trainer.lr_scheduler_configs:
|
| 187 |
+
scheduler_cfg = trainer.lr_scheduler_configs[0]
|
| 188 |
+
if scheduler_cfg.interval != "epoch" or scheduler_cfg.frequency != 1:
|
| 189 |
+
rank_zero_warn(f"SWA is currently only supported every epoch. Found {scheduler_cfg}")
|
| 190 |
+
rank_zero_info(
|
| 191 |
+
f"Swapping scheduler `{scheduler_cfg.scheduler.__class__.__name__}`"
|
| 192 |
+
f" for `{self._swa_scheduler.__class__.__name__}`"
|
| 193 |
+
)
|
| 194 |
+
trainer.lr_scheduler_configs[0] = default_scheduler_cfg
|
| 195 |
+
else:
|
| 196 |
+
trainer.lr_scheduler_configs.append(default_scheduler_cfg)
|
| 197 |
+
|
| 198 |
+
self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device)
|
| 199 |
+
|
| 200 |
+
if self.swa_start <= trainer.current_epoch <= self.swa_end:
|
| 201 |
+
self.update_parameters(self._average_model, pl_module, self.n_averaged, self._avg_fn)
|
| 202 |
+
|
| 203 |
+
# Note: No > here in case the callback is saved with the model and training continues
|
| 204 |
+
if trainer.current_epoch == self.swa_end + 1:
|
| 205 |
+
|
| 206 |
+
# Transfer weights from average model to pl_module
|
| 207 |
+
self.transfer_weights(self._average_model, pl_module)
|
| 208 |
+
|
| 209 |
+
# Reset BatchNorm for update
|
| 210 |
+
self.reset_batch_norm_and_save_state(pl_module)
|
| 211 |
+
|
| 212 |
+
# There is no need to perform either backward or optimizer.step as we are
|
| 213 |
+
# performing only one pass over the train data-loader to compute activation statistics
|
| 214 |
+
# Therefore, we will virtually increase `num_training_batches` by 1 and skip backward.
|
| 215 |
+
trainer.num_training_batches += 1
|
| 216 |
+
trainer.fit_loop._skip_backward = True
|
| 217 |
+
self._accumulate_grad_batches = trainer.accumulate_grad_batches
|
| 218 |
+
|
| 219 |
+
trainer.accumulate_grad_batches = trainer.num_training_batches
|
| 220 |
+
|
| 221 |
+
def on_train_epoch_end(self, trainer: "pl.Trainer", *args):
|
| 222 |
+
trainer.fit_loop._skip_backward = False
|
| 223 |
+
|
| 224 |
+
def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
|
| 225 |
+
# the trainer increases the current epoch before this hook is called
|
| 226 |
+
if self._model_contains_batch_norm and trainer.current_epoch - 1 == self.swa_end + 1:
|
| 227 |
+
# BatchNorm epoch update. Reset state
|
| 228 |
+
trainer.accumulate_grad_batches = self._accumulate_grad_batches
|
| 229 |
+
trainer.num_training_batches -= 1
|
| 230 |
+
trainer.fit_loop.max_epochs -= 1
|
| 231 |
+
self.reset_momenta()
|
| 232 |
+
elif trainer.current_epoch - 1 == self.swa_end:
|
| 233 |
+
# Last SWA epoch. Transfer weights from average model to pl_module
|
| 234 |
+
self.transfer_weights(self._average_model, pl_module)
|
| 235 |
+
|
| 236 |
+
@staticmethod
|
| 237 |
+
def transfer_weights(src_pl_module: "pl.LightningModule", dst_pl_module: "pl.LightningModule"):
|
| 238 |
+
for src_param, dst_param in zip(src_pl_module.parameters(), dst_pl_module.parameters()):
|
| 239 |
+
dst_param.detach().copy_(src_param.to(dst_param.device))
|
| 240 |
+
|
| 241 |
+
def reset_batch_norm_and_save_state(self, pl_module: "pl.LightningModule"):
|
| 242 |
+
"""Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L140-L154."""
|
| 243 |
+
self.momenta = {}
|
| 244 |
+
for module in pl_module.modules():
|
| 245 |
+
if not isinstance(module, nn.modules.batchnorm._BatchNorm):
|
| 246 |
+
continue
|
| 247 |
+
module.running_mean = torch.zeros_like(
|
| 248 |
+
module.running_mean, device=pl_module.device, dtype=module.running_mean.dtype
|
| 249 |
+
)
|
| 250 |
+
module.running_var = torch.ones_like(
|
| 251 |
+
module.running_var, device=pl_module.device, dtype=module.running_var.dtype
|
| 252 |
+
)
|
| 253 |
+
self.momenta[module] = module.momentum
|
| 254 |
+
module.momentum = None
|
| 255 |
+
module.num_batches_tracked *= 0
|
| 256 |
+
|
| 257 |
+
def reset_momenta(self):
|
| 258 |
+
"""Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164-L165."""
|
| 259 |
+
for bn_module in self.momenta:
|
| 260 |
+
bn_module.momentum = self.momenta[bn_module]
|
| 261 |
+
|
| 262 |
+
@staticmethod
|
| 263 |
+
def update_parameters(
|
| 264 |
+
average_model: "pl.LightningModule", model: "pl.LightningModule", n_averaged: torch.LongTensor, avg_fn: _AVG_FN
|
| 265 |
+
):
|
| 266 |
+
"""Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L104-L112."""
|
| 267 |
+
for p_swa, p_model in zip(average_model.parameters(), model.parameters()):
|
| 268 |
+
device = p_swa.device
|
| 269 |
+
p_swa_ = p_swa.detach()
|
| 270 |
+
p_model_ = p_model.detach().to(device)
|
| 271 |
+
src = p_model_ if n_averaged == 0 else avg_fn(p_swa_, p_model_, n_averaged.to(device))
|
| 272 |
+
p_swa_.copy_(src)
|
| 273 |
+
n_averaged += 1
|
| 274 |
+
|
| 275 |
+
@staticmethod
|
| 276 |
+
def avg_fn(
|
| 277 |
+
averaged_model_parameter: torch.Tensor, model_parameter: torch.Tensor, num_averaged: torch.LongTensor
|
| 278 |
+
) -> torch.FloatTensor:
|
| 279 |
+
"""Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L95-L97."""
|
| 280 |
+
return averaged_model_parameter + (model_parameter - averaged_model_parameter) / (num_averaged + 1)
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/timer.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright The PyTorch Lightning team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
r"""
|
| 15 |
+
Timer
|
| 16 |
+
^^^^^
|
| 17 |
+
"""
|
| 18 |
+
import logging
|
| 19 |
+
import time
|
| 20 |
+
from datetime import timedelta
|
| 21 |
+
from typing import Any, Dict, Optional, Union
|
| 22 |
+
|
| 23 |
+
import pytorch_lightning as pl
|
| 24 |
+
from pytorch_lightning.callbacks.base import Callback
|
| 25 |
+
from pytorch_lightning.trainer.states import RunningStage
|
| 26 |
+
from pytorch_lightning.utilities import LightningEnum
|
| 27 |
+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
| 28 |
+
from pytorch_lightning.utilities.rank_zero import rank_zero_info
|
| 29 |
+
|
| 30 |
+
log = logging.getLogger(__name__)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class Interval(LightningEnum):
|
| 34 |
+
step = "step"
|
| 35 |
+
epoch = "epoch"
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class Timer(Callback):
|
| 39 |
+
"""The Timer callback tracks the time spent in the training, validation, and test loops and interrupts the
|
| 40 |
+
Trainer if the given time limit for the training loop is reached.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
duration: A string in the format DD:HH:MM:SS (days, hours, minutes seconds), or a :class:`datetime.timedelta`,
|
| 44 |
+
or a dict containing key-value compatible with :class:`~datetime.timedelta`.
|
| 45 |
+
interval: Determines if the interruption happens on epoch level or mid-epoch.
|
| 46 |
+
Can be either ``"epoch"`` or ``"step"``.
|
| 47 |
+
verbose: Set this to ``False`` to suppress logging messages.
|
| 48 |
+
|
| 49 |
+
Raises:
|
| 50 |
+
MisconfigurationException:
|
| 51 |
+
If ``interval`` is not one of the supported choices.
|
| 52 |
+
|
| 53 |
+
Example::
|
| 54 |
+
from pytorch_lightning import Trainer
|
| 55 |
+
from pytorch_lightning.callbacks import Timer
|
| 56 |
+
|
| 57 |
+
# stop training after 12 hours
|
| 58 |
+
timer = Timer(duration="00:12:00:00")
|
| 59 |
+
|
| 60 |
+
# or provide a datetime.timedelta
|
| 61 |
+
from datetime import timedelta
|
| 62 |
+
timer = Timer(duration=timedelta(weeks=1))
|
| 63 |
+
|
| 64 |
+
# or provide a dictionary
|
| 65 |
+
timer = Timer(duration=dict(weeks=4, days=2))
|
| 66 |
+
|
| 67 |
+
# force training to stop after given time limit
|
| 68 |
+
trainer = Trainer(callbacks=[timer])
|
| 69 |
+
|
| 70 |
+
# query training/validation/test time (in seconds)
|
| 71 |
+
timer.time_elapsed("train")
|
| 72 |
+
timer.start_time("validate")
|
| 73 |
+
timer.end_time("test")
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
def __init__(
|
| 77 |
+
self,
|
| 78 |
+
duration: Optional[Union[str, timedelta, Dict[str, int]]] = None,
|
| 79 |
+
interval: str = Interval.step,
|
| 80 |
+
verbose: bool = True,
|
| 81 |
+
) -> None:
|
| 82 |
+
super().__init__()
|
| 83 |
+
if isinstance(duration, str):
|
| 84 |
+
dhms = duration.strip().split(":")
|
| 85 |
+
dhms = [int(i) for i in dhms]
|
| 86 |
+
duration = timedelta(days=dhms[0], hours=dhms[1], minutes=dhms[2], seconds=dhms[3])
|
| 87 |
+
if isinstance(duration, dict):
|
| 88 |
+
duration = timedelta(**duration)
|
| 89 |
+
if interval not in set(Interval):
|
| 90 |
+
raise MisconfigurationException(
|
| 91 |
+
f"Unsupported parameter value `Timer(interval={interval})`. Possible choices are:"
|
| 92 |
+
f" {', '.join(set(Interval))}"
|
| 93 |
+
)
|
| 94 |
+
self._duration = duration.total_seconds() if duration is not None else None
|
| 95 |
+
self._interval = interval
|
| 96 |
+
self._verbose = verbose
|
| 97 |
+
self._start_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage}
|
| 98 |
+
self._end_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage}
|
| 99 |
+
self._offset = 0
|
| 100 |
+
|
| 101 |
+
def start_time(self, stage: str = RunningStage.TRAINING) -> Optional[float]:
|
| 102 |
+
"""Return the start time of a particular stage (in seconds)"""
|
| 103 |
+
stage = RunningStage(stage)
|
| 104 |
+
return self._start_time[stage]
|
| 105 |
+
|
| 106 |
+
def end_time(self, stage: str = RunningStage.TRAINING) -> Optional[float]:
|
| 107 |
+
"""Return the end time of a particular stage (in seconds)"""
|
| 108 |
+
stage = RunningStage(stage)
|
| 109 |
+
return self._end_time[stage]
|
| 110 |
+
|
| 111 |
+
def time_elapsed(self, stage: str = RunningStage.TRAINING) -> float:
|
| 112 |
+
"""Return the time elapsed for a particular stage (in seconds)"""
|
| 113 |
+
start = self.start_time(stage)
|
| 114 |
+
end = self.end_time(stage)
|
| 115 |
+
offset = self._offset if stage == RunningStage.TRAINING else 0
|
| 116 |
+
if start is None:
|
| 117 |
+
return offset
|
| 118 |
+
if end is None:
|
| 119 |
+
return time.monotonic() - start + offset
|
| 120 |
+
return end - start + offset
|
| 121 |
+
|
| 122 |
+
def time_remaining(self, stage: str = RunningStage.TRAINING) -> Optional[float]:
|
| 123 |
+
"""Return the time remaining for a particular stage (in seconds)"""
|
| 124 |
+
if self._duration is not None:
|
| 125 |
+
return self._duration - self.time_elapsed(stage)
|
| 126 |
+
|
| 127 |
+
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 128 |
+
self._start_time[RunningStage.TRAINING] = time.monotonic()
|
| 129 |
+
|
| 130 |
+
def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 131 |
+
self._end_time[RunningStage.TRAINING] = time.monotonic()
|
| 132 |
+
|
| 133 |
+
def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 134 |
+
self._start_time[RunningStage.VALIDATING] = time.monotonic()
|
| 135 |
+
|
| 136 |
+
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 137 |
+
self._end_time[RunningStage.VALIDATING] = time.monotonic()
|
| 138 |
+
|
| 139 |
+
def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 140 |
+
self._start_time[RunningStage.TESTING] = time.monotonic()
|
| 141 |
+
|
| 142 |
+
def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 143 |
+
self._end_time[RunningStage.TESTING] = time.monotonic()
|
| 144 |
+
|
| 145 |
+
def on_fit_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
|
| 146 |
+
# this checks the time after the state is reloaded, regardless of the interval.
|
| 147 |
+
# this is necessary in case we load a state whose timer is already depleted
|
| 148 |
+
if self._duration is None:
|
| 149 |
+
return
|
| 150 |
+
self._check_time_remaining(trainer)
|
| 151 |
+
|
| 152 |
+
def on_train_batch_end(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
|
| 153 |
+
if self._interval != Interval.step or self._duration is None:
|
| 154 |
+
return
|
| 155 |
+
self._check_time_remaining(trainer)
|
| 156 |
+
|
| 157 |
+
def on_train_epoch_end(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
|
| 158 |
+
if self._interval != Interval.epoch or self._duration is None:
|
| 159 |
+
return
|
| 160 |
+
self._check_time_remaining(trainer)
|
| 161 |
+
|
| 162 |
+
def state_dict(self) -> Dict[str, Any]:
|
| 163 |
+
return {"time_elapsed": {stage.value: self.time_elapsed(stage) for stage in list(RunningStage)}}
|
| 164 |
+
|
| 165 |
+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
| 166 |
+
time_elapsed = state_dict.get("time_elapsed", {})
|
| 167 |
+
self._offset = time_elapsed.get(RunningStage.TRAINING.value, 0)
|
| 168 |
+
|
| 169 |
+
def _check_time_remaining(self, trainer: "pl.Trainer") -> None:
|
| 170 |
+
assert self._duration is not None
|
| 171 |
+
should_stop = self.time_elapsed() >= self._duration
|
| 172 |
+
should_stop = trainer.strategy.broadcast(should_stop)
|
| 173 |
+
trainer.should_stop = trainer.should_stop or should_stop
|
| 174 |
+
if should_stop and self._verbose:
|
| 175 |
+
elapsed = timedelta(seconds=int(self.time_elapsed(RunningStage.TRAINING)))
|
| 176 |
+
rank_zero_info(f"Time limit reached. Elapsed time is {elapsed}. Signaling Trainer to stop.")
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/xla_stats_monitor.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright The PyTorch Lightning team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""
|
| 15 |
+
XLA Stats Monitor
|
| 16 |
+
=================
|
| 17 |
+
|
| 18 |
+
Monitor and logs XLA stats during training.
|
| 19 |
+
|
| 20 |
+
"""
|
| 21 |
+
import time
|
| 22 |
+
|
| 23 |
+
import pytorch_lightning as pl
|
| 24 |
+
from pytorch_lightning.accelerators import TPUAccelerator
|
| 25 |
+
from pytorch_lightning.callbacks.base import Callback
|
| 26 |
+
from pytorch_lightning.utilities import _TPU_AVAILABLE
|
| 27 |
+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
| 28 |
+
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info
|
| 29 |
+
|
| 30 |
+
if _TPU_AVAILABLE:
|
| 31 |
+
import torch_xla.core.xla_model as xm
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class XLAStatsMonitor(Callback):
|
| 35 |
+
r"""
|
| 36 |
+
.. deprecated:: v1.5
|
| 37 |
+
The `XLAStatsMonitor` callback was deprecated in v1.5 and will be removed in v1.7.
|
| 38 |
+
Please use the `DeviceStatsMonitor` callback instead.
|
| 39 |
+
|
| 40 |
+
Automatically monitors and logs XLA stats during training stage. ``XLAStatsMonitor`` is a callback and in
|
| 41 |
+
order to use it you need to assign a logger in the ``Trainer``.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
verbose: Set to ``True`` to print average peak and free memory, and epoch time
|
| 45 |
+
every epoch.
|
| 46 |
+
|
| 47 |
+
Raises:
|
| 48 |
+
MisconfigurationException:
|
| 49 |
+
If not running on TPUs, or ``Trainer`` has no logger.
|
| 50 |
+
|
| 51 |
+
Example::
|
| 52 |
+
|
| 53 |
+
>>> from pytorch_lightning import Trainer
|
| 54 |
+
>>> from pytorch_lightning.callbacks import XLAStatsMonitor
|
| 55 |
+
>>> xla_stats = XLAStatsMonitor() # doctest: +SKIP
|
| 56 |
+
>>> trainer = Trainer(callbacks=[xla_stats]) # doctest: +SKIP
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
def __init__(self, verbose: bool = True) -> None:
|
| 60 |
+
super().__init__()
|
| 61 |
+
|
| 62 |
+
rank_zero_deprecation(
|
| 63 |
+
"The `XLAStatsMonitor` callback was deprecated in v1.5 and will be removed in v1.7."
|
| 64 |
+
" Please use the `DeviceStatsMonitor` callback instead."
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
if not _TPU_AVAILABLE:
|
| 68 |
+
raise MisconfigurationException("Cannot use XLAStatsMonitor with TPUs are not available")
|
| 69 |
+
|
| 70 |
+
self._verbose = verbose
|
| 71 |
+
|
| 72 |
+
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 73 |
+
if not trainer.loggers:
|
| 74 |
+
raise MisconfigurationException("Cannot use XLAStatsMonitor callback with Trainer that has no logger.")
|
| 75 |
+
|
| 76 |
+
if not isinstance(trainer.accelerator, TPUAccelerator):
|
| 77 |
+
raise MisconfigurationException(
|
| 78 |
+
"You are using XLAStatsMonitor but are not running on TPU."
|
| 79 |
+
f" The accelerator is set to {trainer.accelerator.__class__.__name__}."
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
device = trainer.strategy.root_device
|
| 83 |
+
memory_info = xm.get_memory_info(device)
|
| 84 |
+
total_memory = trainer.strategy.reduce(memory_info["kb_total"]) * 0.001
|
| 85 |
+
rank_zero_info(f"Average Total memory: {total_memory:.2f} MB")
|
| 86 |
+
|
| 87 |
+
def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 88 |
+
self._start_time = time.time()
|
| 89 |
+
|
| 90 |
+
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
| 91 |
+
if not trainer.loggers:
|
| 92 |
+
raise MisconfigurationException("Cannot use XLAStatsMonitor callback with Trainer that has no logger.")
|
| 93 |
+
|
| 94 |
+
device = trainer.strategy.root_device
|
| 95 |
+
memory_info = xm.get_memory_info(device)
|
| 96 |
+
epoch_time = time.time() - self._start_time
|
| 97 |
+
|
| 98 |
+
free_memory = memory_info["kb_free"]
|
| 99 |
+
peak_memory = memory_info["kb_total"] - free_memory
|
| 100 |
+
|
| 101 |
+
free_memory = trainer.strategy.reduce(free_memory) * 0.001
|
| 102 |
+
peak_memory = trainer.strategy.reduce(peak_memory) * 0.001
|
| 103 |
+
epoch_time = trainer.strategy.reduce(epoch_time)
|
| 104 |
+
|
| 105 |
+
for logger in trainer.loggers:
|
| 106 |
+
logger.log_metrics(
|
| 107 |
+
{"avg. free memory (MB)": float(free_memory), "avg. peak memory (MB)": float(peak_memory)},
|
| 108 |
+
step=trainer.current_epoch,
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
if self._verbose:
|
| 112 |
+
rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
|
| 113 |
+
rank_zero_info(f"Average Peak memory: {peak_memory:.2f} MB")
|
| 114 |
+
rank_zero_info(f"Average Free memory: {free_memory:.2f} MB")
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/core/datamodule.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright The PyTorch Lightning team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""LightningDataModule for loading DataLoaders with ease."""
|
| 15 |
+
from argparse import ArgumentParser, Namespace
|
| 16 |
+
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union
|
| 17 |
+
|
| 18 |
+
from torch.utils.data import DataLoader, Dataset, IterableDataset
|
| 19 |
+
|
| 20 |
+
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks
|
| 21 |
+
from pytorch_lightning.core.mixins import HyperparametersMixin
|
| 22 |
+
from pytorch_lightning.utilities import rank_zero_deprecation
|
| 23 |
+
from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class LightningDataModule(CheckpointHooks, DataHooks, HyperparametersMixin):
|
| 27 |
+
"""A DataModule standardizes the training, val, test splits, data preparation and transforms. The main
|
| 28 |
+
advantage is consistent data splits, data preparation and transforms across models.
|
| 29 |
+
|
| 30 |
+
Example::
|
| 31 |
+
|
| 32 |
+
class MyDataModule(LightningDataModule):
|
| 33 |
+
def __init__(self):
|
| 34 |
+
super().__init__()
|
| 35 |
+
def prepare_data(self):
|
| 36 |
+
# download, split, etc...
|
| 37 |
+
# only called on 1 GPU/TPU in distributed
|
| 38 |
+
def setup(self, stage):
|
| 39 |
+
# make assignments here (val/train/test split)
|
| 40 |
+
# called on every process in DDP
|
| 41 |
+
def train_dataloader(self):
|
| 42 |
+
train_split = Dataset(...)
|
| 43 |
+
return DataLoader(train_split)
|
| 44 |
+
def val_dataloader(self):
|
| 45 |
+
val_split = Dataset(...)
|
| 46 |
+
return DataLoader(val_split)
|
| 47 |
+
def test_dataloader(self):
|
| 48 |
+
test_split = Dataset(...)
|
| 49 |
+
return DataLoader(test_split)
|
| 50 |
+
def teardown(self):
|
| 51 |
+
# clean up after fit or test
|
| 52 |
+
# called on every process in DDP
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
name: str = ...
|
| 56 |
+
|
| 57 |
+
def __init__(self, train_transforms=None, val_transforms=None, test_transforms=None, dims=None):
|
| 58 |
+
super().__init__()
|
| 59 |
+
if train_transforms is not None:
|
| 60 |
+
rank_zero_deprecation(
|
| 61 |
+
"DataModule property `train_transforms` was deprecated in v1.5 and will be removed in v1.7."
|
| 62 |
+
)
|
| 63 |
+
if val_transforms is not None:
|
| 64 |
+
rank_zero_deprecation(
|
| 65 |
+
"DataModule property `val_transforms` was deprecated in v1.5 and will be removed in v1.7."
|
| 66 |
+
)
|
| 67 |
+
if test_transforms is not None:
|
| 68 |
+
rank_zero_deprecation(
|
| 69 |
+
"DataModule property `test_transforms` was deprecated in v1.5 and will be removed in v1.7."
|
| 70 |
+
)
|
| 71 |
+
if dims is not None:
|
| 72 |
+
rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.")
|
| 73 |
+
self._train_transforms = train_transforms
|
| 74 |
+
self._val_transforms = val_transforms
|
| 75 |
+
self._test_transforms = test_transforms
|
| 76 |
+
self._dims = dims if dims is not None else ()
|
| 77 |
+
|
| 78 |
+
# Pointer to the trainer object
|
| 79 |
+
self.trainer = None
|
| 80 |
+
|
| 81 |
+
@property
|
| 82 |
+
def train_transforms(self):
|
| 83 |
+
"""Optional transforms (or collection of transforms) you can apply to train dataset.
|
| 84 |
+
|
| 85 |
+
.. deprecated:: v1.5 Will be removed in v1.7.0.
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
rank_zero_deprecation(
|
| 89 |
+
"DataModule property `train_transforms` was deprecated in v1.5 and will be removed in v1.7."
|
| 90 |
+
)
|
| 91 |
+
return self._train_transforms
|
| 92 |
+
|
| 93 |
+
@train_transforms.setter
|
| 94 |
+
def train_transforms(self, t):
|
| 95 |
+
rank_zero_deprecation(
|
| 96 |
+
"DataModule property `train_transforms` was deprecated in v1.5 and will be removed in v1.7."
|
| 97 |
+
)
|
| 98 |
+
self._train_transforms = t
|
| 99 |
+
|
| 100 |
+
@property
|
| 101 |
+
def val_transforms(self):
|
| 102 |
+
"""Optional transforms (or collection of transforms) you can apply to validation dataset.
|
| 103 |
+
|
| 104 |
+
.. deprecated:: v1.5 Will be removed in v1.7.0.
|
| 105 |
+
"""
|
| 106 |
+
|
| 107 |
+
rank_zero_deprecation(
|
| 108 |
+
"DataModule property `val_transforms` was deprecated in v1.5 and will be removed in v1.7."
|
| 109 |
+
)
|
| 110 |
+
return self._val_transforms
|
| 111 |
+
|
| 112 |
+
@val_transforms.setter
|
| 113 |
+
def val_transforms(self, t):
|
| 114 |
+
rank_zero_deprecation(
|
| 115 |
+
"DataModule property `val_transforms` was deprecated in v1.5 and will be removed in v1.7."
|
| 116 |
+
)
|
| 117 |
+
self._val_transforms = t
|
| 118 |
+
|
| 119 |
+
@property
|
| 120 |
+
def test_transforms(self):
|
| 121 |
+
"""Optional transforms (or collection of transforms) you can apply to test dataset.
|
| 122 |
+
|
| 123 |
+
.. deprecated:: v1.5 Will be removed in v1.7.0.
|
| 124 |
+
"""
|
| 125 |
+
|
| 126 |
+
rank_zero_deprecation(
|
| 127 |
+
"DataModule property `test_transforms` was deprecated in v1.5 and will be removed in v1.7."
|
| 128 |
+
)
|
| 129 |
+
return self._test_transforms
|
| 130 |
+
|
| 131 |
+
@test_transforms.setter
|
| 132 |
+
def test_transforms(self, t):
|
| 133 |
+
rank_zero_deprecation(
|
| 134 |
+
"DataModule property `test_transforms` was deprecated in v1.5 and will be removed in v1.7."
|
| 135 |
+
)
|
| 136 |
+
self._test_transforms = t
|
| 137 |
+
|
| 138 |
+
@property
|
| 139 |
+
def dims(self):
|
| 140 |
+
"""A tuple describing the shape of your data. Extra functionality exposed in ``size``.
|
| 141 |
+
|
| 142 |
+
.. deprecated:: v1.5 Will be removed in v1.7.0.
|
| 143 |
+
"""
|
| 144 |
+
rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.")
|
| 145 |
+
return self._dims
|
| 146 |
+
|
| 147 |
+
@dims.setter
|
| 148 |
+
def dims(self, d):
|
| 149 |
+
rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.")
|
| 150 |
+
self._dims = d
|
| 151 |
+
|
| 152 |
+
def size(self, dim=None) -> Union[Tuple, List[Tuple]]:
|
| 153 |
+
"""Return the dimension of each input either as a tuple or list of tuples. You can index this just as you
|
| 154 |
+
would with a torch tensor.
|
| 155 |
+
|
| 156 |
+
.. deprecated:: v1.5 Will be removed in v1.7.0.
|
| 157 |
+
"""
|
| 158 |
+
rank_zero_deprecation("DataModule property `size` was deprecated in v1.5 and will be removed in v1.7.")
|
| 159 |
+
|
| 160 |
+
if dim is not None:
|
| 161 |
+
return self.dims[dim]
|
| 162 |
+
|
| 163 |
+
return self.dims
|
| 164 |
+
|
| 165 |
+
@classmethod
|
| 166 |
+
def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentParser:
|
| 167 |
+
"""Extends existing argparse by default `LightningDataModule` attributes."""
|
| 168 |
+
return add_argparse_args(cls, parent_parser, **kwargs)
|
| 169 |
+
|
| 170 |
+
@classmethod
|
| 171 |
+
def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
|
| 172 |
+
"""Create an instance from CLI arguments.
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
args: The parser or namespace to take arguments from. Only known arguments will be
|
| 176 |
+
parsed and passed to the :class:`~pytorch_lightning.core.datamodule.LightningDataModule`.
|
| 177 |
+
**kwargs: Additional keyword arguments that may override ones in the parser or namespace.
|
| 178 |
+
These must be valid DataModule arguments.
|
| 179 |
+
|
| 180 |
+
Example::
|
| 181 |
+
|
| 182 |
+
parser = ArgumentParser(add_help=False)
|
| 183 |
+
parser = LightningDataModule.add_argparse_args(parser)
|
| 184 |
+
module = LightningDataModule.from_argparse_args(args)
|
| 185 |
+
"""
|
| 186 |
+
return from_argparse_args(cls, args, **kwargs)
|
| 187 |
+
|
| 188 |
+
@classmethod
|
| 189 |
+
def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
|
| 190 |
+
r"""Scans the DataModule signature and returns argument names, types and default values.
|
| 191 |
+
|
| 192 |
+
Returns:
|
| 193 |
+
List with tuples of 3 values:
|
| 194 |
+
(argument name, set with argument types, argument default value).
|
| 195 |
+
"""
|
| 196 |
+
return get_init_arguments_and_types(cls)
|
| 197 |
+
|
| 198 |
+
@classmethod
|
| 199 |
+
def from_datasets(
|
| 200 |
+
cls,
|
| 201 |
+
train_dataset: Optional[Union[Dataset, Sequence[Dataset], Mapping[str, Dataset]]] = None,
|
| 202 |
+
val_dataset: Optional[Union[Dataset, Sequence[Dataset]]] = None,
|
| 203 |
+
test_dataset: Optional[Union[Dataset, Sequence[Dataset]]] = None,
|
| 204 |
+
batch_size: int = 1,
|
| 205 |
+
num_workers: int = 0,
|
| 206 |
+
):
|
| 207 |
+
r"""
|
| 208 |
+
Create an instance from torch.utils.data.Dataset.
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
train_dataset: (optional) Dataset to be used for train_dataloader()
|
| 212 |
+
val_dataset: (optional) Dataset or list of Dataset to be used for val_dataloader()
|
| 213 |
+
test_dataset: (optional) Dataset or list of Dataset to be used for test_dataloader()
|
| 214 |
+
batch_size: Batch size to use for each dataloader. Default is 1.
|
| 215 |
+
num_workers: Number of subprocesses to use for data loading. 0 means that the
|
| 216 |
+
data will be loaded in the main process. Number of CPUs available.
|
| 217 |
+
|
| 218 |
+
"""
|
| 219 |
+
|
| 220 |
+
def dataloader(ds: Dataset, shuffle: bool = False) -> DataLoader:
|
| 221 |
+
shuffle &= not isinstance(ds, IterableDataset)
|
| 222 |
+
return DataLoader(ds, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True)
|
| 223 |
+
|
| 224 |
+
def train_dataloader():
|
| 225 |
+
if isinstance(train_dataset, Mapping):
|
| 226 |
+
return {key: dataloader(ds, shuffle=True) for key, ds in train_dataset.items()}
|
| 227 |
+
if isinstance(train_dataset, Sequence):
|
| 228 |
+
return [dataloader(ds, shuffle=True) for ds in train_dataset]
|
| 229 |
+
return dataloader(train_dataset, shuffle=True)
|
| 230 |
+
|
| 231 |
+
def val_dataloader():
|
| 232 |
+
if isinstance(val_dataset, Sequence):
|
| 233 |
+
return [dataloader(ds) for ds in val_dataset]
|
| 234 |
+
return dataloader(val_dataset)
|
| 235 |
+
|
| 236 |
+
def test_dataloader():
|
| 237 |
+
if isinstance(test_dataset, Sequence):
|
| 238 |
+
return [dataloader(ds) for ds in test_dataset]
|
| 239 |
+
return dataloader(test_dataset)
|
| 240 |
+
|
| 241 |
+
datamodule = cls()
|
| 242 |
+
if train_dataset is not None:
|
| 243 |
+
datamodule.train_dataloader = train_dataloader
|
| 244 |
+
if val_dataset is not None:
|
| 245 |
+
datamodule.val_dataloader = val_dataloader
|
| 246 |
+
if test_dataset is not None:
|
| 247 |
+
datamodule.test_dataloader = test_dataloader
|
| 248 |
+
return datamodule
|
| 249 |
+
|
| 250 |
+
def state_dict(self) -> Dict[str, Any]:
|
| 251 |
+
"""Called when saving a checkpoint, implement to generate and save datamodule state.
|
| 252 |
+
|
| 253 |
+
Returns:
|
| 254 |
+
A dictionary containing datamodule state.
|
| 255 |
+
"""
|
| 256 |
+
return {}
|
| 257 |
+
|
| 258 |
+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
| 259 |
+
"""Called when loading a checkpoint, implement to reload datamodule state given datamodule state_dict.
|
| 260 |
+
|
| 261 |
+
Args:
|
| 262 |
+
state_dict: the datamodule state returned by ``state_dict``.
|
| 263 |
+
"""
|
| 264 |
+
pass
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/core/decorators.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright The PyTorch Lightning team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
|
| 15 |
+
|
| 16 |
+
rank_zero_deprecation(
|
| 17 |
+
"Using `pytorch_lightning.core.decorators.parameter_validation` is deprecated in v1.5, "
|
| 18 |
+
"and will be removed in v1.7. It has been replaced by automatic parameters tying with "
|
| 19 |
+
"`pytorch_lightning.utilities.params_tying.set_shared_parameters`"
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
from functools import wraps # noqa: E402
|
| 23 |
+
from typing import Callable # noqa: E402
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def parameter_validation(fn: Callable) -> Callable:
|
| 27 |
+
"""Validates that the module parameter lengths match after moving to the device. It is useful when tying
|
| 28 |
+
weights on TPU's.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
fn: ``model_to_device`` method
|
| 32 |
+
|
| 33 |
+
Note:
|
| 34 |
+
TPU's require weights to be tied/shared after moving the module to the device.
|
| 35 |
+
Failure to do this results in the initialization of new weights which are not tied.
|
| 36 |
+
To overcome this issue, weights should be tied using the ``on_post_move_to_device`` model hook
|
| 37 |
+
which is called after the module has been moved to the device.
|
| 38 |
+
|
| 39 |
+
See Also:
|
| 40 |
+
- `XLA Documentation <https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#xla-tensor-quirks>`_
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
@wraps(fn)
|
| 44 |
+
def inner_fn(self, *args, **kwargs):
|
| 45 |
+
pre_layer_count = len(list(self.model.parameters()))
|
| 46 |
+
module = fn(self, *args, **kwargs)
|
| 47 |
+
self.model.on_post_move_to_device()
|
| 48 |
+
post_layer_count = len(list(self.model.parameters()))
|
| 49 |
+
|
| 50 |
+
if not pre_layer_count == post_layer_count:
|
| 51 |
+
rank_zero_warn(
|
| 52 |
+
"The model layers do not match after moving to the target device."
|
| 53 |
+
" If your model employs weight sharing on TPU,"
|
| 54 |
+
" please tie your weights using the `on_post_move_to_device` model hook.\n"
|
| 55 |
+
f"Layer count: [Before: {pre_layer_count} After: {post_layer_count}]"
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
return module
|
| 59 |
+
|
| 60 |
+
return inner_fn
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/core/hooks.py
ADDED
|
@@ -0,0 +1,828 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright The PyTorch Lightning team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""Various hooks to be used in the Lightning code."""
|
| 15 |
+
|
| 16 |
+
from typing import Any, Dict, List, Optional
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
from torch.optim.optimizer import Optimizer
|
| 20 |
+
|
| 21 |
+
from pytorch_lightning.utilities import move_data_to_device
|
| 22 |
+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
| 23 |
+
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class ModelHooks:
|
| 27 |
+
"""Hooks to be used in LightningModule."""
|
| 28 |
+
|
| 29 |
+
def on_fit_start(self) -> None:
|
| 30 |
+
"""Called at the very beginning of fit.
|
| 31 |
+
|
| 32 |
+
If on DDP it is called on every process
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def on_fit_end(self) -> None:
|
| 36 |
+
"""Called at the very end of fit.
|
| 37 |
+
|
| 38 |
+
If on DDP it is called on every process
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def on_train_start(self) -> None:
|
| 42 |
+
"""Called at the beginning of training after sanity check."""
|
| 43 |
+
|
| 44 |
+
def on_train_end(self) -> None:
|
| 45 |
+
"""Called at the end of training before logger experiment is closed."""
|
| 46 |
+
|
| 47 |
+
def on_validation_start(self) -> None:
|
| 48 |
+
"""Called at the beginning of validation."""
|
| 49 |
+
|
| 50 |
+
def on_validation_end(self) -> None:
|
| 51 |
+
"""Called at the end of validation."""
|
| 52 |
+
|
| 53 |
+
def on_test_start(self) -> None:
|
| 54 |
+
"""Called at the beginning of testing."""
|
| 55 |
+
|
| 56 |
+
def on_test_end(self) -> None:
|
| 57 |
+
"""Called at the end of testing."""
|
| 58 |
+
|
| 59 |
+
def on_predict_start(self) -> None:
|
| 60 |
+
"""Called at the beginning of predicting."""
|
| 61 |
+
|
| 62 |
+
def on_predict_end(self) -> None:
|
| 63 |
+
"""Called at the end of predicting."""
|
| 64 |
+
|
| 65 |
+
def on_pretrain_routine_start(self) -> None:
|
| 66 |
+
"""Called at the beginning of the pretrain routine (between fit and train start).
|
| 67 |
+
|
| 68 |
+
- fit
|
| 69 |
+
- pretrain_routine start
|
| 70 |
+
- pretrain_routine end
|
| 71 |
+
- training_start
|
| 72 |
+
|
| 73 |
+
.. deprecated:: v1.6
|
| 74 |
+
:meth:`on_pretrain_routine_start` has been deprecated in v1.6 and will be removed in v1.8.
|
| 75 |
+
Use ``on_fit_start`` instead.
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
def on_pretrain_routine_end(self) -> None:
|
| 79 |
+
"""Called at the end of the pretrain routine (between fit and train start).
|
| 80 |
+
|
| 81 |
+
- fit
|
| 82 |
+
- pretrain_routine start
|
| 83 |
+
- pretrain_routine end
|
| 84 |
+
- training_start
|
| 85 |
+
|
| 86 |
+
.. deprecated:: v1.6
|
| 87 |
+
:meth:`on_pretrain_routine_end` has been deprecated in v1.6 and will be removed in v1.8.
|
| 88 |
+
Use ``on_fit_start`` instead.
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
def on_train_batch_start(self, batch: Any, batch_idx: int, unused: int = 0) -> Optional[int]:
|
| 92 |
+
"""Called in the training loop before anything happens for that batch.
|
| 93 |
+
|
| 94 |
+
If you return -1 here, you will skip training for the rest of the current epoch.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
batch: The batched data as it is returned by the training DataLoader.
|
| 98 |
+
batch_idx: the index of the batch
|
| 99 |
+
unused: Deprecated argument. Will be removed in v1.7.
|
| 100 |
+
"""
|
| 101 |
+
|
| 102 |
+
def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, unused: int = 0) -> None:
|
| 103 |
+
"""Called in the training loop after the batch.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
outputs: The outputs of training_step_end(training_step(x))
|
| 107 |
+
batch: The batched data as it is returned by the training DataLoader.
|
| 108 |
+
batch_idx: the index of the batch
|
| 109 |
+
unused: Deprecated argument. Will be removed in v1.7.
|
| 110 |
+
"""
|
| 111 |
+
|
| 112 |
+
def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
|
| 113 |
+
"""Called in the validation loop before anything happens for that batch.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
batch: The batched data as it is returned by the validation DataLoader.
|
| 117 |
+
batch_idx: the index of the batch
|
| 118 |
+
dataloader_idx: the index of the dataloader
|
| 119 |
+
"""
|
| 120 |
+
|
| 121 |
+
def on_validation_batch_end(
|
| 122 |
+
self, outputs: Optional[STEP_OUTPUT], batch: Any, batch_idx: int, dataloader_idx: int
|
| 123 |
+
) -> None:
|
| 124 |
+
"""Called in the validation loop after the batch.
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
outputs: The outputs of validation_step_end(validation_step(x))
|
| 128 |
+
batch: The batched data as it is returned by the validation DataLoader.
|
| 129 |
+
batch_idx: the index of the batch
|
| 130 |
+
dataloader_idx: the index of the dataloader
|
| 131 |
+
"""
|
| 132 |
+
|
| 133 |
+
def on_test_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
|
| 134 |
+
"""Called in the test loop before anything happens for that batch.
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
batch: The batched data as it is returned by the test DataLoader.
|
| 138 |
+
batch_idx: the index of the batch
|
| 139 |
+
dataloader_idx: the index of the dataloader
|
| 140 |
+
"""
|
| 141 |
+
|
| 142 |
+
def on_test_batch_end(
|
| 143 |
+
self, outputs: Optional[STEP_OUTPUT], batch: Any, batch_idx: int, dataloader_idx: int
|
| 144 |
+
) -> None:
|
| 145 |
+
"""Called in the test loop after the batch.
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
outputs: The outputs of test_step_end(test_step(x))
|
| 149 |
+
batch: The batched data as it is returned by the test DataLoader.
|
| 150 |
+
batch_idx: the index of the batch
|
| 151 |
+
dataloader_idx: the index of the dataloader
|
| 152 |
+
"""
|
| 153 |
+
|
| 154 |
+
def on_predict_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
|
| 155 |
+
"""Called in the predict loop before anything happens for that batch.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
batch: The batched data as it is returned by the test DataLoader.
|
| 159 |
+
batch_idx: the index of the batch
|
| 160 |
+
dataloader_idx: the index of the dataloader
|
| 161 |
+
"""
|
| 162 |
+
|
| 163 |
+
def on_predict_batch_end(self, outputs: Optional[Any], batch: Any, batch_idx: int, dataloader_idx: int) -> None:
|
| 164 |
+
"""Called in the predict loop after the batch.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
outputs: The outputs of predict_step_end(test_step(x))
|
| 168 |
+
batch: The batched data as it is returned by the test DataLoader.
|
| 169 |
+
batch_idx: the index of the batch
|
| 170 |
+
dataloader_idx: the index of the dataloader
|
| 171 |
+
"""
|
| 172 |
+
|
| 173 |
+
def on_validation_model_eval(self) -> None:
|
| 174 |
+
"""Sets the model to eval during the val loop."""
|
| 175 |
+
self.trainer.model.eval()
|
| 176 |
+
|
| 177 |
+
def on_validation_model_train(self) -> None:
|
| 178 |
+
"""Sets the model to train during the val loop."""
|
| 179 |
+
self.trainer.model.train()
|
| 180 |
+
|
| 181 |
+
def on_test_model_train(self) -> None:
|
| 182 |
+
"""Sets the model to train during the test loop."""
|
| 183 |
+
self.trainer.model.train()
|
| 184 |
+
|
| 185 |
+
def on_test_model_eval(self) -> None:
|
| 186 |
+
"""Sets the model to eval during the test loop."""
|
| 187 |
+
self.trainer.model.eval()
|
| 188 |
+
|
| 189 |
+
def on_predict_model_eval(self) -> None:
|
| 190 |
+
"""Sets the model to eval during the predict loop."""
|
| 191 |
+
self.trainer.model.eval()
|
| 192 |
+
|
| 193 |
+
def on_epoch_start(self) -> None:
|
| 194 |
+
"""Called when either of train/val/test epoch begins.
|
| 195 |
+
|
| 196 |
+
.. deprecated:: v1.6
|
| 197 |
+
:meth:`on_epoch_start` has been deprecated in v1.6 and will be removed in v1.8.
|
| 198 |
+
Use ``on_<train/validation/test>_epoch_start`` instead.
|
| 199 |
+
"""
|
| 200 |
+
|
| 201 |
+
def on_epoch_end(self) -> None:
|
| 202 |
+
"""Called when either of train/val/test epoch ends.
|
| 203 |
+
|
| 204 |
+
.. deprecated:: v1.6
|
| 205 |
+
:meth:`on_epoch_end` has been deprecated in v1.6 and will be removed in v1.8.
|
| 206 |
+
Use ``on_<train/validation/test>_epoch_end`` instead.
|
| 207 |
+
"""
|
| 208 |
+
|
| 209 |
+
def on_train_epoch_start(self) -> None:
|
| 210 |
+
"""Called in the training loop at the very beginning of the epoch."""
|
| 211 |
+
|
| 212 |
+
def on_train_epoch_end(self) -> None:
|
| 213 |
+
"""Called in the training loop at the very end of the epoch.
|
| 214 |
+
|
| 215 |
+
To access all batch outputs at the end of the epoch, either:
|
| 216 |
+
|
| 217 |
+
1. Implement `training_epoch_end` in the LightningModule OR
|
| 218 |
+
2. Cache data across steps on the attribute(s) of the `LightningModule` and access them in this hook
|
| 219 |
+
"""
|
| 220 |
+
|
| 221 |
+
def on_validation_epoch_start(self) -> None:
|
| 222 |
+
"""Called in the validation loop at the very beginning of the epoch."""
|
| 223 |
+
|
| 224 |
+
def on_validation_epoch_end(self) -> None:
|
| 225 |
+
"""Called in the validation loop at the very end of the epoch."""
|
| 226 |
+
|
| 227 |
+
def on_test_epoch_start(self) -> None:
|
| 228 |
+
"""Called in the test loop at the very beginning of the epoch."""
|
| 229 |
+
|
| 230 |
+
def on_test_epoch_end(self) -> None:
|
| 231 |
+
"""Called in the test loop at the very end of the epoch."""
|
| 232 |
+
|
| 233 |
+
def on_predict_epoch_start(self) -> None:
|
| 234 |
+
"""Called at the beginning of predicting."""
|
| 235 |
+
|
| 236 |
+
def on_predict_epoch_end(self, results: List[Any]) -> None:
|
| 237 |
+
"""Called at the end of predicting."""
|
| 238 |
+
|
| 239 |
+
def on_before_zero_grad(self, optimizer: Optimizer) -> None:
|
| 240 |
+
"""Called after ``training_step()`` and before ``optimizer.zero_grad()``.
|
| 241 |
+
|
| 242 |
+
Called in the training loop after taking an optimizer step and before zeroing grads.
|
| 243 |
+
Good place to inspect weight information with weights updated.
|
| 244 |
+
|
| 245 |
+
This is where it is called::
|
| 246 |
+
|
| 247 |
+
for optimizer in optimizers:
|
| 248 |
+
out = training_step(...)
|
| 249 |
+
|
| 250 |
+
model.on_before_zero_grad(optimizer) # < ---- called here
|
| 251 |
+
optimizer.zero_grad()
|
| 252 |
+
|
| 253 |
+
backward()
|
| 254 |
+
|
| 255 |
+
Args:
|
| 256 |
+
optimizer: The optimizer for which grads should be zeroed.
|
| 257 |
+
"""
|
| 258 |
+
|
| 259 |
+
def on_before_backward(self, loss: torch.Tensor) -> None:
|
| 260 |
+
"""Called before ``loss.backward()``.
|
| 261 |
+
|
| 262 |
+
Args:
|
| 263 |
+
loss: Loss divided by number of batches for gradient accumulation and scaled if using native AMP.
|
| 264 |
+
"""
|
| 265 |
+
pass
|
| 266 |
+
|
| 267 |
+
def on_after_backward(self) -> None:
|
| 268 |
+
"""Called after ``loss.backward()`` and before optimizers are stepped.
|
| 269 |
+
|
| 270 |
+
Note:
|
| 271 |
+
If using native AMP, the gradients will not be unscaled at this point.
|
| 272 |
+
Use the ``on_before_optimizer_step`` if you need the unscaled gradients.
|
| 273 |
+
"""
|
| 274 |
+
|
| 275 |
+
def on_before_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None:
|
| 276 |
+
"""Called before ``optimizer.step()``.
|
| 277 |
+
|
| 278 |
+
If using gradient accumulation, the hook is called once the gradients have been accumulated.
|
| 279 |
+
See: :paramref:`~pytorch_lightning.trainer.Trainer.accumulate_grad_batches`.
|
| 280 |
+
|
| 281 |
+
If using native AMP, the loss will be unscaled before calling this hook.
|
| 282 |
+
See these `docs <https://pytorch.org/docs/stable/notes/amp_examples.html#working-with-unscaled-gradients>`__
|
| 283 |
+
for more information on the scaling of gradients.
|
| 284 |
+
|
| 285 |
+
If clipping gradients, the gradients will not have been clipped yet.
|
| 286 |
+
|
| 287 |
+
Args:
|
| 288 |
+
optimizer: Current optimizer being used.
|
| 289 |
+
optimizer_idx: Index of the current optimizer being used.
|
| 290 |
+
|
| 291 |
+
Example::
|
| 292 |
+
|
| 293 |
+
def on_before_optimizer_step(self, optimizer, optimizer_idx):
|
| 294 |
+
# example to inspect gradient information in tensorboard
|
| 295 |
+
if self.trainer.global_step % 25 == 0: # don't make the tf file huge
|
| 296 |
+
for k, v in self.named_parameters():
|
| 297 |
+
self.logger.experiment.add_histogram(
|
| 298 |
+
tag=k, values=v.grad, global_step=self.trainer.global_step
|
| 299 |
+
)
|
| 300 |
+
"""
|
| 301 |
+
|
| 302 |
+
def on_post_move_to_device(self) -> None:
|
| 303 |
+
"""Called in the ``parameter_validation`` decorator after
|
| 304 |
+
:meth:`~pytorch_lightning.core.LightningModule.to` is called. This is a good place to tie weights between
|
| 305 |
+
modules after moving them to a device. Can be used when training models with weight sharing properties on
|
| 306 |
+
TPU.
|
| 307 |
+
|
| 308 |
+
Addresses the handling of shared weights on TPU:
|
| 309 |
+
https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#xla-tensor-quirks
|
| 310 |
+
|
| 311 |
+
Example::
|
| 312 |
+
|
| 313 |
+
def on_post_move_to_device(self):
|
| 314 |
+
self.decoder.weight = self.encoder.weight
|
| 315 |
+
"""
|
| 316 |
+
|
| 317 |
+
def configure_sharded_model(self) -> None:
|
| 318 |
+
"""Hook to create modules in a distributed aware context. This is useful for when using sharded plugins,
|
| 319 |
+
where we'd like to shard the model instantly, which is useful for extremely large models which can save
|
| 320 |
+
memory and initialization time.
|
| 321 |
+
|
| 322 |
+
This hook is called during each of fit/val/test/predict stages in the same process, so ensure that
|
| 323 |
+
implementation of this hook is idempotent.
|
| 324 |
+
"""
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
class DataHooks:
|
| 328 |
+
"""Hooks to be used for data related stuff."""
|
| 329 |
+
|
| 330 |
+
def __init__(self) -> None:
|
| 331 |
+
"""
|
| 332 |
+
Attributes:
|
| 333 |
+
prepare_data_per_node:
|
| 334 |
+
If True, each LOCAL_RANK=0 will call prepare data.
|
| 335 |
+
Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data.
|
| 336 |
+
allow_zero_length_dataloader_with_multiple_devices:
|
| 337 |
+
If True, dataloader with zero length within local rank is allowed.
|
| 338 |
+
Default value is False.
|
| 339 |
+
"""
|
| 340 |
+
super().__init__()
|
| 341 |
+
self.prepare_data_per_node: bool = True
|
| 342 |
+
self.allow_zero_length_dataloader_with_multiple_devices: bool = False
|
| 343 |
+
|
| 344 |
+
def prepare_data(self) -> None:
|
| 345 |
+
"""Use this to download and prepare data. Downloading and saving data with multiple processes (distributed
|
| 346 |
+
settings) will result in corrupted data. Lightning ensures this method is called only within a single
|
| 347 |
+
process, so you can safely add your downloading logic within.
|
| 348 |
+
|
| 349 |
+
.. warning:: DO NOT set state to the model (use ``setup`` instead)
|
| 350 |
+
since this is NOT called on every device
|
| 351 |
+
|
| 352 |
+
Example::
|
| 353 |
+
|
| 354 |
+
def prepare_data(self):
|
| 355 |
+
# good
|
| 356 |
+
download_data()
|
| 357 |
+
tokenize()
|
| 358 |
+
etc()
|
| 359 |
+
|
| 360 |
+
# bad
|
| 361 |
+
self.split = data_split
|
| 362 |
+
self.some_state = some_other_state()
|
| 363 |
+
|
| 364 |
+
In DDP ``prepare_data`` can be called in two ways (using Trainer(prepare_data_per_node)):
|
| 365 |
+
|
| 366 |
+
1. Once per node. This is the default and is only called on LOCAL_RANK=0.
|
| 367 |
+
2. Once in total. Only called on GLOBAL_RANK=0.
|
| 368 |
+
|
| 369 |
+
See :ref:`prepare_data_per_node<common/lightning_module:prepare_data_per_node>`.
|
| 370 |
+
|
| 371 |
+
Example::
|
| 372 |
+
|
| 373 |
+
# DEFAULT
|
| 374 |
+
# called once per node on LOCAL_RANK=0 of that node
|
| 375 |
+
Trainer(prepare_data_per_node=True)
|
| 376 |
+
|
| 377 |
+
# call on GLOBAL_RANK=0 (great for shared file systems)
|
| 378 |
+
Trainer(prepare_data_per_node=False)
|
| 379 |
+
|
| 380 |
+
This is called before requesting the dataloaders:
|
| 381 |
+
|
| 382 |
+
.. code-block:: python
|
| 383 |
+
|
| 384 |
+
model.prepare_data()
|
| 385 |
+
initialize_distributed()
|
| 386 |
+
model.setup(stage)
|
| 387 |
+
model.train_dataloader()
|
| 388 |
+
model.val_dataloader()
|
| 389 |
+
model.test_dataloader()
|
| 390 |
+
"""
|
| 391 |
+
|
| 392 |
+
def setup(self, stage: Optional[str] = None) -> None:
|
| 393 |
+
"""Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when
|
| 394 |
+
you need to build models dynamically or adjust something about them. This hook is called on every process
|
| 395 |
+
when using DDP.
|
| 396 |
+
|
| 397 |
+
Args:
|
| 398 |
+
stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'``
|
| 399 |
+
|
| 400 |
+
Example::
|
| 401 |
+
|
| 402 |
+
class LitModel(...):
|
| 403 |
+
def __init__(self):
|
| 404 |
+
self.l1 = None
|
| 405 |
+
|
| 406 |
+
def prepare_data(self):
|
| 407 |
+
download_data()
|
| 408 |
+
tokenize()
|
| 409 |
+
|
| 410 |
+
# don't do this
|
| 411 |
+
self.something = else
|
| 412 |
+
|
| 413 |
+
def setup(self, stage):
|
| 414 |
+
data = load_data(...)
|
| 415 |
+
self.l1 = nn.Linear(28, data.num_classes)
|
| 416 |
+
"""
|
| 417 |
+
|
| 418 |
+
def teardown(self, stage: Optional[str] = None) -> None:
|
| 419 |
+
"""Called at the end of fit (train + validate), validate, test, or predict.
|
| 420 |
+
|
| 421 |
+
Args:
|
| 422 |
+
stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'``
|
| 423 |
+
"""
|
| 424 |
+
|
| 425 |
+
def train_dataloader(self) -> TRAIN_DATALOADERS:
|
| 426 |
+
"""Implement one or more PyTorch DataLoaders for training.
|
| 427 |
+
|
| 428 |
+
Return:
|
| 429 |
+
A collection of :class:`torch.utils.data.DataLoader` specifying training samples.
|
| 430 |
+
In the case of multiple dataloaders, please see this :ref:`section <multiple-dataloaders>`.
|
| 431 |
+
|
| 432 |
+
The dataloader you return will not be reloaded unless you set
|
| 433 |
+
:paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to
|
| 434 |
+
a positive integer.
|
| 435 |
+
|
| 436 |
+
For data processing use the following pattern:
|
| 437 |
+
|
| 438 |
+
- download in :meth:`prepare_data`
|
| 439 |
+
- process and split in :meth:`setup`
|
| 440 |
+
|
| 441 |
+
However, the above are only necessary for distributed processing.
|
| 442 |
+
|
| 443 |
+
.. warning:: do not assign state in prepare_data
|
| 444 |
+
|
| 445 |
+
- :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit`
|
| 446 |
+
- :meth:`prepare_data`
|
| 447 |
+
- :meth:`setup`
|
| 448 |
+
|
| 449 |
+
Note:
|
| 450 |
+
Lightning adds the correct sampler for distributed and arbitrary hardware.
|
| 451 |
+
There is no need to set it yourself.
|
| 452 |
+
|
| 453 |
+
Example::
|
| 454 |
+
|
| 455 |
+
# single dataloader
|
| 456 |
+
def train_dataloader(self):
|
| 457 |
+
transform = transforms.Compose([transforms.ToTensor(),
|
| 458 |
+
transforms.Normalize((0.5,), (1.0,))])
|
| 459 |
+
dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform,
|
| 460 |
+
download=True)
|
| 461 |
+
loader = torch.utils.data.DataLoader(
|
| 462 |
+
dataset=dataset,
|
| 463 |
+
batch_size=self.batch_size,
|
| 464 |
+
shuffle=True
|
| 465 |
+
)
|
| 466 |
+
return loader
|
| 467 |
+
|
| 468 |
+
# multiple dataloaders, return as list
|
| 469 |
+
def train_dataloader(self):
|
| 470 |
+
mnist = MNIST(...)
|
| 471 |
+
cifar = CIFAR(...)
|
| 472 |
+
mnist_loader = torch.utils.data.DataLoader(
|
| 473 |
+
dataset=mnist, batch_size=self.batch_size, shuffle=True
|
| 474 |
+
)
|
| 475 |
+
cifar_loader = torch.utils.data.DataLoader(
|
| 476 |
+
dataset=cifar, batch_size=self.batch_size, shuffle=True
|
| 477 |
+
)
|
| 478 |
+
# each batch will be a list of tensors: [batch_mnist, batch_cifar]
|
| 479 |
+
return [mnist_loader, cifar_loader]
|
| 480 |
+
|
| 481 |
+
# multiple dataloader, return as dict
|
| 482 |
+
def train_dataloader(self):
|
| 483 |
+
mnist = MNIST(...)
|
| 484 |
+
cifar = CIFAR(...)
|
| 485 |
+
mnist_loader = torch.utils.data.DataLoader(
|
| 486 |
+
dataset=mnist, batch_size=self.batch_size, shuffle=True
|
| 487 |
+
)
|
| 488 |
+
cifar_loader = torch.utils.data.DataLoader(
|
| 489 |
+
dataset=cifar, batch_size=self.batch_size, shuffle=True
|
| 490 |
+
)
|
| 491 |
+
# each batch will be a dict of tensors: {'mnist': batch_mnist, 'cifar': batch_cifar}
|
| 492 |
+
return {'mnist': mnist_loader, 'cifar': cifar_loader}
|
| 493 |
+
"""
|
| 494 |
+
raise MisconfigurationException("`train_dataloader` must be implemented to be used with the Lightning Trainer")
|
| 495 |
+
|
| 496 |
+
def test_dataloader(self) -> EVAL_DATALOADERS:
|
| 497 |
+
r"""
|
| 498 |
+
Implement one or multiple PyTorch DataLoaders for testing.
|
| 499 |
+
|
| 500 |
+
For data processing use the following pattern:
|
| 501 |
+
|
| 502 |
+
- download in :meth:`prepare_data`
|
| 503 |
+
- process and split in :meth:`setup`
|
| 504 |
+
|
| 505 |
+
However, the above are only necessary for distributed processing.
|
| 506 |
+
|
| 507 |
+
.. warning:: do not assign state in prepare_data
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
- :meth:`~pytorch_lightning.trainer.trainer.Trainer.test`
|
| 511 |
+
- :meth:`prepare_data`
|
| 512 |
+
- :meth:`setup`
|
| 513 |
+
|
| 514 |
+
Note:
|
| 515 |
+
Lightning adds the correct sampler for distributed and arbitrary hardware.
|
| 516 |
+
There is no need to set it yourself.
|
| 517 |
+
|
| 518 |
+
Return:
|
| 519 |
+
A :class:`torch.utils.data.DataLoader` or a sequence of them specifying testing samples.
|
| 520 |
+
|
| 521 |
+
Example::
|
| 522 |
+
|
| 523 |
+
def test_dataloader(self):
|
| 524 |
+
transform = transforms.Compose([transforms.ToTensor(),
|
| 525 |
+
transforms.Normalize((0.5,), (1.0,))])
|
| 526 |
+
dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform,
|
| 527 |
+
download=True)
|
| 528 |
+
loader = torch.utils.data.DataLoader(
|
| 529 |
+
dataset=dataset,
|
| 530 |
+
batch_size=self.batch_size,
|
| 531 |
+
shuffle=False
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
return loader
|
| 535 |
+
|
| 536 |
+
# can also return multiple dataloaders
|
| 537 |
+
def test_dataloader(self):
|
| 538 |
+
return [loader_a, loader_b, ..., loader_n]
|
| 539 |
+
|
| 540 |
+
Note:
|
| 541 |
+
If you don't need a test dataset and a :meth:`test_step`, you don't need to implement
|
| 542 |
+
this method.
|
| 543 |
+
|
| 544 |
+
Note:
|
| 545 |
+
In the case where you return multiple test dataloaders, the :meth:`test_step`
|
| 546 |
+
will have an argument ``dataloader_idx`` which matches the order here.
|
| 547 |
+
"""
|
| 548 |
+
raise MisconfigurationException("`test_dataloader` must be implemented to be used with the Lightning Trainer")
|
| 549 |
+
|
| 550 |
+
def val_dataloader(self) -> EVAL_DATALOADERS:
|
| 551 |
+
r"""
|
| 552 |
+
Implement one or multiple PyTorch DataLoaders for validation.
|
| 553 |
+
|
| 554 |
+
The dataloader you return will not be reloaded unless you set
|
| 555 |
+
:paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to
|
| 556 |
+
a positive integer.
|
| 557 |
+
|
| 558 |
+
It's recommended that all data downloads and preparation happen in :meth:`prepare_data`.
|
| 559 |
+
|
| 560 |
+
- :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit`
|
| 561 |
+
- :meth:`~pytorch_lightning.trainer.trainer.Trainer.validate`
|
| 562 |
+
- :meth:`prepare_data`
|
| 563 |
+
- :meth:`setup`
|
| 564 |
+
|
| 565 |
+
Note:
|
| 566 |
+
Lightning adds the correct sampler for distributed and arbitrary hardware
|
| 567 |
+
There is no need to set it yourself.
|
| 568 |
+
|
| 569 |
+
Return:
|
| 570 |
+
A :class:`torch.utils.data.DataLoader` or a sequence of them specifying validation samples.
|
| 571 |
+
|
| 572 |
+
Examples::
|
| 573 |
+
|
| 574 |
+
def val_dataloader(self):
|
| 575 |
+
transform = transforms.Compose([transforms.ToTensor(),
|
| 576 |
+
transforms.Normalize((0.5,), (1.0,))])
|
| 577 |
+
dataset = MNIST(root='/path/to/mnist/', train=False,
|
| 578 |
+
transform=transform, download=True)
|
| 579 |
+
loader = torch.utils.data.DataLoader(
|
| 580 |
+
dataset=dataset,
|
| 581 |
+
batch_size=self.batch_size,
|
| 582 |
+
shuffle=False
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
return loader
|
| 586 |
+
|
| 587 |
+
# can also return multiple dataloaders
|
| 588 |
+
def val_dataloader(self):
|
| 589 |
+
return [loader_a, loader_b, ..., loader_n]
|
| 590 |
+
|
| 591 |
+
Note:
|
| 592 |
+
If you don't need a validation dataset and a :meth:`validation_step`, you don't need to
|
| 593 |
+
implement this method.
|
| 594 |
+
|
| 595 |
+
Note:
|
| 596 |
+
In the case where you return multiple validation dataloaders, the :meth:`validation_step`
|
| 597 |
+
will have an argument ``dataloader_idx`` which matches the order here.
|
| 598 |
+
"""
|
| 599 |
+
raise MisconfigurationException("`val_dataloader` must be implemented to be used with the Lightning Trainer")
|
| 600 |
+
|
| 601 |
+
def predict_dataloader(self) -> EVAL_DATALOADERS:
|
| 602 |
+
r"""
|
| 603 |
+
Implement one or multiple PyTorch DataLoaders for prediction.
|
| 604 |
+
|
| 605 |
+
It's recommended that all data downloads and preparation happen in :meth:`prepare_data`.
|
| 606 |
+
|
| 607 |
+
- :meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`
|
| 608 |
+
- :meth:`prepare_data`
|
| 609 |
+
- :meth:`setup`
|
| 610 |
+
|
| 611 |
+
Note:
|
| 612 |
+
Lightning adds the correct sampler for distributed and arbitrary hardware
|
| 613 |
+
There is no need to set it yourself.
|
| 614 |
+
|
| 615 |
+
Return:
|
| 616 |
+
A :class:`torch.utils.data.DataLoader` or a sequence of them specifying prediction samples.
|
| 617 |
+
|
| 618 |
+
Note:
|
| 619 |
+
In the case where you return multiple prediction dataloaders, the :meth:`predict_step`
|
| 620 |
+
will have an argument ``dataloader_idx`` which matches the order here.
|
| 621 |
+
"""
|
| 622 |
+
raise MisconfigurationException(
|
| 623 |
+
"`predict_dataloader` must be implemented to be used with the Lightning Trainer"
|
| 624 |
+
)
|
| 625 |
+
|
| 626 |
+
def on_train_dataloader(self) -> None:
|
| 627 |
+
"""Called before requesting the train dataloader.
|
| 628 |
+
|
| 629 |
+
.. deprecated:: v1.5
|
| 630 |
+
:meth:`on_train_dataloader` is deprecated and will be removed in v1.7.0.
|
| 631 |
+
Please use :meth:`train_dataloader()` directly.
|
| 632 |
+
"""
|
| 633 |
+
|
| 634 |
+
def on_val_dataloader(self) -> None:
|
| 635 |
+
"""Called before requesting the val dataloader.
|
| 636 |
+
|
| 637 |
+
.. deprecated:: v1.5
|
| 638 |
+
:meth:`on_val_dataloader` is deprecated and will be removed in v1.7.0.
|
| 639 |
+
Please use :meth:`val_dataloader()` directly.
|
| 640 |
+
"""
|
| 641 |
+
|
| 642 |
+
def on_test_dataloader(self) -> None:
|
| 643 |
+
"""Called before requesting the test dataloader.
|
| 644 |
+
|
| 645 |
+
.. deprecated:: v1.5
|
| 646 |
+
:meth:`on_test_dataloader` is deprecated and will be removed in v1.7.0.
|
| 647 |
+
Please use :meth:`test_dataloader()` directly.
|
| 648 |
+
"""
|
| 649 |
+
|
| 650 |
+
def on_predict_dataloader(self) -> None:
|
| 651 |
+
"""Called before requesting the predict dataloader.
|
| 652 |
+
|
| 653 |
+
.. deprecated:: v1.5
|
| 654 |
+
:meth:`on_predict_dataloader` is deprecated and will be removed in v1.7.0.
|
| 655 |
+
Please use :meth:`predict_dataloader()` directly.
|
| 656 |
+
"""
|
| 657 |
+
|
| 658 |
+
def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any:
|
| 659 |
+
"""Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors wrapped in a custom
|
| 660 |
+
data structure.
|
| 661 |
+
|
| 662 |
+
The data types listed below (and any arbitrary nesting of them) are supported out of the box:
|
| 663 |
+
|
| 664 |
+
- :class:`torch.Tensor` or anything that implements `.to(...)`
|
| 665 |
+
- :class:`list`
|
| 666 |
+
- :class:`dict`
|
| 667 |
+
- :class:`tuple`
|
| 668 |
+
- :class:`torchtext.data.batch.Batch`
|
| 669 |
+
|
| 670 |
+
For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, ...).
|
| 671 |
+
|
| 672 |
+
Note:
|
| 673 |
+
This hook should only transfer the data and not modify it, nor should it move the data to
|
| 674 |
+
any other device than the one passed in as argument (unless you know what you are doing).
|
| 675 |
+
To check the current state of execution of this hook you can use
|
| 676 |
+
``self.trainer.training/testing/validating/predicting`` so that you can
|
| 677 |
+
add different logic as per your requirement.
|
| 678 |
+
|
| 679 |
+
Note:
|
| 680 |
+
This hook only runs on single GPU training and DDP (no data-parallel).
|
| 681 |
+
Data-Parallel support will come in near future.
|
| 682 |
+
|
| 683 |
+
Args:
|
| 684 |
+
batch: A batch of data that needs to be transferred to a new device.
|
| 685 |
+
device: The target device as defined in PyTorch.
|
| 686 |
+
dataloader_idx: The index of the dataloader to which the batch belongs.
|
| 687 |
+
|
| 688 |
+
Returns:
|
| 689 |
+
A reference to the data on the new device.
|
| 690 |
+
|
| 691 |
+
Example::
|
| 692 |
+
|
| 693 |
+
def transfer_batch_to_device(self, batch, device, dataloader_idx):
|
| 694 |
+
if isinstance(batch, CustomBatch):
|
| 695 |
+
# move all tensors in your custom data structure to the device
|
| 696 |
+
batch.samples = batch.samples.to(device)
|
| 697 |
+
batch.targets = batch.targets.to(device)
|
| 698 |
+
elif dataloader_idx == 0:
|
| 699 |
+
# skip device transfer for the first dataloader or anything you wish
|
| 700 |
+
pass
|
| 701 |
+
else:
|
| 702 |
+
batch = super().transfer_batch_to_device(data, device, dataloader_idx)
|
| 703 |
+
return batch
|
| 704 |
+
|
| 705 |
+
Raises:
|
| 706 |
+
MisconfigurationException:
|
| 707 |
+
If using data-parallel, ``Trainer(strategy='dp')``.
|
| 708 |
+
|
| 709 |
+
See Also:
|
| 710 |
+
- :meth:`move_data_to_device`
|
| 711 |
+
- :meth:`apply_to_collection`
|
| 712 |
+
"""
|
| 713 |
+
return move_data_to_device(batch, device)
|
| 714 |
+
|
| 715 |
+
def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any:
|
| 716 |
+
"""Override to alter or apply batch augmentations to your batch before it is transferred to the device.
|
| 717 |
+
|
| 718 |
+
Note:
|
| 719 |
+
To check the current state of execution of this hook you can use
|
| 720 |
+
``self.trainer.training/testing/validating/predicting`` so that you can
|
| 721 |
+
add different logic as per your requirement.
|
| 722 |
+
|
| 723 |
+
Note:
|
| 724 |
+
This hook only runs on single GPU training and DDP (no data-parallel).
|
| 725 |
+
Data-Parallel support will come in near future.
|
| 726 |
+
|
| 727 |
+
Args:
|
| 728 |
+
batch: A batch of data that needs to be altered or augmented.
|
| 729 |
+
dataloader_idx: The index of the dataloader to which the batch belongs.
|
| 730 |
+
|
| 731 |
+
Returns:
|
| 732 |
+
A batch of data
|
| 733 |
+
|
| 734 |
+
Example::
|
| 735 |
+
|
| 736 |
+
def on_before_batch_transfer(self, batch, dataloader_idx):
|
| 737 |
+
batch['x'] = transforms(batch['x'])
|
| 738 |
+
return batch
|
| 739 |
+
|
| 740 |
+
Raises:
|
| 741 |
+
MisconfigurationException:
|
| 742 |
+
If using data-parallel, ``Trainer(strategy='dp')``.
|
| 743 |
+
|
| 744 |
+
See Also:
|
| 745 |
+
- :meth:`on_after_batch_transfer`
|
| 746 |
+
- :meth:`transfer_batch_to_device`
|
| 747 |
+
"""
|
| 748 |
+
return batch
|
| 749 |
+
|
| 750 |
+
def on_after_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any:
|
| 751 |
+
"""Override to alter or apply batch augmentations to your batch after it is transferred to the device.
|
| 752 |
+
|
| 753 |
+
Note:
|
| 754 |
+
To check the current state of execution of this hook you can use
|
| 755 |
+
``self.trainer.training/testing/validating/predicting`` so that you can
|
| 756 |
+
add different logic as per your requirement.
|
| 757 |
+
|
| 758 |
+
Note:
|
| 759 |
+
This hook only runs on single GPU training and DDP (no data-parallel).
|
| 760 |
+
Data-Parallel support will come in near future.
|
| 761 |
+
|
| 762 |
+
Args:
|
| 763 |
+
batch: A batch of data that needs to be altered or augmented.
|
| 764 |
+
dataloader_idx: The index of the dataloader to which the batch belongs.
|
| 765 |
+
|
| 766 |
+
Returns:
|
| 767 |
+
A batch of data
|
| 768 |
+
|
| 769 |
+
Example::
|
| 770 |
+
|
| 771 |
+
def on_after_batch_transfer(self, batch, dataloader_idx):
|
| 772 |
+
batch['x'] = gpu_transforms(batch['x'])
|
| 773 |
+
return batch
|
| 774 |
+
|
| 775 |
+
Raises:
|
| 776 |
+
MisconfigurationException:
|
| 777 |
+
If using data-parallel, ``Trainer(strategy='dp')``.
|
| 778 |
+
|
| 779 |
+
See Also:
|
| 780 |
+
- :meth:`on_before_batch_transfer`
|
| 781 |
+
- :meth:`transfer_batch_to_device`
|
| 782 |
+
"""
|
| 783 |
+
return batch
|
| 784 |
+
|
| 785 |
+
|
| 786 |
+
class CheckpointHooks:
|
| 787 |
+
"""Hooks to be used with Checkpointing."""
|
| 788 |
+
|
| 789 |
+
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
| 790 |
+
r"""
|
| 791 |
+
Called by Lightning to restore your model.
|
| 792 |
+
If you saved something with :meth:`on_save_checkpoint` this is your chance to restore this.
|
| 793 |
+
|
| 794 |
+
Args:
|
| 795 |
+
checkpoint: Loaded checkpoint
|
| 796 |
+
|
| 797 |
+
Example::
|
| 798 |
+
|
| 799 |
+
def on_load_checkpoint(self, checkpoint):
|
| 800 |
+
# 99% of the time you don't need to implement this method
|
| 801 |
+
self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save']
|
| 802 |
+
|
| 803 |
+
Note:
|
| 804 |
+
Lightning auto-restores global step, epoch, and train state including amp scaling.
|
| 805 |
+
There is no need for you to restore anything regarding training.
|
| 806 |
+
"""
|
| 807 |
+
|
| 808 |
+
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
| 809 |
+
r"""
|
| 810 |
+
Called by Lightning when saving a checkpoint to give you a chance to store anything
|
| 811 |
+
else you might want to save.
|
| 812 |
+
|
| 813 |
+
Args:
|
| 814 |
+
checkpoint: The full checkpoint dictionary before it gets dumped to a file.
|
| 815 |
+
Implementations of this hook can insert additional data into this dictionary.
|
| 816 |
+
|
| 817 |
+
Example::
|
| 818 |
+
|
| 819 |
+
def on_save_checkpoint(self, checkpoint):
|
| 820 |
+
# 99% of use cases you don't need to implement this method
|
| 821 |
+
checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object
|
| 822 |
+
|
| 823 |
+
Note:
|
| 824 |
+
Lightning saves all aspects of training (epoch, global step, etc...)
|
| 825 |
+
including amp scaling.
|
| 826 |
+
There is no need for you to store anything about training.
|
| 827 |
+
|
| 828 |
+
"""
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py
ADDED
|
@@ -0,0 +1,409 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright The PyTorch Lightning team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from contextlib import contextmanager
|
| 15 |
+
from dataclasses import fields
|
| 16 |
+
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
|
| 17 |
+
from weakref import proxy
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
from torch import optim
|
| 21 |
+
from torch.optim import Optimizer
|
| 22 |
+
|
| 23 |
+
import pytorch_lightning as pl
|
| 24 |
+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
| 25 |
+
from pytorch_lightning.utilities.model_helpers import is_overridden
|
| 26 |
+
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
| 27 |
+
from pytorch_lightning.utilities.types import _Stateful, LRSchedulerConfig, LRSchedulerTypeTuple, ReduceLROnPlateau
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def do_nothing_closure() -> None:
|
| 31 |
+
return
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class LightningOptimizer:
|
| 35 |
+
"""This class is used to wrap the user optimizers and handle properly the backward and optimizer_step logic
|
| 36 |
+
across accelerators, AMP, accumulate_grad_batches."""
|
| 37 |
+
|
| 38 |
+
def __init__(self, optimizer: Optimizer):
|
| 39 |
+
# copy most of the `Optimizer` methods into this instance. `__del__` is skipped in case the optimizer has
|
| 40 |
+
# implemented custom logic which we would not want to call on destruction of the `LightningOptimizer`
|
| 41 |
+
self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k not in ("step", "__del__")}
|
| 42 |
+
|
| 43 |
+
# For Horovod
|
| 44 |
+
if hasattr(optimizer, "skip_synchronize"):
|
| 45 |
+
self.__class__ = type(
|
| 46 |
+
"Lightning" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__.__bases__[0]), {}
|
| 47 |
+
)
|
| 48 |
+
self.skip_synchronize = optimizer.skip_synchronize
|
| 49 |
+
self.synchronize = optimizer.synchronize
|
| 50 |
+
else:
|
| 51 |
+
self.__class__ = type("Lightning" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {})
|
| 52 |
+
|
| 53 |
+
self._optimizer = optimizer
|
| 54 |
+
self._strategy: Optional[pl.strategies.Strategy] = None
|
| 55 |
+
self._optimizer_idx = 0
|
| 56 |
+
# to inject logic around the optimizer step, particularly useful with manual optimization
|
| 57 |
+
self._on_before_step = do_nothing_closure
|
| 58 |
+
self._on_after_step = do_nothing_closure
|
| 59 |
+
|
| 60 |
+
@property
|
| 61 |
+
def optimizer(self) -> Optimizer:
|
| 62 |
+
return self._optimizer
|
| 63 |
+
|
| 64 |
+
@classmethod
|
| 65 |
+
def _to_lightning_optimizer(
|
| 66 |
+
cls, optimizer: Union[Optimizer, "LightningOptimizer"], strategy: "pl.strategies.Strategy", opt_idx: int
|
| 67 |
+
) -> "LightningOptimizer":
|
| 68 |
+
if isinstance(optimizer, LightningOptimizer):
|
| 69 |
+
# the user could return a `LightningOptimizer` from `configure_optimizers`, see test:
|
| 70 |
+
# tests/core/test_lightning_optimizer.py::test_lightning_optimizer[False]
|
| 71 |
+
lightning_optimizer = optimizer
|
| 72 |
+
else:
|
| 73 |
+
lightning_optimizer = cls(optimizer)
|
| 74 |
+
lightning_optimizer._strategy = proxy(strategy)
|
| 75 |
+
lightning_optimizer._optimizer_idx = opt_idx
|
| 76 |
+
return lightning_optimizer
|
| 77 |
+
|
| 78 |
+
@contextmanager
|
| 79 |
+
def toggle_model(self, sync_grad: bool = True) -> Generator[None, None, None]:
|
| 80 |
+
"""This function is just a helper for advanced users.
|
| 81 |
+
|
| 82 |
+
Considering the current optimizer as A and all other optimizers as B.
|
| 83 |
+
Toggling means all parameters from B exclusive to A will have ``requires_grad`` set to False.
|
| 84 |
+
|
| 85 |
+
When performing gradient accumulation, there is no need to perform grad synchronization
|
| 86 |
+
during the accumulation phase.
|
| 87 |
+
Setting `sync_grad` to False will block this synchronization and improve performance.
|
| 88 |
+
"""
|
| 89 |
+
# local import here to avoid circular import
|
| 90 |
+
from pytorch_lightning.loops.utilities import _block_parallel_sync_behavior
|
| 91 |
+
|
| 92 |
+
assert self._strategy is not None
|
| 93 |
+
lightning_module = self._strategy.lightning_module
|
| 94 |
+
assert lightning_module is not None
|
| 95 |
+
with _block_parallel_sync_behavior(self._strategy, block=(not sync_grad)):
|
| 96 |
+
lightning_module.toggle_optimizer(self, self._optimizer_idx)
|
| 97 |
+
yield
|
| 98 |
+
lightning_module.untoggle_optimizer(self._optimizer_idx)
|
| 99 |
+
|
| 100 |
+
def step(self, closure: Optional[Callable[[], Any]] = None, **kwargs: Any) -> Any:
|
| 101 |
+
"""Performs a single optimization step (parameter update).
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
closure: An optional optimizer closure.
|
| 105 |
+
kwargs: Any additional arguments to the ``optimizer.step()`` call.
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
The output from the step call, which is generally the output of the closure execution.
|
| 109 |
+
|
| 110 |
+
Example::
|
| 111 |
+
|
| 112 |
+
# Scenario for a GAN using manual optimization
|
| 113 |
+
def training_step(...):
|
| 114 |
+
opt_gen, opt_dis = self.optimizers()
|
| 115 |
+
|
| 116 |
+
...
|
| 117 |
+
|
| 118 |
+
# compute generator loss
|
| 119 |
+
loss_gen = self.compute_generator_loss(...)
|
| 120 |
+
# zero_grad needs to be called before backward
|
| 121 |
+
opt_gen.zero_grad()
|
| 122 |
+
self.manual_backward(loss_gen)
|
| 123 |
+
opt_gen.step()
|
| 124 |
+
|
| 125 |
+
# compute discriminator loss
|
| 126 |
+
loss_dis = self.compute_discriminator_loss(...)
|
| 127 |
+
|
| 128 |
+
# zero_grad needs to be called before backward
|
| 129 |
+
opt_dis.zero_grad()
|
| 130 |
+
self.manual_backward(loss_dis)
|
| 131 |
+
opt_dis.step()
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
# A more advanced example
|
| 135 |
+
def training_step(self, batch, batch_idx, ...):
|
| 136 |
+
opt_gen, opt_dis = self.optimizers()
|
| 137 |
+
|
| 138 |
+
...
|
| 139 |
+
accumulated_grad_batches = batch_idx % 2 == 0
|
| 140 |
+
|
| 141 |
+
# compute generator loss
|
| 142 |
+
def closure_gen():
|
| 143 |
+
loss_gen = self.compute_generator_loss(...)
|
| 144 |
+
self.manual_backward(loss_gen)
|
| 145 |
+
if accumulated_grad_batches:
|
| 146 |
+
opt_gen.zero_grad()
|
| 147 |
+
|
| 148 |
+
with opt_gen.toggle_model(sync_grad=accumulated_grad_batches):
|
| 149 |
+
opt_gen.step(closure=closure_gen)
|
| 150 |
+
|
| 151 |
+
def closure_dis():
|
| 152 |
+
loss_dis = self.compute_discriminator_loss(...)
|
| 153 |
+
self.manual_backward(loss_dis)
|
| 154 |
+
if accumulated_grad_batches:
|
| 155 |
+
opt_dis.zero_grad()
|
| 156 |
+
|
| 157 |
+
with opt_dis.toggle_model(sync_grad=accumulated_grad_batches):
|
| 158 |
+
opt_dis.step(closure=closure_dis)
|
| 159 |
+
"""
|
| 160 |
+
self._on_before_step()
|
| 161 |
+
|
| 162 |
+
if closure is None:
|
| 163 |
+
closure = do_nothing_closure
|
| 164 |
+
elif not callable(closure):
|
| 165 |
+
raise MisconfigurationException("When `optimizer.step(closure)` is called, the closure should be callable")
|
| 166 |
+
|
| 167 |
+
assert self._strategy is not None
|
| 168 |
+
step_output = self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
|
| 169 |
+
|
| 170 |
+
self._on_after_step()
|
| 171 |
+
|
| 172 |
+
return step_output
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def _init_optimizers_and_lr_schedulers(
|
| 176 |
+
model: "pl.LightningModule",
|
| 177 |
+
) -> Tuple[List[Optimizer], List[LRSchedulerConfig], List[int]]:
|
| 178 |
+
"""Calls `LightningModule.configure_optimizers` and parses and validates the output."""
|
| 179 |
+
assert model.trainer is not None
|
| 180 |
+
optim_conf = model.trainer._call_lightning_module_hook("configure_optimizers", pl_module=model)
|
| 181 |
+
|
| 182 |
+
if optim_conf is None:
|
| 183 |
+
rank_zero_warn(
|
| 184 |
+
"`LightningModule.configure_optimizers` returned `None`, this fit will run with no optimizer",
|
| 185 |
+
)
|
| 186 |
+
optim_conf = _MockOptimizer()
|
| 187 |
+
|
| 188 |
+
optimizers, lr_schedulers, optimizer_frequencies, monitor = _configure_optimizers(optim_conf)
|
| 189 |
+
lr_scheduler_configs = (
|
| 190 |
+
_configure_schedulers_automatic_opt(lr_schedulers, monitor)
|
| 191 |
+
if model.automatic_optimization
|
| 192 |
+
else _configure_schedulers_manual_opt(lr_schedulers)
|
| 193 |
+
)
|
| 194 |
+
_set_scheduler_opt_idx(optimizers, lr_scheduler_configs)
|
| 195 |
+
_validate_scheduler_api(lr_scheduler_configs, model)
|
| 196 |
+
return optimizers, lr_scheduler_configs, optimizer_frequencies
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def _configure_optimizers(
|
| 200 |
+
optim_conf: Union[Dict[str, Any], List, Optimizer, Tuple]
|
| 201 |
+
) -> Tuple[List, List, List, Optional[str]]:
|
| 202 |
+
optimizers, lr_schedulers, optimizer_frequencies = [], [], []
|
| 203 |
+
monitor = None
|
| 204 |
+
|
| 205 |
+
# single output, single optimizer
|
| 206 |
+
if isinstance(optim_conf, Optimizer):
|
| 207 |
+
optimizers = [optim_conf]
|
| 208 |
+
# two lists, optimizer + lr schedulers
|
| 209 |
+
elif (
|
| 210 |
+
isinstance(optim_conf, (list, tuple))
|
| 211 |
+
and len(optim_conf) == 2
|
| 212 |
+
and isinstance(optim_conf[0], list)
|
| 213 |
+
and all(isinstance(opt, Optimizer) for opt in optim_conf[0])
|
| 214 |
+
):
|
| 215 |
+
opt, sch = optim_conf
|
| 216 |
+
optimizers = opt
|
| 217 |
+
lr_schedulers = sch if isinstance(sch, list) else [sch]
|
| 218 |
+
# single dictionary
|
| 219 |
+
elif isinstance(optim_conf, dict):
|
| 220 |
+
_validate_optim_conf(optim_conf)
|
| 221 |
+
optimizers = [optim_conf["optimizer"]]
|
| 222 |
+
monitor = optim_conf.get("monitor", None)
|
| 223 |
+
lr_schedulers = [optim_conf["lr_scheduler"]] if "lr_scheduler" in optim_conf else []
|
| 224 |
+
# multiple dictionaries
|
| 225 |
+
elif isinstance(optim_conf, (list, tuple)) and all(isinstance(d, dict) for d in optim_conf):
|
| 226 |
+
for opt_dict in optim_conf:
|
| 227 |
+
_validate_optim_conf(opt_dict)
|
| 228 |
+
optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf]
|
| 229 |
+
scheduler_dict = (
|
| 230 |
+
lambda scheduler, opt_idx: dict(scheduler, opt_idx=opt_idx)
|
| 231 |
+
if isinstance(scheduler, dict)
|
| 232 |
+
else {"scheduler": scheduler, "opt_idx": opt_idx}
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
lr_schedulers = [
|
| 236 |
+
scheduler_dict(opt_dict["lr_scheduler"], opt_idx)
|
| 237 |
+
for opt_idx, opt_dict in enumerate(optim_conf)
|
| 238 |
+
if "lr_scheduler" in opt_dict
|
| 239 |
+
]
|
| 240 |
+
optimizer_frequencies = [
|
| 241 |
+
opt_dict["frequency"] for opt_dict in optim_conf if opt_dict.get("frequency", None) is not None
|
| 242 |
+
]
|
| 243 |
+
# assert that if frequencies are present, they are given for all optimizers
|
| 244 |
+
if optimizer_frequencies and len(optimizer_frequencies) != len(optimizers):
|
| 245 |
+
raise ValueError("A frequency must be given to each optimizer.")
|
| 246 |
+
# single list or tuple, multiple optimizer
|
| 247 |
+
elif isinstance(optim_conf, (list, tuple)) and all(isinstance(opt, Optimizer) for opt in optim_conf):
|
| 248 |
+
optimizers = list(optim_conf)
|
| 249 |
+
# unknown configuration
|
| 250 |
+
else:
|
| 251 |
+
raise MisconfigurationException(
|
| 252 |
+
"Unknown configuration for model optimizers."
|
| 253 |
+
" Output from `model.configure_optimizers()` should be one of:\n"
|
| 254 |
+
" * `Optimizer`\n"
|
| 255 |
+
" * [`Optimizer`]\n"
|
| 256 |
+
" * ([`Optimizer`], [`_LRScheduler`])\n"
|
| 257 |
+
' * {"optimizer": `Optimizer`, (optional) "lr_scheduler": `_LRScheduler`}\n'
|
| 258 |
+
' * A list of the previously described dict format, with an optional "frequency" key (int)'
|
| 259 |
+
)
|
| 260 |
+
return optimizers, lr_schedulers, optimizer_frequencies, monitor
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str]) -> List[LRSchedulerConfig]:
|
| 264 |
+
"""Convert each scheduler into `LRSchedulerConfig` with relevant information, when using automatic
|
| 265 |
+
optimization."""
|
| 266 |
+
lr_scheduler_configs = []
|
| 267 |
+
for scheduler in schedulers:
|
| 268 |
+
if isinstance(scheduler, dict):
|
| 269 |
+
# check provided keys
|
| 270 |
+
supported_keys = {field.name for field in fields(LRSchedulerConfig)}
|
| 271 |
+
extra_keys = scheduler.keys() - supported_keys
|
| 272 |
+
if extra_keys:
|
| 273 |
+
rank_zero_warn(
|
| 274 |
+
f"Found unsupported keys in the lr scheduler dict: {extra_keys}."
|
| 275 |
+
" HINT: remove them from the output of `configure_optimizers`.",
|
| 276 |
+
category=RuntimeWarning,
|
| 277 |
+
)
|
| 278 |
+
scheduler = {k: v for k, v in scheduler.items() if k in supported_keys}
|
| 279 |
+
if "scheduler" not in scheduler:
|
| 280 |
+
raise MisconfigurationException(
|
| 281 |
+
'The lr scheduler dict must have the key "scheduler" with its item being an lr scheduler'
|
| 282 |
+
)
|
| 283 |
+
if "interval" in scheduler and scheduler["interval"] not in ("step", "epoch"):
|
| 284 |
+
raise MisconfigurationException(
|
| 285 |
+
'The "interval" key in lr scheduler dict must be "step" or "epoch"'
|
| 286 |
+
f' but is "{scheduler["interval"]}"'
|
| 287 |
+
)
|
| 288 |
+
scheduler["reduce_on_plateau"] = isinstance(scheduler["scheduler"], optim.lr_scheduler.ReduceLROnPlateau)
|
| 289 |
+
if scheduler["reduce_on_plateau"] and scheduler.get("monitor", None) is None:
|
| 290 |
+
raise MisconfigurationException(
|
| 291 |
+
"The lr scheduler dict must include a monitor when a `ReduceLROnPlateau` scheduler is used."
|
| 292 |
+
' For example: {"optimizer": optimizer, "lr_scheduler":'
|
| 293 |
+
' {"scheduler": scheduler, "monitor": "your_loss"}}'
|
| 294 |
+
)
|
| 295 |
+
is_one_cycle = isinstance(scheduler["scheduler"], optim.lr_scheduler.OneCycleLR)
|
| 296 |
+
if is_one_cycle and scheduler.get("interval", "epoch") == "epoch":
|
| 297 |
+
rank_zero_warn(
|
| 298 |
+
"A `OneCycleLR` scheduler is using 'interval': 'epoch'."
|
| 299 |
+
" Are you sure you didn't mean 'interval': 'step'?",
|
| 300 |
+
category=RuntimeWarning,
|
| 301 |
+
)
|
| 302 |
+
config = LRSchedulerConfig(**scheduler)
|
| 303 |
+
elif isinstance(scheduler, ReduceLROnPlateau):
|
| 304 |
+
if monitor is None:
|
| 305 |
+
raise MisconfigurationException(
|
| 306 |
+
"`configure_optimizers` must include a monitor when a `ReduceLROnPlateau`"
|
| 307 |
+
" scheduler is used. For example:"
|
| 308 |
+
' {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}'
|
| 309 |
+
)
|
| 310 |
+
config = LRSchedulerConfig(scheduler, reduce_on_plateau=True, monitor=monitor)
|
| 311 |
+
else:
|
| 312 |
+
config = LRSchedulerConfig(scheduler)
|
| 313 |
+
lr_scheduler_configs.append(config)
|
| 314 |
+
return lr_scheduler_configs
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
def _configure_schedulers_manual_opt(schedulers: list) -> List[LRSchedulerConfig]:
|
| 318 |
+
"""Convert each scheduler into `LRSchedulerConfig` structure with relevant information, when using manual
|
| 319 |
+
optimization."""
|
| 320 |
+
lr_scheduler_configs = []
|
| 321 |
+
for scheduler in schedulers:
|
| 322 |
+
if isinstance(scheduler, dict):
|
| 323 |
+
invalid_keys = {"interval", "frequency", "reduce_on_plateau", "monitor", "strict"}
|
| 324 |
+
keys_to_warn = [k for k in scheduler.keys() if k in invalid_keys]
|
| 325 |
+
|
| 326 |
+
if keys_to_warn:
|
| 327 |
+
rank_zero_warn(
|
| 328 |
+
f"The lr scheduler dict contains the key(s) {keys_to_warn}, but the keys will be ignored."
|
| 329 |
+
" You need to call `lr_scheduler.step()` manually in manual optimization.",
|
| 330 |
+
category=RuntimeWarning,
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
config = LRSchedulerConfig(**{key: scheduler[key] for key in scheduler if key not in invalid_keys})
|
| 334 |
+
else:
|
| 335 |
+
config = LRSchedulerConfig(scheduler)
|
| 336 |
+
lr_scheduler_configs.append(config)
|
| 337 |
+
return lr_scheduler_configs
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def _validate_scheduler_api(lr_scheduler_configs: List[LRSchedulerConfig], model: "pl.LightningModule") -> None:
|
| 341 |
+
for config in lr_scheduler_configs:
|
| 342 |
+
scheduler = config.scheduler
|
| 343 |
+
if not isinstance(scheduler, _Stateful):
|
| 344 |
+
raise TypeError(
|
| 345 |
+
f"The provided lr scheduler `{scheduler.__class__.__name__}` is invalid."
|
| 346 |
+
" It should have `state_dict` and `load_state_dict` methods defined."
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
if not isinstance(scheduler, LRSchedulerTypeTuple) and not is_overridden("lr_scheduler_step", model):
|
| 350 |
+
raise MisconfigurationException(
|
| 351 |
+
f"The provided lr scheduler `{scheduler.__class__.__name__}` doesn't follow PyTorch's LRScheduler"
|
| 352 |
+
" API. You should override the `LightningModule.lr_scheduler_step` hook with your own logic if"
|
| 353 |
+
" you are using a custom LR scheduler."
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def _set_scheduler_opt_idx(optimizers: List[Optimizer], lr_scheduler_configs: List[LRSchedulerConfig]) -> None:
|
| 358 |
+
for config in lr_scheduler_configs:
|
| 359 |
+
|
| 360 |
+
for opt_idx, opt in enumerate(optimizers):
|
| 361 |
+
if config.scheduler.optimizer is opt:
|
| 362 |
+
if config.opt_idx is not None and config.opt_idx != opt_idx:
|
| 363 |
+
raise MisconfigurationException(
|
| 364 |
+
"`opt_idx` set inside scheduler config does not match with the index"
|
| 365 |
+
" of the respective optimizer returned from `configure_optimizers`."
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
config.opt_idx = opt_idx
|
| 369 |
+
break
|
| 370 |
+
else:
|
| 371 |
+
raise MisconfigurationException(
|
| 372 |
+
"Some schedulers are attached with an optimizer that wasn't returned from `configure_optimizers`."
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
def _validate_optim_conf(optim_conf: Dict[str, Any]) -> None:
|
| 377 |
+
valid_keys = {"optimizer", "lr_scheduler", "frequency", "monitor"}
|
| 378 |
+
extra_keys = optim_conf.keys() - valid_keys
|
| 379 |
+
if extra_keys:
|
| 380 |
+
rank_zero_warn(
|
| 381 |
+
f"Found unsupported keys in the optimizer configuration: {set(extra_keys)}", category=RuntimeWarning
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
class _MockOptimizer(Optimizer):
|
| 386 |
+
"""The `_MockOptimizer` will be used inplace of an optimizer in the event that `None` is returned from
|
| 387 |
+
`configure_optimizers`."""
|
| 388 |
+
|
| 389 |
+
def __init__(self) -> None:
|
| 390 |
+
super().__init__([torch.zeros(1)], {})
|
| 391 |
+
|
| 392 |
+
def add_param_group(self, param_group: Dict[Any, Any]) -> None:
|
| 393 |
+
pass # Do Nothing
|
| 394 |
+
|
| 395 |
+
def load_state_dict(self, state_dict: Dict[Any, Any]) -> None:
|
| 396 |
+
pass # Do Nothing
|
| 397 |
+
|
| 398 |
+
def state_dict(self) -> Dict[str, Any]:
|
| 399 |
+
return {} # Return Empty
|
| 400 |
+
|
| 401 |
+
def step(self, closure: Callable = None) -> None:
|
| 402 |
+
if closure is not None:
|
| 403 |
+
closure()
|
| 404 |
+
|
| 405 |
+
def zero_grad(self, set_to_none: Optional[bool] = False) -> None:
|
| 406 |
+
pass # Do Nothing
|
| 407 |
+
|
| 408 |
+
def __repr__(self) -> str:
|
| 409 |
+
return "No Optimizer"
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/core/saving.py
ADDED
|
@@ -0,0 +1,419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright The PyTorch Lightning team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import ast
|
| 16 |
+
import csv
|
| 17 |
+
import inspect
|
| 18 |
+
import logging
|
| 19 |
+
import os
|
| 20 |
+
from argparse import Namespace
|
| 21 |
+
from copy import deepcopy
|
| 22 |
+
from enum import Enum
|
| 23 |
+
from typing import Any, Callable, Dict, IO, MutableMapping, Optional, Union
|
| 24 |
+
from warnings import warn
|
| 25 |
+
|
| 26 |
+
import torch
|
| 27 |
+
import yaml
|
| 28 |
+
|
| 29 |
+
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, AttributeDict
|
| 30 |
+
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
| 31 |
+
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
| 32 |
+
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
| 33 |
+
from pytorch_lightning.utilities.migration import pl_legacy_patch
|
| 34 |
+
from pytorch_lightning.utilities.parsing import parse_class_init_keys
|
| 35 |
+
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
| 36 |
+
|
| 37 |
+
log = logging.getLogger(__name__)
|
| 38 |
+
PRIMITIVE_TYPES = (bool, int, float, str)
|
| 39 |
+
ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace)
|
| 40 |
+
|
| 41 |
+
if _OMEGACONF_AVAILABLE:
|
| 42 |
+
from omegaconf import OmegaConf
|
| 43 |
+
from omegaconf.dictconfig import DictConfig
|
| 44 |
+
from omegaconf.errors import UnsupportedValueType, ValidationError
|
| 45 |
+
|
| 46 |
+
# the older shall be on the top
|
| 47 |
+
CHECKPOINT_PAST_HPARAMS_KEYS = ("hparams", "module_arguments") # used in 0.7.6
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class ModelIO:
|
| 51 |
+
CHECKPOINT_HYPER_PARAMS_KEY = "hyper_parameters"
|
| 52 |
+
CHECKPOINT_HYPER_PARAMS_NAME = "hparams_name"
|
| 53 |
+
CHECKPOINT_HYPER_PARAMS_TYPE = "hparams_type"
|
| 54 |
+
|
| 55 |
+
@classmethod
|
| 56 |
+
def load_from_checkpoint(
|
| 57 |
+
cls,
|
| 58 |
+
checkpoint_path: Union[str, IO],
|
| 59 |
+
map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
|
| 60 |
+
hparams_file: Optional[str] = None,
|
| 61 |
+
strict: bool = True,
|
| 62 |
+
**kwargs,
|
| 63 |
+
):
|
| 64 |
+
r"""
|
| 65 |
+
Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint
|
| 66 |
+
it stores the arguments passed to ``__init__`` in the checkpoint under ``"hyper_parameters"``.
|
| 67 |
+
|
| 68 |
+
Any arguments specified through \*\*kwargs will override args stored in ``"hyper_parameters"``.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
checkpoint_path: Path to checkpoint. This can also be a URL, or file-like object
|
| 72 |
+
map_location:
|
| 73 |
+
If your checkpoint saved a GPU model and you now load on CPUs
|
| 74 |
+
or a different number of GPUs, use this to map to the new setup.
|
| 75 |
+
The behaviour is the same as in :func:`torch.load`.
|
| 76 |
+
hparams_file: Optional path to a .yaml file with hierarchical structure
|
| 77 |
+
as in this example::
|
| 78 |
+
|
| 79 |
+
drop_prob: 0.2
|
| 80 |
+
dataloader:
|
| 81 |
+
batch_size: 32
|
| 82 |
+
|
| 83 |
+
You most likely won't need this since Lightning will always save the hyperparameters
|
| 84 |
+
to the checkpoint.
|
| 85 |
+
However, if your checkpoint weights don't have the hyperparameters saved,
|
| 86 |
+
use this method to pass in a .yaml file with the hparams you'd like to use.
|
| 87 |
+
These will be converted into a :class:`~dict` and passed into your
|
| 88 |
+
:class:`LightningModule` for use.
|
| 89 |
+
|
| 90 |
+
If your model's ``hparams`` argument is :class:`~argparse.Namespace`
|
| 91 |
+
and .yaml file has hierarchical structure, you need to refactor your model to treat
|
| 92 |
+
``hparams`` as :class:`~dict`.
|
| 93 |
+
strict: Whether to strictly enforce that the keys in :attr:`checkpoint_path` match the keys
|
| 94 |
+
returned by this module's state dict.
|
| 95 |
+
kwargs: Any extra keyword args needed to init the model. Can also be used to override saved
|
| 96 |
+
hyperparameter values.
|
| 97 |
+
|
| 98 |
+
Return:
|
| 99 |
+
:class:`LightningModule` instance with loaded weights and hyperparameters (if available).
|
| 100 |
+
|
| 101 |
+
Note:
|
| 102 |
+
``load_from_checkpoint`` is a **class** method. You should use your :class:`LightningModule`
|
| 103 |
+
**class** to call it instead of the :class:`LightningModule` instance.
|
| 104 |
+
|
| 105 |
+
Example::
|
| 106 |
+
|
| 107 |
+
# load weights without mapping ...
|
| 108 |
+
model = MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt')
|
| 109 |
+
|
| 110 |
+
# or load weights mapping all weights from GPU 1 to GPU 0 ...
|
| 111 |
+
map_location = {'cuda:1':'cuda:0'}
|
| 112 |
+
model = MyLightningModule.load_from_checkpoint(
|
| 113 |
+
'path/to/checkpoint.ckpt',
|
| 114 |
+
map_location=map_location
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
# or load weights and hyperparameters from separate files.
|
| 118 |
+
model = MyLightningModule.load_from_checkpoint(
|
| 119 |
+
'path/to/checkpoint.ckpt',
|
| 120 |
+
hparams_file='/path/to/hparams_file.yaml'
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
# override some of the params with new values
|
| 124 |
+
model = MyLightningModule.load_from_checkpoint(
|
| 125 |
+
PATH,
|
| 126 |
+
num_layers=128,
|
| 127 |
+
pretrained_ckpt_path=NEW_PATH,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# predict
|
| 131 |
+
pretrained_model.eval()
|
| 132 |
+
pretrained_model.freeze()
|
| 133 |
+
y_hat = pretrained_model(x)
|
| 134 |
+
"""
|
| 135 |
+
with pl_legacy_patch():
|
| 136 |
+
if map_location is not None:
|
| 137 |
+
checkpoint = pl_load(checkpoint_path, map_location=map_location)
|
| 138 |
+
else:
|
| 139 |
+
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)
|
| 140 |
+
|
| 141 |
+
if hparams_file is not None:
|
| 142 |
+
extension = hparams_file.split(".")[-1]
|
| 143 |
+
if extension.lower() == "csv":
|
| 144 |
+
hparams = load_hparams_from_tags_csv(hparams_file)
|
| 145 |
+
elif extension.lower() in ("yml", "yaml"):
|
| 146 |
+
hparams = load_hparams_from_yaml(hparams_file)
|
| 147 |
+
else:
|
| 148 |
+
raise ValueError(".csv, .yml or .yaml is required for `hparams_file`")
|
| 149 |
+
|
| 150 |
+
hparams["on_gpu"] = False
|
| 151 |
+
|
| 152 |
+
# overwrite hparams by the given file
|
| 153 |
+
checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = hparams
|
| 154 |
+
|
| 155 |
+
# for past checkpoint need to add the new key
|
| 156 |
+
if cls.CHECKPOINT_HYPER_PARAMS_KEY not in checkpoint:
|
| 157 |
+
checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = {}
|
| 158 |
+
# override the hparams with values that were passed in
|
| 159 |
+
checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(kwargs)
|
| 160 |
+
|
| 161 |
+
model = cls._load_model_state(checkpoint, strict=strict, **kwargs)
|
| 162 |
+
return model
|
| 163 |
+
|
| 164 |
+
@classmethod
|
| 165 |
+
def _load_model_state(cls, checkpoint: Dict[str, Any], strict: bool = True, **cls_kwargs_new):
|
| 166 |
+
cls_spec = inspect.getfullargspec(cls.__init__)
|
| 167 |
+
cls_init_args_name = inspect.signature(cls.__init__).parameters.keys()
|
| 168 |
+
|
| 169 |
+
self_var, args_var, kwargs_var = parse_class_init_keys(cls)
|
| 170 |
+
drop_names = [n for n in (self_var, args_var, kwargs_var) if n]
|
| 171 |
+
cls_init_args_name = list(filter(lambda n: n not in drop_names, cls_init_args_name))
|
| 172 |
+
|
| 173 |
+
cls_kwargs_loaded = {}
|
| 174 |
+
# pass in the values we saved automatically
|
| 175 |
+
if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
|
| 176 |
+
|
| 177 |
+
# 1. (backward compatibility) Try to restore model hparams from checkpoint using old/past keys
|
| 178 |
+
for _old_hparam_key in CHECKPOINT_PAST_HPARAMS_KEYS:
|
| 179 |
+
cls_kwargs_loaded.update(checkpoint.get(_old_hparam_key, {}))
|
| 180 |
+
|
| 181 |
+
# 2. Try to restore model hparams from checkpoint using the new key
|
| 182 |
+
_new_hparam_key = cls.CHECKPOINT_HYPER_PARAMS_KEY
|
| 183 |
+
cls_kwargs_loaded.update(checkpoint.get(_new_hparam_key))
|
| 184 |
+
|
| 185 |
+
# 3. Ensure that `cls_kwargs_old` has the right type, back compatibility between dict and Namespace
|
| 186 |
+
cls_kwargs_loaded = _convert_loaded_hparams(
|
| 187 |
+
cls_kwargs_loaded, checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_TYPE)
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
# 4. Update cls_kwargs_new with cls_kwargs_old, such that new has higher priority
|
| 191 |
+
args_name = checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_NAME)
|
| 192 |
+
if args_name and args_name in cls_init_args_name:
|
| 193 |
+
cls_kwargs_loaded = {args_name: cls_kwargs_loaded}
|
| 194 |
+
|
| 195 |
+
_cls_kwargs = {}
|
| 196 |
+
_cls_kwargs.update(cls_kwargs_loaded)
|
| 197 |
+
_cls_kwargs.update(cls_kwargs_new)
|
| 198 |
+
|
| 199 |
+
if not cls_spec.varkw:
|
| 200 |
+
# filter kwargs according to class init unless it allows any argument via kwargs
|
| 201 |
+
_cls_kwargs = {k: v for k, v in _cls_kwargs.items() if k in cls_init_args_name}
|
| 202 |
+
|
| 203 |
+
model = cls(**_cls_kwargs)
|
| 204 |
+
|
| 205 |
+
# give model a chance to load something
|
| 206 |
+
model.on_load_checkpoint(checkpoint)
|
| 207 |
+
|
| 208 |
+
# load the state_dict on the model automatically
|
| 209 |
+
keys = model.load_state_dict(checkpoint["state_dict"], strict=strict)
|
| 210 |
+
|
| 211 |
+
if not strict:
|
| 212 |
+
if keys.missing_keys:
|
| 213 |
+
rank_zero_warn(
|
| 214 |
+
f"Found keys that are in the model state dict but not in the checkpoint: {keys.missing_keys}"
|
| 215 |
+
)
|
| 216 |
+
if keys.unexpected_keys:
|
| 217 |
+
rank_zero_warn(
|
| 218 |
+
f"Found keys that are not in the model state dict but in the checkpoint: {keys.unexpected_keys}"
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
return model
|
| 222 |
+
|
| 223 |
+
# -------------------------
|
| 224 |
+
# OPTIONAL HOOKS
|
| 225 |
+
# -------------------------
|
| 226 |
+
def on_hpc_save(self, checkpoint: Dict[str, Any]) -> None:
|
| 227 |
+
"""Hook to do whatever you need right before Slurm manager saves the model.
|
| 228 |
+
|
| 229 |
+
Args:
|
| 230 |
+
checkpoint: A dictionary in which you can save variables to save in a checkpoint.
|
| 231 |
+
Contents need to be pickleable.
|
| 232 |
+
|
| 233 |
+
.. deprecated:: v1.6
|
| 234 |
+
This method is deprecated in v1.6 and will be removed in v1.8.
|
| 235 |
+
Please use ``LightningModule.on_save_checkpoint`` instead.
|
| 236 |
+
"""
|
| 237 |
+
|
| 238 |
+
def on_hpc_load(self, checkpoint: Dict[str, Any]) -> None:
|
| 239 |
+
"""Hook to do whatever you need right before Slurm manager loads the model.
|
| 240 |
+
|
| 241 |
+
Args:
|
| 242 |
+
checkpoint: A dictionary with variables from the checkpoint.
|
| 243 |
+
|
| 244 |
+
.. deprecated:: v1.6
|
| 245 |
+
This method is deprecated in v1.6 and will be removed in v1.8.
|
| 246 |
+
Please use ``LightningModule.on_load_checkpoint`` instead.
|
| 247 |
+
"""
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def _convert_loaded_hparams(model_args: dict, hparams_type: Optional[Union[Callable, str]] = None) -> object:
|
| 251 |
+
"""Convert hparams according given type in callable or string (past) format."""
|
| 252 |
+
# if not hparams type define
|
| 253 |
+
if not hparams_type:
|
| 254 |
+
return model_args
|
| 255 |
+
# if past checkpoint loaded, convert str to callable
|
| 256 |
+
if isinstance(hparams_type, str):
|
| 257 |
+
hparams_type = AttributeDict
|
| 258 |
+
# convert hparams
|
| 259 |
+
return hparams_type(model_args)
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def update_hparams(hparams: dict, updates: dict) -> None:
|
| 263 |
+
"""Overrides hparams with new values.
|
| 264 |
+
|
| 265 |
+
>>> hparams = {'c': 4}
|
| 266 |
+
>>> update_hparams(hparams, {'a': {'b': 2}, 'c': 1})
|
| 267 |
+
>>> hparams['a']['b'], hparams['c']
|
| 268 |
+
(2, 1)
|
| 269 |
+
>>> update_hparams(hparams, {'a': {'b': 4}, 'c': 7})
|
| 270 |
+
>>> hparams['a']['b'], hparams['c']
|
| 271 |
+
(4, 7)
|
| 272 |
+
|
| 273 |
+
Args:
|
| 274 |
+
hparams: the original params and also target object
|
| 275 |
+
updates: new params to be used as update
|
| 276 |
+
"""
|
| 277 |
+
for k, v in updates.items():
|
| 278 |
+
# if missing, add the key
|
| 279 |
+
if k not in hparams:
|
| 280 |
+
hparams[k] = v
|
| 281 |
+
continue
|
| 282 |
+
|
| 283 |
+
# recurse if dictionary
|
| 284 |
+
if isinstance(v, dict):
|
| 285 |
+
update_hparams(hparams[k], updates[k])
|
| 286 |
+
else:
|
| 287 |
+
# update the value
|
| 288 |
+
hparams.update({k: v})
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def load_hparams_from_tags_csv(tags_csv: str) -> Dict[str, Any]:
|
| 292 |
+
"""Load hparams from a file.
|
| 293 |
+
|
| 294 |
+
>>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here')
|
| 295 |
+
>>> path_csv = os.path.join('.', 'testing-hparams.csv')
|
| 296 |
+
>>> save_hparams_to_tags_csv(path_csv, hparams)
|
| 297 |
+
>>> hparams_new = load_hparams_from_tags_csv(path_csv)
|
| 298 |
+
>>> vars(hparams) == hparams_new
|
| 299 |
+
True
|
| 300 |
+
>>> os.remove(path_csv)
|
| 301 |
+
"""
|
| 302 |
+
fs = get_filesystem(tags_csv)
|
| 303 |
+
if not fs.exists(tags_csv):
|
| 304 |
+
rank_zero_warn(f"Missing Tags: {tags_csv}.", category=RuntimeWarning)
|
| 305 |
+
return {}
|
| 306 |
+
|
| 307 |
+
with fs.open(tags_csv, "r", newline="") as fp:
|
| 308 |
+
csv_reader = csv.reader(fp, delimiter=",")
|
| 309 |
+
tags = {row[0]: convert(row[1]) for row in list(csv_reader)[1:]}
|
| 310 |
+
|
| 311 |
+
return tags
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def save_hparams_to_tags_csv(tags_csv: str, hparams: Union[dict, Namespace]) -> None:
|
| 315 |
+
fs = get_filesystem(tags_csv)
|
| 316 |
+
if not fs.isdir(os.path.dirname(tags_csv)):
|
| 317 |
+
raise RuntimeError(f"Missing folder: {os.path.dirname(tags_csv)}.")
|
| 318 |
+
|
| 319 |
+
if isinstance(hparams, Namespace):
|
| 320 |
+
hparams = vars(hparams)
|
| 321 |
+
|
| 322 |
+
with fs.open(tags_csv, "w", newline="") as fp:
|
| 323 |
+
fieldnames = ["key", "value"]
|
| 324 |
+
writer = csv.DictWriter(fp, fieldnames=fieldnames)
|
| 325 |
+
writer.writerow({"key": "key", "value": "value"})
|
| 326 |
+
for k, v in hparams.items():
|
| 327 |
+
writer.writerow({"key": k, "value": v})
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def load_hparams_from_yaml(config_yaml: str, use_omegaconf: bool = True) -> Dict[str, Any]:
|
| 331 |
+
"""Load hparams from a file.
|
| 332 |
+
|
| 333 |
+
Args:
|
| 334 |
+
config_yaml: Path to config yaml file
|
| 335 |
+
use_omegaconf: If omegaconf is available and ``use_omegaconf=True``,
|
| 336 |
+
the hparams will be converted to ``DictConfig`` if possible.
|
| 337 |
+
|
| 338 |
+
>>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here')
|
| 339 |
+
>>> path_yaml = './testing-hparams.yaml'
|
| 340 |
+
>>> save_hparams_to_yaml(path_yaml, hparams)
|
| 341 |
+
>>> hparams_new = load_hparams_from_yaml(path_yaml)
|
| 342 |
+
>>> vars(hparams) == hparams_new
|
| 343 |
+
True
|
| 344 |
+
>>> os.remove(path_yaml)
|
| 345 |
+
"""
|
| 346 |
+
fs = get_filesystem(config_yaml)
|
| 347 |
+
if not fs.exists(config_yaml):
|
| 348 |
+
rank_zero_warn(f"Missing Tags: {config_yaml}.", category=RuntimeWarning)
|
| 349 |
+
return {}
|
| 350 |
+
|
| 351 |
+
with fs.open(config_yaml, "r") as fp:
|
| 352 |
+
hparams = yaml.full_load(fp)
|
| 353 |
+
|
| 354 |
+
if _OMEGACONF_AVAILABLE:
|
| 355 |
+
if use_omegaconf:
|
| 356 |
+
try:
|
| 357 |
+
return OmegaConf.create(hparams)
|
| 358 |
+
except (UnsupportedValueType, ValidationError):
|
| 359 |
+
pass
|
| 360 |
+
return hparams
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace], use_omegaconf: bool = True) -> None:
|
| 364 |
+
"""
|
| 365 |
+
Args:
|
| 366 |
+
config_yaml: path to new YAML file
|
| 367 |
+
hparams: parameters to be saved
|
| 368 |
+
use_omegaconf: If omegaconf is available and ``use_omegaconf=True``,
|
| 369 |
+
the hparams will be converted to ``DictConfig`` if possible.
|
| 370 |
+
|
| 371 |
+
"""
|
| 372 |
+
fs = get_filesystem(config_yaml)
|
| 373 |
+
if not fs.isdir(os.path.dirname(config_yaml)):
|
| 374 |
+
raise RuntimeError(f"Missing folder: {os.path.dirname(config_yaml)}.")
|
| 375 |
+
|
| 376 |
+
# convert Namespace or AD to dict
|
| 377 |
+
if isinstance(hparams, Namespace):
|
| 378 |
+
hparams = vars(hparams)
|
| 379 |
+
elif isinstance(hparams, AttributeDict):
|
| 380 |
+
hparams = dict(hparams)
|
| 381 |
+
|
| 382 |
+
# saving with OmegaConf objects
|
| 383 |
+
if _OMEGACONF_AVAILABLE and use_omegaconf:
|
| 384 |
+
# deepcopy: hparams from user shouldn't be resolved
|
| 385 |
+
hparams = deepcopy(hparams)
|
| 386 |
+
hparams = apply_to_collection(hparams, DictConfig, OmegaConf.to_container, resolve=True)
|
| 387 |
+
with fs.open(config_yaml, "w", encoding="utf-8") as fp:
|
| 388 |
+
try:
|
| 389 |
+
OmegaConf.save(hparams, fp)
|
| 390 |
+
return
|
| 391 |
+
except (UnsupportedValueType, ValidationError):
|
| 392 |
+
pass
|
| 393 |
+
|
| 394 |
+
if not isinstance(hparams, dict):
|
| 395 |
+
raise TypeError("hparams must be dictionary")
|
| 396 |
+
|
| 397 |
+
hparams_allowed = {}
|
| 398 |
+
# drop parameters which contain some strange datatypes as fsspec
|
| 399 |
+
for k, v in hparams.items():
|
| 400 |
+
try:
|
| 401 |
+
v = v.name if isinstance(v, Enum) else v
|
| 402 |
+
yaml.dump(v)
|
| 403 |
+
except TypeError:
|
| 404 |
+
warn(f"Skipping '{k}' parameter because it is not possible to safely dump to YAML.")
|
| 405 |
+
hparams[k] = type(v).__name__
|
| 406 |
+
else:
|
| 407 |
+
hparams_allowed[k] = v
|
| 408 |
+
|
| 409 |
+
# saving the standard way
|
| 410 |
+
with fs.open(config_yaml, "w", newline="") as fp:
|
| 411 |
+
yaml.dump(hparams_allowed, fp)
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
def convert(val: str) -> Union[int, float, bool, str]:
|
| 415 |
+
try:
|
| 416 |
+
return ast.literal_eval(val)
|
| 417 |
+
except (ValueError, SyntaxError) as err:
|
| 418 |
+
log.debug(err)
|
| 419 |
+
return val
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/distributed/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright The PyTorch Lightning team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from pytorch_lightning.distributed.dist import LightningDistributed # noqa: F401
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/distributed/dist.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright The PyTorch Lightning team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import Any
|
| 15 |
+
|
| 16 |
+
import torch.distributed
|
| 17 |
+
|
| 18 |
+
from pytorch_lightning.utilities import rank_zero_deprecation
|
| 19 |
+
from pytorch_lightning.utilities.distributed import group as _group
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class LightningDistributed:
|
| 23 |
+
"""
|
| 24 |
+
.. deprecated:: v1.5
|
| 25 |
+
This class is deprecated in v1.5 and will be removed in v1.7.
|
| 26 |
+
The broadcast logic will be moved to the :class:`DDPStrategy` and :class`DDPSpawnStrategy` classes.
|
| 27 |
+
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, rank=None, device=None):
|
| 31 |
+
rank_zero_deprecation(
|
| 32 |
+
"LightningDistributed is deprecated in v1.5 and will be removed in v1.7."
|
| 33 |
+
"Broadcast logic is implemented directly in the :class:`Strategy` implementations."
|
| 34 |
+
)
|
| 35 |
+
self.rank = rank
|
| 36 |
+
self.device = device
|
| 37 |
+
|
| 38 |
+
def broadcast(self, obj: Any, group=_group.WORLD):
|
| 39 |
+
# always wrap into a list so it can be broadcasted.
|
| 40 |
+
obj = [obj]
|
| 41 |
+
|
| 42 |
+
if self.rank != 0:
|
| 43 |
+
obj = [None] * len(obj)
|
| 44 |
+
|
| 45 |
+
torch.distributed.broadcast_object_list(obj, 0, group=group or _group.WORLD)
|
| 46 |
+
|
| 47 |
+
return obj[0]
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (3.51 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/__pycache__/layer_sync.cpython-38.pyc
ADDED
|
Binary file (3.19 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright The PyTorch Lightning team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from pytorch_lightning.plugins.environments.bagua_environment import BaguaEnvironment # noqa: F401
|
| 15 |
+
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment # noqa: F401
|
| 16 |
+
from pytorch_lightning.plugins.environments.kubeflow_environment import KubeflowEnvironment # noqa: F401
|
| 17 |
+
from pytorch_lightning.plugins.environments.lightning_environment import LightningEnvironment # noqa: F401
|
| 18 |
+
from pytorch_lightning.plugins.environments.lsf_environment import LSFEnvironment # noqa: F401
|
| 19 |
+
from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment # noqa: F401
|
| 20 |
+
from pytorch_lightning.plugins.environments.torchelastic_environment import TorchElasticEnvironment # noqa: F401
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (897 Bytes). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/__pycache__/bagua_environment.cpython-38.pyc
ADDED
|
Binary file (2.78 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/bagua_environment.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright The PyTorch Lightning team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import logging
|
| 16 |
+
import os
|
| 17 |
+
|
| 18 |
+
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
| 19 |
+
|
| 20 |
+
log = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class BaguaEnvironment(ClusterEnvironment):
|
| 24 |
+
"""Environment for distributed training with `Bagua <https://tutorials.baguasys.com/>`_"""
|
| 25 |
+
|
| 26 |
+
@property
|
| 27 |
+
def creates_processes_externally(self) -> bool:
|
| 28 |
+
return True
|
| 29 |
+
|
| 30 |
+
@property
|
| 31 |
+
def main_address(self) -> str:
|
| 32 |
+
return os.environ.get("MASTER_ADDR", "127.0.0.1")
|
| 33 |
+
|
| 34 |
+
@property
|
| 35 |
+
def main_port(self) -> int:
|
| 36 |
+
return int(os.environ.get("MASTER_PORT", -1))
|
| 37 |
+
|
| 38 |
+
@property
|
| 39 |
+
def service_port(self) -> int:
|
| 40 |
+
return int(os.environ.get("BAGUA_SERVICE_PORT", -1))
|
| 41 |
+
|
| 42 |
+
@staticmethod
|
| 43 |
+
def detect() -> bool:
|
| 44 |
+
return "BAGUA_SERVICE_PORT" in os.environ
|
| 45 |
+
|
| 46 |
+
def world_size(self) -> int:
|
| 47 |
+
return int(os.environ["WORLD_SIZE"])
|
| 48 |
+
|
| 49 |
+
def set_world_size(self, size: int) -> None:
|
| 50 |
+
log.debug("`BaguaEnvironment.set_world_size` was called, but setting world size is not allowed. Ignored.")
|
| 51 |
+
|
| 52 |
+
def global_rank(self) -> int:
|
| 53 |
+
return int(os.environ["RANK"])
|
| 54 |
+
|
| 55 |
+
def set_global_rank(self, rank: int) -> None:
|
| 56 |
+
log.debug("`BaguaEnvironment.set_global_rank` was called, but setting global rank is not allowed. Ignored.")
|
| 57 |
+
|
| 58 |
+
def local_rank(self) -> int:
|
| 59 |
+
return int(os.environ.get("LOCAL_RANK", 0))
|
| 60 |
+
|
| 61 |
+
def node_rank(self) -> int:
|
| 62 |
+
return int(os.environ.get("NODE_RANK", 0))
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/cluster_environment.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright The PyTorch Lightning team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from abc import ABC, abstractmethod
|
| 15 |
+
from typing import Any, Type
|
| 16 |
+
|
| 17 |
+
from pytorch_lightning.utilities import rank_zero_deprecation
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class ClusterEnvironment(ABC):
|
| 21 |
+
"""Specification of a cluster environment."""
|
| 22 |
+
|
| 23 |
+
def __new__(cls, *args: Any, **kwargs: Any) -> "ClusterEnvironment":
|
| 24 |
+
# TODO: remove in 1.7
|
| 25 |
+
_check_for_deprecated_methods(cls)
|
| 26 |
+
return super().__new__(cls)
|
| 27 |
+
|
| 28 |
+
@property
|
| 29 |
+
@abstractmethod
|
| 30 |
+
def creates_processes_externally(self) -> bool:
|
| 31 |
+
"""Whether the environment creates the subprocesses or not."""
|
| 32 |
+
|
| 33 |
+
@property
|
| 34 |
+
@abstractmethod
|
| 35 |
+
def main_address(self) -> str:
|
| 36 |
+
"""The main address through which all processes connect and communicate."""
|
| 37 |
+
|
| 38 |
+
@property
|
| 39 |
+
@abstractmethod
|
| 40 |
+
def main_port(self) -> int:
|
| 41 |
+
"""An open and configured port in the main node through which all processes communicate."""
|
| 42 |
+
|
| 43 |
+
@staticmethod
|
| 44 |
+
@abstractmethod
|
| 45 |
+
def detect() -> bool:
|
| 46 |
+
"""Detects the environment settings corresponding to this cluster and returns ``True`` if they match."""
|
| 47 |
+
|
| 48 |
+
@abstractmethod
|
| 49 |
+
def world_size(self) -> int:
|
| 50 |
+
"""The number of processes across all devices and nodes."""
|
| 51 |
+
|
| 52 |
+
@abstractmethod
|
| 53 |
+
def set_world_size(self, size: int) -> None:
|
| 54 |
+
pass
|
| 55 |
+
|
| 56 |
+
@abstractmethod
|
| 57 |
+
def global_rank(self) -> int:
|
| 58 |
+
"""The rank (index) of the currently running process across all nodes and devices."""
|
| 59 |
+
|
| 60 |
+
@abstractmethod
|
| 61 |
+
def set_global_rank(self, rank: int) -> None:
|
| 62 |
+
pass
|
| 63 |
+
|
| 64 |
+
@abstractmethod
|
| 65 |
+
def local_rank(self) -> int:
|
| 66 |
+
"""The rank (index) of the currently running process inside of the current node."""
|
| 67 |
+
|
| 68 |
+
@abstractmethod
|
| 69 |
+
def node_rank(self) -> int:
|
| 70 |
+
"""The rank (index) of the node on which the current process runs."""
|
| 71 |
+
|
| 72 |
+
def teardown(self) -> None:
|
| 73 |
+
"""Clean up any state set after execution finishes."""
|
| 74 |
+
pass
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def _check_for_deprecated_methods(cls: Type[ClusterEnvironment]) -> None:
|
| 78 |
+
if hasattr(cls, "master_address") and callable(cls.master_address):
|
| 79 |
+
rank_zero_deprecation(
|
| 80 |
+
f"`{cls.__name__}.master_address` has been deprecated in v1.6 and will be removed in v1.7."
|
| 81 |
+
" Implement the property `main_address` instead (do not forget to add the `@property` decorator)."
|
| 82 |
+
)
|
| 83 |
+
if hasattr(cls, "master_port") and callable(cls.master_port):
|
| 84 |
+
rank_zero_deprecation(
|
| 85 |
+
f"`{cls.__name__}.master_port` has been deprecated in v1.6 and will be removed in v1.7."
|
| 86 |
+
" Implement the property `main_port` instead (do not forget to add the `@property` decorator)."
|
| 87 |
+
)
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/kubeflow_environment.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright The PyTorch Lightning team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import logging
|
| 16 |
+
import os
|
| 17 |
+
|
| 18 |
+
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
| 19 |
+
from pytorch_lightning.utilities import rank_zero_deprecation
|
| 20 |
+
|
| 21 |
+
log = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class KubeflowEnvironment(ClusterEnvironment):
|
| 25 |
+
"""Environment for distributed training using the `PyTorchJob`_ operator from `Kubeflow`_
|
| 26 |
+
|
| 27 |
+
.. _PyTorchJob: https://www.kubeflow.org/docs/components/training/pytorch/
|
| 28 |
+
.. _Kubeflow: https://www.kubeflow.org
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(self) -> None:
|
| 32 |
+
super().__init__()
|
| 33 |
+
# TODO: remove in 1.7
|
| 34 |
+
if hasattr(self, "is_using_kubeflow") and callable(self.is_using_kubeflow):
|
| 35 |
+
rank_zero_deprecation(
|
| 36 |
+
f"`{self.__class__.__name__}.is_using_kubeflow` has been deprecated in v1.6 and will be removed in"
|
| 37 |
+
f" v1.7. Implement the static method `detect()` instead (do not forget to add the `@staticmethod`"
|
| 38 |
+
f" decorator)."
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
@property
|
| 42 |
+
def creates_processes_externally(self) -> bool:
|
| 43 |
+
return True
|
| 44 |
+
|
| 45 |
+
@property
|
| 46 |
+
def main_address(self) -> str:
|
| 47 |
+
return os.environ["MASTER_ADDR"]
|
| 48 |
+
|
| 49 |
+
@property
|
| 50 |
+
def main_port(self) -> int:
|
| 51 |
+
return int(os.environ["MASTER_PORT"])
|
| 52 |
+
|
| 53 |
+
@staticmethod
|
| 54 |
+
def detect() -> bool:
|
| 55 |
+
"""Returns ``True`` if the current process was launched using Kubeflow PyTorchJob."""
|
| 56 |
+
required_env_vars = {"KUBERNETES_PORT", "MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "RANK"}
|
| 57 |
+
# torchelastic sets these. Make sure we're not in torchelastic
|
| 58 |
+
excluded_env_vars = {"GROUP_RANK", "LOCAL_RANK", "LOCAL_WORLD_SIZE"}
|
| 59 |
+
env_vars = os.environ.keys()
|
| 60 |
+
return required_env_vars.issubset(env_vars) and excluded_env_vars.isdisjoint(env_vars)
|
| 61 |
+
|
| 62 |
+
def world_size(self) -> int:
|
| 63 |
+
return int(os.environ["WORLD_SIZE"])
|
| 64 |
+
|
| 65 |
+
def set_world_size(self, size: int) -> None:
|
| 66 |
+
log.debug("KubeflowEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")
|
| 67 |
+
|
| 68 |
+
def global_rank(self) -> int:
|
| 69 |
+
return int(os.environ["RANK"])
|
| 70 |
+
|
| 71 |
+
def set_global_rank(self, rank: int) -> None:
|
| 72 |
+
log.debug("KubeflowEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")
|
| 73 |
+
|
| 74 |
+
def local_rank(self) -> int:
|
| 75 |
+
return 0
|
| 76 |
+
|
| 77 |
+
def node_rank(self) -> int:
|
| 78 |
+
return self.global_rank()
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/lightning_environment.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright The PyTorch Lightning team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import socket
|
| 17 |
+
|
| 18 |
+
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
| 19 |
+
from pytorch_lightning.utilities.rank_zero import rank_zero_only
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class LightningEnvironment(ClusterEnvironment):
|
| 23 |
+
"""The default environment used by Lightning for a single node or free cluster (not managed).
|
| 24 |
+
|
| 25 |
+
There are two modes the Lightning environment can operate with:
|
| 26 |
+
|
| 27 |
+
1. The user only launches the main process by :code:`python train.py ...` with no additional environment variables
|
| 28 |
+
set. Lightning will spawn new worker processes for distributed training in the current node.
|
| 29 |
+
2. The user launches all processes manually or with utilities like :code:`torch.distributed.launch`.
|
| 30 |
+
The appropriate environment variables need to be set, and at minimum :code:`LOCAL_RANK`.
|
| 31 |
+
|
| 32 |
+
If the main address and port are not provided, the default environment will choose them
|
| 33 |
+
automatically. It is recommended to use this default environment for single-node distributed
|
| 34 |
+
training as it provides a convenient way to launch the training script.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(self) -> None:
|
| 38 |
+
super().__init__()
|
| 39 |
+
self._main_port: int = -1
|
| 40 |
+
self._global_rank: int = 0
|
| 41 |
+
self._world_size: int = 1
|
| 42 |
+
|
| 43 |
+
@property
|
| 44 |
+
def creates_processes_externally(self) -> bool:
|
| 45 |
+
"""Returns whether the cluster creates the processes or not.
|
| 46 |
+
|
| 47 |
+
If at least :code:`LOCAL_RANK` is available as environment variable, Lightning assumes the user acts as the
|
| 48 |
+
process launcher/job scheduler and Lightning will not launch new processes.
|
| 49 |
+
"""
|
| 50 |
+
return "LOCAL_RANK" in os.environ
|
| 51 |
+
|
| 52 |
+
@property
|
| 53 |
+
def main_address(self) -> str:
|
| 54 |
+
return os.environ.get("MASTER_ADDR", "127.0.0.1")
|
| 55 |
+
|
| 56 |
+
@property
|
| 57 |
+
def main_port(self) -> int:
|
| 58 |
+
if self._main_port == -1:
|
| 59 |
+
self._main_port = int(os.environ.get("MASTER_PORT", find_free_network_port()))
|
| 60 |
+
return self._main_port
|
| 61 |
+
|
| 62 |
+
@staticmethod
|
| 63 |
+
def detect() -> bool:
|
| 64 |
+
return True
|
| 65 |
+
|
| 66 |
+
def world_size(self) -> int:
|
| 67 |
+
return self._world_size
|
| 68 |
+
|
| 69 |
+
def set_world_size(self, size: int) -> None:
|
| 70 |
+
self._world_size = size
|
| 71 |
+
|
| 72 |
+
def global_rank(self) -> int:
|
| 73 |
+
return self._global_rank
|
| 74 |
+
|
| 75 |
+
def set_global_rank(self, rank: int) -> None:
|
| 76 |
+
self._global_rank = rank
|
| 77 |
+
rank_zero_only.rank = rank
|
| 78 |
+
|
| 79 |
+
def local_rank(self) -> int:
|
| 80 |
+
return int(os.environ.get("LOCAL_RANK", 0))
|
| 81 |
+
|
| 82 |
+
def node_rank(self) -> int:
|
| 83 |
+
group_rank = os.environ.get("GROUP_RANK", 0)
|
| 84 |
+
return int(os.environ.get("NODE_RANK", group_rank))
|
| 85 |
+
|
| 86 |
+
def teardown(self) -> None:
|
| 87 |
+
if "WORLD_SIZE" in os.environ:
|
| 88 |
+
del os.environ["WORLD_SIZE"]
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def find_free_network_port() -> int:
|
| 92 |
+
"""Finds a free port on localhost.
|
| 93 |
+
|
| 94 |
+
It is useful in single-node training when we don't want to connect to a real main node but have to set the
|
| 95 |
+
`MASTER_PORT` environment variable.
|
| 96 |
+
"""
|
| 97 |
+
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
| 98 |
+
s.bind(("", 0))
|
| 99 |
+
port = s.getsockname()[1]
|
| 100 |
+
s.close()
|
| 101 |
+
return port
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/lsf_environment.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright The PyTorch Lightning team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import socket
|
| 17 |
+
from typing import Dict, List
|
| 18 |
+
|
| 19 |
+
from pytorch_lightning import _logger as log
|
| 20 |
+
from pytorch_lightning.plugins.environments import ClusterEnvironment
|
| 21 |
+
from pytorch_lightning.utilities import rank_zero_deprecation
|
| 22 |
+
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class LSFEnvironment(ClusterEnvironment):
|
| 26 |
+
"""An environment for running on clusters managed by the LSF resource manager.
|
| 27 |
+
|
| 28 |
+
It is expected that any execution using this ClusterEnvironment was executed
|
| 29 |
+
using the Job Step Manager i.e. ``jsrun``.
|
| 30 |
+
|
| 31 |
+
This plugin expects the following environment variables:
|
| 32 |
+
|
| 33 |
+
``LSB_JOBID``
|
| 34 |
+
The LSF assigned job ID
|
| 35 |
+
|
| 36 |
+
``LSB_DJOB_RANKFILE``
|
| 37 |
+
The OpenMPI compatible rank file for the LSF job
|
| 38 |
+
|
| 39 |
+
``JSM_NAMESPACE_LOCAL_RANK``
|
| 40 |
+
The node local rank for the task. This environment variable is set by ``jsrun``
|
| 41 |
+
|
| 42 |
+
``JSM_NAMESPACE_SIZE``
|
| 43 |
+
The world size for the task. This environment variable is set by ``jsrun``
|
| 44 |
+
|
| 45 |
+
``JSM_NAMESPACE_RANK``
|
| 46 |
+
The global rank for the task. This environment variable is set by ``jsrun``
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
def __init__(self) -> None:
|
| 50 |
+
super().__init__()
|
| 51 |
+
# TODO: remove in 1.7
|
| 52 |
+
if hasattr(self, "is_using_lsf") and callable(self.is_using_lsf):
|
| 53 |
+
rank_zero_deprecation(
|
| 54 |
+
f"`{self.__class__.__name__}.is_using_lsf` has been deprecated in v1.6 and will be removed in v1.7."
|
| 55 |
+
" Implement the static method `detect()` instead (do not forget to add the `@staticmethod` decorator)."
|
| 56 |
+
)
|
| 57 |
+
self._main_address = self._get_main_address()
|
| 58 |
+
self._main_port = self._get_main_port()
|
| 59 |
+
self._node_rank = self._get_node_rank()
|
| 60 |
+
self._set_init_progress_group_env_vars()
|
| 61 |
+
|
| 62 |
+
def _set_init_progress_group_env_vars(self) -> None:
|
| 63 |
+
# set environment variables needed for initializing torch distributed process group
|
| 64 |
+
os.environ["MASTER_ADDR"] = str(self._main_address)
|
| 65 |
+
log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")
|
| 66 |
+
os.environ["MASTER_PORT"] = str(self._main_port)
|
| 67 |
+
log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")
|
| 68 |
+
|
| 69 |
+
@property
|
| 70 |
+
def creates_processes_externally(self) -> bool:
|
| 71 |
+
"""LSF creates subprocesses, i.e., PyTorch Lightning does not need to spawn them."""
|
| 72 |
+
return True
|
| 73 |
+
|
| 74 |
+
@property
|
| 75 |
+
def main_address(self) -> str:
|
| 76 |
+
"""The main address is read from an OpenMPI host rank file in the environment variable
|
| 77 |
+
``LSB_DJOB_RANKFILE``."""
|
| 78 |
+
return self._main_address
|
| 79 |
+
|
| 80 |
+
@property
|
| 81 |
+
def main_port(self) -> int:
|
| 82 |
+
"""The main port is calculated from the LSF job ID."""
|
| 83 |
+
return self._main_port
|
| 84 |
+
|
| 85 |
+
@staticmethod
|
| 86 |
+
def detect() -> bool:
|
| 87 |
+
"""Returns ``True`` if the current process was launched using the ``jsrun`` command."""
|
| 88 |
+
required_env_vars = {"LSB_JOBID", "LSB_DJOB_RANKFILE", "JSM_NAMESPACE_LOCAL_RANK", "JSM_NAMESPACE_SIZE"}
|
| 89 |
+
return required_env_vars.issubset(os.environ.keys())
|
| 90 |
+
|
| 91 |
+
def world_size(self) -> int:
|
| 92 |
+
"""The world size is read from the environment variable ``JSM_NAMESPACE_SIZE``."""
|
| 93 |
+
world_size = os.environ.get("JSM_NAMESPACE_SIZE")
|
| 94 |
+
if world_size is None:
|
| 95 |
+
raise ValueError(
|
| 96 |
+
"Cannot determine world size. Environment variable `JSM_NAMESPACE_SIZE` not found."
|
| 97 |
+
"Make sure you run your executable with `jsrun`."
|
| 98 |
+
)
|
| 99 |
+
return int(world_size)
|
| 100 |
+
|
| 101 |
+
def set_world_size(self, size: int) -> None:
|
| 102 |
+
log.debug("LSFEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")
|
| 103 |
+
|
| 104 |
+
def global_rank(self) -> int:
|
| 105 |
+
"""The world size is read from the environment variable ``JSM_NAMESPACE_RANK``."""
|
| 106 |
+
global_rank = os.environ.get("JSM_NAMESPACE_RANK")
|
| 107 |
+
if global_rank is None:
|
| 108 |
+
raise ValueError(
|
| 109 |
+
"Cannot determine global rank. Environment variable `JSM_NAMESPACE_RANK` not found."
|
| 110 |
+
"Make sure you run your executable with `jsrun`."
|
| 111 |
+
)
|
| 112 |
+
return int(global_rank)
|
| 113 |
+
|
| 114 |
+
def set_global_rank(self, rank: int) -> None:
|
| 115 |
+
log.debug("LSFEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")
|
| 116 |
+
|
| 117 |
+
def local_rank(self) -> int:
|
| 118 |
+
"""The local rank is read from the environment variable `JSM_NAMESPACE_LOCAL_RANK`."""
|
| 119 |
+
local_rank = os.environ.get("JSM_NAMESPACE_LOCAL_RANK")
|
| 120 |
+
if local_rank is None:
|
| 121 |
+
raise ValueError(
|
| 122 |
+
"Cannot determine local rank. Environment variable `JSM_NAMESPACE_LOCAL_RANK` not found."
|
| 123 |
+
"Make sure you run your executable with `jsrun`."
|
| 124 |
+
)
|
| 125 |
+
return int(local_rank)
|
| 126 |
+
|
| 127 |
+
def node_rank(self) -> int:
|
| 128 |
+
"""The node rank is determined by the position of the current hostname in the OpenMPI host rank file stored
|
| 129 |
+
in ``LSB_DJOB_RANKFILE``."""
|
| 130 |
+
return self._node_rank
|
| 131 |
+
|
| 132 |
+
def _get_node_rank(self) -> int:
|
| 133 |
+
"""A helper method for getting the node rank.
|
| 134 |
+
|
| 135 |
+
The node rank is determined by the position of the current node in the list of hosts used in the job. This is
|
| 136 |
+
calculated by reading all hosts from ``LSB_DJOB_RANKFILE`` and finding this node's hostname in the list.
|
| 137 |
+
"""
|
| 138 |
+
hosts = self._read_hosts()
|
| 139 |
+
count: Dict[str, int] = {}
|
| 140 |
+
for host in hosts:
|
| 141 |
+
if host not in count:
|
| 142 |
+
count[host] = len(count)
|
| 143 |
+
return count[socket.gethostname()]
|
| 144 |
+
|
| 145 |
+
@staticmethod
|
| 146 |
+
def _read_hosts() -> List[str]:
|
| 147 |
+
"""Read compute hosts that are a part of the compute job.
|
| 148 |
+
|
| 149 |
+
LSF uses the Job Step Manager (JSM) to manage job steps. Job steps are executed by the JSM from "launch" nodes.
|
| 150 |
+
Each job is assigned a launch node. This launch node will be the first node in the list contained in
|
| 151 |
+
``LSB_DJOB_RANKFILE``.
|
| 152 |
+
"""
|
| 153 |
+
var = "LSB_DJOB_RANKFILE"
|
| 154 |
+
rankfile = os.environ.get(var)
|
| 155 |
+
if rankfile is None:
|
| 156 |
+
raise ValueError("Did not find the environment variable `LSB_DJOB_RANKFILE`")
|
| 157 |
+
if not rankfile:
|
| 158 |
+
raise ValueError("The environment variable `LSB_DJOB_RANKFILE` is empty")
|
| 159 |
+
|
| 160 |
+
fs = get_filesystem(rankfile)
|
| 161 |
+
with fs.open(rankfile, "r") as f:
|
| 162 |
+
ret = [line.strip() for line in f]
|
| 163 |
+
# remove the launch node (i.e. the first node in LSB_DJOB_RANKFILE) from the list
|
| 164 |
+
return ret[1:]
|
| 165 |
+
|
| 166 |
+
def _get_main_address(self) -> str:
|
| 167 |
+
"""A helper for getting the main address.
|
| 168 |
+
|
| 169 |
+
The main address is assigned to the first node in the list of nodes used for the job.
|
| 170 |
+
"""
|
| 171 |
+
hosts = self._read_hosts()
|
| 172 |
+
return hosts[0]
|
| 173 |
+
|
| 174 |
+
@staticmethod
|
| 175 |
+
def _get_main_port() -> int:
|
| 176 |
+
"""A helper function for accessing the main port.
|
| 177 |
+
|
| 178 |
+
Uses the LSF job ID so all ranks can compute the main port.
|
| 179 |
+
"""
|
| 180 |
+
# check for user-specified main port
|
| 181 |
+
if "MASTER_PORT" in os.environ:
|
| 182 |
+
log.debug(f"Using externally specified main port: {os.environ['MASTER_PORT']}")
|
| 183 |
+
return int(os.environ["MASTER_PORT"])
|
| 184 |
+
if "LSB_JOBID" in os.environ:
|
| 185 |
+
port = int(os.environ["LSB_JOBID"])
|
| 186 |
+
# all ports should be in the 10k+ range
|
| 187 |
+
port = port % 1000 + 10000
|
| 188 |
+
log.debug(f"calculated LSF main port: {port}")
|
| 189 |
+
return port
|
| 190 |
+
raise ValueError("Could not find job id in environment variable LSB_JOBID")
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/slurm_environment.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright The PyTorch Lightning team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import logging
|
| 16 |
+
import os
|
| 17 |
+
import re
|
| 18 |
+
from typing import Optional
|
| 19 |
+
|
| 20 |
+
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
| 21 |
+
|
| 22 |
+
log = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class SLURMEnvironment(ClusterEnvironment):
|
| 26 |
+
"""Cluster environment for training on a cluster managed by SLURM.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
auto_requeue: Whether automatic job resubmission is enabled or not. How and under which conditions a job gets
|
| 30 |
+
rescheduled gets determined by the owner of this plugin.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(self, auto_requeue: bool = True) -> None:
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.auto_requeue = auto_requeue
|
| 36 |
+
|
| 37 |
+
@property
|
| 38 |
+
def creates_processes_externally(self) -> bool:
|
| 39 |
+
return True
|
| 40 |
+
|
| 41 |
+
@property
|
| 42 |
+
def main_address(self) -> str:
|
| 43 |
+
# figure out the root node addr
|
| 44 |
+
slurm_nodelist = os.environ.get("SLURM_NODELIST")
|
| 45 |
+
if slurm_nodelist:
|
| 46 |
+
root_node = slurm_nodelist.split(" ")[0].split(",")[0]
|
| 47 |
+
else:
|
| 48 |
+
root_node = "127.0.0.1"
|
| 49 |
+
|
| 50 |
+
root_node = self.resolve_root_node_address(root_node)
|
| 51 |
+
os.environ["MASTER_ADDR"] = root_node
|
| 52 |
+
log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")
|
| 53 |
+
return root_node
|
| 54 |
+
|
| 55 |
+
@property
|
| 56 |
+
def main_port(self) -> int:
|
| 57 |
+
# -----------------------
|
| 58 |
+
# SLURM JOB = PORT number
|
| 59 |
+
# -----------------------
|
| 60 |
+
# this way every process knows what port to use
|
| 61 |
+
job_id = os.environ.get("SLURM_JOB_ID")
|
| 62 |
+
if job_id is not None:
|
| 63 |
+
# use the last 4 numbers in the job id as the id
|
| 64 |
+
default_port = job_id[-4:]
|
| 65 |
+
# all ports should be in the 10k+ range
|
| 66 |
+
default_port = int(default_port) + 15000
|
| 67 |
+
else:
|
| 68 |
+
default_port = 12910
|
| 69 |
+
|
| 70 |
+
# -----------------------
|
| 71 |
+
# PORT NUMBER = MASTER_PORT
|
| 72 |
+
# -----------------------
|
| 73 |
+
# in case the user passed it in
|
| 74 |
+
if "MASTER_PORT" in os.environ:
|
| 75 |
+
default_port = int(os.environ["MASTER_PORT"])
|
| 76 |
+
else:
|
| 77 |
+
os.environ["MASTER_PORT"] = str(default_port)
|
| 78 |
+
|
| 79 |
+
log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")
|
| 80 |
+
return default_port
|
| 81 |
+
|
| 82 |
+
@staticmethod
|
| 83 |
+
def detect() -> bool:
|
| 84 |
+
"""Returns ``True`` if the current process was launched on a SLURM cluster."""
|
| 85 |
+
return "SLURM_NTASKS" in os.environ
|
| 86 |
+
|
| 87 |
+
@staticmethod
|
| 88 |
+
def job_name() -> Optional[str]:
|
| 89 |
+
return os.environ.get("SLURM_JOB_NAME")
|
| 90 |
+
|
| 91 |
+
@staticmethod
|
| 92 |
+
def job_id() -> Optional[int]:
|
| 93 |
+
# in interactive mode, don't make logs use the same job id
|
| 94 |
+
in_slurm_interactive_mode = SLURMEnvironment.job_name() == "bash"
|
| 95 |
+
if in_slurm_interactive_mode:
|
| 96 |
+
return None
|
| 97 |
+
|
| 98 |
+
job_id = os.environ.get("SLURM_JOB_ID")
|
| 99 |
+
if job_id is None:
|
| 100 |
+
return None
|
| 101 |
+
try:
|
| 102 |
+
return int(job_id)
|
| 103 |
+
except ValueError:
|
| 104 |
+
return None
|
| 105 |
+
|
| 106 |
+
def world_size(self) -> int:
|
| 107 |
+
return int(os.environ["SLURM_NTASKS"])
|
| 108 |
+
|
| 109 |
+
def set_world_size(self, size: int) -> None:
|
| 110 |
+
log.debug("SLURMEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")
|
| 111 |
+
|
| 112 |
+
def global_rank(self) -> int:
|
| 113 |
+
return int(os.environ["SLURM_PROCID"])
|
| 114 |
+
|
| 115 |
+
def set_global_rank(self, rank: int) -> None:
|
| 116 |
+
log.debug("SLURMEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")
|
| 117 |
+
|
| 118 |
+
def local_rank(self) -> int:
|
| 119 |
+
return int(os.environ["SLURM_LOCALID"])
|
| 120 |
+
|
| 121 |
+
def node_rank(self) -> int:
|
| 122 |
+
return int(os.environ["SLURM_NODEID"])
|
| 123 |
+
|
| 124 |
+
def resolve_root_node_address(self, root_node: str) -> str:
|
| 125 |
+
if "[" in root_node:
|
| 126 |
+
name, numbers = root_node.split("[", maxsplit=1)
|
| 127 |
+
number = numbers.split(",", maxsplit=1)[0]
|
| 128 |
+
if "-" in number:
|
| 129 |
+
number = number.split("-")[0]
|
| 130 |
+
|
| 131 |
+
number = re.sub("[^0-9]", "", number)
|
| 132 |
+
root_node = name + number
|
| 133 |
+
|
| 134 |
+
return root_node
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/torchelastic_environment.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright The PyTorch Lightning team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import logging
|
| 16 |
+
import os
|
| 17 |
+
|
| 18 |
+
import torch.distributed
|
| 19 |
+
|
| 20 |
+
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
| 21 |
+
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_9_1
|
| 22 |
+
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
|
| 23 |
+
|
| 24 |
+
log = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class TorchElasticEnvironment(ClusterEnvironment):
|
| 28 |
+
"""Environment for fault-tolerant and elastic training with `torchelastic <https://pytorch.org/elastic/>`_"""
|
| 29 |
+
|
| 30 |
+
def __init__(self) -> None:
|
| 31 |
+
super().__init__()
|
| 32 |
+
# TODO: remove in 1.7
|
| 33 |
+
if hasattr(self, "is_using_torchelastic") and callable(self.is_using_torchelastic):
|
| 34 |
+
rank_zero_deprecation(
|
| 35 |
+
f"`{self.__class__.__name__}.is_using_torchelastic` has been deprecated in v1.6 and will be removed in"
|
| 36 |
+
" v1.7. Implement the static method `detect()` instead (do not forget to add the `@staticmethod`"
|
| 37 |
+
" decorator)."
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
@property
|
| 41 |
+
def creates_processes_externally(self) -> bool:
|
| 42 |
+
return True
|
| 43 |
+
|
| 44 |
+
@property
|
| 45 |
+
def main_address(self) -> str:
|
| 46 |
+
if "MASTER_ADDR" not in os.environ:
|
| 47 |
+
rank_zero_warn("MASTER_ADDR environment variable is not defined. Set as localhost")
|
| 48 |
+
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
| 49 |
+
log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")
|
| 50 |
+
return os.environ["MASTER_ADDR"]
|
| 51 |
+
|
| 52 |
+
@property
|
| 53 |
+
def main_port(self) -> int:
|
| 54 |
+
if "MASTER_PORT" not in os.environ:
|
| 55 |
+
rank_zero_warn("MASTER_PORT environment variable is not defined. Set as 12910")
|
| 56 |
+
os.environ["MASTER_PORT"] = "12910"
|
| 57 |
+
log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")
|
| 58 |
+
|
| 59 |
+
return int(os.environ["MASTER_PORT"])
|
| 60 |
+
|
| 61 |
+
@staticmethod
|
| 62 |
+
def detect() -> bool:
|
| 63 |
+
"""Returns ``True`` if the current process was launched using the torchelastic command."""
|
| 64 |
+
if _TORCH_GREATER_EQUAL_1_9_1:
|
| 65 |
+
# if not available (for example on MacOS), `is_torchelastic_launched` is not defined
|
| 66 |
+
return torch.distributed.is_available() and torch.distributed.is_torchelastic_launched()
|
| 67 |
+
required_env_vars = {"RANK", "GROUP_RANK", "LOCAL_RANK", "LOCAL_WORLD_SIZE"}
|
| 68 |
+
return required_env_vars.issubset(os.environ.keys())
|
| 69 |
+
|
| 70 |
+
def world_size(self) -> int:
|
| 71 |
+
return int(os.environ["WORLD_SIZE"])
|
| 72 |
+
|
| 73 |
+
def set_world_size(self, size: int) -> None:
|
| 74 |
+
log.debug("TorchElasticEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")
|
| 75 |
+
|
| 76 |
+
def global_rank(self) -> int:
|
| 77 |
+
return int(os.environ["RANK"])
|
| 78 |
+
|
| 79 |
+
def set_global_rank(self, rank: int) -> None:
|
| 80 |
+
log.debug(
|
| 81 |
+
"TorchElasticEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored."
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
def local_rank(self) -> int:
|
| 85 |
+
return int(os.environ["LOCAL_RANK"])
|
| 86 |
+
|
| 87 |
+
def node_rank(self) -> int:
|
| 88 |
+
return int(os.environ.get("GROUP_RANK", 0))
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/io/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright The PyTorch Lightning team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO # noqa: F401
|
| 15 |
+
from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO # noqa: F401
|
| 16 |
+
from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO # noqa: F401
|
| 17 |
+
from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO # noqa: F401
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/io/__pycache__/xla_plugin.cpython-38.pyc
ADDED
|
Binary file (2.38 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/io/checkpoint_plugin.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright The PyTorch Lightning team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from abc import ABC, abstractmethod
|
| 15 |
+
from typing import Any, Dict, Optional
|
| 16 |
+
|
| 17 |
+
from pytorch_lightning.utilities.types import _PATH
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class CheckpointIO(ABC):
|
| 21 |
+
"""Interface to save/load checkpoints as they are saved through the ``Strategy``.
|
| 22 |
+
|
| 23 |
+
Typically most plugins either use the Torch based IO Plugin; ``TorchCheckpointIO`` but may
|
| 24 |
+
require particular handling depending on the plugin.
|
| 25 |
+
|
| 26 |
+
In addition, you can pass a custom ``CheckpointIO`` by extending this class and passing it
|
| 27 |
+
to the Trainer, i.e ``Trainer(plugins=[MyCustomCheckpointIO()])``.
|
| 28 |
+
|
| 29 |
+
.. note::
|
| 30 |
+
|
| 31 |
+
For some plugins, it is not possible to use a custom checkpoint plugin as checkpointing logic is not
|
| 32 |
+
modifiable.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
@abstractmethod
|
| 36 |
+
def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None:
|
| 37 |
+
"""Save model/training states as a checkpoint file through state-dump and file-write.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
checkpoint: dict containing model and trainer state
|
| 41 |
+
path: write-target path
|
| 42 |
+
storage_options: Optional parameters when saving the model/training states.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
@abstractmethod
|
| 46 |
+
def load_checkpoint(self, path: _PATH, storage_options: Optional[Any] = None) -> Dict[str, Any]:
|
| 47 |
+
"""Load checkpoint from a path when resuming or loading ckpt for test/validate/predict stages.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
path: Path to checkpoint
|
| 51 |
+
storage_options: Optional parameters when loading the model/training states.
|
| 52 |
+
|
| 53 |
+
Returns: The loaded checkpoint.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
@abstractmethod
|
| 57 |
+
def remove_checkpoint(self, path: _PATH) -> None:
|
| 58 |
+
"""Remove checkpoint file from the filesystem.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
path: Path to checkpoint
|
| 62 |
+
"""
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/io/hpu_plugin.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright The PyTorch Lightning team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
from typing import Any, Dict, Optional
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO
|
| 21 |
+
from pytorch_lightning.utilities.apply_func import move_data_to_device
|
| 22 |
+
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
|
| 23 |
+
from pytorch_lightning.utilities.types import _PATH
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class HPUCheckpointIO(TorchCheckpointIO):
|
| 27 |
+
"""CheckpointIO to save checkpoints for HPU training strategies."""
|
| 28 |
+
|
| 29 |
+
def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None:
|
| 30 |
+
"""Save model/training states as a checkpoint file through state-dump and file-write.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
checkpoint: dict containing model and trainer state
|
| 34 |
+
path: write-target path
|
| 35 |
+
storage_options: not used in ``XLACheckpointIO.save_checkpoint``
|
| 36 |
+
|
| 37 |
+
Raises:
|
| 38 |
+
TypeError:
|
| 39 |
+
If ``storage_options`` arg is passed in
|
| 40 |
+
"""
|
| 41 |
+
if storage_options is not None:
|
| 42 |
+
raise TypeError(
|
| 43 |
+
"`Trainer.save_checkpoint(..., storage_options=...)` with `storage_options` arg"
|
| 44 |
+
f" is not supported for `{self.__class__.__name__}`. Please implement your custom `CheckpointIO`"
|
| 45 |
+
" to define how you'd like to use `storage_options`."
|
| 46 |
+
)
|
| 47 |
+
fs = get_filesystem(path)
|
| 48 |
+
fs.makedirs(os.path.dirname(path), exist_ok=True)
|
| 49 |
+
|
| 50 |
+
checkpoint = move_data_to_device(checkpoint, torch.device("cpu"))
|
| 51 |
+
# write the checkpoint dictionary to the provided path
|
| 52 |
+
atomic_save(checkpoint, path)
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/io/torch_plugin.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright The PyTorch Lightning team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import logging
|
| 15 |
+
import os
|
| 16 |
+
from typing import Any, Callable, Dict, Optional
|
| 17 |
+
|
| 18 |
+
import pytorch_lightning as pl
|
| 19 |
+
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
| 20 |
+
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
|
| 21 |
+
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
| 22 |
+
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
| 23 |
+
from pytorch_lightning.utilities.types import _PATH
|
| 24 |
+
|
| 25 |
+
log = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class TorchCheckpointIO(CheckpointIO):
|
| 29 |
+
"""CheckpointIO that utilizes :func:`torch.save` and :func:`torch.load` to save and load checkpoints
|
| 30 |
+
respectively, common for most use cases."""
|
| 31 |
+
|
| 32 |
+
def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None:
|
| 33 |
+
"""Save model/training states as a checkpoint file through state-dump and file-write.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
checkpoint: dict containing model and trainer state
|
| 37 |
+
path: write-target path
|
| 38 |
+
storage_options: not used in ``TorchCheckpointIO.save_checkpoint``
|
| 39 |
+
|
| 40 |
+
Raises:
|
| 41 |
+
TypeError:
|
| 42 |
+
If ``storage_options`` arg is passed in
|
| 43 |
+
"""
|
| 44 |
+
if storage_options is not None:
|
| 45 |
+
raise TypeError(
|
| 46 |
+
"`Trainer.save_checkpoint(..., storage_options=...)` with `storage_options` arg"
|
| 47 |
+
f" is not supported for `{self.__class__.__name__}`. Please implement your custom `CheckpointIO`"
|
| 48 |
+
" to define how you'd like to use `storage_options`."
|
| 49 |
+
)
|
| 50 |
+
fs = get_filesystem(path)
|
| 51 |
+
fs.makedirs(os.path.dirname(path), exist_ok=True)
|
| 52 |
+
try:
|
| 53 |
+
# write the checkpoint dictionary on the file
|
| 54 |
+
atomic_save(checkpoint, path)
|
| 55 |
+
except AttributeError as err:
|
| 56 |
+
# todo (sean): is this try catch necessary still?
|
| 57 |
+
# https://github.com/PyTorchLightning/pytorch-lightning/pull/431
|
| 58 |
+
key = pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY
|
| 59 |
+
checkpoint.pop(key, None)
|
| 60 |
+
rank_zero_warn(f"Warning, `{key}` dropped from checkpoint. An attribute is not picklable: {err}")
|
| 61 |
+
atomic_save(checkpoint, path)
|
| 62 |
+
|
| 63 |
+
def load_checkpoint(
|
| 64 |
+
self, path: _PATH, map_location: Optional[Callable] = lambda storage, loc: storage
|
| 65 |
+
) -> Dict[str, Any]:
|
| 66 |
+
"""Loads checkpoint using :func:`torch.load`, with additional handling for ``fsspec`` remote loading of
|
| 67 |
+
files.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
path: Path to checkpoint
|
| 71 |
+
map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage
|
| 72 |
+
locations.
|
| 73 |
+
|
| 74 |
+
Returns: The loaded checkpoint.
|
| 75 |
+
|
| 76 |
+
Raises:
|
| 77 |
+
FileNotFoundError: If ``path`` is not found by the ``fsspec`` filesystem
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
# Try to read the checkpoint at `path`. If not exist, do not restore checkpoint.
|
| 81 |
+
fs = get_filesystem(path)
|
| 82 |
+
if not fs.exists(path):
|
| 83 |
+
raise FileNotFoundError(f"Checkpoint at {path} not found. Aborting training.")
|
| 84 |
+
|
| 85 |
+
return pl_load(path, map_location=map_location)
|
| 86 |
+
|
| 87 |
+
def remove_checkpoint(self, path: _PATH) -> None:
|
| 88 |
+
"""Remove checkpoint file from the filesystem.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
path: Path to checkpoint
|
| 92 |
+
"""
|
| 93 |
+
fs = get_filesystem(path)
|
| 94 |
+
if fs.exists(path):
|
| 95 |
+
fs.rm(path, recursive=True)
|
| 96 |
+
log.debug(f"Removed checkpoint: {path}")
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/io/xla_plugin.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright The PyTorch Lightning team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import os
|
| 15 |
+
from typing import Any, Dict, Optional
|
| 16 |
+
|
| 17 |
+
from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO
|
| 18 |
+
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE
|
| 19 |
+
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
| 20 |
+
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
| 21 |
+
from pytorch_lightning.utilities.types import _PATH
|
| 22 |
+
|
| 23 |
+
if _TPU_AVAILABLE:
|
| 24 |
+
import torch_xla.core.xla_model as xm
|
| 25 |
+
|
| 26 |
+
if _OMEGACONF_AVAILABLE:
|
| 27 |
+
from omegaconf import DictConfig, ListConfig, OmegaConf
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class XLACheckpointIO(TorchCheckpointIO):
|
| 31 |
+
"""CheckpointIO that utilizes :func:`xm.save` to save checkpoints for TPU training strategies."""
|
| 32 |
+
|
| 33 |
+
def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None:
|
| 34 |
+
"""Save model/training states as a checkpoint file through state-dump and file-write.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
checkpoint: dict containing model and trainer state
|
| 38 |
+
path: write-target path
|
| 39 |
+
storage_options: not used in ``XLACheckpointIO.save_checkpoint``
|
| 40 |
+
|
| 41 |
+
Raises:
|
| 42 |
+
TypeError:
|
| 43 |
+
If ``storage_options`` arg is passed in
|
| 44 |
+
"""
|
| 45 |
+
if storage_options is not None:
|
| 46 |
+
raise TypeError(
|
| 47 |
+
"`Trainer.save_checkpoint(..., storage_options=...)` with `storage_options` arg"
|
| 48 |
+
f" is not supported for `{self.__class__.__name__}`. Please implement your custom `CheckpointIO`"
|
| 49 |
+
" to define how you'd like to use `storage_options`."
|
| 50 |
+
)
|
| 51 |
+
fs = get_filesystem(path)
|
| 52 |
+
fs.makedirs(os.path.dirname(path), exist_ok=True)
|
| 53 |
+
# Todo: TypeError: 'mappingproxy' object does not support item assignment
|
| 54 |
+
# Ref: https://github.com/pytorch/xla/issues/2773
|
| 55 |
+
if _OMEGACONF_AVAILABLE:
|
| 56 |
+
checkpoint = apply_to_collection(checkpoint, (DictConfig, ListConfig), OmegaConf.to_container)
|
| 57 |
+
xm.save({k: v for k, v in checkpoint.items() if k != "callbacks"}, path)
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright The PyTorch Lightning team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401
|
| 15 |
+
from pytorch_lightning.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin # noqa: F401
|
| 16 |
+
from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin # noqa: F401
|
| 17 |
+
from pytorch_lightning.plugins.precision.fully_sharded_native_amp import ( # noqa: F401
|
| 18 |
+
FullyShardedNativeMixedPrecisionPlugin,
|
| 19 |
+
)
|
| 20 |
+
from pytorch_lightning.plugins.precision.hpu import HPUPrecisionPlugin # noqa: F401
|
| 21 |
+
from pytorch_lightning.plugins.precision.ipu import IPUPrecisionPlugin # noqa: F401
|
| 22 |
+
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin # noqa: F401
|
| 23 |
+
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401
|
| 24 |
+
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401
|
| 25 |
+
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin # noqa: F401
|
| 26 |
+
from pytorch_lightning.plugins.precision.tpu import TPUPrecisionPlugin # noqa: F401
|
| 27 |
+
from pytorch_lightning.plugins.precision.tpu_bf16 import TPUBf16PrecisionPlugin # noqa: F401
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__pycache__/apex_amp.cpython-38.pyc
ADDED
|
Binary file (3.74 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__pycache__/deepspeed.cpython-38.pyc
ADDED
|
Binary file (3.86 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__pycache__/double.cpython-38.pyc
ADDED
|
Binary file (3.99 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__pycache__/fully_sharded_native_amp.cpython-38.pyc
ADDED
|
Binary file (999 Bytes). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__pycache__/mixed.cpython-38.pyc
ADDED
|
Binary file (719 Bytes). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__pycache__/native_amp.cpython-38.pyc
ADDED
|
Binary file (4.31 kB). View file
|
|
|